diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 2fe8f8a63..9b2a6fa32 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -7,6 +7,7 @@ from typing import Callable, Union, Tuple, List, Iterator, Optional import pandas as pd +from qlib.typehint import Literal from ...log import get_module_logger, TimeInspector from ...utils import init_instance_by_config from ...utils.serial import Serializable @@ -49,6 +50,8 @@ class DataHandler(Serializable): - Fetching data with `col_set=CS_RAW` will return the raw data and may avoid pandas from copying the data when calling `loc` """ + _data: pd.DataFrame # underlying data. + def __init__( self, instruments=None, @@ -155,6 +158,11 @@ class DataHandler(Serializable): """ fetch data from underlying data source + Design motivation: + - providing a unified interface for underlying data. + - Potential to make the interface more friendly. + - User can improve performance when fetching data in this extra layer + Parameters ---------- selector : Union[pd.Timestamp, slice, str] @@ -328,6 +336,9 @@ class DataHandler(Serializable): yield cur_date, self.fetch(selector, **kwargs) +DATA_KEY_TYPE = Literal["raw", "infer", "learn"] + + class DataHandlerLP(DataHandler): """ DataHandler with **(L)earnable (P)rocessor** @@ -353,10 +364,15 @@ class DataHandlerLP(DataHandler): - `drop_raw=True`: this will modify the data inplace on raw data; """ + # based on `self._data`, _infer and _learn are genrated after processors + _infer: pd.DataFrame # data for inference + _learn: pd.DataFrame # data for learning models + # data key - DK_R = "raw" - DK_I = "infer" - DK_L = "learn" + DK_R: DATA_KEY_TYPE = "raw" + DK_I: DATA_KEY_TYPE = "infer" + DK_L: DATA_KEY_TYPE = "learn" + # map data_key to attribute name ATTR_MAP = {DK_R: "_data", DK_I: "_infer", DK_L: "_learn"} # process type @@ -600,7 +616,7 @@ class DataHandlerLP(DataHandler): # TODO: Be able to cache handler data. Save the memory for data processing - def _get_df_by_key(self, data_key: str = DK_I) -> pd.DataFrame: + def _get_df_by_key(self, data_key: DATA_KEY_TYPE = DK_I) -> pd.DataFrame: if data_key == self.DK_R and self.drop_raw: raise AttributeError( "DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data" @@ -613,7 +629,7 @@ class DataHandlerLP(DataHandler): selector: Union[pd.Timestamp, slice, str] = slice(None, None), level: Union[str, int] = "datetime", col_set=DataHandler.CS_ALL, - data_key: str = DK_I, + data_key: DATA_KEY_TYPE = DK_I, squeeze: bool = False, proc_func: Callable = None, ) -> pd.DataFrame: @@ -647,7 +663,7 @@ class DataHandlerLP(DataHandler): proc_func=proc_func, ) - def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list: + def get_cols(self, col_set=DataHandler.CS_ALL, data_key: DATA_KEY_TYPE = DK_I) -> list: """ get the column names @@ -655,7 +671,7 @@ class DataHandlerLP(DataHandler): ---------- col_set : str select a set of meaningful columns.(e.g. features, columns). - data_key : str + data_key : DATA_KEY_TYPE the data to fetch: DK_*. Returns