1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 02:50:58 +08:00

Merge pull request #435 from you-n-g/online_srv

Multiprocessing support for Online Serving
This commit is contained in:
you-n-g
2021-06-02 17:12:19 +08:00
committed by GitHub
7 changed files with 257 additions and 70 deletions

View File

@@ -4,6 +4,7 @@
"""
This example shows how a TrainerRM works based on TaskManager with rolling tasks.
After training, how to collect the rolling results will be shown in task_collecting.
Based on the ability of TaskManager, `worker` method offer a simple way for multiprocessing.
"""
from pprint import pprint
@@ -13,7 +14,7 @@ import qlib
from qlib.config import REG_CN
from qlib.workflow import R
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.task.manage import TaskManager, run_task
from qlib.workflow.task.collect import RecorderCollector
from qlib.model.ens.group import RollingGroup
from qlib.model.trainer import TrainerRM
@@ -68,6 +69,11 @@ class RollingTaskExample:
trainer = TrainerRM(self.experiment_name, self.task_pool)
trainer.train(tasks)
def worker(self):
# train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker.
print("========== worker ==========")
run_task(task_train, self.task_pool, experiment_name=self.experiment_name)
def task_collecting(self):
print("========== task_collecting ==========")

View File

@@ -5,6 +5,7 @@
This example is about how can simulate the OnlineManager based on rolling tasks.
"""
from pprint import pprint
import fire
import qlib
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
@@ -13,7 +14,7 @@ from qlib.workflow.online.manager import OnlineManager
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.task.manage import TaskManager
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE
class OnlineSimulationExample:
@@ -22,8 +23,8 @@ class OnlineSimulationExample:
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
exp_name="rolling_exp",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
task_pool="rolling_task",
rolling_step=80,
start_time="2018-09-10",
@@ -46,7 +47,7 @@ class OnlineSimulationExample:
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
"""
if tasks is None:
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE]
self.exp_name = exp_name
self.task_pool = task_pool
self.start_time = start_time
@@ -59,7 +60,7 @@ class OnlineSimulationExample:
self.rolling_gen = RollingGen(
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
self.trainer = TrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
self.rolling_online_manager = OnlineManager(
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
trainer=self.trainer,
@@ -85,6 +86,15 @@ class OnlineSimulationExample:
print("========== signals ==========")
print(self.rolling_online_manager.get_signals())
def worker(self):
# train tasks by other progress or machines for multiprocessing
# FIXME: only can call after finishing simulation when using DelayTrainerRM, or there will be some exception.
print("========== worker ==========")
if isinstance(self.trainer, TrainerRM):
self.trainer.worker()
else:
print(f"{type(self.trainer)} is not supported for worker.")
if __name__ == "__main__":
## to run all workflow automatically with your own parameters, use the command below

View File

@@ -13,11 +13,13 @@ Finally, the OnlineManager will finish second routine and update all strategies.
import os
import fire
import qlib
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM, end_task_train, task_train
from qlib.workflow import R
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.online.manager import OnlineManager
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING, CSI100_RECORD_LGB_TASK_CONFIG_ROLLING
from qlib.workflow.task.manage import TaskManager
class RollingOnlineExample:
@@ -25,16 +27,17 @@ class RollingOnlineExample:
self,
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
trainer=DelayTrainerRM(), # you can choose from TrainerR, TrainerRM, DelayTrainerR, DelayTrainerRM
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
rolling_step=550,
tasks=None,
add_tasks=None,
):
if add_tasks is None:
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG]
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG_ROLLING]
if tasks is None:
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG]
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING]
mongo_conf = {
"task_url": task_url, # your MongoDB url
"task_db_name": task_db_name, # database name
@@ -53,17 +56,28 @@ class RollingOnlineExample:
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
)
)
self.rolling_online_manager = OnlineManager(strategies)
self.trainer = trainer
self.rolling_online_manager = OnlineManager(strategies, trainer=self.trainer)
_ROLLING_MANAGER_PATH = (
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
)
def worker(self):
# train tasks by other progress or machines for multiprocessing
print("========== worker ==========")
if isinstance(self.trainer, TrainerRM):
for task in self.tasks + self.add_tasks:
name_id = task["model"]["class"]
self.trainer.worker(experiment_name=name_id)
else:
print(f"{type(self.trainer)} is not supported for worker.")
# Reset all things to the first status, be careful to save important data
def reset(self):
for task in self.tasks + self.add_tasks:
name_id = task["model"]["class"]
TaskManager(task_pool=name_id).remove()
exp = R.get_exp(experiment_name=name_id)
for rid in exp.list_recorders():
exp.delete_recorder(rid)

View File

