torch_mist.estimators.discriminative.implementations.tuba

Module Contents

Classes

TUBA

class torch_mist.estimators.discriminative.implementations.tuba.TUBA(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.LearnableBaseline, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator