diff --git a/README.md b/README.md index 05571a3c4..c79e2f2fa 100644 --- a/README.md +++ b/README.md @@ -222,17 +222,17 @@ The automatic workflow may not suite the research workflow of all Quant research # [Quant Model Zoo](examples/benchmarks) Here is a list of models built on `Qlib`. -- [GBDT based on LightGBM (Guolin Ke, et al.)](qlib/contrib/model/gbdt.py) -- [GBDT based on Catboost (Liudmila Prokhorenkova, et al.)](qlib/contrib/model/catboost_model.py) -- [GBDT based on XGBoost (Tianqi Chen, et al.)](qlib/contrib/model/xgboost.py) +- [GBDT based on XGBoost (Tianqi Chen, et al. 2016)](qlib/contrib/model/xgboost.py) +- [GBDT based on LightGBM (Guolin Ke, et al. 2017)](qlib/contrib/model/gbdt.py) +- [GBDT based on Catboost (Liudmila Prokhorenkova, et al. 2017)](qlib/contrib/model/catboost_model.py) - [MLP based on pytorch](qlib/contrib/model/pytorch_nn.py) -- [GRU based on pytorch (Kyunghyun Cho, et al.)](qlib/contrib/model/pytorch_gru.py) -- [LSTM based on pytorch (Sepp Hochreiter, et al.)](qlib/contrib/model/pytorch_lstm.py) -- [ALSTM based on pytorch (Yao Qin, et al.)](qlib/contrib/model/pytorch_alstm.py) -- [GATs based on pytorch (Petar Velickovic, et al.)](qlib/contrib/model/pytorch_gats.py) -- [SFM based on pytorch (Liheng Zhang, et al.)](qlib/contrib/model/pytorch_sfm.py) -- [TFT based on tensorflow (Bryan Lim, et al.)](examples/benchmarks/TFT/tft.py) -- [TabNet based on pytorch (Sercan O. Arik, et al.)](qlib/contrib/model/pytorch_tabnet.py) +- [LSTM based on pytorch (Sepp Hochreiter, et al. 1997)](qlib/contrib/model/pytorch_lstm.py) +- [GRU based on pytorch (Kyunghyun Cho, et al. 2014)](qlib/contrib/model/pytorch_gru.py) +- [ALSTM based on pytorch (Yao Qin, et al. 2017)](qlib/contrib/model/pytorch_alstm.py) +- [GATs based on pytorch (Petar Velickovic, et al. 2017)](qlib/contrib/model/pytorch_gats.py) +- [SFM based on pytorch (Liheng Zhang, et al. 2017)](qlib/contrib/model/pytorch_sfm.py) +- [TFT based on tensorflow (Bryan Lim, et al. 2019)](examples/benchmarks/TFT/tft.py) +- [TabNet based on pytorch (Sercan O. Arik, et al. 2019)](qlib/contrib/model/pytorch_tabnet.py) Your PR of new Quant models is highly welcomed. diff --git a/examples/highfreq/README.md b/examples/highfreq/README.md new file mode 100644 index 000000000..30c2e19db --- /dev/null +++ b/examples/highfreq/README.md @@ -0,0 +1,28 @@ +# High-Frequency Dataset + +This dataset is an example for RL high frequency trading. + +## Get High-Frequency Data + +Get high-frequency data by running the following command: +```bash + python workflow.py get_data +``` + +## Dump & Reload & Reinitialize the Dataset + + +The High-Frequency Dataset is implemented as `qlib.data.dataset.DatasetH` in the `workflow.py`. `DatatsetH` is the subclass of [`qlib.utils.serial.Serializable`](https://qlib.readthedocs.io/en/latest/advanced/serial.html), whose state can be dumped in or loaded from disk in `pickle` format. + +### About Reinitialization + +After reloading `Dataset` from disk, `Qlib` also support reinitializing the dataset. It means that users can reset some states of `Dataset` or `DataHandler` such as `instruments`, `start_time`, `end_time` and `segments`, etc., and generate new data according to the states. + +The example is given in `workflow.py`, users can run the code as follows. + +### Run the Code + +Run the example by running the following command: +```bash + python workflow.py dump_and_load_dataset +``` \ No newline at end of file diff --git a/examples/highfreq/highfreq_handler.py b/examples/highfreq/highfreq_handler.py new file mode 100644 index 000000000..d35650514 --- /dev/null +++ b/examples/highfreq/highfreq_handler.py @@ -0,0 +1,174 @@ +from qlib.data.dataset.handler import DataHandler, DataHandlerLP +from qlib.data.dataset.processor import Processor +from qlib.utils import get_cls_kwargs +from qlib.log import TimeInspector + + +class HighFreqHandler(DataHandlerLP): + def __init__( + self, + instruments="csi300", + start_time=None, + end_time=None, + infer_processors=[], + learn_processors=[], + fit_start_time=None, + fit_end_time=None, + drop_raw=True, + ): + def check_transform_proc(proc_l): + new_l = [] + for p in proc_l: + p["kwargs"].update( + { + "fit_start_time": fit_start_time, + "fit_end_time": fit_end_time, + } + ) + new_l.append(p) + return new_l + + infer_processors = check_transform_proc(infer_processors) + learn_processors = check_transform_proc(learn_processors) + + data_loader = { + "class": "QlibDataLoader", + "kwargs": { + "config": self.get_feature_config(), + "swap_level": False, + "freq": "1min", + }, + } + super().__init__( + instruments=instruments, + start_time=start_time, + end_time=end_time, + data_loader=data_loader, + infer_processors=infer_processors, + learn_processors=learn_processors, + drop_raw=drop_raw, + ) + + def get_feature_config(self): + fields = [] + names = [] + + template_if = "If(IsNull({1}), {0}, {1})" + template_paused = "Select(Or(IsNull($paused), Eq($paused, 0.0)), {0})" + template_fillnan = "BFillNan(FFillNan({0}))" + # Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap + simpson_vwap = "($open + 2*$high + 2*$low + $close)/6" + + def get_normalized_price_feature(price_field, shift=0): + """Get normalized price feature ops""" + if shift == 0: + template_norm = "Cut({0}/Ref(DayLast({1}), 240), 240, None)" + else: + template_norm = "Cut(Ref({0}, " + str(shift) + ")/Ref(DayLast({1}), 240), 240, None)" + + feature_ops = template_norm.format( + template_if.format( + template_fillnan.format(template_paused.format("$close")), + template_paused.format(price_field), + ), + template_fillnan.format(template_paused.format("$close")), + ) + return feature_ops + + fields += [get_normalized_price_feature("$open", 0)] + fields += [get_normalized_price_feature("$high", 0)] + fields += [get_normalized_price_feature("$low", 0)] + fields += [get_normalized_price_feature("$close", 0)] + fields += [get_normalized_price_feature(simpson_vwap, 0)] + names += ["$open", "$high", "$low", "$close", "$vwap"] + + fields += [get_normalized_price_feature("$open", 240)] + fields += [get_normalized_price_feature("$high", 240)] + fields += [get_normalized_price_feature("$low", 240)] + fields += [get_normalized_price_feature("$close", 240)] + fields += [get_normalized_price_feature(simpson_vwap, 240)] + names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"] + + fields += [ + "Cut({0}/Ref(DayLast(Mean({0}, 7200)), 240), 240, None)".format( + "If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format( + template_paused.format("$volume"), + template_paused.format(simpson_vwap), + template_paused.format("$low"), + template_paused.format("$high"), + ) + ) + ] + names += ["$volume"] + fields += [ + "Cut(Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240), 240, None)".format( + "If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format( + template_paused.format("$volume"), + template_paused.format(simpson_vwap), + template_paused.format("$low"), + template_paused.format("$high"), + ) + ) + ] + names += ["$volume_1"] + + fields += ["Cut({0}, 240, None)".format(template_paused.format("Date($close)"))] + names += ["date"] + return fields, names + + +class HighFreqBacktestHandler(DataHandler): + def __init__( + self, + instruments="csi300", + start_time=None, + end_time=None, + ): + data_loader = { + "class": "QlibDataLoader", + "kwargs": { + "config": self.get_feature_config(), + "swap_level": False, + "freq": "1min", + }, + } + super().__init__( + instruments=instruments, + start_time=start_time, + end_time=end_time, + data_loader=data_loader, + ) + + def get_feature_config(self): + fields = [] + names = [] + + template_if = "If(IsNull({1}), {0}, {1})" + template_paused = "Select(Or(IsNull($paused), Eq($paused, 0.0)), {0})" + template_fillnan = "BFillNan(FFillNan({0}))" + # Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap + simpson_vwap = "($open + 2*$high + 2*$low + $close)/6" + fields += [ + "Cut({0}, 240, None)".format(template_fillnan.format(template_paused.format("$close"))), + ] + names += ["$close0"] + fields += [ + "Cut({0}, 240, None)".format( + template_if.format( + template_fillnan.format(template_paused.format("$close")), + template_paused.format(simpson_vwap), + ) + ) + ] + names += ["$vwap0"] + fields += [ + "Cut(If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0})), 240, None)".format( + template_paused.format("$volume"), + template_paused.format(simpson_vwap), + template_paused.format("$low"), + template_paused.format("$high"), + ) + ] + names += ["$volume0"] + + return fields, names diff --git a/examples/highfreq/highfreq_ops.py b/examples/highfreq/highfreq_ops.py new file mode 100644 index 000000000..66a084f9f --- /dev/null +++ b/examples/highfreq/highfreq_ops.py @@ -0,0 +1,190 @@ +import numpy as np +import pandas as pd +import importlib +from qlib.data.ops import ElemOperator, PairOperator +from qlib.config import C +from qlib.data.cache import H +from qlib.data.data import Cal + + +def get_calendar_day(freq="day", future=False): + """Load High-Freq Calendar Date Using Memcache. + + Parameters + ---------- + freq : str + frequency of read calendar file. + future : bool + whether including future trading day. + + Returns + ------- + _calendar: + array of date. + """ + flag = f"{freq}_future_{future}_day" + if flag in H["c"]: + _calendar = H["c"][flag] + else: + _calendar = np.array(list(map(lambda x: x.date(), Cal.load_calendar(freq, future)))) + H["c"][flag] = _calendar + return _calendar + + +class DayLast(ElemOperator): + """DayLast Operator + + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + feature: + a series of that each value equals the last value of its day + """ + + def _load_internal(self, instrument, start_index, end_index, freq): + _calendar = get_calendar_day(freq=freq) + series = self.feature.load(instrument, start_index, end_index, freq) + return series.groupby(_calendar[series.index]).transform("last") + + +class FFillNan(ElemOperator): + """FFillNan Operator + + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + feature: + a forward fill nan feature + """ + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + return series.fillna(method="ffill") + + +class BFillNan(ElemOperator): + """BFillNan Operator + + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + feature: + a backfoward fill nan feature + """ + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + return series.fillna(method="bfill") + + +class Date(ElemOperator): + """Date Operator + + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + feature: + a series of that each value is the date corresponding to feature.index + """ + + def _load_internal(self, instrument, start_index, end_index, freq): + _calendar = get_calendar_day(freq=freq) + series = self.feature.load(instrument, start_index, end_index, freq) + return pd.Series(_calendar[series.index], index=series.index) + + +class Select(PairOperator): + """Select Operator + + Parameters + ---------- + feature_left : Expression + feature instance, select condition + feature_right : Expression + feature instance, select value + + Returns + ---------- + feature: + value(feature_right) that meets the condition(feature_left) + + """ + + def _load_internal(self, instrument, start_index, end_index, freq): + series_condition = self.feature_left.load(instrument, start_index, end_index, freq) + series_feature = self.feature_right.load(instrument, start_index, end_index, freq) + return series_feature.loc[series_condition] + + +class IsNull(ElemOperator): + """IsNull Operator + + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + feature: + A series indicating whether the feature is nan + """ + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + return series.isnull() + + +class Cut(ElemOperator): + """Cut Operator + + Parameters + ---------- + feature : Expression + feature instance + l : int + l > 0, delete the first l elements of feature (default is None, which means 0) + r : int + r < 0, delete the last -r elements of feature (default is None, which means 0) + Returns + ---------- + feature: + A series with the first l and last -r elements deleted from the feature. + Note: It is deleted from the raw data, not the sliced data + """ + + def __init__(self, feature, l=None, r=None): + self.l = l + self.r = r + if (self.l is not None and self.l <= 0) or (self.r is not None and self.r >= 0): + raise ValueError("Cut operator l shoud > 0 and r should < 0") + + super(Cut, self).__init__(feature) + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + return series.iloc[self.l : self.r] + + def get_extended_window_size(self): + ll = 0 if self.l is None else self.l + rr = 0 if self.r is None else abs(self.r) + lft_etd, rght_etd = self.feature.get_extended_window_size() + lft_etd = lft_etd + ll + rght_etd = rght_etd + rr + return lft_etd, rght_etd diff --git a/examples/highfreq/highfreq_processor.py b/examples/highfreq/highfreq_processor.py new file mode 100644 index 000000000..f0ab0dec2 --- /dev/null +++ b/examples/highfreq/highfreq_processor.py @@ -0,0 +1,72 @@ +import numpy as np +import pandas as pd +from qlib.data.dataset.processor import Processor +from qlib.data.dataset.utils import fetch_df_by_index + + +class HighFreqNorm(Processor): + def __init__(self, fit_start_time, fit_end_time): + self.fit_start_time = fit_start_time + self.fit_end_time = fit_end_time + + def fit(self, df_features): + fetch_df = fetch_df_by_index(df_features, slice(self.fit_start_time, self.fit_end_time), level="datetime") + del df_features + df_values = fetch_df.values + names = { + "price": slice(0, 10), + "volume": slice(10, 12), + } + self.feature_med = {} + self.feature_std = {} + self.feature_vmax = {} + self.feature_vmin = {} + for name, name_val in names.items(): + part_values = df_values[:, name_val].astype(np.float32) + if name == "volume": + part_values = np.log1p(part_values) + self.feature_med[name] = np.nanmedian(part_values) + part_values = part_values - self.feature_med[name] + self.feature_std[name] = np.nanmedian(np.absolute(part_values)) * 1.4826 + 1e-12 + part_values = part_values / self.feature_std[name] + self.feature_vmax[name] = np.nanmax(part_values) + self.feature_vmin[name] = np.nanmin(part_values) + + def __call__(self, df_features): + df_features.set_index("date", append=True, drop=True, inplace=True) + df_values = df_features.values + names = { + "price": slice(0, 10), + "volume": slice(10, 12), + } + + for name, name_val in names.items(): + if name == "volume": + df_values[:, name_val] = np.log1p(df_values[:, name_val]) + df_values[:, name_val] -= self.feature_med[name] + df_values[:, name_val] /= self.feature_std[name] + slice0 = df_values[:, name_val] > 3.0 + slice1 = df_values[:, name_val] > 3.5 + slice2 = df_values[:, name_val] < -3.0 + slice3 = df_values[:, name_val] < -3.5 + + df_values[:, name_val][slice0] = ( + 3.0 + (df_values[:, name_val][slice0] - 3.0) / (self.feature_vmax[name] - 3) * 0.5 + ) + df_values[:, name_val][slice1] = 3.5 + df_values[:, name_val][slice2] = ( + -3.0 - (df_values[:, name_val][slice2] + 3.0) / (self.feature_vmin[name] + 3) * 0.5 + ) + df_values[:, name_val][slice3] = -3.5 + idx = df_features.index.droplevel("datetime").drop_duplicates() + idx.set_names(["instrument", "datetime"], inplace=True) + + # Reshape is specifically for adapting to RL high-freq executor + feat = df_values[:, [0, 1, 2, 3, 4, 10]].reshape(-1, 6 * 240) + feat_1 = df_values[:, [5, 6, 7, 8, 9, 11]].reshape(-1, 6 * 240) + df_new_features = pd.DataFrame( + data=np.concatenate((feat, feat_1), axis=1), + index=idx, + columns=["FEATURE_%d" % i for i in range(12 * 240)], + ).sort_index() + return df_new_features diff --git a/examples/highfreq/workflow.py b/examples/highfreq/workflow.py new file mode 100644 index 000000000..f85c6d558 --- /dev/null +++ b/examples/highfreq/workflow.py @@ -0,0 +1,217 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +import fire +from pathlib import Path + +import qlib +import pickle +import numpy as np +import pandas as pd +from qlib.config import REG_CN, HIGH_FREQ_CONFIG +from qlib.contrib.model.gbdt import LGBModel +from qlib.contrib.data.handler import Alpha158 +from qlib.contrib.strategy.strategy import TopkDropoutStrategy +from qlib.contrib.evaluate import ( + backtest as normal_backtest, + risk_analysis, +) + +from qlib.utils import init_instance_by_config, exists_qlib_data +from qlib.data.dataset.handler import DataHandlerLP +from qlib.data.ops import Operators +from qlib.data.data import Cal +from qlib.tests.data import GetData + +from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut + + +class HighfreqWorkflow(object): + + SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None} + + MARKET = "all" + BENCHMARK = "SH000300" + + start_time = "2020-09-15 00:00:00" + end_time = "2021-01-18 16:00:00" + train_end_time = "2020-11-30 16:00:00" + test_start_time = "2020-12-01 00:00:00" + + DATA_HANDLER_CONFIG0 = { + "start_time": start_time, + "end_time": end_time, + "fit_start_time": start_time, + "fit_end_time": train_end_time, + "instruments": MARKET, + "infer_processors": [{"class": "HighFreqNorm", "module_path": "highfreq_processor", "kwargs": {}}], + } + DATA_HANDLER_CONFIG1 = { + "start_time": start_time, + "end_time": end_time, + "instruments": MARKET, + } + + task = { + "dataset": { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "HighFreqHandler", + "module_path": "highfreq_handler", + "kwargs": DATA_HANDLER_CONFIG0, + }, + "segments": { + "train": (start_time, train_end_time), + "test": ( + test_start_time, + end_time, + ), + }, + }, + }, + "dataset_backtest": { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "HighFreqBacktestHandler", + "module_path": "highfreq_handler", + "kwargs": DATA_HANDLER_CONFIG1, + }, + "segments": { + "train": (start_time, train_end_time), + "test": ( + test_start_time, + end_time, + ), + }, + }, + }, + } + + def _init_qlib(self): + """initialize qlib""" + # use yahoo_cn_1min data + QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF} + provider_uri = QLIB_INIT_CONFIG.get("provider_uri") + if not exists_qlib_data(provider_uri): + print(f"Qlib data is not found in {provider_uri}") + GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN) + qlib.init(**QLIB_INIT_CONFIG) + + def _prepare_calender_cache(self): + """preload the calendar for cache""" + + # This code used the copy-on-write feature of Linux to avoid calculating the calendar multiple times in the subprocess + # This code may accelerate, but may be not useful on Windows and Mac Os + Cal.calendar(freq="1min") + get_calendar_day(freq="1min") + + def get_data(self): + """use dataset to get highreq data""" + self._init_qlib() + self._prepare_calender_cache() + + dataset = init_instance_by_config(self.task["dataset"]) + xtrain, xtest = dataset.prepare(["train", "test"]) + print(xtrain, xtest) + + dataset_backtest = init_instance_by_config(self.task["dataset_backtest"]) + backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"]) + print(backtest_train, backtest_test) + + return + + def dump_and_load_dataset(self): + """dump and load dataset state on disk""" + self._init_qlib() + self._prepare_calender_cache() + dataset = init_instance_by_config(self.task["dataset"]) + dataset_backtest = init_instance_by_config(self.task["dataset_backtest"]) + + ##=============dump dataset============= + dataset.to_pickle(path="dataset.pkl") + dataset_backtest.to_pickle(path="dataset_backtest.pkl") + + del dataset, dataset_backtest + ##=============reload dataset============= + with open("dataset.pkl", "rb") as file_dataset: + dataset = pickle.load(file_dataset) + + with open("dataset_backtest.pkl", "rb") as file_dataset_backtest: + dataset_backtest = pickle.load(file_dataset_backtest) + + self._prepare_calender_cache() + ##=============reinit dataset============= + dataset.init( + handler_kwargs={ + "init_type": DataHandlerLP.IT_LS, + "start_time": "2021-01-19 00:00:00", + "end_time": "2021-01-25 16:00:00", + }, + segment_kwargs={ + "test": ( + "2021-01-19 00:00:00", + "2021-01-25 16:00:00", + ), + }, + ) + dataset_backtest.init( + handler_kwargs={ + "start_time": "2021-01-19 00:00:00", + "end_time": "2021-01-25 16:00:00", + }, + segment_kwargs={ + "test": ( + "2021-01-19 00:00:00", + "2021-01-25 16:00:00", + ), + }, + ) + + ##=============get data============= + xtest = dataset.prepare(["test"]) + backtest_test = dataset_backtest.prepare(["test"]) + + print(xtest, backtest_test) + return + + + def get_high_freq_data(self, data_path): + self._init_qlib() + self._prepare_calender_cache() + + import os + dataset = init_instance_by_config(self.task["dataset"]) + xtrain, xtest = dataset.prepare(["train", "test"]) + normed_feature = pd.concat([xtrain, xtest]).sort_index() + dic = dict(tuple(normed_feature.groupby("instrument"))) + feature_path = os.path.join(data_path, "normed_feature/") + if not os.path.exists(feature_path): + os.makedirs(feature_path) + for k, v in dic.items(): + v.to_pickle(feature_path + f"{k}.pkl") + + + dataset_backtest = init_instance_by_config(self.task["dataset_backtest"]) + backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"]) + backtest = pd.concat([backtest_train, backtest_test]).sort_index() + backtest['date'] = backtest.index.map(lambda x: x[1].date()) + backtest.set_index('date', append=True, drop=True, inplace=True) + dic = dict(tuple(backtest.groupby("instrument"))) + backtest_path = os.path.join(data_path, "backtest/") + if not os.path.exists(backtest_path): + os.makedirs(backtest_path) + for k, v in dic.items(): + v.to_pickle(backtest_path + f"{k}.pkl.backtest") + + +if __name__ == "__main__": + #fire.Fire(HighfreqWorkflow) + data_path = '../data/' + workflow = HighfreqWorkflow() + workflow.get_high_freq_data(data_path) + diff --git a/examples/trade/README.md b/examples/trade/README.md index 5a4cb53bf..5b621c37a 100644 --- a/examples/trade/README.md +++ b/examples/trade/README.md @@ -4,6 +4,91 @@ This is the experiment code for our AAAI 2021 paper "[Universal Trading for Orde ## Abstract As a fundamental problem in algorithmic trading, order execution aims at fulfilling a specific trading order, either liquidation or acquirement, for a given instrument. Towards effective execution strategy, recent years have witnessed the shift from the analytical view with model-based market assumptions to model-free perspective, i.e., reinforcement learning, due to its nature of sequential decision optimization. However, the noisy and yet imperfect market information that can be leveraged by the policy has made it quite challenging to build up sample efficient reinforcement learning methods to achieve effective order execution. In this paper, we propose a novel universal trading policy optimization framework to bridge the gap between the noisy yet imperfect market states and the optimal action sequences for order execution. Particularly, this framework leverages a policy distillation method that can better guide the learning of the common policy towards practically optimal execution by an oracle teacher with perfect information to approximate the optimal trading strategy. The extensive experiments have shown significant improvements of our method over various strong baselines, with reasonable trading actions. +## Environment Dependencies + +### Dependencies + +``` +gym==0.17.3 +torch==1.6.0 +numba==0.51.2 +numpy==1.19.1 +pandas==1.1.3 +tqdm==4.50.2 +tianshou==0.3.0.post1 +env==0.1.0 +PyYAML==5.4.1 +redis==3.5.3 +``` + +### Environment Variable + +`EXP_PATH` Absolute path to your config folder, we give folder `exp` as an example. + +`OUTPUT_DIR` Absolute path to your log folder. + +## Data Processing + +For Feature processing, we take Yahoo dataset as an example, which can be precessed in `qlib/examples/highfreq/workflow.py` file. If you have a need to change your data storage path, you can change the `data_path` in `workflow.py`, and then do the following. + +``` +python workflow.py +``` + +For order generation, if you have changed change the the `data_path` in `workflow.py`, change `data_path` in `order_gen.py` again, then do the following. + +``` +python order_gen.py +``` + +## Training and backtest + +### Config file + +Config file is need to start our project, we take `PPO`, `OPDS` and `OPD` as an example in folder `exp/example`. If you want to use our given config, make sure the `data_path` you set before matches the config file. + +### Baseline method + +To run a method, you can do the following. + +``` +python main.py --config={config_path} +``` + +Where `{config_path}` means the relative path from your config.yml to `EXP_PATH`. + +If you need to run our given method such as PPO method, you can do the following. + +``` +python main.py --config=example/PPO/config.yml +``` + +### OPD method + +OPD method is a multi step method, at first you should run OPDT as the teacher in OPD method. + +``` +python main.py --config=example/OPDT/config.yml +``` + +After training, find the `policy_best` file in your OPDT log file and copy it to `trade` file for backtest. Also you can change `policy_path` in the `example/OPDT_b/config.yml` to your `policy_best` file. Then run the backtest method. + +``` +python main.py --config=example/OPDT_b/config.yml +``` + +then processed feature from teacher. Remember to change `log_path` if you have changed `log_dir` in `OPDT_b/config.yml`. + +``` +python teacher_feature.py +``` + +and finally start our OPD method. + +``` +python main.py --config=example/OPD/config.yml +``` + ### Citation You are more than welcome to cite our paper: ``` @@ -13,4 +98,4 @@ You are more than welcome to cite our paper: booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, year={2021} } -``` \ No newline at end of file +``` diff --git a/examples/trade/exp/example/OPD/config.yml b/examples/trade/exp/example/OPD/config.yml new file mode 100644 index 000000000..3bca5141e --- /dev/null +++ b/examples/trade/exp/example/OPD/config.yml @@ -0,0 +1,76 @@ +seed: 42 +task: train +log_dir: example/OPD +buffer_size: 80000 +io_conf: + test_sampler: TestSampler + train_sampler: Sampler + test_logger: DFLogger +resources: + num_cpus: 24 + num_gpus: 1 + device: cuda +train_paths: + raw_dir: ../data/backtest/ + order_dir: ../data/order/train/ +valid_paths: + raw_dir: ../data/backtest/ + order_dir: ../data/order/valid/ +test_paths: + raw_dir: ../data/backtest/ + order_dir: ../data/order/test/ +env_conf: + name: StockEnv_Acc + max_step_num: 237 + limit: 10 + time_interval: 30 + interval_num: 8 + features: + - name: raw + type: range + loc: ../data/normed_feature/ + size: 180 + - name: teacher_action + type: interval + size: 1 + loc: ../data/feature/teacher/ + obs: + name: RuleTeacher + config: {} + action: + name: Static_Action + config: + action_num: 5 + action_map: [0, 0.25, 0.5, 0.75, 1] + reward: + VP_Penalty_small_vec: + penalty: 100 + coefficient: 1 +policy_conf: + name: PPO_sup + config: + discount_factor: 1. + max_grad_norm: 100. + reward_normalization: False + eps_clip: 0.3 + value_clip: True + vf_coef: 1. + gae_lambda: 1. + vf_clip_para: 0.3 + sup_coef: 0.01 +network_conf: + name: OPD + config: + hidden_size: 64 + out_shape: 5 + fc_size: 32 + cnn_shape: [30, 6] +optim: + lr: 1e-4 + batch_size: 1024 + max_epoch: 30 + step_per_epoch: 20 + collect_per_step: 10000 + repeat_per_collect: 5 + early_stopping: 5 + weight_decay: 0. \ No newline at end of file diff --git a/examples/trade/exp/example/OPDS/config.yml b/examples/trade/exp/example/OPDS/config.yml new file mode 100644 index 000000000..ca583ace8 --- /dev/null +++ b/examples/trade/exp/example/OPDS/config.yml @@ -0,0 +1,71 @@ +seed: 42 +task: train +log_dir: example/OPDS +buffer_size: 80000 +io_conf: + test_sampler: TestSampler + train_sampler: Sampler + test_logger: DFLogger +resources: + num_cpus: 24 + num_gpus: 1 + device: cuda +train_paths: + raw_dir: ../data/backtest/ + order_dir: ../data/order/train/ +valid_paths: + raw_dir: ../data/backtest/ + order_dir: ../data/order/valid/ +test_paths: + raw_dir: ../data/backtest/ + order_dir: ../data/order/test/ +env_conf: + name: StockEnv_Acc + max_step_num: 237 + limit: 10 + time_interval: 30 + interval_num: 8 + features: + - name: raw + type: range + loc: ../data/normed_feature/ + size: 180 + obs: + name: TeacherObs + config: {} + action: + name: Static_Action + config: + action_num: 5 + action_map: [0, 0.25, 0.5, 0.75, 1] + reward: + VP_Penalty_small_vec: + penalty: 100 + coefficient: 1 +policy_conf: + name: PPO + config: + discount_factor: 1. + max_grad_norm: 100. + reward_normalization: False + eps_clip: 0.3 + value_clip: True + vf_coef: 1. + gae_lambda: 1. + vf_clip_para: 0.3 +network_conf: + name: PPO + config: + hidden_size: 64 + out_shape: 5 + fc_size: 32 + cnn_shape: [30, 6] +optim: + lr: 1e-4 + batch_size: 1024 + max_epoch: 30 + step_per_epoch: 20 + collect_per_step: 10000 + repeat_per_collect: 5 + early_stopping: 5 + weight_decay: 0. diff --git a/examples/trade/exp/example/OPDT/config.yml b/examples/trade/exp/example/OPDT/config.yml new file mode 100644 index 000000000..fefc76c12 --- /dev/null +++ b/examples/trade/exp/example/OPDT/config.yml @@ -0,0 +1,71 @@ +seed: 42 +task: train +log_dir: example/OPDT +buffer_size: 80000 +io_conf: + test_sampler: TestSampler + train_sampler: Sampler + test_logger: DFLogger +resources: + num_cpus: 24 + num_gpus: 1 + device: cuda +train_paths: + raw_dir: ../data/backtest/ + order_dir: ../data/order/train/ +valid_paths: + raw_dir: ../data/backtest/ + order_dir: ../data/order/valid/ +test_paths: + raw_dir: ../data/backtest/ + order_dir: ../data/order/test/ +env_conf: + name: StockEnv_Acc + max_step_num: 237 + limit: 10 + time_interval: 30 + interval_num: 8 + features: + - name: raw + type: range + loc: ../data/normed_feature/ + size: 180 + obs: + name: TeacherObs + config: {} + action: + name: Static_Action + config: + action_num: 5 + action_map: [0, 0.25, 0.5, 0.75, 1] + reward: + VP_Penalty_small_vec: + penalty: 100 + coefficient: 1 +policy_conf: + name: PPO + config: + discount_factor: 1. + max_grad_norm: 100. + reward_normalization: False + eps_clip: 0.3 + value_clip: True + vf_coef: 1. + gae_lambda: 1. + vf_clip_para: 0.3 +network_conf: + name: Teacher + config: + hidden_size: 64 + out_shape: 5 + fc_size: 32 + cnn_shape: [30, 6] +optim: + lr: 1e-4 + batch_size: 1024 + max_epoch: 30 + step_per_epoch: 20 + collect_per_step: 10000 + repeat_per_collect: 5 + early_stopping: 5 + weight_decay: 0. \ No newline at end of file diff --git a/examples/trade/exp/example/OPDT_b/config.yml b/examples/trade/exp/example/OPDT_b/config.yml new file mode 100644 index 000000000..697f866aa --- /dev/null +++ b/examples/trade/exp/example/OPDT_b/config.yml @@ -0,0 +1,76 @@ +seed: 42 +task: eval +log_dir: example/OPDT_b +buffer_size: 80000 +io_conf: + test_sampler: TestSampler + train_sampler: Sampler + test_logger: DFLogger +resources: + num_cpus: 24 + num_gpus: 1 + device: cuda +train_paths: + raw_dir: ../data/backtest/ + order_dir: ../data/order/train/ +valid_paths: + raw_dir: ../data/backtest/ + order_dir: ../data/order/valid/ +test_paths: + raw_dir: ../data/backtest/ + order_dir: ../data/order/all/ +env_conf: + name: StockEnv_Acc + max_step_num: 237 + limit: 10 + time_interval: 30 + interval_num: 8 + features: + - name: raw + type: range + loc: ../data/normed_feature/ + size: 180 + obs: + name: TeacherObs + config: {} + action: + name: Static_Action + config: + action_num: 5 + action_map: [0, 0.25, 0.5, 0.75, 1] + reward: + VP_Penalty_small_vec: + penalty: 100 + coefficient: 1 +policy_path: policy_best +policy_conf: + name: PPO + config: + discount_factor: 1. + max_grad_norm: 100. + reward_normalization: False + eps_clip: 0.3 + value_clip: True + vf_coef: 1. + gae_lambda: 1. + vf_clip_para: 0.3 +network_conf: + name: Teacher + config: + hidden_size: 64 + out_shape: 5 + fc_size: 32 + cnn_shape: [30, 6] +optim: + lr: 1e-4 + batch_size: 1024 + max_epoch: 30 + step_per_epoch: 20 + collect_per_step: 10000 + repeat_per_collect: 5 + early_stopping: 5 + weight_decay: 0. +search: + optim.weight_decay: + type: choice + value: [0.] \ No newline at end of file diff --git a/examples/trade/exp/example/PPO/config.yml b/examples/trade/exp/example/PPO/config.yml new file mode 100644 index 000000000..b3c759fb0 --- /dev/null +++ b/examples/trade/exp/example/PPO/config.yml @@ -0,0 +1,70 @@ +seed: 42 +task: train +log_dir: example/PPO +buffer_size: 80000 +io_conf: + test_sampler: TestSampler + train_sampler: Sampler + test_logger: DFLogger +resources: + num_cpus: 24 + num_gpus: 1 + device: cuda +train_paths: + raw_dir: ../data/backtest/ + order_dir: ../data/order/train/ +valid_paths: + raw_dir: ../data/backtest/ + order_dir: ../data/order/valid/ +test_paths: + raw_dir: ../data/backtest/ + order_dir: ../data/order/test/ +env_conf: + name: StockEnv_Acc + max_step_num: 237 + limit: 10 + time_interval: 30 + interval_num: 8 + features: + - name: raw + type: range + loc: ../data/normed_feature/ + size: 180 + obs: + name: TeacherObs + config: {} + action: + name: Static_Action + config: + action_num: 5 + action_map: [0, 0.25, 0.5, 0.75, 1] + reward: + PPO_Reward: + coefficient: 1 +policy_conf: + name: PPO + config: + discount_factor: 1. + max_grad_norm: 100. + reward_normalization: False + eps_clip: 0.3 + value_clip: True + vf_coef: 1. + gae_lambda: 1. + vf_clip_para: 0.3 +network_conf: + name: PPO + config: + hidden_size: 64 + out_shape: 5 + fc_size: 32 + cnn_shape: [30, 6] +optim: + lr: 1e-4 + batch_size: 1024 + max_epoch: 30 + step_per_epoch: 20 + collect_per_step: 10000 + repeat_per_collect: 5 + early_stopping: 5 + weight_decay: 0. \ No newline at end of file diff --git a/examples/trade/logger/single_logger.py b/examples/trade/logger/single_logger.py index 002801ab2..c24bc18fe 100644 --- a/examples/trade/logger/single_logger.py +++ b/examples/trade/logger/single_logger.py @@ -87,7 +87,7 @@ class DFLogger(object): df_cache[ins] = ( [], [], - len(pd.read_pickle(order_dir + ins + ".pkl.target")), + (pd.read_pickle(order_dir + ins + ".pkl.target")['amount'] != 0).sum(), ) df_cache[ins][0].append(df) df_cache[ins][1].append(res) diff --git a/examples/trade/order_gen.py b/examples/trade/order_gen.py new file mode 100644 index 000000000..898aaab7c --- /dev/null +++ b/examples/trade/order_gen.py @@ -0,0 +1,59 @@ +import numpy as np +import pandas as pd +import os +import time +import datetime +from joblib import Parallel, delayed + +data_path = '../data/' +in_dir = os.path.join(data_path, 'backtest/') + +### create order folders #### + +def generate_order(df, start, end): +# df['date'] = df.index.map(lambda x: x[1].date()) +# df.set_index('date', append=True, inplace=True) + df = df.groupby('date').take(range(start, end)).droplevel(level=0) + div = df['$volume0'].rolling((end - start)*60).mean().shift(1).groupby(level='date').transform('first') + order = df.groupby(level=(2, 0)).mean().dropna() + order = pd.DataFrame(order) + order['amount'] = np.random.lognormal(-3.28, 1.14) * order['$volume0'] + order['order_type'] = 0 + order = order.drop(columns=["$volume0", "$vwap0"]) + return order + +def w_order(f, start, end): + df = pd.read_pickle(in_dir + f) + #df['date'] = df.index.get_level_values(1).map(lambda x: x.date()) + #df = df.set_index('date', append=True, drop=True) +# old_order = pd.read_pickle('../v-zeh/full-07-20/order/ratio_test/' + f) + order = generate_order(df, start, end) +# order = order[order.index.isin(old_order.index)] + order_train = order[order.index.get_level_values(0) < '2020-12-01'] + order_test = order[order.index.get_level_values(0) >= '2020-12-01'] + order_valid = order_test[order_test.index.get_level_values(0) < '2021-01-01'] + order_test = order_test[order_test.index.get_level_values(0) >= '2021-01-01'] + if len(order_train) > 0: + train_path = os.path.join(data_path, "order/train/") + if not os.path.exists(train_path): + os.makedirs(train_path) + order_train.to_pickle(train_path + f[:-9] + '.target') + if len(order_valid) > 0: + valid_path = os.path.join(data_path, "order/valid/") + if not os.path.exists(valid_path): + os.makedirs(valid_path) + order_valid.to_pickle(valid_path + f[:-9] + '.target') + if len(order_test) > 0: + test_path = os.path.join(data_path, "order/test/") + if not os.path.exists(test_path): + os.makedirs(test_path) + order_test.to_pickle(test_path + f[:-9] + '.target') + if len(order) > 0: + all_path = os.path.join(data_path, "order/all/") + if not os.path.exists(all_path): + os.makedirs(all_path) + order_test.to_pickle(all_path + f[:-9] + '.target') + return 0 + +res = Parallel(n_jobs=64)(delayed(w_order)(f, 0, 239) for f in os.listdir(in_dir)) +print(sum(res)) diff --git a/examples/trade/teacher_feature.py b/examples/trade/teacher_feature.py new file mode 100644 index 000000000..d605fb74e --- /dev/null +++ b/examples/trade/teacher_feature.py @@ -0,0 +1,24 @@ +import pandas as pd +import os + +data_path = '../data/' +feature_path = os.path.join(data_path, 'feature/teacher/') +if not os.path.exists(feature_path): + os.makedirs(feature_path) + +log_file = os.path.join(os.environ.get('OUTPUT_DIR'),'example/OPDT_b/0/test/') +files = os.listdir(log_file) + +for f in files: + if f.endswith(".log"): + df = pd.read_pickle(log_file + f) + df['datetime'] = df.index.get_level_values(1).map(lambda x: x[1]) + df.set_index('datetime', append=True, drop=True, inplace=True) + action = df['action'] + action = action.reset_index(level=1, drop=True) + action.index = action.index.map(lambda x: (x[0], x[1], x[2].time())) + action = action.unstack().iloc[:, ::30] * 2 + action = action.fillna(0) + train_action = action.astype("int") + final = train_action + final.to_pickle(feature_path + f[:-4] + '.pkl') \ No newline at end of file diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py index ea9c70083..6d166646c 100644 --- a/examples/workflow_by_code.py +++ b/examples/workflow_by_code.py @@ -17,7 +17,7 @@ from qlib.contrib.evaluate import ( from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord, PortAnaRecord - +from qlib.tests.data import GetData if __name__ == "__main__": @@ -25,9 +25,6 @@ if __name__ == "__main__": provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir if not exists_qlib_data(provider_uri): print(f"Qlib data is not found in {provider_uri}") - sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) - from get_data import GetData - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) qlib.init(provider_uri=provider_uri, region=REG_CN) diff --git a/qlib/config.py b/qlib/config.py index a65d41041..e7120c23a 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -193,6 +193,12 @@ MODE_CONF = { }, } +HIGH_FREQ_CONFIG = { + "provider_uri": "~/.qlib/qlib_data/yahoo_cn_1min", + "dataset_cache": None, + "expression_cache": "DiskExpressionCache", + "region": REG_CN, +} _default_region_config = { REG_CN: { @@ -291,12 +297,12 @@ class QlibConfig(Config): def register(self): from .utils import init_instance_by_config - from .data.ops import register_custom_ops + from .data.ops import register_all_ops from .data.data import register_all_wrappers from .workflow import R, QlibRecorder from .workflow.utils import experiment_exit_handler - register_custom_ops(self) + register_all_ops(self) register_all_wrappers(self) # set up QlibRecorder exp_manager = init_instance_by_config(self["exp_manager"]) diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py index 88a1f0680..23e37a5e4 100644 --- a/qlib/contrib/data/handler.py +++ b/qlib/contrib/data/handler.py @@ -49,6 +49,7 @@ class Alpha360(DataHandlerLP): instruments="csi500", start_time=None, end_time=None, + freq="day", infer_processors=_DEFAULT_INFER_PROCESSORS, learn_processors=_DEFAULT_LEARN_PROCESSORS, fit_start_time=None, @@ -69,9 +70,10 @@ class Alpha360(DataHandlerLP): } super().__init__( - instruments, - start_time, - end_time, + instruments=instruments, + start_time=start_time, + end_time=end_time, + freq="day", data_loader=data_loader, learn_processors=learn_processors, infer_processors=infer_processors, @@ -130,6 +132,7 @@ class Alpha158(DataHandlerLP): instruments="csi500", start_time=None, end_time=None, + freq="day", infer_processors=[], learn_processors=_DEFAULT_LEARN_PROCESSORS, fit_start_time=None, @@ -147,9 +150,10 @@ class Alpha158(DataHandlerLP): }, } super().__init__( - instruments, - start_time, - end_time, + instruments=instruments, + start_time=start_time, + end_time=end_time, + freq=freq, data_loader=data_loader, infer_processors=infer_processors, learn_processors=learn_processors, diff --git a/qlib/data/base.py b/qlib/data/base.py index 92fc57ffe..e318843c4 100644 --- a/qlib/data/base.py +++ b/qlib/data/base.py @@ -157,7 +157,7 @@ class Expression(abc.ABC): @abc.abstractmethod def _load_internal(self, instrument, start_index, end_index, freq): - pass + raise NotImplementedError("This function must be implemented in your newly defined feature") @abc.abstractmethod def get_longest_back_rolling(self): diff --git a/qlib/data/cache.py b/qlib/data/cache.py index 243736ddc..0174dc63f 100644 --- a/qlib/data/cache.py +++ b/qlib/data/cache.py @@ -825,8 +825,8 @@ class DiskDatasetCache(DatasetCache): .. note:: The start is closed. The end is open!!!!! - - Each line contains two element - - It indicates the `end_index` of the data for `timestamp` + - Each line contains two element with a timestamp as its index. + - It indicates the `start_index`(included) and `end_index`(excluded) of the data for `timestamp` - meta data: cache/d41366901e25de3ec47297f12e2ba11d.meta diff --git a/qlib/data/data.py b/qlib/data/data.py index ece3c3641..2a0e569ab 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -117,7 +117,7 @@ class CalendarProvider(abc.ABC): if flag in H["c"]: _calendar, _calendar_index = H["c"][flag] else: - _calendar = np.array(self._load_calendar(freq, future)) + _calendar = np.array(self.load_calendar(freq, future)) _calendar_index = {x: i for i, x in enumerate(_calendar)} # for fast search H["c"][flag] = _calendar, _calendar_index return _calendar, _calendar_index @@ -504,7 +504,7 @@ class LocalCalendarProvider(CalendarProvider): """Calendar file uri.""" return os.path.join(C.get_data_path(), "calendars", "{}.txt") - def _load_calendar(self, freq, future): + def load_calendar(self, freq, future): """Load original calendar timestamp from file. Parameters @@ -672,6 +672,8 @@ class LocalExpressionProvider(ExpressionProvider): series = series.astype(np.float32) except ValueError: pass + except TypeError: + pass if not series.empty: series = series.loc[start_index:end_index] return series diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 96e4a6e41..117da764f 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -87,6 +87,10 @@ class DatasetH(Dataset): """ super().__init__(handler, segments) + def init(self, **kwargs): + """Initialize the DatasetH, Only parameters belonging to handler.init will be passed in""" + self.handler.init(**kwargs) + def setup_data(self, handler: Union[dict, DataHandler], segments: list): """ Setup the underlying data. @@ -116,8 +120,8 @@ class DatasetH(Dataset): 'outsample': ("2017-01-01", "2020-08-01",), } """ - self._handler = init_instance_by_config(handler, accept_types=DataHandler) - self._segments = segments.copy() + self.handler = init_instance_by_config(handler, accept_types=DataHandler) + self.segments = segments.copy() def _prepare_seg(self, slc: slice, **kwargs): """ @@ -127,7 +131,7 @@ class DatasetH(Dataset): ---------- slc : slice """ - return self._handler.fetch(slc, **kwargs) + return self.handler.fetch(slc, **kwargs) def prepare( self, @@ -150,7 +154,7 @@ class DatasetH(Dataset): - ['train', 'valid'] col_set : str - The col_set will be passed to self._handler when fetching data. + The col_set will be passed to self.handler when fetching data. data_key : str The data to fetch: DK_* Default is DK_I, which indicate fetching data for **inference**. @@ -166,16 +170,16 @@ class DatasetH(Dataset): logger = get_module_logger("DatasetH") fetch_kwargs = {"col_set": col_set} fetch_kwargs.update(kwargs) - if "data_key" in getfullargspec(self._handler.fetch).args: + if "data_key" in getfullargspec(self.handler.fetch).args: fetch_kwargs["data_key"] = data_key else: logger.info(f"data_key[{data_key}] is ignored.") # Handle all kinds of segments format if isinstance(segments, (list, tuple)): - return [self._prepare_seg(slice(*self._segments[seg]), **fetch_kwargs) for seg in segments] + return [self._prepare_seg(slice(*self.segments[seg]), **fetch_kwargs) for seg in segments] elif isinstance(segments, str): - return self._prepare_seg(slice(*self._segments[segments]), **fetch_kwargs) + return self._prepare_seg(slice(*self.segments[segments]), **fetch_kwargs) elif isinstance(segments, slice): return self._prepare_seg(segments, **fetch_kwargs) else: @@ -409,7 +413,7 @@ class TSDatasetH(DatasetH): def setup_data(self, *args, **kwargs): super().setup_data(*args, **kwargs) - cal = self._handler.fetch(col_set=self._handler.CS_RAW).index.get_level_values("datetime").unique() + cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique() cal = sorted(cal) # Get the datatime index for building timestamp self.cal = cal diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 18f838300..abcd5a60c 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -57,6 +57,7 @@ class DataHandler(Serializable): instruments=None, start_time=None, end_time=None, + freq="day", data_loader: Tuple[dict, str, DataLoader] = None, init_data=True, fetch_orig=True, @@ -70,6 +71,8 @@ class DataHandler(Serializable): start_time of the original data. end_time : end_time of the original data. + freq : + frequency of data data_loader : Tuple[dict, str, DataLoader] data loader to load the data. init_data : @@ -92,6 +95,7 @@ class DataHandler(Serializable): self.instruments = instruments self.start_time = start_time self.end_time = end_time + self.freq = freq self.fetch_orig = fetch_orig if init_data: with TimeInspector.logt("Init data"): @@ -119,7 +123,7 @@ class DataHandler(Serializable): # Setup data. # _data may be with multiple column index level. The outer level indicates the feature set name with TimeInspector.logt("Loading data"): - self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time) + self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time, self.freq) # TODO: cache CS_ALL = "__all" # return all columns with single-level index column @@ -258,10 +262,12 @@ class DataHandlerLP(DataHandler): instruments=None, start_time=None, end_time=None, + freq="day", data_loader: Tuple[dict, str, DataLoader] = None, infer_processors=[], learn_processors=[], process_type=PTYPE_A, + drop_raw=False, **kwargs, ): """ @@ -303,6 +309,8 @@ class DataHandlerLP(DataHandler): - self._learn will be processed by infer_processors + learn_processors - (e.g. self._infer processed by learn_processors ) + drop_raw: bool + Whether to drop the raw data """ # Setup preprocessor @@ -319,7 +327,8 @@ class DataHandlerLP(DataHandler): ) self.process_type = process_type - super().__init__(instruments, start_time, end_time, data_loader, **kwargs) + self.drop_raw = drop_raw + super().__init__(instruments, start_time, end_time, freq, data_loader, **kwargs) def get_all_processors(self): return self.infer_processors + self.learn_processors @@ -348,7 +357,7 @@ class DataHandlerLP(DataHandler): """ # data for inference _infer_df = self._data - if len(self.infer_processors) > 0: # avoid modifying the original data + if len(self.infer_processors) > 0 and not self.drop_raw: # avoid modifying the original data _infer_df = _infer_df.copy() for proc in self.infer_processors: @@ -378,6 +387,9 @@ class DataHandlerLP(DataHandler): _learn_df = proc(_learn_df) self._learn = _learn_df + if self.drop_raw: + del self._data + # init type IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df @@ -416,6 +428,10 @@ class DataHandlerLP(DataHandler): # TODO: Be able to cache handler data. Save the memory for data processing def _get_df_by_key(self, data_key: str = DK_I) -> pd.DataFrame: + if data_key == self.DK_R and self.drop_raw: + raise AttributeError( + "DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data" + ) df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key]) return df diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index a51ea119a..3b33ff749 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -19,7 +19,7 @@ class DataLoader(abc.ABC): """ @abc.abstractmethod - def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame: + def load(self, instruments, start_time=None, end_time=None, freq="day") -> pd.DataFrame: """ load the data as pd.DataFrame. @@ -94,7 +94,9 @@ class DLWParser(DataLoader): return exprs, names @abc.abstractmethod - def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame: + def load_group_df( + self, instruments, exprs: list, names: list, start_time=None, end_time=None, freq="day" + ) -> pd.DataFrame: """ load the dataframe for specific group @@ -114,25 +116,25 @@ class DLWParser(DataLoader): """ pass - def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: + def load(self, instruments=None, start_time=None, end_time=None, freq="day") -> pd.DataFrame: if self.is_group: df = pd.concat( { - grp: self.load_group_df(instruments, exprs, names, start_time, end_time) + grp: self.load_group_df(instruments, exprs, names, start_time, end_time, freq) for grp, (exprs, names) in self.fields.items() }, axis=1, ) else: exprs, names = self.fields - df = self.load_group_df(instruments, exprs, names, start_time, end_time) + df = self.load_group_df(instruments, exprs, names, start_time, end_time, freq) return df class QlibDataLoader(DLWParser): """Same as QlibDataLoader. The fields can be define by config""" - def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None): + def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_level=True): """ Parameters ---------- @@ -140,11 +142,16 @@ class QlibDataLoader(DLWParser): Please refer to the doc of DLWParser filter_pipe : Filter pipe for the instruments + swap_level : + Whether to swap level of MultiIndex """ self.filter_pipe = filter_pipe + self.swap_level = swap_level super().__init__(config) - def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame: + def load_group_df( + self, instruments, exprs: list, names: list, start_time=None, end_time=None, freq="day" + ) -> pd.DataFrame: if instruments is None: warnings.warn("`instruments` is not set, will load all stocks") instruments = "all" @@ -153,9 +160,10 @@ class QlibDataLoader(DLWParser): elif self.filter_pipe is not None: warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list") - df = D.features(instruments, exprs, start_time, end_time) + df = D.features(instruments, exprs, start_time, end_time, freq) df.columns = names - df = df.swaplevel().sort_index() # NOTE: always return + if self.swap_level: + df = df.swaplevel().sort_index() # NOTE: if swaplevel, return return df @@ -177,7 +185,7 @@ class StaticDataLoader(DataLoader): self.join = join self._data = None - def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: + def load(self, instruments=None, start_time=None, end_time=None, freq="day") -> pd.DataFrame: self._maybe_load_raw_data() if instruments is None: df = self._data diff --git a/qlib/data/ops.py b/qlib/data/ops.py index 91f7349d2..940c24002 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -6,6 +6,7 @@ from __future__ import division from __future__ import print_function import sys +import abc import numpy as np import pandas as pd @@ -17,7 +18,7 @@ from ..log import get_module_logger try: from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi from ._libs.expanding import expanding_slope, expanding_rsquare, expanding_resi -except ImportError as err: +except ImportError: print( "#### Do not import qlib package in the repository directory in case of importing qlib from . without compiling #####" ) @@ -32,12 +33,39 @@ np.seterr(invalid="ignore") class ElemOperator(ExpressionOps): """Element-wise Operator + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + Expression + feature operation output + """ + + def __init__(self, feature): + self.feature = feature + + def __str__(self): + return "{}({})".format(type(self).__name__, self.feature) + + def get_longest_back_rolling(self): + return self.feature.get_longest_back_rolling() + + def get_extended_window_size(self): + return self.feature.get_extended_window_size() + + +class NpElemOperator(ElemOperator): + """Numpy Element-wise Operator + Parameters ---------- feature : Expression feature instance func : str - feature operation method + numpy feature operation method Returns ---------- @@ -48,22 +76,14 @@ class ElemOperator(ExpressionOps): def __init__(self, feature, func): self.feature = feature self.func = func - - def __str__(self): - return "{}({})".format(type(self).__name__, self.feature) + super(NpElemOperator, self).__init__(feature) def _load_internal(self, instrument, start_index, end_index, freq): series = self.feature.load(instrument, start_index, end_index, freq) return getattr(np, self.func)(series) - def get_longest_back_rolling(self): - return self.feature.get_longest_back_rolling() - def get_extended_window_size(self): - return self.feature.get_extended_window_size() - - -class Abs(ElemOperator): +class Abs(NpElemOperator): """Feature Absolute Value Parameters @@ -81,7 +101,7 @@ class Abs(ElemOperator): super(Abs, self).__init__(feature, "abs") -class Sign(ElemOperator): +class Sign(NpElemOperator): """Feature Sign Parameters @@ -108,7 +128,7 @@ class Sign(ElemOperator): return getattr(np, self.func)(series) -class Log(ElemOperator): +class Log(NpElemOperator): """Feature Log Parameters @@ -126,7 +146,7 @@ class Log(ElemOperator): super(Log, self).__init__(feature, "log") -class Power(ElemOperator): +class Power(NpElemOperator): """Feature Power Parameters @@ -152,7 +172,7 @@ class Power(ElemOperator): return getattr(np, self.func)(series, self.exponent) -class Mask(ElemOperator): +class Mask(NpElemOperator): """Feature Mask Parameters @@ -179,7 +199,7 @@ class Mask(ElemOperator): return self.feature.load(self.instrument, start_index, end_index, freq) -class Not(ElemOperator): +class Not(NpElemOperator): """Not Operator Parameters @@ -218,28 +238,13 @@ class PairOperator(ExpressionOps): two features' operation output """ - def __init__(self, feature_left, feature_right, func): + def __init__(self, feature_left, feature_right): self.feature_left = feature_left self.feature_right = feature_right - self.func = func def __str__(self): return "{}({},{})".format(type(self).__name__, self.feature_left, self.feature_right) - def _load_internal(self, instrument, start_index, end_index, freq): - assert any( - [isinstance(self.feature_left, Expression), self.feature_right, Expression] - ), "at least one of two inputs is Expression instance" - if isinstance(self.feature_left, Expression): - series_left = self.feature_left.load(instrument, start_index, end_index, freq) - else: - series_left = self.feature_left # numeric value - if isinstance(self.feature_right, Expression): - series_right = self.feature_right.load(instrument, start_index, end_index, freq) - else: - series_right = self.feature_right - return getattr(np, self.func)(series_left, series_right) - def get_longest_back_rolling(self): if isinstance(self.feature_left, Expression): left_br = self.feature_left.get_longest_back_rolling() @@ -265,7 +270,46 @@ class PairOperator(ExpressionOps): return max(ll, rl), max(lr, rr) -class Add(PairOperator): +class NpPairOperator(PairOperator): + """Numpy Pair-wise operator + + Parameters + ---------- + feature_left : Expression + feature instance or numeric value + feature_right : Expression + feature instance or numeric value + func : str + operator function + + Returns + ---------- + Feature: + two features' operation output + """ + + def __init__(self, feature_left, feature_right, func): + self.feature_left = feature_left + self.feature_right = feature_right + self.func = func + super(NpPairOperator, self).__init__(feature_left, feature_right) + + def _load_internal(self, instrument, start_index, end_index, freq): + assert any( + [isinstance(self.feature_left, Expression), self.feature_right, Expression] + ), "at least one of two inputs is Expression instance" + if isinstance(self.feature_left, Expression): + series_left = self.feature_left.load(instrument, start_index, end_index, freq) + else: + series_left = self.feature_left # numeric value + if isinstance(self.feature_right, Expression): + series_right = self.feature_right.load(instrument, start_index, end_index, freq) + else: + series_right = self.feature_right + return getattr(np, self.func)(series_left, series_right) + + +class Add(NpPairOperator): """Add Operator Parameters @@ -285,7 +329,7 @@ class Add(PairOperator): super(Add, self).__init__(feature_left, feature_right, "add") -class Sub(PairOperator): +class Sub(NpPairOperator): """Subtract Operator Parameters @@ -305,7 +349,7 @@ class Sub(PairOperator): super(Sub, self).__init__(feature_left, feature_right, "subtract") -class Mul(PairOperator): +class Mul(NpPairOperator): """Multiply Operator Parameters @@ -325,7 +369,7 @@ class Mul(PairOperator): super(Mul, self).__init__(feature_left, feature_right, "multiply") -class Div(PairOperator): +class Div(NpPairOperator): """Division Operator Parameters @@ -345,7 +389,7 @@ class Div(PairOperator): super(Div, self).__init__(feature_left, feature_right, "divide") -class Greater(PairOperator): +class Greater(NpPairOperator): """Greater Operator Parameters @@ -365,7 +409,7 @@ class Greater(PairOperator): super(Greater, self).__init__(feature_left, feature_right, "maximum") -class Less(PairOperator): +class Less(NpPairOperator): """Less Operator Parameters @@ -385,7 +429,7 @@ class Less(PairOperator): super(Less, self).__init__(feature_left, feature_right, "minimum") -class Gt(PairOperator): +class Gt(NpPairOperator): """Greater Than Operator Parameters @@ -405,7 +449,7 @@ class Gt(PairOperator): super(Gt, self).__init__(feature_left, feature_right, "greater") -class Ge(PairOperator): +class Ge(NpPairOperator): """Greater Equal Than Operator Parameters @@ -425,7 +469,7 @@ class Ge(PairOperator): super(Ge, self).__init__(feature_left, feature_right, "greater_equal") -class Lt(PairOperator): +class Lt(NpPairOperator): """Less Than Operator Parameters @@ -445,7 +489,7 @@ class Lt(PairOperator): super(Lt, self).__init__(feature_left, feature_right, "less") -class Le(PairOperator): +class Le(NpPairOperator): """Less Equal Than Operator Parameters @@ -465,7 +509,7 @@ class Le(PairOperator): super(Le, self).__init__(feature_left, feature_right, "less_equal") -class Eq(PairOperator): +class Eq(NpPairOperator): """Equal Operator Parameters @@ -485,7 +529,7 @@ class Eq(PairOperator): super(Eq, self).__init__(feature_left, feature_right, "equal") -class Ne(PairOperator): +class Ne(NpPairOperator): """Not Equal Operator Parameters @@ -505,7 +549,7 @@ class Ne(PairOperator): super(Ne, self).__init__(feature_left, feature_right, "not_equal") -class And(PairOperator): +class And(NpPairOperator): """And Operator Parameters @@ -525,7 +569,7 @@ class And(PairOperator): super(And, self).__init__(feature_left, feature_right, "bitwise_and") -class Or(PairOperator): +class Or(NpPairOperator): """Or Operator Parameters @@ -1451,6 +1495,9 @@ class OpsWrapper(object): def __init__(self): self._ops = {} + def reset(self): + self._ops = {} + def register(self, ops_list): for operator in ops_list: if not issubclass(operator, ExpressionOps): @@ -1469,12 +1516,15 @@ class OpsWrapper(object): Operators = OpsWrapper() -Operators.register(OpsList) -def register_custom_ops(C): - """register custom operator""" +def register_all_ops(C): + """register all operator""" logger = get_module_logger("ops") + + Operators.reset() + Operators.register(OpsList) + if getattr(C, "custom_ops", None) is not None: Operators.register(C.custom_ops) logger.debug("register custom operator {}".format(C.custom_ops)) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 10cd588e6..ed2f14d2f 100755 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -66,7 +66,7 @@ class TestDataset(TestAutoData): # Check the data # Get data from DataFrame Directly data_from_df = ( - tsdh._handler.fetch(data_key=DataHandlerLP.DK_L) + tsdh.handler.fetch(data_key=DataHandlerLP.DK_L) .loc(axis=0)["2015-01-01":"2016-12-31", "SZ300315"] .iloc[-30:] .values diff --git a/tests/test_register_ops.py b/tests/test_register_ops.py index cb172b2bb..7d3322ddc 100644 --- a/tests/test_register_ops.py +++ b/tests/test_register_ops.py @@ -26,9 +26,6 @@ class Diff(ElemOperator): a feature instance with first difference """ - def __init__(self, feature): - super(Diff, self).__init__(feature, "diff") - def _load_internal(self, instrument, start_index, end_index, freq): series = self.feature.load(instrument, start_index, end_index, freq) return series.diff() @@ -50,9 +47,6 @@ class Distance(PairOperator): a feature instance with distance """ - def __init__(self, feature_left, feature_right): - super(Distance, self).__init__(feature_left, feature_right, "distance") - def _load_internal(self, instrument, start_index, end_index, freq): series_left = self.feature_left.load(instrument, start_index, end_index, freq) series_right = self.feature_right.load(instrument, start_index, end_index, freq)