1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 20:11:08 +08:00

fix sampler performance bug

This commit is contained in:
Young
2020-12-06 12:44:09 +00:00
committed by you-n-g
parent a7c6aea386
commit abb90ca2f6
2 changed files with 28 additions and 8 deletions

View File

@@ -190,7 +190,12 @@ class TSDataSampler:
It works like `torch.data.utils.Dataset`, it provides a very convient interface for constructing time-series
dataset based on tabular data.
If user have further requirements for processing data, user could process
If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
more powerful subclasses.
Known Issues:
- For performance issues, this Sampler will convert dataframe into arrays for better performance. This could result
in a different data type
"""
@@ -223,11 +228,20 @@ class TSDataSampler:
self.fillna_type = fillna_type
assert get_level_index(data, "datetime") == 0
self.data = lazy_sort_index(data)
self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values! But
# the data type will be changed
# The index of usable data is between start_idx and end_idx
self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
# self.index_link = self.build_link(self.data)
self.idx_df, self.idx_map = self.build_index(self.data)
def get_index(self):
"""
Get the pandas index of the data, it will be useful in following scenarios
- Special sampler will be used (e.g. user want to sample day by day)
"""
return self.data.index[self.start_idx : self.end_idx]
def config(self, **kwargs):
# Config the attributes
for k, v in kwargs.items():
@@ -313,14 +327,14 @@ class TSDataSampler:
if np.isnan(indices.astype(np.float)).sum() == 0: # np.isnan only works on np.float
# All the index exists
return self.data.values[indices.astype(np.int)]
return self.data_arr[indices.astype(np.int)]
else:
# Only part index exists. These days will be filled with nan
for idx in indices:
if np.isnan(idx):
data_l.append(np.full((self.data.shape[1],), np.nan))
data_l.append(np.full((self.data_arr.shape[1],), np.nan))
else:
data_l.append(self.data.values[idx])
data_l.append(self.data_arr[idx])
return np.array(data_l)
def __len__(self):

View File

@@ -10,6 +10,7 @@ from torch.utils.data import DataLoader
import time
from qlib.data.dataset.handler import DataHandlerLP
class TestDataset(TestAutoData):
def testTSDataset(self):
tsdh = TSDatasetH(
@@ -24,12 +25,12 @@ class TestDataset(TestAutoData):
"instruments": "csi300",
"infer_processors": [
{"class": "FilterCol", "kwargs": {"col_list": ["RESI5", "WVMA5", "RSQR5"]}},
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier":"true"}},
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier": "true"}},
{"class": "Fillna", "kwargs": {"fields_group": "feature"}},
],
"learn_processors": [
"DropnaLabel",
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}}, # CSRankNorm
{"class": "CSRankNorm", "kwargs": {"fields_group": "label"}}, # CSRankNorm
],
},
},
@@ -44,7 +45,7 @@ class TestDataset(TestAutoData):
t = time.time()
for idx in np.random.randint(0, len(tsds_train), size=2000):
data = tsds_train[idx]
_ = tsds_train[idx]
print(f"2000 sample takes {time.time() - t}s")
# FIXME: Please remove pytorch related function. Otherwise the CI tests will fail
@@ -74,7 +75,12 @@ class TestDataset(TestAutoData):
# Check the data
# Get data from DataFrame Directly
data_from_df = tsdh._handler.fetch().loc(axis=0)["2015-01-01":"2016-12-31", "SZ300315"].iloc[-30:].values
data_from_df = (
tsdh._handler.fetch(data_key=DataHandlerLP.DK_L)
.loc(axis=0)["2015-01-01":"2016-12-31", "SZ300315"]
.iloc[-30:]
.values
)
equal = np.isclose(data_from_df, data_from_ds)
self.assertTrue(equal[~np.isnan(data_from_df)].all())