torch_mist.nn

Module Contents

Classes

Model

Identity

Normalize

CachedModule

Functions

dense_nn(→ torch.nn.Module)

multi_head_dense_nn() → Tuple[torch.nn.Module, Ellipsis])

class torch_mist.nn.Model

Bases: torch.nn.Module

lower_bound: bool = False
upper_bound: bool = False
abstract loss(*args, **kwargs)
class torch_mist.nn.Identity

Bases: torch.nn.Module

forward(x: Any)
class torch_mist.nn.Normalize

Bases: torch.nn.Module

forward(x)
class torch_mist.nn.CachedModule(module: torch.nn.Module)

Bases: torch.nn.Module

forward(*args, **kwargs)
__repr__()
torch_mist.nn.dense_nn(input_dim: int, output_dim: int, hidden_dims: List[int] | None = None, nonlinearity: Callable[[torch.Tensor], torch.Tensor] | None = None) torch.nn.Module
torch_mist.nn.multi_head_dense_nn(input_dim: int, output_dim: int, n_shared_layers: int, hidden_dims: List[int], n_heads: int, cached_shared_forward: bool = True, nonlinearity: torch.nn.Module = nn.ReLU(True)) Tuple[torch.nn.Module, Ellipsis]