torch_mist.distributions.transforms.factories

Module Contents

Functions

linear(...)

conditional_linear(...)

permute(...)

A helper function to create a Permute

emanorm(...)

torch_mist.distributions.transforms.factories.linear(input_dim: int, loc: float | None = None, scale: float | None = None, initial_scale: float | None = None) torch_mist.distributions.transforms.implementations.linear.Linear
torch_mist.distributions.transforms.factories.conditional_linear(input_dim: int, context_dim: int, hidden_dims: List[int] | None = None, scale: float | None = None, initial_scale: float | None = None, nonlinearity: conditional_linear.nn = nn.ReLU(True)) torch_mist.distributions.transforms.implementations.linear.ConditionalLinear
torch_mist.distributions.transforms.factories.permute(input_dim: int, permutation: List[int] | None = None, dim: int = -1) torch_mist.distributions.transforms.implementations.permute.Permute

A helper function to create a Permute object for consistency with other helpers.

Parameters:
  • input_dim (int) – Dimension(s) of input variable to permute. Note that when dim < -1 this must be a tuple corresponding to the event shape.

  • permutation (torch.LongTensor) – Torch tensor of integer indices representing permutation. Defaults to a random permutation.

  • dim (int) – the tensor dimension to permute. This value must be negative and defines the event dim as abs(dim).

torch_mist.distributions.transforms.factories.emanorm(input_dim: int, gamma: float = 0.99, normalize_inverse: bool = True) torch_mist.distributions.transforms.implementations.normalize.EMANormalize