1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Fix logic of uri in ExpM and add test

This commit is contained in:
D-X-Y
2021-03-04 21:04:01 -08:00
parent ee7eb79277
commit c4d6e00470
6 changed files with 58 additions and 26 deletions

View File

@@ -1489,7 +1489,7 @@ OpsList = [
]
class OpsWrapper(object):
class OpsWrapper:
"""Ops Wrapper"""
def __init__(self):

View File

@@ -39,8 +39,8 @@ class QlibRecorder:
name of the recorder under the experiment one wants to start.
uri : str
The tracking uri of the experiment, where all the artifacts/metrics etc. will be stored.
The default uri are set in the qlib.config. Note that this uri argument will not change the one defined in the config file.
Therefore, the next time when user call this function in the same experiment,
The default uri is set in the qlib.config. Note that this uri argument will not change the one defined in the config file.
Therefore, the next time when users call this function in the same experiment,
they have to also specify this argument with the same value. Otherwise, inconsistent uri may occur.
"""
run = self.start_exp(experiment_name, recorder_name, uri)
@@ -280,7 +280,7 @@ class QlibRecorder:
-------
The uri of current experiment manager.
"""
return self.exp_manager.get_uri()
return self.exp_manager.uri
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None):
"""

View File

@@ -11,7 +11,7 @@ from ..log import get_module_logger
logger = get_module_logger("workflow", "INFO")
class Experiment(object):
class Experiment:
"""
This is the `Experiment` class for each experiment being run. The API is designed similar to mlflow.
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
@@ -173,7 +173,7 @@ class MLflowExperiment(Experiment):
self._uri = uri
self._default_name = None
self._default_rec_name = "mlflow_recorder"
self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
def start(self, recorder_name=None):
logger.info(f"Experiment {self.id} starts running ...")
@@ -208,7 +208,6 @@ class MLflowExperiment(Experiment):
else:
recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
if is_new:
mlflow.set_experiment(self.name)
self.active_recorder = recorder
# start the recorder
self.active_recorder.start_run()
@@ -237,7 +236,7 @@ class MLflowExperiment(Experiment):
), "Please input at least one of recorder id or name before retrieving recorder."
if recorder_id is not None:
try:
run = self.client.get_run(recorder_id)
run = self._client.get_run(recorder_id)
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=run)
return recorder
except MlflowException:
@@ -258,7 +257,7 @@ class MLflowExperiment(Experiment):
max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
order_by = kwargs.get("order_by")
return self.client.search_runs([self.id], filter_string, run_view_type, max_results, order_by)
return self._client.search_runs([self.id], filter_string, run_view_type, max_results, order_by)
def delete_recorder(self, recorder_id=None, recorder_name=None):
assert (
@@ -266,10 +265,10 @@ class MLflowExperiment(Experiment):
), "Please input a valid recorder id or name before deleting."
try:
if recorder_id is not None:
self.client.delete_run(recorder_id)
self._client.delete_run(recorder_id)
else:
recorder = self._get_recorder(recorder_name=recorder_name)
self.client.delete_run(recorder.id)
self._client.delete_run(recorder.id)
except MlflowException as e:
raise Exception(
f"Error: {e}. Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct."
@@ -278,7 +277,7 @@ class MLflowExperiment(Experiment):
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
def list_recorders(self, max_results=UNLIMITED):
runs = self.client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1]
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1]
recorders = dict()
for i in range(len(runs)):
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])

View File

