1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 11:30:57 +08:00

Add MSERecord in contrib.workflow

This commit is contained in:
D-X-Y
2021-03-16 12:54:12 +00:00
parent 4cb74d77d1
commit d4aa681652
2 changed files with 47 additions and 0 deletions

View File

View File

@@ -0,0 +1,47 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
import pandas as pd
from sklearn.metrics import mean_squared_error
from pprint import pprint
import numpy as np
from ...workflow.record_temp import SignalRecord
from ...log import get_module_logger
logger = get_module_logger("workflow", "INFO")
class SignalMseRecord(SignalRecord):
"""
This is the Signal MSE Record class that computes the mean squared error (MSE).
This class inherits the ``SignalMseRecord`` class.
"""
artifact_path = "sig_analysis"
def __init__(self, recorder, **kwargs):
super().__init__(recorder=recorder, **kwargs)
def generate(self, **kwargs):
try:
self.check(parent=True)
except FileExistsError:
super().generate()
pred = self.load("pred.pkl")
label = self.load("label.pkl")
masks = ~np.isnan(label.values)
mse = mean_squared_error(pred.values[masks], label[masks])
metrics = {
"MSE": mse,
}
objects = {"mse.pkl": mse}
self.recorder.log_metrics(**metrics)
self.recorder.save_objects(**objects, artifact_path=self.get_path())
pprint(metrics)
def list(self):
paths = [self.get_path("mse.pkl")]
return paths