1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 11:30:57 +08:00

Make the logic of handler Clear (#877)

This commit is contained in:
you-n-g
2022-01-20 22:36:28 +08:00
committed by GitHub
parent f979dcf5e8
commit da48f42f3f
3 changed files with 87 additions and 51 deletions

View File

@@ -3,7 +3,7 @@ from typing import Callable, Union, List, Tuple, Dict, Text, Optional
from ...utils import init_instance_by_config, np_ffill, time_to_slc_point
from ...log import get_module_logger
from .handler import DataHandler, DataHandlerLP
from copy import deepcopy
from copy import copy, deepcopy
from inspect import getfullargspec
import pandas as pd
import numpy as np
@@ -83,7 +83,9 @@ class DatasetH(Dataset):
- The processing is related to data split.
"""
def __init__(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], **kwargs):
def __init__(
self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], fetch_kwargs: Dict = {}, **kwargs
):
"""
Setup the underlying data.
@@ -114,7 +116,7 @@ class DatasetH(Dataset):
"""
self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler)
self.segments = segments.copy()
self.fetch_kwargs = {}
self.fetch_kwargs = copy(fetch_kwargs)
super().__init__(**kwargs)
def config(self, handler_kwargs: dict = None, **kwargs):
@@ -164,13 +166,13 @@ class DatasetH(Dataset):
name=self.__class__.__name__, handler=self.handler, segments=self.segments
)
def _prepare_seg(self, slc: slice, **kwargs):
def _prepare_seg(self, slc, **kwargs):
"""
Give a slice, retrieve the according data
Give a query, retrieve the according data
Parameters
----------
slc : slice
slc : please refer to the docs of `prepare`
"""
if hasattr(self, "fetch_kwargs"):
return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs)
@@ -179,7 +181,7 @@ class DatasetH(Dataset):
def prepare(
self,
segments: Union[List[Text], Tuple[Text], Text, slice],
segments: Union[List[Text], Tuple[Text], Text, slice, pd.Index],
col_set=DataHandler.CS_ALL,
data_key=DataHandlerLP.DK_I,
**kwargs,
@@ -218,22 +220,27 @@ class DatasetH(Dataset):
NotImplementedError:
"""
logger = get_module_logger("DatasetH")
fetch_kwargs = {"col_set": col_set}
fetch_kwargs.update(kwargs)
seg_kwargs = {"col_set": col_set}
seg_kwargs.update(kwargs)
if "data_key" in getfullargspec(self.handler.fetch).args:
fetch_kwargs["data_key"] = data_key
seg_kwargs["data_key"] = data_key
else:
logger.info(f"data_key[{data_key}] is ignored.")
# Handle all kinds of segments format
if isinstance(segments, (list, tuple)):
return [self._prepare_seg(slice(*self.segments[seg]), **fetch_kwargs) for seg in segments]
elif isinstance(segments, str):
return self._prepare_seg(slice(*self.segments[segments]), **fetch_kwargs)
elif isinstance(segments, slice):
return self._prepare_seg(segments, **fetch_kwargs)
else:
raise NotImplementedError(f"This type of input is not supported")
# Conflictions may happen here
# - The fetched data and the segment key may both be string
# To resolve the confliction
# - The segment name will have higher priorities
# 1) Use it as segment name first
if isinstance(segments, str) and segments in self.segments:
return self._prepare_seg(self.segments[segments], **seg_kwargs)
if isinstance(segments, (list, tuple)) and all(seg in self.segments for seg in segments):
return [self._prepare_seg(self.segments[seg], **seg_kwargs) for seg in segments]
# 2) Use pass it directly to prepare a single seg
return self._prepare_seg(segments, **seg_kwargs)
# helper functions
@staticmethod
@@ -582,8 +589,11 @@ class TSDatasetH(DatasetH):
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
"""
split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data
NOTE: TSDatasetH only support slc segment on datetime !!!
"""
dtype = kwargs.pop("dtype", None)
if not isinstance(slc, slice):
slc = slice(*slc)
start, end = slc.start, slc.stop
flt_col = kwargs.pop("flt_col", None)
# TSDatasetH will retrieve more data for complete time-series

View File

@@ -154,7 +154,7 @@ class DataHandler(Serializable):
def fetch(
self,
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = CS_ALL,
squeeze: bool = False,
@@ -167,13 +167,24 @@ class DataHandler(Serializable):
----------
selector : Union[pd.Timestamp, slice, str]
describe how to select data by index
It can be categories as following
- fetch single index
- fetch a range of index
- a slice range
- pd.Index for specific indexes
Following conflictions may occurs
- Does [20200101", "20210101"] mean selecting this slice or these two days?
- slice have higher priorities
level : Union[str, int]
which index level to select the data
col_set : Union[str, List[str]]
- if isinstance(col_set, str):
select a set of meaningful columns.(e.g. features, columns)
select a set of meaningful, pd.Index columns.(e.g. features, columns)
if col_set == CS_RAW:
the raw dataset will be returned.
@@ -181,6 +192,7 @@ class DataHandler(Serializable):
- if isinstance(col_set, List[str]):
select several sets of meaningful columns, the returned data has multiple levels
proc_func: Callable
- Give a hook for processing data before fetching
- An example to explain the necessity of the hook:
@@ -197,9 +209,39 @@ class DataHandler(Serializable):
-------
pd.DataFrame.
"""
return self._fetch_data(
data_storage=self._data,
selector=selector,
level=level,
col_set=col_set,
squeeze=squeeze,
proc_func=proc_func,
)
def _fetch_data(
self,
data_storage,
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = CS_ALL,
squeeze: bool = False,
proc_func: Callable = None,
):
# This method is extracted for sharing in subclasses
from .storage import BaseHandlerStorage
data_storage = self._data
# Following conflictions may occurs
# - Does [20200101", "20210101"] mean selecting this slice or these two days?
# To solve this issue
# - slice have higher priorities (except when level is none)
if isinstance(selector, (tuple, list)) and level is not None:
# when level is None, the argument will be passed in directly
# we don't have to convert it into slice
try:
selector = slice(*selector)
except ValueError:
get_module_logger("DataHandlerLP").info(f"Fail to converting to query to slice. It will used directly")
if isinstance(data_storage, pd.DataFrame):
data_df = data_storage
if proc_func is not None:
@@ -551,6 +593,7 @@ class DataHandlerLP(DataHandler):
level: Union[str, int] = "datetime",
col_set=DataHandler.CS_ALL,
data_key: str = DK_I,
squeeze: bool = False,
proc_func: Callable = None,
) -> pd.DataFrame:
"""
@@ -575,34 +618,14 @@ class DataHandlerLP(DataHandler):
"""
from .storage import BaseHandlerStorage
data_storage = self._get_df_by_key(data_key)
if isinstance(data_storage, pd.DataFrame):
data_df = data_storage
if proc_func is not None:
# FIXME: fetch by time first will be more friendly to proc_func
# Copy incase of `proc_func` changing the data inplace....
data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy())
data_df = fetch_df_by_col(data_df, col_set)
else:
# Fetch column first will be more friendly to SepDataFrame
data_df = fetch_df_by_col(data_df, col_set)
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
elif isinstance(data_storage, BaseHandlerStorage):
if not data_storage.is_proc_func_supported():
if proc_func is not None:
raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}")
data_df = data_storage.fetch(
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig
)
else:
data_df = data_storage.fetch(
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func
)
else:
raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}")
return data_df
return self._fetch_data(
data_storage=self._get_df_by_key(data_key),
selector=selector,
level=level,
col_set=col_set,
squeeze=squeeze,
proc_func=proc_func,
)
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list:
"""

View File

@@ -41,13 +41,16 @@ def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
def fetch_df_by_index(
df: pd.DataFrame,
selector: Union[pd.Timestamp, slice, str, list],
selector: Union[pd.Timestamp, slice, str, list, pd.Index],
level: Union[str, int],
fetch_orig=True,
) -> pd.DataFrame:
"""
fetch data from `data` with `selector` and `level`
selector are assumed to be well processed.
`fetch_df_by_index` is only responsible for get the right level
Parameters
----------
selector : Union[pd.Timestamp, slice, str, list]