mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 20:11:08 +08:00
fix sampler performance bug
This commit is contained in:
@@ -190,7 +190,12 @@ class TSDataSampler:
|
||||
It works like `torch.data.utils.Dataset`, it provides a very convient interface for constructing time-series
|
||||
dataset based on tabular data.
|
||||
|
||||
If user have further requirements for processing data, user could process
|
||||
If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
|
||||
more powerful subclasses.
|
||||
|
||||
Known Issues:
|
||||
- For performance issues, this Sampler will convert dataframe into arrays for better performance. This could result
|
||||
in a different data type
|
||||
|
||||
"""
|
||||
|
||||
@@ -223,11 +228,20 @@ class TSDataSampler:
|
||||
self.fillna_type = fillna_type
|
||||
assert get_level_index(data, "datetime") == 0
|
||||
self.data = lazy_sort_index(data)
|
||||
self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values! But
|
||||
# the data type will be changed
|
||||
# The index of usable data is between start_idx and end_idx
|
||||
self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
|
||||
# self.index_link = self.build_link(self.data)
|
||||
self.idx_df, self.idx_map = self.build_index(self.data)
|
||||
|
||||
def get_index(self):
|
||||
"""
|
||||
Get the pandas index of the data, it will be useful in following scenarios
|
||||
- Special sampler will be used (e.g. user want to sample day by day)
|
||||
"""
|
||||
return self.data.index[self.start_idx : self.end_idx]
|
||||
|
||||
def config(self, **kwargs):
|
||||
# Config the attributes
|
||||
for k, v in kwargs.items():
|
||||
@@ -313,14 +327,14 @@ class TSDataSampler:
|
||||
|
||||
if np.isnan(indices.astype(np.float)).sum() == 0: # np.isnan only works on np.float
|
||||
# All the index exists
|
||||
return self.data.values[indices.astype(np.int)]
|
||||
return self.data_arr[indices.astype(np.int)]
|
||||
else:
|
||||
# Only part index exists. These days will be filled with nan
|
||||
for idx in indices:
|
||||
if np.isnan(idx):
|
||||
data_l.append(np.full((self.data.shape[1],), np.nan))
|
||||
data_l.append(np.full((self.data_arr.shape[1],), np.nan))
|
||||
else:
|
||||
data_l.append(self.data.values[idx])
|
||||
data_l.append(self.data_arr[idx])
|
||||
return np.array(data_l)
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -10,6 +10,7 @@ from torch.utils.data import DataLoader
|
||||
import time
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class TestDataset(TestAutoData):
|
||||
def testTSDataset(self):
|
||||
tsdh = TSDatasetH(
|
||||
@@ -24,12 +25,12 @@ class TestDataset(TestAutoData):
|
||||
"instruments": "csi300",
|
||||
"infer_processors": [
|
||||
{"class": "FilterCol", "kwargs": {"col_list": ["RESI5", "WVMA5", "RSQR5"]}},
|
||||
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier":"true"}},
|
||||
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier": "true"}},
|
||||
{"class": "Fillna", "kwargs": {"fields_group": "feature"}},
|
||||
],
|
||||
"learn_processors": [
|
||||
"DropnaLabel",
|
||||
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}}, # CSRankNorm
|
||||
{"class": "CSRankNorm", "kwargs": {"fields_group": "label"}}, # CSRankNorm
|
||||
],
|
||||
},
|
||||
},
|
||||
@@ -44,7 +45,7 @@ class TestDataset(TestAutoData):
|
||||
|
||||
t = time.time()
|
||||
for idx in np.random.randint(0, len(tsds_train), size=2000):
|
||||
data = tsds_train[idx]
|
||||
_ = tsds_train[idx]
|
||||
print(f"2000 sample takes {time.time() - t}s")
|
||||
|
||||
# FIXME: Please remove pytorch related function. Otherwise the CI tests will fail
|
||||
@@ -74,7 +75,12 @@ class TestDataset(TestAutoData):
|
||||
|
||||
# Check the data
|
||||
# Get data from DataFrame Directly
|
||||
data_from_df = tsdh._handler.fetch().loc(axis=0)["2015-01-01":"2016-12-31", "SZ300315"].iloc[-30:].values
|
||||
data_from_df = (
|
||||
tsdh._handler.fetch(data_key=DataHandlerLP.DK_L)
|
||||
.loc(axis=0)["2015-01-01":"2016-12-31", "SZ300315"]
|
||||
.iloc[-30:]
|
||||
.values
|
||||
)
|
||||
|
||||
equal = np.isclose(data_from_df, data_from_ds)
|
||||
self.assertTrue(equal[~np.isnan(data_from_df)].all())
|
||||
|
||||
Reference in New Issue
Block a user