mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
online serving v10
This commit is contained in:
@@ -55,6 +55,7 @@ More information of ``Task Manager`` can be found in `here <../reference/api.htm
|
||||
|
||||
Task Training
|
||||
===============
|
||||
#FIXME: Trainer
|
||||
After generating and storing those ``task``, it's time to run the ``task`` which are in the *WAITING* status.
|
||||
``Qlib`` provides a method called ``run_task`` to run those ``task`` in task pool, however, users can also customize how tasks are executed.
|
||||
An easy way to get the ``task_func`` is using ``qlib.model.trainer.task_train`` directly.
|
||||
|
||||
@@ -8,6 +8,7 @@ This examples is about how can simulate the OnlineManager based on rolling tasks
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.trainer import DelayTrainerRM
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.workflow.online.strategy import RollingAverageStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
@@ -110,23 +111,29 @@ class OnlineSimulationExample:
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.rolling_gen = RollingGen(
|
||||
step=rolling_step, rtype=RollingGen.ROLL_SD, modify_end_time=False
|
||||
) # The rolling tasks generator, modify_end_time is false because we just need simulate to 2018-10-31.
|
||||
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
|
||||
) # The rolling tasks generator, ds_extra_mod_func is None because we just need simulate to 2018-10-31 and needn't change handler end time.
|
||||
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool)
|
||||
self.task_manager = TaskManager(self.task_pool) # A good way to manage all your tasks
|
||||
self.rolling_online_manager = OnlineManager(
|
||||
RollingAverageStrategy(
|
||||
exp_name, task_template=tasks, rolling_gen=self.rolling_gen, trainer=self.trainer, need_log=False
|
||||
),
|
||||
RollingAverageStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen, need_log=False),
|
||||
trainer=self.trainer,
|
||||
begin_time=self.start_time,
|
||||
need_log=False,
|
||||
)
|
||||
self.tasks = tasks
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
TaskManager(self.task_pool).remove()
|
||||
exp = R.get_exp(experiment_name=self.exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
# Run this to run all workflow automatically
|
||||
def main(self):
|
||||
print("========== reset ==========")
|
||||
self.rolling_online_manager.reset()
|
||||
self.reset()
|
||||
print("========== simulate ==========")
|
||||
self.rolling_online_manager.simulate(end_time=self.end_time)
|
||||
print("========== collect results ==========")
|
||||
@@ -134,7 +141,7 @@ class OnlineSimulationExample:
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
print("========== online history ==========")
|
||||
print(self.rolling_online_manager.get_online_history(self.exp_name))
|
||||
print(self.rolling_online_manager.history)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -18,8 +18,6 @@ from qlib.workflow.online.strategy import RollingAverageStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
from qlib.model.trainer import TrainerRM
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2013-01-01",
|
||||
@@ -86,7 +84,7 @@ class RollingOnlineExample:
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
rolling_step=550,
|
||||
tasks=[task_xgboost_config], # , task_lgb_config],
|
||||
tasks=[task_xgboost_config, task_lgb_config],
|
||||
):
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
@@ -103,7 +101,6 @@ class RollingOnlineExample:
|
||||
name_id,
|
||||
task,
|
||||
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
TrainerRM(experiment_name=name_id, task_pool=name_id),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -116,9 +113,8 @@ class RollingOnlineExample:
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
print("========== reset ==========")
|
||||
for task in self.tasks:
|
||||
name_id = task["model"]["class"] + "_" + str(self.rolling_step)
|
||||
name_id = task["model"]["class"]
|
||||
TaskManager(name_id).remove()
|
||||
exp = R.get_exp(experiment_name=name_id)
|
||||
for rid in exp.list_recorders():
|
||||
@@ -127,12 +123,9 @@ class RollingOnlineExample:
|
||||
if os.path.exists(self._ROLLING_MANAGER_PATH):
|
||||
os.remove(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
for rid in list_recorders("OnlineManagerSignals", lambda x: True if x.info["name"] == name_id else False):
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
def first_run(self):
|
||||
print("========== reset ==========")
|
||||
self.rolling_online_manager.reset()
|
||||
self.reset()
|
||||
print("========== first_run ==========")
|
||||
self.rolling_online_manager.first_train()
|
||||
print("========== dump ==========")
|
||||
|
||||
@@ -7,6 +7,7 @@ Ensemble can merge the objects in an Ensemble. For example, if there are many su
|
||||
|
||||
from typing import Union
|
||||
import pandas as pd
|
||||
from qlib.utils import flatten_dict
|
||||
|
||||
|
||||
class Ensemble:
|
||||
@@ -77,19 +78,22 @@ class RollingEnsemble(Ensemble):
|
||||
class AverageEnsemble(Ensemble):
|
||||
def __call__(self, ensemble_dict: dict):
|
||||
"""
|
||||
Average a dict of same shape dataframe like `prediction` or `IC` into an ensemble.
|
||||
Average and standardize a dict of same shape dataframe like `prediction` or `IC` into an ensemble.
|
||||
|
||||
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime"
|
||||
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime". If it is a nested dict, then flat it.
|
||||
|
||||
Args:
|
||||
ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}.
|
||||
The key of the dict will be ignored.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: the complete result of averaging.
|
||||
pd.DataFrame: the complete result of averaging and standardizing.
|
||||
"""
|
||||
# need to flatten the nested dict
|
||||
ensemble_dict = flatten_dict(ensemble_dict)
|
||||
values = list(ensemble_dict.values())
|
||||
results = pd.concat(values, axis=1)
|
||||
results = results.mean(axis=1).to_frame("score")
|
||||
results = results.groupby("datetime").apply(lambda df: (df - df.mean()) / df.std())
|
||||
results = results.mean(axis=1)
|
||||
results = results.sort_index()
|
||||
return results
|
||||
|
||||
@@ -36,20 +36,36 @@ class Group:
|
||||
self._ens_func = ens
|
||||
|
||||
def group(self, *args, **kwargs) -> dict:
|
||||
# TODO: such design is weird when `_group_func` is the only configurable part in the class
|
||||
"""
|
||||
Group a set of object and change them to a dict.
|
||||
|
||||
For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
|
||||
|
||||
Returns:
|
||||
dict: grouped dict
|
||||
"""
|
||||
if isinstance(getattr(self, "_group_func", None), Callable):
|
||||
return self._group_func(*args, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Please specify valid `group_func`.")
|
||||
|
||||
def reduce(self, *args, **kwargs) -> dict:
|
||||
"""
|
||||
Reduce grouped dict in some way.
|
||||
|
||||
For example: {(A,B): {C1: object, C2: object}} -> {(A,B): object}
|
||||
|
||||
Returns:
|
||||
dict: reduced dict
|
||||
"""
|
||||
if isinstance(getattr(self, "_ens_func", None), Callable):
|
||||
return self._ens_func(*args, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Please specify valid `_ens_func`.")
|
||||
|
||||
def __call__(self, ungrouped_dict: dict, n_jobs=1, verbose=0, *args, **kwargs) -> dict:
|
||||
"""Group the ungrouped_dict into different groups.
|
||||
"""
|
||||
Group the ungrouped_dict into different groups.
|
||||
|
||||
Args:
|
||||
ungrouped_dict (dict): the ungrouped dict waiting for grouping like {name: things}
|
||||
|
||||
@@ -12,7 +12,6 @@ In ``DelayTrainer``, the first step is only to save some necessary info to model
|
||||
"""
|
||||
|
||||
import socket
|
||||
import time
|
||||
from typing import Callable, List
|
||||
|
||||
from qlib.data.dataset import Dataset
|
||||
@@ -145,12 +144,6 @@ class Trainer:
|
||||
"""
|
||||
return self.delay
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the Trainer status.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TrainerR(Trainer):
|
||||
"""
|
||||
@@ -160,42 +153,52 @@ class TrainerR(Trainer):
|
||||
Assumption: models were defined by `task` and the results will saved to `Recorder`
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str, train_func: Callable = task_train):
|
||||
# Those tag will help you distinguish whether the Recorder has finished traning
|
||||
STATUS_KEY = "train_status"
|
||||
STATUS_BEGIN = "begin_task_train"
|
||||
STATUS_END = "end_task_train"
|
||||
|
||||
def __init__(self, experiment_name: str = None, train_func: Callable = task_train):
|
||||
"""
|
||||
Init TrainerR.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the name of experiment.
|
||||
experiment_name (str, optional): the default name of experiment.
|
||||
train_func (Callable, optional): default training method. Defaults to `task_train`.
|
||||
"""
|
||||
super().__init__()
|
||||
self.experiment_name = experiment_name
|
||||
self.train_func = train_func
|
||||
|
||||
def train(self, tasks: list, train_func: Callable = None, **kwargs) -> List[Recorder]:
|
||||
def train(self, tasks: list, train_func: Callable = None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
|
||||
|
||||
Args:
|
||||
tasks (list): a list of definition based on `task` dict
|
||||
train_func (Callable): the train method which need at least `task`s and `experiment_name`. None for default training method.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
kwargs: the params for train_func.
|
||||
|
||||
Returns:
|
||||
list: a list of Recorders
|
||||
"""
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
if train_func is None:
|
||||
train_func = self.train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
recs = []
|
||||
for task in tasks:
|
||||
rec = train_func(task, self.experiment_name, **kwargs)
|
||||
rec.set_tags(**{"train_status": "begin_task_train"})
|
||||
rec = train_func(task, experiment_name, **kwargs)
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
|
||||
recs.append(rec)
|
||||
return recs
|
||||
|
||||
def end_train(self, recs: list, **kwargs) -> list:
|
||||
def end_train(self, recs: list, **kwargs) -> List[Recorder]:
|
||||
for rec in recs:
|
||||
rec.set_tags(**{"train_status": "end_task_train"})
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
|
||||
@@ -204,12 +207,12 @@ class DelayTrainerR(TrainerR):
|
||||
A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name, train_func=begin_task_train, end_train_func=end_task_train):
|
||||
def __init__(self, experiment_name: str = None, train_func=begin_task_train, end_train_func=end_task_train):
|
||||
"""
|
||||
Init TrainerRM.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the name of experiment.
|
||||
experiment_name (str): the default name of experiment.
|
||||
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
|
||||
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
|
||||
"""
|
||||
@@ -217,7 +220,7 @@ class DelayTrainerR(TrainerR):
|
||||
self.end_train_func = end_train_func
|
||||
self.delay = True
|
||||
|
||||
def end_train(self, recs, end_train_func=None, **kwargs) -> List[Recorder]:
|
||||
def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
Given a list of Recorder and return a list of trained Recorder.
|
||||
This class will finish real data loading and model fitting.
|
||||
@@ -225,6 +228,7 @@ class DelayTrainerR(TrainerR):
|
||||
Args:
|
||||
recs (list): a list of Recorder, the tasks have been saved to them
|
||||
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
kwargs: the params for end_train_func.
|
||||
|
||||
Returns:
|
||||
@@ -232,9 +236,13 @@ class DelayTrainerR(TrainerR):
|
||||
"""
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
for rec in recs:
|
||||
end_train_func(rec, **kwargs)
|
||||
rec.set_tags(**{"train_status": "end_task_train"})
|
||||
if rec.list_tags()[self.STATUS_KEY] == self.STATUS_END:
|
||||
continue
|
||||
end_train_func(rec, experiment_name, **kwargs)
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
|
||||
@@ -246,13 +254,18 @@ class TrainerRM(Trainer):
|
||||
Assumption: `task` will be saved to TaskManager and `task` will be fetched and trained from TaskManager
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str, task_pool: str, train_func=task_train):
|
||||
# Those tag will help you distinguish whether the Recorder has finished traning
|
||||
STATUS_KEY = "train_status"
|
||||
STATUS_BEGIN = "begin_task_train"
|
||||
STATUS_END = "end_task_train"
|
||||
|
||||
def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
|
||||
"""
|
||||
Init TrainerR.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the name of experiment.
|
||||
task_pool (str): task pool name in TaskManager.
|
||||
experiment_name (str): the default name of experiment.
|
||||
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
|
||||
train_func (Callable, optional): default training method. Defaults to `task_train`.
|
||||
"""
|
||||
super().__init__()
|
||||
@@ -264,6 +277,7 @@ class TrainerRM(Trainer):
|
||||
self,
|
||||
tasks: list,
|
||||
train_func: Callable = None,
|
||||
experiment_name: str = None,
|
||||
before_status: str = TaskManager.STATUS_WAITING,
|
||||
after_status: str = TaskManager.STATUS_DONE,
|
||||
**kwargs,
|
||||
@@ -277,6 +291,7 @@ class TrainerRM(Trainer):
|
||||
Args:
|
||||
tasks (list): a list of definition based on `task` dict
|
||||
train_func (Callable): the train method which need at least `task`s and `experiment_name`. None for default training method.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
before_status (str): the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE.
|
||||
after_status (str): the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE.
|
||||
kwargs: the params for train_func.
|
||||
@@ -284,14 +299,21 @@ class TrainerRM(Trainer):
|
||||
Returns:
|
||||
list: a list of Recorders
|
||||
"""
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
if train_func is None:
|
||||
train_func = self.train_func
|
||||
tm = TaskManager(task_pool=self.task_pool)
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
tm = TaskManager(task_pool=task_pool)
|
||||
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
|
||||
run_task(
|
||||
train_func,
|
||||
self.task_pool,
|
||||
experiment_name=self.experiment_name,
|
||||
task_pool,
|
||||
experiment_name=experiment_name,
|
||||
before_status=before_status,
|
||||
after_status=after_status,
|
||||
**kwargs,
|
||||
@@ -300,23 +322,15 @@ class TrainerRM(Trainer):
|
||||
recs = []
|
||||
for _id in _id_list:
|
||||
rec = tm.re_query(_id)["res"]
|
||||
rec.set_tags(**{"train_status": "begin_task_train"})
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
|
||||
recs.append(rec)
|
||||
return recs
|
||||
|
||||
def end_train(self, recs: list, **kwargs) -> list:
|
||||
for rec in recs:
|
||||
rec.set_tags(**{"train_status": "end_task_train"})
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
.. note::
|
||||
this method will delete all task in this task_pool!
|
||||
"""
|
||||
tm = TaskManager(task_pool=self.task_pool)
|
||||
tm.remove()
|
||||
|
||||
|
||||
class DelayTrainerRM(TrainerRM):
|
||||
"""
|
||||
@@ -324,30 +338,57 @@ class DelayTrainerRM(TrainerRM):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name, task_pool: str, train_func=begin_task_train, end_train_func=end_task_train):
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name: str = None,
|
||||
task_pool: str = None,
|
||||
train_func=begin_task_train,
|
||||
end_train_func=end_task_train,
|
||||
):
|
||||
"""
|
||||
Init DelayTrainerRM.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the default name of experiment.
|
||||
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
|
||||
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
|
||||
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
|
||||
"""
|
||||
super().__init__(experiment_name, task_pool, train_func)
|
||||
self.end_train_func = end_train_func
|
||||
self.delay = True
|
||||
|
||||
def train(self, tasks: list, train_func=None, **kwargs):
|
||||
def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs):
|
||||
"""
|
||||
Same as `train` of TrainerRM, after_status will be STATUS_PART_DONE.
|
||||
Args:
|
||||
tasks (list): a list of definition based on `task` dict
|
||||
train_func (Callable): the train method which need at least `task`s and `experiment_name`. Defaults to None for using self.train_func.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
Returns:
|
||||
list: a list of Recorders
|
||||
"""
|
||||
return super().train(tasks, train_func=train_func, after_status=TaskManager.STATUS_PART_DONE, **kwargs)
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
return super().train(
|
||||
tasks,
|
||||
train_func=train_func,
|
||||
experiment_name=experiment_name,
|
||||
after_status=TaskManager.STATUS_PART_DONE,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def end_train(self, recs, end_train_func=None, **kwargs):
|
||||
def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs):
|
||||
"""
|
||||
Given a list of Recorder and return a list of trained Recorder.
|
||||
This class will finish real data loading and model fitting.
|
||||
|
||||
NOTE: This method will train all STATUS_PART_DONE tasks in task pool, not only the ``recs``.
|
||||
|
||||
Args:
|
||||
recs (list): a list of Recorder, the tasks have been saved to them.
|
||||
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
kwargs: the params for end_train_func.
|
||||
|
||||
Returns:
|
||||
@@ -356,13 +397,23 @@ class DelayTrainerRM(TrainerRM):
|
||||
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
tasks = []
|
||||
for rec in recs:
|
||||
tasks.append(rec.load_object("task"))
|
||||
|
||||
run_task(
|
||||
end_train_func,
|
||||
self.task_pool,
|
||||
experiment_name=self.experiment_name,
|
||||
task_pool,
|
||||
tasks=tasks,
|
||||
experiment_name=experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
**kwargs,
|
||||
)
|
||||
for rec in recs:
|
||||
rec.set_tags(**{"train_status": "end_task_train"})
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
@@ -732,7 +732,7 @@ def flatten_dict(d, parent_key="", sep="."):
|
||||
"""
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
new_key = parent_key + sep + k if parent_key else k
|
||||
new_key = parent_key + sep + str(k) if parent_key else k
|
||||
if isinstance(v, collections.abc.MutableMapping):
|
||||
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
||||
else:
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import dill
|
||||
from typing import Union
|
||||
|
||||
|
||||
@@ -14,6 +15,8 @@ class Serializable:
|
||||
- For examples, a learnable Datahandler just wants to save the parameters without data when dumping to disk
|
||||
"""
|
||||
|
||||
pickle_backend = "pickle" # another optional value is "dill" which can pickle more things of python.
|
||||
|
||||
def __init__(self):
|
||||
self._dump_all = False
|
||||
self._exclude = []
|
||||
@@ -74,4 +77,35 @@ class Serializable:
|
||||
def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list = None):
|
||||
self.config(dump_all=dump_all, exclude=exclude)
|
||||
with Path(path).open("wb") as f:
|
||||
pickle.dump(self, f)
|
||||
if self.pickle_backend == "pickle":
|
||||
pickle.dump(self, f)
|
||||
elif self.pickle_backend == "dill":
|
||||
dill.dump(self, f)
|
||||
else:
|
||||
raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")
|
||||
|
||||
@classmethod
|
||||
def load(cls, filepath):
|
||||
"""
|
||||
load the collector from a file
|
||||
|
||||
Args:
|
||||
filepath (str): the path of file
|
||||
|
||||
Raises:
|
||||
TypeError: the pickled file must be `Collector`
|
||||
|
||||
Returns:
|
||||
Collector: the instance of Collector
|
||||
"""
|
||||
with open(filepath, "rb") as f:
|
||||
if cls.pickle_backend == "pickle":
|
||||
object = pickle.load(f)
|
||||
elif cls.pickle_backend == "dill":
|
||||
object = dill.load(f)
|
||||
else:
|
||||
raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")
|
||||
if isinstance(object, cls):
|
||||
return object
|
||||
else:
|
||||
raise TypeError(f"The instance of {type(object)} is not a valid `{type(cls)}`!")
|
||||
|
||||
@@ -12,15 +12,17 @@ This module also provide a method to simulate `Online Strategy <#Online Strategy
|
||||
Which means you can verify your strategy or find a better one.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Union
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
import pandas as pd
|
||||
from qlib import get_module_logger
|
||||
from qlib.data.data import D
|
||||
from qlib.model.ens.ensemble import AverageEnsemble, SingleKeyEnsemble
|
||||
from qlib.model.ens.ensemble import AverageEnsemble
|
||||
from qlib.model.trainer import DelayTrainerR, Trainer
|
||||
from qlib.utils import flatten_dict
|
||||
from qlib.utils.serial import Serializable
|
||||
from qlib.workflow.online.strategy import OnlineStrategy
|
||||
from qlib.workflow.task.collect import HyperCollector
|
||||
from qlib.workflow.task.collect import MergeCollector
|
||||
|
||||
|
||||
class OnlineManager(Serializable):
|
||||
@@ -32,6 +34,7 @@ class OnlineManager(Serializable):
|
||||
def __init__(
|
||||
self,
|
||||
strategy: Union[OnlineStrategy, List[OnlineStrategy]],
|
||||
trainer: Trainer = None,
|
||||
begin_time: Union[str, pd.Timestamp] = None,
|
||||
freq="day",
|
||||
need_log=True,
|
||||
@@ -43,6 +46,7 @@ class OnlineManager(Serializable):
|
||||
Args:
|
||||
strategy (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy
|
||||
begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None for using latest date.
|
||||
trainer (Trainer): the trainer to train task. None for using DelayTrainerR.
|
||||
freq (str, optional): data frequency. Defaults to "day".
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
@@ -56,96 +60,166 @@ class OnlineManager(Serializable):
|
||||
begin_time = D.calendar(freq=self.freq).max()
|
||||
self.begin_time = pd.Timestamp(begin_time)
|
||||
self.cur_time = self.begin_time
|
||||
self.history = {}
|
||||
# The history of online models, which is a dict like {begin_time, {strategy, [online_models]}}
|
||||
# begin_time means when online_models are onlined
|
||||
self.history = {}
|
||||
if trainer is None:
|
||||
trainer = DelayTrainerR()
|
||||
self.trainer = trainer
|
||||
self.signals = None
|
||||
|
||||
def first_train(self):
|
||||
def first_train(self, strategies:List[OnlineStrategy]=None, model_kwargs: dict = {}):
|
||||
"""
|
||||
Run every strategy first_train method and record the online history.
|
||||
Get tasks from every strategy's first_tasks method and train them.
|
||||
If using DelayTrainer, it can finish training all together after every strategy's first_tasks.
|
||||
|
||||
Args:
|
||||
strategies (List[OnlineStrategy]): the strategies list (need this param when adding strategies). None for use default strategies.
|
||||
model_kwargs (dict): the params for `prepare_online_models`
|
||||
"""
|
||||
for strategy in self.strategy:
|
||||
models_list = []
|
||||
if strategies is None:
|
||||
strategies = self.strategy
|
||||
for strategy in strategies:
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
|
||||
online_models = strategy.first_train()
|
||||
self.history.setdefault(strategy.name_id, {})[self.cur_time] = online_models
|
||||
tasks = strategy.first_tasks()
|
||||
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
|
||||
models_list.append(models)
|
||||
|
||||
def routine(self, cur_time: Union[str, pd.Timestamp] = None, task_kwargs: dict = {}, model_kwargs: dict = {}):
|
||||
for strategy, models in zip(strategies, models_list):
|
||||
self.prepare_online_models(strategy, models, model_kwargs=model_kwargs)
|
||||
|
||||
def routine(
|
||||
self,
|
||||
cur_time: Union[str, pd.Timestamp] = None,
|
||||
delay: bool = False,
|
||||
task_kwargs: dict = {},
|
||||
model_kwargs: dict = {},
|
||||
signal_kwargs: dict = {},
|
||||
):
|
||||
"""
|
||||
Run typical update process for every strategy and record the online history.
|
||||
|
||||
The typical update process after a routine, such as day by day or month by month.
|
||||
The process is: Prepare signals -> Prepare tasks -> Prepare online models.
|
||||
|
||||
If using DelayTrainer, it can finish training all together after every strategy's prepare_tasks.
|
||||
|
||||
Args:
|
||||
cur_time (Union[str,pd.Timestamp], optional): run routine method in this time. Defaults to None.
|
||||
delay (bool): if delay prepare signals and models
|
||||
task_kwargs (dict): the params for `prepare_tasks`
|
||||
model_kwargs (dict): the params for `prepare_online_models`
|
||||
signal_kwargs (dict): the params for `prepare_signals`
|
||||
"""
|
||||
if cur_time is None:
|
||||
cur_time = D.calendar(freq=self.freq).max()
|
||||
self.cur_time = pd.Timestamp(cur_time) # None for latest date
|
||||
models_list = []
|
||||
for strategy in self.strategy:
|
||||
if not delay:
|
||||
strategy.tool.update_online_pred()
|
||||
if self.need_log:
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
|
||||
if not strategy.trainer.is_delay():
|
||||
strategy.prepare_signals()
|
||||
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
|
||||
online_models = strategy.prepare_online_models(tasks, **model_kwargs)
|
||||
if len(online_models) > 0:
|
||||
self.history.setdefault(strategy.name_id, {})[self.cur_time] = online_models
|
||||
|
||||
def get_collector(self) -> HyperCollector:
|
||||
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
|
||||
models = self.trainer.train(tasks)
|
||||
models_list.append(models)
|
||||
|
||||
if not delay:
|
||||
self.prepare_signals(**signal_kwargs)
|
||||
|
||||
for strategy, models in zip(self.strategy, models_list):
|
||||
self.prepare_online_models(strategy, models, delay=delay, model_kwargs=model_kwargs)
|
||||
|
||||
def prepare_online_models(
|
||||
self, strategy: OnlineStrategy, models: list, delay: bool = False, model_kwargs: dict = {}
|
||||
):
|
||||
"""
|
||||
Prepare online model for strategy, including end_train, reset_online_tag and add history.
|
||||
|
||||
Args:
|
||||
strategy (OnlineStrategy): the instance of strategy.
|
||||
models (list): a list of models.
|
||||
delay (bool, optional): if delay prepare models. Defaults to False.
|
||||
model_kwargs (dict, optional): the params for `prepare_online_models`.
|
||||
"""
|
||||
if not delay:
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
online_models = strategy.prepare_online_models(models, **model_kwargs)
|
||||
else:
|
||||
# just set every models as online models temporarily before ``prepare_online_models``
|
||||
online_models = models
|
||||
if len(online_models) > 0:
|
||||
strategy.tool.reset_online_tag(online_models)
|
||||
self.history.setdefault(self.cur_time, {})[strategy] = online_models
|
||||
|
||||
def get_collector(self) -> MergeCollector:
|
||||
"""
|
||||
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy.
|
||||
|
||||
Returns:
|
||||
HyperCollector: the collector to collect other collectors (using SingleKeyEnsemble() to make results more readable).
|
||||
MergeCollector: the collector to merge other collectors.
|
||||
"""
|
||||
collector_dict = {}
|
||||
for strategy in self.strategy:
|
||||
collector_dict[strategy.name_id] = strategy.get_collector()
|
||||
return HyperCollector(collector_dict, process_list=SingleKeyEnsemble())
|
||||
return MergeCollector(collector_dict, process_list=[])
|
||||
|
||||
def get_online_history(self, strategy_name_id: str) -> list:
|
||||
def add_strategy(self, strategy: Union[OnlineStrategy, List[OnlineStrategy]]):
|
||||
"""
|
||||
Get the online history based on strategy_name_id.
|
||||
Add some new strategies to online manager.
|
||||
|
||||
Args:
|
||||
strategy_name_id (str): the name_id of strategy
|
||||
|
||||
Returns:
|
||||
list: a list like [(begin_time, [online_models])]
|
||||
strategy (Union[OnlineStrategy, List[OnlineStrategy]]): a list of OnlineStrategy
|
||||
"""
|
||||
history_dict = self.history[strategy_name_id]
|
||||
history = []
|
||||
for time in sorted(history_dict):
|
||||
models = history_dict[time]
|
||||
history.append((time, models))
|
||||
return history
|
||||
if not isinstance(strategy, list):
|
||||
strategy = [strategy]
|
||||
self.first_train(strategy)
|
||||
self.strategy.extend(strategy)
|
||||
|
||||
def delay_prepare(self, delay_kwargs={}):
|
||||
def prepare_signals(self, prepare_func: Callable = AverageEnsemble(), over_write=False):
|
||||
"""
|
||||
Prepare all models and signals if there are something waiting for prepare.
|
||||
After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine.
|
||||
|
||||
NOTE: Given a set prediction, all signals before these prediction end time will be prepared well.
|
||||
|
||||
Even if the latest signal already exists, the latest calculation result will be overwritten.
|
||||
|
||||
.. note::
|
||||
|
||||
Given a prediction of a certain time, all signals before this time will be prepared well.
|
||||
|
||||
Args:
|
||||
delay_kwargs: the params for `delay_prepare`
|
||||
"""
|
||||
for strategy in self.strategy:
|
||||
strategy.delay_prepare(self.get_online_history(strategy.name_id), **delay_kwargs)
|
||||
|
||||
def get_signals(self) -> pd.DataFrame:
|
||||
"""
|
||||
Average all strategy signals as the online signals.
|
||||
|
||||
Assumption: the signals from every strategy is pd.DataFrame. Override this function to change.
|
||||
prepare_func (Callable, optional): Get signals from a dict after collecting. Defaults to AverageEnsemble(), the results after mergecollector must be {xxx:pred}.
|
||||
over_write (bool, optional): If True, the new signals will overwrite. If False, the new signals will append to the end of signals. Defaults to False.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: signals
|
||||
pd.DataFrame: the signals.
|
||||
"""
|
||||
signals_dict = {}
|
||||
for strategy in self.strategy:
|
||||
signals_dict[strategy.name_id] = strategy.get_signals()
|
||||
return AverageEnsemble()(signals_dict)
|
||||
signals = prepare_func(self.get_collector()())
|
||||
old_signals = self.signals
|
||||
if old_signals is not None and not over_write:
|
||||
old_max = old_signals.index.get_level_values("datetime").max()
|
||||
new_signals = signals.loc[old_max:]
|
||||
signals = pd.concat([old_signals, new_signals], axis=0)
|
||||
else:
|
||||
new_signals = signals
|
||||
if self.need_log:
|
||||
self.logger.info(f"Finished preparing new {len(new_signals)} signals.")
|
||||
self.signals = signals
|
||||
return new_signals
|
||||
|
||||
def simulate(self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, delay_kwargs={}) -> HyperCollector:
|
||||
def get_signals(self) -> pd.Series:
|
||||
"""
|
||||
Get prepared online signals.
|
||||
|
||||
Returns:
|
||||
pd.Series: signals
|
||||
"""
|
||||
return self.signals
|
||||
|
||||
def simulate(self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}):
|
||||
"""
|
||||
Starting from current time, this method will simulate every routine in OnlineManager until end time.
|
||||
|
||||
@@ -153,6 +227,13 @@ class OnlineManager(Serializable):
|
||||
|
||||
The delay training way can be ``DelayTrainer`` and the delay preparing signals way can be ``delay_prepare``.
|
||||
|
||||
Args:
|
||||
end_time: the time the simulation will end
|
||||
frequency: the calendar frequency
|
||||
task_kwargs (dict): the params for `prepare_tasks`
|
||||
model_kwargs (dict): the params for `prepare_online_models`
|
||||
signal_kwargs (dict): the params for `prepare_signals`
|
||||
|
||||
Returns:
|
||||
HyperCollector: the OnlineManager's collector
|
||||
"""
|
||||
@@ -160,18 +241,30 @@ class OnlineManager(Serializable):
|
||||
self.first_train()
|
||||
for cur_time in cal:
|
||||
self.logger.info(f"Simulating at {str(cur_time)}......")
|
||||
self.routine(cur_time, task_kwargs=task_kwargs, model_kwargs=model_kwargs)
|
||||
self.delay_prepare(delay_kwargs=delay_kwargs)
|
||||
self.routine(
|
||||
cur_time,
|
||||
delay=self.trainer.is_delay(),
|
||||
task_kwargs=task_kwargs,
|
||||
model_kwargs=model_kwargs,
|
||||
signal_kwargs=signal_kwargs,
|
||||
)
|
||||
# delay prepare the models and signals
|
||||
if self.trainer.is_delay():
|
||||
self.delay_prepare(model_kwargs=model_kwargs, signal_kwargs=signal_kwargs)
|
||||
self.logger.info(f"Finished preparing signals")
|
||||
return self.get_collector()
|
||||
|
||||
def reset(self):
|
||||
def delay_prepare(self, model_kwargs={}, signal_kwargs={}):
|
||||
"""
|
||||
This method will reset all strategy!
|
||||
Prepare all models and signals if there are something waiting for prepare.
|
||||
|
||||
**Be careful to use it.**
|
||||
Args:
|
||||
model_kwargs: the params for `prepare_online_models`
|
||||
signal_kwargs: the params for `prepare_signals`
|
||||
"""
|
||||
self.cur_time = self.begin_time
|
||||
self.history = {}
|
||||
for strategy in self.strategy:
|
||||
strategy.reset()
|
||||
for cur_time, strategy_models in self.history.items():
|
||||
self.cur_time = cur_time
|
||||
for strategy, models in strategy_models.items():
|
||||
self.prepare_online_models(strategy, models, delay=False, model_kwargs=model_kwargs)
|
||||
# NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way.
|
||||
self.prepare_signals(**signal_kwargs)
|
||||
|
||||
@@ -7,19 +7,14 @@ OnlineStrategy is a set of strategy for online serving.
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
from qlib.data.data import D
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.ens.ensemble import AverageEnsemble, SingleKeyEnsemble
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import Trainer, TrainerR
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.utils import OnlineTool, OnlineToolR
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.collect import Collector, HyperCollector, RecorderCollector
|
||||
from qlib.workflow.task.collect import Collector, RecorderCollector
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.utils import TimeAdjuster, list_recorders
|
||||
from qlib.workflow.task.utils import TimeAdjuster
|
||||
|
||||
|
||||
class OnlineStrategy:
|
||||
@@ -27,7 +22,7 @@ class OnlineStrategy:
|
||||
OnlineStrategy is working with `Online Manager <#Online Manager>`_, responsing how the tasks are generated, the models are updated and signals are perpared.
|
||||
"""
|
||||
|
||||
def __init__(self, name_id: str, trainer: Trainer = None, need_log=True):
|
||||
def __init__(self, name_id: str, need_log=True):
|
||||
"""
|
||||
Init OnlineStrategy.
|
||||
This module **MUST** use `Trainer <../reference/api.html#Trainer>`_ to finishing model training.
|
||||
@@ -38,34 +33,22 @@ class OnlineStrategy:
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
self.name_id = name_id
|
||||
self.trainer = trainer
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.need_log = need_log
|
||||
self.tool = OnlineTool()
|
||||
self.tool = OnlineTool(need_log)
|
||||
|
||||
def prepare_signals(self, delay: bool = False):
|
||||
def prepare_tasks(self, cur_time, **kwargs) -> List[dict]:
|
||||
"""
|
||||
After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine.
|
||||
|
||||
NOTE: Given a set prediction, all signals before these prediction end time will be prepared well.
|
||||
|
||||
Args:
|
||||
delay: bool
|
||||
If this method was called by `delay_prepare`
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `prepare_signals` method.")
|
||||
|
||||
def prepare_tasks(self, *args, **kwargs):
|
||||
"""
|
||||
After the end of a routine, check whether we need to prepare and train some new tasks.
|
||||
After the end of a routine, check whether we need to prepare and train some new tasks based on cur_time (None for latest)..
|
||||
Return the new tasks waiting for training.
|
||||
|
||||
You can find last online models by OnlineTool.online_models.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
|
||||
|
||||
def prepare_online_models(self, tasks, check_func=None, **kwargs):
|
||||
def prepare_online_models(self, models, cur_time=None, check_func=None, **kwargs):
|
||||
"""
|
||||
A typically implementation, but maybe you will need old models by online_tool.
|
||||
Use trainer to train a list of tasks and set the trained model to `online`.
|
||||
|
||||
NOTE: This method will first offline all models and online the online models prepared by this method. So you can find last online models by OnlineTool.online_models if you still need them.
|
||||
@@ -78,64 +61,34 @@ class OnlineStrategy:
|
||||
**kwargs: will be passed to end_train which means will be passed to customized train method.
|
||||
|
||||
"""
|
||||
if check_func is None:
|
||||
check_func = lambda x: True
|
||||
online_models = []
|
||||
if len(tasks) > 0:
|
||||
new_models = self.trainer.train(tasks, **kwargs)
|
||||
for model in new_models:
|
||||
if check_func(model):
|
||||
if check_func is not None:
|
||||
online_models = []
|
||||
for model in models:
|
||||
if check_func(model, cur_time):
|
||||
online_models.append(model)
|
||||
self.tool.reset_online_tag(online_models)
|
||||
return online_models
|
||||
models = online_models
|
||||
self.tool.reset_online_tag(models)
|
||||
return models
|
||||
|
||||
def first_train(self):
|
||||
def first_tasks(self) -> List[dict]:
|
||||
"""
|
||||
Train a series of models firstly and set some of them as online models.
|
||||
Generate a series of tasks firstly and return them.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `first_train` method.")
|
||||
raise NotImplementedError(f"Please implement the `first_tasks` method.")
|
||||
|
||||
def get_collector(self) -> Collector:
|
||||
"""
|
||||
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results of online serving.
|
||||
|
||||
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect different results of this strategy.
|
||||
|
||||
For example:
|
||||
1) collect predictions in Recorder
|
||||
2) collect signals in .txt file
|
||||
2) collect signals in a txt file
|
||||
|
||||
Returns:
|
||||
Collector
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_collector` method.")
|
||||
|
||||
def delay_prepare(self, history: list, **kwargs):
|
||||
"""
|
||||
Prepare all models and signals if there are something waiting for prepare.
|
||||
|
||||
Assumption: the predictions of online models need less than next begin_time, or this method will work in a wrong way.
|
||||
|
||||
Args:
|
||||
history (list): an online models list likes [begin_time:[online models]].
|
||||
**kwargs: will be passed to end_train which means will be passed to customized train method.
|
||||
"""
|
||||
for begin_time, recs_list in history:
|
||||
self.trainer.end_train(recs_list, **kwargs)
|
||||
self.tool.reset_online_tag(recs_list)
|
||||
self.prepare_signals(delay=True)
|
||||
|
||||
def get_signals(self):
|
||||
"""
|
||||
Get prepared signals.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_signals` method.")
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Delete all things and set them to default status. This method is convenient to explore the strategy for online simulation.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class RollingAverageStrategy(OnlineStrategy):
|
||||
|
||||
@@ -148,9 +101,7 @@ class RollingAverageStrategy(OnlineStrategy):
|
||||
name_id: str,
|
||||
task_template: Union[dict, List[dict]],
|
||||
rolling_gen: RollingGen,
|
||||
trainer: Trainer = None,
|
||||
need_log=True,
|
||||
signal_exp_name="OnlineManagerSignals",
|
||||
):
|
||||
"""
|
||||
Init RollingAverageStrategy.
|
||||
@@ -161,22 +112,16 @@ class RollingAverageStrategy(OnlineStrategy):
|
||||
name_id (str): a unique name or id. Will be also the name of Experiment.
|
||||
task_template (Union[dict,List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen.
|
||||
rolling_gen (RollingGen): an instance of RollingGen
|
||||
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
signal_exp_path (str): a specific experiment to save signals of different experiment.
|
||||
"""
|
||||
super().__init__(name_id=name_id, trainer=trainer, need_log=need_log)
|
||||
super().__init__(name_id=name_id, need_log=need_log)
|
||||
self.exp_name = self.name_id
|
||||
if not isinstance(task_template, list):
|
||||
task_template = [task_template]
|
||||
self.task_template = task_template
|
||||
self.signal_exp_name = signal_exp_name
|
||||
self.rg = rolling_gen
|
||||
self.tool = OnlineToolR(self.exp_name)
|
||||
self.tool = OnlineToolR(self.exp_name, need_log)
|
||||
self.ta = TimeAdjuster()
|
||||
with R.start(experiment_name=self.signal_exp_name, recorder_name=self.exp_name, resume=True):
|
||||
self.signal_rec = R.get_recorder() # the recorder to record signals
|
||||
self.signal_rec.save_objects(**{"signals": None})
|
||||
|
||||
def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None):
|
||||
"""
|
||||
@@ -209,18 +154,17 @@ class RollingAverageStrategy(OnlineStrategy):
|
||||
|
||||
return artifacts_collector
|
||||
|
||||
def first_train(self) -> List[Recorder]:
|
||||
def first_tasks(self) -> List[dict]:
|
||||
"""
|
||||
Use rolling_gen to generate different tasks based on task_template and trained them.
|
||||
Use rolling_gen to generate different tasks based on task_template.
|
||||
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorder.
|
||||
List[dict]: a list of tasks
|
||||
"""
|
||||
tasks = task_generator(
|
||||
return task_generator(
|
||||
tasks=self.task_template,
|
||||
generators=self.rg, # generate different date segment
|
||||
)
|
||||
return self.prepare_online_models(tasks)
|
||||
|
||||
def prepare_tasks(self, cur_time) -> List[dict]:
|
||||
"""
|
||||
@@ -255,57 +199,6 @@ class RollingAverageStrategy(OnlineStrategy):
|
||||
return new_tasks
|
||||
return []
|
||||
|
||||
def prepare_signals(self, delay=False, over_write=False) -> pd.DataFrame:
|
||||
"""
|
||||
Average the predictions of online models and offer a trading signals every routine.
|
||||
The signals will be saved to `signal` file of a recorder named self.exp_name of a experiment using the name of `SIGNAL_EXP`
|
||||
Even if the latest signal already exists, the latest calculation result will be overwritten.
|
||||
|
||||
.. note::
|
||||
|
||||
Given a prediction of a certain time, all signals before this time will be prepared well.
|
||||
|
||||
Args:
|
||||
over_write (bool, optional): If True, the new signals will overwrite the file. If False, the new signals will append to the end of signals. Defaults to False.
|
||||
Returns:
|
||||
pd.DataFrame: the signals.
|
||||
"""
|
||||
if not delay:
|
||||
self.tool.update_online_pred()
|
||||
|
||||
# Get a collector to average online models predictions
|
||||
online_collector = self.get_collector(
|
||||
process_list=[AverageEnsemble()],
|
||||
rec_filter_func=lambda x: True if self.tool.get_online_tag(x) == self.tool.ONLINE_TAG else False,
|
||||
artifacts_key="pred",
|
||||
)
|
||||
online_results = online_collector()
|
||||
signals = online_results["pred"]
|
||||
|
||||
old_signals = self.get_signals()
|
||||
if old_signals is not None and not over_write:
|
||||
old_max = old_signals.index.get_level_values("datetime").max()
|
||||
new_signals = signals.loc[old_max:]
|
||||
signals = pd.concat([old_signals, new_signals], axis=0)
|
||||
else:
|
||||
new_signals = signals
|
||||
if self.need_log:
|
||||
self.logger.info(
|
||||
f"Finished preparing new {len(new_signals)} signals to {self.signal_exp_name}/{self.exp_name}."
|
||||
)
|
||||
self.signal_rec.save_objects(**{"signals": signals})
|
||||
return signals
|
||||
|
||||
def get_signals(self) -> object:
|
||||
"""
|
||||
Get signals from the recorder(named self.exp_name) of the experiment(named self.SIGNAL_EXP)
|
||||
|
||||
Returns:
|
||||
object: signals
|
||||
"""
|
||||
signals = self.signal_rec.load_object("signals")
|
||||
return signals
|
||||
|
||||
def _list_latest(self, rec_list: List[Recorder]):
|
||||
"""
|
||||
List latest recorder form rec_list
|
||||
@@ -324,16 +217,3 @@ class RollingAverageStrategy(OnlineStrategy):
|
||||
if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test:
|
||||
latest_rec.append(rec)
|
||||
return latest_rec, max_test
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
NOTE: This method will delete all recorder in Experiment and reset the Trainer!
|
||||
"""
|
||||
self.trainer.reset()
|
||||
# delete models
|
||||
exp = R.get_exp(experiment_name=self.exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
# delete signals
|
||||
for rid in list_recorders(self.signal_exp_name, lambda x: True if x.info["name"] == self.exp_name else False):
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
@@ -17,7 +17,7 @@ from qlib.workflow.task.utils import list_recorders
|
||||
|
||||
class OnlineTool:
|
||||
"""
|
||||
OnlineTool.
|
||||
OnlineTool will manage `online` models in an experiment which includes the models recorder.
|
||||
"""
|
||||
|
||||
ONLINE_KEY = "online_status" # the online status key in recorder
|
||||
@@ -92,7 +92,7 @@ class OnlineToolR(OnlineTool):
|
||||
The implementation of OnlineTool based on (R)ecorder.
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str, need_log=True):
|
||||
def __init__(self, experiment_name:str, need_log=True):
|
||||
"""
|
||||
Init OnlineToolR.
|
||||
|
||||
|
||||
@@ -5,14 +5,16 @@
|
||||
Collector can collect object from everywhere and process them such as merging, grouping, averaging and so on.
|
||||
"""
|
||||
|
||||
from qlib.model.ens.ensemble import SingleKeyEnsemble
|
||||
from typing import Callable, Dict, List
|
||||
from qlib.utils.serial import Serializable
|
||||
from qlib.workflow import R
|
||||
import dill as pickle
|
||||
|
||||
|
||||
class Collector:
|
||||
class Collector(Serializable):
|
||||
"""The collector to collect different results"""
|
||||
|
||||
pickle_backend = "dill" # use dill to dump user method
|
||||
|
||||
def __init__(self, process_list=[]):
|
||||
"""
|
||||
Args:
|
||||
@@ -74,65 +76,42 @@ class Collector:
|
||||
collected = self.collect()
|
||||
return self.process_collect(collected, self.process_list, *args, **kwargs)
|
||||
|
||||
def save(self, filepath):
|
||||
"""
|
||||
save the collector into a file
|
||||
|
||||
Args:
|
||||
filepath (str): the path of file
|
||||
|
||||
Returns:
|
||||
bool: if succeeded
|
||||
"""
|
||||
try:
|
||||
with open(filepath, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
except Exception:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def load(filepath):
|
||||
"""
|
||||
load the collector from a file
|
||||
|
||||
Args:
|
||||
filepath (str): the path of file
|
||||
|
||||
Raises:
|
||||
TypeError: the pickled file must be `Collector`
|
||||
|
||||
Returns:
|
||||
Collector: the instance of Collector
|
||||
"""
|
||||
with open(filepath, "rb") as f:
|
||||
collector = pickle.load(f)
|
||||
if isinstance(collector, Collector):
|
||||
return collector
|
||||
else:
|
||||
raise TypeError(f"The instance of {type(collector)} is not a valid `Collector`!")
|
||||
|
||||
|
||||
class HyperCollector(Collector):
|
||||
class MergeCollector(Collector):
|
||||
"""
|
||||
A collector to collect the results of other Collectors
|
||||
|
||||
For example:
|
||||
|
||||
We have 2 collector, which named A and B.
|
||||
A can collect {"prediction": pd.Series} and B can collect {"IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}.
|
||||
Then after this class's collect, we can collect {"A_prediction": pd.Series, "B_IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}
|
||||
|
||||
......
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, collector_dict, process_list=[]):
|
||||
def __init__(self, collector_dict: Dict[str, Collector], process_list: List[Callable] = []):
|
||||
"""
|
||||
Args:
|
||||
collector_dict (dict): the dict like {collector_key, Collector}
|
||||
process_list (list or Callable): the list of processors or the instance of processor to process dict.
|
||||
NOTE: process_list = [SingleKeyEnsemble()] can ignore key and use value directly if there is only one {k,v} in a dict.
|
||||
This can make result more readable. If you want to maintain as it should be, just give a empty process list.
|
||||
collector_dict (Dict[str,Collector]): the dict like {collector_key, Collector}
|
||||
process_list (List[Callable]): the list of processors or the instance of processor to process dict.
|
||||
"""
|
||||
super().__init__(process_list=process_list)
|
||||
self.collector_dict = collector_dict
|
||||
|
||||
def collect(self) -> dict:
|
||||
"""
|
||||
Collect all result of collector_dict and change the outermost key to "``collector_key``_``key``" (like merge them, but rename every key)
|
||||
|
||||
Returns:
|
||||
dict: the dict after collecting.
|
||||
"""
|
||||
collect_dict = {}
|
||||
for key, collector in self.collector_dict.items():
|
||||
collect_dict[key] = collector()
|
||||
for collector_key, collector in self.collector_dict.items():
|
||||
tmp_dict = collector()
|
||||
for key, value in tmp_dict.items():
|
||||
collect_dict[collector_key + "_" + str(key)] = value
|
||||
return collect_dict
|
||||
|
||||
|
||||
@@ -145,7 +124,7 @@ class RecorderCollector(Collector):
|
||||
process_list=[],
|
||||
rec_key_func=None,
|
||||
rec_filter_func=None,
|
||||
artifacts_path={"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"},
|
||||
artifacts_path={"pred": "pred.pkl"},
|
||||
artifacts_key=None,
|
||||
):
|
||||
"""init RecorderCollector
|
||||
@@ -203,7 +182,11 @@ class RecorderCollector(Collector):
|
||||
if self.ART_KEY_RAW == key:
|
||||
artifact = rec
|
||||
else:
|
||||
artifact = rec.load_object(self.artifacts_path[key])
|
||||
# only collect existing artifact
|
||||
try:
|
||||
artifact = rec.load_object(self.artifacts_path[key])
|
||||
except Exception:
|
||||
continue
|
||||
collect_dict.setdefault(key, {})[rec_key] = artifact
|
||||
|
||||
return collect_dict
|
||||
|
||||
@@ -5,7 +5,7 @@ Task generator can generate many tasks based on TaskGen and some task templates.
|
||||
"""
|
||||
import abc
|
||||
import copy
|
||||
import typing
|
||||
from typing import List, Union, Callable
|
||||
from .utils import TimeAdjuster
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ class TaskGen(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate(self, task: dict) -> typing.List[dict]:
|
||||
def generate(self, task: dict) -> List[dict]:
|
||||
"""
|
||||
generate different tasks based on a task template
|
||||
|
||||
@@ -87,11 +87,34 @@ class TaskGen(metaclass=abc.ABCMeta):
|
||||
return self.generate(*args, **kwargs)
|
||||
|
||||
|
||||
def handler_mod(task: dict, rg):
|
||||
"""
|
||||
Help to modify the handler end time when using RollingGen
|
||||
|
||||
Args:
|
||||
task (dict): a task template
|
||||
rg (RollingGen): an instance of RollingGen
|
||||
"""
|
||||
try:
|
||||
interval = rg.ta.cal_interval(
|
||||
task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"],
|
||||
task["dataset"]["kwargs"]["segments"][rg.test_key][1],
|
||||
)
|
||||
# if end_time < the end of test_segments, then change end_time to allow load more data
|
||||
if interval < 0:
|
||||
task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(
|
||||
task["dataset"]["kwargs"]["segments"][rg.test_key][1]
|
||||
)
|
||||
except KeyError:
|
||||
# Maybe dataset do not have handler, then do nothing.
|
||||
pass
|
||||
|
||||
|
||||
class RollingGen(TaskGen):
|
||||
ROLL_EX = TimeAdjuster.SHIFT_EX # fixed start date, expanding end date
|
||||
ROLL_SD = TimeAdjuster.SHIFT_SD # fixed segments size, slide it from start date
|
||||
|
||||
def __init__(self, step: int = 40, rtype: str = ROLL_EX, modify_end_time=True):
|
||||
def __init__(self, step: int = 40, rtype: str = ROLL_EX, ds_extra_mod_func: Union[None, Callable] = handler_mod):
|
||||
"""
|
||||
Generate tasks for rolling
|
||||
|
||||
@@ -101,19 +124,19 @@ class RollingGen(TaskGen):
|
||||
step to rolling
|
||||
rtype : str
|
||||
rolling type (expanding, sliding)
|
||||
modify_end_time: bool
|
||||
Whether the data set configuration needs to be modified when the required scope exceeds the original data set scope
|
||||
ds_extra_mod_func: Callable
|
||||
A method like: handler_mod(task: dict, rg: RollingGen)
|
||||
Do some extra action after generating a task. For example, use ``handler_mod`` to modify the end time of handler of dataset.
|
||||
"""
|
||||
self.step = step
|
||||
self.rtype = rtype
|
||||
self.modify_end_time = modify_end_time
|
||||
# TODO: Ask pengrong to update future date in dataset
|
||||
self.ds_extra_mod_func = ds_extra_mod_func
|
||||
self.ta = TimeAdjuster(future=True)
|
||||
|
||||
self.test_key = "test"
|
||||
self.train_key = "train"
|
||||
|
||||
def generate(self, task: dict) -> typing.List[dict]:
|
||||
def generate(self, task: dict) -> List[dict]:
|
||||
"""
|
||||
Converting the task into a rolling task.
|
||||
|
||||
@@ -200,18 +223,8 @@ class RollingGen(TaskGen):
|
||||
|
||||
# update segments of this task
|
||||
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
|
||||
|
||||
try:
|
||||
interval = self.ta.cal_interval(
|
||||
t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"],
|
||||
t["dataset"]["kwargs"]["segments"][self.test_key][1],
|
||||
)
|
||||
# if end_time < the end of test_segments, then change end_time to allow load more data
|
||||
if self.modify_end_time and interval < 0:
|
||||
t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(segments[self.test_key][1])
|
||||
except KeyError:
|
||||
# Maybe the user dataset has no handler or end_time
|
||||
pass
|
||||
prev_seg = segments
|
||||
if self.ds_extra_mod_func is not None:
|
||||
self.ds_extra_mod_func(t, self)
|
||||
res.append(t)
|
||||
return res
|
||||
|
||||
@@ -388,6 +388,7 @@ class TaskManager:
|
||||
def run_task(
|
||||
task_func: Callable,
|
||||
task_pool: str,
|
||||
tasks: List[dict] = None,
|
||||
force_release: bool = False,
|
||||
before_status: str = TaskManager.STATUS_WAITING,
|
||||
after_status: str = TaskManager.STATUS_DONE,
|
||||
@@ -413,6 +414,8 @@ def run_task(
|
||||
the function to run the task
|
||||
task_pool : str
|
||||
the name of the task pool (Collection in MongoDB)
|
||||
tasks: List[dict]
|
||||
will only train these tasks config, None for train all tasks.
|
||||
force_release : bool
|
||||
will the program force to release the resource
|
||||
before_status : str:
|
||||
@@ -425,9 +428,12 @@ def run_task(
|
||||
tm = TaskManager(task_pool)
|
||||
|
||||
ever_run = False
|
||||
query = {}
|
||||
if tasks is not None:
|
||||
query = {"filter": {"$in": tasks}}
|
||||
|
||||
while True:
|
||||
with tm.safe_fetch_task(status=before_status) as task:
|
||||
with tm.safe_fetch_task(status=before_status, query=query) as task:
|
||||
if task is None:
|
||||
break
|
||||
get_module_logger("run_task").info(task["def"])
|
||||
|
||||
Reference in New Issue
Block a user