diff --git a/examples/high_freq/__init__.py b/examples/high_freq/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/high_freq/highfreq_handler.py b/examples/high_freq/highfreq_handler.py new file mode 100644 index 000000000..32557f768 --- /dev/null +++ b/examples/high_freq/highfreq_handler.py @@ -0,0 +1,220 @@ +from qlib.data.dataset.handler import DataHandler, DataHandlerLP +from qlib.data.dataset.processor import Processor +from qlib.utils import get_cls_kwargs +from qlib.log import TimeInspector + + +class HighFreqHandler(DataHandlerLP): + def __init__( + self, + instruments="csi500", + start_time=None, + end_time=None, + freq="1min", + infer_processors=[], + learn_processors=[], + fit_start_time=None, + fit_end_time=None, + drop_raw=True, + ): + def check_transform_proc(proc_l): + new_l = [] + for p in proc_l: + p["kwargs"].update( + { + "fit_start_time": fit_start_time, + "fit_end_time": fit_end_time, + } + ) + new_l.append(p) + return new_l + + infer_processors = [] + learn_processors = [] + + data_loader = { + "class": "QlibDataLoader", + "kwargs": { + "config": self.get_feature_config(), + "swap_level": False, + }, + } + super().__init__( + instruments=instruments, + start_time=start_time, + end_time=end_time, + freq=freq, + data_loader=data_loader, + infer_processors=infer_processors, + learn_processors=learn_processors, + drop_raw=drop_raw, + ) + + + def get_feature_config(self): + fields = [] + names = [] + + template_if = "If(IsNull({1}), {0}, {1})" + template_paused = "Select(Eq($paused, 0.0), {0})" + template_fillnan = "FFillNan({0})" + fields += [ + "{0}/Ref(DayLast({1}), 240)".format( + template_if.format( + template_fillnan.format(template_paused.format("$close")), + template_paused.format("$open"), + ), + template_fillnan.format(template_paused.format("$close")), + ) + ] + fields += [ + "{0}/Ref(DayLast({1}), 240)".format( + template_if.format( + template_fillnan.format(template_paused.format("$close")), + template_paused.format("$high"), + ), + template_fillnan.format(template_paused.format("$close")), + ) + ] + fields += [ + "{0}/Ref(DayLast({1}), 240)".format( + template_if.format( + template_fillnan.format(template_paused.format("$close")), + template_paused.format("$low"), + ), + template_fillnan.format(template_paused.format("$close")), + ) + ] + fields += ["{0}/Ref(DayLast({0}), 240)".format(template_fillnan.format(template_paused.format("$close")))] + fields += [ + "{0}/Ref(DayLast({1}), 240)".format( + "If(IsNull({1}), {0}, If(Or(Or(Or(Eq({1}, np.inf), Eq({1}, -np.inf)), Eq({1}, 0)), Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2})))), {0}, {1}))".format( + template_fillnan.format(template_paused.format("$close")), + template_paused.format("$vwap"), + template_paused.format("$low"), + template_paused.format("$high"), + ), + template_fillnan.format(template_paused.format("$close")), + ) + ] + names += ["$open", "$high", "$low", "$close", "$vwap"] + + fields += [ + "Ref({0}, 240)/Ref(DayLast({1}), 240)".format( + template_if.format( + template_fillnan.format(template_paused.format("$close")), + template_paused.format("$open"), + ), + template_fillnan.format(template_paused.format("$close")), + ) + ] + fields += [ + "Ref({0}, 240)/Ref(DayLast({1}), 240)".format( + template_if.format( + template_fillnan.format(template_paused.format("$close")), + template_paused.format("$high"), + ), + template_fillnan.format(template_paused.format("$close")), + ) + ] + fields += [ + "Ref({0}, 240)/Ref(DayLast({1}), 240)".format( + template_if.format( + template_fillnan.format(template_paused.format("$close")), + template_paused.format("$low"), + ), + template_fillnan.format(template_paused.format("$close")), + ) + ] + fields += [ + "Ref({0}, 240)/Ref(DayLast({0}), 240)".format(template_fillnan.format(template_paused.format("$close"))) + ] + fields += [ + "Ref({0}, 240)/Ref(DayLast({1}), 240)".format( + "If(IsNull({1}), {0}, If(Or(Or(Or(Eq({1}, np.inf), Eq({1}, -np.inf)), Eq({1}, 0)), Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2})))), {0}, {1}))".format( + template_fillnan.format(template_paused.format("$close")), + template_paused.format("$vwap"), + template_paused.format("$low"), + template_paused.format("$high"), + ), + template_fillnan.format(template_paused.format("$close")), + ) + ] + names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"] + + fields += [ + "{0}/Ref(DayLast(Mean({0}, 7200)), 240)".format( + "If(IsNull({1}), 0, If(Or(Gt({2}, Mul(1.001, {4})), Lt({2}, Mul(0.999, {3}))), 0, {1}))".format( + template_fillnan.format(template_paused.format("$close")), + template_paused.format("$volume"), + template_paused.format("$vwap"), + template_paused.format("$low"), + template_paused.format("$high"), + ) + ) + ] + names += ["$volume"] + fields += [ + "Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240)".format( + "If(IsNull({1}), 0, If(Or(Gt({2}, Mul(1.001, {4})), Lt({2}, Mul(0.999, {3}))), 0, {1}))".format( + template_fillnan.format(template_paused.format("$close")), + template_paused.format("$volume"), + template_paused.format("$vwap"), + template_paused.format("$low"), + template_paused.format("$high"), + ) + ) + ] + names += ["$volume_1"] + + fields += [template_paused.format("Date($close)")] + names += ["date"] + return fields, names + + +class HighFreqBacktestHandler(DataHandler): + def __init__( + self, + instruments="csi300", + start_time=None, + end_time=None, + freq="1min", + ): + infer_processors = check_transform_proc(infer_processors) + learn_processors = check_transform_proc(learn_processors) + data_loader = { + "class": "QlibDataLoader", + "kwargs": { + "config": self.get_feature_config(), + "swap_level": False, + }, + } + super().__init__( + instruments=instruments, + start_time=start_time, + end_time=end_time, + freq=freq, + data_loader=data_loader, + ) + + def get_feature_config(self): + fields = [] + names = [] + + template_if = "If(Eq({1}, np.nan), {0}, {1})" + template_paused = "Select(Eq($paused, 0.0), {0})" + template_fillnan = "FFillNan({0})" + + fields += [template_fillnan.format(template_paused.format("$close")),] + names += ["$close0"] + fields += [ + "If(Eq({1}, np.nan), 0, If(Or(Gt({2}, Mul(1.001, {4})), Lt({2}, Mul(0.999, {3}))), 0, {1}))".format( + template_fillnan.format(template_paused.format("$close")), + template_paused.format("$volume"), + template_paused.format("$vwap"), + template_paused.format("$low"), + template_paused.format("$high"), + ) + ] + names += ["$volume0"] + return fields, names diff --git a/examples/high_freq/highfreq_ops.py b/examples/high_freq/highfreq_ops.py new file mode 100644 index 000000000..f6470d68e --- /dev/null +++ b/examples/high_freq/highfreq_ops.py @@ -0,0 +1,62 @@ +import numpy as np +import pandas as pd +import importlib +from qlib.data.ops import ElemOperator, PairOperator +from qlib.config import C +from qlib.data.data import Cal + + +class DayFirst(ElemOperator): + def __init__(self, feature): + super(DayFirst, self).__init__(feature, "day_first") + + def _load_internal(self, instrument, start_index, end_index, freq): + _calendar = Cal.get_calender_day(freq=freq)[0] + series = self.feature.load(instrument, start_index, end_index, freq) + return series.groupby(_calendar[series.index]).transform("first") + + +class DayLast(ElemOperator): + def __init__(self, feature): + super(DayLast, self).__init__(feature, "day_last") + + def _load_internal(self, instrument, start_index, end_index, freq): + _calendar = Cal.get_calender_day(freq=freq)[0] + series = self.feature.load(instrument, start_index, end_index, freq) + return series.groupby(_calendar[series.index]).transform("last") + + +class FFillNan(ElemOperator): + def __init__(self, feature): + super(FFillNan, self).__init__(feature, "fill_nan") + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + return series.fillna(method="ffill") + + +class Date(ElemOperator): + def __init__(self, feature): + super(Date, self).__init__(feature, "date") + + def _load_internal(self, instrument, start_index, end_index, freq): + _calendar = Cal.get_calender_day(freq=freq)[0] + series = self.feature.load(instrument, start_index, end_index, freq) + return pd.Series(_calendar[series.index], index=series.index) + +class Select(PairOperator): + def __init__(self, condition, feature): + super(Select, self).__init__(condition, feature, "select") + + def _load_internal(self, instrument, start_index, end_index, freq): + series_condition = self.feature_left.load(instrument, start_index, end_index, freq) + series_feature = self.feature_right.load(instrument, start_index, end_index, freq) + return series_feature.loc[series_condition] + +class IsNull(ElemOperator): + def __init__(self, feature): + super(IsNull, self).__init__(feature, "isnull") + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + return series.isnull() \ No newline at end of file diff --git a/examples/high_freq/highfreq_processor.py b/examples/high_freq/highfreq_processor.py new file mode 100644 index 000000000..fc86b1a70 --- /dev/null +++ b/examples/high_freq/highfreq_processor.py @@ -0,0 +1,70 @@ +import numpy as np +import pandas as pd +from qlib.data.dataset.processor import Processor +from qlib.log import TimeInspector +from qlib.data.dataset.utils import fetch_df_by_index + + +class HighFreqNorm(Processor): + def __init__(self, fit_start_time, fit_end_time): + self.fit_start_time = fit_start_time + self.fit_end_time = fit_end_time + + def fit(self, df_features): + fetch_df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime") + del df + df_values = fetch_df.values + names = { + "price": slice(0, 10), + "volume": slice(10, 12), + } + self.feature_med = {} + self.feature_std = {} + self.feature_vmax = {} + self.feature_vmin = {} + for name, name_val in names.items(): + part_values = df_values[:, name_val] + if name == "volume": + df_features.loc(axis=1)[name_val] = np.log1p(part_values) + self.feature_med[name] = np.nanmedian(part_values) + part_values = part_values - self.feature_med # mean, copy + self.feature_std[name] = np.nanmedian(np.absolute(part_values)) * 1.4826 + 1e-12 + part_values = part_values / self.feature_std + self.feature_vmax[name] = np.nanmax(part_values) + self.feature_vmin[name] = np.nanmin(part_values) + + def __call__(self, df_features): + df_features.set_index("date", append=True, drop=True, inplace=True) + df_values = df_features.values + names = { + "price": slice(0, 10), + "volume": slice(10, 12), + } + + for name, name_val in names.items(): + part_values = df_values[:, name_val] + if name == "volume": + part_values[:] = np.log1p(part_values) + part_values -= self.feature_med[name] + part_values /= self.feature_std[name] + slice0 = part_values > 3.0 + slice1 = part_values > 3.5 + slice2 = part_values < -3.0 + slice3 = part_values < -3.5 + + part_values[slice0] = 3.0 + (part_values[slice0] - 3.0) / (self.feature_vmax[name] - 3) * 0.5 + part_values[slice1] = 3.5 + part_values[slice2] = -3.0 - (part_values[slice2] + 3.0) / (self.feature_vmin[name] + 3) * 0.5 + part_values[slice3] = -3.5 + # print("start_call_feature_reshape") + idx = df_features.index.droplevel("datetime").drop_duplicates() + feat = df_values[:, [0, 1, 2, 3, 4, 10]].reshape(-1, 6 * 240) + feat_1 = df_values[:, [5, 6, 7, 8, 9, 11]].reshape(-1, 6 * 240) + + df_new_features = pd.DataFrame( + data=np.concatenate((feat, feat_1), axis=1), + index=idx, + columns=["FEATURE_%d" % i for i in range(12 * 240)], + ).sort_index() + + return df_new_features \ No newline at end of file diff --git a/examples/high_freq/workflow.py b/examples/high_freq/workflow.py new file mode 100644 index 000000000..83a344b0f --- /dev/null +++ b/examples/high_freq/workflow.py @@ -0,0 +1,137 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +from pathlib import Path + +import qlib +import pickle +import numpy as np +import pandas as pd +from qlib.config import REG_CN +from qlib.contrib.model.gbdt import LGBModel +from qlib.contrib.data.handler import Alpha158 +from qlib.contrib.strategy.strategy import TopkDropoutStrategy +from qlib.contrib.evaluate import ( + backtest as normal_backtest, + risk_analysis, +) + +from qlib.utils import init_instance_by_config +from qlib.data.dataset.handler import DataHandlerLP +from qlib.data.ops import Operators +from qlib.data.data import Cal + +from highfreq_ops import DayFirst, DayLast, FFillNan, Date, Select, IsNull + +def save_dataset(dataset, path: [Path, str]): + """ + save dataset to path + + Parameters + ---------- + path : [Path, str] + path to save + """ + dataset.to_pickle(path=path) + +def load_dataset(path: [Path, str], init_type=DataHandlerLP.IT_LS): + """ + load dataset from path + + Parameters + ---------- + path : [Path, str] + path to load + + init_type : str + - if `init_type` == DataHandlerLP.IT_FIT_SEQ: + + the input of `DataHandlerLP.fit` will be the output of the previous processor + + - if `init_type` == DataHandlerLP.IT_FIT_IND: + + the input of `DataHandlerLP.fit` will be the original df + + - if `init_type` == DataHandlerLP.IT_LS: + + The state of the object has been load by pickle + """ + fd = open(path, 'rb') + dataset = pickle.load(fd) + dataset.init(init_type=init_type) + fd.close() + return dataset + +if __name__ == "__main__": + + # use default data + provider_uri = "/mnt/v-xiabi/data/qlib/high_freq" # target_dir + qlib.init(provider_uri=provider_uri, custom_ops=[DayFirst, DayLast, FFillNan, Date, Select, IsNull], redis_port=233, region=REG_CN, auto_mount=False) + + MARKET = "csi300" + BENCHMARK = "SH000300" + + ################################### + # train model + ################################### + DATA_HANDLER_CONFIG0 = { + "start_time": "2017-01-01 00:00:00", + "end_time": "2020-11-30 15:00:00", + "freq": "1min", + "fit_start_time": "2017-01-01 00:00:00", + "fit_end_time": "2020-08-31 15:00:00", + "instruments": "all", + "infer_processors": [{"class": "HighFreqNorm", "module_path": "highfreq_processor", "kwargs": {}}], + } + DATA_HANDLER_CONFIG1 = { + "start_time": "2017-01-01 00:00:00", + "end_time": "2020-11-30 15:00:00", + "freq": "1min", + "instruments": "all", + } + + task = { + "dataset": { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "HighFreqHandler", + "module_path": "highfreq_handler", + "kwargs": DATA_HANDLER_CONFIG0, + }, + "segments": { + "train": ("2017-01-01 00:00:00", "2020-08-31 15:00:00"), + "test": ( + "2020-09-01 00:00:00", + "2020-11-30 15:00:00", + ), + }, + }, + }, + # You shoud record the data in specific sequence + # "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'], + "dataset_backtest": { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "HighFreqBacktestHandler", + "module_path": "highfreq_hander", + "kwargs": DATA_HANDLER_CONFIG1, + }, + "segments": { + "train": ("2017-01-01 00:00:00", "2020-08-31 15:00:00"), + "test": ( + "2020-09-01 00:00:00", + "2020-11-30 15:00:00", + ), + }, + }, + }, + } + Cal.get_calender_day(freq="1min") # TO FIX: load the calendar day for cache + dataset = init_instance_by_config(task["dataset"]) + dataset_backtest = init_instance_by_config(task["dataset_backtest"]) + diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py index 88a1f0680..23e37a5e4 100644 --- a/qlib/contrib/data/handler.py +++ b/qlib/contrib/data/handler.py @@ -49,6 +49,7 @@ class Alpha360(DataHandlerLP): instruments="csi500", start_time=None, end_time=None, + freq="day", infer_processors=_DEFAULT_INFER_PROCESSORS, learn_processors=_DEFAULT_LEARN_PROCESSORS, fit_start_time=None, @@ -69,9 +70,10 @@ class Alpha360(DataHandlerLP): } super().__init__( - instruments, - start_time, - end_time, + instruments=instruments, + start_time=start_time, + end_time=end_time, + freq="day", data_loader=data_loader, learn_processors=learn_processors, infer_processors=infer_processors, @@ -130,6 +132,7 @@ class Alpha158(DataHandlerLP): instruments="csi500", start_time=None, end_time=None, + freq="day", infer_processors=[], learn_processors=_DEFAULT_LEARN_PROCESSORS, fit_start_time=None, @@ -147,9 +150,10 @@ class Alpha158(DataHandlerLP): }, } super().__init__( - instruments, - start_time, - end_time, + instruments=instruments, + start_time=start_time, + end_time=end_time, + freq=freq, data_loader=data_loader, infer_processors=infer_processors, learn_processors=learn_processors, diff --git a/qlib/data/data.py b/qlib/data/data.py index d95728199..3021ebe82 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -123,6 +123,16 @@ class CalendarProvider(abc.ABC): H["c"][flag] = _calendar, _calendar_index return _calendar, _calendar_index + def get_calender_day(self, freq="day", future=False): + flag = f"{freq}_future_{future}_day" + if flag in H["c"]: + _calendar, _calendar_index = H["c"][flag] + else: + _calendar = np.array(list(map(lambda x: x.date(), self._load_calendar(freq, future)))) + _calendar_index = {x: i for i, x in enumerate(_calendar)} # for fast search + H["c"][flag] = _calendar, _calendar_index + return _calendar, _calendar_index + def _uri(self, start_time, end_time, freq, future=False): """Get the uri of calendar generation task.""" return hash_args(start_time, end_time, freq, future) @@ -686,7 +696,10 @@ class LocalExpressionProvider(ExpressionProvider): # 1) The stock data is currently float. If there is other types of data, this part needs to be re-implemented. # 2) The the precision should be configurable try: - series = series.astype(np.float32) + if series.dtype == np.float64: + series = series.astype(np.float32) + elif series.dtype == np.bool: + series = series.astype(np.int8) except ValueError: pass if not series.empty: diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 96e4a6e41..df7af3f5e 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -87,6 +87,36 @@ class DatasetH(Dataset): """ super().__init__(handler, segments) + + def init(self, init_type: str = DataHandlerLP.IT_FIT_SEQ, enable_cache: bool = False): + """ + Initialize the data of Qlib + + Parameters + ---------- + init_type : str + - if `init_type` == DataHandlerLP.IT_FIT_SEQ: + + the input of `DataHandlerLP.fit` will be the output of the previous processor + + - if `init_type` == DataHandlerLP.IT_FIT_IND: + + the input of `DataHandlerLP.fit` will be the original df + + - if `init_type` == DataHandlerLP.IT_LS: + + The state of the object has been load by pickle + + enable_cache : bool + default value is false: + + - if `enable_cache` == True: + + the processed data will be saved on disk, and handler will load the cached data from the disk directly + when we call `init` next time + """ + self.handler.init(init_type=init_type, enable_cache=enable_cache) + def setup_data(self, handler: Union[dict, DataHandler], segments: list): """ Setup the underlying data. @@ -116,8 +146,8 @@ class DatasetH(Dataset): 'outsample': ("2017-01-01", "2020-08-01",), } """ - self._handler = init_instance_by_config(handler, accept_types=DataHandler) - self._segments = segments.copy() + self.handler = init_instance_by_config(handler, accept_types=DataHandler) + self.segments = segments.copy() def _prepare_seg(self, slc: slice, **kwargs): """ @@ -127,7 +157,7 @@ class DatasetH(Dataset): ---------- slc : slice """ - return self._handler.fetch(slc, **kwargs) + return self.handler.fetch(slc, **kwargs) def prepare( self, @@ -150,7 +180,7 @@ class DatasetH(Dataset): - ['train', 'valid'] col_set : str - The col_set will be passed to self._handler when fetching data. + The col_set will be passed to self.handler when fetching data. data_key : str The data to fetch: DK_* Default is DK_I, which indicate fetching data for **inference**. @@ -166,16 +196,16 @@ class DatasetH(Dataset): logger = get_module_logger("DatasetH") fetch_kwargs = {"col_set": col_set} fetch_kwargs.update(kwargs) - if "data_key" in getfullargspec(self._handler.fetch).args: + if "data_key" in getfullargspec(self.handler.fetch).args: fetch_kwargs["data_key"] = data_key else: logger.info(f"data_key[{data_key}] is ignored.") # Handle all kinds of segments format if isinstance(segments, (list, tuple)): - return [self._prepare_seg(slice(*self._segments[seg]), **fetch_kwargs) for seg in segments] + return [self._prepare_seg(slice(*self.segments[seg]), **fetch_kwargs) for seg in segments] elif isinstance(segments, str): - return self._prepare_seg(slice(*self._segments[segments]), **fetch_kwargs) + return self._prepare_seg(slice(*self.segments[segments]), **fetch_kwargs) elif isinstance(segments, slice): return self._prepare_seg(segments, **fetch_kwargs) else: @@ -409,7 +439,7 @@ class TSDatasetH(DatasetH): def setup_data(self, *args, **kwargs): super().setup_data(*args, **kwargs) - cal = self._handler.fetch(col_set=self._handler.CS_RAW).index.get_level_values("datetime").unique() + cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique() cal = sorted(cal) # Get the datatime index for building timestamp self.cal = cal diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 18f838300..9dfc4746a 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -57,6 +57,7 @@ class DataHandler(Serializable): instruments=None, start_time=None, end_time=None, + freq="day", data_loader: Tuple[dict, str, DataLoader] = None, init_data=True, fetch_orig=True, @@ -70,6 +71,8 @@ class DataHandler(Serializable): start_time of the original data. end_time : end_time of the original data. + freq : + frequency of data data_loader : Tuple[dict, str, DataLoader] data loader to load the data. init_data : @@ -92,6 +95,7 @@ class DataHandler(Serializable): self.instruments = instruments self.start_time = start_time self.end_time = end_time + self.freq = freq self.fetch_orig = fetch_orig if init_data: with TimeInspector.logt("Init data"): @@ -119,7 +123,7 @@ class DataHandler(Serializable): # Setup data. # _data may be with multiple column index level. The outer level indicates the feature set name with TimeInspector.logt("Loading data"): - self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time) + self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time, self.freq) # TODO: cache CS_ALL = "__all" # return all columns with single-level index column @@ -258,10 +262,12 @@ class DataHandlerLP(DataHandler): instruments=None, start_time=None, end_time=None, + freq="day", data_loader: Tuple[dict, str, DataLoader] = None, infer_processors=[], learn_processors=[], process_type=PTYPE_A, + drop_raw=False, **kwargs, ): """ @@ -303,6 +309,8 @@ class DataHandlerLP(DataHandler): - self._learn will be processed by infer_processors + learn_processors - (e.g. self._infer processed by learn_processors ) + drop_raw: bool + Whether to drop the raw data """ # Setup preprocessor @@ -319,7 +327,8 @@ class DataHandlerLP(DataHandler): ) self.process_type = process_type - super().__init__(instruments, start_time, end_time, data_loader, **kwargs) + self.drop_raw = drop_raw + super().__init__(instruments, start_time, end_time, freq, data_loader, **kwargs) def get_all_processors(self): return self.infer_processors + self.learn_processors @@ -348,7 +357,7 @@ class DataHandlerLP(DataHandler): """ # data for inference _infer_df = self._data - if len(self.infer_processors) > 0: # avoid modifying the original data + if len(self.infer_processors) > 0 and not self.drop_raw: # avoid modifying the original data _infer_df = _infer_df.copy() for proc in self.infer_processors: @@ -378,6 +387,8 @@ class DataHandlerLP(DataHandler): _learn_df = proc(_learn_df) self._learn = _learn_df + if self.drop_raw: + del self._data # init type IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df @@ -416,7 +427,11 @@ class DataHandlerLP(DataHandler): # TODO: Be able to cache handler data. Save the memory for data processing def _get_df_by_key(self, data_key: str = DK_I) -> pd.DataFrame: - df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key]) + try: + df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key]) + except AttributeError: + print("please set drop_raw = False if you want to use raw data") + raise return df def fetch( diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index a51ea119a..c6d06b57f 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -19,7 +19,7 @@ class DataLoader(abc.ABC): """ @abc.abstractmethod - def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame: + def load(self, instruments, start_time=None, end_time=None, freq="day") -> pd.DataFrame: """ load the data as pd.DataFrame. @@ -94,7 +94,7 @@ class DLWParser(DataLoader): return exprs, names @abc.abstractmethod - def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame: + def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None, freq="day") -> pd.DataFrame: """ load the dataframe for specific group @@ -114,25 +114,25 @@ class DLWParser(DataLoader): """ pass - def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: + def load(self, instruments=None, start_time=None, end_time=None, freq="day") -> pd.DataFrame: if self.is_group: df = pd.concat( { - grp: self.load_group_df(instruments, exprs, names, start_time, end_time) + grp: self.load_group_df(instruments, exprs, names, start_time, end_time, freq) for grp, (exprs, names) in self.fields.items() }, axis=1, ) else: exprs, names = self.fields - df = self.load_group_df(instruments, exprs, names, start_time, end_time) + df = self.load_group_df(instruments, exprs, names, start_time, end_time, freq) return df class QlibDataLoader(DLWParser): """Same as QlibDataLoader. The fields can be define by config""" - def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None): + def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_level=True): """ Parameters ---------- @@ -140,11 +140,15 @@ class QlibDataLoader(DLWParser): Please refer to the doc of DLWParser filter_pipe : Filter pipe for the instruments + swap_level : + Whether to swap level of MultiIndex """ self.filter_pipe = filter_pipe + self.swap_level = swap_level + print("swap level", swap_level) super().__init__(config) - def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame: + def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None, freq="day") -> pd.DataFrame: if instruments is None: warnings.warn("`instruments` is not set, will load all stocks") instruments = "all" @@ -153,9 +157,10 @@ class QlibDataLoader(DLWParser): elif self.filter_pipe is not None: warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list") - df = D.features(instruments, exprs, start_time, end_time) + df = D.features(instruments, exprs, start_time, end_time, freq) df.columns = names - df = df.swaplevel().sort_index() # NOTE: always return + if self.swap_level: + df = df.swaplevel().sort_index() # NOTE: if swaplevel, return return df @@ -177,7 +182,7 @@ class StaticDataLoader(DataLoader): self.join = join self._data = None - def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: + def load(self, instruments=None, start_time=None, end_time=None, freq="day") -> pd.DataFrame: self._maybe_load_raw_data() if instruments is None: df = self._data