mlexp.trainers.sklearn_trainer
- class mlexp.trainers.sklearn_trainer.SklearnTrainer(sklearn_estimator, validation_metric, direction, saved_files_path, optimization_metric='metric_mean_cv')[source]
Training, logging and hyperparameters search for scikit-learn models.
- Parameters
sklearn_estimator (
Type[BaseEstimator]) – Scikit-learn estimator to be fitted.validation_metric (
Callable[[ndarray,ndarray],float]) – Score function or loss function with signature validation_metric(y_true, y_pred), must return float/integer value of metric.direction (
str) – Direction of optimization.saved_files_path (
str) – Directory to save logging files.optimization_metric (
str) – Metric to optimize.
- init_run(logging_server, upload_files=[], **run_params)
Initiation of logging server run.
- Parameters
logging_server (
Literal[‘neptune’, ‘optuna’]) – logging serverupload_files (
Iterable[str]) – List of paths to files which will be logged in initiated run.run_params –
If logging server == “mlflow”: Mlflow run parameters as kwargs (will be passed to mlflow.start_run).
If logging_server == “neptune”: Kwarg neptune_run_params as dict of Neptune run parameters (will be passed to neptune.init_run).
- Return type
str- Returns
run id of created run.
- train(X, y, cv, n_trials, params_func, sampler)
Run training, hyperparameters search and logging.
- Parameters
X (
ndarray) – Training features.y (
ndarray) – Target values.cv (
list) –Validation indexes. All but last element of list will be used for hyperparameters search, last element - test fold.
Example:
[[[0, 1, 2, 3], [4, 5]], # first validation fold [[6, 7, 8, 9], [10, 11]], # second validation fold [[12, 13, 14, 15], [16, 17]] # test fold ]
Observation with indexes [0, 1, 2, 3] will be used to train model, then this model will be tested on observations with indexes [4, 5]
Observation with indexes [6, 7, 8, 9] will be used to train model, then this model will be tested on observations with indexes [10, 11]
Observation with indexes [12, 13, 14, 15] will be used to train model, then this model will be tested on observations with indexes [16, 17]
Metrics from 1st two folds will be used during hyperparameters optimization, metric from last fold will be just logged.
n_trials (
int) – Number of iterations to search for hyperparamenets.params_func (
Callable[[Trial],dict]) – Function which accepts optuna.trial.Trial and returns dict with hyperparameters. Read more about params_func in the User Guide.sampler (
BaseSampler) – Hyperparameters sampler from optuna.samplers