From ca0363ded804ad97d21d2d151ef823df9336a7c5 Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Thu, 27 May 2021 06:04:46 +0000 Subject: [PATCH] update trainer and manage --- qlib/model/trainer.py | 38 ++++++++++++------------------------ qlib/workflow/task/manage.py | 34 ++++++++++++++++++++++---------- 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 07bb839a2..ace3031ed 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -283,6 +283,9 @@ class TrainerRM(Trainer): STATUS_BEGIN = "begin_task_train" STATUS_END = "end_task_train" + # This tag is the _id in TaskManager to distinguish tasks. + TM_ID = "_id in TaskManager" + def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train): """ Init TrainerR. @@ -336,31 +339,24 @@ class TrainerRM(Trainer): task_pool = experiment_name tm = TaskManager(task_pool=task_pool) _id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB + query = {"_id": {"$in": _id_list}} run_task( train_func, task_pool, - query={"filter": {"$in": tasks}}, # only train these tasks + query=query, # only train these tasks experiment_name=experiment_name, before_status=before_status, after_status=after_status, **kwargs, ) - # FIXME: reset to waiting automatically - for _id in _id_list: - is_prn = False - while tm.re_query(_id)["status"] == "running": - if not is_prn: - get_module_logger("TrainerRM").warn( - f"A task (_id: {_id}) is not being trained by this Trainer. Ignore this message if it is being trained by others." - ) - is_prn = True - time.sleep(10) + tm.wait(query=query) recs = [] for _id in _id_list: rec = tm.re_query(_id)["res"] rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN}) + rec.set_tags(**{self.TM_ID: _id}) recs.append(rec) return recs @@ -475,31 +471,21 @@ class DelayTrainerRM(TrainerRM): task_pool = self.task_pool if task_pool is None: task_pool = experiment_name - tasks = [] + _id_list = [] for rec in recs: - tasks.append(rec.load_object("task")) + _id_list.append(rec.list_tags()[self.TM_ID]) + query = {"_id": {"$in": _id_list}} run_task( end_train_func, task_pool, - query={"filter": {"$in": tasks}}, # only train these tasks + query=query, # only train these tasks experiment_name=experiment_name, before_status=TaskManager.STATUS_PART_DONE, **kwargs, ) - # FIXME: reset to waiting automatically - tm = TaskManager(task_pool=task_pool) - for query_task in tm.query({"filter": {"$in": tasks}}): - _id = query_task["_id"] - is_prn = False - while tm.re_query(_id)["status"] == "running": - if not is_prn: - get_module_logger("DelayTrainerRM").warn( - f"A task (_id: {_id}) is not being trained by this Trainer. Ignore this message if it is being trained by others." - ) - is_prn = True - time.sleep(10) + TaskManager(task_pool=task_pool).wait(query=query) for rec in recs: rec.set_tags(**{self.STATUS_KEY: self.STATUS_END}) diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index 0e495bb0f..167087260 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -108,6 +108,15 @@ class TaskManager: def _dict_to_str(self, flt): return {k: str(v) for k, v in flt.items()} + def _decode_query(self, query): + if "_id" in query: + if isinstance(query["_id"], dict): + for key in query["_id"]: + query["_id"][key] = [ObjectId(i) for i in query["_id"][key]] + else: + query["_id"] = ObjectId(query["_id"]) + return query + def replace_task(self, task, new_task): """ Use a new task to replace a old one @@ -223,8 +232,7 @@ class TaskManager: dict: a task(document in collection) after decoding """ query = query.copy() - if "_id" in query: - query["_id"] = ObjectId(query["_id"]) + query = self._decode_query(query) query.update({"status": status}) task = self.task_pool.find_one_and_update( query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)] @@ -282,8 +290,7 @@ class TaskManager: dict: a task(document in collection) after decoding """ query = query.copy() - if "_id" in query: - query["_id"] = ObjectId(query["_id"]) + query = self._decode_query(query) for t in self.task_pool.find(query): yield self._decode_task(t) @@ -338,8 +345,7 @@ class TaskManager: """ query = query.copy() - if "_id" in query: - query["_id"] = ObjectId(query["_id"]) + query = self._decode_query(query) self.task_pool.delete_many(query) def task_stat(self, query={}) -> dict: @@ -353,8 +359,7 @@ class TaskManager: dict """ query = query.copy() - if "_id" in query: - query["_id"] = ObjectId(query["_id"]) + query = self._decode_query(query) tasks = self.query(query=query, decode=False) status_stat = {} for t in tasks: @@ -376,8 +381,7 @@ class TaskManager: def reset_status(self, query, status): query = query.copy() - if "_id" in query: - query["_id"] = ObjectId(query["_id"]) + query = self._decode_query(query) print(self.task_pool.update_many(query, {"$set": {"status": status}})) def prioritize(self, task, priority: int): @@ -401,9 +405,19 @@ class TaskManager: return sum(task_stat.values()) def wait(self, query={}): + """ + When multiprocessing, the main progress may fetch nothing from TaskManager because there are still some running tasks. + So main progress should wait until all tasks are trained well by other progress or machines. + + Args: + query (dict, optional): the query dict. Defaults to {}. + """ task_stat = self.task_stat(query) total = self._get_total(task_stat) last_undone_n = self._get_undone_n(task_stat) + if last_undone_n == 0: + return + self.logger.warn(f"Waiting for {last_undone_n} undone tasks. Please make sure they are running.") with tqdm(total=total, initial=total - last_undone_n) as pbar: while True: time.sleep(10)