1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 10:31:00 +08:00

Online serving V11

This commit is contained in:
lzh222333
2021-05-13 09:43:42 +00:00
parent 370b6aad74
commit d71a666904
13 changed files with 130 additions and 100 deletions

View File

@@ -7,10 +7,10 @@ This examples is about how can simulate the OnlineManager based on rolling tasks
import fire
import qlib
from qlib.model.trainer import DelayTrainerRM
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
from qlib.workflow import R
from qlib.workflow.online.manager import OnlineManager
from qlib.workflow.online.strategy import RollingAverageStrategy
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.task.manage import TaskManager
@@ -115,7 +115,7 @@ class OnlineSimulationExample:
) # The rolling tasks generator, ds_extra_mod_func is None because we just need simulate to 2018-10-31 and needn't change handler end time.
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool)
self.rolling_online_manager = OnlineManager(
RollingAverageStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
trainer=self.trainer,
begin_time=self.start_time,
)

View File

@@ -14,7 +14,7 @@ import pickle
import fire
import qlib
from qlib.workflow import R
from qlib.workflow.online.strategy import RollingAverageStrategy
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.online.manager import OnlineManager
@@ -97,7 +97,7 @@ class RollingOnlineExample:
for task in tasks:
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
strategy.append(
RollingAverageStrategy(
RollingStrategy(
name_id,
task,
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),

View File

@@ -172,7 +172,10 @@ class DatasetH(Dataset):
----------
slc : slice
"""
return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs)
if hasattr(self, "fetch_kwargs"):
return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs)
else:
return self.handler.fetch(slc, **kwargs)
def prepare(
self,

View File

@@ -7,7 +7,7 @@ Ensemble can merge the objects in an Ensemble. For example, if there are many su
from typing import Union
import pandas as pd
from qlib.utils import flatten_dict
from qlib.utils import FLATTEN_TUPLE, flatten_dict
class Ensemble:
@@ -90,7 +90,7 @@ class AverageEnsemble(Ensemble):
pd.DataFrame: the complete result of averaging and standardizing.
"""
# need to flatten the nested dict
ensemble_dict = flatten_dict(ensemble_dict)
ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE)
values = list(ensemble_dict.values())
results = pd.concat(values, axis=1)
results = results.groupby("datetime").apply(lambda df: (df - df.mean()) / df.std())

View File

@@ -416,7 +416,7 @@ class DelayTrainerRM(TrainerRM):
run_task(
end_train_func,
task_pool,
tasks=tasks,
query={"filter": {"$in": tasks}}, # only train these tasks
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
**kwargs,

View File

@@ -716,23 +716,33 @@ def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
return df.sort_index(axis=axis)
FLATTEN_TUPLE = "_FLATTEN_TUPLE"
def flatten_dict(d, parent_key="", sep="."):
"""flatten_dict.
"""
Flatten a nested dict.
>>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]})
>>> {'a': 1, 'c.a': 2, 'c.b.x': 5, 'd': [1, 2, 3], 'c.b.y': 10}
Parameters
----------
d :
d
parent_key :
parent_key
sep :
sep
>>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]}, sep=FLATTEN_TUPLE)
>>> {'a': 1, ('c','a'): 2, ('c','b','x'): 5, 'd': [1, 2, 3], ('c','b','y'): 10}
Args:
d (dict): the dict waiting for flatting
parent_key (str, optional): the parent key, will be a prefix in new key. Defaults to "".
sep (str, optional): the separator for string connecting. FLATTEN_TUPLE for tuple connecting.
Returns:
dict: flatten dict
"""
items = []
for k, v in d.items():
new_key = parent_key + sep + str(k) if parent_key else k
if sep == FLATTEN_TUPLE:
new_key = (parent_key, k) if parent_key else k
else:
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.abc.MutableMapping):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:

View File

@@ -16,9 +16,10 @@ class Serializable:
"""
pickle_backend = "pickle" # another optional value is "dill" which can pickle more things of python.
default_dump_all = False # if dump all things
def __init__(self):
self._dump_all = False
self._dump_all = self.default_dump_all
self._exclude = []
def __getstate__(self) -> dict:
@@ -77,12 +78,7 @@ class Serializable:
def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list = None):
self.config(dump_all=dump_all, exclude=exclude)
with Path(path).open("wb") as f:
if self.pickle_backend == "pickle":
pickle.dump(self, f)
elif self.pickle_backend == "dill":
dill.dump(self, f)
else:
raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")
self.get_backend().dump(self, f)
@classmethod
def load(cls, filepath):
@@ -99,13 +95,24 @@ class Serializable:
Collector: the instance of Collector
"""
with open(filepath, "rb") as f:
if cls.pickle_backend == "pickle":
object = pickle.load(f)
elif cls.pickle_backend == "dill":
object = dill.load(f)
else:
raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")
object = cls.get_backend().load(f)
if isinstance(object, cls):
return object
else:
raise TypeError(f"The instance of {type(object)} is not a valid `{type(cls)}`!")
@classmethod
def get_backend(cls):
"""
Return the backend of a Serializable class. The value will be "pickle" or "dill".
Returns:
str: The value of "pickle" or "dill"
"""
if cls.pickle_backend == "pickle":
return pickle
elif cls.pickle_backend == "dill":
return dill
else:
raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")

View File

@@ -35,7 +35,7 @@ class OnlineManager(Serializable):
def __init__(
self,
strategy: Union[OnlineStrategy, List[OnlineStrategy]],
strategies: Union[OnlineStrategy, List[OnlineStrategy]],
trainer: Trainer = None,
begin_time: Union[str, pd.Timestamp] = None,
freq="day",
@@ -45,15 +45,15 @@ class OnlineManager(Serializable):
One OnlineManager must have at least one OnlineStrategy.
Args:
strategy (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy
strategies (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy
begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None for using latest date.
trainer (Trainer): the trainer to train task. None for using DelayTrainerR.
freq (str, optional): data frequency. Defaults to "day".
"""
self.logger = get_module_logger(self.__class__.__name__)
if not isinstance(strategy, list):
strategy = [strategy]
self.strategy = strategy
if not isinstance(strategies, list):
strategies = [strategies]
self.strategies = strategies
self.freq = freq
if begin_time is None:
begin_time = D.calendar(freq=self.freq).max()
@@ -77,7 +77,7 @@ class OnlineManager(Serializable):
"""
models_list = []
if strategies is None:
strategies = self.strategy
strategies = self.strategies
for strategy in strategies:
self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
tasks = strategy.first_tasks()
@@ -114,21 +114,22 @@ class OnlineManager(Serializable):
cur_time = D.calendar(freq=self.freq).max()
self.cur_time = pd.Timestamp(cur_time) # None for latest date
models_list = []
for strategy in self.strategy:
for strategy in self.strategies:
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
if not delay:
strategy.tool.update_online_pred()
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
models = self.trainer.train(tasks)
self.logger.info(f"Finished training {len(models)} models.")
models_list.append(models)
for strategy, models in zip(self.strategies, models_list):
self.prepare_online_models(strategy, models, delay=delay, model_kwargs=model_kwargs)
if not delay:
self.prepare_signals(**signal_kwargs)
for strategy, models in zip(self.strategy, models_list):
self.prepare_online_models(strategy, models, delay=delay, model_kwargs=model_kwargs)
def prepare_online_models(
self, strategy: OnlineStrategy, models: list, delay: bool = False, model_kwargs: dict = {}
):
@@ -141,14 +142,9 @@ class OnlineManager(Serializable):
delay (bool, optional): if delay prepare models. Defaults to False.
model_kwargs (dict, optional): the params for `prepare_online_models`.
"""
if not models:
return
if not delay:
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
online_models = strategy.prepare_online_models(models, **model_kwargs)
else:
# just set every models as online models temporarily before ``prepare_online_models``
online_models = models
online_models = strategy.prepare_online_models(models, **model_kwargs)
self.history.setdefault(self.cur_time, {})[strategy] = online_models
def get_collector(self) -> MergeCollector:
@@ -160,21 +156,21 @@ class OnlineManager(Serializable):
MergeCollector: the collector to merge other collectors.
"""
collector_dict = {}
for strategy in self.strategy:
for strategy in self.strategies:
collector_dict[strategy.name_id] = strategy.get_collector()
return MergeCollector(collector_dict, process_list=[])
def add_strategy(self, strategy: Union[OnlineStrategy, List[OnlineStrategy]]):
def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]):
"""
Add some new strategies to online manager.
Args:
strategy (Union[OnlineStrategy, List[OnlineStrategy]]): a list of OnlineStrategy
"""
if not isinstance(strategy, list):
strategy = [strategy]
self.first_train(strategy)
self.strategy.extend(strategy)
if not isinstance(strategies, list):
strategies = [strategies]
self.first_train(strategies)
self.strategies.extend(strategies)
def prepare_signals(self, prepare_func: Callable = AverageEnsemble(), over_write=False):
"""
@@ -258,7 +254,8 @@ class OnlineManager(Serializable):
if self.trainer.is_delay():
self.delay_prepare(model_kwargs=model_kwargs, signal_kwargs=signal_kwargs)
set_global_logger_level(logging.INFO)
# FIXME: get logging level firstly and restore it here
set_global_logger_level(logging.DEBUG)
self.logger.info(f"Finished preparing signals")
return self.get_collector()

