torch_mist.quantization
Submodules
Package Contents
Classes
Functions
|
|
|
|
|
|
- class torch_mist.quantization.QuantizationFunction
Bases:
torch.nn.Module- abstract property n_bins: int
- abstract quantize(x: torch.Tensor) torch.LongTensor
- forward(x: torch.Tensor) torch.LongTensor
- class torch_mist.quantization.LearnableQuantization(**train_params)
Bases:
QuantizationFunction- _fit(dataloader: torch.utils.data.DataLoader) Any | None
- fit(dataloader: Any) Any | None
- forward(x: torch.Tensor) torch.LongTensor
- abstract loss(*args, **kwargs) torch.Tensor
- class torch_mist.quantization.FixedQuantization(input_dim: int, thresholds: torch.Tensor)
Bases:
QuantizationFunction- property n_bins: int
- quantize(x: torch.Tensor) torch.LongTensor
- class torch_mist.quantization.VectorQuantization(vectors: torch.Tensor)
Bases:
QuantizationFunction- property n_bins: int
- codebook_lookup(x: torch.Tensor) torch.Tensor
- quantize(x: torch.Tensor) torch.LongTensor
- class torch_mist.quantization.LearnableVectorQuantization(transform: torch.nn.Module | None = None, vectors: torch.Tensor | None = None, n_bins: int | None = None, quantization_dim: int | None = None, **train_params)
Bases:
VectorQuantization,LearnableQuantization- quantize(x: torch.Tensor) torch.LongTensor
- torch_mist.quantization.conditional_transformed_normal(input_dim: int, context_dim: int, transform_name: str = 'conditional_linear', n_transforms: int = 1, normalization: str | None = None, **kwargs) torch_mist.distributions.transforms.ConditionalDistributionModule
- torch_mist.quantization.dense_nn(input_dim: int, output_dim: int, hidden_dims: List[int] | None = None, nonlinearity: Callable[[torch.Tensor], torch.Tensor] | None = None) torch.nn.Module
- class torch_mist.quantization.ClusterQuantization(clustering: sklearn.base.TransformerMixin, **train_params)
Bases:
LearnableQuantization- property n_bins: int
- _fit(dataloader: torch.utils.data.DataLoader)
- quantize(x: torch.Tensor) torch.LongTensor
- class torch_mist.quantization.QuantizationFunction
Bases:
torch.nn.Module- abstract property n_bins: int
- abstract quantize(x: torch.Tensor) torch.LongTensor
- forward(x: torch.Tensor) torch.LongTensor
- class torch_mist.quantization.VQVAE(encoder: torch.nn.Module, decoder: pyro.distributions.ConditionalDistribution, initial_vectors: torch.Tensor | None = None, n_bins: int | None = None, quantization_dim: int | None = None, beta: float = 0.25, gamma: float = 0.99, version: str = 'v2', **train_params)
Bases:
torch_mist.quantization.functions.LearnableVectorQuantization,torch_mist.nn.Model- upper_bound: bool = True
- quantize(x: torch.Tensor) torch.LongTensor
- _update_vectors(indices: torch.LongTensor, z: torch.Tensor)
- loss(x: torch.Tensor) torch.Tensor
- torch_mist.quantization.kmeans_quantization(n_bins: int, n_init='auto', **kwargs) torch_mist.quantization.functions.ClusterQuantization
- torch_mist.quantization.vqvae(input_dim: int, quantization_dim: int, n_bins: int, hidden_dims: List[int], beta: float = 0.2, nonlinearity: Callable | None = None, version: str = 'v2', **train_params) torch_mist.models.vqvae.VQVAE
- torch_mist.quantization.instantiate_quantization(name: str, n_bins: int, **kwargs) torch_mist.quantization.functions.QuantizationFunction