torch_mist.estimators.discriminative.implementations.js

Module Contents

Classes

JS

class torch_mist.estimators.discriminative.implementations.js.JS(critic: torch_mist.critic.Critic, neg_samples: int = 1)

Bases: torch_mist.estimators.discriminative.base.BaselineDiscriminativeMIEstimator

batch_loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor | None