torch_mist.estimators.discriminative.implementations.smile

Module Contents

Classes

SMILE

class torch_mist.estimators.discriminative.implementations.smile.SMILE(critic: torch_mist.critic.Critic, neg_samples: int = 1, tau: float = 5.0)

Bases: torch_mist.estimators.discriminative.implementations.js.JS, torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor