torch_mist.distributions.joint.wrapper

Module Contents

Classes

TorchJointDistribution

class torch_mist.distributions.joint.wrapper.TorchJointDistribution(torch_dist: torch.distributions.Distribution, variables: List[str], splits: List[int] | None = None, split_dim: int = -1, name: str | None = 'p')

Bases: torch_mist.distributions.joint.base.JointDistribution

_tensor_to_dict(tensor: torch.Tensor) Dict[str, torch.Tensor]
_dict_to_tensor(dict: Dict[str, torch.Tensor]) torch.Tensor
_log_prob(**kwargs) torch.Tensor
_entropy(variables: List[str]) torch.Tensor
sample(sample_shape: torch.Size = torch.Size()) Dict[str, torch.Tensor]
rsample(sample_shape: torch.Size = torch.Size()) Dict[str, torch.Tensor]