1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 18:40:58 +08:00

Merge branch 'online_srv' of https://github.com/you-n-g/qlib into online_srv

This commit is contained in:
lzh222333
2021-05-09 10:52:07 +00:00

View File

@@ -114,6 +114,7 @@ class DatasetH(Dataset):
"""
self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler)
self.segments = segments.copy()
self.fetch_kwargs = {}
super().__init__(**kwargs)
def config(self, handler_kwargs: dict = None, **kwargs):
@@ -171,7 +172,7 @@ class DatasetH(Dataset):
----------
slc : slice
"""
return self.handler.fetch(slc, **kwargs)
return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs)
def prepare(
self,
@@ -199,6 +200,12 @@ class DatasetH(Dataset):
The data to fetch: DK_*
Default is DK_I, which indicate fetching data for **inference**.
kwargs :
The parameters that kwargs may contain:
flt_col : str
It only exists in TSDatasetH, can be used to add a column of data(True or False) to filter data.
This parameter is only supported when it is an instance of TSDatasetH.
Returns
-------
Union[List[pd.DataFrame], pd.DataFrame]:
@@ -243,7 +250,7 @@ class TSDataSampler:
"""
def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None):
def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None, flt_data=None):
"""
Build a dataset which looks like torch.data.utils.Dataset.
@@ -265,6 +272,11 @@ class TSDataSampler:
ffill with previous sample
ffill+bfill:
ffill with previous samples first and fill with later samples second
flt_data : pd.Series
a column of data(True or False) to filter data.
None:
kepp all data
"""
self.start = start
self.end = end
@@ -288,18 +300,35 @@ class TSDataSampler:
# the data type will be changed
# The index of usable data is between start_idx and end_idx
self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
self.idx_df, self.idx_map = self.build_index(self.data)
self.data_index = deepcopy(self.data.index)
if flt_data is not None:
self.flt_data = np.array(flt_data.reindex(self.data_index)).reshape(-1)
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
self.data_idx = deepcopy(self.data.index)
del self.data # save memory
@staticmethod
def flt_idx_map(flt_data, idx_map):
idx = 0
new_idx_map = {}
for i, exist in enumerate(flt_data):
if exist:
new_idx_map[idx] = idx_map[i]
idx += 1
return new_idx_map
def get_index(self):
"""
Get the pandas index of the data, it will be useful in following scenarios
- Special sampler will be used (e.g. user want to sample day by day)
"""
return self.data_idx[self.start_idx : self.end_idx]
return self.data_index[self.start_idx : self.end_idx]
def config(self, **kwargs):
# Config the attributes
@@ -489,6 +518,17 @@ class TSDatasetH(DatasetH):
"""
dtype = kwargs.pop("dtype", None)
start, end = slc.start, slc.stop
data = self._prepare_raw_seg(slc=slc, **kwargs)
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype)
return tsds
flt_col = kwargs.pop('flt_col', None)
# TSDatasetH will retrieve more data for complete
data = self._prepare_raw_seg(slc, **kwargs)
flt_kwargs = deepcopy(kwargs)
if flt_col is not None:
flt_kwargs['col_set'] = flt_col
flt_data = self._prepare_raw_seg(slc, **flt_kwargs)
assert len(flt_data.columns) == 1
else:
flt_data = None
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype, flt_data=flt_data)
return tsds