1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

bug fixed & examples fire

This commit is contained in:
lzh222333
2021-04-07 03:33:27 +00:00
parent 431a9c92c1
commit cb42e99bee
10 changed files with 250 additions and 232 deletions

View File

@@ -97,8 +97,8 @@ def task_generating():
def task_training(tasks, task_pool, exp_name):
trainer = TrainerRM()
trainer.train(tasks, exp_name, task_pool)
trainer = TrainerRM(exp_name, task_pool)
trainer.train(tasks)
# This part corresponds to "Task Collecting" in the document
@@ -119,7 +119,7 @@ def task_collecting(task_pool, exp_name):
return False
artifact = ens_workflow(
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key), RollingGroup(), rec_filter_func=my_filter
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key, rec_filter_func=my_filter), RollingGroup(),
)
print(artifact)
@@ -128,7 +128,7 @@ def main(
provider_uri="~/.qlib/qlib_data/cn_data",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
exp_name="rolling_exp",
experiment_name="rolling_exp",
task_pool="rolling_task",
):
mongo_conf = {
@@ -137,11 +137,13 @@ def main(
}
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf)
# reset(task_pool, exp_name)
# tasks = task_generating()
# task_training(tasks, task_pool, exp_name)
task_collecting(task_pool, exp_name)
reset(task_pool, experiment_name)
tasks = task_generating()
task_training(tasks, task_pool, experiment_name)
task_collecting(task_pool, experiment_name)
if __name__ == "__main__":
## to see the whole process with your own parameters, use the command below
# python update_online_pred.py main --experiment_name="your_exp_name"
fire.Fire()

View File

