mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 02:21:18 +08:00
update trainer and manage
This commit is contained in:
@@ -283,6 +283,9 @@ class TrainerRM(Trainer):
|
||||
STATUS_BEGIN = "begin_task_train"
|
||||
STATUS_END = "end_task_train"
|
||||
|
||||
# This tag is the _id in TaskManager to distinguish tasks.
|
||||
TM_ID = "_id in TaskManager"
|
||||
|
||||
def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
|
||||
"""
|
||||
Init TrainerR.
|
||||
@@ -336,31 +339,24 @@ class TrainerRM(Trainer):
|
||||
task_pool = experiment_name
|
||||
tm = TaskManager(task_pool=task_pool)
|
||||
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
|
||||
query = {"_id": {"$in": _id_list}}
|
||||
run_task(
|
||||
train_func,
|
||||
task_pool,
|
||||
query={"filter": {"$in": tasks}}, # only train these tasks
|
||||
query=query, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=before_status,
|
||||
after_status=after_status,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# FIXME: reset to waiting automatically
|
||||
for _id in _id_list:
|
||||
is_prn = False
|
||||
while tm.re_query(_id)["status"] == "running":
|
||||
if not is_prn:
|
||||
get_module_logger("TrainerRM").warn(
|
||||
f"A task (_id: {_id}) is not being trained by this Trainer. Ignore this message if it is being trained by others."
|
||||
)
|
||||
is_prn = True
|
||||
time.sleep(10)
|
||||
tm.wait(query=query)
|
||||
|
||||
recs = []
|
||||
for _id in _id_list:
|
||||
rec = tm.re_query(_id)["res"]
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
|
||||
rec.set_tags(**{self.TM_ID: _id})
|
||||
recs.append(rec)
|
||||
return recs
|
||||
|
||||
@@ -475,31 +471,21 @@ class DelayTrainerRM(TrainerRM):
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
tasks = []
|
||||
_id_list = []
|
||||
for rec in recs:
|
||||
tasks.append(rec.load_object("task"))
|
||||
_id_list.append(rec.list_tags()[self.TM_ID])
|
||||
|
||||
query = {"_id": {"$in": _id_list}}
|
||||
run_task(
|
||||
end_train_func,
|
||||
task_pool,
|
||||
query={"filter": {"$in": tasks}}, # only train these tasks
|
||||
query=query, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# FIXME: reset to waiting automatically
|
||||
tm = TaskManager(task_pool=task_pool)
|
||||
for query_task in tm.query({"filter": {"$in": tasks}}):
|
||||
_id = query_task["_id"]
|
||||
is_prn = False
|
||||
while tm.re_query(_id)["status"] == "running":
|
||||
if not is_prn:
|
||||
get_module_logger("DelayTrainerRM").warn(
|
||||
f"A task (_id: {_id}) is not being trained by this Trainer. Ignore this message if it is being trained by others."
|
||||
)
|
||||
is_prn = True
|
||||
time.sleep(10)
|
||||
TaskManager(task_pool=task_pool).wait(query=query)
|
||||
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
|
||||
@@ -108,6 +108,15 @@ class TaskManager:
|
||||
def _dict_to_str(self, flt):
|
||||
return {k: str(v) for k, v in flt.items()}
|
||||
|
||||
def _decode_query(self, query):
|
||||
if "_id" in query:
|
||||
if isinstance(query["_id"], dict):
|
||||
for key in query["_id"]:
|
||||
query["_id"][key] = [ObjectId(i) for i in query["_id"][key]]
|
||||
else:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
return query
|
||||
|
||||
def replace_task(self, task, new_task):
|
||||
"""
|
||||
Use a new task to replace a old one
|
||||
@@ -223,8 +232,7 @@ class TaskManager:
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
query.update({"status": status})
|
||||
task = self.task_pool.find_one_and_update(
|
||||
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
|
||||
@@ -282,8 +290,7 @@ class TaskManager:
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
for t in self.task_pool.find(query):
|
||||
yield self._decode_task(t)
|
||||
|
||||
@@ -338,8 +345,7 @@ class TaskManager:
|
||||
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
self.task_pool.delete_many(query)
|
||||
|
||||
def task_stat(self, query={}) -> dict:
|
||||
@@ -353,8 +359,7 @@ class TaskManager:
|
||||
dict
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
tasks = self.query(query=query, decode=False)
|
||||
status_stat = {}
|
||||
for t in tasks:
|
||||
@@ -376,8 +381,7 @@ class TaskManager:
|
||||
|
||||
def reset_status(self, query, status):
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
print(self.task_pool.update_many(query, {"$set": {"status": status}}))
|
||||
|
||||
def prioritize(self, task, priority: int):
|
||||
@@ -401,9 +405,19 @@ class TaskManager:
|
||||
return sum(task_stat.values())
|
||||
|
||||
def wait(self, query={}):
|
||||
"""
|
||||
When multiprocessing, the main progress may fetch nothing from TaskManager because there are still some running tasks.
|
||||
So main progress should wait until all tasks are trained well by other progress or machines.
|
||||
|
||||
Args:
|
||||
query (dict, optional): the query dict. Defaults to {}.
|
||||
"""
|
||||
task_stat = self.task_stat(query)
|
||||
total = self._get_total(task_stat)
|
||||
last_undone_n = self._get_undone_n(task_stat)
|
||||
if last_undone_n == 0:
|
||||
return
|
||||
self.logger.warn(f"Waiting for {last_undone_n} undone tasks. Please make sure they are running.")
|
||||
with tqdm(total=total, initial=total - last_undone_n) as pbar:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
|
||||
Reference in New Issue
Block a user