diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py index 4508a8788..9c1cbf891 100644 --- a/examples/model_rolling/task_manager_rolling.py +++ b/examples/model_rolling/task_manager_rolling.py @@ -15,6 +15,11 @@ from qlib.workflow.task.utils import list_recorders from qlib.model.ens.group import RollingGroup from qlib.model.trainer import TrainerRM +""" +This example shows how a Trainer work based on TaskManager with rolling tasks. +After training, how to collect the rolling results will be showed in task_collecting. +""" + data_handler_config = { "start_time": "2008-01-01", "end_time": "2020-08-01", @@ -71,81 +76,83 @@ task_xgboost_config = { "record": record_config, } -# Reset all things to the first status, be careful to save important data -def reset(task_pool, exp_name): - print("========== reset ==========") - TaskManager(task_pool=task_pool).remove() - exp = R.get_exp(experiment_name=exp_name) +class RollingTaskExample: + def __init__( + self, + provider_uri="~/.qlib/qlib_data/cn_data", + region=REG_CN, + task_url="mongodb://10.0.0.4:27017/", + task_db_name="rolling_db", + experiment_name="rolling_exp", + task_pool="rolling_task", + task_config=[task_xgboost_config, task_lgb_config], + rolling_step=550, + rolling_type=RollingGen.ROLL_SD, + ): + # TaskManager config + mongo_conf = { + "task_url": task_url, + "task_db_name": task_db_name, + } + qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf) + self.experiment_name = experiment_name + self.task_pool = task_pool + self.task_config = task_config + self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type) - for rid in exp.list_recorders(): - exp.delete_recorder(rid) + # Reset all things to the first status, be careful to save important data + def reset(self): + print("========== reset ==========") + TaskManager(task_pool=self.task_pool).remove() + exp = R.get_exp(experiment_name=self.experiment_name) + for rid in exp.list_recorders(): + exp.delete_recorder(rid) + def task_generating(self): + print("========== task_generating ==========") + tasks = task_generator( + tasks=self.task_config, + generators=self.rolling_gen, # generate different date segments + ) + pprint(tasks) + return tasks -# This part corresponds to "Task Generating" in the document -def task_generating(): + def task_training(self, tasks): + print("========== task_training ==========") + trainer = TrainerRM(self.experiment_name, self.task_pool) + trainer.train(tasks) - print("========== task_generating ==========") + def task_collecting(self): + print("========== task_collecting ==========") - tasks = task_generator( - tasks=[task_xgboost_config, task_lgb_config], - generators=RollingGen(step=550, rtype=RollingGen.ROLL_SD), # generate different date segment - ) + def rec_key(recorder): + task_config = recorder.load_object("task") + model_key = task_config["model"]["class"] + rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"] + return model_key, rolling_key - pprint(tasks) + def my_filter(recorder): + # only choose the results of "LGBModel" + model_key, rolling_key = rec_key(recorder) + if model_key == "LGBModel": + return True + return False - return tasks + artifact = ens_workflow( + RecorderCollector(exp_name=self.experiment_name, rec_key_func=rec_key, rec_filter_func=my_filter), + RollingGroup(), + ) + print(artifact) - -def task_training(tasks, task_pool, exp_name): - trainer = TrainerRM(exp_name, task_pool) - trainer.train(tasks) - - -# This part corresponds to "Task Collecting" in the document -def task_collecting(exp_name): - print("========== task_collecting ==========") - - def rec_key(recorder): - task_config = recorder.load_object("task") - model_key = task_config["model"]["class"] - rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"] - return model_key, rolling_key - - def my_filter(recorder): - # only choose the results of "LGBModel" - model_key, rolling_key = rec_key(recorder) - if model_key == "LGBModel": - return True - return False - - artifact = ens_workflow( - RecorderCollector(exp_name=exp_name, rec_key_func=rec_key, rec_filter_func=my_filter), - RollingGroup(), - ) - print(artifact) - - -def main( - provider_uri="~/.qlib/qlib_data/cn_data", - task_url="mongodb://10.0.0.4:27017/", - task_db_name="rolling_db", - experiment_name="rolling_exp", - task_pool="rolling_task", -): - mongo_conf = { - "task_url": task_url, - "task_db_name": task_db_name, - } - qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf) - - reset(task_pool, experiment_name) - tasks = task_generating() - task_training(tasks, task_pool, experiment_name) - task_collecting(experiment_name) + def main(self): + self.reset() + tasks = self.task_generating() + self.task_training(tasks) + self.task_collecting() if __name__ == "__main__": ## to see the whole process with your own parameters, use the command below - # python update_online_pred.py main --experiment_name="your_exp_name" - fire.Fire() + # python task_manager_rolling.py main --experiment_name="your_exp_name" + fire.Fire(RollingTaskExample) diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py index d3a132879..9b5fbcc03 100644 --- a/examples/online_srv/online_management_simulate.py +++ b/examples/online_srv/online_management_simulate.py @@ -11,7 +11,7 @@ from qlib.workflow.task.gen import RollingGen, task_generator from qlib.workflow.task.manage import TaskManager """ -This examples is about the OnlineManager and OnlineSimulator based on Rolling tasks. +This examples is about the OnlineManager and OnlineSimulator based on rolling tasks. The OnlineManager will focus on the updating of your online models. The OnlineSimulator will focus on the simulating real updating routine of your online models. """ diff --git a/examples/online_srv/task_manager_rolling_with_updating.py b/examples/online_srv/rolling_online_management.py similarity index 52% rename from examples/online_srv/task_manager_rolling_with_updating.py rename to examples/online_srv/rolling_online_management.py index 076c1a467..6c30f3af3 100644 --- a/examples/online_srv/task_manager_rolling_with_updating.py +++ b/examples/online_srv/rolling_online_management.py @@ -1,18 +1,21 @@ -from pprint import pprint - +import os +from pathlib import Path +import pickle import fire import qlib -from qlib.config import REG_CN -from qlib.model.trainer import task_train from qlib.workflow import R -from qlib.workflow.task.collect import RecorderCollector -from qlib.model.ens.ensemble import RollingEnsemble, ens_workflow -from qlib.workflow.task.gen import RollingGen, task_generator -from qlib.workflow.task.manage import TaskManager, run_task +from qlib.workflow.task.gen import RollingGen +from qlib.workflow.task.manage import TaskManager from qlib.workflow.online.manager import RollingOnlineManager from qlib.workflow.task.utils import list_recorders from qlib.model.trainer import TrainerRM -from qlib.model.ens.group import RollingGroup + +""" +This example show how RollingOnlineManager works with rolling tasks. +There are two parts including first train and routine. +Firstly, the RollingOnlineManager will finish the first training and set trained models to `online` models. +Next, the RollingOnlineManager will finish a routine process, including update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models +""" data_handler_config = { "start_time": "2013-01-01", @@ -89,92 +92,38 @@ class RollingOnlineExample: "task_db_name": task_db_name, # database name } qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf) - - self.rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD) - self.trainer = TrainerRM(self.exp_name, self.task_pool) - self.task_manager = TaskManager(self.task_pool) self.rolling_online_manager = RollingOnlineManager( - experiment_name=exp_name, rolling_gen=self.rolling_gen, trainer=self.trainer + experiment_name=exp_name, + rolling_gen=RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD), + trainer=TrainerRM(self.exp_name, self.task_pool), ) - def print_online_model(self): - print("========== print_online_model ==========") - print("Current 'online' model:") - - for rec in self.rolling_online_manager.online_models(): - print(rec.info["id"]) - print("Current 'next online' model:") - for rid, rec in list_recorders(self.exp_name).items(): - if self.rolling_online_manager.get_online_tag(rec) == self.rolling_online_manager.NEXT_ONLINE_TAG: - print(rid) - - # This part corresponds to "Task Generating" in the document - def task_generating(self): - - print("========== task_generating ==========") - - tasks = task_generator( - tasks=[task_xgboost_config, task_lgb_config], - generators=self.rolling_gen, # generate different date segment - ) - - pprint(tasks) - - return tasks - - def task_training(self, tasks): - # self.trainer.train(tasks) - self.rolling_online_manager.prepare_new_models(tasks, tag=RollingOnlineManager.ONLINE_TAG) - - # This part corresponds to "Task Collecting" in the document - def task_collecting(self): - print("========== task_collecting ==========") - - def rec_key(recorder): - task_config = recorder.load_object("task") - model_key = task_config["model"]["class"] - rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"] - return model_key, rolling_key - - def my_filter(recorder): - # only choose the results of "LGBModel" - model_key, rolling_key = rec_key(recorder) - if model_key == "LGBModel": - return True - return False - - artifact = ens_workflow( - RecorderCollector(exp_name=self.exp_name, rec_key_func=rec_key, rec_filter_func=my_filter), RollingGroup() - ) - print(artifact) + _ROLLING_MANAGER_PATH = ".rolling_manager" # the RollingOnlineManager will dump to this file, for it will be loaded when calling routine. # Reset all things to the first status, be careful to save important data def reset(self): print("========== reset ==========") - self.task_manager.remove() + TaskManager(self.task_pool).remove() exp = R.get_exp(experiment_name=self.exp_name) for rid in exp.list_recorders(): exp.delete_recorder(rid) - # Run this firstly to see the workflow in Task Management + if os.path.exists(self._ROLLING_MANAGER_PATH): + os.remove(self._ROLLING_MANAGER_PATH) + def first_run(self): print("========== first_run ==========") self.reset() - - tasks = self.task_generating() - pprint(tasks) - self.task_training(tasks) - self.task_collecting() - - # latest_rec, _ = self.rolling_online_manager.list_latest_recorders() - # self.rolling_online_manager.reset_online_tag(list(latest_rec.values())) + self.rolling_online_manager.first_train([task_xgboost_config, task_lgb_config]) + self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH) + print(self.rolling_online_manager.collect_artifact()) def routine(self): print("========== routine ==========") - self.print_online_model() + with Path(self._ROLLING_MANAGER_PATH).open("rb") as f: + self.rolling_online_manager = pickle.load(f) self.rolling_online_manager.routine() - self.print_online_model() - self.task_collecting() + print(self.rolling_online_manager.collect_artifact()) def main(self): self.first_run() diff --git a/examples/online_srv/update_online_pred.py b/examples/online_srv/update_online_pred.py index 0f075abcd..ed2ad6997 100644 --- a/examples/online_srv/update_online_pred.py +++ b/examples/online_srv/update_online_pred.py @@ -5,6 +5,13 @@ from qlib.model.trainer import task_train from qlib.workflow.online.manager import OnlineManagerR from qlib.workflow.task.utils import list_recorders +""" +This example show how OnlineManager works when we need update prediction. +There are two parts including first_train and update_online_pred. +Firstly, the RollingOnlineManager will finish the first training and set the trained model to `online` model. +Next, the RollingOnlineManager will finish updating online prediction +""" + data_handler_config = { "start_time": "2008-01-01", "end_time": "2020-08-01", @@ -52,31 +59,25 @@ task = { } -def first_train(experiment_name="online_srv"): +class UpdatePredExample: + def __init__( + self, provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv", task_config=task + ): + qlib.init(provider_uri=provider_uri, region=region) + self.experiment_name = experiment_name + self.online_manager = OnlineManagerR(self.experiment_name) + self.task_config = task_config - rec = task_train(task_config=task, experiment_name=experiment_name) + def first_train(self): + rec = task_train(self.task_config, experiment_name=self.experiment_name) + self.online_manager.reset_online_tag(rec) # set to online model - online_manager = OnlineManagerR(experiment_name) - online_manager.reset_online_tag(rec) + def update_online_pred(self): + self.online_manager.update_online_pred() - -def update_online_pred(experiment_name="online_srv"): - - online_manager = OnlineManagerR(experiment_name) - - print("Here are the online models waiting for update:") - for rid, rec in list_recorders(experiment_name).items(): - if online_manager.get_online_tag(rec) == OnlineManagerR.ONLINE_TAG: - print(rid) - - online_manager.update_online_pred() - - -def main(provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv"): - provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir - qlib.init(provider_uri=provider_uri, region=region) - first_train(experiment_name) - update_online_pred(experiment_name) + def main(self): + self.first_train() + self.update_online_pred() if __name__ == "__main__": @@ -86,4 +87,4 @@ if __name__ == "__main__": # python update_online_pred.py update_online_pred ## to see the whole process with your own parameters, use the command below # python update_online_pred.py main --experiment_name="your_exp_name" - fire.Fire() + fire.Fire(UpdatePredExample) diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 2182497f5..348f6b521 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -135,3 +135,12 @@ class TrainerRM(Trainer): for _id in _id_list: recs.append(tm.re_query(_id)["res"]) return recs + + +class DelayTrainer(Trainer): + def fake_train(self): + self.fake_trained = [] + + def train(self): + for rec in self.fake_trained: + pass diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index f7f9b62d5..e74488040 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -1,16 +1,29 @@ from copy import deepcopy +from operator import index +import pandas as pd +from qlib.model.ens.ensemble import ens_workflow +from qlib.model.ens.group import RollingGroup +from qlib.utils.serial import Serializable from typing import Dict, List, Union from qlib import get_module_logger from qlib.data.data import D from qlib.model.trainer import Trainer, TrainerR, task_train +from qlib.workflow import R from qlib.workflow.online.update import PredUpdater from qlib.workflow.recorder import Recorder -from qlib.workflow.task.collect import Collector +from qlib.workflow.task.collect import Collector, RecorderCollector from qlib.workflow.task.gen import RollingGen, task_generator from qlib.workflow.task.utils import TimeAdjuster, list_recorders +""" +This class is a component of online serving, it can manage a series of models dynamically. +With the change of time, the decisive models will be also changed. In this module, we called those contributing models as `online` models. +In every routine(such as everyday or every minutes), the `online` models maybe changed and the prediction of them need to be updated. +So this module provide a series methods to control this process. +""" -class OnlineManager: + +class OnlineManager(Serializable): ONLINE_KEY = "online_status" # the online status key in recorder ONLINE_TAG = "online" # the 'online' model @@ -18,26 +31,28 @@ class OnlineManager: NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model OFFLINE_TAG = "offline" # the 'offline' model, not for online serving - def __init__(self, trainer: Trainer = None, collector: Collector = None, need_log=True): + SIGNAL_EXP = "OnlineManagerSignals" # a specific experiment to save signals of different experiment. + + def __init__(self, trainer: Trainer = None, need_log=True): """ init OnlineManager. Args: trainer (Trainer, optional): a instance of Trainer. Defaults to None. - collector (Collector, optional): a instance of Collector. Defaults to None. need_log (bool, optional): print log or not. Defaults to True. """ self.trainer = trainer self.logger = get_module_logger(self.__class__.__name__) self.need_log = need_log self.delay_signals = {} - self.collector = collector self.cur_time = None def prepare_signals(self, *args, **kwargs): """ After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine. + Must use `pass` even though there is nothing to do. """ + raise NotImplementedError(f"Please implement the `prepare_signals` method.") def prepare_tasks(self, *args, **kwargs): @@ -47,7 +62,7 @@ class OnlineManager: """ raise NotImplementedError(f"Please implement the `prepare_tasks` method.") - def prepare_new_models(self, tasks, tag=NEXT_ONLINE_TAG): + def prepare_new_models(self, tasks, tag=NEXT_ONLINE_TAG, check_func=None): """ Use trainer to train a list of tasks and set the trained model to `tag`. @@ -57,14 +72,20 @@ class OnlineManager: `ONLINE_TAG` for first train or additional train `NEXT_ONLINE_TAG` for reset online model when calling `reset_online_tag` `OFFLINE_TAG` for train but offline those models + check_func: the method to judge if a model can be online. + The parameter is the model record and return True for online. + None for online every models. + """ - # TODO: 回调 - if not (tasks is None or len(tasks) == 0): + if check_func is None: + check_func = lambda x: True + if len(tasks) > 0: if self.trainer is not None: new_models = self.trainer.train(tasks) - self.set_online_tag(tag, new_models) - if self.need_log: - self.logger.info(f"Finished prepare {len(new_models)} new models and set them to {tag}.") + if check_func(new_models): + self.set_online_tag(tag, new_models) + if self.need_log: + self.logger.info(f"Finished preparing {len(new_models)} new models and set them to {tag}.") else: self.logger.warn("No trainer to train new tasks.") @@ -101,6 +122,12 @@ class OnlineManager: """ raise NotImplementedError(f"Please implement the `online_models` method.") + def first_train(self): + """ + Train a series of models firstly and set some of them into online models. + """ + raise NotImplementedError(f"Please implement the `first_train` method.") + def get_collector(self): """ Return the collector. @@ -108,7 +135,7 @@ class OnlineManager: Returns: Collector """ - return self.collector + raise NotImplementedError(f"Please implement the `get_collector` method.") def run_delay_signals(self): """ @@ -122,9 +149,10 @@ class OnlineManager: def routine(self, cur_time=None, delay_prepare=False, *args, **kwargs): """ The typical update process after a routine, such as day by day or month by month. - Prepare signals -> prepare tasks -> prepare new models -> update online prediction -> reset online models + update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models """ self.cur_time = cur_time # None for latest date + self.update_online_pred() if not delay_prepare: self.prepare_signals(*args, **kwargs) else: @@ -134,7 +162,7 @@ class OnlineManager: raise ValueError("Can not delay prepare when cur_time is None") tasks = self.prepare_tasks(*args, **kwargs) self.prepare_new_models(tasks) - self.update_online_pred() + return self.reset_online_tag() @@ -144,19 +172,18 @@ class OnlineManagerR(OnlineManager): """ - def __init__(self, experiment_name: str, trainer: Trainer = None, collector: Collector = None, need_log=True): + def __init__(self, experiment_name: str, trainer: Trainer = None, need_log=True): """ init OnlineManagerR. Args: experiment_name (str): the experiment name. trainer (Trainer, optional): a instance of Trainer. Defaults to None. - collector (Collector, optional): a instance of Collector. Defaults to None. need_log (bool, optional): print log or not. Defaults to True. """ if trainer is None: trainer = TrainerR(experiment_name) - super().__init__(trainer=trainer, collector=collector, need_log=need_log) + super().__init__(trainer=trainer, need_log=need_log) self.exp_name = experiment_name def set_online_tag(self, tag, recorder: Union[Recorder, List]): @@ -212,7 +239,40 @@ class OnlineManagerR(OnlineManager): PredUpdater(rec, to_date=self.cur_time, need_log=self.need_log).update() if self.need_log: - self.logger.info(f"Finish updating {len(online_models)} online model predictions of {self.exp_name}.") + self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.") + + def prepare_signals(self, over_write=False): + """ + Average the predictions of online models and offer a trading signals every routine. + The signals will be saved to `signal` file of a recorder named self.exp_name of a experiment using the name of `SIGNAL_EXP` + + Args: + over_write (bool, optional): If True, the new signals will overwrite the file. If False, the new signals will append to the end of signals. Defaults to False. + """ + + with R.start(experiment_name=self.SIGNAL_EXP, recorder_name=self.exp_name, resume=True): + recorder = R.get_recorder() + pred = [] + + try: + old_signals = recorder.load_object("signals") + except OSError: + old_signals = None + + for rec in self.online_models(): + pred.append(rec.load_object("pred.pkl")) + + signals = pd.concat(pred, axis=1).mean(axis=1).to_frame("score") + signals = signals.sort_index() + if old_signals is not None and not over_write: + # signals = old_signals.reindex(signals.index).combine_first(signals) + old_max = old_signals.index.get_level_values("datetime").max() + new_signals = signals.loc[old_max:] + signals = pd.concat([old_signals, new_signals], axis=0) + else: + new_signals = signals + self.logger.info(f"Finished preparing new {len(new_signals)} signals to {self.SIGNAL_EXP}/{self.exp_name}.") + recorder.save_objects(**{"signals": signals}) class RollingOnlineManager(OnlineManagerR): @@ -223,7 +283,6 @@ class RollingOnlineManager(OnlineManagerR): experiment_name: str, rolling_gen: RollingGen, trainer: Trainer = None, - collector: Collector = None, need_log=True, ): """ @@ -238,24 +297,64 @@ class RollingOnlineManager(OnlineManagerR): """ if trainer is None: trainer = TrainerR(experiment_name) - super().__init__(experiment_name=experiment_name, trainer=trainer, collector=collector, need_log=need_log) + super().__init__(experiment_name=experiment_name, trainer=trainer, need_log=need_log) self.ta = TimeAdjuster() self.rg = rolling_gen self.logger = get_module_logger(self.__class__.__name__) - def prepare_signals(self, *args, **kwargs): + def get_collector(self, rec_key_func=None, rec_filter_func=None): """ - Average the online models prediction and save them into a recorder - + get the instance of collector to collect results - Must use `pass` even though there is nothing to do. + Args: + rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id. + rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None. """ - # 检查recorder是否存在,如果不存在就创建一个 - # 检查recorder的上一个信号时间,如果没有那就从上线模型的共同最早时间开始出信号 - # 从recorder的上一个信号时间开始出信号,出到self.cur_time - for model in self.online_models(): - pass + def rec_key(recorder): + task_config = recorder.load_object("task") + model_key = task_config["model"]["class"] + rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"] + return model_key, rolling_key + + if rec_key_func is None: + rec_key_func = rec_key + + return RecorderCollector(exp_name=self.exp_name, rec_key_func=rec_key_func, rec_filter_func=rec_filter_func) + + def collect_artifact(self, rec_key_func=None, rec_filter_func=None): + """ + collecting artifact based on the collector and RollingGroup. + + Args: + rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id. + rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None. + + Returns: + dict: the artifact dict after rolling ensemble + """ + artifact = ens_workflow( + self.get_collector(rec_key_func=rec_key_func, rec_filter_func=rec_filter_func), RollingGroup() + ) + return artifact + + def first_train(self, task_configs: list): + """ + Use rolling_gen to generate different tasks based on task_configs and trained them. + + Args: + task_configs (list or dict): a list of task configs or a task config + + Returns: + Collector: a instance of a Collector. + """ + tasks = task_generator( + tasks=task_configs, + generators=self.rg, # generate different date segment + ) + self.prepare_new_models(tasks, tag=self.ONLINE_TAG) + self.prepare_signals(over_write=True) + return self.get_collector() def prepare_tasks(self, *args, **kwargs): """ @@ -264,7 +363,6 @@ class RollingOnlineManager(OnlineManagerR): Returns: list: a list of new tasks. """ - #TODO: max_test = self.cur_time latest_records, max_test = self.list_latest_recorders( lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG ) diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index f651ef8d8..ef6a7a7d4 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -49,7 +49,7 @@ class RecorderCollector(Collector): if rec_key_func is None: rec_key_func = lambda rec: rec.info["id"] if artifacts_key is None: - artifacts_key = self.artifacts_path.keys() + artifacts_key = list(self.artifacts_path.keys()) self._rec_key_func = rec_key_func self.artifacts_key = artifacts_key self._rec_filter_func = rec_filter_func diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index ad7a16218..9e273b74f 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -194,6 +194,15 @@ class RollingGen(TaskGen): # update segments of this task t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments) + # if end_time < the end of test_segments, then change end_time to allow load more data + if ( + self.ta.cal_interval( + t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"], + t["dataset"]["kwargs"]["segments"][self.test_key][1], + ) + < 0 + ): + t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(segments[self.test_key][1]) prev_seg = segments res.append(t) return res