1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

init version of online serving and rolling

This commit is contained in:
Young
2021-02-26 09:14:40 +00:00
parent fa8f1cba06
commit 1e5cf1c174
7 changed files with 623 additions and 0 deletions

View File

View File

@@ -0,0 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Task related workflow is implemented in this folder
A typical task workflow
| Step | Description |
|-----------------------+------------------------------------------------|
| TaskGen | Generating tasks. |
| TaskManager(optional) | Manage generated tasks |
| run task | retrive tasks from TaskManager and run tasks. |
"""

View File

@@ -0,0 +1,52 @@
from qlib.workflow import R
import pandas as pd
from typing import Union
from tqdm.auto import tqdm
class RollingEnsemble:
'''
Rolling Models Ensemble based on (R)ecord
This shares nothing with Ensemble
'''
# TODO: 这边还可以加加速
def __init__(self, get_key_func, flt_func=None):
self.get_key_func = get_key_func
self.flt_func = flt_func
def __call__(self, exp_name) -> Union[pd.Series, dict]:
# TODO;
# Should we split the scripts into several sub functions?
exp = R.get_exp(experiment_name=exp_name)
# filter records
recs = exp.list_recorders()
recs_flt = {}
for rid, rec in tqdm(recs.items(), desc="Loading data"):
# rec = exp.get_recorder(recorder_id=rid)
params = rec.load_object("param")
if rec.status == rec.STATUS_FI:
if self.flt_func is None or self.flt_func(params):
rec.params = params
recs_flt[rid] = rec
# group
recs_group = {}
for _, rec in recs_flt.items():
params = rec.params
group_key = self.get_key_func(params)
recs_group.setdefault(group_key, []).append(rec)
# reduce group
reduce_group = {}
for k, rec_l in recs_group.items():
pred_l = []
for rec in rec_l:
pred_l.append(rec.load_object('pred.pkl').iloc[:, 0])
pred = pd.concat(pred_l).sort_index()
reduce_group[k] = pred
return reduce_group

133
qlib/workflow/task/gen.py Normal file
View File

@@ -0,0 +1,133 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
'''
this is a task generator
'''
import abc
import copy
import typing
from .utils import TimeAdjuster
class TaskGen(metaclass=abc.ABCMeta):
@abc.abstractmethod
def __call__(self, *args, **kwargs) -> typing.List[dict]:
"""
generate
Parameters
----------
args, kwargs:
The info for generating tasks
Example 1):
input: a specific task template
output: rolling version of the tasks
Example 2):
input: a specific task template
output: a set of tasks with different losses
Returns
-------
typing.List[dict]:
A list of tasks
"""
pass
class RollingGen(TaskGen):
ROLL_EX = TimeAdjuster.SHIFT_EX
ROLL_SD = TimeAdjuster.SHIFT_SD
def __init__(self, step: int = 40, rtype: str = ROLL_EX):
"""
Generate tasks for rolling
Parameters
----------
step : int
step to rolling
rtype : str
rolling type (expanding, rolling)
"""
self.step = step
self.rtype = rtype
self.ta = TimeAdjuster(future=True) # 为了保证test最后的日期不是None, 所以这边要改一改
self.test_key = 'test'
self.train_key = 'train'
def __call__(self, task: dict):
"""
Converting the task into a rolling task
Parameters
----------
task : dict
A dict describing a task. For example.
DEFAULT_TASK = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-20"), # Please avoid leaking the future test data into validation
"test": ("2017-01-01", "2020-08-01"),
},
},
},
# You shoud record the data in specific sequence
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
}
"""
res = []
prev_seg = None
test_end = None
while True:
t = copy.deepcopy(task)
# calculate segments
if prev_seg is None:
# First rolling
# 1) prepare the end porint
segments = copy.deepcopy(self.ta.align_seg(t['dataset']['kwargs']['segments']))
test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1]
# 2) and the init test segments
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
else:
segments = {}
try:
for k, seg in prev_seg.items():
# 决定怎么shift
if k == self.train_key and self.rtype == self.ROLL_EX:
rtype = self.ta.SHIFT_EX
else:
rtype = self.ta.SHIFT_SD
# 整段数据做shift
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
if segments[self.test_key][0] > test_end:
break
except KeyError:
# We reach the end of tasks
# No more rolling
break
t['dataset']['kwargs']['segments'] = copy.deepcopy(segments)
prev_seg = segments
res.append(t)
return res

View File

@@ -0,0 +1,290 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
A task consists of 2 parts
- tasks description: the desc will define the task
- tasks status: the status of the task
- tasks result information : A user can get the task with the task description and task result.
"""
from bson.binary import Binary
import pickle
from pymongo.errors import InvalidDocument
from fire import Fire
from bson.objectid import ObjectId
from contextlib import contextmanager
from loguru import logger
from tqdm.cli import tqdm
import time
import concurrent
import pymongo
from qlib.config import C
class TaskManager:
"""TaskManager
here is the what will a task looks like
{
'def': pickle serialized task definition. using pickle will make it easier
'filter': json-like data. This is for filtering the tasks.
'status': 'waiting' | 'running' | 'done'
'res': pickle serialized task result,
}
The tasks manager assume that you will only update the tasks you fetched.
The mongo fetch one and update will make it date updating secure.
Usage Examples from the CLI.
python -m blocks.tasks.__init__ task_stat --task_pool meta_task_rule
NOTE:
- 假设: 存储在db里面的都是encode过的 拿出来的都是decode过的
"""
STATUS_WAITING = 'waiting'
STATUS_RUNNING = 'running'
STATUS_DONE = 'done'
STATUS_PART_DONE = 'part_done'
ENCODE_FIELDS_PREFIX = ['def', 'res']
def __init__(self, task_pool=None):
self.mdb = get_mongodb()
self.task_pool = task_pool
def list(self):
return self.mdb.list_collection_names()
def _encode_task(self, task):
for prefix in self.ENCODE_FIELDS_PREFIX:
for k in list(task.keys()):
if k.startswith(prefix):
task[k] = Binary(pickle.dumps(task[k]))
return task
def _decode_task(self, task):
for prefix in self.ENCODE_FIELDS_PREFIX:
for k in list(task.keys()):
if k.startswith(prefix):
task[k] = pickle.loads(task[k])
return task
def _get_task_pool(self, task_pool=None):
if task_pool is None:
task_pool = self.task_pool
if task_pool is None:
raise ValueError('You must specify a task pool.')
if isinstance(task_pool, str):
return getattr(self.mdb, task_pool)
return task_pool
def _dict_to_str(self, flt):
return {k: str(v) for k, v in flt.items()}
def replace_task(self, task, new_task, task_pool=None):
# 这里的假设是从接口拿出来的都是decode过的在接口内部的都是 encode过的
new_task = self._encode_task(new_task)
task_pool = self._get_task_pool(task_pool)
query = {'_id': ObjectId(task['_id'])}
try:
task_pool.replace_one(query, new_task)
except InvalidDocument:
task['filter'] = self._dict_to_str(task['filter'])
task_pool.replace_one(query, new_task)
def insert_task(self, task, task_pool=None):
task_pool = self._get_task_pool(task_pool)
try:
task_pool.insert_one(task)
except InvalidDocument:
task['filter'] = self._dict_to_str(task['filter'])
task_pool.insert_one(task)
def insert_task_def(self, task_def, task_pool=None):
task_pool = self._get_task_pool(task_pool)
task = self._encode_task({
'def': task_def,
'filter': task_def, # FIXME: catch the raised error
'status': self.STATUS_WAITING,
})
self.insert_task(task, task_pool)
def create_task(self, task_def_l, task_pool=None, dry_run=False, print_nt=False):
task_pool = self._get_task_pool(task_pool)
new_tasks = []
for t in task_def_l:
try:
r = task_pool.find_one({'filter': t})
except InvalidDocument:
r = task_pool.find_one({'filter': self._dict_to_str(t)})
if r is None:
new_tasks.append(t)
print("Total Tasks, New Tasks:", len(task_def_l), len(new_tasks))
if print_nt: # print new task
for t in new_tasks:
print(t)
if dry_run:
return
for t in new_tasks:
self.insert_task_def(t, task_pool)
def fetch_task(self, query={}, task_pool=None):
task_pool = self._get_task_pool(task_pool)
query = query.copy()
if '_id' in query:
query['_id'] = ObjectId(query['_id'])
query.update({'status': self.STATUS_WAITING})
task = task_pool.find_one_and_update(query, {'$set': {
'status': self.STATUS_RUNNING
}},
sort=[('priority', pymongo.DESCENDING)])
# 这里我的 priority 必须是 高数优先级更高,因为 null会被在 ASCENDING时被排在最前面
if task is None:
return None
task['status'] = self.STATUS_RUNNING
return self._decode_task(task)
@contextmanager
def safe_fetch_task(self, query={}, task_pool=None):
task = self.fetch_task(query=query, task_pool=task_pool)
try:
yield task
except Exception:
if task is not None:
logger.info('Returning task before raising error')
self.return_task(task)
logger.info('Task returned')
raise
def task_fetcher_iter(self, query={}, task_pool=None):
while True:
with self.safe_fetch_task(query=query, task_pool=task_pool) as task:
if task is None:
break
yield task
def query(self, query={}, decode=True, task_pool=None):
"""query
This function may raise exception `pymongo.errors.CursorNotFound: cursor id not found` if it takes too long to iterate the generator
:param query:
:param decode:
:param task_pool:
"""
query = query.copy()
if '_id' in query:
query['_id'] = ObjectId(query['_id'])
task_pool = self._get_task_pool(task_pool)
for t in task_pool.find(query):
yield self._decode_task(t)
def commit_task_res(self, task, res, status=None, task_pool=None):
task_pool = self._get_task_pool(task_pool)
# A workaround to use the class attribute.
if status is None:
status = TaskManager.STATUS_DONE
task_pool.update_one({"_id": task['_id']}, {'$set': {'status': status, 'res': Binary(pickle.dumps(res))}})
def return_task(self, task, status=None, task_pool=None):
task_pool = self._get_task_pool(task_pool)
if status is None:
status = TaskManager.STATUS_WAITING
update_dict = {'$set': {'status': status}}
task_pool.update_one({"_id": task['_id']}, update_dict)
def remove(self, query={}, task_pool=None):
query = query.copy()
task_pool = self._get_task_pool(task_pool)
if '_id' in query:
query['_id'] = ObjectId(query['_id'])
task_pool.delete_many(query)
def task_stat(self, query={}, task_pool=None):
query = query.copy()
if '_id' in query:
query['_id'] = ObjectId(query['_id'])
tasks = self.query(task_pool=task_pool, query=query, decode=False)
status_stat = {}
for t in tasks:
status_stat[t['status']] = status_stat.get(t['status'], 0) + 1
return status_stat
def reset_waiting(self, query={}, task_pool=None):
query = query.copy()
# default query
if 'status' not in query:
query['status'] = self.STATUS_RUNNING
return self.reset_status(query=query, status=self.STATUS_WAITING, task_pool=task_pool)
def reset_status(self, query, status, task_pool=None):
query = query.copy()
task_pool = self._get_task_pool(task_pool)
if '_id' in query:
query['_id'] = ObjectId(query['_id'])
print(task_pool.update_many(query, {"$set": {"status": status}}))
def _get_undone_n(self, task_stat):
return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0)
def _get_total(self, task_stat):
return sum(task_stat.values())
def wait(self, query={}, task_pool=None):
task_stat = self.task_stat(query, task_pool)
total = self._get_total(task_stat)
last_undone_n = self._get_undone_n(task_stat)
with tqdm(total=total, initial=total - last_undone_n) as pbar:
while True:
time.sleep(10)
undone_n = self._get_undone_n(self.task_stat(query, task_pool))
pbar.update(last_undone_n - undone_n)
last_undone_n = undone_n
if undone_n == 0:
break
def __str__(self):
return f"TaskManager({self.task_pool})"
def run_task(task_func, task_pool, force_release=False, *args, **kwargs):
"""run_task.
While task pool is not empty, use task_func to fetch and run tasks in task_pool
Parameters
----------
task_func : def (task_def, *args, **kwargs) -> <res which will be committed>
the function to run the task
task_pool :
The name of the task pool
force_release :
will the program force to release the resource
args :
args
kwargs :
kwargs
"""
tm = TaskManager(task_pool)
ever_run = False
while True:
with tm.safe_fetch_task() as task:
if task is None:
break
logger.info(task['def'])
if force_release:
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
res = executor.submit(task_func, task['def'], *args, **kwargs).result()
else:
res = task_func(task['def'], *args, **kwargs)
tm.commit_task_res(task, res)
ever_run = True
return ever_run
if __name__ == '__main__':
Fire(TaskManager)

