1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 02:50:58 +08:00

add task_generator method and update some hint

This commit is contained in:
lzh222333
2021-03-03 15:42:39 +08:00
parent b84156fde8
commit 05cf0e1edc

View File

@@ -9,11 +9,64 @@ import typing
from .utils import TimeAdjuster
def task_generator(*args, **kwargs) -> list:
"""
Accept the dict of task config and the TaskGen to generate different tasks.
There is no limit to the number and position of input.
The key of input will add to task config.
for example:
There are 3 task_config(a,b,c) and 2 TaskGen(A,B). A will double the task_config and B will triple.
task_generator(a=a, b=b, c=c, A=A, B=B) will finally generate 18 task_config.
Parameters
----------
args : dict or TaskGen
kwargs : dict or TaskGen
Returns
-------
gen_task_list : list
a list of task config after generating
"""
tasks_list = []
gen_list = []
tmp_id = 1
for task in args:
if isinstance(task, dict):
task["task_key"] = tmp_id
tmp_id += 1
tasks_list.append(task)
elif isinstance(task, TaskGen):
gen_list.append(task)
else:
raise NotImplementedError(f"{type(task)} is not supported in task_generator")
for key, task in kwargs.items():
if isinstance(task, dict):
task["task_key"] = key
tasks_list.append(task)
elif isinstance(task, TaskGen):
gen_list.append(task)
else:
raise NotImplementedError(f"{type(task)} is not supported in task_generator")
# generate gen_task_list
gen_task_list = []
for gen in gen_list:
new_task_list = []
for task in tasks_list:
new_task_list.extend(gen(task))
gen_task_list = new_task_list
return gen_task_list
class TaskGen(metaclass=abc.ABCMeta):
@abc.abstractmethod
def __call__(self, *args, **kwargs) -> typing.List[dict]:
"""
generate
the base class for generate different tasks
Parameters
----------
@@ -35,9 +88,8 @@ class TaskGen(metaclass=abc.ABCMeta):
class RollingGen(TaskGen):
ROLL_EX = TimeAdjuster.SHIFT_EX
ROLL_SD = TimeAdjuster.SHIFT_SD
ROLL_EX = TimeAdjuster.SHIFT_EX # fixed start date, expanding end date
ROLL_SD = TimeAdjuster.SHIFT_SD # fixed window size, slide it from start date
def __init__(self, step: int = 40, rtype: str = ROLL_EX):
"""
@@ -48,7 +100,7 @@ class RollingGen(TaskGen):
step : int
step to rolling
rtype : str
rolling type (expanding, rolling)
rolling type (expanding, sliding)
"""
self.step = step
self.rtype = rtype
@@ -111,12 +163,12 @@ class RollingGen(TaskGen):
segments = {}
try:
for k, seg in prev_seg.items():
# 决定怎么shift
# decide how to shift
if k == self.train_key and self.rtype == self.ROLL_EX:
rtype = self.ta.SHIFT_EX
else:
rtype = self.ta.SHIFT_SD
# 整段数据做shift
# shift the segments data
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
if segments[self.test_key][0] > test_end:
break