@@ -12,9 +12,11 @@ In ``DelayTrainer``, the first step is only to save some necessary info to model
"""
import socket
import time
from typing import Callable, List
from qlib.data.dataset import Dataset
from qlib.log import get_module_logger
from qlib.model.base import Model
from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config
from qlib.workflow import R
@@ -190,6 +192,8 @@ class TrainerR(Trainer):
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(tasks, dict):
tasks = [tasks]
if len(tasks) == 0:
return []
if train_func is None:
@@ -213,6 +217,8 @@ class TrainerR(Trainer):
Returns:
List[Recorder]: the same list as the param.
"""
if isinstance(recs, Recorder):
recs = [recs]
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
@@ -250,6 +256,8 @@ class DelayTrainerR(TrainerR):
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(recs, Recorder):
recs = [recs]
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
@@ -275,6 +283,9 @@ class TrainerRM(Trainer):
STATUS_BEGIN = "begin_task_train"
STATUS_END = "end_task_train"
# This tag is the _id in TaskManager to distinguish tasks.
TM_ID = "_id in TaskManager"
def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
"""
Init TrainerR.
@@ -315,6 +326,8 @@ class TrainerRM(Trainer):
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(tasks, dict):
tasks = [tasks]
if len(tasks) == 0:
return []
if train_func is None:
@@ -326,19 +339,24 @@ class TrainerRM(Trainer):
task_pool = experiment_name
tm = TaskManager(task_pool=task_pool)
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
query = {"_id": {"$in": _id_list}}
run_task(
train_func,
task_pool,
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=before_status,
after_status=after_status,
**kwargs,
)
tm.wait(query=query)
recs = []
for _id in _id_list:
rec = tm.re_query(_id)["res"]
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
rec.set_tags(**{self.TM_ID: _id})
recs.append(rec)
return recs
@@ -352,10 +370,33 @@ class TrainerRM(Trainer):
Returns:
List[Recorder]: the same list as the param.
"""
if isinstance(recs, Recorder):
recs = [recs]
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
def worker(
self,
train_func: Callable = None,
experiment_name: str = None,
):
"""
The multiprocessing method for `train`. It can share a same task_pool with `train` and can run in other progress or other machines.
Args:
train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method.
experiment_name (str): the experiment name, None for use default name.
"""
if train_func is None:
train_func = self.train_func
if experiment_name is None:
experiment_name = self.experiment_name
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
run_task(train_func, task_pool=task_pool, experiment_name=experiment_name)
class DelayTrainerRM(TrainerRM):
"""
@@ -395,6 +436,8 @@ class DelayTrainerRM(TrainerRM):
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(tasks, dict):
tasks = [tasks]
if len(tasks) == 0:
return []
return super().train(
@@ -410,8 +453,6 @@ class DelayTrainerRM(TrainerRM):
Given a list of Recorder and return a list of trained Recorder.
This class will finish real data loading and model fitting.
NOTE: This method will train all STATUS_PART_DONE tasks in the task pool, not only the ``recs``.
Args:
recs (list): a list of Recorder, the tasks have been saved to them.
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
@@ -421,7 +462,8 @@ class DelayTrainerRM(TrainerRM):
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(recs, Recorder):
recs = [recs]
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
@@ -429,18 +471,44 @@ class DelayTrainerRM(TrainerRM):
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
tasks = []
_id_list = []
for rec in recs:
tasks.append(rec.load_object("task"))
_id_list.append(rec.list_tags()[self.TM_ID])
query = {"_id": {"$in": _id_list}}
run_task(
end_train_func,
task_pool,
query={"filter": {"$in": tasks}}, # only train these tasks
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
**kwargs,
)
TaskManager(task_pool=task_pool).wait(query=query)
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
def worker(self, end_train_func=None, experiment_name: str = None):
"""
The multiprocessing method for `end_train`. It can share a same task_pool with `end_train` and can run in other progress or other machines.
Args:
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
experiment_name (str): the experiment name, None for use default name.
"""
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
experiment_name = self.experiment_name
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
run_task(
end_train_func,
task_pool=task_pool,
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
)

View File

@@ -43,17 +43,29 @@ RECORD_CONFIG = [
]
def get_data_handler_config(market=CSI300_MARKET):
def get_data_handler_config(
start_time="2008-01-01",
end_time="2020-08-01",
fit_start_time="2008-01-01",
fit_end_time="2014-12-31",
instruments=CSI300_MARKET,
):
return {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
"start_time": start_time,
"end_time": end_time,
"fit_start_time": fit_start_time,
"fit_end_time": fit_end_time,
"instruments": instruments,
}
def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLASS):
def get_dataset_config(
dataset_class=DATASET_ALPHA158_CLASS,
train=("2008-01-01", "2014-12-31"),
valid=("2015-01-01", "2016-12-31"),
test=("2017-01-01", "2020-08-01"),
handler_kwargs={"instruments": CSI300_MARKET},
):
return {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
@@ -61,48 +73,88 @@ def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLAS
"handler": {
"class": dataset_class,
"module_path": "qlib.contrib.data.handler",
"kwargs": get_data_handler_config(market),
"kwargs": get_data_handler_config(**handler_kwargs),
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
"train": train,
"valid": valid,
"test": test,
},
},
}
def get_gbdt_task(market=CSI300_MARKET):
def get_gbdt_task(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
return {
"model": GBDT_MODEL,
"dataset": get_dataset_config(market),
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
}
def get_record_lgb_config(market=CSI300_MARKET):
def get_record_lgb_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
return {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": get_dataset_config(market),
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
"record": RECORD_CONFIG,
}
def get_record_xgboost_config(market=CSI300_MARKET):
def get_record_xgboost_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
return {
"model": {
"class": "XGBModel",
"module_path": "qlib.contrib.model.xgboost",
},
"dataset": get_dataset_config(market),
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
"record": RECORD_CONFIG,
}
CSI300_DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET)
CSI300_GBDT_TASK = get_gbdt_task(market=CSI300_MARKET)
CSI300_DATASET_CONFIG = get_dataset_config(handler_kwargs={"instruments": CSI300_MARKET})
CSI300_GBDT_TASK = get_gbdt_task(handler_kwargs={"instruments": CSI300_MARKET})
CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(market=CSI100_MARKET)
CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(market=CSI100_MARKET)
CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(handler_kwargs={"instruments": CSI100_MARKET})
CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(handler_kwargs={"instruments": CSI100_MARKET})
# use for rolling_online_managment.py
ROLLING_HANDLER_CONFIG = {
"start_time": "2013-01-01",
"end_time": "2020-09-25",
"fit_start_time": "2013-01-01",
"fit_end_time": "2014-12-31",
"instruments": CSI100_MARKET,
}
ROLLING_DATASET_CONFIG = {
"train": ("2013-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2015-12-31"),
"test": ("2016-01-01", "2020-07-10"),
}
CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING = get_record_xgboost_config(
dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG
)
CSI100_RECORD_LGB_TASK_CONFIG_ROLLING = get_record_lgb_config(
dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG
)
# use for online_management_simulate.py
ONLINE_HANDLER_CONFIG = {
"start_time": "2018-01-01",
"end_time": "2018-10-31",
"fit_start_time": "2018-01-01",
"fit_end_time": "2018-03-31",
"instruments": CSI100_MARKET,
}
ONLINE_DATASET_CONFIG = {
"train": ("2018-01-01", "2018-03-31"),
"valid": ("2018-04-01", "2018-05-31"),
"test": ("2018-06-01", "2018-09-10"),
}
CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE = get_record_xgboost_config(
dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG
)
CSI100_RECORD_LGB_TASK_CONFIG_ONLINE = get_record_lgb_config(
dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG
)

View File

@@ -18,10 +18,12 @@ There are 4 total situations for using different trainers in different situation
========================= ===================================================================================
Situations Description
========================= ===================================================================================
Online + Trainer When you REAL want to do a routine, the Trainer will help you train the models.
Online + Trainer When you want to do a REAL routine, the Trainer will help you train the models. It
will train models task by task and strategy by strategy.
Online + DelayTrainer In normal online routine, whether Trainer or DelayTrainer will REAL train models
in this routine. So it is not necessary to use DelayTrainer when do a REAL routine.
Online + DelayTrainer When your models don't have any temporal dependence, the DelayTrainer will train
nothing until all tasks have been prepared. It makes user can train all tasks in
the end of `routine` or `first_train`.
Simulation + Trainer When your models have some temporal dependence on the previous models, then you
need to consider using Trainer. This means it will REAL train your models in
@@ -103,17 +105,21 @@ class OnlineManager(Serializable):
"""
if strategies is None:
strategies = self.strategies
for strategy in strategies:
models_list = []
for strategy in strategies:
self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
tasks = strategy.first_tasks()
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
models_list.append(models)
self.logger.info(f"Finished training {len(models)} models.")
online_models = strategy.prepare_online_models(models, **model_kwargs)
self.history.setdefault(self.cur_time, {})[strategy] = online_models
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
for strategy, models in zip(strategies, models_list):
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
def routine(
self,
cur_time: Union[str, pd.Timestamp] = None,
@@ -139,33 +145,38 @@ class OnlineManager(Serializable):
cur_time = D.calendar(freq=self.freq).max()
self.cur_time = pd.Timestamp(cur_time) # None for latest date
models_list = []
for strategy in self.strategies:
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
if self.status == self.STATUS_NORMAL:
strategy.tool.update_online_pred()
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
models = self.trainer.train(tasks)
if self.status == self.STATUS_NORMAL or not self.trainer.is_delay():
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
models_list.append(models)
self.logger.info(f"Finished training {len(models)} models.")
online_models = strategy.prepare_online_models(models, **model_kwargs)
self.history.setdefault(self.cur_time, {})[strategy] = online_models
if not self.trainer.is_delay():
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
for strategy, models in zip(self.strategies, models_list):
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
self.prepare_signals(**signal_kwargs)
def get_collector(self) -> MergeCollector:
def get_collector(self, **kwargs) -> MergeCollector:
"""
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy.
This collector can be a basis as the signals preparation.
Args:
**kwargs: the params for get_collector.
Returns:
MergeCollector: the collector to merge other collectors.
"""
collector_dict = {}
for strategy in self.strategies:
collector_dict[strategy.name_id] = strategy.get_collector()
collector_dict[strategy.name_id] = strategy.get_collector(**kwargs)
return MergeCollector(collector_dict, process_list=[])
def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]):
@@ -297,6 +308,7 @@ class OnlineManager(Serializable):
# NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way.
self.prepare_signals(**signal_kwargs)
if signals_time > cur_time:
# FIXME: if use DelayTrainer and worker (and worker is faster than main progress), there are some possibilities of showing this warning.
self.logger.warn(
f"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models."
)

