1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 14:01:28 +08:00

fix bugs & add highfreq backtest example

This commit is contained in:
bxdd
2021-05-28 22:29:21 +08:00
parent 6a636546c4
commit 029b63c9dd
8 changed files with 168 additions and 42 deletions

View File

@@ -8,9 +8,12 @@ Qlib supports backtesting of various strategies, including portfolio management
And, Qlib also supports multi-level trading and backtesting. It means that users can use different strategies to trade at different frequencies.
This example uses a DropoutTopkStrategy (a strategy based on the daily frequency Lightgbm model) in weekly frequency for portfolio generation. And, at the daily frequency level, this example uses SBBStrategyEMA (a rule-based strategy that uses EMA for decision-making) to split orders.
## Usage
## Weekly Portfolio Generation and Daily Order Execution
This workflow provides an example that uses a DropoutTopkStrategy (a strategy based on the daily frequency Lightgbm model) in weekly frequency for portfolio generation and uses SBBStrategyEMA (a rule-based strategy that uses EMA for decision-making) to execute orders in daily frequency.
### Usage
Start backtesting by running the following command:
```bash
@@ -22,3 +25,13 @@ Start collecting data by running the following command:
python workflow.py collect_data
```
## Daily Portfolio Generation and Minutely Order Execution
This workflow also provides a high-frequency example that uses a DropoutTopkStrategy for portfolio generation in daily frequency and uses SBBStrategyEMA to execute orders in minutely frequency.
### Usage
Start backtesting by running the following command:
```bash
python workflow.py backtest_highfreq
```

View File

