From abb90ca2f6703f83450278a4ab2556cc988610a9 Mon Sep 17 00:00:00 2001 From: Young Date: Sun, 6 Dec 2020 12:44:09 +0000 Subject: [PATCH] fix sampler performance bug --- qlib/data/dataset/__init__.py | 22 ++++++++++++++++++---- tests/test_dataset.py | 14 ++++++++++---- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 415d1084b..384a7ea47 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -190,7 +190,12 @@ class TSDataSampler: It works like `torch.data.utils.Dataset`, it provides a very convient interface for constructing time-series dataset based on tabular data. - If user have further requirements for processing data, user could process + If user have further requirements for processing data, user could process them based on `TSDataSampler` or create + more powerful subclasses. + + Known Issues: + - For performance issues, this Sampler will convert dataframe into arrays for better performance. This could result + in a different data type """ @@ -223,11 +228,20 @@ class TSDataSampler: self.fillna_type = fillna_type assert get_level_index(data, "datetime") == 0 self.data = lazy_sort_index(data) + self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values! But + # the data type will be changed # The index of usable data is between start_idx and end_idx self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end)) # self.index_link = self.build_link(self.data) self.idx_df, self.idx_map = self.build_index(self.data) + def get_index(self): + """ + Get the pandas index of the data, it will be useful in following scenarios + - Special sampler will be used (e.g. user want to sample day by day) + """ + return self.data.index[self.start_idx : self.end_idx] + def config(self, **kwargs): # Config the attributes for k, v in kwargs.items(): @@ -313,14 +327,14 @@ class TSDataSampler: if np.isnan(indices.astype(np.float)).sum() == 0: # np.isnan only works on np.float # All the index exists - return self.data.values[indices.astype(np.int)] + return self.data_arr[indices.astype(np.int)] else: # Only part index exists. These days will be filled with nan for idx in indices: if np.isnan(idx): - data_l.append(np.full((self.data.shape[1],), np.nan)) + data_l.append(np.full((self.data_arr.shape[1],), np.nan)) else: - data_l.append(self.data.values[idx]) + data_l.append(self.data_arr[idx]) return np.array(data_l) def __len__(self): diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 01454fff8..056653ffa 100755 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -10,6 +10,7 @@ from torch.utils.data import DataLoader import time from qlib.data.dataset.handler import DataHandlerLP + class TestDataset(TestAutoData): def testTSDataset(self): tsdh = TSDatasetH( @@ -24,12 +25,12 @@ class TestDataset(TestAutoData): "instruments": "csi300", "infer_processors": [ {"class": "FilterCol", "kwargs": {"col_list": ["RESI5", "WVMA5", "RSQR5"]}}, - {"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier":"true"}}, + {"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier": "true"}}, {"class": "Fillna", "kwargs": {"fields_group": "feature"}}, ], "learn_processors": [ "DropnaLabel", - {"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}}, # CSRankNorm + {"class": "CSRankNorm", "kwargs": {"fields_group": "label"}}, # CSRankNorm ], }, }, @@ -44,7 +45,7 @@ class TestDataset(TestAutoData): t = time.time() for idx in np.random.randint(0, len(tsds_train), size=2000): - data = tsds_train[idx] + _ = tsds_train[idx] print(f"2000 sample takes {time.time() - t}s") # FIXME: Please remove pytorch related function. Otherwise the CI tests will fail @@ -74,7 +75,12 @@ class TestDataset(TestAutoData): # Check the data # Get data from DataFrame Directly - data_from_df = tsdh._handler.fetch().loc(axis=0)["2015-01-01":"2016-12-31", "SZ300315"].iloc[-30:].values + data_from_df = ( + tsdh._handler.fetch(data_key=DataHandlerLP.DK_L) + .loc(axis=0)["2015-01-01":"2016-12-31", "SZ300315"] + .iloc[-30:] + .values + ) equal = np.isclose(data_from_df, data_from_ds) self.assertTrue(equal[~np.isnan(data_from_df)].all())