From cca43cf102c6c18958d1363e22cc6855aaaeb473 Mon Sep 17 00:00:00 2001 From: Young Date: Sun, 11 Apr 2021 14:39:19 +0000 Subject: [PATCH] Refactor update & modification when running NN --- qlib/model/ens/ensemble.py | 2 +- qlib/model/ens/group.py | 51 ++++++++----- qlib/model/trainer.py | 8 +- qlib/utils/__init__.py | 4 +- qlib/workflow/online/update.py | 129 +++++++++++++++++++++++++++++++++ qlib/workflow/task/collect.py | 15 ++-- qlib/workflow/task/gen.py | 6 ++ qlib/workflow/task/manage.py | 29 +++++++- 8 files changed, 211 insertions(+), 33 deletions(-) diff --git a/qlib/model/ens/ensemble.py b/qlib/model/ens/ensemble.py index a2333cfeb..942303c18 100644 --- a/qlib/model/ens/ensemble.py +++ b/qlib/model/ens/ensemble.py @@ -58,7 +58,7 @@ class RollingEnsemble(Ensemble): """Merge the rolling objects in an Ensemble""" - def __call__(self, ensemble_dict: dict, *args, **kwargs): + def __call__(self, ensemble_dict: dict): """Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble. NOTE: The values of dict must be pd.Dataframe, and have the index "datetime" diff --git a/qlib/model/ens/group.py b/qlib/model/ens/group.py index d138b917c..f5ab5d8a7 100644 --- a/qlib/model/ens/group.py +++ b/qlib/model/ens/group.py @@ -1,6 +1,7 @@ from qlib.model.ens.ensemble import Ensemble, RollingEnsemble from typing import Callable, Union from qlib.utils.serial import Serializable +from joblib import Parallel, delayed class Group(Serializable): @@ -18,10 +19,23 @@ class Group(Serializable): ens (Ensemble, optional): If not None, do ensemble for grouped value after grouping. """ - self.group = group_func - self.ens = ens + self._group_func = group_func + self._ens_func = ens - def __call__(self, ungrouped_dict: dict, *args, **kwargs): + def group(self, *args, **kwargs): + # TODO: such design is weird when `_group_func` is the only configurable part in the class + if isinstance(getattr(self, "_group_func", None), Callable): + return self._group_func(*args, **kwargs) + else: + raise NotImplementedError(f"Please specify valid `group_func`.") + + def reduce(self, *args, **kwargs): + if isinstance(getattr(self, "_ens_func", None), Callable): + return self._ens_func(*args, **kwargs) + else: + raise NotImplementedError(f"Please specify valid `_ens_func`.") + + def __call__(self, ungrouped_dict: dict, n_jobs=1, verbose=0, *args, **kwargs): """Group the ungrouped_dict into different groups. Args: @@ -30,23 +44,24 @@ class Group(Serializable): Returns: dict: grouped_dict like {G1: object, G2: object} """ - if isinstance(getattr(self, "group", None), Callable): - grouped_dict = self.group(ungrouped_dict, *args, **kwargs) - if self.ens is not None: - ens_dict = {} - for key, value in grouped_dict.items(): - ens_dict[key] = self.ens(value) - grouped_dict = ens_dict - return grouped_dict - else: - raise NotImplementedError(f"Please specify valid group_func.") + + # FIXME: The multiprocessing will raise the following error + # NotImplementedError: Please specify valid `_ens_func`. + # The problem maybe the state of the function is lost + grouped_dict = self.group(ungrouped_dict, *args, **kwargs) + + key_l = [] + job_l = [] + for key, value in grouped_dict.items(): + key_l.append(key) + job_l.append(delayed(Group.reduce)(self, value)) + return dict(zip(key_l, Parallel(n_jobs=n_jobs, verbose=verbose)(job_l))) class RollingGroup(Group): """group the rolling dict""" - @staticmethod - def rolling_group(rolling_dict: dict): + def group(self, rolling_dict: dict): """Given an rolling dict likes {(A,B,R): things}, return the grouped dict likes {(A,B): {R:things}} NOTE: There is a assumption which is the rolling key is at the end of key tuple, because the rolling results always need to be ensemble firstly. @@ -63,7 +78,5 @@ class RollingGroup(Group): grouped_dict.setdefault(key[:-1], {})[key[-1]] = values return grouped_dict - def __init__(self, group_func=None): - super().__init__(group_func=group_func, ens=RollingEnsemble()) - if group_func is None: - self.group = RollingGroup.rolling_group + def __init__(self): + super().__init__(ens=RollingEnsemble()) diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 7ffca20ee..516554155 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -8,6 +8,7 @@ from qlib.workflow.record_temp import SignalRecord from qlib.workflow.task.manage import TaskManager, run_task from qlib.data.dataset import Dataset from qlib.model.base import Model +import socket def task_train(task_config: dict, experiment_name: str) -> Recorder: @@ -35,16 +36,17 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder: # train model R.log_params(**flatten_dict(task_config)) - model.fit(dataset) - recorder = R.get_recorder() - R.save_objects(**{"params.pkl": model}) R.save_objects(**{"task": task_config}) # keep the original format and datatype + R.set_tags(hostname=socket.gethostname()) + model.fit(dataset) + R.save_objects(**{"params.pkl": model}) # This dataset is saved for online inference. So the concrete data should not be dumped dataset.config(dump_all=False, recursive=True) R.save_objects(**{"dataset": dataset}) # generate records: prediction, backtest, and analysis records = task_config.get("record", []) + recorder = R.get_recorder() if isinstance(records, dict): # prevent only one dict records = [records] for record in records: diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 7e71ba76c..3ebc6fc1c 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -522,7 +522,7 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False): return calendar -def get_date_by_shift(trading_date, shift, future=False, clip_shift=True): +def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day"): """get trading date with shift bias wil cur_date e.g. : shift == 1, return next trading date shift == -1, return previous trading date @@ -535,7 +535,7 @@ def get_date_by_shift(trading_date, shift, future=False, clip_shift=True): """ from qlib.data import D - cal = D.calendar(future=future) + cal = D.calendar(future=future, freq=freq) if pd.to_datetime(trading_date) not in list(cal): raise ValueError("{} is not trading day!".format(str(trading_date))) _index = bisect.bisect_left(cal, trading_date) diff --git a/qlib/workflow/online/update.py b/qlib/workflow/online/update.py index 1a6897d02..8835fdae2 100644 --- a/qlib/workflow/online/update.py +++ b/qlib/workflow/online/update.py @@ -1,13 +1,142 @@ from typing import Union, List +from qlib.data.dataset import DatasetH from qlib.workflow import R from qlib.data import D import pandas as pd from qlib import get_module_logger from qlib.workflow import R +from qlib.model import Model from qlib.model.trainer import task_train from qlib.workflow.recorder import Recorder from qlib.workflow.task.utils import list_recorders from qlib.data.dataset.handler import DataHandlerLP +from qlib.data.dataset import DatasetH +from abc import ABCMeta, abstractmethod +from qlib.utils import get_date_by_shift + + +class RMDLoader: + """ + Recorder Model Dataset Loader + """ + + def __init__(self, rec: Recorder): + self.rec = rec + + def get_dataset(self, start_time, end_time, segments=None) -> DatasetH: + """ + load, config and setup dataset. + + This dataset is for inferene + + Parameters + ---------- + start_time : + the start_time of underlying data + end_time : + the end_time of underlying data + segments : dict + the segments config for dataset + Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time + """ + if segments is None: + segments = {"test": (start_time, end_time)} + dataset: DatasetH = self.rec.load_object("dataset") + dataset.config(handler_kwargs={"start_time": start_time, "end_time": end_time}, segments=segments) + dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS}) + return dataset + + def get_model(self) -> Model: + return self.rec.load_object("params.pkl") + + +class RecordUpdater(metaclass=ABCMeta): + """ + Updata a specific recorders + """ + + def __init__(self, record: Recorder, *args, **kwargs): + self.record = record + + @abstractmethod + def update(self, *args, **kwargs): + """ + Update info for specific recorder + """ + ... + + +class PredUpdater(RecordUpdater): + """ + Update the prediction in the Recorder + """ + + LATEST = "__latest" + + def __init__(self, record: Recorder, to_date=LATEST, hist_ref: int = 0, freq="day"): + """ + Parameters + ---------- + record : Recorder + to_date : + update to prediction to the `to_date` + hist_ref : int + Sometimes, the dataset will have historical depends. + Leave the problem to user to set the length of historical dependancy + NOTE: the start_time is not included in the hist_ref + # TODO: automate this step in the future. + """ + super().__init__(record=record) + + self.to_date = to_date + self.hist_ref = hist_ref + self.freq = freq + self.rmdl = RMDLoader(rec=record) + + if to_date == self.LATEST: + to_date = D.calendar(freq=freq)[-1] + self.to_date = pd.Timestamp(to_date) + self.old_pred = record.load_object("pred.pkl") + self.last_end = self.old_pred.index.get_level_values("datetime").max() + + def prepare_data(self) -> DatasetH: + """ + # Load dataset + + Seperating this function will make it easier to reuse the dataset + """ + start_time_buffer = get_date_by_shift(self.last_end, -self.hist_ref + 1, clip_shift=False, freq=self.freq) + start_time = get_date_by_shift(self.last_end, 1, freq=self.freq) + seg = {"test": (start_time, self.to_date)} + dataset = self.rmdl.get_dataset(start_time=start_time_buffer, end_time=self.to_date, segments=seg) + return dataset + + def update(self, dataset: DatasetH = None): + """ + update the precition in a recorder + """ + # FIXME: the problme below is not solved + # The model dumped on GPU instances can not be loaded on CPU instance. Follow exception will raised + # RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU. + + # load dataset + if dataset is None: + # For reusing the dataset + dataset = self.prepare_data() + + # Load model + model = self.rmdl.get_model() + + new_pred = model.predict(dataset) + + cb_pred = pd.concat([self.old_pred, new_pred.to_frame("score")], axis=0) + cb_pred = cb_pred.sort_index() + + self.record.save_objects(**{"pred.pkl": cb_pred}) + + get_module_logger(self.__class__.__name__).info( + f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}." + ) class ModelUpdater: diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index 6b9418daf..9bd609670 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -25,6 +25,8 @@ class Collector(Serializable): class RecorderCollector(Collector): + ART_KEY_RAW = "__raw" + def __init__( self, exp_name, @@ -48,9 +50,9 @@ class RecorderCollector(Collector): rec_key_func = lambda rec: rec.info["id"] if artifacts_key is None: artifacts_key = self.artifacts_path.keys() - self.rec_key = rec_key_func + self._rec_key_func = rec_key_func self.artifacts_key = artifacts_key - self.rec_filter = rec_filter_func + self._rec_filter_func = rec_filter_func def collect(self, artifacts_key=None, rec_filter_func=None): """Collect different artifacts based on recorder after filtering. @@ -65,7 +67,7 @@ class RecorderCollector(Collector): if artifacts_key is None: artifacts_key = self.artifacts_key if rec_filter_func is None: - rec_filter_func = self.rec_filter + rec_filter_func = self._rec_filter_func if isinstance(artifacts_key, str): artifacts_key = [artifacts_key] @@ -74,9 +76,12 @@ class RecorderCollector(Collector): # filter records recs_flt = list_recorders(self.exp_name, rec_filter_func) for _, rec in recs_flt.items(): - rec_key = self.rec_key(rec) + rec_key = self._rec_key_func(rec) for key in artifacts_key: - artifact = rec.load_object(self.artifacts_path[key]) + if self.ART_KEY_RAW == key: + artifact = rec + else: + artifact = rec.load_object(self.artifacts_path[key]) collect_dict.setdefault(key, {})[rec_key] = artifact return collect_dict diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index a8426d920..325089126 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -80,6 +80,12 @@ class TaskGen(metaclass=abc.ABCMeta): """ pass + def __call__(self, *args, **kwargs): + """ + This is just a syntactic sugar for generate + """ + return self.generate(*args, **kwargs) + class RollingGen(TaskGen): ROLL_EX = TimeAdjuster.SHIFT_EX # fixed start date, expanding end date diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index 720eeb12f..815529b66 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -18,7 +18,8 @@ import concurrent import pymongo from qlib.config import C from .utils import get_mongodb -from qlib import get_module_logger +from qlib import get_module_logger, auto_init +import fire class TaskManager: @@ -49,7 +50,7 @@ class TaskManager: ENCODE_FIELDS_PREFIX = ["def", "res"] - def __init__(self, task_pool: str): + def __init__(self, task_pool: str = None): """ init Task Manager, remember to make the statement of MongoDB url and database name firstly. @@ -59,7 +60,8 @@ class TaskManager: the name of Collection in MongoDB """ self.mdb = get_mongodb() - self.task_pool = getattr(self.mdb, task_pool) + if task_pool is not None: + self.task_pool = getattr(self.mdb, task_pool) self.logger = get_module_logger(self.__class__.__name__) def list(self): @@ -287,6 +289,20 @@ class TaskManager: query["_id"] = ObjectId(query["_id"]) print(self.task_pool.update_many(query, {"$set": {"status": status}})) + def prioritize(self, task, priority: int): + """ + set priority for task + + Parameters + ---------- + task : dict + The task query from the database + priority : int + the target priority + """ + update_dict = {"$set": {"priority": priority}} + self.task_pool.update_one({"_id": task["_id"]}, update_dict) + def _get_undone_n(self, task_stat): return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0) @@ -345,3 +361,10 @@ def run_task(task_func, task_pool, force_release=False, *args, **kwargs): ever_run = True return ever_run + + +if __name__ == "__main__": + # This is for using it in cmd + # E.g. : `python -m qlib.workflow.task.manage list` + auto_init() + fire.Fire(TaskManager)