1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 03:21:00 +08:00
Files
qlib/qlib/workflow/expm.py
2020-11-02 11:05:40 +08:00

232 lines
7.5 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import mlflow
import os
from pathlib import Path
from contextlib import contextmanager
from .exp import MLflowExperiment
from .recorder import MLflowRecorder
from ..log import get_module_logger
logger = get_module_logger('workflow', 'Warning')
class ExpManager:
"""
This is the `ExpManager` class for managing the experiments. The API is designed similar to mlflow.
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
"""
def __init__(self):
self.uri = None
self.active_recorder = None # only one recorder can running each time
self.experiments = dict() # store the experiment name --> Experiment object
def start_exp(self, experiment_name=None, uri=None, **kwargs):
"""
Start running an experiment.
Parameters
----------
experiment_name : str
name of the active experiment.
uri : str
the current tracking URI.
artifact_location : str
the location to store all the artifacts.
nested : boolean
controls whether run is nested in parent run.
Returns
An object wrapped by context manager.
"""
raise NotImplementedError(f"Please implement the `start_exp` method.")
def end_exp(self, **kwargs):
"""
End an running experiment.
Parameters
----------
experiment_name : str
name of the active experiment.
"""
raise NotImplementedError(f"Please implement the `end_exp` method.")
def search_records(self, experiment_ids=None, **kwargs):
"""
Get a pandas DataFrame of records that fit the search criteria.
Parameters
----------
experiment_ids : list
list of experiment IDs.
filter_string : str
filter query string, defaults to searching all runs.
run_view_type : int
one of enum values ACTIVE_ONLY, DELETED_ONLY, or ALL (e.g. in mlflow.entities.ViewType).
max_results : int
the maximum number of runs to put in the dataframe.
order_by : list
list of columns to order by (e.g., “metrics.rmse”).
Returns
-------
A pandas.DataFrame of runs.
"""
raise NotImplementedError(f"Please implement the `search_records` method.")
def __create_exp(self, experiment_name, artifact_location=None):
"""
Create an experiment.
Parameters
----------
experiment_name : str
the experiment name, which must be unique.
artifact_location : str
the location to store run artifacts.
Returns
-------
An experiment object.
"""
raise NotImplementedError(f"Please implement the `create_exp` method.")
def get_exp(self, experiment_id=None, experiment_name=None):
"""
Retrieve an experiment by experiment_id from the backend store.
Parameters
----------
experiment_id : str
the experiment id to return.
Returns
-------
An experiment object.
"""
raise NotImplementedError(f"Please implement the `get_exp` method.")
def delete_exp(self, experiment_id):
"""
Delete an experiment.
Parameters
----------
experiment_id : str
the experiment id.
"""
raise NotImplementedError(f"Please implement the `create_exp` method.")
def get_uri(self):
"""
Get the default tracking URI or current URI.
Parameters
----------
Returns
-------
The tracking URI string.
"""
return self.uri
def get_recorder(self):
"""
Get the current active Recorder.
Parameters
----------
Returns
-------
An Recorder object.
"""
return self.active_recorder
class MLflowExpManager(ExpManager):
"""
Use mlflow to implement ExpManager.
"""
def __init__(self):
super(MLflowExpManager, self).__init__()
self.uri = None
def start_exp(self, experiment_name=None, uri=None):
# create experiment
experiment = self.__create_exp(experiment_name, uri)
# set up recorder
recorder = experiment.create_recorder()
self.active_recorder = recorder
# store the experiment
self.experiments[experiment_name] = experiment
return self.active_recorder.start_run(experiment_id=experiment.id)
def end_exp(self):
self.active_recorder.end_run()
self.active_recorder = None
def __create_exp(self, experiment_name=None, uri=None):
# init experiment
experiment = MLflowExperiment()
# set the tracking uri
if uri is None:
logger.warning(
"No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory."
)
else:
self.uri = uri
mlflow.set_tracking_uri(self.uri)
# start the experiment
if experiment_name is None:
logger.warning("No experiment name provided. The default experiment name is set as `experiment`.")
experiment_id = mlflow.create_experiment("experiment")
# set the active experiment
mlflow.set_experiment("experiment")
experiment_name = "experiment"
else:
if experiment_name not in self.experiments:
if mlflow.get_experiment_by_name(experiment_name) is not None:
raise Exception(
"The experiment has already been created before. Please pick another name or delete the files under uri."
)
experiment_id = mlflow.create_experiment(experiment_name)
else:
experiment_id = self.experiments[experiment_name].id
experiment = self.experiments[experiment_name]
# set the active experiment
mlflow.set_experiment(experiment_name)
# set up experiment
experiment.id = experiment_id
experiment.name = experiment_name
return experiment
def search_records(self, experiment_ids, **kwargs):
filter_string = "" if kwargs.get("filter_string") is None else kwargs.get("filter_string")
run_view_type = 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type")
max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
order_by = kwargs.get("order_by")
return mlflow.search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by)
def get_exp(self, experiment_id=None, experiment_name=None):
assert (
experiment_id is not None or experiment_name is not None
), "Please provide at least one of the experiment id or name to retrieve an experiment."
if experiment_name is not None:
return self.experiments[experiment_name]
elif experiment_id is not None:
for name in self.experiments:
if self.experiments[name].id == experiment_id:
return self.experiments[name]
else:
raise Exception("No valid experiment is found. Please make sure the id and name are correctly given.")
def delete_exp(self, experiment_id):
mlflow.delete_experiment(experiment_id)
self.experiments = {key: val for key, val in self.experiments.items() if val.id != experiment_id}