mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-30 17:41:18 +08:00
Add MultiSegRecord in contrib.workflow and decouple its tests from test_all_pipeline
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from .record_temp import MultiSegRecord
|
||||
from .record_temp import SignalMseRecord
|
||||
|
||||
@@ -5,14 +5,43 @@ import re
|
||||
import pandas as pd
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from pprint import pprint
|
||||
from typing import Dict, Text, Any
|
||||
import numpy as np
|
||||
|
||||
from ...workflow.record_temp import RecordTemp
|
||||
from ...workflow.record_temp import SignalRecord
|
||||
from ...data import dataset as qlib_dataset
|
||||
from ...log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
|
||||
|
||||
class MultiSegRecord(RecordTemp):
|
||||
"""
|
||||
This is the multiple segments signal record class that generates the signal prediction.
|
||||
This class inherits the ``RecordTemp`` class.
|
||||
"""
|
||||
|
||||
def __init__(self, model, dataset, recorder=None):
|
||||
super().__init__(recorder=recorder)
|
||||
if not isinstance(dataset, qlib_dataset.DatasetH):
|
||||
raise ValueError("The type of dataset is not DatasetH instead of {:}".format(type(dataset)))
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
|
||||
def generate(self, segments: Dict[Text, Any], save: bool = False):
|
||||
# generate prediciton
|
||||
for key, segment in segments.items():
|
||||
predics = self.model.predict(self.dataset, segment)
|
||||
if isinstance(pred, pd.Series):
|
||||
predics = predictions.to_frame("score")
|
||||
# self.recorder.save_objects(**{"pred.pkl": pred})
|
||||
labels = self.dataset.prepare(
|
||||
segments=segment, col_set="label", data_key=dataset.handler.DataHandlerLP.DK_R
|
||||
)
|
||||
# compute ic, rank_ic
|
||||
|
||||
|
||||
class SignalMseRecord(SignalRecord):
|
||||
"""
|
||||
This is the Signal MSE Record class that computes the mean squared error (MSE).
|
||||
|
||||
@@ -159,7 +159,10 @@ class Experiment:
|
||||
if create:
|
||||
recorder, is_new = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name)
|
||||
else:
|
||||
recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
|
||||
recorder, is_new = (
|
||||
self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name),
|
||||
False,
|
||||
)
|
||||
if is_new:
|
||||
self.active_recorder = recorder
|
||||
# start the recorder
|
||||
@@ -174,7 +177,10 @@ class Experiment:
|
||||
try:
|
||||
if recorder_id is None and recorder_name is None:
|
||||
recorder_name = self._default_rec_name
|
||||
return self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
|
||||
return (
|
||||
self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name),
|
||||
False,
|
||||
)
|
||||
except ValueError:
|
||||
if recorder_name is None:
|
||||
recorder_name = self._default_rec_name
|
||||
|
||||
@@ -159,7 +159,10 @@ class ExpManager:
|
||||
if create:
|
||||
exp, is_new = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name)
|
||||
else:
|
||||
exp, is_new = self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False
|
||||
exp, is_new = (
|
||||
self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),
|
||||
False,
|
||||
)
|
||||
if is_new:
|
||||
self.active_experiment = exp
|
||||
# start the recorder
|
||||
@@ -172,7 +175,10 @@ class ExpManager:
|
||||
automatically create a new experiment based on the given id and name.
|
||||
"""
|
||||
try:
|
||||
return self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False
|
||||
return (
|
||||
self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),
|
||||
False,
|
||||
)
|
||||
except ValueError:
|
||||
if experiment_name is None:
|
||||
experiment_name = self._default_exp_name
|
||||
|
||||
@@ -39,7 +39,13 @@ class RecordTemp:
|
||||
return "/".join(names)
|
||||
|
||||
def __init__(self, recorder):
|
||||
self.recorder = recorder
|
||||
self._recorder = recorder
|
||||
|
||||
@property
|
||||
def recorder(self):
|
||||
if self._recorder is None:
|
||||
raise ValueError("This RecordTemp did not set recorder yet.")
|
||||
return self._recorder
|
||||
|
||||
def generate(self, **kwargs):
|
||||
"""
|
||||
@@ -248,11 +254,20 @@ class PortAnaRecord(SignalRecord):
|
||||
report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
|
||||
report_normal = report_dict.get("report_df")
|
||||
positions_normal = report_dict.get("positions")
|
||||
self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
self.recorder.save_objects(
|
||||
**{"report_normal.pkl": report_normal},
|
||||
artifact_path=PortAnaRecord.get_path(),
|
||||
)
|
||||
self.recorder.save_objects(
|
||||
**{"positions_normal.pkl": positions_normal},
|
||||
artifact_path=PortAnaRecord.get_path(),
|
||||
)
|
||||
order_normal = report_dict.get("order_list")
|
||||
if order_normal:
|
||||
self.recorder.save_objects(**{"order_normal.pkl": order_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
self.recorder.save_objects(
|
||||
**{"order_normal.pkl": order_normal},
|
||||
artifact_path=PortAnaRecord.get_path(),
|
||||
)
|
||||
|
||||
# analysis
|
||||
analysis = dict()
|
||||
|
||||
@@ -6,24 +6,11 @@ import shutil
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import qlib
|
||||
from qlib.config import REG_CN, C
|
||||
from qlib.utils import drop_nan_by_y_index
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.contrib.workflow.record_temp import SignalMseRecord
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.config import C
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
|
||||
@@ -166,8 +153,6 @@ def train_with_sigana():
|
||||
ric = sar.load(sar.get_path("ric.pkl"))
|
||||
pred_score = sar.load("pred.pkl")
|
||||
|
||||
smr = SignalMseRecord(recorder)
|
||||
smr.generate()
|
||||
uri_path = R.get_uri()
|
||||
return pred_score, {"ic": ic, "ric": ric}, uri_path
|
||||
|
||||
@@ -256,8 +241,10 @@ class TestAllFlow(TestAutoData):
|
||||
|
||||
def suite():
|
||||
_suite = unittest.TestSuite()
|
||||
_suite.addTest(TestAllFlow("test_0_train"))
|
||||
_suite.addTest(TestAllFlow("test_1_backtest"))
|
||||
_suite.addTest(TestAllFlow("test_0_train_with_sigana"))
|
||||
_suite.addTest(TestAllFlow("test_1_train"))
|
||||
_suite.addTest(TestAllFlow("test_2_backtest"))
|
||||
_suite.addTest(TestAllFlow("test_3_expmanager"))
|
||||
return _suite
|
||||
|
||||
|
||||
|
||||
97
tests/test_contrib_workflow.py
Normal file
97
tests/test_contrib_workflow.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import shutil
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
from qlib.config import C
|
||||
from qlib.contrib.workflow import MultiSegRecord, SignalMseRecord
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_multiseg():
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
model.fit(dataset)
|
||||
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
sr = MultiSegRecord(model, dataset, recorder)
|
||||
sr.generate(dict(valid="valid", test="test"))
|
||||
|
||||
uri = R.get_uri()
|
||||
|
||||
return uri
|
||||
|
||||
|
||||
class TestAllFlow(TestAutoData):
|
||||
def test_0_multiseg(self):
|
||||
uri_path = test_multiseg()
|
||||
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
|
||||
|
||||
|
||||
def suite():
|
||||
_suite = unittest.TestSuite()
|
||||
_suite.addTest(TestAllFlow("test_0_multiseg"))
|
||||
return _suite
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runner = unittest.TextTestRunner()
|
||||
runner.run(suite())
|
||||
Reference in New Issue
Block a user