torch_mist.estimators

Subpackages

Submodules

Package Contents

Classes

MIEstimator

BA

CLUB

DoE

GM

L1Out

DummyGenerativeMIEstimator

AlphaTUBA

FLO

InfoNCE

JS

MINE

NWJ

SMILE

TUBA

DummyDiscriminativeMIEstimator

BinnedMIEstimator

PQ

TransformedMIEstimator

MIEstimator

DiscriminativeMIEstimator

HybridMIEstimator

PQHybridMIEstimator

ResampledHybridMIEstimator

ReweighedHybridMIEstimator

GenerativeMIEstimator

QuantizationFunction

MIEstimator

JointDistribution

BA

CLUB

DoE

GM

L1Out

DummyGenerativeMIEstimator

ConstantBaseline

AlphaTUBA

FLO

InfoNCE

JS

MINE

NWJ

SMILE

TUBA

DummyDiscriminativeMIEstimator

PQ

BinnedMIEstimator

QuantizationFunction

FlippedMIEstimator

MultiMIEstimator

Functions

cached_method(→ Callable[Ellipsis, T])

is_trainable(→ bool)

instantiate_estimator(...)

instantiate_quantization(...)

conditional_categorical(n_classes, context_dim, ...[, ...])

hybrid_pq(...)

resampled_hybrid(...)

reweighed_hybrid(...)

instantiate_estimator(...)

joint_transformed_normal(...)

ba(→ torch_mist.estimators.generative.implementations.BA)

club(...)

doe(→ torch_mist.estimators.generative.implementations.DoE)

dummy_generative(...)

gm(→ torch_mist.estimators.generative.implementations.GM)

l1out(...)

baseline_nn() → torch_mist.baseline.base.LearnableBaseline)

critic_nn(→ torch_mist.critic.base.Critic)

shared_critic_nns(...)

alpha_tuba(...)

flo(...)

infonce(...)

js(...)

dummy_discriminative(...)

mine(...)

nwj(...)

smile(...)

tuba(...)

instantiate_quantization(...)

binned(...)

pq(→ torch_mist.estimators.transformed.implementations.PQ)

flip_estimator(→ torch_mist.estimators.MIEstimator)

Attributes

DEFAULT_HIDDEN_DIMS

JOINT_CRITIC

SEPARABLE_CRITIC

class torch_mist.estimators.MIEstimator

Bases: torch_mist.nn.Model

