torch_mist.models.vqvae

Module Contents

Classes

VQVAE

Attributes

VERSIONS

torch_mist.models.vqvae.VERSIONS = ['v1', 'v2']
class torch_mist.models.vqvae.VQVAE(encoder: torch.nn.Module, decoder: pyro.distributions.ConditionalDistribution, initial_vectors: torch.Tensor | None = None, n_bins: int | None = None, quantization_dim: int | None = None, beta: float = 0.25, gamma: float = 0.99, version: str = 'v2', **train_params)

Bases: torch_mist.quantization.functions.LearnableVectorQuantization, torch_mist.nn.Model

upper_bound: bool = True
quantize(x: torch.Tensor) torch.LongTensor
_update_vectors(indices: torch.LongTensor, z: torch.Tensor)
loss(x: torch.Tensor) torch.Tensor