1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 02:21:18 +08:00

modified the comments

This commit is contained in:
lzh222333
2021-03-08 13:25:11 +08:00
parent 2882929c5d
commit a244f87f95
3 changed files with 168 additions and 40 deletions

View File

@@ -17,7 +17,7 @@ def task_generator(*args, **kwargs) -> list:
for example:
There are 3 task_config(a,b,c) and 2 TaskGen(A,B). A will double the task_config and B will triple.
task_generator(a=a, b=b, c=c, A=A, B=B) will finally generate 18 task_config.
task_generator(a_key=a, b_key=b, c_key=c, A, B) will finally generate 3*2*3 = 18 task_config.
Parameters
----------
@@ -57,27 +57,37 @@ def task_generator(*args, **kwargs) -> list:
for gen in gen_list:
new_task_list = []
for task in tasks_list:
new_task_list.extend(gen(task))
new_task_list.extend(gen.generate(task))
gen_task_list = new_task_list
return gen_task_list
class TaskGen(metaclass=abc.ABCMeta):
"""
the base class for generate different tasks
Example 1:
input: a specific task template and rolling steps
output: rolling version of the tasks
Example 2:
input: a specific task template and losses list
output: a set of tasks with different losses
"""
@abc.abstractmethod
def __call__(self, *args, **kwargs) -> typing.List[dict]:
def generate(self, task: dict) -> typing.List[dict]:
"""
the base class for generate different tasks
generate different tasks based on a task template
Parameters
----------
args, kwargs:
The info for generating tasks
Example 1):
input: a specific task template
output: rolling version of the tasks
Example 2):
input: a specific task template
output: a set of tasks with different losses
task: dict
a task template
Returns
-------
@@ -89,7 +99,7 @@ class TaskGen(metaclass=abc.ABCMeta):
class RollingGen(TaskGen):
ROLL_EX = TimeAdjuster.SHIFT_EX # fixed start date, expanding end date
ROLL_SD = TimeAdjuster.SHIFT_SD # fixed window size, slide it from start date
ROLL_SD = TimeAdjuster.SHIFT_SD # fixed segments size, slide it from start date
def __init__(self, step: int = 40, rtype: str = ROLL_EX):
"""
@@ -104,12 +114,13 @@ class RollingGen(TaskGen):
"""
self.step = step
self.rtype = rtype
self.ta = TimeAdjuster(future=True) # 为了保证test最后的日期不是None, 所以这边要改一改
# TODO: Ask pengrong to update future date in dataset
self.ta = TimeAdjuster(future=True)
self.test_key = "test"
self.train_key = "train"
def __call__(self, task: dict):
def generate(self, task: dict):
"""
Converting the task into a rolling task
@@ -153,9 +164,9 @@ class RollingGen(TaskGen):
# calculate segments
if prev_seg is None:
# First rolling
# 1) prepare the end porint
# 1) prepare the end point
segments = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1]
test_end = self.ta.last_date() if segments[self.test_key][1] is None else segments[self.test_key][1]
# 2) and the init test segments
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
@@ -164,6 +175,7 @@ class RollingGen(TaskGen):
try:
for k, seg in prev_seg.items():
# decide how to shift
# expanding only for train data, the segments size of test data and valid data won't change
if k == self.train_key and self.rtype == self.ROLL_EX:
rtype = self.ta.SHIFT_EX
else:
@@ -177,6 +189,7 @@ class RollingGen(TaskGen):
# No more rolling
break
# update segments of this task
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
prev_seg = segments
res.append(t)

View File

