1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 03:21:00 +08:00

online_serving V3

This commit is contained in:
lzh222333
2021-03-18 09:30:01 +00:00
parent d33041dc24
commit 8abdd63869
9 changed files with 333 additions and 273 deletions

View File

@@ -75,3 +75,14 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
"default_exp_name": "Experiment",
}
})
- `mongo`
Type: dict, optional parameter, the setting of `MongoDB <https://www.mongodb.com/>`_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.
Users need finished `installatin <https://www.mongodb.com/try/download/community>`_ firstly, and run it in a fixed URL.
.. code-block:: Python
# For example, you can initialize qlib below
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo={
"task_url": "mongodb://localhost:27017/", # your mongo url
"task_db_name": "rolling_db", # the database name of Task Management
})

View File

@@ -3,6 +3,11 @@ from qlib.config import REG_CN
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
from qlib.config import C
from qlib.workflow.task.manage import run_task
from qlib.workflow.task.collect import RollingCollector
from qlib.model.trainer import task_train
from qlib.workflow import R
from pprint import pprint
data_handler_config = {
"start_time": "2008-01-01",
@@ -60,51 +65,78 @@ task_xgboost_config = {
"record": record_config,
}
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
qlib.init(provider_uri=provider_uri, region=REG_CN)
# Reset all things to the first status, be careful to save important data
def reset():
print("========== reset ==========")
TaskManager(task_pool=task_pool).remove()
C["mongo"] = {
"task_url": "mongodb://localhost:27017/", # maybe you need to change it to your url
"task_db_name": "rolling_db",
}
# exp = R.get_exp(experiment_name=exp_name)
exp_name = "rolling_exp" # experiment name, will be used as the experiment in MLflow
task_pool = "rolling_task" # task pool name, will be used as the document in MongoDB
tasks = task_generator(
task_xgboost_config, # default task name
RollingGen(step=550, rtype=RollingGen.ROLL_SD), # generate different date segment
task_lgb=task_lgb_config, # use "task_lgb" as the task name
)
# Uncomment next two lines to see the generated tasks
# from pprint import pprint
# pprint(tasks)
tm = TaskManager(task_pool=task_pool)
tm.create_task(tasks) # all tasks will be saved to MongoDB
from qlib.workflow.task.manage import run_task
from qlib.workflow.task.collect import TaskCollector
from qlib.model.trainer import task_train
run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using "task_train" method
# for rid in R.list_recorders():
# exp.delete_recorder(rid)
def get_task_key(task_config):
task_key = task_config["task_key"]
rolling_end_timestamp = task_config["dataset"]["kwargs"]["segments"]["test"][1]
return task_key, rolling_end_timestamp.strftime("%Y-%m-%d")
# This part corresponds to "Task Generating" in the document
def task_generating():
print("========== task_generating ==========")
tasks = task_generator(
tasks=[task_xgboost_config, task_lgb_config],
generators=RollingGen(step=550, rtype=RollingGen.ROLL_SD), # generate different date segment
)
pprint(tasks)
return tasks
def my_filter(task_config):
# only choose the results of "task_lgb" and test in 2019 from all tasks
task_key, rolling_end = get_task_key(task_config)
if task_key == "task_lgb" and rolling_end.startswith("2019"):
return True
return False
# This part corresponds to "Task Storing" in the document
def task_storing(tasks):
print("========== task_storing ==========")
tm = TaskManager(task_pool=task_pool)
tm.create_task(tasks) # all tasks will be saved to MongoDB
# name tasks by "get_task_key" and filter tasks by "my_filter"
pred_rolling = TaskCollector.collect_predictions(exp_name, get_task_key, my_filter)
pred_rolling
# This part corresponds to "Task Running" in the document
def task_running():
print("========== task_running ==========")
run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using "task_train" method
# This part corresponds to "Task Collecting" in the document
def task_collecting():
print("========== task_collecting ==========")
def get_task_key(task_config):
return task_config["model"]["class"]
def my_filter(recorder):
# only choose the results of "LGBModel"
task_key = get_task_key(rolling_collector.get_task(recorder))
if task_key == "LGBModel":
return True
return False
rolling_collector = RollingCollector(exp_name)
# group tasks by "get_task_key" and filter tasks by "my_filter"
pred_rolling = rolling_collector.collect_rolling_predictions(get_task_key, my_filter)
print(pred_rolling)
if __name__ == "__main__":
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
mongo_conf = {
"task_url": "mongodb://10.0.0.4:27017/", # maybe you need to change it to your url
"task_db_name": "rolling_db",
}
exp_name = "rolling_exp" # experiment name, will be used as the experiment in MLflow
task_pool = "rolling_task" # task pool name, will be used as the document in MongoDB
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf)
reset()
tasks = task_generating()
task_storing(tasks)
task_running()
task_collecting()

View File

@@ -3,15 +3,14 @@ import fire
import mlflow
from qlib.config import C
from qlib.workflow import R
from pprint import pprint
from qlib.config import REG_CN
from qlib.model.trainer import task_train
from qlib.workflow.task.manage import run_task
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.task.utils import TimeAdjuster
from qlib.workflow.task.update import ModelUpdater
from qlib.workflow.task.collect import TaskCollector
from qlib.workflow.task.collect import RollingCollector
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.online import RollingOnlineManager
data_handler_config = {
"start_time": "2013-01-01",
@@ -33,7 +32,7 @@ dataset_config = {
"segments": {
"train": ("2013-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2015-12-31"),
"test": ("2016-01-01", "2017-01-01"),
"test": ("2016-01-01", "2020-07-10"),
},
},
}
@@ -69,16 +68,25 @@ task_xgboost_config = {
"record": record_config,
}
def print_online_model():
print("Current 'online' model:")
for online in rolling_online_manager.list_online_model().values():
print(online.info["id"])
print("Current 'next online' model:")
for online in rolling_online_manager.list_next_online_model().values():
print(online.info["id"])
# This part corresponds to "Task Generating" in the document
def task_generating(**kwargs):
print("========================================= task_generating =========================================")
def task_generating():
rolling_generator = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_EX)
print("========== task_generating ==========")
tasks = task_generator(rolling_generator, **kwargs)
# See the generated tasks in a easy way
from pprint import pprint
tasks = task_generator(
tasks=[task_xgboost_config, task_lgb_config],
generators=rolling_gen, # generate different date segment
)
pprint(tasks)
@@ -87,49 +95,45 @@ def task_generating(**kwargs):
# This part corresponds to "Task Storing" in the document
def task_storing(tasks):
print("========================================= task_storing =========================================")
print("========== task_storing ==========")
tm = TaskManager(task_pool=task_pool)
tm.create_task(tasks) # all tasks will be saved to MongoDB
# This part corresponds to "Task Running" in the document
def task_running():
print("========================================= task_running =========================================")
print("========== task_running ==========")
run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using "task_train" method
# This part corresponds to "Task Collecting" in the document
def task_collecting():
print("========================================= task_collecting =========================================")
print("========== task_collecting ==========")
def get_task_key(task_config):
task_key = task_config["task_key"]
rolling_end_timestamp = task_config["dataset"]["kwargs"]["segments"]["test"][1]
if rolling_end_timestamp == None:
rolling_end_timestamp = TimeAdjuster().last_date()
return task_key, rolling_end_timestamp.strftime("%Y-%m-%d")
return task_config["model"]["class"]
def lgb_filter(task_config):
# only choose the results of "task_lgb"
task_key, rolling_end = get_task_key(task_config)
if task_key == "task_lgb":
def my_filter(recorder):
# only choose the results of "LGBModel"
task_key = get_task_key(rolling_collector.get_task(recorder))
if task_key == "LGBModel":
return True
return False
task_collector = TaskCollector(exp_name)
pred_rolling = task_collector.collect_predictions(
get_task_key, lgb_filter
) # name tasks by "get_task_key" and filter tasks by "my_filter"
rolling_collector = RollingCollector(exp_name)
# group tasks by "get_task_key" and filter tasks by "my_filter"
pred_rolling = rolling_collector.collect_rolling_predictions(get_task_key, my_filter)
print(pred_rolling)
# Reset all things to the first status, be careful to save important data
def reset(force_end=False):
print("========================================= reset =========================================")
TaskManager(task_pool=task_pool).remove()
print("========== reset ==========")
task_manager.remove()
for error in task_manager.query():
assert False
exp = R.get_exp(experiment_name=exp_name)
recs = TaskCollector(exp_name).list_recorders(only_finished=True)
recs = exp.list_recorders()
for rid in recs:
exp.delete_recorder(rid)
@@ -141,82 +145,60 @@ def reset(force_end=False):
pass
def set_online_model_to_latest():
print(
"========================================= set_online_model_to_latest ========================================="
)
model_updater = ModelUpdater(experiment_name=exp_name)
latest_records, latest_test = model_updater.collect_latest_records()
model_updater.reset_online_model(latest_records.values())
# Run this firstly to see the workflow in Task Management
def first_run():
print("========================================= first_run =========================================")
print("========== first_run ==========")
reset(force_end=True)
# use "task_lgb" and "task_xgboost" as the task name
tasks = task_generating(**{"task_xgboost": task_xgboost_config, "task_lgb": task_lgb_config})
tasks = task_generating()
task_storing(tasks)
task_running()
task_collecting()
set_online_model_to_latest()
rolling_online_manager.set_latest_model_to_next_online()
rolling_online_manager.reset_online_model()
# Update the predictions of online model
def update_predictions():
print("========================================= update_predictions =========================================")
model_updater = ModelUpdater(experiment_name=exp_name)
model_updater.update_online_pred()
print("========== update_predictions ==========")
rolling_online_manager.update_online_pred()
task_collecting()
# if there are some next_online_model, then online them. if no, still use current online_model.
print_online_model()
rolling_online_manager.reset_online_model()
print_online_model()
# Update the models using the latest date and set them to online model
def update_model():
print("========================================= update_model =========================================")
# get the latest recorders
model_updater = ModelUpdater(experiment_name=exp_name)
latest_records, latest_test = model_updater.collect_latest_records()
# date adjustment based on trade day of Calendar in Qlib
time_adjuster = TimeAdjuster()
calendar_latest = time_adjuster.last_date()
print("The latest date is ", calendar_latest)
if time_adjuster.cal_interval(calendar_latest, latest_test[0]) > rolling_step:
print("Need update models!")
tasks = {}
for rid, rec in latest_records.items():
old_task = rec.task
test_begin = old_task["dataset"]["kwargs"]["segments"]["test"][0]
# modify the test segment to generate new tasks
old_task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
tasks[old_task["task_key"]] = old_task
print("========== update_model ==========")
rolling_online_manager.prepare_new_models()
print_online_model()
rolling_online_manager.set_latest_model_to_next_online()
print_online_model()
# retrain the latest model
new_tasks = task_generating(**tasks)
task_storing(new_tasks)
task_running()
task_collecting()
latest_records, _ = model_updater.collect_latest_records()
# set the latest model to online model
model_updater.reset_online_model(latest_records.values())
def after_day():
rolling_online_manager.prepare_signals()
update_model()
update_predictions()
# Run whole workflow completely
def whole_workflow():
print("========================================= whole_workflow =========================================")
print("========== whole_workflow ==========")
# run this at the first time
first_run()
# run this every day
update_predictions()
# run this every "rolling_steps" day
update_model()
# run this every day after trading
after_day()
if __name__ == "__main__":
####### to train the first version's models, use the command below
# python task_manager_rolling_with_updating.py first_run
####### to update the models using the latest date and set them to online model, use the command below
####### to update the models using the latest date, use the command below
# python task_manager_rolling_with_updating.py update_model
####### to update the predictions to the latest date, use the command below
@@ -231,8 +213,8 @@ if __name__ == "__main__":
qlib.init(provider_uri=provider_uri, region=REG_CN)
C["mongo"] = {
"task_url": "mongodb://localhost:27017/", # your MongoDB url
"task_db_name": "rolling_db", # database name
"task_url": "mongodb://10.0.0.4:27017/", # your MongoDB url
"task_db_name": "online", # database name
}
exp_name = "rolling_exp" # experiment name, will be used as the experiment in MLflow
@@ -240,5 +222,9 @@ if __name__ == "__main__":
rolling_step = 550
##########################################################################################
rolling_gen = RollingGen(step=550, rtype=RollingGen.ROLL_SD)
rolling_online_manager = RollingOnlineManager(
experiment_name=exp_name, rolling_gen=rolling_gen, task_pool=task_pool
)
task_manager = TaskManager(task_pool=task_pool)
fire.Fire()

View File

@@ -60,4 +60,4 @@ def task_train(task_config: dict, experiment_name: str) -> str:
ar = init_instance_by_config(record)
ar.generate()
return recorder.info["id"]
return recorder

View File

@@ -8,7 +8,7 @@ from qlib import get_module_logger
class TaskCollector:
"""
Collect the record results of the finished tasks with key and filter
Collect the record (or its results) of the tasks
"""
def __init__(self, experiment_name: str) -> None:
@@ -17,7 +17,7 @@ class TaskCollector:
self.logger = get_module_logger("TaskCollector")
def list_recorders(self, rec_filter_func=None):
""""""
recs = self.exp.list_recorders()
recs_flt = {}
for rid, rec in recs.items():
@@ -26,57 +26,77 @@ class TaskCollector:
return recs_flt
def list_recorders_by_task(self, task_filter_func=None):
def rec_filter(recorder):
return task_filter_func(self.get_task(recorder))
return self.list_recorders(rec_filter)
def list_latest_recorders(self, rec_filter_func=None):
recs_flt = self.list_recorders(rec_filter_func)
max_test = self.latest_time(recs_flt)
latest_rec = {}
for rid, rec in recs_flt.items():
if self.get_task(rec)["dataset"]["kwargs"]["segments"]["test"] == max_test:
latest_rec[rid] = rec
return latest_rec
def get_recorder_by_id(self, recorder_id):
return self.exp.get_recorder(recorder_id, create=False)
def list_recorders_by_task(self, task_filter_func):
"""[summary]
def get_task(self, recorder):
if isinstance(recorder, str):
recorder = self.get_recorder_by_id(recorder_id=recorder)
try:
task = recorder.load_object("task")
except OSError:
raise OSError(f"Can't find task in {recorder.info['id']}, have you trained with model.trainer.task_train?")
return task
Parameters
----------
task_filter_func : [type], optional
[description], by default None
"""
def latest_time(self, recorders):
if len(recorders) == 0:
raise Exception(f"Can't find any recorder in {self.exp_name}")
max_test = max(self.get_task(rec)["dataset"]["kwargs"]["segments"]["test"] for rec in recorders.values())
return max_test
def rec_filter_func(recorder):
try:
task = recorder.load_object("task")
except OSError:
raise OSError(
f"Can't find task in {recorder.info['id']}, have you trained with model.trainer.task_train?"
)
return task_filter_func(task)
return self.list_recorders(rec_filter_func)
class RollingCollector(TaskCollector):
"""
Collect the record results of the rolling tasks
"""
def collect_predictions(
def __init__(
self,
get_key_func,
task_filter_func=None,
):
"""
Collect predictions using a filter and a key function.
experiment_name: str,
) -> None:
super().__init__(experiment_name)
self.logger = get_module_logger("RollingCollector")
def collect_rolling_predictions(self, get_key_func, rec_filter_func=None):
"""For rolling tasks, the predictions will be in the diffierent recorder.
To collect and concat the predictions of one rolling task, get_key_func will help this method see which group a recorder will be.
Parameters
----------
experiment_name : str
get_key_func : Callable[[dict], bool] -> Union[Number, str, tuple]
get the key of a task when collect it
filter_func : Callable[[dict], bool] -> bool
to judge a task will be collected or not
get_key_func : Callable[dict,str]
a function that get task config and return its group str
rec_filter_func : Callable[Recorder,bool], optional
a function that decide whether filter a recorder, by default None
Returns
-------
dict
the dict of predictions
a dict of {group: predictions}
"""
recs_flt = self.list_recorders(task_filter_func=task_filter_func, only_have_task=True)
# filter records
recs_flt = self.list_recorders(rec_filter_func)
# group
recs_group = {}
for _, rec in recs_flt.items():
params = rec.task
group_key = get_key_func(params)
task = self.get_task(rec)
group_key = get_key_func(task)
recs_group.setdefault(group_key, []).append(rec)
# reduce group
@@ -85,39 +105,12 @@ class TaskCollector:
pred_l = []
for rec in rec_l:
pred_l.append(rec.load_object("pred.pkl").iloc[:, 0])
pred = pd.concat(pred_l).sort_index()
# Make sure the pred are sorted according to the rolling start time
pred_l.sort(key=lambda pred: pred.index.get_level_values("datetime").min())
pred = pd.concat(pred_l)
# If there are duplicated predition, we use the latest perdiction
pred = pred[~pred.index.duplicated(keep="last")]
pred = pred.sort_index()
reduce_group[k] = pred
self.logger.info(f"Collect {len(reduce_group)} predictions in {self.exp_name}")
return reduce_group
def collect_latest_records(
self,
task_filter_func=None,
):
"""Collect latest recorders using a filter.
Parameters
----------
task_filter_func : Callable[[dict], bool], optional
to judge a task will be collected or not, by default None
Returns
-------
dict, tuple
a dict of recorders and a tuple of test segments
"""
recs_flt = self.list_recorders(task_filter_func=task_filter_func, only_have_task=True)
if len(recs_flt) == 0:
self.logger.warning("Can not collect any recorders...")
return None, None
max_test = max(rec.task["dataset"]["kwargs"]["segments"]["test"] for rec in recs_flt.values())
latest_record = {}
for rid, rec in recs_flt.items():
if rec.task["dataset"]["kwargs"]["segments"]["test"] == max_test:
latest_record[rid] = rec
self.logger.info(f"Collect {len(latest_record)} latest records in {self.exp_name}")
return latest_record, max_test
return reduce_group

View File

@@ -9,56 +9,40 @@ import typing
from .utils import TimeAdjuster
def task_generator(*args, **kwargs) -> list:
"""
Accept the dict of task config and the TaskGen to generate different tasks.
There is no limit to the number and position of input.
The key of input will add to task config.
def task_generator(tasks, generators) -> list:
"""Use a list of TaskGen and a list of task templates to generate different tasks.
for example:
There are 3 task_config(a,b,c) and 2 TaskGen(A,B). A will double the task_config and B will triple.
task_generator(a_key=a, b_key=b, c_key=c, A, B) will finally generate 3*2*3 = 18 task_config.
For examples:
There are 3 task templates a,b,c and 2 TaskGen A,B. A will generates 2 tasks from a template and B will generates 3 tasks from a template.
task_generator([a, b, c], [A, B]) will finally generate 3*2*3 = 18 tasks.
Parameters
----------
args : dict or TaskGen
kwargs : dict or TaskGen
tasks : List[dict]
a list of task templates
generators : List[TaskGen]
a list of TaskGen
Returns
-------
gen_task_list : list
a list of task config after generating
list
a list of tasks
"""
tasks_list = []
gen_list = []
tmp_id = 1
for task in args:
if isinstance(task, dict):
task["task_key"] = tmp_id
tmp_id += 1
tasks_list.append(task)
elif isinstance(task, TaskGen):
gen_list.append(task)
else:
raise NotImplementedError(f"{type(task)} is not supported in task_generator")
for key, task in kwargs.items():
if isinstance(task, dict):
task["task_key"] = key
tasks_list.append(task)
elif isinstance(task, TaskGen):
gen_list.append(task)
else:
raise NotImplementedError(f"{type(task)} is not supported in task_generator")
if isinstance(tasks, dict):
tasks = [tasks]
if isinstance(generators, TaskGen):
generators = [generators]
# generate gen_task_list
gen_task_list = []
for gen in gen_list:
for gen in generators:
new_task_list = []
for task in tasks_list:
for task in tasks:
new_task_list.extend(gen.generate(task))
gen_task_list = new_task_list
return gen_task_list
@@ -144,7 +128,13 @@ class RollingGen(TaskGen):
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
"kwargs": {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi100",
},
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
@@ -153,8 +143,12 @@ class RollingGen(TaskGen):
},
},
},
# You shoud record the data in specific sequence
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
"record": [
{
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
},
]
}
"""
res = []

View File

@@ -245,6 +245,11 @@ class TaskManager:
for t in task_pool.find(query):
yield self._decode_task(t)
def get_task_result(self, task, task_pool=None):
task_pool = self._get_task_pool(task_pool)
result = task_pool.find_one({"filter": task})
return self._decode_task(result)["res"]
def commit_task_res(self, task, res, status=None, task_pool=None):
task_pool = self._get_task_pool(task_pool)
# A workaround to use the class attribute.

View File

@@ -1,10 +1,14 @@
from typing import Union, List
from typing import Dict, Union, List
from qlib import get_module_logger
from qlib.workflow import R
from qlib.model.trainer import task_train
from qlib.workflow.recorder import Recorder
from qlib.workflow.recorder import MLflowRecorder, Recorder
from qlib.workflow.task.collect import TaskCollector
from qlib.workflow.task.update import ModelUpdater
from qlib.workflow.task.utils import TimeAdjuster
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.task.manage import run_task
class OnlineManager:
@@ -19,9 +23,10 @@ class OnlineManager:
"""
raise NotImplementedError(f"Please implement the `prepare_new_models` method.")
ONLINE_TAG = "online_model"
ONLINE_TAG_TRUE = "True"
ONLINE_TAG_FALSE = "False"
ONLINE_KEY = "online_status" # the tag key in recorder
ONLINE_TAG = "online" # the 'online' model
NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
def __init__(self, experiment_name: str) -> None:
"""ModelUpdater needs experiment name to find the records
@@ -35,45 +40,57 @@ class OnlineManager:
self.exp_name = experiment_name
self.tc = TaskCollector(experiment_name)
def set_online_model(self, recorder: Union[str, Recorder]):
"""online model will be identified at the tags of the record
def set_next_online_model(self, recorder: MLflowRecorder):
recorder.set_tags(**{self.ONLINE_KEY: self.NEXT_ONLINE_TAG})
Parameters
----------
recorder: Union[str,Recorder]
the id of a Recorder or the Recorder instance
"""
if isinstance(recorder, str):
recorder = self.tc.get_recorder_by_id(recorder_id=recorder)
recorder.set_tags(**{self.ONLINE_TAG: self.ONLINE_TAG_TRUE})
def set_online_model(self, recorder: MLflowRecorder):
"""online model will be identified at the tags of the record"""
recorder.set_tags(**{self.ONLINE_KEY: self.ONLINE_TAG})
def cancel_online_model(self, recorder: Union[str, Recorder]):
if isinstance(recorder, str):
recorder = self.tc.get_recorder_by_id(recorder_id=recorder)
recorder.set_tags(**{self.ONLINE_TAG: self.ONLINE_TAG_FALSE})
def set_offline_model(self, recorder: MLflowRecorder):
recorder.set_tags(**{self.ONLINE_KEY: self.OFFLINE_TAG})
def cancel_all_online_model(self):
def offline_all_model(self):
recs = self.tc.list_recorders()
for rid, rec in recs.items():
self.cancel_online_model(rec)
self.set_offline_model(rec)
def reset_online_model(self, recorders: Union[str, List[Union[str, Recorder]]]):
"""cancel all online model and reset the given model to online model
def reset_online_model(self, recorders: Union[List, Dict] = None):
"""offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing.
Parameters
----------
recorders: List[Union[str,Recorder]]
the list of the id of a Recorder or the Recorder instance
Args:
recorders (Union[List, Dict], optional):
the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model.
"""
self.cancel_all_online_model()
if isinstance(recorders, str):
recorders = [recorders]
for rec_or_rid in recorders:
self.set_online_model(rec_or_rid)
if recorders is None:
recorders = self.list_next_online_model()
if len(recorders) == 0:
self.logger.info("No 'next online' model, just use current 'online' models.")
return
self.offline_all_model()
if isinstance(recorders, dict):
recorders = recorders.values()
for rec in recorders:
self.set_online_model(rec)
self.logger.info(f"Reset {len(recorders)} models to 'online'.")
def online_filter(self, recorder):
def set_latest_model_to_next_online(self):
latest_rec = self.tc.list_latest_recorders()
for rid, rec in latest_rec.items():
self.set_next_online_model(rec)
self.logger.info(f"Set {len(latest_rec)} latest models to 'next online'.")
@staticmethod
def online_filter(recorder):
tags = recorder.list_tags()
if tags.get(self.ONLINE_TAG, self.ONLINE_TAG_FALSE) == self.ONLINE_TAG_TRUE:
if tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG) == OnlineManager.ONLINE_TAG:
return True
return False
@staticmethod
def next_online_filter(recorder):
tags = recorder.list_tags()
if tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG) == OnlineManager.NEXT_ONLINE_TAG:
return True
return False
@@ -88,21 +105,45 @@ class OnlineManager:
return self.tc.list_recorders(rec_filter_func=self.online_filter)
def list_next_online_model(self):
return self.tc.list_recorders(rec_filter_func=self.next_online_filter)
def update_online_pred(self):
"""update all online model predictions to the latest day in Calendar."""
"""update all online model predictions to the latest day in Calendar"""
mu = ModelUpdater(self.exp_name)
cnt = mu.update_all_pred(self.online_filter)
self.logger.info(f"Finish updating {cnt} online model predictions of {self.exp_name}.")
class RollingOnlineManager(OnlineManager):
def prepare_new_models(self, tasks: List[dict]):
"""prepare(train) new models
def __init__(self, experiment_name: str, rolling_gen: RollingGen, task_pool) -> None:
super().__init__(experiment_name)
self.ta = TimeAdjuster()
self.rg = rolling_gen
self.tm = TaskManager(task_pool=task_pool)
self.logger = get_module_logger("RollingOnlineManager")
Parameters
----------
tasks : List[dict]
a list of tasks
def prepare_new_models(self):
"""prepare(train) new models based on online model"""
latest_records = self.tc.list_latest_recorders(self.online_filter) # if we need online_filter here?
max_test = self.tc.latest_time(latest_records)
calendar_latest = self.ta.last_date()
if self.ta.cal_interval(calendar_latest, max_test[0]) > self.rg.step:
old_tasks = []
for rid, rec in latest_records.items():
task = self.tc.get_task(rec)
test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
# modify the test segment to generate new tasks
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
old_tasks.append(task)
new_tasks = task_generator(old_tasks, self.rg)
self.tm.create_task(new_tasks)
run_task(task_train, self.tm.task_pool, experiment_name=self.exp_name)
self.logger.info(f"Finished prepare {len(new_tasks)} new models.")
return new_tasks
self.logger.info("No need to prepare any new models.")
return []
"""
def prepare_signals(self):
# prepare the signals of today
pass

View File

@@ -53,7 +53,7 @@ class ModelUpdater:
datahandler.init(datahandler.IT_LS)
return dataset
def update_pred(self, recorder: Union[str, Recorder]):
def update_pred(self, recorder: Recorder):
"""update predictions to the latest day in Calendar based on rid
Parameters
@@ -61,8 +61,6 @@ class ModelUpdater:
recorder: Union[str,Recorder]
the id of a Recorder or the Recorder instance
"""
if isinstance(recorder, str):
recorder = self.tc.get_recorder_by_id(recorder_id=recorder)
old_pred = recorder.load_object("pred.pkl")
last_end = old_pred.index.get_level_values("datetime").max()