1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

trainer & group & collect & ensemble

This commit is contained in:
lzh222333
2021-04-02 04:27:14 +00:00
parent edcd7b1ff9
commit bd7a1c11b9
12 changed files with 319 additions and 261 deletions

View File

@@ -8,9 +8,11 @@ from qlib.workflow import R
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager, run_task
from qlib.workflow.task.collect import RecorderCollector
from qlib.workflow.task.ensemble import RollingEnsemble
from qlib.model.ens.ensemble import RollingEnsemble, ens_workflow
import pandas as pd
from qlib.workflow.task.utils import list_recorders
from qlib.model.ens.group import RollingGroup
from qlib.model.trainer import TrainerRM
data_handler_config = {
"start_time": "2008-01-01",
@@ -94,24 +96,16 @@ def task_generating():
return tasks
# This part corresponds to "Task Storing" in the document
def task_storing(tasks, task_pool, exp_name):
print("========== task_storing ==========")
tm = TaskManager(task_pool=task_pool)
tm.create_task(tasks) # all tasks will be saved to MongoDB
# This part corresponds to "Task Running" in the document
def task_running(task_pool, exp_name):
print("========== task_running ==========")
run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using "task_train" method
def task_training(tasks, task_pool, exp_name):
trainer = TrainerRM()
trainer.train(tasks, exp_name, task_pool)
# This part corresponds to "Task Collecting" in the document
def task_collecting(task_pool, exp_name):
print("========== task_collecting ==========")
def get_group_key_func(recorder):
def rec_key(recorder):
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
@@ -119,14 +113,14 @@ def task_collecting(task_pool, exp_name):
def my_filter(recorder):
# only choose the results of "LGBModel"
model_key, rolling_key = get_group_key_func(recorder)
model_key, rolling_key = rec_key(recorder)
if model_key == "LGBModel":
return True
return False
collector = RecorderCollector(exp_name)
# group tasks by "get_task_key" and filter tasks by "my_filter"
artifact = collector.collect(RollingEnsemble(), get_group_key_func, rec_filter_func=my_filter)
artifact = ens_workflow(
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key), RollingGroup(), rec_filter_func=my_filter
)
print(artifact)
@@ -143,10 +137,9 @@ def main(
}
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf)
reset(task_pool, exp_name)
tasks = task_generating()
task_storing(tasks, task_pool, exp_name)
task_running(task_pool, exp_name)
# reset(task_pool, exp_name)
# tasks = task_generating()
# task_training(tasks, task_pool, exp_name)
task_collecting(task_pool, exp_name)

View File

