From 029b63c9ddc75aeb22243d4e79092c564677ea25 Mon Sep 17 00:00:00 2001 From: bxdd Date: Fri, 28 May 2021 22:29:21 +0800 Subject: [PATCH 1/4] fix bugs & add highfreq backtest example --- examples/multi_level_trading/README.md | 17 +++- examples/multi_level_trading/workflow.py | 109 +++++++++++++++++++++-- qlib/backtest/executor.py | 4 +- qlib/backtest/utils.py | 4 +- qlib/contrib/strategy/model_strategy.py | 3 + qlib/contrib/strategy/rule_strategy.py | 56 +++++++----- qlib/strategy/base.py | 15 ++-- qlib/utils/resam.py | 2 + 8 files changed, 168 insertions(+), 42 deletions(-) diff --git a/examples/multi_level_trading/README.md b/examples/multi_level_trading/README.md index 6761b84ff..2910de58f 100644 --- a/examples/multi_level_trading/README.md +++ b/examples/multi_level_trading/README.md @@ -8,9 +8,12 @@ Qlib supports backtesting of various strategies, including portfolio management And, Qlib also supports multi-level trading and backtesting. It means that users can use different strategies to trade at different frequencies. -This example uses a DropoutTopkStrategy (a strategy based on the daily frequency Lightgbm model) in weekly frequency for portfolio generation. And, at the daily frequency level, this example uses SBBStrategyEMA (a rule-based strategy that uses EMA for decision-making) to split orders. -## Usage +## Weekly Portfolio Generation and Daily Order Execution + +This workflow provides an example that uses a DropoutTopkStrategy (a strategy based on the daily frequency Lightgbm model) in weekly frequency for portfolio generation and uses SBBStrategyEMA (a rule-based strategy that uses EMA for decision-making) to execute orders in daily frequency. + +### Usage Start backtesting by running the following command: ```bash @@ -22,3 +25,13 @@ Start collecting data by running the following command: python workflow.py collect_data ``` +## Daily Portfolio Generation and Minutely Order Execution + +This workflow also provides a high-frequency example that uses a DropoutTopkStrategy for portfolio generation in daily frequency and uses SBBStrategyEMA to execute orders in minutely frequency. + +### Usage + +Start backtesting by running the following command: +```bash + python workflow.py backtest_highfreq +``` \ No newline at end of file diff --git a/examples/multi_level_trading/workflow.py b/examples/multi_level_trading/workflow.py index 8096fc76f..08c91936a 100644 --- a/examples/multi_level_trading/workflow.py +++ b/examples/multi_level_trading/workflow.py @@ -4,8 +4,9 @@ import qlib import fire -from qlib.config import REG_CN - +from qlib import backtest +from qlib.config import REG_CN, HIGH_FREQ_CONFIG +from qlib.data import D from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord, PortAnaRecord @@ -20,7 +21,7 @@ class MultiLevelTradingWorkflow: data_handler_config = { "start_time": "2008-01-01", - "end_time": "2020-08-01", + "end_time": "2021-01-20", "fit_start_time": "2008-01-01", "fit_end_time": "2014-12-31", "instruments": market, @@ -54,15 +55,12 @@ class MultiLevelTradingWorkflow: "segments": { "train": ("2008-01-01", "2014-12-31"), "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), + "test": ("2017-01-01", "2021-01-20"), }, }, }, } - trade_start_time = "2017-01-01" - trade_end_time = "2020-08-01" - port_analysis_config = { "executor": { "class": "NestedExecutor", @@ -86,12 +84,13 @@ class MultiLevelTradingWorkflow: "instruments": market, }, }, + "generate_report": True, "track_data": True, }, }, "backtest": { - "start_time": trade_start_time, - "end_time": trade_end_time, + "start_time": "2017-01-01", + "end_time": "2020-08-01", "account": 100000000, "benchmark": benchmark, "exchange_kwargs": { @@ -167,6 +166,98 @@ class MultiLevelTradingWorkflow: for trade_decision in data_generator: print(trade_decision) + def _init_qlib_with_backend(self): + provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri") + if not exists_qlib_data(provider_uri_1min): + print(f"Qlib data is not found in {provider_uri_1min}") + GetData().qlib_data(target_dir=provider_uri_1min, interval="1min", region=REG_CN) + + # TODO: update new data + # provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir + # if not exists_qlib_data(provider_uri_day): + # print(f"Qlib data is not found in {provider_uri_day}") + # GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN) + provider_uri_day = "/data/csdesign/qlib" + provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day} + client_config = { + "calendar_provider": { + "class": "LocalCalendarProvider", + "module_path": "qlib.data.data", + "kwargs": { + "backend": { + "class": "FileCalendarStorage", + "module_path": "qlib.data.storage.file_storage", + "kwargs": {"provider_uri_map": provider_uri_map}, + } + }, + }, + "feature_provider": { + "class": "LocalFeatureProvider", + "module_path": "qlib.data.data", + "kwargs": { + "backend": { + "class": "FileFeatureStorage", + "module_path": "qlib.data.storage.file_storage", + "kwargs": {"provider_uri_map": provider_uri_map}, + } + }, + }, + } + qlib.init(provider_uri=provider_uri_day, **client_config) + + def _get_highfreq_config(self, model, dataset): + + executor_config = self.port_analysis_config["executor"] + # update executor with hierarchical decison freq ["day", "1min"] + executor_config["kwargs"]["time_per_step"] = "day" + executor_config["kwargs"]["inner_executor"]["kwargs"]["time_per_step"] = "1min" + backtest_config = self.port_analysis_config["backtest"] + + # yahoo highfreq data time + backtest_config["start_time"] = "2020-09-20" + backtest_config["end_time"] = "2021-01-20" + + # update benchmark, yahoo data don't have SH000300 + instruments = D.instruments(market="csi300") + instrument_list = D.list_instruments(instruments=instruments, as_list=True) + backtest_config["benchmark"] = instrument_list + + # update exchange config + backtest_config["exchange_kwargs"]["freq"] = "1min" + + # set strategy + strategy_config = { + "class": "TopkDropoutStrategy", + "module_path": "qlib.contrib.strategy.model_strategy", + "kwargs": { + "model": model, + "dataset": dataset, + "topk": 50, + "n_drop": 5, + }, + } + + return executor_config, strategy_config, backtest_config + + def backtest_highfreq(self): + self._init_qlib_with_backend() + model = init_instance_by_config(self.task["model"]) + dataset = init_instance_by_config(self.task["dataset"]) + self._train_model(model, dataset) + executor_config, strategy_config, backtest_config = self._get_highfreq_config(model, dataset) + + highfreq_port_analysis_config = { + "executor": executor_config, + "strategy": strategy_config, + "backtest": backtest_config, + } + + with R.start(experiment_name="backtest_highfreq"): + + recorder = R.get_recorder() + par = PortAnaRecord(recorder, highfreq_port_analysis_config, "day") + par.generate() + if __name__ == "__main__": fire.Fire(MultiLevelTradingWorkflow) diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index 88a219f41..c51fc4d9d 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -304,7 +304,7 @@ class SimulatorExecutor(BaseExecutor): if self.verbose: if order.direction == Order.SELL: # sell print( - "[I {:%Y-%m-%d}]: sell {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format( + "[I {:%Y-%m-%d %H:%M:%S}]: sell {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format( trade_start_time, order.stock_id, trade_price, @@ -316,7 +316,7 @@ class SimulatorExecutor(BaseExecutor): ) else: print( - "[I {:%Y-%m-%d}]: buy {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format( + "[I {:%Y-%m-%d %H:%M:%S}]: buy {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format( trade_start_time, order.stock_id, trade_price, diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py index fe51c99f3..f66fa091d 100644 --- a/qlib/backtest/utils.py +++ b/qlib/backtest/utils.py @@ -41,7 +41,7 @@ class TradeCalendarManager: - self.trade_step : The number of trading step finished, self.trade_step can be [0, 1, 2, ..., self.trade_len - 1] """ _calendar, freq, freq_sam = get_resam_calendar(freq=freq) - self.trade_calendar = _calendar + self._calendar = _calendar _, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam) self.start_index = _start_index self.end_index = _end_index @@ -91,7 +91,7 @@ class TradeCalendarManager: """ trade_step = trade_step - shift calendar_index = self.start_index + trade_step - return self.trade_calendar[calendar_index], self.trade_calendar[calendar_index + 1] - pd.Timedelta(seconds=1) + return self._calendar[calendar_index], self._calendar[calendar_index + 1] - pd.Timedelta(seconds=1) def get_all_time(self): """Get the start_time and end_time for trading""" diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/model_strategy.py index d563bccea..3a1087be4 100644 --- a/qlib/contrib/strategy/model_strategy.py +++ b/qlib/contrib/strategy/model_strategy.py @@ -93,6 +93,9 @@ class TopkDropoutStrategy(ModelStrategy): trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") + print( + trade_step, pred_start_time, pred_end_time, trade_start_time, trade_end_time, pred_score, self.pred_scores + ) if pred_score is None: return [] if self.only_tradable: diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 24873caae..a85b81636 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -53,7 +53,7 @@ class TWAPStrategy(BaseStrategy): outer_trade_decision : object, optional """ - super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, common_infra=common_infra, **kwargs) + super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs) if outer_trade_decision is not None: self.trade_amount = {} for order in outer_trade_decision: @@ -73,21 +73,24 @@ class TWAPStrategy(BaseStrategy): trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) order_list = [] for order in self.outer_trade_decision: + # if not tradable, continue if not self.trade_exchange.is_stock_tradable( stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time ): continue _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor) _order_amount = None - # consider trade unit + # considering trade unit if _amount_trade_unit is None: - # divide the order equally + # divide the order into equal parts, and trade one part _order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step + 1) # without considering trade unit elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit: - # divide the order equally - # floor((trade_unit_cnt + trade_len - trade_step) / (trade_len - trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - trade_step + 1)) + # divide the order into equal parts, and trade one part + # calculate the total count of trade units to trade trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) + # calculate the amount of one part, ceil the amount + # floor((trade_unit_cnt + trade_len - trade_step) / (trade_len - trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - trade_step + 1)) _order_amount = ( (trade_unit_cnt + trade_len - trade_step) // (trade_len - trade_step + 1) * _amount_trade_unit ) @@ -144,6 +147,14 @@ class SBBStrategyBase(BaseStrategy): self.trade_exchange = trade_exchange def reset_common_infra(self, common_infra): + """ + Parameters + ---------- + common_infra : dict, optional + common infrastructure for backtesting, by default None + - It should include `trade_account`, used to get position + - It should include `trade_exchange`, used to provide market info + """ super(SBBStrategyBase, self).reset_common_infra(common_infra) if common_infra is not None: if "trade_exchange" in common_infra: @@ -154,10 +165,6 @@ class SBBStrategyBase(BaseStrategy): Parameters ---------- outer_trade_decision : object, optional - common_infra : None, optional - common infrastructure for backtesting, by default None - - It should include `trade_account`, used to get position - - It should include `trade_exchange`, used to provide market info """ super(SBBStrategyBase, self).reset(outer_trade_decision=outer_trade_decision, **kwargs) if outer_trade_decision is not None: @@ -186,10 +193,12 @@ class SBBStrategyBase(BaseStrategy): order_list = [] # for each order in in self.outer_trade_decision for order in self.outer_trade_decision: - # predict the price trend + # get the price trend if trade_step % 2 == 0: + # in the first of two adjacent bars, predict the price trend _pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time) else: + # in the second of two adjacent bars, use the trend predicted in the first one _pred_trend = self.trade_trend[(order.stock_id, order.direction)] # if not tradable, continue if not self.trade_exchange.is_stock_tradable( @@ -204,13 +213,14 @@ class SBBStrategyBase(BaseStrategy): _order_amount = None # considering trade unit if _amount_trade_unit is None: - # divide the order equally + # divide the order into equal parts, and trade one part _order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step) # without considering trade unit elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit: - # cal how many trade unit + # divide the order into equal parts, and trade one part + # calculate the total count of trade units to trade trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) - # divide the order equally + # calculate the amount of one part, ceil the amount # floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step)) _order_amount = ( (trade_unit_cnt + trade_len - trade_step - 1) // (trade_len - trade_step) * _amount_trade_unit @@ -262,9 +272,9 @@ class SBBStrategyBase(BaseStrategy): if _order_amount: _order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)]) if trade_step % 2 == 0: - # in the first of two adjacent bar + # in the first one of two adjacent bars # if look short on the price, sell the stock more - # if look long on the price, sell the stock more + # if look long on the price, buy the stock more if ( _pred_trend == self.TREND_SHORT and order.direction == order.SELL @@ -281,7 +291,7 @@ class SBBStrategyBase(BaseStrategy): ) order_list.append(_order) else: - # in the second of two adjacent bar + # in the second one of two adjacent bars # if look short on the price, buy the stock more # if look long on the price, sell the stock more if ( @@ -301,6 +311,7 @@ class SBBStrategyBase(BaseStrategy): order_list.append(_order) if trade_step % 2 == 0: + # in the first one of two adjacent bars, store the trend for the second one to use self.trade_trend[(order.stock_id, order.direction)] = _pred_trend return order_list @@ -328,7 +339,7 @@ class SBBStrategyEMA(SBBStrategyBase): instruments of EMA signal, by default "csi300" freq : str, optional freq of EMA signal, by default "day" - Note: `freq` may be different from `steb_bar` + Note: `freq` may be different from `time_per_step` """ if instruments is None: warnings.warn("`instruments` is not set, will load all stocks") @@ -349,8 +360,10 @@ class SBBStrategyEMA(SBBStrategyBase): signal_df = convert_index_format(signal_df) signal_df.columns = ["signal"] self.signal = {} - for stock_id, stock_val in signal_df.groupby(level="instrument"): - self.signal[stock_id] = stock_val + + if not signal_df.empty: + for stock_id, stock_val in signal_df.groupby(level="instrument"): + self.signal[stock_id] = stock_val def reset_level_infra(self, level_infra): """ @@ -367,16 +380,19 @@ class SBBStrategyEMA(SBBStrategyBase): self._reset_signal() def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None): - + # if no signal, return mid trend if stock_id not in self.signal: return self.TREND_MID else: _sample_signal = resam_ts_data( self.signal[stock_id]["signal"], pred_start_time, pred_end_time, method="last" ) + # if EMA signal == 0 or None, return mid trend if _sample_signal is None or _sample_signal.iloc[0] == 0: return self.TREND_MID + # if EMA signal > 0, return long trend elif _sample_signal.iloc[0] > 0: return self.TREND_LONG + # if EMA signal > 0, return short trend else: return self.TREND_SHORT diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index 7828db609..f04bcb097 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from typing import Union from ..model.base import BaseModel from ..data.dataset import DatasetH @@ -141,8 +142,8 @@ class RLIntStrategy(RLStrategy): def __init__( self, policy, - state_interpreter: StateInterpreter, - action_interpreter: ActionInterpreter, + state_interpreter: Union[dict, StateInterpreter], + action_interpreter: Union[dict, ActionInterpreter], outer_trade_decision: object = None, level_infra: dict = {}, common_infra: dict = {}, @@ -151,9 +152,9 @@ class RLIntStrategy(RLStrategy): """ Parameters ---------- - state_interpreter : StateInterpreter - interpretor that interprets the qlib execute result into rl env state. - action_interpreter : ActionInterpreter + state_interpreter : Union[dict, StateInterpreter] + interpretor that interprets the qlib execute result into rl env state + action_interpreter : Union[dict, ActionInterpreter] interpretor that interprets the rl agent action into qlib order list start_time : Union[str, pd.Timestamp], optional start time of trading, by default None @@ -163,8 +164,8 @@ class RLIntStrategy(RLStrategy): super(RLIntStrategy, self).__init__(policy, outer_trade_decision, level_infra, common_infra, **kwargs) self.policy = policy - self.state_interpreter = init_instance_by_config(state_interpreter) - self.action_interpreter = init_instance_by_config(action_interpreter) + self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter) + self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter) def generate_trade_decision(self, execute_result=None): _interpret_state = self.state_interpretor.interpret(execute_result=execute_result) diff --git a/qlib/utils/resam.py b/qlib/utils/resam.py index cdac48533..026870077 100644 --- a/qlib/utils/resam.py +++ b/qlib/utils/resam.py @@ -288,11 +288,13 @@ def resam_ts_data( from ..data.dataset.utils import get_level_index feature = lazy_sort_index(ts_feature) + datetime_level = get_level_index(feature, level="datetime") == 0 if datetime_level: feature = feature.loc[selector_datetime] else: feature = feature.loc[(slice(None), selector_datetime)] + if feature.empty: return None if isinstance(feature.index, pd.MultiIndex): From 96e393b599c718931af3d8b5289c17c6aafbcf13 Mon Sep 17 00:00:00 2001 From: bxdd Date: Fri, 28 May 2021 22:32:33 +0800 Subject: [PATCH 2/4] del DEBUG log --- qlib/contrib/strategy/model_strategy.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/model_strategy.py index 3a1087be4..d563bccea 100644 --- a/qlib/contrib/strategy/model_strategy.py +++ b/qlib/contrib/strategy/model_strategy.py @@ -93,9 +93,6 @@ class TopkDropoutStrategy(ModelStrategy): trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") - print( - trade_step, pred_start_time, pred_end_time, trade_start_time, trade_end_time, pred_score, self.pred_scores - ) if pred_score is None: return [] if self.only_tradable: From bf3b757294772f635798b42e27064e75afe01558 Mon Sep 17 00:00:00 2001 From: bxdd Date: Sat, 29 May 2021 00:31:40 +0800 Subject: [PATCH 3/4] fix bugs --- examples/multi_level_trading/workflow.py | 12 ++++++------ qlib/backtest/report.py | 4 ++-- qlib/contrib/strategy/model_strategy.py | 2 -- qlib/utils/resam.py | 8 ++++---- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/examples/multi_level_trading/workflow.py b/examples/multi_level_trading/workflow.py index 08c91936a..531b88f64 100644 --- a/examples/multi_level_trading/workflow.py +++ b/examples/multi_level_trading/workflow.py @@ -173,11 +173,11 @@ class MultiLevelTradingWorkflow: GetData().qlib_data(target_dir=provider_uri_1min, interval="1min", region=REG_CN) # TODO: update new data - # provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir - # if not exists_qlib_data(provider_uri_day): - # print(f"Qlib data is not found in {provider_uri_day}") - # GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN) - provider_uri_day = "/data/csdesign/qlib" + provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir + if not exists_qlib_data(provider_uri_day): + print(f"Qlib data is not found in {provider_uri_day}") + GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN) + provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day} client_config = { "calendar_provider": { @@ -210,7 +210,7 @@ class MultiLevelTradingWorkflow: executor_config = self.port_analysis_config["executor"] # update executor with hierarchical decison freq ["day", "1min"] executor_config["kwargs"]["time_per_step"] = "day" - executor_config["kwargs"]["inner_executor"]["kwargs"]["time_per_step"] = "1min" + executor_config["kwargs"]["inner_executor"]["kwargs"]["time_per_step"] = "15min" backtest_config = self.port_analysis_config["backtest"] # yahoo highfreq data time diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index c26c46f9d..4b9b0ce26 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -80,12 +80,12 @@ class Report: fields = ["$close/Ref($close,1)-1"] try: _temp_result = D.features(_codes, fields, start_time, end_time, freq=freq, disk_cache=1) - except ValueError: + except (ValueError, KeyError): _, norm_freq = parse_freq(freq) if norm_freq in ["month", "week", "day"]: try: _temp_result = D.features(_codes, fields, start_time, end_time, freq="day", disk_cache=1) - except ValueError: + except (ValueError, KeyError): _temp_result = D.features(_codes, fields, start_time, end_time, freq="1min", disk_cache=1) elif norm_freq == "minute": _temp_result = D.features(_codes, fields, start_time, end_time, freq="1min", disk_cache=1) diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/model_strategy.py index d563bccea..9125329d4 100644 --- a/qlib/contrib/strategy/model_strategy.py +++ b/qlib/contrib/strategy/model_strategy.py @@ -177,8 +177,6 @@ class TopkDropoutStrategy(ModelStrategy): # Get the stock list we really want to buy buy = today[: len(sell) + self.topk - len(last)] - # print("INTRANEL BAR", len(sell), len(sell) + self.topk - len(last), len(last)) - # print("flag", len(sell), len(buy), self.topk, len(last)) for code in current_stock_list: if not self.trade_exchange.is_stock_tradable( stock_id=code, start_time=trade_start_time, end_time=trade_end_time diff --git a/qlib/utils/resam.py b/qlib/utils/resam.py index 026870077..71e0aa654 100644 --- a/qlib/utils/resam.py +++ b/qlib/utils/resam.py @@ -182,7 +182,7 @@ def get_resam_calendar( try: _calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq=freq, future=future) freq, freq_sam = freq, None - except ValueError: + except (ValueError, KeyError): freq_sam = freq if norm_freq in ["month", "week", "day"]: try: @@ -190,16 +190,16 @@ def get_resam_calendar( start_time=start_time, end_time=end_time, freq="day", freq_sam=freq, future=future ) freq = "day" - except ValueError: + except (ValueError, KeyError): _calendar = Cal.calendar( start_time=start_time, end_time=end_time, freq="1min", freq_sam=freq, future=future ) - freq = "min" + freq = "1min" elif norm_freq == "minute": _calendar = Cal.calendar( start_time=start_time, end_time=end_time, freq="1min", freq_sam=freq, future=future ) - freq = "min" + freq = "1min" else: raise ValueError(f"freq {freq} is not supported") return _calendar, freq, freq_sam From 60e082e44662da769d76a07e7d811b8818ca97bb Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 31 May 2021 20:40:11 +0800 Subject: [PATCH 4/4] add infra interface & fix no KeyboardInterpret bug --- .../README.md | 11 +--- .../workflow.py | 7 ++- qlib/backtest/__init__.py | 6 +-- qlib/backtest/executor.py | 20 +++---- qlib/backtest/utils.py | 44 +++++++++++++++ qlib/contrib/strategy/cost_control.py | 4 +- qlib/contrib/strategy/model_strategy.py | 12 ++--- qlib/contrib/strategy/rule_strategy.py | 54 ++++++++++--------- qlib/strategy/base.py | 33 +++++++----- qlib/workflow/utils.py | 1 + 10 files changed, 120 insertions(+), 72 deletions(-) rename examples/{multi_level_trading => nested_decision_execution}/README.md (67%) rename examples/{multi_level_trading => nested_decision_execution}/workflow.py (98%) diff --git a/examples/multi_level_trading/README.md b/examples/nested_decision_execution/README.md similarity index 67% rename from examples/multi_level_trading/README.md rename to examples/nested_decision_execution/README.md index 2910de58f..312f94d31 100644 --- a/examples/multi_level_trading/README.md +++ b/examples/nested_decision_execution/README.md @@ -1,13 +1,6 @@ -# Multi-level Trading - -This worflow is an example for multi-level trading. - -## Introduction - -Qlib supports backtesting of various strategies, including portfolio management strategies, order split strategies, model-based strategies (such as deep learning models), rule-based strategies, and RL-based strategies. - -And, Qlib also supports multi-level trading and backtesting. It means that users can use different strategies to trade at different frequencies. +# Nested Decision Execution +This worflow is an example for nested decision execution in backtesting. Qlib supports nested decision execution in backtesting. It means that users can use different strategies to make trade decision in different frequencies. ## Weekly Portfolio Generation and Daily Order Execution diff --git a/examples/multi_level_trading/workflow.py b/examples/nested_decision_execution/workflow.py similarity index 98% rename from examples/multi_level_trading/workflow.py rename to examples/nested_decision_execution/workflow.py index 531b88f64..b8e9e5fb5 100644 --- a/examples/multi_level_trading/workflow.py +++ b/examples/nested_decision_execution/workflow.py @@ -4,7 +4,6 @@ import qlib import fire -from qlib import backtest from qlib.config import REG_CN, HIGH_FREQ_CONFIG from qlib.data import D from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict @@ -14,7 +13,7 @@ from qlib.tests.data import GetData from qlib.backtest import collect_data -class MultiLevelTradingWorkflow: +class NestedDecisonExecutionWorkflow: market = "csi300" benchmark = "SH000300" @@ -172,7 +171,7 @@ class MultiLevelTradingWorkflow: print(f"Qlib data is not found in {provider_uri_1min}") GetData().qlib_data(target_dir=provider_uri_1min, interval="1min", region=REG_CN) - # TODO: update new data + # TODO: update latest data provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir if not exists_qlib_data(provider_uri_day): print(f"Qlib data is not found in {provider_uri_day}") @@ -260,4 +259,4 @@ class MultiLevelTradingWorkflow: if __name__ == "__main__": - fire.Fire(MultiLevelTradingWorkflow) + fire.Fire(NestedDecisonExecutionWorkflow) diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index 12db0a314..33c2cb2d8 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -7,6 +7,7 @@ from .executor import BaseExecutor from .backtest import backtest as backtest_func from .backtest import collect_data as data_generator +from .utils import CommonInfrastructure from ..strategy.base import BaseStrategy from ..utils import init_instance_by_config from ..log import get_module_logger @@ -101,10 +102,7 @@ def get_strategy_executor( ) trade_exchange = get_exchange(**exchange_kwargs) - common_infra = { - "trade_account": trade_account, - "trade_exchange": trade_exchange, - } + common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange) trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy, common_infra=common_infra) trade_executor = init_instance_by_config(executor, accept_types=BaseExecutor, common_infra=common_infra) diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index c51fc4d9d..1cc198bf6 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -9,7 +9,7 @@ from ..utils.resam import parse_freq from .order import Order from .exchange import Exchange -from .utils import TradeCalendarManager +from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure class BaseExecutor: @@ -23,7 +23,7 @@ class BaseExecutor: generate_report: bool = False, verbose: bool = False, track_data: bool = False, - common_infra: dict = {}, + common_infra: CommonInfrastructure = None, **kwargs, ): """ @@ -39,7 +39,7 @@ class BaseExecutor: whether to generate trade_decision, will be used when making data for multi-level training - If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will be generated by `collect_data` - Else, `trade_decision` will not be generated - common_infra : dict, optional: + common_infra : CommonInfrastructure, optional: common infrastructure for backtesting, may including: - trade_account : Account, optional trade account for trading @@ -63,11 +63,11 @@ class BaseExecutor: else: self.common_infra.update(common_infra) - if "trade_account" in common_infra: + if common_infra.has("trade_account"): self.trade_account = copy.copy(common_infra.get("trade_account")) self.trade_account.reset(freq=self.time_per_step, init_report=True) - def reset(self, track_data: bool = None, common_infra: dict = None, **kwargs): + def reset(self, track_data: bool = None, common_infra: CommonInfrastructure = None, **kwargs): """ - reset `start_time` and `end_time`, used in trade calendar - reset `track_data`, used when making data for multi-level training @@ -88,7 +88,7 @@ class BaseExecutor: self.reset_common_infra(common_infra) def get_level_infra(self): - return {"trade_calendar": self.trade_calendar} + return LevelInfrastructure(trade_calendar=self.trade_calendar) def finished(self): return self.trade_calendar.finished() @@ -138,7 +138,7 @@ class NestedExecutor(BaseExecutor): verbose: bool = False, track_data: bool = False, trade_exchange: Exchange = None, - common_infra: dict = {}, + common_infra: CommonInfrastructure = None, **kwargs, ): """ @@ -182,7 +182,7 @@ class NestedExecutor(BaseExecutor): """ super(NestedExecutor, self).reset_common_infra(common_infra) - if self.generate_report and "trade_exchange" in common_infra: + if self.generate_report and common_infra.has("trade_exchange"): self.trade_exchange = common_infra.get("trade_exchange") self.inner_executor.reset_common_infra(common_infra) @@ -257,7 +257,7 @@ class SimulatorExecutor(BaseExecutor): verbose: bool = False, track_data: bool = False, trade_exchange: Exchange = None, - common_infra: dict = {}, + common_infra: CommonInfrastructure = None, **kwargs, ): """ @@ -286,7 +286,7 @@ class SimulatorExecutor(BaseExecutor): - reset trade_exchange """ super(SimulatorExecutor, self).reset_common_infra(common_infra) - if "trade_exchange" in common_infra: + if common_infra.has("trade_exchange"): self.trade_exchange = common_infra.get("trade_exchange") def execute(self, trade_decision): diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py index f66fa091d..8582cfe28 100644 --- a/qlib/backtest/utils.py +++ b/qlib/backtest/utils.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import pandas as pd +import warnings from typing import Union from ..utils.resam import get_resam_calendar @@ -96,3 +97,46 @@ class TradeCalendarManager: def get_all_time(self): """Get the start_time and end_time for trading""" return self.start_time, self.end_time + + +class BaseInfrastructure: + def __init__(self, **kwargs): + self.reset_infra(**kwargs) + + def get_support_infra(self): + raise NotImplementedError("`get_support_infra` is not implemented!") + + def reset_infra(self, **kwargs): + support_infra = self.get_support_infra() + for k, v in kwargs.items(): + if k in support_infra: + setattr(self, k, v) + else: + warnings.warn(f"{k} is ignored in `reset_infra`!") + + def get(self, infra_name): + if hasattr(self, infra_name): + return getattr(self, infra_name) + else: + warnings.warn(f"infra {infra_name} is not found!") + + def has(self, infra_name): + if infra_name in self.get_support_infra() and hasattr(self, infra_name): + return True + else: + return False + + def update(self, other): + support_infra = other.get_support_infra() + infra_dict = {_infra: getattr(other, _infra) for _infra in support_infra if hasattr(other, _infra)} + self.reset_infra(**infra_dict) + + +class CommonInfrastructure(BaseInfrastructure): + def get_support_infra(self): + return ["trade_account", "trade_exchange"] + + +class LevelInfrastructure(BaseInfrastructure): + def get_support_infra(self): + return ["trade_calendar"] diff --git a/qlib/contrib/strategy/cost_control.py b/qlib/contrib/strategy/cost_control.py index e7f6cce04..88e35b2e4 100644 --- a/qlib/contrib/strategy/cost_control.py +++ b/qlib/contrib/strategy/cost_control.py @@ -18,8 +18,8 @@ class SoftTopkStrategy(WeightStrategyBase): risk_degree=0.95, buy_method="first_fill", trade_exchange=None, - level_infra={}, - common_infra={}, + level_infra=None, + common_infra=None, **kwargs, ): """Parameter diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/model_strategy.py index 9125329d4..ba1e3c785 100644 --- a/qlib/contrib/strategy/model_strategy.py +++ b/qlib/contrib/strategy/model_strategy.py @@ -22,8 +22,8 @@ class TopkDropoutStrategy(ModelStrategy): hold_thresh=1, only_tradable=False, trade_exchange=None, - level_infra={}, - common_infra={}, + level_infra=None, + common_infra=None, **kwargs, ): """ @@ -76,7 +76,7 @@ class TopkDropoutStrategy(ModelStrategy): """ super(TopkDropoutStrategy, self).reset_common_infra(common_infra) - if "trade_exchange" in common_infra: + if common_infra.has("trade_exchange"): self.trade_exchange = common_infra.get("trade_exchange") def get_risk_degree(self, trade_step=None): @@ -249,8 +249,8 @@ class WeightStrategyBase(ModelStrategy): dataset, order_generator_cls_or_obj=OrderGenWInteract, trade_exchange=None, - level_infra={}, - common_infra={}, + level_infra=None, + common_infra=None, **kwargs, ): super(WeightStrategyBase, self).__init__( @@ -274,7 +274,7 @@ class WeightStrategyBase(ModelStrategy): """ super(WeightStrategyBase, self).reset_common_infra(common_infra) - if "trade_exchange" in common_infra: + if common_infra.has("trade_exchange"): self.trade_exchange = common_infra.get("trade_exchange") def get_risk_degree(self, trade_step=None): diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index a85b81636..b72f32c29 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -1,4 +1,5 @@ import warnings +from typing import List, Union from ...utils.resam import resam_ts_data from ...data.data import D @@ -6,6 +7,7 @@ from ...data.dataset.utils import convert_index_format from ...strategy.base import BaseStrategy from ...backtest.order import Order from ...backtest.exchange import Exchange +from ...backtest.utils import CommonInfrastructure, LevelInfrastructure class TWAPStrategy(BaseStrategy): @@ -13,17 +15,20 @@ class TWAPStrategy(BaseStrategy): def __init__( self, - outer_trade_decision: object = None, + outer_trade_decision: List[Order] = None, trade_exchange: Exchange = None, - level_infra: dict = {}, - common_infra: dict = {}, + level_infra: LevelInfrastructure = None, + common_infra: CommonInfrastructure = None, ): """ Parameters ---------- + outer_trade_decision : List[Order] + the trade decison of outer strategy which this startegy relies, it should be List[Order] in TWAPStrategy trade_exchange : Exchange exchange that provides market info, used to deal order and generate report - If `trade_exchange` is None, self.trade_exchange will be set with common_infra + """ super(TWAPStrategy, self).__init__( outer_trade_decision=outer_trade_decision, level_infra=level_infra, common_infra=common_infra @@ -36,21 +41,21 @@ class TWAPStrategy(BaseStrategy): """ Parameters ---------- - common_infra : dict, optional + common_infra : CommonInfrastructure, optional common infrastructure for backtesting, by default None - It should include `trade_account`, used to get position - It should include `trade_exchange`, used to provide market info """ super(TWAPStrategy, self).reset_common_infra(common_infra) - if common_infra is not None: - if "trade_exchange" in common_infra: - self.trade_exchange = common_infra.get("trade_exchange") - def reset(self, outer_trade_decision: object = None, **kwargs): + if common_infra.has("trade_exchange"): + self.trade_exchange = common_infra.get("trade_exchange") + + def reset(self, outer_trade_decision: List[Order] = None, **kwargs): """ Parameters ---------- - outer_trade_decision : object, optional + outer_trade_decision : List[Order], optional """ super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs) @@ -127,14 +132,16 @@ class SBBStrategyBase(BaseStrategy): def __init__( self, - outer_trade_decision: object = None, + outer_trade_decision: List[Order] = None, trade_exchange: Exchange = None, - level_infra: dict = {}, - common_infra: dict = {}, + level_infra: LevelInfrastructure = None, + common_infra: CommonInfrastructure = None, ): """ Parameters ---------- + outer_trade_decision : List[Order] + the trade decison of outer strategy which this startegy relies, it should be List[Order] in SBBStrategyBase trade_exchange : Exchange exchange that provides market info, used to deal order and generate report - If `trade_exchange` is None, self.trade_exchange will be set with common_infra @@ -156,15 +163,14 @@ class SBBStrategyBase(BaseStrategy): - It should include `trade_exchange`, used to provide market info """ super(SBBStrategyBase, self).reset_common_infra(common_infra) - if common_infra is not None: - if "trade_exchange" in common_infra: - self.trade_exchange = common_infra.get("trade_exchange") + if common_infra.has("trade_exchange"): + self.trade_exchange = common_infra.get("trade_exchange") - def reset(self, outer_trade_decision=None, **kwargs): + def reset(self, outer_trade_decision: List[Order] = None, **kwargs): """ Parameters ---------- - outer_trade_decision : object, optional + outer_trade_decision : List[Order], optional """ super(SBBStrategyBase, self).reset(outer_trade_decision=outer_trade_decision, **kwargs) if outer_trade_decision is not None: @@ -324,18 +330,18 @@ class SBBStrategyEMA(SBBStrategyBase): def __init__( self, - outer_trade_decision=[], - instruments="csi300", - freq="day", + outer_trade_decision: List[Order] = None, + instruments: Union[List, str] = "csi300", + freq: str = "day", trade_exchange: Exchange = None, - level_infra={}, - common_infra={}, + level_infra: LevelInfrastructure = None, + common_infra: CommonInfrastructure = None, **kwargs, ): """ Parameters ---------- - instruments : str, optional + instruments : Union[List, str], optional instruments of EMA signal, by default "csi300" freq : str, optional freq of EMA signal, by default "day" @@ -375,7 +381,7 @@ class SBBStrategyEMA(SBBStrategyBase): else: self.level_infra.update(level_infra) - if "trade_calendar" in level_infra: + if level_infra.has("trade_calendar"): self.trade_calendar = level_infra.get("trade_calendar") self._reset_signal() diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index f04bcb097..9d3e0c72b 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -7,6 +7,7 @@ from ..data.dataset import DatasetH from ..data.dataset.utils import convert_index_format from ..rl.interpreter import ActionInterpreter, StateInterpreter from ..utils import init_instance_by_config +from ..backtest.utils import CommonInfrastructure, LevelInfrastructure class BaseStrategy: @@ -15,8 +16,8 @@ class BaseStrategy: def __init__( self, outer_trade_decision: object = None, - level_infra: dict = {}, - common_infra: dict = {}, + level_infra: LevelInfrastructure = None, + common_infra: CommonInfrastructure = None, ): """ Parameters @@ -25,9 +26,9 @@ class BaseStrategy: the trade decison of outer strategy which this startegy relies, and it will be traded in [start_time, end_time], by default None - If the strategy is used to split trade decison, it will be used - If the strategy is used for portfolio management, it can be ignored - level_infra : dict, optional + level_infra : LevelInfrastructure, optional level shared infrastructure for backtesting, including trade calendar - common_infra : dict, optional + common_infra : CommonInfrastructure, optional common infrastructure for backtesting, including trade_account, trade_exchange, .etc """ @@ -39,7 +40,7 @@ class BaseStrategy: else: self.level_infra.update(level_infra) - if "trade_calendar" in level_infra: + if level_infra.has("trade_calendar"): self.trade_calendar = level_infra.get("trade_calendar") def reset_common_infra(self, common_infra): @@ -48,10 +49,16 @@ class BaseStrategy: else: self.common_infra.update(common_infra) - if "trade_account" in common_infra: + if common_infra.has("trade_account"): self.trade_position = common_infra.get("trade_account").current - def reset(self, level_infra: dict = None, common_infra: dict = None, outer_trade_decision=None, **kwargs): + def reset( + self, + level_infra: LevelInfrastructure = None, + common_infra: CommonInfrastructure = None, + outer_trade_decision=None, + **kwargs, + ): """ - reset `level_infra`, used to reset trade calendar, .etc - reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc @@ -86,8 +93,8 @@ class ModelStrategy(BaseStrategy): model: BaseModel, dataset: DatasetH, outer_trade_decision: object = None, - level_infra: dict = {}, - common_infra: dict = {}, + level_infra: LevelInfrastructure = None, + common_infra: CommonInfrastructure = None, **kwargs, ): """ @@ -122,8 +129,8 @@ class RLStrategy(BaseStrategy): self, policy, outer_trade_decision: object = None, - level_infra: dict = {}, - common_infra: dict = {}, + level_infra: LevelInfrastructure = None, + common_infra: CommonInfrastructure = None, **kwargs, ): """ @@ -145,8 +152,8 @@ class RLIntStrategy(RLStrategy): state_interpreter: Union[dict, StateInterpreter], action_interpreter: Union[dict, ActionInterpreter], outer_trade_decision: object = None, - level_infra: dict = {}, - common_infra: dict = {}, + level_infra: LevelInfrastructure = None, + common_infra: CommonInfrastructure = None, **kwargs, ): """ diff --git a/qlib/workflow/utils.py b/qlib/workflow/utils.py index 596ff0927..cd87187e9 100644 --- a/qlib/workflow/utils.py +++ b/qlib/workflow/utils.py @@ -46,3 +46,4 @@ def experiment_kill_signal_handler(signum, frame): End an experiment when user kill the program through keyboard (CTRL+C, etc.). """ R.end_exp(recorder_status=Recorder.STATUS_FA) + raise KeyboardInterrupt