From a244f87f95e70ca2f97a687be10e3e3f606517a0 Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Mon, 8 Mar 2021 13:25:11 +0800 Subject: [PATCH] modified the comments --- qlib/workflow/task/gen.py | 47 +++++++++------ qlib/workflow/task/manage.py | 114 +++++++++++++++++++++++++++++------ qlib/workflow/task/utils.py | 47 +++++++++++++-- 3 files changed, 168 insertions(+), 40 deletions(-) diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index efbfe94a6..60fc5c221 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -17,7 +17,7 @@ def task_generator(*args, **kwargs) -> list: for example: There are 3 task_config(a,b,c) and 2 TaskGen(A,B). A will double the task_config and B will triple. - task_generator(a=a, b=b, c=c, A=A, B=B) will finally generate 18 task_config. + task_generator(a_key=a, b_key=b, c_key=c, A, B) will finally generate 3*2*3 = 18 task_config. Parameters ---------- @@ -57,27 +57,37 @@ def task_generator(*args, **kwargs) -> list: for gen in gen_list: new_task_list = [] for task in tasks_list: - new_task_list.extend(gen(task)) + new_task_list.extend(gen.generate(task)) gen_task_list = new_task_list return gen_task_list class TaskGen(metaclass=abc.ABCMeta): + """ + the base class for generate different tasks + + Example 1: + + input: a specific task template and rolling steps + + output: rolling version of the tasks + + Example 2: + + input: a specific task template and losses list + + output: a set of tasks with different losses + + """ @abc.abstractmethod - def __call__(self, *args, **kwargs) -> typing.List[dict]: + def generate(self, task: dict) -> typing.List[dict]: """ - the base class for generate different tasks + generate different tasks based on a task template Parameters ---------- - args, kwargs: - The info for generating tasks - Example 1): - input: a specific task template - output: rolling version of the tasks - Example 2): - input: a specific task template - output: a set of tasks with different losses + task: dict + a task template Returns ------- @@ -89,7 +99,7 @@ class TaskGen(metaclass=abc.ABCMeta): class RollingGen(TaskGen): ROLL_EX = TimeAdjuster.SHIFT_EX # fixed start date, expanding end date - ROLL_SD = TimeAdjuster.SHIFT_SD # fixed window size, slide it from start date + ROLL_SD = TimeAdjuster.SHIFT_SD # fixed segments size, slide it from start date def __init__(self, step: int = 40, rtype: str = ROLL_EX): """ @@ -104,12 +114,13 @@ class RollingGen(TaskGen): """ self.step = step self.rtype = rtype - self.ta = TimeAdjuster(future=True) # 为了保证test最后的日期不是None, 所以这边要改一改 + # TODO: Ask pengrong to update future date in dataset + self.ta = TimeAdjuster(future=True) self.test_key = "test" self.train_key = "train" - def __call__(self, task: dict): + def generate(self, task: dict): """ Converting the task into a rolling task @@ -153,9 +164,9 @@ class RollingGen(TaskGen): # calculate segments if prev_seg is None: # First rolling - # 1) prepare the end porint + # 1) prepare the end point segments = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"])) - test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1] + test_end = self.ta.last_date() if segments[self.test_key][1] is None else segments[self.test_key][1] # 2) and the init test segments test_start_idx = self.ta.align_idx(segments[self.test_key][0]) segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1)) @@ -164,6 +175,7 @@ class RollingGen(TaskGen): try: for k, seg in prev_seg.items(): # decide how to shift + # expanding only for train data, the segments size of test data and valid data won't change if k == self.train_key and self.rtype == self.ROLL_EX: rtype = self.ta.SHIFT_EX else: @@ -177,6 +189,7 @@ class RollingGen(TaskGen): # No more rolling break + # update segments of this task t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments) prev_seg = segments res.append(t) diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index 1a4c341de..ae4aee147 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. """ -A task consists of 2 parts +A task consists of 3 parts - tasks description: the desc will define the task - tasks status: the status of the task - tasks result information : A user can get the task with the task description and task result. @@ -26,18 +26,22 @@ from qlib import auto_init class TaskManager: """TaskManager here is the what will a task looks like - { - 'def': pickle serialized task definition. using pickle will make it easier - 'filter': json-like data. This is for filtering the tasks. - 'status': 'waiting' | 'running' | 'done' - 'res': pickle serialized task result, - } + + .. code-block:: python + + { + 'def': pickle serialized task definition. using pickle will make it easier + 'filter': json-like data. This is for filtering the tasks. + 'status': 'waiting' | 'running' | 'done' + 'res': pickle serialized task result, + } The tasks manager assume that you will only update the tasks you fetched. The mongo fetch one and update will make it date updating secure. - NOTE: - - assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded + .. note:: + + assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded """ STATUS_WAITING = "waiting" @@ -48,6 +52,14 @@ class TaskManager: ENCODE_FIELDS_PREFIX = ["def", "res"] def __init__(self, task_pool=None): + """ + init Task Manager, remember to make the statement of MongoDB url and database name firstly. + + Parameters + ---------- + task_pool: str + the name of Collection in MongoDB + """ self.mdb = get_mongodb() self.task_pool = task_pool @@ -100,6 +112,19 @@ class TaskManager: task_pool.insert_one(task) def insert_task_def(self, task_def, task_pool=None): + """ + insert a task to task_pool + + Parameters + ---------- + task_def: dict + task_pool: str + the name of Collection in MongoDB + + Returns + ------- + + """ task_pool = self._get_task_pool(task_pool) task = self._encode_task( { @@ -111,6 +136,23 @@ class TaskManager: self.insert_task(task, task_pool) def create_task(self, task_def_l, task_pool=None, dry_run=False, print_nt=False): + """ + if the tasks in task_def_l is new, then insert new tasks into the task_pool + + Parameters + ---------- + task_def_l: list + a list of task + task_pool: str + the name of task_pool (collection name of MongoDB) + dry_run: bool + if insert those new tasks to task pool + print_nt: bool + if print new task + Returns + ------- + + """ task_pool = self._get_task_pool(task_pool) new_tasks = [] for t in task_def_l: @@ -141,7 +183,7 @@ class TaskManager: task = task_pool.find_one_and_update( query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)] ) - # 这里我的 priority 必须是 高数优先级更高,因为 null会被在 ASCENDING时被排在最前面 + # null will be at the top after sorting when using ASCENDING, so the larger the number higher, the higher the priority if task is None: return None task["status"] = self.STATUS_RUNNING @@ -149,6 +191,20 @@ class TaskManager: @contextmanager def safe_fetch_task(self, query={}, task_pool=None): + """ + fetch task from task_pool using query with contextmanager + + Parameters + ---------- + query: dict + the dict of query + task_pool: str + the name of Collection in MongoDB + + Returns + ------- + + """ task = self.fetch_task(query=query, task_pool=task_pool) try: yield task @@ -167,12 +223,20 @@ class TaskManager: yield task def query(self, query={}, decode=True, task_pool=None): - """query + """ This function may raise exception `pymongo.errors.CursorNotFound: cursor id not found` if it takes too long to iterate the generator - :param query: - :param decode: - :param task_pool: + Parameters + ---------- + query: dict + the dict of query + decode: bool + task_pool: str + the name of Collection in MongoDB + + Returns + ------- + """ query = query.copy() if "_id" in query: @@ -196,6 +260,20 @@ class TaskManager: task_pool.update_one({"_id": task["_id"]}, update_dict) def remove(self, query={}, task_pool=None): + """ + remove the task using query + + Parameters + ---------- + query: dict + the dict of query + task_pool: str + the name of Collection in MongoDB + + Returns + ------- + + """ query = query.copy() task_pool = self._get_task_pool(task_pool) if "_id" in query: @@ -250,15 +328,15 @@ class TaskManager: def run_task(task_func, task_pool, force_release=False, *args, **kwargs): - """run_task. - While task pool is not empty, use task_func to fetch and run tasks in task_pool + """ + While task pool is not empty (has WAITING tasks), use task_func to fetch and run tasks in task_pool Parameters ---------- task_func : def (task_def, *args, **kwargs) -> the function to run the task - task_pool : - The name of the task pool + task_pool : str + the name of the task pool (Collection in MongoDB) force_release : will the program force to release the resource args : diff --git a/qlib/workflow/task/utils.py b/qlib/workflow/task/utils.py index 719359d5b..5e94f55ae 100644 --- a/qlib/workflow/task/utils.py +++ b/qlib/workflow/task/utils.py @@ -52,11 +52,30 @@ class TimeAdjuster: def max(self): """ - Return the max calendar date + (Deprecated) + Return the max calendar datetime """ return max(self.cals) + def last_date(self) -> pd.Timestamp: + """ + Return the last datetime in the calendar + """ + return self.cals[-1] + def align_idx(self, time_point, tp_type="start"): + """ + align the index of time_point in the calendar + + Parameters + ---------- + time_point + tp_type : str + + Returns + ------- + index : int + """ time_point = pd.Timestamp(time_point) if tp_type == "start": idx = bisect.bisect_left(self.cals, time_point) @@ -68,11 +87,11 @@ class TimeAdjuster: def align_time(self, time_point, tp_type="start"): """ - Align a timepoint to calendar weekdays + Align time_point to trade date of calendar Parameters ---------- - time_point : + time_point Time point tp_type : str time point type (`"start"`, `"end"`) @@ -80,6 +99,24 @@ class TimeAdjuster: return self.cals[self.align_idx(time_point, tp_type=tp_type)] def align_seg(self, segment: Union[dict, tuple]): + """ + align the given date to trade date + + for example: + input: {'train': ('2008-01-01', '2014-12-31'), 'valid': ('2015-01-01', '2016-12-31'), 'test': ('2017-01-01', '2020-08-01')} + + output: {'train': (Timestamp('2008-01-02 00:00:00'), Timestamp('2014-12-31 00:00:00')), + 'valid': (Timestamp('2015-01-05 00:00:00'), Timestamp('2016-12-30 00:00:00')), + 'test': (Timestamp('2017-01-03 00:00:00'), Timestamp('2020-07-31 00:00:00'))} + + Parameters + ---------- + segment + + Returns + ------- + the start and end trade date (pd.Timestamp) between the given start and end date. + """ if isinstance(segment, dict): return {k: self.align_seg(seg) for k, seg in segment.items()} elif isinstance(segment, tuple): @@ -98,7 +135,7 @@ class TimeAdjuster: test_start days : int The trading days to be truncated - 大部分情况是因为这个时间段的数据(一般是特征)会用到 `days` 天的数据 + the data in this segment may need 'days' data """ test_idx = self.align_idx(test_start) if isinstance(segment, tuple): @@ -116,7 +153,7 @@ class TimeAdjuster: def shift(self, seg: tuple, step: int, rtype=SHIFT_SD): """ - shift the datatiem of segment + shift the datatime of segment Parameters ----------