From c4d6e00470e7cb1e962c56e971a6bfe5874ffadc Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 4 Mar 2021 21:04:01 -0800 Subject: [PATCH] Fix logic of uri in ExpM and add test --- qlib/data/ops.py | 2 +- qlib/workflow/__init__.py | 6 +++--- qlib/workflow/exp.py | 15 +++++++-------- qlib/workflow/expm.py | 28 +++++++++++++++++----------- qlib/workflow/recorder.py | 2 +- tests/test_all_pipeline.py | 31 +++++++++++++++++++++++++++++-- 6 files changed, 58 insertions(+), 26 deletions(-) diff --git a/qlib/data/ops.py b/qlib/data/ops.py index 940c24002..9307fdafa 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -1489,7 +1489,7 @@ OpsList = [ ] -class OpsWrapper(object): +class OpsWrapper: """Ops Wrapper""" def __init__(self): diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index c6bf0c86c..7bff505ce 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -39,8 +39,8 @@ class QlibRecorder: name of the recorder under the experiment one wants to start. uri : str The tracking uri of the experiment, where all the artifacts/metrics etc. will be stored. - The default uri are set in the qlib.config. Note that this uri argument will not change the one defined in the config file. - Therefore, the next time when user call this function in the same experiment, + The default uri is set in the qlib.config. Note that this uri argument will not change the one defined in the config file. + Therefore, the next time when users call this function in the same experiment, they have to also specify this argument with the same value. Otherwise, inconsistent uri may occur. """ run = self.start_exp(experiment_name, recorder_name, uri) @@ -280,7 +280,7 @@ class QlibRecorder: ------- The uri of current experiment manager. """ - return self.exp_manager.get_uri() + return self.exp_manager.uri def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None): """ diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index 15bb7604c..18b0a143d 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -11,7 +11,7 @@ from ..log import get_module_logger logger = get_module_logger("workflow", "INFO") -class Experiment(object): +class Experiment: """ This is the `Experiment` class for each experiment being run. The API is designed similar to mlflow. (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) @@ -173,7 +173,7 @@ class MLflowExperiment(Experiment): self._uri = uri self._default_name = None self._default_rec_name = "mlflow_recorder" - self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) + self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) def start(self, recorder_name=None): logger.info(f"Experiment {self.id} starts running ...") @@ -208,7 +208,6 @@ class MLflowExperiment(Experiment): else: recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False if is_new: - mlflow.set_experiment(self.name) self.active_recorder = recorder # start the recorder self.active_recorder.start_run() @@ -237,7 +236,7 @@ class MLflowExperiment(Experiment): ), "Please input at least one of recorder id or name before retrieving recorder." if recorder_id is not None: try: - run = self.client.get_run(recorder_id) + run = self._client.get_run(recorder_id) recorder = MLflowRecorder(self.id, self._uri, mlflow_run=run) return recorder except MlflowException: @@ -258,7 +257,7 @@ class MLflowExperiment(Experiment): max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results") order_by = kwargs.get("order_by") - return self.client.search_runs([self.id], filter_string, run_view_type, max_results, order_by) + return self._client.search_runs([self.id], filter_string, run_view_type, max_results, order_by) def delete_recorder(self, recorder_id=None, recorder_name=None): assert ( @@ -266,10 +265,10 @@ class MLflowExperiment(Experiment): ), "Please input a valid recorder id or name before deleting." try: if recorder_id is not None: - self.client.delete_run(recorder_id) + self._client.delete_run(recorder_id) else: recorder = self._get_recorder(recorder_name=recorder_name) - self.client.delete_run(recorder.id) + self._client.delete_run(recorder.id) except MlflowException as e: raise Exception( f"Error: {e}. Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct." @@ -278,7 +277,7 @@ class MLflowExperiment(Experiment): UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!! def list_recorders(self, max_results=UNLIMITED): - runs = self.client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1] + runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1] recorders = dict() for i in range(len(runs)): recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i]) diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 541507a73..82265b585 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -14,19 +14,20 @@ from ..log import get_module_logger logger = get_module_logger("workflow", "INFO") -class ExpManager(object): +class ExpManager: """ This is the `ExpManager` class for managing experiments. The API is designed similar to mlflow. (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) """ def __init__(self, uri, default_exp_name): - self.uri = uri + self._default_uri = uri + self._current_uri = None self.default_exp_name = default_exp_name self.active_experiment = None # only one experiment can active each time def __repr__(self): - return "{name}(uri={uri})".format(name=self.__class__.__name__, uri=self.uri) + return "{name}(default_uri={duri}, current_uri={curi})".format(name=self.__class__.__name__, duri=self._default_uri, curi=self._current_uri) def start_exp(self, experiment_name=None, recorder_name=None, uri=None, **kwargs): """ @@ -206,7 +207,8 @@ class ExpManager(object): """ raise NotImplementedError(f"Please implement the `delete_exp` method.") - def get_uri(self): + @property + def uri(self): """ Get the default tracking URI or current URI. @@ -214,7 +216,7 @@ class ExpManager(object): ------- The tracking URI string. """ - return self.uri + return self._current_uri or self._default_uri def list_experiments(self): """ @@ -234,25 +236,27 @@ class MLflowExpManager(ExpManager): def __init__(self, uri, default_exp_name): super(MLflowExpManager, self).__init__(uri, default_exp_name) + self._client = None @property def client(self): # Delay the creation of mlflow client in case of creating `mlruns` folder when importing qlib - if not hasattr(self, "_client"): + if self._client is None: self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri) return self._client def start_exp(self, experiment_name=None, recorder_name=None, uri=None): - # set the tracking uri + # Set the tracking uri if uri is None: logger.info("No tracking URI is provided. Use the default tracking URI.") else: - self.uri = uri - # create experiment + # Temporarily re-set the current uri as the uri argument. + self._current_uri = uri + # Create experiment experiment, _ = self._get_or_create_exp(experiment_name=experiment_name) - # set up active experiment + # Set up active experiment self.active_experiment = experiment - # start the experiment + # Start the experiment self.active_experiment.start(recorder_name) return self.active_experiment @@ -261,6 +265,8 @@ class MLflowExpManager(ExpManager): if self.active_experiment is not None: self.active_experiment.end(recorder_status) self.active_experiment = None + # When an experiment end, we will release the current uri. + self._current_uri = None def create_exp(self, experiment_name=None): assert experiment_name is not None diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 31077176d..e75ae347b 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -11,7 +11,7 @@ from ..log import get_module_logger logger = get_module_logger("workflow", "INFO") -class Recorder(object): +class Recorder: """ This is the `Recorder` class for logging the experiments. The API is designed similar to mlflow. (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index 97f3f986a..a75eada75 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -96,7 +96,6 @@ port_analysis_config = { } -# train def train(): """train model @@ -110,8 +109,8 @@ def train(): # model initiaiton model = init_instance_by_config(task["model"]) - print(model) dataset = init_instance_by_config(task["dataset"]) + # To test __repr__ print(dataset) print(R) @@ -122,6 +121,7 @@ def train(): # prediction recorder = R.get_recorder() + # To test __repr__ print(recorder) rid = recorder.id sr = SignalRecord(model, dataset, recorder) @@ -137,6 +137,27 @@ def train(): return pred_score, {"ic": ic, "ric": ric}, rid +def fake_experiment(): + """A fake experiment workflow to test uri + + Returns + ------- + pass_or_not_for_default_uri: bool + pass_or_not_for_current_uri: bool + temporary_exp_dir: str + """ + + # start exp + default_uri = R.get_uri() + current_uri = 'file:./temp-test-exp-mag' + with R.start(experiment_name="fake_workflow_for_expm", uri=current_uri): + R.log_params(**flatten_dict(task)) + + current_uri_to_check = R.get_uri() + default_uri_to_check = R.get_uri() + return default_uri == default_uri_to_check, current_uri == current_uri_to_check, current_uri + + def backtest_analysis(pred, rid): """backtest and analysis @@ -185,6 +206,12 @@ class TestAllFlow(TestAutoData): "backtest failed", ) + def test_2_expmanager(self): + pass_default, pass_current, uri_path = fake_experiment() + self.assertTrue(pass_default, msg='default uri is incorrect') + self.assertTrue(pass_current, msg='current uri is incorrect') + shutil.rmtree(str(Path(uri_path.strip("file:")).resolve())) + def suite(): _suite = unittest.TestSuite()