mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 11:30:57 +08:00
Make the logic of handler Clear (#877)
This commit is contained in:
@@ -3,7 +3,7 @@ from typing import Callable, Union, List, Tuple, Dict, Text, Optional
|
||||
from ...utils import init_instance_by_config, np_ffill, time_to_slc_point
|
||||
from ...log import get_module_logger
|
||||
from .handler import DataHandler, DataHandlerLP
|
||||
from copy import deepcopy
|
||||
from copy import copy, deepcopy
|
||||
from inspect import getfullargspec
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
@@ -83,7 +83,9 @@ class DatasetH(Dataset):
|
||||
- The processing is related to data split.
|
||||
"""
|
||||
|
||||
def __init__(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], **kwargs):
|
||||
def __init__(
|
||||
self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], fetch_kwargs: Dict = {}, **kwargs
|
||||
):
|
||||
"""
|
||||
Setup the underlying data.
|
||||
|
||||
@@ -114,7 +116,7 @@ class DatasetH(Dataset):
|
||||
"""
|
||||
self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self.segments = segments.copy()
|
||||
self.fetch_kwargs = {}
|
||||
self.fetch_kwargs = copy(fetch_kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def config(self, handler_kwargs: dict = None, **kwargs):
|
||||
@@ -164,13 +166,13 @@ class DatasetH(Dataset):
|
||||
name=self.__class__.__name__, handler=self.handler, segments=self.segments
|
||||
)
|
||||
|
||||
def _prepare_seg(self, slc: slice, **kwargs):
|
||||
def _prepare_seg(self, slc, **kwargs):
|
||||
"""
|
||||
Give a slice, retrieve the according data
|
||||
Give a query, retrieve the according data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
slc : slice
|
||||
slc : please refer to the docs of `prepare`
|
||||
"""
|
||||
if hasattr(self, "fetch_kwargs"):
|
||||
return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs)
|
||||
@@ -179,7 +181,7 @@ class DatasetH(Dataset):
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
segments: Union[List[Text], Tuple[Text], Text, slice],
|
||||
segments: Union[List[Text], Tuple[Text], Text, slice, pd.Index],
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key=DataHandlerLP.DK_I,
|
||||
**kwargs,
|
||||
@@ -218,22 +220,27 @@ class DatasetH(Dataset):
|
||||
NotImplementedError:
|
||||
"""
|
||||
logger = get_module_logger("DatasetH")
|
||||
fetch_kwargs = {"col_set": col_set}
|
||||
fetch_kwargs.update(kwargs)
|
||||
seg_kwargs = {"col_set": col_set}
|
||||
seg_kwargs.update(kwargs)
|
||||
if "data_key" in getfullargspec(self.handler.fetch).args:
|
||||
fetch_kwargs["data_key"] = data_key
|
||||
seg_kwargs["data_key"] = data_key
|
||||
else:
|
||||
logger.info(f"data_key[{data_key}] is ignored.")
|
||||
|
||||
# Handle all kinds of segments format
|
||||
if isinstance(segments, (list, tuple)):
|
||||
return [self._prepare_seg(slice(*self.segments[seg]), **fetch_kwargs) for seg in segments]
|
||||
elif isinstance(segments, str):
|
||||
return self._prepare_seg(slice(*self.segments[segments]), **fetch_kwargs)
|
||||
elif isinstance(segments, slice):
|
||||
return self._prepare_seg(segments, **fetch_kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
# Conflictions may happen here
|
||||
# - The fetched data and the segment key may both be string
|
||||
# To resolve the confliction
|
||||
# - The segment name will have higher priorities
|
||||
|
||||
# 1) Use it as segment name first
|
||||
if isinstance(segments, str) and segments in self.segments:
|
||||
return self._prepare_seg(self.segments[segments], **seg_kwargs)
|
||||
|
||||
if isinstance(segments, (list, tuple)) and all(seg in self.segments for seg in segments):
|
||||
return [self._prepare_seg(self.segments[seg], **seg_kwargs) for seg in segments]
|
||||
|
||||
# 2) Use pass it directly to prepare a single seg
|
||||
return self._prepare_seg(segments, **seg_kwargs)
|
||||
|
||||
# helper functions
|
||||
@staticmethod
|
||||
@@ -582,8 +589,11 @@ class TSDatasetH(DatasetH):
|
||||
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
|
||||
"""
|
||||
split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data
|
||||
NOTE: TSDatasetH only support slc segment on datetime !!!
|
||||
"""
|
||||
dtype = kwargs.pop("dtype", None)
|
||||
if not isinstance(slc, slice):
|
||||
slc = slice(*slc)
|
||||
start, end = slc.start, slc.stop
|
||||
flt_col = kwargs.pop("flt_col", None)
|
||||
# TSDatasetH will retrieve more data for complete time-series
|
||||
|
||||
@@ -154,7 +154,7 @@ class DataHandler(Serializable):
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
|
||||
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = CS_ALL,
|
||||
squeeze: bool = False,
|
||||
@@ -167,13 +167,24 @@ class DataHandler(Serializable):
|
||||
----------
|
||||
selector : Union[pd.Timestamp, slice, str]
|
||||
describe how to select data by index
|
||||
It can be categories as following
|
||||
- fetch single index
|
||||
- fetch a range of index
|
||||
- a slice range
|
||||
- pd.Index for specific indexes
|
||||
|
||||
Following conflictions may occurs
|
||||
- Does [20200101", "20210101"] mean selecting this slice or these two days?
|
||||
- slice have higher priorities
|
||||
|
||||
level : Union[str, int]
|
||||
which index level to select the data
|
||||
|
||||
col_set : Union[str, List[str]]
|
||||
|
||||
- if isinstance(col_set, str):
|
||||
|
||||
select a set of meaningful columns.(e.g. features, columns)
|
||||
select a set of meaningful, pd.Index columns.(e.g. features, columns)
|
||||
|
||||
if col_set == CS_RAW:
|
||||
the raw dataset will be returned.
|
||||
@@ -181,6 +192,7 @@ class DataHandler(Serializable):
|
||||
- if isinstance(col_set, List[str]):
|
||||
|
||||
select several sets of meaningful columns, the returned data has multiple levels
|
||||
|
||||
proc_func: Callable
|
||||
- Give a hook for processing data before fetching
|
||||
- An example to explain the necessity of the hook:
|
||||
@@ -197,9 +209,39 @@ class DataHandler(Serializable):
|
||||
-------
|
||||
pd.DataFrame.
|
||||
"""
|
||||
return self._fetch_data(
|
||||
data_storage=self._data,
|
||||
selector=selector,
|
||||
level=level,
|
||||
col_set=col_set,
|
||||
squeeze=squeeze,
|
||||
proc_func=proc_func,
|
||||
)
|
||||
|
||||
def _fetch_data(
|
||||
self,
|
||||
data_storage,
|
||||
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = CS_ALL,
|
||||
squeeze: bool = False,
|
||||
proc_func: Callable = None,
|
||||
):
|
||||
# This method is extracted for sharing in subclasses
|
||||
from .storage import BaseHandlerStorage
|
||||
|
||||
data_storage = self._data
|
||||
# Following conflictions may occurs
|
||||
# - Does [20200101", "20210101"] mean selecting this slice or these two days?
|
||||
# To solve this issue
|
||||
# - slice have higher priorities (except when level is none)
|
||||
if isinstance(selector, (tuple, list)) and level is not None:
|
||||
# when level is None, the argument will be passed in directly
|
||||
# we don't have to convert it into slice
|
||||
try:
|
||||
selector = slice(*selector)
|
||||
except ValueError:
|
||||
get_module_logger("DataHandlerLP").info(f"Fail to converting to query to slice. It will used directly")
|
||||
|
||||
if isinstance(data_storage, pd.DataFrame):
|
||||
data_df = data_storage
|
||||
if proc_func is not None:
|
||||
@@ -551,6 +593,7 @@ class DataHandlerLP(DataHandler):
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key: str = DK_I,
|
||||
squeeze: bool = False,
|
||||
proc_func: Callable = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
@@ -575,34 +618,14 @@ class DataHandlerLP(DataHandler):
|
||||
"""
|
||||
from .storage import BaseHandlerStorage
|
||||
|
||||
data_storage = self._get_df_by_key(data_key)
|
||||
if isinstance(data_storage, pd.DataFrame):
|
||||
data_df = data_storage
|
||||
if proc_func is not None:
|
||||
# FIXME: fetch by time first will be more friendly to proc_func
|
||||
# Copy incase of `proc_func` changing the data inplace....
|
||||
data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy())
|
||||
data_df = fetch_df_by_col(data_df, col_set)
|
||||
else:
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
data_df = fetch_df_by_col(data_df, col_set)
|
||||
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
|
||||
|
||||
elif isinstance(data_storage, BaseHandlerStorage):
|
||||
if not data_storage.is_proc_func_supported():
|
||||
if proc_func is not None:
|
||||
raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}")
|
||||
data_df = data_storage.fetch(
|
||||
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig
|
||||
)
|
||||
else:
|
||||
data_df = data_storage.fetch(
|
||||
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}")
|
||||
|
||||
return data_df
|
||||
return self._fetch_data(
|
||||
data_storage=self._get_df_by_key(data_key),
|
||||
selector=selector,
|
||||
level=level,
|
||||
col_set=col_set,
|
||||
squeeze=squeeze,
|
||||
proc_func=proc_func,
|
||||
)
|
||||
|
||||
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list:
|
||||
"""
|
||||
|
||||
@@ -41,13 +41,16 @@ def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
|
||||
|
||||
def fetch_df_by_index(
|
||||
df: pd.DataFrame,
|
||||
selector: Union[pd.Timestamp, slice, str, list],
|
||||
selector: Union[pd.Timestamp, slice, str, list, pd.Index],
|
||||
level: Union[str, int],
|
||||
fetch_orig=True,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
fetch data from `data` with `selector` and `level`
|
||||
|
||||
selector are assumed to be well processed.
|
||||
`fetch_df_by_index` is only responsible for get the right level
|
||||
|
||||
Parameters
|
||||
----------
|
||||
selector : Union[pd.Timestamp, slice, str, list]
|
||||
|
||||
Reference in New Issue
Block a user