mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 18:40:58 +08:00
optimize rule_strategy performance
This commit is contained in:
@@ -74,11 +74,12 @@ class BaseTradeCalendar:
|
||||
return self.calendar[calendar_index - 1], self.calendar[calendar_index] - pd.Timedelta(seconds=1)
|
||||
|
||||
def finished(self):
|
||||
return self.trade_index >= self.trade_len - 1
|
||||
return self.trade_index >= self.trade_len
|
||||
|
||||
def step(self):
|
||||
if self.finished():
|
||||
raise RuntimeError(f"this env has completed its task, please reset it if you want to call it!")
|
||||
# trade count += 1
|
||||
self.trade_index = self.trade_index + 1
|
||||
|
||||
|
||||
@@ -165,6 +166,7 @@ class SplitExecutor(BaseExecutor):
|
||||
trading strategy in each trading bar
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_faculty
|
||||
"""
|
||||
super(SplitExecutor, self).__init__(
|
||||
step_bar=step_bar,
|
||||
|
||||
@@ -2,7 +2,7 @@ import copy
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from typing import Union
|
||||
|
||||
from ...utils.sample import sample_feature
|
||||
from ...data.data import D
|
||||
@@ -13,6 +13,8 @@ from ..backtest.faculty import common_faculty
|
||||
|
||||
|
||||
class TWAPStrategy(RuleStrategy, OrderEnhancement):
|
||||
"""TWAP Strategy for trading"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
@@ -22,6 +24,15 @@ class TWAPStrategy(RuleStrategy, OrderEnhancement):
|
||||
trade_order_list=[],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
trade_exchange : Exchange, optional
|
||||
exchange that provides market info, by default None
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_faculty
|
||||
trade_order_list : list, optional
|
||||
order list to trade, which the strategy will trade in [start_time , end_time] , by default []
|
||||
"""
|
||||
super(TWAPStrategy, self).__init__(step_bar, start_time, end_time, **kwargs)
|
||||
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
|
||||
self.trade_order_list = trade_order_list
|
||||
@@ -51,19 +62,19 @@ class TWAPStrategy(RuleStrategy, OrderEnhancement):
|
||||
_order_amount = None
|
||||
if _amount_trade_unit is None:
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (
|
||||
self.trade_len - self.trade_index
|
||||
self.trade_len - self.trade_index + 1
|
||||
)
|
||||
if self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
|
||||
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
|
||||
_order_amount = (
|
||||
(trade_unit_cnt + self.trade_len - self.trade_index - 1)
|
||||
// (self.trade_len - self.trade_index)
|
||||
(trade_unit_cnt + self.trade_len - self.trade_index)
|
||||
// (self.trade_len - self.trade_index + 1)
|
||||
* _amount_trade_unit
|
||||
)
|
||||
|
||||
if order.direction == order.SELL:
|
||||
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
|
||||
_order_amount is None or self.trade_index == self.trade_len - 1
|
||||
_order_amount is None or self.trade_index == self.trade_len
|
||||
):
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
|
||||
|
||||
@@ -99,6 +110,15 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement):
|
||||
trade_order_list=[],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
trade_exchange : Exchange, optional
|
||||
exchange that provides market info, by default None
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_faculty
|
||||
trade_order_list : list, optional
|
||||
order list to trade, which the strategy will trade in [start_time , end_time] , by default []
|
||||
"""
|
||||
super(SBBStrategyBase, self).__init__(step_bar, start_time, end_time, **kwargs)
|
||||
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
|
||||
self.trade_order_list = trade_order_list
|
||||
@@ -144,18 +164,18 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement):
|
||||
_order_amount = None
|
||||
if _amount_trade_unit is None:
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (
|
||||
self.trade_len - self.trade_index
|
||||
self.trade_len - self.trade_index + 1
|
||||
)
|
||||
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
|
||||
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
|
||||
_order_amount = (
|
||||
(trade_unit_cnt + self.trade_len - self.trade_index - 1)
|
||||
// (self.trade_len - self.trade_index)
|
||||
(trade_unit_cnt + self.trade_len - self.trade_index)
|
||||
// (self.trade_len - self.trade_index + 1)
|
||||
* _amount_trade_unit
|
||||
)
|
||||
if order.direction == order.SELL:
|
||||
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
|
||||
_order_amount is None or self.trade_index == self.trade_len - 1
|
||||
_order_amount is None or self.trade_index == self.trade_len
|
||||
):
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
|
||||
|
||||
@@ -176,19 +196,19 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement):
|
||||
_order_amount = (
|
||||
2
|
||||
* self.trade_amount[(order.stock_id, order.direction)]
|
||||
/ (self.trade_len - self.trade_index + 1)
|
||||
/ (self.trade_len - self.trade_index + 2)
|
||||
)
|
||||
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
|
||||
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
|
||||
_order_amount = (
|
||||
(trade_unit_cnt + self.trade_len - self.trade_index)
|
||||
// (self.trade_len - self.trade_index + 1)
|
||||
(trade_unit_cnt + self.trade_len - self.trade_index + 1)
|
||||
// (self.trade_len - self.trade_index + 2)
|
||||
* 2
|
||||
* _amount_trade_unit
|
||||
)
|
||||
if order.direction == order.SELL:
|
||||
if self.trade_amount[(order.stock_id, order.direction)] >= 1e-5 and (
|
||||
_order_amount is None or self.trade_index == self.trade_len - 1
|
||||
_order_amount is None or self.trade_index == self.trade_len
|
||||
):
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
|
||||
|
||||
@@ -235,7 +255,7 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement):
|
||||
|
||||
class SBBStrategyEMA(SBBStrategyBase):
|
||||
"""
|
||||
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy with (EMA).
|
||||
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy with (EMA) signal.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -249,6 +269,15 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
freq="day",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
instruments : str, optional
|
||||
instruments of EMA signal, by default "csi300"
|
||||
freq : str, optional
|
||||
freq of EMA signal, by default "day"
|
||||
Note: `freq` may be different from `steb_bar`
|
||||
"""
|
||||
super(SBBStrategyEMA, self).__init__(step_bar, start_time, end_time, trade_exchange, trade_order_list, **kwargs)
|
||||
if instruments is None:
|
||||
warnings.warn("`instruments` is not set, will load all stocks")
|
||||
@@ -257,13 +286,25 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
self.instruments = D.instruments(instruments)
|
||||
self.freq = freq
|
||||
|
||||
def reset(self, start_time=None, end_time=None, **kwargs):
|
||||
def reset(self, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None, **kwargs):
|
||||
"""
|
||||
Reset EMA signal for trading
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time : Union[str, pd.Timestamp], optional
|
||||
start time for trading, also used to calculate the start time of EMA signal, by default None
|
||||
|
||||
end_time : Union[str, pd.Timestamp], optional
|
||||
end time for trading, also used to calculate the end time of EMA signal, by default None
|
||||
"""
|
||||
super(SBBStrategyEMA, self).reset(start_time=start_time, end_time=end_time, **kwargs)
|
||||
if self.start_time and self.end_time and (start_time or end_time):
|
||||
fields = ["EMA($close, 10)-EMA($close, 20)"]
|
||||
signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1)
|
||||
signal_start_time, _ = self._get_calendar_time(trade_index=1, shift=1)
|
||||
_, signal_end_time = self._get_calendar_time(trade_index=self.trade_len, shift=1)
|
||||
signal_df = D.features(
|
||||
self.instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq
|
||||
self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
|
||||
)
|
||||
signal_df = convert_index_format(signal_df)
|
||||
signal_df.columns = ["signal"]
|
||||
@@ -272,6 +313,7 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
self.signal[stock_id] = stock_val
|
||||
|
||||
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
|
||||
|
||||
if stock_id not in self.signal:
|
||||
return self.TREND_MID
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user