torch_mist.estimators.generative

Subpackages

Submodules

Package Contents

Classes

GenerativeMIEstimator

ConditionalGenerativeMIEstimator

BA

CLUB

DoE

GM

L1Out

DummyGenerativeMIEstimator

JointDistribution

BA

CLUB

DoE

GM

L1Out

DummyGenerativeMIEstimator

Functions

joint_transformed_normal(...)

ba(→ torch_mist.estimators.generative.implementations.BA)

club(...)

doe(→ torch_mist.estimators.generative.implementations.DoE)

dummy_generative(...)

gm(→ torch_mist.estimators.generative.implementations.GM)

l1out(...)

class torch_mist.estimators.generative.GenerativeMIEstimator

Bases: torch_mist.estimators.base.MIEstimator

unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract approx_log_p_y_given_x(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract approx_log_p_y(x: torch.Tensor, y: torch.Tensor) torch.Tensor
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.generative.ConditionalGenerativeMIEstimator(q_Y_given_X: pyro.distributions.ConditionalDistribution)

Bases: GenerativeMIEstimator

approx_log_p_y_given_x(x: torch.Tensor, y: torch.Tensor) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
__repr__()
class torch_mist.estimators.generative.BA(q_Y_given_X: pyro.distributions.ConditionalDistribution, entropy_y: torch.Tensor | None = None)

Bases: torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator

lower_bound: bool = True
infomax_gradient: Dict[str, bool]
mutual_information(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.generative.CLUB(q_Y_given_X: pyro.distributions.ConditionalDistribution, neg_samples: int = 0)

Bases: torch_mist.estimators.generative.implementations.l1out.L1Out

infomax_gradient: Dict[str, bool]
approx_log_p_y(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.generative.DoE(q_Y_given_X: pyro.distributions.ConditionalDistribution, q_Y: torch.distributions.Distribution)

Bases: torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator

infomax_gradient: Dict[str, bool]
approx_log_p_y(y: torch.Tensor, x: torch.Tensor | None = None) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
__repr__()
class torch_mist.estimators.generative.GM(q_XY: torch_mist.distributions.joint.base.JointDistribution, q_Y: torch.distributions.Distribution | torch_mist.distributions.joint.base.JointDistribution, q_X: torch.distributions.Distribution | torch_mist.distributions.joint.base.JointDistribution)

Bases: torch_mist.estimators.generative.base.JointGenerativeMIEstimator

property q_X
property q_Y
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.generative.L1Out(q_Y_given_X: pyro.distributions.ConditionalDistribution, neg_samples: int = -1)

Bases: torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator

infomax_gradient: Dict[str, bool]
_broadcast_log_p_y_given_x(x: torch.Tensor, y: torch.Tensor) torch.Tensor
approx_log_p_y(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.generative.DummyGenerativeMIEstimator

Bases: torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator

lower_bound: bool = True
infomax_gradient: Dict[str, bool]
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
torch_mist.estimators.generative.joint_transformed_normal(input_dims: Dict[str, int], transform_name: str = 'conditional_linear', n_transforms: int = 1, normalization: str | None = None, **kwargs) torch_mist.distributions.joint.base.JointDistribution
class torch_mist.estimators.generative.JointDistribution(variables: List[str], name: str = 'p')

Bases: torch.nn.Module, torch.distributions.Distribution, pyro.distributions.ConditionalDistribution

full_name()
abstract _log_prob(**kwargs) torch.Tensor
log_prob(*args, **kwargs) torch.Tensor
abstract _marginal(variables: List[str]) T
marginal(*variables) T
conditional(*variables) pyro.distributions.ConditionalDistribution
condition(**conditioning) T
_mutual_information(variable_1: str, variable_2: str) torch.Tensor
mutual_information(variable_1: str | None = None, variable_2: str | None = None) torch.Tensor
abstract _entropy(variables: List[str]) torch.Tensor
entropy(*variables) torch.Tensor
class torch_mist.estimators.generative.BA(q_Y_given_X: pyro.distributions.ConditionalDistribution, entropy_y: torch.Tensor | None = None)

Bases: torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator

lower_bound: bool = True
infomax_gradient: Dict[str, bool]
mutual_information(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.generative.CLUB(q_Y_given_X: pyro.distributions.ConditionalDistribution, neg_samples: int = 0)

Bases: torch_mist.estimators.generative.implementations.l1out.L1Out

infomax_gradient: Dict[str, bool]
approx_log_p_y(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.generative.DoE(q_Y_given_X: pyro.distributions.ConditionalDistribution, q_Y: torch.distributions.Distribution)

Bases: torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator

infomax_gradient: Dict[str, bool]
approx_log_p_y(y: torch.Tensor, x: torch.Tensor | None = None) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
__repr__()
class torch_mist.estimators.generative.GM(q_XY: torch_mist.distributions.joint.base.JointDistribution, q_Y: torch.distributions.Distribution | torch_mist.distributions.joint.base.JointDistribution, q_X: torch.distributions.Distribution | torch_mist.distributions.joint.base.JointDistribution)

Bases: torch_mist.estimators.generative.base.JointGenerativeMIEstimator

property q_X
property q_Y
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.generative.L1Out(q_Y_given_X: pyro.distributions.ConditionalDistribution, neg_samples: int = -1)

Bases: torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator

infomax_gradient: Dict[str, bool]
_broadcast_log_p_y_given_x(x: torch.Tensor, y: torch.Tensor) torch.Tensor
approx_log_p_y(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.generative.DummyGenerativeMIEstimator

Bases: torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator

lower_bound: bool = True
infomax_gradient: Dict[str, bool]
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
torch_mist.estimators.generative.ba(entropy_y: float | torch.Tensor, x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, q_Y_given_X: pyro.distributions.ConditionalDistribution | None = None, transform_name: str = 'conditional_linear', n_transforms: int = 1) torch_mist.estimators.generative.implementations.BA
torch_mist.estimators.generative.club(x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, q_Y_given_X: pyro.distributions.ConditionalDistribution | None = None, transform_name: str = 'conditional_linear', n_transforms: int = 1) torch_mist.estimators.generative.implementations.CLUB
torch_mist.estimators.generative.doe(x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, q_Y_given_X: pyro.distributions.ConditionalDistribution | None = None, q_Y: torch.distributions.Distribution | None = None, conditional_transform_name: str = 'conditional_linear', n_conditional_transforms: int = 1, marginal_transform_name: str = 'linear', n_marginal_transforms: int = 1) torch_mist.estimators.generative.implementations.DoE
torch_mist.estimators.generative.dummy_generative(**kwargs) torch_mist.estimators.generative.implementations.DummyGenerativeMIEstimator
torch_mist.estimators.generative.gm(x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] = None, q_XY: torch_mist.distributions.joint.base.JointDistribution | None = None, q_Y: torch.distributions.Distribution | None = None, q_X: torch.distributions.Distribution | None = None, joint_transform_name: str = 'affine_autoregressive', n_joint_transforms: int = 1, marginal_transform_name: str = 'linear', n_marginal_transforms: int = 1) torch_mist.estimators.generative.implementations.GM
torch_mist.estimators.generative.l1out(x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, q_Y_given_X: pyro.distributions.ConditionalDistribution | None = None, transform_name: str = 'conditional_linear', n_transforms: int = 1) torch_mist.estimators.generative.implementations.L1Out