mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Online Serving V4
This commit is contained in:
@@ -77,7 +77,7 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
})
|
||||
- `mongo`
|
||||
Type: dict, optional parameter, the setting of `MongoDB <https://www.mongodb.com/>`_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.
|
||||
Users need finished `installatin <https://www.mongodb.com/try/download/community>`_ firstly, and run it in a fixed URL.
|
||||
Users need finished `installation <https://www.mongodb.com/try/download/community>`_ firstly, and run it in a fixed URL.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from pprint import pprint
|
||||
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.config import C
|
||||
from qlib.workflow.task.manage import run_task
|
||||
from qlib.workflow.task.collect import RollingCollector
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow import R
|
||||
from pprint import pprint
|
||||
from qlib.workflow.task.collect import RollingCollector
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
@@ -66,14 +66,14 @@ task_xgboost_config = {
|
||||
}
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset():
|
||||
def reset(task_pool, exp_name):
|
||||
print("========== reset ==========")
|
||||
TaskManager(task_pool=task_pool).remove()
|
||||
|
||||
# exp = R.get_exp(experiment_name=exp_name)
|
||||
exp, _ = R.exp_manager._get_or_create_exp(experiment_name=exp_name)
|
||||
|
||||
# for rid in R.list_recorders():
|
||||
# exp.delete_recorder(rid)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
|
||||
# This part corresponds to "Task Generating" in the document
|
||||
@@ -92,51 +92,58 @@ def task_generating():
|
||||
|
||||
|
||||
# This part corresponds to "Task Storing" in the document
|
||||
def task_storing(tasks):
|
||||
def task_storing(tasks, task_pool, exp_name):
|
||||
print("========== task_storing ==========")
|
||||
tm = TaskManager(task_pool=task_pool)
|
||||
tm.create_task(tasks) # all tasks will be saved to MongoDB
|
||||
|
||||
|
||||
# This part corresponds to "Task Running" in the document
|
||||
def task_running():
|
||||
def task_running(task_pool, exp_name):
|
||||
print("========== task_running ==========")
|
||||
run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using "task_train" method
|
||||
|
||||
|
||||
# This part corresponds to "Task Collecting" in the document
|
||||
def task_collecting():
|
||||
def task_collecting(task_pool, exp_name):
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
def get_task_key(task_config):
|
||||
def get_group_key_func(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
return task_config["model"]["class"]
|
||||
|
||||
def my_filter(recorder):
|
||||
# only choose the results of "LGBModel"
|
||||
task_key = get_task_key(rolling_collector.get_task(recorder))
|
||||
task_key = get_group_key_func(recorder)
|
||||
if task_key == "LGBModel":
|
||||
return True
|
||||
return False
|
||||
|
||||
rolling_collector = RollingCollector(exp_name)
|
||||
# group tasks by "get_task_key" and filter tasks by "my_filter"
|
||||
pred_rolling = rolling_collector.collect_rolling_predictions(get_task_key, my_filter)
|
||||
pred_rolling = rolling_collector.collect(get_group_key_func, my_filter)
|
||||
print(pred_rolling)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
def main(
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
exp_name="rolling_exp",
|
||||
task_pool="rolling_task",
|
||||
):
|
||||
mongo_conf = {
|
||||
"task_url": "mongodb://10.0.0.4:27017/", # maybe you need to change it to your url
|
||||
"task_db_name": "rolling_db",
|
||||
"task_url": task_url,
|
||||
"task_db_name": task_db_name,
|
||||
}
|
||||
exp_name = "rolling_exp" # experiment name, will be used as the experiment in MLflow
|
||||
task_pool = "rolling_task" # task pool name, will be used as the document in MongoDB
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf)
|
||||
|
||||
reset()
|
||||
reset(task_pool, exp_name)
|
||||
tasks = task_generating()
|
||||
task_storing(tasks)
|
||||
task_running()
|
||||
task_collecting()
|
||||
task_storing(tasks, task_pool, exp_name)
|
||||
task_running(task_pool, exp_name)
|
||||
task_collecting(task_pool, exp_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire()
|
||||
@@ -1,16 +1,15 @@
|
||||
import qlib
|
||||
import fire
|
||||
import mlflow
|
||||
from qlib.config import C
|
||||
from qlib.workflow import R
|
||||
from pprint import pprint
|
||||
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.task.manage import run_task
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.collect import RollingCollector
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.workflow.task.online import RollingOnlineManager
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2013-01-01",
|
||||
@@ -70,12 +69,15 @@ task_xgboost_config = {
|
||||
|
||||
|
||||
def print_online_model():
|
||||
print("========== print_online_model ==========")
|
||||
print("Current 'online' model:")
|
||||
for online in rolling_online_manager.list_online_model().values():
|
||||
print(online.info["id"])
|
||||
for rid, rec in list_recorders(exp_name).items():
|
||||
if rolling_online_manager.get_online_tag(rec) == rolling_online_manager.ONLINE_TAG:
|
||||
print(rid)
|
||||
print("Current 'next online' model:")
|
||||
for online in rolling_online_manager.list_next_online_model().values():
|
||||
print(online.info["id"])
|
||||
for rid, rec in list_recorders(exp_name).items():
|
||||
if rolling_online_manager.get_online_tag(rec) == rolling_online_manager.NEXT_ONLINE_TAG:
|
||||
print(rid)
|
||||
|
||||
|
||||
# This part corresponds to "Task Generating" in the document
|
||||
@@ -110,119 +112,76 @@ def task_running():
|
||||
def task_collecting():
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
def get_task_key(task_config):
|
||||
def get_group_key_func(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
return task_config["model"]["class"]
|
||||
|
||||
def my_filter(recorder):
|
||||
# only choose the results of "LGBModel"
|
||||
task_key = get_task_key(rolling_collector.get_task(recorder))
|
||||
task_key = get_group_key_func(recorder)
|
||||
if task_key == "LGBModel":
|
||||
return True
|
||||
return False
|
||||
|
||||
rolling_collector = RollingCollector(exp_name)
|
||||
# group tasks by "get_task_key" and filter tasks by "my_filter"
|
||||
pred_rolling = rolling_collector.collect_rolling_predictions(get_task_key, my_filter)
|
||||
pred_rolling = rolling_collector.collect(get_group_key_func, my_filter)
|
||||
print(pred_rolling)
|
||||
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(force_end=False):
|
||||
def reset():
|
||||
print("========== reset ==========")
|
||||
task_manager.remove()
|
||||
for error in task_manager.query():
|
||||
assert False
|
||||
exp = R.get_exp(experiment_name=exp_name)
|
||||
recs = exp.list_recorders()
|
||||
|
||||
for rid in recs:
|
||||
exp, _ = R.exp_manager._get_or_create_exp(experiment_name=exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
try:
|
||||
if force_end:
|
||||
mlflow.end_run()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Run this firstly to see the workflow in Task Management
|
||||
def first_run():
|
||||
print("========== first_run ==========")
|
||||
reset(force_end=True)
|
||||
reset()
|
||||
|
||||
tasks = task_generating()
|
||||
task_storing(tasks)
|
||||
task_running()
|
||||
task_collecting()
|
||||
|
||||
rolling_online_manager.set_latest_model_to_next_online()
|
||||
rolling_online_manager.reset_online_model()
|
||||
|
||||
|
||||
# Update the predictions of online model
|
||||
def update_predictions():
|
||||
print("========== update_predictions ==========")
|
||||
rolling_online_manager.update_online_pred()
|
||||
task_collecting()
|
||||
# if there are some next_online_model, then online them. if no, still use current online_model.
|
||||
print_online_model()
|
||||
rolling_online_manager.reset_online_model()
|
||||
print_online_model()
|
||||
|
||||
|
||||
# Update the models using the latest date and set them to online model
|
||||
def update_model():
|
||||
print("========== update_model ==========")
|
||||
rolling_online_manager.prepare_new_models()
|
||||
print_online_model()
|
||||
rolling_online_manager.set_latest_model_to_next_online()
|
||||
print_online_model()
|
||||
latest_rec, _ = rolling_online_manager.list_latest_recorders()
|
||||
rolling_online_manager.reset_online_tag(latest_rec.values())
|
||||
|
||||
|
||||
def after_day():
|
||||
rolling_online_manager.prepare_signals()
|
||||
update_model()
|
||||
update_predictions()
|
||||
|
||||
|
||||
# Run whole workflow completely
|
||||
def whole_workflow():
|
||||
print("========== whole_workflow ==========")
|
||||
# run this at the first time
|
||||
first_run()
|
||||
# run this every day after trading
|
||||
after_day()
|
||||
print("========== after_day ==========")
|
||||
print_online_model()
|
||||
rolling_online_manager.after_day()
|
||||
print_online_model()
|
||||
task_collecting()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### to train the first version's models, use the command below
|
||||
# python task_manager_rolling_with_updating.py first_run
|
||||
|
||||
####### to update the models using the latest date, use the command below
|
||||
# python task_manager_rolling_with_updating.py update_model
|
||||
|
||||
####### to update the predictions to the latest date, use the command below
|
||||
# python task_manager_rolling_with_updating.py update_predictions
|
||||
|
||||
####### to run whole workflow completely, use the command below
|
||||
# python task_manager_rolling_with_updating.py whole_workflow
|
||||
####### to update the models and predictions after the trading time, use the command below
|
||||
# python task_manager_rolling_with_updating.py after_day
|
||||
|
||||
#################### you need to finish the configurations below #########################
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # data_dir
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
C["mongo"] = {
|
||||
mongo_conf = {
|
||||
"task_url": "mongodb://10.0.0.4:27017/", # your MongoDB url
|
||||
"task_db_name": "online", # database name
|
||||
"task_db_name": "rolling_db", # database name
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf)
|
||||
|
||||
exp_name = "rolling_exp" # experiment name, will be used as the experiment in MLflow
|
||||
task_pool = "rolling_task" # task pool name, will be used as the document in MongoDB
|
||||
rolling_step = 550
|
||||
|
||||
##########################################################################################
|
||||
rolling_gen = RollingGen(step=550, rtype=RollingGen.ROLL_SD)
|
||||
rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD)
|
||||
rolling_online_manager = RollingOnlineManager(
|
||||
experiment_name=exp_name, rolling_gen=rolling_gen, task_pool=task_pool
|
||||
)
|
||||
@@ -1,9 +1,9 @@
|
||||
import qlib
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.task.online import OnlineManager
|
||||
from qlib.config import REG_CN
|
||||
import fire
|
||||
from qlib.workflow import R
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.task.online import OnlineManagerR
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
@@ -56,19 +56,20 @@ def first_train(experiment_name="online_svr"):
|
||||
|
||||
rid = task_train(task_config=task, experiment_name=experiment_name)
|
||||
|
||||
rom = OnlineManager(experiment_name)
|
||||
rom.reset_online_model(rid)
|
||||
online_manager = OnlineManagerR(experiment_name)
|
||||
online_manager.reset_online_tag(rid)
|
||||
|
||||
|
||||
def update_online_pred(experiment_name="online_svr"):
|
||||
|
||||
rom = OnlineManager(experiment_name)
|
||||
online_manager = OnlineManagerR(experiment_name)
|
||||
|
||||
print("Here are the online models waiting for update:")
|
||||
for rid, rec in rom.list_online_model().items():
|
||||
print(rid)
|
||||
for rid, rec in list_recorders(experiment_name).items():
|
||||
if online_manager.get_online_tag(rec) == OnlineManagerR.ONLINE_TAG:
|
||||
print(rid)
|
||||
|
||||
rom.update_online_pred()
|
||||
online_manager.update_online_pred()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -134,7 +134,7 @@ _default_config = {
|
||||
},
|
||||
"loggers": {"qlib": {"level": "DEBUG", "handlers": ["console"]}},
|
||||
},
|
||||
# Defatult config for experiment manager
|
||||
# Default config for experiment manager
|
||||
"exp_manager": {
|
||||
"class": "MLflowExpManager",
|
||||
"module_path": "qlib.workflow.expm",
|
||||
@@ -143,6 +143,11 @@ _default_config = {
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
},
|
||||
# Default config for MongoDB
|
||||
"mongo": {
|
||||
"task_url": "mongodb://localhost:27017/",
|
||||
"task_db_name": "default_task_db",
|
||||
}
|
||||
}
|
||||
|
||||
MODE_CONF = {
|
||||
|
||||
@@ -27,6 +27,7 @@ def task_train(task_config: dict, experiment_name: str) -> str:
|
||||
model = init_instance_by_config(task_config["model"])
|
||||
dataset = init_instance_by_config(task_config["dataset"])
|
||||
datahandler = dataset.handler
|
||||
dataset.config(exclude=["handler"])
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name=experiment_name):
|
||||
@@ -37,10 +38,8 @@ def task_train(task_config: dict, experiment_name: str) -> str:
|
||||
recorder = R.get_recorder()
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
R.save_objects(**{"task": task_config}) # keep the original format and datatype
|
||||
|
||||
artifact_uri = recorder.get_artifact_uri()[7:] # delete "file://"
|
||||
dataset.to_pickle(artifact_uri + "/dataset", exclude=["handler"])
|
||||
datahandler.to_pickle(artifact_uri + "/datahandler")
|
||||
R.save_objects(**{"dataset": dataset})
|
||||
R.save_objects(**{"datahandler": datahandler})
|
||||
|
||||
# generate records: prediction, backtest, and analysis
|
||||
records = task_config.get("record", [])
|
||||
|
||||
@@ -1,116 +1,172 @@
|
||||
from qlib.workflow import R
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Union
|
||||
|
||||
import pandas as pd
|
||||
from typing import Union
|
||||
from typing import Callable
|
||||
|
||||
from qlib import get_module_logger
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
|
||||
|
||||
class TaskCollector:
|
||||
class Collector:
|
||||
"""
|
||||
Collect the record (or its results) of the tasks
|
||||
This class will divide disorderly records or anything worth collecting into different groups based on the group_key.
|
||||
After grouping, we can reduce the useful information from different groups.
|
||||
"""
|
||||
|
||||
def group(self, *args, **kwargs):
|
||||
"""
|
||||
According to the get_group_key_func, divide disorderly things into different groups.
|
||||
|
||||
For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
input:
|
||||
[thing1, thing2, thing3, thing4, thing5]
|
||||
|
||||
output:
|
||||
{
|
||||
"group_name1": [thing3, thing5, thing1]
|
||||
"group_name2": [thing2, thing4]
|
||||
}
|
||||
|
||||
Args:
|
||||
get_group_key_func (Callable): get a group key based on a thing
|
||||
things_list (list): a list of things
|
||||
|
||||
Returns:
|
||||
dict: a dict including the group key and members of the group.
|
||||
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `group` method.")
|
||||
|
||||
def reduce(self, things_group: dict):
|
||||
"""
|
||||
Using the dict from `group`, reduce useful information.
|
||||
|
||||
Args:
|
||||
things_group (dict): a dict after grouping
|
||||
|
||||
Returns:
|
||||
dict: a dict including the group key, the information key and the information value
|
||||
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `reduce` method.")
|
||||
|
||||
def collect(self, *args, **kwargs):
|
||||
"""group and reduce
|
||||
|
||||
Returns:
|
||||
dict: a dict including the group key, the information key and the information value
|
||||
"""
|
||||
grouped = self.group(*args, **kwargs)
|
||||
return self.reduce(grouped)
|
||||
|
||||
|
||||
class RecorderCollector(Collector):
|
||||
"""
|
||||
The Recorder's Collector. This class is a implementation of Collector, collecting some artifacts saved by Recorder.
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str) -> None:
|
||||
self.exp_name = experiment_name
|
||||
self.exp = R.get_exp(experiment_name=experiment_name)
|
||||
self.logger = get_module_logger("TaskCollector")
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
def list_recorders(self, rec_filter_func=None):
|
||||
_artifacts_key_path = {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}
|
||||
_artifacts_key_merge_method = {}
|
||||
|
||||
recs = self.exp.list_recorders()
|
||||
recs_flt = {}
|
||||
for rid, rec in recs.items():
|
||||
if rec_filter_func is None or rec_filter_func(rec):
|
||||
recs_flt[rid] = rec
|
||||
def default_merge(self, artifact_list):
|
||||
"""Merge disorderly artifacts in artifact list.
|
||||
|
||||
return recs_flt
|
||||
Args:
|
||||
artifact_list (list): A artifact list.
|
||||
|
||||
def list_recorders_by_task(self, task_filter_func=None):
|
||||
def rec_filter(recorder):
|
||||
return task_filter_func(self.get_task(recorder))
|
||||
|
||||
return self.list_recorders(rec_filter)
|
||||
|
||||
def list_latest_recorders(self, rec_filter_func=None):
|
||||
recs_flt = self.list_recorders(rec_filter_func)
|
||||
max_test = self.latest_time(recs_flt)
|
||||
latest_rec = {}
|
||||
for rid, rec in recs_flt.items():
|
||||
if self.get_task(rec)["dataset"]["kwargs"]["segments"]["test"] == max_test:
|
||||
latest_rec[rid] = rec
|
||||
return latest_rec
|
||||
|
||||
def get_recorder_by_id(self, recorder_id):
|
||||
return self.exp.get_recorder(recorder_id, create=False)
|
||||
|
||||
def get_task(self, recorder):
|
||||
if isinstance(recorder, str):
|
||||
recorder = self.get_recorder_by_id(recorder_id=recorder)
|
||||
try:
|
||||
task = recorder.load_object("task")
|
||||
except OSError:
|
||||
raise OSError(f"Can't find task in {recorder.info['id']}, have you trained with model.trainer.task_train?")
|
||||
return task
|
||||
|
||||
def latest_time(self, recorders):
|
||||
if len(recorders) == 0:
|
||||
raise Exception(f"Can't find any recorder in {self.exp_name}")
|
||||
max_test = max(self.get_task(rec)["dataset"]["kwargs"]["segments"]["test"] for rec in recorders.values())
|
||||
return max_test
|
||||
|
||||
|
||||
class RollingCollector(TaskCollector):
|
||||
"""
|
||||
Collect the record results of the rolling tasks
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name: str,
|
||||
) -> None:
|
||||
super().__init__(experiment_name)
|
||||
self.logger = get_module_logger("RollingCollector")
|
||||
|
||||
def collect_rolling_predictions(self, get_key_func, rec_filter_func=None):
|
||||
"""For rolling tasks, the predictions will be in the diffierent recorder.
|
||||
To collect and concat the predictions of one rolling task, get_key_func will help this method see which group a recorder will be.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
get_key_func : Callable[dict,str]
|
||||
a function that get task config and return its group str
|
||||
rec_filter_func : Callable[Recorder,bool], optional
|
||||
a function that decide whether filter a recorder, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
a dict of {group: predictions}
|
||||
Raises:
|
||||
NotImplementedError: [description]
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `default_merge` method.")
|
||||
|
||||
def group(self, get_group_key_func, rec_filter_func=None):
|
||||
"""
|
||||
Filter recorders and group recorders by group key.
|
||||
|
||||
Args:
|
||||
get_group_key_func (Callable): get a group key based on a recorder
|
||||
rec_filter_func (Callable, optional): filter the recorders in this experiment. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: a dict including the group key and recorders of the group
|
||||
"""
|
||||
# filter records
|
||||
recs_flt = self.list_recorders(rec_filter_func)
|
||||
recs_flt = list_recorders(self.exp_name, rec_filter_func)
|
||||
|
||||
# group
|
||||
recs_group = {}
|
||||
for _, rec in recs_flt.items():
|
||||
task = self.get_task(rec)
|
||||
group_key = get_key_func(task)
|
||||
group_key = get_group_key_func(rec)
|
||||
recs_group.setdefault(group_key, []).append(rec)
|
||||
|
||||
# reduce group
|
||||
reduce_group = {}
|
||||
for k, rec_l in recs_group.items():
|
||||
pred_l = []
|
||||
for rec in rec_l:
|
||||
pred_l.append(rec.load_object("pred.pkl").iloc[:, 0])
|
||||
# Make sure the pred are sorted according to the rolling start time
|
||||
pred_l.sort(key=lambda pred: pred.index.get_level_values("datetime").min())
|
||||
pred = pd.concat(pred_l)
|
||||
# If there are duplicated predition, we use the latest perdiction
|
||||
pred = pred[~pred.index.duplicated(keep="last")]
|
||||
pred = pred.sort_index()
|
||||
reduce_group[k] = pred
|
||||
return recs_group
|
||||
|
||||
return reduce_group
|
||||
def reduce(self, recs_group: dict, artifact_keys_list: list = None):
|
||||
"""
|
||||
Reduce artifacts based on the dict of grouped recorder.
|
||||
The artifacts need be declared by artifact_keys_list.
|
||||
The artifacts path in recorder need be declared by _artifacts_key_path.
|
||||
If there is no declartion in _artifacts_key_merge_method, then use default_merge method to merge it.
|
||||
|
||||
Args:
|
||||
recs_group (dict): The dict grouped by `group`
|
||||
artifact_keys_list (list): The list of artifact keys. If it is None, then use all artifacts in _artifacts_key_path.
|
||||
|
||||
Returns:
|
||||
a dict including the group key, the artifact key and the artifact value.
|
||||
|
||||
For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
group_key: {"pred": <VALUE>, "IC": <VALUE>}
|
||||
}
|
||||
"""
|
||||
if artifact_keys_list == None:
|
||||
artifact_keys_list = self._artifacts_key_path.keys()
|
||||
reduce_group = {}
|
||||
for group_key, recorder_list in recs_group.items():
|
||||
reduced_artifacts = {}
|
||||
for artifact_key in artifact_keys_list:
|
||||
artifact_list = []
|
||||
for recorder in recorder_list:
|
||||
artifact_list.append(recorder.load_object(self._artifacts_key_path[artifact_key]))
|
||||
merge_method = self._artifacts_key_merge_method.get(artifact_key, self.default_merge)
|
||||
artifact = merge_method(artifact_list)
|
||||
reduced_artifacts[artifact_key] = artifact
|
||||
reduce_group[group_key] = reduced_artifacts
|
||||
return reduce_group
|
||||
|
||||
|
||||
class RollingCollector(RecorderCollector):
|
||||
"""
|
||||
Collect the record results of the rolling tasks
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str):
|
||||
super().__init__(experiment_name)
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
def default_merge(self, artifact_list):
|
||||
"""merge disorderly artifacts based on the datetime.
|
||||
|
||||
Args:
|
||||
artifact_list (list): a list of artifacts from different recorders
|
||||
|
||||
Returns:
|
||||
merged artifact
|
||||
"""
|
||||
# Make sure the pred are sorted according to the rolling start time
|
||||
artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min())
|
||||
artifact = pd.concat(artifact_list)
|
||||
# If there are duplicated predition, we use the latest perdiction
|
||||
artifact = artifact[~artifact.index.duplicated(keep="last")]
|
||||
artifact = artifact.sort_index()
|
||||
return artifact
|
||||
|
||||
@@ -19,10 +19,10 @@ def task_generator(tasks, generators) -> list:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tasks : List[dict]
|
||||
a list of task templates
|
||||
generators : List[TaskGen]
|
||||
a list of TaskGen
|
||||
tasks : List[dict] or dict
|
||||
a list of task templates or a single task
|
||||
generators : List[TaskGen] or TaskGen
|
||||
a list of TaskGen or a single TaskGen
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
@@ -151,7 +151,8 @@ class TaskManager:
|
||||
if print new task
|
||||
Returns
|
||||
-------
|
||||
|
||||
int
|
||||
the length of new tasks
|
||||
"""
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
new_tasks = []
|
||||
@@ -173,6 +174,8 @@ class TaskManager:
|
||||
|
||||
for t in new_tasks:
|
||||
self.insert_task_def(t, task_pool)
|
||||
|
||||
return len(new_tasks)
|
||||
|
||||
def fetch_task(self, query={}, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
@@ -245,10 +248,9 @@ class TaskManager:
|
||||
for t in task_pool.find(query):
|
||||
yield self._decode_task(t)
|
||||
|
||||
def get_task_result(self, task, task_pool=None):
|
||||
def re_query(self, task, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
result = task_pool.find_one({"filter": task})
|
||||
return self._decode_task(result)["res"]
|
||||
return task_pool.find_one({"_id":ObjectId(task["_id"])})
|
||||
|
||||
def commit_task_res(self, task, res, status=None, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
|
||||
@@ -3,147 +3,140 @@ from qlib import get_module_logger
|
||||
from qlib.workflow import R
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.recorder import MLflowRecorder, Recorder
|
||||
from qlib.workflow.task.collect import TaskCollector
|
||||
from qlib.workflow.task.update import ModelUpdater
|
||||
from qlib.workflow.task.utils import TimeAdjuster
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.manage import run_task
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
from qlib.utils.serial import Serializable
|
||||
|
||||
|
||||
class OnlineManager:
|
||||
def prepare_new_models(self, tasks: List[dict]):
|
||||
"""prepare(train) new models
|
||||
class OnlineManager(Serializable):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tasks : List[dict]
|
||||
a list of tasks
|
||||
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `prepare_new_models` method.")
|
||||
|
||||
ONLINE_KEY = "online_status" # the tag key in recorder
|
||||
ONLINE_KEY = "online_status" # the online status key in recorder
|
||||
ONLINE_TAG = "online" # the 'online' model
|
||||
NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model
|
||||
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
|
||||
|
||||
def prepare_signals(self, *args, **kwargs):
|
||||
raise NotImplementedError(f"Please implement the `prepare_signals` method.")
|
||||
|
||||
def prepare_tasks(self, *args, **kwargs):
|
||||
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
|
||||
|
||||
def prepare_new_models(self, *args, **kwargs):
|
||||
raise NotImplementedError(f"Please implement the `prepare_new_models` method.")
|
||||
|
||||
def update_online_pred(self, *args, **kwargs):
|
||||
raise NotImplementedError(f"Please implement the `update_online_pred` method.")
|
||||
|
||||
def set_online_tag(self, tag, *args, **kwargs):
|
||||
raise NotImplementedError(f"Please implement the `set_online_tag` method.")
|
||||
|
||||
def get_online_tag(self, *args, **kwargs):
|
||||
raise NotImplementedError(f"Please implement the `get_online_tag` method.")
|
||||
|
||||
|
||||
class OnlineManagerR(OnlineManager):
|
||||
"""
|
||||
The implementation of OnlineManager based on (R)ecorder.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str) -> None:
|
||||
"""ModelUpdater needs experiment name to find the records
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_name : str
|
||||
experiment name string
|
||||
"""
|
||||
self.logger = get_module_logger("OnlineManagement")
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.exp_name = experiment_name
|
||||
self.tc = TaskCollector(experiment_name)
|
||||
|
||||
def set_next_online_model(self, recorder: MLflowRecorder):
|
||||
recorder.set_tags(**{self.ONLINE_KEY: self.NEXT_ONLINE_TAG})
|
||||
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
|
||||
if isinstance(recorder, Recorder):
|
||||
recorder = [recorder]
|
||||
for rec in recorder:
|
||||
rec.set_tags(**{self.ONLINE_KEY: tag})
|
||||
self.logger.info(f"Set {len(recorder)} models to '{tag}'.")
|
||||
|
||||
def set_online_model(self, recorder: MLflowRecorder):
|
||||
"""online model will be identified at the tags of the record"""
|
||||
recorder.set_tags(**{self.ONLINE_KEY: self.ONLINE_TAG})
|
||||
def get_online_tag(self, recorder: Recorder):
|
||||
tags = recorder.list_tags()
|
||||
return tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG)
|
||||
|
||||
def set_offline_model(self, recorder: MLflowRecorder):
|
||||
recorder.set_tags(**{self.ONLINE_KEY: self.OFFLINE_TAG})
|
||||
|
||||
def offline_all_model(self):
|
||||
recs = self.tc.list_recorders()
|
||||
for rid, rec in recs.items():
|
||||
self.set_offline_model(rec)
|
||||
|
||||
def reset_online_model(self, recorders: Union[List, Dict] = None):
|
||||
def reset_online_tag(self, recorder: Union[Recorder, List] = None):
|
||||
"""offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing.
|
||||
|
||||
Args:
|
||||
recorders (Union[List, Dict], optional):
|
||||
the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model.
|
||||
"""
|
||||
if recorders is None:
|
||||
recorders = self.list_next_online_model()
|
||||
if len(recorders) == 0:
|
||||
if recorder is None:
|
||||
recorder = list_recorders(
|
||||
self.exp_name, lambda rec: self.get_online_tag(rec) == OnlineManager.NEXT_ONLINE_TAG
|
||||
).values()
|
||||
if isinstance(recorder, Recorder):
|
||||
recorder = [recorder]
|
||||
if len(recorder) == 0:
|
||||
self.logger.info("No 'next online' model, just use current 'online' models.")
|
||||
return
|
||||
self.offline_all_model()
|
||||
if isinstance(recorders, dict):
|
||||
recorders = recorders.values()
|
||||
for rec in recorders:
|
||||
self.set_online_model(rec)
|
||||
self.logger.info(f"Reset {len(recorders)} models to 'online'.")
|
||||
|
||||
def set_latest_model_to_next_online(self):
|
||||
latest_rec = self.tc.list_latest_recorders()
|
||||
for rid, rec in latest_rec.items():
|
||||
self.set_next_online_model(rec)
|
||||
self.logger.info(f"Set {len(latest_rec)} latest models to 'next online'.")
|
||||
|
||||
@staticmethod
|
||||
def online_filter(recorder):
|
||||
tags = recorder.list_tags()
|
||||
if tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG) == OnlineManager.ONLINE_TAG:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def next_online_filter(recorder):
|
||||
tags = recorder.list_tags()
|
||||
if tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG) == OnlineManager.NEXT_ONLINE_TAG:
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_online_model(self):
|
||||
"""list the record of online model
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
{rid : recorder of the online model}
|
||||
"""
|
||||
|
||||
return self.tc.list_recorders(rec_filter_func=self.online_filter)
|
||||
|
||||
def list_next_online_model(self):
|
||||
return self.tc.list_recorders(rec_filter_func=self.next_online_filter)
|
||||
recs = list_recorders(self.exp_name)
|
||||
self.set_online_tag(OnlineManager.OFFLINE_TAG, recs.values())
|
||||
self.set_online_tag(OnlineManager.ONLINE_TAG, recorder)
|
||||
self.logger.info(f"Reset {len(recorder)} models to 'online'.")
|
||||
|
||||
def update_online_pred(self):
|
||||
"""update all online model predictions to the latest day in Calendar"""
|
||||
mu = ModelUpdater(self.exp_name)
|
||||
cnt = mu.update_all_pred(self.online_filter)
|
||||
cnt = mu.update_all_pred(lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG)
|
||||
self.logger.info(f"Finish updating {cnt} online model predictions of {self.exp_name}.")
|
||||
|
||||
def after_day(self, *args, **kwargs):
|
||||
self.prepare_signals(*args, **kwargs)
|
||||
self.prepare_tasks(*args, **kwargs)
|
||||
self.prepare_new_models(*args, **kwargs)
|
||||
self.update_online_pred(*args, **kwargs)
|
||||
self.reset_online_tag()
|
||||
|
||||
class RollingOnlineManager(OnlineManager):
|
||||
|
||||
class RollingOnlineManager(OnlineManagerR):
|
||||
def __init__(self, experiment_name: str, rolling_gen: RollingGen, task_pool) -> None:
|
||||
super().__init__(experiment_name)
|
||||
self.ta = TimeAdjuster()
|
||||
self.rg = rolling_gen
|
||||
self.tm = TaskManager(task_pool=task_pool)
|
||||
self.logger = get_module_logger("RollingOnlineManager")
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
def prepare_new_models(self):
|
||||
"""prepare(train) new models based on online model"""
|
||||
latest_records = self.tc.list_latest_recorders(self.online_filter) # if we need online_filter here?
|
||||
max_test = self.tc.latest_time(latest_records)
|
||||
def prepare_signals(self):
|
||||
pass
|
||||
|
||||
def prepare_tasks(self):
|
||||
latest_records, max_test = self.list_latest_recorders(lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG)
|
||||
if max_test is None:
|
||||
self.logger.warn(f"No latest_recorders.")
|
||||
return
|
||||
calendar_latest = self.ta.last_date()
|
||||
if self.ta.cal_interval(calendar_latest, max_test[0]) > self.rg.step:
|
||||
old_tasks = []
|
||||
for rid, rec in latest_records.items():
|
||||
task = self.tc.get_task(rec)
|
||||
task = rec.load_object("task")
|
||||
test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
|
||||
# modify the test segment to generate new tasks
|
||||
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
|
||||
old_tasks.append(task)
|
||||
new_tasks = task_generator(old_tasks, self.rg)
|
||||
self.tm.create_task(new_tasks)
|
||||
run_task(task_train, self.tm.task_pool, experiment_name=self.exp_name)
|
||||
self.logger.info(f"Finished prepare {len(new_tasks)} new models.")
|
||||
return new_tasks
|
||||
self.logger.info("No need to prepare any new models.")
|
||||
return []
|
||||
new_num = self.tm.create_task(new_tasks)
|
||||
self.logger.info(f"Finished prepare {new_num} tasks.")
|
||||
|
||||
def prepare_signals(self):
|
||||
# prepare the signals of today
|
||||
pass
|
||||
def prepare_new_models(self):
|
||||
"""prepare(train) new models based on online model"""
|
||||
run_task(task_train, self.tm.task_pool, experiment_name=self.exp_name)
|
||||
latest_records, _ = self.list_latest_recorders()
|
||||
self.set_online_tag(OnlineManager.NEXT_ONLINE_TAG, latest_records.values())
|
||||
self.logger.info(f"Finished prepare {len(latest_records)} new models and set them to next_online.")
|
||||
|
||||
def list_latest_recorders(self, rec_filter_func=None):
|
||||
recs_flt = list_recorders(self.exp_name, rec_filter_func)
|
||||
if len(recs_flt) == 0:
|
||||
return recs_flt, None
|
||||
max_test = max(rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in recs_flt.values())
|
||||
latest_rec = {}
|
||||
for rid, rec in recs_flt.items():
|
||||
if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test:
|
||||
latest_rec[rid] = rec
|
||||
return latest_rec, max_test
|
||||
@@ -6,8 +6,7 @@ from qlib import get_module_logger
|
||||
from qlib.workflow import R
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.collect import TaskCollector
|
||||
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
|
||||
class ModelUpdater:
|
||||
"""
|
||||
@@ -23,8 +22,7 @@ class ModelUpdater:
|
||||
experiment name string
|
||||
"""
|
||||
self.exp_name = experiment_name
|
||||
self.logger = get_module_logger("ModelUpdater")
|
||||
self.tc = TaskCollector(experiment_name)
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
def _reload_dataset(self, recorder, start_time, end_time):
|
||||
"""reload dataset from pickle file
|
||||
@@ -53,7 +51,7 @@ class ModelUpdater:
|
||||
datahandler.init(datahandler.IT_LS)
|
||||
return dataset
|
||||
|
||||
def update_pred(self, recorder: Recorder):
|
||||
def update_pred(self, recorder: Recorder, frequency='day'):
|
||||
"""update predictions to the latest day in Calendar based on rid
|
||||
|
||||
Parameters
|
||||
@@ -65,7 +63,10 @@ class ModelUpdater:
|
||||
last_end = old_pred.index.get_level_values("datetime").max()
|
||||
|
||||
# updated to the latest trading day
|
||||
cal = D.calendar(start_time=last_end + pd.Timedelta(days=1), end_time=None)
|
||||
if frequency=='day':
|
||||
cal = D.calendar(start_time=last_end + pd.Timedelta(days=1), end_time=None)
|
||||
else:
|
||||
raise NotImplementedError("Now Qlib only support update daily frequency prediction")
|
||||
|
||||
if len(cal) == 0:
|
||||
self.logger.info(
|
||||
@@ -113,7 +114,7 @@ class ModelUpdater:
|
||||
the count of updated record
|
||||
|
||||
"""
|
||||
recs = self.tc.list_recorders(rec_filter_func=rec_filter_func)
|
||||
recs = list_recorders(self.exp_name, rec_filter_func=rec_filter_func)
|
||||
for rid, rec in recs.items():
|
||||
self.update_pred(rec)
|
||||
return len(recs)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import bisect
|
||||
import pandas as pd
|
||||
from qlib.data import D
|
||||
from qlib.workflow import R
|
||||
from qlib.config import C
|
||||
from qlib.log import get_module_logger
|
||||
from pymongo import MongoClient
|
||||
@@ -29,6 +30,25 @@ def get_mongodb():
|
||||
client = MongoClient(cfg["task_url"])
|
||||
return client.get_database(name=cfg["task_db_name"])
|
||||
|
||||
def list_recorders(experiment, rec_filter_func=None):
|
||||
"""list all recorders which can pass the filter in a experiment.
|
||||
|
||||
Args:
|
||||
experiment (str or Experiment): the name of a Experiment or a instance
|
||||
rec_filter_func (Callable, optional): return True to retain the given recorder. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: a dict {rid: recorder} after filtering.
|
||||
"""
|
||||
if isinstance(experiment, str):
|
||||
experiment, _ = R.exp_manager._get_or_create_exp(experiment_name=experiment)
|
||||
recs = experiment.list_recorders()
|
||||
recs_flt = {}
|
||||
for rid, rec in recs.items():
|
||||
if rec_filter_func is None or rec_filter_func(rec):
|
||||
recs_flt[rid] = rec
|
||||
|
||||
return recs_flt
|
||||
|
||||
class TimeAdjuster:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user