View File

@@ -44,28 +44,23 @@ class OnlineStrategy:
"""
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
def prepare_online_models(self, models, cur_time=None, check_func=None) -> List[object]:
def prepare_online_models(self, models, cur_time=None) -> List[object]:
"""
A typically implementation, but maybe you will need old models by online_tool.
Select some models as the online models from the trained models.
Select some models from trained models and set them to online models.
This is a typically implementation to online all trained models, you can override it to implement complex method.
You can find last online models by OnlineTool.online_models if you still need them.
NOTE: This method offline all models and online the online models prepared by this method (if have). So you can find last online models by OnlineTool.online_models if you still need them.
NOTE: Reset all online models to trained model. If there is no trained models, then do nothing.
Args:
tasks (list): a list of tasks.
check_func: the method to judge if a model can be online.
The parameter is the model record and return True for online.
None for online every models.
models (list): a list of models.
cur_time (pd.Dataframe): current time from OnlineManger. None for latest.
Returns:
List[object]: a list of selected models.
List[object]: a list of online models.
"""
if check_func is not None:
online_models = []
for model in models:
if check_func(model, cur_time):
online_models.append(model)
models = online_models
if not models:
return self.tool.online_models()
self.tool.reset_online_tag(models)
return models
@@ -89,10 +84,10 @@ class OnlineStrategy:
raise NotImplementedError(f"Please implement the `get_collector` method.")
class RollingAverageStrategy(OnlineStrategy):
class RollingStrategy(OnlineStrategy):
"""
This example strategy always use latest rolling model as online model and prepare trading signals using the average prediction of online models
This example strategy always use latest rolling model as online model.
"""
def __init__(
@@ -102,7 +97,7 @@ class RollingAverageStrategy(OnlineStrategy):
rolling_gen: RollingGen,
):
"""
Init RollingAverageStrategy.
Init RollingStrategy.
Assumption: the str of name_id, the experiment name and the trainer's experiment name are same one.