@@ -70,89 +70,106 @@ task_xgboost_config = {
"record": record_config,
}
class RollingOnlineExample:
def print_online_model():
print("========== print_online_model ==========")
print("Current 'online' model:")
for rid, rec in list_recorders(exp_name).items():
if rolling_online_manager.get_online_tag(rec) == rolling_online_manager.ONLINE_TAG:
print(rid)
print("Current 'next online' model:")
for rid, rec in list_recorders(exp_name).items():
if rolling_online_manager.get_online_tag(rec) == rolling_online_manager.NEXT_ONLINE_TAG:
print(rid)
def __init__(self, exp_name="rolling_exp", task_pool="rolling_task", provider_uri="~/.qlib/qlib_data/cn_data", region="cn", task_url="mongodb://10.0.0.4:27017/", task_db_name="rolling_db", rolling_step=550):
self.exp_name = exp_name
self.task_pool = task_pool
mongo_conf = {
"task_url": task_url, # your MongoDB url
"task_db_name": task_db_name, # database name
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD)
self.trainer = TrainerRM(self.exp_name, self.task_pool)
self.task_manager = TaskManager(self.task_pool)
self.rolling_online_manager = RollingOnlineManager(experiment_name=exp_name, rolling_gen=self.rolling_gen, trainer=self.trainer)
def print_online_model(self):
print("========== print_online_model ==========")
print("Current 'online' model:")
for rid, rec in list_recorders(self.exp_name).items():
if self.rolling_online_manager.get_online_tag(rec) == self.rolling_online_manager.ONLINE_TAG:
print(rid)
print("Current 'next online' model:")
for rid, rec in list_recorders(self.exp_name).items():
if self.rolling_online_manager.get_online_tag(rec) == self.rolling_online_manager.NEXT_ONLINE_TAG:
print(rid)
# This part corresponds to "Task Generating" in the document
def task_generating():
# This part corresponds to "Task Generating" in the document
def task_generating(self):
print("========== task_generating ==========")
print("========== task_generating ==========")
tasks = task_generator(
tasks=[task_xgboost_config, task_lgb_config],
generators=rolling_gen, # generate different date segment
)
tasks = task_generator(
tasks=[task_xgboost_config, task_lgb_config],
generators=self.rolling_gen, # generate different date segment
)
pprint(tasks)
pprint(tasks)
return tasks
return tasks
def task_training(tasks):
trainer.train(tasks, exp_name, task_pool)
def task_training(self, tasks):
self.trainer.train(tasks)
# This part corresponds to "Task Collecting" in the document
def task_collecting():
print("========== task_collecting ==========")
# This part corresponds to "Task Collecting" in the document
def task_collecting(self):
print("========== task_collecting ==========")
def rec_key(recorder):
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, rolling_key
def rec_key(recorder):
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, rolling_key
def my_filter(recorder):
# only choose the results of "LGBModel"
model_key, rolling_key = rec_key(recorder)
if model_key == "LGBModel":
return True
return False
def my_filter(recorder):
# only choose the results of "LGBModel"
model_key, rolling_key = rec_key(recorder)
if model_key == "LGBModel":
return True
return False
artifact = ens_workflow(
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key), RollingGroup(), rec_filter_func=my_filter
)
print(artifact)
artifact = ens_workflow(
RecorderCollector(exp_name=self.exp_name, rec_key_func=rec_key, rec_filter_func=my_filter), RollingGroup()
)
print(artifact)
# Reset all things to the first status, be careful to save important data
def reset():
print("========== reset ==========")
task_manager.remove()
exp = R.get_exp(experiment_name=exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
# Reset all things to the first status, be careful to save important data
def reset(self):
print("========== reset ==========")
self.task_manager.remove()
exp = R.get_exp(experiment_name=self.exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
# Run this firstly to see the workflow in Task Management
def first_run():
print("========== first_run ==========")
reset()
# Run this firstly to see the workflow in Task Management
def first_run(self):
print("========== first_run ==========")
self.reset()
tasks = task_generating()
task_training(tasks)
task_collecting()
tasks = self.task_generating()
self.task_training(tasks)
self.task_collecting()
latest_rec, _ = rolling_online_manager.list_latest_recorders()
rolling_online_manager.reset_online_tag(latest_rec.values())
latest_rec, _ = self.rolling_online_manager.list_latest_recorders()
self.rolling_online_manager.reset_online_tag(list(latest_rec.values()))
def routine():
print("========== routine ==========")
print_online_model()
rolling_online_manager.routine()
print_online_model()
task_collecting()
def routine(self):
print("========== routine ==========")
self.print_online_model()
self.rolling_online_manager.routine()
self.print_online_model()
self.task_collecting()
if __name__ == "__main__":
@@ -161,26 +178,7 @@ if __name__ == "__main__":
####### to update the models and predictions after the trading time, use the command below
# python task_manager_rolling_with_updating.py after_day
#################### you need to finish the configurations below #########################
provider_uri = "~/.qlib/qlib_data/cn_data" # data_dir
mongo_conf = {
"task_url": "mongodb://10.0.0.4:27017/", # your MongoDB url
"task_db_name": "rolling_db", # database name
}
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf)
exp_name = "rolling_exp" # experiment name, will be used as the experiment in MLflow
task_pool = "rolling_task" # task pool name, will be used as the document in MongoDB
rolling_step = 550
##########################################################################################
rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD)
task_manager = TaskManager(task_pool=task_pool)
trainer = TrainerRM()
rolling_online_manager = RollingOnlineManager(
experiment_name=exp_name, rolling_gen=rolling_gen, task_manager=task_manager, trainer=trainer
)
fire.Fire()
####### to define your own parameters, use `--`
# python task_manager_rolling_with_updating.py first_run --exp_name='your_exp_name' --rolling_step=40
fire.Fire(RollingOnlineExample)

View File

