Source code for mlexp.inference.torch_inference

import importlib.util
import os
import pickle
import shutil
import sys
from types import ModuleType
from typing import Literal, Union

import pytorch_lightning as pl

from mlexp.inference._base_inference import _BaseModelInference, SERVER_INFERENCES


[docs]class TorchInference(_BaseModelInference): """Downloading logged parameters, hyperparameters and scikit-learn model from particular server.""" def __init__( self, downloaded_files_path: str, inference_server_params: dict, server: Literal[list(SERVER_INFERENCES.keys())], ): """ :param downloaded_files_path: Directory to which files from mlflow server will be downloaded. :param inference_server_params: Server params to download file from. For mlflow - dict with 'run_id' of particular run and 'tracking_uri' of your mlflow server. For neptune - 'project' in form of ({User}/{Project Name}) and 'run' as run id of your experiment. :param server: Type of server. """ self.downloaded_files_path = downloaded_files_path self.downloaded_params = {} self.server_inference = SERVER_INFERENCES[server]( inference_server_params, downloaded_files_path ) if os.path.isdir(self.downloaded_files_path): shutil.rmtree(self.downloaded_files_path) os.makedirs(r"{}/downloaded_models/".format(downloaded_files_path)) os.makedirs(r"{}/downloaded_utils/".format(downloaded_files_path)) os.makedirs(r"{}/downloaded_studies/".format(downloaded_files_path)) def _load_module(self, module_name: str, module_directory: str) -> ModuleType: """Load module from module_directory with name module_name :param module_name: Name of module to use :param module_directory: Path to module :return: Loaded module """ spec = importlib.util.spec_from_file_location(module_name, module_directory) module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module spec.loader.exec_module(module) return sys.modules[module_name]
[docs] def get_params_model( self, metric: str = "metric_mean_cv", step: Union[int, str] = "best", fold_num: Union[int, str] = "test", trained_model: bool = True, nn_model_params: dict = {}, ) -> dict: """Get logged parameters, hyperparameters, metrics and pytorch-lightning neural network from particular step and fold in run. :param metric: Metric to get from server and find best step. :param step: Index of step in run. If int - index of step. If 'best' - best step (with best value of validation metric). :param fold_num: Cross validation fold. If int - index of validation fold. If 'test' - test fold. :param trained_model: Whether to downloaded trained model or initialise new model. :param nn_model_params: Dictionary with hyperparameters of neural network to be overwritten. :return: Dictionary with downloaded parameters of run, hyperparameters, model and logged metrics. """ ( self.downloaded_params["direction"], self.downloaded_params["model_type"], self.downloaded_params["validation_metric"], self.downloaded_params["use_average_epochs_on_test_fold"], ) = self.server_inference.get_run_params() if step == "best": step = self.server_inference.get_best_step( self.downloaded_params["direction"], metric ) self.downloaded_params["step"] = step self.downloaded_params["metric"] = self.server_inference.get_metric( step, metric ) self.downloaded_params["params"] = self.server_inference.get_step_params(step) optuna_study_path = self.server_inference.get_file( r"saved_studies/optuna_study_{}.pickle".format(step), r"{}/downloaded_studies/optuna_study_{}.pickle".format( self.downloaded_files_path, step ), ) with open(optuna_study_path, "rb") as f: self.downloaded_params["optuna_study"] = pickle.load(f) initial_min_epoches = self.downloaded_params["params"]["trainer_params"][ "min_epochs" ] initial_max_epoches = self.downloaded_params["params"]["trainer_params"][ "max_epochs" ] nn_model_module_path = self.server_inference.get_file( "saved_utils/nn_model_module.py", r"{}/downloaded_utils/nn_model_module.py".format( self.downloaded_files_path ), ) data_loaders_module_path = self.server_inference.get_file( "saved_utils/data_loaders_module.py", r"{}/downloaded_utils/data_loaders_module.py".format( self.downloaded_files_path ), ) metrics_callback_module_path = self.server_inference.get_file( "saved_utils/metrics_callback_module.py", r"{}/downloaded_utils/metrics_callback_module.py".format( self.downloaded_files_path ), ) self.downloaded_params["nn_model"] = self._load_module( "nn_model_module", nn_model_module_path ).nn_model self.downloaded_params["data_loaders"] = self._load_module( "data_loaders_module", data_loaders_module_path ).train_val_data_loaders self.downloaded_params["metrics_callback"] = self._load_module( "metrics_callback_module", metrics_callback_module_path ).MetricsCallback if trained_model: weigths_args_path = self.server_inference.get_file( r"saved_models/model_trial_{}_fold_{}.ckpt".format(step, fold_num), r"{}/downloaded_models/model_trial_{}_fold_{}.ckpt".format( self.downloaded_files_path, step, fold_num ), ) self.downloaded_params["trained_model"] = self.downloaded_params[ "nn_model" ].load_from_checkpoint(weigths_args_path, **nn_model_params) if self.downloaded_params["use_average_epochs_on_test_fold"]: self.downloaded_params["params"]["trainer_params"][ "min_epochs" ] = self.downloaded_params["params"]["validation_mean_epochs"] self.downloaded_params["params"]["trainer_params"][ "max_epochs" ] = self.downloaded_params["params"]["validation_mean_epochs"] metrics_callback = self.downloaded_params["metrics_callback"]() self.downloaded_params["trainer"] = pl.Trainer( **self.downloaded_params["params"]["trainer_params"], callbacks=[ pl.callbacks.early_stopping.EarlyStopping( self.downloaded_params["params"]["EarlyStopping_params"] ), metrics_callback, ] ) self.downloaded_params["params"]["trainer_params"][ "min_epochs" ] = initial_min_epoches self.downloaded_params["params"]["trainer_params"][ "max_epochs" ] = initial_max_epoches else: self.downloaded_params["model"] = self.downloaded_params["nn_model"]( **{ **self.downloaded_params["params"]["model_params"], **nn_model_params, **{ "validation_metric": self.downloaded_params["validation_metric"] }, } ) if self.downloaded_params["use_average_epochs_on_test_fold"]: self.downloaded_params["params"]["trainer_params"][ "min_epochs" ] = self.downloaded_params["params"]["validation_mean_epochs"] self.downloaded_params["params"]["trainer_params"][ "max_epochs" ] = self.downloaded_params["params"]["validation_mean_epochs"] metrics_callback = self.downloaded_params["metrics_callback"]() self.downloaded_params["trainer"] = pl.Trainer( **self.downloaded_params["params"]["trainer_params"], callbacks=[ pl.callbacks.early_stopping.EarlyStopping( self.downloaded_params["params"]["EarlyStopping_params"] ), metrics_callback, ] ) self.downloaded_params["params"]["trainer_params"][ "min_epochs" ] = initial_min_epoches self.downloaded_params["params"]["trainer_params"][ "max_epochs" ] = initial_max_epoches return self.downloaded_params