torch_mist.data.multivariate

Module Contents

Classes

JointMultivariateNormal

class torch_mist.data.multivariate.JointMultivariateNormal(n_dim: int, rho: float = 0.9, sigma: float = 1, device: torch.device = torch.device('cpu'))

Bases: torch_mist.distributions.joint.wrapper.TorchJointDistribution

_marginal(variables: List[str]) torch_mist.distributions.joint.base.JointDistribution
_entropy(variables: List[str]) torch.Tensor