diff --git a/.gitignore b/.gitignore index 5b3745a02..0ddd5d21f 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,5 @@ tags .pytest_cache/ .vscode/ + +*.swp diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 8ff8c1210..ecbeebc95 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -1,5 +1,5 @@ from ...utils.serial import Serializable -from typing import Union, List, Tuple +from typing import Union, List, Tuple, Dict, Text, Optional from ...utils import init_instance_by_config, np_ffill from ...log import get_module_logger from .handler import DataHandler, DataHandlerLP @@ -76,7 +76,7 @@ class DatasetH(Dataset): - The processing is related to data split. """ - def __init__(self, handler: Union[dict, DataHandler], segments: dict): + def __init__(self, handler: Union[Dict, DataHandler], segments: Dict): """ Parameters ---------- @@ -87,7 +87,7 @@ class DatasetH(Dataset): """ super().__init__(handler, segments) - def init(self, handler_kwargs: dict = None, segment_kwargs: dict = None): + def init(self, handler_kwargs: Optional[Dict] = None, segment_kwargs: Optional[Dict] = None): """ Initialize the DatasetH @@ -124,7 +124,7 @@ class DatasetH(Dataset): raise TypeError(f"param handler_kwargs must be type dict, not {type(segment_kwargs)}") self.segments = segment_kwargs.copy() - def setup_data(self, handler: Union[dict, DataHandler], segments: dict): + def setup_data(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple]): """ Setup the underlying data. @@ -156,6 +156,11 @@ class DatasetH(Dataset): self.handler = init_instance_by_config(handler, accept_types=DataHandler) self.segments = segments.copy() + def __repr__(self): + return "{name}(handler={handler}, segments={segments})".format( + name=self.__class__.__name__, handler=self.handler, segments=self.segments + ) + def _prepare_seg(self, slc: slice, **kwargs): """ Give a slice, retrieve the according data @@ -168,7 +173,7 @@ class DatasetH(Dataset): def prepare( self, - segments: Union[List[str], Tuple[str], str, slice], + segments: Union[List[Text], Tuple[Text], Text, slice], col_set=DataHandler.CS_ALL, data_key=DataHandlerLP.DK_I, **kwargs, @@ -178,7 +183,7 @@ class DatasetH(Dataset): Parameters ---------- - segments : Union[List[str], Tuple[str], str, slice] + segments : Union[List[Text], Tuple[Text], Text, slice] Describe the scope of the data to be prepared Here are some examples: diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 2889c4465..050043ba6 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -35,7 +35,7 @@ class DataHandler(Serializable): The data handler try to maintain a handler with 2 level. `datetime` & `instruments`. - Any order of the index level can be suported(The order will implied in the data). + Any order of the index level can be suported (The order will be implied in the data). The order <`datetime`, `instruments`> will be used when the dataframe index name is missed. Example of the data: @@ -47,8 +47,8 @@ class DataHandler(Serializable): $close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0 datetime instrument 2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032 - SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042 - SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289 + SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042 + SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289 """ diff --git a/qlib/data/ops.py b/qlib/data/ops.py index 94f3ff5b8..8bc7e1fa7 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -1489,7 +1489,7 @@ OpsList = [ ] -class OpsWrapper(object): +class OpsWrapper: """Ops Wrapper""" def __init__(self): diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index aedf73a9c..7bff505ce 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -16,8 +16,11 @@ class QlibRecorder: def __init__(self, exp_manager): self.exp_manager = exp_manager + def __repr__(self): + return "{name}(manager={manager})".format(name=self.__class__.__name__, manager=self.exp_manager) + @contextmanager - def start(self, experiment_name=None, recorder_name=None): + def start(self, experiment_name=None, recorder_name=None, uri=None): """ Method to start an experiment. This method can only be called within a Python's `with` statement. Here is the example code: @@ -34,8 +37,13 @@ class QlibRecorder: name of the experiment one wants to start. recorder_name : str name of the recorder under the experiment one wants to start. + uri : str + The tracking uri of the experiment, where all the artifacts/metrics etc. will be stored. + The default uri is set in the qlib.config. Note that this uri argument will not change the one defined in the config file. + Therefore, the next time when users call this function in the same experiment, + they have to also specify this argument with the same value. Otherwise, inconsistent uri may occur. """ - run = self.start_exp(experiment_name, recorder_name) + run = self.start_exp(experiment_name, recorder_name, uri) try: yield run except Exception as e: @@ -272,7 +280,7 @@ class QlibRecorder: ------- The uri of current experiment manager. """ - return self.exp_manager.get_uri() + return self.exp_manager.uri def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None): """ diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index c2548971a..18b0a143d 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -23,7 +23,7 @@ class Experiment: self.active_recorder = None # only one recorder can running each time def __repr__(self): - return str(self.info) + return "{name}(info={info})".format(name=self.__class__.__name__, info=self.info) def __str__(self): return str(self.info) @@ -173,11 +173,9 @@ class MLflowExperiment(Experiment): self._uri = uri self._default_name = None self._default_rec_name = "mlflow_recorder" - self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) + self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) def start(self, recorder_name=None): - # set the active experiment - mlflow.set_experiment(self.name) logger.info(f"Experiment {self.id} starts running ...") # set up recorder recorder = self.create_recorder(recorder_name) @@ -210,7 +208,6 @@ class MLflowExperiment(Experiment): else: recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False if is_new: - mlflow.set_experiment(self.name) self.active_recorder = recorder # start the recorder self.active_recorder.start_run() @@ -239,7 +236,7 @@ class MLflowExperiment(Experiment): ), "Please input at least one of recorder id or name before retrieving recorder." if recorder_id is not None: try: - run = self.client.get_run(recorder_id) + run = self._client.get_run(recorder_id) recorder = MLflowRecorder(self.id, self._uri, mlflow_run=run) return recorder except MlflowException: @@ -260,7 +257,7 @@ class MLflowExperiment(Experiment): max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results") order_by = kwargs.get("order_by") - return self.client.search_runs([self.id], filter_string, run_view_type, max_results, order_by) + return self._client.search_runs([self.id], filter_string, run_view_type, max_results, order_by) def delete_recorder(self, recorder_id=None, recorder_name=None): assert ( @@ -268,10 +265,10 @@ class MLflowExperiment(Experiment): ), "Please input a valid recorder id or name before deleting." try: if recorder_id is not None: - self.client.delete_run(recorder_id) + self._client.delete_run(recorder_id) else: recorder = self._get_recorder(recorder_name=recorder_name) - self.client.delete_run(recorder.id) + self._client.delete_run(recorder.id) except MlflowException as e: raise Exception( f"Error: {e}. Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct." @@ -280,7 +277,7 @@ class MLflowExperiment(Experiment): UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!! def list_recorders(self, max_results=UNLIMITED): - runs = self.client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1] + runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1] recorders = dict() for i in range(len(runs)): recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i]) diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index a50dce7c9..4ba72a634 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -7,8 +7,10 @@ from mlflow.entities import ViewType import os from pathlib import Path from contextlib import contextmanager +from typing import Optional, Text + from .exp import MLflowExperiment, Experiment -from .recorder import Recorder, MLflowRecorder +from .recorder import Recorder from ..log import get_module_logger logger = get_module_logger("workflow", "INFO") @@ -20,12 +22,24 @@ class ExpManager: (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) """ - def __init__(self, uri, default_exp_name): - self.uri = uri + def __init__(self, uri: Text, default_exp_name: Optional[Text]): + self._default_uri = uri + self._current_uri = None self.default_exp_name = default_exp_name self.active_experiment = None # only one experiment can active each time - def start_exp(self, experiment_name=None, recorder_name=None, uri=None, **kwargs): + def __repr__(self): + return "{name}(default_uri={duri}, current_uri={curi})".format( + name=self.__class__.__name__, duri=self._default_uri, curi=self._current_uri + ) + + def start_exp( + self, + experiment_name: Optional[Text] = None, + recorder_name: Optional[Text] = None, + uri: Optional[Text] = None, + **kwargs, + ): """ Start an experiment. This method includes first get_or_create an experiment, and then set it to be active. @@ -45,7 +59,7 @@ class ExpManager: """ raise NotImplementedError(f"Please implement the `start_exp` method.") - def end_exp(self, recorder_status: str = Recorder.STATUS_S, **kwargs): + def end_exp(self, recorder_status: Text = Recorder.STATUS_S, **kwargs): """ End an active experiment. @@ -58,7 +72,7 @@ class ExpManager: """ raise NotImplementedError(f"Please implement the `end_exp` method.") - def create_exp(self, experiment_name=None): + def create_exp(self, experiment_name: Optional[Text] = None): """ Create an experiment. @@ -203,7 +217,8 @@ class ExpManager: """ raise NotImplementedError(f"Please implement the `delete_exp` method.") - def get_uri(self): + @property + def uri(self): """ Get the default tracking URI or current URI. @@ -211,7 +226,31 @@ class ExpManager: ------- The tracking URI string. """ - return self.uri + return self._current_uri or self._default_uri + + def set_uri(self, uri: Optional[Text] = None): + """ + Set the current tracking URI and the corresponding variables. + + Parameters + ---------- + uri : str + + """ + if uri is None: + logger.info("No tracking URI is provided. Use the default tracking URI.") + self._current_uri = self._default_uri + else: + # Temporarily re-set the current uri as the uri argument. + self._current_uri = uri + # Customized features for subclasses. + self._set_uri() + + def _set_uri(self): + """ + Customized features for subclasses' set_uri function. + """ + raise NotImplementedError(f"Please implement the `_set_uri` method.") def list_experiments(self): """ @@ -229,37 +268,43 @@ class MLflowExpManager(ExpManager): Use mlflow to implement ExpManager. """ - def __init__(self, uri, default_exp_name): + def __init__(self, uri: Text, default_exp_name: Optional[Text]): super(MLflowExpManager, self).__init__(uri, default_exp_name) + self._client = None + + def _set_uri(self): + self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri) + logger.info("{:}".format(self._client)) @property def client(self): # Delay the creation of mlflow client in case of creating `mlruns` folder when importing qlib - if not hasattr(self, "_client"): + if self._client is None: self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri) return self._client - def start_exp(self, experiment_name=None, recorder_name=None, uri=None): - # set the tracking uri - if uri is None: - logger.info("No tracking URI is provided. Use the default tracking URI.") - else: - self.uri = uri - # create experiment + def start_exp( + self, experiment_name: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None + ): + # Set the tracking uri + self.set_uri(uri) + # Create experiment experiment, _ = self._get_or_create_exp(experiment_name=experiment_name) - # set up active experiment + # Set up active experiment self.active_experiment = experiment - # start the experiment + # Start the experiment self.active_experiment.start(recorder_name) return self.active_experiment - def end_exp(self, recorder_status: str = Recorder.STATUS_S): + def end_exp(self, recorder_status: Text = Recorder.STATUS_S): if self.active_experiment is not None: self.active_experiment.end(recorder_status) self.active_experiment = None + # When an experiment end, we will release the current uri. + self._current_uri = None - def create_exp(self, experiment_name=None): + def create_exp(self, experiment_name: Optional[Text] = None): assert experiment_name is not None # init experiment experiment_id = self.client.create_experiment(experiment_name) diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index ceb57150c..e75ae347b 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -34,7 +34,7 @@ class Recorder: self.status = Recorder.STATUS_S def __repr__(self): - return str(self.info) + return "{name}(info={info})".format(name=self.__class__.__name__, info=self.info) def __str__(self): return str(self.info) diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index f6e77cba4..d9d684697 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -96,7 +96,6 @@ port_analysis_config = { } -# train def train(): """train model @@ -111,6 +110,9 @@ def train(): # model initiaiton model = init_instance_by_config(task["model"]) dataset = init_instance_by_config(task["dataset"]) + # To test __repr__ + print(dataset) + print(R) # start exp with R.start(experiment_name="workflow"): @@ -119,6 +121,8 @@ def train(): # prediction recorder = R.get_recorder() + # To test __repr__ + print(recorder) rid = recorder.id sr = SignalRecord(model, dataset, recorder) sr.generate() @@ -133,6 +137,27 @@ def train(): return pred_score, {"ic": ic, "ric": ric}, rid +def fake_experiment(): + """A fake experiment workflow to test uri + + Returns + ------- + pass_or_not_for_default_uri: bool + pass_or_not_for_current_uri: bool + temporary_exp_dir: str + """ + + # start exp + default_uri = R.get_uri() + current_uri = "file:./temp-test-exp-mag" + with R.start(experiment_name="fake_workflow_for_expm", uri=current_uri): + R.log_params(**flatten_dict(task)) + + current_uri_to_check = R.get_uri() + default_uri_to_check = R.get_uri() + return default_uri == default_uri_to_check, current_uri == current_uri_to_check, current_uri + + def backtest_analysis(pred, rid): """backtest and analysis @@ -181,6 +206,12 @@ class TestAllFlow(TestAutoData): "backtest failed", ) + def test_2_expmanager(self): + pass_default, pass_current, uri_path = fake_experiment() + self.assertTrue(pass_default, msg="default uri is incorrect") + self.assertTrue(pass_current, msg="current uri is incorrect") + shutil.rmtree(str(Path(uri_path.strip("file:")).resolve())) + def suite(): _suite = unittest.TestSuite()