mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 10:31:00 +08:00
Online serving V11
This commit is contained in:
@@ -7,10 +7,10 @@ This examples is about how can simulate the OnlineManager based on rolling tasks
|
||||
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.trainer import DelayTrainerRM
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.workflow.online.strategy import RollingAverageStrategy
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
|
||||
@@ -115,7 +115,7 @@ class OnlineSimulationExample:
|
||||
) # The rolling tasks generator, ds_extra_mod_func is None because we just need simulate to 2018-10-31 and needn't change handler end time.
|
||||
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool)
|
||||
self.rolling_online_manager = OnlineManager(
|
||||
RollingAverageStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
|
||||
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
|
||||
trainer=self.trainer,
|
||||
begin_time=self.start_time,
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@ import pickle
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.strategy import RollingAverageStrategy
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
@@ -97,7 +97,7 @@ class RollingOnlineExample:
|
||||
for task in tasks:
|
||||
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
|
||||
strategy.append(
|
||||
RollingAverageStrategy(
|
||||
RollingStrategy(
|
||||
name_id,
|
||||
task,
|
||||
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
|
||||
@@ -172,7 +172,10 @@ class DatasetH(Dataset):
|
||||
----------
|
||||
slc : slice
|
||||
"""
|
||||
return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs)
|
||||
if hasattr(self, "fetch_kwargs"):
|
||||
return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs)
|
||||
else:
|
||||
return self.handler.fetch(slc, **kwargs)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
|
||||
@@ -7,7 +7,7 @@ Ensemble can merge the objects in an Ensemble. For example, if there are many su
|
||||
|
||||
from typing import Union
|
||||
import pandas as pd
|
||||
from qlib.utils import flatten_dict
|
||||
from qlib.utils import FLATTEN_TUPLE, flatten_dict
|
||||
|
||||
|
||||
class Ensemble:
|
||||
@@ -90,7 +90,7 @@ class AverageEnsemble(Ensemble):
|
||||
pd.DataFrame: the complete result of averaging and standardizing.
|
||||
"""
|
||||
# need to flatten the nested dict
|
||||
ensemble_dict = flatten_dict(ensemble_dict)
|
||||
ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE)
|
||||
values = list(ensemble_dict.values())
|
||||
results = pd.concat(values, axis=1)
|
||||
results = results.groupby("datetime").apply(lambda df: (df - df.mean()) / df.std())
|
||||
|
||||
@@ -416,7 +416,7 @@ class DelayTrainerRM(TrainerRM):
|
||||
run_task(
|
||||
end_train_func,
|
||||
task_pool,
|
||||
tasks=tasks,
|
||||
query={"filter": {"$in": tasks}}, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
**kwargs,
|
||||
|
||||
@@ -716,23 +716,33 @@ def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
|
||||
return df.sort_index(axis=axis)
|
||||
|
||||
|
||||
FLATTEN_TUPLE = "_FLATTEN_TUPLE"
|
||||
|
||||
|
||||
def flatten_dict(d, parent_key="", sep="."):
|
||||
"""flatten_dict.
|
||||
"""
|
||||
Flatten a nested dict.
|
||||
|
||||
>>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]})
|
||||
>>> {'a': 1, 'c.a': 2, 'c.b.x': 5, 'd': [1, 2, 3], 'c.b.y': 10}
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d :
|
||||
d
|
||||
parent_key :
|
||||
parent_key
|
||||
sep :
|
||||
sep
|
||||
>>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]}, sep=FLATTEN_TUPLE)
|
||||
>>> {'a': 1, ('c','a'): 2, ('c','b','x'): 5, 'd': [1, 2, 3], ('c','b','y'): 10}
|
||||
|
||||
Args:
|
||||
d (dict): the dict waiting for flatting
|
||||
parent_key (str, optional): the parent key, will be a prefix in new key. Defaults to "".
|
||||
sep (str, optional): the separator for string connecting. FLATTEN_TUPLE for tuple connecting.
|
||||
|
||||
Returns:
|
||||
dict: flatten dict
|
||||
"""
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
new_key = parent_key + sep + str(k) if parent_key else k
|
||||
if sep == FLATTEN_TUPLE:
|
||||
new_key = (parent_key, k) if parent_key else k
|
||||
else:
|
||||
new_key = parent_key + sep + k if parent_key else k
|
||||
if isinstance(v, collections.abc.MutableMapping):
|
||||
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
||||
else:
|
||||
|
||||
@@ -16,9 +16,10 @@ class Serializable:
|
||||
"""
|
||||
|
||||
pickle_backend = "pickle" # another optional value is "dill" which can pickle more things of python.
|
||||
default_dump_all = False # if dump all things
|
||||
|
||||
def __init__(self):
|
||||
self._dump_all = False
|
||||
self._dump_all = self.default_dump_all
|
||||
self._exclude = []
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
@@ -77,12 +78,7 @@ class Serializable:
|
||||
def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list = None):
|
||||
self.config(dump_all=dump_all, exclude=exclude)
|
||||
with Path(path).open("wb") as f:
|
||||
if self.pickle_backend == "pickle":
|
||||
pickle.dump(self, f)
|
||||
elif self.pickle_backend == "dill":
|
||||
dill.dump(self, f)
|
||||
else:
|
||||
raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")
|
||||
self.get_backend().dump(self, f)
|
||||
|
||||
@classmethod
|
||||
def load(cls, filepath):
|
||||
@@ -99,13 +95,24 @@ class Serializable:
|
||||
Collector: the instance of Collector
|
||||
"""
|
||||
with open(filepath, "rb") as f:
|
||||
if cls.pickle_backend == "pickle":
|
||||
object = pickle.load(f)
|
||||
elif cls.pickle_backend == "dill":
|
||||
object = dill.load(f)
|
||||
else:
|
||||
raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")
|
||||
object = cls.get_backend().load(f)
|
||||
if isinstance(object, cls):
|
||||
return object
|
||||
else:
|
||||
raise TypeError(f"The instance of {type(object)} is not a valid `{type(cls)}`!")
|
||||
|
||||
@classmethod
|
||||
def get_backend(cls):
|
||||
"""
|
||||
Return the backend of a Serializable class. The value will be "pickle" or "dill".
|
||||
|
||||
Returns:
|
||||
str: The value of "pickle" or "dill"
|
||||
"""
|
||||
if cls.pickle_backend == "pickle":
|
||||
return pickle
|
||||
elif cls.pickle_backend == "dill":
|
||||
return dill
|
||||
else:
|
||||
raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ class OnlineManager(Serializable):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy: Union[OnlineStrategy, List[OnlineStrategy]],
|
||||
strategies: Union[OnlineStrategy, List[OnlineStrategy]],
|
||||
trainer: Trainer = None,
|
||||
begin_time: Union[str, pd.Timestamp] = None,
|
||||
freq="day",
|
||||
@@ -45,15 +45,15 @@ class OnlineManager(Serializable):
|
||||
One OnlineManager must have at least one OnlineStrategy.
|
||||
|
||||
Args:
|
||||
strategy (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy
|
||||
strategies (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy
|
||||
begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None for using latest date.
|
||||
trainer (Trainer): the trainer to train task. None for using DelayTrainerR.
|
||||
freq (str, optional): data frequency. Defaults to "day".
|
||||
"""
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
if not isinstance(strategy, list):
|
||||
strategy = [strategy]
|
||||
self.strategy = strategy
|
||||
if not isinstance(strategies, list):
|
||||
strategies = [strategies]
|
||||
self.strategies = strategies
|
||||
self.freq = freq
|
||||
if begin_time is None:
|
||||
begin_time = D.calendar(freq=self.freq).max()
|
||||
@@ -77,7 +77,7 @@ class OnlineManager(Serializable):
|
||||
"""
|
||||
models_list = []
|
||||
if strategies is None:
|
||||
strategies = self.strategy
|
||||
strategies = self.strategies
|
||||
for strategy in strategies:
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
|
||||
tasks = strategy.first_tasks()
|
||||
@@ -114,21 +114,22 @@ class OnlineManager(Serializable):
|
||||
cur_time = D.calendar(freq=self.freq).max()
|
||||
self.cur_time = pd.Timestamp(cur_time) # None for latest date
|
||||
models_list = []
|
||||
for strategy in self.strategy:
|
||||
for strategy in self.strategies:
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
|
||||
if not delay:
|
||||
strategy.tool.update_online_pred()
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
|
||||
|
||||
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
|
||||
models = self.trainer.train(tasks)
|
||||
self.logger.info(f"Finished training {len(models)} models.")
|
||||
models_list.append(models)
|
||||
|
||||
for strategy, models in zip(self.strategies, models_list):
|
||||
self.prepare_online_models(strategy, models, delay=delay, model_kwargs=model_kwargs)
|
||||
|
||||
if not delay:
|
||||
self.prepare_signals(**signal_kwargs)
|
||||
|
||||
for strategy, models in zip(self.strategy, models_list):
|
||||
self.prepare_online_models(strategy, models, delay=delay, model_kwargs=model_kwargs)
|
||||
|
||||
def prepare_online_models(
|
||||
self, strategy: OnlineStrategy, models: list, delay: bool = False, model_kwargs: dict = {}
|
||||
):
|
||||
@@ -141,14 +142,9 @@ class OnlineManager(Serializable):
|
||||
delay (bool, optional): if delay prepare models. Defaults to False.
|
||||
model_kwargs (dict, optional): the params for `prepare_online_models`.
|
||||
"""
|
||||
if not models:
|
||||
return
|
||||
if not delay:
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
online_models = strategy.prepare_online_models(models, **model_kwargs)
|
||||
else:
|
||||
# just set every models as online models temporarily before ``prepare_online_models``
|
||||
online_models = models
|
||||
online_models = strategy.prepare_online_models(models, **model_kwargs)
|
||||
self.history.setdefault(self.cur_time, {})[strategy] = online_models
|
||||
|
||||
def get_collector(self) -> MergeCollector:
|
||||
@@ -160,21 +156,21 @@ class OnlineManager(Serializable):
|
||||
MergeCollector: the collector to merge other collectors.
|
||||
"""
|
||||
collector_dict = {}
|
||||
for strategy in self.strategy:
|
||||
for strategy in self.strategies:
|
||||
collector_dict[strategy.name_id] = strategy.get_collector()
|
||||
return MergeCollector(collector_dict, process_list=[])
|
||||
|
||||
def add_strategy(self, strategy: Union[OnlineStrategy, List[OnlineStrategy]]):
|
||||
def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]):
|
||||
"""
|
||||
Add some new strategies to online manager.
|
||||
|
||||
Args:
|
||||
strategy (Union[OnlineStrategy, List[OnlineStrategy]]): a list of OnlineStrategy
|
||||
"""
|
||||
if not isinstance(strategy, list):
|
||||
strategy = [strategy]
|
||||
self.first_train(strategy)
|
||||
self.strategy.extend(strategy)
|
||||
if not isinstance(strategies, list):
|
||||
strategies = [strategies]
|
||||
self.first_train(strategies)
|
||||
self.strategies.extend(strategies)
|
||||
|
||||
def prepare_signals(self, prepare_func: Callable = AverageEnsemble(), over_write=False):
|
||||
"""
|
||||
@@ -258,7 +254,8 @@ class OnlineManager(Serializable):
|
||||
if self.trainer.is_delay():
|
||||
self.delay_prepare(model_kwargs=model_kwargs, signal_kwargs=signal_kwargs)
|
||||
|
||||
set_global_logger_level(logging.INFO)
|
||||
# FIXME: get logging level firstly and restore it here
|
||||
set_global_logger_level(logging.DEBUG)
|
||||
self.logger.info(f"Finished preparing signals")
|
||||
return self.get_collector()
|
||||
|
||||
|
||||
@@ -44,28 +44,23 @@ class OnlineStrategy:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
|
||||
|
||||
def prepare_online_models(self, models, cur_time=None, check_func=None) -> List[object]:
|
||||
def prepare_online_models(self, models, cur_time=None) -> List[object]:
|
||||
"""
|
||||
A typically implementation, but maybe you will need old models by online_tool.
|
||||
Select some models as the online models from the trained models.
|
||||
Select some models from trained models and set them to online models.
|
||||
This is a typically implementation to online all trained models, you can override it to implement complex method.
|
||||
You can find last online models by OnlineTool.online_models if you still need them.
|
||||
|
||||
NOTE: This method offline all models and online the online models prepared by this method (if have). So you can find last online models by OnlineTool.online_models if you still need them.
|
||||
NOTE: Reset all online models to trained model. If there is no trained models, then do nothing.
|
||||
|
||||
Args:
|
||||
tasks (list): a list of tasks.
|
||||
check_func: the method to judge if a model can be online.
|
||||
The parameter is the model record and return True for online.
|
||||
None for online every models.
|
||||
models (list): a list of models.
|
||||
cur_time (pd.Dataframe): current time from OnlineManger. None for latest.
|
||||
|
||||
Returns:
|
||||
List[object]: a list of selected models.
|
||||
List[object]: a list of online models.
|
||||
"""
|
||||
if check_func is not None:
|
||||
online_models = []
|
||||
for model in models:
|
||||
if check_func(model, cur_time):
|
||||
online_models.append(model)
|
||||
models = online_models
|
||||
if not models:
|
||||
return self.tool.online_models()
|
||||
self.tool.reset_online_tag(models)
|
||||
return models
|
||||
|
||||
@@ -89,10 +84,10 @@ class OnlineStrategy:
|
||||
raise NotImplementedError(f"Please implement the `get_collector` method.")
|
||||
|
||||
|
||||
class RollingAverageStrategy(OnlineStrategy):
|
||||
class RollingStrategy(OnlineStrategy):
|
||||
|
||||
"""
|
||||
This example strategy always use latest rolling model as online model and prepare trading signals using the average prediction of online models
|
||||
This example strategy always use latest rolling model as online model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -102,7 +97,7 @@ class RollingAverageStrategy(OnlineStrategy):
|
||||
rolling_gen: RollingGen,
|
||||
):
|
||||
"""
|
||||
Init RollingAverageStrategy.
|
||||
Init RollingStrategy.
|
||||
|
||||
Assumption: the str of name_id, the experiment name and the trainer's experiment name are same one.
|
||||
|
||||
|
||||
@@ -137,7 +137,9 @@ class PredUpdater(RecordUpdater):
|
||||
|
||||
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
|
||||
if start_time >= self.to_date:
|
||||
self.logger.info(f"The prediction in {self.record.info['id']} are latest. No need to update.")
|
||||
self.logger.info(
|
||||
f"The prediction in {self.record.info['id']} are latest ({start_time}). No need to update to {self.to_date}."
|
||||
)
|
||||
return
|
||||
|
||||
# load dataset
|
||||
|
||||
@@ -158,6 +158,10 @@ class OnlineToolR(OnlineTool):
|
||||
"""
|
||||
online_models = self.online_models()
|
||||
for rec in online_models:
|
||||
PredUpdater(rec, to_date=to_date).update()
|
||||
hist_ref = 0
|
||||
task = rec.load_object("task")
|
||||
if task["dataset"]["class"] == "TSDatasetH":
|
||||
hist_ref = task["dataset"]["kwargs"]["step_len"]
|
||||
PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update()
|
||||
|
||||
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")
|
||||
|
||||
@@ -91,18 +91,21 @@ class MergeCollector(Collector):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, collector_dict: Dict[str, Collector], process_list: List[Callable] = []):
|
||||
def __init__(self, collector_dict: Dict[str, Collector], process_list: List[Callable] = [], merge_func=None):
|
||||
"""
|
||||
Args:
|
||||
collector_dict (Dict[str,Collector]): the dict like {collector_key, Collector}
|
||||
process_list (List[Callable]): the list of processors or the instance of processor to process dict.
|
||||
merge_func (Callable): a method to generate outermost key. The given params are ``collector_key`` from collector_dict and ``key`` from every collector after collecting.
|
||||
None for use tuple to connect them, such as "ABC"+("a","b") -> ("ABC", ("a","b")).
|
||||
"""
|
||||
super().__init__(process_list=process_list)
|
||||
self.collector_dict = collector_dict
|
||||
self.merge_func = merge_func
|
||||
|
||||
def collect(self) -> dict:
|
||||
"""
|
||||
Collect all result of collector_dict and change the outermost key to "``collector_key``_``key``" (like merge them, but rename every key)
|
||||
Collect all result of collector_dict and change the outermost key to a recombination key.
|
||||
|
||||
Returns:
|
||||
dict: the dict after collecting.
|
||||
@@ -111,7 +114,10 @@ class MergeCollector(Collector):
|
||||
for collector_key, collector in self.collector_dict.items():
|
||||
tmp_dict = collector()
|
||||
for key, value in tmp_dict.items():
|
||||
collect_dict[collector_key + "_" + str(key)] = value
|
||||
if self.merge_func is not None:
|
||||
collect_dict[self.merge_func(collector_key, key)] = value
|
||||
else:
|
||||
collect_dict[(collector_key, key)] = value
|
||||
return collect_dict
|
||||
|
||||
|
||||
@@ -146,16 +152,18 @@ class RecorderCollector(Collector):
|
||||
rec_key_func = lambda rec: rec.info["id"]
|
||||
if artifacts_key is None:
|
||||
artifacts_key = list(self.artifacts_path.keys())
|
||||
self._rec_key_func = rec_key_func
|
||||
self.rec_key_func = rec_key_func
|
||||
self.artifacts_key = artifacts_key
|
||||
self._rec_filter_func = rec_filter_func
|
||||
self.rec_filter_func = rec_filter_func
|
||||
|
||||
def collect(self, artifacts_key=None, rec_filter_func=None) -> dict:
|
||||
def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> dict:
|
||||
"""Collect different artifacts based on recorder after filtering.
|
||||
|
||||
Args:
|
||||
artifacts_key (str or List, optional): the artifacts key you want to get. If None, use default.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. If None, use default.
|
||||
only_exist (bool, optional): if only collect the artifacts when a recorder really have.
|
||||
If True, the recorder with exception when loading will not be collected. But if False, it will raise the exception.
|
||||
|
||||
Returns:
|
||||
dict: the dict after collected like {artifact: {rec_key: object}}
|
||||
@@ -163,7 +171,7 @@ class RecorderCollector(Collector):
|
||||
if artifacts_key is None:
|
||||
artifacts_key = self.artifacts_key
|
||||
if rec_filter_func is None:
|
||||
rec_filter_func = self._rec_filter_func
|
||||
rec_filter_func = self.rec_filter_func
|
||||
|
||||
if isinstance(artifacts_key, str):
|
||||
artifacts_key = [artifacts_key]
|
||||
@@ -177,16 +185,18 @@ class RecorderCollector(Collector):
|
||||
recs_flt[rid] = rec
|
||||
|
||||
for _, rec in recs_flt.items():
|
||||
rec_key = self._rec_key_func(rec)
|
||||
rec_key = self.rec_key_func(rec)
|
||||
for key in artifacts_key:
|
||||
if self.ART_KEY_RAW == key:
|
||||
artifact = rec
|
||||
else:
|
||||
# only collect existing artifact
|
||||
try:
|
||||
artifact = rec.load_object(self.artifacts_path[key])
|
||||
except Exception:
|
||||
continue
|
||||
except Exception as e:
|
||||
if only_exist:
|
||||
# only collect existing artifact
|
||||
continue
|
||||
raise e
|
||||
collect_dict.setdefault(key, {})[rec_key] = artifact
|
||||
|
||||
return collect_dict
|
||||
|
||||
@@ -155,7 +155,8 @@ class TaskManager:
|
||||
|
||||
def create_task(self, task_def_l, dry_run=False, print_nt=False) -> List[str]:
|
||||
"""
|
||||
If the tasks in task_def_l is new, then insert new tasks into the task_pool
|
||||
If the tasks in task_def_l is new, then insert new tasks into the task_pool, and record inserted_id.
|
||||
If a task is not new, then query its _id.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -169,9 +170,10 @@ class TaskManager:
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
a list of the _id of new tasks
|
||||
a list of the _id of task_def_l
|
||||
"""
|
||||
new_tasks = []
|
||||
_id_list = []
|
||||
for t in task_def_l:
|
||||
try:
|
||||
r = self.task_pool.find_one({"filter": t})
|
||||
@@ -179,6 +181,14 @@ class TaskManager:
|
||||
r = self.task_pool.find_one({"filter": self._dict_to_str(t)})
|
||||
if r is None:
|
||||
new_tasks.append(t)
|
||||
if not dry_run:
|
||||
insert_result = self.insert_task_def(t)
|
||||
_id_list.append(insert_result.inserted_id)
|
||||
else:
|
||||
_id_list.append(None)
|
||||
else:
|
||||
_id_list.append(self._decode_task(r)["_id"])
|
||||
|
||||
self.logger.info(f"Total Tasks: {len(task_def_l)}, New Tasks: {len(new_tasks)}")
|
||||
|
||||
if print_nt: # print new task
|
||||
@@ -188,11 +198,6 @@ class TaskManager:
|
||||
if dry_run:
|
||||
return []
|
||||
|
||||
_id_list = []
|
||||
for t in new_tasks:
|
||||
insert_result = self.insert_task_def(t)
|
||||
_id_list.append(insert_result.inserted_id)
|
||||
|
||||
return _id_list
|
||||
|
||||
def fetch_task(self, query={}, status=STATUS_WAITING) -> dict:
|
||||
@@ -388,7 +393,7 @@ class TaskManager:
|
||||
def run_task(
|
||||
task_func: Callable,
|
||||
task_pool: str,
|
||||
tasks: List[dict] = None,
|
||||
query: dict = {},
|
||||
force_release: bool = False,
|
||||
before_status: str = TaskManager.STATUS_WAITING,
|
||||
after_status: str = TaskManager.STATUS_DONE,
|
||||
@@ -414,8 +419,8 @@ def run_task(
|
||||
the function to run the task
|
||||
task_pool : str
|
||||
the name of the task pool (Collection in MongoDB)
|
||||
tasks: List[dict]
|
||||
will only train these tasks config, None for train all tasks.
|
||||
query: dict
|
||||
will use this dict to query task_pool when fetching task
|
||||
force_release : bool
|
||||
will the program force to release the resource
|
||||
before_status : str:
|
||||
@@ -428,9 +433,6 @@ def run_task(
|
||||
tm = TaskManager(task_pool)
|
||||
|
||||
ever_run = False
|
||||
query = {}
|
||||
if tasks is not None:
|
||||
query = {"filter": {"$in": tasks}}
|
||||
|
||||
while True:
|
||||
with tm.safe_fetch_task(status=before_status, query=query) as task:
|
||||
|
||||
Reference in New Issue
Block a user