From a9a70dfddf131a278ba1a3d2c7395bf94d5fab08 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Wed, 3 Mar 2021 06:47:52 +0000 Subject: [PATCH 1/7] Update repr for DatasetH and ExpManager --- .gitignore | 2 ++ qlib/data/dataset/__init__.py | 17 +++++++++++------ qlib/workflow/__init__.py | 10 ++++++++-- qlib/workflow/expm.py | 3 +++ 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 5b3745a02..0ddd5d21f 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,5 @@ tags .pytest_cache/ .vscode/ + +*.swp diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 8ff8c1210..ecbeebc95 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -1,5 +1,5 @@ from ...utils.serial import Serializable -from typing import Union, List, Tuple +from typing import Union, List, Tuple, Dict, Text, Optional from ...utils import init_instance_by_config, np_ffill from ...log import get_module_logger from .handler import DataHandler, DataHandlerLP @@ -76,7 +76,7 @@ class DatasetH(Dataset): - The processing is related to data split. """ - def __init__(self, handler: Union[dict, DataHandler], segments: dict): + def __init__(self, handler: Union[Dict, DataHandler], segments: Dict): """ Parameters ---------- @@ -87,7 +87,7 @@ class DatasetH(Dataset): """ super().__init__(handler, segments) - def init(self, handler_kwargs: dict = None, segment_kwargs: dict = None): + def init(self, handler_kwargs: Optional[Dict] = None, segment_kwargs: Optional[Dict] = None): """ Initialize the DatasetH @@ -124,7 +124,7 @@ class DatasetH(Dataset): raise TypeError(f"param handler_kwargs must be type dict, not {type(segment_kwargs)}") self.segments = segment_kwargs.copy() - def setup_data(self, handler: Union[dict, DataHandler], segments: dict): + def setup_data(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple]): """ Setup the underlying data. @@ -156,6 +156,11 @@ class DatasetH(Dataset): self.handler = init_instance_by_config(handler, accept_types=DataHandler) self.segments = segments.copy() + def __repr__(self): + return "{name}(handler={handler}, segments={segments})".format( + name=self.__class__.__name__, handler=self.handler, segments=self.segments + ) + def _prepare_seg(self, slc: slice, **kwargs): """ Give a slice, retrieve the according data @@ -168,7 +173,7 @@ class DatasetH(Dataset): def prepare( self, - segments: Union[List[str], Tuple[str], str, slice], + segments: Union[List[Text], Tuple[Text], Text, slice], col_set=DataHandler.CS_ALL, data_key=DataHandlerLP.DK_I, **kwargs, @@ -178,7 +183,7 @@ class DatasetH(Dataset): Parameters ---------- - segments : Union[List[str], Tuple[str], str, slice] + segments : Union[List[Text], Tuple[Text], Text, slice] Describe the scope of the data to be prepared Here are some examples: diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index aedf73a9c..9602963d8 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -16,8 +16,11 @@ class QlibRecorder: def __init__(self, exp_manager): self.exp_manager = exp_manager + def __repr__(self): + return "{name}(manager={manager})".format(name=self.__class__.__name__, manager=self.exp_manager) + @contextmanager - def start(self, experiment_name=None, recorder_name=None): + def start(self, experiment_name=None, recorder_name=None, uri=None): """ Method to start an experiment. This method can only be called within a Python's `with` statement. Here is the example code: @@ -34,8 +37,11 @@ class QlibRecorder: name of the experiment one wants to start. recorder_name : str name of the recorder under the experiment one wants to start. + uri : str + the tracking uri of the experiment, where all the artifacts/metrics etc. will be stored. + The default uri are set in the qlib.config. """ - run = self.start_exp(experiment_name, recorder_name) + run = self.start_exp(experiment_name, recorder_name, uri) try: yield run except Exception as e: diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index a50dce7c9..e905434cb 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -25,6 +25,9 @@ class ExpManager: self.default_exp_name = default_exp_name self.active_experiment = None # only one experiment can active each time + def __repr__(self): + return "{name}(uri={uri})".format(name=self.__class__.__name__, uri=self.uri) + def start_exp(self, experiment_name=None, recorder_name=None, uri=None, **kwargs): """ Start an experiment. This method includes first get_or_create an experiment, and then From 229a39d0d353788f649d414c01e229ec1a3ff92e Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Wed, 3 Mar 2021 07:30:55 +0000 Subject: [PATCH 2/7] Fix typos in DataHandler's doc --- qlib/data/dataset/handler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 2889c4465..050043ba6 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -35,7 +35,7 @@ class DataHandler(Serializable): The data handler try to maintain a handler with 2 level. `datetime` & `instruments`. - Any order of the index level can be suported(The order will implied in the data). + Any order of the index level can be suported (The order will be implied in the data). The order <`datetime`, `instruments`> will be used when the dataframe index name is missed. Example of the data: @@ -47,8 +47,8 @@ class DataHandler(Serializable): $close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0 datetime instrument 2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032 - SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042 - SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289 + SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042 + SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289 """ From 592db903b3afbfd1628fd7e6ad4f150a5c5d13cc Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 4 Mar 2021 05:02:56 +0000 Subject: [PATCH 3/7] Update repr for Experiment & Recorder --- qlib/workflow/__init__.py | 6 ++++-- qlib/workflow/exp.py | 4 ++-- qlib/workflow/expm.py | 2 +- qlib/workflow/recorder.py | 4 ++-- tests/test_all_pipeline.py | 4 ++++ 5 files changed, 13 insertions(+), 7 deletions(-) diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 9602963d8..c6bf0c86c 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -38,8 +38,10 @@ class QlibRecorder: recorder_name : str name of the recorder under the experiment one wants to start. uri : str - the tracking uri of the experiment, where all the artifacts/metrics etc. will be stored. - The default uri are set in the qlib.config. + The tracking uri of the experiment, where all the artifacts/metrics etc. will be stored. + The default uri are set in the qlib.config. Note that this uri argument will not change the one defined in the config file. + Therefore, the next time when user call this function in the same experiment, + they have to also specify this argument with the same value. Otherwise, inconsistent uri may occur. """ run = self.start_exp(experiment_name, recorder_name, uri) try: diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index c2548971a..7c98fd68f 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -11,7 +11,7 @@ from ..log import get_module_logger logger = get_module_logger("workflow", "INFO") -class Experiment: +class Experiment(object): """ This is the `Experiment` class for each experiment being run. The API is designed similar to mlflow. (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) @@ -23,7 +23,7 @@ class Experiment: self.active_recorder = None # only one recorder can running each time def __repr__(self): - return str(self.info) + return "{name}(info={info})".format(name=self.__class__.__name__, info=self.info) def __str__(self): return str(self.info) diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index e905434cb..541507a73 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -14,7 +14,7 @@ from ..log import get_module_logger logger = get_module_logger("workflow", "INFO") -class ExpManager: +class ExpManager(object): """ This is the `ExpManager` class for managing experiments. The API is designed similar to mlflow. (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index ceb57150c..31077176d 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -11,7 +11,7 @@ from ..log import get_module_logger logger = get_module_logger("workflow", "INFO") -class Recorder: +class Recorder(object): """ This is the `Recorder` class for logging the experiments. The API is designed similar to mlflow. (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) @@ -34,7 +34,7 @@ class Recorder: self.status = Recorder.STATUS_S def __repr__(self): - return str(self.info) + return "{name}(info={info})".format(name=self.__class__.__name__, info=self.info) def __str__(self): return str(self.info) diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index f6e77cba4..97f3f986a 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -110,7 +110,10 @@ def train(): # model initiaiton model = init_instance_by_config(task["model"]) + print(model) dataset = init_instance_by_config(task["dataset"]) + print(dataset) + print(R) # start exp with R.start(experiment_name="workflow"): @@ -119,6 +122,7 @@ def train(): # prediction recorder = R.get_recorder() + print(recorder) rid = recorder.id sr = SignalRecord(model, dataset, recorder) sr.generate() From ee7eb79277896ad3d1729dcf028cb0a027b9d538 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 4 Mar 2021 06:15:24 +0000 Subject: [PATCH 4/7] Fix unexpected mlruns folder error --- qlib/workflow/exp.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index 7c98fd68f..15bb7604c 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -176,8 +176,6 @@ class MLflowExperiment(Experiment): self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) def start(self, recorder_name=None): - # set the active experiment - mlflow.set_experiment(self.name) logger.info(f"Experiment {self.id} starts running ...") # set up recorder recorder = self.create_recorder(recorder_name) From c4d6e00470e7cb1e962c56e971a6bfe5874ffadc Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 4 Mar 2021 21:04:01 -0800 Subject: [PATCH 5/7] Fix logic of uri in ExpM and add test --- qlib/data/ops.py | 2 +- qlib/workflow/__init__.py | 6 +++--- qlib/workflow/exp.py | 15 +++++++-------- qlib/workflow/expm.py | 28 +++++++++++++++++----------- qlib/workflow/recorder.py | 2 +- tests/test_all_pipeline.py | 31 +++++++++++++++++++++++++++++-- 6 files changed, 58 insertions(+), 26 deletions(-) diff --git a/qlib/data/ops.py b/qlib/data/ops.py index 940c24002..9307fdafa 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -1489,7 +1489,7 @@ OpsList = [ ] -class OpsWrapper(object): +class OpsWrapper: """Ops Wrapper""" def __init__(self): diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index c6bf0c86c..7bff505ce 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -39,8 +39,8 @@ class QlibRecorder: name of the recorder under the experiment one wants to start. uri : str The tracking uri of the experiment, where all the artifacts/metrics etc. will be stored. - The default uri are set in the qlib.config. Note that this uri argument will not change the one defined in the config file. - Therefore, the next time when user call this function in the same experiment, + The default uri is set in the qlib.config. Note that this uri argument will not change the one defined in the config file. + Therefore, the next time when users call this function in the same experiment, they have to also specify this argument with the same value. Otherwise, inconsistent uri may occur. """ run = self.start_exp(experiment_name, recorder_name, uri) @@ -280,7 +280,7 @@ class QlibRecorder: ------- The uri of current experiment manager. """ - return self.exp_manager.get_uri() + return self.exp_manager.uri def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None): """ diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index 15bb7604c..18b0a143d 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -11,7 +11,7 @@ from ..log import get_module_logger logger = get_module_logger("workflow", "INFO") -class Experiment(object): +class Experiment: """ This is the `Experiment` class for each experiment being run. The API is designed similar to mlflow. (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) @@ -173,7 +173,7 @@ class MLflowExperiment(Experiment): self._uri = uri self._default_name = None self._default_rec_name = "mlflow_recorder" - self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) + self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) def start(self, recorder_name=None): logger.info(f"Experiment {self.id} starts running ...") @@ -208,7 +208,6 @@ class MLflowExperiment(Experiment): else: recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False if is_new: - mlflow.set_experiment(self.name) self.active_recorder = recorder # start the recorder self.active_recorder.start_run() @@ -237,7 +236,7 @@ class MLflowExperiment(Experiment): ), "Please input at least one of recorder id or name before retrieving recorder." if recorder_id is not None: try: - run = self.client.get_run(recorder_id) + run = self._client.get_run(recorder_id) recorder = MLflowRecorder(self.id, self._uri, mlflow_run=run) return recorder except MlflowException: @@ -258,7 +257,7 @@ class MLflowExperiment(Experiment): max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results") order_by = kwargs.get("order_by") - return self.client.search_runs([self.id], filter_string, run_view_type, max_results, order_by) + return self._client.search_runs([self.id], filter_string, run_view_type, max_results, order_by) def delete_recorder(self, recorder_id=None, recorder_name=None): assert ( @@ -266,10 +265,10 @@ class MLflowExperiment(Experiment): ), "Please input a valid recorder id or name before deleting." try: if recorder_id is not None: - self.client.delete_run(recorder_id) + self._client.delete_run(recorder_id) else: recorder = self._get_recorder(recorder_name=recorder_name) - self.client.delete_run(recorder.id) + self._client.delete_run(recorder.id) except MlflowException as e: raise Exception( f"Error: {e}. Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct." @@ -278,7 +277,7 @@ class MLflowExperiment(Experiment): UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!! def list_recorders(self, max_results=UNLIMITED): - runs = self.client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1] + runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1] recorders = dict() for i in range(len(runs)): recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i]) diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 541507a73..82265b585 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -14,19 +14,20 @@ from ..log import get_module_logger logger = get_module_logger("workflow", "INFO") -class ExpManager(object): +class ExpManager: """ This is the `ExpManager` class for managing experiments. The API is designed similar to mlflow. (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) """ def __init__(self, uri, default_exp_name): - self.uri = uri + self._default_uri = uri + self._current_uri = None self.default_exp_name = default_exp_name self.active_experiment = None # only one experiment can active each time def __repr__(self): - return "{name}(uri={uri})".format(name=self.__class__.__name__, uri=self.uri) + return "{name}(default_uri={duri}, current_uri={curi})".format(name=self.__class__.__name__, duri=self._default_uri, curi=self._current_uri) def start_exp(self, experiment_name=None, recorder_name=None, uri=None, **kwargs): """ @@ -206,7 +207,8 @@ class ExpManager(object): """ raise NotImplementedError(f"Please implement the `delete_exp` method.") - def get_uri(self): + @property + def uri(self): """ Get the default tracking URI or current URI. @@ -214,7 +216,7 @@ class ExpManager(object): ------- The tracking URI string. """ - return self.uri + return self._current_uri or self._default_uri def list_experiments(self): """ @@ -234,25 +236,27 @@ class MLflowExpManager(ExpManager): def __init__(self, uri, default_exp_name): super(MLflowExpManager, self).__init__(uri, default_exp_name) + self._client = None @property def client(self): # Delay the creation of mlflow client in case of creating `mlruns` folder when importing qlib - if not hasattr(self, "_client"): + if self._client is None: self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri) return self._client def start_exp(self, experiment_name=None, recorder_name=None, uri=None): - # set the tracking uri + # Set the tracking uri if uri is None: logger.info("No tracking URI is provided. Use the default tracking URI.") else: - self.uri = uri - # create experiment + # Temporarily re-set the current uri as the uri argument. + self._current_uri = uri + # Create experiment experiment, _ = self._get_or_create_exp(experiment_name=experiment_name) - # set up active experiment + # Set up active experiment self.active_experiment = experiment - # start the experiment + # Start the experiment self.active_experiment.start(recorder_name) return self.active_experiment @@ -261,6 +265,8 @@ class MLflowExpManager(ExpManager): if self.active_experiment is not None: self.active_experiment.end(recorder_status) self.active_experiment = None + # When an experiment end, we will release the current uri. + self._current_uri = None def create_exp(self, experiment_name=None): assert experiment_name is not None diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 31077176d..e75ae347b 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -11,7 +11,7 @@ from ..log import get_module_logger logger = get_module_logger("workflow", "INFO") -class Recorder(object): +class Recorder: """ This is the `Recorder` class for logging the experiments. The API is designed similar to mlflow. (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index 97f3f986a..a75eada75 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -96,7 +96,6 @@ port_analysis_config = { } -# train def train(): """train model @@ -110,8 +109,8 @@ def train(): # model initiaiton model = init_instance_by_config(task["model"]) - print(model) dataset = init_instance_by_config(task["dataset"]) + # To test __repr__ print(dataset) print(R) @@ -122,6 +121,7 @@ def train(): # prediction recorder = R.get_recorder() + # To test __repr__ print(recorder) rid = recorder.id sr = SignalRecord(model, dataset, recorder) @@ -137,6 +137,27 @@ def train(): return pred_score, {"ic": ic, "ric": ric}, rid +def fake_experiment(): + """A fake experiment workflow to test uri + + Returns + ------- + pass_or_not_for_default_uri: bool + pass_or_not_for_current_uri: bool + temporary_exp_dir: str + """ + + # start exp + default_uri = R.get_uri() + current_uri = 'file:./temp-test-exp-mag' + with R.start(experiment_name="fake_workflow_for_expm", uri=current_uri): + R.log_params(**flatten_dict(task)) + + current_uri_to_check = R.get_uri() + default_uri_to_check = R.get_uri() + return default_uri == default_uri_to_check, current_uri == current_uri_to_check, current_uri + + def backtest_analysis(pred, rid): """backtest and analysis @@ -185,6 +206,12 @@ class TestAllFlow(TestAutoData): "backtest failed", ) + def test_2_expmanager(self): + pass_default, pass_current, uri_path = fake_experiment() + self.assertTrue(pass_default, msg='default uri is incorrect') + self.assertTrue(pass_current, msg='current uri is incorrect') + shutil.rmtree(str(Path(uri_path.strip("file:")).resolve())) + def suite(): _suite = unittest.TestSuite() From 452fb8f013904a0f119771da7dc37b4d083201b4 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 4 Mar 2021 22:33:35 -0800 Subject: [PATCH 6/7] Make mlflow client consistant with uri --- qlib/workflow/expm.py | 66 +++++++++++++++++++++++++++++++++---------- 1 file changed, 51 insertions(+), 15 deletions(-) diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 82265b585..362b1a82b 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -7,8 +7,10 @@ from mlflow.entities import ViewType import os from pathlib import Path from contextlib import contextmanager +from typing import Optional, Text + from .exp import MLflowExperiment, Experiment -from .recorder import Recorder, MLflowRecorder +from .recorder import Recorder from ..log import get_module_logger logger = get_module_logger("workflow", "INFO") @@ -20,16 +22,24 @@ class ExpManager: (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) """ - def __init__(self, uri, default_exp_name): + def __init__(self, uri: Text, default_exp_name: Optional[Text]): self._default_uri = uri self._current_uri = None self.default_exp_name = default_exp_name self.active_experiment = None # only one experiment can active each time def __repr__(self): - return "{name}(default_uri={duri}, current_uri={curi})".format(name=self.__class__.__name__, duri=self._default_uri, curi=self._current_uri) + return "{name}(default_uri={duri}, current_uri={curi})".format( + name=self.__class__.__name__, duri=self._default_uri, curi=self._current_uri + ) - def start_exp(self, experiment_name=None, recorder_name=None, uri=None, **kwargs): + def start_exp( + self, + experiment_name: Optional[Text] = None, + recorder_name: Optional[Text] = None, + uri: Optional[Text] = None, + **kwargs, + ): """ Start an experiment. This method includes first get_or_create an experiment, and then set it to be active. @@ -49,7 +59,7 @@ class ExpManager: """ raise NotImplementedError(f"Please implement the `start_exp` method.") - def end_exp(self, recorder_status: str = Recorder.STATUS_S, **kwargs): + def end_exp(self, recorder_status: Text = Recorder.STATUS_S, **kwargs): """ End an active experiment. @@ -62,7 +72,7 @@ class ExpManager: """ raise NotImplementedError(f"Please implement the `end_exp` method.") - def create_exp(self, experiment_name=None): + def create_exp(self, experiment_name: Optional[Text] = None): """ Create an experiment. @@ -218,6 +228,30 @@ class ExpManager: """ return self._current_uri or self._default_uri + def set_uri(self, uri: Optional[Text] = None): + """ + Set the current tracking URI and the corresponding variables. + + Parameters + ---------- + uri : str + + """ + if uri is None: + logger.info("No tracking URI is provided. Use the default tracking URI.") + self._current_uri = self._default_uri + else: + # Temporarily re-set the current uri as the uri argument. + self._current_uri = uri + # Customized features for subclasses. + self._set_uri() + + def _set_uri(self): + """ + Customized features for subclasses' set_uri function. + """ + raise NotImplementedError(f"Please implement the `_set_uri` method.") + def list_experiments(self): """ List all the existing experiments. @@ -234,10 +268,14 @@ class MLflowExpManager(ExpManager): Use mlflow to implement ExpManager. """ - def __init__(self, uri, default_exp_name): + def __init__(self, uri: Text, default_exp_name: Optional[Text]): super(MLflowExpManager, self).__init__(uri, default_exp_name) self._client = None + def _set_uri(self): + self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri) + logger.info('{:}'.format(self._client)) + @property def client(self): # Delay the creation of mlflow client in case of creating `mlruns` folder when importing qlib @@ -245,13 +283,11 @@ class MLflowExpManager(ExpManager): self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri) return self._client - def start_exp(self, experiment_name=None, recorder_name=None, uri=None): + def start_exp( + self, experiment_name: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None + ): # Set the tracking uri - if uri is None: - logger.info("No tracking URI is provided. Use the default tracking URI.") - else: - # Temporarily re-set the current uri as the uri argument. - self._current_uri = uri + self.set_uri(uri) # Create experiment experiment, _ = self._get_or_create_exp(experiment_name=experiment_name) # Set up active experiment @@ -261,14 +297,14 @@ class MLflowExpManager(ExpManager): return self.active_experiment - def end_exp(self, recorder_status: str = Recorder.STATUS_S): + def end_exp(self, recorder_status: Text = Recorder.STATUS_S): if self.active_experiment is not None: self.active_experiment.end(recorder_status) self.active_experiment = None # When an experiment end, we will release the current uri. self._current_uri = None - def create_exp(self, experiment_name=None): + def create_exp(self, experiment_name: Optional[Text] = None): assert experiment_name is not None # init experiment experiment_id = self.client.create_experiment(experiment_name) From e327f404e33c7fc12f7c2fb43dd383b5d9dcaba4 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 4 Mar 2021 22:37:58 -0800 Subject: [PATCH 7/7] Fix pylint issues --- qlib/workflow/expm.py | 2 +- tests/test_all_pipeline.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 362b1a82b..4ba72a634 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -274,7 +274,7 @@ class MLflowExpManager(ExpManager): def _set_uri(self): self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri) - logger.info('{:}'.format(self._client)) + logger.info("{:}".format(self._client)) @property def client(self): diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index a75eada75..d9d684697 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -149,10 +149,10 @@ def fake_experiment(): # start exp default_uri = R.get_uri() - current_uri = 'file:./temp-test-exp-mag' + current_uri = "file:./temp-test-exp-mag" with R.start(experiment_name="fake_workflow_for_expm", uri=current_uri): R.log_params(**flatten_dict(task)) - + current_uri_to_check = R.get_uri() default_uri_to_check = R.get_uri() return default_uri == default_uri_to_check, current_uri == current_uri_to_check, current_uri @@ -208,8 +208,8 @@ class TestAllFlow(TestAutoData): def test_2_expmanager(self): pass_default, pass_current, uri_path = fake_experiment() - self.assertTrue(pass_default, msg='default uri is incorrect') - self.assertTrue(pass_current, msg='current uri is incorrect') + self.assertTrue(pass_default, msg="default uri is incorrect") + self.assertTrue(pass_current, msg="current uri is incorrect") shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))