134
qlib/workflow/task/utils.py Normal file
View File

@@ -0,0 +1,134 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import bisect
import pandas as pd
from qlib.data import D
from qlib.config import C
from qlib.log import get_module_logger
from pymongo import MongoClient
def get_mongodb():
try:
cfg = C['mongo']
except KeyError:
get_module_logger("task").error("Please configure `C['mongo']` before using TaskManager")
raise
client = MongoClient(cfg['task_url'])
return client.get_database(name=cfg['task_db_name'])
class TimeAdjuster:
'''找到合适的日期然后adjust date'''
def __init__(self, future=False):
self.cals = D.calendar(future=future)
def get(self, idx: int):
"""
Get datetime by index
Parameters
----------
idx : int
index of the calendar
"""
if idx >= len(self.cals):
return None
return self.cals[idx]
def max(self):
"""
Return return the max calendar date
"""
return max(self.cals)
def align_idx(self, time_point, tp_type="start"):
time_point = pd.Timestamp(time_point)
if tp_type == 'start':
idx = bisect.bisect_left(self.cals, time_point)
elif tp_type == 'end':
idx = bisect.bisect_right(self.cals, time_point) - 1
else:
raise NotImplementedError(f"This type of input is not supported")
return idx
def align_time(self, time_point, tp_type="start"):
"""
Align a timepoint to calendar weekdays
Parameters
----------
time_point :
Time point
tp_type : str
time point type (`"start"`, `"end"`)
"""
return self.cals[self.align_idx(time_point, tp_type=tp_type)]
def align_seg(self, segment):
if isinstance(segment, dict):
return {k: self.align_seg(seg) for k, seg in segment.items()}
elif isinstance(segment, tuple):
return self.align_time(segment[0], tp_type="start"), self.align_time(segment[1], tp_type="end")
else:
raise NotImplementedError(f"This type of input is not supported")
def truncate(self, segment, test_start, days: int):
"""
truncate the segment based on the test_start date
Parameters
----------
segment :
time segment
days : int
The trading days to be truncated
大部分情况是因为这个时间段的数据(一般是特征)会用到 `days` 天的数据
"""
test_idx = self.align_idx(test_start)
if isinstance(segment, tuple):
new_seg = []
for time_point in segment:
tp_idx = min(self.align_idx(time_point), test_idx - days)
assert (tp_idx > 0)
new_seg.append(self.get(tp_idx))
return tuple(new_seg)
else:
raise NotImplementedError(f"This type of input is not supported")
SHIFT_SD = "sliding"
SHIFT_EX = "expanding"
def shift(self, seg, step: int, rtype=SHIFT_SD):
"""
shift the datatiem of segment
Parameters
----------
seg :
datetime segment
step : int
rolling step
rtype : str
rolling type ("sliding" or "expanding")
Raises
------
KeyError:
shift will raise error if the index(both start and end) is out of self.cal
"""
if isinstance(seg, tuple):
start_idx, end_idx = self.align_idx(seg[0], tp_type="start"), self.align_idx(seg[1], tp_type="end")
if rtype == self.SHIFT_SD:
start_idx += step
end_idx += step
elif rtype == self.SHIFT_EX:
end_idx += step
else:
raise NotImplementedError(f"This type of input is not supported")
if start_idx > len(self.cals):
raise KeyError("The segment is out of valid calendar")
return self.get(start_idx), self.get(end_idx)
else:
raise NotImplementedError(f"This type of input is not supported")

View File

@@ -55,6 +55,7 @@ REQUIRED = [
"tornado",
"joblib>=0.17.0",
"ruamel.yaml>=0.16.12",
"pymongo==3.7.2", # For task management
]
# Numpy include