From 88b0871c12d0b139da489c53e02444606f6ca634 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Tue, 16 Mar 2021 22:55:28 +0800 Subject: [PATCH] Add RMSE for contrib.workflow.record_temp and unit tests --- qlib/contrib/workflow/record_temp.py | 5 +++-- tests/test_all_pipeline.py | 4 ++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/qlib/contrib/workflow/record_temp.py b/qlib/contrib/workflow/record_temp.py index 2b9930743..7094d844e 100644 --- a/qlib/contrib/workflow/record_temp.py +++ b/qlib/contrib/workflow/record_temp.py @@ -36,12 +36,13 @@ class SignalMseRecord(SignalRecord): mse = mean_squared_error(pred.values[masks], label[masks]) metrics = { "MSE": mse, + "RMSE": np.sqrt(mse) } - objects = {"mse.pkl": mse} + objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)} self.recorder.log_metrics(**metrics) self.recorder.save_objects(**objects, artifact_path=self.get_path()) pprint(metrics) def list(self): - paths = [self.get_path("mse.pkl")] + paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")] return paths diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index 21a82cd30..29d39179d 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -19,6 +19,7 @@ from qlib.contrib.evaluate import ( backtest as normal_backtest, risk_analysis, ) +from qlib.contrib.workflow.record_temp import SignalMseRecord from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord @@ -164,6 +165,9 @@ def train_with_sigana(): ic = sar.load(sar.get_path("ic.pkl")) ric = sar.load(sar.get_path("ric.pkl")) pred_score = sar.load("pred.pkl") + + smr = SignalMseRecord(recorder) + smr.generate() uri_path = R.get_uri() return pred_score, {"ic": ic, "ric": ric}, uri_path