mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
add default executor config & update bug in indicator
This commit is contained in:
@@ -64,22 +64,41 @@ class NestedDecisonExecutionWorkflow:
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "week",
|
||||
"time_per_step": "day",
|
||||
"inner_executor": {
|
||||
"class": "SimulatorExecutor",
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "day",
|
||||
"time_per_step": "30min",
|
||||
"inner_executor": {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "5min",
|
||||
"generate_report": True,
|
||||
"verbose": True,
|
||||
"indicator_config": {
|
||||
"show_indicator": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "TWAPStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
},
|
||||
"generate_report": True,
|
||||
"verbose": True,
|
||||
"indicator_config": {
|
||||
"show_indicator": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "TWAPStrategy",
|
||||
"class": "SBBStrategyEMA",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
"kwargs": {
|
||||
"instruments": market,
|
||||
"freq": "1min",
|
||||
},
|
||||
},
|
||||
"track_data": True,
|
||||
"generate_report": True,
|
||||
@@ -92,9 +111,8 @@ class NestedDecisonExecutionWorkflow:
|
||||
"start_time": "2020-01-01",
|
||||
"end_time": "2020-12-31",
|
||||
"account": 100000000,
|
||||
"benchmark": benchmark,
|
||||
"exchange_kwargs": {
|
||||
"freq": "day",
|
||||
"freq": "1min",
|
||||
"limit_threshold": 0.095,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
@@ -106,14 +124,14 @@ class NestedDecisonExecutionWorkflow:
|
||||
|
||||
def _init_qlib(self):
|
||||
"""initialize qlib"""
|
||||
provider_uri_day = "/data1/v-xiabi/qlib/qlib_data/cn_data" # target_dir
|
||||
# provider_uri_day = "/data/stock_data/huaxia/qlib"
|
||||
# provider_uri_1min = "/data2/stock_data/huaxia_1min_qlib"
|
||||
provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN, version="v2", exists_skip=True)
|
||||
# provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri")
|
||||
provider_uri_1min = "/data1/v-xiabi/qlib/qlib_data/cn_data_highfreq"
|
||||
provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri")
|
||||
GetData().qlib_data(
|
||||
target_dir=provider_uri_1min, interval="1min", region=REG_CN, version="v2", exists_skip=True
|
||||
)
|
||||
provider_uri_day = "/data/csdesign/qlib"
|
||||
provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day}
|
||||
client_config = {
|
||||
"calendar_provider": {
|
||||
@@ -139,7 +157,7 @@ class NestedDecisonExecutionWorkflow:
|
||||
},
|
||||
},
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri_day, **client_config)
|
||||
qlib.init(provider_uri=provider_uri_day, **client_config, redis_port=-1)
|
||||
|
||||
def _train_model(self, model, dataset):
|
||||
with R.start(experiment_name="train"):
|
||||
@@ -177,8 +195,8 @@ class NestedDecisonExecutionWorkflow:
|
||||
par = PortAnaRecord(
|
||||
recorder,
|
||||
self.port_analysis_config,
|
||||
risk_analysis_freq=["week", "day"],
|
||||
indicator_analysis_freq=["week", "day"],
|
||||
risk_analysis_freq=["day", "30min", "5min"],
|
||||
indicator_analysis_freq=["day", "30min", "5min"],
|
||||
indicator_analysis_method="value_weighted",
|
||||
)
|
||||
par.generate()
|
||||
|
||||
@@ -166,6 +166,7 @@ class BaseExecutor:
|
||||
return self.execute(trade_decision)
|
||||
|
||||
def get_report(self):
|
||||
"""get the history report and postions instance"""
|
||||
if self.generate_report:
|
||||
_report = self.trade_account.report.generate_report_dataframe()
|
||||
_positions = self.trade_account.get_positions()
|
||||
@@ -173,13 +174,14 @@ class BaseExecutor:
|
||||
else:
|
||||
raise ValueError("generate_report should be True if you want to generate report")
|
||||
|
||||
def get_all_executors(self):
|
||||
"""Return all executors"""
|
||||
return [self]
|
||||
|
||||
def get_trade_indicator(self):
|
||||
"""get the trade indicator instance, which has pa/pos/ffr info."""
|
||||
return self.trade_account.indicator
|
||||
|
||||
def get_all_executors(self):
|
||||
"""get all executors"""
|
||||
return [self]
|
||||
|
||||
|
||||
class NestedExecutor(BaseExecutor):
|
||||
"""
|
||||
@@ -295,7 +297,7 @@ class NestedExecutor(BaseExecutor):
|
||||
return execute_result
|
||||
|
||||
def get_all_executors(self):
|
||||
"""Return all executors, including self and inner_executor.get_all_executors()"""
|
||||
"""get all executors, including self and inner_executor.get_all_executors()"""
|
||||
return [self, *self.inner_executor.get_all_executors()]
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from pandas.core import groupby
|
||||
|
||||
from pandas.core.frame import DataFrame
|
||||
|
||||
from ..utils.resam import parse_freq, resam_ts_data, get_higher_freq_feature
|
||||
from ..utils.resam import parse_freq, resam_ts_data, get_higher_eq_freq_feature
|
||||
from ..data import D
|
||||
from ..tests.config import CSI300_BENCH
|
||||
|
||||
@@ -82,7 +82,7 @@ class Report:
|
||||
raise ValueError("benchmark freq can't be None!")
|
||||
_codes = benchmark if isinstance(benchmark, list) else [benchmark]
|
||||
fields = ["$close/Ref($close,1)-1"]
|
||||
_temp_result, _ = get_higher_freq_feature(_codes, fields, start_time, end_time, freq=freq)
|
||||
_temp_result, _ = get_higher_eq_freq_feature(_codes, fields, start_time, end_time, freq=freq)
|
||||
if len(_temp_result) == 0:
|
||||
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
|
||||
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
|
||||
@@ -308,6 +308,7 @@ class Indicator:
|
||||
raise ValueError(f"base_price {base_price} is not supported!")
|
||||
|
||||
self.order_indicator["pa"] = self.order_indicator["trade_price"] / self.order_indicator["base_price"] - 1
|
||||
# print("trade_price", self.order_indicator["trade_price"], "base_price", self.order_indicator["base_price"], "pa", self.order_indicator["pa"]* (2 * (self.order_indicator["amount"] < 0).astype(int) - 1))
|
||||
|
||||
def _cal_trade_fulfill_rate(self, method="mean"):
|
||||
if method == "mean":
|
||||
@@ -322,8 +323,7 @@ class Indicator:
|
||||
raise ValueError(f"method {method} is not supported!")
|
||||
|
||||
def _cal_trade_price_advantage(self, method="mean"):
|
||||
|
||||
pa_order = self.order_indicator["pa"] * (self.order_indicator["amount"] < 0).astype(int)
|
||||
pa_order = self.order_indicator["pa"] * (2 * (self.order_indicator["amount"] < 0).astype(int) - 1)
|
||||
if method == "mean":
|
||||
return pa_order.mean()
|
||||
elif method == "amount_weighted":
|
||||
@@ -336,8 +336,8 @@ class Indicator:
|
||||
raise ValueError(f"method {method} is not supported!")
|
||||
|
||||
def _cal_trade_positive_rate(self):
|
||||
pa_order = self.order_indicator["pa"] * (self.order_indicator["amount"] < 0).astype(int)
|
||||
return (pa_order > 0).astype(int).sum() / len(pa_order)
|
||||
pa_order = self.order_indicator["pa"] * (2 * (self.order_indicator["amount"] < 0).astype(int) - 1)
|
||||
return (pa_order > 0).astype(int).sum() / pa_order.count()
|
||||
|
||||
def _cal_trade_amount(self):
|
||||
return self.order_indicator["deal_amount"].abs().sum()
|
||||
@@ -345,6 +345,9 @@ class Indicator:
|
||||
def _cal_trade_value(self):
|
||||
return self.order_indicator["trade_value"].abs().sum()
|
||||
|
||||
def _cal_trade_order_count(self):
|
||||
return self.order_indicator["amount"].count()
|
||||
|
||||
def update_order_indicators(self, trade_start_time, trade_end_time, trade_info, trade_exchange):
|
||||
self._update_order_trade_info(trade_info=trade_info)
|
||||
self._update_order_fulfill_rate()
|
||||
@@ -365,11 +368,13 @@ class Indicator:
|
||||
positive_rate = self._cal_trade_positive_rate()
|
||||
trade_amount = self._cal_trade_amount()
|
||||
trade_value = self._cal_trade_value()
|
||||
order_count = self._cal_trade_order_count()
|
||||
self.trade_indicator["ffr"] = fulfill_rate
|
||||
self.trade_indicator["pa"] = price_advantage
|
||||
self.trade_indicator["pos"] = positive_rate
|
||||
self.trade_indicator["amount"] = trade_amount
|
||||
self.trade_indicator["value"] = trade_value
|
||||
self.trade_indicator["count"] = order_count
|
||||
if show_indicator:
|
||||
print(
|
||||
"[Indicator({}) {:%Y-%m-%d %H:%M:%S}]: FFR: {}, PA: {}, POS: {}".format(
|
||||
|
||||
@@ -84,29 +84,33 @@ def indicator_analysis(df, method="mean"):
|
||||
|
||||
index: Index(datetime)
|
||||
method : str, optional
|
||||
statistics method, by default "mean"
|
||||
statistics method of pa/ffr, by default "mean"
|
||||
- if method is 'mean', count the mean statistical value of each trade indicator
|
||||
- if method is 'amount_weighted', count the amount weighted mean statistical value of each trade indicator
|
||||
- if method is 'value_weighted', count the value weighted mean statistical value of each trade indicator
|
||||
Note: statistics method of pos is always "mean"
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame
|
||||
statistical value of each trade indicator
|
||||
statistical value of each trade indicators
|
||||
"""
|
||||
indicators_df = df[["pa", "pos", "ffr"]]
|
||||
|
||||
if method == "mean":
|
||||
res = indicators_df.mean()
|
||||
elif method == "amount_weighted":
|
||||
weights = df["amount"].abs()
|
||||
res = indicators_df.mul(weights, axis=0).sum() / weights.sum()
|
||||
elif method == "value_weighted":
|
||||
weights = df["value"].abs()
|
||||
res = indicators_df.mul(weights, axis=0).sum() / weights.sum()
|
||||
else:
|
||||
weights_dict = {
|
||||
"mean": df["count"],
|
||||
"amount_weighted": df["amount"].abs(),
|
||||
"value_weighted": df["value"].abs(),
|
||||
}
|
||||
if method not in weights_dict:
|
||||
raise ValueError(f"indicator_analysis method {method} is not supported!")
|
||||
|
||||
# statistic pa/ffr indicator
|
||||
indicators_df = df[["ffr", "pa"]]
|
||||
weights = weights_dict.get(method)
|
||||
res = indicators_df.mul(weights, axis=0).sum() / weights.sum()
|
||||
|
||||
# statistic pos
|
||||
weights = weights_dict.get("mean")
|
||||
res.loc["pos"] = df["pos"].mul(weights).sum() / weights.sum()
|
||||
res = res.to_frame("value")
|
||||
return res
|
||||
|
||||
|
||||
@@ -414,12 +414,12 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
# if EMA signal > 0, return long trend
|
||||
elif _sample_signal.iloc[0] > 0:
|
||||
return self.TREND_LONG
|
||||
# if EMA signal > 0, return short trend
|
||||
# if EMA signal < 0, return short trend
|
||||
else:
|
||||
return self.TREND_SHORT
|
||||
|
||||
|
||||
class VAStrategy(BaseStrategy):
|
||||
class ACStrategy(BaseStrategy):
|
||||
def __init__(
|
||||
self,
|
||||
lamb: float = 1e-6,
|
||||
@@ -451,7 +451,7 @@ class VAStrategy(BaseStrategy):
|
||||
if isinstance(instruments, str):
|
||||
self.instruments = D.instruments(instruments)
|
||||
self.freq = freq
|
||||
super(VAStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
|
||||
super(ACStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
|
||||
|
||||
if trade_exchange is not None:
|
||||
self.trade_exchange = trade_exchange
|
||||
@@ -483,7 +483,7 @@ class VAStrategy(BaseStrategy):
|
||||
- It should include `trade_account`, used to get position
|
||||
- It should include `trade_exchange`, used to provide market info
|
||||
"""
|
||||
super(VAStrategy, self).reset_common_infra(common_infra)
|
||||
super(ACStrategy, self).reset_common_infra(common_infra)
|
||||
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
@@ -508,7 +508,7 @@ class VAStrategy(BaseStrategy):
|
||||
----------
|
||||
outer_trade_decision : List[Order], optional
|
||||
"""
|
||||
super(VAStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
super(ACStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
self.trade_amount = {}
|
||||
# init the trade amount of order and predicted trade trend
|
||||
|
||||
@@ -210,33 +210,12 @@ def get_resam_calendar(
|
||||
return _calendar, freq, freq_sam
|
||||
|
||||
|
||||
def get_higher_freq_feature(instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
|
||||
"""[summary]
|
||||
|
||||
Parameters
|
||||
----------
|
||||
instruments : [type]
|
||||
[description]
|
||||
fields : [type]
|
||||
[description]
|
||||
start_time : [type], optional
|
||||
[description], by default None
|
||||
end_time : [type], optional
|
||||
[description], by default None
|
||||
freq : str, optional
|
||||
[description], by default "day"
|
||||
disk_cache : int, optional
|
||||
[description], by default 1
|
||||
|
||||
def get_higher_eq_freq_feature(instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
|
||||
"""get the feature with higher or equal frequency than `freq`.
|
||||
Returns
|
||||
-------
|
||||
[type]
|
||||
[description]
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
[description]
|
||||
pd.DataFrame
|
||||
the feature with higher or equal frequency
|
||||
"""
|
||||
|
||||
from ..data.data import D
|
||||
@@ -331,13 +310,12 @@ def resam_ts_data(
|
||||
sample method, apply method function to each stock series data, by default "last"
|
||||
- If type(method) is str or callable function, it should be an attribute of SeriesGroupBy or DataFrameGroupby, and applies groupy.method for the sliced time-series data
|
||||
- If method is None, do nothing for the sliced time-series data.
|
||||
- Only when the index `feature` is MultiIndex[instrument, datetime], the method is valid.
|
||||
method_kwargs : dict, optional
|
||||
arguments of method, by default {}
|
||||
|
||||
Returns
|
||||
-------
|
||||
The Resampled DataFrame/Series/Value
|
||||
The resampled DataFrame/Series/value, return None when the resampled data is empty.
|
||||
"""
|
||||
|
||||
selector_datetime = slice(start_time, end_time)
|
||||
|
||||
@@ -299,8 +299,8 @@ class PortAnaRecord(RecordTemp):
|
||||
self,
|
||||
recorder,
|
||||
config,
|
||||
risk_analysis_freq: Union[List, str] = [],
|
||||
indicator_analysis_freq: Union[List, str] = [],
|
||||
risk_analysis_freq: Union[List, str] = None,
|
||||
indicator_analysis_freq: Union[List, str] = None,
|
||||
indicator_analysis_method=None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -321,8 +321,23 @@ class PortAnaRecord(RecordTemp):
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
|
||||
self.strategy_config = config["strategy"]
|
||||
self.executor_config = config["executor"]
|
||||
_default_executor_config = {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "day",
|
||||
"generate_report": True,
|
||||
},
|
||||
}
|
||||
self.executor_config = config.get("executor", _default_executor_config)
|
||||
self.backtest_config = config["backtest"]
|
||||
|
||||
self.all_freq = self._get_report_freq(self.executor_config)
|
||||
if risk_analysis_freq is None:
|
||||
risk_analysis_freq = [self.all_freq[0]]
|
||||
if indicator_analysis_freq is None:
|
||||
indicator_analysis_freq = [self.all_freq[0]]
|
||||
|
||||
if isinstance(risk_analysis_freq, str):
|
||||
risk_analysis_freq = [risk_analysis_freq]
|
||||
if isinstance(indicator_analysis_freq, str):
|
||||
@@ -335,7 +350,6 @@ class PortAnaRecord(RecordTemp):
|
||||
"{0}{1}".format(*parse_freq(_analysis_freq)) for _analysis_freq in indicator_analysis_freq
|
||||
]
|
||||
self.indicator_analysis_method = indicator_analysis_method
|
||||
self.all_freq = self._get_report_freq(self.executor_config)
|
||||
|
||||
def _get_report_freq(self, executor_config):
|
||||
ret_freq = []
|
||||
@@ -399,21 +413,26 @@ class PortAnaRecord(RecordTemp):
|
||||
pprint(analysis["excess_return_with_cost"])
|
||||
|
||||
for _analysis_freq in self.indicator_analysis_freq:
|
||||
indicators_normal = indicator_dict.get(_analysis_freq)
|
||||
if self.indicator_analysis_method is None:
|
||||
analysis_df = indicator_analysis(indicators_normal)
|
||||
if _analysis_freq not in indicator_dict:
|
||||
warnings.warn(f"the freq {_analysis_freq} indicator is not found")
|
||||
else:
|
||||
analysis_df = indicator_analysis(indicators_normal, method=self.indicator_analysis_method)
|
||||
|
||||
# log metrics
|
||||
analysis_dict = analysis_df["value"].to_dict()
|
||||
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
|
||||
# save results
|
||||
self.recorder.save_objects(
|
||||
**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
pprint(f"The following are analysis results of indicators({_analysis_freq}).")
|
||||
pprint(analysis_df)
|
||||
indicators_normal = indicator_dict.get(_analysis_freq)
|
||||
if self.indicator_analysis_method is None:
|
||||
analysis_df = indicator_analysis(indicators_normal)
|
||||
else:
|
||||
analysis_df = indicator_analysis(indicators_normal, method=self.indicator_analysis_method)
|
||||
# log metrics
|
||||
analysis_dict = analysis_df["value"].to_dict()
|
||||
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
|
||||
# save results
|
||||
self.recorder.save_objects(
|
||||
**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
logger.info(
|
||||
f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
pprint(f"The following are analysis results of indicators({_analysis_freq}).")
|
||||
pprint(analysis_df)
|
||||
|
||||
def list(self):
|
||||
list_path = []
|
||||
@@ -424,10 +443,16 @@ class PortAnaRecord(RecordTemp):
|
||||
PortAnaRecord.get_path(f"positions_normal_{_freq}.pkl"),
|
||||
]
|
||||
)
|
||||
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
if _analysis_freq in self.all_freq:
|
||||
list_path.append(PortAnaRecord.get_path(f"port_analysis_{_analysis_freq}.pkl"))
|
||||
else:
|
||||
warnings.warn(f"{_analysis_freq} is not found")
|
||||
warnings.warn(f"risk_analysis freq {_analysis_freq} is not found")
|
||||
|
||||
for _analysis_freq in self.indicator_analysis_freq:
|
||||
if _analysis_freq in self.all_freq:
|
||||
list_path.append(PortAnaRecord.get_path(f"indicator_analysis_{_analysis_freq}.pkl"))
|
||||
else:
|
||||
warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found")
|
||||
|
||||
return list_path
|
||||
|
||||
Reference in New Issue
Block a user