1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-30 17:41:18 +08:00

Add MultiSegRecord in contrib.workflow and decouple its tests from test_all_pipeline

This commit is contained in:
D-X-Y
2021-03-28 00:33:59 -07:00
parent 0387eaf7ab
commit 9d04ae4676
7 changed files with 171 additions and 27 deletions

View File

@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .record_temp import MultiSegRecord
from .record_temp import SignalMseRecord

View File

@@ -5,14 +5,43 @@ import re
import pandas as pd
from sklearn.metrics import mean_squared_error
from pprint import pprint
from typing import Dict, Text, Any
import numpy as np
from ...workflow.record_temp import RecordTemp
from ...workflow.record_temp import SignalRecord
from ...data import dataset as qlib_dataset
from ...log import get_module_logger
logger = get_module_logger("workflow", "INFO")
class MultiSegRecord(RecordTemp):
"""
This is the multiple segments signal record class that generates the signal prediction.
This class inherits the ``RecordTemp`` class.
"""
def __init__(self, model, dataset, recorder=None):
super().__init__(recorder=recorder)
if not isinstance(dataset, qlib_dataset.DatasetH):
raise ValueError("The type of dataset is not DatasetH instead of {:}".format(type(dataset)))
self.model = model
self.dataset = dataset
def generate(self, segments: Dict[Text, Any], save: bool = False):
# generate prediciton
for key, segment in segments.items():
predics = self.model.predict(self.dataset, segment)
if isinstance(pred, pd.Series):
predics = predictions.to_frame("score")
# self.recorder.save_objects(**{"pred.pkl": pred})
labels = self.dataset.prepare(
segments=segment, col_set="label", data_key=dataset.handler.DataHandlerLP.DK_R
)
# compute ic, rank_ic
class SignalMseRecord(SignalRecord):
"""
This is the Signal MSE Record class that computes the mean squared error (MSE).

View File

@@ -159,7 +159,10 @@ class Experiment:
if create:
recorder, is_new = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name)
else:
recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
recorder, is_new = (
self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name),
False,
)
if is_new:
self.active_recorder = recorder
# start the recorder
@@ -174,7 +177,10 @@ class Experiment:
try:
if recorder_id is None and recorder_name is None:
recorder_name = self._default_rec_name
return self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
return (
self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name),
False,
)
except ValueError:
if recorder_name is None:
recorder_name = self._default_rec_name

View File

@@ -159,7 +159,10 @@ class ExpManager:
if create:
exp, is_new = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name)
else:
exp, is_new = self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False
exp, is_new = (
self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),
False,
)
if is_new:
self.active_experiment = exp
# start the recorder
@@ -172,7 +175,10 @@ class ExpManager:
automatically create a new experiment based on the given id and name.
"""
try:
return self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False
return (
self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),
False,
)
except ValueError:
if experiment_name is None:
experiment_name = self._default_exp_name

View File

@@ -39,7 +39,13 @@ class RecordTemp:
return "/".join(names)
def __init__(self, recorder):
self.recorder = recorder
self._recorder = recorder
@property
def recorder(self):
if self._recorder is None:
raise ValueError("This RecordTemp did not set recorder yet.")
return self._recorder
def generate(self, **kwargs):
"""
@@ -248,11 +254,20 @@ class PortAnaRecord(SignalRecord):
report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
report_normal = report_dict.get("report_df")
positions_normal = report_dict.get("positions")
self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=PortAnaRecord.get_path())
self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path())
self.recorder.save_objects(
**{"report_normal.pkl": report_normal},
artifact_path=PortAnaRecord.get_path(),
)
self.recorder.save_objects(
**{"positions_normal.pkl": positions_normal},
artifact_path=PortAnaRecord.get_path(),
)
order_normal = report_dict.get("order_list")
if order_normal:
self.recorder.save_objects(**{"order_normal.pkl": order_normal}, artifact_path=PortAnaRecord.get_path())
self.recorder.save_objects(
**{"order_normal.pkl": order_normal},
artifact_path=PortAnaRecord.get_path(),
)
# analysis
analysis = dict()

View File

@@ -6,24 +6,11 @@ import shutil
import unittest
from pathlib import Path
import numpy as np
import pandas as pd
import qlib
from qlib.config import REG_CN, C
from qlib.utils import drop_nan_by_y_index
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from qlib.contrib.workflow.record_temp import SignalMseRecord
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.config import C
from qlib.utils import init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord
from qlib.tests.data import GetData
from qlib.tests import TestAutoData
@@ -166,8 +153,6 @@ def train_with_sigana():
ric = sar.load(sar.get_path("ric.pkl"))
pred_score = sar.load("pred.pkl")
smr = SignalMseRecord(recorder)
smr.generate()
uri_path = R.get_uri()
return pred_score, {"ic": ic, "ric": ric}, uri_path
@@ -256,8 +241,10 @@ class TestAllFlow(TestAutoData):
def suite():
_suite = unittest.TestSuite()
_suite.addTest(TestAllFlow("test_0_train"))
_suite.addTest(TestAllFlow("test_1_backtest"))
_suite.addTest(TestAllFlow("test_0_train_with_sigana"))
_suite.addTest(TestAllFlow("test_1_train"))
_suite.addTest(TestAllFlow("test_2_backtest"))
_suite.addTest(TestAllFlow("test_3_expmanager"))
return _suite

View File

@@ -0,0 +1,97 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import shutil
import unittest
from pathlib import Path
import qlib
from qlib.config import C
from qlib.contrib.workflow import MultiSegRecord, SignalMseRecord
from qlib.utils import init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.tests import TestAutoData
market = "csi300"
benchmark = "SH000300"
###################################
# train model
###################################
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
}
task = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": 0.8879,
"learning_rate": 0.0421,
"subsample": 0.8789,
"lambda_l1": 205.6999,
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
},
},
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
}
def test_multiseg():
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
with R.start(experiment_name="workflow"):
R.log_params(**flatten_dict(task))
model.fit(dataset)
# prediction
recorder = R.get_recorder()
sr = MultiSegRecord(model, dataset, recorder)
sr.generate(dict(valid="valid", test="test"))
uri = R.get_uri()
return uri
class TestAllFlow(TestAutoData):
def test_0_multiseg(self):
uri_path = test_multiseg()
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
def suite():
_suite = unittest.TestSuite()
_suite.addTest(TestAllFlow("test_0_multiseg"))
return _suite
if __name__ == "__main__":
runner = unittest.TextTestRunner()
runner.run(suite())