torch_mist.estimators.base

Module Contents

Classes

MIEstimator

class torch_mist.estimators.base.MIEstimator

Bases: torch_mist.nn.Model

infomax_gradient: Dict[str, bool]
abstract log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
mutual_information(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
abstract batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
forward(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]