diff --git a/docs/start/initialization.rst b/docs/start/initialization.rst index 95ab7f77d..32c17ff83 100644 --- a/docs/start/initialization.rst +++ b/docs/start/initialization.rst @@ -77,7 +77,7 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo }) - `mongo` Type: dict, optional parameter, the setting of `MongoDB `_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing. - Users need finished `installatin `_ firstly, and run it in a fixed URL. + Users need finished `installation `_ firstly, and run it in a fixed URL. .. code-block:: Python diff --git a/examples/taskmanager/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py similarity index 75% rename from examples/taskmanager/task_manager_rolling.py rename to examples/model_rolling/task_manager_rolling.py index ffa88d75e..70a4f7d7e 100644 --- a/examples/taskmanager/task_manager_rolling.py +++ b/examples/model_rolling/task_manager_rolling.py @@ -1,13 +1,13 @@ +from pprint import pprint + +import fire import qlib from qlib.config import REG_CN -from qlib.workflow.task.gen import RollingGen, task_generator -from qlib.workflow.task.manage import TaskManager -from qlib.config import C -from qlib.workflow.task.manage import run_task -from qlib.workflow.task.collect import RollingCollector from qlib.model.trainer import task_train from qlib.workflow import R -from pprint import pprint +from qlib.workflow.task.collect import RollingCollector +from qlib.workflow.task.gen import RollingGen, task_generator +from qlib.workflow.task.manage import TaskManager, run_task data_handler_config = { "start_time": "2008-01-01", @@ -66,14 +66,14 @@ task_xgboost_config = { } # Reset all things to the first status, be careful to save important data -def reset(): +def reset(task_pool, exp_name): print("========== reset ==========") TaskManager(task_pool=task_pool).remove() - # exp = R.get_exp(experiment_name=exp_name) + exp, _ = R.exp_manager._get_or_create_exp(experiment_name=exp_name) - # for rid in R.list_recorders(): - # exp.delete_recorder(rid) + for rid in exp.list_recorders(): + exp.delete_recorder(rid) # This part corresponds to "Task Generating" in the document @@ -92,51 +92,58 @@ def task_generating(): # This part corresponds to "Task Storing" in the document -def task_storing(tasks): +def task_storing(tasks, task_pool, exp_name): print("========== task_storing ==========") tm = TaskManager(task_pool=task_pool) tm.create_task(tasks) # all tasks will be saved to MongoDB # This part corresponds to "Task Running" in the document -def task_running(): +def task_running(task_pool, exp_name): print("========== task_running ==========") run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using "task_train" method # This part corresponds to "Task Collecting" in the document -def task_collecting(): +def task_collecting(task_pool, exp_name): print("========== task_collecting ==========") - def get_task_key(task_config): + def get_group_key_func(recorder): + task_config = recorder.load_object("task") return task_config["model"]["class"] def my_filter(recorder): # only choose the results of "LGBModel" - task_key = get_task_key(rolling_collector.get_task(recorder)) + task_key = get_group_key_func(recorder) if task_key == "LGBModel": return True return False rolling_collector = RollingCollector(exp_name) # group tasks by "get_task_key" and filter tasks by "my_filter" - pred_rolling = rolling_collector.collect_rolling_predictions(get_task_key, my_filter) + pred_rolling = rolling_collector.collect(get_group_key_func, my_filter) print(pred_rolling) -if __name__ == "__main__": - - provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir +def main( + provider_uri="~/.qlib/qlib_data/cn_data", + task_url="mongodb://10.0.0.4:27017/", + task_db_name="rolling_db", + exp_name="rolling_exp", + task_pool="rolling_task", +): mongo_conf = { - "task_url": "mongodb://10.0.0.4:27017/", # maybe you need to change it to your url - "task_db_name": "rolling_db", + "task_url": task_url, + "task_db_name": task_db_name, } - exp_name = "rolling_exp" # experiment name, will be used as the experiment in MLflow - task_pool = "rolling_task" # task pool name, will be used as the document in MongoDB qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf) - reset() + reset(task_pool, exp_name) tasks = task_generating() - task_storing(tasks) - task_running() - task_collecting() + task_storing(tasks, task_pool, exp_name) + task_running(task_pool, exp_name) + task_collecting(task_pool, exp_name) + + +if __name__ == "__main__": + fire.Fire() diff --git a/examples/taskmanager/task_manager_rolling_with_updating.py b/examples/online_svr/task_manager_rolling_with_updating.py similarity index 63% rename from examples/taskmanager/task_manager_rolling_with_updating.py rename to examples/online_svr/task_manager_rolling_with_updating.py index 27e3ad269..24bc38a02 100644 --- a/examples/taskmanager/task_manager_rolling_with_updating.py +++ b/examples/online_svr/task_manager_rolling_with_updating.py @@ -1,16 +1,15 @@ -import qlib -import fire -import mlflow -from qlib.config import C -from qlib.workflow import R from pprint import pprint + +import fire +import qlib from qlib.config import REG_CN from qlib.model.trainer import task_train -from qlib.workflow.task.manage import run_task -from qlib.workflow.task.manage import TaskManager +from qlib.workflow import R from qlib.workflow.task.collect import RollingCollector from qlib.workflow.task.gen import RollingGen, task_generator +from qlib.workflow.task.manage import TaskManager, run_task from qlib.workflow.task.online import RollingOnlineManager +from qlib.workflow.task.utils import list_recorders data_handler_config = { "start_time": "2013-01-01", @@ -70,12 +69,15 @@ task_xgboost_config = { def print_online_model(): + print("========== print_online_model ==========") print("Current 'online' model:") - for online in rolling_online_manager.list_online_model().values(): - print(online.info["id"]) + for rid, rec in list_recorders(exp_name).items(): + if rolling_online_manager.get_online_tag(rec) == rolling_online_manager.ONLINE_TAG: + print(rid) print("Current 'next online' model:") - for online in rolling_online_manager.list_next_online_model().values(): - print(online.info["id"]) + for rid, rec in list_recorders(exp_name).items(): + if rolling_online_manager.get_online_tag(rec) == rolling_online_manager.NEXT_ONLINE_TAG: + print(rid) # This part corresponds to "Task Generating" in the document @@ -110,119 +112,76 @@ def task_running(): def task_collecting(): print("========== task_collecting ==========") - def get_task_key(task_config): + def get_group_key_func(recorder): + task_config = recorder.load_object("task") return task_config["model"]["class"] def my_filter(recorder): # only choose the results of "LGBModel" - task_key = get_task_key(rolling_collector.get_task(recorder)) + task_key = get_group_key_func(recorder) if task_key == "LGBModel": return True return False rolling_collector = RollingCollector(exp_name) # group tasks by "get_task_key" and filter tasks by "my_filter" - pred_rolling = rolling_collector.collect_rolling_predictions(get_task_key, my_filter) + pred_rolling = rolling_collector.collect(get_group_key_func, my_filter) print(pred_rolling) # Reset all things to the first status, be careful to save important data -def reset(force_end=False): +def reset(): print("========== reset ==========") task_manager.remove() - for error in task_manager.query(): - assert False - exp = R.get_exp(experiment_name=exp_name) - recs = exp.list_recorders() - - for rid in recs: + exp, _ = R.exp_manager._get_or_create_exp(experiment_name=exp_name) + for rid in exp.list_recorders(): exp.delete_recorder(rid) - try: - if force_end: - mlflow.end_run() - except Exception: - pass - # Run this firstly to see the workflow in Task Management def first_run(): print("========== first_run ==========") - reset(force_end=True) + reset() tasks = task_generating() task_storing(tasks) task_running() task_collecting() - rolling_online_manager.set_latest_model_to_next_online() - rolling_online_manager.reset_online_model() - - -# Update the predictions of online model -def update_predictions(): - print("========== update_predictions ==========") - rolling_online_manager.update_online_pred() - task_collecting() - # if there are some next_online_model, then online them. if no, still use current online_model. - print_online_model() - rolling_online_manager.reset_online_model() - print_online_model() - - -# Update the models using the latest date and set them to online model -def update_model(): - print("========== update_model ==========") - rolling_online_manager.prepare_new_models() - print_online_model() - rolling_online_manager.set_latest_model_to_next_online() - print_online_model() + latest_rec, _ = rolling_online_manager.list_latest_recorders() + rolling_online_manager.reset_online_tag(latest_rec.values()) def after_day(): - rolling_online_manager.prepare_signals() - update_model() - update_predictions() - - -# Run whole workflow completely -def whole_workflow(): - print("========== whole_workflow ==========") - # run this at the first time - first_run() - # run this every day after trading - after_day() + print("========== after_day ==========") + print_online_model() + rolling_online_manager.after_day() + print_online_model() + task_collecting() if __name__ == "__main__": ####### to train the first version's models, use the command below # python task_manager_rolling_with_updating.py first_run - ####### to update the models using the latest date, use the command below - # python task_manager_rolling_with_updating.py update_model - - ####### to update the predictions to the latest date, use the command below - # python task_manager_rolling_with_updating.py update_predictions - - ####### to run whole workflow completely, use the command below - # python task_manager_rolling_with_updating.py whole_workflow + ####### to update the models and predictions after the trading time, use the command below + # python task_manager_rolling_with_updating.py after_day #################### you need to finish the configurations below ######################### provider_uri = "~/.qlib/qlib_data/cn_data" # data_dir - qlib.init(provider_uri=provider_uri, region=REG_CN) - - C["mongo"] = { + mongo_conf = { "task_url": "mongodb://10.0.0.4:27017/", # your MongoDB url - "task_db_name": "online", # database name + "task_db_name": "rolling_db", # database name } + qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf) exp_name = "rolling_exp" # experiment name, will be used as the experiment in MLflow task_pool = "rolling_task" # task pool name, will be used as the document in MongoDB rolling_step = 550 ########################################################################################## - rolling_gen = RollingGen(step=550, rtype=RollingGen.ROLL_SD) + rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD) rolling_online_manager = RollingOnlineManager( experiment_name=exp_name, rolling_gen=rolling_gen, task_pool=task_pool ) diff --git a/examples/taskmanager/update_online_pred.py b/examples/online_svr/update_online_pred.py similarity index 78% rename from examples/taskmanager/update_online_pred.py rename to examples/online_svr/update_online_pred.py index 5ce963fbc..ac86b48e8 100644 --- a/examples/taskmanager/update_online_pred.py +++ b/examples/online_svr/update_online_pred.py @@ -1,9 +1,9 @@ -import qlib -from qlib.model.trainer import task_train -from qlib.workflow.task.online import OnlineManager -from qlib.config import REG_CN import fire -from qlib.workflow import R +import qlib +from qlib.config import REG_CN +from qlib.model.trainer import task_train +from qlib.workflow.task.online import OnlineManagerR +from qlib.workflow.task.utils import list_recorders data_handler_config = { "start_time": "2008-01-01", @@ -56,19 +56,20 @@ def first_train(experiment_name="online_svr"): rid = task_train(task_config=task, experiment_name=experiment_name) - rom = OnlineManager(experiment_name) - rom.reset_online_model(rid) + online_manager = OnlineManagerR(experiment_name) + online_manager.reset_online_tag(rid) def update_online_pred(experiment_name="online_svr"): - rom = OnlineManager(experiment_name) + online_manager = OnlineManagerR(experiment_name) print("Here are the online models waiting for update:") - for rid, rec in rom.list_online_model().items(): - print(rid) + for rid, rec in list_recorders(experiment_name).items(): + if online_manager.get_online_tag(rec) == OnlineManagerR.ONLINE_TAG: + print(rid) - rom.update_online_pred() + online_manager.update_online_pred() if __name__ == "__main__": diff --git a/qlib/config.py b/qlib/config.py index b245cc1df..95fdaf645 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -134,7 +134,7 @@ _default_config = { }, "loggers": {"qlib": {"level": "DEBUG", "handlers": ["console"]}}, }, - # Defatult config for experiment manager + # Default config for experiment manager "exp_manager": { "class": "MLflowExpManager", "module_path": "qlib.workflow.expm", @@ -143,6 +143,11 @@ _default_config = { "default_exp_name": "Experiment", }, }, + # Default config for MongoDB + "mongo": { + "task_url": "mongodb://localhost:27017/", + "task_db_name": "default_task_db", + } } MODE_CONF = { diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index c18145073..60f56609f 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -27,6 +27,7 @@ def task_train(task_config: dict, experiment_name: str) -> str: model = init_instance_by_config(task_config["model"]) dataset = init_instance_by_config(task_config["dataset"]) datahandler = dataset.handler + dataset.config(exclude=["handler"]) # start exp with R.start(experiment_name=experiment_name): @@ -37,10 +38,8 @@ def task_train(task_config: dict, experiment_name: str) -> str: recorder = R.get_recorder() R.save_objects(**{"params.pkl": model}) R.save_objects(**{"task": task_config}) # keep the original format and datatype - - artifact_uri = recorder.get_artifact_uri()[7:] # delete "file://" - dataset.to_pickle(artifact_uri + "/dataset", exclude=["handler"]) - datahandler.to_pickle(artifact_uri + "/datahandler") + R.save_objects(**{"dataset": dataset}) + R.save_objects(**{"datahandler": datahandler}) # generate records: prediction, backtest, and analysis records = task_config.get("record", []) diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index fb7ff0b0b..0a007cc5c 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -1,116 +1,172 @@ -from qlib.workflow import R +from abc import abstractmethod +from typing import Callable, Union + import pandas as pd -from typing import Union -from typing import Callable - from qlib import get_module_logger +from qlib.workflow.task.utils import list_recorders -class TaskCollector: +class Collector: """ - Collect the record (or its results) of the tasks + This class will divide disorderly records or anything worth collecting into different groups based on the group_key. + After grouping, we can reduce the useful information from different groups. + """ + + def group(self, *args, **kwargs): + """ + According to the get_group_key_func, divide disorderly things into different groups. + + For example: + + .. code-block:: python + + input: + [thing1, thing2, thing3, thing4, thing5] + + output: + { + "group_name1": [thing3, thing5, thing1] + "group_name2": [thing2, thing4] + } + + Args: + get_group_key_func (Callable): get a group key based on a thing + things_list (list): a list of things + + Returns: + dict: a dict including the group key and members of the group. + + """ + raise NotImplementedError(f"Please implement the `group` method.") + + def reduce(self, things_group: dict): + """ + Using the dict from `group`, reduce useful information. + + Args: + things_group (dict): a dict after grouping + + Returns: + dict: a dict including the group key, the information key and the information value + + """ + raise NotImplementedError(f"Please implement the `reduce` method.") + + def collect(self, *args, **kwargs): + """group and reduce + + Returns: + dict: a dict including the group key, the information key and the information value + """ + grouped = self.group(*args, **kwargs) + return self.reduce(grouped) + + +class RecorderCollector(Collector): + """ + The Recorder's Collector. This class is a implementation of Collector, collecting some artifacts saved by Recorder. """ def __init__(self, experiment_name: str) -> None: self.exp_name = experiment_name - self.exp = R.get_exp(experiment_name=experiment_name) - self.logger = get_module_logger("TaskCollector") + self.logger = get_module_logger(self.__class__.__name__) - def list_recorders(self, rec_filter_func=None): + _artifacts_key_path = {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"} + _artifacts_key_merge_method = {} - recs = self.exp.list_recorders() - recs_flt = {} - for rid, rec in recs.items(): - if rec_filter_func is None or rec_filter_func(rec): - recs_flt[rid] = rec + def default_merge(self, artifact_list): + """Merge disorderly artifacts in artifact list. - return recs_flt + Args: + artifact_list (list): A artifact list. - def list_recorders_by_task(self, task_filter_func=None): - def rec_filter(recorder): - return task_filter_func(self.get_task(recorder)) - - return self.list_recorders(rec_filter) - - def list_latest_recorders(self, rec_filter_func=None): - recs_flt = self.list_recorders(rec_filter_func) - max_test = self.latest_time(recs_flt) - latest_rec = {} - for rid, rec in recs_flt.items(): - if self.get_task(rec)["dataset"]["kwargs"]["segments"]["test"] == max_test: - latest_rec[rid] = rec - return latest_rec - - def get_recorder_by_id(self, recorder_id): - return self.exp.get_recorder(recorder_id, create=False) - - def get_task(self, recorder): - if isinstance(recorder, str): - recorder = self.get_recorder_by_id(recorder_id=recorder) - try: - task = recorder.load_object("task") - except OSError: - raise OSError(f"Can't find task in {recorder.info['id']}, have you trained with model.trainer.task_train?") - return task - - def latest_time(self, recorders): - if len(recorders) == 0: - raise Exception(f"Can't find any recorder in {self.exp_name}") - max_test = max(self.get_task(rec)["dataset"]["kwargs"]["segments"]["test"] for rec in recorders.values()) - return max_test - - -class RollingCollector(TaskCollector): - """ - Collect the record results of the rolling tasks - """ - - def __init__( - self, - experiment_name: str, - ) -> None: - super().__init__(experiment_name) - self.logger = get_module_logger("RollingCollector") - - def collect_rolling_predictions(self, get_key_func, rec_filter_func=None): - """For rolling tasks, the predictions will be in the diffierent recorder. - To collect and concat the predictions of one rolling task, get_key_func will help this method see which group a recorder will be. - - Parameters - ---------- - get_key_func : Callable[dict,str] - a function that get task config and return its group str - rec_filter_func : Callable[Recorder,bool], optional - a function that decide whether filter a recorder, by default None - - Returns - ------- - dict - a dict of {group: predictions} + Raises: + NotImplementedError: [description] """ + raise NotImplementedError(f"Please implement the `default_merge` method.") + def group(self, get_group_key_func, rec_filter_func=None): + """ + Filter recorders and group recorders by group key. + + Args: + get_group_key_func (Callable): get a group key based on a recorder + rec_filter_func (Callable, optional): filter the recorders in this experiment. Defaults to None. + + Returns: + dict: a dict including the group key and recorders of the group + """ # filter records - recs_flt = self.list_recorders(rec_filter_func) + recs_flt = list_recorders(self.exp_name, rec_filter_func) # group recs_group = {} for _, rec in recs_flt.items(): - task = self.get_task(rec) - group_key = get_key_func(task) + group_key = get_group_key_func(rec) recs_group.setdefault(group_key, []).append(rec) - # reduce group - reduce_group = {} - for k, rec_l in recs_group.items(): - pred_l = [] - for rec in rec_l: - pred_l.append(rec.load_object("pred.pkl").iloc[:, 0]) - # Make sure the pred are sorted according to the rolling start time - pred_l.sort(key=lambda pred: pred.index.get_level_values("datetime").min()) - pred = pd.concat(pred_l) - # If there are duplicated predition, we use the latest perdiction - pred = pred[~pred.index.duplicated(keep="last")] - pred = pred.sort_index() - reduce_group[k] = pred + return recs_group - return reduce_group \ No newline at end of file + def reduce(self, recs_group: dict, artifact_keys_list: list = None): + """ + Reduce artifacts based on the dict of grouped recorder. + The artifacts need be declared by artifact_keys_list. + The artifacts path in recorder need be declared by _artifacts_key_path. + If there is no declartion in _artifacts_key_merge_method, then use default_merge method to merge it. + + Args: + recs_group (dict): The dict grouped by `group` + artifact_keys_list (list): The list of artifact keys. If it is None, then use all artifacts in _artifacts_key_path. + + Returns: + a dict including the group key, the artifact key and the artifact value. + + For example: + + .. code-block:: python + + { + group_key: {"pred": , "IC": } + } + """ + if artifact_keys_list == None: + artifact_keys_list = self._artifacts_key_path.keys() + reduce_group = {} + for group_key, recorder_list in recs_group.items(): + reduced_artifacts = {} + for artifact_key in artifact_keys_list: + artifact_list = [] + for recorder in recorder_list: + artifact_list.append(recorder.load_object(self._artifacts_key_path[artifact_key])) + merge_method = self._artifacts_key_merge_method.get(artifact_key, self.default_merge) + artifact = merge_method(artifact_list) + reduced_artifacts[artifact_key] = artifact + reduce_group[group_key] = reduced_artifacts + return reduce_group + + +class RollingCollector(RecorderCollector): + """ + Collect the record results of the rolling tasks + """ + + def __init__(self, experiment_name: str): + super().__init__(experiment_name) + self.logger = get_module_logger(self.__class__.__name__) + + def default_merge(self, artifact_list): + """merge disorderly artifacts based on the datetime. + + Args: + artifact_list (list): a list of artifacts from different recorders + + Returns: + merged artifact + """ + # Make sure the pred are sorted according to the rolling start time + artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min()) + artifact = pd.concat(artifact_list) + # If there are duplicated predition, we use the latest perdiction + artifact = artifact[~artifact.index.duplicated(keep="last")] + artifact = artifact.sort_index() + return artifact diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index 63000d77d..1d363d7f1 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -19,10 +19,10 @@ def task_generator(tasks, generators) -> list: Parameters ---------- - tasks : List[dict] - a list of task templates - generators : List[TaskGen] - a list of TaskGen + tasks : List[dict] or dict + a list of task templates or a single task + generators : List[TaskGen] or TaskGen + a list of TaskGen or a single TaskGen Returns ------- diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index db4c15038..6e9fa6571 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -151,7 +151,8 @@ class TaskManager: if print new task Returns ------- - + int + the length of new tasks """ task_pool = self._get_task_pool(task_pool) new_tasks = [] @@ -173,6 +174,8 @@ class TaskManager: for t in new_tasks: self.insert_task_def(t, task_pool) + + return len(new_tasks) def fetch_task(self, query={}, task_pool=None): task_pool = self._get_task_pool(task_pool) @@ -245,10 +248,9 @@ class TaskManager: for t in task_pool.find(query): yield self._decode_task(t) - def get_task_result(self, task, task_pool=None): + def re_query(self, task, task_pool=None): task_pool = self._get_task_pool(task_pool) - result = task_pool.find_one({"filter": task}) - return self._decode_task(result)["res"] + return task_pool.find_one({"_id":ObjectId(task["_id"])}) def commit_task_res(self, task, res, status=None, task_pool=None): task_pool = self._get_task_pool(task_pool) diff --git a/qlib/workflow/task/online.py b/qlib/workflow/task/online.py index 8d551e858..d23fc88c8 100644 --- a/qlib/workflow/task/online.py +++ b/qlib/workflow/task/online.py @@ -3,147 +3,140 @@ from qlib import get_module_logger from qlib.workflow import R from qlib.model.trainer import task_train from qlib.workflow.recorder import MLflowRecorder, Recorder -from qlib.workflow.task.collect import TaskCollector from qlib.workflow.task.update import ModelUpdater from qlib.workflow.task.utils import TimeAdjuster from qlib.workflow.task.gen import RollingGen, task_generator from qlib.workflow.task.manage import TaskManager from qlib.workflow.task.manage import run_task +from qlib.workflow.task.utils import list_recorders +from qlib.utils.serial import Serializable -class OnlineManager: - def prepare_new_models(self, tasks: List[dict]): - """prepare(train) new models +class OnlineManager(Serializable): - Parameters - ---------- - tasks : List[dict] - a list of tasks - - """ - raise NotImplementedError(f"Please implement the `prepare_new_models` method.") - - ONLINE_KEY = "online_status" # the tag key in recorder + ONLINE_KEY = "online_status" # the online status key in recorder ONLINE_TAG = "online" # the 'online' model NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model OFFLINE_TAG = "offline" # the 'offline' model, not for online serving + def prepare_signals(self, *args, **kwargs): + raise NotImplementedError(f"Please implement the `prepare_signals` method.") + + def prepare_tasks(self, *args, **kwargs): + raise NotImplementedError(f"Please implement the `prepare_tasks` method.") + + def prepare_new_models(self, *args, **kwargs): + raise NotImplementedError(f"Please implement the `prepare_new_models` method.") + + def update_online_pred(self, *args, **kwargs): + raise NotImplementedError(f"Please implement the `update_online_pred` method.") + + def set_online_tag(self, tag, *args, **kwargs): + raise NotImplementedError(f"Please implement the `set_online_tag` method.") + + def get_online_tag(self, *args, **kwargs): + raise NotImplementedError(f"Please implement the `get_online_tag` method.") + + +class OnlineManagerR(OnlineManager): + """ + The implementation of OnlineManager based on (R)ecorder. + + """ + def __init__(self, experiment_name: str) -> None: - """ModelUpdater needs experiment name to find the records - - Parameters - ---------- - experiment_name : str - experiment name string - """ - self.logger = get_module_logger("OnlineManagement") + self.logger = get_module_logger(self.__class__.__name__) self.exp_name = experiment_name - self.tc = TaskCollector(experiment_name) - def set_next_online_model(self, recorder: MLflowRecorder): - recorder.set_tags(**{self.ONLINE_KEY: self.NEXT_ONLINE_TAG}) + def set_online_tag(self, tag, recorder: Union[Recorder, List]): + if isinstance(recorder, Recorder): + recorder = [recorder] + for rec in recorder: + rec.set_tags(**{self.ONLINE_KEY: tag}) + self.logger.info(f"Set {len(recorder)} models to '{tag}'.") - def set_online_model(self, recorder: MLflowRecorder): - """online model will be identified at the tags of the record""" - recorder.set_tags(**{self.ONLINE_KEY: self.ONLINE_TAG}) + def get_online_tag(self, recorder: Recorder): + tags = recorder.list_tags() + return tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG) - def set_offline_model(self, recorder: MLflowRecorder): - recorder.set_tags(**{self.ONLINE_KEY: self.OFFLINE_TAG}) - - def offline_all_model(self): - recs = self.tc.list_recorders() - for rid, rec in recs.items(): - self.set_offline_model(rec) - - def reset_online_model(self, recorders: Union[List, Dict] = None): + def reset_online_tag(self, recorder: Union[Recorder, List] = None): """offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing. Args: recorders (Union[List, Dict], optional): the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model. """ - if recorders is None: - recorders = self.list_next_online_model() - if len(recorders) == 0: + if recorder is None: + recorder = list_recorders( + self.exp_name, lambda rec: self.get_online_tag(rec) == OnlineManager.NEXT_ONLINE_TAG + ).values() + if isinstance(recorder, Recorder): + recorder = [recorder] + if len(recorder) == 0: self.logger.info("No 'next online' model, just use current 'online' models.") return - self.offline_all_model() - if isinstance(recorders, dict): - recorders = recorders.values() - for rec in recorders: - self.set_online_model(rec) - self.logger.info(f"Reset {len(recorders)} models to 'online'.") - - def set_latest_model_to_next_online(self): - latest_rec = self.tc.list_latest_recorders() - for rid, rec in latest_rec.items(): - self.set_next_online_model(rec) - self.logger.info(f"Set {len(latest_rec)} latest models to 'next online'.") - - @staticmethod - def online_filter(recorder): - tags = recorder.list_tags() - if tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG) == OnlineManager.ONLINE_TAG: - return True - return False - - @staticmethod - def next_online_filter(recorder): - tags = recorder.list_tags() - if tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG) == OnlineManager.NEXT_ONLINE_TAG: - return True - return False - - def list_online_model(self): - """list the record of online model - - Returns - ------- - dict - {rid : recorder of the online model} - """ - - return self.tc.list_recorders(rec_filter_func=self.online_filter) - - def list_next_online_model(self): - return self.tc.list_recorders(rec_filter_func=self.next_online_filter) + recs = list_recorders(self.exp_name) + self.set_online_tag(OnlineManager.OFFLINE_TAG, recs.values()) + self.set_online_tag(OnlineManager.ONLINE_TAG, recorder) + self.logger.info(f"Reset {len(recorder)} models to 'online'.") def update_online_pred(self): """update all online model predictions to the latest day in Calendar""" mu = ModelUpdater(self.exp_name) - cnt = mu.update_all_pred(self.online_filter) + cnt = mu.update_all_pred(lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG) self.logger.info(f"Finish updating {cnt} online model predictions of {self.exp_name}.") + def after_day(self, *args, **kwargs): + self.prepare_signals(*args, **kwargs) + self.prepare_tasks(*args, **kwargs) + self.prepare_new_models(*args, **kwargs) + self.update_online_pred(*args, **kwargs) + self.reset_online_tag() -class RollingOnlineManager(OnlineManager): + +class RollingOnlineManager(OnlineManagerR): def __init__(self, experiment_name: str, rolling_gen: RollingGen, task_pool) -> None: super().__init__(experiment_name) self.ta = TimeAdjuster() self.rg = rolling_gen self.tm = TaskManager(task_pool=task_pool) - self.logger = get_module_logger("RollingOnlineManager") + self.logger = get_module_logger(self.__class__.__name__) - def prepare_new_models(self): - """prepare(train) new models based on online model""" - latest_records = self.tc.list_latest_recorders(self.online_filter) # if we need online_filter here? - max_test = self.tc.latest_time(latest_records) + def prepare_signals(self): + pass + + def prepare_tasks(self): + latest_records, max_test = self.list_latest_recorders(lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG) + if max_test is None: + self.logger.warn(f"No latest_recorders.") + return calendar_latest = self.ta.last_date() if self.ta.cal_interval(calendar_latest, max_test[0]) > self.rg.step: old_tasks = [] for rid, rec in latest_records.items(): - task = self.tc.get_task(rec) + task = rec.load_object("task") test_begin = task["dataset"]["kwargs"]["segments"]["test"][0] # modify the test segment to generate new tasks task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest) old_tasks.append(task) new_tasks = task_generator(old_tasks, self.rg) - self.tm.create_task(new_tasks) - run_task(task_train, self.tm.task_pool, experiment_name=self.exp_name) - self.logger.info(f"Finished prepare {len(new_tasks)} new models.") - return new_tasks - self.logger.info("No need to prepare any new models.") - return [] + new_num = self.tm.create_task(new_tasks) + self.logger.info(f"Finished prepare {new_num} tasks.") - def prepare_signals(self): - # prepare the signals of today - pass + def prepare_new_models(self): + """prepare(train) new models based on online model""" + run_task(task_train, self.tm.task_pool, experiment_name=self.exp_name) + latest_records, _ = self.list_latest_recorders() + self.set_online_tag(OnlineManager.NEXT_ONLINE_TAG, latest_records.values()) + self.logger.info(f"Finished prepare {len(latest_records)} new models and set them to next_online.") + + def list_latest_recorders(self, rec_filter_func=None): + recs_flt = list_recorders(self.exp_name, rec_filter_func) + if len(recs_flt) == 0: + return recs_flt, None + max_test = max(rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in recs_flt.values()) + latest_rec = {} + for rid, rec in recs_flt.items(): + if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test: + latest_rec[rid] = rec + return latest_rec, max_test \ No newline at end of file diff --git a/qlib/workflow/task/update.py b/qlib/workflow/task/update.py index fcee84349..b8190bca0 100644 --- a/qlib/workflow/task/update.py +++ b/qlib/workflow/task/update.py @@ -6,8 +6,7 @@ from qlib import get_module_logger from qlib.workflow import R from qlib.model.trainer import task_train from qlib.workflow.recorder import Recorder -from qlib.workflow.task.collect import TaskCollector - +from qlib.workflow.task.utils import list_recorders class ModelUpdater: """ @@ -23,8 +22,7 @@ class ModelUpdater: experiment name string """ self.exp_name = experiment_name - self.logger = get_module_logger("ModelUpdater") - self.tc = TaskCollector(experiment_name) + self.logger = get_module_logger(self.__class__.__name__) def _reload_dataset(self, recorder, start_time, end_time): """reload dataset from pickle file @@ -53,7 +51,7 @@ class ModelUpdater: datahandler.init(datahandler.IT_LS) return dataset - def update_pred(self, recorder: Recorder): + def update_pred(self, recorder: Recorder, frequency='day'): """update predictions to the latest day in Calendar based on rid Parameters @@ -65,7 +63,10 @@ class ModelUpdater: last_end = old_pred.index.get_level_values("datetime").max() # updated to the latest trading day - cal = D.calendar(start_time=last_end + pd.Timedelta(days=1), end_time=None) + if frequency=='day': + cal = D.calendar(start_time=last_end + pd.Timedelta(days=1), end_time=None) + else: + raise NotImplementedError("Now Qlib only support update daily frequency prediction") if len(cal) == 0: self.logger.info( @@ -113,7 +114,7 @@ class ModelUpdater: the count of updated record """ - recs = self.tc.list_recorders(rec_filter_func=rec_filter_func) + recs = list_recorders(self.exp_name, rec_filter_func=rec_filter_func) for rid, rec in recs.items(): self.update_pred(rec) return len(recs) diff --git a/qlib/workflow/task/utils.py b/qlib/workflow/task/utils.py index 272f219ec..15123a291 100644 --- a/qlib/workflow/task/utils.py +++ b/qlib/workflow/task/utils.py @@ -3,6 +3,7 @@ import bisect import pandas as pd from qlib.data import D +from qlib.workflow import R from qlib.config import C from qlib.log import get_module_logger from pymongo import MongoClient @@ -29,6 +30,25 @@ def get_mongodb(): client = MongoClient(cfg["task_url"]) return client.get_database(name=cfg["task_db_name"]) +def list_recorders(experiment, rec_filter_func=None): + """list all recorders which can pass the filter in a experiment. + + Args: + experiment (str or Experiment): the name of a Experiment or a instance + rec_filter_func (Callable, optional): return True to retain the given recorder. Defaults to None. + + Returns: + dict: a dict {rid: recorder} after filtering. + """ + if isinstance(experiment, str): + experiment, _ = R.exp_manager._get_or_create_exp(experiment_name=experiment) + recs = experiment.list_recorders() + recs_flt = {} + for rid, rec in recs.items(): + if rec_filter_func is None or rec_filter_func(rec): + recs_flt[rid] = rec + + return recs_flt class TimeAdjuster: """