1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

update rolling process

This commit is contained in:
bxdd
2021-03-25 16:14:22 +08:00
parent 1fcfe8e4ba
commit f6dc25b229
5 changed files with 192 additions and 11 deletions

View File

@@ -32,7 +32,6 @@ class HighfreqWorkflow(object):
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None}
MARKET = "all"
BENCHMARK = "SH000300"
start_time = "2020-09-15 00:00:00"
end_time = "2021-01-18 16:00:00"

View File

@@ -0,0 +1,34 @@
from qlib.data.dataset.handler import DataHandlerLP
from qlib.data.dataset.loader import DataLoaderDH
from qlib.contrib.data.handler import check_transform_proc
class RollingDataHandler(DataHandlerLP):
def __init__(
self,
start_time=None,
end_time=None,
infer_processors=[],
learn_processors=[],
fit_start_time=None,
fit_end_time=None,
data_loader_kwargs={}
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
data_loader = {
"class": "DataLoaderDH",
"kwargs": {
**data_loader_kwargs
},
}
super().__init__(
instruments=None,
start_time=start_time,
end_time=end_time,
data_loader=data_loader,
infer_processors=infer_processors,
learn_processors=learn_processors,
)

View File

@@ -0,0 +1,145 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import qlib
import pickle
import datetime
import pandas as pd
from qlib.config import REG_CN
from qlib.data.dataset.handler import DataHandlerLP
from qlib.contrib.data.handler import Alpha158
from qlib.utils import exists_qlib_data, init_instance_by_config
from qlib.tests.data import GetData
class RollingDataWorkflow(object):
MARKET = "csi300"
start_time = "2010-01-01"
end_time = "2019-12-31"
rolling_cnt = 5
def _init_qlib(self):
"""initialize qlib"""
# use yahoo_cn_1min data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)
def _dump_pre_handler(self, path):
handler_config = {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": {
"start_time": start_time,
"end_time": end_time,
"instruments": MARKET,
},
}
pre_handler = init_instance_by_config(handler_config)
pre_handler.to_pickle(path)
def _load_pre_handler(self, path):
with open(path, "rb") as file_dataset:
pre_handler = pickle.load(file_dataset)
return pre_handler
def rolling_process(self):
self._init_qlib()
self._dump_pre_handler("pre_handler.py")
pre_handler = self._load_pre_handler("pre_handler.py")
init_start_time = datetime.datetime(2010,1,1)
init_end_time = datetime.datetime(2014,12,31)
init_fit_end_time = datetime.datetime(2012,12,31)
dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "RollingDataHandler",
"module_path": "rolling_handler",
"kwargs": {
"start_time": init_start_time,
"end_time": init_start_time,
"fit_start_time": init_fit_start_time,
"fit_end_time": init_fit_end_time,
"data_loader_kwargs":{
"handler_config": pre_handler,
}
},
},
"segments": {
"train": (init_start_time, init_fit_end_time),
"valid": (init_start_time, "2013-12-31"),
"test": (init_start_time, init_end_time),
},
},
}
dataset = init_instance_by_config(dataset_config)
for rolling_offset in range(rolling_cnt):
if rolling_offset:
dataset.init(
handler_kwargs={
"init_type": DataHandlerLP.IT_FIT_IND,
"start_time": "2021-01-19 00:00:00",
"end_time": "2021-01-25 16:00:00",
},
segment_kwargs={
"train": ("2010-01-01", "2012-12-31"),
"valid": ("2013-01-01", "2013-12-31"),
"test": ("2014-01-01", "2014-12-31"),
},
)
if __name__ == "__main__":
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)
market = "csi300"
benchmark = "SH000300"
###################################
# train model
###################################
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
}
task = {
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
}
dataset = init_instance_by_config(task["dataset"])

View File

@@ -16,7 +16,7 @@ from ...data import D
from ...config import C
from ...utils import parse_config, transform_end_date, init_instance_by_config
from ...utils.serial import Serializable
from .utils import get_level_index, fetch_df_by_index
from .utils import fetch_df_by_index
from pathlib import Path
from .loader import DataLoader

View File

@@ -219,14 +219,14 @@ class StaticDataLoader(DataLoader):
self._data.sort_index(inplace=True)
class DataHandlerDL(DataLoader):
"""DataHandlerDL
DataHandler-based (D)ata (L)oader
class DataLoaderDH(DataLoader):
"""DataLoaderDH
DataLoader based on (D)ata (H)andler
It is designed to load multiple data from data handler
- If you just want to load data from single datahandler, you can write them in single data handler
"""
def __init__(self, handler_config: dict, fetch_config: dict = {}, is_group=False):
def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False):
"""
Parameters
----------
@@ -243,8 +243,8 @@ class DataHandlerDL(DataLoader):
<handler_config> := <handler>
<handler> := DataHandler Instance | DataHandler Config
fetch_config : dict
fetch_config will be used to describe the different arguments of fetch method, such as squeeze, data_key, etc.
fetch_kwargs : dict
fetch_kwargs will be used to describe the different arguments of fetch method, such as col_set, squeeze, data_key, etc.
is_group: bool
is_group will be used to describe whether the key of handler_config is group
@@ -258,7 +258,10 @@ class DataHandlerDL(DataLoader):
self.handlers = init_instance_by_config(handler_config, accept_types=DataHandler)
self.is_group = is_group
self.fetch_config = fetch_config
self.fetch_kwargs = {
"col_set":DataHandler.CS_RAW
}
self.fetch_kwargs = {**self.fetch_kwargs, **fetch_kwargs}
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
if instruments is not None:
@@ -267,11 +270,11 @@ class DataHandlerDL(DataLoader):
if self.is_group:
df = pd.concat(
{
grp: dh.fetch(slice(start_time, end_time), col_set=DataHandler.CS_RAW, **fetch_config)
grp: dh.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs)
for grp, dh in self.handlers.items()
},
axis=1,
)
else:
df = self.handler.fetch(slice(start_time, end_time), col_set=DataHandler.CS_RAW, **fetch_config)
df = self.handler.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs)
return df