torch_mist.baseline.base

Module Contents

Classes

Baseline

ConstantBaseline

ExponentialMovingAverage

BatchLogMeanExp

LearnableBaseline

InterpolatedBaseline

class torch_mist.baseline.base.Baseline

Bases: torch.nn.Module

abstract forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor
class torch_mist.baseline.base.ConstantBaseline(value: float = 0)

Bases: Baseline

forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor
class torch_mist.baseline.base.ExponentialMovingAverage(gamma: float = 0.9)

Bases: Baseline

forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor
class torch_mist.baseline.base.BatchLogMeanExp(dims: str)

Bases: Baseline

forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor
class torch_mist.baseline.base.LearnableBaseline(net: torch.nn.Module)

Bases: Baseline

forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor
class torch_mist.baseline.base.InterpolatedBaseline(baseline_1: Baseline, baseline_2: Baseline, alpha: float)

Bases: Baseline

forward(x: torch.Tensor, f_: torch.Tensor) torch.Tensor