diff --git a/examples/highfreq/workflow.py b/examples/highfreq/workflow.py index 01de59c0e..c2ca36db3 100644 --- a/examples/highfreq/workflow.py +++ b/examples/highfreq/workflow.py @@ -32,7 +32,6 @@ class HighfreqWorkflow(object): SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None} MARKET = "all" - BENCHMARK = "SH000300" start_time = "2020-09-15 00:00:00" end_time = "2021-01-18 16:00:00" diff --git a/examples/rolling_process_data/rolling_handler.py b/examples/rolling_process_data/rolling_handler.py new file mode 100644 index 000000000..50a36f219 --- /dev/null +++ b/examples/rolling_process_data/rolling_handler.py @@ -0,0 +1,34 @@ +from qlib.data.dataset.handler import DataHandlerLP +from qlib.data.dataset.loader import DataLoaderDH +from qlib.contrib.data.handler import check_transform_proc + + +class RollingDataHandler(DataHandlerLP): + def __init__( + self, + start_time=None, + end_time=None, + infer_processors=[], + learn_processors=[], + fit_start_time=None, + fit_end_time=None, + data_loader_kwargs={} + ): + infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) + learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) + + data_loader = { + "class": "DataLoaderDH", + "kwargs": { + **data_loader_kwargs + }, + } + + super().__init__( + instruments=None, + start_time=start_time, + end_time=end_time, + data_loader=data_loader, + infer_processors=infer_processors, + learn_processors=learn_processors, + ) diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index e69de29bb..8581f149b 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -0,0 +1,145 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import qlib +import pickle +import datetime +import pandas as pd +from qlib.config import REG_CN +from qlib.data.dataset.handler import DataHandlerLP +from qlib.contrib.data.handler import Alpha158 +from qlib.utils import exists_qlib_data, init_instance_by_config +from qlib.tests.data import GetData + +class RollingDataWorkflow(object): + + MARKET = "csi300" + + start_time = "2010-01-01" + end_time = "2019-12-31" + rolling_cnt = 5 + + def _init_qlib(self): + """initialize qlib""" + # use yahoo_cn_1min data + provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir + if not exists_qlib_data(provider_uri): + print(f"Qlib data is not found in {provider_uri}") + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + qlib.init(provider_uri=provider_uri, region=REG_CN) + + def _dump_pre_handler(self, path): + handler_config = { + "class": "Alpha158", + "module_path": "qlib.contrib.data.handler", + "kwargs": { + "start_time": start_time, + "end_time": end_time, + "instruments": MARKET, + }, + } + pre_handler = init_instance_by_config(handler_config) + pre_handler.to_pickle(path) + + def _load_pre_handler(self, path): + with open(path, "rb") as file_dataset: + pre_handler = pickle.load(file_dataset) + return pre_handler + + def rolling_process(self): + self._init_qlib() + self._dump_pre_handler("pre_handler.py") + pre_handler = self._load_pre_handler("pre_handler.py") + + init_start_time = datetime.datetime(2010,1,1) + init_end_time = datetime.datetime(2014,12,31) + init_fit_end_time = datetime.datetime(2012,12,31) + + dataset_config = { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "RollingDataHandler", + "module_path": "rolling_handler", + "kwargs": { + "start_time": init_start_time, + "end_time": init_start_time, + "fit_start_time": init_fit_start_time, + "fit_end_time": init_fit_end_time, + "data_loader_kwargs":{ + "handler_config": pre_handler, + } + }, + }, + "segments": { + "train": (init_start_time, init_fit_end_time), + "valid": (init_start_time, "2013-12-31"), + "test": (init_start_time, init_end_time), + }, + }, + } + + dataset = init_instance_by_config(dataset_config) + + for rolling_offset in range(rolling_cnt): + if rolling_offset: + dataset.init( + handler_kwargs={ + "init_type": DataHandlerLP.IT_FIT_IND, + "start_time": "2021-01-19 00:00:00", + "end_time": "2021-01-25 16:00:00", + }, + segment_kwargs={ + "train": ("2010-01-01", "2012-12-31"), + "valid": ("2013-01-01", "2013-12-31"), + "test": ("2014-01-01", "2014-12-31"), + }, + ) + + +if __name__ == "__main__": + + # use default data + provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir + if not exists_qlib_data(provider_uri): + print(f"Qlib data is not found in {provider_uri}") + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + + qlib.init(provider_uri=provider_uri, region=REG_CN) + + market = "csi300" + benchmark = "SH000300" + + ################################### + # train model + ################################### + data_handler_config = { + "start_time": "2008-01-01", + "end_time": "2020-08-01", + "fit_start_time": "2008-01-01", + "fit_end_time": "2014-12-31", + "instruments": market, + } + + task = { + "dataset": { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "Alpha158", + "module_path": "qlib.contrib.data.handler", + "kwargs": data_handler_config, + }, + "segments": { + "train": ("2008-01-01", "2014-12-31"), + "valid": ("2015-01-01", "2016-12-31"), + "test": ("2017-01-01", "2020-08-01"), + }, + }, + }, + } + + dataset = init_instance_by_config(task["dataset"]) + diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 050043ba6..f4795c566 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -16,7 +16,7 @@ from ...data import D from ...config import C from ...utils import parse_config, transform_end_date, init_instance_by_config from ...utils.serial import Serializable -from .utils import get_level_index, fetch_df_by_index +from .utils import fetch_df_by_index from pathlib import Path from .loader import DataLoader diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 884d15635..f88aaf05e 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -219,14 +219,14 @@ class StaticDataLoader(DataLoader): self._data.sort_index(inplace=True) -class DataHandlerDL(DataLoader): - """DataHandlerDL - DataHandler-based (D)ata (L)oader +class DataLoaderDH(DataLoader): + """DataLoaderDH + DataLoader based on (D)ata (H)andler It is designed to load multiple data from data handler - If you just want to load data from single datahandler, you can write them in single data handler """ - def __init__(self, handler_config: dict, fetch_config: dict = {}, is_group=False): + def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False): """ Parameters ---------- @@ -243,8 +243,8 @@ class DataHandlerDL(DataLoader): := := DataHandler Instance | DataHandler Config - fetch_config : dict - fetch_config will be used to describe the different arguments of fetch method, such as squeeze, data_key, etc. + fetch_kwargs : dict + fetch_kwargs will be used to describe the different arguments of fetch method, such as col_set, squeeze, data_key, etc. is_group: bool is_group will be used to describe whether the key of handler_config is group @@ -258,7 +258,10 @@ class DataHandlerDL(DataLoader): self.handlers = init_instance_by_config(handler_config, accept_types=DataHandler) self.is_group = is_group - self.fetch_config = fetch_config + self.fetch_kwargs = { + "col_set":DataHandler.CS_RAW + } + self.fetch_kwargs = {**self.fetch_kwargs, **fetch_kwargs} def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: if instruments is not None: @@ -267,11 +270,11 @@ class DataHandlerDL(DataLoader): if self.is_group: df = pd.concat( { - grp: dh.fetch(slice(start_time, end_time), col_set=DataHandler.CS_RAW, **fetch_config) + grp: dh.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs) for grp, dh in self.handlers.items() }, axis=1, ) else: - df = self.handler.fetch(slice(start_time, end_time), col_set=DataHandler.CS_RAW, **fetch_config) + df = self.handler.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs) return df