torch_mist.distributions.normal
Module Contents
Classes
- class torch_mist.distributions.normal.NormalModule(loc: torch.Tensor, scale: torch.Tensor, learnable: bool = False)
Bases:
torch.distributions.Distribution,torch.nn.Module- rsample(sample_shape=torch.Size())
- log_prob(value)
- __repr__()
- class torch_mist.distributions.normal.StandardNormalModule(n_dim: int)
Bases:
NormalModule
- class torch_mist.distributions.normal.ConditionalStandardNormalModule(n_dim: int)
Bases:
torch_mist.distributions.transforms.ConditionalDistributionModule- condition(context)
- class torch_mist.distributions.normal.TransformedNormalModule(input_dim: int, transforms: List[torch.distributions.Transform])
Bases:
torch_mist.distributions.transforms.TransformedDistributionModule
- class torch_mist.distributions.normal.ConditionalTransformedNormalModule(input_dim: int, transforms: List[torch.distributions.Transform | pyro.distributions.ConditionalTransform])
Bases:
torch_mist.distributions.transforms.ConditionalTransformedDistributionModule
- class torch_mist.distributions.normal.JointTransformedNormalModule(input_dims: Dict[str, int], transforms: List[torch.distributions.Transform | pyro.distributions.ConditionalTransform])
Bases:
torch_mist.distributions.joint.wrapper.TorchJointDistribution