torch_mist.estimators.transformed.base
Module Contents
Classes
Attributes
- torch_mist.estimators.transformed.base.ERROR_MESSAGE = 'The TransformedMIEstimator can be called by passing two arguments or by specifying multiple...'
- torch_mist.estimators.transformed.base.SPLIT_SEQUENCE = '->'
- class torch_mist.estimators.transformed.base.TransformedMIEstimator(base_estimator: torch_mist.estimators.base.MIEstimator, transforms: Dict[str, Callable[[Any], Any]] | None = None, transforms_rename: Dict[Tuple[str, str], Callable[[Any], Any]] | None = None)
Bases:
torch_mist.estimators.base.MIEstimator- transform(**variables) Dict[str, torch.Tensor]
- _unfold_variables(*args, **variables) Dict[str, Any]
- batch_loss(*args, **variables) torch.Tensor
- log_ratio(*args, **variables) torch.Tensor | List[torch.Tensor]
- unnormalized_log_ratio(*args, **variables) torch.Tensor
- loss(*args, **variables) torch.Tensor
- mutual_information(*args, **variables) torch.Tensor | Dict[str, torch.Tensor]
- forward(*args, **kwargs) torch.Tensor