torch_mist.estimators.discriminative.implementations.infonce

Module Contents

Classes

InfoNCE

class torch_mist.estimators.discriminative.implementations.infonce.InfoNCE(critic: torch_mist.critic.Critic, neg_samples: int = 0)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.tensor, log_w: torch.Tensor | None)