From 05cf0e1edcdbe0b696b7e8c1cde538e3a5168dfa Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Wed, 3 Mar 2021 15:42:39 +0800 Subject: [PATCH] add task_generator method and update some hint --- qlib/workflow/task/gen.py | 66 ++++++++++++++++++++++++++++++++++----- 1 file changed, 59 insertions(+), 7 deletions(-) diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index 9b031435e..efbfe94a6 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -9,11 +9,64 @@ import typing from .utils import TimeAdjuster +def task_generator(*args, **kwargs) -> list: + """ + Accept the dict of task config and the TaskGen to generate different tasks. + There is no limit to the number and position of input. + The key of input will add to task config. + + for example: + There are 3 task_config(a,b,c) and 2 TaskGen(A,B). A will double the task_config and B will triple. + task_generator(a=a, b=b, c=c, A=A, B=B) will finally generate 18 task_config. + + Parameters + ---------- + args : dict or TaskGen + kwargs : dict or TaskGen + + Returns + ------- + gen_task_list : list + a list of task config after generating + """ + tasks_list = [] + gen_list = [] + + tmp_id = 1 + for task in args: + if isinstance(task, dict): + task["task_key"] = tmp_id + tmp_id += 1 + tasks_list.append(task) + elif isinstance(task, TaskGen): + gen_list.append(task) + else: + raise NotImplementedError(f"{type(task)} is not supported in task_generator") + + for key, task in kwargs.items(): + if isinstance(task, dict): + task["task_key"] = key + tasks_list.append(task) + elif isinstance(task, TaskGen): + gen_list.append(task) + else: + raise NotImplementedError(f"{type(task)} is not supported in task_generator") + + # generate gen_task_list + gen_task_list = [] + for gen in gen_list: + new_task_list = [] + for task in tasks_list: + new_task_list.extend(gen(task)) + gen_task_list = new_task_list + return gen_task_list + + class TaskGen(metaclass=abc.ABCMeta): @abc.abstractmethod def __call__(self, *args, **kwargs) -> typing.List[dict]: """ - generate + the base class for generate different tasks Parameters ---------- @@ -35,9 +88,8 @@ class TaskGen(metaclass=abc.ABCMeta): class RollingGen(TaskGen): - - ROLL_EX = TimeAdjuster.SHIFT_EX - ROLL_SD = TimeAdjuster.SHIFT_SD + ROLL_EX = TimeAdjuster.SHIFT_EX # fixed start date, expanding end date + ROLL_SD = TimeAdjuster.SHIFT_SD # fixed window size, slide it from start date def __init__(self, step: int = 40, rtype: str = ROLL_EX): """ @@ -48,7 +100,7 @@ class RollingGen(TaskGen): step : int step to rolling rtype : str - rolling type (expanding, rolling) + rolling type (expanding, sliding) """ self.step = step self.rtype = rtype @@ -111,12 +163,12 @@ class RollingGen(TaskGen): segments = {} try: for k, seg in prev_seg.items(): - # 决定怎么shift + # decide how to shift if k == self.train_key and self.rtype == self.ROLL_EX: rtype = self.ta.SHIFT_EX else: rtype = self.ta.SHIFT_SD - # 整段数据做shift + # shift the segments data segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype) if segments[self.test_key][0] > test_end: break