diff --git a/examples/multi_level_trading/README.md b/examples/multi_level_trading/README.md index 6761b84ff..2910de58f 100644 --- a/examples/multi_level_trading/README.md +++ b/examples/multi_level_trading/README.md @@ -8,9 +8,12 @@ Qlib supports backtesting of various strategies, including portfolio management And, Qlib also supports multi-level trading and backtesting. It means that users can use different strategies to trade at different frequencies. -This example uses a DropoutTopkStrategy (a strategy based on the daily frequency Lightgbm model) in weekly frequency for portfolio generation. And, at the daily frequency level, this example uses SBBStrategyEMA (a rule-based strategy that uses EMA for decision-making) to split orders. -## Usage +## Weekly Portfolio Generation and Daily Order Execution + +This workflow provides an example that uses a DropoutTopkStrategy (a strategy based on the daily frequency Lightgbm model) in weekly frequency for portfolio generation and uses SBBStrategyEMA (a rule-based strategy that uses EMA for decision-making) to execute orders in daily frequency. + +### Usage Start backtesting by running the following command: ```bash @@ -22,3 +25,13 @@ Start collecting data by running the following command: python workflow.py collect_data ``` +## Daily Portfolio Generation and Minutely Order Execution + +This workflow also provides a high-frequency example that uses a DropoutTopkStrategy for portfolio generation in daily frequency and uses SBBStrategyEMA to execute orders in minutely frequency. + +### Usage + +Start backtesting by running the following command: +```bash + python workflow.py backtest_highfreq +``` \ No newline at end of file diff --git a/examples/multi_level_trading/workflow.py b/examples/multi_level_trading/workflow.py index 8096fc76f..08c91936a 100644 --- a/examples/multi_level_trading/workflow.py +++ b/examples/multi_level_trading/workflow.py @@ -4,8 +4,9 @@ import qlib import fire -from qlib.config import REG_CN - +from qlib import backtest +from qlib.config import REG_CN, HIGH_FREQ_CONFIG +from qlib.data import D from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord, PortAnaRecord @@ -20,7 +21,7 @@ class MultiLevelTradingWorkflow: data_handler_config = { "start_time": "2008-01-01", - "end_time": "2020-08-01", + "end_time": "2021-01-20", "fit_start_time": "2008-01-01", "fit_end_time": "2014-12-31", "instruments": market, @@ -54,15 +55,12 @@ class MultiLevelTradingWorkflow: "segments": { "train": ("2008-01-01", "2014-12-31"), "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), + "test": ("2017-01-01", "2021-01-20"), }, }, }, } - trade_start_time = "2017-01-01" - trade_end_time = "2020-08-01" - port_analysis_config = { "executor": { "class": "NestedExecutor", @@ -86,12 +84,13 @@ class MultiLevelTradingWorkflow: "instruments": market, }, }, + "generate_report": True, "track_data": True, }, }, "backtest": { - "start_time": trade_start_time, - "end_time": trade_end_time, + "start_time": "2017-01-01", + "end_time": "2020-08-01", "account": 100000000, "benchmark": benchmark, "exchange_kwargs": { @@ -167,6 +166,98 @@ class MultiLevelTradingWorkflow: for trade_decision in data_generator: print(trade_decision) + def _init_qlib_with_backend(self): + provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri") + if not exists_qlib_data(provider_uri_1min): + print(f"Qlib data is not found in {provider_uri_1min}") + GetData().qlib_data(target_dir=provider_uri_1min, interval="1min", region=REG_CN) + + # TODO: update new data + # provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir + # if not exists_qlib_data(provider_uri_day): + # print(f"Qlib data is not found in {provider_uri_day}") + # GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN) + provider_uri_day = "/data/csdesign/qlib" + provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day} + client_config = { + "calendar_provider": { + "class": "LocalCalendarProvider", + "module_path": "qlib.data.data", + "kwargs": { + "backend": { + "class": "FileCalendarStorage", + "module_path": "qlib.data.storage.file_storage", + "kwargs": {"provider_uri_map": provider_uri_map}, + } + }, + }, + "feature_provider": { + "class": "LocalFeatureProvider", + "module_path": "qlib.data.data", + "kwargs": { + "backend": { + "class": "FileFeatureStorage", + "module_path": "qlib.data.storage.file_storage", + "kwargs": {"provider_uri_map": provider_uri_map}, + } + }, + }, + } + qlib.init(provider_uri=provider_uri_day, **client_config) + + def _get_highfreq_config(self, model, dataset): + + executor_config = self.port_analysis_config["executor"] + # update executor with hierarchical decison freq ["day", "1min"] + executor_config["kwargs"]["time_per_step"] = "day" + executor_config["kwargs"]["inner_executor"]["kwargs"]["time_per_step"] = "1min" + backtest_config = self.port_analysis_config["backtest"] + + # yahoo highfreq data time + backtest_config["start_time"] = "2020-09-20" + backtest_config["end_time"] = "2021-01-20" + + # update benchmark, yahoo data don't have SH000300 + instruments = D.instruments(market="csi300") + instrument_list = D.list_instruments(instruments=instruments, as_list=True) + backtest_config["benchmark"] = instrument_list + + # update exchange config + backtest_config["exchange_kwargs"]["freq"] = "1min" + + # set strategy + strategy_config = { + "class": "TopkDropoutStrategy", + "module_path": "qlib.contrib.strategy.model_strategy", + "kwargs": { + "model": model, + "dataset": dataset, + "topk": 50, + "n_drop": 5, + }, + } + + return executor_config, strategy_config, backtest_config + + def backtest_highfreq(self): + self._init_qlib_with_backend() + model = init_instance_by_config(self.task["model"]) + dataset = init_instance_by_config(self.task["dataset"]) + self._train_model(model, dataset) + executor_config, strategy_config, backtest_config = self._get_highfreq_config(model, dataset) + + highfreq_port_analysis_config = { + "executor": executor_config, + "strategy": strategy_config, + "backtest": backtest_config, + } + + with R.start(experiment_name="backtest_highfreq"): + + recorder = R.get_recorder() + par = PortAnaRecord(recorder, highfreq_port_analysis_config, "day") + par.generate() + if __name__ == "__main__": fire.Fire(MultiLevelTradingWorkflow) diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index 88a219f41..c51fc4d9d 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -304,7 +304,7 @@ class SimulatorExecutor(BaseExecutor): if self.verbose: if order.direction == Order.SELL: # sell print( - "[I {:%Y-%m-%d}]: sell {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format( + "[I {:%Y-%m-%d %H:%M:%S}]: sell {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format( trade_start_time, order.stock_id, trade_price, @@ -316,7 +316,7 @@ class SimulatorExecutor(BaseExecutor): ) else: print( - "[I {:%Y-%m-%d}]: buy {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format( + "[I {:%Y-%m-%d %H:%M:%S}]: buy {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format( trade_start_time, order.stock_id, trade_price, diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py index fe51c99f3..f66fa091d 100644 --- a/qlib/backtest/utils.py +++ b/qlib/backtest/utils.py @@ -41,7 +41,7 @@ class TradeCalendarManager: - self.trade_step : The number of trading step finished, self.trade_step can be [0, 1, 2, ..., self.trade_len - 1] """ _calendar, freq, freq_sam = get_resam_calendar(freq=freq) - self.trade_calendar = _calendar + self._calendar = _calendar _, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam) self.start_index = _start_index self.end_index = _end_index @@ -91,7 +91,7 @@ class TradeCalendarManager: """ trade_step = trade_step - shift calendar_index = self.start_index + trade_step - return self.trade_calendar[calendar_index], self.trade_calendar[calendar_index + 1] - pd.Timedelta(seconds=1) + return self._calendar[calendar_index], self._calendar[calendar_index + 1] - pd.Timedelta(seconds=1) def get_all_time(self): """Get the start_time and end_time for trading""" diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/model_strategy.py index d563bccea..3a1087be4 100644 --- a/qlib/contrib/strategy/model_strategy.py +++ b/qlib/contrib/strategy/model_strategy.py @@ -93,6 +93,9 @@ class TopkDropoutStrategy(ModelStrategy): trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") + print( + trade_step, pred_start_time, pred_end_time, trade_start_time, trade_end_time, pred_score, self.pred_scores + ) if pred_score is None: return [] if self.only_tradable: diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 24873caae..a85b81636 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -53,7 +53,7 @@ class TWAPStrategy(BaseStrategy): outer_trade_decision : object, optional """ - super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, common_infra=common_infra, **kwargs) + super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs) if outer_trade_decision is not None: self.trade_amount = {} for order in outer_trade_decision: @@ -73,21 +73,24 @@ class TWAPStrategy(BaseStrategy): trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) order_list = [] for order in self.outer_trade_decision: + # if not tradable, continue if not self.trade_exchange.is_stock_tradable( stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time ): continue _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor) _order_amount = None - # consider trade unit + # considering trade unit if _amount_trade_unit is None: - # divide the order equally + # divide the order into equal parts, and trade one part _order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step + 1) # without considering trade unit elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit: - # divide the order equally - # floor((trade_unit_cnt + trade_len - trade_step) / (trade_len - trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - trade_step + 1)) + # divide the order into equal parts, and trade one part + # calculate the total count of trade units to trade trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) + # calculate the amount of one part, ceil the amount + # floor((trade_unit_cnt + trade_len - trade_step) / (trade_len - trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - trade_step + 1)) _order_amount = ( (trade_unit_cnt + trade_len - trade_step) // (trade_len - trade_step + 1) * _amount_trade_unit ) @@ -144,6 +147,14 @@ class SBBStrategyBase(BaseStrategy): self.trade_exchange = trade_exchange def reset_common_infra(self, common_infra): + """ + Parameters + ---------- + common_infra : dict, optional + common infrastructure for backtesting, by default None + - It should include `trade_account`, used to get position + - It should include `trade_exchange`, used to provide market info + """ super(SBBStrategyBase, self).reset_common_infra(common_infra) if common_infra is not None: if "trade_exchange" in common_infra: @@ -154,10 +165,6 @@ class SBBStrategyBase(BaseStrategy): Parameters ---------- outer_trade_decision : object, optional - common_infra : None, optional - common infrastructure for backtesting, by default None - - It should include `trade_account`, used to get position - - It should include `trade_exchange`, used to provide market info """ super(SBBStrategyBase, self).reset(outer_trade_decision=outer_trade_decision, **kwargs) if outer_trade_decision is not None: @@ -186,10 +193,12 @@ class SBBStrategyBase(BaseStrategy): order_list = [] # for each order in in self.outer_trade_decision for order in self.outer_trade_decision: - # predict the price trend + # get the price trend if trade_step % 2 == 0: + # in the first of two adjacent bars, predict the price trend _pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time) else: + # in the second of two adjacent bars, use the trend predicted in the first one _pred_trend = self.trade_trend[(order.stock_id, order.direction)] # if not tradable, continue if not self.trade_exchange.is_stock_tradable( @@ -204,13 +213,14 @@ class SBBStrategyBase(BaseStrategy): _order_amount = None # considering trade unit if _amount_trade_unit is None: - # divide the order equally + # divide the order into equal parts, and trade one part _order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step) # without considering trade unit elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit: - # cal how many trade unit + # divide the order into equal parts, and trade one part + # calculate the total count of trade units to trade trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) - # divide the order equally + # calculate the amount of one part, ceil the amount # floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step)) _order_amount = ( (trade_unit_cnt + trade_len - trade_step - 1) // (trade_len - trade_step) * _amount_trade_unit @@ -262,9 +272,9 @@ class SBBStrategyBase(BaseStrategy): if _order_amount: _order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)]) if trade_step % 2 == 0: - # in the first of two adjacent bar + # in the first one of two adjacent bars # if look short on the price, sell the stock more - # if look long on the price, sell the stock more + # if look long on the price, buy the stock more if ( _pred_trend == self.TREND_SHORT and order.direction == order.SELL @@ -281,7 +291,7 @@ class SBBStrategyBase(BaseStrategy): ) order_list.append(_order) else: - # in the second of two adjacent bar + # in the second one of two adjacent bars # if look short on the price, buy the stock more # if look long on the price, sell the stock more if ( @@ -301,6 +311,7 @@ class SBBStrategyBase(BaseStrategy): order_list.append(_order) if trade_step % 2 == 0: + # in the first one of two adjacent bars, store the trend for the second one to use self.trade_trend[(order.stock_id, order.direction)] = _pred_trend return order_list @@ -328,7 +339,7 @@ class SBBStrategyEMA(SBBStrategyBase): instruments of EMA signal, by default "csi300" freq : str, optional freq of EMA signal, by default "day" - Note: `freq` may be different from `steb_bar` + Note: `freq` may be different from `time_per_step` """ if instruments is None: warnings.warn("`instruments` is not set, will load all stocks") @@ -349,8 +360,10 @@ class SBBStrategyEMA(SBBStrategyBase): signal_df = convert_index_format(signal_df) signal_df.columns = ["signal"] self.signal = {} - for stock_id, stock_val in signal_df.groupby(level="instrument"): - self.signal[stock_id] = stock_val + + if not signal_df.empty: + for stock_id, stock_val in signal_df.groupby(level="instrument"): + self.signal[stock_id] = stock_val def reset_level_infra(self, level_infra): """ @@ -367,16 +380,19 @@ class SBBStrategyEMA(SBBStrategyBase): self._reset_signal() def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None): - + # if no signal, return mid trend if stock_id not in self.signal: return self.TREND_MID else: _sample_signal = resam_ts_data( self.signal[stock_id]["signal"], pred_start_time, pred_end_time, method="last" ) + # if EMA signal == 0 or None, return mid trend if _sample_signal is None or _sample_signal.iloc[0] == 0: return self.TREND_MID + # if EMA signal > 0, return long trend elif _sample_signal.iloc[0] > 0: return self.TREND_LONG + # if EMA signal > 0, return short trend else: return self.TREND_SHORT diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index 7828db609..f04bcb097 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from typing import Union from ..model.base import BaseModel from ..data.dataset import DatasetH @@ -141,8 +142,8 @@ class RLIntStrategy(RLStrategy): def __init__( self, policy, - state_interpreter: StateInterpreter, - action_interpreter: ActionInterpreter, + state_interpreter: Union[dict, StateInterpreter], + action_interpreter: Union[dict, ActionInterpreter], outer_trade_decision: object = None, level_infra: dict = {}, common_infra: dict = {}, @@ -151,9 +152,9 @@ class RLIntStrategy(RLStrategy): """ Parameters ---------- - state_interpreter : StateInterpreter - interpretor that interprets the qlib execute result into rl env state. - action_interpreter : ActionInterpreter + state_interpreter : Union[dict, StateInterpreter] + interpretor that interprets the qlib execute result into rl env state + action_interpreter : Union[dict, ActionInterpreter] interpretor that interprets the rl agent action into qlib order list start_time : Union[str, pd.Timestamp], optional start time of trading, by default None @@ -163,8 +164,8 @@ class RLIntStrategy(RLStrategy): super(RLIntStrategy, self).__init__(policy, outer_trade_decision, level_infra, common_infra, **kwargs) self.policy = policy - self.state_interpreter = init_instance_by_config(state_interpreter) - self.action_interpreter = init_instance_by_config(action_interpreter) + self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter) + self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter) def generate_trade_decision(self, execute_result=None): _interpret_state = self.state_interpretor.interpret(execute_result=execute_result) diff --git a/qlib/utils/resam.py b/qlib/utils/resam.py index cdac48533..026870077 100644 --- a/qlib/utils/resam.py +++ b/qlib/utils/resam.py @@ -288,11 +288,13 @@ def resam_ts_data( from ..data.dataset.utils import get_level_index feature = lazy_sort_index(ts_feature) + datetime_level = get_level_index(feature, level="datetime") == 0 if datetime_level: feature = feature.loc[selector_datetime] else: feature = feature.loc[(slice(None), selector_datetime)] + if feature.empty: return None if isinstance(feature.index, pd.MultiIndex):