diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index dab91bdcd..415d1084b 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -1,6 +1,6 @@ from ...utils.serial import Serializable from typing import Union, List, Tuple -from ...utils import init_instance_by_config +from ...utils import init_instance_by_config, np_ffill from ...log import get_module_logger from .handler import DataHandler, DataHandlerLP from inspect import getfullargspec @@ -194,10 +194,33 @@ class TSDataSampler: """ - def __init__(self, data, start, end, step_len): + def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none"): + """ + Build a dataset which looks like torch.data.utils.Dataset. + + Parameters + ---------- + data : pd.DataFrame + The raw tabular data + start : + The indexable start time + end : + The indexable end time + step_len : int + The length of the time-series step + fillna_type : int + How will qlib handle the sample if there is on sample in a specific date. + none: + fill with np.nan + ffill: + ffill with previous sample + ffill+bfill: + ffill with previous samples first and fill with later samples second + """ self.start = start self.end = end self.step_len = step_len + self.fillna_type = fillna_type assert get_level_index(data, "datetime") == 0 self.data = lazy_sort_index(data) # The index of usable data is between start_idx and end_idx @@ -205,6 +228,11 @@ class TSDataSampler: # self.index_link = self.build_link(self.data) self.idx_df, self.idx_map = self.build_index(self.data) + def config(self, **kwargs): + # Config the attributes + for k, v in kwargs.items(): + setattr(self, k, v) + @staticmethod def build_index(data: pd.DataFrame) -> dict: """ @@ -253,10 +281,12 @@ class TSDataSampler: idx : Union[int, Tuple[object, str]] """ # The the right row number `i` and col number `j` in idx_df - if isinstance(idx, int): + if isinstance(idx, (int, np.integer)): real_idx = self.start_idx + idx if self.start_idx <= real_idx < self.end_idx: i, j = self.idx_map[real_idx] + else: + raise KeyError(f"{real_idx} is out of [{self.start_idx}, {self.end_idx})") elif isinstance(idx, tuple): # ["datetime", "instruments"] date, inst = idx @@ -265,19 +295,33 @@ class TSDataSampler: # NOTE: This relies on the idx_df columns sorted in `__init__` j = bisect.bisect_left(self.idx_df.columns, inst) else: - raise KeyError(f"{real_idx} is out of [{self.start_idx}, {self.end_idx})") + raise NotImplementedError(f"This type of input is not supported") data_l = [] - indices = self.idx_df.iloc[max(i - self.step_len + 1, 0) : i + 1, j].values + indices = self.idx_df.values[max(i - self.step_len + 1, 0) : i + 1, j] indices = indices.reshape(-1) + if len(indices) < self.step_len: indices = np.concatenate([np.full((self.step_len - len(indices),), np.nan), indices]) - for idx in indices: - if np.isnan(idx): - data_l.append(np.full((self.data.shape[1],), np.nan)) - else: - data_l.append(self.data.iloc[idx]) - return np.array(data_l) + + if self.fillna_type == "ffill": + indices = np_ffill(indices) + elif self.fillna_type == "ffill+bfill": + indices = np_ffill(np_ffill(indices)[::-1])[::-1] + else: + assert self.fillna_type == "none" + + if np.isnan(indices.astype(np.float)).sum() == 0: # np.isnan only works on np.float + # All the index exists + return self.data.values[indices.astype(np.int)] + else: + # Only part index exists. These days will be filled with nan + for idx in indices: + if np.isnan(idx): + data_l.append(np.full((self.data.shape[1],), np.nan)) + else: + data_l.append(self.data.values[idx]) + return np.array(data_l) def __len__(self): return self.end_idx - self.start_idx diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index 350f75382..0bdb192a5 100755 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -101,6 +101,7 @@ class DropCol(Processor): mask = df.columns.isin(self.col_list) return df.loc[:, ~mask] + class FilterCol(Processor): def __init__(self, fields_group="feature", col_list=[]): self.fields_group = fields_group @@ -119,6 +120,7 @@ class FilterCol(Processor): mask = df.columns.isin(self.col_list) return df.loc[:, mask] + class TanhProcess(Processor): """ Use tanh to process noise data""" diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 138ce35f8..ab67b67e3 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -55,6 +55,22 @@ def read_bin(file_path, start_index, end_index): return series +def np_ffill(arr: np.array): + """ + forward fill a 1D numpy array + + Parameters + ---------- + arr : np.array + Input numpy 1D array + """ + mask = np.isnan(arr.astype(np.float)) # np.isnan only works on np.float + # get fill index + idx = np.where(~mask, np.arange(mask.shape[0]), 0) + np.maximum.accumulate(idx, out=idx) + return arr[idx] + + #################### Search #################### def lower_bound(data, val, level=0): """multi fields list lower bound. diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 5c0c9d843..36234c879 100755 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -23,35 +23,17 @@ class TestDataset(TestAutoData): "fit_end_time": "2014-12-31", "instruments": "csi300", "infer_processors": [ - { - "class" : "DropCol", - "kwargs":{"col_list": ["VWAP0"]} - }, - { - "class" : "FilterCol", - "kwargs":{"col_list": ["RESI5", "WVMA5", "RSQR5"]} - }, - { - "class" : "CSZFillna", - "kwargs":{"fields_group": "feature"} - } - ], - "learn_processors": [ - { - "class" : "DropCol", - "kwargs":{"col_list": ["VWAP0"]} - }, - { - "class" : "DropnaProcessor", - "kwargs":{"fields_group": "feature"} - }, - "DropnaLabel", - { - "class": "CSZScoreNorm", - "kwargs": {"fields_group": "label"} - } + {"class": "DropCol", "kwargs": {"col_list": ["VWAP0"]}}, + {"class": "FilterCol", "kwargs": {"col_list": ["RESI5", "WVMA5", "RSQR5"]}}, + {"class": "CSZFillna", "kwargs": {"fields_group": "feature"}}, ], - "process_type": "independent" + "learn_processors": [ + {"class": "DropCol", "kwargs": {"col_list": ["VWAP0"]}}, + {"class": "DropnaProcessor", "kwargs": {"fields_group": "feature"}}, + "DropnaLabel", + {"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}}, + ], + "process_type": "independent", }, }, segments={ @@ -62,10 +44,26 @@ class TestDataset(TestAutoData): ) tsds_train = tsdh.prepare("train") # Test the correctness tsds = tsdh.prepare("valid") # prepare a dataset with is friendly to converting tabular data to time-series + + t = time.time() + for idx in np.random.randint(0, len(tsds_train), size=2000): + data = tsds_train[idx] + print(f"2000 sample takes {time.time() - t}s") + + # FIXME: Please remove pytorch related function. Otherwise the CI tests will fail train_loader = DataLoader(tsds_train, batch_size=800, shuffle=True, num_workers=10) + t = time.time() for data in train_loader: - now = time.localtime() - print(time.strftime("%Y-%m-%d-%H_%M_%S", now)) + pass + print(f"Passing all training batches takes {time.time() - t}s") + + # Here is an example of ffill+bfill for index + tsds_train.config(fillna_type="ffill+bfill") + train_loader = DataLoader(tsds_train, batch_size=800, shuffle=True, num_workers=10) + t = time.time() + for data in train_loader: + pass + print(f"Passing all training batches with fill takes {time.time() - t}s") # The dimension of sample is same as tabular data, but it will return timeseries data of the sample