diff --git a/qlib/model/ens/__init__.py b/qlib/model/ens/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/qlib/workflow/task/__init__.py b/qlib/workflow/task/__init__.py new file mode 100644 index 000000000..cc338cca4 --- /dev/null +++ b/qlib/workflow/task/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +""" +Task related workflow is implemented in this folder + +A typical task workflow + +| Step | Description | +|-----------------------+------------------------------------------------| +| TaskGen | Generating tasks. | +| TaskManager(optional) | Manage generated tasks | +| run task | retrive tasks from TaskManager and run tasks. | +""" diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py new file mode 100644 index 000000000..13b5869de --- /dev/null +++ b/qlib/workflow/task/collect.py @@ -0,0 +1,52 @@ +from qlib.workflow import R +import pandas as pd +from typing import Union +from tqdm.auto import tqdm + + +class RollingEnsemble: + ''' + Rolling Models Ensemble based on (R)ecord + + This shares nothing with Ensemble + ''' + # TODO: 这边还可以加加速 + def __init__(self, get_key_func, flt_func=None): + self.get_key_func = get_key_func + self.flt_func = flt_func + + def __call__(self, exp_name) -> Union[pd.Series, dict]: + # TODO; + # Should we split the scripts into several sub functions? + exp = R.get_exp(experiment_name=exp_name) + + # filter records + recs = exp.list_recorders() + + recs_flt = {} + for rid, rec in tqdm(recs.items(), desc="Loading data"): + # rec = exp.get_recorder(recorder_id=rid) + params = rec.load_object("param") + if rec.status == rec.STATUS_FI: + if self.flt_func is None or self.flt_func(params): + rec.params = params + recs_flt[rid] = rec + + # group + recs_group = {} + for _, rec in recs_flt.items(): + params = rec.params + group_key = self.get_key_func(params) + recs_group.setdefault(group_key, []).append(rec) + + # reduce group + reduce_group = {} + for k, rec_l in recs_group.items(): + pred_l = [] + for rec in rec_l: + pred_l.append(rec.load_object('pred.pkl').iloc[:, 0]) + pred = pd.concat(pred_l).sort_index() + reduce_group[k] = pred + + return reduce_group + diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py new file mode 100644 index 000000000..66529f3a5 --- /dev/null +++ b/qlib/workflow/task/gen.py @@ -0,0 +1,133 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +''' +this is a task generator +''' +import abc +import copy +import typing +from .utils import TimeAdjuster + + +class TaskGen(metaclass=abc.ABCMeta): + @abc.abstractmethod + def __call__(self, *args, **kwargs) -> typing.List[dict]: + """ + generate + + Parameters + ---------- + args, kwargs: + The info for generating tasks + Example 1): + input: a specific task template + output: rolling version of the tasks + Example 2): + input: a specific task template + output: a set of tasks with different losses + + Returns + ------- + typing.List[dict]: + A list of tasks + """ + pass + + +class RollingGen(TaskGen): + + ROLL_EX = TimeAdjuster.SHIFT_EX + ROLL_SD = TimeAdjuster.SHIFT_SD + + def __init__(self, step: int = 40, rtype: str = ROLL_EX): + """ + Generate tasks for rolling + + Parameters + ---------- + step : int + step to rolling + rtype : str + rolling type (expanding, rolling) + """ + self.step = step + self.rtype = rtype + self.ta = TimeAdjuster(future=True) # 为了保证test最后的日期不是None, 所以这边要改一改 + + self.test_key = 'test' + self.train_key = 'train' + + def __call__(self, task: dict): + """ + Converting the task into a rolling task + + Parameters + ---------- + task : dict + A dict describing a task. For example. + + DEFAULT_TASK = { + "model": { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + }, + "dataset": { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "Alpha158", + "module_path": "qlib.contrib.data.handler", + "kwargs": data_handler_config, + }, + "segments": { + "train": ("2008-01-01", "2014-12-31"), + "valid": ("2015-01-01", "2016-12-20"), # Please avoid leaking the future test data into validation + "test": ("2017-01-01", "2020-08-01"), + }, + }, + }, + # You shoud record the data in specific sequence + # "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'], + } + """ + res = [] + + prev_seg = None + test_end = None + while True: + t = copy.deepcopy(task) + + # calculate segments + if prev_seg is None: + # First rolling + # 1) prepare the end porint + segments = copy.deepcopy(self.ta.align_seg(t['dataset']['kwargs']['segments'])) + test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1] + # 2) and the init test segments + test_start_idx = self.ta.align_idx(segments[self.test_key][0]) + segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1)) + else: + segments = {} + try: + for k, seg in prev_seg.items(): + # 决定怎么shift + if k == self.train_key and self.rtype == self.ROLL_EX: + rtype = self.ta.SHIFT_EX + else: + rtype = self.ta.SHIFT_SD + # 整段数据做shift + segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype) + if segments[self.test_key][0] > test_end: + break + except KeyError: + # We reach the end of tasks + # No more rolling + break + + t['dataset']['kwargs']['segments'] = copy.deepcopy(segments) + prev_seg = segments + res.append(t) + return res + + diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py new file mode 100644 index 000000000..6407279f0 --- /dev/null +++ b/qlib/workflow/task/manage.py @@ -0,0 +1,290 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +""" +A task consists of 2 parts +- tasks description: the desc will define the task +- tasks status: the status of the task +- tasks result information : A user can get the task with the task description and task result. + +""" +from bson.binary import Binary +import pickle +from pymongo.errors import InvalidDocument +from fire import Fire +from bson.objectid import ObjectId +from contextlib import contextmanager +from loguru import logger +from tqdm.cli import tqdm +import time +import concurrent +import pymongo +from qlib.config import C + + +class TaskManager: + """TaskManager + here is the what will a task looks like + { + 'def': pickle serialized task definition. using pickle will make it easier + 'filter': json-like data. This is for filtering the tasks. + 'status': 'waiting' | 'running' | 'done' + 'res': pickle serialized task result, + } + + The tasks manager assume that you will only update the tasks you fetched. + The mongo fetch one and update will make it date updating secure. + + Usage Examples from the CLI. + python -m blocks.tasks.__init__ task_stat --task_pool meta_task_rule + + + NOTE: + - 假设: 存储在db里面的都是encode过的, 拿出来的都是decode过的 + """ + STATUS_WAITING = 'waiting' + STATUS_RUNNING = 'running' + STATUS_DONE = 'done' + STATUS_PART_DONE = 'part_done' + + ENCODE_FIELDS_PREFIX = ['def', 'res'] + + def __init__(self, task_pool=None): + self.mdb = get_mongodb() + self.task_pool = task_pool + + def list(self): + return self.mdb.list_collection_names() + + def _encode_task(self, task): + for prefix in self.ENCODE_FIELDS_PREFIX: + for k in list(task.keys()): + if k.startswith(prefix): + task[k] = Binary(pickle.dumps(task[k])) + return task + + def _decode_task(self, task): + for prefix in self.ENCODE_FIELDS_PREFIX: + for k in list(task.keys()): + if k.startswith(prefix): + task[k] = pickle.loads(task[k]) + return task + + def _get_task_pool(self, task_pool=None): + if task_pool is None: + task_pool = self.task_pool + if task_pool is None: + raise ValueError('You must specify a task pool.') + if isinstance(task_pool, str): + return getattr(self.mdb, task_pool) + return task_pool + + def _dict_to_str(self, flt): + return {k: str(v) for k, v in flt.items()} + + def replace_task(self, task, new_task, task_pool=None): + # 这里的假设是从接口拿出来的都是decode过的,在接口内部的都是 encode过的 + new_task = self._encode_task(new_task) + task_pool = self._get_task_pool(task_pool) + query = {'_id': ObjectId(task['_id'])} + try: + task_pool.replace_one(query, new_task) + except InvalidDocument: + task['filter'] = self._dict_to_str(task['filter']) + task_pool.replace_one(query, new_task) + + def insert_task(self, task, task_pool=None): + task_pool = self._get_task_pool(task_pool) + try: + task_pool.insert_one(task) + except InvalidDocument: + task['filter'] = self._dict_to_str(task['filter']) + task_pool.insert_one(task) + + def insert_task_def(self, task_def, task_pool=None): + task_pool = self._get_task_pool(task_pool) + task = self._encode_task({ + 'def': task_def, + 'filter': task_def, # FIXME: catch the raised error + 'status': self.STATUS_WAITING, + }) + self.insert_task(task, task_pool) + + def create_task(self, task_def_l, task_pool=None, dry_run=False, print_nt=False): + task_pool = self._get_task_pool(task_pool) + new_tasks = [] + for t in task_def_l: + try: + r = task_pool.find_one({'filter': t}) + except InvalidDocument: + r = task_pool.find_one({'filter': self._dict_to_str(t)}) + if r is None: + new_tasks.append(t) + print("Total Tasks, New Tasks:", len(task_def_l), len(new_tasks)) + + if print_nt: # print new task + for t in new_tasks: + print(t) + + if dry_run: + return + + for t in new_tasks: + self.insert_task_def(t, task_pool) + + def fetch_task(self, query={}, task_pool=None): + task_pool = self._get_task_pool(task_pool) + query = query.copy() + if '_id' in query: + query['_id'] = ObjectId(query['_id']) + query.update({'status': self.STATUS_WAITING}) + task = task_pool.find_one_and_update(query, {'$set': { + 'status': self.STATUS_RUNNING + }}, + sort=[('priority', pymongo.DESCENDING)]) + # 这里我的 priority 必须是 高数优先级更高,因为 null会被在 ASCENDING时被排在最前面 + if task is None: + return None + task['status'] = self.STATUS_RUNNING + return self._decode_task(task) + + @contextmanager + def safe_fetch_task(self, query={}, task_pool=None): + task = self.fetch_task(query=query, task_pool=task_pool) + try: + yield task + except Exception: + if task is not None: + logger.info('Returning task before raising error') + self.return_task(task) + logger.info('Task returned') + raise + + def task_fetcher_iter(self, query={}, task_pool=None): + while True: + with self.safe_fetch_task(query=query, task_pool=task_pool) as task: + if task is None: + break + yield task + + def query(self, query={}, decode=True, task_pool=None): + """query + This function may raise exception `pymongo.errors.CursorNotFound: cursor id not found` if it takes too long to iterate the generator + + :param query: + :param decode: + :param task_pool: + """ + query = query.copy() + if '_id' in query: + query['_id'] = ObjectId(query['_id']) + task_pool = self._get_task_pool(task_pool) + for t in task_pool.find(query): + yield self._decode_task(t) + + def commit_task_res(self, task, res, status=None, task_pool=None): + task_pool = self._get_task_pool(task_pool) + # A workaround to use the class attribute. + if status is None: + status = TaskManager.STATUS_DONE + task_pool.update_one({"_id": task['_id']}, {'$set': {'status': status, 'res': Binary(pickle.dumps(res))}}) + + def return_task(self, task, status=None, task_pool=None): + task_pool = self._get_task_pool(task_pool) + if status is None: + status = TaskManager.STATUS_WAITING + update_dict = {'$set': {'status': status}} + task_pool.update_one({"_id": task['_id']}, update_dict) + + def remove(self, query={}, task_pool=None): + query = query.copy() + task_pool = self._get_task_pool(task_pool) + if '_id' in query: + query['_id'] = ObjectId(query['_id']) + task_pool.delete_many(query) + + def task_stat(self, query={}, task_pool=None): + query = query.copy() + if '_id' in query: + query['_id'] = ObjectId(query['_id']) + tasks = self.query(task_pool=task_pool, query=query, decode=False) + status_stat = {} + for t in tasks: + status_stat[t['status']] = status_stat.get(t['status'], 0) + 1 + return status_stat + + def reset_waiting(self, query={}, task_pool=None): + query = query.copy() + # default query + if 'status' not in query: + query['status'] = self.STATUS_RUNNING + return self.reset_status(query=query, status=self.STATUS_WAITING, task_pool=task_pool) + + def reset_status(self, query, status, task_pool=None): + query = query.copy() + task_pool = self._get_task_pool(task_pool) + if '_id' in query: + query['_id'] = ObjectId(query['_id']) + print(task_pool.update_many(query, {"$set": {"status": status}})) + + def _get_undone_n(self, task_stat): + return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0) + + def _get_total(self, task_stat): + return sum(task_stat.values()) + + def wait(self, query={}, task_pool=None): + task_stat = self.task_stat(query, task_pool) + total = self._get_total(task_stat) + last_undone_n = self._get_undone_n(task_stat) + with tqdm(total=total, initial=total - last_undone_n) as pbar: + while True: + time.sleep(10) + undone_n = self._get_undone_n(self.task_stat(query, task_pool)) + pbar.update(last_undone_n - undone_n) + last_undone_n = undone_n + if undone_n == 0: + break + + def __str__(self): + return f"TaskManager({self.task_pool})" + + +def run_task(task_func, task_pool, force_release=False, *args, **kwargs): + """run_task. + While task pool is not empty, use task_func to fetch and run tasks in task_pool + + Parameters + ---------- + task_func : def (task_def, *args, **kwargs) -> + the function to run the task + task_pool : + The name of the task pool + force_release : + will the program force to release the resource + args : + args + kwargs : + kwargs + """ + tm = TaskManager(task_pool) + + ever_run = False + + while True: + with tm.safe_fetch_task() as task: + if task is None: + break + logger.info(task['def']) + if force_release: + with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: + res = executor.submit(task_func, task['def'], *args, **kwargs).result() + else: + res = task_func(task['def'], *args, **kwargs) + tm.commit_task_res(task, res) + ever_run = True + + return ever_run + + +if __name__ == '__main__': + Fire(TaskManager) diff --git a/qlib/workflow/task/utils.py b/qlib/workflow/task/utils.py new file mode 100644 index 000000000..3d8fe8996 --- /dev/null +++ b/qlib/workflow/task/utils.py @@ -0,0 +1,134 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import bisect +import pandas as pd +from qlib.data import D +from qlib.config import C +from qlib.log import get_module_logger +from pymongo import MongoClient + + +def get_mongodb(): + try: + cfg = C['mongo'] + except KeyError: + get_module_logger("task").error("Please configure `C['mongo']` before using TaskManager") + raise + + client = MongoClient(cfg['task_url']) + return client.get_database(name=cfg['task_db_name']) + + +class TimeAdjuster: + '''找到合适的日期,然后adjust date''' + def __init__(self, future=False): + self.cals = D.calendar(future=future) + + def get(self, idx: int): + """ + Get datetime by index + + Parameters + ---------- + idx : int + index of the calendar + """ + if idx >= len(self.cals): + return None + return self.cals[idx] + + def max(self): + """ + Return return the max calendar date + """ + return max(self.cals) + + def align_idx(self, time_point, tp_type="start"): + time_point = pd.Timestamp(time_point) + if tp_type == 'start': + idx = bisect.bisect_left(self.cals, time_point) + elif tp_type == 'end': + idx = bisect.bisect_right(self.cals, time_point) - 1 + else: + raise NotImplementedError(f"This type of input is not supported") + return idx + + def align_time(self, time_point, tp_type="start"): + """ + Align a timepoint to calendar weekdays + + Parameters + ---------- + time_point : + Time point + tp_type : str + time point type (`"start"`, `"end"`) + """ + return self.cals[self.align_idx(time_point, tp_type=tp_type)] + + def align_seg(self, segment): + if isinstance(segment, dict): + return {k: self.align_seg(seg) for k, seg in segment.items()} + elif isinstance(segment, tuple): + return self.align_time(segment[0], tp_type="start"), self.align_time(segment[1], tp_type="end") + else: + raise NotImplementedError(f"This type of input is not supported") + + def truncate(self, segment, test_start, days: int): + """ + truncate the segment based on the test_start date + + Parameters + ---------- + segment : + time segment + days : int + The trading days to be truncated + 大部分情况是因为这个时间段的数据(一般是特征)会用到 `days` 天的数据 + """ + test_idx = self.align_idx(test_start) + if isinstance(segment, tuple): + new_seg = [] + for time_point in segment: + tp_idx = min(self.align_idx(time_point), test_idx - days) + assert (tp_idx > 0) + new_seg.append(self.get(tp_idx)) + return tuple(new_seg) + else: + raise NotImplementedError(f"This type of input is not supported") + + SHIFT_SD = "sliding" + SHIFT_EX = "expanding" + + def shift(self, seg, step: int, rtype=SHIFT_SD): + """ + shift the datatiem of segment + + Parameters + ---------- + seg : + datetime segment + step : int + rolling step + rtype : str + rolling type ("sliding" or "expanding") + + Raises + ------ + KeyError: + shift will raise error if the index(both start and end) is out of self.cal + """ + if isinstance(seg, tuple): + start_idx, end_idx = self.align_idx(seg[0], tp_type="start"), self.align_idx(seg[1], tp_type="end") + if rtype == self.SHIFT_SD: + start_idx += step + end_idx += step + elif rtype == self.SHIFT_EX: + end_idx += step + else: + raise NotImplementedError(f"This type of input is not supported") + if start_idx > len(self.cals): + raise KeyError("The segment is out of valid calendar") + return self.get(start_idx), self.get(end_idx) + else: + raise NotImplementedError(f"This type of input is not supported") diff --git a/setup.py b/setup.py index f759945fd..c83d092a3 100644 --- a/setup.py +++ b/setup.py @@ -55,6 +55,7 @@ REQUIRED = [ "tornado", "joblib>=0.17.0", "ruamel.yaml>=0.16.12", + "pymongo==3.7.2", # For task management ] # Numpy include