mlexp.trainers.torch_trainer
- class mlexp.trainers.torch_trainer.TorchTrainer(nn_model_module, data_loaders_module, metrics_callback_module, validation_metric, direction, saved_files_path, use_average_epochs_on_test_fold=True, optimization_metric='metric_mean_cv')[source]
Training, logging and hyperparameters search for pytorch-lightning neural network.
- Parameters
nn_model_module (
module) – Module with class nn_model, which inherits from pytorch_lightning.LightningModule.data_loaders_module (
module) – Module with function train_val_data_loaders, which has signature callable([numpy.ndarray, numpy.ndarray, list[list[int], list[int]]], [torch.utils.data.DataLoader, torch.utils.data.DataLoader])metrics_callback_module (
module) –Module with MetricsCallback class, which inherits from pytorch_lightning.Callback.
Must have these 2 methods:
get_metric must return list with metric by epoches
get_n_epochs must return number of epoches as int
validation_metric (
Callable[[ndarray,ndarray],float]) – Score function or loss function with signature validation_metric(y_true, y_pred), must return float/integer value of metric.direction (
str) – Direction of optimization.saved_files_path (
str) – Directory to save logging files.use_average_epochs_on_test_fold (
bool) – Whether to train model on test fold with mean number of epoches from validation folds or use number of epoches from params_func.optimization_metric (
str) – Metric to optimize.
- init_run(logging_server, upload_files=[], **run_params)
Initiation of logging server run.
- Parameters
logging_server (
Literal[‘neptune’, ‘optuna’]) – logging serverupload_files (
Iterable[str]) – List of paths to files which will be logged in initiated run.run_params –
If logging server == “mlflow”: Mlflow run parameters as kwargs (will be passed to mlflow.start_run).
If logging_server == “neptune”: Kwarg neptune_run_params as dict of Neptune run parameters (will be passed to neptune.init_run).
- Return type
str- Returns
run id of created run.
- train(X, y, cv, n_trials, params_func, sampler)
Run training, hyperparameters search and logging.
- Parameters
X (
ndarray) – Training features.y (
ndarray) – Target values.cv (
list) –Validation indexes. All but last element of list will be used for hyperparameters search, last element - test fold.
Example:
[[[0, 1, 2, 3], [4, 5]], # first validation fold [[6, 7, 8, 9], [10, 11]], # second validation fold [[12, 13, 14, 15], [16, 17]] # test fold ]
Observation with indexes [0, 1, 2, 3] will be used to train model, then this model will be tested on observations with indexes [4, 5]
Observation with indexes [6, 7, 8, 9] will be used to train model, then this model will be tested on observations with indexes [10, 11]
Observation with indexes [12, 13, 14, 15] will be used to train model, then this model will be tested on observations with indexes [16, 17]
Metrics from 1st two folds will be used during hyperparameters optimization, metric from last fold will be just logged.
n_trials (
int) – Number of iterations to search for hyperparamenets.params_func (
Callable[[Trial],dict]) – Function which accepts optuna.trial.Trial and returns dict with hyperparameters. Read more about params_func in the User Guide.sampler (
BaseSampler) – Hyperparameters sampler from optuna.samplers