torch_mist.estimators.generative.base
Module Contents
Classes
- class torch_mist.estimators.generative.base.GenerativeMIEstimator
Bases:
torch_mist.estimators.base.MIEstimator- unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- abstract approx_log_p_y_given_x(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- abstract approx_log_p_y(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- class torch_mist.estimators.generative.base.ConditionalGenerativeMIEstimator(q_Y_given_X: pyro.distributions.ConditionalDistribution)
Bases:
GenerativeMIEstimator- approx_log_p_y_given_x(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- __repr__()
- class torch_mist.estimators.generative.base.JointGenerativeMIEstimator(q_XY: torch_mist.distributions.JointDistribution)
Bases:
GenerativeMIEstimator- property q_X
- property q_Y
- approx_log_p_xy(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- approx_log_p_x(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- approx_log_p_y(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- approx_log_p_y_given_x(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- __repr__()