mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
refactor: introduce BaseDataHandler and unify fetch interface (#1958)
* refactor: introduce BaseDataHandler and unify fetch interface * refactor: include data_key in seg_kwargs and simplify segments loop * refactor: default data_key to BaseDataHandler.DK_I in _get_df_by_key * style: fix indentation and remove extra blank lines in data handlers * refactor: use BaseDataHandler.DK_I as default data_key * docs: fix BaseDataHandler docstring grammar and formatting * refactor: remove unused **kwargs from storage fetch methods * docs: refine BaseDataHandler and DataHandler docstrings * refactor: rename BaseDataHandler to DataHandlerABC, update type hints * feat: add flt_col to TSDatasetH and list-to-slice conversion in storage * lint * comment
This commit is contained in:
@@ -226,13 +226,8 @@ class DatasetH(Dataset):
|
||||
------
|
||||
NotImplementedError:
|
||||
"""
|
||||
logger = get_module_logger("DatasetH")
|
||||
seg_kwargs = {"col_set": col_set}
|
||||
seg_kwargs = {"col_set": col_set, "data_key": data_key}
|
||||
seg_kwargs.update(kwargs)
|
||||
if "data_key" in getfullargspec(self.handler.fetch).args:
|
||||
seg_kwargs["data_key"] = data_key
|
||||
else:
|
||||
logger.info(f"data_key[{data_key}] is ignored.")
|
||||
|
||||
# Conflictions may happen here
|
||||
# - The fetched data and the segment key may both be string
|
||||
@@ -240,9 +235,11 @@ class DatasetH(Dataset):
|
||||
# - The segment name will have higher priorities
|
||||
|
||||
# 1) Use it as segment name first
|
||||
# 1.1) directly fetch split like "train" "valid" "test"
|
||||
if isinstance(segments, str) and segments in self.segments:
|
||||
return self._prepare_seg(self.segments[segments], **seg_kwargs)
|
||||
|
||||
# 1.2) fetch multiple splits like ["train", "valid"] ["train", "valid", "test"]
|
||||
if isinstance(segments, (list, tuple)) and all(seg in self.segments for seg in segments):
|
||||
return [self._prepare_seg(self.segments[seg], **seg_kwargs) for seg in segments]
|
||||
|
||||
@@ -262,7 +259,7 @@ class DatasetH(Dataset):
|
||||
def _get_extrema(segments, idx: int, cmp: Callable, key_func=pd.Timestamp):
|
||||
"""it will act like sort and return the max value or None"""
|
||||
candidate = None
|
||||
for k, seg in segments.items():
|
||||
for _, seg in segments.items():
|
||||
point = seg[idx]
|
||||
if point is None:
|
||||
# None indicates unbounded, return directly
|
||||
@@ -376,6 +373,8 @@ class TSDataSampler:
|
||||
ffill with previous samples first and fill with later samples second
|
||||
flt_data : pd.Series
|
||||
a column of data(True or False) to filter data. Its index order is <"datetime", "instrument">
|
||||
This feature is essential because:
|
||||
- We want some sample not included due to label-based filtering, but we can't filter them at the beginning due to the features is still important in the feature.
|
||||
None:
|
||||
kepp all data
|
||||
|
||||
@@ -661,8 +660,9 @@ class TSDatasetH(DatasetH):
|
||||
|
||||
DEFAULT_STEP_LEN = 30
|
||||
|
||||
def __init__(self, step_len=DEFAULT_STEP_LEN, **kwargs):
|
||||
def __init__(self, step_len=DEFAULT_STEP_LEN, flt_col: Optional[str] = None, **kwargs):
|
||||
self.step_len = step_len
|
||||
self.flt_col = flt_col
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def config(self, **kwargs):
|
||||
@@ -693,10 +693,10 @@ class TSDatasetH(DatasetH):
|
||||
dtype = kwargs.pop("dtype", None)
|
||||
if not isinstance(slc, slice):
|
||||
slc = slice(*slc)
|
||||
start, end = slc.start, slc.stop
|
||||
flt_col = kwargs.pop("flt_col", None)
|
||||
# TSDatasetH will retrieve more data for complete time-series
|
||||
if (flt_col := kwargs.pop("flt_col", None)) is None:
|
||||
flt_col = self.flt_col
|
||||
|
||||
# TSDatasetH will retrieve more data for complete time-series
|
||||
ext_slice = self._extend_slice(slc, self.cal, self.step_len)
|
||||
data = super()._prepare_seg(ext_slice, **kwargs)
|
||||
|
||||
@@ -710,8 +710,8 @@ class TSDatasetH(DatasetH):
|
||||
|
||||
tsds = TSDataSampler(
|
||||
data=data,
|
||||
start=start,
|
||||
end=end,
|
||||
start=slc.start,
|
||||
end=slc.stop,
|
||||
step_len=self.step_len,
|
||||
dtype=dtype,
|
||||
flt_data=flt_data,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# coding=utf-8
|
||||
from abc import abstractmethod
|
||||
import warnings
|
||||
from typing import Callable, Union, Tuple, List, Iterator, Optional
|
||||
|
||||
@@ -19,9 +20,59 @@ from . import processor as processor_module
|
||||
from . import loader as data_loader_module
|
||||
|
||||
|
||||
# TODO: A more general handler interface which does not relies on internal pd.DataFrame is needed.
|
||||
class DataHandler(Serializable):
|
||||
DATA_KEY_TYPE = Literal["raw", "infer", "learn"]
|
||||
|
||||
|
||||
class DataHandlerABC(Serializable):
|
||||
"""
|
||||
Interface for data handler.
|
||||
|
||||
This class does not assume the internal data structure of the data handler.
|
||||
It only defines the interface for external users (uses DataFrame as the internal data structure).
|
||||
|
||||
In the future, the data handler's more detailed implementation should be refactored. Here are some guidelines:
|
||||
|
||||
It covers several components:
|
||||
|
||||
- [data loader] -> internal representation of the data -> data preprocessing -> interface adaptor for the fetch interface
|
||||
- The workflow to combine them all:
|
||||
The workflow may be very complicated. DataHandlerLP is one of the practices, but it can't satisfy all the requirements.
|
||||
So leaving the flexibility to the user to implement the workflow is a more reasonable choice.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
We should define how to get ready for the fetching.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
CS_ALL = "__all" # return all columns with single-level index column
|
||||
CS_RAW = "__raw" # return raw data with multi-level index column
|
||||
|
||||
# data key
|
||||
DK_R: DATA_KEY_TYPE = "raw"
|
||||
DK_I: DATA_KEY_TYPE = "infer"
|
||||
DK_L: DATA_KEY_TYPE = "learn"
|
||||
|
||||
@abstractmethod
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = CS_ALL,
|
||||
data_key: DATA_KEY_TYPE = DK_I,
|
||||
) -> pd.DataFrame:
|
||||
pass
|
||||
|
||||
|
||||
class DataHandler(DataHandlerABC):
|
||||
"""
|
||||
The motivation of DataHandler:
|
||||
|
||||
- It provides an implementation of BaseDataHandler that we implement with:
|
||||
- Handling responses with an internal loaded DataFrame
|
||||
- The DataFrame is loaded by a data loader.
|
||||
|
||||
The steps to using a handler
|
||||
1. initialized data handler (call by `init`).
|
||||
2. use the data.
|
||||
@@ -144,16 +195,14 @@ class DataHandler(Serializable):
|
||||
self._data = lazy_sort_index(self.data_loader.load(self.instruments, self.start_time, self.end_time))
|
||||
# TODO: cache
|
||||
|
||||
CS_ALL = "__all" # return all columns with single-level index column
|
||||
CS_RAW = "__raw" # return raw data with multi-level index column
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = CS_ALL,
|
||||
col_set: Union[str, List[str]] = DataHandlerABC.CS_ALL,
|
||||
data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I,
|
||||
squeeze: bool = False,
|
||||
proc_func: Callable = None,
|
||||
proc_func: Optional[Callable] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
fetch data from underlying data source
|
||||
@@ -216,6 +265,8 @@ class DataHandler(Serializable):
|
||||
-------
|
||||
pd.DataFrame.
|
||||
"""
|
||||
# DataHandler is an example with only one dataframe, so data_key is not used.
|
||||
_ = data_key # avoid linting errors (e.g., unused-argument)
|
||||
return self._fetch_data(
|
||||
data_storage=self._data,
|
||||
selector=selector,
|
||||
@@ -230,7 +281,7 @@ class DataHandler(Serializable):
|
||||
data_storage,
|
||||
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = CS_ALL,
|
||||
col_set: Union[str, List[str]] = DataHandlerABC.CS_ALL,
|
||||
squeeze: bool = False,
|
||||
proc_func: Callable = None,
|
||||
):
|
||||
@@ -261,16 +312,9 @@ class DataHandler(Serializable):
|
||||
data_df = fetch_df_by_col(data_df, col_set)
|
||||
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
|
||||
elif isinstance(data_storage, BaseHandlerStorage):
|
||||
if not data_storage.is_proc_func_supported():
|
||||
if proc_func is not None:
|
||||
raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}")
|
||||
data_df = data_storage.fetch(
|
||||
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig
|
||||
)
|
||||
else:
|
||||
data_df = data_storage.fetch(
|
||||
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func
|
||||
)
|
||||
if proc_func is not None:
|
||||
raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}")
|
||||
data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig)
|
||||
else:
|
||||
raise TypeError(f"data_storage should be pd.DataFrame|HashingStockStorage, not {type(data_storage)}")
|
||||
|
||||
@@ -282,7 +326,7 @@ class DataHandler(Serializable):
|
||||
data_df = data_df.reset_index(level=level, drop=True)
|
||||
return data_df
|
||||
|
||||
def get_cols(self, col_set=CS_ALL) -> list:
|
||||
def get_cols(self, col_set=DataHandlerABC.CS_ALL) -> list:
|
||||
"""
|
||||
get the column names
|
||||
|
||||
@@ -336,11 +380,12 @@ class DataHandler(Serializable):
|
||||
yield cur_date, self.fetch(selector, **kwargs)
|
||||
|
||||
|
||||
DATA_KEY_TYPE = Literal["raw", "infer", "learn"]
|
||||
|
||||
|
||||
class DataHandlerLP(DataHandler):
|
||||
"""
|
||||
Motivation:
|
||||
- For the case that we hope using different processor workflows for learning and inference;
|
||||
|
||||
|
||||
DataHandler with **(L)earnable (P)rocessor**
|
||||
|
||||
This handler will produce three pieces of data in pd.DataFrame format.
|
||||
@@ -374,12 +419,8 @@ class DataHandlerLP(DataHandler):
|
||||
_infer: pd.DataFrame # data for inference
|
||||
_learn: pd.DataFrame # data for learning models
|
||||
|
||||
# data key
|
||||
DK_R: DATA_KEY_TYPE = "raw"
|
||||
DK_I: DATA_KEY_TYPE = "infer"
|
||||
DK_L: DATA_KEY_TYPE = "learn"
|
||||
# map data_key to attribute name
|
||||
ATTR_MAP = {DK_R: "_data", DK_I: "_infer", DK_L: "_learn"}
|
||||
ATTR_MAP = {DataHandler.DK_R: "_data", DataHandler.DK_I: "_infer", DataHandler.DK_L: "_learn"}
|
||||
|
||||
# process type
|
||||
PTYPE_I = "independent"
|
||||
@@ -622,7 +663,7 @@ class DataHandlerLP(DataHandler):
|
||||
|
||||
# TODO: Be able to cache handler data. Save the memory for data processing
|
||||
|
||||
def _get_df_by_key(self, data_key: DATA_KEY_TYPE = DK_I) -> pd.DataFrame:
|
||||
def _get_df_by_key(self, data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I) -> pd.DataFrame:
|
||||
if data_key == self.DK_R and self.drop_raw:
|
||||
raise AttributeError(
|
||||
"DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data"
|
||||
@@ -635,7 +676,7 @@ class DataHandlerLP(DataHandler):
|
||||
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key: DATA_KEY_TYPE = DK_I,
|
||||
data_key: DATA_KEY_TYPE = DataHandler.DK_I,
|
||||
squeeze: bool = False,
|
||||
proc_func: Callable = None,
|
||||
) -> pd.DataFrame:
|
||||
@@ -669,7 +710,7 @@ class DataHandlerLP(DataHandler):
|
||||
proc_func=proc_func,
|
||||
)
|
||||
|
||||
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: DATA_KEY_TYPE = DK_I) -> list:
|
||||
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I) -> list:
|
||||
"""
|
||||
get the column names
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from abc import abstractmethod
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from .handler import DataHandler
|
||||
from typing import Union, List, Callable
|
||||
from typing import Union, List
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col
|
||||
|
||||
@@ -14,14 +16,13 @@ class BaseHandlerStorage:
|
||||
- If users want to use custom data storage, they should define subclass inherited BaseHandlerStorage, and implement the following method
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str, list] = slice(None, None),
|
||||
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
|
||||
fetch_orig: bool = True,
|
||||
proc_func: Callable = None,
|
||||
**kwargs,
|
||||
) -> pd.DataFrame:
|
||||
"""fetch data from the data storage
|
||||
|
||||
@@ -41,8 +42,6 @@ class BaseHandlerStorage:
|
||||
select several sets of meaningful columns, the returned data has multiple level
|
||||
fetch_orig : bool
|
||||
Return the original data instead of copy if possible.
|
||||
proc_func: Callable
|
||||
please refer to the doc of DataHandler.fetch
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -51,13 +50,40 @@ class BaseHandlerStorage:
|
||||
"""
|
||||
raise NotImplementedError("fetch is method not implemented!")
|
||||
|
||||
@staticmethod
|
||||
def from_df(df: pd.DataFrame):
|
||||
raise NotImplementedError("from_df method is not implemented!")
|
||||
|
||||
def is_proc_func_supported(self):
|
||||
"""whether the arg `proc_func` in `fetch` method is supported."""
|
||||
raise NotImplementedError("is_proc_func_supported method is not implemented!")
|
||||
class NaiveDFStorage(BaseHandlerStorage):
|
||||
"""Naive data storage for datahandler
|
||||
- NaiveDFStorage is a naive data storage for datahandler
|
||||
- NaiveDFStorage will input a pandas.DataFrame as and provide interface support for fetching data
|
||||
"""
|
||||
|
||||
def __init__(self, df: pd.DataFrame):
|
||||
self.df = df
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
|
||||
fetch_orig: bool = True,
|
||||
) -> pd.DataFrame:
|
||||
|
||||
# Following conflicts may occur
|
||||
# - Does [20200101", "20210101"] mean selecting this slice or these two days?
|
||||
# To solve this issue
|
||||
# - slice have higher priorities (except when level is none)
|
||||
if isinstance(selector, (tuple, list)) and level is not None:
|
||||
# when level is None, the argument will be passed in directly
|
||||
# we don't have to convert it into slice
|
||||
try:
|
||||
selector = slice(*selector)
|
||||
except ValueError:
|
||||
get_module_logger("DataHandlerLP").info(f"Fail to converting to query to slice. It will used directly")
|
||||
|
||||
data_df = self.df
|
||||
data_df = fetch_df_by_col(data_df, col_set)
|
||||
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=fetch_orig)
|
||||
return data_df
|
||||
|
||||
|
||||
class HashingStockStorage(BaseHandlerStorage):
|
||||
@@ -142,7 +168,7 @@ class HashingStockStorage(BaseHandlerStorage):
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
|
||||
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
|
||||
fetch_orig: bool = True,
|
||||
@@ -164,7 +190,3 @@ class HashingStockStorage(BaseHandlerStorage):
|
||||
return fetch_stock_df_list[0]
|
||||
else:
|
||||
return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig)
|
||||
|
||||
def is_proc_func_supported(self):
|
||||
"""the arg `proc_func` in `fetch` method is not supported in HashingStockStorage"""
|
||||
return False
|
||||
|
||||
@@ -240,7 +240,9 @@ class TrainerR(Trainer):
|
||||
self.train_func = train_func
|
||||
self._call_in_subproc = call_in_subproc
|
||||
|
||||
def train(self, tasks: list, train_func: Callable = None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
def train(
|
||||
self, tasks: list, train_func: Optional[Callable] = None, experiment_name: Optional[str] = None, **kwargs
|
||||
) -> List[Recorder]:
|
||||
"""
|
||||
Given a list of `tasks` and return a list of trained Recorder. The order can be guaranteed.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user