mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 10:01:19 +08:00
support multi indexing of TSDatasetSample
This commit is contained in:
@@ -229,11 +229,16 @@ class TSDataSampler:
|
||||
assert get_level_index(data, "datetime") == 0
|
||||
self.data = lazy_sort_index(data)
|
||||
self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values! But
|
||||
# NOTE: append last line with full NaN for better performance in `__getitem__`
|
||||
self.data_arr = np.append(self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan), axis=0)
|
||||
self.nan_idx = -1 # The last line is all NaN
|
||||
|
||||
# the data type will be changed
|
||||
# The index of usable data is between start_idx and end_idx
|
||||
self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
|
||||
# self.index_link = self.build_link(self.data)
|
||||
self.idx_df, self.idx_map = self.build_index(self.data)
|
||||
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
|
||||
|
||||
def get_index(self):
|
||||
"""
|
||||
@@ -276,7 +281,68 @@ class TSDataSampler:
|
||||
idx_map[real_idx] = (i, j)
|
||||
return idx_df, idx_map
|
||||
|
||||
def __getitem__(self, idx: Union[int, Tuple[object, str]]):
|
||||
def _get_indices(self, row: int, col: int) -> np.array:
|
||||
"""
|
||||
get series indices of self.data_arr from the row, col indices of self.idx_df
|
||||
|
||||
Parameters
|
||||
----------
|
||||
row : int
|
||||
the row in self.idx_df
|
||||
col : int
|
||||
the col in self.idx_df
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.array:
|
||||
The indices of data of the data
|
||||
"""
|
||||
indices = self.idx_arr[max(row - self.step_len + 1, 0) : row + 1, col]
|
||||
|
||||
if len(indices) < self.step_len:
|
||||
indices = np.concatenate([np.full((self.step_len - len(indices),), np.nan), indices])
|
||||
|
||||
if self.fillna_type == "ffill":
|
||||
indices = np_ffill(indices)
|
||||
elif self.fillna_type == "ffill+bfill":
|
||||
indices = np_ffill(np_ffill(indices)[::-1])[::-1]
|
||||
else:
|
||||
assert self.fillna_type == "none"
|
||||
return indices
|
||||
|
||||
def _get_row_col(self, idx) -> Tuple[int]:
|
||||
"""
|
||||
get the col index and row index of a given sample index in self.idx_df
|
||||
|
||||
Parameters
|
||||
----------
|
||||
idx :
|
||||
the input of `__getitem__`
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[int]:
|
||||
the row and col index
|
||||
"""
|
||||
# The the right row number `i` and col number `j` in idx_df
|
||||
if isinstance(idx, (int, np.integer)):
|
||||
real_idx = self.start_idx + idx
|
||||
if self.start_idx <= real_idx < self.end_idx:
|
||||
i, j = self.idx_map[real_idx] # TODO: The performance of this line is not good
|
||||
else:
|
||||
raise KeyError(f"{real_idx} is out of [{self.start_idx}, {self.end_idx})")
|
||||
elif isinstance(idx, tuple):
|
||||
# <TSDataSampler object>["datetime", "instruments"]
|
||||
date, inst = idx
|
||||
date = pd.Timestamp(date)
|
||||
i = bisect.bisect_right(self.idx_df.index, date) - 1
|
||||
# NOTE: This relies on the idx_df columns sorted in `__init__`
|
||||
j = bisect.bisect_left(self.idx_df.columns, inst)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return i, j
|
||||
|
||||
def __getitem__(self, idx: Union[int, Tuple[object, str], List[int]]):
|
||||
"""
|
||||
# We have two method to get the time-series of a sample
|
||||
tsds is a instance of TSDataSampler
|
||||
@@ -294,48 +360,26 @@ class TSDataSampler:
|
||||
----------
|
||||
idx : Union[int, Tuple[object, str]]
|
||||
"""
|
||||
# The the right row number `i` and col number `j` in idx_df
|
||||
if isinstance(idx, (int, np.integer)):
|
||||
real_idx = self.start_idx + idx
|
||||
if self.start_idx <= real_idx < self.end_idx:
|
||||
i, j = self.idx_map[real_idx]
|
||||
else:
|
||||
raise KeyError(f"{real_idx} is out of [{self.start_idx}, {self.end_idx})")
|
||||
elif isinstance(idx, tuple):
|
||||
# <TSDataSampler object>["datetime", "instruments"]
|
||||
date, inst = idx
|
||||
date = pd.Timestamp(date)
|
||||
i = bisect.bisect_right(self.idx_df.index, date) - 1
|
||||
# NOTE: This relies on the idx_df columns sorted in `__init__`
|
||||
j = bisect.bisect_left(self.idx_df.columns, inst)
|
||||
# Multi-index type
|
||||
mtit = (list, np.ndarray)
|
||||
if isinstance(idx, mtit):
|
||||
indices = [self._get_indices(*self._get_row_col(i)) for i in idx]
|
||||
indices = np.concatenate(indices)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
indices = self._get_indices(*self._get_row_col(idx))
|
||||
|
||||
data_l = []
|
||||
indices = self.idx_df.values[max(i - self.step_len + 1, 0) : i + 1, j]
|
||||
indices = indices.reshape(-1)
|
||||
|
||||
if len(indices) < self.step_len:
|
||||
indices = np.concatenate([np.full((self.step_len - len(indices),), np.nan), indices])
|
||||
# 1) for better performance, use the last nan line for padding the lost date
|
||||
# 2) In case of precision problems. We use np.float64. # TODO: I'm not sure if whether np.float64 will result in
|
||||
# precision problems. It will not cause any problems in my tests at least
|
||||
indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(np.int)
|
||||
|
||||
if self.fillna_type == "ffill":
|
||||
indices = np_ffill(indices)
|
||||
elif self.fillna_type == "ffill+bfill":
|
||||
indices = np_ffill(np_ffill(indices)[::-1])[::-1]
|
||||
else:
|
||||
assert self.fillna_type == "none"
|
||||
|
||||
if np.isnan(indices.astype(np.float)).sum() == 0: # np.isnan only works on np.float
|
||||
# All the index exists
|
||||
return self.data_arr[indices.astype(np.int)]
|
||||
else:
|
||||
# Only part index exists. These days will be filled with nan
|
||||
for idx in indices:
|
||||
if np.isnan(idx):
|
||||
data_l.append(np.full((self.data_arr.shape[1],), np.nan))
|
||||
else:
|
||||
data_l.append(self.data_arr[idx])
|
||||
return np.array(data_l)
|
||||
data = self.data_arr[indices]
|
||||
if isinstance(idx, mtit):
|
||||
# if we get multiple indexes, addition dimension should be added.
|
||||
# <sample_idx, step_idx, feature_idx>
|
||||
data = data.reshape(-1, self.step_len, *data.shape[1:])
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return self.end_idx - self.start_idx
|
||||
|
||||
@@ -48,6 +48,12 @@ class TestDataset(TestAutoData):
|
||||
_ = tsds_train[idx]
|
||||
print(f"2000 sample takes {time.time() - t}s")
|
||||
|
||||
t = time.time()
|
||||
for _ in range(20):
|
||||
data = tsds_train[np.random.randint(0, len(tsds_train), size=2000)]
|
||||
print(data.shape)
|
||||
print(f"2000 sample(batch index) * 20 times takes {time.time() - t}s")
|
||||
|
||||
# FIXME: Please remove pytorch related function. Otherwise the CI tests will fail
|
||||
train_loader = DataLoader(tsds_train, batch_size=800, shuffle=True, num_workers=10)
|
||||
t = time.time()
|
||||
@@ -88,3 +94,8 @@ class TestDataset(TestAutoData):
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=10)
|
||||
|
||||
# User could use following code to run test when using line_profiler
|
||||
# td = TestDataset()
|
||||
# td.setUpClass()
|
||||
# td.testTSDataset()
|
||||
|
||||
Reference in New Issue
Block a user