torch_mist.estimators.generative.base

Module Contents

Classes

GenerativeMIEstimator

ConditionalGenerativeMIEstimator

JointGenerativeMIEstimator

class torch_mist.estimators.generative.base.GenerativeMIEstimator

Bases: torch_mist.estimators.base.MIEstimator

unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract approx_log_p_y_given_x(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract approx_log_p_y(x: torch.Tensor, y: torch.Tensor) torch.Tensor
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
class torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator(q_Y_given_X: pyro.distributions.ConditionalDistribution)

Bases: GenerativeMIEstimator

approx_log_p_y_given_x(x: torch.Tensor, y: torch.Tensor) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
__repr__()
class torch_mist.estimators.generative.base.JointGenerativeMIEstimator(q_XY: torch_mist.distributions.JointDistribution)

Bases: GenerativeMIEstimator

property q_X
property q_Y
approx_log_p_xy(x: torch.Tensor, y: torch.Tensor) torch.Tensor
approx_log_p_x(x: torch.Tensor, y: torch.Tensor) torch.Tensor
approx_log_p_y(x: torch.Tensor, y: torch.Tensor) torch.Tensor
approx_log_p_y_given_x(x: torch.Tensor, y: torch.Tensor) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
__repr__()