torch_mist.critic
Submodules
Package Contents
Classes
Functions
|
|
|
|
|
- class torch_mist.critic.Critic
Bases:
torch.nn.Module- abstract forward(x: torch.Tensor, y: torch.Tensor) torch.Tensor
Compute the value of the critic evaluated at the pair (x, y) :param x: a tensor representing x :param y: a tensor representing y :return: The value of the ratio estimator on the given pair
- class torch_mist.critic.JointCritic(joint_net: torch.nn.Module)
Bases:
torch_mist.critic.base.Critic- forward(x, y) torch.Tensor
Compute the value of the critic evaluated at the pair (x, y) :param x: a tensor representing x :param y: a tensor representing y :return: The value of the ratio estimator on the given pair
- class torch_mist.critic.SeparableCritic(f_x: torch.nn.Module | None = None, f_y: torch.nn.Module | None = None, temperature: float = 1.0)
Bases:
torch_mist.critic.base.Critic- forward(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- torch_mist.critic.critic_nn(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str, **kwargs) torch_mist.critic.base.Critic
- torch_mist.critic.joint_critic(x_dim: int, y_dim: int, hidden_dims: List[int] | None, nonlinearity: torch.nn.Module = nn.ReLU(True)) torch_mist.critic.joint.JointCritic
- torch_mist.critic.separable_critic(x_dim: int, y_dim: int, hidden_dims: List[int], projection_head: str = ASYMMETRIC_HEADS, nonlinearity: torch.nn.Module = nn.ReLU(True), normalize: bool = False, temperature: float = 1.0, k_dim: int | None = None) torch_mist.critic.separable.SeparableCritic