torch_mist.estimators.hybrid
Subpackages
Submodules
Package Contents
Classes
Functions
|
|
|
|
|
|
|
|
|
|
|
- class torch_mist.estimators.hybrid.MIEstimator
Bases:
torch_mist.nn.Model- infomax_gradient: Dict[str, bool]
- abstract log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- abstract unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- mutual_information(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
- abstract batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- forward(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
- class torch_mist.estimators.hybrid.DiscriminativeMIEstimator(critic: torch_mist.critic.Critic, neg_samples: int = 1)
Bases:
torch_mist.estimators.base.MIEstimator- lower_bound: bool = True
- infomax_gradient: Dict[str, bool]
- unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- n_negatives_to_use(N: int)
- sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- abstract _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- approx_log_partition(x: torch.Tensor, y: torch.Tensor, x_: torch.Tensor, y_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- __repr__()
- torch_mist.estimators.hybrid.cached_method(method: Callable[Ellipsis, T]) Callable[Ellipsis, T]
- torch_mist.estimators.hybrid.is_trainable(function: Any) bool
- class torch_mist.estimators.hybrid.HybridMIEstimator(generative_estimator: torch_mist.estimators.base.MIEstimator, discriminative_estimator: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator)
Bases:
torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator- unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- abstract sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor | None]
- resampling_strategy()
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- __repr__()
- class torch_mist.estimators.hybrid.PQHybridMIEstimator(q_QY_given_X: pyro.distributions.ConditionalDistribution, quantize_y: torch_mist.quantization.QuantizationFunction, discriminative_estimator: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator, temperature: float = 1.0)
Bases:
torch_mist.estimators.hybrid.base.HybridMIEstimator- property quantize_y: Callable[[torch.Tensor], torch.LongTensor]
- disable_batch_validation()
- enable_batch_validation()
- sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
- class torch_mist.estimators.hybrid.ResampledHybridMIEstimator(generative_estimator: torch_mist.estimators.generative.base.GenerativeMIEstimator, discriminative_estimator: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator)
Bases:
torch_mist.estimators.hybrid.base.HybridMIEstimator- sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
- class torch_mist.estimators.hybrid.ReweighedHybridMIEstimator(generative_estimator: torch_mist.estimators.base.MIEstimator, discriminative_estimator: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator)
Bases:
torch_mist.estimators.hybrid.base.HybridMIEstimator- sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
- class torch_mist.estimators.hybrid.MIEstimator
Bases:
torch_mist.nn.Model- infomax_gradient: Dict[str, bool]
- abstract log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- abstract unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- mutual_information(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
- abstract batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- forward(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
- torch_mist.estimators.hybrid.instantiate_estimator(estimator_name: str, x_dim: int | None = None, y_dim: int | None = None, **kwargs) torch_mist.estimators.base.MIEstimator
- class torch_mist.estimators.hybrid.DiscriminativeMIEstimator(critic: torch_mist.critic.Critic, neg_samples: int = 1)
Bases:
torch_mist.estimators.base.MIEstimator- lower_bound: bool = True
- infomax_gradient: Dict[str, bool]
- unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- n_negatives_to_use(N: int)
- sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- abstract _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- approx_log_partition(x: torch.Tensor, y: torch.Tensor, x_: torch.Tensor, y_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- __repr__()
- class torch_mist.estimators.hybrid.GenerativeMIEstimator
Bases:
torch_mist.estimators.base.MIEstimator- unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- abstract approx_log_p_y_given_x(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- abstract approx_log_p_y(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- class torch_mist.estimators.hybrid.QuantizationFunction
Bases:
torch.nn.Module- abstract property n_bins: int
- abstract quantize(x: torch.Tensor) torch.LongTensor
- forward(x: torch.Tensor) torch.LongTensor
- torch_mist.estimators.hybrid.instantiate_quantization(name: str, n_bins: int, **kwargs) torch_mist.quantization.functions.QuantizationFunction
- torch_mist.estimators.hybrid.conditional_categorical(n_classes: int, context_dim: int, hidden_dims: List[int], temperature: float = 1.0)
- torch_mist.estimators.hybrid.hybrid_pq(discriminative_estimator: torch_mist.estimators.discriminative.DiscriminativeMIEstimator | str, x_dim: int | None = None, y_dim: int | None = None, quantize_y: torch_mist.quantization.QuantizationFunction | str | None = None, hidden_dims: List[int] | None = None, q_QY_given_X: pyro.distributions.ConditionalDistribution | None = None, temperature: float = 0.1, n_bins: int | None = 32, quantization_params: Dict[str, Any] | None = None, **kwargs) torch_mist.estimators.hybrid.PQHybridMIEstimator
- torch_mist.estimators.hybrid.resampled_hybrid(discriminative_estimator: torch_mist.estimators.discriminative.DiscriminativeMIEstimator | str, generative_estimator: torch_mist.estimators.generative.GenerativeMIEstimator | str, x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, discriminative_params: Dict[str, Any] | None = None, generative_params: Dict[str, Any] | None = None, **kwargs) torch_mist.estimators.hybrid.ResampledHybridMIEstimator
- torch_mist.estimators.hybrid.reweighed_hybrid(discriminative_estimator: torch_mist.estimators.discriminative.DiscriminativeMIEstimator | str, generative_estimator: torch_mist.estimators.generative.GenerativeMIEstimator | str, x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, discriminative_params: Dict[str, Any] | None = None, generative_params: Dict[str, Any] | None = None, **kwargs) torch_mist.estimators.hybrid.ReweighedHybridMIEstimator