1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 02:21:18 +08:00

update trainer and manage

This commit is contained in:
lzh222333
2021-05-27 06:04:46 +00:00
parent a467e10974
commit ca0363ded8
2 changed files with 36 additions and 36 deletions

View File

@@ -283,6 +283,9 @@ class TrainerRM(Trainer):
STATUS_BEGIN = "begin_task_train"
STATUS_END = "end_task_train"
# This tag is the _id in TaskManager to distinguish tasks.
TM_ID = "_id in TaskManager"
def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
"""
Init TrainerR.
@@ -336,31 +339,24 @@ class TrainerRM(Trainer):
task_pool = experiment_name
tm = TaskManager(task_pool=task_pool)
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
query = {"_id": {"$in": _id_list}}
run_task(
train_func,
task_pool,
query={"filter": {"$in": tasks}}, # only train these tasks
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=before_status,
after_status=after_status,
**kwargs,
)
# FIXME: reset to waiting automatically
for _id in _id_list:
is_prn = False
while tm.re_query(_id)["status"] == "running":
if not is_prn:
get_module_logger("TrainerRM").warn(
f"A task (_id: {_id}) is not being trained by this Trainer. Ignore this message if it is being trained by others."
)
is_prn = True
time.sleep(10)
tm.wait(query=query)
recs = []
for _id in _id_list:
rec = tm.re_query(_id)["res"]
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
rec.set_tags(**{self.TM_ID: _id})
recs.append(rec)
return recs
@@ -475,31 +471,21 @@ class DelayTrainerRM(TrainerRM):
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
tasks = []
_id_list = []
for rec in recs:
tasks.append(rec.load_object("task"))
_id_list.append(rec.list_tags()[self.TM_ID])
query = {"_id": {"$in": _id_list}}
run_task(
end_train_func,
task_pool,
query={"filter": {"$in": tasks}}, # only train these tasks
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
**kwargs,
)
# FIXME: reset to waiting automatically
tm = TaskManager(task_pool=task_pool)
for query_task in tm.query({"filter": {"$in": tasks}}):
_id = query_task["_id"]
is_prn = False
while tm.re_query(_id)["status"] == "running":
if not is_prn:
get_module_logger("DelayTrainerRM").warn(
f"A task (_id: {_id}) is not being trained by this Trainer. Ignore this message if it is being trained by others."
)
is_prn = True
time.sleep(10)
TaskManager(task_pool=task_pool).wait(query=query)
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})

View File

@@ -108,6 +108,15 @@ class TaskManager:
def _dict_to_str(self, flt):
return {k: str(v) for k, v in flt.items()}
def _decode_query(self, query):
if "_id" in query:
if isinstance(query["_id"], dict):
for key in query["_id"]:
query["_id"][key] = [ObjectId(i) for i in query["_id"][key]]
else:
query["_id"] = ObjectId(query["_id"])
return query
def replace_task(self, task, new_task):
"""
Use a new task to replace a old one
@@ -223,8 +232,7 @@ class TaskManager:
dict: a task(document in collection) after decoding
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query = self._decode_query(query)
query.update({"status": status})
task = self.task_pool.find_one_and_update(
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
@@ -282,8 +290,7 @@ class TaskManager:
dict: a task(document in collection) after decoding
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query = self._decode_query(query)
for t in self.task_pool.find(query):
yield self._decode_task(t)
@@ -338,8 +345,7 @@ class TaskManager:
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query = self._decode_query(query)
self.task_pool.delete_many(query)
def task_stat(self, query={}) -> dict:
@@ -353,8 +359,7 @@ class TaskManager:
dict
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query = self._decode_query(query)
tasks = self.query(query=query, decode=False)
status_stat = {}
for t in tasks:
@@ -376,8 +381,7 @@ class TaskManager:
def reset_status(self, query, status):
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query = self._decode_query(query)
print(self.task_pool.update_many(query, {"$set": {"status": status}}))
def prioritize(self, task, priority: int):
@@ -401,9 +405,19 @@ class TaskManager:
return sum(task_stat.values())
def wait(self, query={}):
"""
When multiprocessing, the main progress may fetch nothing from TaskManager because there are still some running tasks.
So main progress should wait until all tasks are trained well by other progress or machines.
Args:
query (dict, optional): the query dict. Defaults to {}.
"""
task_stat = self.task_stat(query)
total = self._get_total(task_stat)
last_undone_n = self._get_undone_n(task_stat)
if last_undone_n == 0:
return
self.logger.warn(f"Waiting for {last_undone_n} undone tasks. Please make sure they are running.")
with tqdm(total=total, initial=total - last_undone_n) as pbar:
while True:
time.sleep(10)