torch_mist.baseline
Submodules
Package Contents
Classes
Functions
|
- class torch_mist.baseline.Baseline
Bases:
torch.nn.Module- abstract forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor
- class torch_mist.baseline.BatchLogMeanExp(dims: str)
Bases:
Baseline- forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor
- class torch_mist.baseline.LearnableBaseline(net: torch.nn.Module)
Bases:
Baseline- forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor
- class torch_mist.baseline.ConstantBaseline(value: float = 0)
Bases:
Baseline- forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor
- class torch_mist.baseline.InterpolatedBaseline(baseline_1: Baseline, baseline_2: Baseline, alpha: float)
Bases:
Baseline- forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor
- class torch_mist.baseline.ExponentialMovingAverage(gamma: float = 0.9)
Bases:
Baseline- forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor
- torch_mist.baseline.baseline_nn(x_dim: int, hidden_dims: List[int], nonlinearity: Callable = nn.ReLU(True)) torch_mist.baseline.base.LearnableBaseline