diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py index 3a87e01c4..48433c6d5 100644 --- a/examples/online_srv/online_management_simulate.py +++ b/examples/online_srv/online_management_simulate.py @@ -7,10 +7,10 @@ This examples is about how can simulate the OnlineManager based on rolling tasks import fire import qlib -from qlib.model.trainer import DelayTrainerRM +from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM from qlib.workflow import R from qlib.workflow.online.manager import OnlineManager -from qlib.workflow.online.strategy import RollingAverageStrategy +from qlib.workflow.online.strategy import RollingStrategy from qlib.workflow.task.gen import RollingGen from qlib.workflow.task.manage import TaskManager @@ -115,7 +115,7 @@ class OnlineSimulationExample: ) # The rolling tasks generator, ds_extra_mod_func is None because we just need simulate to 2018-10-31 and needn't change handler end time. self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) self.rolling_online_manager = OnlineManager( - RollingAverageStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen), + RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen), trainer=self.trainer, begin_time=self.start_time, ) diff --git a/examples/online_srv/rolling_online_management.py b/examples/online_srv/rolling_online_management.py index ebf1ab59a..e15daeb29 100644 --- a/examples/online_srv/rolling_online_management.py +++ b/examples/online_srv/rolling_online_management.py @@ -14,7 +14,7 @@ import pickle import fire import qlib from qlib.workflow import R -from qlib.workflow.online.strategy import RollingAverageStrategy +from qlib.workflow.online.strategy import RollingStrategy from qlib.workflow.task.gen import RollingGen from qlib.workflow.task.manage import TaskManager from qlib.workflow.online.manager import OnlineManager @@ -97,7 +97,7 @@ class RollingOnlineExample: for task in tasks: name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy strategy.append( - RollingAverageStrategy( + RollingStrategy( name_id, task, RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD), diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 021311d47..206561aed 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -172,7 +172,10 @@ class DatasetH(Dataset): ---------- slc : slice """ - return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs) + if hasattr(self, "fetch_kwargs"): + return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs) + else: + return self.handler.fetch(slc, **kwargs) def prepare( self, diff --git a/qlib/model/ens/ensemble.py b/qlib/model/ens/ensemble.py index a7b837ea5..6040517e2 100644 --- a/qlib/model/ens/ensemble.py +++ b/qlib/model/ens/ensemble.py @@ -7,7 +7,7 @@ Ensemble can merge the objects in an Ensemble. For example, if there are many su from typing import Union import pandas as pd -from qlib.utils import flatten_dict +from qlib.utils import FLATTEN_TUPLE, flatten_dict class Ensemble: @@ -90,7 +90,7 @@ class AverageEnsemble(Ensemble): pd.DataFrame: the complete result of averaging and standardizing. """ # need to flatten the nested dict - ensemble_dict = flatten_dict(ensemble_dict) + ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE) values = list(ensemble_dict.values()) results = pd.concat(values, axis=1) results = results.groupby("datetime").apply(lambda df: (df - df.mean()) / df.std()) diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 0b64d3b30..f261a4b4e 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -416,7 +416,7 @@ class DelayTrainerRM(TrainerRM): run_task( end_train_func, task_pool, - tasks=tasks, + query={"filter": {"$in": tasks}}, # only train these tasks experiment_name=experiment_name, before_status=TaskManager.STATUS_PART_DONE, **kwargs, diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 8583e946f..2ff687737 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -716,23 +716,33 @@ def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame: return df.sort_index(axis=axis) +FLATTEN_TUPLE = "_FLATTEN_TUPLE" + + def flatten_dict(d, parent_key="", sep="."): - """flatten_dict. + """ + Flatten a nested dict. + >>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]}) >>> {'a': 1, 'c.a': 2, 'c.b.x': 5, 'd': [1, 2, 3], 'c.b.y': 10} - Parameters - ---------- - d : - d - parent_key : - parent_key - sep : - sep + >>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]}, sep=FLATTEN_TUPLE) + >>> {'a': 1, ('c','a'): 2, ('c','b','x'): 5, 'd': [1, 2, 3], ('c','b','y'): 10} + + Args: + d (dict): the dict waiting for flatting + parent_key (str, optional): the parent key, will be a prefix in new key. Defaults to "". + sep (str, optional): the separator for string connecting. FLATTEN_TUPLE for tuple connecting. + + Returns: + dict: flatten dict """ items = [] for k, v in d.items(): - new_key = parent_key + sep + str(k) if parent_key else k + if sep == FLATTEN_TUPLE: + new_key = (parent_key, k) if parent_key else k + else: + new_key = parent_key + sep + k if parent_key else k if isinstance(v, collections.abc.MutableMapping): items.extend(flatten_dict(v, new_key, sep=sep).items()) else: diff --git a/qlib/utils/serial.py b/qlib/utils/serial.py index 9c5fc9ac2..352949198 100644 --- a/qlib/utils/serial.py +++ b/qlib/utils/serial.py @@ -16,9 +16,10 @@ class Serializable: """ pickle_backend = "pickle" # another optional value is "dill" which can pickle more things of python. + default_dump_all = False # if dump all things def __init__(self): - self._dump_all = False + self._dump_all = self.default_dump_all self._exclude = [] def __getstate__(self) -> dict: @@ -77,12 +78,7 @@ class Serializable: def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list = None): self.config(dump_all=dump_all, exclude=exclude) with Path(path).open("wb") as f: - if self.pickle_backend == "pickle": - pickle.dump(self, f) - elif self.pickle_backend == "dill": - dill.dump(self, f) - else: - raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.") + self.get_backend().dump(self, f) @classmethod def load(cls, filepath): @@ -99,13 +95,24 @@ class Serializable: Collector: the instance of Collector """ with open(filepath, "rb") as f: - if cls.pickle_backend == "pickle": - object = pickle.load(f) - elif cls.pickle_backend == "dill": - object = dill.load(f) - else: - raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.") + object = cls.get_backend().load(f) if isinstance(object, cls): return object else: raise TypeError(f"The instance of {type(object)} is not a valid `{type(cls)}`!") + + @classmethod + def get_backend(cls): + """ + Return the backend of a Serializable class. The value will be "pickle" or "dill". + + Returns: + str: The value of "pickle" or "dill" + """ + if cls.pickle_backend == "pickle": + return pickle + elif cls.pickle_backend == "dill": + return dill + else: + raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.") + diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index e41c3f20a..63169b58d 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -35,7 +35,7 @@ class OnlineManager(Serializable): def __init__( self, - strategy: Union[OnlineStrategy, List[OnlineStrategy]], + strategies: Union[OnlineStrategy, List[OnlineStrategy]], trainer: Trainer = None, begin_time: Union[str, pd.Timestamp] = None, freq="day", @@ -45,15 +45,15 @@ class OnlineManager(Serializable): One OnlineManager must have at least one OnlineStrategy. Args: - strategy (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy + strategies (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None for using latest date. trainer (Trainer): the trainer to train task. None for using DelayTrainerR. freq (str, optional): data frequency. Defaults to "day". """ self.logger = get_module_logger(self.__class__.__name__) - if not isinstance(strategy, list): - strategy = [strategy] - self.strategy = strategy + if not isinstance(strategies, list): + strategies = [strategies] + self.strategies = strategies self.freq = freq if begin_time is None: begin_time = D.calendar(freq=self.freq).max() @@ -77,7 +77,7 @@ class OnlineManager(Serializable): """ models_list = [] if strategies is None: - strategies = self.strategy + strategies = self.strategies for strategy in strategies: self.logger.info(f"Strategy `{strategy.name_id}` begins first training...") tasks = strategy.first_tasks() @@ -114,21 +114,22 @@ class OnlineManager(Serializable): cur_time = D.calendar(freq=self.freq).max() self.cur_time = pd.Timestamp(cur_time) # None for latest date models_list = [] - for strategy in self.strategy: + for strategy in self.strategies: + self.logger.info(f"Strategy `{strategy.name_id}` begins routine...") if not delay: strategy.tool.update_online_pred() - self.logger.info(f"Strategy `{strategy.name_id}` begins routine...") tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs) models = self.trainer.train(tasks) + self.logger.info(f"Finished training {len(models)} models.") models_list.append(models) + for strategy, models in zip(self.strategies, models_list): + self.prepare_online_models(strategy, models, delay=delay, model_kwargs=model_kwargs) + if not delay: self.prepare_signals(**signal_kwargs) - for strategy, models in zip(self.strategy, models_list): - self.prepare_online_models(strategy, models, delay=delay, model_kwargs=model_kwargs) - def prepare_online_models( self, strategy: OnlineStrategy, models: list, delay: bool = False, model_kwargs: dict = {} ): @@ -141,14 +142,9 @@ class OnlineManager(Serializable): delay (bool, optional): if delay prepare models. Defaults to False. model_kwargs (dict, optional): the params for `prepare_online_models`. """ - if not models: - return if not delay: models = self.trainer.end_train(models, experiment_name=strategy.name_id) - online_models = strategy.prepare_online_models(models, **model_kwargs) - else: - # just set every models as online models temporarily before ``prepare_online_models`` - online_models = models + online_models = strategy.prepare_online_models(models, **model_kwargs) self.history.setdefault(self.cur_time, {})[strategy] = online_models def get_collector(self) -> MergeCollector: @@ -160,21 +156,21 @@ class OnlineManager(Serializable): MergeCollector: the collector to merge other collectors. """ collector_dict = {} - for strategy in self.strategy: + for strategy in self.strategies: collector_dict[strategy.name_id] = strategy.get_collector() return MergeCollector(collector_dict, process_list=[]) - def add_strategy(self, strategy: Union[OnlineStrategy, List[OnlineStrategy]]): + def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]): """ Add some new strategies to online manager. Args: strategy (Union[OnlineStrategy, List[OnlineStrategy]]): a list of OnlineStrategy """ - if not isinstance(strategy, list): - strategy = [strategy] - self.first_train(strategy) - self.strategy.extend(strategy) + if not isinstance(strategies, list): + strategies = [strategies] + self.first_train(strategies) + self.strategies.extend(strategies) def prepare_signals(self, prepare_func: Callable = AverageEnsemble(), over_write=False): """ @@ -258,7 +254,8 @@ class OnlineManager(Serializable): if self.trainer.is_delay(): self.delay_prepare(model_kwargs=model_kwargs, signal_kwargs=signal_kwargs) - set_global_logger_level(logging.INFO) + # FIXME: get logging level firstly and restore it here + set_global_logger_level(logging.DEBUG) self.logger.info(f"Finished preparing signals") return self.get_collector() diff --git a/qlib/workflow/online/strategy.py b/qlib/workflow/online/strategy.py index 9f657427d..04c854f79 100644 --- a/qlib/workflow/online/strategy.py +++ b/qlib/workflow/online/strategy.py @@ -44,28 +44,23 @@ class OnlineStrategy: """ raise NotImplementedError(f"Please implement the `prepare_tasks` method.") - def prepare_online_models(self, models, cur_time=None, check_func=None) -> List[object]: + def prepare_online_models(self, models, cur_time=None) -> List[object]: """ - A typically implementation, but maybe you will need old models by online_tool. - Select some models as the online models from the trained models. + Select some models from trained models and set them to online models. + This is a typically implementation to online all trained models, you can override it to implement complex method. + You can find last online models by OnlineTool.online_models if you still need them. - NOTE: This method offline all models and online the online models prepared by this method (if have). So you can find last online models by OnlineTool.online_models if you still need them. + NOTE: Reset all online models to trained model. If there is no trained models, then do nothing. Args: - tasks (list): a list of tasks. - check_func: the method to judge if a model can be online. - The parameter is the model record and return True for online. - None for online every models. + models (list): a list of models. + cur_time (pd.Dataframe): current time from OnlineManger. None for latest. Returns: - List[object]: a list of selected models. + List[object]: a list of online models. """ - if check_func is not None: - online_models = [] - for model in models: - if check_func(model, cur_time): - online_models.append(model) - models = online_models + if not models: + return self.tool.online_models() self.tool.reset_online_tag(models) return models @@ -89,10 +84,10 @@ class OnlineStrategy: raise NotImplementedError(f"Please implement the `get_collector` method.") -class RollingAverageStrategy(OnlineStrategy): +class RollingStrategy(OnlineStrategy): """ - This example strategy always use latest rolling model as online model and prepare trading signals using the average prediction of online models + This example strategy always use latest rolling model as online model. """ def __init__( @@ -102,7 +97,7 @@ class RollingAverageStrategy(OnlineStrategy): rolling_gen: RollingGen, ): """ - Init RollingAverageStrategy. + Init RollingStrategy. Assumption: the str of name_id, the experiment name and the trainer's experiment name are same one. diff --git a/qlib/workflow/online/update.py b/qlib/workflow/online/update.py index a69e1005f..9cb294169 100644 --- a/qlib/workflow/online/update.py +++ b/qlib/workflow/online/update.py @@ -137,7 +137,9 @@ class PredUpdater(RecordUpdater): start_time = get_date_by_shift(self.last_end, 1, freq=self.freq) if start_time >= self.to_date: - self.logger.info(f"The prediction in {self.record.info['id']} are latest. No need to update.") + self.logger.info( + f"The prediction in {self.record.info['id']} are latest ({start_time}). No need to update to {self.to_date}." + ) return # load dataset diff --git a/qlib/workflow/online/utils.py b/qlib/workflow/online/utils.py index c3af9d1ca..3c2774cec 100644 --- a/qlib/workflow/online/utils.py +++ b/qlib/workflow/online/utils.py @@ -158,6 +158,10 @@ class OnlineToolR(OnlineTool): """ online_models = self.online_models() for rec in online_models: - PredUpdater(rec, to_date=to_date).update() + hist_ref = 0 + task = rec.load_object("task") + if task["dataset"]["class"] == "TSDatasetH": + hist_ref = task["dataset"]["kwargs"]["step_len"] + PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update() self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.") diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index b40ee0164..1080d07f4 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -91,18 +91,21 @@ class MergeCollector(Collector): """ - def __init__(self, collector_dict: Dict[str, Collector], process_list: List[Callable] = []): + def __init__(self, collector_dict: Dict[str, Collector], process_list: List[Callable] = [], merge_func=None): """ Args: collector_dict (Dict[str,Collector]): the dict like {collector_key, Collector} process_list (List[Callable]): the list of processors or the instance of processor to process dict. + merge_func (Callable): a method to generate outermost key. The given params are ``collector_key`` from collector_dict and ``key`` from every collector after collecting. + None for use tuple to connect them, such as "ABC"+("a","b") -> ("ABC", ("a","b")). """ super().__init__(process_list=process_list) self.collector_dict = collector_dict + self.merge_func = merge_func def collect(self) -> dict: """ - Collect all result of collector_dict and change the outermost key to "``collector_key``_``key``" (like merge them, but rename every key) + Collect all result of collector_dict and change the outermost key to a recombination key. Returns: dict: the dict after collecting. @@ -111,7 +114,10 @@ class MergeCollector(Collector): for collector_key, collector in self.collector_dict.items(): tmp_dict = collector() for key, value in tmp_dict.items(): - collect_dict[collector_key + "_" + str(key)] = value + if self.merge_func is not None: + collect_dict[self.merge_func(collector_key, key)] = value + else: + collect_dict[(collector_key, key)] = value return collect_dict @@ -146,16 +152,18 @@ class RecorderCollector(Collector): rec_key_func = lambda rec: rec.info["id"] if artifacts_key is None: artifacts_key = list(self.artifacts_path.keys()) - self._rec_key_func = rec_key_func + self.rec_key_func = rec_key_func self.artifacts_key = artifacts_key - self._rec_filter_func = rec_filter_func + self.rec_filter_func = rec_filter_func - def collect(self, artifacts_key=None, rec_filter_func=None) -> dict: + def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> dict: """Collect different artifacts based on recorder after filtering. Args: artifacts_key (str or List, optional): the artifacts key you want to get. If None, use default. rec_filter_func (Callable, optional): filter the recorder by return True or False. If None, use default. + only_exist (bool, optional): if only collect the artifacts when a recorder really have. + If True, the recorder with exception when loading will not be collected. But if False, it will raise the exception. Returns: dict: the dict after collected like {artifact: {rec_key: object}} @@ -163,7 +171,7 @@ class RecorderCollector(Collector): if artifacts_key is None: artifacts_key = self.artifacts_key if rec_filter_func is None: - rec_filter_func = self._rec_filter_func + rec_filter_func = self.rec_filter_func if isinstance(artifacts_key, str): artifacts_key = [artifacts_key] @@ -177,16 +185,18 @@ class RecorderCollector(Collector): recs_flt[rid] = rec for _, rec in recs_flt.items(): - rec_key = self._rec_key_func(rec) + rec_key = self.rec_key_func(rec) for key in artifacts_key: if self.ART_KEY_RAW == key: artifact = rec else: - # only collect existing artifact try: artifact = rec.load_object(self.artifacts_path[key]) - except Exception: - continue + except Exception as e: + if only_exist: + # only collect existing artifact + continue + raise e collect_dict.setdefault(key, {})[rec_key] = artifact return collect_dict diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index 025dfa85c..40f868295 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -155,7 +155,8 @@ class TaskManager: def create_task(self, task_def_l, dry_run=False, print_nt=False) -> List[str]: """ - If the tasks in task_def_l is new, then insert new tasks into the task_pool + If the tasks in task_def_l is new, then insert new tasks into the task_pool, and record inserted_id. + If a task is not new, then query its _id. Parameters ---------- @@ -169,9 +170,10 @@ class TaskManager: Returns ------- list - a list of the _id of new tasks + a list of the _id of task_def_l """ new_tasks = [] + _id_list = [] for t in task_def_l: try: r = self.task_pool.find_one({"filter": t}) @@ -179,6 +181,14 @@ class TaskManager: r = self.task_pool.find_one({"filter": self._dict_to_str(t)}) if r is None: new_tasks.append(t) + if not dry_run: + insert_result = self.insert_task_def(t) + _id_list.append(insert_result.inserted_id) + else: + _id_list.append(None) + else: + _id_list.append(self._decode_task(r)["_id"]) + self.logger.info(f"Total Tasks: {len(task_def_l)}, New Tasks: {len(new_tasks)}") if print_nt: # print new task @@ -188,11 +198,6 @@ class TaskManager: if dry_run: return [] - _id_list = [] - for t in new_tasks: - insert_result = self.insert_task_def(t) - _id_list.append(insert_result.inserted_id) - return _id_list def fetch_task(self, query={}, status=STATUS_WAITING) -> dict: @@ -388,7 +393,7 @@ class TaskManager: def run_task( task_func: Callable, task_pool: str, - tasks: List[dict] = None, + query: dict = {}, force_release: bool = False, before_status: str = TaskManager.STATUS_WAITING, after_status: str = TaskManager.STATUS_DONE, @@ -414,8 +419,8 @@ def run_task( the function to run the task task_pool : str the name of the task pool (Collection in MongoDB) - tasks: List[dict] - will only train these tasks config, None for train all tasks. + query: dict + will use this dict to query task_pool when fetching task force_release : bool will the program force to release the resource before_status : str: @@ -428,9 +433,6 @@ def run_task( tm = TaskManager(task_pool) ever_run = False - query = {} - if tasks is not None: - query = {"filter": {"$in": tasks}} while True: with tm.safe_fetch_task(status=before_status, query=query) as task: