torch_mist.estimators.discriminative.implementations.flo

Module Contents

Classes

FLO

class torch_mist.estimators.discriminative.implementations.flo.FLO(critic: torch_mist.critic.Critic, normalized_critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator

unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
__repr__()