diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index be458a24d..641669898 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -110,7 +110,7 @@ class SignalRecord(RecordTemp): This is the Signal Record class that generates the signal prediction. This class inherits the ``RecordTemp`` class. """ - def __init__(self, model=None, dataset=None, recorder=None, **kwargs): + def __init__(self, model=None, dataset=None, recorder=None): super().__init__(recorder=recorder) self.model = model self.dataset = dataset @@ -163,14 +163,16 @@ class SigAnaRecord(SignalRecord): artifact_path = "sig_analysis" - def __init__(self, recorder, ana_long_short=False, ann_scaler=252, **kwargs): + def __init__(self, recorder, ana_long_short=False, ann_scaler=252): + super().__init__(recorder=recorder) self.ana_long_short = ana_long_short self.ann_scaler = ann_scaler - super().__init__(recorder=recorder, **kwargs) - # The name must be unique. Otherwise it will be overridden - def generate(self): - self.check(parent=True) + def generate(self, **kwargs): + try: + self.check(parent=True) + except: + super().generate() pred = self.load("pred.pkl") label = self.load("label.pkl") @@ -228,7 +230,7 @@ class PortAnaRecord(SignalRecord): config["backtest"] : dict define the backtest kwargs. """ - super().__init__(recorder=recorder) + super().__init__(recorder=recorder, **kwargs) self.strategy_config = config["strategy"] self.backtest_config = config["backtest"] @@ -236,10 +238,13 @@ class PortAnaRecord(SignalRecord): def generate(self, **kwargs): # check previously stored prediction results - self.check(parent=True) # "Make sure the parent process is completed and store the data properly." + try: + self.check(parent=True) # "Make sure the parent process is completed and store the data properly." + except: + super().generate() # custom strategy and get backtest - pred_score = super().load() + pred_score = super().load("pred.pkl") report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config) report_normal = report_dict.get("report_df") positions_normal = report_dict.get("positions") diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index fbf15d29a..ac0cad199 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -139,6 +139,33 @@ def train(): return pred_score, {"ic": ic, "ric": ric}, rid +def train_with_sigana(): + """train model followed by SigAnaRecord + + Returns + ------- + pred_score: pandas.DataFrame + predict scores + performance: dict + model performance + """ + model = init_instance_by_config(task["model"]) + dataset = init_instance_by_config(task["dataset"]) + + # start exp + with R.start(experiment_name="workflow"): + R.log_params(**flatten_dict(task)) + model.fit(dataset) + + # predict and calculate ic and ric + recorder = R.get_recorder() + sar = SigAnaRecord(recorder, model=model, dataset=dataset) + sar.generate() + ic = sar.load(sar.get_path("ic.pkl")) + ric = sar.load(sar.get_path("ric.pkl")) + return pred_score, {"ic": ic, "ric": ric}, rid + + def fake_experiment(): """A fake experiment workflow to test uri @@ -214,6 +241,11 @@ class TestAllFlow(TestAutoData): self.assertTrue(pass_current, msg="current uri is incorrect") shutil.rmtree(str(Path(uri_path.strip("file:")).resolve())) + def test_3_train_with_sigana(self): + TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train_with_sigana() + self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed") + self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed") + def suite(): _suite = unittest.TestSuite()