torch_mist.estimators.hybrid

Subpackages

Submodules

Package Contents

Classes

MIEstimator

DiscriminativeMIEstimator

HybridMIEstimator

PQHybridMIEstimator

ResampledHybridMIEstimator

ReweighedHybridMIEstimator

MIEstimator

DiscriminativeMIEstimator

GenerativeMIEstimator

QuantizationFunction

Functions

cached_method(→ Callable[Ellipsis, T])

is_trainable(→ bool)

instantiate_estimator(...)

instantiate_quantization(...)

conditional_categorical(n_classes, context_dim, ...[, ...])

hybrid_pq(...)

resampled_hybrid(...)

reweighed_hybrid(...)

class torch_mist.estimators.hybrid.MIEstimator

Bases: torch_mist.nn.Model

infomax_gradient: Dict[str, bool]
abstract log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
mutual_information(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
abstract batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
forward(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
class torch_mist.estimators.hybrid.DiscriminativeMIEstimator(critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.base.MIEstimator

lower_bound: bool = True
infomax_gradient: Dict[str, bool]
unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
n_negatives_to_use(N: int)
sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
approx_log_partition(x: torch.Tensor, y: torch.Tensor, x_: torch.Tensor, y_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
__repr__()
torch_mist.estimators.hybrid.cached_method(method: Callable[Ellipsis, T]) Callable[Ellipsis, T]
torch_mist.estimators.hybrid.is_trainable(function: Any) bool
class torch_mist.estimators.hybrid.HybridMIEstimator(generative_estimator: torch_mist.estimators.base.MIEstimator, discriminative_estimator: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator)

Bases: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator

unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor | None]
resampling_strategy()
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
__repr__()
class torch_mist.estimators.hybrid.PQHybridMIEstimator(q_QY_given_X: pyro.distributions.ConditionalDistribution, quantize_y: torch_mist.quantization.QuantizationFunction, discriminative_estimator: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator, temperature: float = 1.0)

Bases: torch_mist.estimators.hybrid.base.HybridMIEstimator

property quantize_y: Callable[[torch.Tensor], torch.LongTensor]
disable_batch_validation()
enable_batch_validation()
sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
class torch_mist.estimators.hybrid.ResampledHybridMIEstimator(generative_estimator: torch_mist.estimators.generative.base.GenerativeMIEstimator, discriminative_estimator: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator)

Bases: torch_mist.estimators.hybrid.base.HybridMIEstimator

sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
class torch_mist.estimators.hybrid.ReweighedHybridMIEstimator(generative_estimator: torch_mist.estimators.base.MIEstimator, discriminative_estimator: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator)

Bases: torch_mist.estimators.hybrid.base.HybridMIEstimator

sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
class torch_mist.estimators.hybrid.MIEstimator

Bases: torch_mist.nn.Model

infomax_gradient: Dict[str, bool]
abstract log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
mutual_information(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
abstract batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
forward(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
torch_mist.estimators.hybrid.instantiate_estimator(estimator_name: str, x_dim: int | None = None, y_dim: int | None = None, **kwargs) torch_mist.estimators.base.MIEstimator
class torch_mist.estimators.hybrid.DiscriminativeMIEstimator(critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.base.MIEstimator

lower_bound: bool = True
infomax_gradient: Dict[str, bool]
unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
n_negatives_to_use(N: int)
sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
approx_log_partition(x: torch.Tensor, y: torch.Tensor, x_: torch.Tensor, y_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
__repr__()
class torch_mist.estimators.hybrid.GenerativeMIEstimator

Bases: torch_mist.estimators.base.MIEstimator

unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract approx_log_p_y_given_x(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract approx_log_p_y(x: torch.Tensor, y: torch.Tensor) torch.Tensor
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.hybrid.QuantizationFunction

Bases: torch.nn.Module

abstract property n_bins: int
abstract quantize(x: torch.Tensor) torch.LongTensor
forward(x: torch.Tensor) torch.LongTensor
torch_mist.estimators.hybrid.instantiate_quantization(name: str, n_bins: int, **kwargs) torch_mist.quantization.functions.QuantizationFunction
torch_mist.estimators.hybrid.conditional_categorical(n_classes: int, context_dim: int, hidden_dims: List[int], temperature: float = 1.0)
torch_mist.estimators.hybrid.hybrid_pq(discriminative_estimator: torch_mist.estimators.discriminative.DiscriminativeMIEstimator | str, x_dim: int | None = None, y_dim: int | None = None, quantize_y: torch_mist.quantization.QuantizationFunction | str | None = None, hidden_dims: List[int] | None = None, q_QY_given_X: pyro.distributions.ConditionalDistribution | None = None, temperature: float = 0.1, n_bins: int | None = 32, quantization_params: Dict[str, Any] | None = None, **kwargs) torch_mist.estimators.hybrid.PQHybridMIEstimator
torch_mist.estimators.hybrid.resampled_hybrid(discriminative_estimator: torch_mist.estimators.discriminative.DiscriminativeMIEstimator | str, generative_estimator: torch_mist.estimators.generative.GenerativeMIEstimator | str, x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, discriminative_params: Dict[str, Any] | None = None, generative_params: Dict[str, Any] | None = None, **kwargs) torch_mist.estimators.hybrid.ResampledHybridMIEstimator
torch_mist.estimators.hybrid.reweighed_hybrid(discriminative_estimator: torch_mist.estimators.discriminative.DiscriminativeMIEstimator | str, generative_estimator: torch_mist.estimators.generative.GenerativeMIEstimator | str, x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, discriminative_params: Dict[str, Any] | None = None, generative_params: Dict[str, Any] | None = None, **kwargs) torch_mist.estimators.hybrid.ReweighedHybridMIEstimator