torch_mist.quantization.functions

Module Contents

Classes

QuantizationFunction

LearnableQuantization

ClusterQuantization

VectorQuantization

LearnableVectorQuantization

FixedQuantization

class torch_mist.quantization.functions.QuantizationFunction

Bases: torch.nn.Module

abstract property n_bins: int
abstract quantize(x: torch.Tensor) torch.LongTensor
forward(x: torch.Tensor) torch.LongTensor
exception torch_mist.quantization.functions.NotTrainedError(message: str, model_to_train: torch.nn.Module)

Bases: Exception

Common base class for all non-exit exceptions.

class torch_mist.quantization.functions.LearnableQuantization(**train_params)

Bases: QuantizationFunction

_fit(dataloader: torch.utils.data.DataLoader) Any | None
fit(dataloader: Any) Any | None
forward(x: torch.Tensor) torch.LongTensor
abstract loss(*args, **kwargs) torch.Tensor
class torch_mist.quantization.functions.ClusterQuantization(clustering: sklearn.base.TransformerMixin, **train_params)

Bases: LearnableQuantization

property n_bins: int
_fit(dataloader: torch.utils.data.DataLoader)
quantize(x: torch.Tensor) torch.LongTensor
class torch_mist.quantization.functions.VectorQuantization(vectors: torch.Tensor)

Bases: QuantizationFunction

property n_bins: int
codebook_lookup(x: torch.Tensor) torch.Tensor
quantize(x: torch.Tensor) torch.LongTensor
class torch_mist.quantization.functions.LearnableVectorQuantization(transform: torch.nn.Module | None = None, vectors: torch.Tensor | None = None, n_bins: int | None = None, quantization_dim: int | None = None, **train_params)

Bases: VectorQuantization, LearnableQuantization

quantize(x: torch.Tensor) torch.LongTensor
class torch_mist.quantization.functions.FixedQuantization(input_dim: int, thresholds: torch.Tensor)

Bases: QuantizationFunction

property n_bins: int
quantize(x: torch.Tensor) torch.LongTensor