mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 12:00:58 +08:00
update TSDataSampler refineing the memory layout of data array to speed up NN training (#1342)
* update TSDataSampler * reformat code with black * use pre-commit to reformat the code * Add documents * More docstring * More Safety Co-authored-by: Young <afe.young@gmail.com>
This commit is contained in:
@@ -82,7 +82,11 @@ class DatasetH(Dataset):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], fetch_kwargs: Dict = {}, **kwargs
|
||||
self,
|
||||
handler: Union[Dict, DataHandler],
|
||||
segments: Dict[Text, Tuple],
|
||||
fetch_kwargs: Dict = {},
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Setup the underlying data.
|
||||
@@ -284,10 +288,69 @@ class TSDataSampler:
|
||||
- For performance issues, this Sampler will convert dataframe into arrays for better performance. This could result
|
||||
in a different data type
|
||||
|
||||
|
||||
Indices design:
|
||||
TSDataSampler has a index mechanism to help users query time-series data efficiently.
|
||||
|
||||
The definition of related variables:
|
||||
data_arr: np.ndarray
|
||||
The original data. it will contains all the original data.
|
||||
The querying are often for time-series of a specific stock.
|
||||
By leveraging this data charactoristics to speed up querying, the multi-index of data_arr is rearranged in (instrument, datetime) order
|
||||
|
||||
data_index: pd.MultiIndex with index order <instrument, datetime>
|
||||
it has the same shape with `idx_map`. Each elements of them are expected to be aligned.
|
||||
|
||||
idx_map: np.ndarray
|
||||
It is the indexable data. It originates from data_arr, and then filtered by 1) `start` and `end` 2) `flt_data`
|
||||
The extra data in data_arr is useful in following cases
|
||||
1) creating meaningful time series data before `start` instead of padding them with zeros
|
||||
2) some data are excluded by `flt_data` (e.g. no <X, y> sample pair for that index). but they are still used in time-series in X
|
||||
|
||||
Finnally, it will look like.
|
||||
|
||||
array([[ 0, 0],
|
||||
[ 1, 0],
|
||||
[ 2, 0],
|
||||
...,
|
||||
[241, 348],
|
||||
[242, 348],
|
||||
[243, 348]], dtype=int32)
|
||||
|
||||
It list all indexable data(some data only used in historical time series data may not be indexabla), the values are the corresponding row and col in idx_df
|
||||
idx_df: pd.DataFrame
|
||||
It aims to map the <datetime, instrument> key to the original position in data_arr
|
||||
|
||||
For example, it may look like (NOTE: the index for a instrument time-series is continoues in memory)
|
||||
|
||||
instrument SH600000 SH600008 SH600009 SH600010 SH600011 SH600015 ...
|
||||
datetime
|
||||
2017-01-03 0 242 473 717 NaN 974 ...
|
||||
2017-01-04 1 243 474 718 NaN 975 ...
|
||||
2017-01-05 2 244 475 719 NaN 976 ...
|
||||
2017-01-06 3 245 476 720 NaN 977 ...
|
||||
|
||||
With these two indices(idx_map, idx_df) and original data(data_arr), we can make the following queries fast (implemented in __getitem__)
|
||||
(1) Get the i-th indexable sample(time-series): (indexable sample index) -> [idx_map] -> (row col) -> [idx_df] -> (index in data_arr)
|
||||
(2) Get the specific sample by <datetime, instrument>: (<datetime, instrument>, i.e. <row, col>) -> [idx_df] -> (index in data_arr)
|
||||
(3) Get the index of a time-series data: (get the <row, col>, refer to (1), (2)) -> [idx_df] -> (all indices in data_arr for time-series)
|
||||
"""
|
||||
|
||||
# Please refer to the docstring of TSDataSampler for the definition of following attributes
|
||||
data_arr: np.ndarray
|
||||
data_index: pd.MultiIndex
|
||||
idx_map: np.ndarray
|
||||
idx_df: pd.DataFrame
|
||||
|
||||
def __init__(
|
||||
self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None, flt_data=None
|
||||
self,
|
||||
data: pd.DataFrame,
|
||||
start,
|
||||
end,
|
||||
step_len: int,
|
||||
fillna_type: str = "none",
|
||||
dtype=None,
|
||||
flt_data=None,
|
||||
):
|
||||
"""
|
||||
Build a dataset which looks like torch.data.utils.Dataset.
|
||||
@@ -295,7 +358,7 @@ class TSDataSampler:
|
||||
Parameters
|
||||
----------
|
||||
data : pd.DataFrame
|
||||
The raw tabular data
|
||||
The raw tabular data whose index order is <"datetime", "instrument">
|
||||
start :
|
||||
The indexable start time
|
||||
end :
|
||||
@@ -311,7 +374,7 @@ class TSDataSampler:
|
||||
ffill+bfill:
|
||||
ffill with previous samples first and fill with later samples second
|
||||
flt_data : pd.Series
|
||||
a column of data(True or False) to filter data.
|
||||
a column of data(True or False) to filter data. Its index order is <"datetime", "instrument">
|
||||
None:
|
||||
kepp all data
|
||||
|
||||
@@ -321,7 +384,10 @@ class TSDataSampler:
|
||||
self.step_len = step_len
|
||||
self.fillna_type = fillna_type
|
||||
assert get_level_index(data, "datetime") == 0
|
||||
self.data = lazy_sort_index(data)
|
||||
self.data = data.swaplevel().sort_index().copy()
|
||||
data.drop(
|
||||
data.columns, axis=1, inplace=True
|
||||
) # data is useless since it's passed to a transposed one, hard code to free the memory of this dataframe to avoid three big dataframe in the memory(including: data, self.data, self.data_arr)
|
||||
|
||||
kwargs = {"object": self.data}
|
||||
if dtype is not None:
|
||||
@@ -332,7 +398,9 @@ class TSDataSampler:
|
||||
# - append last line with full NaN for better performance in `__getitem__`
|
||||
# - Keep the same dtype will result in a better performance
|
||||
self.data_arr = np.append(
|
||||
self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), axis=0
|
||||
self.data_arr,
|
||||
np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype),
|
||||
axis=0,
|
||||
)
|
||||
self.nan_idx = -1 # The last line is all NaN
|
||||
|
||||
@@ -347,19 +415,36 @@ class TSDataSampler:
|
||||
flt_data = flt_data.iloc[:, 0]
|
||||
# NOTE: bool(np.nan) is True !!!!!!!!
|
||||
# make sure reindex comes first. Otherwise extra NaN may appear.
|
||||
flt_data = flt_data.swaplevel()
|
||||
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
|
||||
self.flt_data = flt_data.values
|
||||
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
|
||||
self.data_index = self.data_index[np.where(self.flt_data)[0]]
|
||||
self.idx_map = self.idx_map2arr(self.idx_map)
|
||||
|
||||
self.start_idx, self.end_idx = self.data_index.slice_locs(
|
||||
start=time_to_slc_point(start), end=time_to_slc_point(end)
|
||||
self.idx_map, self.data_index = self.slice_idx_map_and_data_index(
|
||||
self.idx_map, self.idx_df, self.data_index, start, end
|
||||
)
|
||||
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
|
||||
|
||||
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
|
||||
del self.data # save memory
|
||||
|
||||
@staticmethod
|
||||
def slice_idx_map_and_data_index(
|
||||
idx_map,
|
||||
idx_df,
|
||||
data_index,
|
||||
start,
|
||||
end,
|
||||
):
|
||||
assert (
|
||||
len(idx_map) == data_index.shape[0]
|
||||
) # make sure idx_map and data_index is same so index of idx_map can be used on data_index
|
||||
|
||||
start_row_idx, end_row_idx = idx_df.index.slice_locs(start=time_to_slc_point(start), end=time_to_slc_point(end))
|
||||
|
||||
time_flter_idx = (idx_map[:, 0] < end_row_idx) & (idx_map[:, 0] >= start_row_idx)
|
||||
return idx_map[time_flter_idx], data_index[time_flter_idx]
|
||||
|
||||
@staticmethod
|
||||
def idx_map2arr(idx_map):
|
||||
# pytorch data sampler will have better memory control without large dict or list
|
||||
@@ -394,7 +479,7 @@ class TSDataSampler:
|
||||
Get the pandas index of the data, it will be useful in following scenarios
|
||||
- Special sampler will be used (e.g. user want to sample day by day)
|
||||
"""
|
||||
return self.data_index[self.start_idx : self.end_idx]
|
||||
return self.data_index.swaplevel() # to align the order of multiple index of original data received by __init__
|
||||
|
||||
def config(self, **kwargs):
|
||||
# Config the attributes
|
||||
@@ -409,25 +494,33 @@ class TSDataSampler:
|
||||
Parameters
|
||||
----------
|
||||
data : pd.DataFrame
|
||||
The dataframe with <datetime, DataFrame>
|
||||
A DataFrame with index in order <instrument, datetime>
|
||||
|
||||
RSQR5 RESI5 WVMA5 LABEL0
|
||||
instrument datetime
|
||||
SH600000 2017-01-03 0.016389 0.461632 -1.154788 -0.048056
|
||||
2017-01-04 0.884545 -0.110597 -1.059332 -0.030139
|
||||
2017-01-05 0.507540 -0.535493 -1.099665 -0.644983
|
||||
2017-01-06 -1.267771 -0.669685 -1.636733 0.295366
|
||||
2017-01-09 0.339346 0.074317 -0.984989 0.765540
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[pd.DataFrame, dict]:
|
||||
1) the first element: reshape the original index into a <datetime(row), instrument(column)> 2D dataframe
|
||||
instrument SH600000 SH600004 SH600006 SH600007 SH600008 SH600009 ...
|
||||
instrument SH600000 SH600008 SH600009 SH600010 SH600011 SH600015 ...
|
||||
datetime
|
||||
2021-01-11 0 1 2 3 4 5 ...
|
||||
2021-01-12 4146 4147 4148 4149 4150 4151 ...
|
||||
2021-01-13 8293 8294 8295 8296 8297 8298 ...
|
||||
2021-01-14 12441 12442 12443 12444 12445 12446 ...
|
||||
2017-01-03 0 242 473 717 NaN 974 ...
|
||||
2017-01-04 1 243 474 718 NaN 975 ...
|
||||
2017-01-05 2 244 475 719 NaN 976 ...
|
||||
2017-01-06 3 245 476 720 NaN 977 ...
|
||||
2) the second element: {<original index>: <row, col>}
|
||||
"""
|
||||
# object incase of pandas converting int to float
|
||||
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)
|
||||
idx_df = lazy_sort_index(idx_df.unstack())
|
||||
# NOTE: the correctness of `__getitem__` depends on columns sorted here
|
||||
idx_df = lazy_sort_index(idx_df, axis=1)
|
||||
idx_df = lazy_sort_index(idx_df, axis=1).T
|
||||
|
||||
idx_map = {}
|
||||
for i, (_, row) in enumerate(idx_df.iterrows()):
|
||||
@@ -485,11 +578,11 @@ class TSDataSampler:
|
||||
"""
|
||||
# The the right row number `i` and col number `j` in idx_df
|
||||
if isinstance(idx, (int, np.integer)):
|
||||
real_idx = self.start_idx + idx
|
||||
if self.start_idx <= real_idx < self.end_idx:
|
||||
real_idx = idx
|
||||
if 0 <= real_idx < len(self.idx_map):
|
||||
i, j = self.idx_map[real_idx] # TODO: The performance of this line is not good
|
||||
else:
|
||||
raise KeyError(f"{real_idx} is out of [{self.start_idx}, {self.end_idx})")
|
||||
raise KeyError(f"{real_idx} is out of [0, {len(self.idx_map)})")
|
||||
elif isinstance(idx, tuple):
|
||||
# <TSDataSampler object>["datetime", "instruments"]
|
||||
date, inst = idx
|
||||
@@ -532,7 +625,10 @@ class TSDataSampler:
|
||||
# precision problems. It will not cause any problems in my tests at least
|
||||
indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(int)
|
||||
|
||||
data = self.data_arr[indices]
|
||||
if (np.diff(indices) == 1).all(): # slicing instead of indexing for speeding up.
|
||||
data = self.data_arr[indices[0] : indices[-1] + 1]
|
||||
else:
|
||||
data = self.data_arr[indices]
|
||||
if isinstance(idx, mtit):
|
||||
# if we get multiple indexes, addition dimension should be added.
|
||||
# <sample_idx, step_idx, feature_idx>
|
||||
@@ -540,7 +636,7 @@ class TSDataSampler:
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return self.end_idx - self.start_idx
|
||||
return len(self.idx_map)
|
||||
|
||||
|
||||
class TSDatasetH(DatasetH):
|
||||
@@ -611,7 +707,14 @@ class TSDatasetH(DatasetH):
|
||||
else:
|
||||
flt_data = None
|
||||
|
||||
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype, flt_data=flt_data)
|
||||
tsds = TSDataSampler(
|
||||
data=data,
|
||||
start=start,
|
||||
end=end,
|
||||
step_len=self.step_len,
|
||||
dtype=dtype,
|
||||
flt_data=flt_data,
|
||||
)
|
||||
return tsds
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user