torch_mist.distributions.factories

Module Contents

Functions

delete_unused_kwargs(transform_factory, all_kwargs[, ...])

make_transforms(...)

transformed_normal(→ torch.distributions.Distribution)

conditional_transformed_normal(...)

joint_transformed_normal(...)

conditional_categorical(n_classes, context_dim, ...[, ...])

torch_mist.distributions.factories.delete_unused_kwargs(transform_factory: Callable[[Any], Any], all_kwargs: Dict[str, Any], warnings: bool = True)
torch_mist.distributions.factories.make_transforms(input_dim: int, transform_name: str = 'conditional_linear', normalization: str | None = None, n_transforms: int = 1, **kwargs) List[torch.distributions.Transform | pyro.distributions.ConditionalTransform]
torch_mist.distributions.factories.transformed_normal(input_dim: int, transform_name: str = 'linear', n_transforms: int = 1, normalization: str | None = None, **kwargs) torch.distributions.Distribution
torch_mist.distributions.factories.conditional_transformed_normal(input_dim: int, context_dim: int, transform_name: str = 'conditional_linear', n_transforms: int = 1, normalization: str | None = None, **kwargs) torch_mist.distributions.transforms.ConditionalDistributionModule
torch_mist.distributions.factories.joint_transformed_normal(input_dims: Dict[str, int], transform_name: str = 'conditional_linear', n_transforms: int = 1, normalization: str | None = None, **kwargs) torch_mist.distributions.joint.base.JointDistribution
torch_mist.distributions.factories.conditional_categorical(n_classes: int, context_dim: int, hidden_dims: List[int], temperature: float = 1.0)