torch_mist.estimators.discriminative

Subpackages

Submodules

Package Contents

Classes

DiscriminativeMIEstimator

BaselineDiscriminativeMIEstimator

ConstantBaseline

AlphaTUBA

FLO

InfoNCE

JS

MINE

NWJ

SMILE

TUBA

DummyDiscriminativeMIEstimator

AlphaTUBA

FLO

InfoNCE

JS

MINE

NWJ

SMILE

TUBA

DummyDiscriminativeMIEstimator

Functions

baseline_nn() → torch_mist.baseline.base.LearnableBaseline)

critic_nn(→ torch_mist.critic.base.Critic)

shared_critic_nns(...)

alpha_tuba(...)

flo(...)

infonce(...)

js(...)

dummy_discriminative(...)

mine(...)

nwj(...)

smile(...)

tuba(...)

Attributes

JOINT_CRITIC

SEPARABLE_CRITIC

class torch_mist.estimators.discriminative.DiscriminativeMIEstimator(critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.base.MIEstimator

lower_bound: bool = True
infomax_gradient: Dict[str, bool]
unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
n_negatives_to_use(N: int)
sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
approx_log_partition(x: torch.Tensor, y: torch.Tensor, x_: torch.Tensor, y_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
__repr__()
class torch_mist.estimators.discriminative.BaselineDiscriminativeMIEstimator(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.Baseline, neg_samples: int = 1)

Bases: DiscriminativeMIEstimator

_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
__repr__()
class torch_mist.estimators.discriminative.ConstantBaseline(value: float = 0)

Bases: Baseline

forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor
torch_mist.estimators.discriminative.baseline_nn(x_dim: int, hidden_dims: List[int], nonlinearity: Callable = nn.ReLU(True)) torch_mist.baseline.base.LearnableBaseline
torch_mist.estimators.discriminative.JOINT_CRITIC = 'joint'
torch_mist.estimators.discriminative.SEPARABLE_CRITIC = 'separable'
torch_mist.estimators.discriminative.critic_nn(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str, **kwargs) torch_mist.critic.base.Critic
torch_mist.estimators.discriminative.shared_critic_nns(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str, n_critics: int, n_shared_layers: int = -1, **kwargs) Tuple[torch_mist.critic.base.Critic, Ellipsis]
class torch_mist.estimators.discriminative.AlphaTUBA(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.LearnableBaseline, alpha: float = 0.01, neg_samples: int = -1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

class torch_mist.estimators.discriminative.FLO(critic: torch_mist.critic.Critic, normalized_critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator

unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
__repr__()
class torch_mist.estimators.discriminative.InfoNCE(critic: torch_mist.critic.Critic, neg_samples: int = 0)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.tensor, log_w: torch.Tensor | None)
class torch_mist.estimators.discriminative.JS(critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor | None
class torch_mist.estimators.discriminative.MINE(critic: torch_mist.critic.Critic, neg_samples: int = 1, gamma: float = 0.9)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

lower_bound = False
train_baseline()
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.discriminative.NWJ(critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

class torch_mist.estimators.discriminative.SMILE(critic: torch_mist.critic.Critic, neg_samples: int = 1, tau: float = 5.0)

Bases: torch_mist.estimators.discriminative.implementations.js.JS, torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
class torch_mist.estimators.discriminative.TUBA(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.LearnableBaseline, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

class torch_mist.estimators.discriminative.DummyDiscriminativeMIEstimator(neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator

_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
torch_mist.estimators.discriminative.alpha_tuba(x_dim: int, y_dim: int, hidden_dims: List[int], alpha: float = 0.01, learnable_baseline: bool = True, critic_type: str = SEPARABLE_CRITIC, neg_samples: int = 0, **kwargs) torch_mist.estimators.discriminative.implementations.AlphaTUBA
torch_mist.estimators.discriminative.flo(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = -1, n_shared_layers: int = -1, critic_type: str = SEPARABLE_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.FLO
torch_mist.estimators.discriminative.infonce(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str = SEPARABLE_CRITIC, neg_samples: int = 0, **kwargs) torch_mist.estimators.discriminative.implementations.InfoNCE
torch_mist.estimators.discriminative.js(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.JS
torch_mist.estimators.discriminative.dummy_discriminative(neg_samples: int = 1, **kwargs) torch_mist.estimators.discriminative.implementations.DummyDiscriminativeMIEstimator
torch_mist.estimators.discriminative.mine(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str = JOINT_CRITIC, neg_samples: int = 1, gamma: float = 0.9, **kwargs) torch_mist.estimators.discriminative.implementations.MINE
torch_mist.estimators.discriminative.nwj(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.NWJ
torch_mist.estimators.discriminative.smile(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, tau: float = 5.0, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.SMILE
torch_mist.estimators.discriminative.tuba(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.TUBA
class torch_mist.estimators.discriminative.AlphaTUBA(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.LearnableBaseline, alpha: float = 0.01, neg_samples: int = -1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

class torch_mist.estimators.discriminative.FLO(critic: torch_mist.critic.Critic, normalized_critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator

unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
__repr__()
class torch_mist.estimators.discriminative.InfoNCE(critic: torch_mist.critic.Critic, neg_samples: int = 0)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.tensor, log_w: torch.Tensor | None)
class torch_mist.estimators.discriminative.JS(critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor | None
class torch_mist.estimators.discriminative.MINE(critic: torch_mist.critic.Critic, neg_samples: int = 1, gamma: float = 0.9)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

lower_bound = False
train_baseline()
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.discriminative.NWJ(critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

class torch_mist.estimators.discriminative.SMILE(critic: torch_mist.critic.Critic, neg_samples: int = 1, tau: float = 5.0)

Bases: torch_mist.estimators.discriminative.implementations.js.JS, torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
class torch_mist.estimators.discriminative.TUBA(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.LearnableBaseline, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

class torch_mist.estimators.discriminative.DummyDiscriminativeMIEstimator(neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator

_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor