torch_mist.estimators.discriminative
Subpackages
torch_mist.estimators.discriminative.implementationstorch_mist.estimators.discriminative.implementations.alpha_tubatorch_mist.estimators.discriminative.implementations.dummytorch_mist.estimators.discriminative.implementations.flotorch_mist.estimators.discriminative.implementations.infoncetorch_mist.estimators.discriminative.implementations.jstorch_mist.estimators.discriminative.implementations.minetorch_mist.estimators.discriminative.implementations.nwjtorch_mist.estimators.discriminative.implementations.smiletorch_mist.estimators.discriminative.implementations.tuba
Submodules
Package Contents
Classes
Functions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Attributes
- class torch_mist.estimators.discriminative.DiscriminativeMIEstimator(critic: torch_mist.critic.Critic, neg_samples: int = 1)
Bases:
torch_mist.estimators.base.MIEstimator- lower_bound: bool = True
- infomax_gradient: Dict[str, bool]
- unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- n_negatives_to_use(N: int)
- sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- abstract _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- approx_log_partition(x: torch.Tensor, y: torch.Tensor, x_: torch.Tensor, y_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- __repr__()
- class torch_mist.estimators.discriminative.BaselineDiscriminativeMIEstimator(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.Baseline, neg_samples: int = 1)
Bases:
DiscriminativeMIEstimator- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- __repr__()
- class torch_mist.estimators.discriminative.ConstantBaseline(value: float = 0)
Bases:
Baseline- forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor
- torch_mist.estimators.discriminative.baseline_nn(x_dim: int, hidden_dims: List[int], nonlinearity: Callable = nn.ReLU(True)) torch_mist.baseline.base.LearnableBaseline
- torch_mist.estimators.discriminative.JOINT_CRITIC = 'joint'
- torch_mist.estimators.discriminative.SEPARABLE_CRITIC = 'separable'
- torch_mist.estimators.discriminative.critic_nn(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str, **kwargs) torch_mist.critic.base.Critic
- class torch_mist.estimators.discriminative.AlphaTUBA(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.LearnableBaseline, alpha: float = 0.01, neg_samples: int = -1)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator
- class torch_mist.estimators.discriminative.FLO(critic: torch_mist.critic.Critic, normalized_critic: torch_mist.critic.Critic, neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator- unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- __repr__()
- class torch_mist.estimators.discriminative.InfoNCE(critic: torch_mist.critic.Critic, neg_samples: int = 0)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.tensor, log_w: torch.Tensor | None)
- class torch_mist.estimators.discriminative.JS(critic: torch_mist.critic.Critic, neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor | None
- class torch_mist.estimators.discriminative.MINE(critic: torch_mist.critic.Critic, neg_samples: int = 1, gamma: float = 0.9)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator- lower_bound = False
- train_baseline()
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- class torch_mist.estimators.discriminative.NWJ(critic: torch_mist.critic.Critic, neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator
- class torch_mist.estimators.discriminative.SMILE(critic: torch_mist.critic.Critic, neg_samples: int = 1, tau: float = 5.0)
Bases:
torch_mist.estimators.discriminative.implementations.js.JS,torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- class torch_mist.estimators.discriminative.TUBA(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.LearnableBaseline, neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator
- class torch_mist.estimators.discriminative.DummyDiscriminativeMIEstimator(neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- torch_mist.estimators.discriminative.alpha_tuba(x_dim: int, y_dim: int, hidden_dims: List[int], alpha: float = 0.01, learnable_baseline: bool = True, critic_type: str = SEPARABLE_CRITIC, neg_samples: int = 0, **kwargs) torch_mist.estimators.discriminative.implementations.AlphaTUBA
- torch_mist.estimators.discriminative.flo(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = -1, n_shared_layers: int = -1, critic_type: str = SEPARABLE_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.FLO
- torch_mist.estimators.discriminative.infonce(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str = SEPARABLE_CRITIC, neg_samples: int = 0, **kwargs) torch_mist.estimators.discriminative.implementations.InfoNCE
- torch_mist.estimators.discriminative.js(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.JS
- torch_mist.estimators.discriminative.dummy_discriminative(neg_samples: int = 1, **kwargs) torch_mist.estimators.discriminative.implementations.DummyDiscriminativeMIEstimator
- torch_mist.estimators.discriminative.mine(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str = JOINT_CRITIC, neg_samples: int = 1, gamma: float = 0.9, **kwargs) torch_mist.estimators.discriminative.implementations.MINE
- torch_mist.estimators.discriminative.nwj(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.NWJ
- torch_mist.estimators.discriminative.smile(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, tau: float = 5.0, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.SMILE
- torch_mist.estimators.discriminative.tuba(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.TUBA
- class torch_mist.estimators.discriminative.AlphaTUBA(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.LearnableBaseline, alpha: float = 0.01, neg_samples: int = -1)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator
- class torch_mist.estimators.discriminative.FLO(critic: torch_mist.critic.Critic, normalized_critic: torch_mist.critic.Critic, neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator- unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- __repr__()
- class torch_mist.estimators.discriminative.InfoNCE(critic: torch_mist.critic.Critic, neg_samples: int = 0)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.tensor, log_w: torch.Tensor | None)
- class torch_mist.estimators.discriminative.JS(critic: torch_mist.critic.Critic, neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor | None
- class torch_mist.estimators.discriminative.MINE(critic: torch_mist.critic.Critic, neg_samples: int = 1, gamma: float = 0.9)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator- lower_bound = False
- train_baseline()
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- class torch_mist.estimators.discriminative.NWJ(critic: torch_mist.critic.Critic, neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator
- class torch_mist.estimators.discriminative.SMILE(critic: torch_mist.critic.Critic, neg_samples: int = 1, tau: float = 5.0)
Bases:
torch_mist.estimators.discriminative.implementations.js.JS,torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- class torch_mist.estimators.discriminative.TUBA(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.LearnableBaseline, neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator
- class torch_mist.estimators.discriminative.DummyDiscriminativeMIEstimator(neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor