torch_mist.critic

Submodules

Package Contents

Classes

Critic

JointCritic

SeparableCritic

Functions

critic_nn(→ torch_mist.critic.base.Critic)

joint_critic() → torch_mist.critic.joint.JointCritic)

separable_critic(, normalize, temperature, k_dim)

class torch_mist.critic.Critic

Bases: torch.nn.Module

abstract forward(x: torch.Tensor, y: torch.Tensor) torch.Tensor

Compute the value of the critic evaluated at the pair (x, y) :param x: a tensor representing x :param y: a tensor representing y :return: The value of the ratio estimator on the given pair

class torch_mist.critic.JointCritic(joint_net: torch.nn.Module)

Bases: torch_mist.critic.base.Critic

forward(x, y) torch.Tensor

Compute the value of the critic evaluated at the pair (x, y) :param x: a tensor representing x :param y: a tensor representing y :return: The value of the ratio estimator on the given pair

class torch_mist.critic.SeparableCritic(f_x: torch.nn.Module | None = None, f_y: torch.nn.Module | None = None, temperature: float = 1.0)

Bases: torch_mist.critic.base.Critic

forward(x: torch.Tensor, y: torch.Tensor) torch.Tensor
torch_mist.critic.critic_nn(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str, **kwargs) torch_mist.critic.base.Critic
torch_mist.critic.joint_critic(x_dim: int, y_dim: int, hidden_dims: List[int] | None, nonlinearity: torch.nn.Module = nn.ReLU(True)) torch_mist.critic.joint.JointCritic
torch_mist.critic.separable_critic(x_dim: int, y_dim: int, hidden_dims: List[int], projection_head: str = ASYMMETRIC_HEADS, nonlinearity: torch.nn.Module = nn.ReLU(True), normalize: bool = False, temperature: float = 1.0, k_dim: int | None = None) torch_mist.critic.separable.SeparableCritic