torch_mist.distributions
Subpackages
torch_mist.distributions.conditionaltorch_mist.distributions.jointtorch_mist.distributions.parametrizationstorch_mist.distributions.transforms
Submodules
Package Contents
Classes
Helper class that provides a standard way to create an ABC using |
|
Functions
|
|
|
|
|
|
|
|
|
- class torch_mist.distributions.ConditionalCategoricalModule(net: torch.nn.Module, temperature: float = 1.0)
Bases:
torch_mist.distributions.transforms.ConditionalDistributionModule- condition(x)
- class torch_mist.distributions.JointDistribution(variables: List[str], name: str = 'p')
Bases:
torch.nn.Module,torch.distributions.Distribution,pyro.distributions.ConditionalDistribution- full_name()
- abstract _log_prob(**kwargs) torch.Tensor
- log_prob(*args, **kwargs) torch.Tensor
- abstract _marginal(variables: List[str]) T
- marginal(*variables) T
- conditional(*variables) pyro.distributions.ConditionalDistribution
- condition(**conditioning) T
- _mutual_information(variable_1: str, variable_2: str) torch.Tensor
- mutual_information(variable_1: str | None = None, variable_2: str | None = None) torch.Tensor
- abstract _entropy(variables: List[str]) torch.Tensor
- entropy(*variables) torch.Tensor
- class torch_mist.distributions.ConditionalTransformedNormalModule(input_dim: int, transforms: List[torch.distributions.Transform | pyro.distributions.ConditionalTransform])
Bases:
torch_mist.distributions.transforms.ConditionalTransformedDistributionModule
- class torch_mist.distributions.TransformedNormalModule(input_dim: int, transforms: List[torch.distributions.Transform])
Bases:
torch_mist.distributions.transforms.TransformedDistributionModule
- class torch_mist.distributions.JointTransformedNormalModule(input_dims: Dict[str, int], transforms: List[torch.distributions.Transform | pyro.distributions.ConditionalTransform])
Bases:
torch_mist.distributions.joint.wrapper.TorchJointDistribution
- class torch_mist.distributions.ConditionalDistributionModule
Bases:
pyro.distributions.ConditionalDistribution,torch.nn.Module,abc.ABCHelper class that provides a standard way to create an ABC using inheritance.
- torch_mist.distributions.fetch_transform(transform_name: str)
- torch_mist.distributions.delete_unused_kwargs(transform_factory: Callable[[Any], Any], all_kwargs: Dict[str, Any], warnings: bool = True)
- torch_mist.distributions.make_transforms(input_dim: int, transform_name: str = 'conditional_linear', normalization: str | None = None, n_transforms: int = 1, **kwargs) List[torch.distributions.Transform | pyro.distributions.ConditionalTransform]
- torch_mist.distributions.transformed_normal(input_dim: int, transform_name: str = 'linear', n_transforms: int = 1, normalization: str | None = None, **kwargs) torch.distributions.Distribution
- torch_mist.distributions.conditional_transformed_normal(input_dim: int, context_dim: int, transform_name: str = 'conditional_linear', n_transforms: int = 1, normalization: str | None = None, **kwargs) torch_mist.distributions.transforms.ConditionalDistributionModule
- torch_mist.distributions.joint_transformed_normal(input_dims: Dict[str, int], transform_name: str = 'conditional_linear', n_transforms: int = 1, normalization: str | None = None, **kwargs) torch_mist.distributions.joint.base.JointDistribution
- torch_mist.distributions.conditional_categorical(n_classes: int, context_dim: int, hidden_dims: List[int], temperature: float = 1.0)
- class torch_mist.distributions.NormalModule(loc: torch.Tensor, scale: torch.Tensor, learnable: bool = False)
Bases:
torch.distributions.Distribution,torch.nn.Module- rsample(sample_shape=torch.Size())
- log_prob(value)
- __repr__()
- class torch_mist.distributions.StandardNormalModule(n_dim: int)
Bases:
NormalModule
- class torch_mist.distributions.TransformedNormalModule(input_dim: int, transforms: List[torch.distributions.Transform])
Bases:
torch_mist.distributions.transforms.TransformedDistributionModule
- class torch_mist.distributions.ConditionalTransformedNormalModule(input_dim: int, transforms: List[torch.distributions.Transform | pyro.distributions.ConditionalTransform])
Bases:
torch_mist.distributions.transforms.ConditionalTransformedDistributionModule
- class torch_mist.distributions.JointTransformedNormalModule(input_dims: Dict[str, int], transforms: List[torch.distributions.Transform | pyro.distributions.ConditionalTransform])
Bases:
torch_mist.distributions.joint.wrapper.TorchJointDistribution
- class torch_mist.distributions.CategoricalModule(logits: torch.tensor, temperature: float = 1.0)
Bases:
torch.distributions.Distribution,torch.nn.Module- rsample(sample_shape=torch.Size())
- log_prob(value)
- __repr__()
- class torch_mist.distributions.ConditionalCategoricalModule(net: torch.nn.Module, temperature: float = 1.0)
Bases:
torch_mist.distributions.transforms.ConditionalDistributionModule- condition(x)