@@ -6,10 +6,10 @@ from qlib.config import REG_CN
from qlib.model.trainer import task_train
from qlib.workflow import R
from qlib.workflow.task.collect import RecorderCollector
from qlib.workflow.task.ensemble import RollingEnsemble
from qlib.model.ens.ensemble import RollingEnsemble
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager, run_task
from qlib.workflow.task.online import RollingOnlineManager
from qlib.workflow.online.manager import RollingOnlineManager
from qlib.workflow.task.utils import list_recorders
data_handler_config = {
@@ -155,10 +155,10 @@ def first_run():
rolling_online_manager.reset_online_tag(latest_rec.values())
def after_day():
def routine():
print("========== after_day ==========")
print_online_model()
rolling_online_manager.after_day()
rolling_online_manager.routine()
print_online_model()
task_collecting()

View File

@@ -2,7 +2,7 @@ import fire
import qlib
from qlib.config import REG_CN
from qlib.model.trainer import task_train
from qlib.workflow.task.online import OnlineManagerR
from qlib.workflow.online.manager import OnlineManagerR
from qlib.workflow.task.utils import list_recorders
data_handler_config = {
@@ -52,7 +52,7 @@ task = {
}
def first_train(experiment_name="online_svr"):
def first_train(experiment_name="online_srv"):
rid = task_train(task_config=task, experiment_name=experiment_name)
@@ -60,7 +60,7 @@ def first_train(experiment_name="online_svr"):
online_manager.reset_online_tag(rid)
def update_online_pred(experiment_name="online_svr"):
def update_online_pred(experiment_name="online_srv"):
online_manager = OnlineManagerR(experiment_name)

View File

@@ -0,0 +1,98 @@
from abc import abstractmethod
from typing import Callable, Union
import pandas as pd
from qlib.workflow.task.collect import Collector
def ens_workflow(collector: Collector, process_list, artifacts_key=None, rec_filter_func=None, *args, **kwargs):
"""the ensemble workflow based on collector and different dict processors.
Args:
collector (Collector): the collector to collect the result into {result_key: things}
process_list (list or Callable): the list of processors or the instance of processor to process dict.
The processor order is same as the list order.
For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())]
artifacts_key (list, optional): the artifacts key you want to get. If None, get all artifacts.
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
Returns:
dict: the ensemble dict
"""
collect_dict = collector.collect(artifacts_key=artifacts_key, rec_filter_func=rec_filter_func)
if not isinstance(process_list, list):
process_list = [process_list]
ensemble = {}
for artifact in collect_dict:
value = collect_dict[artifact]
for process in process_list:
if not callable(process):
raise NotImplementedError(f"{type(process)} is not supported in `ens_workflow`.")
value = process(value, *args, **kwargs)
ensemble[artifact] = value
return ensemble
class Ensemble:
"""Merge the objects in an Ensemble."""
def __init__(self, merge_func=None):
"""init Ensemble
Args:
merge_func (Callable, optional): Given a dict and return the ensemble.
For example: {Rollinga_b: object, Rollingb_c: object} -> object
Defaults to None.
"""
self._merge = merge_func
def __call__(self, ensemble_dict: dict, *args, **kwargs):
"""Merge the ensemble_dict into an ensemble object.
Args:
ensemble_dict (dict): the ensemble dict waiting for merging like {name: things}
Returns:
object: the ensemble object
"""
if isinstance(getattr(self, "_merge", None), Callable):
return self._merge(ensemble_dict, *args, **kwargs)
else:
raise NotImplementedError(f"Please specify valid merge_func.")
class RollingEnsemble(Ensemble):
"""Merge the rolling objects in an Ensemble"""
@staticmethod
def rolling_merge(rolling_dict: dict):
"""Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
NOTE: The values of dict must be pd.Dataframe, and have the index "datetime"
Args:
rolling_dict (dict): a dict like {"A": pd.Dataframe, "B": pd.Dataframe}.
The key of the dict will be ignored.
Returns:
pd.Dataframe: the complete result of rolling.
"""
artifact_list = list(rolling_dict.values())
artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min())
artifact = pd.concat(artifact_list)
# If there are duplicated predition, use the latest perdiction
artifact = artifact[~artifact.index.duplicated(keep="last")]
artifact = artifact.sort_index()
return artifact
def __init__(self, merge_func=None):
super().__init__(merge_func=merge_func)
if merge_func is None:
self._merge = RollingEnsemble.rolling_merge

68
qlib/model/ens/group.py Normal file
View File

@@ -0,0 +1,68 @@
from qlib.model.ens.ensemble import Ensemble, RollingEnsemble
from typing import Callable, Union
class Group:
"""Group the objects based on dict"""
def __init__(self, group_func=None, ens: Ensemble = None):
"""init Group.
Args:
group_func (Callable, optional): Given a dict and return the group key and one of group elements.
For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
Defaults to None.
ens (Ensemble, optional): If not None, do ensemble for grouped value after grouping.
"""
self._group = group_func
self._ens = ens
def __call__(self, ungrouped_dict: dict, *args, **kwargs):
"""Group the ungrouped_dict into different groups.
Args:
ungrouped_dict (dict): the ungrouped dict waiting for grouping like {name: things}
Returns:
dict: grouped_dict like {G1: object, G2: object}
"""
if isinstance(getattr(self, "_group", None), Callable):
grouped_dict = self._group(ungrouped_dict, *args, **kwargs)
if self._ens is not None:
ens_dict = {}
for key, value in grouped_dict.items():
ens_dict[key] = self._ens(value)
grouped_dict = ens_dict
return grouped_dict
else:
raise NotImplementedError(f"Please specify valid merge_func.")
class RollingGroup(Group):
"""group the rolling dict"""
@staticmethod
def rolling_group(rolling_dict: dict):
"""Given an rolling dict likes {(A,B,R): things}, return the grouped dict likes {(A,B): {R:things}}
NOTE: There is a assumption which is the rolling key is at the end of key tuple, because the rolling results always need to be ensemble firstly.
Args:
rolling_dict (dict): an rolling dict. If the key is not a tuple, then do nothing.
Returns:
dict: grouped dict
"""
grouped_dict = {}
for key, values in rolling_dict.items():
if isinstance(key, tuple):
grouped_dict.setdefault(key[:-1], {})[key[-1]] = values
return grouped_dict
def __init__(self, group_func=None, ens: Ensemble = RollingEnsemble()):
super().__init__(group_func=group_func, ens=ens)
if group_func is None:
self._group = RollingGroup.rolling_group

View File

@@ -4,6 +4,7 @@
from qlib.utils import init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord
from qlib.workflow.task.manage import TaskManager, run_task
def task_train(task_config: dict, experiment_name: str) -> str:
@@ -57,3 +58,70 @@ def task_train(task_config: dict, experiment_name: str) -> str:
ar.generate()
return recorder
class Trainer:
"""
The trainer which can train a list of model
"""
def train(self, *args, **kwargs):
"""Given a list of model definition, finished training and return the results of them.
Returns:
list: a list of trained results
"""
raise NotImplementedError(f"Please implement the `train` method.")
class TrainerR(Trainer):
"""Trainer based on (R)ecorder.
Assumption: models were defined by `task` and the results will saved to `Recorder`
"""
def train(self, tasks: list, experiment_name: str, train_func=task_train, *args, **kwargs):
"""Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
Args:
tasks (list): a list of definition based on `task` dict
experiment_name (str): the experiment name
train_func (Callable): the train method which need at least `task` and `experiment_name`
Returns:
list: a list of Recorders
"""
recs = []
for task in tasks:
recs.append(train_func(task, experiment_name, *args, **kwargs))
return recs
class TrainerRM(TrainerR):
"""Trainer based on (R)ecorder and Task(M)anager
Assumption: `task` will be saved to TaskManager and `task` will be fetched and trained from TaskManager
"""
def train(self, tasks: list, experiment_name: str, task_pool: str, train_func=task_train, *args, **kwargs):
"""Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
This method defaults to a single process, but TaskManager offered a great way to parallel training.
Users can customize their train_func to realize multiple processes or even multiple machines.
Args:
tasks (list): a list of definition based on `task` dict
experiment_name (str): the experiment name
train_func (Callable): the train method which need at least `task` and `experiment_name`
Returns:
list: a list of Recorders
"""
tm = TaskManager(task_pool=task_pool)
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
run_task(train_func, task_pool, experiment_name=experiment_name, *args, **kwargs)
recs = []
for _id in _id_list:
recs.append(tm.re_query(_id)["res"])
return recs

View File

View File

@@ -3,7 +3,7 @@ from qlib import get_module_logger
from qlib.workflow import R
from qlib.model.trainer import task_train
from qlib.workflow.recorder import MLflowRecorder, Recorder
from qlib.workflow.task.update import ModelUpdater
from qlib.workflow.online.update import ModelUpdater
from qlib.workflow.task.utils import TimeAdjuster
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
@@ -37,6 +37,16 @@ class OnlineManager(Serializable):
def get_online_tag(self, *args, **kwargs):
raise NotImplementedError(f"Please implement the `get_online_tag` method.")
def reset_online_tag(self, *args, **kwargs):
raise NotImplementedError(f"Please implement the `reset_online_tag` method.")
def routine(self, *args, **kwargs):
self.prepare_signals(*args, **kwargs)
self.prepare_tasks(*args, **kwargs)
self.prepare_new_models(*args, **kwargs)
self.update_online_pred(*args, **kwargs)
self.reset_online_tag(*args, **kwargs)
class OnlineManagerR(OnlineManager):
"""
@@ -86,21 +96,18 @@ class OnlineManagerR(OnlineManager):
cnt = mu.update_all_pred(lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG)
self.logger.info(f"Finish updating {cnt} online model predictions of {self.exp_name}.")
def after_day(self, *args, **kwargs):
self.prepare_signals(*args, **kwargs)
self.prepare_tasks(*args, **kwargs)
self.prepare_new_models(*args, **kwargs)
self.update_online_pred(*args, **kwargs)
self.reset_online_tag()
class RollingOnlineManager(OnlineManagerR):
def __init__(self, experiment_name: str, rolling_gen: RollingGen, task_pool) -> None:
# FIXME: TaskManager不应该与onlinemanager强耦合
def __init__(
self, experiment_name: str, rolling_gen: RollingGen, task_manager: TaskManager, trainer=run_task
) -> None:
super().__init__(experiment_name)
self.ta = TimeAdjuster()
self.rg = rolling_gen
self.tm = TaskManager(task_pool=task_pool)
self.tm = task_manager
self.logger = get_module_logger(self.__class__.__name__)
self.trainer = trainer
def prepare_signals(self):
pass
@@ -122,13 +129,13 @@ class RollingOnlineManager(OnlineManagerR):
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
old_tasks.append(task)
new_tasks = task_generator(old_tasks, self.rg)
new_num = self.tm.create_task(new_tasks)
self.logger.info(f"Finished prepare {new_num} tasks.")
self.tm.create_task(new_tasks)
def prepare_new_models(self):
"""prepare(train) new models based on online model"""
run_task(task_train, self.tm.task_pool, experiment_name=self.exp_name)
run_task(task_train, task_pool=self.tm.task_pool, experiment_name=self.exp_name)
latest_records, _ = self.list_latest_recorders()
# FIXME: 现有的流程如果没有可更新的模型仍会调用这个导致会先将以前的模型设置成nextonline再去更新pred但这个时候online已经没有了pred无法更新
self.set_online_tag(OnlineManager.NEXT_ONLINE_TAG, latest_records.values())
self.logger.info(f"Finished prepare {len(latest_records)} new models and set them to next_online.")

View File

@@ -45,8 +45,8 @@ class ModelUpdater:
"""
segments = {"test": (start_time, end_time)}
dataset = recorder.load_object("dataset")
dataset.config(handler_kwargs={"start_time": start_time, "end_time": end_time})
dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS}, segments=segments)
dataset.config(handler_kwargs={"start_time": start_time, "end_time": end_time}, segments=segments)
dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS})
return dataset
def update_pred(self, recorder: Recorder, frequency="day"):

