1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 18:40:58 +08:00

optimize rule_strategy performance

This commit is contained in:
bxdd
2021-05-14 15:50:27 +08:00
parent ea60e608ba
commit eaa719df17
2 changed files with 62 additions and 18 deletions

View File

@@ -74,11 +74,12 @@ class BaseTradeCalendar:
return self.calendar[calendar_index - 1], self.calendar[calendar_index] - pd.Timedelta(seconds=1)
def finished(self):
return self.trade_index >= self.trade_len - 1
return self.trade_index >= self.trade_len
def step(self):
if self.finished():
raise RuntimeError(f"this env has completed its task, please reset it if you want to call it!")
# trade count += 1
self.trade_index = self.trade_index + 1
@@ -165,6 +166,7 @@ class SplitExecutor(BaseExecutor):
trading strategy in each trading bar
trade_exchange : Exchange
exchange that provides market info
- If `trade_exchange` is None, self.trade_exchange will be set with common_faculty
"""
super(SplitExecutor, self).__init__(
step_bar=step_bar,

View File

@@ -2,7 +2,7 @@ import copy
import warnings
import numpy as np
import pandas as pd
from typing import Union
from ...utils.sample import sample_feature
from ...data.data import D
@@ -13,6 +13,8 @@ from ..backtest.faculty import common_faculty
class TWAPStrategy(RuleStrategy, OrderEnhancement):
"""TWAP Strategy for trading"""
def __init__(
self,
step_bar,
@@ -22,6 +24,15 @@ class TWAPStrategy(RuleStrategy, OrderEnhancement):
trade_order_list=[],
**kwargs,
):
"""
Parameters
----------
trade_exchange : Exchange, optional
exchange that provides market info, by default None
- If `trade_exchange` is None, self.trade_exchange will be set with common_faculty
trade_order_list : list, optional
order list to trade, which the strategy will trade in [start_time , end_time] , by default []
"""
super(TWAPStrategy, self).__init__(step_bar, start_time, end_time, **kwargs)
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
self.trade_order_list = trade_order_list
@@ -51,19 +62,19 @@ class TWAPStrategy(RuleStrategy, OrderEnhancement):
_order_amount = None
if _amount_trade_unit is None:
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (
self.trade_len - self.trade_index
self.trade_len - self.trade_index + 1
)
if self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
_order_amount = (
(trade_unit_cnt + self.trade_len - self.trade_index - 1)
// (self.trade_len - self.trade_index)
(trade_unit_cnt + self.trade_len - self.trade_index)
// (self.trade_len - self.trade_index + 1)
* _amount_trade_unit
)
if order.direction == order.SELL:
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
_order_amount is None or self.trade_index == self.trade_len - 1
_order_amount is None or self.trade_index == self.trade_len
):
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
@@ -99,6 +110,15 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement):
trade_order_list=[],
**kwargs,
):
"""
Parameters
----------
trade_exchange : Exchange, optional
exchange that provides market info, by default None
- If `trade_exchange` is None, self.trade_exchange will be set with common_faculty
trade_order_list : list, optional
order list to trade, which the strategy will trade in [start_time , end_time] , by default []
"""
super(SBBStrategyBase, self).__init__(step_bar, start_time, end_time, **kwargs)
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
self.trade_order_list = trade_order_list
@@ -144,18 +164,18 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement):
_order_amount = None
if _amount_trade_unit is None:
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (
self.trade_len - self.trade_index
self.trade_len - self.trade_index + 1
)
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
_order_amount = (
(trade_unit_cnt + self.trade_len - self.trade_index - 1)
// (self.trade_len - self.trade_index)
(trade_unit_cnt + self.trade_len - self.trade_index)
// (self.trade_len - self.trade_index + 1)
* _amount_trade_unit
)
if order.direction == order.SELL:
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
_order_amount is None or self.trade_index == self.trade_len - 1
_order_amount is None or self.trade_index == self.trade_len
):
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
@@ -176,19 +196,19 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement):
_order_amount = (
2
* self.trade_amount[(order.stock_id, order.direction)]
/ (self.trade_len - self.trade_index + 1)
/ (self.trade_len - self.trade_index + 2)
)
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
_order_amount = (
(trade_unit_cnt + self.trade_len - self.trade_index)
// (self.trade_len - self.trade_index + 1)
(trade_unit_cnt + self.trade_len - self.trade_index + 1)
// (self.trade_len - self.trade_index + 2)
* 2
* _amount_trade_unit
)
if order.direction == order.SELL:
if self.trade_amount[(order.stock_id, order.direction)] >= 1e-5 and (
_order_amount is None or self.trade_index == self.trade_len - 1
_order_amount is None or self.trade_index == self.trade_len
):
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
@@ -235,7 +255,7 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement):
class SBBStrategyEMA(SBBStrategyBase):
"""
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy with (EMA).
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy with (EMA) signal.
"""
def __init__(
@@ -249,6 +269,15 @@ class SBBStrategyEMA(SBBStrategyBase):
freq="day",
**kwargs,
):
"""
Parameters
----------
instruments : str, optional
instruments of EMA signal, by default "csi300"
freq : str, optional
freq of EMA signal, by default "day"
Note: `freq` may be different from `steb_bar`
"""
super(SBBStrategyEMA, self).__init__(step_bar, start_time, end_time, trade_exchange, trade_order_list, **kwargs)
if instruments is None:
warnings.warn("`instruments` is not set, will load all stocks")
@@ -257,13 +286,25 @@ class SBBStrategyEMA(SBBStrategyBase):
self.instruments = D.instruments(instruments)
self.freq = freq
def reset(self, start_time=None, end_time=None, **kwargs):
def reset(self, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None, **kwargs):
"""
Reset EMA signal for trading
Parameters
----------
start_time : Union[str, pd.Timestamp], optional
start time for trading, also used to calculate the start time of EMA signal, by default None
end_time : Union[str, pd.Timestamp], optional
end time for trading, also used to calculate the end time of EMA signal, by default None
"""
super(SBBStrategyEMA, self).reset(start_time=start_time, end_time=end_time, **kwargs)
if self.start_time and self.end_time and (start_time or end_time):
fields = ["EMA($close, 10)-EMA($close, 20)"]
signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1)
signal_start_time, _ = self._get_calendar_time(trade_index=1, shift=1)
_, signal_end_time = self._get_calendar_time(trade_index=self.trade_len, shift=1)
signal_df = D.features(
self.instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq
self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
)
signal_df = convert_index_format(signal_df)
signal_df.columns = ["signal"]
@@ -272,6 +313,7 @@ class SBBStrategyEMA(SBBStrategyBase):
self.signal[stock_id] = stock_val
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
if stock_id not in self.signal:
return self.TREND_MID
else: