mlexp.inference.torch_inference

class mlexp.inference.torch_inference.TorchInference(downloaded_files_path, inference_server_params, server)[source]

Downloading logged parameters, hyperparameters and scikit-learn model from particular server.

Parameters
  • downloaded_files_path (str) – Directory to which files from mlflow server will be downloaded.

  • inference_server_params (dict) – Server params to download file from. For mlflow - dict with ‘run_id’ of particular run and ‘tracking_uri’ of your mlflow server. For neptune - ‘project’ in form of ({User}/{Project Name}) and ‘run’ as run id of your experiment.

  • server (Literal[[‘mlflow’, ‘neptune’]]) – Type of server.

get_params_model(metric='metric_mean_cv', step='best', fold_num='test', trained_model=True, nn_model_params={})[source]

Get logged parameters, hyperparameters, metrics and pytorch-lightning neural network from particular step and fold in run.

Parameters
  • metric (str) – Metric to get from server and find best step.

  • step (Union[int, str]) – Index of step in run. If int - index of step. If ‘best’ - best step (with best value of validation metric).

  • fold_num (Union[int, str]) – Cross validation fold. If int - index of validation fold. If ‘test’ - test fold.

  • trained_model (bool) – Whether to downloaded trained model or initialise new model.

  • nn_model_params (dict) – Dictionary with hyperparameters of neural network to be overwritten.

Return type

dict

Returns

Dictionary with downloaded parameters of run, hyperparameters, model and logged metrics.