torch_mist.estimators.discriminative.factories
Module Contents
Functions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- torch_mist.estimators.discriminative.factories.alpha_tuba(x_dim: int, y_dim: int, hidden_dims: List[int], alpha: float = 0.01, learnable_baseline: bool = True, critic_type: str = SEPARABLE_CRITIC, neg_samples: int = 0, **kwargs) torch_mist.estimators.discriminative.implementations.AlphaTUBA
- torch_mist.estimators.discriminative.factories.flo(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = -1, n_shared_layers: int = -1, critic_type: str = SEPARABLE_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.FLO
- torch_mist.estimators.discriminative.factories.infonce(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str = SEPARABLE_CRITIC, neg_samples: int = 0, **kwargs) torch_mist.estimators.discriminative.implementations.InfoNCE
- torch_mist.estimators.discriminative.factories.js(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.JS
- torch_mist.estimators.discriminative.factories.dummy_discriminative(neg_samples: int = 1, **kwargs) torch_mist.estimators.discriminative.implementations.DummyDiscriminativeMIEstimator
- torch_mist.estimators.discriminative.factories.mine(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str = JOINT_CRITIC, neg_samples: int = 1, gamma: float = 0.9, **kwargs) torch_mist.estimators.discriminative.implementations.MINE
- torch_mist.estimators.discriminative.factories.nwj(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.NWJ
- torch_mist.estimators.discriminative.factories.smile(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, tau: float = 5.0, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.SMILE
- torch_mist.estimators.discriminative.factories.tuba(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.TUBA