1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 10:31:00 +08:00

update fix CI tests bugs

This commit is contained in:
Young
2021-10-02 08:28:17 +00:00
committed by you-n-g
parent 3a152f9b8b
commit 873129aa9b
10 changed files with 101 additions and 74 deletions

View File

@@ -100,7 +100,6 @@ Here is a quick **[demo](https://terminalizer.com/view/3f24561a4470)** shows how
This table demonstrates the supported Python version of `Qlib`:
| | install with pip | install from source | plot |
| ------------- |:---------------------:|:--------------------:|:----:|
| Python 3.6 | :heavy_check_mark: | :heavy_check_mark: (only with `Anaconda`) | :heavy_check_mark: |
| Python 3.7 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Python 3.8 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Python 3.9 | :x: | :heavy_check_mark: | :x: |

View File

@@ -57,22 +57,20 @@ class MultiSegRecord(RecordTemp):
)
class SignalMseRecord(SignalRecord):
class SignalMseRecord(RecordTemp):
"""
This is the Signal MSE Record class that computes the mean squared error (MSE).
This class inherits the ``SignalMseRecord`` class.
"""
artifact_path = "sig_analysis"
depend_cls = SignalRecord
def __init__(self, recorder, **kwargs):
super().__init__(recorder=recorder, **kwargs)
def generate(self, **kwargs):
try:
self.check(parent=True)
except FileExistsError:
super().generate()
def generate(self):
self.check()
pred = self.load("pred.pkl")
label = self.load("label.pkl")

View File

@@ -44,36 +44,10 @@ class TestAutoData(unittest.TestCase):
)
provider_uri_map = {"1min": cls.provider_uri_1min, "day": provider_uri_day}
client_config = {
"calendar_provider": {
"class": "LocalCalendarProvider",
"module_path": "qlib.data.data",
"kwargs": {
"backend": {
"class": "FileCalendarStorage",
"module_path": "qlib.data.storage.file_storage",
"kwargs": {"provider_uri_map": provider_uri_map},
}
},
},
"feature_provider": {
"class": "LocalFeatureProvider",
"module_path": "qlib.data.data",
"kwargs": {
"backend": {
"class": "FileFeatureStorage",
"module_path": "qlib.data.storage.file_storage",
"kwargs": {"provider_uri_map": provider_uri_map},
}
},
},
}
init(
provider_uri=cls.provider_uri,
provider_uri=provider_uri_map,
region=REG_CN,
expression_cache=None,
dataset_cache=None,
**client_config,
**cls._setup_kwargs,
)

View File

