torch_mist.estimators.hybrid.implementations

Submodules

Package Contents

Classes

PQHybridMIEstimator

ResampledHybridMIEstimator

ReweighedHybridMIEstimator

class torch_mist.estimators.hybrid.implementations.PQHybridMIEstimator(q_QY_given_X: pyro.distributions.ConditionalDistribution, quantize_y: torch_mist.quantization.QuantizationFunction, discriminative_estimator: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator, temperature: float = 1.0)

Bases: torch_mist.estimators.hybrid.base.HybridMIEstimator

property quantize_y: Callable[[torch.Tensor], torch.LongTensor]
disable_batch_validation()
enable_batch_validation()
sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
class torch_mist.estimators.hybrid.implementations.ResampledHybridMIEstimator(generative_estimator: torch_mist.estimators.generative.base.GenerativeMIEstimator, discriminative_estimator: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator)

Bases: torch_mist.estimators.hybrid.base.HybridMIEstimator

sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
class torch_mist.estimators.hybrid.implementations.ReweighedHybridMIEstimator(generative_estimator: torch_mist.estimators.base.MIEstimator, discriminative_estimator: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator)

Bases: torch_mist.estimators.hybrid.base.HybridMIEstimator

sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]