View File

@@ -1,49 +1,54 @@
from abc import abstractmethod
from typing import Callable, Union
import pandas as pd
from qlib import get_module_logger
from qlib.workflow.task.utils import list_recorders
class Collector:
"""The collector to collect different results based on experiment backend and ensemble method"""
"""The collector to collect different results"""
def collect(self, ensemble, get_group_key_func, *args, **kwargs):
"""To collect the results, we need to get the experiment record firstly and divided them into
different groups. Then use ensemble methods to merge the group.
def collect(self, *args, **kwargs):
"""Collect the results and return a dict like {key: things}
Args:
ensemble (Ensemble): an instance of Ensemble
get_group_key_func (Callable): a function to get the group of a experiment record
Returns:
dict: the dict after collected.
For example:
{"prediction": pd.Series}
{"IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}
......
"""
raise NotImplementedError(f"Please implement the `collect` method.")
class RecorderCollector(Collector):
def __init__(self, exp_name, artifacts_path={"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}) -> None:
def __init__(
self, exp_name, artifacts_path={"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}, rec_key_func=None
) -> None:
"""init RecorderCollector
Args:
exp_name (str): the name of Experiment
artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}.
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
"""
self.exp_name = exp_name
self.artifacts_path = artifacts_path
if rec_key_func is None:
rec_key_func = lambda rec: rec.info["id"]
self._get_key = rec_key_func
def collect(self, ensemble, get_group_key_func, artifacts_key=None, rec_filter_func=None):
"""Collect different artifacts based on recorder after filtering and ensemble method.
Group recorder by get_group_key_func.
def collect(self, artifacts_key=None, rec_filter_func=None): # ensemble, get_group_key_func,
"""Collect different artifacts based on recorder after filtering.
Args:
ensemble (Ensemble): an instance of Ensemble
get_group_key_func (Callable): a function to get the group of a experiment record
artifacts_key (str or List, optional): the artifacts key you want to get. Defaults to None.
artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
Returns:
dict: the dict after collected.
dict: the dict after collected like {artifact: {rec_key: object}}
"""
if artifacts_key is None:
artifacts_key = self.artifacts_path.keys()
@@ -51,22 +56,13 @@ class RecorderCollector(Collector):
if isinstance(artifacts_key, str):
artifacts_key = [artifacts_key]
# prepare_ensemble
ensemble_dict = {}
for key in artifacts_key:
ensemble_dict.setdefault(key, {})
collect_dict = {}
# filter records
recs_flt = list_recorders(self.exp_name, rec_filter_func)
for _, rec in recs_flt.items():
group_key = get_group_key_func(rec)
rec_key = self._get_key(rec)
for key in artifacts_key:
artifact = rec.load_object(self.artifacts_path[key])
ensemble_dict[key][group_key] = artifact
collect_dict.setdefault(key, {})[rec_key] = artifact
if isinstance(artifacts_key, str):
return ensemble(ensemble_dict[artifacts_key])
collect_dict = {}
for key in artifacts_key:
collect_dict[key] = ensemble(ensemble_dict[key])
return collect_dict
return collect_dict

