mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 02:21:18 +08:00
filter
This commit is contained in:
@@ -114,6 +114,7 @@ class DatasetH(Dataset):
|
||||
"""
|
||||
self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self.segments = segments.copy()
|
||||
self.fetch_kwargs = {}
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def config(self, handler_kwargs: dict = None, **kwargs):
|
||||
@@ -171,7 +172,7 @@ class DatasetH(Dataset):
|
||||
----------
|
||||
slc : slice
|
||||
"""
|
||||
return self.handler.fetch(slc, **kwargs)
|
||||
return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
@@ -288,13 +289,29 @@ class TSDataSampler:
|
||||
|
||||
# the data type will be changed
|
||||
# The index of usable data is between start_idx and end_idx
|
||||
self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
|
||||
self.idx_df, self.idx_map = self.build_index(self.data)
|
||||
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
|
||||
self.data_idx = deepcopy(self.data.index)
|
||||
self.data_index = deepcopy(self.data.index)
|
||||
|
||||
if flt_data is not None:
|
||||
self.flt_data = np.array(flt_data).reshape(-1)
|
||||
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
|
||||
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
|
||||
|
||||
self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
|
||||
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
|
||||
|
||||
del self.data # save memory
|
||||
|
||||
@staticmethod
|
||||
def flt_idx_map(flt_data, idx_map):
|
||||
idx = 0
|
||||
new_idx_map = {}
|
||||
for i, exist in enumerate(flt_data):
|
||||
if exist:
|
||||
new_idx_map[idx] = idx_map[i]
|
||||
idx += 1
|
||||
return new_idx_map
|
||||
|
||||
def get_index(self):
|
||||
"""
|
||||
Get the pandas index of the data, it will be useful in following scenarios
|
||||
@@ -488,8 +505,19 @@ class TSDatasetH(DatasetH):
|
||||
"""
|
||||
split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data
|
||||
"""
|
||||
dtype = kwargs.pop("dtype", None)
|
||||
dtype = kwargs.pop("dtype")
|
||||
start, end = slc.start, slc.stop
|
||||
data = self._prepare_raw_seg(slc=slc, **kwargs)
|
||||
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype)
|
||||
return tsds
|
||||
flt_col = kwargs.pop('flt_col', None)
|
||||
# TSDatasetH will retrieve more data for complete
|
||||
data = self._prepare_raw_seg(slc, **kwargs)
|
||||
|
||||
flt_kwargs = deepcopy(kwargs)
|
||||
if flt_col is not None:
|
||||
flt_kwargs['col_set'] = flt_col
|
||||
flt_data = self._prepare_raw_seg(slc, **flt_kwargs)
|
||||
assert len(flt_data.columns) == 1
|
||||
else:
|
||||
flt_data = None
|
||||
|
||||
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype, flt_data=flt_data)
|
||||
return tsds
|
||||
Reference in New Issue
Block a user