torch_mist.distributions.joint.base

Module Contents

Classes

JointDistribution

ConditionedJointDistribution

ConditionalNamedDistribution

Attributes

T

torch_mist.distributions.joint.base.T
class torch_mist.distributions.joint.base.JointDistribution(variables: List[str], name: str = 'p')

Bases: torch.nn.Module, torch.distributions.Distribution, pyro.distributions.ConditionalDistribution

full_name()
abstract _log_prob(**kwargs) torch.Tensor
log_prob(*args, **kwargs) torch.Tensor
abstract _marginal(variables: List[str]) T
marginal(*variables) T
conditional(*variables) pyro.distributions.ConditionalDistribution
condition(**conditioning) T
_mutual_information(variable_1: str, variable_2: str) torch.Tensor
mutual_information(variable_1: str | None = None, variable_2: str | None = None) torch.Tensor
abstract _entropy(variables: List[str]) torch.Tensor
entropy(*variables) torch.Tensor
class torch_mist.distributions.joint.base.ConditionedJointDistribution(joint: JointDistribution, marginal: JointDistribution, cond_dict: Dict[str, torch.Tensor])

Bases: JointDistribution

_log_prob(**kwargs) torch.Tensor
__repr__()
class torch_mist.distributions.joint.base.ConditionalNamedDistribution(joint: JointDistribution, condition_on: List[str])

Bases: pyro.distributions.ConditionalDistribution

condition(context: Dict[str, torch.Tensor] | torch.Tensor) JointDistribution
__repr__()