@@ -14,19 +14,20 @@ from ..log import get_module_logger
logger = get_module_logger("workflow", "INFO")
class ExpManager(object):
class ExpManager:
"""
This is the `ExpManager` class for managing experiments. The API is designed similar to mlflow.
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
"""
def __init__(self, uri, default_exp_name):
self.uri = uri
self._default_uri = uri
self._current_uri = None
self.default_exp_name = default_exp_name
self.active_experiment = None # only one experiment can active each time
def __repr__(self):
return "{name}(uri={uri})".format(name=self.__class__.__name__, uri=self.uri)
return "{name}(default_uri={duri}, current_uri={curi})".format(name=self.__class__.__name__, duri=self._default_uri, curi=self._current_uri)
def start_exp(self, experiment_name=None, recorder_name=None, uri=None, **kwargs):
"""
@@ -206,7 +207,8 @@ class ExpManager(object):
"""
raise NotImplementedError(f"Please implement the `delete_exp` method.")
def get_uri(self):
@property
def uri(self):
"""
Get the default tracking URI or current URI.
@@ -214,7 +216,7 @@ class ExpManager(object):
-------
The tracking URI string.
"""
return self.uri
return self._current_uri or self._default_uri
def list_experiments(self):
"""
@@ -234,25 +236,27 @@ class MLflowExpManager(ExpManager):
def __init__(self, uri, default_exp_name):
super(MLflowExpManager, self).__init__(uri, default_exp_name)
self._client = None
@property
def client(self):
# Delay the creation of mlflow client in case of creating `mlruns` folder when importing qlib
if not hasattr(self, "_client"):
if self._client is None:
self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
return self._client
def start_exp(self, experiment_name=None, recorder_name=None, uri=None):
# set the tracking uri
# Set the tracking uri
if uri is None:
logger.info("No tracking URI is provided. Use the default tracking URI.")
else:
self.uri = uri
# create experiment
# Temporarily re-set the current uri as the uri argument.
self._current_uri = uri
# Create experiment
experiment, _ = self._get_or_create_exp(experiment_name=experiment_name)
# set up active experiment
# Set up active experiment
self.active_experiment = experiment
# start the experiment
# Start the experiment
self.active_experiment.start(recorder_name)
return self.active_experiment
@@ -261,6 +265,8 @@ class MLflowExpManager(ExpManager):
if self.active_experiment is not None:
self.active_experiment.end(recorder_status)
self.active_experiment = None
# When an experiment end, we will release the current uri.
self._current_uri = None
def create_exp(self, experiment_name=None):
assert experiment_name is not None

View File

@@ -11,7 +11,7 @@ from ..log import get_module_logger
logger = get_module_logger("workflow", "INFO")
class Recorder(object):
class Recorder:
"""
This is the `Recorder` class for logging the experiments. The API is designed similar to mlflow.
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)

View File

@@ -96,7 +96,6 @@ port_analysis_config = {
}
# train
def train():
"""train model
@@ -110,8 +109,8 @@ def train():
# model initiaiton
model = init_instance_by_config(task["model"])
print(model)
dataset = init_instance_by_config(task["dataset"])
# To test __repr__
print(dataset)
print(R)
@@ -122,6 +121,7 @@ def train():
# prediction
recorder = R.get_recorder()
# To test __repr__
print(recorder)
rid = recorder.id
sr = SignalRecord(model, dataset, recorder)
@@ -137,6 +137,27 @@ def train():
return pred_score, {"ic": ic, "ric": ric}, rid
def fake_experiment():
"""A fake experiment workflow to test uri
Returns
-------
pass_or_not_for_default_uri: bool
pass_or_not_for_current_uri: bool
temporary_exp_dir: str
"""
# start exp
default_uri = R.get_uri()
current_uri = 'file:./temp-test-exp-mag'
with R.start(experiment_name="fake_workflow_for_expm", uri=current_uri):
R.log_params(**flatten_dict(task))
current_uri_to_check = R.get_uri()
default_uri_to_check = R.get_uri()
return default_uri == default_uri_to_check, current_uri == current_uri_to_check, current_uri
def backtest_analysis(pred, rid):
"""backtest and analysis
@@ -185,6 +206,12 @@ class TestAllFlow(TestAutoData):
"backtest failed",
)
def test_2_expmanager(self):
pass_default, pass_current, uri_path = fake_experiment()
self.assertTrue(pass_default, msg='default uri is incorrect')
self.assertTrue(pass_current, msg='current uri is incorrect')
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
def suite():
_suite = unittest.TestSuite()