torch_mist.estimators
Subpackages
torch_mist.estimators.discriminativetorch_mist.estimators.discriminative.implementationstorch_mist.estimators.discriminative.implementations.alpha_tubatorch_mist.estimators.discriminative.implementations.dummytorch_mist.estimators.discriminative.implementations.flotorch_mist.estimators.discriminative.implementations.infoncetorch_mist.estimators.discriminative.implementations.jstorch_mist.estimators.discriminative.implementations.minetorch_mist.estimators.discriminative.implementations.nwjtorch_mist.estimators.discriminative.implementations.smiletorch_mist.estimators.discriminative.implementations.tuba
torch_mist.estimators.discriminative.basetorch_mist.estimators.discriminative.factories
torch_mist.estimators.generativetorch_mist.estimators.generative.implementationstorch_mist.estimators.generative.implementations.batorch_mist.estimators.generative.implementations.clubtorch_mist.estimators.generative.implementations.doetorch_mist.estimators.generative.implementations.dummytorch_mist.estimators.generative.implementations.gmtorch_mist.estimators.generative.implementations.l1out
torch_mist.estimators.generative.basetorch_mist.estimators.generative.factories
torch_mist.estimators.hybridtorch_mist.estimators.transformed
Submodules
Package Contents
Classes
Functions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Attributes
- class torch_mist.estimators.MIEstimator
Bases:
torch_mist.nn.Model- infomax_gradient: Dict[str, bool]
- abstract log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- abstract unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- mutual_information(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
- abstract batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- forward(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
- class torch_mist.estimators.BA(q_Y_given_X: pyro.distributions.ConditionalDistribution, entropy_y: torch.Tensor | None = None)
Bases:
torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator- lower_bound: bool = True
- infomax_gradient: Dict[str, bool]
- mutual_information(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- class torch_mist.estimators.CLUB(q_Y_given_X: pyro.distributions.ConditionalDistribution, neg_samples: int = 0)
Bases:
torch_mist.estimators.generative.implementations.l1out.L1Out- infomax_gradient: Dict[str, bool]
- approx_log_p_y(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- class torch_mist.estimators.DoE(q_Y_given_X: pyro.distributions.ConditionalDistribution, q_Y: torch.distributions.Distribution)
Bases:
torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator- infomax_gradient: Dict[str, bool]
- approx_log_p_y(y: torch.Tensor, x: torch.Tensor | None = None) torch.Tensor
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- __repr__()
- class torch_mist.estimators.GM(q_XY: torch_mist.distributions.joint.base.JointDistribution, q_Y: torch.distributions.Distribution | torch_mist.distributions.joint.base.JointDistribution, q_X: torch.distributions.Distribution | torch_mist.distributions.joint.base.JointDistribution)
Bases:
torch_mist.estimators.generative.base.JointGenerativeMIEstimator- property q_X
- property q_Y
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- class torch_mist.estimators.L1Out(q_Y_given_X: pyro.distributions.ConditionalDistribution, neg_samples: int = -1)
Bases:
torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator- infomax_gradient: Dict[str, bool]
- _broadcast_log_p_y_given_x(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- approx_log_p_y(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- class torch_mist.estimators.DummyGenerativeMIEstimator
Bases:
torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator- lower_bound: bool = True
- infomax_gradient: Dict[str, bool]
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- class torch_mist.estimators.AlphaTUBA(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.LearnableBaseline, alpha: float = 0.01, neg_samples: int = -1)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator
- class torch_mist.estimators.FLO(critic: torch_mist.critic.Critic, normalized_critic: torch_mist.critic.Critic, neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator- unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- __repr__()
- class torch_mist.estimators.InfoNCE(critic: torch_mist.critic.Critic, neg_samples: int = 0)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.tensor, log_w: torch.Tensor | None)
- class torch_mist.estimators.JS(critic: torch_mist.critic.Critic, neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor | None
- class torch_mist.estimators.MINE(critic: torch_mist.critic.Critic, neg_samples: int = 1, gamma: float = 0.9)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator- lower_bound = False
- train_baseline()
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- class torch_mist.estimators.NWJ(critic: torch_mist.critic.Critic, neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator
- class torch_mist.estimators.SMILE(critic: torch_mist.critic.Critic, neg_samples: int = 1, tau: float = 5.0)
Bases:
torch_mist.estimators.discriminative.implementations.js.JS,torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- class torch_mist.estimators.TUBA(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.LearnableBaseline, neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator
- class torch_mist.estimators.DummyDiscriminativeMIEstimator(neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- class torch_mist.estimators.BinnedMIEstimator(quantize_x: torch_mist.quantization.functions.QuantizationFunction | None = None, quantize_y: torch_mist.quantization.functions.QuantizationFunction | None = None, temperature: float = 1.0)
Bases:
torch_mist.estimators.transformed.base.TransformedMIEstimator- lower_bound = True
- class torch_mist.estimators.PQ(q_QY_given_X: pyro.distributions.ConditionalDistribution, quantize_y: torch_mist.quantization.functions.QuantizationFunction, temperature: float = 1.0)
Bases:
torch_mist.estimators.transformed.base.TransformedMIEstimator
- class torch_mist.estimators.TransformedMIEstimator(base_estimator: torch_mist.estimators.base.MIEstimator, transforms: Dict[str, Callable[[Any], Any]] | None = None, transforms_rename: Dict[Tuple[str, str], Callable[[Any], Any]] | None = None)
Bases:
torch_mist.estimators.base.MIEstimator- transform(**variables) Dict[str, torch.Tensor]
- _unfold_variables(*args, **variables) Dict[str, Any]
- batch_loss(*args, **variables) torch.Tensor
- log_ratio(*args, **variables) torch.Tensor | List[torch.Tensor]
- unnormalized_log_ratio(*args, **variables) torch.Tensor
- loss(*args, **variables) torch.Tensor
- mutual_information(*args, **variables) torch.Tensor | Dict[str, torch.Tensor]
- forward(*args, **kwargs) torch.Tensor
- class torch_mist.estimators.MIEstimator
Bases:
torch_mist.nn.Model- infomax_gradient: Dict[str, bool]
- abstract log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- abstract unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- mutual_information(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
- abstract batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- forward(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
- class torch_mist.estimators.DiscriminativeMIEstimator(critic: torch_mist.critic.Critic, neg_samples: int = 1)
Bases:
torch_mist.estimators.base.MIEstimator- lower_bound: bool = True
- infomax_gradient: Dict[str, bool]
- unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- n_negatives_to_use(N: int)
- sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- abstract _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- approx_log_partition(x: torch.Tensor, y: torch.Tensor, x_: torch.Tensor, y_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- __repr__()
- torch_mist.estimators.cached_method(method: Callable[Ellipsis, T]) Callable[Ellipsis, T]
- torch_mist.estimators.is_trainable(function: Any) bool
- class torch_mist.estimators.HybridMIEstimator(generative_estimator: torch_mist.estimators.base.MIEstimator, discriminative_estimator: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator)
Bases:
torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator- unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- abstract sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor | None]
- resampling_strategy()
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- __repr__()
- class torch_mist.estimators.PQHybridMIEstimator(q_QY_given_X: pyro.distributions.ConditionalDistribution, quantize_y: torch_mist.quantization.QuantizationFunction, discriminative_estimator: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator, temperature: float = 1.0)
Bases:
torch_mist.estimators.hybrid.base.HybridMIEstimator- property quantize_y: Callable[[torch.Tensor], torch.LongTensor]
- disable_batch_validation()
- enable_batch_validation()
- sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
- class torch_mist.estimators.ResampledHybridMIEstimator(generative_estimator: torch_mist.estimators.generative.base.GenerativeMIEstimator, discriminative_estimator: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator)
Bases:
torch_mist.estimators.hybrid.base.HybridMIEstimator- sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
- class torch_mist.estimators.ReweighedHybridMIEstimator(generative_estimator: torch_mist.estimators.base.MIEstimator, discriminative_estimator: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator)
Bases:
torch_mist.estimators.hybrid.base.HybridMIEstimator- sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
- torch_mist.estimators.instantiate_estimator(estimator_name: str, x_dim: int | None = None, y_dim: int | None = None, **kwargs) torch_mist.estimators.base.MIEstimator
- class torch_mist.estimators.GenerativeMIEstimator
Bases:
torch_mist.estimators.base.MIEstimator- unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- abstract approx_log_p_y_given_x(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- abstract approx_log_p_y(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- class torch_mist.estimators.QuantizationFunction
Bases:
torch.nn.Module- abstract property n_bins: int
- abstract quantize(x: torch.Tensor) torch.LongTensor
- forward(x: torch.Tensor) torch.LongTensor
- torch_mist.estimators.instantiate_quantization(name: str, n_bins: int, **kwargs) torch_mist.quantization.functions.QuantizationFunction
- torch_mist.estimators.conditional_categorical(n_classes: int, context_dim: int, hidden_dims: List[int], temperature: float = 1.0)
- torch_mist.estimators.hybrid_pq(discriminative_estimator: torch_mist.estimators.discriminative.DiscriminativeMIEstimator | str, x_dim: int | None = None, y_dim: int | None = None, quantize_y: torch_mist.quantization.QuantizationFunction | str | None = None, hidden_dims: List[int] | None = None, q_QY_given_X: pyro.distributions.ConditionalDistribution | None = None, temperature: float = 0.1, n_bins: int | None = 32, quantization_params: Dict[str, Any] | None = None, **kwargs) torch_mist.estimators.hybrid.PQHybridMIEstimator
- torch_mist.estimators.resampled_hybrid(discriminative_estimator: torch_mist.estimators.discriminative.DiscriminativeMIEstimator | str, generative_estimator: torch_mist.estimators.generative.GenerativeMIEstimator | str, x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, discriminative_params: Dict[str, Any] | None = None, generative_params: Dict[str, Any] | None = None, **kwargs) torch_mist.estimators.hybrid.ResampledHybridMIEstimator
- torch_mist.estimators.reweighed_hybrid(discriminative_estimator: torch_mist.estimators.discriminative.DiscriminativeMIEstimator | str, generative_estimator: torch_mist.estimators.generative.GenerativeMIEstimator | str, x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, discriminative_params: Dict[str, Any] | None = None, generative_params: Dict[str, Any] | None = None, **kwargs) torch_mist.estimators.hybrid.ReweighedHybridMIEstimator
- class torch_mist.estimators.MIEstimator
Bases:
torch_mist.nn.Model- infomax_gradient: Dict[str, bool]
- abstract log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- abstract unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- mutual_information(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
- abstract batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- forward(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
- torch_mist.estimators.DEFAULT_HIDDEN_DIMS = [128]
- torch_mist.estimators.instantiate_estimator(estimator_name: str, x_dim: int | None = None, y_dim: int | None = None, **kwargs) torch_mist.estimators.base.MIEstimator
- torch_mist.estimators.joint_transformed_normal(input_dims: Dict[str, int], transform_name: str = 'conditional_linear', n_transforms: int = 1, normalization: str | None = None, **kwargs) torch_mist.distributions.joint.base.JointDistribution
- class torch_mist.estimators.JointDistribution(variables: List[str], name: str = 'p')
Bases:
torch.nn.Module,torch.distributions.Distribution,pyro.distributions.ConditionalDistribution- full_name()
- abstract _log_prob(**kwargs) torch.Tensor
- log_prob(*args, **kwargs) torch.Tensor
- abstract _marginal(variables: List[str]) T
- marginal(*variables) T
- conditional(*variables) pyro.distributions.ConditionalDistribution
- condition(**conditioning) T
- _mutual_information(variable_1: str, variable_2: str) torch.Tensor
- mutual_information(variable_1: str | None = None, variable_2: str | None = None) torch.Tensor
- abstract _entropy(variables: List[str]) torch.Tensor
- entropy(*variables) torch.Tensor
- class torch_mist.estimators.BA(q_Y_given_X: pyro.distributions.ConditionalDistribution, entropy_y: torch.Tensor | None = None)
Bases:
torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator- lower_bound: bool = True
- infomax_gradient: Dict[str, bool]
- mutual_information(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- class torch_mist.estimators.CLUB(q_Y_given_X: pyro.distributions.ConditionalDistribution, neg_samples: int = 0)
Bases:
torch_mist.estimators.generative.implementations.l1out.L1Out- infomax_gradient: Dict[str, bool]
- approx_log_p_y(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- class torch_mist.estimators.DoE(q_Y_given_X: pyro.distributions.ConditionalDistribution, q_Y: torch.distributions.Distribution)
Bases:
torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator- infomax_gradient: Dict[str, bool]
- approx_log_p_y(y: torch.Tensor, x: torch.Tensor | None = None) torch.Tensor
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- __repr__()
- class torch_mist.estimators.GM(q_XY: torch_mist.distributions.joint.base.JointDistribution, q_Y: torch.distributions.Distribution | torch_mist.distributions.joint.base.JointDistribution, q_X: torch.distributions.Distribution | torch_mist.distributions.joint.base.JointDistribution)
Bases:
torch_mist.estimators.generative.base.JointGenerativeMIEstimator- property q_X
- property q_Y
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- class torch_mist.estimators.L1Out(q_Y_given_X: pyro.distributions.ConditionalDistribution, neg_samples: int = -1)
Bases:
torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator- infomax_gradient: Dict[str, bool]
- _broadcast_log_p_y_given_x(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- approx_log_p_y(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- class torch_mist.estimators.DummyGenerativeMIEstimator
Bases:
torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator- lower_bound: bool = True
- infomax_gradient: Dict[str, bool]
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- torch_mist.estimators.ba(entropy_y: float | torch.Tensor, x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, q_Y_given_X: pyro.distributions.ConditionalDistribution | None = None, transform_name: str = 'conditional_linear', n_transforms: int = 1) torch_mist.estimators.generative.implementations.BA
- torch_mist.estimators.club(x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, q_Y_given_X: pyro.distributions.ConditionalDistribution | None = None, transform_name: str = 'conditional_linear', n_transforms: int = 1) torch_mist.estimators.generative.implementations.CLUB
- torch_mist.estimators.doe(x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, q_Y_given_X: pyro.distributions.ConditionalDistribution | None = None, q_Y: torch.distributions.Distribution | None = None, conditional_transform_name: str = 'conditional_linear', n_conditional_transforms: int = 1, marginal_transform_name: str = 'linear', n_marginal_transforms: int = 1) torch_mist.estimators.generative.implementations.DoE
- torch_mist.estimators.dummy_generative(**kwargs) torch_mist.estimators.generative.implementations.DummyGenerativeMIEstimator
- torch_mist.estimators.gm(x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] = None, q_XY: torch_mist.distributions.joint.base.JointDistribution | None = None, q_Y: torch.distributions.Distribution | None = None, q_X: torch.distributions.Distribution | None = None, joint_transform_name: str = 'affine_autoregressive', n_joint_transforms: int = 1, marginal_transform_name: str = 'linear', n_marginal_transforms: int = 1) torch_mist.estimators.generative.implementations.GM
- torch_mist.estimators.l1out(x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, q_Y_given_X: pyro.distributions.ConditionalDistribution | None = None, transform_name: str = 'conditional_linear', n_transforms: int = 1) torch_mist.estimators.generative.implementations.L1Out
- class torch_mist.estimators.ConstantBaseline(value: float = 0)
Bases:
Baseline- forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor
- torch_mist.estimators.baseline_nn(x_dim: int, hidden_dims: List[int], nonlinearity: Callable = nn.ReLU(True)) torch_mist.baseline.base.LearnableBaseline
- torch_mist.estimators.JOINT_CRITIC = 'joint'
- torch_mist.estimators.SEPARABLE_CRITIC = 'separable'
- torch_mist.estimators.critic_nn(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str, **kwargs) torch_mist.critic.base.Critic
- class torch_mist.estimators.AlphaTUBA(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.LearnableBaseline, alpha: float = 0.01, neg_samples: int = -1)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator
- class torch_mist.estimators.FLO(critic: torch_mist.critic.Critic, normalized_critic: torch_mist.critic.Critic, neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator- unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- __repr__()
- class torch_mist.estimators.InfoNCE(critic: torch_mist.critic.Critic, neg_samples: int = 0)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.tensor, log_w: torch.Tensor | None)
- class torch_mist.estimators.JS(critic: torch_mist.critic.Critic, neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor | None
- class torch_mist.estimators.MINE(critic: torch_mist.critic.Critic, neg_samples: int = 1, gamma: float = 0.9)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator- lower_bound = False
- train_baseline()
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- class torch_mist.estimators.NWJ(critic: torch_mist.critic.Critic, neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator
- class torch_mist.estimators.SMILE(critic: torch_mist.critic.Critic, neg_samples: int = 1, tau: float = 5.0)
Bases:
torch_mist.estimators.discriminative.implementations.js.JS,torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- class torch_mist.estimators.TUBA(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.LearnableBaseline, neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator
- class torch_mist.estimators.DummyDiscriminativeMIEstimator(neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- torch_mist.estimators.alpha_tuba(x_dim: int, y_dim: int, hidden_dims: List[int], alpha: float = 0.01, learnable_baseline: bool = True, critic_type: str = SEPARABLE_CRITIC, neg_samples: int = 0, **kwargs) torch_mist.estimators.discriminative.implementations.AlphaTUBA
- torch_mist.estimators.flo(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = -1, n_shared_layers: int = -1, critic_type: str = SEPARABLE_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.FLO
- torch_mist.estimators.infonce(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str = SEPARABLE_CRITIC, neg_samples: int = 0, **kwargs) torch_mist.estimators.discriminative.implementations.InfoNCE
- torch_mist.estimators.js(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.JS
- torch_mist.estimators.dummy_discriminative(neg_samples: int = 1, **kwargs) torch_mist.estimators.discriminative.implementations.DummyDiscriminativeMIEstimator
- torch_mist.estimators.mine(x_dim: int, y_dim: int, hidden_dims: List[int], critic_type: str = JOINT_CRITIC, neg_samples: int = 1, gamma: float = 0.9, **kwargs) torch_mist.estimators.discriminative.implementations.MINE
- torch_mist.estimators.nwj(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.NWJ
- torch_mist.estimators.smile(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, tau: float = 5.0, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.SMILE
- torch_mist.estimators.tuba(x_dim: int, y_dim: int, hidden_dims: List[int], neg_samples: int = 1, critic_type: str = JOINT_CRITIC, **kwargs) torch_mist.estimators.discriminative.implementations.TUBA
- class torch_mist.estimators.PQ(q_QY_given_X: pyro.distributions.ConditionalDistribution, quantize_y: torch_mist.quantization.functions.QuantizationFunction, temperature: float = 1.0)
Bases:
torch_mist.estimators.transformed.base.TransformedMIEstimator
- class torch_mist.estimators.BinnedMIEstimator(quantize_x: torch_mist.quantization.functions.QuantizationFunction | None = None, quantize_y: torch_mist.quantization.functions.QuantizationFunction | None = None, temperature: float = 1.0)
Bases:
torch_mist.estimators.transformed.base.TransformedMIEstimator- lower_bound = True
- class torch_mist.estimators.QuantizationFunction
Bases:
torch.nn.Module- abstract property n_bins: int
- abstract quantize(x: torch.Tensor) torch.LongTensor
- forward(x: torch.Tensor) torch.LongTensor
- torch_mist.estimators.instantiate_quantization(name: str, n_bins: int, **kwargs) torch_mist.quantization.functions.QuantizationFunction
- torch_mist.estimators.binned(quantize_x: torch_mist.quantization.QuantizationFunction | str | None = 'kmeans', quantize_y: torch_mist.quantization.QuantizationFunction | str | None = 'kmeans', temperature: float = 0.1, n_bins: int | None = 32, x_dim: int | None = None, y_dim: int | None = None, **kwargs) torch_mist.estimators.transformed.implementations.BinnedMIEstimator
- torch_mist.estimators.pq(quantize_y: torch_mist.quantization.QuantizationFunction | str | None = 'kmeans', x_dim: int | None = None, hidden_dims: List[int] | None = None, q_QY_given_X: pyro.distributions.ConditionalDistribution | None = None, temperature: float = 0.1, n_bins: int | None = 32, y_dim: int | None = None, **kwargs) torch_mist.estimators.transformed.implementations.PQ
- class torch_mist.estimators.FlippedMIEstimator(estimator: torch_mist.estimators.MIEstimator)
Bases:
torch_mist.estimators.MIEstimator- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- mutual_information(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- torch_mist.estimators.flip_estimator(estimator: torch_mist.estimators.MIEstimator) torch_mist.estimators.MIEstimator
- class torch_mist.estimators.MultiMIEstimator(estimators: Dict[Tuple[str, str], torch_mist.estimators.base.MIEstimator])
Bases:
torch_mist.estimators.base.MIEstimator- broadcast_function(function_name: str, **variables) Dict[Tuple[str, str], torch.Tensor]
- loss(**variables) torch.Tensor
- batch_loss(**variables) torch.Tensor
- mutual_information(**variables) Dict[Tuple[str, str], torch.Tensor]
- log_ratio(**variables) Dict[Tuple[str, str], torch.Tensor]
- forward(**variables) torch.Tensor