torch_mist.critic.joint

Module Contents

Classes

JointCritic

class torch_mist.critic.joint.JointCritic(joint_net: torch.nn.Module)

Bases: torch_mist.critic.base.Critic

forward(x, y) torch.Tensor

Compute the value of the critic evaluated at the pair (x, y) :param x: a tensor representing x :param y: a tensor representing y :return: The value of the ratio estimator on the given pair