From 873129aa9b36629bb973bfba578e37a0325d4ec1 Mon Sep 17 00:00:00 2001 From: Young Date: Sat, 2 Oct 2021 08:28:17 +0000 Subject: [PATCH] update fix CI tests bugs --- README.md | 1 - qlib/contrib/workflow/record_temp.py | 10 ++-- qlib/tests/__init__.py | 28 +---------- qlib/tests/config.py | 4 ++ qlib/utils/__init__.py | 19 ++++++++ qlib/workflow/__init__.py | 24 +++++++++- qlib/workflow/record_temp.py | 64 +++++++++++++------------ tests/rolling_tests/test_update_pred.py | 4 +- tests/test_all_pipeline.py | 17 +++++-- tests/test_contrib_workflow.py | 4 +- 10 files changed, 101 insertions(+), 74 deletions(-) diff --git a/README.md b/README.md index 6ceb26e66..820324e3d 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,6 @@ Here is a quick **[demo](https://terminalizer.com/view/3f24561a4470)** shows how This table demonstrates the supported Python version of `Qlib`: | | install with pip | install from source | plot | | ------------- |:---------------------:|:--------------------:|:----:| -| Python 3.6 | :heavy_check_mark: | :heavy_check_mark: (only with `Anaconda`) | :heavy_check_mark: | | Python 3.7 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | Python 3.8 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | Python 3.9 | :x: | :heavy_check_mark: | :x: | diff --git a/qlib/contrib/workflow/record_temp.py b/qlib/contrib/workflow/record_temp.py index bedf89105..e7c80cf6e 100644 --- a/qlib/contrib/workflow/record_temp.py +++ b/qlib/contrib/workflow/record_temp.py @@ -57,22 +57,20 @@ class MultiSegRecord(RecordTemp): ) -class SignalMseRecord(SignalRecord): +class SignalMseRecord(RecordTemp): """ This is the Signal MSE Record class that computes the mean squared error (MSE). This class inherits the ``SignalMseRecord`` class. """ artifact_path = "sig_analysis" + depend_cls = SignalRecord def __init__(self, recorder, **kwargs): super().__init__(recorder=recorder, **kwargs) - def generate(self, **kwargs): - try: - self.check(parent=True) - except FileExistsError: - super().generate() + def generate(self): + self.check() pred = self.load("pred.pkl") label = self.load("label.pkl") diff --git a/qlib/tests/__init__.py b/qlib/tests/__init__.py index cc452ae0f..549c0e752 100644 --- a/qlib/tests/__init__.py +++ b/qlib/tests/__init__.py @@ -44,36 +44,10 @@ class TestAutoData(unittest.TestCase): ) provider_uri_map = {"1min": cls.provider_uri_1min, "day": provider_uri_day} - - client_config = { - "calendar_provider": { - "class": "LocalCalendarProvider", - "module_path": "qlib.data.data", - "kwargs": { - "backend": { - "class": "FileCalendarStorage", - "module_path": "qlib.data.storage.file_storage", - "kwargs": {"provider_uri_map": provider_uri_map}, - } - }, - }, - "feature_provider": { - "class": "LocalFeatureProvider", - "module_path": "qlib.data.data", - "kwargs": { - "backend": { - "class": "FileFeatureStorage", - "module_path": "qlib.data.storage.file_storage", - "kwargs": {"provider_uri_map": provider_uri_map}, - } - }, - }, - } init( - provider_uri=cls.provider_uri, + provider_uri=provider_uri_map, region=REG_CN, expression_cache=None, dataset_cache=None, - **client_config, **cls._setup_kwargs, ) diff --git a/qlib/tests/config.py b/qlib/tests/config.py index c61b5651e..f01715992 100644 --- a/qlib/tests/config.py +++ b/qlib/tests/config.py @@ -35,6 +35,10 @@ RECORD_CONFIG = [ { "class": "SignalRecord", "module_path": "qlib.workflow.record_temp", + "kwargs": { + "dataset": "", + "model": "", + }, }, { "class": "SigAnaRecord", diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index e247ea23b..6a3f871d9 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -289,6 +289,25 @@ def init_instance_by_config( return klass(**cls_kwargs, **kwargs) +@contextlib.contextmanager +def class_casting(obj: object, cls: type): + """ + Python doesn't provide the downcasting mechanism. + We use the trick here to downcast the class + + Parameters + ---------- + obj : object + the object to be cast + cls : type + the target class type + """ + orig_cls = obj.__class__ + obj.__class__ = cls + yield + obj.__class__ = orig_cls + + def compare_dict_value(src_data: dict, dst_data: dict): """Compare dict value diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 51a6ed553..4b16bd387 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from typing import Text, Optional -from .expm import MLflowExpManager +from .expm import ExpManager from .exp import Experiment from .recorder import Recorder from ..utils import Wrapper @@ -16,7 +16,7 @@ class QlibRecorder: """ def __init__(self, exp_manager): - self.exp_manager = exp_manager + self.exp_manager: ExpManager = exp_manager def __repr__(self): return "{name}(manager={manager})".format(name=self.__class__.__name__, manager=self.exp_manager) @@ -334,6 +334,26 @@ class QlibRecorder: """ self.exp_manager.set_uri(uri) + @contextmanager + def uri_context(self, uri: Text): + """ + Temporarily set the exp_manager's uri to uri + + NOTE: + - Please refer to the NOTE in the `set_uri` + + Parameters + ---------- + uri : Text + the temporal uri + """ + prev_uri = self.exp_manager._current_uri + self.exp_manager.set_uri(uri) + try: + yield + finally: + self.exp_manager.set_uri(prev_uri) + def get_recorder( self, *, recorder_id=None, recorder_name=None, experiment_id=None, experiment_name=None ) -> Recorder: diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index bae14d642..4a3898a25 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -16,7 +16,7 @@ from ..data.dataset.handler import DataHandlerLP from ..backtest import backtest as normal_backtest from ..utils import init_instance_by_config, get_module_by_module_path from ..log import get_module_logger -from ..utils import flatten_dict +from ..utils import flatten_dict, class_casting from ..utils.time import Freq from ..strategy.base import BaseStrategy from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec @@ -32,6 +32,7 @@ class RecordTemp: """ artifact_path = None + depend_cls = None # the depend class of the record; the record will depend on the results generated by `depend_cls` @classmethod def get_path(cls, path=None): @@ -98,21 +99,30 @@ class RecordTemp: """ return [] - def check(self, cls="self"): + def check(self, include_self: bool = False): """ Check if the records is properly generated and saved. + It is useful in fololwing examples + - checking if the depended files complete before genrating new things. + - checking if the final files is completed + + Parameters + ---------- + include_self : bool + is the file generated by self included Raise ------ FileExistsError: whether the records are stored properly. """ artifacts = set(self.recorder.list_artifacts()) - if cls == "self": - cls = self - flist = cls.list() - for item in flist: - if item not in artifacts: - raise FileExistsError(item) + if include_self: + for item in self.list(): + if item not in artifacts: + raise FileExistsError(item) + if self.depend_cls is not None: + with class_casting(self, self.depend_cls): + self.check(include_self=True) class SignalRecord(RecordTemp): @@ -127,26 +137,20 @@ class SignalRecord(RecordTemp): @staticmethod def generate_label(dataset): - # NOTE: - # Python doesn't provide the downcasting mechanism. - # We use the trick here to downcast the class - orig_cls = dataset.__class__ - dataset.__class__ = DatasetH - - params = dict(segments="test", col_set="label", data_key=DataHandlerLP.DK_R) - try: - # Assume the backend handler is DataHandlerLP - raw_label = dataset.prepare(**params) - except TypeError: - # The argument number is not right - del params["data_key"] - # The backend handler should be DataHandler - raw_label = dataset.prepare(**params) - except AttributeError: - # The data handler is initialize with `drop_raw=True`... - # So raw_label is not available - raw_label = None - dataset.__class__ = orig_cls + with class_casting(dataset, DatasetH): + params = dict(segments="test", col_set="label", data_key=DataHandlerLP.DK_R) + try: + # Assume the backend handler is DataHandlerLP + raw_label = dataset.prepare(**params) + except TypeError: + # The argument number is not right + del params["data_key"] + # The backend handler should be DataHandler + raw_label = dataset.prepare(**params) + except AttributeError: + # The data handler is initialize with `drop_raw=True`... + # So raw_label is not available + raw_label = None return raw_label def generate(self, **kwargs): @@ -235,7 +239,7 @@ class SigAnaRecord(RecordTemp): """ artifact_path = "sig_analysis" - pre_class = SignalRecord + depend_cls = SignalRecord def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0): super().__init__(recorder=recorder) @@ -244,7 +248,7 @@ class SigAnaRecord(RecordTemp): self.label_col = label_col def generate(self, **kwargs): - self.check(self.pre_class) + self.check() pred = self.load("pred.pkl") label = self.load("label.pkl") diff --git a/tests/rolling_tests/test_update_pred.py b/tests/rolling_tests/test_update_pred.py index 7b900d0b4..b22152fd2 100644 --- a/tests/rolling_tests/test_update_pred.py +++ b/tests/rolling_tests/test_update_pred.py @@ -15,8 +15,6 @@ from qlib.workflow.online.update import LabelUpdater class TestRolling(TestAutoData): - _setup_kwargs = dict(expression_cache=None, dataset_cache=None) - def test_update_pred(self): """ This test is for testing if it will raise error if the `to_date` is out of the boundary. @@ -26,6 +24,7 @@ class TestRolling(TestAutoData): task["record"] = { "class": "SignalRecord", "module_path": "qlib.workflow.record_temp", + "kwargs": {"dataset": "", "model": ""}, } exp_name = "online_srv_test" @@ -65,6 +64,7 @@ class TestRolling(TestAutoData): task["record"] = { "class": "SignalRecord", "module_path": "qlib.workflow.record_temp", + "kwargs": {"dataset": "", "model": ""}, } exp_name = "online_srv_test" diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index 606a0ea3b..da68139a8 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -47,6 +47,7 @@ def train(uri_path: str = None): rid = recorder.id sr = SignalRecord(model, dataset, recorder) sr.generate() + pred_score = sr.load(sr.get_path("pred.pkl")) # calculate ic and ric sar = SigAnaRecord(recorder) @@ -54,7 +55,7 @@ def train(uri_path: str = None): ic = sar.load(sar.get_path("ic.pkl")) ric = sar.load(sar.get_path("ric.pkl")) - return {"ic": ic, "ric": ric}, rid + return pred_score, {"ic": ic, "ric": ric}, rid def train_with_sigana(uri_path: str = None): @@ -73,16 +74,20 @@ def train_with_sigana(uri_path: str = None): with R.start(experiment_name="workflow_with_sigana", uri=uri_path): R.log_params(**flatten_dict(CSI300_GBDT_TASK)) model.fit(dataset) + recorder = R.get_recorder() + + sr = SignalRecord(model, dataset, recorder) + sr.generate() + pred_score = sr.load(sr.get_path("pred.pkl")) # predict and calculate ic and ric - recorder = R.get_recorder() - sar = SigAnaRecord(recorder, model=model, dataset=dataset) + sar = SigAnaRecord(recorder) sar.generate() ic = sar.load(sar.get_path("ic.pkl")) ric = sar.load(sar.get_path("ric.pkl")) uri_path = R.get_uri() - return {"ic": ic, "ric": ric}, uri_path + return pred_score, {"ic": ic, "ric": ric}, uri_path def fake_experiment(): @@ -122,7 +127,9 @@ def backtest_analysis(pred, rid, uri_path: str = None): the analysis result """ - recorder = R.get_recorder(experiment_name="workflow", recorder_id=rid) + with R.uri_context(uri=uri_path): + recorder = R.get_recorder(experiment_name="workflow", recorder_id=rid) + dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) model = recorder.load_object("trained_model") diff --git a/tests/test_contrib_workflow.py b/tests/test_contrib_workflow.py index 2f5feb78e..081cf78ea 100644 --- a/tests/test_contrib_workflow.py +++ b/tests/test_contrib_workflow.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from qlib.workflow.record_temp import SignalRecord import shutil import unittest from pathlib import Path @@ -32,7 +33,8 @@ def train_mse(uri_path: str = None): R.log_params(**flatten_dict(CSI300_GBDT_TASK)) model.fit(dataset) recorder = R.get_recorder() - sr = SignalMseRecord(recorder, model=model, dataset=dataset) + SignalRecord(recorder=recorder, model=model, dataset=dataset).generate() + sr = SignalMseRecord(recorder) sr.generate() uri = R.get_uri() return uri