1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 12:00:58 +08:00

add highfreq example

This commit is contained in:
bxdd
2021-01-25 17:58:45 +00:00
parent 3f9f295a87
commit ffedb6382f
10 changed files with 585 additions and 29 deletions

View File

@@ -123,6 +123,16 @@ class CalendarProvider(abc.ABC):
H["c"][flag] = _calendar, _calendar_index
return _calendar, _calendar_index
def get_calender_day(self, freq="day", future=False):
flag = f"{freq}_future_{future}_day"
if flag in H["c"]:
_calendar, _calendar_index = H["c"][flag]
else:
_calendar = np.array(list(map(lambda x: x.date(), self._load_calendar(freq, future))))
_calendar_index = {x: i for i, x in enumerate(_calendar)} # for fast search
H["c"][flag] = _calendar, _calendar_index
return _calendar, _calendar_index
def _uri(self, start_time, end_time, freq, future=False):
"""Get the uri of calendar generation task."""
return hash_args(start_time, end_time, freq, future)
@@ -686,7 +696,10 @@ class LocalExpressionProvider(ExpressionProvider):
# 1) The stock data is currently float. If there is other types of data, this part needs to be re-implemented.
# 2) The the precision should be configurable
try:
series = series.astype(np.float32)
if series.dtype == np.float64:
series = series.astype(np.float32)
elif series.dtype == np.bool:
series = series.astype(np.int8)
except ValueError:
pass
if not series.empty:

View File

