1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Online Serving V4

This commit is contained in:
lzh222333
2021-03-26 04:20:25 +00:00
parent 8abdd63869
commit 46cd57688e
12 changed files with 366 additions and 323 deletions

View File

@@ -77,7 +77,7 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
})
- `mongo`
Type: dict, optional parameter, the setting of `MongoDB <https://www.mongodb.com/>`_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.
Users need finished `installatin <https://www.mongodb.com/try/download/community>`_ firstly, and run it in a fixed URL.
Users need finished `installation <https://www.mongodb.com/try/download/community>`_ firstly, and run it in a fixed URL.
.. code-block:: Python

View File

@@ -1,13 +1,13 @@
from pprint import pprint
import fire
import qlib
from qlib.config import REG_CN
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
from qlib.config import C
from qlib.workflow.task.manage import run_task
from qlib.workflow.task.collect import RollingCollector
from qlib.model.trainer import task_train
from qlib.workflow import R
from pprint import pprint
from qlib.workflow.task.collect import RollingCollector
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager, run_task
data_handler_config = {
"start_time": "2008-01-01",
@@ -66,14 +66,14 @@ task_xgboost_config = {
}
# Reset all things to the first status, be careful to save important data
def reset():
def reset(task_pool, exp_name):
print("========== reset ==========")
TaskManager(task_pool=task_pool).remove()
# exp = R.get_exp(experiment_name=exp_name)
exp, _ = R.exp_manager._get_or_create_exp(experiment_name=exp_name)
# for rid in R.list_recorders():
# exp.delete_recorder(rid)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
# This part corresponds to "Task Generating" in the document
@@ -92,51 +92,58 @@ def task_generating():
# This part corresponds to "Task Storing" in the document
def task_storing(tasks):
def task_storing(tasks, task_pool, exp_name):
print("========== task_storing ==========")
tm = TaskManager(task_pool=task_pool)
tm.create_task(tasks) # all tasks will be saved to MongoDB
# This part corresponds to "Task Running" in the document
def task_running():
def task_running(task_pool, exp_name):
print("========== task_running ==========")
run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using "task_train" method
# This part corresponds to "Task Collecting" in the document
def task_collecting():
def task_collecting(task_pool, exp_name):
print("========== task_collecting ==========")
def get_task_key(task_config):
def get_group_key_func(recorder):
task_config = recorder.load_object("task")
return task_config["model"]["class"]
def my_filter(recorder):
# only choose the results of "LGBModel"
task_key = get_task_key(rolling_collector.get_task(recorder))
task_key = get_group_key_func(recorder)
if task_key == "LGBModel":
return True
return False
rolling_collector = RollingCollector(exp_name)
# group tasks by "get_task_key" and filter tasks by "my_filter"
pred_rolling = rolling_collector.collect_rolling_predictions(get_task_key, my_filter)
pred_rolling = rolling_collector.collect(get_group_key_func, my_filter)
print(pred_rolling)
if __name__ == "__main__":
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
def main(
provider_uri="~/.qlib/qlib_data/cn_data",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
exp_name="rolling_exp",
task_pool="rolling_task",
):
mongo_conf = {
"task_url": "mongodb://10.0.0.4:27017/", # maybe you need to change it to your url
"task_db_name": "rolling_db",
"task_url": task_url,
"task_db_name": task_db_name,
}
exp_name = "rolling_exp" # experiment name, will be used as the experiment in MLflow
task_pool = "rolling_task" # task pool name, will be used as the document in MongoDB
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf)
reset()
reset(task_pool, exp_name)
tasks = task_generating()
task_storing(tasks)
task_running()
task_collecting()
task_storing(tasks, task_pool, exp_name)
task_running(task_pool, exp_name)
task_collecting(task_pool, exp_name)
if __name__ == "__main__":
fire.Fire()

View File