@@ -54,10 +54,10 @@ task = {
def first_train(experiment_name="online_srv"):
rid = task_train(task_config=task, experiment_name=experiment_name)
rec = task_train(task_config=task, experiment_name=experiment_name)
online_manager = OnlineManagerR(experiment_name)
online_manager.reset_online_tag(rid)
online_manager.reset_online_tag(rec)
def update_online_pred(experiment_name="online_srv"):
@@ -71,13 +71,17 @@ def update_online_pred(experiment_name="online_srv"):
online_manager.update_online_pred()
def main(provider_uri = "~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv"):
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
qlib.init(provider_uri=provider_uri, region=region)
first_train(experiment_name)
update_online_pred(experiment_name)
if __name__ == "__main__":
## to train a model and set it to online model, use the command below
# python update_online_pred.py first_train
## to update online predictions once a day, use the command below
# python update_online_pred.py update_online_pred
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
qlib.init(provider_uri=provider_uri, region=REG_CN)
## to see the whole process with your own parameters, use the command below
# python update_online_pred.py main --experiment_name="your_exp_name"
fire.Fire()

View File

@@ -147,7 +147,7 @@ _default_config = {
"mongo": {
"task_url": "mongodb://localhost:27017/",
"task_db_name": "default_task_db",
}
},
}
MODE_CONF = {

View File

@@ -3,9 +3,10 @@ from typing import Callable, Union
import pandas as pd
from qlib.workflow.task.collect import Collector
from qlib.utils.serial import Serializable
def ens_workflow(collector: Collector, process_list, artifacts_key=None, rec_filter_func=None, *args, **kwargs):
def ens_workflow(collector: Collector, process_list, *args, **kwargs):
"""the ensemble workflow based on collector and different dict processors.
Args:
@@ -21,7 +22,7 @@ def ens_workflow(collector: Collector, process_list, artifacts_key=None, rec_fil
Returns:
dict: the ensemble dict
"""
collect_dict = collector.collect(artifacts_key=artifacts_key, rec_filter_func=rec_filter_func)
collect_dict = collector.collect()
if not isinstance(process_list, list):
process_list = [process_list]
@@ -37,23 +38,12 @@ def ens_workflow(collector: Collector, process_list, artifacts_key=None, rec_fil
return ensemble
class Ensemble:
class Ensemble(Serializable):
"""Merge the objects in an Ensemble."""
def __init__(self, merge_func=None):
"""init Ensemble
Args:
merge_func (Callable, optional): Given a dict and return the ensemble.
For example: {Rollinga_b: object, Rollingb_c: object} -> object
Defaults to None.
"""
self._merge = merge_func
def __call__(self, ensemble_dict: dict, *args, **kwargs):
"""Merge the ensemble_dict into an ensemble object.
For example: {Rollinga_b: object, Rollingb_c: object} -> object
Args:
ensemble_dict (dict): the ensemble dict waiting for merging like {name: things}
@@ -61,38 +51,29 @@ class Ensemble:
Returns:
object: the ensemble object
"""
if isinstance(getattr(self, "_merge", None), Callable):
return self._merge(ensemble_dict, *args, **kwargs)
else:
raise NotImplementedError(f"Please specify valid merge_func.")
raise NotImplementedError(f"Please implement the `__call__` method.")
class RollingEnsemble(Ensemble):
"""Merge the rolling objects in an Ensemble"""
@staticmethod
def rolling_merge(rolling_dict: dict):
def __call__(self, ensemble_dict: dict, *args, **kwargs):
"""Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
NOTE: The values of dict must be pd.Dataframe, and have the index "datetime"
Args:
rolling_dict (dict): a dict like {"A": pd.Dataframe, "B": pd.Dataframe}.
ensemble_dict (dict): a dict like {"A": pd.Dataframe, "B": pd.Dataframe}.
The key of the dict will be ignored.
Returns:
pd.Dataframe: the complete result of rolling.
"""
artifact_list = list(rolling_dict.values())
artifact_list = list(ensemble_dict.values())
artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min())
artifact = pd.concat(artifact_list)
# If there are duplicated predition, use the latest perdiction
artifact = artifact[~artifact.index.duplicated(keep="last")]
artifact = artifact.sort_index()
return artifact
def __init__(self, merge_func=None):
super().__init__(merge_func=merge_func)
if merge_func is None:
self._merge = RollingEnsemble.rolling_merge

View File

@@ -1,8 +1,9 @@
from qlib.model.ens.ensemble import Ensemble, RollingEnsemble
from typing import Callable, Union
from qlib.utils.serial import Serializable
class Group:
class Group(Serializable):
"""Group the objects based on dict"""
def __init__(self, group_func=None, ens: Ensemble = None):
@@ -17,8 +18,8 @@ class Group:
ens (Ensemble, optional): If not None, do ensemble for grouped value after grouping.
"""
self._group = group_func
self._ens = ens
self.group = group_func
self.ens = ens
def __call__(self, ungrouped_dict: dict, *args, **kwargs):
"""Group the ungrouped_dict into different groups.
@@ -29,16 +30,16 @@ class Group:
Returns:
dict: grouped_dict like {G1: object, G2: object}
"""
if isinstance(getattr(self, "_group", None), Callable):
grouped_dict = self._group(ungrouped_dict, *args, **kwargs)
if self._ens is not None:
if isinstance(getattr(self, "group", None), Callable):
grouped_dict = self.group(ungrouped_dict, *args, **kwargs)
if self.ens is not None:
ens_dict = {}
for key, value in grouped_dict.items():
ens_dict[key] = self._ens(value)
ens_dict[key] = self.ens(value)
grouped_dict = ens_dict
return grouped_dict
else:
raise NotImplementedError(f"Please specify valid merge_func.")
raise NotImplementedError(f"Please specify valid group_func.")
class RollingGroup(Group):
@@ -65,4 +66,4 @@ class RollingGroup(Group):
def __init__(self, group_func=None, ens: Ensemble = RollingEnsemble()):
super().__init__(group_func=group_func, ens=ens)
if group_func is None:
self._group = RollingGroup.rolling_group
self.group = RollingGroup.rolling_group

View File

@@ -3,11 +3,12 @@
from qlib.utils import init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.recorder import Recorder
from qlib.workflow.record_temp import SignalRecord
from qlib.workflow.task.manage import TaskManager, run_task
def task_train(task_config: dict, experiment_name: str) -> str:
def task_train(task_config: dict, experiment_name: str) -> Recorder:
"""
task based training
@@ -20,8 +21,7 @@ def task_train(task_config: dict, experiment_name: str) -> str:
Returns
----------
rid : str
The id of the recorder of this task
Recorder : The instance of the recorder
"""
# model initiaiton
@@ -80,30 +80,40 @@ class TrainerR(Trainer):
Assumption: models were defined by `task` and the results will saved to `Recorder`
"""
def train(self, tasks: list, experiment_name: str, train_func=task_train, *args, **kwargs):
def __init__(self, experiment_name, train_func=task_train):
self.experiment_name = experiment_name
self.train_func = train_func
def train(self, tasks: list, train_func=None, *args, **kwargs):
"""Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
Args:
tasks (list): a list of definition based on `task` dict
experiment_name (str): the experiment name
train_func (Callable): the train method which need at least `task` and `experiment_name`
train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default.
Returns:
list: a list of Recorders
"""
if train_func is None:
train_func = self.train_func
recs = []
for task in tasks:
recs.append(train_func(task, experiment_name, *args, **kwargs))
recs.append(train_func(task, self.experiment_name, *args, **kwargs))
return recs
class TrainerRM(TrainerR):
class TrainerRM(Trainer):
"""Trainer based on (R)ecorder and Task(M)anager
Assumption: `task` will be saved to TaskManager and `task` will be fetched and trained from TaskManager
"""
def train(self, tasks: list, experiment_name: str, task_pool: str, train_func=task_train, *args, **kwargs):
def __init__(self, experiment_name: str, task_pool: str, train_func=task_train):
self.experiment_name = experiment_name
self.task_pool = task_pool
self.train_func = train_func
def train(self, tasks: list, train_func=None, *args, **kwargs):
"""Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
This method defaults to a single process, but TaskManager offered a great way to parallel training.
@@ -111,17 +121,18 @@ class TrainerRM(TrainerR):
Args:
tasks (list): a list of definition based on `task` dict
experiment_name (str): the experiment name
train_func (Callable): the train method which need at least `task` and `experiment_name`
train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default.
Returns:
list: a list of Recorders
"""
tm = TaskManager(task_pool=task_pool)
if train_func is None:
train_func = self.train_func
tm = TaskManager(task_pool=self.task_pool)
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
run_task(train_func, task_pool, experiment_name=experiment_name, *args, **kwargs)
run_task(train_func, self.task_pool, experiment_name=self.experiment_name, *args, **kwargs)
recs = []
for _id in _id_list:
recs.append(tm.re_query(_id)["res"])
return recs
return recs

View File

@@ -20,7 +20,7 @@ class OnlineManager(Serializable):
NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
def __init__(self, trainer: Trainer = None) -> None:
def __init__(self, trainer: Trainer = None):
self._trainer = trainer
self.logger = get_module_logger(self.__class__.__name__)
@@ -81,7 +81,8 @@ class OnlineManagerR(OnlineManager):
"""
def __init__(self, experiment_name: str, trainer: Trainer = TrainerR()) -> None:
def __init__(self, experiment_name: str, trainer: Trainer = None):
trainer = TrainerR(experiment_name)
super().__init__(trainer)
self.logger = get_module_logger(self.__class__.__name__)
self.exp_name = experiment_name
@@ -105,20 +106,22 @@ class OnlineManagerR(OnlineManager):
the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model.
"""
if recorder is None:
recorder = list_recorders(
self.exp_name, lambda rec: self.get_online_tag(rec) == OnlineManager.NEXT_ONLINE_TAG
).values()
recorder = list(
list_recorders(
self.exp_name, lambda rec: self.get_online_tag(rec) == OnlineManager.NEXT_ONLINE_TAG
).values()
)
if isinstance(recorder, Recorder):
recorder = [recorder]
if len(recorder) == 0:
self.logger.info("No 'next online' model, just use current 'online' models.")
return
recs = list_recorders(self.exp_name)
self.set_online_tag(OnlineManager.OFFLINE_TAG, recs.values())
self.set_online_tag(OnlineManager.OFFLINE_TAG, list(recs.values()))
self.set_online_tag(OnlineManager.ONLINE_TAG, recorder)
self.logger.info(f"Reset {len(recorder)} models to 'online'.")
def update_online_pred(self):
def update_online_pred(self, *args, **kwargs):
"""update all online model predictions to the latest day in Calendar"""
mu = ModelUpdater(self.exp_name)
cnt = mu.update_all_pred(lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG)
@@ -126,25 +129,24 @@ class OnlineManagerR(OnlineManager):
class RollingOnlineManager(OnlineManagerR):
"""An implementation of OnlineManager based on Rolling.
"""
"""An implementation of OnlineManager based on Rolling."""
def __init__(
self,
experiment_name: str,
rolling_gen: RollingGen,
trainer: Trainer = TrainerR(),
) -> None:
trainer: Trainer = None,
):
trainer = TrainerR(experiment_name)
super().__init__(experiment_name, trainer)
self.ta = TimeAdjuster()
self.rg = rolling_gen
self.logger = get_module_logger(self.__class__.__name__)
def prepare_signals(self):
def prepare_signals(self, *args, **kwargs):
pass
def prepare_tasks(self):
def prepare_tasks(self, *args, **kwargs):
"""prepare new tasks based on new date.
Returns:
@@ -155,7 +157,7 @@ class RollingOnlineManager(OnlineManagerR):
)
if max_test is None:
self.logger.warn(f"No latest online recorders, no new tasks.")
return None
return []
calendar_latest = self.ta.last_date()
if self.ta.cal_interval(calendar_latest, max_test[0]) > self.rg.step:
old_tasks = []
@@ -168,7 +170,7 @@ class RollingOnlineManager(OnlineManagerR):
new_tasks_tmp = task_generator(old_tasks, self.rg)
new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
return new_tasks
return None
return []
def list_latest_recorders(self, rec_filter_func=None):
"""find latest recorders based on test segments.
@@ -187,4 +189,4 @@ class RollingOnlineManager(OnlineManagerR):
for rid, rec in recs_flt.items():
if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test:
latest_rec[rid] = rec
return latest_rec, max_test
return latest_rec, max_test

