torch_mist.distributions.categorical
Module Contents
Classes
- class torch_mist.distributions.categorical.CategoricalModule(logits: torch.tensor, temperature: float = 1.0)
Bases:
torch.distributions.Distribution,torch.nn.Module- rsample(sample_shape=torch.Size())
- log_prob(value)
- __repr__()
- class torch_mist.distributions.categorical.ConditionalCategoricalModule(net: torch.nn.Module, temperature: float = 1.0)
Bases:
torch_mist.distributions.transforms.ConditionalDistributionModule- condition(x)