torch_mist.distributions.joint.categorical
Module Contents
Classes
- class torch_mist.distributions.joint.categorical.JointCategorical(variables: List[str], bins: List[int], temperature: float = 1.0, logits: torch.Tensor | None = None, name='p')
Bases:
torch_mist.distributions.joint.base.JointDistribution- property categorical: torch.distributions.Categorical
- _tensor_to_dict(tensor: torch.LongTensor) Dict[str, torch.LongTensor]
- _dict_to_tensor(tensor_dict: Dict[str, torch.Tensor]) torch.Tensor
- _log_prob(**kwargs) torch.Tensor
- _entropy(variables: List[str]) torch.Tensor
- sample(sample_shape: torch.Size = torch.Size()) Dict[str, torch.Tensor]
- _marginal(variables: List[str]) torch_mist.distributions.joint.base.T