1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 12:00:58 +08:00

update TSDataSampler refineing the memory layout of data array to speed up NN training (#1342)

* update TSDataSampler

* reformat code with black

* use pre-commit to reformat the code

* Add documents

* More docstring

* More Safety

Co-authored-by: Young <afe.young@gmail.com>
This commit is contained in:
Xu Yang
2022-11-11 19:35:10 +08:00
committed by GitHub
parent 3b471a0fe3
commit a82cc0b129

View File

@@ -82,7 +82,11 @@ class DatasetH(Dataset):
"""
def __init__(
self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], fetch_kwargs: Dict = {}, **kwargs
self,
handler: Union[Dict, DataHandler],
segments: Dict[Text, Tuple],
fetch_kwargs: Dict = {},
**kwargs,
):
"""
Setup the underlying data.
@@ -284,10 +288,69 @@ class TSDataSampler:
- For performance issues, this Sampler will convert dataframe into arrays for better performance. This could result
in a different data type
Indices design:
TSDataSampler has a index mechanism to help users query time-series data efficiently.
The definition of related variables:
data_arr: np.ndarray
The original data. it will contains all the original data.
The querying are often for time-series of a specific stock.
By leveraging this data charactoristics to speed up querying, the multi-index of data_arr is rearranged in (instrument, datetime) order
data_index: pd.MultiIndex with index order <instrument, datetime>
it has the same shape with `idx_map`. Each elements of them are expected to be aligned.
idx_map: np.ndarray
It is the indexable data. It originates from data_arr, and then filtered by 1) `start` and `end` 2) `flt_data`
The extra data in data_arr is useful in following cases
1) creating meaningful time series data before `start` instead of padding them with zeros
2) some data are excluded by `flt_data` (e.g. no <X, y> sample pair for that index). but they are still used in time-series in X
Finnally, it will look like.
array([[ 0, 0],
[ 1, 0],
[ 2, 0],
...,
[241, 348],
[242, 348],
[243, 348]], dtype=int32)
It list all indexable data(some data only used in historical time series data may not be indexabla), the values are the corresponding row and col in idx_df
idx_df: pd.DataFrame
It aims to map the <datetime, instrument> key to the original position in data_arr
For example, it may look like (NOTE: the index for a instrument time-series is continoues in memory)
instrument SH600000 SH600008 SH600009 SH600010 SH600011 SH600015 ...
datetime
2017-01-03 0 242 473 717 NaN 974 ...
2017-01-04 1 243 474 718 NaN 975 ...
2017-01-05 2 244 475 719 NaN 976 ...
2017-01-06 3 245 476 720 NaN 977 ...
With these two indices(idx_map, idx_df) and original data(data_arr), we can make the following queries fast (implemented in __getitem__)
(1) Get the i-th indexable sample(time-series): (indexable sample index) -> [idx_map] -> (row col) -> [idx_df] -> (index in data_arr)
(2) Get the specific sample by <datetime, instrument>: (<datetime, instrument>, i.e. <row, col>) -> [idx_df] -> (index in data_arr)
(3) Get the index of a time-series data: (get the <row, col>, refer to (1), (2)) -> [idx_df] -> (all indices in data_arr for time-series)
"""
# Please refer to the docstring of TSDataSampler for the definition of following attributes
data_arr: np.ndarray
data_index: pd.MultiIndex
idx_map: np.ndarray
idx_df: pd.DataFrame
def __init__(
self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None, flt_data=None
self,
data: pd.DataFrame,
start,
end,
step_len: int,
fillna_type: str = "none",
dtype=None,
flt_data=None,
):
"""
Build a dataset which looks like torch.data.utils.Dataset.
@@ -295,7 +358,7 @@ class TSDataSampler:
Parameters
----------
data : pd.DataFrame
The raw tabular data
The raw tabular data whose index order is <"datetime", "instrument">
start :
The indexable start time
end :
@@ -311,7 +374,7 @@ class TSDataSampler:
ffill+bfill:
ffill with previous samples first and fill with later samples second
flt_data : pd.Series
a column of data(True or False) to filter data.
a column of data(True or False) to filter data. Its index order is <"datetime", "instrument">
None:
kepp all data
@@ -321,7 +384,10 @@ class TSDataSampler:
self.step_len = step_len
self.fillna_type = fillna_type
assert get_level_index(data, "datetime") == 0
self.data = lazy_sort_index(data)
self.data = data.swaplevel().sort_index().copy()
data.drop(
data.columns, axis=1, inplace=True
) # data is useless since it's passed to a transposed one, hard code to free the memory of this dataframe to avoid three big dataframe in the memory(including: data, self.data, self.data_arr)
kwargs = {"object": self.data}
if dtype is not None:
@@ -332,7 +398,9 @@ class TSDataSampler:
# - append last line with full NaN for better performance in `__getitem__`
# - Keep the same dtype will result in a better performance
self.data_arr = np.append(
self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), axis=0
self.data_arr,
np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype),
axis=0,
)
self.nan_idx = -1 # The last line is all NaN
@@ -347,19 +415,36 @@ class TSDataSampler:
flt_data = flt_data.iloc[:, 0]
# NOTE: bool(np.nan) is True !!!!!!!!
# make sure reindex comes first. Otherwise extra NaN may appear.
flt_data = flt_data.swaplevel()
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
self.flt_data = flt_data.values
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
self.data_index = self.data_index[np.where(self.flt_data)[0]]
self.idx_map = self.idx_map2arr(self.idx_map)
self.start_idx, self.end_idx = self.data_index.slice_locs(
start=time_to_slc_point(start), end=time_to_slc_point(end)
self.idx_map, self.data_index = self.slice_idx_map_and_data_index(
self.idx_map, self.idx_df, self.data_index, start, end
)
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
del self.data # save memory
@staticmethod
def slice_idx_map_and_data_index(
idx_map,
idx_df,
data_index,
start,
end,
):
assert (
len(idx_map) == data_index.shape[0]
) # make sure idx_map and data_index is same so index of idx_map can be used on data_index
start_row_idx, end_row_idx = idx_df.index.slice_locs(start=time_to_slc_point(start), end=time_to_slc_point(end))
time_flter_idx = (idx_map[:, 0] < end_row_idx) & (idx_map[:, 0] >= start_row_idx)
return idx_map[time_flter_idx], data_index[time_flter_idx]
@staticmethod
def idx_map2arr(idx_map):
# pytorch data sampler will have better memory control without large dict or list
@@ -394,7 +479,7 @@ class TSDataSampler:
Get the pandas index of the data, it will be useful in following scenarios
- Special sampler will be used (e.g. user want to sample day by day)
"""
return self.data_index[self.start_idx : self.end_idx]
return self.data_index.swaplevel() # to align the order of multiple index of original data received by __init__
def config(self, **kwargs):
# Config the attributes
@@ -409,25 +494,33 @@ class TSDataSampler:
Parameters
----------
data : pd.DataFrame
The dataframe with <datetime, DataFrame>
A DataFrame with index in order <instrument, datetime>
RSQR5 RESI5 WVMA5 LABEL0
instrument datetime
SH600000 2017-01-03 0.016389 0.461632 -1.154788 -0.048056
2017-01-04 0.884545 -0.110597 -1.059332 -0.030139
2017-01-05 0.507540 -0.535493 -1.099665 -0.644983
2017-01-06 -1.267771 -0.669685 -1.636733 0.295366
2017-01-09 0.339346 0.074317 -0.984989 0.765540
Returns
-------
Tuple[pd.DataFrame, dict]:
1) the first element: reshape the original index into a <datetime(row), instrument(column)> 2D dataframe
instrument SH600000 SH600004 SH600006 SH600007 SH600008 SH600009 ...
instrument SH600000 SH600008 SH600009 SH600010 SH600011 SH600015 ...
datetime
2021-01-11 0 1 2 3 4 5 ...
2021-01-12 4146 4147 4148 4149 4150 4151 ...
2021-01-13 8293 8294 8295 8296 8297 8298 ...
2021-01-14 12441 12442 12443 12444 12445 12446 ...
2017-01-03 0 242 473 717 NaN 974 ...
2017-01-04 1 243 474 718 NaN 975 ...
2017-01-05 2 244 475 719 NaN 976 ...
2017-01-06 3 245 476 720 NaN 977 ...
2) the second element: {<original index>: <row, col>}
"""
# object incase of pandas converting int to float
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)
idx_df = lazy_sort_index(idx_df.unstack())
# NOTE: the correctness of `__getitem__` depends on columns sorted here
idx_df = lazy_sort_index(idx_df, axis=1)
idx_df = lazy_sort_index(idx_df, axis=1).T
idx_map = {}
for i, (_, row) in enumerate(idx_df.iterrows()):
@@ -485,11 +578,11 @@ class TSDataSampler:
"""
# The the right row number `i` and col number `j` in idx_df
if isinstance(idx, (int, np.integer)):
real_idx = self.start_idx + idx
if self.start_idx <= real_idx < self.end_idx:
real_idx = idx
if 0 <= real_idx < len(self.idx_map):
i, j = self.idx_map[real_idx] # TODO: The performance of this line is not good
else:
raise KeyError(f"{real_idx} is out of [{self.start_idx}, {self.end_idx})")
raise KeyError(f"{real_idx} is out of [0, {len(self.idx_map)})")
elif isinstance(idx, tuple):
# <TSDataSampler object>["datetime", "instruments"]
date, inst = idx
@@ -532,7 +625,10 @@ class TSDataSampler:
# precision problems. It will not cause any problems in my tests at least
indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(int)
data = self.data_arr[indices]
if (np.diff(indices) == 1).all(): # slicing instead of indexing for speeding up.
data = self.data_arr[indices[0] : indices[-1] + 1]
else:
data = self.data_arr[indices]
if isinstance(idx, mtit):
# if we get multiple indexes, addition dimension should be added.
# <sample_idx, step_idx, feature_idx>
@@ -540,7 +636,7 @@ class TSDataSampler:
return data
def __len__(self):
return self.end_idx - self.start_idx
return len(self.idx_map)
class TSDatasetH(DatasetH):
@@ -611,7 +707,14 @@ class TSDatasetH(DatasetH):
else:
flt_data = None
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype, flt_data=flt_data)
tsds = TSDataSampler(
data=data,
start=start,
end=end,
step_len=self.step_len,
dtype=dtype,
flt_data=flt_data,
)
return tsds