torch_mist.quantization

Submodules

Package Contents

Classes

QuantizationFunction

LearnableQuantization

FixedQuantization

VectorQuantization

LearnableVectorQuantization

ClusterQuantization

QuantizationFunction

VQVAE

Functions

conditional_transformed_normal(...)

dense_nn(→ torch.nn.Module)

kmeans_quantization(...)

vqvae(→ torch_mist.models.vqvae.VQVAE)

instantiate_quantization(...)

class torch_mist.quantization.QuantizationFunction

Bases: torch.nn.Module

abstract property n_bins: int
abstract quantize(x: torch.Tensor) torch.LongTensor
forward(x: torch.Tensor) torch.LongTensor
class torch_mist.quantization.LearnableQuantization(**train_params)

Bases: QuantizationFunction

_fit(dataloader: torch.utils.data.DataLoader) Any | None
fit(dataloader: Any) Any | None
forward(x: torch.Tensor) torch.LongTensor
abstract loss(*args, **kwargs) torch.Tensor
class torch_mist.quantization.FixedQuantization(input_dim: int, thresholds: torch.Tensor)

Bases: QuantizationFunction

property n_bins: int
quantize(x: torch.Tensor) torch.LongTensor
class torch_mist.quantization.VectorQuantization(vectors: torch.Tensor)

Bases: QuantizationFunction

property n_bins: int
codebook_lookup(x: torch.Tensor) torch.Tensor
quantize(x: torch.Tensor) torch.LongTensor
class torch_mist.quantization.LearnableVectorQuantization(transform: torch.nn.Module | None = None, vectors: torch.Tensor | None = None, n_bins: int | None = None, quantization_dim: int | None = None, **train_params)

Bases: VectorQuantization, LearnableQuantization

quantize(x: torch.Tensor) torch.LongTensor
torch_mist.quantization.conditional_transformed_normal(input_dim: int, context_dim: int, transform_name: str = 'conditional_linear', n_transforms: int = 1, normalization: str | None = None, **kwargs) torch_mist.distributions.transforms.ConditionalDistributionModule
torch_mist.quantization.dense_nn(input_dim: int, output_dim: int, hidden_dims: List[int] | None = None, nonlinearity: Callable[[torch.Tensor], torch.Tensor] | None = None) torch.nn.Module
class torch_mist.quantization.ClusterQuantization(clustering: sklearn.base.TransformerMixin, **train_params)

Bases: LearnableQuantization

property n_bins: int
_fit(dataloader: torch.utils.data.DataLoader)
quantize(x: torch.Tensor) torch.LongTensor
class torch_mist.quantization.QuantizationFunction

Bases: torch.nn.Module

abstract property n_bins: int
abstract quantize(x: torch.Tensor) torch.LongTensor
forward(x: torch.Tensor) torch.LongTensor
class torch_mist.quantization.VQVAE(encoder: torch.nn.Module, decoder: pyro.distributions.ConditionalDistribution, initial_vectors: torch.Tensor | None = None, n_bins: int | None = None, quantization_dim: int | None = None, beta: float = 0.25, gamma: float = 0.99, version: str = 'v2', **train_params)

Bases: torch_mist.quantization.functions.LearnableVectorQuantization, torch_mist.nn.Model

upper_bound: bool = True
quantize(x: torch.Tensor) torch.LongTensor
_update_vectors(indices: torch.LongTensor, z: torch.Tensor)
loss(x: torch.Tensor) torch.Tensor
torch_mist.quantization.kmeans_quantization(n_bins: int, n_init='auto', **kwargs) torch_mist.quantization.functions.ClusterQuantization
torch_mist.quantization.vqvae(input_dim: int, quantization_dim: int, n_bins: int, hidden_dims: List[int], beta: float = 0.2, nonlinearity: Callable | None = None, version: str = 'v2', **train_params) torch_mist.models.vqvae.VQVAE
torch_mist.quantization.instantiate_quantization(name: str, n_bins: int, **kwargs) torch_mist.quantization.functions.QuantizationFunction