infomax_gradient: Dict[str, bool]
abstract log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
mutual_information(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
abstract batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
forward(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
class torch_mist.estimators.BA(q_Y_given_X: pyro.distributions.ConditionalDistribution, entropy_y: torch.Tensor | None = None)

Bases: torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator

lower_bound: bool = True
infomax_gradient: Dict[str, bool]
mutual_information(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.CLUB(q_Y_given_X: pyro.distributions.ConditionalDistribution, neg_samples: int = 0)

Bases: torch_mist.estimators.generative.implementations.l1out.L1Out

infomax_gradient: Dict[str, bool]
approx_log_p_y(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.DoE(q_Y_given_X: pyro.distributions.ConditionalDistribution, q_Y: torch.distributions.Distribution)

Bases: torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator

infomax_gradient: Dict[str, bool]
approx_log_p_y(y: torch.Tensor, x: torch.Tensor | None = None) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
__repr__()
class torch_mist.estimators.GM(q_XY: torch_mist.distributions.joint.base.JointDistribution, q_Y: torch.distributions.Distribution | torch_mist.distributions.joint.base.JointDistribution, q_X: torch.distributions.Distribution | torch_mist.distributions.joint.base.JointDistribution)

Bases: torch_mist.estimators.generative.base.JointGenerativeMIEstimator

property q_X
property q_Y
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.L1Out(q_Y_given_X: pyro.distributions.ConditionalDistribution, neg_samples: int = -1)

Bases: torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator

infomax_gradient: Dict[str, bool]
_broadcast_log_p_y_given_x(x: torch.Tensor, y: torch.Tensor) torch.Tensor
approx_log_p_y(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.DummyGenerativeMIEstimator

Bases: torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator

lower_bound: bool = True
infomax_gradient: Dict[str, bool]
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.AlphaTUBA(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.LearnableBaseline, alpha: float = 0.01, neg_samples: int = -1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

class torch_mist.estimators.FLO(critic: torch_mist.critic.Critic, normalized_critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator

unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
__repr__()
class torch_mist.estimators.InfoNCE(critic: torch_mist.critic.Critic, neg_samples: int = 0)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.tensor, log_w: torch.Tensor | None)
class torch_mist.estimators.JS(critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor | None
class torch_mist.estimators.MINE(critic: torch_mist.critic.Critic, neg_samples: int = 1, gamma: float = 0.9)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

lower_bound = False
train_baseline()
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.NWJ(critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

class torch_mist.estimators.SMILE(critic: torch_mist.critic.Critic, neg_samples: int = 1, tau: float = 5.0)

Bases: torch_mist.estimators.discriminative.implementations.js.JS, torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
class torch_mist.estimators.TUBA(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.LearnableBaseline, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

class torch_mist.estimators.DummyDiscriminativeMIEstimator(neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator

_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
class torch_mist.estimators.BinnedMIEstimator(quantize_x: torch_mist.quantization.functions.QuantizationFunction | None = None, quantize_y: torch_mist.quantization.functions.QuantizationFunction | None = None, temperature: float = 1.0)

Bases: torch_mist.estimators.transformed.base.TransformedMIEstimator

lower_bound = True
class torch_mist.estimators.PQ(q_QY_given_X: pyro.distributions.ConditionalDistribution, quantize_y: torch_mist.quantization.functions.QuantizationFunction, temperature: float = 1.0)

Bases: torch_mist.estimators.transformed.base.TransformedMIEstimator

class torch_mist.estimators.TransformedMIEstimator(base_estimator: torch_mist.estimators.base.MIEstimator, transforms: Dict[str, Callable[[Any], Any]] | None = None, transforms_rename: Dict[Tuple[str, str], Callable[[Any], Any]] | None = None)

Bases: torch_mist.estimators.base.MIEstimator

transform(**variables) Dict[str, torch.Tensor]
_unfold_variables(*args, **variables) Dict[str, Any]
batch_loss(*args, **variables) torch.Tensor
log_ratio(*args, **variables) torch.Tensor | List[torch.Tensor]
unnormalized_log_ratio(*args, **variables) torch.Tensor
loss(*args, **variables) torch.Tensor
mutual_information(*args, **variables) torch.Tensor | Dict[str, torch.Tensor]
forward(*args, **kwargs) torch.Tensor
class torch_mist.estimators.MIEstimator

Bases: torch_mist.nn.Model

infomax_gradient: Dict[str, bool]
abstract log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
mutual_information(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
abstract batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
forward(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
class torch_mist.estimators.DiscriminativeMIEstimator(critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.base.MIEstimator

lower_bound: bool = True
infomax_gradient: Dict[str, bool]
unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
n_negatives_to_use(N: int)
sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
approx_log_partition(x: torch.Tensor, y: torch.Tensor, x_: torch.Tensor, y_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
__repr__()
torch_mist.estimators.cached_method(method: Callable[Ellipsis, T]) Callable[Ellipsis, T]
torch_mist.estimators.is_trainable(function: Any) bool
class torch_mist.estimators.HybridMIEstimator(generative_estimator: torch_mist.estimators.base.MIEstimator, discriminative_estimator: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator)

Bases: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator

unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor | None]
resampling_strategy()
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
__repr__()
class torch_mist.estimators.PQHybridMIEstimator(q_QY_given_X: pyro.distributions.ConditionalDistribution, quantize_y: torch_mist.quantization.QuantizationFunction, discriminative_estimator: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator, temperature: float = 1.0)

Bases: torch_mist.estimators.hybrid.base.HybridMIEstimator

property quantize_y: Callable[[torch.Tensor], torch.LongTensor]
disable_batch_validation()
enable_batch_validation()
sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
class torch_mist.estimators.ResampledHybridMIEstimator(generative_estimator: torch_mist.estimators.generative.base.GenerativeMIEstimator, discriminative_estimator: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator)

Bases: torch_mist.estimators.hybrid.base.HybridMIEstimator

sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
class torch_mist.estimators.ReweighedHybridMIEstimator(generative_estimator: torch_mist.estimators.base.MIEstimator, discriminative_estimator: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator)

Bases: torch_mist.estimators.hybrid.base.HybridMIEstimator

sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
torch_mist.estimators.instantiate_estimator(estimator_name: str, x_dim: int | None = None, y_dim: int | None = None, **kwargs) torch_mist.estimators.base.MIEstimator
class torch_mist.estimators.GenerativeMIEstimator

Bases: torch_mist.estimators.base.MIEstimator

unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract approx_log_p_y_given_x(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract approx_log_p_y(x: torch.Tensor, y: torch.Tensor) torch.Tensor
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.QuantizationFunction

Bases: torch.nn.Module

abstract property n_bins: int
abstract quantize(x: torch.Tensor) torch.LongTensor
forward(x: torch.Tensor) torch.LongTensor
torch_mist.estimators.instantiate_quantization(name: str, n_bins: int, **kwargs) torch_mist.quantization.functions.QuantizationFunction
torch_mist.estimators.conditional_categorical(n_classes: int, context_dim: int, hidden_dims: List[int], temperature: float = 1.0)
torch_mist.estimators.hybrid_pq(discriminative_estimator: torch_mist.estimators.discriminative.DiscriminativeMIEstimator | str, x_dim: int | None = None, y_dim: int | None = None, quantize_y: torch_mist.quantization.QuantizationFunction | str | None = None, hidden_dims: List[int] | None = None, q_QY_given_X: pyro.distributions.ConditionalDistribution | None = None, temperature: float = 0.1, n_bins: int | None = 32, quantization_params: Dict[str, Any] | None = None, **kwargs) torch_mist.estimators.hybrid.PQHybridMIEstimator
torch_mist.estimators.resampled_hybrid(discriminative_estimator: torch_mist.estimators.discriminative.DiscriminativeMIEstimator | str, generative_estimator: torch_mist.estimators.generative.GenerativeMIEstimator | str, x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, discriminative_params: Dict[str, Any] | None = None, generative_params: Dict[str, Any] | None = None, **kwargs) torch_mist.estimators.hybrid.ResampledHybridMIEstimator
torch_mist.estimators.reweighed_hybrid(discriminative_estimator: torch_mist.estimators.discriminative.DiscriminativeMIEstimator | str, generative_estimator: torch_mist.estimators.generative.GenerativeMIEstimator | str, x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, discriminative_params: Dict[str, Any] | None = None, generative_params: Dict[str, Any] | None = None, **kwargs) torch_mist.estimators.hybrid.ReweighedHybridMIEstimator
class torch_mist.estimators.MIEstimator

Bases: torch_mist.nn.Model

infomax_gradient: Dict[str, bool]
abstract log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
mutual_information(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
abstract batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
forward(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
torch_mist.estimators.DEFAULT_HIDDEN_DIMS = [128]
torch_mist.estimators.instantiate_estimator(estimator_name: str, x_dim: int | None = None, y_dim: int | None = None, **kwargs) torch_mist.estimators.base.MIEstimator
torch_mist.estimators.joint_transformed_normal(input_dims: Dict[str, int], transform_name: str = 'conditional_linear', n_transforms: int = 1, normalization: str | None = None, **kwargs) torch_mist.distributions.joint.base.JointDistribution
class torch_mist.estimators.JointDistribution(variables: List[str], name: str = 'p')

Bases: torch.nn.Module, torch.distributions.Distribution, pyro.distributions.ConditionalDistribution

full_name()
abstract _log_prob(**kwargs) torch.Tensor
log_prob(*args, **kwargs) torch.Tensor
abstract _marginal(variables: List[str]) T
marginal(*variables) T
conditional(*variables) pyro.distributions.ConditionalDistribution
condition(**conditioning) T
_mutual_information(variable_1: str, variable_2: str) torch.Tensor
mutual_information(variable_1: str | None = None, variable_2: str | None = None) torch.Tensor
abstract _entropy(variables: List[str]) torch.Tensor
entropy(*variables) torch.Tensor
class torch_mist.estimators.BA(q_Y_given_X: pyro.distributions.ConditionalDistribution, entropy_y: torch.Tensor | None = None)

Bases: torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator

lower_bound: bool = True
infomax_gradient: Dict[str, bool]
mutual_information(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.CLUB(q_Y_given_X: pyro.distributions.ConditionalDistribution, neg_samples: int = 0)

Bases: torch_mist.estimators.generative.implementations.l1out.L1Out

infomax_gradient: Dict[str, bool]
approx_log_p_y(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.DoE(q_Y_given_X: pyro.distributions.ConditionalDistribution, q_Y: torch.distributions.Distribution)

Bases: torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator

infomax_gradient: Dict[str, bool]
approx_log_p_y(y: torch.Tensor, x: torch.Tensor | None = None) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
__repr__()
class torch_mist.estimators.GM(q_XY: torch_mist.distributions.joint.base.JointDistribution, q_Y: torch.distributions.Distribution | torch_mist.distributions.joint.base.JointDistribution, q_X: torch.distributions.Distribution | torch_mist.distributions.joint.base.JointDistribution)

Bases: torch_mist.estimators.generative.base.JointGenerativeMIEstimator

property q_X
property q_Y
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.L1Out(q_Y_given_X: pyro.distributions.ConditionalDistribution, neg_samples: int = -1)

Bases: torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator

infomax_gradient: Dict[str, bool]
_broadcast_log_p_y_given_x(x: torch.Tensor, y: torch.Tensor) torch.Tensor
approx_log_p_y(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.DummyGenerativeMIEstimator

Bases: torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator

lower_bound: bool = True
infomax_gradient: Dict[str, bool]
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
torch_mist.estimators.ba(entropy_y: float | torch.Tensor, x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, q_Y_given_X: pyro.distributions.ConditionalDistribution | None = None, transform_name: str = 'conditional_linear', n_transforms: int = 1) torch_mist.estimators.generative.implementations.BA
torch_mist.estimators.club(x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, q_Y_given_X: pyro.distributions.ConditionalDistribution | None = None, transform_name: str = 'conditional_linear', n_transforms: int = 1) torch_mist.estimators.generative.implementations.CLUB
torch_mist.estimators.doe(x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, q_Y_given_X: pyro.distributions.ConditionalDistribution | None = None, q_Y: torch.distributions.Distribution | None = None, conditional_transform_name: str = 'conditional_linear', n_conditional_transforms: int = 1, marginal_transform_name: str = 'linear', n_marginal_transforms: int = 1) torch_mist.estimators.generative.implementations.DoE
torch_mist.estimators.dummy_generative(**kwargs) torch_mist.estimators.generative.implementations.DummyGenerativeMIEstimator
torch_mist.estimators.gm(x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] = None, q_XY: torch_mist.distributions.joint.base.JointDistribution | None = None, q_Y: torch.distributions.Distribution | None = None, q_X: torch.distributions.Distribution | None = None, joint_transform_name: str = 'affine_autoregressive', n_joint_transforms: int = 1, marginal_transform_name: str = 'linear', n_marginal_transforms: int = 1) torch_mist.estimators.generative.implementations.GM
torch_mist.estimators.l1out(x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, q_Y_given_X: pyro.distributions.ConditionalDistribution | None = None, transform_name: str = 'conditional_linear', n_transforms: int = 1) torch_mist.estimators.generative.implementations.L1Out
class torch_mist.estimators.ConstantBaseline(value: float = 0)

Bases: Baseline

forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor
torch_mist.estimators.baseline_nn(x_dim: int, hidden_dims: List[int], nonlinearity: Callable = nn.ReLU(True)) torch_mist.baseline.base.LearnableBaseline
torch_mist.estimators.JOINT_CRITIC = 'joint'
torch_mist.estimators.SEPARABLE_CRITIC = 'separable'
torch_mist.estimators.critic_nn(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str, **kwargs) torch_mist.critic.base.Critic
torch_mist.estimators.shared_critic_nns(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str, n_critics: int, n_shared_layers: int = -1, **kwargs) Tuple[torch_mist.critic.base.Critic, Ellipsis]
class torch_mist.estimators.AlphaTUBA(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.LearnableBaseline, alpha: float = 0.01, neg_samples: int = -1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

class torch_mist.estimators.FLO(critic: torch_mist.critic.Critic, normalized_critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator

unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
__repr__()
class torch_mist.estimators.InfoNCE(critic: torch_mist.critic.Critic, neg_samples: int = 0)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.tensor, log_w: torch.Tensor | None)
class torch_mist.estimators.JS(critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor | None
class torch_mist.estimators.MINE(critic: torch_mist.critic.Critic, neg_samples: int = 1, gamma: float = 0.9)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

lower_bound = False
train_baseline()
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.NWJ(critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

class torch_mist.estimators.SMILE(critic: torch_mist.critic.Critic, neg_samples: int = 1, tau: float = 5.0)

Bases: torch_mist.estimators.discriminative.implementations.js.JS, torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
class torch_mist.estimators.TUBA(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.LearnableBaseline, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

class torch_mist.estimators.DummyDiscriminativeMIEstimator(neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator

_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
torch_mist.estimators.alpha_tuba(x_dim: int, y_dim: int, hidden_dims: List[int], alpha: float = 0.01, learnable_baseline: bool = True, critic_type: str = SEPARABLE_CRITIC, neg_samples: int = 0, **kwargs) torch_mist.estimators.discriminative.implementations.AlphaTUBA
torch_mist.estimators.flo(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = -1, n_shared_layers: int = -1, critic_type: str = SEPARABLE_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.FLO
torch_mist.estimators.infonce(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str = SEPARABLE_CRITIC, neg_samples: int = 0, **kwargs) torch_mist.estimators.discriminative.implementations.InfoNCE
torch_mist.estimators.js(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.JS
torch_mist.estimators.dummy_discriminative(neg_samples: int = 1, **kwargs) torch_mist.estimators.discriminative.implementations.DummyDiscriminativeMIEstimator
torch_mist.estimators.mine(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str = JOINT_CRITIC, neg_samples: int = 1, gamma: float = 0.9, **kwargs) torch_mist.estimators.discriminative.implementations.MINE
torch_mist.estimators.nwj(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.NWJ
torch_mist.estimators.smile(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, tau: float = 5.0, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.SMILE
torch_mist.estimators.tuba(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.TUBA
class torch_mist.estimators.PQ(q_QY_given_X: pyro.distributions.ConditionalDistribution, quantize_y: torch_mist.quantization.functions.QuantizationFunction, temperature: float = 1.0)

Bases: torch_mist.estimators.transformed.base.TransformedMIEstimator

class torch_mist.estimators.BinnedMIEstimator(quantize_x: torch_mist.quantization.functions.QuantizationFunction | None = None, quantize_y: torch_mist.quantization.functions.QuantizationFunction | None = None, temperature: float = 1.0)

Bases: torch_mist.estimators.transformed.base.TransformedMIEstimator

lower_bound = True
class torch_mist.estimators.QuantizationFunction

Bases: torch.nn.Module

abstract property n_bins: int
abstract quantize(x: torch.Tensor) torch.LongTensor
forward(x: torch.Tensor) torch.LongTensor
torch_mist.estimators.instantiate_quantization(name: str, n_bins: int, **kwargs) torch_mist.quantization.functions.QuantizationFunction
torch_mist.estimators.binned(quantize_x: torch_mist.quantization.QuantizationFunction | str | None = 'kmeans', quantize_y: torch_mist.quantization.QuantizationFunction | str | None = 'kmeans', temperature: float = 0.1, n_bins: int | None = 32, x_dim: int | None = None, y_dim: int | None = None, **kwargs) torch_mist.estimators.transformed.implementations.BinnedMIEstimator
torch_mist.estimators.pq(quantize_y: torch_mist.quantization.QuantizationFunction | str | None = 'kmeans', x_dim: int | None = None, hidden_dims: List[int] | None = None, q_QY_given_X: pyro.distributions.ConditionalDistribution | None = None, temperature: float = 0.1, n_bins: int | None = 32, y_dim: int | None = None, **kwargs) torch_mist.estimators.transformed.implementations.PQ
class torch_mist.estimators.FlippedMIEstimator(estimator: torch_mist.estimators.MIEstimator)

Bases: torch_mist.estimators.MIEstimator

log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
mutual_information(x: torch.Tensor, y: torch.Tensor) torch.Tensor
torch_mist.estimators.flip_estimator(estimator: torch_mist.estimators.MIEstimator) torch_mist.estimators.MIEstimator
class torch_mist.estimators.MultiMIEstimator(estimators: Dict[Tuple[str, str], torch_mist.estimators.base.MIEstimator])

Bases: torch_mist.estimators.base.MIEstimator

broadcast_function(function_name: str, **variables) Dict[Tuple[str, str], torch.Tensor]
loss(**variables) torch.Tensor
batch_loss(**variables) torch.Tensor
mutual_information(**variables) Dict[Tuple[str, str], torch.Tensor]
log_ratio(**variables) Dict[Tuple[str, str], torch.Tensor]
forward(**variables) torch.Tensor