torch_mist.distributions.transforms
Subpackages
Submodules
Package Contents
Classes
Helper class that provides a standard way to create an ABC using |
|
Helper class that provides a standard way to create an ABC using |
|
Helper class that provides a standard way to create an ABC using |
|
Functions
|
|
|
|
|
A helper function to create a |
|
|
|
- class torch_mist.distributions.transforms.DistributionModule(validate_args: bool = False)
Bases:
torch.distributions.Distribution,torch.nn.Module,abc.ABCHelper class that provides a standard way to create an ABC using inheritance.
- __repr__()
Return repr(self).
- class torch_mist.distributions.transforms.ConditionalDistributionModule
Bases:
pyro.distributions.ConditionalDistribution,torch.nn.Module,abc.ABCHelper class that provides a standard way to create an ABC using inheritance.
- class torch_mist.distributions.transforms.TransformedDistributionModule(base_dist: torch.distributions.Distribution, transforms: torch.distributions.Transform | List[torch.distributions.Transform] | Dict[str, torch.distributions.Transform] | None, cached: bool = True)
Bases:
DistributionModuleHelper class that provides a standard way to create an ABC using inheritance.
- rsample(sample_shape=torch.Size())
- log_prob(value)
- __repr__()
Return repr(self).
- class torch_mist.distributions.transforms.ConditionalTransformedDistributionModule(base_dist: pyro.distributions.ConditionalDistribution | torch.distributions.Distribution, transforms: pyro.distributions.ConditionalTransform | List[pyro.distributions.ConditionalTransform] | Dict[str, pyro.distributions.ConditionalTransform | torch.distributions.Transform] | torch.distributions.Transform | List[torch.distributions.Transform] | None, cached: bool = True)
Bases:
pyro.distributions.ConditionalTransformedDistribution,torch.nn.Module- condition(context)
- clear_cache()
- __repr__()
- class torch_mist.distributions.transforms.Linear(input_dim, loc=None, scale=None, initial_scale=None, epsilon=1e-06)
Bases:
ConditionedLinear,pyro.distributions.TransformModule- _params()
- class torch_mist.distributions.transforms.ConditionalLinear(net, loc=None, scale=None, initial_scale=None, epsilon=1e-06, skip_connection=False)
Bases:
pyro.distributions.ConditionalTransformModule- domain
- codomain
- bijective = True
- _params(context)
- condition(context)
- class torch_mist.distributions.transforms.Permute(permutation, *, dim=-1, cache_size=1)
Bases:
pyro.distributions.transforms.Permute,pyro.distributions.TransformModule- update_device()
- _call(x)
- _inverse(y)
- log_abs_det_jacobian(x, y)
- class torch_mist.distributions.transforms.EMANormalize(input_dim: int, epsilon: float = 1e-06, gamma: float = 0.99, normalize_inverse=True)
Bases:
torch_mist.distributions.transforms.implementations.linear.Linear- _update_params(t: torch.Tensor)
- _inverse(y: torch.Tensor) torch.Tensor
- _call(x: torch.Tensor) torch.Tensor
- class torch_mist.distributions.transforms.Permute(permutation, *, dim=-1, cache_size=1)
Bases:
pyro.distributions.transforms.Permute,pyro.distributions.TransformModule- update_device()
- _call(x)
- _inverse(y)
- log_abs_det_jacobian(x, y)
- class torch_mist.distributions.transforms.Linear(input_dim, loc=None, scale=None, initial_scale=None, epsilon=1e-06)
Bases:
ConditionedLinear,pyro.distributions.TransformModule- _params()
- class torch_mist.distributions.transforms.ConditionalLinear(net, loc=None, scale=None, initial_scale=None, epsilon=1e-06, skip_connection=False)
Bases:
pyro.distributions.ConditionalTransformModule- domain
- codomain
- bijective = True
- _params(context)
- condition(context)
- class torch_mist.distributions.transforms.EMANormalize(input_dim: int, epsilon: float = 1e-06, gamma: float = 0.99, normalize_inverse=True)
Bases:
torch_mist.distributions.transforms.implementations.linear.Linear- _update_params(t: torch.Tensor)
- _inverse(y: torch.Tensor) torch.Tensor
- _call(x: torch.Tensor) torch.Tensor
- torch_mist.distributions.transforms.linear(input_dim: int, loc: float | None = None, scale: float | None = None, initial_scale: float | None = None) torch_mist.distributions.transforms.implementations.linear.Linear
- torch_mist.distributions.transforms.conditional_linear(input_dim: int, context_dim: int, hidden_dims: List[int] | None = None, scale: float | None = None, initial_scale: float | None = None, nonlinearity: conditional_linear.nn = nn.ReLU(True)) torch_mist.distributions.transforms.implementations.linear.ConditionalLinear
- torch_mist.distributions.transforms.permute(input_dim: int, permutation: List[int] | None = None, dim: int = -1) torch_mist.distributions.transforms.implementations.permute.Permute
A helper function to create a
Permuteobject for consistency with other helpers.- Parameters:
input_dim (int) – Dimension(s) of input variable to permute. Note that when dim < -1 this must be a tuple corresponding to the event shape.
permutation (torch.LongTensor) – Torch tensor of integer indices representing permutation. Defaults to a random permutation.
dim (int) – the tensor dimension to permute. This value must be negative and defines the event dim as abs(dim).
- torch_mist.distributions.transforms.emanorm(input_dim: int, gamma: float = 0.99, normalize_inverse: bool = True) torch_mist.distributions.transforms.implementations.normalize.EMANormalize
- torch_mist.distributions.transforms.fetch_transform(transform_name: str)