View File

@@ -137,7 +137,9 @@ class PredUpdater(RecordUpdater):
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
if start_time >= self.to_date:
self.logger.info(f"The prediction in {self.record.info['id']} are latest. No need to update.")
self.logger.info(
f"The prediction in {self.record.info['id']} are latest ({start_time}). No need to update to {self.to_date}."
)
return
# load dataset

View File

@@ -158,6 +158,10 @@ class OnlineToolR(OnlineTool):
"""
online_models = self.online_models()
for rec in online_models:
PredUpdater(rec, to_date=to_date).update()
hist_ref = 0
task = rec.load_object("task")
if task["dataset"]["class"] == "TSDatasetH":
hist_ref = task["dataset"]["kwargs"]["step_len"]
PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update()
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")

View File

@@ -91,18 +91,21 @@ class MergeCollector(Collector):
"""
def __init__(self, collector_dict: Dict[str, Collector], process_list: List[Callable] = []):
def __init__(self, collector_dict: Dict[str, Collector], process_list: List[Callable] = [], merge_func=None):
"""
Args:
collector_dict (Dict[str,Collector]): the dict like {collector_key, Collector}
process_list (List[Callable]): the list of processors or the instance of processor to process dict.
merge_func (Callable): a method to generate outermost key. The given params are ``collector_key`` from collector_dict and ``key`` from every collector after collecting.
None for use tuple to connect them, such as "ABC"+("a","b") -> ("ABC", ("a","b")).
"""
super().__init__(process_list=process_list)
self.collector_dict = collector_dict
self.merge_func = merge_func
def collect(self) -> dict:
"""
Collect all result of collector_dict and change the outermost key to "``collector_key``_``key``" (like merge them, but rename every key)
Collect all result of collector_dict and change the outermost key to a recombination key.
Returns:
dict: the dict after collecting.
@@ -111,7 +114,10 @@ class MergeCollector(Collector):
for collector_key, collector in self.collector_dict.items():
tmp_dict = collector()
for key, value in tmp_dict.items():
collect_dict[collector_key + "_" + str(key)] = value
if self.merge_func is not None:
collect_dict[self.merge_func(collector_key, key)] = value
else:
collect_dict[(collector_key, key)] = value
return collect_dict
@@ -146,16 +152,18 @@ class RecorderCollector(Collector):
rec_key_func = lambda rec: rec.info["id"]
if artifacts_key is None:
artifacts_key = list(self.artifacts_path.keys())
self._rec_key_func = rec_key_func
self.rec_key_func = rec_key_func
self.artifacts_key = artifacts_key
self._rec_filter_func = rec_filter_func
self.rec_filter_func = rec_filter_func
def collect(self, artifacts_key=None, rec_filter_func=None) -> dict:
def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> dict:
"""Collect different artifacts based on recorder after filtering.
Args:
artifacts_key (str or List, optional): the artifacts key you want to get. If None, use default.
rec_filter_func (Callable, optional): filter the recorder by return True or False. If None, use default.
only_exist (bool, optional): if only collect the artifacts when a recorder really have.
If True, the recorder with exception when loading will not be collected. But if False, it will raise the exception.
Returns:
dict: the dict after collected like {artifact: {rec_key: object}}
@@ -163,7 +171,7 @@ class RecorderCollector(Collector):
if artifacts_key is None:
artifacts_key = self.artifacts_key
if rec_filter_func is None:
rec_filter_func = self._rec_filter_func
rec_filter_func = self.rec_filter_func
if isinstance(artifacts_key, str):
artifacts_key = [artifacts_key]
@@ -177,16 +185,18 @@ class RecorderCollector(Collector):
recs_flt[rid] = rec
for _, rec in recs_flt.items():
rec_key = self._rec_key_func(rec)
rec_key = self.rec_key_func(rec)
for key in artifacts_key:
if self.ART_KEY_RAW == key:
artifact = rec
else:
# only collect existing artifact
try:
artifact = rec.load_object(self.artifacts_path[key])
except Exception:
continue
except Exception as e:
if only_exist:
# only collect existing artifact
continue
raise e
collect_dict.setdefault(key, {})[rec_key] = artifact
return collect_dict

