torch_mist.distributions.categorical

Module Contents

Classes

CategoricalModule

ConditionalCategoricalModule

class torch_mist.distributions.categorical.CategoricalModule(logits: torch.tensor, temperature: float = 1.0)

Bases: torch.distributions.Distribution, torch.nn.Module

rsample(sample_shape=torch.Size())
log_prob(value)
__repr__()
class torch_mist.distributions.categorical.ConditionalCategoricalModule(net: torch.nn.Module, temperature: float = 1.0)

Bases: torch_mist.distributions.transforms.ConditionalDistributionModule

condition(x)