From de86e46ed02de7492f5e82f7f9cc7ac1b5ff235d Mon Sep 17 00:00:00 2001 From: you-n-g Date: Sun, 29 Jun 2025 15:50:59 +0800 Subject: [PATCH] refactor: introduce BaseDataHandler and unify fetch interface (#1958) * refactor: introduce BaseDataHandler and unify fetch interface * refactor: include data_key in seg_kwargs and simplify segments loop * refactor: default data_key to BaseDataHandler.DK_I in _get_df_by_key * style: fix indentation and remove extra blank lines in data handlers * refactor: use BaseDataHandler.DK_I as default data_key * docs: fix BaseDataHandler docstring grammar and formatting * refactor: remove unused **kwargs from storage fetch methods * docs: refine BaseDataHandler and DataHandler docstrings * refactor: rename BaseDataHandler to DataHandlerABC, update type hints * feat: add flt_col to TSDatasetH and list-to-slice conversion in storage * lint * comment --- qlib/data/dataset/__init__.py | 26 ++++----- qlib/data/dataset/handler.py | 101 ++++++++++++++++++++++++---------- qlib/data/dataset/storage.py | 56 +++++++++++++------ qlib/model/trainer.py | 4 +- 4 files changed, 126 insertions(+), 61 deletions(-) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 0b6c552a3..a6cace373 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -226,13 +226,8 @@ class DatasetH(Dataset): ------ NotImplementedError: """ - logger = get_module_logger("DatasetH") - seg_kwargs = {"col_set": col_set} + seg_kwargs = {"col_set": col_set, "data_key": data_key} seg_kwargs.update(kwargs) - if "data_key" in getfullargspec(self.handler.fetch).args: - seg_kwargs["data_key"] = data_key - else: - logger.info(f"data_key[{data_key}] is ignored.") # Conflictions may happen here # - The fetched data and the segment key may both be string @@ -240,9 +235,11 @@ class DatasetH(Dataset): # - The segment name will have higher priorities # 1) Use it as segment name first + # 1.1) directly fetch split like "train" "valid" "test" if isinstance(segments, str) and segments in self.segments: return self._prepare_seg(self.segments[segments], **seg_kwargs) + # 1.2) fetch multiple splits like ["train", "valid"] ["train", "valid", "test"] if isinstance(segments, (list, tuple)) and all(seg in self.segments for seg in segments): return [self._prepare_seg(self.segments[seg], **seg_kwargs) for seg in segments] @@ -262,7 +259,7 @@ class DatasetH(Dataset): def _get_extrema(segments, idx: int, cmp: Callable, key_func=pd.Timestamp): """it will act like sort and return the max value or None""" candidate = None - for k, seg in segments.items(): + for _, seg in segments.items(): point = seg[idx] if point is None: # None indicates unbounded, return directly @@ -376,6 +373,8 @@ class TSDataSampler: ffill with previous samples first and fill with later samples second flt_data : pd.Series a column of data(True or False) to filter data. Its index order is <"datetime", "instrument"> + This feature is essential because: + - We want some sample not included due to label-based filtering, but we can't filter them at the beginning due to the features is still important in the feature. None: kepp all data @@ -661,8 +660,9 @@ class TSDatasetH(DatasetH): DEFAULT_STEP_LEN = 30 - def __init__(self, step_len=DEFAULT_STEP_LEN, **kwargs): + def __init__(self, step_len=DEFAULT_STEP_LEN, flt_col: Optional[str] = None, **kwargs): self.step_len = step_len + self.flt_col = flt_col super().__init__(**kwargs) def config(self, **kwargs): @@ -693,10 +693,10 @@ class TSDatasetH(DatasetH): dtype = kwargs.pop("dtype", None) if not isinstance(slc, slice): slc = slice(*slc) - start, end = slc.start, slc.stop - flt_col = kwargs.pop("flt_col", None) - # TSDatasetH will retrieve more data for complete time-series + if (flt_col := kwargs.pop("flt_col", None)) is None: + flt_col = self.flt_col + # TSDatasetH will retrieve more data for complete time-series ext_slice = self._extend_slice(slc, self.cal, self.step_len) data = super()._prepare_seg(ext_slice, **kwargs) @@ -710,8 +710,8 @@ class TSDatasetH(DatasetH): tsds = TSDataSampler( data=data, - start=start, - end=end, + start=slc.start, + end=slc.stop, step_len=self.step_len, dtype=dtype, flt_data=flt_data, diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index ad4178d34..551b43a98 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. # coding=utf-8 +from abc import abstractmethod import warnings from typing import Callable, Union, Tuple, List, Iterator, Optional @@ -19,9 +20,59 @@ from . import processor as processor_module from . import loader as data_loader_module -# TODO: A more general handler interface which does not relies on internal pd.DataFrame is needed. -class DataHandler(Serializable): +DATA_KEY_TYPE = Literal["raw", "infer", "learn"] + + +class DataHandlerABC(Serializable): """ + Interface for data handler. + + This class does not assume the internal data structure of the data handler. + It only defines the interface for external users (uses DataFrame as the internal data structure). + + In the future, the data handler's more detailed implementation should be refactored. Here are some guidelines: + + It covers several components: + + - [data loader] -> internal representation of the data -> data preprocessing -> interface adaptor for the fetch interface + - The workflow to combine them all: + The workflow may be very complicated. DataHandlerLP is one of the practices, but it can't satisfy all the requirements. + So leaving the flexibility to the user to implement the workflow is a more reasonable choice. + """ + + def __init__(self, *args, **kwargs): + """ + We should define how to get ready for the fetching. + """ + super().__init__(*args, **kwargs) + + CS_ALL = "__all" # return all columns with single-level index column + CS_RAW = "__raw" # return raw data with multi-level index column + + # data key + DK_R: DATA_KEY_TYPE = "raw" + DK_I: DATA_KEY_TYPE = "infer" + DK_L: DATA_KEY_TYPE = "learn" + + @abstractmethod + def fetch( + self, + selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None), + level: Union[str, int] = "datetime", + col_set: Union[str, List[str]] = CS_ALL, + data_key: DATA_KEY_TYPE = DK_I, + ) -> pd.DataFrame: + pass + + +class DataHandler(DataHandlerABC): + """ + The motivation of DataHandler: + + - It provides an implementation of BaseDataHandler that we implement with: + - Handling responses with an internal loaded DataFrame + - The DataFrame is loaded by a data loader. + The steps to using a handler 1. initialized data handler (call by `init`). 2. use the data. @@ -144,16 +195,14 @@ class DataHandler(Serializable): self._data = lazy_sort_index(self.data_loader.load(self.instruments, self.start_time, self.end_time)) # TODO: cache - CS_ALL = "__all" # return all columns with single-level index column - CS_RAW = "__raw" # return raw data with multi-level index column - def fetch( self, selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None), level: Union[str, int] = "datetime", - col_set: Union[str, List[str]] = CS_ALL, + col_set: Union[str, List[str]] = DataHandlerABC.CS_ALL, + data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I, squeeze: bool = False, - proc_func: Callable = None, + proc_func: Optional[Callable] = None, ) -> pd.DataFrame: """ fetch data from underlying data source @@ -216,6 +265,8 @@ class DataHandler(Serializable): ------- pd.DataFrame. """ + # DataHandler is an example with only one dataframe, so data_key is not used. + _ = data_key # avoid linting errors (e.g., unused-argument) return self._fetch_data( data_storage=self._data, selector=selector, @@ -230,7 +281,7 @@ class DataHandler(Serializable): data_storage, selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None), level: Union[str, int] = "datetime", - col_set: Union[str, List[str]] = CS_ALL, + col_set: Union[str, List[str]] = DataHandlerABC.CS_ALL, squeeze: bool = False, proc_func: Callable = None, ): @@ -261,16 +312,9 @@ class DataHandler(Serializable): data_df = fetch_df_by_col(data_df, col_set) data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig) elif isinstance(data_storage, BaseHandlerStorage): - if not data_storage.is_proc_func_supported(): - if proc_func is not None: - raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}") - data_df = data_storage.fetch( - selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig - ) - else: - data_df = data_storage.fetch( - selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func - ) + if proc_func is not None: + raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}") + data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig) else: raise TypeError(f"data_storage should be pd.DataFrame|HashingStockStorage, not {type(data_storage)}") @@ -282,7 +326,7 @@ class DataHandler(Serializable): data_df = data_df.reset_index(level=level, drop=True) return data_df - def get_cols(self, col_set=CS_ALL) -> list: + def get_cols(self, col_set=DataHandlerABC.CS_ALL) -> list: """ get the column names @@ -336,11 +380,12 @@ class DataHandler(Serializable): yield cur_date, self.fetch(selector, **kwargs) -DATA_KEY_TYPE = Literal["raw", "infer", "learn"] - - class DataHandlerLP(DataHandler): """ + Motivation: + - For the case that we hope using different processor workflows for learning and inference; + + DataHandler with **(L)earnable (P)rocessor** This handler will produce three pieces of data in pd.DataFrame format. @@ -374,12 +419,8 @@ class DataHandlerLP(DataHandler): _infer: pd.DataFrame # data for inference _learn: pd.DataFrame # data for learning models - # data key - DK_R: DATA_KEY_TYPE = "raw" - DK_I: DATA_KEY_TYPE = "infer" - DK_L: DATA_KEY_TYPE = "learn" # map data_key to attribute name - ATTR_MAP = {DK_R: "_data", DK_I: "_infer", DK_L: "_learn"} + ATTR_MAP = {DataHandler.DK_R: "_data", DataHandler.DK_I: "_infer", DataHandler.DK_L: "_learn"} # process type PTYPE_I = "independent" @@ -622,7 +663,7 @@ class DataHandlerLP(DataHandler): # TODO: Be able to cache handler data. Save the memory for data processing - def _get_df_by_key(self, data_key: DATA_KEY_TYPE = DK_I) -> pd.DataFrame: + def _get_df_by_key(self, data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I) -> pd.DataFrame: if data_key == self.DK_R and self.drop_raw: raise AttributeError( "DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data" @@ -635,7 +676,7 @@ class DataHandlerLP(DataHandler): selector: Union[pd.Timestamp, slice, str] = slice(None, None), level: Union[str, int] = "datetime", col_set=DataHandler.CS_ALL, - data_key: DATA_KEY_TYPE = DK_I, + data_key: DATA_KEY_TYPE = DataHandler.DK_I, squeeze: bool = False, proc_func: Callable = None, ) -> pd.DataFrame: @@ -669,7 +710,7 @@ class DataHandlerLP(DataHandler): proc_func=proc_func, ) - def get_cols(self, col_set=DataHandler.CS_ALL, data_key: DATA_KEY_TYPE = DK_I) -> list: + def get_cols(self, col_set=DataHandler.CS_ALL, data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I) -> list: """ get the column names diff --git a/qlib/data/dataset/storage.py b/qlib/data/dataset/storage.py index 62e7ba7e4..dd51f1d5f 100644 --- a/qlib/data/dataset/storage.py +++ b/qlib/data/dataset/storage.py @@ -1,8 +1,10 @@ +from abc import abstractmethod import pandas as pd import numpy as np from .handler import DataHandler -from typing import Union, List, Callable +from typing import Union, List +from qlib.log import get_module_logger from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col @@ -14,14 +16,13 @@ class BaseHandlerStorage: - If users want to use custom data storage, they should define subclass inherited BaseHandlerStorage, and implement the following method """ + @abstractmethod def fetch( self, - selector: Union[pd.Timestamp, slice, str, list] = slice(None, None), + selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None), level: Union[str, int] = "datetime", col_set: Union[str, List[str]] = DataHandler.CS_ALL, fetch_orig: bool = True, - proc_func: Callable = None, - **kwargs, ) -> pd.DataFrame: """fetch data from the data storage @@ -41,8 +42,6 @@ class BaseHandlerStorage: select several sets of meaningful columns, the returned data has multiple level fetch_orig : bool Return the original data instead of copy if possible. - proc_func: Callable - please refer to the doc of DataHandler.fetch Returns ------- @@ -51,13 +50,40 @@ class BaseHandlerStorage: """ raise NotImplementedError("fetch is method not implemented!") - @staticmethod - def from_df(df: pd.DataFrame): - raise NotImplementedError("from_df method is not implemented!") - def is_proc_func_supported(self): - """whether the arg `proc_func` in `fetch` method is supported.""" - raise NotImplementedError("is_proc_func_supported method is not implemented!") +class NaiveDFStorage(BaseHandlerStorage): + """Naive data storage for datahandler + - NaiveDFStorage is a naive data storage for datahandler + - NaiveDFStorage will input a pandas.DataFrame as and provide interface support for fetching data + """ + + def __init__(self, df: pd.DataFrame): + self.df = df + + def fetch( + self, + selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None), + level: Union[str, int] = "datetime", + col_set: Union[str, List[str]] = DataHandler.CS_ALL, + fetch_orig: bool = True, + ) -> pd.DataFrame: + + # Following conflicts may occur + # - Does [20200101", "20210101"] mean selecting this slice or these two days? + # To solve this issue + # - slice have higher priorities (except when level is none) + if isinstance(selector, (tuple, list)) and level is not None: + # when level is None, the argument will be passed in directly + # we don't have to convert it into slice + try: + selector = slice(*selector) + except ValueError: + get_module_logger("DataHandlerLP").info(f"Fail to converting to query to slice. It will used directly") + + data_df = self.df + data_df = fetch_df_by_col(data_df, col_set) + data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=fetch_orig) + return data_df class HashingStockStorage(BaseHandlerStorage): @@ -142,7 +168,7 @@ class HashingStockStorage(BaseHandlerStorage): def fetch( self, - selector: Union[pd.Timestamp, slice, str] = slice(None, None), + selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None), level: Union[str, int] = "datetime", col_set: Union[str, List[str]] = DataHandler.CS_ALL, fetch_orig: bool = True, @@ -164,7 +190,3 @@ class HashingStockStorage(BaseHandlerStorage): return fetch_stock_df_list[0] else: return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig) - - def is_proc_func_supported(self): - """the arg `proc_func` in `fetch` method is not supported in HashingStockStorage""" - return False diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 606c2154e..ce204420f 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -240,7 +240,9 @@ class TrainerR(Trainer): self.train_func = train_func self._call_in_subproc = call_in_subproc - def train(self, tasks: list, train_func: Callable = None, experiment_name: str = None, **kwargs) -> List[Recorder]: + def train( + self, tasks: list, train_func: Optional[Callable] = None, experiment_name: Optional[str] = None, **kwargs + ) -> List[Recorder]: """ Given a list of `tasks` and return a list of trained Recorder. The order can be guaranteed.