1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 02:21:18 +08:00

multiprocessing support

This commit is contained in:
lzh222333
2021-05-24 05:07:38 +00:00
parent 8c3a08b18d
commit b24af7fff6
6 changed files with 145 additions and 26 deletions

View File

@@ -4,6 +4,7 @@
"""
This example shows how a TrainerRM works based on TaskManager with rolling tasks.
After training, how to collect the rolling results will be shown in task_collecting.
Based on the ability of TaskManager, `worker` method offer a simple way for multiprocessing.
"""
from pprint import pprint
@@ -13,10 +14,10 @@ import qlib
from qlib.config import REG_CN
from qlib.workflow import R
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.task.manage import TaskManager, run_task
from qlib.workflow.task.collect import RecorderCollector
from qlib.model.ens.group import RollingGroup
from qlib.model.trainer import TrainerRM
from qlib.model.trainer import TrainerRM, task_train
data_handler_config = {
@@ -122,6 +123,11 @@ class RollingTaskExample:
trainer = TrainerRM(self.experiment_name, self.task_pool)
trainer.train(tasks)
def worker(self):
# train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker.
print("========== worker ==========")
run_task(task_train, self.task_pool, experiment_name=self.experiment_name)
def task_collecting(self):
print("========== task_collecting ==========")

View File

@@ -78,8 +78,8 @@ class OnlineSimulationExample:
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
exp_name="rolling_exp",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
task_pool="rolling_task",
rolling_step=80,
start_time="2018-09-10",
@@ -113,7 +113,7 @@ class OnlineSimulationExample:
self.rolling_gen = RollingGen(
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
self.trainer = TrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
self.rolling_online_manager = OnlineManager(
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
trainer=self.trainer,
@@ -139,6 +139,15 @@ class OnlineSimulationExample:
print("========== signals ==========")
print(self.rolling_online_manager.get_signals())
def worker(self):
# train tasks by other progress or machines for multiprocessing
# FIXME: only can call after finishing simulation when using DelayTrainerRM, or there will be some exception.
print("========== worker ==========")
if isinstance(self.trainer, TrainerRM):
self.trainer.worker()
else:
print(f"{type(self.trainer)} is not supported for worker.")
if __name__ == "__main__":
## to run all workflow automatically with your own parameters, use the command below

View File

@@ -13,10 +13,12 @@ Finally, the OnlineManager will finish second routine and update all strategies.
import os
import fire
import qlib
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM, end_task_train, task_train
from qlib.workflow import R
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.online.manager import OnlineManager
from qlib.workflow.task.manage import TaskManager, run_task
data_handler_config = {
"start_time": "2013-01-01",
@@ -80,8 +82,9 @@ class RollingOnlineExample:
self,
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
trainer=DelayTrainerRM(), # you can choose from TrainerR, TrainerRM, DelayTrainerR, DelayTrainerRM
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
rolling_step=550,
tasks=[task_xgboost_config],
add_tasks=[task_lgb_config],
@@ -104,17 +107,28 @@ class RollingOnlineExample:
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
)
)
self.rolling_online_manager = OnlineManager(strategies)
self.trainer = trainer
self.rolling_online_manager = OnlineManager(strategies, trainer=self.trainer)
_ROLLING_MANAGER_PATH = (
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
)
def worker(self):
# train tasks by other progress or machines for multiprocessing
print("========== worker ==========")
if isinstance(self.trainer, TrainerRM):
for task in self.tasks + self.add_tasks:
name_id = task["model"]["class"]
self.trainer.worker(experiment_name=name_id)
else:
print(f"{type(self.trainer)} is not supported for worker.")
# Reset all things to the first status, be careful to save important data
def reset(self):
for task in self.tasks + self.add_tasks:
name_id = task["model"]["class"]
TaskManager(task_pool=name_id).remove()
exp = R.get_exp(experiment_name=name_id)
for rid in exp.list_recorders():
exp.delete_recorder(rid)

View File

@@ -12,9 +12,11 @@ In ``DelayTrainer``, the first step is only to save some necessary info to model
"""
import socket
import time
from typing import Callable, List
from qlib.data.dataset import Dataset
from qlib.log import get_module_logger
from qlib.model.base import Model
from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config
from qlib.workflow import R
@@ -190,6 +192,8 @@ class TrainerR(Trainer):
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(tasks, dict):
tasks = [tasks]
if len(tasks) == 0:
return []
if train_func is None:
@@ -213,6 +217,8 @@ class TrainerR(Trainer):
Returns:
List[Recorder]: the same list as the param.
"""
if isinstance(recs, Recorder):
recs = [recs]
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
@@ -250,6 +256,8 @@ class DelayTrainerR(TrainerR):
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(recs, Recorder):
recs = [recs]
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
@@ -315,6 +323,8 @@ class TrainerRM(Trainer):
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(tasks, dict):
tasks = [tasks]
if len(tasks) == 0:
return []
if train_func is None:
@@ -329,12 +339,24 @@ class TrainerRM(Trainer):
run_task(
train_func,
task_pool,
query={"filter": {"$in": tasks}}, # only train these tasks
experiment_name=experiment_name,
before_status=before_status,
after_status=after_status,
**kwargs,
)
# FIXME: reset to waiting automatically
for _id in _id_list:
is_prn = False
while tm.re_query(_id)["status"] == "running":
if not is_prn:
get_module_logger("TrainerRM").warn(
f"A task (_id: {_id}) is not being trained by this Trainer. Ignore this message if it is being trained by others."
)
is_prn = True
time.sleep(10)
recs = []
for _id in _id_list:
rec = tm.re_query(_id)["res"]
@@ -352,10 +374,33 @@ class TrainerRM(Trainer):
Returns:
List[Recorder]: the same list as the param.
"""
if isinstance(recs, Recorder):
recs = [recs]
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
def worker(
self,
train_func: Callable = None,
experiment_name: str = None,
):
"""
The multiprocessing method for `train`. It can share a same task_pool with `train` and can run in other progress or other machines.
Args:
train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method.
experiment_name (str): the experiment name, None for use default name.
"""
if train_func is None:
train_func = self.train_func
if experiment_name is None:
experiment_name = self.experiment_name
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
run_task(train_func, task_pool=task_pool, experiment_name=experiment_name)
class DelayTrainerRM(TrainerRM):
"""
@@ -395,6 +440,8 @@ class DelayTrainerRM(TrainerRM):
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(tasks, dict):
tasks = [tasks]
if len(tasks) == 0:
return []
return super().train(
@@ -410,8 +457,6 @@ class DelayTrainerRM(TrainerRM):
Given a list of Recorder and return a list of trained Recorder.
This class will finish real data loading and model fitting.
NOTE: This method will train all STATUS_PART_DONE tasks in the task pool, not only the ``recs``.
Args:
recs (list): a list of Recorder, the tasks have been saved to them.
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
@@ -421,7 +466,8 @@ class DelayTrainerRM(TrainerRM):
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(recs, Recorder):
recs = [recs]
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
@@ -441,6 +487,42 @@ class DelayTrainerRM(TrainerRM):
before_status=TaskManager.STATUS_PART_DONE,
**kwargs,
)
# FIXME: reset to waiting automatically
tm = TaskManager(task_pool=task_pool)
for query_task in tm.query({"filter": {"$in": tasks}}):
_id = query_task["_id"]
is_prn = False
while tm.re_query(_id)["status"] == "running":
if not is_prn:
get_module_logger("DelayTrainerRM").warn(
f"A task (_id: {_id}) is not being trained by this Trainer. Ignore this message if it is being trained by others."
)
is_prn = True
time.sleep(10)
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
def worker(self, end_train_func=None, experiment_name: str = None):
"""
The multiprocessing method for `end_train`. It can share a same task_pool with `end_train` and can run in other progress or other machines.
Args:
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
experiment_name (str): the experiment name, None for use default name.
"""
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
experiment_name = self.experiment_name
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
run_task(
end_train_func,
task_pool=task_pool,
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
)

View File

@@ -18,10 +18,12 @@ There are 4 total situations for using different trainers in different situation
========================= ===================================================================================
Situations Description
========================= ===================================================================================
Online + Trainer When you REAL want to do a routine, the Trainer will help you train the models.
Online + Trainer When you want to do a REAL routine, the Trainer will help you train the models. It
will train models task by task and strategy by strategy.
Online + DelayTrainer In normal online routine, whether Trainer or DelayTrainer will REAL train models
in this routine. So it is not necessary to use DelayTrainer when do a REAL routine.
Online + DelayTrainer When your models don't have any temporal dependence, the DelayTrainer will train
nothing until all tasks have been prepared. It makes user can train all tasks in
the end of `routine` or `first_train`.
Simulation + Trainer When your models have some temporal dependence on the previous models, then you
need to consider using Trainer. This means it will REAL train your models in
@@ -103,17 +105,21 @@ class OnlineManager(Serializable):
"""
if strategies is None:
strategies = self.strategies
for strategy in strategies:
models_list = []
for strategy in strategies:
self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
tasks = strategy.first_tasks()
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
models_list.append(models)
self.logger.info(f"Finished training {len(models)} models.")
online_models = strategy.prepare_online_models(models, **model_kwargs)
self.history.setdefault(self.cur_time, {})[strategy] = online_models
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
for strategy, models in zip(strategies, models_list):
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
def routine(
self,
cur_time: Union[str, pd.Timestamp] = None,
@@ -139,20 +145,22 @@ class OnlineManager(Serializable):
cur_time = D.calendar(freq=self.freq).max()
self.cur_time = pd.Timestamp(cur_time) # None for latest date
models_list = []
for strategy in self.strategies:
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
if self.status == self.STATUS_NORMAL:
strategy.tool.update_online_pred()
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
models = self.trainer.train(tasks)
if self.status == self.STATUS_NORMAL or not self.trainer.is_delay():
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
models_list.append(models)
self.logger.info(f"Finished training {len(models)} models.")
online_models = strategy.prepare_online_models(models, **model_kwargs)
self.history.setdefault(self.cur_time, {})[strategy] = online_models
if not self.trainer.is_delay():
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
for strategy, models in zip(self.strategies, models_list):
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
self.prepare_signals(**signal_kwargs)
def get_collector(self) -> MergeCollector:
@@ -297,6 +305,7 @@ class OnlineManager(Serializable):
# NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way.
self.prepare_signals(**signal_kwargs)
if signals_time > cur_time:
# FIXME: if use DelayTrainer and worker (and worker is faster than main progress), there are some possibilities of showing this warning.
self.logger.warn(
f"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models."
)

View File

@@ -69,7 +69,7 @@ class TaskManager:
ENCODE_FIELDS_PREFIX = ["def", "res"]
def __init__(self, task_pool: str = None):
def __init__(self, task_pool: str):
"""
Init Task Manager, remember to make the statement of MongoDB url and database name firstly.
@@ -79,8 +79,7 @@ class TaskManager:
the name of Collection in MongoDB
"""
self.mdb = get_mongodb()
if task_pool is not None:
self.task_pool = getattr(self.mdb, task_pool)
self.task_pool = getattr(self.mdb, task_pool)
self.logger = get_module_logger(self.__class__.__name__)
def list(self) -> list:
@@ -288,7 +287,7 @@ class TaskManager:
for t in self.task_pool.find(query):
yield self._decode_task(t)
def re_query(self, _id):
def re_query(self, _id) -> dict:
"""
Use _id to query task.