View File

@@ -1,176 +0,0 @@
from abc import abstractmethod
from typing import Callable, Union
import pandas as pd
from qlib import get_module_logger
from qlib.workflow.task.utils import list_recorders
from typing import Dict
class Ensemble:
"""Merge the objects in an Ensemble."""
def __init__(self, merge_func=None, get_grouped_key_func=None) -> None:
"""init Ensemble
Args:
merge_func (Callable, optional): The specific merge function. Defaults to None.
get_grouped_key_func (Callable, optional): Get group_inner_key and group_outer_key by group_key. Defaults to None.
"""
self.logger = get_module_logger(self.__class__.__name__)
if merge_func is not None:
self.merge_func = merge_func
if get_grouped_key_func is not None:
self.get_grouped_key_func = get_grouped_key_func
def merge_func(self, group_inner_dict):
"""Given a group_inner_dict such as {Rollinga_b: object, Rollingb_c: object},
merge it to object
Args:
group_inner_dict (dict): the inner group dict
"""
raise NotImplementedError(f"Please implement the `merge_func` method.")
def get_grouped_key_func(self, group_key):
"""Given a group_key and return the group_outer_key, group_inner_key.
For example:
(A,B,Rolling) -> (A,B):Rolling
(A,B) -> C:(A,B)
Args:
group_key (tuple or str): the group key
"""
raise NotImplementedError(f"Please implement the `get_grouped_key_func` method.")
def group(self, group_dict: Dict[tuple or str, object]) -> Dict[tuple or str, Dict[tuple or str, object]]:
"""In a group of dict, further divide them into outgroups and innergroup.
For example:
.. code-block:: python
RollingEnsemble:
input:
{
(ModelA,Horizon5,Rollinga_b): object
(ModelA,Horizon5,Rollingb_c): object
(ModelA,Horizon10,Rollinga_b): object
(ModelA,Horizon10,Rollingb_c): object
(ModelB,Horizon5,Rollinga_b): object
(ModelB,Horizon5,Rollingb_c): object
(ModelB,Horizon10,Rollinga_b): object
(ModelB,Horizon10,Rollingb_c): object
}
output:
{
(ModelA,Horizon5): {Rollinga_b: object, Rollingb_c: object}
(ModelA,Horizon10): {Rollinga_b: object, Rollingb_c: object}
(ModelB,Horizon5): {Rollinga_b: object, Rollingb_c: object}
(ModelB,Horizon10): {Rollinga_b: object, Rollingb_c: object}
}
Args:
group_dict (Dict[tuple or str, object]): a group of dict
Returns:
Dict[tuple or str, Dict[tuple or str, object]]: the dict after `group`
"""
grouped_dict = {}
for group_key, artifact in group_dict.items():
group_outer_key, group_inner_key = self.get_grouped_key_func(group_key) # (A,B,Rolling) -> (A,B):Rolling
grouped_dict.setdefault(group_outer_key, {})[group_inner_key] = artifact
return grouped_dict
def reduce(self, grouped_dict: dict):
"""After grouping, reduce the innergroup.
For example:
.. code-block:: python
RollingEnsemble:
input:
{
(ModelA,Horizon5): {Rollinga_b: object, Rollingb_c: object}
(ModelA,Horizon10): {Rollinga_b: object, Rollingb_c: object}
(ModelB,Horizon5): {Rollinga_b: object, Rollingb_c: object}
(ModelB,Horizon10): {Rollinga_b: object, Rollingb_c: object}
}
output:
{
(ModelA,Horizon5): object
(ModelA,Horizon10): object
(ModelB,Horizon5): object
(ModelB,Horizon10): object
}
Args:
grouped_dict (dict): the dict after `group`
Returns:
dict: the dict after `reduce`
"""
reduce_group = {}
for group_outer_key, group_inner_dict in grouped_dict.items():
artifact = self.merge_func(group_inner_dict)
reduce_group[group_outer_key] = artifact
return reduce_group
def __call__(self, group_dict):
"""The process of Ensemble is group it firstly and then reduce it, like MapReduce.
Args:
group_dict (Dict[tuple or str, object]): a group of dict
Returns:
dict: the dict after `reduce`
"""
grouped_dict = self.group(group_dict)
return self.reduce(grouped_dict)
class RollingEnsemble(Ensemble):
"""A specific implementation of Ensemble for Rolling."""
def merge_func(self, group_inner_dict):
"""merge group_inner_dict by datetime.
Args:
group_inner_dict (dict): the inner group dict
Returns:
object: the artifact after merging
"""
artifact_list = list(group_inner_dict.values())
artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min())
artifact = pd.concat(artifact_list)
# If there are duplicated predition, use the latest perdiction
artifact = artifact[~artifact.index.duplicated(keep="last")]
artifact = artifact.sort_index()
return artifact
def get_grouped_key_func(self, group_key):
"""The final axis of group_key must be the Rolling key.
When `collect`, get_group_key_func can add the statement below.
.. code-block:: python
def get_group_key_func(recorder):
task_config = recorder.load_object("task")
......
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return ......, rolling_key
Args:
group_key (tuple or str): the group key
Returns:
tuple or str, tuple or str: group_outer_key, group_inner_key
"""
assert len(group_key) >= 2
return group_key[:-1], group_key[-1]

