mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 19:10:58 +08:00
Fix errors when SignalRecord is not called before SigAna/PortAna
This commit is contained in:
@@ -110,7 +110,7 @@ class SignalRecord(RecordTemp):
|
||||
This is the Signal Record class that generates the signal prediction. This class inherits the ``RecordTemp`` class.
|
||||
"""
|
||||
|
||||
def __init__(self, model=None, dataset=None, recorder=None, **kwargs):
|
||||
def __init__(self, model=None, dataset=None, recorder=None):
|
||||
super().__init__(recorder=recorder)
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
@@ -163,14 +163,16 @@ class SigAnaRecord(SignalRecord):
|
||||
|
||||
artifact_path = "sig_analysis"
|
||||
|
||||
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, **kwargs):
|
||||
def __init__(self, recorder, ana_long_short=False, ann_scaler=252):
|
||||
super().__init__(recorder=recorder)
|
||||
self.ana_long_short = ana_long_short
|
||||
self.ann_scaler = ann_scaler
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
# The name must be unique. Otherwise it will be overridden
|
||||
|
||||
def generate(self):
|
||||
self.check(parent=True)
|
||||
def generate(self, **kwargs):
|
||||
try:
|
||||
self.check(parent=True)
|
||||
except:
|
||||
super().generate()
|
||||
|
||||
pred = self.load("pred.pkl")
|
||||
label = self.load("label.pkl")
|
||||
@@ -228,7 +230,7 @@ class PortAnaRecord(SignalRecord):
|
||||
config["backtest"] : dict
|
||||
define the backtest kwargs.
|
||||
"""
|
||||
super().__init__(recorder=recorder)
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
|
||||
self.strategy_config = config["strategy"]
|
||||
self.backtest_config = config["backtest"]
|
||||
@@ -236,10 +238,13 @@ class PortAnaRecord(SignalRecord):
|
||||
|
||||
def generate(self, **kwargs):
|
||||
# check previously stored prediction results
|
||||
self.check(parent=True) # "Make sure the parent process is completed and store the data properly."
|
||||
try:
|
||||
self.check(parent=True) # "Make sure the parent process is completed and store the data properly."
|
||||
except:
|
||||
super().generate()
|
||||
|
||||
# custom strategy and get backtest
|
||||
pred_score = super().load()
|
||||
pred_score = super().load("pred.pkl")
|
||||
report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
|
||||
report_normal = report_dict.get("report_df")
|
||||
positions_normal = report_dict.get("positions")
|
||||
|
||||
@@ -139,6 +139,33 @@ def train():
|
||||
return pred_score, {"ic": ic, "ric": ric}, rid
|
||||
|
||||
|
||||
def train_with_sigana():
|
||||
"""train model followed by SigAnaRecord
|
||||
|
||||
Returns
|
||||
-------
|
||||
pred_score: pandas.DataFrame
|
||||
predict scores
|
||||
performance: dict
|
||||
model performance
|
||||
"""
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
model.fit(dataset)
|
||||
|
||||
# predict and calculate ic and ric
|
||||
recorder = R.get_recorder()
|
||||
sar = SigAnaRecord(recorder, model=model, dataset=dataset)
|
||||
sar.generate()
|
||||
ic = sar.load(sar.get_path("ic.pkl"))
|
||||
ric = sar.load(sar.get_path("ric.pkl"))
|
||||
return pred_score, {"ic": ic, "ric": ric}, rid
|
||||
|
||||
|
||||
def fake_experiment():
|
||||
"""A fake experiment workflow to test uri
|
||||
|
||||
@@ -214,6 +241,11 @@ class TestAllFlow(TestAutoData):
|
||||
self.assertTrue(pass_current, msg="current uri is incorrect")
|
||||
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
|
||||
|
||||
def test_3_train_with_sigana(self):
|
||||
TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train_with_sigana()
|
||||
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
|
||||
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
|
||||
|
||||
|
||||
def suite():
|
||||
_suite = unittest.TestSuite()
|
||||
|
||||
Reference in New Issue
Block a user