@@ -35,6 +35,10 @@ RECORD_CONFIG = [
{
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
"kwargs": {
"dataset": "<DATASET>",
"model": "<MODEL>",
},
},
{
"class": "SigAnaRecord",

View File

@@ -289,6 +289,25 @@ def init_instance_by_config(
return klass(**cls_kwargs, **kwargs)
@contextlib.contextmanager
def class_casting(obj: object, cls: type):
"""
Python doesn't provide the downcasting mechanism.
We use the trick here to downcast the class
Parameters
----------
obj : object
the object to be cast
cls : type
the target class type
"""
orig_cls = obj.__class__
obj.__class__ = cls
yield
obj.__class__ = orig_cls
def compare_dict_value(src_data: dict, dst_data: dict):
"""Compare dict value

View File

@@ -3,7 +3,7 @@
from contextlib import contextmanager
from typing import Text, Optional
from .expm import MLflowExpManager
from .expm import ExpManager
from .exp import Experiment
from .recorder import Recorder
from ..utils import Wrapper
@@ -16,7 +16,7 @@ class QlibRecorder:
"""
def __init__(self, exp_manager):
self.exp_manager = exp_manager
self.exp_manager: ExpManager = exp_manager
def __repr__(self):
return "{name}(manager={manager})".format(name=self.__class__.__name__, manager=self.exp_manager)
@@ -334,6 +334,26 @@ class QlibRecorder:
"""
self.exp_manager.set_uri(uri)
@contextmanager
def uri_context(self, uri: Text):
"""
Temporarily set the exp_manager's uri to uri
NOTE:
- Please refer to the NOTE in the `set_uri`
Parameters
----------
uri : Text
the temporal uri
"""
prev_uri = self.exp_manager._current_uri
self.exp_manager.set_uri(uri)
try:
yield
finally:
self.exp_manager.set_uri(prev_uri)
def get_recorder(
self, *, recorder_id=None, recorder_name=None, experiment_id=None, experiment_name=None
) -> Recorder:

View File

@@ -16,7 +16,7 @@ from ..data.dataset.handler import DataHandlerLP
from ..backtest import backtest as normal_backtest
from ..utils import init_instance_by_config, get_module_by_module_path
from ..log import get_module_logger
from ..utils import flatten_dict
from ..utils import flatten_dict, class_casting
from ..utils.time import Freq
from ..strategy.base import BaseStrategy
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
@@ -32,6 +32,7 @@ class RecordTemp:
"""
artifact_path = None
depend_cls = None # the depend class of the record; the record will depend on the results generated by `depend_cls`
@classmethod
def get_path(cls, path=None):
@@ -98,21 +99,30 @@ class RecordTemp:
"""
return []
def check(self, cls="self"):
def check(self, include_self: bool = False):
"""
Check if the records is properly generated and saved.
It is useful in fololwing examples
- checking if the depended files complete before genrating new things.
- checking if the final files is completed
Parameters
----------
include_self : bool
is the file generated by self included
Raise
------
FileExistsError: whether the records are stored properly.
"""
artifacts = set(self.recorder.list_artifacts())
if cls == "self":
cls = self
flist = cls.list()
for item in flist:
if item not in artifacts:
raise FileExistsError(item)
if include_self:
for item in self.list():
if item not in artifacts:
raise FileExistsError(item)
if self.depend_cls is not None:
with class_casting(self, self.depend_cls):
self.check(include_self=True)
class SignalRecord(RecordTemp):
@@ -127,26 +137,20 @@ class SignalRecord(RecordTemp):
@staticmethod
def generate_label(dataset):
# NOTE:
# Python doesn't provide the downcasting mechanism.
# We use the trick here to downcast the class
orig_cls = dataset.__class__
dataset.__class__ = DatasetH
params = dict(segments="test", col_set="label", data_key=DataHandlerLP.DK_R)
try:
# Assume the backend handler is DataHandlerLP
raw_label = dataset.prepare(**params)
except TypeError:
# The argument number is not right
del params["data_key"]
# The backend handler should be DataHandler
raw_label = dataset.prepare(**params)
except AttributeError:
# The data handler is initialize with `drop_raw=True`...
# So raw_label is not available
raw_label = None
dataset.__class__ = orig_cls
with class_casting(dataset, DatasetH):
params = dict(segments="test", col_set="label", data_key=DataHandlerLP.DK_R)
try:
# Assume the backend handler is DataHandlerLP
raw_label = dataset.prepare(**params)
except TypeError:
# The argument number is not right
del params["data_key"]
# The backend handler should be DataHandler
raw_label = dataset.prepare(**params)
except AttributeError:
# The data handler is initialize with `drop_raw=True`...
# So raw_label is not available
raw_label = None
return raw_label
def generate(self, **kwargs):
@@ -235,7 +239,7 @@ class SigAnaRecord(RecordTemp):
"""
artifact_path = "sig_analysis"
pre_class = SignalRecord
depend_cls = SignalRecord
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0):
super().__init__(recorder=recorder)
@@ -244,7 +248,7 @@ class SigAnaRecord(RecordTemp):
self.label_col = label_col
def generate(self, **kwargs):
self.check(self.pre_class)
self.check()
pred = self.load("pred.pkl")
label = self.load("label.pkl")

View File

@@ -15,8 +15,6 @@ from qlib.workflow.online.update import LabelUpdater
class TestRolling(TestAutoData):
_setup_kwargs = dict(expression_cache=None, dataset_cache=None)
def test_update_pred(self):
"""
This test is for testing if it will raise error if the `to_date` is out of the boundary.
@@ -26,6 +24,7 @@ class TestRolling(TestAutoData):
task["record"] = {
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
"kwargs": {"dataset": "<DATASET>", "model": "<MODEL>"},
}
exp_name = "online_srv_test"
@@ -65,6 +64,7 @@ class TestRolling(TestAutoData):
task["record"] = {
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
"kwargs": {"dataset": "<DATASET>", "model": "<MODEL>"},
}
exp_name = "online_srv_test"

View File

@@ -47,6 +47,7 @@ def train(uri_path: str = None):
rid = recorder.id
sr = SignalRecord(model, dataset, recorder)
sr.generate()
pred_score = sr.load(sr.get_path("pred.pkl"))
# calculate ic and ric
sar = SigAnaRecord(recorder)
@@ -54,7 +55,7 @@ def train(uri_path: str = None):
ic = sar.load(sar.get_path("ic.pkl"))
ric = sar.load(sar.get_path("ric.pkl"))
return {"ic": ic, "ric": ric}, rid
return pred_score, {"ic": ic, "ric": ric}, rid
def train_with_sigana(uri_path: str = None):
@@ -73,16 +74,20 @@ def train_with_sigana(uri_path: str = None):
with R.start(experiment_name="workflow_with_sigana", uri=uri_path):
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
model.fit(dataset)
recorder = R.get_recorder()
sr = SignalRecord(model, dataset, recorder)
sr.generate()
pred_score = sr.load(sr.get_path("pred.pkl"))
# predict and calculate ic and ric
recorder = R.get_recorder()
sar = SigAnaRecord(recorder, model=model, dataset=dataset)
sar = SigAnaRecord(recorder)
sar.generate()
ic = sar.load(sar.get_path("ic.pkl"))
ric = sar.load(sar.get_path("ric.pkl"))
uri_path = R.get_uri()
return {"ic": ic, "ric": ric}, uri_path
return pred_score, {"ic": ic, "ric": ric}, uri_path
def fake_experiment():
@@ -122,7 +127,9 @@ def backtest_analysis(pred, rid, uri_path: str = None):
the analysis result
"""
recorder = R.get_recorder(experiment_name="workflow", recorder_id=rid)
with R.uri_context(uri=uri_path):
recorder = R.get_recorder(experiment_name="workflow", recorder_id=rid)
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
model = recorder.load_object("trained_model")

View File

@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from qlib.workflow.record_temp import SignalRecord
import shutil
import unittest
from pathlib import Path
@@ -32,7 +33,8 @@ def train_mse(uri_path: str = None):
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
model.fit(dataset)
recorder = R.get_recorder()
sr = SignalMseRecord(recorder, model=model, dataset=dataset)
SignalRecord(recorder=recorder, model=model, dataset=dataset).generate()
sr = SignalMseRecord(recorder)
sr.generate()
uri = R.get_uri()
return uri