1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 19:10:58 +08:00

Fix errors when SignalRecord is not called before SigAna/PortAna

This commit is contained in:
D-X-Y
2021-03-16 08:11:05 +00:00
parent d47e35d64e
commit 9f57681032
2 changed files with 46 additions and 9 deletions

View File

@@ -110,7 +110,7 @@ class SignalRecord(RecordTemp):
This is the Signal Record class that generates the signal prediction. This class inherits the ``RecordTemp`` class.
"""
def __init__(self, model=None, dataset=None, recorder=None, **kwargs):
def __init__(self, model=None, dataset=None, recorder=None):
super().__init__(recorder=recorder)
self.model = model
self.dataset = dataset
@@ -163,14 +163,16 @@ class SigAnaRecord(SignalRecord):
artifact_path = "sig_analysis"
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, **kwargs):
def __init__(self, recorder, ana_long_short=False, ann_scaler=252):
super().__init__(recorder=recorder)
self.ana_long_short = ana_long_short
self.ann_scaler = ann_scaler
super().__init__(recorder=recorder, **kwargs)
# The name must be unique. Otherwise it will be overridden
def generate(self):
self.check(parent=True)
def generate(self, **kwargs):
try:
self.check(parent=True)
except:
super().generate()
pred = self.load("pred.pkl")
label = self.load("label.pkl")
@@ -228,7 +230,7 @@ class PortAnaRecord(SignalRecord):
config["backtest"] : dict
define the backtest kwargs.
"""
super().__init__(recorder=recorder)
super().__init__(recorder=recorder, **kwargs)
self.strategy_config = config["strategy"]
self.backtest_config = config["backtest"]
@@ -236,10 +238,13 @@ class PortAnaRecord(SignalRecord):
def generate(self, **kwargs):
# check previously stored prediction results
self.check(parent=True) # "Make sure the parent process is completed and store the data properly."
try:
self.check(parent=True) # "Make sure the parent process is completed and store the data properly."
except:
super().generate()
# custom strategy and get backtest
pred_score = super().load()
pred_score = super().load("pred.pkl")
report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
report_normal = report_dict.get("report_df")
positions_normal = report_dict.get("positions")

View File

@@ -139,6 +139,33 @@ def train():
return pred_score, {"ic": ic, "ric": ric}, rid
def train_with_sigana():
"""train model followed by SigAnaRecord
Returns
-------
pred_score: pandas.DataFrame
predict scores
performance: dict
model performance
"""
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
# start exp
with R.start(experiment_name="workflow"):
R.log_params(**flatten_dict(task))
model.fit(dataset)
# predict and calculate ic and ric
recorder = R.get_recorder()
sar = SigAnaRecord(recorder, model=model, dataset=dataset)
sar.generate()
ic = sar.load(sar.get_path("ic.pkl"))
ric = sar.load(sar.get_path("ric.pkl"))
return pred_score, {"ic": ic, "ric": ric}, rid
def fake_experiment():
"""A fake experiment workflow to test uri
@@ -214,6 +241,11 @@ class TestAllFlow(TestAutoData):
self.assertTrue(pass_current, msg="current uri is incorrect")
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
def test_3_train_with_sigana(self):
TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train_with_sigana()
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
def suite():
_suite = unittest.TestSuite()