@@ -1,16 +1,15 @@
import qlib
import fire
import mlflow
from qlib.config import C
from qlib.workflow import R
from pprint import pprint
import fire
import qlib
from qlib.config import REG_CN
from qlib.model.trainer import task_train
from qlib.workflow.task.manage import run_task
from qlib.workflow.task.manage import TaskManager
from qlib.workflow import R
from qlib.workflow.task.collect import RollingCollector
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager, run_task
from qlib.workflow.task.online import RollingOnlineManager
from qlib.workflow.task.utils import list_recorders
data_handler_config = {
"start_time": "2013-01-01",
@@ -70,12 +69,15 @@ task_xgboost_config = {
def print_online_model():
print("========== print_online_model ==========")
print("Current 'online' model:")
for online in rolling_online_manager.list_online_model().values():
print(online.info["id"])
for rid, rec in list_recorders(exp_name).items():
if rolling_online_manager.get_online_tag(rec) == rolling_online_manager.ONLINE_TAG:
print(rid)
print("Current 'next online' model:")
for online in rolling_online_manager.list_next_online_model().values():
print(online.info["id"])
for rid, rec in list_recorders(exp_name).items():
if rolling_online_manager.get_online_tag(rec) == rolling_online_manager.NEXT_ONLINE_TAG:
print(rid)
# This part corresponds to "Task Generating" in the document
@@ -110,119 +112,76 @@ def task_running():
def task_collecting():
print("========== task_collecting ==========")
def get_task_key(task_config):
def get_group_key_func(recorder):
task_config = recorder.load_object("task")
return task_config["model"]["class"]
def my_filter(recorder):
# only choose the results of "LGBModel"
task_key = get_task_key(rolling_collector.get_task(recorder))
task_key = get_group_key_func(recorder)
if task_key == "LGBModel":
return True
return False
rolling_collector = RollingCollector(exp_name)
# group tasks by "get_task_key" and filter tasks by "my_filter"
pred_rolling = rolling_collector.collect_rolling_predictions(get_task_key, my_filter)
pred_rolling = rolling_collector.collect(get_group_key_func, my_filter)
print(pred_rolling)
# Reset all things to the first status, be careful to save important data
def reset(force_end=False):
def reset():
print("========== reset ==========")
task_manager.remove()
for error in task_manager.query():
assert False
exp = R.get_exp(experiment_name=exp_name)
recs = exp.list_recorders()
for rid in recs:
exp, _ = R.exp_manager._get_or_create_exp(experiment_name=exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
try:
if force_end:
mlflow.end_run()
except Exception:
pass
# Run this firstly to see the workflow in Task Management
def first_run():
print("========== first_run ==========")
reset(force_end=True)
reset()
tasks = task_generating()
task_storing(tasks)
task_running()
task_collecting()
rolling_online_manager.set_latest_model_to_next_online()
rolling_online_manager.reset_online_model()
# Update the predictions of online model
def update_predictions():
print("========== update_predictions ==========")
rolling_online_manager.update_online_pred()
task_collecting()
# if there are some next_online_model, then online them. if no, still use current online_model.
print_online_model()
rolling_online_manager.reset_online_model()
print_online_model()
# Update the models using the latest date and set them to online model
def update_model():
print("========== update_model ==========")
rolling_online_manager.prepare_new_models()
print_online_model()
rolling_online_manager.set_latest_model_to_next_online()
print_online_model()
latest_rec, _ = rolling_online_manager.list_latest_recorders()
rolling_online_manager.reset_online_tag(latest_rec.values())
def after_day():
rolling_online_manager.prepare_signals()
update_model()
update_predictions()
# Run whole workflow completely
def whole_workflow():
print("========== whole_workflow ==========")
# run this at the first time
first_run()
# run this every day after trading
after_day()
print("========== after_day ==========")
print_online_model()
rolling_online_manager.after_day()
print_online_model()
task_collecting()
if __name__ == "__main__":
####### to train the first version's models, use the command below
# python task_manager_rolling_with_updating.py first_run
####### to update the models using the latest date, use the command below
# python task_manager_rolling_with_updating.py update_model
####### to update the predictions to the latest date, use the command below
# python task_manager_rolling_with_updating.py update_predictions
####### to run whole workflow completely, use the command below
# python task_manager_rolling_with_updating.py whole_workflow
####### to update the models and predictions after the trading time, use the command below
# python task_manager_rolling_with_updating.py after_day
#################### you need to finish the configurations below #########################
provider_uri = "~/.qlib/qlib_data/cn_data" # data_dir
qlib.init(provider_uri=provider_uri, region=REG_CN)
C["mongo"] = {
mongo_conf = {
"task_url": "mongodb://10.0.0.4:27017/", # your MongoDB url
"task_db_name": "online", # database name
"task_db_name": "rolling_db", # database name
}
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf)
exp_name = "rolling_exp" # experiment name, will be used as the experiment in MLflow
task_pool = "rolling_task" # task pool name, will be used as the document in MongoDB
rolling_step = 550
##########################################################################################
rolling_gen = RollingGen(step=550, rtype=RollingGen.ROLL_SD)
rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD)
rolling_online_manager = RollingOnlineManager(
experiment_name=exp_name, rolling_gen=rolling_gen, task_pool=task_pool
)

View File

@@ -1,9 +1,9 @@
import qlib
from qlib.model.trainer import task_train
from qlib.workflow.task.online import OnlineManager
from qlib.config import REG_CN
import fire
from qlib.workflow import R
import qlib
from qlib.config import REG_CN
from qlib.model.trainer import task_train
from qlib.workflow.task.online import OnlineManagerR
from qlib.workflow.task.utils import list_recorders
data_handler_config = {
"start_time": "2008-01-01",
@@ -56,19 +56,20 @@ def first_train(experiment_name="online_svr"):
rid = task_train(task_config=task, experiment_name=experiment_name)
rom = OnlineManager(experiment_name)
rom.reset_online_model(rid)
online_manager = OnlineManagerR(experiment_name)
online_manager.reset_online_tag(rid)
def update_online_pred(experiment_name="online_svr"):
rom = OnlineManager(experiment_name)
online_manager = OnlineManagerR(experiment_name)
print("Here are the online models waiting for update:")
for rid, rec in rom.list_online_model().items():
print(rid)
for rid, rec in list_recorders(experiment_name).items():
if online_manager.get_online_tag(rec) == OnlineManagerR.ONLINE_TAG:
print(rid)
rom.update_online_pred()
online_manager.update_online_pred()
if __name__ == "__main__":

View File

@@ -134,7 +134,7 @@ _default_config = {
},
"loggers": {"qlib": {"level": "DEBUG", "handlers": ["console"]}},
},
# Defatult config for experiment manager
# Default config for experiment manager
"exp_manager": {
"class": "MLflowExpManager",
"module_path": "qlib.workflow.expm",
@@ -143,6 +143,11 @@ _default_config = {
"default_exp_name": "Experiment",
},
},
# Default config for MongoDB
"mongo": {
"task_url": "mongodb://localhost:27017/",
"task_db_name": "default_task_db",
}
}
MODE_CONF = {

View File

@@ -27,6 +27,7 @@ def task_train(task_config: dict, experiment_name: str) -> str:
model = init_instance_by_config(task_config["model"])
dataset = init_instance_by_config(task_config["dataset"])
datahandler = dataset.handler
dataset.config(exclude=["handler"])
# start exp
with R.start(experiment_name=experiment_name):
@@ -37,10 +38,8 @@ def task_train(task_config: dict, experiment_name: str) -> str:
recorder = R.get_recorder()
R.save_objects(**{"params.pkl": model})
R.save_objects(**{"task": task_config}) # keep the original format and datatype
artifact_uri = recorder.get_artifact_uri()[7:] # delete "file://"
dataset.to_pickle(artifact_uri + "/dataset", exclude=["handler"])
datahandler.to_pickle(artifact_uri + "/datahandler")
R.save_objects(**{"dataset": dataset})
R.save_objects(**{"datahandler": datahandler})
# generate records: prediction, backtest, and analysis
records = task_config.get("record", [])

View File

@@ -1,116 +1,172 @@
from qlib.workflow import R
from abc import abstractmethod
from typing import Callable, Union
import pandas as pd
from typing import Union
from typing import Callable
from qlib import get_module_logger
from qlib.workflow.task.utils import list_recorders
class TaskCollector:
class Collector:
"""
Collect the record (or its results) of the tasks
This class will divide disorderly records or anything worth collecting into different groups based on the group_key.
After grouping, we can reduce the useful information from different groups.
"""
def group(self, *args, **kwargs):
"""
According to the get_group_key_func, divide disorderly things into different groups.
For example:
.. code-block:: python
input:
[thing1, thing2, thing3, thing4, thing5]
output:
{
"group_name1": [thing3, thing5, thing1]
"group_name2": [thing2, thing4]
}
Args:
get_group_key_func (Callable): get a group key based on a thing
things_list (list): a list of things
Returns:
dict: a dict including the group key and members of the group.
"""
raise NotImplementedError(f"Please implement the `group` method.")
def reduce(self, things_group: dict):
"""
Using the dict from `group`, reduce useful information.
Args:
things_group (dict): a dict after grouping
Returns:
dict: a dict including the group key, the information key and the information value
"""
raise NotImplementedError(f"Please implement the `reduce` method.")
def collect(self, *args, **kwargs):
"""group and reduce
Returns:
dict: a dict including the group key, the information key and the information value
"""
grouped = self.group(*args, **kwargs)
return self.reduce(grouped)
class RecorderCollector(Collector):
"""
The Recorder's Collector. This class is a implementation of Collector, collecting some artifacts saved by Recorder.
"""
def __init__(self, experiment_name: str) -> None:
self.exp_name = experiment_name
self.exp = R.get_exp(experiment_name=experiment_name)
self.logger = get_module_logger("TaskCollector")
self.logger = get_module_logger(self.__class__.__name__)
def list_recorders(self, rec_filter_func=None):
_artifacts_key_path = {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}
_artifacts_key_merge_method = {}
recs = self.exp.list_recorders()
recs_flt = {}
for rid, rec in recs.items():
if rec_filter_func is None or rec_filter_func(rec):
recs_flt[rid] = rec
def default_merge(self, artifact_list):
"""Merge disorderly artifacts in artifact list.
return recs_flt
Args:
artifact_list (list): A artifact list.
def list_recorders_by_task(self, task_filter_func=None):
def rec_filter(recorder):
return task_filter_func(self.get_task(recorder))
return self.list_recorders(rec_filter)
def list_latest_recorders(self, rec_filter_func=None):
recs_flt = self.list_recorders(rec_filter_func)
max_test = self.latest_time(recs_flt)
latest_rec = {}
for rid, rec in recs_flt.items():
if self.get_task(rec)["dataset"]["kwargs"]["segments"]["test"] == max_test:
latest_rec[rid] = rec
return latest_rec
def get_recorder_by_id(self, recorder_id):
return self.exp.get_recorder(recorder_id, create=False)
def get_task(self, recorder):
if isinstance(recorder, str):
recorder = self.get_recorder_by_id(recorder_id=recorder)
try:
task = recorder.load_object("task")
except OSError:
raise OSError(f"Can't find task in {recorder.info['id']}, have you trained with model.trainer.task_train?")
return task
def latest_time(self, recorders):
if len(recorders) == 0:
raise Exception(f"Can't find any recorder in {self.exp_name}")
max_test = max(self.get_task(rec)["dataset"]["kwargs"]["segments"]["test"] for rec in recorders.values())
return max_test
class RollingCollector(TaskCollector):
"""
Collect the record results of the rolling tasks
"""
def __init__(
self,
experiment_name: str,
) -> None:
super().__init__(experiment_name)
self.logger = get_module_logger("RollingCollector")
def collect_rolling_predictions(self, get_key_func, rec_filter_func=None):
"""For rolling tasks, the predictions will be in the diffierent recorder.
To collect and concat the predictions of one rolling task, get_key_func will help this method see which group a recorder will be.
Parameters
----------
get_key_func : Callable[dict,str]
a function that get task config and return its group str
rec_filter_func : Callable[Recorder,bool], optional
a function that decide whether filter a recorder, by default None
Returns
-------
dict
a dict of {group: predictions}
Raises:
NotImplementedError: [description]
"""
raise NotImplementedError(f"Please implement the `default_merge` method.")
def group(self, get_group_key_func, rec_filter_func=None):
"""
Filter recorders and group recorders by group key.
Args:
get_group_key_func (Callable): get a group key based on a recorder
rec_filter_func (Callable, optional): filter the recorders in this experiment. Defaults to None.
Returns:
dict: a dict including the group key and recorders of the group
"""
# filter records
recs_flt = self.list_recorders(rec_filter_func)
recs_flt = list_recorders(self.exp_name, rec_filter_func)
# group
recs_group = {}
for _, rec in recs_flt.items():
task = self.get_task(rec)
group_key = get_key_func(task)
group_key = get_group_key_func(rec)
recs_group.setdefault(group_key, []).append(rec)
# reduce group
reduce_group = {}
for k, rec_l in recs_group.items():
pred_l = []
for rec in rec_l:
pred_l.append(rec.load_object("pred.pkl").iloc[:, 0])
# Make sure the pred are sorted according to the rolling start time
pred_l.sort(key=lambda pred: pred.index.get_level_values("datetime").min())
pred = pd.concat(pred_l)
# If there are duplicated predition, we use the latest perdiction
pred = pred[~pred.index.duplicated(keep="last")]
pred = pred.sort_index()
reduce_group[k] = pred
return recs_group
return reduce_group
def reduce(self, recs_group: dict, artifact_keys_list: list = None):
"""
Reduce artifacts based on the dict of grouped recorder.
The artifacts need be declared by artifact_keys_list.
The artifacts path in recorder need be declared by _artifacts_key_path.
If there is no declartion in _artifacts_key_merge_method, then use default_merge method to merge it.
Args:
recs_group (dict): The dict grouped by `group`
artifact_keys_list (list): The list of artifact keys. If it is None, then use all artifacts in _artifacts_key_path.
Returns:
a dict including the group key, the artifact key and the artifact value.
For example:
.. code-block:: python
{
group_key: {"pred": <VALUE>, "IC": <VALUE>}
}
"""
if artifact_keys_list == None:
artifact_keys_list = self._artifacts_key_path.keys()
reduce_group = {}
for group_key, recorder_list in recs_group.items():
reduced_artifacts = {}
for artifact_key in artifact_keys_list:
artifact_list = []
for recorder in recorder_list:
artifact_list.append(recorder.load_object(self._artifacts_key_path[artifact_key]))
merge_method = self._artifacts_key_merge_method.get(artifact_key, self.default_merge)
artifact = merge_method(artifact_list)
reduced_artifacts[artifact_key] = artifact
reduce_group[group_key] = reduced_artifacts
return reduce_group
class RollingCollector(RecorderCollector):
"""
Collect the record results of the rolling tasks
"""
def __init__(self, experiment_name: str):
super().__init__(experiment_name)
self.logger = get_module_logger(self.__class__.__name__)
def default_merge(self, artifact_list):
"""merge disorderly artifacts based on the datetime.
Args:
artifact_list (list): a list of artifacts from different recorders
Returns:
merged artifact
"""
# Make sure the pred are sorted according to the rolling start time
artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min())
artifact = pd.concat(artifact_list)
# If there are duplicated predition, we use the latest perdiction
artifact = artifact[~artifact.index.duplicated(keep="last")]
artifact = artifact.sort_index()
return artifact

View File

@@ -19,10 +19,10 @@ def task_generator(tasks, generators) -> list:
Parameters
----------
tasks : List[dict]
a list of task templates
generators : List[TaskGen]
a list of TaskGen
tasks : List[dict] or dict
a list of task templates or a single task
generators : List[TaskGen] or TaskGen
a list of TaskGen or a single TaskGen
Returns
-------

View File

@@ -151,7 +151,8 @@ class TaskManager:
if print new task
Returns
-------
int
the length of new tasks
"""
task_pool = self._get_task_pool(task_pool)
new_tasks = []
@@ -173,6 +174,8 @@ class TaskManager:
for t in new_tasks:
self.insert_task_def(t, task_pool)
return len(new_tasks)
def fetch_task(self, query={}, task_pool=None):
task_pool = self._get_task_pool(task_pool)
@@ -245,10 +248,9 @@ class TaskManager:
for t in task_pool.find(query):
yield self._decode_task(t)
def get_task_result(self, task, task_pool=None):
def re_query(self, task, task_pool=None):
task_pool = self._get_task_pool(task_pool)
result = task_pool.find_one({"filter": task})
return self._decode_task(result)["res"]
return task_pool.find_one({"_id":ObjectId(task["_id"])})
def commit_task_res(self, task, res, status=None, task_pool=None):
task_pool = self._get_task_pool(task_pool)

View File

@@ -3,147 +3,140 @@ from qlib import get_module_logger
from qlib.workflow import R
from qlib.model.trainer import task_train
from qlib.workflow.recorder import MLflowRecorder, Recorder
from qlib.workflow.task.collect import TaskCollector
from qlib.workflow.task.update import ModelUpdater
from qlib.workflow.task.utils import TimeAdjuster
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.task.manage import run_task
from qlib.workflow.task.utils import list_recorders
from qlib.utils.serial import Serializable
class OnlineManager:
def prepare_new_models(self, tasks: List[dict]):
"""prepare(train) new models
class OnlineManager(Serializable):
Parameters
----------
tasks : List[dict]
a list of tasks
"""
raise NotImplementedError(f"Please implement the `prepare_new_models` method.")
ONLINE_KEY = "online_status" # the tag key in recorder
ONLINE_KEY = "online_status" # the online status key in recorder
ONLINE_TAG = "online" # the 'online' model
NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
def prepare_signals(self, *args, **kwargs):
raise NotImplementedError(f"Please implement the `prepare_signals` method.")
def prepare_tasks(self, *args, **kwargs):
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
def prepare_new_models(self, *args, **kwargs):
raise NotImplementedError(f"Please implement the `prepare_new_models` method.")
def update_online_pred(self, *args, **kwargs):
raise NotImplementedError(f"Please implement the `update_online_pred` method.")
def set_online_tag(self, tag, *args, **kwargs):
raise NotImplementedError(f"Please implement the `set_online_tag` method.")
def get_online_tag(self, *args, **kwargs):
raise NotImplementedError(f"Please implement the `get_online_tag` method.")
class OnlineManagerR(OnlineManager):
"""
The implementation of OnlineManager based on (R)ecorder.
"""
def __init__(self, experiment_name: str) -> None:
"""ModelUpdater needs experiment name to find the records
Parameters
----------
experiment_name : str
experiment name string
"""
self.logger = get_module_logger("OnlineManagement")
self.logger = get_module_logger(self.__class__.__name__)
self.exp_name = experiment_name
self.tc = TaskCollector(experiment_name)
def set_next_online_model(self, recorder: MLflowRecorder):
recorder.set_tags(**{self.ONLINE_KEY: self.NEXT_ONLINE_TAG})
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
if isinstance(recorder, Recorder):
recorder = [recorder]
for rec in recorder:
rec.set_tags(**{self.ONLINE_KEY: tag})
self.logger.info(f"Set {len(recorder)} models to '{tag}'.")
def set_online_model(self, recorder: MLflowRecorder):
"""online model will be identified at the tags of the record"""
recorder.set_tags(**{self.ONLINE_KEY: self.ONLINE_TAG})
def get_online_tag(self, recorder: Recorder):
tags = recorder.list_tags()
return tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG)
def set_offline_model(self, recorder: MLflowRecorder):
recorder.set_tags(**{self.ONLINE_KEY: self.OFFLINE_TAG})
def offline_all_model(self):
recs = self.tc.list_recorders()
for rid, rec in recs.items():
self.set_offline_model(rec)
def reset_online_model(self, recorders: Union[List, Dict] = None):
def reset_online_tag(self, recorder: Union[Recorder, List] = None):
"""offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing.
Args:
recorders (Union[List, Dict], optional):
the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model.
"""
if recorders is None:
recorders = self.list_next_online_model()
if len(recorders) == 0:
if recorder is None:
recorder = list_recorders(
self.exp_name, lambda rec: self.get_online_tag(rec) == OnlineManager.NEXT_ONLINE_TAG
).values()
if isinstance(recorder, Recorder):
recorder = [recorder]
if len(recorder) == 0:
self.logger.info("No 'next online' model, just use current 'online' models.")
return
self.offline_all_model()
if isinstance(recorders, dict):
recorders = recorders.values()
for rec in recorders:
self.set_online_model(rec)
self.logger.info(f"Reset {len(recorders)} models to 'online'.")
def set_latest_model_to_next_online(self):
latest_rec = self.tc.list_latest_recorders()
for rid, rec in latest_rec.items():
self.set_next_online_model(rec)
self.logger.info(f"Set {len(latest_rec)} latest models to 'next online'.")
@staticmethod
def online_filter(recorder):
tags = recorder.list_tags()
if tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG) == OnlineManager.ONLINE_TAG:
return True
return False
@staticmethod
def next_online_filter(recorder):
tags = recorder.list_tags()
if tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG) == OnlineManager.NEXT_ONLINE_TAG:
return True
return False
def list_online_model(self):
"""list the record of online model
Returns
-------
dict
{rid : recorder of the online model}
"""
return self.tc.list_recorders(rec_filter_func=self.online_filter)
def list_next_online_model(self):
return self.tc.list_recorders(rec_filter_func=self.next_online_filter)
recs = list_recorders(self.exp_name)
self.set_online_tag(OnlineManager.OFFLINE_TAG, recs.values())
self.set_online_tag(OnlineManager.ONLINE_TAG, recorder)
self.logger.info(f"Reset {len(recorder)} models to 'online'.")
def update_online_pred(self):
"""update all online model predictions to the latest day in Calendar"""
mu = ModelUpdater(self.exp_name)
cnt = mu.update_all_pred(self.online_filter)
cnt = mu.update_all_pred(lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG)
self.logger.info(f"Finish updating {cnt} online model predictions of {self.exp_name}.")
def after_day(self, *args, **kwargs):
self.prepare_signals(*args, **kwargs)
self.prepare_tasks(*args, **kwargs)
self.prepare_new_models(*args, **kwargs)
self.update_online_pred(*args, **kwargs)
self.reset_online_tag()
class RollingOnlineManager(OnlineManager):
class RollingOnlineManager(OnlineManagerR):
def __init__(self, experiment_name: str, rolling_gen: RollingGen, task_pool) -> None:
super().__init__(experiment_name)
self.ta = TimeAdjuster()
self.rg = rolling_gen
self.tm = TaskManager(task_pool=task_pool)
self.logger = get_module_logger("RollingOnlineManager")
self.logger = get_module_logger(self.__class__.__name__)
def prepare_new_models(self):
"""prepare(train) new models based on online model"""
latest_records = self.tc.list_latest_recorders(self.online_filter) # if we need online_filter here?
max_test = self.tc.latest_time(latest_records)
def prepare_signals(self):
pass
def prepare_tasks(self):
latest_records, max_test = self.list_latest_recorders(lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG)
if max_test is None:
self.logger.warn(f"No latest_recorders.")
return
calendar_latest = self.ta.last_date()
if self.ta.cal_interval(calendar_latest, max_test[0]) > self.rg.step:
old_tasks = []
for rid, rec in latest_records.items():
task = self.tc.get_task(rec)
task = rec.load_object("task")
test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
# modify the test segment to generate new tasks
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
old_tasks.append(task)
new_tasks = task_generator(old_tasks, self.rg)
self.tm.create_task(new_tasks)
run_task(task_train, self.tm.task_pool, experiment_name=self.exp_name)
self.logger.info(f"Finished prepare {len(new_tasks)} new models.")
return new_tasks
self.logger.info("No need to prepare any new models.")
return []
new_num = self.tm.create_task(new_tasks)
self.logger.info(f"Finished prepare {new_num} tasks.")
def prepare_signals(self):
# prepare the signals of today
pass
def prepare_new_models(self):
"""prepare(train) new models based on online model"""
run_task(task_train, self.tm.task_pool, experiment_name=self.exp_name)
latest_records, _ = self.list_latest_recorders()
self.set_online_tag(OnlineManager.NEXT_ONLINE_TAG, latest_records.values())
self.logger.info(f"Finished prepare {len(latest_records)} new models and set them to next_online.")
def list_latest_recorders(self, rec_filter_func=None):
recs_flt = list_recorders(self.exp_name, rec_filter_func)
if len(recs_flt) == 0:
return recs_flt, None
max_test = max(rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in recs_flt.values())
latest_rec = {}
for rid, rec in recs_flt.items():
if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test:
latest_rec[rid] = rec
return latest_rec, max_test