View File

@@ -69,28 +69,29 @@ class TaskManager:
ENCODE_FIELDS_PREFIX = ["def", "res"]
def __init__(self, task_pool: str = None):
def __init__(self, task_pool: str):
"""
Init Task Manager, remember to make the statement of MongoDB url and database name firstly.
A TaskManager instance serves a specific task pool.
The static method of this module serves the whole MongoDB.
Parameters
----------
task_pool: str
the name of Collection in MongoDB
"""
self.mdb = get_mongodb()
if task_pool is not None:
self.task_pool = getattr(self.mdb, task_pool)
self.task_pool = getattr(get_mongodb(), task_pool)
self.logger = get_module_logger(self.__class__.__name__)
def list(self) -> list:
@staticmethod
def list() -> list:
"""
List the all collection(task_pool) of the db
List the all collection(task_pool) of the db.
Returns:
list
"""
return self.mdb.list_collection_names()
return get_mongodb().list_collection_names()
def _encode_task(self, task):
for prefix in self.ENCODE_FIELDS_PREFIX:
@@ -109,6 +110,25 @@ class TaskManager:
def _dict_to_str(self, flt):
return {k: str(v) for k, v in flt.items()}
def _decode_query(self, query):
"""
If the query includes any `_id`, then it needs `ObjectId` to decode.
For example, when using TrainerRM, it needs query `{"_id": {"$in": _id_list}}`. Then we need to `ObjectId` every `_id` in `_id_list`.
Args:
query (dict): query dict. Defaults to {}.
Returns:
dict: the query after decoding.
"""
if "_id" in query:
if isinstance(query["_id"], dict):
for key in query["_id"]:
query["_id"][key] = [ObjectId(i) for i in query["_id"][key]]
else:
query["_id"] = ObjectId(query["_id"])
return query
def replace_task(self, task, new_task):
"""
Use a new task to replace a old one
@@ -224,8 +244,7 @@ class TaskManager:
dict: a task(document in collection) after decoding
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query = self._decode_query(query)
query.update({"status": status})
task = self.task_pool.find_one_and_update(
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
@@ -283,12 +302,11 @@ class TaskManager:
dict: a task(document in collection) after decoding
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query = self._decode_query(query)
for t in self.task_pool.find(query):
yield self._decode_task(t)
def re_query(self, _id):
def re_query(self, _id) -> dict:
"""
Use _id to query task.
@@ -339,8 +357,7 @@ class TaskManager:
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query = self._decode_query(query)
self.task_pool.delete_many(query)
def task_stat(self, query={}) -> dict:
@@ -354,8 +371,7 @@ class TaskManager:
dict
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query = self._decode_query(query)
tasks = self.query(query=query, decode=False)
status_stat = {}
for t in tasks:
@@ -377,8 +393,7 @@ class TaskManager:
def reset_status(self, query, status):
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query = self._decode_query(query)
print(self.task_pool.update_many(query, {"$set": {"status": status}}))
def prioritize(self, task, priority: int):
@@ -402,9 +417,19 @@ class TaskManager:
return sum(task_stat.values())
def wait(self, query={}):
"""
When multiprocessing, the main progress may fetch nothing from TaskManager because there are still some running tasks.
So main progress should wait until all tasks are trained well by other progress or machines.
Args:
query (dict, optional): the query dict. Defaults to {}.
"""
task_stat = self.task_stat(query)
total = self._get_total(task_stat)
last_undone_n = self._get_undone_n(task_stat)
if last_undone_n == 0:
return
self.logger.warn(f"Waiting for {last_undone_n} undone tasks. Please make sure they are running.")
with tqdm(total=total, initial=total - last_undone_n) as pbar:
while True:
time.sleep(10)