torch_mist.baseline

Submodules

Package Contents

Classes

Baseline

BatchLogMeanExp

LearnableBaseline

ConstantBaseline

InterpolatedBaseline

ExponentialMovingAverage

Functions

baseline_nn() → torch_mist.baseline.base.LearnableBaseline)

class torch_mist.baseline.Baseline

Bases: torch.nn.Module

abstract forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor
class torch_mist.baseline.BatchLogMeanExp(dims: str)

Bases: Baseline

forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor
class torch_mist.baseline.LearnableBaseline(net: torch.nn.Module)

Bases: Baseline

forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor
class torch_mist.baseline.ConstantBaseline(value: float = 0)

Bases: Baseline

forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor
class torch_mist.baseline.InterpolatedBaseline(baseline_1: Baseline, baseline_2: Baseline, alpha: float)

Bases: Baseline

forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor
class torch_mist.baseline.ExponentialMovingAverage(gamma: float = 0.9)

Bases: Baseline

forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor
torch_mist.baseline.baseline_nn(x_dim: int, hidden_dims: List[int], nonlinearity: Callable = nn.ReLU(True)) torch_mist.baseline.base.LearnableBaseline