From 36ab078fbdbdd69f1ac93b0be75ab29253b357d3 Mon Sep 17 00:00:00 2001 From: blin Date: Wed, 28 Apr 2021 07:15:59 +0000 Subject: [PATCH] filter --- qlib/data/dataset/__init__.py | 44 ++++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index cd15a98c9..5485796ef 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -114,6 +114,7 @@ class DatasetH(Dataset): """ self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler) self.segments = segments.copy() + self.fetch_kwargs = {} super().__init__(**kwargs) def config(self, handler_kwargs: dict = None, **kwargs): @@ -171,7 +172,7 @@ class DatasetH(Dataset): ---------- slc : slice """ - return self.handler.fetch(slc, **kwargs) + return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs) def prepare( self, @@ -288,13 +289,29 @@ class TSDataSampler: # the data type will be changed # The index of usable data is between start_idx and end_idx - self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end)) self.idx_df, self.idx_map = self.build_index(self.data) - self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance - self.data_idx = deepcopy(self.data.index) + self.data_index = deepcopy(self.data.index) + if flt_data is not None: + self.flt_data = np.array(flt_data).reshape(-1) + self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map) + self.data_index = self.data_index[np.where(self.flt_data == True)[0]] + + self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end)) + self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance + del self.data # save memory + @staticmethod + def flt_idx_map(flt_data, idx_map): + idx = 0 + new_idx_map = {} + for i, exist in enumerate(flt_data): + if exist: + new_idx_map[idx] = idx_map[i] + idx += 1 + return new_idx_map + def get_index(self): """ Get the pandas index of the data, it will be useful in following scenarios @@ -488,8 +505,19 @@ class TSDatasetH(DatasetH): """ split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data """ - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype") start, end = slc.start, slc.stop - data = self._prepare_raw_seg(slc=slc, **kwargs) - tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype) - return tsds + flt_col = kwargs.pop('flt_col', None) + # TSDatasetH will retrieve more data for complete + data = self._prepare_raw_seg(slc, **kwargs) + + flt_kwargs = deepcopy(kwargs) + if flt_col is not None: + flt_kwargs['col_set'] = flt_col + flt_data = self._prepare_raw_seg(slc, **flt_kwargs) + assert len(flt_data.columns) == 1 + else: + flt_data = None + + tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype, flt_data=flt_data) + return tsds \ No newline at end of file