torch_mist.estimators.discriminative.factories

Module Contents

Functions

alpha_tuba(...)

flo(...)

infonce(...)

js(...)

dummy_discriminative(...)

mine(...)

nwj(...)

smile(...)

tuba(...)

torch_mist.estimators.discriminative.factories.alpha_tuba(x_dim: int, y_dim: int, hidden_dims: List[int], alpha: float = 0.01, learnable_baseline: bool = True, critic_type: str = SEPARABLE_CRITIC, neg_samples: int = 0, **kwargs) torch_mist.estimators.discriminative.implementations.AlphaTUBA
torch_mist.estimators.discriminative.factories.flo(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = -1, n_shared_layers: int = -1, critic_type: str = SEPARABLE_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.FLO
torch_mist.estimators.discriminative.factories.infonce(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str = SEPARABLE_CRITIC, neg_samples: int = 0, **kwargs) torch_mist.estimators.discriminative.implementations.InfoNCE
torch_mist.estimators.discriminative.factories.js(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.JS
torch_mist.estimators.discriminative.factories.dummy_discriminative(neg_samples: int = 1, **kwargs) torch_mist.estimators.discriminative.implementations.DummyDiscriminativeMIEstimator
torch_mist.estimators.discriminative.factories.mine(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str = JOINT_CRITIC, neg_samples: int = 1, gamma: float = 0.9, **kwargs) torch_mist.estimators.discriminative.implementations.MINE
torch_mist.estimators.discriminative.factories.nwj(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.NWJ
torch_mist.estimators.discriminative.factories.smile(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, tau: float = 5.0, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.SMILE
torch_mist.estimators.discriminative.factories.tuba(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.TUBA