mlexp.inference.torch_inference
- class mlexp.inference.torch_inference.TorchInference(downloaded_files_path, inference_server_params, server)[source]
Downloading logged parameters, hyperparameters and scikit-learn model from particular server.
- Parameters
downloaded_files_path (
str) – Directory to which files from mlflow server will be downloaded.inference_server_params (
dict) – Server params to download file from. For mlflow - dict with ‘run_id’ of particular run and ‘tracking_uri’ of your mlflow server. For neptune - ‘project’ in form of ({User}/{Project Name}) and ‘run’ as run id of your experiment.server (
Literal[[‘mlflow’, ‘neptune’]]) – Type of server.
- get_params_model(metric='metric_mean_cv', step='best', fold_num='test', trained_model=True, nn_model_params={})[source]
Get logged parameters, hyperparameters, metrics and pytorch-lightning neural network from particular step and fold in run.
- Parameters
metric (
str) – Metric to get from server and find best step.step (
Union[int,str]) – Index of step in run. If int - index of step. If ‘best’ - best step (with best value of validation metric).fold_num (
Union[int,str]) – Cross validation fold. If int - index of validation fold. If ‘test’ - test fold.trained_model (
bool) – Whether to downloaded trained model or initialise new model.nn_model_params (
dict) – Dictionary with hyperparameters of neural network to be overwritten.
- Return type
dict- Returns
Dictionary with downloaded parameters of run, hyperparameters, model and logged metrics.