mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 02:21:18 +08:00
optimize performance of resam data in rule_strategy & exchange
This commit is contained in:
@@ -12,7 +12,7 @@ import pandas as pd
|
||||
from ..data.data import D
|
||||
from ..data.dataset.utils import get_level_index
|
||||
from ..config import C, REG_CN
|
||||
from ..utils.resam import resam_ts_data
|
||||
from ..utils.resam import resam_ts_data, ts_data_last
|
||||
from ..log import get_module_logger
|
||||
from .order import Order
|
||||
|
||||
@@ -166,7 +166,7 @@ class Exchange:
|
||||
|
||||
quote_dict = {}
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument"):
|
||||
quote_dict[stock_id] = stock_val
|
||||
quote_dict[stock_id] = stock_val.droplevel(level="instrument")
|
||||
|
||||
self.quote = quote_dict
|
||||
|
||||
@@ -186,13 +186,13 @@ class Exchange:
|
||||
|
||||
"""
|
||||
if direction is None:
|
||||
buy_limit = resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all").iloc[0]
|
||||
sell_limit = resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all").iloc[0]
|
||||
buy_limit = resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all")
|
||||
sell_limit = resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all")
|
||||
return buy_limit or sell_limit
|
||||
elif direction == Order.BUY:
|
||||
return resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all").iloc[0]
|
||||
return resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all")
|
||||
elif direction == Order.SELL:
|
||||
return resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all").iloc[0]
|
||||
return resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all")
|
||||
else:
|
||||
raise ValueError(f"direction {direction} is not supported!")
|
||||
|
||||
@@ -267,16 +267,16 @@ class Exchange:
|
||||
)
|
||||
|
||||
def get_quote_info(self, stock_id, start_time, end_time):
|
||||
return resam_ts_data(self.quote[stock_id], start_time, end_time, method="last").iloc[0]
|
||||
return resam_ts_data(self.quote[stock_id], start_time, end_time, method=ts_data_last)
|
||||
|
||||
def get_close(self, stock_id, start_time, end_time):
|
||||
return resam_ts_data(self.quote[stock_id]["$close"], start_time, end_time, method="last").iloc[0]
|
||||
return resam_ts_data(self.quote[stock_id]["$close"], start_time, end_time, method=ts_data_last)
|
||||
|
||||
def get_volume(self, stock_id, start_time, end_time):
|
||||
return resam_ts_data(self.quote[stock_id]["$volume"], start_time, end_time, method="sum").iloc[0]
|
||||
return resam_ts_data(self.quote[stock_id]["$volume"], start_time, end_time, method="sum")
|
||||
|
||||
def get_deal_price(self, stock_id, start_time, end_time):
|
||||
deal_price = resam_ts_data(self.quote[stock_id][self.deal_price], start_time, end_time, method="last").iloc[0]
|
||||
deal_price = resam_ts_data(self.quote[stock_id][self.deal_price], start_time, end_time, method=ts_data_last)
|
||||
if np.isclose(deal_price, 0.0) or np.isnan(deal_price):
|
||||
self.logger.warning(
|
||||
f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!"
|
||||
@@ -295,10 +295,7 @@ class Exchange:
|
||||
"""
|
||||
if stock_id not in self.quote:
|
||||
return None
|
||||
res = resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method="last")
|
||||
if res is not None:
|
||||
res = res.iloc[0]
|
||||
return res
|
||||
return resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method=ts_data_last)
|
||||
|
||||
def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time):
|
||||
"""
|
||||
|
||||
@@ -3,7 +3,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from ...utils.resam import resam_ts_data
|
||||
from ...utils.resam import resam_ts_data, ts_data_last
|
||||
from ...data.data import D
|
||||
from ...strategy.base import BaseStrategy
|
||||
from ...backtest.order import BaseTradeDecision, Order, TradeDecisionWO
|
||||
@@ -427,7 +427,7 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
|
||||
if not signal_df.empty:
|
||||
for stock_id, stock_val in signal_df.groupby(level="instrument"):
|
||||
self.signal[stock_id] = stock_val
|
||||
self.signal[stock_id] = stock_val["signal"].droplevel(level="instrument")
|
||||
|
||||
def reset_level_infra(self, level_infra):
|
||||
"""
|
||||
@@ -449,13 +449,16 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
return self.TREND_MID
|
||||
else:
|
||||
_sample_signal = resam_ts_data(
|
||||
self.signal[stock_id]["signal"], pred_start_time, pred_end_time, method="last"
|
||||
self.signal[stock_id],
|
||||
pred_start_time,
|
||||
pred_end_time,
|
||||
method=ts_data_last,
|
||||
)
|
||||
# if EMA signal == 0 or None, return mid trend
|
||||
if _sample_signal is None or _sample_signal.iloc[0] == 0:
|
||||
if _sample_signal is None or np.isnan(_sample_signal) or _sample_signal == 0:
|
||||
return self.TREND_MID
|
||||
# if EMA signal > 0, return long trend
|
||||
elif _sample_signal.iloc[0] > 0:
|
||||
elif _sample_signal > 0:
|
||||
return self.TREND_LONG
|
||||
# if EMA signal < 0, return short trend
|
||||
else:
|
||||
@@ -518,7 +521,7 @@ class ACStrategy(BaseStrategy):
|
||||
|
||||
if not signal_df.empty:
|
||||
for stock_id, stock_val in signal_df.groupby(level="instrument"):
|
||||
self.signal[stock_id] = stock_val
|
||||
self.signal[stock_id] = stock_val["volatility"].droplevel(level="instrument")
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
@@ -585,12 +588,12 @@ class ACStrategy(BaseStrategy):
|
||||
# considering trade unit
|
||||
|
||||
sig_sam = (
|
||||
resam_ts_data(self.signal[order.stock_id]["volatility"], pred_start_time, pred_end_time, method="last")
|
||||
resam_ts_data(self.signal[order.stock_id], pred_start_time, pred_end_time, method=ts_data_last)
|
||||
if order.stock_id in self.signal
|
||||
else None
|
||||
)
|
||||
|
||||
if sig_sam is None or sig_sam.iloc[0] is None:
|
||||
if sig_sam is None or np.isnan(sig_sam):
|
||||
# no signal, TWAP
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
|
||||
if _amount_trade_unit is None:
|
||||
@@ -607,7 +610,7 @@ class ACStrategy(BaseStrategy):
|
||||
)
|
||||
else:
|
||||
# VA strategy
|
||||
kappa_tild = self.lamb / self.eta * sig_sam.iloc[0] * sig_sam.iloc[0]
|
||||
kappa_tild = self.lamb / self.eta * sig_sam * sig_sam
|
||||
kappa = np.arccosh(kappa_tild / 2 + 1)
|
||||
amount_ratio = (
|
||||
np.sinh(kappa * (trade_len - trade_step)) - np.sinh(kappa * (trade_len - trade_step - 1))
|
||||
|
||||
@@ -263,3 +263,45 @@ def resam_ts_data(
|
||||
elif isinstance(method, str):
|
||||
return getattr(feature, method)(**method_kwargs)
|
||||
return feature
|
||||
|
||||
|
||||
def get_valid_value(series, last=True):
|
||||
"""get the first/last not nan value of pd.Series with single level index
|
||||
Parameters
|
||||
----------
|
||||
series : pd.Seires
|
||||
last : bool, optional
|
||||
wether to get the last valid value, by default True
|
||||
- if last is True, get the last valid value
|
||||
- else, get the first valid value
|
||||
|
||||
Returns
|
||||
-------
|
||||
Nan | float
|
||||
the first/last valid value
|
||||
"""
|
||||
x = series.dropna()
|
||||
if x.empty:
|
||||
return np.nan
|
||||
else:
|
||||
return x.iloc[-1] if last else x.iloc[0]
|
||||
|
||||
|
||||
def ts_data_last(ts_feature):
|
||||
"""get the last not nan value of pd.Series|DataFrame with single level index"""
|
||||
if isinstance(ts_feature, pd.DataFrame):
|
||||
return ts_feature.apply(lambda column: get_valid_value(column, last=True))
|
||||
elif isinstance(ts_feature, pd.Series):
|
||||
return get_valid_value(ts_feature, last=True)
|
||||
else:
|
||||
raise TypeError(f"ts_feature should be pd.DataFrame/Series, not {type(ts_feature)}")
|
||||
|
||||
|
||||
def ts_data_first(ts_feature):
|
||||
"""get the first not nan value of pd.Series|DataFrame with single level index"""
|
||||
if isinstance(ts_feature, pd.DataFrame):
|
||||
return ts_feature.apply(lambda column: get_valid_value(column, last=False))
|
||||
elif isinstance(ts_feature, pd.Series):
|
||||
return get_valid_value(ts_feature, last=False)
|
||||
else:
|
||||
raise TypeError(f"ts_feature should be pd.DataFrame/Series, not {type(ts_feature)}")
|
||||
|
||||
Reference in New Issue
Block a user