torch_mist.estimators.transformed

Subpackages

Submodules

Package Contents

Classes

TransformedMIEstimator

BinnedMIEstimator

PQ

PQ

BinnedMIEstimator

QuantizationFunction

Functions

instantiate_quantization(...)

binned(...)

pq(→ torch_mist.estimators.transformed.implementations.PQ)

class torch_mist.estimators.transformed.TransformedMIEstimator(base_estimator: torch_mist.estimators.base.MIEstimator, transforms: Dict[str, Callable[[Any], Any]] | None = None, transforms_rename: Dict[Tuple[str, str], Callable[[Any], Any]] | None = None)

Bases: torch_mist.estimators.base.MIEstimator

transform(**variables) Dict[str, torch.Tensor]
_unfold_variables(*args, **variables) Dict[str, Any]
batch_loss(*args, **variables) torch.Tensor
log_ratio(*args, **variables) torch.Tensor | List[torch.Tensor]
unnormalized_log_ratio(*args, **variables) torch.Tensor
loss(*args, **variables) torch.Tensor
mutual_information(*args, **variables) torch.Tensor | Dict[str, torch.Tensor]
forward(*args, **kwargs) torch.Tensor
class torch_mist.estimators.transformed.BinnedMIEstimator(quantize_x: torch_mist.quantization.functions.QuantizationFunction | None = None, quantize_y: torch_mist.quantization.functions.QuantizationFunction | None = None, temperature: float = 1.0)

Bases: torch_mist.estimators.transformed.base.TransformedMIEstimator

lower_bound = True
class torch_mist.estimators.transformed.PQ(q_QY_given_X: pyro.distributions.ConditionalDistribution, quantize_y: torch_mist.quantization.functions.QuantizationFunction, temperature: float = 1.0)

Bases: torch_mist.estimators.transformed.base.TransformedMIEstimator

class torch_mist.estimators.transformed.PQ(q_QY_given_X: pyro.distributions.ConditionalDistribution, quantize_y: torch_mist.quantization.functions.QuantizationFunction, temperature: float = 1.0)

Bases: torch_mist.estimators.transformed.base.TransformedMIEstimator

class torch_mist.estimators.transformed.BinnedMIEstimator(quantize_x: torch_mist.quantization.functions.QuantizationFunction | None = None, quantize_y: torch_mist.quantization.functions.QuantizationFunction | None = None, temperature: float = 1.0)

Bases: torch_mist.estimators.transformed.base.TransformedMIEstimator

lower_bound = True
class torch_mist.estimators.transformed.QuantizationFunction

Bases: torch.nn.Module

abstract property n_bins: int
abstract quantize(x: torch.Tensor) torch.LongTensor
forward(x: torch.Tensor) torch.LongTensor
torch_mist.estimators.transformed.instantiate_quantization(name: str, n_bins: int, **kwargs) torch_mist.quantization.functions.QuantizationFunction
torch_mist.estimators.transformed.binned(quantize_x: torch_mist.quantization.QuantizationFunction | str | None = 'kmeans', quantize_y: torch_mist.quantization.QuantizationFunction | str | None = 'kmeans', temperature: float = 0.1, n_bins: int | None = 32, x_dim: int | None = None, y_dim: int | None = None, **kwargs) torch_mist.estimators.transformed.implementations.BinnedMIEstimator
torch_mist.estimators.transformed.pq(quantize_y: torch_mist.quantization.QuantizationFunction | str | None = 'kmeans', x_dim: int | None = None, hidden_dims: List[int] | None = None, q_QY_given_X: pyro.distributions.ConditionalDistribution | None = None, temperature: float = 0.1, n_bins: int | None = 32, y_dim: int | None = None, **kwargs) torch_mist.estimators.transformed.implementations.PQ