torch_mist.distributions

Subpackages

Submodules

Package Contents

Classes

ConditionalCategoricalModule

JointDistribution

ConditionalTransformedNormalModule

TransformedNormalModule

JointTransformedNormalModule

ConditionalDistributionModule

Helper class that provides a standard way to create an ABC using

NormalModule

StandardNormalModule

TransformedNormalModule

ConditionalTransformedNormalModule

JointTransformedNormalModule

CategoricalModule

ConditionalCategoricalModule

EmpiricalDistribution

Functions

fetch_transform(transform_name)

delete_unused_kwargs(transform_factory, all_kwargs[, ...])

make_transforms(...)

transformed_normal(→ torch.distributions.Distribution)

conditional_transformed_normal(...)

joint_transformed_normal(...)

conditional_categorical(n_classes, context_dim, ...[, ...])

class torch_mist.distributions.ConditionalCategoricalModule(net: torch.nn.Module, temperature: float = 1.0)

Bases: torch_mist.distributions.transforms.ConditionalDistributionModule

condition(x)
class torch_mist.distributions.JointDistribution(variables: List[str], name: str = 'p')

Bases: torch.nn.Module, torch.distributions.Distribution, pyro.distributions.ConditionalDistribution

full_name()
abstract _log_prob(**kwargs) torch.Tensor
log_prob(*args, **kwargs) torch.Tensor
abstract _marginal(variables: List[str]) T
marginal(*variables) T
conditional(*variables) pyro.distributions.ConditionalDistribution
condition(**conditioning) T
_mutual_information(variable_1: str, variable_2: str) torch.Tensor
mutual_information(variable_1: str | None = None, variable_2: str | None = None) torch.Tensor
abstract _entropy(variables: List[str]) torch.Tensor
entropy(*variables) torch.Tensor
class torch_mist.distributions.ConditionalTransformedNormalModule(input_dim: int, transforms: List[torch.distributions.Transform | pyro.distributions.ConditionalTransform])

Bases: torch_mist.distributions.transforms.ConditionalTransformedDistributionModule

class torch_mist.distributions.TransformedNormalModule(input_dim: int, transforms: List[torch.distributions.Transform])

Bases: torch_mist.distributions.transforms.TransformedDistributionModule

class torch_mist.distributions.JointTransformedNormalModule(input_dims: Dict[str, int], transforms: List[torch.distributions.Transform | pyro.distributions.ConditionalTransform])

Bases: torch_mist.distributions.joint.wrapper.TorchJointDistribution

class torch_mist.distributions.ConditionalDistributionModule

Bases: pyro.distributions.ConditionalDistribution, torch.nn.Module, abc.ABC

Helper class that provides a standard way to create an ABC using inheritance.

torch_mist.distributions.fetch_transform(transform_name: str)
torch_mist.distributions.delete_unused_kwargs(transform_factory: Callable[[Any], Any], all_kwargs: Dict[str, Any], warnings: bool = True)
torch_mist.distributions.make_transforms(input_dim: int, transform_name: str = 'conditional_linear', normalization: str | None = None, n_transforms: int = 1, **kwargs) List[torch.distributions.Transform | pyro.distributions.ConditionalTransform]
torch_mist.distributions.transformed_normal(input_dim: int, transform_name: str = 'linear', n_transforms: int = 1, normalization: str | None = None, **kwargs) torch.distributions.Distribution
torch_mist.distributions.conditional_transformed_normal(input_dim: int, context_dim: int, transform_name: str = 'conditional_linear', n_transforms: int = 1, normalization: str | None = None, **kwargs) torch_mist.distributions.transforms.ConditionalDistributionModule
torch_mist.distributions.joint_transformed_normal(input_dims: Dict[str, int], transform_name: str = 'conditional_linear', n_transforms: int = 1, normalization: str | None = None, **kwargs) torch_mist.distributions.joint.base.JointDistribution
torch_mist.distributions.conditional_categorical(n_classes: int, context_dim: int, hidden_dims: List[int], temperature: float = 1.0)
class torch_mist.distributions.NormalModule(loc: torch.Tensor, scale: torch.Tensor, learnable: bool = False)

Bases: torch.distributions.Distribution, torch.nn.Module

rsample(sample_shape=torch.Size())
log_prob(value)
__repr__()
class torch_mist.distributions.StandardNormalModule(n_dim: int)

Bases: NormalModule

class torch_mist.distributions.TransformedNormalModule(input_dim: int, transforms: List[torch.distributions.Transform])

Bases: torch_mist.distributions.transforms.TransformedDistributionModule

class torch_mist.distributions.ConditionalTransformedNormalModule(input_dim: int, transforms: List[torch.distributions.Transform | pyro.distributions.ConditionalTransform])

Bases: torch_mist.distributions.transforms.ConditionalTransformedDistributionModule

class torch_mist.distributions.JointTransformedNormalModule(input_dims: Dict[str, int], transforms: List[torch.distributions.Transform | pyro.distributions.ConditionalTransform])

Bases: torch_mist.distributions.joint.wrapper.TorchJointDistribution

class torch_mist.distributions.CategoricalModule(logits: torch.tensor, temperature: float = 1.0)

Bases: torch.distributions.Distribution, torch.nn.Module

rsample(sample_shape=torch.Size())
log_prob(value)
__repr__()
class torch_mist.distributions.ConditionalCategoricalModule(net: torch.nn.Module, temperature: float = 1.0)

Bases: torch_mist.distributions.transforms.ConditionalDistributionModule

condition(x)
class torch_mist.distributions.EmpiricalDistribution

Bases: torch.distributions.Distribution

add_samples(samples)
sample(sample_shape: torch.Size = torch.Size()) torch.Tensor
update()
__repr__()