torch_mist.critic.factories

Module Contents

Functions

shared_joint_critics(, n_shared_layers, n_critics, ...)

shared_separable_critics(, normalize, temperature, ...)

shared_critic_nns(...)

separable_critic(, normalize, temperature, k_dim)

joint_critic() → torch_mist.critic.joint.JointCritic)

critic_nn(→ torch_mist.critic.base.Critic)

Attributes

SYMMETRIC_HEADS

ASYMMETRIC_HEADS

ONE_HEAD

POSSIBLE_HEADS

torch_mist.critic.factories.SYMMETRIC_HEADS = 'symmetric'
torch_mist.critic.factories.ASYMMETRIC_HEADS = 'asymmetric'
torch_mist.critic.factories.ONE_HEAD = 'one'
torch_mist.critic.factories.POSSIBLE_HEADS
torch_mist.critic.factories.shared_joint_critics(x_dim: int, y_dim: int, hidden_dims: List[int], nonlinearity: torch.nn.Module = nn.ReLU(True), n_shared_layers: int = -1, n_critics: int = 2, **kwargs) Tuple[torch_mist.critic.joint.JointCritic, Ellipsis]
torch_mist.critic.factories.shared_separable_critics(x_dim: int, y_dim: int, hidden_dims: List[int], projection_head: str = ASYMMETRIC_HEADS, nonlinearity: torch.nn.Module = nn.ReLU(True), normalize: bool = False, temperature: float = 1.0, n_critics: int = 2, n_shared_layers: int = -1, k_dim: int | None = None, **kwargs) Tuple[torch_mist.critic.separable.SeparableCritic, Ellipsis]
torch_mist.critic.factories.shared_critic_nns(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str, n_critics: int, n_shared_layers: int = -1, **kwargs) Tuple[torch_mist.critic.base.Critic, Ellipsis]
torch_mist.critic.factories.separable_critic(x_dim: int, y_dim: int, hidden_dims: List[int], projection_head: str = ASYMMETRIC_HEADS, nonlinearity: torch.nn.Module = nn.ReLU(True), normalize: bool = False, temperature: float = 1.0, k_dim: int | None = None) torch_mist.critic.separable.SeparableCritic
torch_mist.critic.factories.joint_critic(x_dim: int, y_dim: int, hidden_dims: List[int] | None, nonlinearity: torch.nn.Module = nn.ReLU(True)) torch_mist.critic.joint.JointCritic
torch_mist.critic.factories.critic_nn(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str, **kwargs) torch_mist.critic.base.Critic