diff --git a/qlib/data/ops.py b/qlib/data/ops.py index 940c24002..9307fdafa 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -1489,7 +1489,7 @@ OpsList = [ ] -class OpsWrapper(object): +class OpsWrapper: """Ops Wrapper""" def __init__(self): diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index c6bf0c86c..7bff505ce 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -39,8 +39,8 @@ class QlibRecorder: name of the recorder under the experiment one wants to start. uri : str The tracking uri of the experiment, where all the artifacts/metrics etc. will be stored. - The default uri are set in the qlib.config. Note that this uri argument will not change the one defined in the config file. - Therefore, the next time when user call this function in the same experiment, + The default uri is set in the qlib.config. Note that this uri argument will not change the one defined in the config file. + Therefore, the next time when users call this function in the same experiment, they have to also specify this argument with the same value. Otherwise, inconsistent uri may occur. """ run = self.start_exp(experiment_name, recorder_name, uri) @@ -280,7 +280,7 @@ class QlibRecorder: ------- The uri of current experiment manager. """ - return self.exp_manager.get_uri() + return self.exp_manager.uri def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None): """ diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index 15bb7604c..18b0a143d 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -11,7 +11,7 @@ from ..log import get_module_logger logger = get_module_logger("workflow", "INFO") -class Experiment(object): +class Experiment: """ This is the `Experiment` class for each experiment being run. The API is designed similar to mlflow. (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) @@ -173,7 +173,7 @@ class MLflowExperiment(Experiment): self._uri = uri self._default_name = None self._default_rec_name = "mlflow_recorder" - self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) + self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) def start(self, recorder_name=None): logger.info(f"Experiment {self.id} starts running ...") @@ -208,7 +208,6 @@ class MLflowExperiment(Experiment): else: recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False if is_new: - mlflow.set_experiment(self.name) self.active_recorder = recorder # start the recorder self.active_recorder.start_run() @@ -237,7 +236,7 @@ class MLflowExperiment(Experiment): ), "Please input at least one of recorder id or name before retrieving recorder." if recorder_id is not None: try: - run = self.client.get_run(recorder_id) + run = self._client.get_run(recorder_id) recorder = MLflowRecorder(self.id, self._uri, mlflow_run=run) return recorder except MlflowException: @@ -258,7 +257,7 @@ class MLflowExperiment(Experiment): max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results") order_by = kwargs.get("order_by") - return self.client.search_runs([self.id], filter_string, run_view_type, max_results, order_by) + return self._client.search_runs([self.id], filter_string, run_view_type, max_results, order_by) def delete_recorder(self, recorder_id=None, recorder_name=None): assert ( @@ -266,10 +265,10 @@ class MLflowExperiment(Experiment): ), "Please input a valid recorder id or name before deleting." try: if recorder_id is not None: - self.client.delete_run(recorder_id) + self._client.delete_run(recorder_id) else: recorder = self._get_recorder(recorder_name=recorder_name) - self.client.delete_run(recorder.id) + self._client.delete_run(recorder.id) except MlflowException as e: raise Exception( f"Error: {e}. Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct." @@ -278,7 +277,7 @@ class MLflowExperiment(Experiment): UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!! def list_recorders(self, max_results=UNLIMITED): - runs = self.client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1] + runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1] recorders = dict() for i in range(len(runs)): recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i]) diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 541507a73..4ba72a634 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -7,28 +7,39 @@ from mlflow.entities import ViewType import os from pathlib import Path from contextlib import contextmanager +from typing import Optional, Text + from .exp import MLflowExperiment, Experiment -from .recorder import Recorder, MLflowRecorder +from .recorder import Recorder from ..log import get_module_logger logger = get_module_logger("workflow", "INFO") -class ExpManager(object): +class ExpManager: """ This is the `ExpManager` class for managing experiments. The API is designed similar to mlflow. (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) """ - def __init__(self, uri, default_exp_name): - self.uri = uri + def __init__(self, uri: Text, default_exp_name: Optional[Text]): + self._default_uri = uri + self._current_uri = None self.default_exp_name = default_exp_name self.active_experiment = None # only one experiment can active each time def __repr__(self): - return "{name}(uri={uri})".format(name=self.__class__.__name__, uri=self.uri) + return "{name}(default_uri={duri}, current_uri={curi})".format( + name=self.__class__.__name__, duri=self._default_uri, curi=self._current_uri + ) - def start_exp(self, experiment_name=None, recorder_name=None, uri=None, **kwargs): + def start_exp( + self, + experiment_name: Optional[Text] = None, + recorder_name: Optional[Text] = None, + uri: Optional[Text] = None, + **kwargs, + ): """ Start an experiment. This method includes first get_or_create an experiment, and then set it to be active. @@ -48,7 +59,7 @@ class ExpManager(object): """ raise NotImplementedError(f"Please implement the `start_exp` method.") - def end_exp(self, recorder_status: str = Recorder.STATUS_S, **kwargs): + def end_exp(self, recorder_status: Text = Recorder.STATUS_S, **kwargs): """ End an active experiment. @@ -61,7 +72,7 @@ class ExpManager(object): """ raise NotImplementedError(f"Please implement the `end_exp` method.") - def create_exp(self, experiment_name=None): + def create_exp(self, experiment_name: Optional[Text] = None): """ Create an experiment. @@ -206,7 +217,8 @@ class ExpManager(object): """ raise NotImplementedError(f"Please implement the `delete_exp` method.") - def get_uri(self): + @property + def uri(self): """ Get the default tracking URI or current URI. @@ -214,7 +226,31 @@ class ExpManager(object): ------- The tracking URI string. """ - return self.uri + return self._current_uri or self._default_uri + + def set_uri(self, uri: Optional[Text] = None): + """ + Set the current tracking URI and the corresponding variables. + + Parameters + ---------- + uri : str + + """ + if uri is None: + logger.info("No tracking URI is provided. Use the default tracking URI.") + self._current_uri = self._default_uri + else: + # Temporarily re-set the current uri as the uri argument. + self._current_uri = uri + # Customized features for subclasses. + self._set_uri() + + def _set_uri(self): + """ + Customized features for subclasses' set_uri function. + """ + raise NotImplementedError(f"Please implement the `_set_uri` method.") def list_experiments(self): """ @@ -232,37 +268,43 @@ class MLflowExpManager(ExpManager): Use mlflow to implement ExpManager. """ - def __init__(self, uri, default_exp_name): + def __init__(self, uri: Text, default_exp_name: Optional[Text]): super(MLflowExpManager, self).__init__(uri, default_exp_name) + self._client = None + + def _set_uri(self): + self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri) + logger.info("{:}".format(self._client)) @property def client(self): # Delay the creation of mlflow client in case of creating `mlruns` folder when importing qlib - if not hasattr(self, "_client"): + if self._client is None: self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri) return self._client - def start_exp(self, experiment_name=None, recorder_name=None, uri=None): - # set the tracking uri - if uri is None: - logger.info("No tracking URI is provided. Use the default tracking URI.") - else: - self.uri = uri - # create experiment + def start_exp( + self, experiment_name: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None + ): + # Set the tracking uri + self.set_uri(uri) + # Create experiment experiment, _ = self._get_or_create_exp(experiment_name=experiment_name) - # set up active experiment + # Set up active experiment self.active_experiment = experiment - # start the experiment + # Start the experiment self.active_experiment.start(recorder_name) return self.active_experiment - def end_exp(self, recorder_status: str = Recorder.STATUS_S): + def end_exp(self, recorder_status: Text = Recorder.STATUS_S): if self.active_experiment is not None: self.active_experiment.end(recorder_status) self.active_experiment = None + # When an experiment end, we will release the current uri. + self._current_uri = None - def create_exp(self, experiment_name=None): + def create_exp(self, experiment_name: Optional[Text] = None): assert experiment_name is not None # init experiment experiment_id = self.client.create_experiment(experiment_name) diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index 97f3f986a..d9d684697 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -96,7 +96,6 @@ port_analysis_config = { } -# train def train(): """train model @@ -110,8 +109,8 @@ def train(): # model initiaiton model = init_instance_by_config(task["model"]) - print(model) dataset = init_instance_by_config(task["dataset"]) + # To test __repr__ print(dataset) print(R) @@ -122,6 +121,7 @@ def train(): # prediction recorder = R.get_recorder() + # To test __repr__ print(recorder) rid = recorder.id sr = SignalRecord(model, dataset, recorder) @@ -137,6 +137,27 @@ def train(): return pred_score, {"ic": ic, "ric": ric}, rid +def fake_experiment(): + """A fake experiment workflow to test uri + + Returns + ------- + pass_or_not_for_default_uri: bool + pass_or_not_for_current_uri: bool + temporary_exp_dir: str + """ + + # start exp + default_uri = R.get_uri() + current_uri = "file:./temp-test-exp-mag" + with R.start(experiment_name="fake_workflow_for_expm", uri=current_uri): + R.log_params(**flatten_dict(task)) + + current_uri_to_check = R.get_uri() + default_uri_to_check = R.get_uri() + return default_uri == default_uri_to_check, current_uri == current_uri_to_check, current_uri + + def backtest_analysis(pred, rid): """backtest and analysis @@ -185,6 +206,12 @@ class TestAllFlow(TestAutoData): "backtest failed", ) + def test_2_expmanager(self): + pass_default, pass_current, uri_path = fake_experiment() + self.assertTrue(pass_default, msg="default uri is incorrect") + self.assertTrue(pass_current, msg="current uri is incorrect") + shutil.rmtree(str(Path(uri_path.strip("file:")).resolve())) + def suite(): _suite = unittest.TestSuite()