@@ -4,8 +4,9 @@
import qlib
import fire
from qlib.config import REG_CN
from qlib import backtest
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
from qlib.data import D
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
@@ -20,7 +21,7 @@ class MultiLevelTradingWorkflow:
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"end_time": "2021-01-20",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
@@ -54,15 +55,12 @@ class MultiLevelTradingWorkflow:
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
"test": ("2017-01-01", "2021-01-20"),
},
},
},
}
trade_start_time = "2017-01-01"
trade_end_time = "2020-08-01"
port_analysis_config = {
"executor": {
"class": "NestedExecutor",
@@ -86,12 +84,13 @@ class MultiLevelTradingWorkflow:
"instruments": market,
},
},
"generate_report": True,
"track_data": True,
},
},
"backtest": {
"start_time": trade_start_time,
"end_time": trade_end_time,
"start_time": "2017-01-01",
"end_time": "2020-08-01",
"account": 100000000,
"benchmark": benchmark,
"exchange_kwargs": {
@@ -167,6 +166,98 @@ class MultiLevelTradingWorkflow:
for trade_decision in data_generator:
print(trade_decision)
def _init_qlib_with_backend(self):
provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri")
if not exists_qlib_data(provider_uri_1min):
print(f"Qlib data is not found in {provider_uri_1min}")
GetData().qlib_data(target_dir=provider_uri_1min, interval="1min", region=REG_CN)
# TODO: update new data
# provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir
# if not exists_qlib_data(provider_uri_day):
# print(f"Qlib data is not found in {provider_uri_day}")
# GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN)
provider_uri_day = "/data/csdesign/qlib"
provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day}
client_config = {
"calendar_provider": {
"class": "LocalCalendarProvider",
"module_path": "qlib.data.data",
"kwargs": {
"backend": {
"class": "FileCalendarStorage",
"module_path": "qlib.data.storage.file_storage",
"kwargs": {"provider_uri_map": provider_uri_map},
}
},
},
"feature_provider": {
"class": "LocalFeatureProvider",
"module_path": "qlib.data.data",
"kwargs": {
"backend": {
"class": "FileFeatureStorage",
"module_path": "qlib.data.storage.file_storage",
"kwargs": {"provider_uri_map": provider_uri_map},
}
},
},
}
qlib.init(provider_uri=provider_uri_day, **client_config)
def _get_highfreq_config(self, model, dataset):
executor_config = self.port_analysis_config["executor"]
# update executor with hierarchical decison freq ["day", "1min"]
executor_config["kwargs"]["time_per_step"] = "day"
executor_config["kwargs"]["inner_executor"]["kwargs"]["time_per_step"] = "1min"
backtest_config = self.port_analysis_config["backtest"]
# yahoo highfreq data time
backtest_config["start_time"] = "2020-09-20"
backtest_config["end_time"] = "2021-01-20"
# update benchmark, yahoo data don't have SH000300
instruments = D.instruments(market="csi300")
instrument_list = D.list_instruments(instruments=instruments, as_list=True)
backtest_config["benchmark"] = instrument_list
# update exchange config
backtest_config["exchange_kwargs"]["freq"] = "1min"
# set strategy
strategy_config = {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"kwargs": {
"model": model,
"dataset": dataset,
"topk": 50,
"n_drop": 5,
},
}
return executor_config, strategy_config, backtest_config
def backtest_highfreq(self):
self._init_qlib_with_backend()
model = init_instance_by_config(self.task["model"])
dataset = init_instance_by_config(self.task["dataset"])
self._train_model(model, dataset)
executor_config, strategy_config, backtest_config = self._get_highfreq_config(model, dataset)
highfreq_port_analysis_config = {
"executor": executor_config,
"strategy": strategy_config,
"backtest": backtest_config,
}
with R.start(experiment_name="backtest_highfreq"):
recorder = R.get_recorder()
par = PortAnaRecord(recorder, highfreq_port_analysis_config, "day")
par.generate()
if __name__ == "__main__":
fire.Fire(MultiLevelTradingWorkflow)

View File

@@ -304,7 +304,7 @@ class SimulatorExecutor(BaseExecutor):
if self.verbose:
if order.direction == Order.SELL: # sell
print(
"[I {:%Y-%m-%d}]: sell {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format(
"[I {:%Y-%m-%d %H:%M:%S}]: sell {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format(
trade_start_time,
order.stock_id,
trade_price,
@@ -316,7 +316,7 @@ class SimulatorExecutor(BaseExecutor):
)
else:
print(
"[I {:%Y-%m-%d}]: buy {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format(
"[I {:%Y-%m-%d %H:%M:%S}]: buy {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format(
trade_start_time,
order.stock_id,
trade_price,

View File

@@ -41,7 +41,7 @@ class TradeCalendarManager:
- self.trade_step : The number of trading step finished, self.trade_step can be [0, 1, 2, ..., self.trade_len - 1]
"""
_calendar, freq, freq_sam = get_resam_calendar(freq=freq)
self.trade_calendar = _calendar
self._calendar = _calendar
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam)
self.start_index = _start_index
self.end_index = _end_index
@@ -91,7 +91,7 @@ class TradeCalendarManager:
"""
trade_step = trade_step - shift
calendar_index = self.start_index + trade_step
return self.trade_calendar[calendar_index], self.trade_calendar[calendar_index + 1] - pd.Timedelta(seconds=1)
return self._calendar[calendar_index], self._calendar[calendar_index + 1] - pd.Timedelta(seconds=1)
def get_all_time(self):
"""Get the start_time and end_time for trading"""

View File

@@ -93,6 +93,9 @@ class TopkDropoutStrategy(ModelStrategy):
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
print(
trade_step, pred_start_time, pred_end_time, trade_start_time, trade_end_time, pred_score, self.pred_scores
)
if pred_score is None:
return []
if self.only_tradable:

View File

@@ -53,7 +53,7 @@ class TWAPStrategy(BaseStrategy):
outer_trade_decision : object, optional
"""
super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, common_infra=common_infra, **kwargs)
super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
if outer_trade_decision is not None:
self.trade_amount = {}
for order in outer_trade_decision:
@@ -73,21 +73,24 @@ class TWAPStrategy(BaseStrategy):
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
order_list = []
for order in self.outer_trade_decision:
# if not tradable, continue
if not self.trade_exchange.is_stock_tradable(
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
):
continue
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
_order_amount = None
# consider trade unit
# considering trade unit
if _amount_trade_unit is None:
# divide the order equally
# divide the order into equal parts, and trade one part
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step + 1)
# without considering trade unit
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
# divide the order equally
# floor((trade_unit_cnt + trade_len - trade_step) / (trade_len - trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - trade_step + 1))
# divide the order into equal parts, and trade one part
# calculate the total count of trade units to trade
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
# calculate the amount of one part, ceil the amount
# floor((trade_unit_cnt + trade_len - trade_step) / (trade_len - trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - trade_step + 1))
_order_amount = (
(trade_unit_cnt + trade_len - trade_step) // (trade_len - trade_step + 1) * _amount_trade_unit
)
@@ -144,6 +147,14 @@ class SBBStrategyBase(BaseStrategy):
self.trade_exchange = trade_exchange
def reset_common_infra(self, common_infra):
"""
Parameters
----------
common_infra : dict, optional
common infrastructure for backtesting, by default None
- It should include `trade_account`, used to get position
- It should include `trade_exchange`, used to provide market info
"""
super(SBBStrategyBase, self).reset_common_infra(common_infra)
if common_infra is not None:
if "trade_exchange" in common_infra:
@@ -154,10 +165,6 @@ class SBBStrategyBase(BaseStrategy):
Parameters
----------
outer_trade_decision : object, optional
common_infra : None, optional
common infrastructure for backtesting, by default None
- It should include `trade_account`, used to get position
- It should include `trade_exchange`, used to provide market info
"""
super(SBBStrategyBase, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
if outer_trade_decision is not None:
@@ -186,10 +193,12 @@ class SBBStrategyBase(BaseStrategy):
order_list = []
# for each order in in self.outer_trade_decision
for order in self.outer_trade_decision:
# predict the price trend
# get the price trend
if trade_step % 2 == 0:
# in the first of two adjacent bars, predict the price trend
_pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time)
else:
# in the second of two adjacent bars, use the trend predicted in the first one
_pred_trend = self.trade_trend[(order.stock_id, order.direction)]
# if not tradable, continue
if not self.trade_exchange.is_stock_tradable(
@@ -204,13 +213,14 @@ class SBBStrategyBase(BaseStrategy):
_order_amount = None
# considering trade unit
if _amount_trade_unit is None:
# divide the order equally
# divide the order into equal parts, and trade one part
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step)
# without considering trade unit
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
# cal how many trade unit
# divide the order into equal parts, and trade one part
# calculate the total count of trade units to trade
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
# divide the order equally
# calculate the amount of one part, ceil the amount
# floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step))
_order_amount = (
(trade_unit_cnt + trade_len - trade_step - 1) // (trade_len - trade_step) * _amount_trade_unit
@@ -262,9 +272,9 @@ class SBBStrategyBase(BaseStrategy):
if _order_amount:
_order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
if trade_step % 2 == 0:
# in the first of two adjacent bar
# in the first one of two adjacent bars
# if look short on the price, sell the stock more
# if look long on the price, sell the stock more
# if look long on the price, buy the stock more
if (
_pred_trend == self.TREND_SHORT
and order.direction == order.SELL
@@ -281,7 +291,7 @@ class SBBStrategyBase(BaseStrategy):
)
order_list.append(_order)
else:
# in the second of two adjacent bar
# in the second one of two adjacent bars
# if look short on the price, buy the stock more
# if look long on the price, sell the stock more
if (
@@ -301,6 +311,7 @@ class SBBStrategyBase(BaseStrategy):
order_list.append(_order)
if trade_step % 2 == 0:
# in the first one of two adjacent bars, store the trend for the second one to use
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
return order_list
@@ -328,7 +339,7 @@ class SBBStrategyEMA(SBBStrategyBase):
instruments of EMA signal, by default "csi300"
freq : str, optional
freq of EMA signal, by default "day"
Note: `freq` may be different from `steb_bar`
Note: `freq` may be different from `time_per_step`
"""
if instruments is None:
warnings.warn("`instruments` is not set, will load all stocks")
@@ -349,8 +360,10 @@ class SBBStrategyEMA(SBBStrategyBase):
signal_df = convert_index_format(signal_df)
signal_df.columns = ["signal"]
self.signal = {}
for stock_id, stock_val in signal_df.groupby(level="instrument"):
self.signal[stock_id] = stock_val
if not signal_df.empty:
for stock_id, stock_val in signal_df.groupby(level="instrument"):
self.signal[stock_id] = stock_val
def reset_level_infra(self, level_infra):
"""
@@ -367,16 +380,19 @@ class SBBStrategyEMA(SBBStrategyBase):
self._reset_signal()
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
# if no signal, return mid trend
if stock_id not in self.signal:
return self.TREND_MID
else:
_sample_signal = resam_ts_data(
self.signal[stock_id]["signal"], pred_start_time, pred_end_time, method="last"
)
# if EMA signal == 0 or None, return mid trend
if _sample_signal is None or _sample_signal.iloc[0] == 0:
return self.TREND_MID
# if EMA signal > 0, return long trend
elif _sample_signal.iloc[0] > 0:
return self.TREND_LONG
# if EMA signal > 0, return short trend
else:
return self.TREND_SHORT

View File

@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Union
from ..model.base import BaseModel
from ..data.dataset import DatasetH
@@ -141,8 +142,8 @@ class RLIntStrategy(RLStrategy):
def __init__(
self,
policy,
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
state_interpreter: Union[dict, StateInterpreter],
action_interpreter: Union[dict, ActionInterpreter],
outer_trade_decision: object = None,
level_infra: dict = {},
common_infra: dict = {},
@@ -151,9 +152,9 @@ class RLIntStrategy(RLStrategy):
"""
Parameters
----------
state_interpreter : StateInterpreter
interpretor that interprets the qlib execute result into rl env state.
action_interpreter : ActionInterpreter
state_interpreter : Union[dict, StateInterpreter]
interpretor that interprets the qlib execute result into rl env state
action_interpreter : Union[dict, ActionInterpreter]
interpretor that interprets the rl agent action into qlib order list
start_time : Union[str, pd.Timestamp], optional
start time of trading, by default None
@@ -163,8 +164,8 @@ class RLIntStrategy(RLStrategy):
super(RLIntStrategy, self).__init__(policy, outer_trade_decision, level_infra, common_infra, **kwargs)
self.policy = policy
self.state_interpreter = init_instance_by_config(state_interpreter)
self.action_interpreter = init_instance_by_config(action_interpreter)
self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter)
self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter)
def generate_trade_decision(self, execute_result=None):
_interpret_state = self.state_interpretor.interpret(execute_result=execute_result)

View File

@@ -288,11 +288,13 @@ def resam_ts_data(
from ..data.dataset.utils import get_level_index
feature = lazy_sort_index(ts_feature)
datetime_level = get_level_index(feature, level="datetime") == 0
if datetime_level:
feature = feature.loc[selector_datetime]
else:
feature = feature.loc[(slice(None), selector_datetime)]
if feature.empty:
return None
if isinstance(feature.index, pd.MultiIndex):