@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
A task consists of 2 parts
A task consists of 3 parts
- tasks description: the desc will define the task
- tasks status: the status of the task
- tasks result information : A user can get the task with the task description and task result.
@@ -26,18 +26,22 @@ from qlib import auto_init
class TaskManager:
"""TaskManager
here is the what will a task looks like
{
'def': pickle serialized task definition. using pickle will make it easier
'filter': json-like data. This is for filtering the tasks.
'status': 'waiting' | 'running' | 'done'
'res': pickle serialized task result,
}
.. code-block:: python
{
'def': pickle serialized task definition. using pickle will make it easier
'filter': json-like data. This is for filtering the tasks.
'status': 'waiting' | 'running' | 'done'
'res': pickle serialized task result,
}
The tasks manager assume that you will only update the tasks you fetched.
The mongo fetch one and update will make it date updating secure.
NOTE:
- assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded
.. note::
assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded
"""
STATUS_WAITING = "waiting"
@@ -48,6 +52,14 @@ class TaskManager:
ENCODE_FIELDS_PREFIX = ["def", "res"]
def __init__(self, task_pool=None):
"""
init Task Manager, remember to make the statement of MongoDB url and database name firstly.
Parameters
----------
task_pool: str
the name of Collection in MongoDB
"""
self.mdb = get_mongodb()
self.task_pool = task_pool
@@ -100,6 +112,19 @@ class TaskManager:
task_pool.insert_one(task)
def insert_task_def(self, task_def, task_pool=None):
"""
insert a task to task_pool
Parameters
----------
task_def: dict
task_pool: str
the name of Collection in MongoDB
Returns
-------
"""
task_pool = self._get_task_pool(task_pool)
task = self._encode_task(
{
@@ -111,6 +136,23 @@ class TaskManager:
self.insert_task(task, task_pool)
def create_task(self, task_def_l, task_pool=None, dry_run=False, print_nt=False):
"""
if the tasks in task_def_l is new, then insert new tasks into the task_pool
Parameters
----------
task_def_l: list
a list of task
task_pool: str
the name of task_pool (collection name of MongoDB)
dry_run: bool
if insert those new tasks to task pool
print_nt: bool
if print new task
Returns
-------
"""
task_pool = self._get_task_pool(task_pool)
new_tasks = []
for t in task_def_l:
@@ -141,7 +183,7 @@ class TaskManager:
task = task_pool.find_one_and_update(
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
)
# 这里我的 priority 必须是 高数优先级更高,因为 null会被在 ASCENDING时被排在最前面
# null will be at the top after sorting when using ASCENDING, so the larger the number higher, the higher the priority
if task is None:
return None
task["status"] = self.STATUS_RUNNING
@@ -149,6 +191,20 @@ class TaskManager:
@contextmanager
def safe_fetch_task(self, query={}, task_pool=None):
"""
fetch task from task_pool using query with contextmanager
Parameters
----------
query: dict
the dict of query
task_pool: str
the name of Collection in MongoDB
Returns
-------
"""
task = self.fetch_task(query=query, task_pool=task_pool)
try:
yield task
@@ -167,12 +223,20 @@ class TaskManager:
yield task
def query(self, query={}, decode=True, task_pool=None):
"""query
"""
This function may raise exception `pymongo.errors.CursorNotFound: cursor id not found` if it takes too long to iterate the generator
:param query:
:param decode:
:param task_pool:
Parameters
----------
query: dict
the dict of query
decode: bool
task_pool: str
the name of Collection in MongoDB
Returns
-------
"""
query = query.copy()
if "_id" in query:
@@ -196,6 +260,20 @@ class TaskManager:
task_pool.update_one({"_id": task["_id"]}, update_dict)
def remove(self, query={}, task_pool=None):
"""
remove the task using query
Parameters
----------
query: dict
the dict of query
task_pool: str
the name of Collection in MongoDB
Returns
-------
"""
query = query.copy()
task_pool = self._get_task_pool(task_pool)
if "_id" in query:
@@ -250,15 +328,15 @@ class TaskManager:
def run_task(task_func, task_pool, force_release=False, *args, **kwargs):
"""run_task.
While task pool is not empty, use task_func to fetch and run tasks in task_pool
"""
While task pool is not empty (has WAITING tasks), use task_func to fetch and run tasks in task_pool
Parameters
----------
task_func : def (task_def, *args, **kwargs) -> <res which will be committed>
the function to run the task
task_pool :
The name of the task pool
task_pool : str
the name of the task pool (Collection in MongoDB)
force_release :
will the program force to release the resource
args :

View File

@@ -52,11 +52,30 @@ class TimeAdjuster:
def max(self):
"""
Return the max calendar date
(Deprecated)
Return the max calendar datetime
"""
return max(self.cals)
def last_date(self) -> pd.Timestamp:
"""
Return the last datetime in the calendar
"""
return self.cals[-1]
def align_idx(self, time_point, tp_type="start"):
"""
align the index of time_point in the calendar
Parameters
----------
time_point
tp_type : str
Returns
-------
index : int
"""
time_point = pd.Timestamp(time_point)
if tp_type == "start":
idx = bisect.bisect_left(self.cals, time_point)
@@ -68,11 +87,11 @@ class TimeAdjuster:
def align_time(self, time_point, tp_type="start"):
"""
Align a timepoint to calendar weekdays
Align time_point to trade date of calendar
Parameters
----------
time_point :
time_point
Time point
tp_type : str
time point type (`"start"`, `"end"`)
@@ -80,6 +99,24 @@ class TimeAdjuster:
return self.cals[self.align_idx(time_point, tp_type=tp_type)]
def align_seg(self, segment: Union[dict, tuple]):
"""
align the given date to trade date
for example:
input: {'train': ('2008-01-01', '2014-12-31'), 'valid': ('2015-01-01', '2016-12-31'), 'test': ('2017-01-01', '2020-08-01')}
output: {'train': (Timestamp('2008-01-02 00:00:00'), Timestamp('2014-12-31 00:00:00')),
'valid': (Timestamp('2015-01-05 00:00:00'), Timestamp('2016-12-30 00:00:00')),
'test': (Timestamp('2017-01-03 00:00:00'), Timestamp('2020-07-31 00:00:00'))}
Parameters
----------
segment
Returns
-------
the start and end trade date (pd.Timestamp) between the given start and end date.
"""
if isinstance(segment, dict):
return {k: self.align_seg(seg) for k, seg in segment.items()}
elif isinstance(segment, tuple):
@@ -98,7 +135,7 @@ class TimeAdjuster:
test_start
days : int
The trading days to be truncated
大部分情况是因为这个时间段的数据(一般是特征)会用到 `days` 天的数据
the data in this segment may need 'days' data
"""
test_idx = self.align_idx(test_start)
if isinstance(segment, tuple):
@@ -116,7 +153,7 @@ class TimeAdjuster:
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD):
"""
shift the datatiem of segment
shift the datatime of segment
Parameters
----------