diff --git a/qlib/contrib/workflow/record_temp.py b/qlib/contrib/workflow/record_temp.py index 7094d844e..3fdf0c281 100644 --- a/qlib/contrib/workflow/record_temp.py +++ b/qlib/contrib/workflow/record_temp.py @@ -34,10 +34,7 @@ class SignalMseRecord(SignalRecord): label = self.load("label.pkl") masks = ~np.isnan(label.values) mse = mean_squared_error(pred.values[masks], label[masks]) - metrics = { - "MSE": mse, - "RMSE": np.sqrt(mse) - } + metrics = {"MSE": mse, "RMSE": np.sqrt(mse)} objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)} self.recorder.log_metrics(**metrics) self.recorder.save_objects(**objects, artifact_path=self.get_path())