mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 14:01:28 +08:00
fix bugs & add highfreq backtest example
This commit is contained in:
@@ -8,9 +8,12 @@ Qlib supports backtesting of various strategies, including portfolio management
|
||||
|
||||
And, Qlib also supports multi-level trading and backtesting. It means that users can use different strategies to trade at different frequencies.
|
||||
|
||||
This example uses a DropoutTopkStrategy (a strategy based on the daily frequency Lightgbm model) in weekly frequency for portfolio generation. And, at the daily frequency level, this example uses SBBStrategyEMA (a rule-based strategy that uses EMA for decision-making) to split orders.
|
||||
|
||||
## Usage
|
||||
## Weekly Portfolio Generation and Daily Order Execution
|
||||
|
||||
This workflow provides an example that uses a DropoutTopkStrategy (a strategy based on the daily frequency Lightgbm model) in weekly frequency for portfolio generation and uses SBBStrategyEMA (a rule-based strategy that uses EMA for decision-making) to execute orders in daily frequency.
|
||||
|
||||
### Usage
|
||||
|
||||
Start backtesting by running the following command:
|
||||
```bash
|
||||
@@ -22,3 +25,13 @@ Start collecting data by running the following command:
|
||||
python workflow.py collect_data
|
||||
```
|
||||
|
||||
## Daily Portfolio Generation and Minutely Order Execution
|
||||
|
||||
This workflow also provides a high-frequency example that uses a DropoutTopkStrategy for portfolio generation in daily frequency and uses SBBStrategyEMA to execute orders in minutely frequency.
|
||||
|
||||
### Usage
|
||||
|
||||
Start backtesting by running the following command:
|
||||
```bash
|
||||
python workflow.py backtest_highfreq
|
||||
```
|
||||
@@ -4,8 +4,9 @@
|
||||
|
||||
import qlib
|
||||
import fire
|
||||
from qlib.config import REG_CN
|
||||
|
||||
from qlib import backtest
|
||||
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
|
||||
from qlib.data import D
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
@@ -20,7 +21,7 @@ class MultiLevelTradingWorkflow:
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"end_time": "2021-01-20",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
@@ -54,15 +55,12 @@ class MultiLevelTradingWorkflow:
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
"test": ("2017-01-01", "2021-01-20"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
trade_start_time = "2017-01-01"
|
||||
trade_end_time = "2020-08-01"
|
||||
|
||||
port_analysis_config = {
|
||||
"executor": {
|
||||
"class": "NestedExecutor",
|
||||
@@ -86,12 +84,13 @@ class MultiLevelTradingWorkflow:
|
||||
"instruments": market,
|
||||
},
|
||||
},
|
||||
"generate_report": True,
|
||||
"track_data": True,
|
||||
},
|
||||
},
|
||||
"backtest": {
|
||||
"start_time": trade_start_time,
|
||||
"end_time": trade_end_time,
|
||||
"start_time": "2017-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"account": 100000000,
|
||||
"benchmark": benchmark,
|
||||
"exchange_kwargs": {
|
||||
@@ -167,6 +166,98 @@ class MultiLevelTradingWorkflow:
|
||||
for trade_decision in data_generator:
|
||||
print(trade_decision)
|
||||
|
||||
def _init_qlib_with_backend(self):
|
||||
provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri")
|
||||
if not exists_qlib_data(provider_uri_1min):
|
||||
print(f"Qlib data is not found in {provider_uri_1min}")
|
||||
GetData().qlib_data(target_dir=provider_uri_1min, interval="1min", region=REG_CN)
|
||||
|
||||
# TODO: update new data
|
||||
# provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
# if not exists_qlib_data(provider_uri_day):
|
||||
# print(f"Qlib data is not found in {provider_uri_day}")
|
||||
# GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN)
|
||||
provider_uri_day = "/data/csdesign/qlib"
|
||||
provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day}
|
||||
client_config = {
|
||||
"calendar_provider": {
|
||||
"class": "LocalCalendarProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileCalendarStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
}
|
||||
},
|
||||
},
|
||||
"feature_provider": {
|
||||
"class": "LocalFeatureProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileFeatureStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri_day, **client_config)
|
||||
|
||||
def _get_highfreq_config(self, model, dataset):
|
||||
|
||||
executor_config = self.port_analysis_config["executor"]
|
||||
# update executor with hierarchical decison freq ["day", "1min"]
|
||||
executor_config["kwargs"]["time_per_step"] = "day"
|
||||
executor_config["kwargs"]["inner_executor"]["kwargs"]["time_per_step"] = "1min"
|
||||
backtest_config = self.port_analysis_config["backtest"]
|
||||
|
||||
# yahoo highfreq data time
|
||||
backtest_config["start_time"] = "2020-09-20"
|
||||
backtest_config["end_time"] = "2021-01-20"
|
||||
|
||||
# update benchmark, yahoo data don't have SH000300
|
||||
instruments = D.instruments(market="csi300")
|
||||
instrument_list = D.list_instruments(instruments=instruments, as_list=True)
|
||||
backtest_config["benchmark"] = instrument_list
|
||||
|
||||
# update exchange config
|
||||
backtest_config["exchange_kwargs"]["freq"] = "1min"
|
||||
|
||||
# set strategy
|
||||
strategy_config = {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"kwargs": {
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
}
|
||||
|
||||
return executor_config, strategy_config, backtest_config
|
||||
|
||||
def backtest_highfreq(self):
|
||||
self._init_qlib_with_backend()
|
||||
model = init_instance_by_config(self.task["model"])
|
||||
dataset = init_instance_by_config(self.task["dataset"])
|
||||
self._train_model(model, dataset)
|
||||
executor_config, strategy_config, backtest_config = self._get_highfreq_config(model, dataset)
|
||||
|
||||
highfreq_port_analysis_config = {
|
||||
"executor": executor_config,
|
||||
"strategy": strategy_config,
|
||||
"backtest": backtest_config,
|
||||
}
|
||||
|
||||
with R.start(experiment_name="backtest_highfreq"):
|
||||
|
||||
recorder = R.get_recorder()
|
||||
par = PortAnaRecord(recorder, highfreq_port_analysis_config, "day")
|
||||
par.generate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(MultiLevelTradingWorkflow)
|
||||
|
||||
@@ -304,7 +304,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
if self.verbose:
|
||||
if order.direction == Order.SELL: # sell
|
||||
print(
|
||||
"[I {:%Y-%m-%d}]: sell {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format(
|
||||
"[I {:%Y-%m-%d %H:%M:%S}]: sell {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format(
|
||||
trade_start_time,
|
||||
order.stock_id,
|
||||
trade_price,
|
||||
@@ -316,7 +316,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"[I {:%Y-%m-%d}]: buy {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format(
|
||||
"[I {:%Y-%m-%d %H:%M:%S}]: buy {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format(
|
||||
trade_start_time,
|
||||
order.stock_id,
|
||||
trade_price,
|
||||
|
||||
@@ -41,7 +41,7 @@ class TradeCalendarManager:
|
||||
- self.trade_step : The number of trading step finished, self.trade_step can be [0, 1, 2, ..., self.trade_len - 1]
|
||||
"""
|
||||
_calendar, freq, freq_sam = get_resam_calendar(freq=freq)
|
||||
self.trade_calendar = _calendar
|
||||
self._calendar = _calendar
|
||||
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam)
|
||||
self.start_index = _start_index
|
||||
self.end_index = _end_index
|
||||
@@ -91,7 +91,7 @@ class TradeCalendarManager:
|
||||
"""
|
||||
trade_step = trade_step - shift
|
||||
calendar_index = self.start_index + trade_step
|
||||
return self.trade_calendar[calendar_index], self.trade_calendar[calendar_index + 1] - pd.Timedelta(seconds=1)
|
||||
return self._calendar[calendar_index], self._calendar[calendar_index + 1] - pd.Timedelta(seconds=1)
|
||||
|
||||
def get_all_time(self):
|
||||
"""Get the start_time and end_time for trading"""
|
||||
|
||||
@@ -93,6 +93,9 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
|
||||
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
print(
|
||||
trade_step, pred_start_time, pred_end_time, trade_start_time, trade_end_time, pred_score, self.pred_scores
|
||||
)
|
||||
if pred_score is None:
|
||||
return []
|
||||
if self.only_tradable:
|
||||
|
||||
@@ -53,7 +53,7 @@ class TWAPStrategy(BaseStrategy):
|
||||
outer_trade_decision : object, optional
|
||||
"""
|
||||
|
||||
super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, common_infra=common_infra, **kwargs)
|
||||
super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
self.trade_amount = {}
|
||||
for order in outer_trade_decision:
|
||||
@@ -73,21 +73,24 @@ class TWAPStrategy(BaseStrategy):
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
order_list = []
|
||||
for order in self.outer_trade_decision:
|
||||
# if not tradable, continue
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
|
||||
):
|
||||
continue
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
|
||||
_order_amount = None
|
||||
# consider trade unit
|
||||
# considering trade unit
|
||||
if _amount_trade_unit is None:
|
||||
# divide the order equally
|
||||
# divide the order into equal parts, and trade one part
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step + 1)
|
||||
# without considering trade unit
|
||||
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
|
||||
# divide the order equally
|
||||
# floor((trade_unit_cnt + trade_len - trade_step) / (trade_len - trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - trade_step + 1))
|
||||
# divide the order into equal parts, and trade one part
|
||||
# calculate the total count of trade units to trade
|
||||
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
|
||||
# calculate the amount of one part, ceil the amount
|
||||
# floor((trade_unit_cnt + trade_len - trade_step) / (trade_len - trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - trade_step + 1))
|
||||
_order_amount = (
|
||||
(trade_unit_cnt + trade_len - trade_step) // (trade_len - trade_step + 1) * _amount_trade_unit
|
||||
)
|
||||
@@ -144,6 +147,14 @@ class SBBStrategyBase(BaseStrategy):
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
common_infra : dict, optional
|
||||
common infrastructure for backtesting, by default None
|
||||
- It should include `trade_account`, used to get position
|
||||
- It should include `trade_exchange`, used to provide market info
|
||||
"""
|
||||
super(SBBStrategyBase, self).reset_common_infra(common_infra)
|
||||
if common_infra is not None:
|
||||
if "trade_exchange" in common_infra:
|
||||
@@ -154,10 +165,6 @@ class SBBStrategyBase(BaseStrategy):
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : object, optional
|
||||
common_infra : None, optional
|
||||
common infrastructure for backtesting, by default None
|
||||
- It should include `trade_account`, used to get position
|
||||
- It should include `trade_exchange`, used to provide market info
|
||||
"""
|
||||
super(SBBStrategyBase, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
@@ -186,10 +193,12 @@ class SBBStrategyBase(BaseStrategy):
|
||||
order_list = []
|
||||
# for each order in in self.outer_trade_decision
|
||||
for order in self.outer_trade_decision:
|
||||
# predict the price trend
|
||||
# get the price trend
|
||||
if trade_step % 2 == 0:
|
||||
# in the first of two adjacent bars, predict the price trend
|
||||
_pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time)
|
||||
else:
|
||||
# in the second of two adjacent bars, use the trend predicted in the first one
|
||||
_pred_trend = self.trade_trend[(order.stock_id, order.direction)]
|
||||
# if not tradable, continue
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
@@ -204,13 +213,14 @@ class SBBStrategyBase(BaseStrategy):
|
||||
_order_amount = None
|
||||
# considering trade unit
|
||||
if _amount_trade_unit is None:
|
||||
# divide the order equally
|
||||
# divide the order into equal parts, and trade one part
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step)
|
||||
# without considering trade unit
|
||||
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
|
||||
# cal how many trade unit
|
||||
# divide the order into equal parts, and trade one part
|
||||
# calculate the total count of trade units to trade
|
||||
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
|
||||
# divide the order equally
|
||||
# calculate the amount of one part, ceil the amount
|
||||
# floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step))
|
||||
_order_amount = (
|
||||
(trade_unit_cnt + trade_len - trade_step - 1) // (trade_len - trade_step) * _amount_trade_unit
|
||||
@@ -262,9 +272,9 @@ class SBBStrategyBase(BaseStrategy):
|
||||
if _order_amount:
|
||||
_order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
|
||||
if trade_step % 2 == 0:
|
||||
# in the first of two adjacent bar
|
||||
# in the first one of two adjacent bars
|
||||
# if look short on the price, sell the stock more
|
||||
# if look long on the price, sell the stock more
|
||||
# if look long on the price, buy the stock more
|
||||
if (
|
||||
_pred_trend == self.TREND_SHORT
|
||||
and order.direction == order.SELL
|
||||
@@ -281,7 +291,7 @@ class SBBStrategyBase(BaseStrategy):
|
||||
)
|
||||
order_list.append(_order)
|
||||
else:
|
||||
# in the second of two adjacent bar
|
||||
# in the second one of two adjacent bars
|
||||
# if look short on the price, buy the stock more
|
||||
# if look long on the price, sell the stock more
|
||||
if (
|
||||
@@ -301,6 +311,7 @@ class SBBStrategyBase(BaseStrategy):
|
||||
order_list.append(_order)
|
||||
|
||||
if trade_step % 2 == 0:
|
||||
# in the first one of two adjacent bars, store the trend for the second one to use
|
||||
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
|
||||
|
||||
return order_list
|
||||
@@ -328,7 +339,7 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
instruments of EMA signal, by default "csi300"
|
||||
freq : str, optional
|
||||
freq of EMA signal, by default "day"
|
||||
Note: `freq` may be different from `steb_bar`
|
||||
Note: `freq` may be different from `time_per_step`
|
||||
"""
|
||||
if instruments is None:
|
||||
warnings.warn("`instruments` is not set, will load all stocks")
|
||||
@@ -349,8 +360,10 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
signal_df = convert_index_format(signal_df)
|
||||
signal_df.columns = ["signal"]
|
||||
self.signal = {}
|
||||
for stock_id, stock_val in signal_df.groupby(level="instrument"):
|
||||
self.signal[stock_id] = stock_val
|
||||
|
||||
if not signal_df.empty:
|
||||
for stock_id, stock_val in signal_df.groupby(level="instrument"):
|
||||
self.signal[stock_id] = stock_val
|
||||
|
||||
def reset_level_infra(self, level_infra):
|
||||
"""
|
||||
@@ -367,16 +380,19 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
self._reset_signal()
|
||||
|
||||
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
|
||||
|
||||
# if no signal, return mid trend
|
||||
if stock_id not in self.signal:
|
||||
return self.TREND_MID
|
||||
else:
|
||||
_sample_signal = resam_ts_data(
|
||||
self.signal[stock_id]["signal"], pred_start_time, pred_end_time, method="last"
|
||||
)
|
||||
# if EMA signal == 0 or None, return mid trend
|
||||
if _sample_signal is None or _sample_signal.iloc[0] == 0:
|
||||
return self.TREND_MID
|
||||
# if EMA signal > 0, return long trend
|
||||
elif _sample_signal.iloc[0] > 0:
|
||||
return self.TREND_LONG
|
||||
# if EMA signal > 0, return short trend
|
||||
else:
|
||||
return self.TREND_SHORT
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from typing import Union
|
||||
|
||||
from ..model.base import BaseModel
|
||||
from ..data.dataset import DatasetH
|
||||
@@ -141,8 +142,8 @@ class RLIntStrategy(RLStrategy):
|
||||
def __init__(
|
||||
self,
|
||||
policy,
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
state_interpreter: Union[dict, StateInterpreter],
|
||||
action_interpreter: Union[dict, ActionInterpreter],
|
||||
outer_trade_decision: object = None,
|
||||
level_infra: dict = {},
|
||||
common_infra: dict = {},
|
||||
@@ -151,9 +152,9 @@ class RLIntStrategy(RLStrategy):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
state_interpreter : StateInterpreter
|
||||
interpretor that interprets the qlib execute result into rl env state.
|
||||
action_interpreter : ActionInterpreter
|
||||
state_interpreter : Union[dict, StateInterpreter]
|
||||
interpretor that interprets the qlib execute result into rl env state
|
||||
action_interpreter : Union[dict, ActionInterpreter]
|
||||
interpretor that interprets the rl agent action into qlib order list
|
||||
start_time : Union[str, pd.Timestamp], optional
|
||||
start time of trading, by default None
|
||||
@@ -163,8 +164,8 @@ class RLIntStrategy(RLStrategy):
|
||||
super(RLIntStrategy, self).__init__(policy, outer_trade_decision, level_infra, common_infra, **kwargs)
|
||||
|
||||
self.policy = policy
|
||||
self.state_interpreter = init_instance_by_config(state_interpreter)
|
||||
self.action_interpreter = init_instance_by_config(action_interpreter)
|
||||
self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter)
|
||||
self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter)
|
||||
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
_interpret_state = self.state_interpretor.interpret(execute_result=execute_result)
|
||||
|
||||
@@ -288,11 +288,13 @@ def resam_ts_data(
|
||||
from ..data.dataset.utils import get_level_index
|
||||
|
||||
feature = lazy_sort_index(ts_feature)
|
||||
|
||||
datetime_level = get_level_index(feature, level="datetime") == 0
|
||||
if datetime_level:
|
||||
feature = feature.loc[selector_datetime]
|
||||
else:
|
||||
feature = feature.loc[(slice(None), selector_datetime)]
|
||||
|
||||
if feature.empty:
|
||||
return None
|
||||
if isinstance(feature.index, pd.MultiIndex):
|
||||
|
||||
Reference in New Issue
Block a user