View File

@@ -1,9 +1,10 @@
from abc import abstractmethod
from typing import Callable, Union
from qlib.workflow.task.utils import list_recorders
from qlib.utils.serial import Serializable
class Collector:
class Collector(Serializable):
"""The collector to collect different results"""
def collect(self, *args, **kwargs):
@@ -25,33 +26,46 @@ class Collector:
class RecorderCollector(Collector):
def __init__(
self, exp_name, artifacts_path={"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}, rec_key_func=None
) -> None:
self,
exp_name,
artifacts_path={"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"},
rec_key_func=None,
artifacts_key=None,
rec_filter_func=None,
):
"""init RecorderCollector
Args:
exp_name (str): the name of Experiment
artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}.
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
"""
self.exp_name = exp_name
self.artifacts_path = artifacts_path
if rec_key_func is None:
rec_key_func = lambda rec: rec.info["id"]
self._get_key = rec_key_func
if artifacts_key is None:
artifacts_key = self.artifacts_path.keys()
self.rec_key = rec_key_func
self.artifacts_key = artifacts_key
self.rec_filter = rec_filter_func
def collect(self, artifacts_key=None, rec_filter_func=None): # ensemble, get_group_key_func,
def collect(self, artifacts_key=None, rec_filter_func=None):
"""Collect different artifacts based on recorder after filtering.
Args:
artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
artifacts_key (str or List, optional): the artifacts key you want to get. If None, use default.
rec_filter_func (Callable, optional): filter the recorder by return True or False. If None, use default.
Returns:
dict: the dict after collected like {artifact: {rec_key: object}}
"""
if artifacts_key is None:
artifacts_key = self.artifacts_path.keys()
artifacts_key = self.artifacts_key
if rec_filter_func is None:
rec_filter_func = self.rec_filter
if isinstance(artifacts_key, str):
artifacts_key = [artifacts_key]
@@ -60,9 +74,9 @@ class RecorderCollector(Collector):
# filter records
recs_flt = list_recorders(self.exp_name, rec_filter_func)
for _, rec in recs_flt.items():
rec_key = self._get_key(rec)
rec_key = self.rec_key(rec)
for key in artifacts_key:
artifact = rec.load_object(self.artifacts_path[key])
collect_dict.setdefault(key, {})[rec_key] = artifact
return collect_dict
return collect_dict

