mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 02:21:18 +08:00
modified the comments
This commit is contained in:
@@ -17,7 +17,7 @@ def task_generator(*args, **kwargs) -> list:
|
||||
|
||||
for example:
|
||||
There are 3 task_config(a,b,c) and 2 TaskGen(A,B). A will double the task_config and B will triple.
|
||||
task_generator(a=a, b=b, c=c, A=A, B=B) will finally generate 18 task_config.
|
||||
task_generator(a_key=a, b_key=b, c_key=c, A, B) will finally generate 3*2*3 = 18 task_config.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -57,27 +57,37 @@ def task_generator(*args, **kwargs) -> list:
|
||||
for gen in gen_list:
|
||||
new_task_list = []
|
||||
for task in tasks_list:
|
||||
new_task_list.extend(gen(task))
|
||||
new_task_list.extend(gen.generate(task))
|
||||
gen_task_list = new_task_list
|
||||
return gen_task_list
|
||||
|
||||
|
||||
class TaskGen(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
the base class for generate different tasks
|
||||
|
||||
Example 1:
|
||||
|
||||
input: a specific task template and rolling steps
|
||||
|
||||
output: rolling version of the tasks
|
||||
|
||||
Example 2:
|
||||
|
||||
input: a specific task template and losses list
|
||||
|
||||
output: a set of tasks with different losses
|
||||
|
||||
"""
|
||||
@abc.abstractmethod
|
||||
def __call__(self, *args, **kwargs) -> typing.List[dict]:
|
||||
def generate(self, task: dict) -> typing.List[dict]:
|
||||
"""
|
||||
the base class for generate different tasks
|
||||
generate different tasks based on a task template
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args, kwargs:
|
||||
The info for generating tasks
|
||||
Example 1):
|
||||
input: a specific task template
|
||||
output: rolling version of the tasks
|
||||
Example 2):
|
||||
input: a specific task template
|
||||
output: a set of tasks with different losses
|
||||
task: dict
|
||||
a task template
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -89,7 +99,7 @@ class TaskGen(metaclass=abc.ABCMeta):
|
||||
|
||||
class RollingGen(TaskGen):
|
||||
ROLL_EX = TimeAdjuster.SHIFT_EX # fixed start date, expanding end date
|
||||
ROLL_SD = TimeAdjuster.SHIFT_SD # fixed window size, slide it from start date
|
||||
ROLL_SD = TimeAdjuster.SHIFT_SD # fixed segments size, slide it from start date
|
||||
|
||||
def __init__(self, step: int = 40, rtype: str = ROLL_EX):
|
||||
"""
|
||||
@@ -104,12 +114,13 @@ class RollingGen(TaskGen):
|
||||
"""
|
||||
self.step = step
|
||||
self.rtype = rtype
|
||||
self.ta = TimeAdjuster(future=True) # 为了保证test最后的日期不是None, 所以这边要改一改
|
||||
# TODO: Ask pengrong to update future date in dataset
|
||||
self.ta = TimeAdjuster(future=True)
|
||||
|
||||
self.test_key = "test"
|
||||
self.train_key = "train"
|
||||
|
||||
def __call__(self, task: dict):
|
||||
def generate(self, task: dict):
|
||||
"""
|
||||
Converting the task into a rolling task
|
||||
|
||||
@@ -153,9 +164,9 @@ class RollingGen(TaskGen):
|
||||
# calculate segments
|
||||
if prev_seg is None:
|
||||
# First rolling
|
||||
# 1) prepare the end porint
|
||||
# 1) prepare the end point
|
||||
segments = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
|
||||
test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1]
|
||||
test_end = self.ta.last_date() if segments[self.test_key][1] is None else segments[self.test_key][1]
|
||||
# 2) and the init test segments
|
||||
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
|
||||
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
|
||||
@@ -164,6 +175,7 @@ class RollingGen(TaskGen):
|
||||
try:
|
||||
for k, seg in prev_seg.items():
|
||||
# decide how to shift
|
||||
# expanding only for train data, the segments size of test data and valid data won't change
|
||||
if k == self.train_key and self.rtype == self.ROLL_EX:
|
||||
rtype = self.ta.SHIFT_EX
|
||||
else:
|
||||
@@ -177,6 +189,7 @@ class RollingGen(TaskGen):
|
||||
# No more rolling
|
||||
break
|
||||
|
||||
# update segments of this task
|
||||
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
|
||||
prev_seg = segments
|
||||
res.append(t)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
A task consists of 2 parts
|
||||
A task consists of 3 parts
|
||||
- tasks description: the desc will define the task
|
||||
- tasks status: the status of the task
|
||||
- tasks result information : A user can get the task with the task description and task result.
|
||||
@@ -26,18 +26,22 @@ from qlib import auto_init
|
||||
class TaskManager:
|
||||
"""TaskManager
|
||||
here is the what will a task looks like
|
||||
{
|
||||
'def': pickle serialized task definition. using pickle will make it easier
|
||||
'filter': json-like data. This is for filtering the tasks.
|
||||
'status': 'waiting' | 'running' | 'done'
|
||||
'res': pickle serialized task result,
|
||||
}
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'def': pickle serialized task definition. using pickle will make it easier
|
||||
'filter': json-like data. This is for filtering the tasks.
|
||||
'status': 'waiting' | 'running' | 'done'
|
||||
'res': pickle serialized task result,
|
||||
}
|
||||
|
||||
The tasks manager assume that you will only update the tasks you fetched.
|
||||
The mongo fetch one and update will make it date updating secure.
|
||||
|
||||
NOTE:
|
||||
- assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded
|
||||
.. note::
|
||||
|
||||
assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded
|
||||
"""
|
||||
|
||||
STATUS_WAITING = "waiting"
|
||||
@@ -48,6 +52,14 @@ class TaskManager:
|
||||
ENCODE_FIELDS_PREFIX = ["def", "res"]
|
||||
|
||||
def __init__(self, task_pool=None):
|
||||
"""
|
||||
init Task Manager, remember to make the statement of MongoDB url and database name firstly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_pool: str
|
||||
the name of Collection in MongoDB
|
||||
"""
|
||||
self.mdb = get_mongodb()
|
||||
self.task_pool = task_pool
|
||||
|
||||
@@ -100,6 +112,19 @@ class TaskManager:
|
||||
task_pool.insert_one(task)
|
||||
|
||||
def insert_task_def(self, task_def, task_pool=None):
|
||||
"""
|
||||
insert a task to task_pool
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_def: dict
|
||||
task_pool: str
|
||||
the name of Collection in MongoDB
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
task = self._encode_task(
|
||||
{
|
||||
@@ -111,6 +136,23 @@ class TaskManager:
|
||||
self.insert_task(task, task_pool)
|
||||
|
||||
def create_task(self, task_def_l, task_pool=None, dry_run=False, print_nt=False):
|
||||
"""
|
||||
if the tasks in task_def_l is new, then insert new tasks into the task_pool
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_def_l: list
|
||||
a list of task
|
||||
task_pool: str
|
||||
the name of task_pool (collection name of MongoDB)
|
||||
dry_run: bool
|
||||
if insert those new tasks to task pool
|
||||
print_nt: bool
|
||||
if print new task
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
new_tasks = []
|
||||
for t in task_def_l:
|
||||
@@ -141,7 +183,7 @@ class TaskManager:
|
||||
task = task_pool.find_one_and_update(
|
||||
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
|
||||
)
|
||||
# 这里我的 priority 必须是 高数优先级更高,因为 null会被在 ASCENDING时被排在最前面
|
||||
# null will be at the top after sorting when using ASCENDING, so the larger the number higher, the higher the priority
|
||||
if task is None:
|
||||
return None
|
||||
task["status"] = self.STATUS_RUNNING
|
||||
@@ -149,6 +191,20 @@ class TaskManager:
|
||||
|
||||
@contextmanager
|
||||
def safe_fetch_task(self, query={}, task_pool=None):
|
||||
"""
|
||||
fetch task from task_pool using query with contextmanager
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: dict
|
||||
the dict of query
|
||||
task_pool: str
|
||||
the name of Collection in MongoDB
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
task = self.fetch_task(query=query, task_pool=task_pool)
|
||||
try:
|
||||
yield task
|
||||
@@ -167,12 +223,20 @@ class TaskManager:
|
||||
yield task
|
||||
|
||||
def query(self, query={}, decode=True, task_pool=None):
|
||||
"""query
|
||||
"""
|
||||
This function may raise exception `pymongo.errors.CursorNotFound: cursor id not found` if it takes too long to iterate the generator
|
||||
|
||||
:param query:
|
||||
:param decode:
|
||||
:param task_pool:
|
||||
Parameters
|
||||
----------
|
||||
query: dict
|
||||
the dict of query
|
||||
decode: bool
|
||||
task_pool: str
|
||||
the name of Collection in MongoDB
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
@@ -196,6 +260,20 @@ class TaskManager:
|
||||
task_pool.update_one({"_id": task["_id"]}, update_dict)
|
||||
|
||||
def remove(self, query={}, task_pool=None):
|
||||
"""
|
||||
remove the task using query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: dict
|
||||
the dict of query
|
||||
task_pool: str
|
||||
the name of Collection in MongoDB
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
query = query.copy()
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
if "_id" in query:
|
||||
@@ -250,15 +328,15 @@ class TaskManager:
|
||||
|
||||
|
||||
def run_task(task_func, task_pool, force_release=False, *args, **kwargs):
|
||||
"""run_task.
|
||||
While task pool is not empty, use task_func to fetch and run tasks in task_pool
|
||||
"""
|
||||
While task pool is not empty (has WAITING tasks), use task_func to fetch and run tasks in task_pool
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_func : def (task_def, *args, **kwargs) -> <res which will be committed>
|
||||
the function to run the task
|
||||
task_pool :
|
||||
The name of the task pool
|
||||
task_pool : str
|
||||
the name of the task pool (Collection in MongoDB)
|
||||
force_release :
|
||||
will the program force to release the resource
|
||||
args :
|
||||
|
||||
@@ -52,11 +52,30 @@ class TimeAdjuster:
|
||||
|
||||
def max(self):
|
||||
"""
|
||||
Return the max calendar date
|
||||
(Deprecated)
|
||||
Return the max calendar datetime
|
||||
"""
|
||||
return max(self.cals)
|
||||
|
||||
def last_date(self) -> pd.Timestamp:
|
||||
"""
|
||||
Return the last datetime in the calendar
|
||||
"""
|
||||
return self.cals[-1]
|
||||
|
||||
def align_idx(self, time_point, tp_type="start"):
|
||||
"""
|
||||
align the index of time_point in the calendar
|
||||
|
||||
Parameters
|
||||
----------
|
||||
time_point
|
||||
tp_type : str
|
||||
|
||||
Returns
|
||||
-------
|
||||
index : int
|
||||
"""
|
||||
time_point = pd.Timestamp(time_point)
|
||||
if tp_type == "start":
|
||||
idx = bisect.bisect_left(self.cals, time_point)
|
||||
@@ -68,11 +87,11 @@ class TimeAdjuster:
|
||||
|
||||
def align_time(self, time_point, tp_type="start"):
|
||||
"""
|
||||
Align a timepoint to calendar weekdays
|
||||
Align time_point to trade date of calendar
|
||||
|
||||
Parameters
|
||||
----------
|
||||
time_point :
|
||||
time_point
|
||||
Time point
|
||||
tp_type : str
|
||||
time point type (`"start"`, `"end"`)
|
||||
@@ -80,6 +99,24 @@ class TimeAdjuster:
|
||||
return self.cals[self.align_idx(time_point, tp_type=tp_type)]
|
||||
|
||||
def align_seg(self, segment: Union[dict, tuple]):
|
||||
"""
|
||||
align the given date to trade date
|
||||
|
||||
for example:
|
||||
input: {'train': ('2008-01-01', '2014-12-31'), 'valid': ('2015-01-01', '2016-12-31'), 'test': ('2017-01-01', '2020-08-01')}
|
||||
|
||||
output: {'train': (Timestamp('2008-01-02 00:00:00'), Timestamp('2014-12-31 00:00:00')),
|
||||
'valid': (Timestamp('2015-01-05 00:00:00'), Timestamp('2016-12-30 00:00:00')),
|
||||
'test': (Timestamp('2017-01-03 00:00:00'), Timestamp('2020-07-31 00:00:00'))}
|
||||
|
||||
Parameters
|
||||
----------
|
||||
segment
|
||||
|
||||
Returns
|
||||
-------
|
||||
the start and end trade date (pd.Timestamp) between the given start and end date.
|
||||
"""
|
||||
if isinstance(segment, dict):
|
||||
return {k: self.align_seg(seg) for k, seg in segment.items()}
|
||||
elif isinstance(segment, tuple):
|
||||
@@ -98,7 +135,7 @@ class TimeAdjuster:
|
||||
test_start
|
||||
days : int
|
||||
The trading days to be truncated
|
||||
大部分情况是因为这个时间段的数据(一般是特征)会用到 `days` 天的数据
|
||||
the data in this segment may need 'days' data
|
||||
"""
|
||||
test_idx = self.align_idx(test_start)
|
||||
if isinstance(segment, tuple):
|
||||
@@ -116,7 +153,7 @@ class TimeAdjuster:
|
||||
|
||||
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD):
|
||||
"""
|
||||
shift the datatiem of segment
|
||||
shift the datatime of segment
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Reference in New Issue
Block a user