mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 20:11:08 +08:00
Merge pull request #345 from D-X-Y/main
Fix errors when SignalRecord is not called before SigAna/PortAna
This commit is contained in:
0
qlib/contrib/workflow/__init__.py
Normal file
0
qlib/contrib/workflow/__init__.py
Normal file
45
qlib/contrib/workflow/record_temp.py
Normal file
45
qlib/contrib/workflow/record_temp.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import pandas as pd
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from pprint import pprint
|
||||
import numpy as np
|
||||
|
||||
from ...workflow.record_temp import SignalRecord
|
||||
from ...log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
|
||||
|
||||
class SignalMseRecord(SignalRecord):
|
||||
"""
|
||||
This is the Signal MSE Record class that computes the mean squared error (MSE).
|
||||
This class inherits the ``SignalMseRecord`` class.
|
||||
"""
|
||||
|
||||
artifact_path = "sig_analysis"
|
||||
|
||||
def __init__(self, recorder, **kwargs):
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
|
||||
def generate(self, **kwargs):
|
||||
try:
|
||||
self.check(parent=True)
|
||||
except FileExistsError:
|
||||
super().generate()
|
||||
|
||||
pred = self.load("pred.pkl")
|
||||
label = self.load("label.pkl")
|
||||
masks = ~np.isnan(label.values)
|
||||
mse = mean_squared_error(pred.values[masks], label[masks])
|
||||
metrics = {"MSE": mse, "RMSE": np.sqrt(mse)}
|
||||
objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)}
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
pprint(metrics)
|
||||
|
||||
def list(self):
|
||||
paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")]
|
||||
return paths
|
||||
@@ -110,7 +110,7 @@ class SignalRecord(RecordTemp):
|
||||
This is the Signal Record class that generates the signal prediction. This class inherits the ``RecordTemp`` class.
|
||||
"""
|
||||
|
||||
def __init__(self, model=None, dataset=None, recorder=None, **kwargs):
|
||||
def __init__(self, model=None, dataset=None, recorder=None):
|
||||
super().__init__(recorder=recorder)
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
@@ -164,13 +164,15 @@ class SigAnaRecord(SignalRecord):
|
||||
artifact_path = "sig_analysis"
|
||||
|
||||
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, **kwargs):
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
self.ana_long_short = ana_long_short
|
||||
self.ann_scaler = ann_scaler
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
# The name must be unique. Otherwise it will be overridden
|
||||
|
||||
def generate(self):
|
||||
self.check(parent=True)
|
||||
def generate(self, **kwargs):
|
||||
try:
|
||||
self.check(parent=True)
|
||||
except FileExistsError:
|
||||
super().generate()
|
||||
|
||||
pred = self.load("pred.pkl")
|
||||
label = self.load("label.pkl")
|
||||
@@ -228,7 +230,7 @@ class PortAnaRecord(SignalRecord):
|
||||
config["backtest"] : dict
|
||||
define the backtest kwargs.
|
||||
"""
|
||||
super().__init__(recorder=recorder)
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
|
||||
self.strategy_config = config["strategy"]
|
||||
self.backtest_config = config["backtest"]
|
||||
@@ -236,10 +238,13 @@ class PortAnaRecord(SignalRecord):
|
||||
|
||||
def generate(self, **kwargs):
|
||||
# check previously stored prediction results
|
||||
self.check(parent=True) # "Make sure the parent process is completed and store the data properly."
|
||||
try:
|
||||
self.check(parent=True) # "Make sure the parent process is completed and store the data properly."
|
||||
except FileExistsError:
|
||||
super().generate()
|
||||
|
||||
# custom strategy and get backtest
|
||||
pred_score = super().load()
|
||||
pred_score = super().load("pred.pkl")
|
||||
report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
|
||||
report_normal = report_dict.get("report_df")
|
||||
positions_normal = report_dict.get("positions")
|
||||
|
||||
@@ -19,6 +19,7 @@ from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.contrib.workflow.record_temp import SignalMseRecord
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord
|
||||
@@ -139,6 +140,38 @@ def train():
|
||||
return pred_score, {"ic": ic, "ric": ric}, rid
|
||||
|
||||
|
||||
def train_with_sigana():
|
||||
"""train model followed by SigAnaRecord
|
||||
|
||||
Returns
|
||||
-------
|
||||
pred_score: pandas.DataFrame
|
||||
predict scores
|
||||
performance: dict
|
||||
model performance
|
||||
"""
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow_with_sigana"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
model.fit(dataset)
|
||||
|
||||
# predict and calculate ic and ric
|
||||
recorder = R.get_recorder()
|
||||
sar = SigAnaRecord(recorder, model=model, dataset=dataset)
|
||||
sar.generate()
|
||||
ic = sar.load(sar.get_path("ic.pkl"))
|
||||
ric = sar.load(sar.get_path("ric.pkl"))
|
||||
pred_score = sar.load("pred.pkl")
|
||||
|
||||
smr = SignalMseRecord(recorder)
|
||||
smr.generate()
|
||||
uri_path = R.get_uri()
|
||||
return pred_score, {"ic": ic, "ric": ric}, uri_path
|
||||
|
||||
|
||||
def fake_experiment():
|
||||
"""A fake experiment workflow to test uri
|
||||
|
||||
@@ -195,12 +228,18 @@ class TestAllFlow(TestAutoData):
|
||||
def tearDownClass(cls) -> None:
|
||||
shutil.rmtree(str(Path(C["exp_manager"]["kwargs"]["uri"].strip("file:")).resolve()))
|
||||
|
||||
def test_0_train(self):
|
||||
def test_0_train_with_sigana(self):
|
||||
TestAllFlow.PRED_SCORE, ic_ric, uri_path = train_with_sigana()
|
||||
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
|
||||
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
|
||||
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
|
||||
|
||||
def test_1_train(self):
|
||||
TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train()
|
||||
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
|
||||
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
|
||||
|
||||
def test_1_backtest(self):
|
||||
def test_2_backtest(self):
|
||||
analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID)
|
||||
self.assertGreaterEqual(
|
||||
analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0],
|
||||
@@ -208,7 +247,7 @@ class TestAllFlow(TestAutoData):
|
||||
"backtest failed",
|
||||
)
|
||||
|
||||
def test_2_expmanager(self):
|
||||
def test_3_expmanager(self):
|
||||
pass_default, pass_current, uri_path = fake_experiment()
|
||||
self.assertTrue(pass_default, msg="default uri is incorrect")
|
||||
self.assertTrue(pass_current, msg="current uri is incorrect")
|
||||
|
||||
Reference in New Issue
Block a user