View File

@@ -60,7 +60,7 @@ class TaskManager:
"""
self.mdb = get_mongodb()
self.task_pool = task_pool
self.logger = get_module_logger("TaskManager")
self.logger = get_module_logger(self.__class__.__name__)
def list(self):
return self.mdb.list_collection_names()
@@ -105,10 +105,11 @@ class TaskManager:
def insert_task(self, task, task_pool=None):
task_pool = self._get_task_pool(task_pool)
try:
task_pool.insert_one(task)
insert_result = task_pool.insert_one(task)
except InvalidDocument:
task["filter"] = self._dict_to_str(task["filter"])
task_pool.insert_one(task)
insert_result = task_pool.insert_one(task)
return insert_result
def insert_task_def(self, task_def, task_pool=None):
"""
@@ -133,7 +134,8 @@ class TaskManager:
"status": self.STATUS_WAITING,
}
)
self.insert_task(task, task_pool)
insert_result = self.insert_task(task, task_pool)
return insert_result
def create_task(self, task_def_l, task_pool=None, dry_run=False, print_nt=False):
"""
@@ -151,8 +153,8 @@ class TaskManager:
if print new task
Returns
-------
int
the length of new tasks
list
a list of the _id of new tasks
"""
task_pool = self._get_task_pool(task_pool)
new_tasks = []
@@ -163,7 +165,7 @@ class TaskManager:
r = task_pool.find_one({"filter": self._dict_to_str(t)})
if r is None:
new_tasks.append(t)
print("Total Tasks, New Tasks:", len(task_def_l), len(new_tasks))
self.logger.info(f"Total Tasks: {len(task_def_l)}, New Tasks: {len(new_tasks)}")
if print_nt: # print new task
for t in new_tasks:
@@ -172,10 +174,12 @@ class TaskManager:
if dry_run:
return
_id_list = []
for t in new_tasks:
self.insert_task_def(t, task_pool)
insert_result = self.insert_task_def(t, task_pool)
_id_list.append(insert_result.inserted_id)
return len(new_tasks)
return _id_list
def fetch_task(self, query={}, task_pool=None):
task_pool = self._get_task_pool(task_pool)
@@ -248,9 +252,9 @@ class TaskManager:
for t in task_pool.find(query):
yield self._decode_task(t)
def re_query(self, task, task_pool=None):
def re_query(self, _id, task_pool=None):
task_pool = self._get_task_pool(task_pool)
return task_pool.find_one({"_id": ObjectId(task["_id"])})
return task_pool.find_one({"_id": ObjectId(_id)})
def commit_task_res(self, task, res, status=None, task_pool=None):
task_pool = self._get_task_pool(task_pool)