View File

@@ -6,8 +6,7 @@ from qlib import get_module_logger
from qlib.workflow import R
from qlib.model.trainer import task_train
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.collect import TaskCollector
from qlib.workflow.task.utils import list_recorders
class ModelUpdater:
"""
@@ -23,8 +22,7 @@ class ModelUpdater:
experiment name string
"""
self.exp_name = experiment_name
self.logger = get_module_logger("ModelUpdater")
self.tc = TaskCollector(experiment_name)
self.logger = get_module_logger(self.__class__.__name__)
def _reload_dataset(self, recorder, start_time, end_time):
"""reload dataset from pickle file
@@ -53,7 +51,7 @@ class ModelUpdater:
datahandler.init(datahandler.IT_LS)
return dataset
def update_pred(self, recorder: Recorder):
def update_pred(self, recorder: Recorder, frequency='day'):
"""update predictions to the latest day in Calendar based on rid
Parameters
@@ -65,7 +63,10 @@ class ModelUpdater:
last_end = old_pred.index.get_level_values("datetime").max()
# updated to the latest trading day
cal = D.calendar(start_time=last_end + pd.Timedelta(days=1), end_time=None)
if frequency=='day':
cal = D.calendar(start_time=last_end + pd.Timedelta(days=1), end_time=None)
else:
raise NotImplementedError("Now Qlib only support update daily frequency prediction")
if len(cal) == 0:
self.logger.info(
@@ -113,7 +114,7 @@ class ModelUpdater:
the count of updated record
"""
recs = self.tc.list_recorders(rec_filter_func=rec_filter_func)
recs = list_recorders(self.exp_name, rec_filter_func=rec_filter_func)
for rid, rec in recs.items():
self.update_pred(rec)
return len(recs)

View File

@@ -3,6 +3,7 @@
import bisect
import pandas as pd
from qlib.data import D
from qlib.workflow import R
from qlib.config import C
from qlib.log import get_module_logger
from pymongo import MongoClient
@@ -29,6 +30,25 @@ def get_mongodb():
client = MongoClient(cfg["task_url"])
return client.get_database(name=cfg["task_db_name"])
def list_recorders(experiment, rec_filter_func=None):
"""list all recorders which can pass the filter in a experiment.
Args:
experiment (str or Experiment): the name of a Experiment or a instance
rec_filter_func (Callable, optional): return True to retain the given recorder. Defaults to None.
Returns:
dict: a dict {rid: recorder} after filtering.
"""
if isinstance(experiment, str):
experiment, _ = R.exp_manager._get_or_create_exp(experiment_name=experiment)
recs = experiment.list_recorders()
recs_flt = {}
for rid, rec in recs.items():
if rec_filter_func is None or rec_filter_func(rec):
recs_flt[rid] = rec
return recs_flt
class TimeAdjuster:
"""