torch_mist.critic.factories
Module Contents
Functions
|
|
|
|
|
|
|
|
|
|
|
Attributes
- torch_mist.critic.factories.SYMMETRIC_HEADS = 'symmetric'
- torch_mist.critic.factories.ASYMMETRIC_HEADS = 'asymmetric'
- torch_mist.critic.factories.ONE_HEAD = 'one'
- torch_mist.critic.factories.POSSIBLE_HEADS
- torch_mist.critic.factories.separable_critic(x_dim: int, y_dim: int, hidden_dims: List[int], projection_head: str = ASYMMETRIC_HEADS, nonlinearity: torch.nn.Module = nn.ReLU(True), normalize: bool = False, temperature: float = 1.0, k_dim: int | None = None) torch_mist.critic.separable.SeparableCritic
- torch_mist.critic.factories.joint_critic(x_dim: int, y_dim: int, hidden_dims: List[int] | None, nonlinearity: torch.nn.Module = nn.ReLU(True)) torch_mist.critic.joint.JointCritic
- torch_mist.critic.factories.critic_nn(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str, **kwargs) torch_mist.critic.base.Critic