From b24af7fff6311a2ff0e8e5456b359febf5d6099c Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Mon, 24 May 2021 05:07:38 +0000 Subject: [PATCH 1/6] multiprocessing support --- .../model_rolling/task_manager_rolling.py | 10 ++- .../online_srv/online_management_simulate.py | 15 +++- .../online_srv/rolling_online_management.py | 22 ++++- qlib/model/trainer.py | 88 ++++++++++++++++++- qlib/workflow/online/manager.py | 29 +++--- qlib/workflow/task/manage.py | 7 +- 6 files changed, 145 insertions(+), 26 deletions(-) diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py index 4f3ac04b1..89233b37b 100644 --- a/examples/model_rolling/task_manager_rolling.py +++ b/examples/model_rolling/task_manager_rolling.py @@ -4,6 +4,7 @@ """ This example shows how a TrainerRM works based on TaskManager with rolling tasks. After training, how to collect the rolling results will be shown in task_collecting. +Based on the ability of TaskManager, `worker` method offer a simple way for multiprocessing. """ from pprint import pprint @@ -13,10 +14,10 @@ import qlib from qlib.config import REG_CN from qlib.workflow import R from qlib.workflow.task.gen import RollingGen, task_generator -from qlib.workflow.task.manage import TaskManager +from qlib.workflow.task.manage import TaskManager, run_task from qlib.workflow.task.collect import RecorderCollector from qlib.model.ens.group import RollingGroup -from qlib.model.trainer import TrainerRM +from qlib.model.trainer import TrainerRM, task_train data_handler_config = { @@ -122,6 +123,11 @@ class RollingTaskExample: trainer = TrainerRM(self.experiment_name, self.task_pool) trainer.train(tasks) + def worker(self): + # train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker. + print("========== worker ==========") + run_task(task_train, self.task_pool, experiment_name=self.experiment_name) + def task_collecting(self): print("========== task_collecting ==========") diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py index 4bb5022ee..de6dbcb21 100644 --- a/examples/online_srv/online_management_simulate.py +++ b/examples/online_srv/online_management_simulate.py @@ -78,8 +78,8 @@ class OnlineSimulationExample: provider_uri="~/.qlib/qlib_data/cn_data", region="cn", exp_name="rolling_exp", - task_url="mongodb://10.0.0.4:27017/", - task_db_name="rolling_db", + task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR + task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR task_pool="rolling_task", rolling_step=80, start_time="2018-09-10", @@ -113,7 +113,7 @@ class OnlineSimulationExample: self.rolling_gen = RollingGen( step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None ) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time. - self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR + self.trainer = TrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR self.rolling_online_manager = OnlineManager( RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen), trainer=self.trainer, @@ -139,6 +139,15 @@ class OnlineSimulationExample: print("========== signals ==========") print(self.rolling_online_manager.get_signals()) + def worker(self): + # train tasks by other progress or machines for multiprocessing + # FIXME: only can call after finishing simulation when using DelayTrainerRM, or there will be some exception. + print("========== worker ==========") + if isinstance(self.trainer, TrainerRM): + self.trainer.worker() + else: + print(f"{type(self.trainer)} is not supported for worker.") + if __name__ == "__main__": ## to run all workflow automatically with your own parameters, use the command below diff --git a/examples/online_srv/rolling_online_management.py b/examples/online_srv/rolling_online_management.py index 25b8b2a0c..40da30db7 100644 --- a/examples/online_srv/rolling_online_management.py +++ b/examples/online_srv/rolling_online_management.py @@ -13,10 +13,12 @@ Finally, the OnlineManager will finish second routine and update all strategies. import os import fire import qlib +from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM, end_task_train, task_train from qlib.workflow import R from qlib.workflow.online.strategy import RollingStrategy from qlib.workflow.task.gen import RollingGen from qlib.workflow.online.manager import OnlineManager +from qlib.workflow.task.manage import TaskManager, run_task data_handler_config = { "start_time": "2013-01-01", @@ -80,8 +82,9 @@ class RollingOnlineExample: self, provider_uri="~/.qlib/qlib_data/cn_data", region="cn", - task_url="mongodb://10.0.0.4:27017/", - task_db_name="rolling_db", + trainer=DelayTrainerRM(), # you can choose from TrainerR, TrainerRM, DelayTrainerR, DelayTrainerRM + task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR + task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR rolling_step=550, tasks=[task_xgboost_config], add_tasks=[task_lgb_config], @@ -104,17 +107,28 @@ class RollingOnlineExample: RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD), ) ) - - self.rolling_online_manager = OnlineManager(strategies) + self.trainer = trainer + self.rolling_online_manager = OnlineManager(strategies, trainer=self.trainer) _ROLLING_MANAGER_PATH = ( ".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine. ) + def worker(self): + # train tasks by other progress or machines for multiprocessing + print("========== worker ==========") + if isinstance(self.trainer, TrainerRM): + for task in self.tasks + self.add_tasks: + name_id = task["model"]["class"] + self.trainer.worker(experiment_name=name_id) + else: + print(f"{type(self.trainer)} is not supported for worker.") + # Reset all things to the first status, be careful to save important data def reset(self): for task in self.tasks + self.add_tasks: name_id = task["model"]["class"] + TaskManager(task_pool=name_id).remove() exp = R.get_exp(experiment_name=name_id) for rid in exp.list_recorders(): exp.delete_recorder(rid) diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index fd76e6728..07bb839a2 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -12,9 +12,11 @@ In ``DelayTrainer``, the first step is only to save some necessary info to model """ import socket +import time from typing import Callable, List from qlib.data.dataset import Dataset +from qlib.log import get_module_logger from qlib.model.base import Model from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config from qlib.workflow import R @@ -190,6 +192,8 @@ class TrainerR(Trainer): Returns: List[Recorder]: a list of Recorders """ + if isinstance(tasks, dict): + tasks = [tasks] if len(tasks) == 0: return [] if train_func is None: @@ -213,6 +217,8 @@ class TrainerR(Trainer): Returns: List[Recorder]: the same list as the param. """ + if isinstance(recs, Recorder): + recs = [recs] for rec in recs: rec.set_tags(**{self.STATUS_KEY: self.STATUS_END}) return recs @@ -250,6 +256,8 @@ class DelayTrainerR(TrainerR): Returns: List[Recorder]: a list of Recorders """ + if isinstance(recs, Recorder): + recs = [recs] if end_train_func is None: end_train_func = self.end_train_func if experiment_name is None: @@ -315,6 +323,8 @@ class TrainerRM(Trainer): Returns: List[Recorder]: a list of Recorders """ + if isinstance(tasks, dict): + tasks = [tasks] if len(tasks) == 0: return [] if train_func is None: @@ -329,12 +339,24 @@ class TrainerRM(Trainer): run_task( train_func, task_pool, + query={"filter": {"$in": tasks}}, # only train these tasks experiment_name=experiment_name, before_status=before_status, after_status=after_status, **kwargs, ) + # FIXME: reset to waiting automatically + for _id in _id_list: + is_prn = False + while tm.re_query(_id)["status"] == "running": + if not is_prn: + get_module_logger("TrainerRM").warn( + f"A task (_id: {_id}) is not being trained by this Trainer. Ignore this message if it is being trained by others." + ) + is_prn = True + time.sleep(10) + recs = [] for _id in _id_list: rec = tm.re_query(_id)["res"] @@ -352,10 +374,33 @@ class TrainerRM(Trainer): Returns: List[Recorder]: the same list as the param. """ + if isinstance(recs, Recorder): + recs = [recs] for rec in recs: rec.set_tags(**{self.STATUS_KEY: self.STATUS_END}) return recs + def worker( + self, + train_func: Callable = None, + experiment_name: str = None, + ): + """ + The multiprocessing method for `train`. It can share a same task_pool with `train` and can run in other progress or other machines. + + Args: + train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method. + experiment_name (str): the experiment name, None for use default name. + """ + if train_func is None: + train_func = self.train_func + if experiment_name is None: + experiment_name = self.experiment_name + task_pool = self.task_pool + if task_pool is None: + task_pool = experiment_name + run_task(train_func, task_pool=task_pool, experiment_name=experiment_name) + class DelayTrainerRM(TrainerRM): """ @@ -395,6 +440,8 @@ class DelayTrainerRM(TrainerRM): Returns: List[Recorder]: a list of Recorders """ + if isinstance(tasks, dict): + tasks = [tasks] if len(tasks) == 0: return [] return super().train( @@ -410,8 +457,6 @@ class DelayTrainerRM(TrainerRM): Given a list of Recorder and return a list of trained Recorder. This class will finish real data loading and model fitting. - NOTE: This method will train all STATUS_PART_DONE tasks in the task pool, not only the ``recs``. - Args: recs (list): a list of Recorder, the tasks have been saved to them. end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func. @@ -421,7 +466,8 @@ class DelayTrainerRM(TrainerRM): Returns: List[Recorder]: a list of Recorders """ - + if isinstance(recs, Recorder): + recs = [recs] if end_train_func is None: end_train_func = self.end_train_func if experiment_name is None: @@ -441,6 +487,42 @@ class DelayTrainerRM(TrainerRM): before_status=TaskManager.STATUS_PART_DONE, **kwargs, ) + + # FIXME: reset to waiting automatically + tm = TaskManager(task_pool=task_pool) + for query_task in tm.query({"filter": {"$in": tasks}}): + _id = query_task["_id"] + is_prn = False + while tm.re_query(_id)["status"] == "running": + if not is_prn: + get_module_logger("DelayTrainerRM").warn( + f"A task (_id: {_id}) is not being trained by this Trainer. Ignore this message if it is being trained by others." + ) + is_prn = True + time.sleep(10) + for rec in recs: rec.set_tags(**{self.STATUS_KEY: self.STATUS_END}) return recs + + def worker(self, end_train_func=None, experiment_name: str = None): + """ + The multiprocessing method for `end_train`. It can share a same task_pool with `end_train` and can run in other progress or other machines. + + Args: + end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func. + experiment_name (str): the experiment name, None for use default name. + """ + if end_train_func is None: + end_train_func = self.end_train_func + if experiment_name is None: + experiment_name = self.experiment_name + task_pool = self.task_pool + if task_pool is None: + task_pool = experiment_name + run_task( + end_train_func, + task_pool=task_pool, + experiment_name=experiment_name, + before_status=TaskManager.STATUS_PART_DONE, + ) diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index 443cd61ad..ef6cb8dfa 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -18,10 +18,12 @@ There are 4 total situations for using different trainers in different situation ========================= =================================================================================== Situations Description ========================= =================================================================================== -Online + Trainer When you REAL want to do a routine, the Trainer will help you train the models. +Online + Trainer When you want to do a REAL routine, the Trainer will help you train the models. It + will train models task by task and strategy by strategy. -Online + DelayTrainer In normal online routine, whether Trainer or DelayTrainer will REAL train models - in this routine. So it is not necessary to use DelayTrainer when do a REAL routine. +Online + DelayTrainer When your models don't have any temporal dependence, the DelayTrainer will train + nothing until all tasks have been prepared. It makes user can train all tasks in + the end of `routine` or `first_train`. Simulation + Trainer When your models have some temporal dependence on the previous models, then you need to consider using Trainer. This means it will REAL train your models in @@ -103,17 +105,21 @@ class OnlineManager(Serializable): """ if strategies is None: strategies = self.strategies - for strategy in strategies: + models_list = [] + for strategy in strategies: self.logger.info(f"Strategy `{strategy.name_id}` begins first training...") tasks = strategy.first_tasks() models = self.trainer.train(tasks, experiment_name=strategy.name_id) - models = self.trainer.end_train(models, experiment_name=strategy.name_id) + models_list.append(models) self.logger.info(f"Finished training {len(models)} models.") - online_models = strategy.prepare_online_models(models, **model_kwargs) self.history.setdefault(self.cur_time, {})[strategy] = online_models + if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay(): + for strategy, models in zip(strategies, models_list): + models = self.trainer.end_train(models, experiment_name=strategy.name_id) + def routine( self, cur_time: Union[str, pd.Timestamp] = None, @@ -139,20 +145,22 @@ class OnlineManager(Serializable): cur_time = D.calendar(freq=self.freq).max() self.cur_time = pd.Timestamp(cur_time) # None for latest date + models_list = [] for strategy in self.strategies: self.logger.info(f"Strategy `{strategy.name_id}` begins routine...") if self.status == self.STATUS_NORMAL: strategy.tool.update_online_pred() tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs) - models = self.trainer.train(tasks) - if self.status == self.STATUS_NORMAL or not self.trainer.is_delay(): - models = self.trainer.end_train(models, experiment_name=strategy.name_id) + models = self.trainer.train(tasks, experiment_name=strategy.name_id) + models_list.append(models) self.logger.info(f"Finished training {len(models)} models.") online_models = strategy.prepare_online_models(models, **model_kwargs) self.history.setdefault(self.cur_time, {})[strategy] = online_models - if not self.trainer.is_delay(): + if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay(): + for strategy, models in zip(self.strategies, models_list): + models = self.trainer.end_train(models, experiment_name=strategy.name_id) self.prepare_signals(**signal_kwargs) def get_collector(self) -> MergeCollector: @@ -297,6 +305,7 @@ class OnlineManager(Serializable): # NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way. self.prepare_signals(**signal_kwargs) if signals_time > cur_time: + # FIXME: if use DelayTrainer and worker (and worker is faster than main progress), there are some possibilities of showing this warning. self.logger.warn( f"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models." ) diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index 658eec4d6..0e495bb0f 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -69,7 +69,7 @@ class TaskManager: ENCODE_FIELDS_PREFIX = ["def", "res"] - def __init__(self, task_pool: str = None): + def __init__(self, task_pool: str): """ Init Task Manager, remember to make the statement of MongoDB url and database name firstly. @@ -79,8 +79,7 @@ class TaskManager: the name of Collection in MongoDB """ self.mdb = get_mongodb() - if task_pool is not None: - self.task_pool = getattr(self.mdb, task_pool) + self.task_pool = getattr(self.mdb, task_pool) self.logger = get_module_logger(self.__class__.__name__) def list(self) -> list: @@ -288,7 +287,7 @@ class TaskManager: for t in self.task_pool.find(query): yield self._decode_task(t) - def re_query(self, _id): + def re_query(self, _id) -> dict: """ Use _id to query task. From ca0363ded804ad97d21d2d151ef823df9336a7c5 Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Thu, 27 May 2021 06:04:46 +0000 Subject: [PATCH 2/6] update trainer and manage --- qlib/model/trainer.py | 38 ++++++++++++------------------------ qlib/workflow/task/manage.py | 34 ++++++++++++++++++++++---------- 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 07bb839a2..ace3031ed 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -283,6 +283,9 @@ class TrainerRM(Trainer): STATUS_BEGIN = "begin_task_train" STATUS_END = "end_task_train" + # This tag is the _id in TaskManager to distinguish tasks. + TM_ID = "_id in TaskManager" + def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train): """ Init TrainerR. @@ -336,31 +339,24 @@ class TrainerRM(Trainer): task_pool = experiment_name tm = TaskManager(task_pool=task_pool) _id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB + query = {"_id": {"$in": _id_list}} run_task( train_func, task_pool, - query={"filter": {"$in": tasks}}, # only train these tasks + query=query, # only train these tasks experiment_name=experiment_name, before_status=before_status, after_status=after_status, **kwargs, ) - # FIXME: reset to waiting automatically - for _id in _id_list: - is_prn = False - while tm.re_query(_id)["status"] == "running": - if not is_prn: - get_module_logger("TrainerRM").warn( - f"A task (_id: {_id}) is not being trained by this Trainer. Ignore this message if it is being trained by others." - ) - is_prn = True - time.sleep(10) + tm.wait(query=query) recs = [] for _id in _id_list: rec = tm.re_query(_id)["res"] rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN}) + rec.set_tags(**{self.TM_ID: _id}) recs.append(rec) return recs @@ -475,31 +471,21 @@ class DelayTrainerRM(TrainerRM): task_pool = self.task_pool if task_pool is None: task_pool = experiment_name - tasks = [] + _id_list = [] for rec in recs: - tasks.append(rec.load_object("task")) + _id_list.append(rec.list_tags()[self.TM_ID]) + query = {"_id": {"$in": _id_list}} run_task( end_train_func, task_pool, - query={"filter": {"$in": tasks}}, # only train these tasks + query=query, # only train these tasks experiment_name=experiment_name, before_status=TaskManager.STATUS_PART_DONE, **kwargs, ) - # FIXME: reset to waiting automatically - tm = TaskManager(task_pool=task_pool) - for query_task in tm.query({"filter": {"$in": tasks}}): - _id = query_task["_id"] - is_prn = False - while tm.re_query(_id)["status"] == "running": - if not is_prn: - get_module_logger("DelayTrainerRM").warn( - f"A task (_id: {_id}) is not being trained by this Trainer. Ignore this message if it is being trained by others." - ) - is_prn = True - time.sleep(10) + TaskManager(task_pool=task_pool).wait(query=query) for rec in recs: rec.set_tags(**{self.STATUS_KEY: self.STATUS_END}) diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index 0e495bb0f..167087260 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -108,6 +108,15 @@ class TaskManager: def _dict_to_str(self, flt): return {k: str(v) for k, v in flt.items()} + def _decode_query(self, query): + if "_id" in query: + if isinstance(query["_id"], dict): + for key in query["_id"]: + query["_id"][key] = [ObjectId(i) for i in query["_id"][key]] + else: + query["_id"] = ObjectId(query["_id"]) + return query + def replace_task(self, task, new_task): """ Use a new task to replace a old one @@ -223,8 +232,7 @@ class TaskManager: dict: a task(document in collection) after decoding """ query = query.copy() - if "_id" in query: - query["_id"] = ObjectId(query["_id"]) + query = self._decode_query(query) query.update({"status": status}) task = self.task_pool.find_one_and_update( query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)] @@ -282,8 +290,7 @@ class TaskManager: dict: a task(document in collection) after decoding """ query = query.copy() - if "_id" in query: - query["_id"] = ObjectId(query["_id"]) + query = self._decode_query(query) for t in self.task_pool.find(query): yield self._decode_task(t) @@ -338,8 +345,7 @@ class TaskManager: """ query = query.copy() - if "_id" in query: - query["_id"] = ObjectId(query["_id"]) + query = self._decode_query(query) self.task_pool.delete_many(query) def task_stat(self, query={}) -> dict: @@ -353,8 +359,7 @@ class TaskManager: dict """ query = query.copy() - if "_id" in query: - query["_id"] = ObjectId(query["_id"]) + query = self._decode_query(query) tasks = self.query(query=query, decode=False) status_stat = {} for t in tasks: @@ -376,8 +381,7 @@ class TaskManager: def reset_status(self, query, status): query = query.copy() - if "_id" in query: - query["_id"] = ObjectId(query["_id"]) + query = self._decode_query(query) print(self.task_pool.update_many(query, {"$set": {"status": status}})) def prioritize(self, task, priority: int): @@ -401,9 +405,19 @@ class TaskManager: return sum(task_stat.values()) def wait(self, query={}): + """ + When multiprocessing, the main progress may fetch nothing from TaskManager because there are still some running tasks. + So main progress should wait until all tasks are trained well by other progress or machines. + + Args: + query (dict, optional): the query dict. Defaults to {}. + """ task_stat = self.task_stat(query) total = self._get_total(task_stat) last_undone_n = self._get_undone_n(task_stat) + if last_undone_n == 0: + return + self.logger.warn(f"Waiting for {last_undone_n} undone tasks. Please make sure they are running.") with tqdm(total=total, initial=total - last_undone_n) as pbar: while True: time.sleep(10) From 94ab4bbf3feb5496720c6359dc85cfb1766ed5dd Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Tue, 1 Jun 2021 07:45:39 +0000 Subject: [PATCH 3/6] add docs --- qlib/workflow/task/manage.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index 167087260..dd42caf65 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -24,7 +24,9 @@ from bson.binary import Binary from bson.objectid import ObjectId from pymongo.errors import InvalidDocument from qlib import auto_init, get_module_logger +import qlib from tqdm.cli import tqdm +import yaml from .utils import get_mongodb @@ -72,24 +74,26 @@ class TaskManager: def __init__(self, task_pool: str): """ Init Task Manager, remember to make the statement of MongoDB url and database name firstly. + A TaskManager instance serves a specific task pool. + The static method of this module serves the whole MongoDB. Parameters ---------- task_pool: str the name of Collection in MongoDB """ - self.mdb = get_mongodb() - self.task_pool = getattr(self.mdb, task_pool) + self.task_pool = getattr(get_mongodb(), task_pool) self.logger = get_module_logger(self.__class__.__name__) - def list(self) -> list: + @staticmethod + def list() -> list: """ - List the all collection(task_pool) of the db + List the all collection(task_pool) of the db. Returns: list """ - return self.mdb.list_collection_names() + return get_mongodb().list_collection_names() def _encode_task(self, task): for prefix in self.ENCODE_FIELDS_PREFIX: @@ -109,6 +113,16 @@ class TaskManager: return {k: str(v) for k, v in flt.items()} def _decode_query(self, query): + """ + If the query includes any `_id`, then it needs `ObjectId` to decode. + For example, when using TrainerRM, it needs query `{"_id": {"$in": _id_list}}`. Then we need to `ObjectId` every `_id` in `_id_list`. + + Args: + query (dict): query dict. Defaults to {}. + + Returns: + dict: the query after decoding. + """ if "_id" in query: if isinstance(query["_id"], dict): for key in query["_id"]: From ab6b88ce14814fe5679e4f5b5f9a016a2397c1a6 Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Tue, 1 Jun 2021 07:48:14 +0000 Subject: [PATCH 4/6] delete useless import --- qlib/workflow/task/manage.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index dd42caf65..7a85036da 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -24,9 +24,7 @@ from bson.binary import Binary from bson.objectid import ObjectId from pymongo.errors import InvalidDocument from qlib import auto_init, get_module_logger -import qlib from tqdm.cli import tqdm -import yaml from .utils import get_mongodb From 8d05cd2dafcb8f8dbf2bcdb453b5d9236d3bd766 Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Tue, 1 Jun 2021 09:40:53 +0000 Subject: [PATCH 5/6] modify tests.config.py --- .../online_srv/online_management_simulate.py | 64 ++++++++++++- .../online_srv/rolling_online_management.py | 6 +- qlib/tests/config.py | 94 ++++++++++++++----- 3 files changed, 138 insertions(+), 26 deletions(-) diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py index 5f024192f..8650859ff 100644 --- a/examples/online_srv/online_management_simulate.py +++ b/examples/online_srv/online_management_simulate.py @@ -5,6 +5,7 @@ This example is about how can simulate the OnlineManager based on rolling tasks. """ +from pprint import pprint import fire import qlib from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM @@ -13,7 +14,63 @@ from qlib.workflow.online.manager import OnlineManager from qlib.workflow.online.strategy import RollingStrategy from qlib.workflow.task.gen import RollingGen from qlib.workflow.task.manage import TaskManager -from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG +from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE + +data_handler_config = { + "start_time": "2018-01-01", + "end_time": "2018-10-31", + "fit_start_time": "2018-01-01", + "fit_end_time": "2018-03-31", + "instruments": "csi100", +} + +dataset_config = { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "Alpha158", + "module_path": "qlib.contrib.data.handler", + "kwargs": data_handler_config, + }, + "segments": { + "train": ("2018-01-01", "2018-03-31"), + "valid": ("2018-04-01", "2018-05-31"), + "test": ("2018-06-01", "2018-09-10"), + }, + }, +} + +record_config = [ + { + "class": "SignalRecord", + "module_path": "qlib.workflow.record_temp", + }, + { + "class": "SigAnaRecord", + "module_path": "qlib.workflow.record_temp", + }, +] + +# use lgb model +task_lgb_config = { + "model": { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + }, + "dataset": dataset_config, + "record": record_config, +} + +# use xgboost model +task_xgboost_config = { + "model": { + "class": "XGBModel", + "module_path": "qlib.contrib.model.xgboost", + }, + "dataset": dataset_config, + "record": record_config, +} class OnlineSimulationExample: @@ -46,7 +103,10 @@ class OnlineSimulationExample: tasks (dict or list[dict]): a set of the task config waiting for rolling and training """ if tasks is None: - tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG] + #tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE] + tasks = [task_xgboost_config, task_lgb_config] + #pprint(CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE) + #pprint(task_xgboost_config) self.exp_name = exp_name self.task_pool = task_pool self.start_time = start_time diff --git a/examples/online_srv/rolling_online_management.py b/examples/online_srv/rolling_online_management.py index b4f7245b7..99a91e027 100644 --- a/examples/online_srv/rolling_online_management.py +++ b/examples/online_srv/rolling_online_management.py @@ -18,7 +18,7 @@ from qlib.workflow import R from qlib.workflow.online.strategy import RollingStrategy from qlib.workflow.task.gen import RollingGen from qlib.workflow.online.manager import OnlineManager -from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG +from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING, CSI100_RECORD_LGB_TASK_CONFIG_ROLLING class RollingOnlineExample: @@ -34,9 +34,9 @@ class RollingOnlineExample: add_tasks=None, ): if add_tasks is None: - add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG] + add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG_ROLLING] if tasks is None: - tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG] + tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING] mongo_conf = { "task_url": task_url, # your MongoDB url "task_db_name": task_db_name, # database name diff --git a/qlib/tests/config.py b/qlib/tests/config.py index 80461f6f9..c61b5651e 100644 --- a/qlib/tests/config.py +++ b/qlib/tests/config.py @@ -43,17 +43,29 @@ RECORD_CONFIG = [ ] -def get_data_handler_config(market=CSI300_MARKET): +def get_data_handler_config( + start_time="2008-01-01", + end_time="2020-08-01", + fit_start_time="2008-01-01", + fit_end_time="2014-12-31", + instruments=CSI300_MARKET, +): return { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, + "start_time": start_time, + "end_time": end_time, + "fit_start_time": fit_start_time, + "fit_end_time": fit_end_time, + "instruments": instruments, } -def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLASS): +def get_dataset_config( + dataset_class=DATASET_ALPHA158_CLASS, + train=("2008-01-01", "2014-12-31"), + valid=("2015-01-01", "2016-12-31"), + test=("2017-01-01", "2020-08-01"), + handler_kwargs={"instruments": CSI300_MARKET}, +): return { "class": "DatasetH", "module_path": "qlib.data.dataset", @@ -61,48 +73,88 @@ def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLAS "handler": { "class": dataset_class, "module_path": "qlib.contrib.data.handler", - "kwargs": get_data_handler_config(market), + "kwargs": get_data_handler_config(**handler_kwargs), }, "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), + "train": train, + "valid": valid, + "test": test, }, }, } -def get_gbdt_task(market=CSI300_MARKET): +def get_gbdt_task(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}): return { "model": GBDT_MODEL, - "dataset": get_dataset_config(market), + "dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs), } -def get_record_lgb_config(market=CSI300_MARKET): +def get_record_lgb_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}): return { "model": { "class": "LGBModel", "module_path": "qlib.contrib.model.gbdt", }, - "dataset": get_dataset_config(market), + "dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs), "record": RECORD_CONFIG, } -def get_record_xgboost_config(market=CSI300_MARKET): +def get_record_xgboost_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}): return { "model": { "class": "XGBModel", "module_path": "qlib.contrib.model.xgboost", }, - "dataset": get_dataset_config(market), + "dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs), "record": RECORD_CONFIG, } -CSI300_DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET) -CSI300_GBDT_TASK = get_gbdt_task(market=CSI300_MARKET) +CSI300_DATASET_CONFIG = get_dataset_config(handler_kwargs={"instruments": CSI300_MARKET}) +CSI300_GBDT_TASK = get_gbdt_task(handler_kwargs={"instruments": CSI300_MARKET}) -CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(market=CSI100_MARKET) -CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(market=CSI100_MARKET) +CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(handler_kwargs={"instruments": CSI100_MARKET}) +CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(handler_kwargs={"instruments": CSI100_MARKET}) + +# use for rolling_online_managment.py +ROLLING_HANDLER_CONFIG = { + "start_time": "2013-01-01", + "end_time": "2020-09-25", + "fit_start_time": "2013-01-01", + "fit_end_time": "2014-12-31", + "instruments": CSI100_MARKET, +} +ROLLING_DATASET_CONFIG = { + "train": ("2013-01-01", "2014-12-31"), + "valid": ("2015-01-01", "2015-12-31"), + "test": ("2016-01-01", "2020-07-10"), +} +CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING = get_record_xgboost_config( + dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG +) +CSI100_RECORD_LGB_TASK_CONFIG_ROLLING = get_record_lgb_config( + dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG +) + +# use for online_management_simulate.py +ONLINE_HANDLER_CONFIG = { + "start_time": "2018-01-01", + "end_time": "2018-10-31", + "fit_start_time": "2018-01-01", + "fit_end_time": "2018-03-31", + "instruments": CSI100_MARKET, +} +ONLINE_DATASET_CONFIG = { + "train": ("2018-01-01", "2018-03-31"), + "valid": ("2018-04-01", "2018-05-31"), + "test": ("2018-06-01", "2018-09-10"), +} +CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE = get_record_xgboost_config( + dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG +) +CSI100_RECORD_LGB_TASK_CONFIG_ONLINE = get_record_lgb_config( + dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG +) From 811d2c975e4c277651e3e87220ac4c36eb63d8d4 Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Wed, 2 Jun 2021 08:56:15 +0000 Subject: [PATCH 6/6] update & fix --- .../online_srv/online_management_simulate.py | 61 +------------------ .../online_srv/rolling_online_management.py | 1 + qlib/workflow/online/manager.py | 7 ++- 3 files changed, 7 insertions(+), 62 deletions(-) diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py index 8650859ff..bd7c4675d 100644 --- a/examples/online_srv/online_management_simulate.py +++ b/examples/online_srv/online_management_simulate.py @@ -16,62 +16,6 @@ from qlib.workflow.task.gen import RollingGen from qlib.workflow.task.manage import TaskManager from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE -data_handler_config = { - "start_time": "2018-01-01", - "end_time": "2018-10-31", - "fit_start_time": "2018-01-01", - "fit_end_time": "2018-03-31", - "instruments": "csi100", -} - -dataset_config = { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2018-01-01", "2018-03-31"), - "valid": ("2018-04-01", "2018-05-31"), - "test": ("2018-06-01", "2018-09-10"), - }, - }, -} - -record_config = [ - { - "class": "SignalRecord", - "module_path": "qlib.workflow.record_temp", - }, - { - "class": "SigAnaRecord", - "module_path": "qlib.workflow.record_temp", - }, -] - -# use lgb model -task_lgb_config = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - }, - "dataset": dataset_config, - "record": record_config, -} - -# use xgboost model -task_xgboost_config = { - "model": { - "class": "XGBModel", - "module_path": "qlib.contrib.model.xgboost", - }, - "dataset": dataset_config, - "record": record_config, -} - class OnlineSimulationExample: def __init__( @@ -103,10 +47,7 @@ class OnlineSimulationExample: tasks (dict or list[dict]): a set of the task config waiting for rolling and training """ if tasks is None: - #tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE] - tasks = [task_xgboost_config, task_lgb_config] - #pprint(CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE) - #pprint(task_xgboost_config) + tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE] self.exp_name = exp_name self.task_pool = task_pool self.start_time = start_time diff --git a/examples/online_srv/rolling_online_management.py b/examples/online_srv/rolling_online_management.py index 99a91e027..6abbbfb0e 100644 --- a/examples/online_srv/rolling_online_management.py +++ b/examples/online_srv/rolling_online_management.py @@ -19,6 +19,7 @@ from qlib.workflow.online.strategy import RollingStrategy from qlib.workflow.task.gen import RollingGen from qlib.workflow.online.manager import OnlineManager from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING, CSI100_RECORD_LGB_TASK_CONFIG_ROLLING +from qlib.workflow.task.manage import TaskManager class RollingOnlineExample: diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index ef6cb8dfa..dc1186038 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -163,17 +163,20 @@ class OnlineManager(Serializable): models = self.trainer.end_train(models, experiment_name=strategy.name_id) self.prepare_signals(**signal_kwargs) - def get_collector(self) -> MergeCollector: + def get_collector(self, **kwargs) -> MergeCollector: """ Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy. This collector can be a basis as the signals preparation. + + Args: + **kwargs: the params for get_collector. Returns: MergeCollector: the collector to merge other collectors. """ collector_dict = {} for strategy in self.strategies: - collector_dict[strategy.name_id] = strategy.get_collector() + collector_dict[strategy.name_id] = strategy.get_collector(**kwargs) return MergeCollector(collector_dict, process_list=[]) def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]):