From fd2c1ba1ed1c3919b6ddd418f0b3f82239f0baf5 Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Wed, 3 Mar 2021 16:36:15 +0800 Subject: [PATCH] Update some hint --- qlib/workflow/task/collect.py | 9 ++++----- qlib/workflow/task/manage.py | 8 ++------ qlib/workflow/task/utils.py | 29 +++++++++++++++++++++-------- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index 7cdca30fa..9a67d8e06 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -4,17 +4,17 @@ from typing import Union from tqdm.auto import tqdm -class RollingEnsemble: +class RollingCollector: """ Rolling Models Ensemble based on (R)ecord This shares nothing with Ensemble """ - # TODO: 这边还可以加加速 + # TODO: speed up this class def __init__(self, get_key_func, flt_func=None): - self.get_key_func = get_key_func - self.flt_func = flt_func + self.get_key_func = get_key_func # user need to implement this method to get the key of a task based on task config + self.flt_func = flt_func # determine whether a task can be retained based on task config def __call__(self, exp_name) -> Union[pd.Series, dict]: # TODO; @@ -26,7 +26,6 @@ class RollingEnsemble: recs_flt = {} for rid, rec in tqdm(recs.items(), desc="Loading data"): - # rec = exp.get_recorder(recorder_id=rid) params = rec.load_object("param") if rec.status == rec.STATUS_FI: if self.flt_func is None or self.flt_func(params): diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index 3bcac8360..1a4c341de 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -36,12 +36,8 @@ class TaskManager: The tasks manager assume that you will only update the tasks you fetched. The mongo fetch one and update will make it date updating secure. - Usage Examples from the CLI. - python -m blocks.tasks.__init__ task_stat --task_pool meta_task_rule - - NOTE: - - 假设: 存储在db里面的都是encode过的, 拿出来的都是decode过的 + - assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded """ STATUS_WAITING = "waiting" @@ -85,7 +81,7 @@ class TaskManager: return {k: str(v) for k, v in flt.items()} def replace_task(self, task, new_task, task_pool=None): - # 这里的假设是从接口拿出来的都是decode过的,在接口内部的都是 encode过的 + # assume that the data out of interface was decoded and the data in interface was encoded new_task = self._encode_task(new_task) task_pool = self._get_task_pool(task_pool) query = {"_id": ObjectId(task["_id"])} diff --git a/qlib/workflow/task/utils.py b/qlib/workflow/task/utils.py index d6089ff66..719359d5b 100644 --- a/qlib/workflow/task/utils.py +++ b/qlib/workflow/task/utils.py @@ -6,9 +6,19 @@ from qlib.data import D from qlib.config import C from qlib.log import get_module_logger from pymongo import MongoClient - +from typing import Union def get_mongodb(): + """ + + get database in MongoDB, which means you need to declare the address and the name of database. + for example: + C["mongo"] = { + "task_url" : "mongodb://localhost:27017/", + "task_db_name" : "rolling_db" + } + + """ try: cfg = C["mongo"] except KeyError: @@ -20,7 +30,9 @@ def get_mongodb(): class TimeAdjuster: - """找到合适的日期,然后adjust date""" + """ + find appropriate date and adjust date. + """ def __init__(self, future=False): self.cals = D.calendar(future=future) @@ -40,7 +52,7 @@ class TimeAdjuster: def max(self): """ - Return return the max calendar date + Return the max calendar date """ return max(self.cals) @@ -56,7 +68,7 @@ class TimeAdjuster: def align_time(self, time_point, tp_type="start"): """ - Align a timepoint to calendar weekdays + Align a timepoint to calendar weekdays Parameters ---------- @@ -67,7 +79,7 @@ class TimeAdjuster: """ return self.cals[self.align_idx(time_point, tp_type=tp_type)] - def align_seg(self, segment): + def align_seg(self, segment: Union[dict, tuple]): if isinstance(segment, dict): return {k: self.align_seg(seg) for k, seg in segment.items()} elif isinstance(segment, tuple): @@ -75,14 +87,15 @@ class TimeAdjuster: else: raise NotImplementedError(f"This type of input is not supported") - def truncate(self, segment, test_start, days: int): + def truncate(self, segment: tuple, test_start, days: int): """ truncate the segment based on the test_start date Parameters ---------- - segment : + segment : tuple time segment + test_start days : int The trading days to be truncated 大部分情况是因为这个时间段的数据(一般是特征)会用到 `days` 天的数据 @@ -101,7 +114,7 @@ class TimeAdjuster: SHIFT_SD = "sliding" SHIFT_EX = "expanding" - def shift(self, seg, step: int, rtype=SHIFT_SD): + def shift(self, seg: tuple, step: int, rtype=SHIFT_SD): """ shift the datatiem of segment