torch_mist.nn
Module Contents
Classes
Functions
|
|
|
- class torch_mist.nn.Model
Bases:
torch.nn.Module- lower_bound: bool = False
- upper_bound: bool = False
- abstract loss(*args, **kwargs)
- class torch_mist.nn.CachedModule(module: torch.nn.Module)
Bases:
torch.nn.Module- forward(*args, **kwargs)
- __repr__()
- torch_mist.nn.dense_nn(input_dim: int, output_dim: int, hidden_dims: List[int] | None = None, nonlinearity: Callable[[torch.Tensor], torch.Tensor] | None = None) torch.nn.Module
- torch_mist.nn.multi_head_dense_nn(input_dim: int, output_dim: int, n_shared_layers: int, hidden_dims: List[int], n_heads: int, cached_shared_forward: bool = True, nonlinearity: torch.nn.Module = nn.ReLU(True)) Tuple[torch.nn.Module, Ellipsis]