torch_mist.estimators.discriminative.base

Module Contents

Classes

DiscriminativeMIEstimator

BaselineDiscriminativeMIEstimator

class torch_mist.estimators.discriminative.base.DiscriminativeMIEstimator(critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.base.MIEstimator

lower_bound: bool = True
infomax_gradient: Dict[str, bool]
unnormalized_log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
n_negatives_to_use(N: int)
sample_negatives(x: torch.Tensor, y: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
log_ratio(x: torch.Tensor, y: torch.Tensor) torch.Tensor
abstract _approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
approx_log_partition(x: torch.Tensor, y: torch.Tensor, x_: torch.Tensor, y_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
__repr__()
class torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator(critic: torch_mist.critic.Critic, baseline: torch_mist.baseline.Baseline, neg_samples: int = 1)

Bases: DiscriminativeMIEstimator

_approx_log_partition(x: torch.Tensor, y: torch.Tensor, f_: torch.Tensor, log_w: torch.Tensor | None) torch.Tensor
__repr__()