1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00
Files
qlib/examples/rolling_process_data/workflow.py
2021-03-25 19:54:52 +08:00

151 lines
5.3 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import qlib
import pickle
import pandas as pd
from datetime import datetime
from qlib.config import REG_CN
from qlib.data.dataset.handler import DataHandlerLP
from qlib.contrib.data.handler import Alpha158
from qlib.utils import exists_qlib_data, init_instance_by_config
from qlib.tests.data import GetData
class RollingDataWorkflow(object):
MARKET = "csi300"
start_time = "2010-01-01"
end_time = "2019-12-31"
rolling_cnt = 5
def _init_qlib(self):
"""initialize qlib"""
# use yahoo_cn_1min data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)
def _dump_pre_handler(self, path):
handler_config = {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": {
"start_time": self.start_time,
"end_time": self.end_time,
"instruments": self.MARKET,
},
}
pre_handler = init_instance_by_config(handler_config)
pre_handler.to_pickle(path)
def _load_pre_handler(self, path):
with open(path, "rb") as file_dataset:
pre_handler = pickle.load(file_dataset)
return pre_handler
def rolling_process(self):
self._init_qlib()
self._dump_pre_handler("pre_handler.py")
pre_handler = self._load_pre_handler("pre_handler.py")
train_start_time = (2010,1,1)
train_end_time = (2012,12,31)
valid_start_time = (2013,1,1)
valid_end_time = (2013,12,31)
test_start_time = (2014,1,1)
test_end_time = (2014,12,31)
dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "RollingDataHandler",
"module_path": "rolling_handler",
"kwargs": {
"start_time": datetime(*train_start_time),
"end_time": datetime(*test_end_time),
"fit_start_time": datetime(*train_start_time),
"fit_end_time": datetime(*train_end_time),
"data_loader_kwargs":{
"handler_config": pre_handler,
}
},
},
"segments": {
"train": (datetime(*train_start_time), datetime(*train_end_time)),
"valid": (datetime(*valid_start_time), datetime(*valid_end_time)),
"test": (datetime(*test_start_time), datetime(*test_end_time)),
},
},
}
dataset = init_instance_by_config(dataset_config)
for rolling_offset in range(rolling_cnt):
if rolling_offset:
dataset.init(
handler_kwargs={
"init_type": DataHandlerLP.IT_FIT_SEQ,
"start_time": datetime(train_start_time[0] + 1, *train_start_time[1:]),
"end_time": datetime(test_end_time[0] + 1, *test_end_time[1:]),
},
segment_kwargs={
"train": (datetime(train_start_time[0] + 1, *train_start_time[1:]), datetime(train_end_time[0], *train_end_time[1:])),
"valid": (datetime(valid_start_time[0] + 1, *valid_start_time[1:]), datetime(valid_end_time[0], *valid_end_time[1:])),
"test": (datetime(test_start_time[0] + 1, *test_start_time[1:]), datetime(test_end_time[0], *test_end_time[1:])),
},
)
dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"])
if __name__ == "__main__":
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)
market = "csi300"
benchmark = "SH000300"
###################################
# train model
###################################
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
}
task = {
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
}
dataset = init_instance_by_config(task["dataset"])