From f6dd006c35139c6528c5507e2f60d7c3c7eaab72 Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 28 Jan 2021 11:31:15 +0000 Subject: [PATCH] update --- examples/{high_freq => highfreq}/__init__.py | 0 .../highfreq_handler.py | 107 ++++---------- .../{high_freq => highfreq}/highfreq_ops.py | 42 ++++-- .../highfreq_processor.py | 32 ++-- examples/{high_freq => highfreq}/workflow.py | 120 +++++++-------- examples/workflow_by_code.py | 2 +- qlib/config.py | 6 + qlib/data/base.py | 2 +- qlib/data/data.py | 21 +-- qlib/data/dataset/__init__.py | 11 +- qlib/data/dataset/handler.py | 12 +- qlib/data/ops.py | 138 ++++++++++++------ 12 files changed, 242 insertions(+), 251 deletions(-) rename examples/{high_freq => highfreq}/__init__.py (100%) rename examples/{high_freq => highfreq}/highfreq_handler.py (62%) rename examples/{high_freq => highfreq}/highfreq_ops.py (69%) rename examples/{high_freq => highfreq}/highfreq_processor.py (69%) rename examples/{high_freq => highfreq}/workflow.py (54%) diff --git a/examples/high_freq/__init__.py b/examples/highfreq/__init__.py similarity index 100% rename from examples/high_freq/__init__.py rename to examples/highfreq/__init__.py diff --git a/examples/high_freq/highfreq_handler.py b/examples/highfreq/highfreq_handler.py similarity index 62% rename from examples/high_freq/highfreq_handler.py rename to examples/highfreq/highfreq_handler.py index 298ffb5c0..cb23f48bb 100644 --- a/examples/high_freq/highfreq_handler.py +++ b/examples/highfreq/highfreq_handler.py @@ -56,88 +56,44 @@ class HighFreqHandler(DataHandlerLP): template_if = "If(IsNull({1}), {0}, {1})" template_paused = "Select(Or(IsNull($paused), Eq($paused, 0.0)), {0})" - # template_paused="{0}" - template_fillnan = "FFillNan({0})" + template_fillnan = "BFillNan(FFillNan({0}))" + # Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap simpson_vwap = "($open + 2*$high + 2*$low + $close)/6" - fields += [ - "{0}/Ref(DayLast({1}), 240)".format( + + def get_04_price_feature(price_field): + """Get 0~4 column price feature ops""" + feature_ops = "{0}/Ref(DayLast({1}), 240)".format( template_if.format( template_fillnan.format(template_paused.format("$close")), - template_paused.format("$open"), + template_paused.format(price_field), ), template_fillnan.format(template_paused.format("$close")), ) - ] - fields += [ - "{0}/Ref(DayLast({1}), 240)".format( - template_if.format( - template_fillnan.format(template_paused.format("$close")), - template_paused.format("$high"), - ), - template_fillnan.format(template_paused.format("$close")), - ) - ] - fields += [ - "{0}/Ref(DayLast({1}), 240)".format( - template_if.format( - template_fillnan.format(template_paused.format("$close")), - template_paused.format("$low"), - ), - template_fillnan.format(template_paused.format("$close")), - ) - ] - fields += ["{0}/Ref(DayLast({0}), 240)".format(template_fillnan.format(template_paused.format("$close")))] - fields += [ - "{0}/Ref(DayLast({1}), 240)".format( - template_if.format( - template_fillnan.format(template_paused.format("$close")), - template_paused.format(simpson_vwap), - ), - template_fillnan.format(template_paused.format("$close")), - ) - ] + return feature_ops + + fields += [get_04_price_feature("$open")] + fields += [get_04_price_feature("$high")] + fields += [get_04_price_feature("$low")] + fields += [get_04_price_feature("$close")] + fields += [get_04_price_feature(simpson_vwap)] names += ["$open", "$high", "$low", "$close", "$vwap"] - fields += [ - "Ref({0}, 240)/Ref(DayLast({1}), 240)".format( + def get_59_price_feature(price_field): + """Get 5~9 column price feature ops""" + feature_ops = "Ref({0}, 240)/Ref(DayLast({1}), 240)".format( template_if.format( template_fillnan.format(template_paused.format("$close")), - template_paused.format("$open"), + template_paused.format(price_field), ), template_fillnan.format(template_paused.format("$close")), ) - ] - fields += [ - "Ref({0}, 240)/Ref(DayLast({1}), 240)".format( - template_if.format( - template_fillnan.format(template_paused.format("$close")), - template_paused.format("$high"), - ), - template_fillnan.format(template_paused.format("$close")), - ) - ] - fields += [ - "Ref({0}, 240)/Ref(DayLast({1}), 240)".format( - template_if.format( - template_fillnan.format(template_paused.format("$close")), - template_paused.format("$low"), - ), - template_fillnan.format(template_paused.format("$close")), - ) - ] - fields += [ - "Ref({0}, 240)/Ref(DayLast({0}), 240)".format(template_fillnan.format(template_paused.format("$close"))) - ] + return feature_ops - fields += [ - "Ref({0}, 240)/Ref(DayLast({1}), 240)".format( - template_if.format( - template_fillnan.format(template_paused.format("$close")), - template_paused.format(simpson_vwap), - ), - template_fillnan.format(template_paused.format("$close")), - ) - ] + fields += [get_59_price_feature("$open")] + fields += [get_59_price_feature("$high")] + fields += [get_59_price_feature("$low")] + fields += [get_59_price_feature("$close")] + fields += [get_59_price_feature(simpson_vwap)] names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"] fields += [ @@ -197,19 +153,20 @@ class HighFreqBacktestHandler(DataHandler): template_if = "If(IsNull({1}), {0}, {1})" template_paused = "Select(Or(IsNull($paused), Eq($paused, 0.0)), {0})" - # template_paused="{0}" - template_fillnan = "FFillNan({0})" + template_fillnan = "BFillNan(FFillNan({0}))" + # Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap simpson_vwap = "($open + 2*$high + 2*$low + $close)/6" - # fields += [ - # template_fillnan.format(template_paused.format("$close")), - # ] + fields += [ + template_fillnan.format(template_paused.format("$close")), + ] + names += ["$close0"] fields += [ template_if.format( template_fillnan.format(template_paused.format("$close")), template_paused.format(simpson_vwap), ) ] - names += ["$vwap_0"] + names += ["$vwap0"] fields += [ "If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format( template_paused.format("$volume"), @@ -218,6 +175,6 @@ class HighFreqBacktestHandler(DataHandler): template_paused.format("$high"), ) ] - names += ["$volume_0"] + names += ["$volume0"] return fields, names diff --git a/examples/high_freq/highfreq_ops.py b/examples/highfreq/highfreq_ops.py similarity index 69% rename from examples/high_freq/highfreq_ops.py rename to examples/highfreq/highfreq_ops.py index a3fa7ac4a..cee6914a2 100644 --- a/examples/high_freq/highfreq_ops.py +++ b/examples/highfreq/highfreq_ops.py @@ -3,51 +3,61 @@ import pandas as pd import importlib from qlib.data.ops import ElemOperator, PairOperator from qlib.config import C +from qlib.data.cache import H from qlib.data.data import Cal -class DayFirst(ElemOperator): - def __init__(self, feature): - super(DayFirst, self).__init__(feature, "day_first") - - def _load_internal(self, instrument, start_index, end_index, freq): - _calendar = Cal.get_calendar_day(freq=freq)[0] - series = self.feature.load(instrument, start_index, end_index, freq) - return series.groupby(_calendar[series.index]).transform("first") +def get_calendar_day(freq="day", future=False): + flag = f"{freq}_future_{future}_day" + if flag in H["c"]: + _calendar = H["c"][flag] + else: + _calendar = np.array(list(map(lambda x: x.date(), Cal.load_calendar(freq, future)))) + H["c"][flag] = _calendar + return _calendar class DayLast(ElemOperator): def __init__(self, feature): - super(DayLast, self).__init__(feature, "day_last") + super(DayLast, self).__init__(feature) def _load_internal(self, instrument, start_index, end_index, freq): - _calendar = Cal.get_calendar_day(freq=freq)[0] + _calendar = get_calendar_day(freq=freq) series = self.feature.load(instrument, start_index, end_index, freq) return series.groupby(_calendar[series.index]).transform("last") class FFillNan(ElemOperator): def __init__(self, feature): - super(FFillNan, self).__init__(feature, "fill_nan") + super(FFillNan, self).__init__(feature) def _load_internal(self, instrument, start_index, end_index, freq): series = self.feature.load(instrument, start_index, end_index, freq) return series.fillna(method="ffill") -class Date(ElemOperator): +class BFillNan(ElemOperator): def __init__(self, feature): - super(Date, self).__init__(feature, "date") + super(BFillNan, self).__init__(feature) def _load_internal(self, instrument, start_index, end_index, freq): - _calendar = Cal.get_calendar_day(freq=freq)[0] + series = self.feature.load(instrument, start_index, end_index, freq) + return series.fillna(method="bfill") + + +class Date(ElemOperator): + def __init__(self, feature): + super(Date, self).__init__(feature) + + def _load_internal(self, instrument, start_index, end_index, freq): + _calendar = get_calendar_day(freq=freq) series = self.feature.load(instrument, start_index, end_index, freq) return pd.Series(_calendar[series.index], index=series.index) class Select(PairOperator): def __init__(self, condition, feature): - super(Select, self).__init__(condition, feature, "select") + super(Select, self).__init__(condition, feature) def _load_internal(self, instrument, start_index, end_index, freq): series_condition = self.feature_left.load(instrument, start_index, end_index, freq) @@ -57,7 +67,7 @@ class Select(PairOperator): class IsNull(ElemOperator): def __init__(self, feature): - super(IsNull, self).__init__(feature, "isnull") + super(IsNull, self).__init__(feature) def _load_internal(self, instrument, start_index, end_index, freq): series = self.feature.load(instrument, start_index, end_index, freq) diff --git a/examples/high_freq/highfreq_processor.py b/examples/highfreq/highfreq_processor.py similarity index 69% rename from examples/high_freq/highfreq_processor.py rename to examples/highfreq/highfreq_processor.py index d71cd2e85..f0ab0dec2 100644 --- a/examples/high_freq/highfreq_processor.py +++ b/examples/highfreq/highfreq_processor.py @@ -26,7 +26,7 @@ class HighFreqNorm(Processor): if name == "volume": part_values = np.log1p(part_values) self.feature_med[name] = np.nanmedian(part_values) - part_values = part_values - self.feature_med[name] # mean, copy + part_values = part_values - self.feature_med[name] self.feature_std[name] = np.nanmedian(np.absolute(part_values)) * 1.4826 + 1e-12 part_values = part_values / self.feature_std[name] self.feature_vmax[name] = np.nanmax(part_values) @@ -41,23 +41,27 @@ class HighFreqNorm(Processor): } for name, name_val in names.items(): - part_values = df_values[:, name_val] if name == "volume": - part_values[:] = np.log1p(part_values) - part_values -= self.feature_med[name] - part_values /= self.feature_std[name] - slice0 = part_values > 3.0 - slice1 = part_values > 3.5 - slice2 = part_values < -3.0 - slice3 = part_values < -3.5 + df_values[:, name_val] = np.log1p(df_values[:, name_val]) + df_values[:, name_val] -= self.feature_med[name] + df_values[:, name_val] /= self.feature_std[name] + slice0 = df_values[:, name_val] > 3.0 + slice1 = df_values[:, name_val] > 3.5 + slice2 = df_values[:, name_val] < -3.0 + slice3 = df_values[:, name_val] < -3.5 - part_values[slice0] = 3.0 + (part_values[slice0] - 3.0) / (self.feature_vmax[name] - 3) * 0.5 - part_values[slice1] = 3.5 - part_values[slice2] = -3.0 - (part_values[slice2] + 3.0) / (self.feature_vmin[name] + 3) * 0.5 - part_values[slice3] = -3.5 - # print("start_call_feature_reshape") + df_values[:, name_val][slice0] = ( + 3.0 + (df_values[:, name_val][slice0] - 3.0) / (self.feature_vmax[name] - 3) * 0.5 + ) + df_values[:, name_val][slice1] = 3.5 + df_values[:, name_val][slice2] = ( + -3.0 - (df_values[:, name_val][slice2] + 3.0) / (self.feature_vmin[name] + 3) * 0.5 + ) + df_values[:, name_val][slice3] = -3.5 idx = df_features.index.droplevel("datetime").drop_duplicates() idx.set_names(["instrument", "datetime"], inplace=True) + + # Reshape is specifically for adapting to RL high-freq executor feat = df_values[:, [0, 1, 2, 3, 4, 10]].reshape(-1, 6 * 240) feat_1 = df_values[:, [5, 6, 7, 8, 9, 11]].reshape(-1, 6 * 240) df_new_features = pd.DataFrame( diff --git a/examples/high_freq/workflow.py b/examples/highfreq/workflow.py similarity index 54% rename from examples/high_freq/workflow.py rename to examples/highfreq/workflow.py index a2ec67365..7bbb03df4 100644 --- a/examples/high_freq/workflow.py +++ b/examples/highfreq/workflow.py @@ -2,13 +2,14 @@ # Licensed under the MIT License. import sys +import fire from pathlib import Path import qlib import pickle import numpy as np import pandas as pd -from qlib.config import REG_CN +from qlib.config import HIGH_FREQ_CONFIG from qlib.contrib.model.gbdt import LGBModel from qlib.contrib.data.handler import Alpha158 from qlib.contrib.strategy.strategy import TopkDropoutStrategy @@ -23,42 +24,22 @@ from qlib.data.ops import Operators from qlib.data.data import Cal from qlib.utils import exists_qlib_data -from highfreq_ops import DayFirst, DayLast, FFillNan, Date, Select, IsNull +from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull -if __name__ == "__main__": - # use yahoo_cn_1min data - provider_uri = "~/.qlib/qlib_data/yahoo_cn_1min" - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath("scripts"))) - from get_data import GetData +class HighfreqWorkflow(object): - GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN) - - qlib.init( - provider_uri=provider_uri, - custom_ops=[DayFirst, DayLast, FFillNan, Date, Select, IsNull], - redis_port=-1, - region=REG_CN, - auto_mount=False, - ) + SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull], "expression_cache": None} MARKET = "all" BENCHMARK = "SH000300" DROP_LOAD_DATASET = False # flag wether to test [drop and load dataset] - # start_time = "2019-01-01 00:00:00" - # end_time = "2019-12-31 15:00:00" - # train_end_time = "2019-05-31 15:00:00" - # test_start_time = "2019-06-01 00:00:00" start_time = "2020-09-14 00:00:00" end_time = "2021-01-18 16:00:00" train_end_time = "2020-11-30 16:00:00" test_start_time = "2020-12-01 00:00:00" - ################################### - # train model - ################################### + DATA_HANDLER_CONFIG0 = { "start_time": start_time, "end_time": end_time, @@ -94,8 +75,6 @@ if __name__ == "__main__": }, }, }, - # You shoud record the data in specific sequence - # "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'], "dataset_backtest": { "class": "DatasetH", "module_path": "qlib.data.dataset", @@ -115,26 +94,50 @@ if __name__ == "__main__": }, }, } - ##=============load the calendar for cache============= - # unnecessary, but may accelerate - Cal.calendar(freq="1min") # load the calendar for cache - Cal.get_calendar_day(freq="1min") # load the calendar for cache - ##=============get data============= + def _init_qlib(self): + """initialize qlib""" + # use yahoo_cn_1min data + QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF} + provider_uri = QLIB_INIT_CONFIG.get("provider_uri") + if not exists_qlib_data(provider_uri): + print(f"Qlib data is not found in {provider_uri}") + sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath("scripts"))) + from get_data import GetData - dataset = init_instance_by_config(task["dataset"]) - xtrain, xtest = dataset.prepare(["train", "test"]) - print(xtrain, xtest) + GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN) + qlib.init(**QLIB_INIT_CONFIG) - dataset_backtest = init_instance_by_config(task["dataset_backtest"]) - backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"]) - print(backtest_train, backtest_test) + def _prepare_calender_cache(self): + """preload the calendar for cache""" - del xtrain, xtest - del backtest_train, backtest_test + # This code used the copy-on-write feature of Linux to avoid calculating the calendar multiple times in the subprocess + # This code may accelerate, but may be not useful on Windows and Mac Os + Cal.calendar(freq="1min") + get_calendar_day(freq="1min") - ## example to show how to save the dataset and reload it, and how to use different data - if DROP_LOAD_DATASET: + def get_data(self): + """use dataset to get highreq data""" + self._init_qlib() + self._prepare_calender_cache() + + dataset = init_instance_by_config(self.task["dataset"]) + xtrain, xtest = dataset.prepare(["train", "test"]) + print(xtrain, xtest) + + dataset_backtest = init_instance_by_config(self.task["dataset_backtest"]) + backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"]) + print(backtest_train, backtest_test) + + del xtrain, xtest + del backtest_train, backtest_test + + def dump_and_load_dataset(self): + """dump and load dataset state on disk""" + self._init_qlib() + self._prepare_calender_cache() + dataset = init_instance_by_config(self.task["dataset"]) + dataset_backtest = init_instance_by_config(self.task["dataset_backtest"]) ##=============dump dataset============= dataset.to_pickle(path="dataset.pkl") @@ -142,33 +145,18 @@ if __name__ == "__main__": del dataset, dataset_backtest ##=============reload dataset============= - file_dataset = open("dataset.pkl", "rb") - dataset = pickle.load(file_dataset) - file_dataset.close() + with open("dataset.pkl", "rb") as file_dataset: + dataset = pickle.load(file_dataset) - file_dataset_backtest = open("dataset_backtest.pkl", "rb") - dataset_backtest = pickle.load(file_dataset_backtest) - - file_dataset_backtest.close() + with open("dataset_backtest.pkl", "rb") as file_dataset_backtest: + dataset_backtest = pickle.load(file_dataset_backtest) + self._prepare_calender_cache() ##=============reload_dataset============= dataset.init(init_type=DataHandlerLP.IT_LS) - dataset_backtest.init(init_type=DataHandlerLP.IT_LS) + dataset_backtest.init() - ##=============reinit qlib============= - ## Unless you want to modify the provider_uri and other configurations, reinit is unnecessary - qlib.init( - provider_uri=provider_uri, - custom_ops=[DayFirst, DayLast, FFillNan, Date, Select, IsNull], - redis_port=-1, - region=REG_CN, - auto_mount=False, - ) - - Cal.calendar(freq="1min") # load the calendar for cache - Cal.get_calendar_day(freq="1min") # load the calendar for cache - - ##=============test dataset============= + ##=============get data============= xtrain, xtest = dataset.prepare(["train", "test"]) backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"]) @@ -176,3 +164,7 @@ if __name__ == "__main__": print(backtest_train, backtest_test) del xtrain, xtest del backtest_train, backtest_test + + +if __name__ == "__main__": + fire.Fire(HighfreqWorkflow) diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py index 6253f3ee4..ea9c70083 100644 --- a/examples/workflow_by_code.py +++ b/examples/workflow_by_code.py @@ -30,7 +30,7 @@ if __name__ == "__main__": GetData().qlib_data(target_dir=provider_uri, region=REG_CN) - qlib.init(provider_uri=provider_uri, region=REG_CN, redis_port=233) + qlib.init(provider_uri=provider_uri, region=REG_CN) market = "csi300" benchmark = "SH000300" diff --git a/qlib/config.py b/qlib/config.py index e94752953..e7120c23a 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -193,6 +193,12 @@ MODE_CONF = { }, } +HIGH_FREQ_CONFIG = { + "provider_uri": "~/.qlib/qlib_data/yahoo_cn_1min", + "dataset_cache": None, + "expression_cache": "DiskExpressionCache", + "region": REG_CN, +} _default_region_config = { REG_CN: { diff --git a/qlib/data/base.py b/qlib/data/base.py index 92fc57ffe..e318843c4 100644 --- a/qlib/data/base.py +++ b/qlib/data/base.py @@ -157,7 +157,7 @@ class Expression(abc.ABC): @abc.abstractmethod def _load_internal(self, instrument, start_index, end_index, freq): - pass + raise NotImplementedError("This function must be implemented in your newly defined feature") @abc.abstractmethod def get_longest_back_rolling(self): diff --git a/qlib/data/data.py b/qlib/data/data.py index d7f50e0b0..2a0e569ab 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -117,17 +117,7 @@ class CalendarProvider(abc.ABC): if flag in H["c"]: _calendar, _calendar_index = H["c"][flag] else: - _calendar = np.array(self._load_calendar(freq, future)) - _calendar_index = {x: i for i, x in enumerate(_calendar)} # for fast search - H["c"][flag] = _calendar, _calendar_index - return _calendar, _calendar_index - - def get_calendar_day(self, freq="day", future=False): - flag = f"{freq}_future_{future}_day" - if flag in H["c"]: - _calendar, _calendar_index = H["c"][flag] - else: - _calendar = np.array(list(map(lambda x: x.date(), self._load_calendar(freq, future)))) + _calendar = np.array(self.load_calendar(freq, future)) _calendar_index = {x: i for i, x in enumerate(_calendar)} # for fast search H["c"][flag] = _calendar, _calendar_index return _calendar, _calendar_index @@ -514,7 +504,7 @@ class LocalCalendarProvider(CalendarProvider): """Calendar file uri.""" return os.path.join(C.get_data_path(), "calendars", "{}.txt") - def _load_calendar(self, freq, future): + def load_calendar(self, freq, future): """Load original calendar timestamp from file. Parameters @@ -679,12 +669,11 @@ class LocalExpressionProvider(ExpressionProvider): # 1) The stock data is currently float. If there is other types of data, this part needs to be re-implemented. # 2) The the precision should be configurable try: - if series.dtype == np.float64: - series = series.astype(np.float32) - elif series.dtype == np.bool: - series = series.astype(np.int8) + series = series.astype(np.float32) except ValueError: pass + except TypeError: + pass if not series.empty: series = series.loc[start_index:end_index] return series diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 65dcf7ccb..117da764f 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -88,15 +88,8 @@ class DatasetH(Dataset): super().__init__(handler, segments) def init(self, **kwargs): - - logger = get_module_logger("DatasetH") - handler_init_kwargs = {} - for arg_key, arg_value in kwargs.items(): - if arg_key in getfullargspec(self.handler.init).args: - handler_init_kwargs[arg_key] = arg_value - else: - logger.info(f"init arguments[{arg_key}] is ignored.") - self.handler.init(**handler_init_kwargs) + """Initialize the DatasetH, Only parameters belonging to handler.init will be passed in""" + self.handler.init(**kwargs) def setup_data(self, handler: Union[dict, DataHandler], segments: list): """ diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 627624022..abcd5a60c 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -428,13 +428,11 @@ class DataHandlerLP(DataHandler): # TODO: Be able to cache handler data. Save the memory for data processing def _get_df_by_key(self, data_key: str = DK_I) -> pd.DataFrame: - try: - df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key]) - except AttributeError: - print("please set drop_raw = False if you want to use raw data") - raise - except: - raise + if data_key == self.DK_R and self.drop_raw: + raise AttributeError( + "DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data" + ) + df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key]) return df def fetch( diff --git a/qlib/data/ops.py b/qlib/data/ops.py index 66e588be1..940c24002 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -6,6 +6,7 @@ from __future__ import division from __future__ import print_function import sys +import abc import numpy as np import pandas as pd @@ -22,8 +23,6 @@ except ImportError: "#### Do not import qlib package in the repository directory in case of importing qlib from . without compiling #####" ) raise -except: - raise np.seterr(invalid="ignore") @@ -34,12 +33,39 @@ np.seterr(invalid="ignore") class ElemOperator(ExpressionOps): """Element-wise Operator + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + Expression + feature operation output + """ + + def __init__(self, feature): + self.feature = feature + + def __str__(self): + return "{}({})".format(type(self).__name__, self.feature) + + def get_longest_back_rolling(self): + return self.feature.get_longest_back_rolling() + + def get_extended_window_size(self): + return self.feature.get_extended_window_size() + + +class NpElemOperator(ElemOperator): + """Numpy Element-wise Operator + Parameters ---------- feature : Expression feature instance func : str - feature operation method + numpy feature operation method Returns ---------- @@ -50,22 +76,14 @@ class ElemOperator(ExpressionOps): def __init__(self, feature, func): self.feature = feature self.func = func - - def __str__(self): - return "{}({})".format(type(self).__name__, self.feature) + super(NpElemOperator, self).__init__(feature) def _load_internal(self, instrument, start_index, end_index, freq): series = self.feature.load(instrument, start_index, end_index, freq) return getattr(np, self.func)(series) - def get_longest_back_rolling(self): - return self.feature.get_longest_back_rolling() - def get_extended_window_size(self): - return self.feature.get_extended_window_size() - - -class Abs(ElemOperator): +class Abs(NpElemOperator): """Feature Absolute Value Parameters @@ -83,7 +101,7 @@ class Abs(ElemOperator): super(Abs, self).__init__(feature, "abs") -class Sign(ElemOperator): +class Sign(NpElemOperator): """Feature Sign Parameters @@ -110,7 +128,7 @@ class Sign(ElemOperator): return getattr(np, self.func)(series) -class Log(ElemOperator): +class Log(NpElemOperator): """Feature Log Parameters @@ -128,7 +146,7 @@ class Log(ElemOperator): super(Log, self).__init__(feature, "log") -class Power(ElemOperator): +class Power(NpElemOperator): """Feature Power Parameters @@ -154,7 +172,7 @@ class Power(ElemOperator): return getattr(np, self.func)(series, self.exponent) -class Mask(ElemOperator): +class Mask(NpElemOperator): """Feature Mask Parameters @@ -181,7 +199,7 @@ class Mask(ElemOperator): return self.feature.load(self.instrument, start_index, end_index, freq) -class Not(ElemOperator): +class Not(NpElemOperator): """Not Operator Parameters @@ -220,28 +238,13 @@ class PairOperator(ExpressionOps): two features' operation output """ - def __init__(self, feature_left, feature_right, func): + def __init__(self, feature_left, feature_right): self.feature_left = feature_left self.feature_right = feature_right - self.func = func def __str__(self): return "{}({},{})".format(type(self).__name__, self.feature_left, self.feature_right) - def _load_internal(self, instrument, start_index, end_index, freq): - assert any( - [isinstance(self.feature_left, Expression), self.feature_right, Expression] - ), "at least one of two inputs is Expression instance" - if isinstance(self.feature_left, Expression): - series_left = self.feature_left.load(instrument, start_index, end_index, freq) - else: - series_left = self.feature_left # numeric value - if isinstance(self.feature_right, Expression): - series_right = self.feature_right.load(instrument, start_index, end_index, freq) - else: - series_right = self.feature_right - return getattr(np, self.func)(series_left, series_right) - def get_longest_back_rolling(self): if isinstance(self.feature_left, Expression): left_br = self.feature_left.get_longest_back_rolling() @@ -267,7 +270,46 @@ class PairOperator(ExpressionOps): return max(ll, rl), max(lr, rr) -class Add(PairOperator): +class NpPairOperator(PairOperator): + """Numpy Pair-wise operator + + Parameters + ---------- + feature_left : Expression + feature instance or numeric value + feature_right : Expression + feature instance or numeric value + func : str + operator function + + Returns + ---------- + Feature: + two features' operation output + """ + + def __init__(self, feature_left, feature_right, func): + self.feature_left = feature_left + self.feature_right = feature_right + self.func = func + super(NpPairOperator, self).__init__(feature_left, feature_right) + + def _load_internal(self, instrument, start_index, end_index, freq): + assert any( + [isinstance(self.feature_left, Expression), self.feature_right, Expression] + ), "at least one of two inputs is Expression instance" + if isinstance(self.feature_left, Expression): + series_left = self.feature_left.load(instrument, start_index, end_index, freq) + else: + series_left = self.feature_left # numeric value + if isinstance(self.feature_right, Expression): + series_right = self.feature_right.load(instrument, start_index, end_index, freq) + else: + series_right = self.feature_right + return getattr(np, self.func)(series_left, series_right) + + +class Add(NpPairOperator): """Add Operator Parameters @@ -287,7 +329,7 @@ class Add(PairOperator): super(Add, self).__init__(feature_left, feature_right, "add") -class Sub(PairOperator): +class Sub(NpPairOperator): """Subtract Operator Parameters @@ -307,7 +349,7 @@ class Sub(PairOperator): super(Sub, self).__init__(feature_left, feature_right, "subtract") -class Mul(PairOperator): +class Mul(NpPairOperator): """Multiply Operator Parameters @@ -327,7 +369,7 @@ class Mul(PairOperator): super(Mul, self).__init__(feature_left, feature_right, "multiply") -class Div(PairOperator): +class Div(NpPairOperator): """Division Operator Parameters @@ -347,7 +389,7 @@ class Div(PairOperator): super(Div, self).__init__(feature_left, feature_right, "divide") -class Greater(PairOperator): +class Greater(NpPairOperator): """Greater Operator Parameters @@ -367,7 +409,7 @@ class Greater(PairOperator): super(Greater, self).__init__(feature_left, feature_right, "maximum") -class Less(PairOperator): +class Less(NpPairOperator): """Less Operator Parameters @@ -387,7 +429,7 @@ class Less(PairOperator): super(Less, self).__init__(feature_left, feature_right, "minimum") -class Gt(PairOperator): +class Gt(NpPairOperator): """Greater Than Operator Parameters @@ -407,7 +449,7 @@ class Gt(PairOperator): super(Gt, self).__init__(feature_left, feature_right, "greater") -class Ge(PairOperator): +class Ge(NpPairOperator): """Greater Equal Than Operator Parameters @@ -427,7 +469,7 @@ class Ge(PairOperator): super(Ge, self).__init__(feature_left, feature_right, "greater_equal") -class Lt(PairOperator): +class Lt(NpPairOperator): """Less Than Operator Parameters @@ -447,7 +489,7 @@ class Lt(PairOperator): super(Lt, self).__init__(feature_left, feature_right, "less") -class Le(PairOperator): +class Le(NpPairOperator): """Less Equal Than Operator Parameters @@ -467,7 +509,7 @@ class Le(PairOperator): super(Le, self).__init__(feature_left, feature_right, "less_equal") -class Eq(PairOperator): +class Eq(NpPairOperator): """Equal Operator Parameters @@ -487,7 +529,7 @@ class Eq(PairOperator): super(Eq, self).__init__(feature_left, feature_right, "equal") -class Ne(PairOperator): +class Ne(NpPairOperator): """Not Equal Operator Parameters @@ -507,7 +549,7 @@ class Ne(PairOperator): super(Ne, self).__init__(feature_left, feature_right, "not_equal") -class And(PairOperator): +class And(NpPairOperator): """And Operator Parameters @@ -527,7 +569,7 @@ class And(PairOperator): super(And, self).__init__(feature_left, feature_right, "bitwise_and") -class Or(PairOperator): +class Or(NpPairOperator): """Or Operator Parameters