torch_mist.estimators.transformed.base

Module Contents

Classes

TransformedMIEstimator

Attributes

ERROR_MESSAGE

SPLIT_SEQUENCE

torch_mist.estimators.transformed.base.ERROR_MESSAGE = 'The TransformedMIEstimator can be called by passing two arguments or by specifying multiple...'
torch_mist.estimators.transformed.base.SPLIT_SEQUENCE = '->'
class torch_mist.estimators.transformed.base.TransformedMIEstimator(base_estimator: torch_mist.estimators.base.MIEstimator, transforms: Dict[str, Callable[[Any], Any]] | None = None, transforms_rename: Dict[Tuple[str, str], Callable[[Any], Any]] | None = None)

Bases: torch_mist.estimators.base.MIEstimator

transform(**variables) Dict[str, torch.Tensor]
_unfold_variables(*args, **variables) Dict[str, Any]
batch_loss(*args, **variables) torch.Tensor
log_ratio(*args, **variables) torch.Tensor | List[torch.Tensor]
unnormalized_log_ratio(*args, **variables) torch.Tensor
loss(*args, **variables) torch.Tensor
mutual_information(*args, **variables) torch.Tensor | Dict[str, torch.Tensor]
forward(*args, **kwargs) torch.Tensor