1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

refactor: introduce BaseDataHandler and unify fetch interface (#1958)

* refactor: introduce BaseDataHandler and unify fetch interface

* refactor: include data_key in seg_kwargs and simplify segments loop

* refactor: default data_key to BaseDataHandler.DK_I in _get_df_by_key

* style: fix indentation and remove extra blank lines in data handlers

* refactor: use BaseDataHandler.DK_I as default data_key

* docs: fix BaseDataHandler docstring grammar and formatting

* refactor: remove unused **kwargs from storage fetch methods

* docs: refine BaseDataHandler and DataHandler docstrings

* refactor: rename BaseDataHandler to DataHandlerABC, update type hints

* feat: add flt_col to TSDatasetH and list-to-slice conversion in storage

* lint

* comment
This commit is contained in:
you-n-g
2025-06-29 15:50:59 +08:00
committed by GitHub
parent ba8b6cc30f
commit de86e46ed0
4 changed files with 126 additions and 61 deletions

View File

@@ -226,13 +226,8 @@ class DatasetH(Dataset):
------
NotImplementedError:
"""
logger = get_module_logger("DatasetH")
seg_kwargs = {"col_set": col_set}
seg_kwargs = {"col_set": col_set, "data_key": data_key}
seg_kwargs.update(kwargs)
if "data_key" in getfullargspec(self.handler.fetch).args:
seg_kwargs["data_key"] = data_key
else:
logger.info(f"data_key[{data_key}] is ignored.")
# Conflictions may happen here
# - The fetched data and the segment key may both be string
@@ -240,9 +235,11 @@ class DatasetH(Dataset):
# - The segment name will have higher priorities
# 1) Use it as segment name first
# 1.1) directly fetch split like "train" "valid" "test"
if isinstance(segments, str) and segments in self.segments:
return self._prepare_seg(self.segments[segments], **seg_kwargs)
# 1.2) fetch multiple splits like ["train", "valid"] ["train", "valid", "test"]
if isinstance(segments, (list, tuple)) and all(seg in self.segments for seg in segments):
return [self._prepare_seg(self.segments[seg], **seg_kwargs) for seg in segments]
@@ -262,7 +259,7 @@ class DatasetH(Dataset):
def _get_extrema(segments, idx: int, cmp: Callable, key_func=pd.Timestamp):
"""it will act like sort and return the max value or None"""
candidate = None
for k, seg in segments.items():
for _, seg in segments.items():
point = seg[idx]
if point is None:
# None indicates unbounded, return directly
@@ -376,6 +373,8 @@ class TSDataSampler:
ffill with previous samples first and fill with later samples second
flt_data : pd.Series
a column of data(True or False) to filter data. Its index order is <"datetime", "instrument">
This feature is essential because:
- We want some sample not included due to label-based filtering, but we can't filter them at the beginning due to the features is still important in the feature.
None:
kepp all data
@@ -661,8 +660,9 @@ class TSDatasetH(DatasetH):
DEFAULT_STEP_LEN = 30
def __init__(self, step_len=DEFAULT_STEP_LEN, **kwargs):
def __init__(self, step_len=DEFAULT_STEP_LEN, flt_col: Optional[str] = None, **kwargs):
self.step_len = step_len
self.flt_col = flt_col
super().__init__(**kwargs)
def config(self, **kwargs):
@@ -693,10 +693,10 @@ class TSDatasetH(DatasetH):
dtype = kwargs.pop("dtype", None)
if not isinstance(slc, slice):
slc = slice(*slc)
start, end = slc.start, slc.stop
flt_col = kwargs.pop("flt_col", None)
# TSDatasetH will retrieve more data for complete time-series
if (flt_col := kwargs.pop("flt_col", None)) is None:
flt_col = self.flt_col
# TSDatasetH will retrieve more data for complete time-series
ext_slice = self._extend_slice(slc, self.cal, self.step_len)
data = super()._prepare_seg(ext_slice, **kwargs)
@@ -710,8 +710,8 @@ class TSDatasetH(DatasetH):
tsds = TSDataSampler(
data=data,
start=start,
end=end,
start=slc.start,
end=slc.stop,
step_len=self.step_len,
dtype=dtype,
flt_data=flt_data,

View File

@@ -2,6 +2,7 @@
# Licensed under the MIT License.
# coding=utf-8
from abc import abstractmethod
import warnings
from typing import Callable, Union, Tuple, List, Iterator, Optional
@@ -19,9 +20,59 @@ from . import processor as processor_module
from . import loader as data_loader_module
# TODO: A more general handler interface which does not relies on internal pd.DataFrame is needed.
class DataHandler(Serializable):
DATA_KEY_TYPE = Literal["raw", "infer", "learn"]
class DataHandlerABC(Serializable):
"""
Interface for data handler.
This class does not assume the internal data structure of the data handler.
It only defines the interface for external users (uses DataFrame as the internal data structure).
In the future, the data handler's more detailed implementation should be refactored. Here are some guidelines:
It covers several components:
- [data loader] -> internal representation of the data -> data preprocessing -> interface adaptor for the fetch interface
- The workflow to combine them all:
The workflow may be very complicated. DataHandlerLP is one of the practices, but it can't satisfy all the requirements.
So leaving the flexibility to the user to implement the workflow is a more reasonable choice.
"""
def __init__(self, *args, **kwargs):
"""
We should define how to get ready for the fetching.
"""
super().__init__(*args, **kwargs)
CS_ALL = "__all" # return all columns with single-level index column
CS_RAW = "__raw" # return raw data with multi-level index column
# data key
DK_R: DATA_KEY_TYPE = "raw"
DK_I: DATA_KEY_TYPE = "infer"
DK_L: DATA_KEY_TYPE = "learn"
@abstractmethod
def fetch(
self,
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = CS_ALL,
data_key: DATA_KEY_TYPE = DK_I,
) -> pd.DataFrame:
pass
class DataHandler(DataHandlerABC):
"""
The motivation of DataHandler:
- It provides an implementation of BaseDataHandler that we implement with:
- Handling responses with an internal loaded DataFrame
- The DataFrame is loaded by a data loader.
The steps to using a handler
1. initialized data handler (call by `init`).
2. use the data.
@@ -144,16 +195,14 @@ class DataHandler(Serializable):
self._data = lazy_sort_index(self.data_loader.load(self.instruments, self.start_time, self.end_time))
# TODO: cache
CS_ALL = "__all" # return all columns with single-level index column
CS_RAW = "__raw" # return raw data with multi-level index column
def fetch(
self,
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = CS_ALL,
col_set: Union[str, List[str]] = DataHandlerABC.CS_ALL,
data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I,
squeeze: bool = False,
proc_func: Callable = None,
proc_func: Optional[Callable] = None,
) -> pd.DataFrame:
"""
fetch data from underlying data source
@@ -216,6 +265,8 @@ class DataHandler(Serializable):
-------
pd.DataFrame.
"""
# DataHandler is an example with only one dataframe, so data_key is not used.
_ = data_key # avoid linting errors (e.g., unused-argument)
return self._fetch_data(
data_storage=self._data,
selector=selector,
@@ -230,7 +281,7 @@ class DataHandler(Serializable):
data_storage,
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = CS_ALL,
col_set: Union[str, List[str]] = DataHandlerABC.CS_ALL,
squeeze: bool = False,
proc_func: Callable = None,
):
@@ -261,16 +312,9 @@ class DataHandler(Serializable):
data_df = fetch_df_by_col(data_df, col_set)
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
elif isinstance(data_storage, BaseHandlerStorage):
if not data_storage.is_proc_func_supported():
if proc_func is not None:
raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}")
data_df = data_storage.fetch(
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig
)
else:
data_df = data_storage.fetch(
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func
)
if proc_func is not None:
raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}")
data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig)
else:
raise TypeError(f"data_storage should be pd.DataFrame|HashingStockStorage, not {type(data_storage)}")
@@ -282,7 +326,7 @@ class DataHandler(Serializable):
data_df = data_df.reset_index(level=level, drop=True)
return data_df
def get_cols(self, col_set=CS_ALL) -> list:
def get_cols(self, col_set=DataHandlerABC.CS_ALL) -> list:
"""
get the column names
@@ -336,11 +380,12 @@ class DataHandler(Serializable):
yield cur_date, self.fetch(selector, **kwargs)
DATA_KEY_TYPE = Literal["raw", "infer", "learn"]
class DataHandlerLP(DataHandler):
"""
Motivation:
- For the case that we hope using different processor workflows for learning and inference;
DataHandler with **(L)earnable (P)rocessor**
This handler will produce three pieces of data in pd.DataFrame format.
@@ -374,12 +419,8 @@ class DataHandlerLP(DataHandler):
_infer: pd.DataFrame # data for inference
_learn: pd.DataFrame # data for learning models
# data key
DK_R: DATA_KEY_TYPE = "raw"
DK_I: DATA_KEY_TYPE = "infer"
DK_L: DATA_KEY_TYPE = "learn"
# map data_key to attribute name
ATTR_MAP = {DK_R: "_data", DK_I: "_infer", DK_L: "_learn"}
ATTR_MAP = {DataHandler.DK_R: "_data", DataHandler.DK_I: "_infer", DataHandler.DK_L: "_learn"}
# process type
PTYPE_I = "independent"
@@ -622,7 +663,7 @@ class DataHandlerLP(DataHandler):
# TODO: Be able to cache handler data. Save the memory for data processing
def _get_df_by_key(self, data_key: DATA_KEY_TYPE = DK_I) -> pd.DataFrame:
def _get_df_by_key(self, data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I) -> pd.DataFrame:
if data_key == self.DK_R and self.drop_raw:
raise AttributeError(
"DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data"
@@ -635,7 +676,7 @@ class DataHandlerLP(DataHandler):
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
level: Union[str, int] = "datetime",
col_set=DataHandler.CS_ALL,
data_key: DATA_KEY_TYPE = DK_I,
data_key: DATA_KEY_TYPE = DataHandler.DK_I,
squeeze: bool = False,
proc_func: Callable = None,
) -> pd.DataFrame:
@@ -669,7 +710,7 @@ class DataHandlerLP(DataHandler):
proc_func=proc_func,
)
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: DATA_KEY_TYPE = DK_I) -> list:
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I) -> list:
"""
get the column names

View File

@@ -1,8 +1,10 @@
from abc import abstractmethod
import pandas as pd
import numpy as np
from .handler import DataHandler
from typing import Union, List, Callable
from typing import Union, List
from qlib.log import get_module_logger
from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col
@@ -14,14 +16,13 @@ class BaseHandlerStorage:
- If users want to use custom data storage, they should define subclass inherited BaseHandlerStorage, and implement the following method
"""
@abstractmethod
def fetch(
self,
selector: Union[pd.Timestamp, slice, str, list] = slice(None, None),
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
fetch_orig: bool = True,
proc_func: Callable = None,
**kwargs,
) -> pd.DataFrame:
"""fetch data from the data storage
@@ -41,8 +42,6 @@ class BaseHandlerStorage:
select several sets of meaningful columns, the returned data has multiple level
fetch_orig : bool
Return the original data instead of copy if possible.
proc_func: Callable
please refer to the doc of DataHandler.fetch
Returns
-------
@@ -51,13 +50,40 @@ class BaseHandlerStorage:
"""
raise NotImplementedError("fetch is method not implemented!")
@staticmethod
def from_df(df: pd.DataFrame):
raise NotImplementedError("from_df method is not implemented!")
def is_proc_func_supported(self):
"""whether the arg `proc_func` in `fetch` method is supported."""
raise NotImplementedError("is_proc_func_supported method is not implemented!")
class NaiveDFStorage(BaseHandlerStorage):
"""Naive data storage for datahandler
- NaiveDFStorage is a naive data storage for datahandler
- NaiveDFStorage will input a pandas.DataFrame as and provide interface support for fetching data
"""
def __init__(self, df: pd.DataFrame):
self.df = df
def fetch(
self,
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
fetch_orig: bool = True,
) -> pd.DataFrame:
# Following conflicts may occur
# - Does [20200101", "20210101"] mean selecting this slice or these two days?
# To solve this issue
# - slice have higher priorities (except when level is none)
if isinstance(selector, (tuple, list)) and level is not None:
# when level is None, the argument will be passed in directly
# we don't have to convert it into slice
try:
selector = slice(*selector)
except ValueError:
get_module_logger("DataHandlerLP").info(f"Fail to converting to query to slice. It will used directly")
data_df = self.df
data_df = fetch_df_by_col(data_df, col_set)
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=fetch_orig)
return data_df
class HashingStockStorage(BaseHandlerStorage):
@@ -142,7 +168,7 @@ class HashingStockStorage(BaseHandlerStorage):
def fetch(
self,
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
fetch_orig: bool = True,
@@ -164,7 +190,3 @@ class HashingStockStorage(BaseHandlerStorage):
return fetch_stock_df_list[0]
else:
return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig)
def is_proc_func_supported(self):
"""the arg `proc_func` in `fetch` method is not supported in HashingStockStorage"""
return False

View File

@@ -240,7 +240,9 @@ class TrainerR(Trainer):
self.train_func = train_func
self._call_in_subproc = call_in_subproc
def train(self, tasks: list, train_func: Callable = None, experiment_name: str = None, **kwargs) -> List[Recorder]:
def train(
self, tasks: list, train_func: Optional[Callable] = None, experiment_name: Optional[str] = None, **kwargs
) -> List[Recorder]:
"""
Given a list of `tasks` and return a list of trained Recorder. The order can be guaranteed.