From 03cce8c90826e33d060c64f1feb37bf26d540649 Mon Sep 17 00:00:00 2001 From: you-n-g Date: Mon, 3 Jan 2022 15:52:03 +0800 Subject: [PATCH] Some Optimization of online code (#784) * Some Optimization of online code * more flexible updater and load_object & fix p*_uri * make recorder more friendly * remove unused import --- qlib/data/data.py | 9 ---- qlib/data/storage/file_storage.py | 22 ++++++--- qlib/model/ens/ensemble.py | 2 + qlib/model/ens/group.py | 8 ++-- qlib/model/trainer.py | 1 + qlib/utils/__init__.py | 2 +- qlib/workflow/exp.py | 8 ++-- qlib/workflow/online/update.py | 77 +++++++++++++++++++++++++------ qlib/workflow/online/utils.py | 8 +--- qlib/workflow/recorder.py | 6 ++- qlib/workflow/task/collect.py | 19 +++++++- 11 files changed, 113 insertions(+), 49 deletions(-) diff --git a/qlib/data/data.py b/qlib/data/data.py index cdc5d8076..186e907f1 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -55,15 +55,6 @@ class ProviderBackendMixin: def backend_obj(self, **kwargs): backend = self.backend if self.backend else self.get_default_backend() backend = copy.deepcopy(backend) - - # set default storage kwargs - # NOTE: provider_uri priority: - # 1. backend_config: backend_obj["kwargs"]["provider_uri"] - # 2. qlib.init: provider_uri - backend_kwargs = backend.setdefault("kwargs", {}) - provider_uri = backend_kwargs.get("provider_uri", None) - provider_uri = C.dpm.provider_uri if provider_uri is None else C.dpm.format_provider_uri(provider_uri) - backend_kwargs["provider_uri"] = provider_uri backend.setdefault("kwargs", {}).update(**kwargs) return init_instance_by_config(backend) diff --git a/qlib/data/storage/file_storage.py b/qlib/data/storage/file_storage.py index cd9794638..31f2712a2 100644 --- a/qlib/data/storage/file_storage.py +++ b/qlib/data/storage/file_storage.py @@ -24,9 +24,17 @@ class FileStorageMixin: """ + # NOTE: provider_uri priority: + # 1. self._provider_uri : if provider_uri is provided. + # 2. provider_uri in qlib.config.C + + @property + def provider_uri(self): + return C["provider_uri"] if getattr(self, "_provider_uri", None) is None else self._provider_uri + @property def dpm(self): - return C.DataPathManager(self.provider_uri, None) + return C.dpm if getattr(self, "_provider_uri", None) is None else C.DataPathManager(self._provider_uri, None) @property def support_freq(self) -> List[str]: @@ -62,10 +70,10 @@ class FileStorageMixin: class FileCalendarStorage(FileStorageMixin, CalendarStorage): - def __init__(self, freq: str, future: bool, provider_uri: dict, **kwargs): + def __init__(self, freq: str, future: bool, provider_uri: dict = None, **kwargs): super(FileCalendarStorage, self).__init__(freq, future, **kwargs) self.future = future - self.provider_uri = C.DataPathManager.format_provider_uri(provider_uri) + self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri) self.enable_read_cache = True # TODO: make it configurable @property @@ -173,9 +181,9 @@ class FileInstrumentStorage(FileStorageMixin, InstrumentStorage): INSTRUMENT_END_FIELD = "end_datetime" SYMBOL_FIELD_NAME = "instrument" - def __init__(self, market: str, freq: str, provider_uri: dict, **kwargs): + def __init__(self, market: str, freq: str, provider_uri: dict = None, **kwargs): super(FileInstrumentStorage, self).__init__(market, freq, **kwargs) - self.provider_uri = C.DataPathManager.format_provider_uri(provider_uri) + self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri) self.file_name = f"{market.lower()}.txt" def _read_instrument(self) -> Dict[InstKT, InstVT]: @@ -262,9 +270,9 @@ class FileInstrumentStorage(FileStorageMixin, InstrumentStorage): class FileFeatureStorage(FileStorageMixin, FeatureStorage): - def __init__(self, instrument: str, field: str, freq: str, provider_uri: dict, **kwargs): + def __init__(self, instrument: str, field: str, freq: str, provider_uri: dict = None, **kwargs): super(FileFeatureStorage, self).__init__(instrument, field, freq, **kwargs) - self.provider_uri = C.DataPathManager.format_provider_uri(provider_uri) + self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri) self.file_name = f"{instrument.lower()}/{field.lower()}.{freq.lower()}.bin" def clear(self): diff --git a/qlib/model/ens/ensemble.py b/qlib/model/ens/ensemble.py index 0997c9367..863282416 100644 --- a/qlib/model/ens/ensemble.py +++ b/qlib/model/ens/ensemble.py @@ -122,6 +122,8 @@ class AverageEnsemble(Ensemble): # need to flatten the nested dict ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE) values = list(ensemble_dict.values()) + # NOTE: this may change the style underlying data!!!! + # from pd.DataFrame to pd.Series results = pd.concat(values, axis=1) results = results.groupby("datetime").apply(lambda df: (df - df.mean()) / df.std()) results = results.mean(axis=1) diff --git a/qlib/model/ens/group.py b/qlib/model/ens/group.py index 7f45b06a5..f9a3cb81b 100644 --- a/qlib/model/ens/group.py +++ b/qlib/model/ens/group.py @@ -3,9 +3,9 @@ """ Group can group a set of objects based on `group_func` and change them to a dict. -After group, we provide a method to reduce them. +After group, we provide a method to reduce them. -For example: +For example: group: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}} reduce: {(A,B): {C1: object, C2: object}} -> {(A,B): object} @@ -109,5 +109,5 @@ class RollingGroup(Group): grouped_dict.setdefault(key[:-1], {})[key[-1]] = values return grouped_dict - def __init__(self): - super().__init__(ens=RollingEnsemble()) + def __init__(self, ens=RollingEnsemble()): + super().__init__(ens=ens) diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index c9050638c..642fe7b35 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -26,6 +26,7 @@ from qlib.workflow.record_temp import SignalRecord from qlib.workflow.recorder import Recorder from qlib.workflow.task.manage import TaskManager, run_task + # from qlib.data.dataset.weight import Reweighter diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 3c503fa40..07f69bbf2 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -27,7 +27,7 @@ import collections import numpy as np import pandas as pd from pathlib import Path -from typing import Dict, Union, Tuple, Any, Text, Optional +from typing import Dict, Union, Tuple, Any, Text, Optional, Callable from types import ModuleType from urllib.parse import urlparse diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index 2136ece8d..506d382fd 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import Union +from typing import Dict, Union import mlflow, logging from mlflow.entities import ViewType from mlflow.exceptions import MlflowException @@ -214,7 +214,7 @@ class Experiment: """ raise NotImplementedError(f"Please implement the `_get_recorder` method") - def list_recorders(self, **flt_kwargs): + def list_recorders(self, **flt_kwargs) -> Dict[str, Recorder]: """ List all the existing recorders of this experiment. Please first get the experiment instance before calling this method. If user want to use the method `R.list_recorders()`, please refer to the related API document in `QlibRecorder`. @@ -325,7 +325,9 @@ class MLflowExperiment(Experiment): UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!! - def list_recorders(self, max_results: int = UNLIMITED, status: Union[str, None] = None, filter_string: str = ""): + def list_recorders( + self, max_results: int = UNLIMITED, status: Union[str, None] = None, filter_string: str = "" + ) -> Dict[str, Recorder]: """ Parameters ---------- diff --git a/qlib/workflow/online/update.py b/qlib/workflow/online/update.py index ae6a21427..1e8c7d750 100644 --- a/qlib/workflow/online/update.py +++ b/qlib/workflow/online/update.py @@ -5,11 +5,12 @@ Updater is a module to update artifacts such as predictions when the stock data """ from abc import ABCMeta, abstractmethod +from typing import Optional import pandas as pd from qlib import get_module_logger from qlib.data import D -from qlib.data.dataset import Dataset, DatasetH +from qlib.data.dataset import Dataset, DatasetH, TSDatasetH from qlib.data.dataset.handler import DataHandlerLP from qlib.model import Model from qlib.utils import get_date_by_shift @@ -25,7 +26,9 @@ class RMDLoader: def __init__(self, rec: Recorder): self.rec = rec - def get_dataset(self, start_time, end_time, segments=None) -> DatasetH: + def get_dataset( + self, start_time, end_time, segments=None, unprepared_dataset: Optional[DatasetH] = None + ) -> DatasetH: """ Load, config and setup dataset. @@ -39,6 +42,8 @@ class RMDLoader: segments : dict the segments config for dataset Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time + unprepared_dataset: Optional[DatasetH] + if user don't want to load dataset from recorder, please specify user's dataset Returns: DatasetH: the instance of DatasetH @@ -46,7 +51,10 @@ class RMDLoader: """ if segments is None: segments = {"test": (start_time, end_time)} - dataset: DatasetH = self.rec.load_object("dataset") + if unprepared_dataset is None: + dataset: DatasetH = self.rec.load_object("dataset") + else: + dataset = unprepared_dataset dataset.config(handler_kwargs={"start_time": start_time, "end_time": end_time}, segments=segments) dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS}) return dataset @@ -90,7 +98,16 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta): SZ300676 -0.001321 """ - def __init__(self, record: Recorder, to_date=None, from_date=None, hist_ref: int = 0, freq="day", fname="pred.pkl"): + def __init__( + self, + record: Recorder, + to_date=None, + from_date=None, + hist_ref: Optional[int] = None, + freq="day", + fname="pred.pkl", + loader_cls: type = RMDLoader, + ): """ Init PredUpdater. @@ -111,11 +128,15 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta): hist_ref : int Sometimes, the dataset will have historical depends. Leave the problem to users to set the length of historical dependency + If user doesn't specify this parameter, Updater will try to load dataset to automatically determine the hist_ref .. note:: the start_time is not included in the hist_ref + loader_cls : type + the class to load the model and dataset + """ # TODO: automate this hist_ref in the future. super().__init__(record=record) @@ -124,7 +145,7 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta): self.hist_ref = hist_ref self.freq = freq self.fname = fname - self.rmdl = RMDLoader(rec=record) + self.rmdl = loader_cls(rec=record) latest_date = D.calendar(freq=freq)[-1] if to_date == None: @@ -148,27 +169,50 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta): else: self.last_end = get_date_by_shift(from_date, -1, align="right") - def prepare_data(self) -> DatasetH: + def prepare_data(self, unprepared_dataset: Optional[DatasetH] = None) -> DatasetH: """ Load dataset + - if unprepared_dataset is specified, then prepare the dataset directly + - Otherwise, Separating this function will make it easier to reuse the dataset Returns: DatasetH: the instance of DatasetH """ - start_time_buffer = get_date_by_shift(self.last_end, -self.hist_ref + 1, clip_shift=False, freq=self.freq) + # automatically getting the historical dependency if not specified + if self.hist_ref is None: + dataset: DatasetH = self.record.load_object("dataset") if unprepared_dataset is None else unprepared_dataset + # Special treatment of historical dependencies + if isinstance(dataset, TSDatasetH): + hist_ref = dataset.step_len + else: + hist_ref = 0 + else: + hist_ref = self.hist_ref + + start_time_buffer = get_date_by_shift(self.last_end, -hist_ref + 1, clip_shift=False, freq=self.freq) start_time = get_date_by_shift(self.last_end, 1, freq=self.freq) seg = {"test": (start_time, self.to_date)} - dataset = self.rmdl.get_dataset(start_time=start_time_buffer, end_time=self.to_date, segments=seg) - return dataset + return self.rmdl.get_dataset( + start_time=start_time_buffer, end_time=self.to_date, segments=seg, unprepared_dataset=unprepared_dataset + ) - def update(self, dataset: DatasetH = None): + def update(self, dataset: DatasetH = None, write: bool = True, ret_new: bool = False) -> Optional[object]: """ - Update the data in a recorder. + Parameters + ---------- + dataset : DatasetH + DatasetH: the instance of DatasetH. None for prepare it again. + write : bool + will the the write action be executed + ret_new : bool + will the updated data be returned - Args: - DatasetH: the instance of DatasetH. None for reprepare. + Returns + ------- + Optional[object] + the updated dataset """ # FIXME: the problem below is not solved # The model dumped on GPU instances can not be loaded on CPU instance. Follow exception will raised @@ -186,7 +230,12 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta): # For reusing the dataset dataset = self.prepare_data() - self.record.save_objects(**{self.fname: self.get_update_data(dataset)}) + updated_data = self.get_update_data(dataset) + + if write: + self.record.save_objects(**{self.fname: updated_data}) + if ret_new: + return updated_data @abstractmethod def get_update_data(self, dataset: Dataset) -> pd.DataFrame: diff --git a/qlib/workflow/online/utils.py b/qlib/workflow/online/utils.py index b1743d932..75ff3c4fd 100644 --- a/qlib/workflow/online/utils.py +++ b/qlib/workflow/online/utils.py @@ -169,14 +169,8 @@ class OnlineToolR(OnlineTool): exp_name = self._get_exp_name(exp_name) online_models = self.online_models(exp_name=exp_name) for rec in online_models: - hist_ref = 0 - task = rec.load_object("task") - # Special treatment of historical dependencies - cls, kwargs = get_callable_kwargs(task["dataset"], default_module="qlib.data.dataset") - if issubclass(cls, TSDatasetH): - hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN) try: - updater = PredUpdater(rec, to_date=to_date, from_date=from_date, hist_ref=hist_ref) + updater = PredUpdater(rec, to_date=to_date, from_date=from_date) except LoadObjectError as e: # skip the recorder without pred self.logger.warn(f"An exception `{str(e)}` happened when load `pred.pkl`, skip it.") diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 056d75be1..bd777e40b 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -326,13 +326,15 @@ class MLflowRecorder(Recorder): self.client.log_artifact(self.id, temp_dir / name, artifact_path) shutil.rmtree(temp_dir) - def load_object(self, name): + def load_object(self, name, unpickler=pickle.Unpickler): """ Load object such as prediction file or model checkpoint in mlflow. Args: name (str): the object name + unpickler: Supporting using custom unpickler + Raises: LoadObjectError: if raise some exceptions when load the object @@ -344,7 +346,7 @@ class MLflowRecorder(Recorder): try: path = self.client.download_artifacts(self.id, name) with Path(path).open("rb") as f: - data = pickle.load(f) + data = unpickler(f).load() ar = self.client._tracking_client._get_artifact_repo(self.id) if isinstance(ar, AzureBlobArtifactRepository): # for saving disk space diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index b5b63bba6..8d7b6a71c 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -5,12 +5,14 @@ Collector module can collect objects from everywhere and process them such as merging, grouping, averaging and so on. """ +from collections import defaultdict from qlib.log import TimeInspector -from typing import Callable, Dict, List +from typing import Callable, Dict, Iterable, List from qlib.log import get_module_logger from qlib.utils.serial import Serializable from qlib.workflow import R from qlib.workflow.exp import Experiment +from qlib.workflow.recorder import Recorder class Collector(Serializable): @@ -142,6 +144,7 @@ class RecorderCollector(Collector): artifacts_path={"pred": "pred.pkl"}, artifacts_key=None, list_kwargs={}, + status: Iterable = {Recorder.STATUS_FI}, ): """ Init RecorderCollector. @@ -156,6 +159,7 @@ class RecorderCollector(Collector): artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}. artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts. list_kwargs (str): arguments for list_recorders function. + status (Iterable): only collect recorders with specific status. None indicating collecting all the recorders """ super().__init__(process_list=process_list) if isinstance(experiment, str): @@ -171,6 +175,7 @@ class RecorderCollector(Collector): self.artifacts_key = artifacts_key self.rec_filter_func = rec_filter_func self.list_kwargs = list_kwargs + self.status = status def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> dict: """ @@ -202,9 +207,19 @@ class RecorderCollector(Collector): elif isinstance(self.experiment, Callable): recs = self.experiment() - recs = [rec for rec in recs if rec_filter_func is None or rec_filter_func(rec)] + recs = [ + rec + for rec in recs + if ( + (self.status is None or rec.status in self.status) and (rec_filter_func is None or rec_filter_func(rec)) + ) + ] logger = get_module_logger("RecorderCollector") + status_stat = defaultdict(int) + for r in recs: + status_stat[r.status] += 1 + logger.info(f"Nubmer of recorders after filter: {status_stat}") for rec in recs: rec_key = self.rec_key_func(rec) for key in artifacts_key: