torch_mist.distributions.transforms

Subpackages

Submodules

Package Contents

Classes

DistributionModule

Helper class that provides a standard way to create an ABC using

ConditionalDistributionModule

Helper class that provides a standard way to create an ABC using

TransformedDistributionModule

Helper class that provides a standard way to create an ABC using

ConditionalTransformedDistributionModule

Linear

ConditionalLinear

Permute

EMANormalize

Permute

Linear

ConditionalLinear

EMANormalize

Functions

linear(...)

conditional_linear(...)

permute(...)

A helper function to create a Permute

emanorm(...)

fetch_transform(transform_name)

class torch_mist.distributions.transforms.DistributionModule(validate_args: bool = False)

Bases: torch.distributions.Distribution, torch.nn.Module, abc.ABC

Helper class that provides a standard way to create an ABC using inheritance.

__repr__()

Return repr(self).

class torch_mist.distributions.transforms.ConditionalDistributionModule

Bases: pyro.distributions.ConditionalDistribution, torch.nn.Module, abc.ABC

Helper class that provides a standard way to create an ABC using inheritance.

class torch_mist.distributions.transforms.TransformedDistributionModule(base_dist: torch.distributions.Distribution, transforms: torch.distributions.Transform | List[torch.distributions.Transform] | Dict[str, torch.distributions.Transform] | None, cached: bool = True)

Bases: DistributionModule

Helper class that provides a standard way to create an ABC using inheritance.

rsample(sample_shape=torch.Size())
log_prob(value)
__repr__()

Return repr(self).

class torch_mist.distributions.transforms.ConditionalTransformedDistributionModule(base_dist: pyro.distributions.ConditionalDistribution | torch.distributions.Distribution, transforms: pyro.distributions.ConditionalTransform | List[pyro.distributions.ConditionalTransform] | Dict[str, pyro.distributions.ConditionalTransform | torch.distributions.Transform] | torch.distributions.Transform | List[torch.distributions.Transform] | None, cached: bool = True)

Bases: pyro.distributions.ConditionalTransformedDistribution, torch.nn.Module

condition(context)
clear_cache()
__repr__()
class torch_mist.distributions.transforms.Linear(input_dim, loc=None, scale=None, initial_scale=None, epsilon=1e-06)

Bases: ConditionedLinear, pyro.distributions.TransformModule

_params()
class torch_mist.distributions.transforms.ConditionalLinear(net, loc=None, scale=None, initial_scale=None, epsilon=1e-06, skip_connection=False)

Bases: pyro.distributions.ConditionalTransformModule

domain
codomain
bijective = True
_params(context)
condition(context)
class torch_mist.distributions.transforms.Permute(permutation, *, dim=-1, cache_size=1)

Bases: pyro.distributions.transforms.Permute, pyro.distributions.TransformModule

update_device()
_call(x)
_inverse(y)
log_abs_det_jacobian(x, y)
class torch_mist.distributions.transforms.EMANormalize(input_dim: int, epsilon: float = 1e-06, gamma: float = 0.99, normalize_inverse=True)

Bases: torch_mist.distributions.transforms.implementations.linear.Linear

_update_params(t: torch.Tensor)
_inverse(y: torch.Tensor) torch.Tensor
_call(x: torch.Tensor) torch.Tensor
class torch_mist.distributions.transforms.Permute(permutation, *, dim=-1, cache_size=1)

Bases: pyro.distributions.transforms.Permute, pyro.distributions.TransformModule

update_device()
_call(x)
_inverse(y)
log_abs_det_jacobian(x, y)
class torch_mist.distributions.transforms.Linear(input_dim, loc=None, scale=None, initial_scale=None, epsilon=1e-06)

Bases: ConditionedLinear, pyro.distributions.TransformModule

_params()
class torch_mist.distributions.transforms.ConditionalLinear(net, loc=None, scale=None, initial_scale=None, epsilon=1e-06, skip_connection=False)

Bases: pyro.distributions.ConditionalTransformModule

domain
codomain
bijective = True
_params(context)
condition(context)
class torch_mist.distributions.transforms.EMANormalize(input_dim: int, epsilon: float = 1e-06, gamma: float = 0.99, normalize_inverse=True)

Bases: torch_mist.distributions.transforms.implementations.linear.Linear

_update_params(t: torch.Tensor)
_inverse(y: torch.Tensor) torch.Tensor
_call(x: torch.Tensor) torch.Tensor
torch_mist.distributions.transforms.linear(input_dim: int, loc: float | None = None, scale: float | None = None, initial_scale: float | None = None) torch_mist.distributions.transforms.implementations.linear.Linear
torch_mist.distributions.transforms.conditional_linear(input_dim: int, context_dim: int, hidden_dims: List[int] | None = None, scale: float | None = None, initial_scale: float | None = None, nonlinearity: conditional_linear.nn = nn.ReLU(True)) torch_mist.distributions.transforms.implementations.linear.ConditionalLinear
torch_mist.distributions.transforms.permute(input_dim: int, permutation: List[int] | None = None, dim: int = -1) torch_mist.distributions.transforms.implementations.permute.Permute

A helper function to create a Permute object for consistency with other helpers.

Parameters:
  • input_dim (int) – Dimension(s) of input variable to permute. Note that when dim < -1 this must be a tuple corresponding to the event shape.

  • permutation (torch.LongTensor) – Torch tensor of integer indices representing permutation. Defaults to a random permutation.

  • dim (int) – the tensor dimension to permute. This value must be negative and defines the event dim as abs(dim).

torch_mist.distributions.transforms.emanorm(input_dim: int, gamma: float = 0.99, normalize_inverse: bool = True) torch_mist.distributions.transforms.implementations.normalize.EMANormalize
torch_mist.distributions.transforms.fetch_transform(transform_name: str)