1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 18:11:18 +08:00

Merge pull request #493 from bxdd/optimize_resam_data

optimize performance of resam data in rule_strategy & exchange
This commit is contained in:
bxdd
2021-07-04 02:44:53 +08:00
committed by GitHub
6 changed files with 99 additions and 35 deletions

View File

@@ -12,7 +12,7 @@ import pandas as pd
from ..data.data import D
from ..data.dataset.utils import get_level_index
from ..config import C, REG_CN
from ..utils.resam import resam_ts_data
from ..utils.resam import resam_ts_data, ts_data_last
from ..log import get_module_logger
from .order import Order, OrderDir, OrderHelper
@@ -166,7 +166,7 @@ class Exchange:
quote_dict = {}
for stock_id, stock_val in quote_df.groupby(level="instrument"):
quote_dict[stock_id] = stock_val
quote_dict[stock_id] = stock_val.droplevel(level="instrument")
self.quote = quote_dict
@@ -186,13 +186,13 @@ class Exchange:
"""
if direction is None:
buy_limit = resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all").iloc[0]
sell_limit = resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all").iloc[0]
buy_limit = resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all")
sell_limit = resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all")
return buy_limit or sell_limit
elif direction == Order.BUY:
return resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all").iloc[0]
return resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all")
elif direction == Order.SELL:
return resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all").iloc[0]
return resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all")
else:
raise ValueError(f"direction {direction} is not supported!")
@@ -258,16 +258,16 @@ class Exchange:
return trade_val, trade_cost, trade_price
def get_quote_info(self, stock_id, start_time, end_time):
return resam_ts_data(self.quote[stock_id], start_time, end_time, method="last").iloc[0]
return resam_ts_data(self.quote[stock_id], start_time, end_time, method=ts_data_last)
def get_close(self, stock_id, start_time, end_time):
return resam_ts_data(self.quote[stock_id]["$close"], start_time, end_time, method="last").iloc[0]
return resam_ts_data(self.quote[stock_id]["$close"], start_time, end_time, method=ts_data_last)
def get_volume(self, stock_id, start_time, end_time):
return resam_ts_data(self.quote[stock_id]["$volume"], start_time, end_time, method="sum").iloc[0]
return resam_ts_data(self.quote[stock_id]["$volume"], start_time, end_time, method="sum")
def get_deal_price(self, stock_id, start_time, end_time):
deal_price = resam_ts_data(self.quote[stock_id][self.deal_price], start_time, end_time, method="last").iloc[0]
deal_price = resam_ts_data(self.quote[stock_id][self.deal_price], start_time, end_time, method=ts_data_last)
if np.isclose(deal_price, 0.0) or np.isnan(deal_price):
self.logger.warning(
f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!"
@@ -286,10 +286,7 @@ class Exchange:
"""
if stock_id not in self.quote:
return None
res = resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method="last")
if res is not None:
res = res.iloc[0]
return res
return resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method=ts_data_last)
def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time):
"""

View File

@@ -91,7 +91,7 @@ class Report:
if freq is None:
raise ValueError("benchmark freq can't be None!")
_codes = benchmark if isinstance(benchmark, list) else [benchmark]
_codes = benchmark if isinstance(benchmark, (list, dict)) else [benchmark]
fields = ["$close/Ref($close,1)-1"]
_temp_result, _ = get_higher_eq_freq_feature(_codes, fields, start_time, end_time, freq=freq)
if len(_temp_result) == 0:

View File

@@ -7,7 +7,7 @@ from qlib.data.dataset.utils import convert_index_format
from qlib.utils import lazy_sort_index
from ...utils.resam import resam_ts_data
from ...utils.resam import resam_ts_data, ts_data_last
from ...data.data import D
from ...strategy.base import BaseStrategy
from ...backtest.order import BaseTradeDecision, Order, TradeDecisionWO
@@ -432,7 +432,7 @@ class SBBStrategyEMA(SBBStrategyBase):
if not signal_df.empty:
for stock_id, stock_val in signal_df.groupby(level="instrument"):
self.signal[stock_id] = stock_val
self.signal[stock_id] = stock_val["signal"].droplevel(level="instrument")
def reset_level_infra(self, level_infra):
"""
@@ -454,13 +454,16 @@ class SBBStrategyEMA(SBBStrategyBase):
return self.TREND_MID
else:
_sample_signal = resam_ts_data(
self.signal[stock_id]["signal"], pred_start_time, pred_end_time, method="last"
self.signal[stock_id],
pred_start_time,
pred_end_time,
method=ts_data_last,
)
# if EMA signal == 0 or None, return mid trend
if _sample_signal is None or _sample_signal.iloc[0] == 0:
if _sample_signal is None or np.isnan(_sample_signal) or _sample_signal == 0:
return self.TREND_MID
# if EMA signal > 0, return long trend
elif _sample_signal.iloc[0] > 0:
elif _sample_signal > 0:
return self.TREND_LONG
# if EMA signal < 0, return short trend
else:
@@ -523,7 +526,7 @@ class ACStrategy(BaseStrategy):
if not signal_df.empty:
for stock_id, stock_val in signal_df.groupby(level="instrument"):
self.signal[stock_id] = stock_val
self.signal[stock_id] = stock_val["volatility"].droplevel(level="instrument")
def reset_common_infra(self, common_infra):
"""
@@ -590,12 +593,12 @@ class ACStrategy(BaseStrategy):
# considering trade unit
sig_sam = (
resam_ts_data(self.signal[order.stock_id]["volatility"], pred_start_time, pred_end_time, method="last")
resam_ts_data(self.signal[order.stock_id], pred_start_time, pred_end_time, method=ts_data_last)
if order.stock_id in self.signal
else None
)
if sig_sam is None or sig_sam.iloc[0] is None:
if sig_sam is None or np.isnan(sig_sam):
# no signal, TWAP
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
if _amount_trade_unit is None:
@@ -612,7 +615,7 @@ class ACStrategy(BaseStrategy):
)
else:
# VA strategy
kappa_tild = self.lamb / self.eta * sig_sam.iloc[0] * sig_sam.iloc[0]
kappa_tild = self.lamb / self.eta * sig_sam * sig_sam
kappa = np.arccosh(kappa_tild / 2 + 1)
amount_ratio = (
np.sinh(kappa * (trade_len - trade_step)) - np.sinh(kappa * (trade_len - trade_step - 1))

View File

@@ -197,7 +197,7 @@ class DataHandler(Serializable):
-------
pd.DataFrame.
"""
from .storage import HasingStockStorage
from .storage import BaseHandlerStorage
data_storage = self._data
if isinstance(data_storage, pd.DataFrame):
@@ -211,10 +211,17 @@ class DataHandler(Serializable):
# Fetch column first will be more friendly to SepDataFrame
data_df = fetch_df_by_col(data_df, col_set)
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
elif isinstance(data_storage, HasingStockStorage):
if proc_func is not None:
raise ValueError("proc_func is not supported by the HasingStockStorage")
data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig)
elif isinstance(data_storage, BaseHandlerStorage):
if not data_storage.is_proc_func_supported():
if proc_func is not None:
raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}")
data_df = data_storage.fetch(
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig
)
else:
data_df = data_storage.fetch(
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func
)
else:
raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}")
@@ -522,7 +529,7 @@ class DataHandlerLP(DataHandler):
-------
pd.DataFrame:
"""
from .storage import HasingStockStorage
from .storage import BaseHandlerStorage
data_storage = self._get_df_by_key(data_key)
if isinstance(data_storage, pd.DataFrame):
@@ -537,10 +544,17 @@ class DataHandlerLP(DataHandler):
data_df = fetch_df_by_col(data_df, col_set)
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
elif isinstance(data_storage, HasingStockStorage):
if proc_func is not None:
raise ValueError("proc_func is not supported by the HasingStockStorage")
data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig)
elif isinstance(data_storage, BaseHandlerStorage):
if not data_storage.is_proc_func_supported():
if proc_func is not None:
raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}")
data_df = data_storage.fetch(
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig
)
else:
data_df = data_storage.fetch(
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func
)
else:
raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}")

View File

@@ -14,6 +14,7 @@ class BaseHandlerStorage:
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
fetch_orig: bool = True,
proc_func: Callable = None,
**kwargs,
) -> pd.DataFrame:
"""fetch data from the data storage
@@ -24,6 +25,7 @@ class BaseHandlerStorage:
describe how to select data by index
level : Union[str, int]
which index level to select the data
- if level is None, apply selector to df directly
col_set : Union[str, List[str]]
- if isinstance(col_set, str):
select a set of meaningful columns.(e.g. features, columns)
@@ -33,15 +35,24 @@ class BaseHandlerStorage:
select several sets of meaningful columns, the returned data has multiple level
fetch_orig : bool
Return the original data instead of copy if possible.
proc_func: Callable
please refer to the doc of DataHandler.fetch
Returns
-------
pd.DataFrame
the dataframe fetched
"""
raise NotImplementedError("fetch is method not implemented!")
@staticmethod
def from_df(df: pd.DataFrame):
raise NotImplementedError("from_df method is not implemented!")
def is_proc_func_supported(self):
"""whether the arg `proc_func` in `fetch` method is supported."""
raise NotImplementedError("is_proc_func_supported method is not implemented!")
class HasingStockStorage(BaseHandlerStorage):
def __init__(self, df):
@@ -105,3 +116,7 @@ class HasingStockStorage(BaseHandlerStorage):
return fetch_stock_df_list[0]
else:
return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig)
def is_proc_func_supported(self):
"""the arg `proc_func` in `fetch` method is not supported in HasingStockStorage"""
return False

View File

@@ -3,6 +3,8 @@ import datetime
import numpy as np
import pandas as pd
from functools import partial
from typing import Tuple, List, Union, Optional, Callable
from . import lazy_sort_index
@@ -263,3 +265,36 @@ def resam_ts_data(
elif isinstance(method, str):
return getattr(feature, method)(**method_kwargs)
return feature
def get_valid_value(series, last=True):
"""get the first/last not nan value of pd.Series with single level index
Parameters
----------
series : pd.Seires
series should not be empty
last : bool, optional
wether to get the last valid value, by default True
- if last is True, get the last valid value
- else, get the first valid value
Returns
-------
Nan | float
the first/last valid value
"""
return series.fillna(method="ffill").iloc[-1] if last else series.fillna(method="bfill").iloc[0]
def _ts_data_valid(ts_feature, last=False):
"""get the first/last not nan value of pd.Series|DataFrame with single level index"""
if isinstance(ts_feature, pd.DataFrame):
return ts_feature.apply(lambda column: get_valid_value(column, last=last))
elif isinstance(ts_feature, pd.Series):
return get_valid_value(ts_feature, last=last)
else:
raise TypeError(f"ts_feature should be pd.DataFrame/Series, not {type(ts_feature)}")
ts_data_last = partial(_ts_data_valid, last=False)
ts_data_first = partial(_ts_data_valid, last=True)