torch_mist.estimators.base
Module Contents
Classes
- class torch_mist.estimators.base.MIEstimator
Bases:
torch_mist.nn.Model- infomax_gradient: Dict[str, bool]
- abstract log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- abstract unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- mutual_information(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]
- abstract batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- forward(x: torch.Tensor, y: torch.Tensor) torch.Tensor | Dict[str, torch.Tensor]