torch_mist.distributions.joint.categorical

Module Contents

Classes

JointCategorical

class torch_mist.distributions.joint.categorical.JointCategorical(variables: List[str], bins: List[int], temperature: float = 1.0, logits: torch.Tensor | None = None, name='p')

Bases: torch_mist.distributions.joint.base.JointDistribution

property categorical: torch.distributions.Categorical
_tensor_to_dict(tensor: torch.LongTensor) Dict[str, torch.LongTensor]
_dict_to_tensor(tensor_dict: Dict[str, torch.Tensor]) torch.Tensor
_log_prob(**kwargs) torch.Tensor
_entropy(variables: List[str]) torch.Tensor
sample(sample_shape: torch.Size = torch.Size()) Dict[str, torch.Tensor]
_marginal(variables: List[str]) torch_mist.distributions.joint.base.T