@@ -87,6 +87,36 @@ class DatasetH(Dataset):
"""
super().__init__(handler, segments)
def init(self, init_type: str = DataHandlerLP.IT_FIT_SEQ, enable_cache: bool = False):
"""
Initialize the data of Qlib
Parameters
----------
init_type : str
- if `init_type` == DataHandlerLP.IT_FIT_SEQ:
the input of `DataHandlerLP.fit` will be the output of the previous processor
- if `init_type` == DataHandlerLP.IT_FIT_IND:
the input of `DataHandlerLP.fit` will be the original df
- if `init_type` == DataHandlerLP.IT_LS:
The state of the object has been load by pickle
enable_cache : bool
default value is false:
- if `enable_cache` == True:
the processed data will be saved on disk, and handler will load the cached data from the disk directly
when we call `init` next time
"""
self.handler.init(init_type=init_type, enable_cache=enable_cache)
def setup_data(self, handler: Union[dict, DataHandler], segments: list):
"""
Setup the underlying data.
@@ -116,8 +146,8 @@ class DatasetH(Dataset):
'outsample': ("2017-01-01", "2020-08-01",),
}
"""
self._handler = init_instance_by_config(handler, accept_types=DataHandler)
self._segments = segments.copy()
self.handler = init_instance_by_config(handler, accept_types=DataHandler)
self.segments = segments.copy()
def _prepare_seg(self, slc: slice, **kwargs):
"""
@@ -127,7 +157,7 @@ class DatasetH(Dataset):
----------
slc : slice
"""
return self._handler.fetch(slc, **kwargs)
return self.handler.fetch(slc, **kwargs)
def prepare(
self,
@@ -150,7 +180,7 @@ class DatasetH(Dataset):
- ['train', 'valid']
col_set : str
The col_set will be passed to self._handler when fetching data.
The col_set will be passed to self.handler when fetching data.
data_key : str
The data to fetch: DK_*
Default is DK_I, which indicate fetching data for **inference**.
@@ -166,16 +196,16 @@ class DatasetH(Dataset):
logger = get_module_logger("DatasetH")
fetch_kwargs = {"col_set": col_set}
fetch_kwargs.update(kwargs)
if "data_key" in getfullargspec(self._handler.fetch).args:
if "data_key" in getfullargspec(self.handler.fetch).args:
fetch_kwargs["data_key"] = data_key
else:
logger.info(f"data_key[{data_key}] is ignored.")
# Handle all kinds of segments format
if isinstance(segments, (list, tuple)):
return [self._prepare_seg(slice(*self._segments[seg]), **fetch_kwargs) for seg in segments]
return [self._prepare_seg(slice(*self.segments[seg]), **fetch_kwargs) for seg in segments]
elif isinstance(segments, str):
return self._prepare_seg(slice(*self._segments[segments]), **fetch_kwargs)
return self._prepare_seg(slice(*self.segments[segments]), **fetch_kwargs)
elif isinstance(segments, slice):
return self._prepare_seg(segments, **fetch_kwargs)
else:
@@ -409,7 +439,7 @@ class TSDatasetH(DatasetH):
def setup_data(self, *args, **kwargs):
super().setup_data(*args, **kwargs)
cal = self._handler.fetch(col_set=self._handler.CS_RAW).index.get_level_values("datetime").unique()
cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique()
cal = sorted(cal)
# Get the datatime index for building timestamp
self.cal = cal

View File

@@ -57,6 +57,7 @@ class DataHandler(Serializable):
instruments=None,
start_time=None,
end_time=None,
freq="day",
data_loader: Tuple[dict, str, DataLoader] = None,
init_data=True,
fetch_orig=True,
@@ -70,6 +71,8 @@ class DataHandler(Serializable):
start_time of the original data.
end_time :
end_time of the original data.
freq :
frequency of data
data_loader : Tuple[dict, str, DataLoader]
data loader to load the data.
init_data :
@@ -92,6 +95,7 @@ class DataHandler(Serializable):
self.instruments = instruments
self.start_time = start_time
self.end_time = end_time
self.freq = freq
self.fetch_orig = fetch_orig
if init_data:
with TimeInspector.logt("Init data"):
@@ -119,7 +123,7 @@ class DataHandler(Serializable):
# Setup data.
# _data may be with multiple column index level. The outer level indicates the feature set name
with TimeInspector.logt("Loading data"):
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time)
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time, self.freq)
# TODO: cache
CS_ALL = "__all" # return all columns with single-level index column
@@ -258,10 +262,12 @@ class DataHandlerLP(DataHandler):
instruments=None,
start_time=None,
end_time=None,
freq="day",
data_loader: Tuple[dict, str, DataLoader] = None,
infer_processors=[],
learn_processors=[],
process_type=PTYPE_A,
drop_raw=False,
**kwargs,
):
"""
@@ -303,6 +309,8 @@ class DataHandlerLP(DataHandler):
- self._learn will be processed by infer_processors + learn_processors
- (e.g. self._infer processed by learn_processors )
drop_raw: bool
Whether to drop the raw data
"""
# Setup preprocessor
@@ -319,7 +327,8 @@ class DataHandlerLP(DataHandler):
)
self.process_type = process_type
super().__init__(instruments, start_time, end_time, data_loader, **kwargs)
self.drop_raw = drop_raw
super().__init__(instruments, start_time, end_time, freq, data_loader, **kwargs)
def get_all_processors(self):
return self.infer_processors + self.learn_processors
@@ -348,7 +357,7 @@ class DataHandlerLP(DataHandler):
"""
# data for inference
_infer_df = self._data
if len(self.infer_processors) > 0: # avoid modifying the original data
if len(self.infer_processors) > 0 and not self.drop_raw: # avoid modifying the original data
_infer_df = _infer_df.copy()
for proc in self.infer_processors:
@@ -378,6 +387,8 @@ class DataHandlerLP(DataHandler):
_learn_df = proc(_learn_df)
self._learn = _learn_df
if self.drop_raw:
del self._data
# init type
IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor
IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df
@@ -416,7 +427,11 @@ class DataHandlerLP(DataHandler):
# TODO: Be able to cache handler data. Save the memory for data processing
def _get_df_by_key(self, data_key: str = DK_I) -> pd.DataFrame:
df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key])
try:
df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key])
except AttributeError:
print("please set drop_raw = False if you want to use raw data")
raise
return df
def fetch(

View File

@@ -19,7 +19,7 @@ class DataLoader(abc.ABC):
"""
@abc.abstractmethod
def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame:
def load(self, instruments, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
"""
load the data as pd.DataFrame.
@@ -94,7 +94,7 @@ class DLWParser(DataLoader):
return exprs, names
@abc.abstractmethod
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
"""
load the dataframe for specific group
@@ -114,25 +114,25 @@ class DLWParser(DataLoader):
"""
pass
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
def load(self, instruments=None, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
if self.is_group:
df = pd.concat(
{
grp: self.load_group_df(instruments, exprs, names, start_time, end_time)
grp: self.load_group_df(instruments, exprs, names, start_time, end_time, freq)
for grp, (exprs, names) in self.fields.items()
},
axis=1,
)
else:
exprs, names = self.fields
df = self.load_group_df(instruments, exprs, names, start_time, end_time)
df = self.load_group_df(instruments, exprs, names, start_time, end_time, freq)
return df
class QlibDataLoader(DLWParser):
"""Same as QlibDataLoader. The fields can be define by config"""
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None):
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_level=True):
"""
Parameters
----------
@@ -140,11 +140,15 @@ class QlibDataLoader(DLWParser):
Please refer to the doc of DLWParser
filter_pipe :
Filter pipe for the instruments
swap_level :
Whether to swap level of MultiIndex
"""
self.filter_pipe = filter_pipe
self.swap_level = swap_level
print("swap level", swap_level)
super().__init__(config)
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
if instruments is None:
warnings.warn("`instruments` is not set, will load all stocks")
instruments = "all"
@@ -153,9 +157,10 @@ class QlibDataLoader(DLWParser):
elif self.filter_pipe is not None:
warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")
df = D.features(instruments, exprs, start_time, end_time)
df = D.features(instruments, exprs, start_time, end_time, freq)
df.columns = names
df = df.swaplevel().sort_index() # NOTE: always return <datetime, instrument>
if self.swap_level:
df = df.swaplevel().sort_index() # NOTE: if swaplevel, return <datetime, instrument>
return df
@@ -177,7 +182,7 @@ class StaticDataLoader(DataLoader):
self.join = join
self._data = None
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
def load(self, instruments=None, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
self._maybe_load_raw_data()
if instruments is None:
df = self._data