torch_mist.distributions.normal

Module Contents

Classes

NormalModule

StandardNormalModule

ConditionalStandardNormalModule

TransformedNormalModule

ConditionalTransformedNormalModule

JointTransformedNormalModule

class torch_mist.distributions.normal.NormalModule(loc: torch.Tensor, scale: torch.Tensor, learnable: bool = False)

Bases: torch.distributions.Distribution, torch.nn.Module

rsample(sample_shape=torch.Size())
log_prob(value)
__repr__()
class torch_mist.distributions.normal.StandardNormalModule(n_dim: int)

Bases: NormalModule

class torch_mist.distributions.normal.ConditionalStandardNormalModule(n_dim: int)

Bases: torch_mist.distributions.transforms.ConditionalDistributionModule

condition(context)
class torch_mist.distributions.normal.TransformedNormalModule(input_dim: int, transforms: List[torch.distributions.Transform])

Bases: torch_mist.distributions.transforms.TransformedDistributionModule

class torch_mist.distributions.normal.ConditionalTransformedNormalModule(input_dim: int, transforms: List[torch.distributions.Transform | pyro.distributions.ConditionalTransform])

Bases: torch_mist.distributions.transforms.ConditionalTransformedDistributionModule

class torch_mist.distributions.normal.JointTransformedNormalModule(input_dims: Dict[str, int], transforms: List[torch.distributions.Transform | pyro.distributions.ConditionalTransform])

Bases: torch_mist.distributions.joint.wrapper.TorchJointDistribution