torch_mist.estimators.discriminative.implementations
Submodules
torch_mist.estimators.discriminative.implementations.alpha_tubatorch_mist.estimators.discriminative.implementations.dummytorch_mist.estimators.discriminative.implementations.flotorch_mist.estimators.discriminative.implementations.infoncetorch_mist.estimators.discriminative.implementations.jstorch_mist.estimators.discriminative.implementations.minetorch_mist.estimators.discriminative.implementations.nwjtorch_mist.estimators.discriminative.implementations.smiletorch_mist.estimators.discriminative.implementations.tuba
Package Contents
Classes
- class torch_mist.estimators.discriminative.implementations.AlphaTUBA(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.LearnableBaseline, alpha: float = 0.01, neg_samples: int = -1)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator
- class torch_mist.estimators.discriminative.implementations.FLO(critic: torch_mist.critic.Critic, normalized_critic: torch_mist.critic.Critic, neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator- unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- __repr__()
- class torch_mist.estimators.discriminative.implementations.InfoNCE(critic: torch_mist.critic.Critic, neg_samples: int = 0)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.tensor, log_w: torch.Tensor | None)
- class torch_mist.estimators.discriminative.implementations.JS(critic: torch_mist.critic.Critic, neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor | None
- class torch_mist.estimators.discriminative.implementations.MINE(critic: torch_mist.critic.Critic, neg_samples: int = 1, gamma: float = 0.9)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator- lower_bound = False
- train_baseline()
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- class torch_mist.estimators.discriminative.implementations.NWJ(critic: torch_mist.critic.Critic, neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator
- class torch_mist.estimators.discriminative.implementations.SMILE(critic: torch_mist.critic.Critic, neg_samples: int = 1, tau: float = 5.0)
Bases:
torch_mist.estimators.discriminative.implementations.js.JS,torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- class torch_mist.estimators.discriminative.implementations.TUBA(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.LearnableBaseline, neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator
- class torch_mist.estimators.discriminative.implementations.DummyDiscriminativeMIEstimator(neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor