From 8abdd63869c9eb329e78e72eef0850ac147b83e7 Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Thu, 18 Mar 2021 09:30:01 +0000 Subject: [PATCH] online_serving V3 --- docs/start/initialization.rst | 11 ++ examples/taskmanager/task_manager_rolling.py | 112 ++++++++----- .../task_manager_rolling_with_updating.py | 150 ++++++++---------- qlib/model/trainer.py | 2 +- qlib/workflow/task/collect.py | 127 +++++++-------- qlib/workflow/task/gen.py | 70 ++++---- qlib/workflow/task/manage.py | 5 + qlib/workflow/task/online.py | 125 ++++++++++----- qlib/workflow/task/update.py | 4 +- 9 files changed, 333 insertions(+), 273 deletions(-) diff --git a/docs/start/initialization.rst b/docs/start/initialization.rst index 15aa957d1..95ab7f77d 100644 --- a/docs/start/initialization.rst +++ b/docs/start/initialization.rst @@ -75,3 +75,14 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo "default_exp_name": "Experiment", } }) +- `mongo` + Type: dict, optional parameter, the setting of `MongoDB `_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing. + Users need finished `installatin `_ firstly, and run it in a fixed URL. + + .. code-block:: Python + + # For example, you can initialize qlib below + qlib.init(provider_uri=provider_uri, region=REG_CN, mongo={ + "task_url": "mongodb://localhost:27017/", # your mongo url + "task_db_name": "rolling_db", # the database name of Task Management + }) diff --git a/examples/taskmanager/task_manager_rolling.py b/examples/taskmanager/task_manager_rolling.py index 9223ed818..ffa88d75e 100644 --- a/examples/taskmanager/task_manager_rolling.py +++ b/examples/taskmanager/task_manager_rolling.py @@ -3,6 +3,11 @@ from qlib.config import REG_CN from qlib.workflow.task.gen import RollingGen, task_generator from qlib.workflow.task.manage import TaskManager from qlib.config import C +from qlib.workflow.task.manage import run_task +from qlib.workflow.task.collect import RollingCollector +from qlib.model.trainer import task_train +from qlib.workflow import R +from pprint import pprint data_handler_config = { "start_time": "2008-01-01", @@ -60,51 +65,78 @@ task_xgboost_config = { "record": record_config, } -provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir -qlib.init(provider_uri=provider_uri, region=REG_CN) +# Reset all things to the first status, be careful to save important data +def reset(): + print("========== reset ==========") + TaskManager(task_pool=task_pool).remove() -C["mongo"] = { - "task_url": "mongodb://localhost:27017/", # maybe you need to change it to your url - "task_db_name": "rolling_db", -} + # exp = R.get_exp(experiment_name=exp_name) -exp_name = "rolling_exp" # experiment name, will be used as the experiment in MLflow -task_pool = "rolling_task" # task pool name, will be used as the document in MongoDB - -tasks = task_generator( - task_xgboost_config, # default task name - RollingGen(step=550, rtype=RollingGen.ROLL_SD), # generate different date segment - task_lgb=task_lgb_config, # use "task_lgb" as the task name -) - -# Uncomment next two lines to see the generated tasks -# from pprint import pprint -# pprint(tasks) - -tm = TaskManager(task_pool=task_pool) -tm.create_task(tasks) # all tasks will be saved to MongoDB - -from qlib.workflow.task.manage import run_task -from qlib.workflow.task.collect import TaskCollector -from qlib.model.trainer import task_train - -run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using "task_train" method + # for rid in R.list_recorders(): + # exp.delete_recorder(rid) -def get_task_key(task_config): - task_key = task_config["task_key"] - rolling_end_timestamp = task_config["dataset"]["kwargs"]["segments"]["test"][1] - return task_key, rolling_end_timestamp.strftime("%Y-%m-%d") +# This part corresponds to "Task Generating" in the document +def task_generating(): + + print("========== task_generating ==========") + + tasks = task_generator( + tasks=[task_xgboost_config, task_lgb_config], + generators=RollingGen(step=550, rtype=RollingGen.ROLL_SD), # generate different date segment + ) + + pprint(tasks) + + return tasks -def my_filter(task_config): - # only choose the results of "task_lgb" and test in 2019 from all tasks - task_key, rolling_end = get_task_key(task_config) - if task_key == "task_lgb" and rolling_end.startswith("2019"): - return True - return False +# This part corresponds to "Task Storing" in the document +def task_storing(tasks): + print("========== task_storing ==========") + tm = TaskManager(task_pool=task_pool) + tm.create_task(tasks) # all tasks will be saved to MongoDB -# name tasks by "get_task_key" and filter tasks by "my_filter" -pred_rolling = TaskCollector.collect_predictions(exp_name, get_task_key, my_filter) -pred_rolling +# This part corresponds to "Task Running" in the document +def task_running(): + print("========== task_running ==========") + run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using "task_train" method + + +# This part corresponds to "Task Collecting" in the document +def task_collecting(): + print("========== task_collecting ==========") + + def get_task_key(task_config): + return task_config["model"]["class"] + + def my_filter(recorder): + # only choose the results of "LGBModel" + task_key = get_task_key(rolling_collector.get_task(recorder)) + if task_key == "LGBModel": + return True + return False + + rolling_collector = RollingCollector(exp_name) + # group tasks by "get_task_key" and filter tasks by "my_filter" + pred_rolling = rolling_collector.collect_rolling_predictions(get_task_key, my_filter) + print(pred_rolling) + + +if __name__ == "__main__": + + provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir + mongo_conf = { + "task_url": "mongodb://10.0.0.4:27017/", # maybe you need to change it to your url + "task_db_name": "rolling_db", + } + exp_name = "rolling_exp" # experiment name, will be used as the experiment in MLflow + task_pool = "rolling_task" # task pool name, will be used as the document in MongoDB + qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf) + + reset() + tasks = task_generating() + task_storing(tasks) + task_running() + task_collecting() diff --git a/examples/taskmanager/task_manager_rolling_with_updating.py b/examples/taskmanager/task_manager_rolling_with_updating.py index c69b558bc..27e3ad269 100644 --- a/examples/taskmanager/task_manager_rolling_with_updating.py +++ b/examples/taskmanager/task_manager_rolling_with_updating.py @@ -3,15 +3,14 @@ import fire import mlflow from qlib.config import C from qlib.workflow import R +from pprint import pprint from qlib.config import REG_CN from qlib.model.trainer import task_train from qlib.workflow.task.manage import run_task from qlib.workflow.task.manage import TaskManager -from qlib.workflow.task.utils import TimeAdjuster -from qlib.workflow.task.update import ModelUpdater -from qlib.workflow.task.collect import TaskCollector +from qlib.workflow.task.collect import RollingCollector from qlib.workflow.task.gen import RollingGen, task_generator - +from qlib.workflow.task.online import RollingOnlineManager data_handler_config = { "start_time": "2013-01-01", @@ -33,7 +32,7 @@ dataset_config = { "segments": { "train": ("2013-01-01", "2014-12-31"), "valid": ("2015-01-01", "2015-12-31"), - "test": ("2016-01-01", "2017-01-01"), + "test": ("2016-01-01", "2020-07-10"), }, }, } @@ -69,16 +68,25 @@ task_xgboost_config = { "record": record_config, } + +def print_online_model(): + print("Current 'online' model:") + for online in rolling_online_manager.list_online_model().values(): + print(online.info["id"]) + print("Current 'next online' model:") + for online in rolling_online_manager.list_next_online_model().values(): + print(online.info["id"]) + + # This part corresponds to "Task Generating" in the document -def task_generating(**kwargs): - print("========================================= task_generating =========================================") +def task_generating(): - rolling_generator = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_EX) + print("========== task_generating ==========") - tasks = task_generator(rolling_generator, **kwargs) - - # See the generated tasks in a easy way - from pprint import pprint + tasks = task_generator( + tasks=[task_xgboost_config, task_lgb_config], + generators=rolling_gen, # generate different date segment + ) pprint(tasks) @@ -87,49 +95,45 @@ def task_generating(**kwargs): # This part corresponds to "Task Storing" in the document def task_storing(tasks): - print("========================================= task_storing =========================================") + print("========== task_storing ==========") tm = TaskManager(task_pool=task_pool) tm.create_task(tasks) # all tasks will be saved to MongoDB # This part corresponds to "Task Running" in the document def task_running(): - print("========================================= task_running =========================================") + print("========== task_running ==========") run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using "task_train" method # This part corresponds to "Task Collecting" in the document def task_collecting(): - print("========================================= task_collecting =========================================") + print("========== task_collecting ==========") def get_task_key(task_config): - task_key = task_config["task_key"] - rolling_end_timestamp = task_config["dataset"]["kwargs"]["segments"]["test"][1] - if rolling_end_timestamp == None: - rolling_end_timestamp = TimeAdjuster().last_date() - return task_key, rolling_end_timestamp.strftime("%Y-%m-%d") + return task_config["model"]["class"] - def lgb_filter(task_config): - # only choose the results of "task_lgb" - task_key, rolling_end = get_task_key(task_config) - if task_key == "task_lgb": + def my_filter(recorder): + # only choose the results of "LGBModel" + task_key = get_task_key(rolling_collector.get_task(recorder)) + if task_key == "LGBModel": return True return False - task_collector = TaskCollector(exp_name) - pred_rolling = task_collector.collect_predictions( - get_task_key, lgb_filter - ) # name tasks by "get_task_key" and filter tasks by "my_filter" + rolling_collector = RollingCollector(exp_name) + # group tasks by "get_task_key" and filter tasks by "my_filter" + pred_rolling = rolling_collector.collect_rolling_predictions(get_task_key, my_filter) print(pred_rolling) # Reset all things to the first status, be careful to save important data def reset(force_end=False): - print("========================================= reset =========================================") - TaskManager(task_pool=task_pool).remove() - + print("========== reset ==========") + task_manager.remove() + for error in task_manager.query(): + assert False exp = R.get_exp(experiment_name=exp_name) - recs = TaskCollector(exp_name).list_recorders(only_finished=True) + recs = exp.list_recorders() for rid in recs: exp.delete_recorder(rid) @@ -141,82 +145,60 @@ def reset(force_end=False): pass -def set_online_model_to_latest(): - print( - "========================================= set_online_model_to_latest =========================================" - ) - model_updater = ModelUpdater(experiment_name=exp_name) - latest_records, latest_test = model_updater.collect_latest_records() - model_updater.reset_online_model(latest_records.values()) - - # Run this firstly to see the workflow in Task Management def first_run(): - print("========================================= first_run =========================================") + print("========== first_run ==========") reset(force_end=True) - # use "task_lgb" and "task_xgboost" as the task name - tasks = task_generating(**{"task_xgboost": task_xgboost_config, "task_lgb": task_lgb_config}) + tasks = task_generating() task_storing(tasks) task_running() task_collecting() - set_online_model_to_latest() + + rolling_online_manager.set_latest_model_to_next_online() + rolling_online_manager.reset_online_model() # Update the predictions of online model def update_predictions(): - print("========================================= update_predictions =========================================") - model_updater = ModelUpdater(experiment_name=exp_name) - model_updater.update_online_pred() + print("========== update_predictions ==========") + rolling_online_manager.update_online_pred() + task_collecting() + # if there are some next_online_model, then online them. if no, still use current online_model. + print_online_model() + rolling_online_manager.reset_online_model() + print_online_model() # Update the models using the latest date and set them to online model def update_model(): - print("========================================= update_model =========================================") - # get the latest recorders - model_updater = ModelUpdater(experiment_name=exp_name) - latest_records, latest_test = model_updater.collect_latest_records() - # date adjustment based on trade day of Calendar in Qlib - time_adjuster = TimeAdjuster() - calendar_latest = time_adjuster.last_date() - print("The latest date is ", calendar_latest) - if time_adjuster.cal_interval(calendar_latest, latest_test[0]) > rolling_step: - print("Need update models!") - tasks = {} - for rid, rec in latest_records.items(): - old_task = rec.task - test_begin = old_task["dataset"]["kwargs"]["segments"]["test"][0] - # modify the test segment to generate new tasks - old_task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest) - tasks[old_task["task_key"]] = old_task + print("========== update_model ==========") + rolling_online_manager.prepare_new_models() + print_online_model() + rolling_online_manager.set_latest_model_to_next_online() + print_online_model() - # retrain the latest model - new_tasks = task_generating(**tasks) - task_storing(new_tasks) - task_running() - task_collecting() - latest_records, _ = model_updater.collect_latest_records() - # set the latest model to online model - model_updater.reset_online_model(latest_records.values()) +def after_day(): + rolling_online_manager.prepare_signals() + update_model() + update_predictions() # Run whole workflow completely def whole_workflow(): - print("========================================= whole_workflow =========================================") + print("========== whole_workflow ==========") # run this at the first time first_run() - # run this every day - update_predictions() - # run this every "rolling_steps" day - update_model() + # run this every day after trading + after_day() if __name__ == "__main__": ####### to train the first version's models, use the command below # python task_manager_rolling_with_updating.py first_run - ####### to update the models using the latest date and set them to online model, use the command below + ####### to update the models using the latest date, use the command below # python task_manager_rolling_with_updating.py update_model ####### to update the predictions to the latest date, use the command below @@ -231,8 +213,8 @@ if __name__ == "__main__": qlib.init(provider_uri=provider_uri, region=REG_CN) C["mongo"] = { - "task_url": "mongodb://localhost:27017/", # your MongoDB url - "task_db_name": "rolling_db", # database name + "task_url": "mongodb://10.0.0.4:27017/", # your MongoDB url + "task_db_name": "online", # database name } exp_name = "rolling_exp" # experiment name, will be used as the experiment in MLflow @@ -240,5 +222,9 @@ if __name__ == "__main__": rolling_step = 550 ########################################################################################## - + rolling_gen = RollingGen(step=550, rtype=RollingGen.ROLL_SD) + rolling_online_manager = RollingOnlineManager( + experiment_name=exp_name, rolling_gen=rolling_gen, task_pool=task_pool + ) + task_manager = TaskManager(task_pool=task_pool) fire.Fire() diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 5c5609eb0..c18145073 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -60,4 +60,4 @@ def task_train(task_config: dict, experiment_name: str) -> str: ar = init_instance_by_config(record) ar.generate() - return recorder.info["id"] + return recorder diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index 21639e7f8..fb7ff0b0b 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -8,7 +8,7 @@ from qlib import get_module_logger class TaskCollector: """ - Collect the record results of the finished tasks with key and filter + Collect the record (or its results) of the tasks """ def __init__(self, experiment_name: str) -> None: @@ -17,7 +17,7 @@ class TaskCollector: self.logger = get_module_logger("TaskCollector") def list_recorders(self, rec_filter_func=None): - """""" + recs = self.exp.list_recorders() recs_flt = {} for rid, rec in recs.items(): @@ -26,57 +26,77 @@ class TaskCollector: return recs_flt + def list_recorders_by_task(self, task_filter_func=None): + def rec_filter(recorder): + return task_filter_func(self.get_task(recorder)) + + return self.list_recorders(rec_filter) + + def list_latest_recorders(self, rec_filter_func=None): + recs_flt = self.list_recorders(rec_filter_func) + max_test = self.latest_time(recs_flt) + latest_rec = {} + for rid, rec in recs_flt.items(): + if self.get_task(rec)["dataset"]["kwargs"]["segments"]["test"] == max_test: + latest_rec[rid] = rec + return latest_rec + def get_recorder_by_id(self, recorder_id): return self.exp.get_recorder(recorder_id, create=False) - def list_recorders_by_task(self, task_filter_func): - """[summary] + def get_task(self, recorder): + if isinstance(recorder, str): + recorder = self.get_recorder_by_id(recorder_id=recorder) + try: + task = recorder.load_object("task") + except OSError: + raise OSError(f"Can't find task in {recorder.info['id']}, have you trained with model.trainer.task_train?") + return task - Parameters - ---------- - task_filter_func : [type], optional - [description], by default None - """ + def latest_time(self, recorders): + if len(recorders) == 0: + raise Exception(f"Can't find any recorder in {self.exp_name}") + max_test = max(self.get_task(rec)["dataset"]["kwargs"]["segments"]["test"] for rec in recorders.values()) + return max_test - def rec_filter_func(recorder): - try: - task = recorder.load_object("task") - except OSError: - raise OSError( - f"Can't find task in {recorder.info['id']}, have you trained with model.trainer.task_train?" - ) - return task_filter_func(task) - return self.list_recorders(rec_filter_func) +class RollingCollector(TaskCollector): + """ + Collect the record results of the rolling tasks + """ - def collect_predictions( + def __init__( self, - get_key_func, - task_filter_func=None, - ): - """ - Collect predictions using a filter and a key function. + experiment_name: str, + ) -> None: + super().__init__(experiment_name) + self.logger = get_module_logger("RollingCollector") + + def collect_rolling_predictions(self, get_key_func, rec_filter_func=None): + """For rolling tasks, the predictions will be in the diffierent recorder. + To collect and concat the predictions of one rolling task, get_key_func will help this method see which group a recorder will be. Parameters ---------- - experiment_name : str - get_key_func : Callable[[dict], bool] -> Union[Number, str, tuple] - get the key of a task when collect it - filter_func : Callable[[dict], bool] -> bool - to judge a task will be collected or not + get_key_func : Callable[dict,str] + a function that get task config and return its group str + rec_filter_func : Callable[Recorder,bool], optional + a function that decide whether filter a recorder, by default None Returns ------- dict - the dict of predictions + a dict of {group: predictions} """ - recs_flt = self.list_recorders(task_filter_func=task_filter_func, only_have_task=True) + + # filter records + recs_flt = self.list_recorders(rec_filter_func) # group recs_group = {} for _, rec in recs_flt.items(): - params = rec.task - group_key = get_key_func(params) + task = self.get_task(rec) + group_key = get_key_func(task) recs_group.setdefault(group_key, []).append(rec) # reduce group @@ -85,39 +105,12 @@ class TaskCollector: pred_l = [] for rec in rec_l: pred_l.append(rec.load_object("pred.pkl").iloc[:, 0]) - pred = pd.concat(pred_l).sort_index() + # Make sure the pred are sorted according to the rolling start time + pred_l.sort(key=lambda pred: pred.index.get_level_values("datetime").min()) + pred = pd.concat(pred_l) + # If there are duplicated predition, we use the latest perdiction + pred = pred[~pred.index.duplicated(keep="last")] + pred = pred.sort_index() reduce_group[k] = pred - self.logger.info(f"Collect {len(reduce_group)} predictions in {self.exp_name}") - return reduce_group - - def collect_latest_records( - self, - task_filter_func=None, - ): - """Collect latest recorders using a filter. - - Parameters - ---------- - task_filter_func : Callable[[dict], bool], optional - to judge a task will be collected or not, by default None - - Returns - ------- - dict, tuple - a dict of recorders and a tuple of test segments - """ - recs_flt = self.list_recorders(task_filter_func=task_filter_func, only_have_task=True) - - if len(recs_flt) == 0: - self.logger.warning("Can not collect any recorders...") - return None, None - max_test = max(rec.task["dataset"]["kwargs"]["segments"]["test"] for rec in recs_flt.values()) - - latest_record = {} - for rid, rec in recs_flt.items(): - if rec.task["dataset"]["kwargs"]["segments"]["test"] == max_test: - latest_record[rid] = rec - - self.logger.info(f"Collect {len(latest_record)} latest records in {self.exp_name}") - return latest_record, max_test + return reduce_group \ No newline at end of file diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index 96448cefe..63000d77d 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -9,56 +9,40 @@ import typing from .utils import TimeAdjuster -def task_generator(*args, **kwargs) -> list: - """ - Accept the dict of task config and the TaskGen to generate different tasks. - There is no limit to the number and position of input. - The key of input will add to task config. +def task_generator(tasks, generators) -> list: + """Use a list of TaskGen and a list of task templates to generate different tasks. - for example: - There are 3 task_config(a,b,c) and 2 TaskGen(A,B). A will double the task_config and B will triple. - task_generator(a_key=a, b_key=b, c_key=c, A, B) will finally generate 3*2*3 = 18 task_config. + For examples: + + There are 3 task templates a,b,c and 2 TaskGen A,B. A will generates 2 tasks from a template and B will generates 3 tasks from a template. + task_generator([a, b, c], [A, B]) will finally generate 3*2*3 = 18 tasks. Parameters ---------- - args : dict or TaskGen - kwargs : dict or TaskGen + tasks : List[dict] + a list of task templates + generators : List[TaskGen] + a list of TaskGen Returns ------- - gen_task_list : list - a list of task config after generating + list + a list of tasks """ - tasks_list = [] - gen_list = [] - tmp_id = 1 - for task in args: - if isinstance(task, dict): - task["task_key"] = tmp_id - tmp_id += 1 - tasks_list.append(task) - elif isinstance(task, TaskGen): - gen_list.append(task) - else: - raise NotImplementedError(f"{type(task)} is not supported in task_generator") - - for key, task in kwargs.items(): - if isinstance(task, dict): - task["task_key"] = key - tasks_list.append(task) - elif isinstance(task, TaskGen): - gen_list.append(task) - else: - raise NotImplementedError(f"{type(task)} is not supported in task_generator") + if isinstance(tasks, dict): + tasks = [tasks] + if isinstance(generators, TaskGen): + generators = [generators] # generate gen_task_list gen_task_list = [] - for gen in gen_list: + for gen in generators: new_task_list = [] - for task in tasks_list: + for task in tasks: new_task_list.extend(gen.generate(task)) gen_task_list = new_task_list + return gen_task_list @@ -144,7 +128,13 @@ class RollingGen(TaskGen): "handler": { "class": "Alpha158", "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, + "kwargs": { + "start_time": "2008-01-01", + "end_time": "2020-08-01", + "fit_start_time": "2008-01-01", + "fit_end_time": "2014-12-31", + "instruments": "csi100", + }, }, "segments": { "train": ("2008-01-01", "2014-12-31"), @@ -153,8 +143,12 @@ class RollingGen(TaskGen): }, }, }, - # You shoud record the data in specific sequence - # "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'], + "record": [ + { + "class": "SignalRecord", + "module_path": "qlib.workflow.record_temp", + }, + ] } """ res = [] diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index e97fdb774..db4c15038 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -245,6 +245,11 @@ class TaskManager: for t in task_pool.find(query): yield self._decode_task(t) + def get_task_result(self, task, task_pool=None): + task_pool = self._get_task_pool(task_pool) + result = task_pool.find_one({"filter": task}) + return self._decode_task(result)["res"] + def commit_task_res(self, task, res, status=None, task_pool=None): task_pool = self._get_task_pool(task_pool) # A workaround to use the class attribute. diff --git a/qlib/workflow/task/online.py b/qlib/workflow/task/online.py index f2b8e5706..8d551e858 100644 --- a/qlib/workflow/task/online.py +++ b/qlib/workflow/task/online.py @@ -1,10 +1,14 @@ -from typing import Union, List +from typing import Dict, Union, List from qlib import get_module_logger from qlib.workflow import R from qlib.model.trainer import task_train -from qlib.workflow.recorder import Recorder +from qlib.workflow.recorder import MLflowRecorder, Recorder from qlib.workflow.task.collect import TaskCollector from qlib.workflow.task.update import ModelUpdater +from qlib.workflow.task.utils import TimeAdjuster +from qlib.workflow.task.gen import RollingGen, task_generator +from qlib.workflow.task.manage import TaskManager +from qlib.workflow.task.manage import run_task class OnlineManager: @@ -19,9 +23,10 @@ class OnlineManager: """ raise NotImplementedError(f"Please implement the `prepare_new_models` method.") - ONLINE_TAG = "online_model" - ONLINE_TAG_TRUE = "True" - ONLINE_TAG_FALSE = "False" + ONLINE_KEY = "online_status" # the tag key in recorder + ONLINE_TAG = "online" # the 'online' model + NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model + OFFLINE_TAG = "offline" # the 'offline' model, not for online serving def __init__(self, experiment_name: str) -> None: """ModelUpdater needs experiment name to find the records @@ -35,45 +40,57 @@ class OnlineManager: self.exp_name = experiment_name self.tc = TaskCollector(experiment_name) - def set_online_model(self, recorder: Union[str, Recorder]): - """online model will be identified at the tags of the record + def set_next_online_model(self, recorder: MLflowRecorder): + recorder.set_tags(**{self.ONLINE_KEY: self.NEXT_ONLINE_TAG}) - Parameters - ---------- - recorder: Union[str,Recorder] - the id of a Recorder or the Recorder instance - """ - if isinstance(recorder, str): - recorder = self.tc.get_recorder_by_id(recorder_id=recorder) - recorder.set_tags(**{self.ONLINE_TAG: self.ONLINE_TAG_TRUE}) + def set_online_model(self, recorder: MLflowRecorder): + """online model will be identified at the tags of the record""" + recorder.set_tags(**{self.ONLINE_KEY: self.ONLINE_TAG}) - def cancel_online_model(self, recorder: Union[str, Recorder]): - if isinstance(recorder, str): - recorder = self.tc.get_recorder_by_id(recorder_id=recorder) - recorder.set_tags(**{self.ONLINE_TAG: self.ONLINE_TAG_FALSE}) + def set_offline_model(self, recorder: MLflowRecorder): + recorder.set_tags(**{self.ONLINE_KEY: self.OFFLINE_TAG}) - def cancel_all_online_model(self): + def offline_all_model(self): recs = self.tc.list_recorders() for rid, rec in recs.items(): - self.cancel_online_model(rec) + self.set_offline_model(rec) - def reset_online_model(self, recorders: Union[str, List[Union[str, Recorder]]]): - """cancel all online model and reset the given model to online model + def reset_online_model(self, recorders: Union[List, Dict] = None): + """offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing. - Parameters - ---------- - recorders: List[Union[str,Recorder]] - the list of the id of a Recorder or the Recorder instance + Args: + recorders (Union[List, Dict], optional): + the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model. """ - self.cancel_all_online_model() - if isinstance(recorders, str): - recorders = [recorders] - for rec_or_rid in recorders: - self.set_online_model(rec_or_rid) + if recorders is None: + recorders = self.list_next_online_model() + if len(recorders) == 0: + self.logger.info("No 'next online' model, just use current 'online' models.") + return + self.offline_all_model() + if isinstance(recorders, dict): + recorders = recorders.values() + for rec in recorders: + self.set_online_model(rec) + self.logger.info(f"Reset {len(recorders)} models to 'online'.") - def online_filter(self, recorder): + def set_latest_model_to_next_online(self): + latest_rec = self.tc.list_latest_recorders() + for rid, rec in latest_rec.items(): + self.set_next_online_model(rec) + self.logger.info(f"Set {len(latest_rec)} latest models to 'next online'.") + + @staticmethod + def online_filter(recorder): tags = recorder.list_tags() - if tags.get(self.ONLINE_TAG, self.ONLINE_TAG_FALSE) == self.ONLINE_TAG_TRUE: + if tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG) == OnlineManager.ONLINE_TAG: + return True + return False + + @staticmethod + def next_online_filter(recorder): + tags = recorder.list_tags() + if tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG) == OnlineManager.NEXT_ONLINE_TAG: return True return False @@ -88,21 +105,45 @@ class OnlineManager: return self.tc.list_recorders(rec_filter_func=self.online_filter) + def list_next_online_model(self): + return self.tc.list_recorders(rec_filter_func=self.next_online_filter) + def update_online_pred(self): - """update all online model predictions to the latest day in Calendar.""" + """update all online model predictions to the latest day in Calendar""" mu = ModelUpdater(self.exp_name) cnt = mu.update_all_pred(self.online_filter) self.logger.info(f"Finish updating {cnt} online model predictions of {self.exp_name}.") class RollingOnlineManager(OnlineManager): - def prepare_new_models(self, tasks: List[dict]): - """prepare(train) new models + def __init__(self, experiment_name: str, rolling_gen: RollingGen, task_pool) -> None: + super().__init__(experiment_name) + self.ta = TimeAdjuster() + self.rg = rolling_gen + self.tm = TaskManager(task_pool=task_pool) + self.logger = get_module_logger("RollingOnlineManager") - Parameters - ---------- - tasks : List[dict] - a list of tasks + def prepare_new_models(self): + """prepare(train) new models based on online model""" + latest_records = self.tc.list_latest_recorders(self.online_filter) # if we need online_filter here? + max_test = self.tc.latest_time(latest_records) + calendar_latest = self.ta.last_date() + if self.ta.cal_interval(calendar_latest, max_test[0]) > self.rg.step: + old_tasks = [] + for rid, rec in latest_records.items(): + task = self.tc.get_task(rec) + test_begin = task["dataset"]["kwargs"]["segments"]["test"][0] + # modify the test segment to generate new tasks + task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest) + old_tasks.append(task) + new_tasks = task_generator(old_tasks, self.rg) + self.tm.create_task(new_tasks) + run_task(task_train, self.tm.task_pool, experiment_name=self.exp_name) + self.logger.info(f"Finished prepare {len(new_tasks)} new models.") + return new_tasks + self.logger.info("No need to prepare any new models.") + return [] - """ + def prepare_signals(self): + # prepare the signals of today pass diff --git a/qlib/workflow/task/update.py b/qlib/workflow/task/update.py index 9f68dbd0a..fcee84349 100644 --- a/qlib/workflow/task/update.py +++ b/qlib/workflow/task/update.py @@ -53,7 +53,7 @@ class ModelUpdater: datahandler.init(datahandler.IT_LS) return dataset - def update_pred(self, recorder: Union[str, Recorder]): + def update_pred(self, recorder: Recorder): """update predictions to the latest day in Calendar based on rid Parameters @@ -61,8 +61,6 @@ class ModelUpdater: recorder: Union[str,Recorder] the id of a Recorder or the Recorder instance """ - if isinstance(recorder, str): - recorder = self.tc.get_recorder_by_id(recorder_id=recorder) old_pred = recorder.load_object("pred.pkl") last_end = old_pred.index.get_level_values("datetime").max()