torch_mist.estimators.discriminative.implementations

Submodules

Package Contents

Classes

AlphaTUBA

FLO

InfoNCE

JS

MINE

NWJ

SMILE

TUBA

DummyDiscriminativeMIEstimator

class torch_mist.estimators.discriminative.implementations.AlphaTUBA(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.LearnableBaseline, alpha: float = 0.01, neg_samples: int = -1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

class torch_mist.estimators.discriminative.implementations.FLO(critic: torch_mist.critic.Critic, normalized_critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator

unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
__repr__()
class torch_mist.estimators.discriminative.implementations.InfoNCE(critic: torch_mist.critic.Critic, neg_samples: int = 0)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.tensor, log_w: torch.Tensor | None)
class torch_mist.estimators.discriminative.implementations.JS(critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor | None
class torch_mist.estimators.discriminative.implementations.MINE(critic: torch_mist.critic.Critic, neg_samples: int = 1, gamma: float = 0.9)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

lower_bound = False
train_baseline()
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.discriminative.implementations.NWJ(critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

class torch_mist.estimators.discriminative.implementations.SMILE(critic: torch_mist.critic.Critic, neg_samples: int = 1, tau: float = 5.0)

Bases: torch_mist.estimators.discriminative.implementations.js.JS, torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
class torch_mist.estimators.discriminative.implementations.TUBA(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.LearnableBaseline, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

class torch_mist.estimators.discriminative.implementations.DummyDiscriminativeMIEstimator(neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator

_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor