torch_mist.distributions.factories
Module Contents
Functions
|
|
|
|
|
|
|
- torch_mist.distributions.factories.delete_unused_kwargs(transform_factory: Callable[[Any], Any], all_kwargs: Dict[str, Any], warnings: bool = True)
- torch_mist.distributions.factories.make_transforms(input_dim: int, transform_name: str = 'conditional_linear', normalization: str | None = None, n_transforms: int = 1, **kwargs) List[torch.distributions.Transform | pyro.distributions.ConditionalTransform]
- torch_mist.distributions.factories.transformed_normal(input_dim: int, transform_name: str = 'linear', n_transforms: int = 1, normalization: str | None = None, **kwargs) torch.distributions.Distribution
- torch_mist.distributions.factories.conditional_transformed_normal(input_dim: int, context_dim: int, transform_name: str = 'conditional_linear', n_transforms: int = 1, normalization: str | None = None, **kwargs) torch_mist.distributions.transforms.ConditionalDistributionModule
- torch_mist.distributions.factories.joint_transformed_normal(input_dims: Dict[str, int], transform_name: str = 'conditional_linear', n_transforms: int = 1, normalization: str | None = None, **kwargs) torch_mist.distributions.joint.base.JointDistribution
- torch_mist.distributions.factories.conditional_categorical(n_classes: int, context_dim: int, hidden_dims: List[int], temperature: float = 1.0)