diff --git a/examples/highfreq/__init__.py b/examples/highfreq/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/highfreq/highfreq_handler.py b/examples/highfreq/highfreq_handler.py new file mode 100644 index 000000000..be2084626 --- /dev/null +++ b/examples/highfreq/highfreq_handler.py @@ -0,0 +1,174 @@ +from qlib.data.dataset.handler import DataHandler, DataHandlerLP +from qlib.data.dataset.processor import Processor +from qlib.utils import get_cls_kwargs +from qlib.log import TimeInspector + + +class HighFreqHandler(DataHandlerLP): + def __init__( + self, + instruments="csi300", + start_time=None, + end_time=None, + freq="1min", + infer_processors=[], + learn_processors=[], + fit_start_time=None, + fit_end_time=None, + drop_raw=True, + ): + def check_transform_proc(proc_l): + new_l = [] + for p in proc_l: + p["kwargs"].update( + { + "fit_start_time": fit_start_time, + "fit_end_time": fit_end_time, + } + ) + new_l.append(p) + return new_l + + infer_processors = check_transform_proc(infer_processors) + learn_processors = check_transform_proc(learn_processors) + + data_loader = { + "class": "QlibDataLoader", + "kwargs": { + "config": self.get_feature_config(), + "swap_level": False, + }, + } + super().__init__( + instruments=instruments, + start_time=start_time, + end_time=end_time, + freq=freq, + data_loader=data_loader, + infer_processors=infer_processors, + learn_processors=learn_processors, + drop_raw=drop_raw, + ) + + def get_feature_config(self): + fields = [] + names = [] + + template_if = "If(IsNull({1}), {0}, {1})" + template_paused = "Select(Or(IsNull($paused), Eq($paused, 0.0)), {0})" + template_fillnan = "BFillNan(FFillNan({0}))" + # Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap + simpson_vwap = "($open + 2*$high + 2*$low + $close)/6" + + def get_normalized_price_feature(price_field, shift=0): + """Get normalized price feature ops""" + if shift == 0: + template_norm = "{0}/Ref(DayLast({1}), 240)" + else: + template_norm = "Ref({0}, " + str(shift) + ")/Ref(DayLast({1}), 240)" + + feature_ops = template_norm.format( + template_if.format( + template_fillnan.format(template_paused.format("$close")), + template_paused.format(price_field), + ), + template_fillnan.format(template_paused.format("$close")), + ) + return feature_ops + + fields += [get_normalized_price_feature("$open", 0)] + fields += [get_normalized_price_feature("$high", 0)] + fields += [get_normalized_price_feature("$low", 0)] + fields += [get_normalized_price_feature("$close", 0)] + fields += [get_normalized_price_feature(simpson_vwap, 0)] + names += ["$open", "$high", "$low", "$close", "$vwap"] + + fields += [get_normalized_price_feature("$open", 240)] + fields += [get_normalized_price_feature("$high", 240)] + fields += [get_normalized_price_feature("$low", 240)] + fields += [get_normalized_price_feature("$close", 240)] + fields += [get_normalized_price_feature(simpson_vwap, 240)] + names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"] + + fields += [ + "{0}/Ref(DayLast(Mean({0}, 7200)), 240)".format( + "If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format( + template_paused.format("$volume"), + template_paused.format(simpson_vwap), + template_paused.format("$low"), + template_paused.format("$high"), + ) + ) + ] + names += ["$volume"] + fields += [ + "Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240)".format( + "If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format( + template_paused.format("$volume"), + template_paused.format(simpson_vwap), + template_paused.format("$low"), + template_paused.format("$high"), + ) + ) + ] + names += ["$volume_1"] + + fields += [template_paused.format("Date($close)")] + names += ["date"] + return fields, names + + +class HighFreqBacktestHandler(DataHandler): + def __init__( + self, + instruments="csi300", + start_time=None, + end_time=None, + freq="1min", + ): + data_loader = { + "class": "QlibDataLoader", + "kwargs": { + "config": self.get_feature_config(), + "swap_level": False, + }, + } + super().__init__( + instruments=instruments, + start_time=start_time, + end_time=end_time, + freq=freq, + data_loader=data_loader, + ) + + def get_feature_config(self): + fields = [] + names = [] + + template_if = "If(IsNull({1}), {0}, {1})" + template_paused = "Select(Or(IsNull($paused), Eq($paused, 0.0)), {0})" + template_fillnan = "BFillNan(FFillNan({0}))" + # Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap + simpson_vwap = "($open + 2*$high + 2*$low + $close)/6" + fields += [ + template_fillnan.format(template_paused.format("$close")), + ] + names += ["$close0"] + fields += [ + template_if.format( + template_fillnan.format(template_paused.format("$close")), + template_paused.format(simpson_vwap), + ) + ] + names += ["$vwap0"] + fields += [ + "If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format( + template_paused.format("$volume"), + template_paused.format(simpson_vwap), + template_paused.format("$low"), + template_paused.format("$high"), + ) + ] + names += ["$volume0"] + + return fields, names diff --git a/examples/highfreq/highfreq_ops.py b/examples/highfreq/highfreq_ops.py new file mode 100644 index 000000000..85ed63285 --- /dev/null +++ b/examples/highfreq/highfreq_ops.py @@ -0,0 +1,56 @@ +import numpy as np +import pandas as pd +import importlib +from qlib.data.ops import ElemOperator, PairOperator +from qlib.config import C +from qlib.data.cache import H +from qlib.data.data import Cal + + +def get_calendar_day(freq="day", future=False): + flag = f"{freq}_future_{future}_day" + if flag in H["c"]: + _calendar = H["c"][flag] + else: + _calendar = np.array(list(map(lambda x: x.date(), Cal.load_calendar(freq, future)))) + H["c"][flag] = _calendar + return _calendar + + +class DayLast(ElemOperator): + def _load_internal(self, instrument, start_index, end_index, freq): + _calendar = get_calendar_day(freq=freq) + series = self.feature.load(instrument, start_index, end_index, freq) + return series.groupby(_calendar[series.index]).transform("last") + + +class FFillNan(ElemOperator): + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + return series.fillna(method="ffill") + + +class BFillNan(ElemOperator): + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + return series.fillna(method="bfill") + + +class Date(ElemOperator): + def _load_internal(self, instrument, start_index, end_index, freq): + _calendar = get_calendar_day(freq=freq) + series = self.feature.load(instrument, start_index, end_index, freq) + return pd.Series(_calendar[series.index], index=series.index) + + +class Select(PairOperator): + def _load_internal(self, instrument, start_index, end_index, freq): + series_condition = self.feature_left.load(instrument, start_index, end_index, freq) + series_feature = self.feature_right.load(instrument, start_index, end_index, freq) + return series_feature.loc[series_condition] + + +class IsNull(ElemOperator): + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + return series.isnull() diff --git a/examples/highfreq/highfreq_processor.py b/examples/highfreq/highfreq_processor.py new file mode 100644 index 000000000..f0ab0dec2 --- /dev/null +++ b/examples/highfreq/highfreq_processor.py @@ -0,0 +1,72 @@ +import numpy as np +import pandas as pd +from qlib.data.dataset.processor import Processor +from qlib.data.dataset.utils import fetch_df_by_index + + +class HighFreqNorm(Processor): + def __init__(self, fit_start_time, fit_end_time): + self.fit_start_time = fit_start_time + self.fit_end_time = fit_end_time + + def fit(self, df_features): + fetch_df = fetch_df_by_index(df_features, slice(self.fit_start_time, self.fit_end_time), level="datetime") + del df_features + df_values = fetch_df.values + names = { + "price": slice(0, 10), + "volume": slice(10, 12), + } + self.feature_med = {} + self.feature_std = {} + self.feature_vmax = {} + self.feature_vmin = {} + for name, name_val in names.items(): + part_values = df_values[:, name_val].astype(np.float32) + if name == "volume": + part_values = np.log1p(part_values) + self.feature_med[name] = np.nanmedian(part_values) + part_values = part_values - self.feature_med[name] + self.feature_std[name] = np.nanmedian(np.absolute(part_values)) * 1.4826 + 1e-12 + part_values = part_values / self.feature_std[name] + self.feature_vmax[name] = np.nanmax(part_values) + self.feature_vmin[name] = np.nanmin(part_values) + + def __call__(self, df_features): + df_features.set_index("date", append=True, drop=True, inplace=True) + df_values = df_features.values + names = { + "price": slice(0, 10), + "volume": slice(10, 12), + } + + for name, name_val in names.items(): + if name == "volume": + df_values[:, name_val] = np.log1p(df_values[:, name_val]) + df_values[:, name_val] -= self.feature_med[name] + df_values[:, name_val] /= self.feature_std[name] + slice0 = df_values[:, name_val] > 3.0 + slice1 = df_values[:, name_val] > 3.5 + slice2 = df_values[:, name_val] < -3.0 + slice3 = df_values[:, name_val] < -3.5 + + df_values[:, name_val][slice0] = ( + 3.0 + (df_values[:, name_val][slice0] - 3.0) / (self.feature_vmax[name] - 3) * 0.5 + ) + df_values[:, name_val][slice1] = 3.5 + df_values[:, name_val][slice2] = ( + -3.0 - (df_values[:, name_val][slice2] + 3.0) / (self.feature_vmin[name] + 3) * 0.5 + ) + df_values[:, name_val][slice3] = -3.5 + idx = df_features.index.droplevel("datetime").drop_duplicates() + idx.set_names(["instrument", "datetime"], inplace=True) + + # Reshape is specifically for adapting to RL high-freq executor + feat = df_values[:, [0, 1, 2, 3, 4, 10]].reshape(-1, 6 * 240) + feat_1 = df_values[:, [5, 6, 7, 8, 9, 11]].reshape(-1, 6 * 240) + df_new_features = pd.DataFrame( + data=np.concatenate((feat, feat_1), axis=1), + index=idx, + columns=["FEATURE_%d" % i for i in range(12 * 240)], + ).sort_index() + return df_new_features diff --git a/examples/highfreq/workflow.py b/examples/highfreq/workflow.py new file mode 100644 index 000000000..e5fdcdb59 --- /dev/null +++ b/examples/highfreq/workflow.py @@ -0,0 +1,166 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +import fire +from pathlib import Path + +import qlib +import pickle +import numpy as np +import pandas as pd +from qlib.config import HIGH_FREQ_CONFIG +from qlib.contrib.model.gbdt import LGBModel +from qlib.contrib.data.handler import Alpha158 +from qlib.contrib.strategy.strategy import TopkDropoutStrategy +from qlib.contrib.evaluate import ( + backtest as normal_backtest, + risk_analysis, +) + +from qlib.utils import init_instance_by_config, exists_qlib_data +from qlib.data.dataset.handler import DataHandlerLP +from qlib.data.ops import Operators +from qlib.data.data import Cal +from qlib.tests.data import GetData + +from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull + + +class HighfreqWorkflow(object): + + SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull], "expression_cache": None} + + MARKET = "all" + BENCHMARK = "SH000300" + + start_time = "2020-09-14 00:00:00" + end_time = "2021-01-18 16:00:00" + train_end_time = "2020-11-30 16:00:00" + test_start_time = "2020-12-01 00:00:00" + + DATA_HANDLER_CONFIG0 = { + "start_time": start_time, + "end_time": end_time, + "freq": "1min", + "fit_start_time": start_time, + "fit_end_time": train_end_time, + "instruments": MARKET, + "infer_processors": [{"class": "HighFreqNorm", "module_path": "highfreq_processor", "kwargs": {}}], + } + DATA_HANDLER_CONFIG1 = { + "start_time": start_time, + "end_time": end_time, + "freq": "1min", + "instruments": MARKET, + } + + task = { + "dataset": { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "HighFreqHandler", + "module_path": "highfreq_handler", + "kwargs": DATA_HANDLER_CONFIG0, + }, + "segments": { + "train": (start_time, train_end_time), + "test": ( + test_start_time, + end_time, + ), + }, + }, + }, + "dataset_backtest": { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "HighFreqBacktestHandler", + "module_path": "highfreq_handler", + "kwargs": DATA_HANDLER_CONFIG1, + }, + "segments": { + "train": (start_time, train_end_time), + "test": ( + test_start_time, + end_time, + ), + }, + }, + }, + } + + def _init_qlib(self): + """initialize qlib""" + # use yahoo_cn_1min data + QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF} + provider_uri = QLIB_INIT_CONFIG.get("provider_uri") + if not exists_qlib_data(provider_uri): + print(f"Qlib data is not found in {provider_uri}") + GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN) + qlib.init(**QLIB_INIT_CONFIG) + + def _prepare_calender_cache(self): + """preload the calendar for cache""" + + # This code used the copy-on-write feature of Linux to avoid calculating the calendar multiple times in the subprocess + # This code may accelerate, but may be not useful on Windows and Mac Os + Cal.calendar(freq="1min") + get_calendar_day(freq="1min") + + def get_data(self): + """use dataset to get highreq data""" + self._init_qlib() + self._prepare_calender_cache() + + dataset = init_instance_by_config(self.task["dataset"]) + xtrain, xtest = dataset.prepare(["train", "test"]) + print(xtrain, xtest) + + dataset_backtest = init_instance_by_config(self.task["dataset_backtest"]) + backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"]) + print(backtest_train, backtest_test) + + del xtrain, xtest + del backtest_train, backtest_test + + def dump_and_load_dataset(self): + """dump and load dataset state on disk""" + self._init_qlib() + self._prepare_calender_cache() + dataset = init_instance_by_config(self.task["dataset"]) + dataset_backtest = init_instance_by_config(self.task["dataset_backtest"]) + + ##=============dump dataset============= + dataset.to_pickle(path="dataset.pkl") + dataset_backtest.to_pickle(path="dataset_backtest.pkl") + + del dataset, dataset_backtest + ##=============reload dataset============= + with open("dataset.pkl", "rb") as file_dataset: + dataset = pickle.load(file_dataset) + + with open("dataset_backtest.pkl", "rb") as file_dataset_backtest: + dataset_backtest = pickle.load(file_dataset_backtest) + + self._prepare_calender_cache() + ##=============reload_dataset============= + dataset.init(init_type=DataHandlerLP.IT_LS) + dataset_backtest.init() + + ##=============get data============= + xtrain, xtest = dataset.prepare(["train", "test"]) + backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"]) + + print(xtrain, xtest) + print(backtest_train, backtest_test) + del xtrain, xtest + del backtest_train, backtest_test + + +if __name__ == "__main__": + fire.Fire(HighfreqWorkflow) diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py index ea9c70083..6d166646c 100644 --- a/examples/workflow_by_code.py +++ b/examples/workflow_by_code.py @@ -17,7 +17,7 @@ from qlib.contrib.evaluate import ( from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord, PortAnaRecord - +from qlib.tests.data import GetData if __name__ == "__main__": @@ -25,9 +25,6 @@ if __name__ == "__main__": provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir if not exists_qlib_data(provider_uri): print(f"Qlib data is not found in {provider_uri}") - sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) - from get_data import GetData - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) qlib.init(provider_uri=provider_uri, region=REG_CN) diff --git a/qlib/config.py b/qlib/config.py index a65d41041..e7120c23a 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -193,6 +193,12 @@ MODE_CONF = { }, } +HIGH_FREQ_CONFIG = { + "provider_uri": "~/.qlib/qlib_data/yahoo_cn_1min", + "dataset_cache": None, + "expression_cache": "DiskExpressionCache", + "region": REG_CN, +} _default_region_config = { REG_CN: { @@ -291,12 +297,12 @@ class QlibConfig(Config): def register(self): from .utils import init_instance_by_config - from .data.ops import register_custom_ops + from .data.ops import register_all_ops from .data.data import register_all_wrappers from .workflow import R, QlibRecorder from .workflow.utils import experiment_exit_handler - register_custom_ops(self) + register_all_ops(self) register_all_wrappers(self) # set up QlibRecorder exp_manager = init_instance_by_config(self["exp_manager"]) diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py index 88a1f0680..23e37a5e4 100644 --- a/qlib/contrib/data/handler.py +++ b/qlib/contrib/data/handler.py @@ -49,6 +49,7 @@ class Alpha360(DataHandlerLP): instruments="csi500", start_time=None, end_time=None, + freq="day", infer_processors=_DEFAULT_INFER_PROCESSORS, learn_processors=_DEFAULT_LEARN_PROCESSORS, fit_start_time=None, @@ -69,9 +70,10 @@ class Alpha360(DataHandlerLP): } super().__init__( - instruments, - start_time, - end_time, + instruments=instruments, + start_time=start_time, + end_time=end_time, + freq="day", data_loader=data_loader, learn_processors=learn_processors, infer_processors=infer_processors, @@ -130,6 +132,7 @@ class Alpha158(DataHandlerLP): instruments="csi500", start_time=None, end_time=None, + freq="day", infer_processors=[], learn_processors=_DEFAULT_LEARN_PROCESSORS, fit_start_time=None, @@ -147,9 +150,10 @@ class Alpha158(DataHandlerLP): }, } super().__init__( - instruments, - start_time, - end_time, + instruments=instruments, + start_time=start_time, + end_time=end_time, + freq=freq, data_loader=data_loader, infer_processors=infer_processors, learn_processors=learn_processors, diff --git a/qlib/data/base.py b/qlib/data/base.py index 92fc57ffe..e318843c4 100644 --- a/qlib/data/base.py +++ b/qlib/data/base.py @@ -157,7 +157,7 @@ class Expression(abc.ABC): @abc.abstractmethod def _load_internal(self, instrument, start_index, end_index, freq): - pass + raise NotImplementedError("This function must be implemented in your newly defined feature") @abc.abstractmethod def get_longest_back_rolling(self): diff --git a/qlib/data/data.py b/qlib/data/data.py index ece3c3641..2a0e569ab 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -117,7 +117,7 @@ class CalendarProvider(abc.ABC): if flag in H["c"]: _calendar, _calendar_index = H["c"][flag] else: - _calendar = np.array(self._load_calendar(freq, future)) + _calendar = np.array(self.load_calendar(freq, future)) _calendar_index = {x: i for i, x in enumerate(_calendar)} # for fast search H["c"][flag] = _calendar, _calendar_index return _calendar, _calendar_index @@ -504,7 +504,7 @@ class LocalCalendarProvider(CalendarProvider): """Calendar file uri.""" return os.path.join(C.get_data_path(), "calendars", "{}.txt") - def _load_calendar(self, freq, future): + def load_calendar(self, freq, future): """Load original calendar timestamp from file. Parameters @@ -672,6 +672,8 @@ class LocalExpressionProvider(ExpressionProvider): series = series.astype(np.float32) except ValueError: pass + except TypeError: + pass if not series.empty: series = series.loc[start_index:end_index] return series diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 96e4a6e41..117da764f 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -87,6 +87,10 @@ class DatasetH(Dataset): """ super().__init__(handler, segments) + def init(self, **kwargs): + """Initialize the DatasetH, Only parameters belonging to handler.init will be passed in""" + self.handler.init(**kwargs) + def setup_data(self, handler: Union[dict, DataHandler], segments: list): """ Setup the underlying data. @@ -116,8 +120,8 @@ class DatasetH(Dataset): 'outsample': ("2017-01-01", "2020-08-01",), } """ - self._handler = init_instance_by_config(handler, accept_types=DataHandler) - self._segments = segments.copy() + self.handler = init_instance_by_config(handler, accept_types=DataHandler) + self.segments = segments.copy() def _prepare_seg(self, slc: slice, **kwargs): """ @@ -127,7 +131,7 @@ class DatasetH(Dataset): ---------- slc : slice """ - return self._handler.fetch(slc, **kwargs) + return self.handler.fetch(slc, **kwargs) def prepare( self, @@ -150,7 +154,7 @@ class DatasetH(Dataset): - ['train', 'valid'] col_set : str - The col_set will be passed to self._handler when fetching data. + The col_set will be passed to self.handler when fetching data. data_key : str The data to fetch: DK_* Default is DK_I, which indicate fetching data for **inference**. @@ -166,16 +170,16 @@ class DatasetH(Dataset): logger = get_module_logger("DatasetH") fetch_kwargs = {"col_set": col_set} fetch_kwargs.update(kwargs) - if "data_key" in getfullargspec(self._handler.fetch).args: + if "data_key" in getfullargspec(self.handler.fetch).args: fetch_kwargs["data_key"] = data_key else: logger.info(f"data_key[{data_key}] is ignored.") # Handle all kinds of segments format if isinstance(segments, (list, tuple)): - return [self._prepare_seg(slice(*self._segments[seg]), **fetch_kwargs) for seg in segments] + return [self._prepare_seg(slice(*self.segments[seg]), **fetch_kwargs) for seg in segments] elif isinstance(segments, str): - return self._prepare_seg(slice(*self._segments[segments]), **fetch_kwargs) + return self._prepare_seg(slice(*self.segments[segments]), **fetch_kwargs) elif isinstance(segments, slice): return self._prepare_seg(segments, **fetch_kwargs) else: @@ -409,7 +413,7 @@ class TSDatasetH(DatasetH): def setup_data(self, *args, **kwargs): super().setup_data(*args, **kwargs) - cal = self._handler.fetch(col_set=self._handler.CS_RAW).index.get_level_values("datetime").unique() + cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique() cal = sorted(cal) # Get the datatime index for building timestamp self.cal = cal diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 18f838300..abcd5a60c 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -57,6 +57,7 @@ class DataHandler(Serializable): instruments=None, start_time=None, end_time=None, + freq="day", data_loader: Tuple[dict, str, DataLoader] = None, init_data=True, fetch_orig=True, @@ -70,6 +71,8 @@ class DataHandler(Serializable): start_time of the original data. end_time : end_time of the original data. + freq : + frequency of data data_loader : Tuple[dict, str, DataLoader] data loader to load the data. init_data : @@ -92,6 +95,7 @@ class DataHandler(Serializable): self.instruments = instruments self.start_time = start_time self.end_time = end_time + self.freq = freq self.fetch_orig = fetch_orig if init_data: with TimeInspector.logt("Init data"): @@ -119,7 +123,7 @@ class DataHandler(Serializable): # Setup data. # _data may be with multiple column index level. The outer level indicates the feature set name with TimeInspector.logt("Loading data"): - self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time) + self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time, self.freq) # TODO: cache CS_ALL = "__all" # return all columns with single-level index column @@ -258,10 +262,12 @@ class DataHandlerLP(DataHandler): instruments=None, start_time=None, end_time=None, + freq="day", data_loader: Tuple[dict, str, DataLoader] = None, infer_processors=[], learn_processors=[], process_type=PTYPE_A, + drop_raw=False, **kwargs, ): """ @@ -303,6 +309,8 @@ class DataHandlerLP(DataHandler): - self._learn will be processed by infer_processors + learn_processors - (e.g. self._infer processed by learn_processors ) + drop_raw: bool + Whether to drop the raw data """ # Setup preprocessor @@ -319,7 +327,8 @@ class DataHandlerLP(DataHandler): ) self.process_type = process_type - super().__init__(instruments, start_time, end_time, data_loader, **kwargs) + self.drop_raw = drop_raw + super().__init__(instruments, start_time, end_time, freq, data_loader, **kwargs) def get_all_processors(self): return self.infer_processors + self.learn_processors @@ -348,7 +357,7 @@ class DataHandlerLP(DataHandler): """ # data for inference _infer_df = self._data - if len(self.infer_processors) > 0: # avoid modifying the original data + if len(self.infer_processors) > 0 and not self.drop_raw: # avoid modifying the original data _infer_df = _infer_df.copy() for proc in self.infer_processors: @@ -378,6 +387,9 @@ class DataHandlerLP(DataHandler): _learn_df = proc(_learn_df) self._learn = _learn_df + if self.drop_raw: + del self._data + # init type IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df @@ -416,6 +428,10 @@ class DataHandlerLP(DataHandler): # TODO: Be able to cache handler data. Save the memory for data processing def _get_df_by_key(self, data_key: str = DK_I) -> pd.DataFrame: + if data_key == self.DK_R and self.drop_raw: + raise AttributeError( + "DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data" + ) df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key]) return df diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index a51ea119a..3b33ff749 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -19,7 +19,7 @@ class DataLoader(abc.ABC): """ @abc.abstractmethod - def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame: + def load(self, instruments, start_time=None, end_time=None, freq="day") -> pd.DataFrame: """ load the data as pd.DataFrame. @@ -94,7 +94,9 @@ class DLWParser(DataLoader): return exprs, names @abc.abstractmethod - def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame: + def load_group_df( + self, instruments, exprs: list, names: list, start_time=None, end_time=None, freq="day" + ) -> pd.DataFrame: """ load the dataframe for specific group @@ -114,25 +116,25 @@ class DLWParser(DataLoader): """ pass - def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: + def load(self, instruments=None, start_time=None, end_time=None, freq="day") -> pd.DataFrame: if self.is_group: df = pd.concat( { - grp: self.load_group_df(instruments, exprs, names, start_time, end_time) + grp: self.load_group_df(instruments, exprs, names, start_time, end_time, freq) for grp, (exprs, names) in self.fields.items() }, axis=1, ) else: exprs, names = self.fields - df = self.load_group_df(instruments, exprs, names, start_time, end_time) + df = self.load_group_df(instruments, exprs, names, start_time, end_time, freq) return df class QlibDataLoader(DLWParser): """Same as QlibDataLoader. The fields can be define by config""" - def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None): + def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_level=True): """ Parameters ---------- @@ -140,11 +142,16 @@ class QlibDataLoader(DLWParser): Please refer to the doc of DLWParser filter_pipe : Filter pipe for the instruments + swap_level : + Whether to swap level of MultiIndex """ self.filter_pipe = filter_pipe + self.swap_level = swap_level super().__init__(config) - def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame: + def load_group_df( + self, instruments, exprs: list, names: list, start_time=None, end_time=None, freq="day" + ) -> pd.DataFrame: if instruments is None: warnings.warn("`instruments` is not set, will load all stocks") instruments = "all" @@ -153,9 +160,10 @@ class QlibDataLoader(DLWParser): elif self.filter_pipe is not None: warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list") - df = D.features(instruments, exprs, start_time, end_time) + df = D.features(instruments, exprs, start_time, end_time, freq) df.columns = names - df = df.swaplevel().sort_index() # NOTE: always return + if self.swap_level: + df = df.swaplevel().sort_index() # NOTE: if swaplevel, return return df @@ -177,7 +185,7 @@ class StaticDataLoader(DataLoader): self.join = join self._data = None - def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: + def load(self, instruments=None, start_time=None, end_time=None, freq="day") -> pd.DataFrame: self._maybe_load_raw_data() if instruments is None: df = self._data diff --git a/qlib/data/ops.py b/qlib/data/ops.py index 91f7349d2..940c24002 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -6,6 +6,7 @@ from __future__ import division from __future__ import print_function import sys +import abc import numpy as np import pandas as pd @@ -17,7 +18,7 @@ from ..log import get_module_logger try: from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi from ._libs.expanding import expanding_slope, expanding_rsquare, expanding_resi -except ImportError as err: +except ImportError: print( "#### Do not import qlib package in the repository directory in case of importing qlib from . without compiling #####" ) @@ -32,12 +33,39 @@ np.seterr(invalid="ignore") class ElemOperator(ExpressionOps): """Element-wise Operator + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + Expression + feature operation output + """ + + def __init__(self, feature): + self.feature = feature + + def __str__(self): + return "{}({})".format(type(self).__name__, self.feature) + + def get_longest_back_rolling(self): + return self.feature.get_longest_back_rolling() + + def get_extended_window_size(self): + return self.feature.get_extended_window_size() + + +class NpElemOperator(ElemOperator): + """Numpy Element-wise Operator + Parameters ---------- feature : Expression feature instance func : str - feature operation method + numpy feature operation method Returns ---------- @@ -48,22 +76,14 @@ class ElemOperator(ExpressionOps): def __init__(self, feature, func): self.feature = feature self.func = func - - def __str__(self): - return "{}({})".format(type(self).__name__, self.feature) + super(NpElemOperator, self).__init__(feature) def _load_internal(self, instrument, start_index, end_index, freq): series = self.feature.load(instrument, start_index, end_index, freq) return getattr(np, self.func)(series) - def get_longest_back_rolling(self): - return self.feature.get_longest_back_rolling() - def get_extended_window_size(self): - return self.feature.get_extended_window_size() - - -class Abs(ElemOperator): +class Abs(NpElemOperator): """Feature Absolute Value Parameters @@ -81,7 +101,7 @@ class Abs(ElemOperator): super(Abs, self).__init__(feature, "abs") -class Sign(ElemOperator): +class Sign(NpElemOperator): """Feature Sign Parameters @@ -108,7 +128,7 @@ class Sign(ElemOperator): return getattr(np, self.func)(series) -class Log(ElemOperator): +class Log(NpElemOperator): """Feature Log Parameters @@ -126,7 +146,7 @@ class Log(ElemOperator): super(Log, self).__init__(feature, "log") -class Power(ElemOperator): +class Power(NpElemOperator): """Feature Power Parameters @@ -152,7 +172,7 @@ class Power(ElemOperator): return getattr(np, self.func)(series, self.exponent) -class Mask(ElemOperator): +class Mask(NpElemOperator): """Feature Mask Parameters @@ -179,7 +199,7 @@ class Mask(ElemOperator): return self.feature.load(self.instrument, start_index, end_index, freq) -class Not(ElemOperator): +class Not(NpElemOperator): """Not Operator Parameters @@ -218,28 +238,13 @@ class PairOperator(ExpressionOps): two features' operation output """ - def __init__(self, feature_left, feature_right, func): + def __init__(self, feature_left, feature_right): self.feature_left = feature_left self.feature_right = feature_right - self.func = func def __str__(self): return "{}({},{})".format(type(self).__name__, self.feature_left, self.feature_right) - def _load_internal(self, instrument, start_index, end_index, freq): - assert any( - [isinstance(self.feature_left, Expression), self.feature_right, Expression] - ), "at least one of two inputs is Expression instance" - if isinstance(self.feature_left, Expression): - series_left = self.feature_left.load(instrument, start_index, end_index, freq) - else: - series_left = self.feature_left # numeric value - if isinstance(self.feature_right, Expression): - series_right = self.feature_right.load(instrument, start_index, end_index, freq) - else: - series_right = self.feature_right - return getattr(np, self.func)(series_left, series_right) - def get_longest_back_rolling(self): if isinstance(self.feature_left, Expression): left_br = self.feature_left.get_longest_back_rolling() @@ -265,7 +270,46 @@ class PairOperator(ExpressionOps): return max(ll, rl), max(lr, rr) -class Add(PairOperator): +class NpPairOperator(PairOperator): + """Numpy Pair-wise operator + + Parameters + ---------- + feature_left : Expression + feature instance or numeric value + feature_right : Expression + feature instance or numeric value + func : str + operator function + + Returns + ---------- + Feature: + two features' operation output + """ + + def __init__(self, feature_left, feature_right, func): + self.feature_left = feature_left + self.feature_right = feature_right + self.func = func + super(NpPairOperator, self).__init__(feature_left, feature_right) + + def _load_internal(self, instrument, start_index, end_index, freq): + assert any( + [isinstance(self.feature_left, Expression), self.feature_right, Expression] + ), "at least one of two inputs is Expression instance" + if isinstance(self.feature_left, Expression): + series_left = self.feature_left.load(instrument, start_index, end_index, freq) + else: + series_left = self.feature_left # numeric value + if isinstance(self.feature_right, Expression): + series_right = self.feature_right.load(instrument, start_index, end_index, freq) + else: + series_right = self.feature_right + return getattr(np, self.func)(series_left, series_right) + + +class Add(NpPairOperator): """Add Operator Parameters @@ -285,7 +329,7 @@ class Add(PairOperator): super(Add, self).__init__(feature_left, feature_right, "add") -class Sub(PairOperator): +class Sub(NpPairOperator): """Subtract Operator Parameters @@ -305,7 +349,7 @@ class Sub(PairOperator): super(Sub, self).__init__(feature_left, feature_right, "subtract") -class Mul(PairOperator): +class Mul(NpPairOperator): """Multiply Operator Parameters @@ -325,7 +369,7 @@ class Mul(PairOperator): super(Mul, self).__init__(feature_left, feature_right, "multiply") -class Div(PairOperator): +class Div(NpPairOperator): """Division Operator Parameters @@ -345,7 +389,7 @@ class Div(PairOperator): super(Div, self).__init__(feature_left, feature_right, "divide") -class Greater(PairOperator): +class Greater(NpPairOperator): """Greater Operator Parameters @@ -365,7 +409,7 @@ class Greater(PairOperator): super(Greater, self).__init__(feature_left, feature_right, "maximum") -class Less(PairOperator): +class Less(NpPairOperator): """Less Operator Parameters @@ -385,7 +429,7 @@ class Less(PairOperator): super(Less, self).__init__(feature_left, feature_right, "minimum") -class Gt(PairOperator): +class Gt(NpPairOperator): """Greater Than Operator Parameters @@ -405,7 +449,7 @@ class Gt(PairOperator): super(Gt, self).__init__(feature_left, feature_right, "greater") -class Ge(PairOperator): +class Ge(NpPairOperator): """Greater Equal Than Operator Parameters @@ -425,7 +469,7 @@ class Ge(PairOperator): super(Ge, self).__init__(feature_left, feature_right, "greater_equal") -class Lt(PairOperator): +class Lt(NpPairOperator): """Less Than Operator Parameters @@ -445,7 +489,7 @@ class Lt(PairOperator): super(Lt, self).__init__(feature_left, feature_right, "less") -class Le(PairOperator): +class Le(NpPairOperator): """Less Equal Than Operator Parameters @@ -465,7 +509,7 @@ class Le(PairOperator): super(Le, self).__init__(feature_left, feature_right, "less_equal") -class Eq(PairOperator): +class Eq(NpPairOperator): """Equal Operator Parameters @@ -485,7 +529,7 @@ class Eq(PairOperator): super(Eq, self).__init__(feature_left, feature_right, "equal") -class Ne(PairOperator): +class Ne(NpPairOperator): """Not Equal Operator Parameters @@ -505,7 +549,7 @@ class Ne(PairOperator): super(Ne, self).__init__(feature_left, feature_right, "not_equal") -class And(PairOperator): +class And(NpPairOperator): """And Operator Parameters @@ -525,7 +569,7 @@ class And(PairOperator): super(And, self).__init__(feature_left, feature_right, "bitwise_and") -class Or(PairOperator): +class Or(NpPairOperator): """Or Operator Parameters @@ -1451,6 +1495,9 @@ class OpsWrapper(object): def __init__(self): self._ops = {} + def reset(self): + self._ops = {} + def register(self, ops_list): for operator in ops_list: if not issubclass(operator, ExpressionOps): @@ -1469,12 +1516,15 @@ class OpsWrapper(object): Operators = OpsWrapper() -Operators.register(OpsList) -def register_custom_ops(C): - """register custom operator""" +def register_all_ops(C): + """register all operator""" logger = get_module_logger("ops") + + Operators.reset() + Operators.register(OpsList) + if getattr(C, "custom_ops", None) is not None: Operators.register(C.custom_ops) logger.debug("register custom operator {}".format(C.custom_ops)) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 10cd588e6..ed2f14d2f 100755 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -66,7 +66,7 @@ class TestDataset(TestAutoData): # Check the data # Get data from DataFrame Directly data_from_df = ( - tsdh._handler.fetch(data_key=DataHandlerLP.DK_L) + tsdh.handler.fetch(data_key=DataHandlerLP.DK_L) .loc(axis=0)["2015-01-01":"2016-12-31", "SZ300315"] .iloc[-30:] .values diff --git a/tests/test_register_ops.py b/tests/test_register_ops.py index cb172b2bb..7d3322ddc 100644 --- a/tests/test_register_ops.py +++ b/tests/test_register_ops.py @@ -26,9 +26,6 @@ class Diff(ElemOperator): a feature instance with first difference """ - def __init__(self, feature): - super(Diff, self).__init__(feature, "diff") - def _load_internal(self, instrument, start_index, end_index, freq): series = self.feature.load(instrument, start_index, end_index, freq) return series.diff() @@ -50,9 +47,6 @@ class Distance(PairOperator): a feature instance with distance """ - def __init__(self, feature_left, feature_right): - super(Distance, self).__init__(feature_left, feature_right, "distance") - def _load_internal(self, instrument, start_index, end_index, freq): series_left = self.feature_left.load(instrument, start_index, end_index, freq) series_right = self.feature_right.load(instrument, start_index, end_index, freq)