mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 02:50:58 +08:00
add task_generator method and update some hint
This commit is contained in:
@@ -9,11 +9,64 @@ import typing
|
||||
from .utils import TimeAdjuster
|
||||
|
||||
|
||||
def task_generator(*args, **kwargs) -> list:
|
||||
"""
|
||||
Accept the dict of task config and the TaskGen to generate different tasks.
|
||||
There is no limit to the number and position of input.
|
||||
The key of input will add to task config.
|
||||
|
||||
for example:
|
||||
There are 3 task_config(a,b,c) and 2 TaskGen(A,B). A will double the task_config and B will triple.
|
||||
task_generator(a=a, b=b, c=c, A=A, B=B) will finally generate 18 task_config.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args : dict or TaskGen
|
||||
kwargs : dict or TaskGen
|
||||
|
||||
Returns
|
||||
-------
|
||||
gen_task_list : list
|
||||
a list of task config after generating
|
||||
"""
|
||||
tasks_list = []
|
||||
gen_list = []
|
||||
|
||||
tmp_id = 1
|
||||
for task in args:
|
||||
if isinstance(task, dict):
|
||||
task["task_key"] = tmp_id
|
||||
tmp_id += 1
|
||||
tasks_list.append(task)
|
||||
elif isinstance(task, TaskGen):
|
||||
gen_list.append(task)
|
||||
else:
|
||||
raise NotImplementedError(f"{type(task)} is not supported in task_generator")
|
||||
|
||||
for key, task in kwargs.items():
|
||||
if isinstance(task, dict):
|
||||
task["task_key"] = key
|
||||
tasks_list.append(task)
|
||||
elif isinstance(task, TaskGen):
|
||||
gen_list.append(task)
|
||||
else:
|
||||
raise NotImplementedError(f"{type(task)} is not supported in task_generator")
|
||||
|
||||
# generate gen_task_list
|
||||
gen_task_list = []
|
||||
for gen in gen_list:
|
||||
new_task_list = []
|
||||
for task in tasks_list:
|
||||
new_task_list.extend(gen(task))
|
||||
gen_task_list = new_task_list
|
||||
return gen_task_list
|
||||
|
||||
|
||||
class TaskGen(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def __call__(self, *args, **kwargs) -> typing.List[dict]:
|
||||
"""
|
||||
generate
|
||||
the base class for generate different tasks
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -35,9 +88,8 @@ class TaskGen(metaclass=abc.ABCMeta):
|
||||
|
||||
|
||||
class RollingGen(TaskGen):
|
||||
|
||||
ROLL_EX = TimeAdjuster.SHIFT_EX
|
||||
ROLL_SD = TimeAdjuster.SHIFT_SD
|
||||
ROLL_EX = TimeAdjuster.SHIFT_EX # fixed start date, expanding end date
|
||||
ROLL_SD = TimeAdjuster.SHIFT_SD # fixed window size, slide it from start date
|
||||
|
||||
def __init__(self, step: int = 40, rtype: str = ROLL_EX):
|
||||
"""
|
||||
@@ -48,7 +100,7 @@ class RollingGen(TaskGen):
|
||||
step : int
|
||||
step to rolling
|
||||
rtype : str
|
||||
rolling type (expanding, rolling)
|
||||
rolling type (expanding, sliding)
|
||||
"""
|
||||
self.step = step
|
||||
self.rtype = rtype
|
||||
@@ -111,12 +163,12 @@ class RollingGen(TaskGen):
|
||||
segments = {}
|
||||
try:
|
||||
for k, seg in prev_seg.items():
|
||||
# 决定怎么shift
|
||||
# decide how to shift
|
||||
if k == self.train_key and self.rtype == self.ROLL_EX:
|
||||
rtype = self.ta.SHIFT_EX
|
||||
else:
|
||||
rtype = self.ta.SHIFT_SD
|
||||
# 整段数据做shift
|
||||
# shift the segments data
|
||||
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
|
||||
if segments[self.test_key][0] > test_end:
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user