mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Fix logic of uri in ExpM and add test
This commit is contained in:
@@ -1489,7 +1489,7 @@ OpsList = [
|
||||
]
|
||||
|
||||
|
||||
class OpsWrapper(object):
|
||||
class OpsWrapper:
|
||||
"""Ops Wrapper"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -39,8 +39,8 @@ class QlibRecorder:
|
||||
name of the recorder under the experiment one wants to start.
|
||||
uri : str
|
||||
The tracking uri of the experiment, where all the artifacts/metrics etc. will be stored.
|
||||
The default uri are set in the qlib.config. Note that this uri argument will not change the one defined in the config file.
|
||||
Therefore, the next time when user call this function in the same experiment,
|
||||
The default uri is set in the qlib.config. Note that this uri argument will not change the one defined in the config file.
|
||||
Therefore, the next time when users call this function in the same experiment,
|
||||
they have to also specify this argument with the same value. Otherwise, inconsistent uri may occur.
|
||||
"""
|
||||
run = self.start_exp(experiment_name, recorder_name, uri)
|
||||
@@ -280,7 +280,7 @@ class QlibRecorder:
|
||||
-------
|
||||
The uri of current experiment manager.
|
||||
"""
|
||||
return self.exp_manager.get_uri()
|
||||
return self.exp_manager.uri
|
||||
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None):
|
||||
"""
|
||||
|
||||
@@ -11,7 +11,7 @@ from ..log import get_module_logger
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
|
||||
|
||||
class Experiment(object):
|
||||
class Experiment:
|
||||
"""
|
||||
This is the `Experiment` class for each experiment being run. The API is designed similar to mlflow.
|
||||
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
|
||||
@@ -173,7 +173,7 @@ class MLflowExperiment(Experiment):
|
||||
self._uri = uri
|
||||
self._default_name = None
|
||||
self._default_rec_name = "mlflow_recorder"
|
||||
self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
|
||||
def start(self, recorder_name=None):
|
||||
logger.info(f"Experiment {self.id} starts running ...")
|
||||
@@ -208,7 +208,6 @@ class MLflowExperiment(Experiment):
|
||||
else:
|
||||
recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
|
||||
if is_new:
|
||||
mlflow.set_experiment(self.name)
|
||||
self.active_recorder = recorder
|
||||
# start the recorder
|
||||
self.active_recorder.start_run()
|
||||
@@ -237,7 +236,7 @@ class MLflowExperiment(Experiment):
|
||||
), "Please input at least one of recorder id or name before retrieving recorder."
|
||||
if recorder_id is not None:
|
||||
try:
|
||||
run = self.client.get_run(recorder_id)
|
||||
run = self._client.get_run(recorder_id)
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=run)
|
||||
return recorder
|
||||
except MlflowException:
|
||||
@@ -258,7 +257,7 @@ class MLflowExperiment(Experiment):
|
||||
max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
|
||||
order_by = kwargs.get("order_by")
|
||||
|
||||
return self.client.search_runs([self.id], filter_string, run_view_type, max_results, order_by)
|
||||
return self._client.search_runs([self.id], filter_string, run_view_type, max_results, order_by)
|
||||
|
||||
def delete_recorder(self, recorder_id=None, recorder_name=None):
|
||||
assert (
|
||||
@@ -266,10 +265,10 @@ class MLflowExperiment(Experiment):
|
||||
), "Please input a valid recorder id or name before deleting."
|
||||
try:
|
||||
if recorder_id is not None:
|
||||
self.client.delete_run(recorder_id)
|
||||
self._client.delete_run(recorder_id)
|
||||
else:
|
||||
recorder = self._get_recorder(recorder_name=recorder_name)
|
||||
self.client.delete_run(recorder.id)
|
||||
self._client.delete_run(recorder.id)
|
||||
except MlflowException as e:
|
||||
raise Exception(
|
||||
f"Error: {e}. Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct."
|
||||
@@ -278,7 +277,7 @@ class MLflowExperiment(Experiment):
|
||||
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
|
||||
|
||||
def list_recorders(self, max_results=UNLIMITED):
|
||||
runs = self.client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1]
|
||||
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1]
|
||||
recorders = dict()
|
||||
for i in range(len(runs)):
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
|
||||
|
||||
@@ -14,19 +14,20 @@ from ..log import get_module_logger
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
|
||||
|
||||
class ExpManager(object):
|
||||
class ExpManager:
|
||||
"""
|
||||
This is the `ExpManager` class for managing experiments. The API is designed similar to mlflow.
|
||||
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
|
||||
"""
|
||||
|
||||
def __init__(self, uri, default_exp_name):
|
||||
self.uri = uri
|
||||
self._default_uri = uri
|
||||
self._current_uri = None
|
||||
self.default_exp_name = default_exp_name
|
||||
self.active_experiment = None # only one experiment can active each time
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(uri={uri})".format(name=self.__class__.__name__, uri=self.uri)
|
||||
return "{name}(default_uri={duri}, current_uri={curi})".format(name=self.__class__.__name__, duri=self._default_uri, curi=self._current_uri)
|
||||
|
||||
def start_exp(self, experiment_name=None, recorder_name=None, uri=None, **kwargs):
|
||||
"""
|
||||
@@ -206,7 +207,8 @@ class ExpManager(object):
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `delete_exp` method.")
|
||||
|
||||
def get_uri(self):
|
||||
@property
|
||||
def uri(self):
|
||||
"""
|
||||
Get the default tracking URI or current URI.
|
||||
|
||||
@@ -214,7 +216,7 @@ class ExpManager(object):
|
||||
-------
|
||||
The tracking URI string.
|
||||
"""
|
||||
return self.uri
|
||||
return self._current_uri or self._default_uri
|
||||
|
||||
def list_experiments(self):
|
||||
"""
|
||||
@@ -234,25 +236,27 @@ class MLflowExpManager(ExpManager):
|
||||
|
||||
def __init__(self, uri, default_exp_name):
|
||||
super(MLflowExpManager, self).__init__(uri, default_exp_name)
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
# Delay the creation of mlflow client in case of creating `mlruns` folder when importing qlib
|
||||
if not hasattr(self, "_client"):
|
||||
if self._client is None:
|
||||
self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
|
||||
return self._client
|
||||
|
||||
def start_exp(self, experiment_name=None, recorder_name=None, uri=None):
|
||||
# set the tracking uri
|
||||
# Set the tracking uri
|
||||
if uri is None:
|
||||
logger.info("No tracking URI is provided. Use the default tracking URI.")
|
||||
else:
|
||||
self.uri = uri
|
||||
# create experiment
|
||||
# Temporarily re-set the current uri as the uri argument.
|
||||
self._current_uri = uri
|
||||
# Create experiment
|
||||
experiment, _ = self._get_or_create_exp(experiment_name=experiment_name)
|
||||
# set up active experiment
|
||||
# Set up active experiment
|
||||
self.active_experiment = experiment
|
||||
# start the experiment
|
||||
# Start the experiment
|
||||
self.active_experiment.start(recorder_name)
|
||||
|
||||
return self.active_experiment
|
||||
@@ -261,6 +265,8 @@ class MLflowExpManager(ExpManager):
|
||||
if self.active_experiment is not None:
|
||||
self.active_experiment.end(recorder_status)
|
||||
self.active_experiment = None
|
||||
# When an experiment end, we will release the current uri.
|
||||
self._current_uri = None
|
||||
|
||||
def create_exp(self, experiment_name=None):
|
||||
assert experiment_name is not None
|
||||
|
||||
@@ -11,7 +11,7 @@ from ..log import get_module_logger
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
|
||||
|
||||
class Recorder(object):
|
||||
class Recorder:
|
||||
"""
|
||||
This is the `Recorder` class for logging the experiments. The API is designed similar to mlflow.
|
||||
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
|
||||
|
||||
@@ -96,7 +96,6 @@ port_analysis_config = {
|
||||
}
|
||||
|
||||
|
||||
# train
|
||||
def train():
|
||||
"""train model
|
||||
|
||||
@@ -110,8 +109,8 @@ def train():
|
||||
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(task["model"])
|
||||
print(model)
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
# To test __repr__
|
||||
print(dataset)
|
||||
print(R)
|
||||
|
||||
@@ -122,6 +121,7 @@ def train():
|
||||
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
# To test __repr__
|
||||
print(recorder)
|
||||
rid = recorder.id
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
@@ -137,6 +137,27 @@ def train():
|
||||
return pred_score, {"ic": ic, "ric": ric}, rid
|
||||
|
||||
|
||||
def fake_experiment():
|
||||
"""A fake experiment workflow to test uri
|
||||
|
||||
Returns
|
||||
-------
|
||||
pass_or_not_for_default_uri: bool
|
||||
pass_or_not_for_current_uri: bool
|
||||
temporary_exp_dir: str
|
||||
"""
|
||||
|
||||
# start exp
|
||||
default_uri = R.get_uri()
|
||||
current_uri = 'file:./temp-test-exp-mag'
|
||||
with R.start(experiment_name="fake_workflow_for_expm", uri=current_uri):
|
||||
R.log_params(**flatten_dict(task))
|
||||
|
||||
current_uri_to_check = R.get_uri()
|
||||
default_uri_to_check = R.get_uri()
|
||||
return default_uri == default_uri_to_check, current_uri == current_uri_to_check, current_uri
|
||||
|
||||
|
||||
def backtest_analysis(pred, rid):
|
||||
"""backtest and analysis
|
||||
|
||||
@@ -185,6 +206,12 @@ class TestAllFlow(TestAutoData):
|
||||
"backtest failed",
|
||||
)
|
||||
|
||||
def test_2_expmanager(self):
|
||||
pass_default, pass_current, uri_path = fake_experiment()
|
||||
self.assertTrue(pass_default, msg='default uri is incorrect')
|
||||
self.assertTrue(pass_current, msg='current uri is incorrect')
|
||||
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
|
||||
|
||||
|
||||
def suite():
|
||||
_suite = unittest.TestSuite()
|
||||
|
||||
Reference in New Issue
Block a user