torch_mist.distributions.transforms.implementations

Submodules

Package Contents

Classes

Linear

ConditionalLinear

Permute

EMANormalize

class torch_mist.distributions.transforms.implementations.Linear(input_dim, loc=None, scale=None, initial_scale=None, epsilon=1e-06)

Bases: ConditionedLinear, pyro.distributions.TransformModule

_params()
class torch_mist.distributions.transforms.implementations.ConditionalLinear(net, loc=None, scale=None, initial_scale=None, epsilon=1e-06, skip_connection=False)

Bases: pyro.distributions.ConditionalTransformModule

domain
codomain
bijective = True
_params(context)
condition(context)
class torch_mist.distributions.transforms.implementations.Permute(permutation, *, dim=-1, cache_size=1)

Bases: pyro.distributions.transforms.Permute, pyro.distributions.TransformModule

update_device()
_call(x)
_inverse(y)
log_abs_det_jacobian(x, y)
class torch_mist.distributions.transforms.implementations.EMANormalize(input_dim: int, epsilon: float = 1e-06, gamma: float = 0.99, normalize_inverse=True)

Bases: torch_mist.distributions.transforms.implementations.linear.Linear

_update_params(t: torch.Tensor)
_inverse(y: torch.Tensor) torch.Tensor
_call(x: torch.Tensor) torch.Tensor