From fb4a2e65ccdc62963d8aadf9e0ab5a57777d29fb Mon Sep 17 00:00:00 2001 From: Young Date: Mon, 7 Dec 2020 10:31:14 +0000 Subject: [PATCH] support multi indexing of TSDatasetSample --- qlib/data/dataset/__init__.py | 122 +++++++++++++++++++++++----------- tests/test_dataset.py | 11 +++ 2 files changed, 94 insertions(+), 39 deletions(-) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 384a7ea47..a07f7ab8f 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -229,11 +229,16 @@ class TSDataSampler: assert get_level_index(data, "datetime") == 0 self.data = lazy_sort_index(data) self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values! But + # NOTE: append last line with full NaN for better performance in `__getitem__` + self.data_arr = np.append(self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan), axis=0) + self.nan_idx = -1 # The last line is all NaN + # the data type will be changed # The index of usable data is between start_idx and end_idx self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end)) # self.index_link = self.build_link(self.data) self.idx_df, self.idx_map = self.build_index(self.data) + self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance def get_index(self): """ @@ -276,7 +281,68 @@ class TSDataSampler: idx_map[real_idx] = (i, j) return idx_df, idx_map - def __getitem__(self, idx: Union[int, Tuple[object, str]]): + def _get_indices(self, row: int, col: int) -> np.array: + """ + get series indices of self.data_arr from the row, col indices of self.idx_df + + Parameters + ---------- + row : int + the row in self.idx_df + col : int + the col in self.idx_df + + Returns + ------- + np.array: + The indices of data of the data + """ + indices = self.idx_arr[max(row - self.step_len + 1, 0) : row + 1, col] + + if len(indices) < self.step_len: + indices = np.concatenate([np.full((self.step_len - len(indices),), np.nan), indices]) + + if self.fillna_type == "ffill": + indices = np_ffill(indices) + elif self.fillna_type == "ffill+bfill": + indices = np_ffill(np_ffill(indices)[::-1])[::-1] + else: + assert self.fillna_type == "none" + return indices + + def _get_row_col(self, idx) -> Tuple[int]: + """ + get the col index and row index of a given sample index in self.idx_df + + Parameters + ---------- + idx : + the input of `__getitem__` + + Returns + ------- + Tuple[int]: + the row and col index + """ + # The the right row number `i` and col number `j` in idx_df + if isinstance(idx, (int, np.integer)): + real_idx = self.start_idx + idx + if self.start_idx <= real_idx < self.end_idx: + i, j = self.idx_map[real_idx] # TODO: The performance of this line is not good + else: + raise KeyError(f"{real_idx} is out of [{self.start_idx}, {self.end_idx})") + elif isinstance(idx, tuple): + # ["datetime", "instruments"] + date, inst = idx + date = pd.Timestamp(date) + i = bisect.bisect_right(self.idx_df.index, date) - 1 + # NOTE: This relies on the idx_df columns sorted in `__init__` + j = bisect.bisect_left(self.idx_df.columns, inst) + else: + raise NotImplementedError(f"This type of input is not supported") + return i, j + + def __getitem__(self, idx: Union[int, Tuple[object, str], List[int]]): """ # We have two method to get the time-series of a sample tsds is a instance of TSDataSampler @@ -294,48 +360,26 @@ class TSDataSampler: ---------- idx : Union[int, Tuple[object, str]] """ - # The the right row number `i` and col number `j` in idx_df - if isinstance(idx, (int, np.integer)): - real_idx = self.start_idx + idx - if self.start_idx <= real_idx < self.end_idx: - i, j = self.idx_map[real_idx] - else: - raise KeyError(f"{real_idx} is out of [{self.start_idx}, {self.end_idx})") - elif isinstance(idx, tuple): - # ["datetime", "instruments"] - date, inst = idx - date = pd.Timestamp(date) - i = bisect.bisect_right(self.idx_df.index, date) - 1 - # NOTE: This relies on the idx_df columns sorted in `__init__` - j = bisect.bisect_left(self.idx_df.columns, inst) + # Multi-index type + mtit = (list, np.ndarray) + if isinstance(idx, mtit): + indices = [self._get_indices(*self._get_row_col(i)) for i in idx] + indices = np.concatenate(indices) else: - raise NotImplementedError(f"This type of input is not supported") + indices = self._get_indices(*self._get_row_col(idx)) - data_l = [] - indices = self.idx_df.values[max(i - self.step_len + 1, 0) : i + 1, j] - indices = indices.reshape(-1) - if len(indices) < self.step_len: - indices = np.concatenate([np.full((self.step_len - len(indices),), np.nan), indices]) + # 1) for better performance, use the last nan line for padding the lost date + # 2) In case of precision problems. We use np.float64. # TODO: I'm not sure if whether np.float64 will result in + # precision problems. It will not cause any problems in my tests at least + indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(np.int) - if self.fillna_type == "ffill": - indices = np_ffill(indices) - elif self.fillna_type == "ffill+bfill": - indices = np_ffill(np_ffill(indices)[::-1])[::-1] - else: - assert self.fillna_type == "none" - - if np.isnan(indices.astype(np.float)).sum() == 0: # np.isnan only works on np.float - # All the index exists - return self.data_arr[indices.astype(np.int)] - else: - # Only part index exists. These days will be filled with nan - for idx in indices: - if np.isnan(idx): - data_l.append(np.full((self.data_arr.shape[1],), np.nan)) - else: - data_l.append(self.data_arr[idx]) - return np.array(data_l) + data = self.data_arr[indices] + if isinstance(idx, mtit): + # if we get multiple indexes, addition dimension should be added. + # + data = data.reshape(-1, self.step_len, *data.shape[1:]) + return data def __len__(self): return self.end_idx - self.start_idx diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 056653ffa..cc781f024 100755 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -48,6 +48,12 @@ class TestDataset(TestAutoData): _ = tsds_train[idx] print(f"2000 sample takes {time.time() - t}s") + t = time.time() + for _ in range(20): + data = tsds_train[np.random.randint(0, len(tsds_train), size=2000)] + print(data.shape) + print(f"2000 sample(batch index) * 20 times takes {time.time() - t}s") + # FIXME: Please remove pytorch related function. Otherwise the CI tests will fail train_loader = DataLoader(tsds_train, batch_size=800, shuffle=True, num_workers=10) t = time.time() @@ -88,3 +94,8 @@ class TestDataset(TestAutoData): if __name__ == "__main__": unittest.main(verbosity=10) + + # User could use following code to run test when using line_profiler + # td = TestDataset() + # td.setUpClass() + # td.testTSDataset()