torch_mist.models.bottleneck
Module Contents
Classes
- class torch_mist.models.bottleneck.InformationBottleneck(mi_estimator: torch_mist.estimators.MIEstimator, beta: float, p_ZX_given_X: torch.nn.Module | pyro.distributions.ConditionalDistribution | None = None, p_ZY_given_Y: torch.nn.Module | pyro.distributions.ConditionalDistribution | None = None)
Bases:
torch_mist.nn.Model- abstract regularization(zx: torch.Tensor, zy: torch.Tensor, p_ZX_given_x: torch.distributions.Distribution, p_ZY_given_y: torch.distributions.Distribution)
- mutual_information(x: torch.Tensor, y: torch.Tensor)
- loss(x: torch.Tensor, y: torch.Tensor) torch.Tensor
- class torch_mist.models.bottleneck.VIB(mi_estimator: torch_mist.estimators.MIEstimator, q_ZX: torch.distributions.Distribution, beta: float, p_ZX_given_X: torch.nn.Module | pyro.distributions.ConditionalDistribution | None = None, p_ZY_given_Y: torch.nn.Module | pyro.distributions.ConditionalDistribution | None = None)
Bases:
InformationBottleneck- regularization(zx: torch.Tensor, zy: torch.Tensor, p_ZX_given_x: torch.distributions.Distribution, p_ZY_given_y: torch.distributions.Distribution)
- class torch_mist.models.bottleneck.MIB(mi_estimator: torch_mist.estimators.MIEstimator, beta: float, p_ZX_given_X: torch.nn.Module | pyro.distributions.ConditionalDistribution | None = None, p_ZY_given_Y: torch.nn.Module | pyro.distributions.ConditionalDistribution | None = None)
Bases:
InformationBottleneck- regularization(zx: torch.Tensor, zy: torch.Tensor, p_ZX_given_x: torch.distributions.Distribution, p_ZY_given_y: torch.distributions.Distribution)
- class torch_mist.models.bottleneck.CEB(mi_estimator: torch_mist.estimators.MIEstimator, q_ZX_given_ZY: pyro.distributions.ConditionalDistribution, beta: float, p_ZX_given_X: torch.nn.Module | pyro.distributions.ConditionalDistribution | None = None, p_ZY_given_Y: torch.nn.Module | pyro.distributions.ConditionalDistribution | None = None)
Bases:
InformationBottleneck- regularization(zx: torch.Tensor, zy: torch.Tensor, p_ZX_given_x: torch.distributions.Distribution, p_ZY_given_y: torch.distributions.Distribution)
- class torch_mist.models.bottleneck.TIB(mi_estimator: torch_mist.estimators.MIEstimator, q_Zt2_given_Zt1: pyro.distributions.ConditionalDistribution, p_Zt_given_Xt: pyro.distributions.ConditionalDistribution, beta: float)
Bases:
CEB