View File

@@ -49,7 +49,7 @@ class TaskManager:
ENCODE_FIELDS_PREFIX = ["def", "res"]
def __init__(self, task_pool=None):
def __init__(self, task_pool: str):
"""
init Task Manager, remember to make the statement of MongoDB url and database name firstly.
@@ -59,9 +59,13 @@ class TaskManager:
the name of Collection in MongoDB
"""
self.mdb = get_mongodb()
self.task_pool = task_pool
self.task_pool = getattr(self.mdb, task_pool)
self.logger = get_module_logger(self.__class__.__name__)
# @property
# def task_pool(self):
# return self._task_pool
def list(self):
return self.mdb.list_collection_names()
@@ -79,39 +83,39 @@ class TaskManager:
task[k] = pickle.loads(task[k])
return task
def _get_task_pool(self, task_pool=None):
if task_pool is None:
task_pool = self.task_pool
if task_pool is None:
raise ValueError("You must specify a task pool.")
if isinstance(task_pool, str):
return getattr(self.mdb, task_pool)
return task_pool
# def _get_task_pool(self, task_pool=None):
# if task_pool is None:
# task_pool = self.task_pool
# if task_pool is None:
# raise ValueError("You must specify a task pool.")
# if isinstance(task_pool, str):
# return getattr(self.mdb, task_pool)
# return task_pool
def _dict_to_str(self, flt):
return {k: str(v) for k, v in flt.items()}
def replace_task(self, task, new_task, task_pool=None):
def replace_task(self, task, new_task):
# assume that the data out of interface was decoded and the data in interface was encoded
new_task = self._encode_task(new_task)
task_pool = self._get_task_pool(task_pool)
# task_pool = self._get_task_pool(task_pool)
query = {"_id": ObjectId(task["_id"])}
try:
task_pool.replace_one(query, new_task)
self.task_pool.replace_one(query, new_task)
except InvalidDocument:
task["filter"] = self._dict_to_str(task["filter"])
task_pool.replace_one(query, new_task)
self.task_pool.replace_one(query, new_task)
def insert_task(self, task, task_pool=None):
task_pool = self._get_task_pool(task_pool)
def insert_task(self, task):
# task_pool = self._get_task_pool(task_pool)
try:
insert_result = task_pool.insert_one(task)
insert_result = self.task_pool.insert_one(task)
except InvalidDocument:
task["filter"] = self._dict_to_str(task["filter"])
insert_result = task_pool.insert_one(task)
insert_result = self.task_pool.insert_one(task)
return insert_result
def insert_task_def(self, task_def, task_pool=None):
def insert_task_def(self, task_def):
"""
insert a task to task_pool
@@ -126,7 +130,7 @@ class TaskManager:
-------
"""
task_pool = self._get_task_pool(task_pool)
# task_pool = self._get_task_pool(task_pool)
task = self._encode_task(
{
"def": task_def,
@@ -134,10 +138,10 @@ class TaskManager:
"status": self.STATUS_WAITING,
}
)
insert_result = self.insert_task(task, task_pool)
insert_result = self.insert_task(task)
return insert_result
def create_task(self, task_def_l, task_pool=None, dry_run=False, print_nt=False):
def create_task(self, task_def_l, dry_run=False, print_nt=False):
"""
if the tasks in task_def_l is new, then insert new tasks into the task_pool
@@ -156,13 +160,13 @@ class TaskManager:
list
a list of the _id of new tasks
"""
task_pool = self._get_task_pool(task_pool)
# task_pool = self._get_task_pool(task_pool)
new_tasks = []
for t in task_def_l:
try:
r = task_pool.find_one({"filter": t})
r = self.task_pool.find_one({"filter": t})
except InvalidDocument:
r = task_pool.find_one({"filter": self._dict_to_str(t)})
r = self.task_pool.find_one({"filter": self._dict_to_str(t)})
if r is None:
new_tasks.append(t)
self.logger.info(f"Total Tasks: {len(task_def_l)}, New Tasks: {len(new_tasks)}")
@@ -176,18 +180,18 @@ class TaskManager:
_id_list = []
for t in new_tasks:
insert_result = self.insert_task_def(t, task_pool)
insert_result = self.insert_task_def(t)
_id_list.append(insert_result.inserted_id)
return _id_list
def fetch_task(self, query={}, task_pool=None):
task_pool = self._get_task_pool(task_pool)
def fetch_task(self, query={}):
# task_pool = self._get_task_pool(task_pool)
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query.update({"status": self.STATUS_WAITING})
task = task_pool.find_one_and_update(
task = self.task_pool.find_one_and_update(
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
)
# null will be at the top after sorting when using ASCENDING, so the larger the number higher, the higher the priority
@@ -197,7 +201,7 @@ class TaskManager:
return self._decode_task(task)
@contextmanager
def safe_fetch_task(self, query={}, task_pool=None):
def safe_fetch_task(self, query={}):
"""
fetch task from task_pool using query with contextmanager
@@ -212,7 +216,7 @@ class TaskManager:
-------
"""
task = self.fetch_task(query=query, task_pool=task_pool)
task = self.fetch_task(query=query)
try:
yield task
except Exception:
@@ -229,7 +233,7 @@ class TaskManager:
break
yield task
def query(self, query={}, decode=True, task_pool=None):
def query(self, query={}, decode=True):
"""
This function may raise exception `pymongo.errors.CursorNotFound: cursor id not found` if it takes too long to iterate the generator
@@ -248,29 +252,30 @@ class TaskManager:
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
task_pool = self._get_task_pool(task_pool)
for t in task_pool.find(query):
# task_pool = self._get_task_pool(task_pool)
for t in self.task_pool.find(query):
yield self._decode_task(t)
def re_query(self, _id, task_pool=None):
task_pool = self._get_task_pool(task_pool)
return task_pool.find_one({"_id": ObjectId(_id)})
def re_query(self, _id):
# task_pool = self._get_task_pool(task_pool)
t = self.task_pool.find_one({"_id": ObjectId(_id)})
return self._decode_task(t)
def commit_task_res(self, task, res, status=None, task_pool=None):
task_pool = self._get_task_pool(task_pool)
def commit_task_res(self, task, res, status=None):
# task_pool = self._get_task_pool(task_pool)
# A workaround to use the class attribute.
if status is None:
status = TaskManager.STATUS_DONE
task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}})
self.task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}})
def return_task(self, task, status=None, task_pool=None):
task_pool = self._get_task_pool(task_pool)
def return_task(self, task, status=None):
# task_pool = self._get_task_pool(task_pool)
if status is None:
status = TaskManager.STATUS_WAITING
update_dict = {"$set": {"status": status}}
task_pool.update_one({"_id": task["_id"]}, update_dict)
self.task_pool.update_one({"_id": task["_id"]}, update_dict)
def remove(self, query={}, task_pool=None):
def remove(self, query={}):
"""
remove the task using query
@@ -286,16 +291,16 @@ class TaskManager:
"""
query = query.copy()
task_pool = self._get_task_pool(task_pool)
# task_pool = self._get_task_pool(task_pool)
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
task_pool.delete_many(query)
self.task_pool.delete_many(query)
def task_stat(self, query={}, task_pool=None):
def task_stat(self, query={}):
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
tasks = self.query(task_pool=task_pool, query=query, decode=False)
tasks = self.query(query=query, decode=False)
status_stat = {}
for t in tasks:
status_stat[t["status"]] = status_stat.get(t["status"], 0) + 1
@@ -306,14 +311,14 @@ class TaskManager:
# default query
if "status" not in query:
query["status"] = self.STATUS_RUNNING
return self.reset_status(query=query, status=self.STATUS_WAITING, task_pool=task_pool)
return self.reset_status(query=query, status=self.STATUS_WAITING)
def reset_status(self, query, status, task_pool=None):
def reset_status(self, query, status):
query = query.copy()
task_pool = self._get_task_pool(task_pool)
# task_pool = self._get_task_pool(task_pool)
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
print(task_pool.update_many(query, {"$set": {"status": status}}))
print(self.task_pool.update_many(query, {"$set": {"status": status}}))
def _get_undone_n(self, task_stat):
return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0)
@@ -321,14 +326,14 @@ class TaskManager:
def _get_total(self, task_stat):
return sum(task_stat.values())
def wait(self, query={}, task_pool=None):
task_stat = self.task_stat(query, task_pool)
def wait(self, query={}):
task_stat = self.task_stat(query)
total = self._get_total(task_stat)
last_undone_n = self._get_undone_n(task_stat)
with tqdm(total=total, initial=total - last_undone_n) as pbar:
while True:
time.sleep(10)
undone_n = self._get_undone_n(self.task_stat(query, task_pool))
undone_n = self._get_undone_n(self.task_stat(query))
pbar.update(last_undone_n - undone_n)
last_undone_n = undone_n
if undone_n == 0:
@@ -365,7 +370,7 @@ def run_task(task_func, task_pool, force_release=False, *args, **kwargs):
break
get_module_logger("run_task").info(task["def"])
if force_release:
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: # what this means?
res = executor.submit(task_func, task["def"], *args, **kwargs).result()
else:
res = task_func(task["def"], *args, **kwargs)