1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 02:50:58 +08:00

balck format

This commit is contained in:
bxdd
2021-03-25 19:56:22 +08:00
parent efe134e9f4
commit a04c6bd6c9

View File

@@ -12,11 +12,12 @@ from qlib.contrib.data.handler import Alpha158
from qlib.utils import exists_qlib_data, init_instance_by_config
from qlib.tests.data import GetData
class RollingDataWorkflow(object):
MARKET = "csi300"
start_time = "2010-01-01"
end_time = "2019-12-31"
end_time = "2019-12-31"
rolling_cnt = 5
def _init_qlib(self):
@@ -27,7 +28,7 @@ class RollingDataWorkflow(object):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)
def _dump_pre_handler(self, path):
handler_config = {
"class": "Alpha158",
@@ -51,13 +52,13 @@ class RollingDataWorkflow(object):
self._dump_pre_handler("pre_handler.py")
pre_handler = self._load_pre_handler("pre_handler.py")
train_start_time = (2010,1,1)
train_end_time = (2012,12,31)
valid_start_time = (2013,1,1)
valid_end_time = (2013,12,31)
test_start_time = (2014,1,1)
test_end_time = (2014,12,31)
train_start_time = (2010, 1, 1)
train_end_time = (2012, 12, 31)
valid_start_time = (2013, 1, 1)
valid_end_time = (2013, 12, 31)
test_start_time = (2014, 1, 1)
test_end_time = (2014, 12, 31)
dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
@@ -70,9 +71,9 @@ class RollingDataWorkflow(object):
"end_time": datetime(*test_end_time),
"fit_start_time": datetime(*train_start_time),
"fit_end_time": datetime(*train_end_time),
"data_loader_kwargs":{
"data_loader_kwargs": {
"handler_config": pre_handler,
}
},
},
},
"segments": {
@@ -94,14 +95,23 @@ class RollingDataWorkflow(object):
"end_time": datetime(test_end_time[0] + 1, *test_end_time[1:]),
},
segment_kwargs={
"train": (datetime(train_start_time[0] + 1, *train_start_time[1:]), datetime(train_end_time[0], *train_end_time[1:])),
"valid": (datetime(valid_start_time[0] + 1, *valid_start_time[1:]), datetime(valid_end_time[0], *valid_end_time[1:])),
"test": (datetime(test_start_time[0] + 1, *test_start_time[1:]), datetime(test_end_time[0], *test_end_time[1:])),
"train": (
datetime(train_start_time[0] + 1, *train_start_time[1:]),
datetime(train_end_time[0], *train_end_time[1:]),
),
"valid": (
datetime(valid_start_time[0] + 1, *valid_start_time[1:]),
datetime(valid_end_time[0], *valid_end_time[1:]),
),
"test": (
datetime(test_start_time[0] + 1, *test_start_time[1:]),
datetime(test_end_time[0], *test_end_time[1:]),
),
},
)
dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"])
dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"])
if __name__ == "__main__":
@@ -147,4 +157,3 @@ if __name__ == "__main__":
}
dataset = init_instance_by_config(task["dataset"])