mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 12:00:58 +08:00
Add MultiSegRecord in contrib.workflow and decouple its tests from test_all_pipeline
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from .record_temp import MultiSegRecord
|
||||
from .record_temp import SignalMseRecord
|
||||
|
||||
@@ -5,14 +5,43 @@ import re
|
||||
import pandas as pd
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from pprint import pprint
|
||||
from typing import Dict, Text, Any
|
||||
import numpy as np
|
||||
|
||||
from ...workflow.record_temp import RecordTemp
|
||||
from ...workflow.record_temp import SignalRecord
|
||||
from ...data import dataset as qlib_dataset
|
||||
from ...log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
|
||||
|
||||
class MultiSegRecord(RecordTemp):
|
||||
"""
|
||||
This is the multiple segments signal record class that generates the signal prediction.
|
||||
This class inherits the ``RecordTemp`` class.
|
||||
"""
|
||||
|
||||
def __init__(self, model, dataset, recorder=None):
|
||||
super().__init__(recorder=recorder)
|
||||
if not isinstance(dataset, qlib_dataset.DatasetH):
|
||||
raise ValueError("The type of dataset is not DatasetH instead of {:}".format(type(dataset)))
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
|
||||
def generate(self, segments: Dict[Text, Any], save: bool = False):
|
||||
# generate prediciton
|
||||
for key, segment in segments.items():
|
||||
predics = self.model.predict(self.dataset, segment)
|
||||
if isinstance(pred, pd.Series):
|
||||
predics = predictions.to_frame("score")
|
||||
# self.recorder.save_objects(**{"pred.pkl": pred})
|
||||
labels = self.dataset.prepare(
|
||||
segments=segment, col_set="label", data_key=dataset.handler.DataHandlerLP.DK_R
|
||||
)
|
||||
# compute ic, rank_ic
|
||||
|
||||
|
||||
class SignalMseRecord(SignalRecord):
|
||||
"""
|
||||
This is the Signal MSE Record class that computes the mean squared error (MSE).
|
||||
|
||||
@@ -159,7 +159,10 @@ class Experiment:
|
||||
if create:
|
||||
recorder, is_new = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name)
|
||||
else:
|
||||
recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
|
||||
recorder, is_new = (
|
||||
self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name),
|
||||
False,
|
||||
)
|
||||
if is_new:
|
||||
self.active_recorder = recorder
|
||||
# start the recorder
|
||||
@@ -174,7 +177,10 @@ class Experiment:
|
||||
try:
|
||||
if recorder_id is None and recorder_name is None:
|
||||
recorder_name = self._default_rec_name
|
||||
return self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
|
||||
return (
|
||||
self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name),
|
||||
False,
|
||||
)
|
||||
except ValueError:
|
||||
if recorder_name is None:
|
||||
recorder_name = self._default_rec_name
|
||||
|
||||
@@ -159,7 +159,10 @@ class ExpManager:
|
||||
if create:
|
||||
exp, is_new = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name)
|
||||
else:
|
||||
exp, is_new = self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False
|
||||
exp, is_new = (
|
||||
self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),
|
||||
False,
|
||||
)
|
||||
if is_new:
|
||||
self.active_experiment = exp
|
||||
# start the recorder
|
||||
@@ -172,7 +175,10 @@ class ExpManager:
|
||||
automatically create a new experiment based on the given id and name.
|
||||
"""
|
||||
try:
|
||||
return self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False
|
||||
return (
|
||||
self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),
|
||||
False,
|
||||
)
|
||||
except ValueError:
|
||||
if experiment_name is None:
|
||||
experiment_name = self._default_exp_name
|
||||
|
||||
@@ -39,7 +39,13 @@ class RecordTemp:
|
||||
return "/".join(names)
|
||||
|
||||
def __init__(self, recorder):
|
||||
self.recorder = recorder
|
||||
self._recorder = recorder
|
||||
|
||||
@property
|
||||
def recorder(self):
|
||||
if self._recorder is None:
|
||||
raise ValueError("This RecordTemp did not set recorder yet.")
|
||||
return self._recorder
|
||||
|
||||
def generate(self, **kwargs):
|
||||
"""
|
||||
@@ -248,11 +254,20 @@ class PortAnaRecord(SignalRecord):
|
||||
report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
|
||||
report_normal = report_dict.get("report_df")
|
||||
positions_normal = report_dict.get("positions")
|
||||
self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
self.recorder.save_objects(
|
||||
**{"report_normal.pkl": report_normal},
|
||||
artifact_path=PortAnaRecord.get_path(),
|
||||
)
|
||||
self.recorder.save_objects(
|
||||
**{"positions_normal.pkl": positions_normal},
|
||||
artifact_path=PortAnaRecord.get_path(),
|
||||
)
|
||||
order_normal = report_dict.get("order_list")
|
||||
if order_normal:
|
||||
self.recorder.save_objects(**{"order_normal.pkl": order_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
self.recorder.save_objects(
|
||||
**{"order_normal.pkl": order_normal},
|
||||
artifact_path=PortAnaRecord.get_path(),
|
||||
)
|
||||
|
||||
# analysis
|
||||
analysis = dict()
|
||||
|
||||
Reference in New Issue
Block a user