View File

@@ -155,7 +155,8 @@ class TaskManager:
def create_task(self, task_def_l, dry_run=False, print_nt=False) -> List[str]:
"""
If the tasks in task_def_l is new, then insert new tasks into the task_pool
If the tasks in task_def_l is new, then insert new tasks into the task_pool, and record inserted_id.
If a task is not new, then query its _id.
Parameters
----------
@@ -169,9 +170,10 @@ class TaskManager:
Returns
-------
list
a list of the _id of new tasks
a list of the _id of task_def_l
"""
new_tasks = []
_id_list = []
for t in task_def_l:
try:
r = self.task_pool.find_one({"filter": t})
@@ -179,6 +181,14 @@ class TaskManager:
r = self.task_pool.find_one({"filter": self._dict_to_str(t)})
if r is None:
new_tasks.append(t)
if not dry_run:
insert_result = self.insert_task_def(t)
_id_list.append(insert_result.inserted_id)
else:
_id_list.append(None)
else:
_id_list.append(self._decode_task(r)["_id"])
self.logger.info(f"Total Tasks: {len(task_def_l)}, New Tasks: {len(new_tasks)}")
if print_nt: # print new task
@@ -188,11 +198,6 @@ class TaskManager:
if dry_run:
return []
_id_list = []
for t in new_tasks:
insert_result = self.insert_task_def(t)
_id_list.append(insert_result.inserted_id)
return _id_list
def fetch_task(self, query={}, status=STATUS_WAITING) -> dict:
@@ -388,7 +393,7 @@ class TaskManager:
def run_task(
task_func: Callable,
task_pool: str,
tasks: List[dict] = None,
query: dict = {},
force_release: bool = False,
before_status: str = TaskManager.STATUS_WAITING,
after_status: str = TaskManager.STATUS_DONE,
@@ -414,8 +419,8 @@ def run_task(
the function to run the task
task_pool : str
the name of the task pool (Collection in MongoDB)
tasks: List[dict]
will only train these tasks config, None for train all tasks.
query: dict
will use this dict to query task_pool when fetching task
force_release : bool
will the program force to release the resource
before_status : str:
@@ -428,9 +433,6 @@ def run_task(
tm = TaskManager(task_pool)
ever_run = False
query = {}
if tasks is not None:
query = {"filter": {"$in": tasks}}
while True:
with tm.safe_fetch_task(status=before_status, query=query) as task: