torch_mist.estimators.transformed
Subpackages
Submodules
Package Contents
Classes
Functions
|
|
|
- class torch_mist.estimators.transformed.TransformedMIEstimator(base_estimator: torch_mist.estimators.base.MIEstimator, transforms: Dict[str, Callable[[Any], Any]] | None = None, transforms_rename: Dict[Tuple[str, str], Callable[[Any], Any]] | None = None)
Bases:
torch_mist.estimators.base.MIEstimator- transform(**variables) Dict[str, torch.Tensor]
- _unfold_variables(*args, **variables) Dict[str, Any]
- batch_loss(*args, **variables) torch.Tensor
- log_ratio(*args, **variables) torch.Tensor | List[torch.Tensor]
- unnormalized_log_ratio(*args, **variables) torch.Tensor
- loss(*args, **variables) torch.Tensor
- mutual_information(*args, **variables) torch.Tensor | Dict[str, torch.Tensor]
- forward(*args, **kwargs) torch.Tensor
- class torch_mist.estimators.transformed.BinnedMIEstimator(quantize_x: torch_mist.quantization.functions.QuantizationFunction | None = None, quantize_y: torch_mist.quantization.functions.QuantizationFunction | None = None, temperature: float = 1.0)
Bases:
torch_mist.estimators.transformed.base.TransformedMIEstimator- lower_bound = True
- class torch_mist.estimators.transformed.PQ(q_QY_given_X: pyro.distributions.ConditionalDistribution, quantize_y: torch_mist.quantization.functions.QuantizationFunction, temperature: float = 1.0)
Bases:
torch_mist.estimators.transformed.base.TransformedMIEstimator
- class torch_mist.estimators.transformed.PQ(q_QY_given_X: pyro.distributions.ConditionalDistribution, quantize_y: torch_mist.quantization.functions.QuantizationFunction, temperature: float = 1.0)
Bases:
torch_mist.estimators.transformed.base.TransformedMIEstimator
- class torch_mist.estimators.transformed.BinnedMIEstimator(quantize_x: torch_mist.quantization.functions.QuantizationFunction | None = None, quantize_y: torch_mist.quantization.functions.QuantizationFunction | None = None, temperature: float = 1.0)
Bases:
torch_mist.estimators.transformed.base.TransformedMIEstimator- lower_bound = True
- class torch_mist.estimators.transformed.QuantizationFunction
Bases:
torch.nn.Module- abstract property n_bins: int
- abstract quantize(x: torch.Tensor) torch.LongTensor
- forward(x: torch.Tensor) torch.LongTensor
- torch_mist.estimators.transformed.instantiate_quantization(name: str, n_bins: int, **kwargs) torch_mist.quantization.functions.QuantizationFunction
- torch_mist.estimators.transformed.binned(quantize_x: torch_mist.quantization.QuantizationFunction | str | None = 'kmeans', quantize_y: torch_mist.quantization.QuantizationFunction | str | None = 'kmeans', temperature: float = 0.1, n_bins: int | None = 32, x_dim: int | None = None, y_dim: int | None = None, **kwargs) torch_mist.estimators.transformed.implementations.BinnedMIEstimator
- torch_mist.estimators.transformed.pq(quantize_y: torch_mist.quantization.QuantizationFunction | str | None = 'kmeans', x_dim: int | None = None, hidden_dims: List[int] | None = None, q_QY_given_X: pyro.distributions.ConditionalDistribution | None = None, temperature: float = 0.1, n_bins: int | None = 32, y_dim: int | None = None, **kwargs) torch_mist.estimators.transformed.implementations.PQ