torch_mist.estimators.hybrid.base

Module Contents

Classes

HybridMIEstimator

class torch_mist.estimators.hybrid.base.HybridMIEstimator(generative_estimator: torch_mist.estimators.base.MIEstimator, discriminative_estimator: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator)

Bases: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator

unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor | None]
resampling_strategy()
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
__repr__()