mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 03:50:57 +08:00
Add RMSE for contrib.workflow.record_temp and unit tests
This commit is contained in:
@@ -36,12 +36,13 @@ class SignalMseRecord(SignalRecord):
|
||||
mse = mean_squared_error(pred.values[masks], label[masks])
|
||||
metrics = {
|
||||
"MSE": mse,
|
||||
"RMSE": np.sqrt(mse)
|
||||
}
|
||||
objects = {"mse.pkl": mse}
|
||||
objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)}
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
pprint(metrics)
|
||||
|
||||
def list(self):
|
||||
paths = [self.get_path("mse.pkl")]
|
||||
paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")]
|
||||
return paths
|
||||
|
||||
@@ -19,6 +19,7 @@ from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.contrib.workflow.record_temp import SignalMseRecord
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord
|
||||
@@ -164,6 +165,9 @@ def train_with_sigana():
|
||||
ic = sar.load(sar.get_path("ic.pkl"))
|
||||
ric = sar.load(sar.get_path("ric.pkl"))
|
||||
pred_score = sar.load("pred.pkl")
|
||||
|
||||
smr = SignalMseRecord(recorder)
|
||||
smr.generate()
|
||||
uri_path = R.get_uri()
|
||||
return pred_score, {"ic": ic, "ric": ric}, uri_path
|
||||
|
||||
|
||||
Reference in New Issue
Block a user