mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 10:31:00 +08:00
update fix CI tests bugs
This commit is contained in:
@@ -100,7 +100,6 @@ Here is a quick **[demo](https://terminalizer.com/view/3f24561a4470)** shows how
|
||||
This table demonstrates the supported Python version of `Qlib`:
|
||||
| | install with pip | install from source | plot |
|
||||
| ------------- |:---------------------:|:--------------------:|:----:|
|
||||
| Python 3.6 | :heavy_check_mark: | :heavy_check_mark: (only with `Anaconda`) | :heavy_check_mark: |
|
||||
| Python 3.7 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.8 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
|
||||
|
||||
@@ -57,22 +57,20 @@ class MultiSegRecord(RecordTemp):
|
||||
)
|
||||
|
||||
|
||||
class SignalMseRecord(SignalRecord):
|
||||
class SignalMseRecord(RecordTemp):
|
||||
"""
|
||||
This is the Signal MSE Record class that computes the mean squared error (MSE).
|
||||
This class inherits the ``SignalMseRecord`` class.
|
||||
"""
|
||||
|
||||
artifact_path = "sig_analysis"
|
||||
depend_cls = SignalRecord
|
||||
|
||||
def __init__(self, recorder, **kwargs):
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
|
||||
def generate(self, **kwargs):
|
||||
try:
|
||||
self.check(parent=True)
|
||||
except FileExistsError:
|
||||
super().generate()
|
||||
def generate(self):
|
||||
self.check()
|
||||
|
||||
pred = self.load("pred.pkl")
|
||||
label = self.load("label.pkl")
|
||||
|
||||
@@ -44,36 +44,10 @@ class TestAutoData(unittest.TestCase):
|
||||
)
|
||||
|
||||
provider_uri_map = {"1min": cls.provider_uri_1min, "day": provider_uri_day}
|
||||
|
||||
client_config = {
|
||||
"calendar_provider": {
|
||||
"class": "LocalCalendarProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileCalendarStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
}
|
||||
},
|
||||
},
|
||||
"feature_provider": {
|
||||
"class": "LocalFeatureProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileFeatureStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
init(
|
||||
provider_uri=cls.provider_uri,
|
||||
provider_uri=provider_uri_map,
|
||||
region=REG_CN,
|
||||
expression_cache=None,
|
||||
dataset_cache=None,
|
||||
**client_config,
|
||||
**cls._setup_kwargs,
|
||||
)
|
||||
|
||||
@@ -35,6 +35,10 @@ RECORD_CONFIG = [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
"kwargs": {
|
||||
"dataset": "<DATASET>",
|
||||
"model": "<MODEL>",
|
||||
},
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
|
||||
@@ -289,6 +289,25 @@ def init_instance_by_config(
|
||||
return klass(**cls_kwargs, **kwargs)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def class_casting(obj: object, cls: type):
|
||||
"""
|
||||
Python doesn't provide the downcasting mechanism.
|
||||
We use the trick here to downcast the class
|
||||
|
||||
Parameters
|
||||
----------
|
||||
obj : object
|
||||
the object to be cast
|
||||
cls : type
|
||||
the target class type
|
||||
"""
|
||||
orig_cls = obj.__class__
|
||||
obj.__class__ = cls
|
||||
yield
|
||||
obj.__class__ = orig_cls
|
||||
|
||||
|
||||
def compare_dict_value(src_data: dict, dst_data: dict):
|
||||
"""Compare dict value
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Text, Optional
|
||||
from .expm import MLflowExpManager
|
||||
from .expm import ExpManager
|
||||
from .exp import Experiment
|
||||
from .recorder import Recorder
|
||||
from ..utils import Wrapper
|
||||
@@ -16,7 +16,7 @@ class QlibRecorder:
|
||||
"""
|
||||
|
||||
def __init__(self, exp_manager):
|
||||
self.exp_manager = exp_manager
|
||||
self.exp_manager: ExpManager = exp_manager
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(manager={manager})".format(name=self.__class__.__name__, manager=self.exp_manager)
|
||||
@@ -334,6 +334,26 @@ class QlibRecorder:
|
||||
"""
|
||||
self.exp_manager.set_uri(uri)
|
||||
|
||||
@contextmanager
|
||||
def uri_context(self, uri: Text):
|
||||
"""
|
||||
Temporarily set the exp_manager's uri to uri
|
||||
|
||||
NOTE:
|
||||
- Please refer to the NOTE in the `set_uri`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri : Text
|
||||
the temporal uri
|
||||
"""
|
||||
prev_uri = self.exp_manager._current_uri
|
||||
self.exp_manager.set_uri(uri)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.exp_manager.set_uri(prev_uri)
|
||||
|
||||
def get_recorder(
|
||||
self, *, recorder_id=None, recorder_name=None, experiment_id=None, experiment_name=None
|
||||
) -> Recorder:
|
||||
|
||||
@@ -16,7 +16,7 @@ from ..data.dataset.handler import DataHandlerLP
|
||||
from ..backtest import backtest as normal_backtest
|
||||
from ..utils import init_instance_by_config, get_module_by_module_path
|
||||
from ..log import get_module_logger
|
||||
from ..utils import flatten_dict
|
||||
from ..utils import flatten_dict, class_casting
|
||||
from ..utils.time import Freq
|
||||
from ..strategy.base import BaseStrategy
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
|
||||
@@ -32,6 +32,7 @@ class RecordTemp:
|
||||
"""
|
||||
|
||||
artifact_path = None
|
||||
depend_cls = None # the depend class of the record; the record will depend on the results generated by `depend_cls`
|
||||
|
||||
@classmethod
|
||||
def get_path(cls, path=None):
|
||||
@@ -98,21 +99,30 @@ class RecordTemp:
|
||||
"""
|
||||
return []
|
||||
|
||||
def check(self, cls="self"):
|
||||
def check(self, include_self: bool = False):
|
||||
"""
|
||||
Check if the records is properly generated and saved.
|
||||
It is useful in fololwing examples
|
||||
- checking if the depended files complete before genrating new things.
|
||||
- checking if the final files is completed
|
||||
|
||||
Parameters
|
||||
----------
|
||||
include_self : bool
|
||||
is the file generated by self included
|
||||
|
||||
Raise
|
||||
------
|
||||
FileExistsError: whether the records are stored properly.
|
||||
"""
|
||||
artifacts = set(self.recorder.list_artifacts())
|
||||
if cls == "self":
|
||||
cls = self
|
||||
flist = cls.list()
|
||||
for item in flist:
|
||||
if item not in artifacts:
|
||||
raise FileExistsError(item)
|
||||
if include_self:
|
||||
for item in self.list():
|
||||
if item not in artifacts:
|
||||
raise FileExistsError(item)
|
||||
if self.depend_cls is not None:
|
||||
with class_casting(self, self.depend_cls):
|
||||
self.check(include_self=True)
|
||||
|
||||
|
||||
class SignalRecord(RecordTemp):
|
||||
@@ -127,26 +137,20 @@ class SignalRecord(RecordTemp):
|
||||
|
||||
@staticmethod
|
||||
def generate_label(dataset):
|
||||
# NOTE:
|
||||
# Python doesn't provide the downcasting mechanism.
|
||||
# We use the trick here to downcast the class
|
||||
orig_cls = dataset.__class__
|
||||
dataset.__class__ = DatasetH
|
||||
|
||||
params = dict(segments="test", col_set="label", data_key=DataHandlerLP.DK_R)
|
||||
try:
|
||||
# Assume the backend handler is DataHandlerLP
|
||||
raw_label = dataset.prepare(**params)
|
||||
except TypeError:
|
||||
# The argument number is not right
|
||||
del params["data_key"]
|
||||
# The backend handler should be DataHandler
|
||||
raw_label = dataset.prepare(**params)
|
||||
except AttributeError:
|
||||
# The data handler is initialize with `drop_raw=True`...
|
||||
# So raw_label is not available
|
||||
raw_label = None
|
||||
dataset.__class__ = orig_cls
|
||||
with class_casting(dataset, DatasetH):
|
||||
params = dict(segments="test", col_set="label", data_key=DataHandlerLP.DK_R)
|
||||
try:
|
||||
# Assume the backend handler is DataHandlerLP
|
||||
raw_label = dataset.prepare(**params)
|
||||
except TypeError:
|
||||
# The argument number is not right
|
||||
del params["data_key"]
|
||||
# The backend handler should be DataHandler
|
||||
raw_label = dataset.prepare(**params)
|
||||
except AttributeError:
|
||||
# The data handler is initialize with `drop_raw=True`...
|
||||
# So raw_label is not available
|
||||
raw_label = None
|
||||
return raw_label
|
||||
|
||||
def generate(self, **kwargs):
|
||||
@@ -235,7 +239,7 @@ class SigAnaRecord(RecordTemp):
|
||||
"""
|
||||
|
||||
artifact_path = "sig_analysis"
|
||||
pre_class = SignalRecord
|
||||
depend_cls = SignalRecord
|
||||
|
||||
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0):
|
||||
super().__init__(recorder=recorder)
|
||||
@@ -244,7 +248,7 @@ class SigAnaRecord(RecordTemp):
|
||||
self.label_col = label_col
|
||||
|
||||
def generate(self, **kwargs):
|
||||
self.check(self.pre_class)
|
||||
self.check()
|
||||
|
||||
pred = self.load("pred.pkl")
|
||||
label = self.load("label.pkl")
|
||||
|
||||
@@ -15,8 +15,6 @@ from qlib.workflow.online.update import LabelUpdater
|
||||
|
||||
|
||||
class TestRolling(TestAutoData):
|
||||
_setup_kwargs = dict(expression_cache=None, dataset_cache=None)
|
||||
|
||||
def test_update_pred(self):
|
||||
"""
|
||||
This test is for testing if it will raise error if the `to_date` is out of the boundary.
|
||||
@@ -26,6 +24,7 @@ class TestRolling(TestAutoData):
|
||||
task["record"] = {
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
"kwargs": {"dataset": "<DATASET>", "model": "<MODEL>"},
|
||||
}
|
||||
|
||||
exp_name = "online_srv_test"
|
||||
@@ -65,6 +64,7 @@ class TestRolling(TestAutoData):
|
||||
task["record"] = {
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
"kwargs": {"dataset": "<DATASET>", "model": "<MODEL>"},
|
||||
}
|
||||
|
||||
exp_name = "online_srv_test"
|
||||
|
||||
@@ -47,6 +47,7 @@ def train(uri_path: str = None):
|
||||
rid = recorder.id
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
pred_score = sr.load(sr.get_path("pred.pkl"))
|
||||
|
||||
# calculate ic and ric
|
||||
sar = SigAnaRecord(recorder)
|
||||
@@ -54,7 +55,7 @@ def train(uri_path: str = None):
|
||||
ic = sar.load(sar.get_path("ic.pkl"))
|
||||
ric = sar.load(sar.get_path("ric.pkl"))
|
||||
|
||||
return {"ic": ic, "ric": ric}, rid
|
||||
return pred_score, {"ic": ic, "ric": ric}, rid
|
||||
|
||||
|
||||
def train_with_sigana(uri_path: str = None):
|
||||
@@ -73,16 +74,20 @@ def train_with_sigana(uri_path: str = None):
|
||||
with R.start(experiment_name="workflow_with_sigana", uri=uri_path):
|
||||
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
pred_score = sr.load(sr.get_path("pred.pkl"))
|
||||
|
||||
# predict and calculate ic and ric
|
||||
recorder = R.get_recorder()
|
||||
sar = SigAnaRecord(recorder, model=model, dataset=dataset)
|
||||
sar = SigAnaRecord(recorder)
|
||||
sar.generate()
|
||||
ic = sar.load(sar.get_path("ic.pkl"))
|
||||
ric = sar.load(sar.get_path("ric.pkl"))
|
||||
|
||||
uri_path = R.get_uri()
|
||||
return {"ic": ic, "ric": ric}, uri_path
|
||||
return pred_score, {"ic": ic, "ric": ric}, uri_path
|
||||
|
||||
|
||||
def fake_experiment():
|
||||
@@ -122,7 +127,9 @@ def backtest_analysis(pred, rid, uri_path: str = None):
|
||||
the analysis result
|
||||
|
||||
"""
|
||||
recorder = R.get_recorder(experiment_name="workflow", recorder_id=rid)
|
||||
with R.uri_context(uri=uri_path):
|
||||
recorder = R.get_recorder(experiment_name="workflow", recorder_id=rid)
|
||||
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
model = recorder.load_object("trained_model")
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
import shutil
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
@@ -32,7 +33,8 @@ def train_mse(uri_path: str = None):
|
||||
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
sr = SignalMseRecord(recorder, model=model, dataset=dataset)
|
||||
SignalRecord(recorder=recorder, model=model, dataset=dataset).generate()
|
||||
sr = SignalMseRecord(recorder)
|
||||
sr.generate()
|
||||
uri = R.get_uri()
|
||||
return uri
|
||||
|
||||
Reference in New Issue
Block a user