torch_mist.models.vqvae
Module Contents
Classes
Attributes
- torch_mist.models.vqvae.VERSIONS = ['v1', 'v2']
- class torch_mist.models.vqvae.VQVAE(encoder: torch.nn.Module, decoder: pyro.distributions.ConditionalDistribution, initial_vectors: torch.Tensor | None = None, n_bins: int | None = None, quantization_dim: int | None = None, beta: float = 0.25, gamma: float = 0.99, version: str = 'v2', **train_params)
Bases:
torch_mist.quantization.functions.LearnableVectorQuantization,torch_mist.nn.Model- upper_bound: bool = True
- quantize(x: torch.Tensor) torch.LongTensor
- _update_vectors(indices: torch.LongTensor, z: torch.Tensor)
- loss(x: torch.Tensor) torch.Tensor