torch_mist.baseline.factories

Module Contents

Functions

baseline_nn() → torch_mist.baseline.base.LearnableBaseline)

torch_mist.baseline.factories.baseline_nn(x_dim: int, hidden_dims: List[int], nonlinearity: Callable = nn.ReLU(True)) torch_mist.baseline.base.LearnableBaseline