torch_mist.estimators.discriminative.implementations.flo
Module Contents
Classes
- class torch_mist.estimators.discriminative.implementations.flo.FLO(critic: torch_mist.critic.Critic, normalized_critic: torch_mist.critic.Critic, neg_samples: int = 1)
Bases:
torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator- unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
- __repr__()