1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 20:11:08 +08:00

dataset performance optm

This commit is contained in:
Young
2020-12-05 17:00:23 +00:00
committed by you-n-g
parent 65902e424c
commit d2107c9957
4 changed files with 101 additions and 41 deletions

View File

@@ -1,6 +1,6 @@
from ...utils.serial import Serializable
from typing import Union, List, Tuple
from ...utils import init_instance_by_config
from ...utils import init_instance_by_config, np_ffill
from ...log import get_module_logger
from .handler import DataHandler, DataHandlerLP
from inspect import getfullargspec
@@ -194,10 +194,33 @@ class TSDataSampler:
"""
def __init__(self, data, start, end, step_len):
def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none"):
"""
Build a dataset which looks like torch.data.utils.Dataset.
Parameters
----------
data : pd.DataFrame
The raw tabular data
start :
The indexable start time
end :
The indexable end time
step_len : int
The length of the time-series step
fillna_type : int
How will qlib handle the sample if there is on sample in a specific date.
none:
fill with np.nan
ffill:
ffill with previous sample
ffill+bfill:
ffill with previous samples first and fill with later samples second
"""
self.start = start
self.end = end
self.step_len = step_len
self.fillna_type = fillna_type
assert get_level_index(data, "datetime") == 0
self.data = lazy_sort_index(data)
# The index of usable data is between start_idx and end_idx
@@ -205,6 +228,11 @@ class TSDataSampler:
# self.index_link = self.build_link(self.data)
self.idx_df, self.idx_map = self.build_index(self.data)
def config(self, **kwargs):
# Config the attributes
for k, v in kwargs.items():
setattr(self, k, v)
@staticmethod
def build_index(data: pd.DataFrame) -> dict:
"""
@@ -253,10 +281,12 @@ class TSDataSampler:
idx : Union[int, Tuple[object, str]]
"""
# The the right row number `i` and col number `j` in idx_df
if isinstance(idx, int):
if isinstance(idx, (int, np.integer)):
real_idx = self.start_idx + idx
if self.start_idx <= real_idx < self.end_idx:
i, j = self.idx_map[real_idx]
else:
raise KeyError(f"{real_idx} is out of [{self.start_idx}, {self.end_idx})")
elif isinstance(idx, tuple):
# <TSDataSampler object>["datetime", "instruments"]
date, inst = idx
@@ -265,19 +295,33 @@ class TSDataSampler:
# NOTE: This relies on the idx_df columns sorted in `__init__`
j = bisect.bisect_left(self.idx_df.columns, inst)
else:
raise KeyError(f"{real_idx} is out of [{self.start_idx}, {self.end_idx})")
raise NotImplementedError(f"This type of input is not supported")
data_l = []
indices = self.idx_df.iloc[max(i - self.step_len + 1, 0) : i + 1, j].values
indices = self.idx_df.values[max(i - self.step_len + 1, 0) : i + 1, j]
indices = indices.reshape(-1)
if len(indices) < self.step_len:
indices = np.concatenate([np.full((self.step_len - len(indices),), np.nan), indices])
for idx in indices:
if np.isnan(idx):
data_l.append(np.full((self.data.shape[1],), np.nan))
else:
data_l.append(self.data.iloc[idx])
return np.array(data_l)
if self.fillna_type == "ffill":
indices = np_ffill(indices)
elif self.fillna_type == "ffill+bfill":
indices = np_ffill(np_ffill(indices)[::-1])[::-1]
else:
assert self.fillna_type == "none"
if np.isnan(indices.astype(np.float)).sum() == 0: # np.isnan only works on np.float
# All the index exists
return self.data.values[indices.astype(np.int)]
else:
# Only part index exists. These days will be filled with nan
for idx in indices:
if np.isnan(idx):
data_l.append(np.full((self.data.shape[1],), np.nan))
else:
data_l.append(self.data.values[idx])
return np.array(data_l)
def __len__(self):
return self.end_idx - self.start_idx