diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index c35cf62a3..5cb81e8c9 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -3,7 +3,7 @@ from typing import Callable, Union, List, Tuple, Dict, Text, Optional from ...utils import init_instance_by_config, np_ffill, time_to_slc_point from ...log import get_module_logger from .handler import DataHandler, DataHandlerLP -from copy import deepcopy +from copy import copy, deepcopy from inspect import getfullargspec import pandas as pd import numpy as np @@ -83,7 +83,9 @@ class DatasetH(Dataset): - The processing is related to data split. """ - def __init__(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], **kwargs): + def __init__( + self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], fetch_kwargs: Dict = {}, **kwargs + ): """ Setup the underlying data. @@ -114,7 +116,7 @@ class DatasetH(Dataset): """ self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler) self.segments = segments.copy() - self.fetch_kwargs = {} + self.fetch_kwargs = copy(fetch_kwargs) super().__init__(**kwargs) def config(self, handler_kwargs: dict = None, **kwargs): @@ -164,13 +166,13 @@ class DatasetH(Dataset): name=self.__class__.__name__, handler=self.handler, segments=self.segments ) - def _prepare_seg(self, slc: slice, **kwargs): + def _prepare_seg(self, slc, **kwargs): """ - Give a slice, retrieve the according data + Give a query, retrieve the according data Parameters ---------- - slc : slice + slc : please refer to the docs of `prepare` """ if hasattr(self, "fetch_kwargs"): return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs) @@ -179,7 +181,7 @@ class DatasetH(Dataset): def prepare( self, - segments: Union[List[Text], Tuple[Text], Text, slice], + segments: Union[List[Text], Tuple[Text], Text, slice, pd.Index], col_set=DataHandler.CS_ALL, data_key=DataHandlerLP.DK_I, **kwargs, @@ -218,22 +220,27 @@ class DatasetH(Dataset): NotImplementedError: """ logger = get_module_logger("DatasetH") - fetch_kwargs = {"col_set": col_set} - fetch_kwargs.update(kwargs) + seg_kwargs = {"col_set": col_set} + seg_kwargs.update(kwargs) if "data_key" in getfullargspec(self.handler.fetch).args: - fetch_kwargs["data_key"] = data_key + seg_kwargs["data_key"] = data_key else: logger.info(f"data_key[{data_key}] is ignored.") - # Handle all kinds of segments format - if isinstance(segments, (list, tuple)): - return [self._prepare_seg(slice(*self.segments[seg]), **fetch_kwargs) for seg in segments] - elif isinstance(segments, str): - return self._prepare_seg(slice(*self.segments[segments]), **fetch_kwargs) - elif isinstance(segments, slice): - return self._prepare_seg(segments, **fetch_kwargs) - else: - raise NotImplementedError(f"This type of input is not supported") + # Conflictions may happen here + # - The fetched data and the segment key may both be string + # To resolve the confliction + # - The segment name will have higher priorities + + # 1) Use it as segment name first + if isinstance(segments, str) and segments in self.segments: + return self._prepare_seg(self.segments[segments], **seg_kwargs) + + if isinstance(segments, (list, tuple)) and all(seg in self.segments for seg in segments): + return [self._prepare_seg(self.segments[seg], **seg_kwargs) for seg in segments] + + # 2) Use pass it directly to prepare a single seg + return self._prepare_seg(segments, **seg_kwargs) # helper functions @staticmethod @@ -582,8 +589,11 @@ class TSDatasetH(DatasetH): def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler: """ split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data + NOTE: TSDatasetH only support slc segment on datetime !!! """ dtype = kwargs.pop("dtype", None) + if not isinstance(slc, slice): + slc = slice(*slc) start, end = slc.start, slc.stop flt_col = kwargs.pop("flt_col", None) # TSDatasetH will retrieve more data for complete time-series diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index f2be3a3c6..0547ef41a 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -154,7 +154,7 @@ class DataHandler(Serializable): def fetch( self, - selector: Union[pd.Timestamp, slice, str] = slice(None, None), + selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None), level: Union[str, int] = "datetime", col_set: Union[str, List[str]] = CS_ALL, squeeze: bool = False, @@ -167,13 +167,24 @@ class DataHandler(Serializable): ---------- selector : Union[pd.Timestamp, slice, str] describe how to select data by index + It can be categories as following + - fetch single index + - fetch a range of index + - a slice range + - pd.Index for specific indexes + + Following conflictions may occurs + - Does [20200101", "20210101"] mean selecting this slice or these two days? + - slice have higher priorities + level : Union[str, int] which index level to select the data + col_set : Union[str, List[str]] - if isinstance(col_set, str): - select a set of meaningful columns.(e.g. features, columns) + select a set of meaningful, pd.Index columns.(e.g. features, columns) if col_set == CS_RAW: the raw dataset will be returned. @@ -181,6 +192,7 @@ class DataHandler(Serializable): - if isinstance(col_set, List[str]): select several sets of meaningful columns, the returned data has multiple levels + proc_func: Callable - Give a hook for processing data before fetching - An example to explain the necessity of the hook: @@ -197,9 +209,39 @@ class DataHandler(Serializable): ------- pd.DataFrame. """ + return self._fetch_data( + data_storage=self._data, + selector=selector, + level=level, + col_set=col_set, + squeeze=squeeze, + proc_func=proc_func, + ) + + def _fetch_data( + self, + data_storage, + selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None), + level: Union[str, int] = "datetime", + col_set: Union[str, List[str]] = CS_ALL, + squeeze: bool = False, + proc_func: Callable = None, + ): + # This method is extracted for sharing in subclasses from .storage import BaseHandlerStorage - data_storage = self._data + # Following conflictions may occurs + # - Does [20200101", "20210101"] mean selecting this slice or these two days? + # To solve this issue + # - slice have higher priorities (except when level is none) + if isinstance(selector, (tuple, list)) and level is not None: + # when level is None, the argument will be passed in directly + # we don't have to convert it into slice + try: + selector = slice(*selector) + except ValueError: + get_module_logger("DataHandlerLP").info(f"Fail to converting to query to slice. It will used directly") + if isinstance(data_storage, pd.DataFrame): data_df = data_storage if proc_func is not None: @@ -551,6 +593,7 @@ class DataHandlerLP(DataHandler): level: Union[str, int] = "datetime", col_set=DataHandler.CS_ALL, data_key: str = DK_I, + squeeze: bool = False, proc_func: Callable = None, ) -> pd.DataFrame: """ @@ -575,34 +618,14 @@ class DataHandlerLP(DataHandler): """ from .storage import BaseHandlerStorage - data_storage = self._get_df_by_key(data_key) - if isinstance(data_storage, pd.DataFrame): - data_df = data_storage - if proc_func is not None: - # FIXME: fetch by time first will be more friendly to proc_func - # Copy incase of `proc_func` changing the data inplace.... - data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy()) - data_df = fetch_df_by_col(data_df, col_set) - else: - # Fetch column first will be more friendly to SepDataFrame - data_df = fetch_df_by_col(data_df, col_set) - data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig) - - elif isinstance(data_storage, BaseHandlerStorage): - if not data_storage.is_proc_func_supported(): - if proc_func is not None: - raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}") - data_df = data_storage.fetch( - selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig - ) - else: - data_df = data_storage.fetch( - selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func - ) - else: - raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}") - - return data_df + return self._fetch_data( + data_storage=self._get_df_by_key(data_key), + selector=selector, + level=level, + col_set=col_set, + squeeze=squeeze, + proc_func=proc_func, + ) def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list: """ diff --git a/qlib/data/dataset/utils.py b/qlib/data/dataset/utils.py index 15946f3dc..4b8fedb0b 100644 --- a/qlib/data/dataset/utils.py +++ b/qlib/data/dataset/utils.py @@ -41,13 +41,16 @@ def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int: def fetch_df_by_index( df: pd.DataFrame, - selector: Union[pd.Timestamp, slice, str, list], + selector: Union[pd.Timestamp, slice, str, list, pd.Index], level: Union[str, int], fetch_orig=True, ) -> pd.DataFrame: """ fetch data from `data` with `selector` and `level` + selector are assumed to be well processed. + `fetch_df_by_index` is only responsible for get the right level + Parameters ---------- selector : Union[pd.Timestamp, slice, str, list]