mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 03:21:00 +08:00
online_serving V3
This commit is contained in:
@@ -75,3 +75,14 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
"default_exp_name": "Experiment",
|
||||
}
|
||||
})
|
||||
- `mongo`
|
||||
Type: dict, optional parameter, the setting of `MongoDB <https://www.mongodb.com/>`_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.
|
||||
Users need finished `installatin <https://www.mongodb.com/try/download/community>`_ firstly, and run it in a fixed URL.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
# For example, you can initialize qlib below
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo={
|
||||
"task_url": "mongodb://localhost:27017/", # your mongo url
|
||||
"task_db_name": "rolling_db", # the database name of Task Management
|
||||
})
|
||||
|
||||
@@ -3,6 +3,11 @@ from qlib.config import REG_CN
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.config import C
|
||||
from qlib.workflow.task.manage import run_task
|
||||
from qlib.workflow.task.collect import RollingCollector
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow import R
|
||||
from pprint import pprint
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
@@ -60,51 +65,78 @@ task_xgboost_config = {
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset():
|
||||
print("========== reset ==========")
|
||||
TaskManager(task_pool=task_pool).remove()
|
||||
|
||||
C["mongo"] = {
|
||||
"task_url": "mongodb://localhost:27017/", # maybe you need to change it to your url
|
||||
"task_db_name": "rolling_db",
|
||||
}
|
||||
# exp = R.get_exp(experiment_name=exp_name)
|
||||
|
||||
exp_name = "rolling_exp" # experiment name, will be used as the experiment in MLflow
|
||||
task_pool = "rolling_task" # task pool name, will be used as the document in MongoDB
|
||||
|
||||
tasks = task_generator(
|
||||
task_xgboost_config, # default task name
|
||||
RollingGen(step=550, rtype=RollingGen.ROLL_SD), # generate different date segment
|
||||
task_lgb=task_lgb_config, # use "task_lgb" as the task name
|
||||
)
|
||||
|
||||
# Uncomment next two lines to see the generated tasks
|
||||
# from pprint import pprint
|
||||
# pprint(tasks)
|
||||
|
||||
tm = TaskManager(task_pool=task_pool)
|
||||
tm.create_task(tasks) # all tasks will be saved to MongoDB
|
||||
|
||||
from qlib.workflow.task.manage import run_task
|
||||
from qlib.workflow.task.collect import TaskCollector
|
||||
from qlib.model.trainer import task_train
|
||||
|
||||
run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using "task_train" method
|
||||
# for rid in R.list_recorders():
|
||||
# exp.delete_recorder(rid)
|
||||
|
||||
|
||||
def get_task_key(task_config):
|
||||
task_key = task_config["task_key"]
|
||||
rolling_end_timestamp = task_config["dataset"]["kwargs"]["segments"]["test"][1]
|
||||
return task_key, rolling_end_timestamp.strftime("%Y-%m-%d")
|
||||
# This part corresponds to "Task Generating" in the document
|
||||
def task_generating():
|
||||
|
||||
print("========== task_generating ==========")
|
||||
|
||||
tasks = task_generator(
|
||||
tasks=[task_xgboost_config, task_lgb_config],
|
||||
generators=RollingGen(step=550, rtype=RollingGen.ROLL_SD), # generate different date segment
|
||||
)
|
||||
|
||||
pprint(tasks)
|
||||
|
||||
return tasks
|
||||
|
||||
|
||||
def my_filter(task_config):
|
||||
# only choose the results of "task_lgb" and test in 2019 from all tasks
|
||||
task_key, rolling_end = get_task_key(task_config)
|
||||
if task_key == "task_lgb" and rolling_end.startswith("2019"):
|
||||
return True
|
||||
return False
|
||||
# This part corresponds to "Task Storing" in the document
|
||||
def task_storing(tasks):
|
||||
print("========== task_storing ==========")
|
||||
tm = TaskManager(task_pool=task_pool)
|
||||
tm.create_task(tasks) # all tasks will be saved to MongoDB
|
||||
|
||||
|
||||
# name tasks by "get_task_key" and filter tasks by "my_filter"
|
||||
pred_rolling = TaskCollector.collect_predictions(exp_name, get_task_key, my_filter)
|
||||
pred_rolling
|
||||
# This part corresponds to "Task Running" in the document
|
||||
def task_running():
|
||||
print("========== task_running ==========")
|
||||
run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using "task_train" method
|
||||
|
||||
|
||||
# This part corresponds to "Task Collecting" in the document
|
||||
def task_collecting():
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
def get_task_key(task_config):
|
||||
return task_config["model"]["class"]
|
||||
|
||||
def my_filter(recorder):
|
||||
# only choose the results of "LGBModel"
|
||||
task_key = get_task_key(rolling_collector.get_task(recorder))
|
||||
if task_key == "LGBModel":
|
||||
return True
|
||||
return False
|
||||
|
||||
rolling_collector = RollingCollector(exp_name)
|
||||
# group tasks by "get_task_key" and filter tasks by "my_filter"
|
||||
pred_rolling = rolling_collector.collect_rolling_predictions(get_task_key, my_filter)
|
||||
print(pred_rolling)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
mongo_conf = {
|
||||
"task_url": "mongodb://10.0.0.4:27017/", # maybe you need to change it to your url
|
||||
"task_db_name": "rolling_db",
|
||||
}
|
||||
exp_name = "rolling_exp" # experiment name, will be used as the experiment in MLflow
|
||||
task_pool = "rolling_task" # task pool name, will be used as the document in MongoDB
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf)
|
||||
|
||||
reset()
|
||||
tasks = task_generating()
|
||||
task_storing(tasks)
|
||||
task_running()
|
||||
task_collecting()
|
||||
|
||||
@@ -3,15 +3,14 @@ import fire
|
||||
import mlflow
|
||||
from qlib.config import C
|
||||
from qlib.workflow import R
|
||||
from pprint import pprint
|
||||
from qlib.config import REG_CN
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.task.manage import run_task
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.utils import TimeAdjuster
|
||||
from qlib.workflow.task.update import ModelUpdater
|
||||
from qlib.workflow.task.collect import TaskCollector
|
||||
from qlib.workflow.task.collect import RollingCollector
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
|
||||
from qlib.workflow.task.online import RollingOnlineManager
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2013-01-01",
|
||||
@@ -33,7 +32,7 @@ dataset_config = {
|
||||
"segments": {
|
||||
"train": ("2013-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2015-12-31"),
|
||||
"test": ("2016-01-01", "2017-01-01"),
|
||||
"test": ("2016-01-01", "2020-07-10"),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -69,16 +68,25 @@ task_xgboost_config = {
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
|
||||
def print_online_model():
|
||||
print("Current 'online' model:")
|
||||
for online in rolling_online_manager.list_online_model().values():
|
||||
print(online.info["id"])
|
||||
print("Current 'next online' model:")
|
||||
for online in rolling_online_manager.list_next_online_model().values():
|
||||
print(online.info["id"])
|
||||
|
||||
|
||||
# This part corresponds to "Task Generating" in the document
|
||||
def task_generating(**kwargs):
|
||||
print("========================================= task_generating =========================================")
|
||||
def task_generating():
|
||||
|
||||
rolling_generator = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_EX)
|
||||
print("========== task_generating ==========")
|
||||
|
||||
tasks = task_generator(rolling_generator, **kwargs)
|
||||
|
||||
# See the generated tasks in a easy way
|
||||
from pprint import pprint
|
||||
tasks = task_generator(
|
||||
tasks=[task_xgboost_config, task_lgb_config],
|
||||
generators=rolling_gen, # generate different date segment
|
||||
)
|
||||
|
||||
pprint(tasks)
|
||||
|
||||
@@ -87,49 +95,45 @@ def task_generating(**kwargs):
|
||||
|
||||
# This part corresponds to "Task Storing" in the document
|
||||
def task_storing(tasks):
|
||||
print("========================================= task_storing =========================================")
|
||||
print("========== task_storing ==========")
|
||||
tm = TaskManager(task_pool=task_pool)
|
||||
tm.create_task(tasks) # all tasks will be saved to MongoDB
|
||||
|
||||
|
||||
# This part corresponds to "Task Running" in the document
|
||||
def task_running():
|
||||
print("========================================= task_running =========================================")
|
||||
print("========== task_running ==========")
|
||||
run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using "task_train" method
|
||||
|
||||
|
||||
# This part corresponds to "Task Collecting" in the document
|
||||
def task_collecting():
|
||||
print("========================================= task_collecting =========================================")
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
def get_task_key(task_config):
|
||||
task_key = task_config["task_key"]
|
||||
rolling_end_timestamp = task_config["dataset"]["kwargs"]["segments"]["test"][1]
|
||||
if rolling_end_timestamp == None:
|
||||
rolling_end_timestamp = TimeAdjuster().last_date()
|
||||
return task_key, rolling_end_timestamp.strftime("%Y-%m-%d")
|
||||
return task_config["model"]["class"]
|
||||
|
||||
def lgb_filter(task_config):
|
||||
# only choose the results of "task_lgb"
|
||||
task_key, rolling_end = get_task_key(task_config)
|
||||
if task_key == "task_lgb":
|
||||
def my_filter(recorder):
|
||||
# only choose the results of "LGBModel"
|
||||
task_key = get_task_key(rolling_collector.get_task(recorder))
|
||||
if task_key == "LGBModel":
|
||||
return True
|
||||
return False
|
||||
|
||||
task_collector = TaskCollector(exp_name)
|
||||
pred_rolling = task_collector.collect_predictions(
|
||||
get_task_key, lgb_filter
|
||||
) # name tasks by "get_task_key" and filter tasks by "my_filter"
|
||||
rolling_collector = RollingCollector(exp_name)
|
||||
# group tasks by "get_task_key" and filter tasks by "my_filter"
|
||||
pred_rolling = rolling_collector.collect_rolling_predictions(get_task_key, my_filter)
|
||||
print(pred_rolling)
|
||||
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(force_end=False):
|
||||
print("========================================= reset =========================================")
|
||||
TaskManager(task_pool=task_pool).remove()
|
||||
|
||||
print("========== reset ==========")
|
||||
task_manager.remove()
|
||||
for error in task_manager.query():
|
||||
assert False
|
||||
exp = R.get_exp(experiment_name=exp_name)
|
||||
recs = TaskCollector(exp_name).list_recorders(only_finished=True)
|
||||
recs = exp.list_recorders()
|
||||
|
||||
for rid in recs:
|
||||
exp.delete_recorder(rid)
|
||||
@@ -141,82 +145,60 @@ def reset(force_end=False):
|
||||
pass
|
||||
|
||||
|
||||
def set_online_model_to_latest():
|
||||
print(
|
||||
"========================================= set_online_model_to_latest ========================================="
|
||||
)
|
||||
model_updater = ModelUpdater(experiment_name=exp_name)
|
||||
latest_records, latest_test = model_updater.collect_latest_records()
|
||||
model_updater.reset_online_model(latest_records.values())
|
||||
|
||||
|
||||
# Run this firstly to see the workflow in Task Management
|
||||
def first_run():
|
||||
print("========================================= first_run =========================================")
|
||||
print("========== first_run ==========")
|
||||
reset(force_end=True)
|
||||
|
||||
# use "task_lgb" and "task_xgboost" as the task name
|
||||
tasks = task_generating(**{"task_xgboost": task_xgboost_config, "task_lgb": task_lgb_config})
|
||||
tasks = task_generating()
|
||||
task_storing(tasks)
|
||||
task_running()
|
||||
task_collecting()
|
||||
set_online_model_to_latest()
|
||||
|
||||
rolling_online_manager.set_latest_model_to_next_online()
|
||||
rolling_online_manager.reset_online_model()
|
||||
|
||||
|
||||
# Update the predictions of online model
|
||||
def update_predictions():
|
||||
print("========================================= update_predictions =========================================")
|
||||
model_updater = ModelUpdater(experiment_name=exp_name)
|
||||
model_updater.update_online_pred()
|
||||
print("========== update_predictions ==========")
|
||||
rolling_online_manager.update_online_pred()
|
||||
task_collecting()
|
||||
# if there are some next_online_model, then online them. if no, still use current online_model.
|
||||
print_online_model()
|
||||
rolling_online_manager.reset_online_model()
|
||||
print_online_model()
|
||||
|
||||
|
||||
# Update the models using the latest date and set them to online model
|
||||
def update_model():
|
||||
print("========================================= update_model =========================================")
|
||||
# get the latest recorders
|
||||
model_updater = ModelUpdater(experiment_name=exp_name)
|
||||
latest_records, latest_test = model_updater.collect_latest_records()
|
||||
# date adjustment based on trade day of Calendar in Qlib
|
||||
time_adjuster = TimeAdjuster()
|
||||
calendar_latest = time_adjuster.last_date()
|
||||
print("The latest date is ", calendar_latest)
|
||||
if time_adjuster.cal_interval(calendar_latest, latest_test[0]) > rolling_step:
|
||||
print("Need update models!")
|
||||
tasks = {}
|
||||
for rid, rec in latest_records.items():
|
||||
old_task = rec.task
|
||||
test_begin = old_task["dataset"]["kwargs"]["segments"]["test"][0]
|
||||
# modify the test segment to generate new tasks
|
||||
old_task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
|
||||
tasks[old_task["task_key"]] = old_task
|
||||
print("========== update_model ==========")
|
||||
rolling_online_manager.prepare_new_models()
|
||||
print_online_model()
|
||||
rolling_online_manager.set_latest_model_to_next_online()
|
||||
print_online_model()
|
||||
|
||||
# retrain the latest model
|
||||
new_tasks = task_generating(**tasks)
|
||||
task_storing(new_tasks)
|
||||
task_running()
|
||||
task_collecting()
|
||||
latest_records, _ = model_updater.collect_latest_records()
|
||||
|
||||
# set the latest model to online model
|
||||
model_updater.reset_online_model(latest_records.values())
|
||||
def after_day():
|
||||
rolling_online_manager.prepare_signals()
|
||||
update_model()
|
||||
update_predictions()
|
||||
|
||||
|
||||
# Run whole workflow completely
|
||||
def whole_workflow():
|
||||
print("========================================= whole_workflow =========================================")
|
||||
print("========== whole_workflow ==========")
|
||||
# run this at the first time
|
||||
first_run()
|
||||
# run this every day
|
||||
update_predictions()
|
||||
# run this every "rolling_steps" day
|
||||
update_model()
|
||||
# run this every day after trading
|
||||
after_day()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### to train the first version's models, use the command below
|
||||
# python task_manager_rolling_with_updating.py first_run
|
||||
|
||||
####### to update the models using the latest date and set them to online model, use the command below
|
||||
####### to update the models using the latest date, use the command below
|
||||
# python task_manager_rolling_with_updating.py update_model
|
||||
|
||||
####### to update the predictions to the latest date, use the command below
|
||||
@@ -231,8 +213,8 @@ if __name__ == "__main__":
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
C["mongo"] = {
|
||||
"task_url": "mongodb://localhost:27017/", # your MongoDB url
|
||||
"task_db_name": "rolling_db", # database name
|
||||
"task_url": "mongodb://10.0.0.4:27017/", # your MongoDB url
|
||||
"task_db_name": "online", # database name
|
||||
}
|
||||
|
||||
exp_name = "rolling_exp" # experiment name, will be used as the experiment in MLflow
|
||||
@@ -240,5 +222,9 @@ if __name__ == "__main__":
|
||||
rolling_step = 550
|
||||
|
||||
##########################################################################################
|
||||
|
||||
rolling_gen = RollingGen(step=550, rtype=RollingGen.ROLL_SD)
|
||||
rolling_online_manager = RollingOnlineManager(
|
||||
experiment_name=exp_name, rolling_gen=rolling_gen, task_pool=task_pool
|
||||
)
|
||||
task_manager = TaskManager(task_pool=task_pool)
|
||||
fire.Fire()
|
||||
|
||||
@@ -60,4 +60,4 @@ def task_train(task_config: dict, experiment_name: str) -> str:
|
||||
ar = init_instance_by_config(record)
|
||||
ar.generate()
|
||||
|
||||
return recorder.info["id"]
|
||||
return recorder
|
||||
|
||||
@@ -8,7 +8,7 @@ from qlib import get_module_logger
|
||||
|
||||
class TaskCollector:
|
||||
"""
|
||||
Collect the record results of the finished tasks with key and filter
|
||||
Collect the record (or its results) of the tasks
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str) -> None:
|
||||
@@ -17,7 +17,7 @@ class TaskCollector:
|
||||
self.logger = get_module_logger("TaskCollector")
|
||||
|
||||
def list_recorders(self, rec_filter_func=None):
|
||||
""""""
|
||||
|
||||
recs = self.exp.list_recorders()
|
||||
recs_flt = {}
|
||||
for rid, rec in recs.items():
|
||||
@@ -26,57 +26,77 @@ class TaskCollector:
|
||||
|
||||
return recs_flt
|
||||
|
||||
def list_recorders_by_task(self, task_filter_func=None):
|
||||
def rec_filter(recorder):
|
||||
return task_filter_func(self.get_task(recorder))
|
||||
|
||||
return self.list_recorders(rec_filter)
|
||||
|
||||
def list_latest_recorders(self, rec_filter_func=None):
|
||||
recs_flt = self.list_recorders(rec_filter_func)
|
||||
max_test = self.latest_time(recs_flt)
|
||||
latest_rec = {}
|
||||
for rid, rec in recs_flt.items():
|
||||
if self.get_task(rec)["dataset"]["kwargs"]["segments"]["test"] == max_test:
|
||||
latest_rec[rid] = rec
|
||||
return latest_rec
|
||||
|
||||
def get_recorder_by_id(self, recorder_id):
|
||||
return self.exp.get_recorder(recorder_id, create=False)
|
||||
|
||||
def list_recorders_by_task(self, task_filter_func):
|
||||
"""[summary]
|
||||
def get_task(self, recorder):
|
||||
if isinstance(recorder, str):
|
||||
recorder = self.get_recorder_by_id(recorder_id=recorder)
|
||||
try:
|
||||
task = recorder.load_object("task")
|
||||
except OSError:
|
||||
raise OSError(f"Can't find task in {recorder.info['id']}, have you trained with model.trainer.task_train?")
|
||||
return task
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_filter_func : [type], optional
|
||||
[description], by default None
|
||||
"""
|
||||
def latest_time(self, recorders):
|
||||
if len(recorders) == 0:
|
||||
raise Exception(f"Can't find any recorder in {self.exp_name}")
|
||||
max_test = max(self.get_task(rec)["dataset"]["kwargs"]["segments"]["test"] for rec in recorders.values())
|
||||
return max_test
|
||||
|
||||
def rec_filter_func(recorder):
|
||||
try:
|
||||
task = recorder.load_object("task")
|
||||
except OSError:
|
||||
raise OSError(
|
||||
f"Can't find task in {recorder.info['id']}, have you trained with model.trainer.task_train?"
|
||||
)
|
||||
return task_filter_func(task)
|
||||
|
||||
return self.list_recorders(rec_filter_func)
|
||||
class RollingCollector(TaskCollector):
|
||||
"""
|
||||
Collect the record results of the rolling tasks
|
||||
"""
|
||||
|
||||
def collect_predictions(
|
||||
def __init__(
|
||||
self,
|
||||
get_key_func,
|
||||
task_filter_func=None,
|
||||
):
|
||||
"""
|
||||
Collect predictions using a filter and a key function.
|
||||
experiment_name: str,
|
||||
) -> None:
|
||||
super().__init__(experiment_name)
|
||||
self.logger = get_module_logger("RollingCollector")
|
||||
|
||||
def collect_rolling_predictions(self, get_key_func, rec_filter_func=None):
|
||||
"""For rolling tasks, the predictions will be in the diffierent recorder.
|
||||
To collect and concat the predictions of one rolling task, get_key_func will help this method see which group a recorder will be.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_name : str
|
||||
get_key_func : Callable[[dict], bool] -> Union[Number, str, tuple]
|
||||
get the key of a task when collect it
|
||||
filter_func : Callable[[dict], bool] -> bool
|
||||
to judge a task will be collected or not
|
||||
get_key_func : Callable[dict,str]
|
||||
a function that get task config and return its group str
|
||||
rec_filter_func : Callable[Recorder,bool], optional
|
||||
a function that decide whether filter a recorder, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
the dict of predictions
|
||||
a dict of {group: predictions}
|
||||
"""
|
||||
recs_flt = self.list_recorders(task_filter_func=task_filter_func, only_have_task=True)
|
||||
|
||||
# filter records
|
||||
recs_flt = self.list_recorders(rec_filter_func)
|
||||
|
||||
# group
|
||||
recs_group = {}
|
||||
for _, rec in recs_flt.items():
|
||||
params = rec.task
|
||||
group_key = get_key_func(params)
|
||||
task = self.get_task(rec)
|
||||
group_key = get_key_func(task)
|
||||
recs_group.setdefault(group_key, []).append(rec)
|
||||
|
||||
# reduce group
|
||||
@@ -85,39 +105,12 @@ class TaskCollector:
|
||||
pred_l = []
|
||||
for rec in rec_l:
|
||||
pred_l.append(rec.load_object("pred.pkl").iloc[:, 0])
|
||||
pred = pd.concat(pred_l).sort_index()
|
||||
# Make sure the pred are sorted according to the rolling start time
|
||||
pred_l.sort(key=lambda pred: pred.index.get_level_values("datetime").min())
|
||||
pred = pd.concat(pred_l)
|
||||
# If there are duplicated predition, we use the latest perdiction
|
||||
pred = pred[~pred.index.duplicated(keep="last")]
|
||||
pred = pred.sort_index()
|
||||
reduce_group[k] = pred
|
||||
|
||||
self.logger.info(f"Collect {len(reduce_group)} predictions in {self.exp_name}")
|
||||
return reduce_group
|
||||
|
||||
def collect_latest_records(
|
||||
self,
|
||||
task_filter_func=None,
|
||||
):
|
||||
"""Collect latest recorders using a filter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_filter_func : Callable[[dict], bool], optional
|
||||
to judge a task will be collected or not, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict, tuple
|
||||
a dict of recorders and a tuple of test segments
|
||||
"""
|
||||
recs_flt = self.list_recorders(task_filter_func=task_filter_func, only_have_task=True)
|
||||
|
||||
if len(recs_flt) == 0:
|
||||
self.logger.warning("Can not collect any recorders...")
|
||||
return None, None
|
||||
max_test = max(rec.task["dataset"]["kwargs"]["segments"]["test"] for rec in recs_flt.values())
|
||||
|
||||
latest_record = {}
|
||||
for rid, rec in recs_flt.items():
|
||||
if rec.task["dataset"]["kwargs"]["segments"]["test"] == max_test:
|
||||
latest_record[rid] = rec
|
||||
|
||||
self.logger.info(f"Collect {len(latest_record)} latest records in {self.exp_name}")
|
||||
return latest_record, max_test
|
||||
return reduce_group
|
||||
@@ -9,56 +9,40 @@ import typing
|
||||
from .utils import TimeAdjuster
|
||||
|
||||
|
||||
def task_generator(*args, **kwargs) -> list:
|
||||
"""
|
||||
Accept the dict of task config and the TaskGen to generate different tasks.
|
||||
There is no limit to the number and position of input.
|
||||
The key of input will add to task config.
|
||||
def task_generator(tasks, generators) -> list:
|
||||
"""Use a list of TaskGen and a list of task templates to generate different tasks.
|
||||
|
||||
for example:
|
||||
There are 3 task_config(a,b,c) and 2 TaskGen(A,B). A will double the task_config and B will triple.
|
||||
task_generator(a_key=a, b_key=b, c_key=c, A, B) will finally generate 3*2*3 = 18 task_config.
|
||||
For examples:
|
||||
|
||||
There are 3 task templates a,b,c and 2 TaskGen A,B. A will generates 2 tasks from a template and B will generates 3 tasks from a template.
|
||||
task_generator([a, b, c], [A, B]) will finally generate 3*2*3 = 18 tasks.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args : dict or TaskGen
|
||||
kwargs : dict or TaskGen
|
||||
tasks : List[dict]
|
||||
a list of task templates
|
||||
generators : List[TaskGen]
|
||||
a list of TaskGen
|
||||
|
||||
Returns
|
||||
-------
|
||||
gen_task_list : list
|
||||
a list of task config after generating
|
||||
list
|
||||
a list of tasks
|
||||
"""
|
||||
tasks_list = []
|
||||
gen_list = []
|
||||
|
||||
tmp_id = 1
|
||||
for task in args:
|
||||
if isinstance(task, dict):
|
||||
task["task_key"] = tmp_id
|
||||
tmp_id += 1
|
||||
tasks_list.append(task)
|
||||
elif isinstance(task, TaskGen):
|
||||
gen_list.append(task)
|
||||
else:
|
||||
raise NotImplementedError(f"{type(task)} is not supported in task_generator")
|
||||
|
||||
for key, task in kwargs.items():
|
||||
if isinstance(task, dict):
|
||||
task["task_key"] = key
|
||||
tasks_list.append(task)
|
||||
elif isinstance(task, TaskGen):
|
||||
gen_list.append(task)
|
||||
else:
|
||||
raise NotImplementedError(f"{type(task)} is not supported in task_generator")
|
||||
if isinstance(tasks, dict):
|
||||
tasks = [tasks]
|
||||
if isinstance(generators, TaskGen):
|
||||
generators = [generators]
|
||||
|
||||
# generate gen_task_list
|
||||
gen_task_list = []
|
||||
for gen in gen_list:
|
||||
for gen in generators:
|
||||
new_task_list = []
|
||||
for task in tasks_list:
|
||||
for task in tasks:
|
||||
new_task_list.extend(gen.generate(task))
|
||||
gen_task_list = new_task_list
|
||||
|
||||
return gen_task_list
|
||||
|
||||
|
||||
@@ -144,7 +128,13 @@ class RollingGen(TaskGen):
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
"kwargs": {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi100",
|
||||
},
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
@@ -153,8 +143,12 @@ class RollingGen(TaskGen):
|
||||
},
|
||||
},
|
||||
},
|
||||
# You shoud record the data in specific sequence
|
||||
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
|
||||
"record": [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
}
|
||||
"""
|
||||
res = []
|
||||
|
||||
@@ -245,6 +245,11 @@ class TaskManager:
|
||||
for t in task_pool.find(query):
|
||||
yield self._decode_task(t)
|
||||
|
||||
def get_task_result(self, task, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
result = task_pool.find_one({"filter": task})
|
||||
return self._decode_task(result)["res"]
|
||||
|
||||
def commit_task_res(self, task, res, status=None, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
# A workaround to use the class attribute.
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
from typing import Union, List
|
||||
from typing import Dict, Union, List
|
||||
from qlib import get_module_logger
|
||||
from qlib.workflow import R
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.recorder import MLflowRecorder, Recorder
|
||||
from qlib.workflow.task.collect import TaskCollector
|
||||
from qlib.workflow.task.update import ModelUpdater
|
||||
from qlib.workflow.task.utils import TimeAdjuster
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.manage import run_task
|
||||
|
||||
|
||||
class OnlineManager:
|
||||
@@ -19,9 +23,10 @@ class OnlineManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `prepare_new_models` method.")
|
||||
|
||||
ONLINE_TAG = "online_model"
|
||||
ONLINE_TAG_TRUE = "True"
|
||||
ONLINE_TAG_FALSE = "False"
|
||||
ONLINE_KEY = "online_status" # the tag key in recorder
|
||||
ONLINE_TAG = "online" # the 'online' model
|
||||
NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model
|
||||
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
|
||||
|
||||
def __init__(self, experiment_name: str) -> None:
|
||||
"""ModelUpdater needs experiment name to find the records
|
||||
@@ -35,45 +40,57 @@ class OnlineManager:
|
||||
self.exp_name = experiment_name
|
||||
self.tc = TaskCollector(experiment_name)
|
||||
|
||||
def set_online_model(self, recorder: Union[str, Recorder]):
|
||||
"""online model will be identified at the tags of the record
|
||||
def set_next_online_model(self, recorder: MLflowRecorder):
|
||||
recorder.set_tags(**{self.ONLINE_KEY: self.NEXT_ONLINE_TAG})
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recorder: Union[str,Recorder]
|
||||
the id of a Recorder or the Recorder instance
|
||||
"""
|
||||
if isinstance(recorder, str):
|
||||
recorder = self.tc.get_recorder_by_id(recorder_id=recorder)
|
||||
recorder.set_tags(**{self.ONLINE_TAG: self.ONLINE_TAG_TRUE})
|
||||
def set_online_model(self, recorder: MLflowRecorder):
|
||||
"""online model will be identified at the tags of the record"""
|
||||
recorder.set_tags(**{self.ONLINE_KEY: self.ONLINE_TAG})
|
||||
|
||||
def cancel_online_model(self, recorder: Union[str, Recorder]):
|
||||
if isinstance(recorder, str):
|
||||
recorder = self.tc.get_recorder_by_id(recorder_id=recorder)
|
||||
recorder.set_tags(**{self.ONLINE_TAG: self.ONLINE_TAG_FALSE})
|
||||
def set_offline_model(self, recorder: MLflowRecorder):
|
||||
recorder.set_tags(**{self.ONLINE_KEY: self.OFFLINE_TAG})
|
||||
|
||||
def cancel_all_online_model(self):
|
||||
def offline_all_model(self):
|
||||
recs = self.tc.list_recorders()
|
||||
for rid, rec in recs.items():
|
||||
self.cancel_online_model(rec)
|
||||
self.set_offline_model(rec)
|
||||
|
||||
def reset_online_model(self, recorders: Union[str, List[Union[str, Recorder]]]):
|
||||
"""cancel all online model and reset the given model to online model
|
||||
def reset_online_model(self, recorders: Union[List, Dict] = None):
|
||||
"""offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recorders: List[Union[str,Recorder]]
|
||||
the list of the id of a Recorder or the Recorder instance
|
||||
Args:
|
||||
recorders (Union[List, Dict], optional):
|
||||
the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model.
|
||||
"""
|
||||
self.cancel_all_online_model()
|
||||
if isinstance(recorders, str):
|
||||
recorders = [recorders]
|
||||
for rec_or_rid in recorders:
|
||||
self.set_online_model(rec_or_rid)
|
||||
if recorders is None:
|
||||
recorders = self.list_next_online_model()
|
||||
if len(recorders) == 0:
|
||||
self.logger.info("No 'next online' model, just use current 'online' models.")
|
||||
return
|
||||
self.offline_all_model()
|
||||
if isinstance(recorders, dict):
|
||||
recorders = recorders.values()
|
||||
for rec in recorders:
|
||||
self.set_online_model(rec)
|
||||
self.logger.info(f"Reset {len(recorders)} models to 'online'.")
|
||||
|
||||
def online_filter(self, recorder):
|
||||
def set_latest_model_to_next_online(self):
|
||||
latest_rec = self.tc.list_latest_recorders()
|
||||
for rid, rec in latest_rec.items():
|
||||
self.set_next_online_model(rec)
|
||||
self.logger.info(f"Set {len(latest_rec)} latest models to 'next online'.")
|
||||
|
||||
@staticmethod
|
||||
def online_filter(recorder):
|
||||
tags = recorder.list_tags()
|
||||
if tags.get(self.ONLINE_TAG, self.ONLINE_TAG_FALSE) == self.ONLINE_TAG_TRUE:
|
||||
if tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG) == OnlineManager.ONLINE_TAG:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def next_online_filter(recorder):
|
||||
tags = recorder.list_tags()
|
||||
if tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG) == OnlineManager.NEXT_ONLINE_TAG:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -88,21 +105,45 @@ class OnlineManager:
|
||||
|
||||
return self.tc.list_recorders(rec_filter_func=self.online_filter)
|
||||
|
||||
def list_next_online_model(self):
|
||||
return self.tc.list_recorders(rec_filter_func=self.next_online_filter)
|
||||
|
||||
def update_online_pred(self):
|
||||
"""update all online model predictions to the latest day in Calendar."""
|
||||
"""update all online model predictions to the latest day in Calendar"""
|
||||
mu = ModelUpdater(self.exp_name)
|
||||
cnt = mu.update_all_pred(self.online_filter)
|
||||
self.logger.info(f"Finish updating {cnt} online model predictions of {self.exp_name}.")
|
||||
|
||||
|
||||
class RollingOnlineManager(OnlineManager):
|
||||
def prepare_new_models(self, tasks: List[dict]):
|
||||
"""prepare(train) new models
|
||||
def __init__(self, experiment_name: str, rolling_gen: RollingGen, task_pool) -> None:
|
||||
super().__init__(experiment_name)
|
||||
self.ta = TimeAdjuster()
|
||||
self.rg = rolling_gen
|
||||
self.tm = TaskManager(task_pool=task_pool)
|
||||
self.logger = get_module_logger("RollingOnlineManager")
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tasks : List[dict]
|
||||
a list of tasks
|
||||
def prepare_new_models(self):
|
||||
"""prepare(train) new models based on online model"""
|
||||
latest_records = self.tc.list_latest_recorders(self.online_filter) # if we need online_filter here?
|
||||
max_test = self.tc.latest_time(latest_records)
|
||||
calendar_latest = self.ta.last_date()
|
||||
if self.ta.cal_interval(calendar_latest, max_test[0]) > self.rg.step:
|
||||
old_tasks = []
|
||||
for rid, rec in latest_records.items():
|
||||
task = self.tc.get_task(rec)
|
||||
test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
|
||||
# modify the test segment to generate new tasks
|
||||
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
|
||||
old_tasks.append(task)
|
||||
new_tasks = task_generator(old_tasks, self.rg)
|
||||
self.tm.create_task(new_tasks)
|
||||
run_task(task_train, self.tm.task_pool, experiment_name=self.exp_name)
|
||||
self.logger.info(f"Finished prepare {len(new_tasks)} new models.")
|
||||
return new_tasks
|
||||
self.logger.info("No need to prepare any new models.")
|
||||
return []
|
||||
|
||||
"""
|
||||
def prepare_signals(self):
|
||||
# prepare the signals of today
|
||||
pass
|
||||
|
||||
@@ -53,7 +53,7 @@ class ModelUpdater:
|
||||
datahandler.init(datahandler.IT_LS)
|
||||
return dataset
|
||||
|
||||
def update_pred(self, recorder: Union[str, Recorder]):
|
||||
def update_pred(self, recorder: Recorder):
|
||||
"""update predictions to the latest day in Calendar based on rid
|
||||
|
||||
Parameters
|
||||
@@ -61,8 +61,6 @@ class ModelUpdater:
|
||||
recorder: Union[str,Recorder]
|
||||
the id of a Recorder or the Recorder instance
|
||||
"""
|
||||
if isinstance(recorder, str):
|
||||
recorder = self.tc.get_recorder_by_id(recorder_id=recorder)
|
||||
old_pred = recorder.load_object("pred.pkl")
|
||||
last_end = old_pred.index.get_level_values("datetime").max()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user