torch_mist.quantization.functions
Module Contents
Classes
- class torch_mist.quantization.functions.QuantizationFunction
Bases:
torch.nn.Module- abstract property n_bins: int
- abstract quantize(x: torch.Tensor) torch.LongTensor
- forward(x: torch.Tensor) torch.LongTensor
- exception torch_mist.quantization.functions.NotTrainedError(message: str, model_to_train: torch.nn.Module)
Bases:
ExceptionCommon base class for all non-exit exceptions.
- class torch_mist.quantization.functions.LearnableQuantization(**train_params)
Bases:
QuantizationFunction- _fit(dataloader: torch.utils.data.DataLoader) Any | None
- fit(dataloader: Any) Any | None
- forward(x: torch.Tensor) torch.LongTensor
- abstract loss(*args, **kwargs) torch.Tensor
- class torch_mist.quantization.functions.ClusterQuantization(clustering: sklearn.base.TransformerMixin, **train_params)
Bases:
LearnableQuantization- property n_bins: int
- _fit(dataloader: torch.utils.data.DataLoader)
- quantize(x: torch.Tensor) torch.LongTensor
- class torch_mist.quantization.functions.VectorQuantization(vectors: torch.Tensor)
Bases:
QuantizationFunction- property n_bins: int
- codebook_lookup(x: torch.Tensor) torch.Tensor
- quantize(x: torch.Tensor) torch.LongTensor
- class torch_mist.quantization.functions.LearnableVectorQuantization(transform: torch.nn.Module | None = None, vectors: torch.Tensor | None = None, n_bins: int | None = None, quantization_dim: int | None = None, **train_params)
Bases:
VectorQuantization,LearnableQuantization- quantize(x: torch.Tensor) torch.LongTensor
- class torch_mist.quantization.functions.FixedQuantization(input_dim: int, thresholds: torch.Tensor)
Bases:
QuantizationFunction- property n_bins: int
- quantize(x: torch.Tensor) torch.LongTensor