1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 03:50:57 +08:00

Add RMSE for contrib.workflow.record_temp and unit tests

This commit is contained in:
D-X-Y
2021-03-16 22:55:28 +08:00
parent d4aa681652
commit 88b0871c12
2 changed files with 7 additions and 2 deletions

View File

@@ -36,12 +36,13 @@ class SignalMseRecord(SignalRecord):
mse = mean_squared_error(pred.values[masks], label[masks])
metrics = {
"MSE": mse,
"RMSE": np.sqrt(mse)
}
objects = {"mse.pkl": mse}
objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)}
self.recorder.log_metrics(**metrics)
self.recorder.save_objects(**objects, artifact_path=self.get_path())
pprint(metrics)
def list(self):
paths = [self.get_path("mse.pkl")]
paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")]
return paths

View File

@@ -19,6 +19,7 @@ from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from qlib.contrib.workflow.record_temp import SignalMseRecord
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord
@@ -164,6 +165,9 @@ def train_with_sigana():
ic = sar.load(sar.get_path("ic.pkl"))
ric = sar.load(sar.get_path("ric.pkl"))
pred_score = sar.load("pred.pkl")
smr = SignalMseRecord(recorder)
smr.generate()
uri_path = R.get_uri()
return pred_score, {"ic": ic, "ric": ric}, uri_path