1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 12:00:58 +08:00

Update some hint

This commit is contained in:
lzh222333
2021-03-03 16:36:15 +08:00
parent 05cf0e1edc
commit fd2c1ba1ed
3 changed files with 27 additions and 19 deletions

View File

@@ -4,17 +4,17 @@ from typing import Union
from tqdm.auto import tqdm
class RollingEnsemble:
class RollingCollector:
"""
Rolling Models Ensemble based on (R)ecord
This shares nothing with Ensemble
"""
# TODO: 这边还可以加加速
# TODO: speed up this class
def __init__(self, get_key_func, flt_func=None):
self.get_key_func = get_key_func
self.flt_func = flt_func
self.get_key_func = get_key_func # user need to implement this method to get the key of a task based on task config
self.flt_func = flt_func # determine whether a task can be retained based on task config
def __call__(self, exp_name) -> Union[pd.Series, dict]:
# TODO;
@@ -26,7 +26,6 @@ class RollingEnsemble:
recs_flt = {}
for rid, rec in tqdm(recs.items(), desc="Loading data"):
# rec = exp.get_recorder(recorder_id=rid)
params = rec.load_object("param")
if rec.status == rec.STATUS_FI:
if self.flt_func is None or self.flt_func(params):

View File

@@ -36,12 +36,8 @@ class TaskManager:
The tasks manager assume that you will only update the tasks you fetched.
The mongo fetch one and update will make it date updating secure.
Usage Examples from the CLI.
python -m blocks.tasks.__init__ task_stat --task_pool meta_task_rule
NOTE:
- 假设: 存储在db里面的都是encode过的 拿出来的都是decode过的
- assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded
"""
STATUS_WAITING = "waiting"
@@ -85,7 +81,7 @@ class TaskManager:
return {k: str(v) for k, v in flt.items()}
def replace_task(self, task, new_task, task_pool=None):
# 这里的假设是从接口拿出来的都是decode过的在接口内部的都是 encode过的
# assume that the data out of interface was decoded and the data in interface was encoded
new_task = self._encode_task(new_task)
task_pool = self._get_task_pool(task_pool)
query = {"_id": ObjectId(task["_id"])}

View File

@@ -6,9 +6,19 @@ from qlib.data import D
from qlib.config import C
from qlib.log import get_module_logger
from pymongo import MongoClient
from typing import Union
def get_mongodb():
"""
get database in MongoDB, which means you need to declare the address and the name of database.
for example:
C["mongo"] = {
"task_url" : "mongodb://localhost:27017/",
"task_db_name" : "rolling_db"
}
"""
try:
cfg = C["mongo"]
except KeyError:
@@ -20,7 +30,9 @@ def get_mongodb():
class TimeAdjuster:
"""找到合适的日期然后adjust date"""
"""
find appropriate date and adjust date.
"""
def __init__(self, future=False):
self.cals = D.calendar(future=future)
@@ -40,7 +52,7 @@ class TimeAdjuster:
def max(self):
"""
Return return the max calendar date
Return the max calendar date
"""
return max(self.cals)
@@ -56,7 +68,7 @@ class TimeAdjuster:
def align_time(self, time_point, tp_type="start"):
"""
Align a timepoint to calendar weekdays
Align a timepoint to calendar weekdays
Parameters
----------
@@ -67,7 +79,7 @@ class TimeAdjuster:
"""
return self.cals[self.align_idx(time_point, tp_type=tp_type)]
def align_seg(self, segment):
def align_seg(self, segment: Union[dict, tuple]):
if isinstance(segment, dict):
return {k: self.align_seg(seg) for k, seg in segment.items()}
elif isinstance(segment, tuple):
@@ -75,14 +87,15 @@ class TimeAdjuster:
else:
raise NotImplementedError(f"This type of input is not supported")
def truncate(self, segment, test_start, days: int):
def truncate(self, segment: tuple, test_start, days: int):
"""
truncate the segment based on the test_start date
Parameters
----------
segment :
segment : tuple
time segment
test_start
days : int
The trading days to be truncated
大部分情况是因为这个时间段的数据(一般是特征)会用到 `days` 天的数据
@@ -101,7 +114,7 @@ class TimeAdjuster:
SHIFT_SD = "sliding"
SHIFT_EX = "expanding"
def shift(self, seg, step: int, rtype=SHIFT_SD):
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD):
"""
shift the datatiem of segment