mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 12:00:58 +08:00
Update some hint
This commit is contained in:
@@ -4,17 +4,17 @@ from typing import Union
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
class RollingEnsemble:
|
||||
class RollingCollector:
|
||||
"""
|
||||
Rolling Models Ensemble based on (R)ecord
|
||||
|
||||
This shares nothing with Ensemble
|
||||
"""
|
||||
|
||||
# TODO: 这边还可以加加速
|
||||
# TODO: speed up this class
|
||||
def __init__(self, get_key_func, flt_func=None):
|
||||
self.get_key_func = get_key_func
|
||||
self.flt_func = flt_func
|
||||
self.get_key_func = get_key_func # user need to implement this method to get the key of a task based on task config
|
||||
self.flt_func = flt_func # determine whether a task can be retained based on task config
|
||||
|
||||
def __call__(self, exp_name) -> Union[pd.Series, dict]:
|
||||
# TODO;
|
||||
@@ -26,7 +26,6 @@ class RollingEnsemble:
|
||||
|
||||
recs_flt = {}
|
||||
for rid, rec in tqdm(recs.items(), desc="Loading data"):
|
||||
# rec = exp.get_recorder(recorder_id=rid)
|
||||
params = rec.load_object("param")
|
||||
if rec.status == rec.STATUS_FI:
|
||||
if self.flt_func is None or self.flt_func(params):
|
||||
|
||||
@@ -36,12 +36,8 @@ class TaskManager:
|
||||
The tasks manager assume that you will only update the tasks you fetched.
|
||||
The mongo fetch one and update will make it date updating secure.
|
||||
|
||||
Usage Examples from the CLI.
|
||||
python -m blocks.tasks.__init__ task_stat --task_pool meta_task_rule
|
||||
|
||||
|
||||
NOTE:
|
||||
- 假设: 存储在db里面的都是encode过的, 拿出来的都是decode过的
|
||||
- assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded
|
||||
"""
|
||||
|
||||
STATUS_WAITING = "waiting"
|
||||
@@ -85,7 +81,7 @@ class TaskManager:
|
||||
return {k: str(v) for k, v in flt.items()}
|
||||
|
||||
def replace_task(self, task, new_task, task_pool=None):
|
||||
# 这里的假设是从接口拿出来的都是decode过的,在接口内部的都是 encode过的
|
||||
# assume that the data out of interface was decoded and the data in interface was encoded
|
||||
new_task = self._encode_task(new_task)
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
query = {"_id": ObjectId(task["_id"])}
|
||||
|
||||
@@ -6,9 +6,19 @@ from qlib.data import D
|
||||
from qlib.config import C
|
||||
from qlib.log import get_module_logger
|
||||
from pymongo import MongoClient
|
||||
|
||||
from typing import Union
|
||||
|
||||
def get_mongodb():
|
||||
"""
|
||||
|
||||
get database in MongoDB, which means you need to declare the address and the name of database.
|
||||
for example:
|
||||
C["mongo"] = {
|
||||
"task_url" : "mongodb://localhost:27017/",
|
||||
"task_db_name" : "rolling_db"
|
||||
}
|
||||
|
||||
"""
|
||||
try:
|
||||
cfg = C["mongo"]
|
||||
except KeyError:
|
||||
@@ -20,7 +30,9 @@ def get_mongodb():
|
||||
|
||||
|
||||
class TimeAdjuster:
|
||||
"""找到合适的日期,然后adjust date"""
|
||||
"""
|
||||
find appropriate date and adjust date.
|
||||
"""
|
||||
|
||||
def __init__(self, future=False):
|
||||
self.cals = D.calendar(future=future)
|
||||
@@ -40,7 +52,7 @@ class TimeAdjuster:
|
||||
|
||||
def max(self):
|
||||
"""
|
||||
Return return the max calendar date
|
||||
Return the max calendar date
|
||||
"""
|
||||
return max(self.cals)
|
||||
|
||||
@@ -56,7 +68,7 @@ class TimeAdjuster:
|
||||
|
||||
def align_time(self, time_point, tp_type="start"):
|
||||
"""
|
||||
Align a timepoint to calendar weekdays
|
||||
Align a timepoint to calendar weekdays
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -67,7 +79,7 @@ class TimeAdjuster:
|
||||
"""
|
||||
return self.cals[self.align_idx(time_point, tp_type=tp_type)]
|
||||
|
||||
def align_seg(self, segment):
|
||||
def align_seg(self, segment: Union[dict, tuple]):
|
||||
if isinstance(segment, dict):
|
||||
return {k: self.align_seg(seg) for k, seg in segment.items()}
|
||||
elif isinstance(segment, tuple):
|
||||
@@ -75,14 +87,15 @@ class TimeAdjuster:
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def truncate(self, segment, test_start, days: int):
|
||||
def truncate(self, segment: tuple, test_start, days: int):
|
||||
"""
|
||||
truncate the segment based on the test_start date
|
||||
|
||||
Parameters
|
||||
----------
|
||||
segment :
|
||||
segment : tuple
|
||||
time segment
|
||||
test_start
|
||||
days : int
|
||||
The trading days to be truncated
|
||||
大部分情况是因为这个时间段的数据(一般是特征)会用到 `days` 天的数据
|
||||
@@ -101,7 +114,7 @@ class TimeAdjuster:
|
||||
SHIFT_SD = "sliding"
|
||||
SHIFT_EX = "expanding"
|
||||
|
||||
def shift(self, seg, step: int, rtype=SHIFT_SD):
|
||||
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD):
|
||||
"""
|
||||
shift the datatiem of segment
|
||||
|
||||
|
||||
Reference in New Issue
Block a user