mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 02:50:58 +08:00
dataset performance optm
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from ...utils.serial import Serializable
|
||||
from typing import Union, List, Tuple
|
||||
from ...utils import init_instance_by_config
|
||||
from ...utils import init_instance_by_config, np_ffill
|
||||
from ...log import get_module_logger
|
||||
from .handler import DataHandler, DataHandlerLP
|
||||
from inspect import getfullargspec
|
||||
@@ -194,10 +194,33 @@ class TSDataSampler:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data, start, end, step_len):
|
||||
def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none"):
|
||||
"""
|
||||
Build a dataset which looks like torch.data.utils.Dataset.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : pd.DataFrame
|
||||
The raw tabular data
|
||||
start :
|
||||
The indexable start time
|
||||
end :
|
||||
The indexable end time
|
||||
step_len : int
|
||||
The length of the time-series step
|
||||
fillna_type : int
|
||||
How will qlib handle the sample if there is on sample in a specific date.
|
||||
none:
|
||||
fill with np.nan
|
||||
ffill:
|
||||
ffill with previous sample
|
||||
ffill+bfill:
|
||||
ffill with previous samples first and fill with later samples second
|
||||
"""
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.step_len = step_len
|
||||
self.fillna_type = fillna_type
|
||||
assert get_level_index(data, "datetime") == 0
|
||||
self.data = lazy_sort_index(data)
|
||||
# The index of usable data is between start_idx and end_idx
|
||||
@@ -205,6 +228,11 @@ class TSDataSampler:
|
||||
# self.index_link = self.build_link(self.data)
|
||||
self.idx_df, self.idx_map = self.build_index(self.data)
|
||||
|
||||
def config(self, **kwargs):
|
||||
# Config the attributes
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
@staticmethod
|
||||
def build_index(data: pd.DataFrame) -> dict:
|
||||
"""
|
||||
@@ -253,10 +281,12 @@ class TSDataSampler:
|
||||
idx : Union[int, Tuple[object, str]]
|
||||
"""
|
||||
# The the right row number `i` and col number `j` in idx_df
|
||||
if isinstance(idx, int):
|
||||
if isinstance(idx, (int, np.integer)):
|
||||
real_idx = self.start_idx + idx
|
||||
if self.start_idx <= real_idx < self.end_idx:
|
||||
i, j = self.idx_map[real_idx]
|
||||
else:
|
||||
raise KeyError(f"{real_idx} is out of [{self.start_idx}, {self.end_idx})")
|
||||
elif isinstance(idx, tuple):
|
||||
# <TSDataSampler object>["datetime", "instruments"]
|
||||
date, inst = idx
|
||||
@@ -265,19 +295,33 @@ class TSDataSampler:
|
||||
# NOTE: This relies on the idx_df columns sorted in `__init__`
|
||||
j = bisect.bisect_left(self.idx_df.columns, inst)
|
||||
else:
|
||||
raise KeyError(f"{real_idx} is out of [{self.start_idx}, {self.end_idx})")
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
data_l = []
|
||||
indices = self.idx_df.iloc[max(i - self.step_len + 1, 0) : i + 1, j].values
|
||||
indices = self.idx_df.values[max(i - self.step_len + 1, 0) : i + 1, j]
|
||||
indices = indices.reshape(-1)
|
||||
|
||||
if len(indices) < self.step_len:
|
||||
indices = np.concatenate([np.full((self.step_len - len(indices),), np.nan), indices])
|
||||
for idx in indices:
|
||||
if np.isnan(idx):
|
||||
data_l.append(np.full((self.data.shape[1],), np.nan))
|
||||
else:
|
||||
data_l.append(self.data.iloc[idx])
|
||||
return np.array(data_l)
|
||||
|
||||
if self.fillna_type == "ffill":
|
||||
indices = np_ffill(indices)
|
||||
elif self.fillna_type == "ffill+bfill":
|
||||
indices = np_ffill(np_ffill(indices)[::-1])[::-1]
|
||||
else:
|
||||
assert self.fillna_type == "none"
|
||||
|
||||
if np.isnan(indices.astype(np.float)).sum() == 0: # np.isnan only works on np.float
|
||||
# All the index exists
|
||||
return self.data.values[indices.astype(np.int)]
|
||||
else:
|
||||
# Only part index exists. These days will be filled with nan
|
||||
for idx in indices:
|
||||
if np.isnan(idx):
|
||||
data_l.append(np.full((self.data.shape[1],), np.nan))
|
||||
else:
|
||||
data_l.append(self.data.values[idx])
|
||||
return np.array(data_l)
|
||||
|
||||
def __len__(self):
|
||||
return self.end_idx - self.start_idx
|
||||
|
||||
@@ -101,6 +101,7 @@ class DropCol(Processor):
|
||||
mask = df.columns.isin(self.col_list)
|
||||
return df.loc[:, ~mask]
|
||||
|
||||
|
||||
class FilterCol(Processor):
|
||||
def __init__(self, fields_group="feature", col_list=[]):
|
||||
self.fields_group = fields_group
|
||||
@@ -119,6 +120,7 @@ class FilterCol(Processor):
|
||||
mask = df.columns.isin(self.col_list)
|
||||
return df.loc[:, mask]
|
||||
|
||||
|
||||
class TanhProcess(Processor):
|
||||
""" Use tanh to process noise data"""
|
||||
|
||||
|
||||
@@ -55,6 +55,22 @@ def read_bin(file_path, start_index, end_index):
|
||||
return series
|
||||
|
||||
|
||||
def np_ffill(arr: np.array):
|
||||
"""
|
||||
forward fill a 1D numpy array
|
||||
|
||||
Parameters
|
||||
----------
|
||||
arr : np.array
|
||||
Input numpy 1D array
|
||||
"""
|
||||
mask = np.isnan(arr.astype(np.float)) # np.isnan only works on np.float
|
||||
# get fill index
|
||||
idx = np.where(~mask, np.arange(mask.shape[0]), 0)
|
||||
np.maximum.accumulate(idx, out=idx)
|
||||
return arr[idx]
|
||||
|
||||
|
||||
#################### Search ####################
|
||||
def lower_bound(data, val, level=0):
|
||||
"""multi fields list lower bound.
|
||||
|
||||
@@ -23,35 +23,17 @@ class TestDataset(TestAutoData):
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi300",
|
||||
"infer_processors": [
|
||||
{
|
||||
"class" : "DropCol",
|
||||
"kwargs":{"col_list": ["VWAP0"]}
|
||||
},
|
||||
{
|
||||
"class" : "FilterCol",
|
||||
"kwargs":{"col_list": ["RESI5", "WVMA5", "RSQR5"]}
|
||||
},
|
||||
{
|
||||
"class" : "CSZFillna",
|
||||
"kwargs":{"fields_group": "feature"}
|
||||
}
|
||||
],
|
||||
"learn_processors": [
|
||||
{
|
||||
"class" : "DropCol",
|
||||
"kwargs":{"col_list": ["VWAP0"]}
|
||||
},
|
||||
{
|
||||
"class" : "DropnaProcessor",
|
||||
"kwargs":{"fields_group": "feature"}
|
||||
},
|
||||
"DropnaLabel",
|
||||
{
|
||||
"class": "CSZScoreNorm",
|
||||
"kwargs": {"fields_group": "label"}
|
||||
}
|
||||
{"class": "DropCol", "kwargs": {"col_list": ["VWAP0"]}},
|
||||
{"class": "FilterCol", "kwargs": {"col_list": ["RESI5", "WVMA5", "RSQR5"]}},
|
||||
{"class": "CSZFillna", "kwargs": {"fields_group": "feature"}},
|
||||
],
|
||||
"process_type": "independent"
|
||||
"learn_processors": [
|
||||
{"class": "DropCol", "kwargs": {"col_list": ["VWAP0"]}},
|
||||
{"class": "DropnaProcessor", "kwargs": {"fields_group": "feature"}},
|
||||
"DropnaLabel",
|
||||
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
|
||||
],
|
||||
"process_type": "independent",
|
||||
},
|
||||
},
|
||||
segments={
|
||||
@@ -62,10 +44,26 @@ class TestDataset(TestAutoData):
|
||||
)
|
||||
tsds_train = tsdh.prepare("train") # Test the correctness
|
||||
tsds = tsdh.prepare("valid") # prepare a dataset with is friendly to converting tabular data to time-series
|
||||
|
||||
t = time.time()
|
||||
for idx in np.random.randint(0, len(tsds_train), size=2000):
|
||||
data = tsds_train[idx]
|
||||
print(f"2000 sample takes {time.time() - t}s")
|
||||
|
||||
# FIXME: Please remove pytorch related function. Otherwise the CI tests will fail
|
||||
train_loader = DataLoader(tsds_train, batch_size=800, shuffle=True, num_workers=10)
|
||||
t = time.time()
|
||||
for data in train_loader:
|
||||
now = time.localtime()
|
||||
print(time.strftime("%Y-%m-%d-%H_%M_%S", now))
|
||||
pass
|
||||
print(f"Passing all training batches takes {time.time() - t}s")
|
||||
|
||||
# Here is an example of ffill+bfill for index
|
||||
tsds_train.config(fillna_type="ffill+bfill")
|
||||
train_loader = DataLoader(tsds_train, batch_size=800, shuffle=True, num_workers=10)
|
||||
t = time.time()
|
||||
for data in train_loader:
|
||||
pass
|
||||
print(f"Passing all training batches with fill takes {time.time() - t}s")
|
||||
|
||||
# The dimension of sample is same as tabular data, but it will return timeseries data of the sample
|
||||
|
||||
|
||||
Reference in New Issue
Block a user