diff --git a/qlib/contrib/workflow/__init__.py b/qlib/contrib/workflow/__init__.py index e69de29bb..9945e179c 100644 --- a/qlib/contrib/workflow/__init__.py +++ b/qlib/contrib/workflow/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from .record_temp import MultiSegRecord +from .record_temp import SignalMseRecord diff --git a/qlib/contrib/workflow/record_temp.py b/qlib/contrib/workflow/record_temp.py index 3fdf0c281..4baa15faa 100644 --- a/qlib/contrib/workflow/record_temp.py +++ b/qlib/contrib/workflow/record_temp.py @@ -5,14 +5,43 @@ import re import pandas as pd from sklearn.metrics import mean_squared_error from pprint import pprint +from typing import Dict, Text, Any import numpy as np +from ...workflow.record_temp import RecordTemp from ...workflow.record_temp import SignalRecord +from ...data import dataset as qlib_dataset from ...log import get_module_logger logger = get_module_logger("workflow", "INFO") +class MultiSegRecord(RecordTemp): + """ + This is the multiple segments signal record class that generates the signal prediction. + This class inherits the ``RecordTemp`` class. + """ + + def __init__(self, model, dataset, recorder=None): + super().__init__(recorder=recorder) + if not isinstance(dataset, qlib_dataset.DatasetH): + raise ValueError("The type of dataset is not DatasetH instead of {:}".format(type(dataset))) + self.model = model + self.dataset = dataset + + def generate(self, segments: Dict[Text, Any], save: bool = False): + # generate prediciton + for key, segment in segments.items(): + predics = self.model.predict(self.dataset, segment) + if isinstance(pred, pd.Series): + predics = predictions.to_frame("score") + # self.recorder.save_objects(**{"pred.pkl": pred}) + labels = self.dataset.prepare( + segments=segment, col_set="label", data_key=dataset.handler.DataHandlerLP.DK_R + ) + # compute ic, rank_ic + + class SignalMseRecord(SignalRecord): """ This is the Signal MSE Record class that computes the mean squared error (MSE). diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index 5ed4362de..0f420cec4 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -159,7 +159,10 @@ class Experiment: if create: recorder, is_new = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name) else: - recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False + recorder, is_new = ( + self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), + False, + ) if is_new: self.active_recorder = recorder # start the recorder @@ -174,7 +177,10 @@ class Experiment: try: if recorder_id is None and recorder_name is None: recorder_name = self._default_rec_name - return self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False + return ( + self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), + False, + ) except ValueError: if recorder_name is None: recorder_name = self._default_rec_name diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 95cad4c6e..28d6d92c7 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -159,7 +159,10 @@ class ExpManager: if create: exp, is_new = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name) else: - exp, is_new = self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False + exp, is_new = ( + self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), + False, + ) if is_new: self.active_experiment = exp # start the recorder @@ -172,7 +175,10 @@ class ExpManager: automatically create a new experiment based on the given id and name. """ try: - return self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False + return ( + self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), + False, + ) except ValueError: if experiment_name is None: experiment_name = self._default_exp_name diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 2c1b6fecc..ed8039ac8 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -39,7 +39,13 @@ class RecordTemp: return "/".join(names) def __init__(self, recorder): - self.recorder = recorder + self._recorder = recorder + + @property + def recorder(self): + if self._recorder is None: + raise ValueError("This RecordTemp did not set recorder yet.") + return self._recorder def generate(self, **kwargs): """ @@ -248,11 +254,20 @@ class PortAnaRecord(SignalRecord): report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config) report_normal = report_dict.get("report_df") positions_normal = report_dict.get("positions") - self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()) - self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path()) + self.recorder.save_objects( + **{"report_normal.pkl": report_normal}, + artifact_path=PortAnaRecord.get_path(), + ) + self.recorder.save_objects( + **{"positions_normal.pkl": positions_normal}, + artifact_path=PortAnaRecord.get_path(), + ) order_normal = report_dict.get("order_list") if order_normal: - self.recorder.save_objects(**{"order_normal.pkl": order_normal}, artifact_path=PortAnaRecord.get_path()) + self.recorder.save_objects( + **{"order_normal.pkl": order_normal}, + artifact_path=PortAnaRecord.get_path(), + ) # analysis analysis = dict() diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index 29d39179d..d34c1773a 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -6,24 +6,11 @@ import shutil import unittest from pathlib import Path -import numpy as np -import pandas as pd - import qlib -from qlib.config import REG_CN, C -from qlib.utils import drop_nan_by_y_index -from qlib.contrib.model.gbdt import LGBModel -from qlib.contrib.data.handler import Alpha158 -from qlib.contrib.strategy.strategy import TopkDropoutStrategy -from qlib.contrib.evaluate import ( - backtest as normal_backtest, - risk_analysis, -) -from qlib.contrib.workflow.record_temp import SignalMseRecord -from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict +from qlib.config import C +from qlib.utils import init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord -from qlib.tests.data import GetData from qlib.tests import TestAutoData @@ -166,8 +153,6 @@ def train_with_sigana(): ric = sar.load(sar.get_path("ric.pkl")) pred_score = sar.load("pred.pkl") - smr = SignalMseRecord(recorder) - smr.generate() uri_path = R.get_uri() return pred_score, {"ic": ic, "ric": ric}, uri_path @@ -256,8 +241,10 @@ class TestAllFlow(TestAutoData): def suite(): _suite = unittest.TestSuite() - _suite.addTest(TestAllFlow("test_0_train")) - _suite.addTest(TestAllFlow("test_1_backtest")) + _suite.addTest(TestAllFlow("test_0_train_with_sigana")) + _suite.addTest(TestAllFlow("test_1_train")) + _suite.addTest(TestAllFlow("test_2_backtest")) + _suite.addTest(TestAllFlow("test_3_expmanager")) return _suite diff --git a/tests/test_contrib_workflow.py b/tests/test_contrib_workflow.py new file mode 100644 index 000000000..92ed7e8d1 --- /dev/null +++ b/tests/test_contrib_workflow.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +import shutil +import unittest +from pathlib import Path + +import qlib +from qlib.config import C +from qlib.contrib.workflow import MultiSegRecord, SignalMseRecord +from qlib.utils import init_instance_by_config, flatten_dict +from qlib.workflow import R +from qlib.tests import TestAutoData + + +market = "csi300" +benchmark = "SH000300" + +################################### +# train model +################################### +data_handler_config = { + "start_time": "2008-01-01", + "end_time": "2020-08-01", + "fit_start_time": "2008-01-01", + "fit_end_time": "2014-12-31", + "instruments": market, +} + +task = { + "model": { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + "kwargs": { + "loss": "mse", + "colsample_bytree": 0.8879, + "learning_rate": 0.0421, + "subsample": 0.8789, + "lambda_l1": 205.6999, + "lambda_l2": 580.9768, + "max_depth": 8, + "num_leaves": 210, + "num_threads": 20, + }, + }, + "dataset": { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "Alpha158", + "module_path": "qlib.contrib.data.handler", + "kwargs": data_handler_config, + }, + "segments": { + "train": ("2008-01-01", "2014-12-31"), + "valid": ("2015-01-01", "2016-12-31"), + "test": ("2017-01-01", "2020-08-01"), + }, + }, + }, +} + + +def test_multiseg(): + model = init_instance_by_config(task["model"]) + dataset = init_instance_by_config(task["dataset"]) + with R.start(experiment_name="workflow"): + R.log_params(**flatten_dict(task)) + model.fit(dataset) + + # prediction + recorder = R.get_recorder() + sr = MultiSegRecord(model, dataset, recorder) + sr.generate(dict(valid="valid", test="test")) + + uri = R.get_uri() + + return uri + + +class TestAllFlow(TestAutoData): + def test_0_multiseg(self): + uri_path = test_multiseg() + shutil.rmtree(str(Path(uri_path.strip("file:")).resolve())) + + +def suite(): + _suite = unittest.TestSuite() + _suite.addTest(TestAllFlow("test_0_multiseg")) + return _suite + + +if __name__ == "__main__": + runner = unittest.TextTestRunner() + runner.run(suite())