torch_mist.utils.train
Submodules
Package Contents
Functions
|
|
|
- torch_mist.utils.train.train_mi_estimator(estimator: torch_mist.estimators.base.MIEstimator, train_data: torch_mist.utils.data.utils.TensorDictLike, valid_data: torch_mist.utils.data.utils.TensorDictLike | None = None, valid_percentage: float = 0.1, batch_size: int | None = None, num_workers: int = 0, device: torch.device | str = torch.device('cpu'), max_epochs: int | None = None, max_iterations: int | None = None, optimizer_class: Type[torch.optim.Optimizer] = Adam, optimizer_params: Dict[str, Any] | None = None, lr_annealing: bool = False, warmup_percentage: float = 0, verbose: bool = True, logger: torch_mist.utils.logging.logger.base.Logger | bool | None = None, early_stopping: bool = False, patience: int | None = None, tolerance: float = 0.001, fast_train: bool = False, train_logged_methods: List[str | Tuple[str, Callable]] | None = None, eval_logged_methods: List[str | Tuple[str, Callable]] | None = None) Any | None
- torch_mist.utils.train.train_model(model: torch_mist.nn.Model, train_data: torch_mist.utils.data.utils.TensorDictLike, train_method: str = 'loss', eval_method: str | None = None, valid_data: torch_mist.utils.data.utils.TensorDictLike | None = None, valid_percentage: float = 0.1, batch_size: int | None = None, num_workers: int = 0, device: torch.device | str = torch.device('cpu'), max_epochs: int | None = None, max_iterations: int | None = None, optimizer_class: Type[torch.optim.Optimizer] = Adam, optimizer_params: Dict[str, Any] | None = None, lr_annealing: bool = False, warmup_percentage: float = 0, verbose: bool = True, logger: torch_mist.utils.logging.logger.base.Logger | bool | None = None, early_stopping: bool = False, patience: int | None = None, tolerance: float = 0.001, fast_train: bool = False, train_logged_methods: List[str | Tuple[str, Callable]] | None = None, eval_logged_methods: List[str | Tuple[str, Callable]] | None = None) Any | None