mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 02:50:58 +08:00
Merge pull request #435 from you-n-g/online_srv
Multiprocessing support for Online Serving
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
"""
|
||||
This example shows how a TrainerRM works based on TaskManager with rolling tasks.
|
||||
After training, how to collect the rolling results will be shown in task_collecting.
|
||||
Based on the ability of TaskManager, `worker` method offer a simple way for multiprocessing.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
@@ -13,7 +14,7 @@ import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM
|
||||
@@ -68,6 +69,11 @@ class RollingTaskExample:
|
||||
trainer = TrainerRM(self.experiment_name, self.task_pool)
|
||||
trainer.train(tasks)
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker.
|
||||
print("========== worker ==========")
|
||||
run_task(task_train, self.task_pool, experiment_name=self.experiment_name)
|
||||
|
||||
def task_collecting(self):
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
This example is about how can simulate the OnlineManager based on rolling tasks.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
|
||||
@@ -13,7 +14,7 @@ from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE
|
||||
|
||||
|
||||
class OnlineSimulationExample:
|
||||
@@ -22,8 +23,8 @@ class OnlineSimulationExample:
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
exp_name="rolling_exp",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_pool="rolling_task",
|
||||
rolling_step=80,
|
||||
start_time="2018-09-10",
|
||||
@@ -46,7 +47,7 @@ class OnlineSimulationExample:
|
||||
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
|
||||
"""
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE]
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
self.start_time = start_time
|
||||
@@ -59,7 +60,7 @@ class OnlineSimulationExample:
|
||||
self.rolling_gen = RollingGen(
|
||||
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
|
||||
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
|
||||
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
self.trainer = TrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
self.rolling_online_manager = OnlineManager(
|
||||
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
|
||||
trainer=self.trainer,
|
||||
@@ -85,6 +86,15 @@ class OnlineSimulationExample:
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing
|
||||
# FIXME: only can call after finishing simulation when using DelayTrainerRM, or there will be some exception.
|
||||
print("========== worker ==========")
|
||||
if isinstance(self.trainer, TrainerRM):
|
||||
self.trainer.worker()
|
||||
else:
|
||||
print(f"{type(self.trainer)} is not supported for worker.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to run all workflow automatically with your own parameters, use the command below
|
||||
|
||||
@@ -13,11 +13,13 @@ Finally, the OnlineManager will finish second routine and update all strategies.
|
||||
import os
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM, end_task_train, task_train
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG
|
||||
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING, CSI100_RECORD_LGB_TASK_CONFIG_ROLLING
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
|
||||
|
||||
class RollingOnlineExample:
|
||||
@@ -25,16 +27,17 @@ class RollingOnlineExample:
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
trainer=DelayTrainerRM(), # you can choose from TrainerR, TrainerRM, DelayTrainerR, DelayTrainerRM
|
||||
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
|
||||
rolling_step=550,
|
||||
tasks=None,
|
||||
add_tasks=None,
|
||||
):
|
||||
if add_tasks is None:
|
||||
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG_ROLLING]
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG]
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING]
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
"task_db_name": task_db_name, # database name
|
||||
@@ -53,17 +56,28 @@ class RollingOnlineExample:
|
||||
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
)
|
||||
)
|
||||
|
||||
self.rolling_online_manager = OnlineManager(strategies)
|
||||
self.trainer = trainer
|
||||
self.rolling_online_manager = OnlineManager(strategies, trainer=self.trainer)
|
||||
|
||||
_ROLLING_MANAGER_PATH = (
|
||||
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
|
||||
)
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing
|
||||
print("========== worker ==========")
|
||||
if isinstance(self.trainer, TrainerRM):
|
||||
for task in self.tasks + self.add_tasks:
|
||||
name_id = task["model"]["class"]
|
||||
self.trainer.worker(experiment_name=name_id)
|
||||
else:
|
||||
print(f"{type(self.trainer)} is not supported for worker.")
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
for task in self.tasks + self.add_tasks:
|
||||
name_id = task["model"]["class"]
|
||||
TaskManager(task_pool=name_id).remove()
|
||||
exp = R.get_exp(experiment_name=name_id)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
@@ -12,9 +12,11 @@ In ``DelayTrainer``, the first step is only to save some necessary info to model
|
||||
"""
|
||||
|
||||
import socket
|
||||
import time
|
||||
from typing import Callable, List
|
||||
|
||||
from qlib.data.dataset import Dataset
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
@@ -190,6 +192,8 @@ class TrainerR(Trainer):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if isinstance(tasks, dict):
|
||||
tasks = [tasks]
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
if train_func is None:
|
||||
@@ -213,6 +217,8 @@ class TrainerR(Trainer):
|
||||
Returns:
|
||||
List[Recorder]: the same list as the param.
|
||||
"""
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
@@ -250,6 +256,8 @@ class DelayTrainerR(TrainerR):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
@@ -275,6 +283,9 @@ class TrainerRM(Trainer):
|
||||
STATUS_BEGIN = "begin_task_train"
|
||||
STATUS_END = "end_task_train"
|
||||
|
||||
# This tag is the _id in TaskManager to distinguish tasks.
|
||||
TM_ID = "_id in TaskManager"
|
||||
|
||||
def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
|
||||
"""
|
||||
Init TrainerR.
|
||||
@@ -315,6 +326,8 @@ class TrainerRM(Trainer):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if isinstance(tasks, dict):
|
||||
tasks = [tasks]
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
if train_func is None:
|
||||
@@ -326,19 +339,24 @@ class TrainerRM(Trainer):
|
||||
task_pool = experiment_name
|
||||
tm = TaskManager(task_pool=task_pool)
|
||||
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
|
||||
query = {"_id": {"$in": _id_list}}
|
||||
run_task(
|
||||
train_func,
|
||||
task_pool,
|
||||
query=query, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=before_status,
|
||||
after_status=after_status,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
tm.wait(query=query)
|
||||
|
||||
recs = []
|
||||
for _id in _id_list:
|
||||
rec = tm.re_query(_id)["res"]
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
|
||||
rec.set_tags(**{self.TM_ID: _id})
|
||||
recs.append(rec)
|
||||
return recs
|
||||
|
||||
@@ -352,10 +370,33 @@ class TrainerRM(Trainer):
|
||||
Returns:
|
||||
List[Recorder]: the same list as the param.
|
||||
"""
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
def worker(
|
||||
self,
|
||||
train_func: Callable = None,
|
||||
experiment_name: str = None,
|
||||
):
|
||||
"""
|
||||
The multiprocessing method for `train`. It can share a same task_pool with `train` and can run in other progress or other machines.
|
||||
|
||||
Args:
|
||||
train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
"""
|
||||
if train_func is None:
|
||||
train_func = self.train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
run_task(train_func, task_pool=task_pool, experiment_name=experiment_name)
|
||||
|
||||
|
||||
class DelayTrainerRM(TrainerRM):
|
||||
"""
|
||||
@@ -395,6 +436,8 @@ class DelayTrainerRM(TrainerRM):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if isinstance(tasks, dict):
|
||||
tasks = [tasks]
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
return super().train(
|
||||
@@ -410,8 +453,6 @@ class DelayTrainerRM(TrainerRM):
|
||||
Given a list of Recorder and return a list of trained Recorder.
|
||||
This class will finish real data loading and model fitting.
|
||||
|
||||
NOTE: This method will train all STATUS_PART_DONE tasks in the task pool, not only the ``recs``.
|
||||
|
||||
Args:
|
||||
recs (list): a list of Recorder, the tasks have been saved to them.
|
||||
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
@@ -421,7 +462,8 @@ class DelayTrainerRM(TrainerRM):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
@@ -429,18 +471,44 @@ class DelayTrainerRM(TrainerRM):
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
tasks = []
|
||||
_id_list = []
|
||||
for rec in recs:
|
||||
tasks.append(rec.load_object("task"))
|
||||
_id_list.append(rec.list_tags()[self.TM_ID])
|
||||
|
||||
query = {"_id": {"$in": _id_list}}
|
||||
run_task(
|
||||
end_train_func,
|
||||
task_pool,
|
||||
query={"filter": {"$in": tasks}}, # only train these tasks
|
||||
query=query, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
TaskManager(task_pool=task_pool).wait(query=query)
|
||||
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
def worker(self, end_train_func=None, experiment_name: str = None):
|
||||
"""
|
||||
The multiprocessing method for `end_train`. It can share a same task_pool with `end_train` and can run in other progress or other machines.
|
||||
|
||||
Args:
|
||||
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
"""
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
run_task(
|
||||
end_train_func,
|
||||
task_pool=task_pool,
|
||||
experiment_name=experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
)
|
||||
|
||||
@@ -43,17 +43,29 @@ RECORD_CONFIG = [
|
||||
]
|
||||
|
||||
|
||||
def get_data_handler_config(market=CSI300_MARKET):
|
||||
def get_data_handler_config(
|
||||
start_time="2008-01-01",
|
||||
end_time="2020-08-01",
|
||||
fit_start_time="2008-01-01",
|
||||
fit_end_time="2014-12-31",
|
||||
instruments=CSI300_MARKET,
|
||||
):
|
||||
return {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
"instruments": instruments,
|
||||
}
|
||||
|
||||
|
||||
def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLASS):
|
||||
def get_dataset_config(
|
||||
dataset_class=DATASET_ALPHA158_CLASS,
|
||||
train=("2008-01-01", "2014-12-31"),
|
||||
valid=("2015-01-01", "2016-12-31"),
|
||||
test=("2017-01-01", "2020-08-01"),
|
||||
handler_kwargs={"instruments": CSI300_MARKET},
|
||||
):
|
||||
return {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
@@ -61,48 +73,88 @@ def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLAS
|
||||
"handler": {
|
||||
"class": dataset_class,
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": get_data_handler_config(market),
|
||||
"kwargs": get_data_handler_config(**handler_kwargs),
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
"train": train,
|
||||
"valid": valid,
|
||||
"test": test,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_gbdt_task(market=CSI300_MARKET):
|
||||
def get_gbdt_task(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
|
||||
return {
|
||||
"model": GBDT_MODEL,
|
||||
"dataset": get_dataset_config(market),
|
||||
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
|
||||
}
|
||||
|
||||
|
||||
def get_record_lgb_config(market=CSI300_MARKET):
|
||||
def get_record_lgb_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
|
||||
return {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": get_dataset_config(market),
|
||||
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
|
||||
"record": RECORD_CONFIG,
|
||||
}
|
||||
|
||||
|
||||
def get_record_xgboost_config(market=CSI300_MARKET):
|
||||
def get_record_xgboost_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
|
||||
return {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": get_dataset_config(market),
|
||||
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
|
||||
"record": RECORD_CONFIG,
|
||||
}
|
||||
|
||||
|
||||
CSI300_DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET)
|
||||
CSI300_GBDT_TASK = get_gbdt_task(market=CSI300_MARKET)
|
||||
CSI300_DATASET_CONFIG = get_dataset_config(handler_kwargs={"instruments": CSI300_MARKET})
|
||||
CSI300_GBDT_TASK = get_gbdt_task(handler_kwargs={"instruments": CSI300_MARKET})
|
||||
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(market=CSI100_MARKET)
|
||||
CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(market=CSI100_MARKET)
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(handler_kwargs={"instruments": CSI100_MARKET})
|
||||
CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(handler_kwargs={"instruments": CSI100_MARKET})
|
||||
|
||||
# use for rolling_online_managment.py
|
||||
ROLLING_HANDLER_CONFIG = {
|
||||
"start_time": "2013-01-01",
|
||||
"end_time": "2020-09-25",
|
||||
"fit_start_time": "2013-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": CSI100_MARKET,
|
||||
}
|
||||
ROLLING_DATASET_CONFIG = {
|
||||
"train": ("2013-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2015-12-31"),
|
||||
"test": ("2016-01-01", "2020-07-10"),
|
||||
}
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING = get_record_xgboost_config(
|
||||
dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG
|
||||
)
|
||||
CSI100_RECORD_LGB_TASK_CONFIG_ROLLING = get_record_lgb_config(
|
||||
dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG
|
||||
)
|
||||
|
||||
# use for online_management_simulate.py
|
||||
ONLINE_HANDLER_CONFIG = {
|
||||
"start_time": "2018-01-01",
|
||||
"end_time": "2018-10-31",
|
||||
"fit_start_time": "2018-01-01",
|
||||
"fit_end_time": "2018-03-31",
|
||||
"instruments": CSI100_MARKET,
|
||||
}
|
||||
ONLINE_DATASET_CONFIG = {
|
||||
"train": ("2018-01-01", "2018-03-31"),
|
||||
"valid": ("2018-04-01", "2018-05-31"),
|
||||
"test": ("2018-06-01", "2018-09-10"),
|
||||
}
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE = get_record_xgboost_config(
|
||||
dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG
|
||||
)
|
||||
CSI100_RECORD_LGB_TASK_CONFIG_ONLINE = get_record_lgb_config(
|
||||
dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG
|
||||
)
|
||||
|
||||
@@ -18,10 +18,12 @@ There are 4 total situations for using different trainers in different situation
|
||||
========================= ===================================================================================
|
||||
Situations Description
|
||||
========================= ===================================================================================
|
||||
Online + Trainer When you REAL want to do a routine, the Trainer will help you train the models.
|
||||
Online + Trainer When you want to do a REAL routine, the Trainer will help you train the models. It
|
||||
will train models task by task and strategy by strategy.
|
||||
|
||||
Online + DelayTrainer In normal online routine, whether Trainer or DelayTrainer will REAL train models
|
||||
in this routine. So it is not necessary to use DelayTrainer when do a REAL routine.
|
||||
Online + DelayTrainer When your models don't have any temporal dependence, the DelayTrainer will train
|
||||
nothing until all tasks have been prepared. It makes user can train all tasks in
|
||||
the end of `routine` or `first_train`.
|
||||
|
||||
Simulation + Trainer When your models have some temporal dependence on the previous models, then you
|
||||
need to consider using Trainer. This means it will REAL train your models in
|
||||
@@ -103,17 +105,21 @@ class OnlineManager(Serializable):
|
||||
"""
|
||||
if strategies is None:
|
||||
strategies = self.strategies
|
||||
for strategy in strategies:
|
||||
|
||||
models_list = []
|
||||
for strategy in strategies:
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
|
||||
tasks = strategy.first_tasks()
|
||||
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
models_list.append(models)
|
||||
self.logger.info(f"Finished training {len(models)} models.")
|
||||
|
||||
online_models = strategy.prepare_online_models(models, **model_kwargs)
|
||||
self.history.setdefault(self.cur_time, {})[strategy] = online_models
|
||||
|
||||
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
|
||||
for strategy, models in zip(strategies, models_list):
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
|
||||
def routine(
|
||||
self,
|
||||
cur_time: Union[str, pd.Timestamp] = None,
|
||||
@@ -139,33 +145,38 @@ class OnlineManager(Serializable):
|
||||
cur_time = D.calendar(freq=self.freq).max()
|
||||
self.cur_time = pd.Timestamp(cur_time) # None for latest date
|
||||
|
||||
models_list = []
|
||||
for strategy in self.strategies:
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
|
||||
if self.status == self.STATUS_NORMAL:
|
||||
strategy.tool.update_online_pred()
|
||||
|
||||
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
|
||||
models = self.trainer.train(tasks)
|
||||
if self.status == self.STATUS_NORMAL or not self.trainer.is_delay():
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
|
||||
models_list.append(models)
|
||||
self.logger.info(f"Finished training {len(models)} models.")
|
||||
online_models = strategy.prepare_online_models(models, **model_kwargs)
|
||||
self.history.setdefault(self.cur_time, {})[strategy] = online_models
|
||||
|
||||
if not self.trainer.is_delay():
|
||||
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
|
||||
for strategy, models in zip(self.strategies, models_list):
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
self.prepare_signals(**signal_kwargs)
|
||||
|
||||
def get_collector(self) -> MergeCollector:
|
||||
def get_collector(self, **kwargs) -> MergeCollector:
|
||||
"""
|
||||
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy.
|
||||
This collector can be a basis as the signals preparation.
|
||||
|
||||
Args:
|
||||
**kwargs: the params for get_collector.
|
||||
|
||||
Returns:
|
||||
MergeCollector: the collector to merge other collectors.
|
||||
"""
|
||||
collector_dict = {}
|
||||
for strategy in self.strategies:
|
||||
collector_dict[strategy.name_id] = strategy.get_collector()
|
||||
collector_dict[strategy.name_id] = strategy.get_collector(**kwargs)
|
||||
return MergeCollector(collector_dict, process_list=[])
|
||||
|
||||
def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]):
|
||||
@@ -297,6 +308,7 @@ class OnlineManager(Serializable):
|
||||
# NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way.
|
||||
self.prepare_signals(**signal_kwargs)
|
||||
if signals_time > cur_time:
|
||||
# FIXME: if use DelayTrainer and worker (and worker is faster than main progress), there are some possibilities of showing this warning.
|
||||
self.logger.warn(
|
||||
f"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models."
|
||||
)
|
||||
|
||||
@@ -69,28 +69,29 @@ class TaskManager:
|
||||
|
||||
ENCODE_FIELDS_PREFIX = ["def", "res"]
|
||||
|
||||
def __init__(self, task_pool: str = None):
|
||||
def __init__(self, task_pool: str):
|
||||
"""
|
||||
Init Task Manager, remember to make the statement of MongoDB url and database name firstly.
|
||||
A TaskManager instance serves a specific task pool.
|
||||
The static method of this module serves the whole MongoDB.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_pool: str
|
||||
the name of Collection in MongoDB
|
||||
"""
|
||||
self.mdb = get_mongodb()
|
||||
if task_pool is not None:
|
||||
self.task_pool = getattr(self.mdb, task_pool)
|
||||
self.task_pool = getattr(get_mongodb(), task_pool)
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
def list(self) -> list:
|
||||
@staticmethod
|
||||
def list() -> list:
|
||||
"""
|
||||
List the all collection(task_pool) of the db
|
||||
List the all collection(task_pool) of the db.
|
||||
|
||||
Returns:
|
||||
list
|
||||
"""
|
||||
return self.mdb.list_collection_names()
|
||||
return get_mongodb().list_collection_names()
|
||||
|
||||
def _encode_task(self, task):
|
||||
for prefix in self.ENCODE_FIELDS_PREFIX:
|
||||
@@ -109,6 +110,25 @@ class TaskManager:
|
||||
def _dict_to_str(self, flt):
|
||||
return {k: str(v) for k, v in flt.items()}
|
||||
|
||||
def _decode_query(self, query):
|
||||
"""
|
||||
If the query includes any `_id`, then it needs `ObjectId` to decode.
|
||||
For example, when using TrainerRM, it needs query `{"_id": {"$in": _id_list}}`. Then we need to `ObjectId` every `_id` in `_id_list`.
|
||||
|
||||
Args:
|
||||
query (dict): query dict. Defaults to {}.
|
||||
|
||||
Returns:
|
||||
dict: the query after decoding.
|
||||
"""
|
||||
if "_id" in query:
|
||||
if isinstance(query["_id"], dict):
|
||||
for key in query["_id"]:
|
||||
query["_id"][key] = [ObjectId(i) for i in query["_id"][key]]
|
||||
else:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
return query
|
||||
|
||||
def replace_task(self, task, new_task):
|
||||
"""
|
||||
Use a new task to replace a old one
|
||||
@@ -224,8 +244,7 @@ class TaskManager:
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
query.update({"status": status})
|
||||
task = self.task_pool.find_one_and_update(
|
||||
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
|
||||
@@ -283,12 +302,11 @@ class TaskManager:
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
for t in self.task_pool.find(query):
|
||||
yield self._decode_task(t)
|
||||
|
||||
def re_query(self, _id):
|
||||
def re_query(self, _id) -> dict:
|
||||
"""
|
||||
Use _id to query task.
|
||||
|
||||
@@ -339,8 +357,7 @@ class TaskManager:
|
||||
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
self.task_pool.delete_many(query)
|
||||
|
||||
def task_stat(self, query={}) -> dict:
|
||||
@@ -354,8 +371,7 @@ class TaskManager:
|
||||
dict
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
tasks = self.query(query=query, decode=False)
|
||||
status_stat = {}
|
||||
for t in tasks:
|
||||
@@ -377,8 +393,7 @@ class TaskManager:
|
||||
|
||||
def reset_status(self, query, status):
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
print(self.task_pool.update_many(query, {"$set": {"status": status}}))
|
||||
|
||||
def prioritize(self, task, priority: int):
|
||||
@@ -402,9 +417,19 @@ class TaskManager:
|
||||
return sum(task_stat.values())
|
||||
|
||||
def wait(self, query={}):
|
||||
"""
|
||||
When multiprocessing, the main progress may fetch nothing from TaskManager because there are still some running tasks.
|
||||
So main progress should wait until all tasks are trained well by other progress or machines.
|
||||
|
||||
Args:
|
||||
query (dict, optional): the query dict. Defaults to {}.
|
||||
"""
|
||||
task_stat = self.task_stat(query)
|
||||
total = self._get_total(task_stat)
|
||||
last_undone_n = self._get_undone_n(task_stat)
|
||||
if last_undone_n == 0:
|
||||
return
|
||||
self.logger.warn(f"Waiting for {last_undone_n} undone tasks. Please make sure they are running.")
|
||||
with tqdm(total=total, initial=total - last_undone_n) as pbar:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
|
||||
Reference in New Issue
Block a user