From a82cc0b12963e989fd33527bce666aabfa21ce75 Mon Sep 17 00:00:00 2001 From: Xu Yang Date: Fri, 11 Nov 2022 19:35:10 +0800 Subject: [PATCH] update TSDataSampler refineing the memory layout of data array to speed up NN training (#1342) * update TSDataSampler * reformat code with black * use pre-commit to reformat the code * Add documents * More docstring * More Safety Co-authored-by: Young --- qlib/data/dataset/__init__.py | 151 ++++++++++++++++++++++++++++------ 1 file changed, 127 insertions(+), 24 deletions(-) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 5e98bfc97..dcc9957ed 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -82,7 +82,11 @@ class DatasetH(Dataset): """ def __init__( - self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], fetch_kwargs: Dict = {}, **kwargs + self, + handler: Union[Dict, DataHandler], + segments: Dict[Text, Tuple], + fetch_kwargs: Dict = {}, + **kwargs, ): """ Setup the underlying data. @@ -284,10 +288,69 @@ class TSDataSampler: - For performance issues, this Sampler will convert dataframe into arrays for better performance. This could result in a different data type + + Indices design: + TSDataSampler has a index mechanism to help users query time-series data efficiently. + + The definition of related variables: + data_arr: np.ndarray + The original data. it will contains all the original data. + The querying are often for time-series of a specific stock. + By leveraging this data charactoristics to speed up querying, the multi-index of data_arr is rearranged in (instrument, datetime) order + + data_index: pd.MultiIndex with index order + it has the same shape with `idx_map`. Each elements of them are expected to be aligned. + + idx_map: np.ndarray + It is the indexable data. It originates from data_arr, and then filtered by 1) `start` and `end` 2) `flt_data` + The extra data in data_arr is useful in following cases + 1) creating meaningful time series data before `start` instead of padding them with zeros + 2) some data are excluded by `flt_data` (e.g. no sample pair for that index). but they are still used in time-series in X + + Finnally, it will look like. + + array([[ 0, 0], + [ 1, 0], + [ 2, 0], + ..., + [241, 348], + [242, 348], + [243, 348]], dtype=int32) + + It list all indexable data(some data only used in historical time series data may not be indexabla), the values are the corresponding row and col in idx_df + idx_df: pd.DataFrame + It aims to map the key to the original position in data_arr + + For example, it may look like (NOTE: the index for a instrument time-series is continoues in memory) + + instrument SH600000 SH600008 SH600009 SH600010 SH600011 SH600015 ... + datetime + 2017-01-03 0 242 473 717 NaN 974 ... + 2017-01-04 1 243 474 718 NaN 975 ... + 2017-01-05 2 244 475 719 NaN 976 ... + 2017-01-06 3 245 476 720 NaN 977 ... + + With these two indices(idx_map, idx_df) and original data(data_arr), we can make the following queries fast (implemented in __getitem__) + (1) Get the i-th indexable sample(time-series): (indexable sample index) -> [idx_map] -> (row col) -> [idx_df] -> (index in data_arr) + (2) Get the specific sample by : (, i.e. ) -> [idx_df] -> (index in data_arr) + (3) Get the index of a time-series data: (get the , refer to (1), (2)) -> [idx_df] -> (all indices in data_arr for time-series) """ + # Please refer to the docstring of TSDataSampler for the definition of following attributes + data_arr: np.ndarray + data_index: pd.MultiIndex + idx_map: np.ndarray + idx_df: pd.DataFrame + def __init__( - self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None, flt_data=None + self, + data: pd.DataFrame, + start, + end, + step_len: int, + fillna_type: str = "none", + dtype=None, + flt_data=None, ): """ Build a dataset which looks like torch.data.utils.Dataset. @@ -295,7 +358,7 @@ class TSDataSampler: Parameters ---------- data : pd.DataFrame - The raw tabular data + The raw tabular data whose index order is <"datetime", "instrument"> start : The indexable start time end : @@ -311,7 +374,7 @@ class TSDataSampler: ffill+bfill: ffill with previous samples first and fill with later samples second flt_data : pd.Series - a column of data(True or False) to filter data. + a column of data(True or False) to filter data. Its index order is <"datetime", "instrument"> None: kepp all data @@ -321,7 +384,10 @@ class TSDataSampler: self.step_len = step_len self.fillna_type = fillna_type assert get_level_index(data, "datetime") == 0 - self.data = lazy_sort_index(data) + self.data = data.swaplevel().sort_index().copy() + data.drop( + data.columns, axis=1, inplace=True + ) # data is useless since it's passed to a transposed one, hard code to free the memory of this dataframe to avoid three big dataframe in the memory(including: data, self.data, self.data_arr) kwargs = {"object": self.data} if dtype is not None: @@ -332,7 +398,9 @@ class TSDataSampler: # - append last line with full NaN for better performance in `__getitem__` # - Keep the same dtype will result in a better performance self.data_arr = np.append( - self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), axis=0 + self.data_arr, + np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), + axis=0, ) self.nan_idx = -1 # The last line is all NaN @@ -347,19 +415,36 @@ class TSDataSampler: flt_data = flt_data.iloc[:, 0] # NOTE: bool(np.nan) is True !!!!!!!! # make sure reindex comes first. Otherwise extra NaN may appear. + flt_data = flt_data.swaplevel() flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool) self.flt_data = flt_data.values self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map) self.data_index = self.data_index[np.where(self.flt_data)[0]] self.idx_map = self.idx_map2arr(self.idx_map) - - self.start_idx, self.end_idx = self.data_index.slice_locs( - start=time_to_slc_point(start), end=time_to_slc_point(end) + self.idx_map, self.data_index = self.slice_idx_map_and_data_index( + self.idx_map, self.idx_df, self.data_index, start, end ) - self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance + self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance del self.data # save memory + @staticmethod + def slice_idx_map_and_data_index( + idx_map, + idx_df, + data_index, + start, + end, + ): + assert ( + len(idx_map) == data_index.shape[0] + ) # make sure idx_map and data_index is same so index of idx_map can be used on data_index + + start_row_idx, end_row_idx = idx_df.index.slice_locs(start=time_to_slc_point(start), end=time_to_slc_point(end)) + + time_flter_idx = (idx_map[:, 0] < end_row_idx) & (idx_map[:, 0] >= start_row_idx) + return idx_map[time_flter_idx], data_index[time_flter_idx] + @staticmethod def idx_map2arr(idx_map): # pytorch data sampler will have better memory control without large dict or list @@ -394,7 +479,7 @@ class TSDataSampler: Get the pandas index of the data, it will be useful in following scenarios - Special sampler will be used (e.g. user want to sample day by day) """ - return self.data_index[self.start_idx : self.end_idx] + return self.data_index.swaplevel() # to align the order of multiple index of original data received by __init__ def config(self, **kwargs): # Config the attributes @@ -409,25 +494,33 @@ class TSDataSampler: Parameters ---------- data : pd.DataFrame - The dataframe with + A DataFrame with index in order + + RSQR5 RESI5 WVMA5 LABEL0 + instrument datetime + SH600000 2017-01-03 0.016389 0.461632 -1.154788 -0.048056 + 2017-01-04 0.884545 -0.110597 -1.059332 -0.030139 + 2017-01-05 0.507540 -0.535493 -1.099665 -0.644983 + 2017-01-06 -1.267771 -0.669685 -1.636733 0.295366 + 2017-01-09 0.339346 0.074317 -0.984989 0.765540 Returns ------- Tuple[pd.DataFrame, dict]: 1) the first element: reshape the original index into a 2D dataframe - instrument SH600000 SH600004 SH600006 SH600007 SH600008 SH600009 ... + instrument SH600000 SH600008 SH600009 SH600010 SH600011 SH600015 ... datetime - 2021-01-11 0 1 2 3 4 5 ... - 2021-01-12 4146 4147 4148 4149 4150 4151 ... - 2021-01-13 8293 8294 8295 8296 8297 8298 ... - 2021-01-14 12441 12442 12443 12444 12445 12446 ... + 2017-01-03 0 242 473 717 NaN 974 ... + 2017-01-04 1 243 474 718 NaN 975 ... + 2017-01-05 2 244 475 719 NaN 976 ... + 2017-01-06 3 245 476 720 NaN 977 ... 2) the second element: {: } """ # object incase of pandas converting int to float idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object) idx_df = lazy_sort_index(idx_df.unstack()) # NOTE: the correctness of `__getitem__` depends on columns sorted here - idx_df = lazy_sort_index(idx_df, axis=1) + idx_df = lazy_sort_index(idx_df, axis=1).T idx_map = {} for i, (_, row) in enumerate(idx_df.iterrows()): @@ -485,11 +578,11 @@ class TSDataSampler: """ # The the right row number `i` and col number `j` in idx_df if isinstance(idx, (int, np.integer)): - real_idx = self.start_idx + idx - if self.start_idx <= real_idx < self.end_idx: + real_idx = idx + if 0 <= real_idx < len(self.idx_map): i, j = self.idx_map[real_idx] # TODO: The performance of this line is not good else: - raise KeyError(f"{real_idx} is out of [{self.start_idx}, {self.end_idx})") + raise KeyError(f"{real_idx} is out of [0, {len(self.idx_map)})") elif isinstance(idx, tuple): # ["datetime", "instruments"] date, inst = idx @@ -532,7 +625,10 @@ class TSDataSampler: # precision problems. It will not cause any problems in my tests at least indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(int) - data = self.data_arr[indices] + if (np.diff(indices) == 1).all(): # slicing instead of indexing for speeding up. + data = self.data_arr[indices[0] : indices[-1] + 1] + else: + data = self.data_arr[indices] if isinstance(idx, mtit): # if we get multiple indexes, addition dimension should be added. # @@ -540,7 +636,7 @@ class TSDataSampler: return data def __len__(self): - return self.end_idx - self.start_idx + return len(self.idx_map) class TSDatasetH(DatasetH): @@ -611,7 +707,14 @@ class TSDatasetH(DatasetH): else: flt_data = None - tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype, flt_data=flt_data) + tsds = TSDataSampler( + data=data, + start=start, + end=end, + step_len=self.step_len, + dtype=dtype, + flt_data=flt_data, + ) return tsds