torch_mist.distributions.transforms.factories
Module Contents
Functions
|
|
|
|
|
A helper function to create a |
|
- torch_mist.distributions.transforms.factories.linear(input_dim: int, loc: float | None = None, scale: float | None = None, initial_scale: float | None = None) torch_mist.distributions.transforms.implementations.linear.Linear
- torch_mist.distributions.transforms.factories.conditional_linear(input_dim: int, context_dim: int, hidden_dims: List[int] | None = None, scale: float | None = None, initial_scale: float | None = None, nonlinearity: conditional_linear.nn = nn.ReLU(True)) torch_mist.distributions.transforms.implementations.linear.ConditionalLinear
- torch_mist.distributions.transforms.factories.permute(input_dim: int, permutation: List[int] | None = None, dim: int = -1) torch_mist.distributions.transforms.implementations.permute.Permute
A helper function to create a
Permuteobject for consistency with other helpers.- Parameters:
input_dim (int) – Dimension(s) of input variable to permute. Note that when dim < -1 this must be a tuple corresponding to the event shape.
permutation (torch.LongTensor) – Torch tensor of integer indices representing permutation. Defaults to a random permutation.
dim (int) – the tensor dimension to permute. This value must be negative and defines the event dim as abs(dim).
- torch_mist.distributions.transforms.factories.emanorm(input_dim: int, gamma: float = 0.99, normalize_inverse: bool = True) torch_mist.distributions.transforms.implementations.normalize.EMANormalize