torch_mist.estimators.hybrid.implementations.reweighted
Module Contents
Classes
- class torch_mist.estimators.hybrid.implementations.reweighted.ReweighedHybridMIEstimator(generative_estimator: torch_mist.estimators.base.MIEstimator, discriminative_estimator: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator)
Bases:
torch_mist.estimators.hybrid.base.HybridMIEstimator- sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]