torch_mist.distributions.transforms.implementations
Submodules
Package Contents
Classes
- class torch_mist.distributions.transforms.implementations.Linear(input_dim, loc=None, scale=None, initial_scale=None, epsilon=1e-06)
Bases:
ConditionedLinear,pyro.distributions.TransformModule- _params()
- class torch_mist.distributions.transforms.implementations.ConditionalLinear(net, loc=None, scale=None, initial_scale=None, epsilon=1e-06, skip_connection=False)
Bases:
pyro.distributions.ConditionalTransformModule- domain
- codomain
- bijective = True
- _params(context)
- condition(context)
- class torch_mist.distributions.transforms.implementations.Permute(permutation, *, dim=-1, cache_size=1)
Bases:
pyro.distributions.transforms.Permute,pyro.distributions.TransformModule- update_device()
- _call(x)
- _inverse(y)
- log_abs_det_jacobian(x, y)
- class torch_mist.distributions.transforms.implementations.EMANormalize(input_dim: int, epsilon: float = 1e-06, gamma: float = 0.99, normalize_inverse=True)
Bases:
torch_mist.distributions.transforms.implementations.linear.Linear- _update_params(t: torch.Tensor)
- _inverse(y: torch.Tensor) torch.Tensor
- _call(x: torch.Tensor) torch.Tensor