1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 20:11:08 +08:00

rename var in backtest

This commit is contained in:
bxdd
2021-05-27 17:03:53 +08:00
parent ee74489c37
commit 2ad61f12b3
10 changed files with 165 additions and 166 deletions

View File

@@ -61,24 +61,24 @@ class MultiLevelTradingWorkflow:
}
trade_start_time = "2017-01-01"
trade_end_time = "2020-08-01"
trade_end_time = "2017-02-01"
port_analysis_config = {
"executor": {
"class": "SplitExecutor",
"module_path": "qlib.contrib.backtest.executor",
"kwargs": {
"step_bar": "week",
"sub_executor": {
"time_per_step": "week",
"inner_executor": {
"class": "SimulatorExecutor",
"module_path": "qlib.contrib.backtest.executor",
"kwargs": {
"step_bar": "day",
"time_per_step": "day",
"verbose": True,
"generate_report": True,
},
},
"sub_strategy": {
"inner_strategy": {
"class": "SBBStrategyEMA",
"module_path": "qlib.contrib.strategy.rule_strategy",
"kwargs": {
@@ -107,7 +107,6 @@ class MultiLevelTradingWorkflow:
def _init_qlib(self):
"""initialize qlib"""
# use yahoo_cn_1min data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")

View File

@@ -23,7 +23,6 @@ class RollingDataWorkflow:
def _init_qlib(self):
"""initialize qlib"""
# use yahoo_cn_1min data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")

View File

@@ -8,10 +8,10 @@ def backtest(start_time, end_time, trade_strategy, trade_executor):
level_infra = trade_executor.get_level_infra()
trade_strategy.reset(level_infra=level_infra)
sub_execute_state = trade_executor.get_init_state()
_execute_result = None
while not trade_executor.finished():
sub_trade_decision = trade_strategy.generate_trade_decision(sub_execute_state)
sub_execute_state = trade_executor.execute(sub_trade_decision)
_trade_decision = trade_strategy.generate_trade_decision(_execute_result)
_execute_result = trade_executor.execute(_trade_decision)
return trade_executor.get_report()
@@ -22,9 +22,9 @@ def collect_data(start_time, end_time, trade_strategy, trade_executor):
level_infra = trade_executor.get_level_infra()
trade_strategy.reset(level_infra=level_infra)
sub_execute_state = trade_executor.get_init_state()
_execute_result = None
while not trade_executor.finished():
sub_trade_decision = trade_strategy.generate_trade_decision(sub_execute_state)
sub_execute_state = yield from trade_executor.collect_data(sub_trade_decision)
_trade_decision = trade_strategy.generate_trade_decision(_execute_result)
_execute_result = yield from trade_executor.collect_data(_trade_decision)
return trade_executor.get_report()

View File

@@ -8,7 +8,6 @@ from ...utils.resam import parse_freq
from .order import Order
from .account import Account
from .exchange import Exchange
from .utils import TradeCalendarManager
@@ -18,7 +17,7 @@ class BaseExecutor:
def __init__(
self,
step_bar: str,
time_per_step: str,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
generate_report: bool = False,
@@ -30,6 +29,8 @@ class BaseExecutor:
"""
Parameters
----------
time_per_step : str
trade time per trading step, used for genreate trade calendar
generate_report : bool, optional
whether to generate report, by default False
verbose : bool, optional
@@ -46,7 +47,7 @@ class BaseExecutor:
exchange that provides market info
"""
self.step_bar = step_bar
self.time_per_step = time_per_step
self.generate_report = generate_report
self.verbose = verbose
self.track_data = track_data
@@ -64,7 +65,7 @@ class BaseExecutor:
if "trade_account" in common_infra:
self.trade_account = copy.copy(common_infra.get("trade_account"))
self.trade_account.reset(freq=self.step_bar, init_report=True)
self.trade_account.reset(freq=self.time_per_step, init_report=True)
def reset(self, track_data: bool = None, common_infra: dict = None, **kwargs):
"""
@@ -76,19 +77,19 @@ class BaseExecutor:
if track_data is not None:
self.track_data = track_data
if common_infra is not None:
self.reset_common_infra(common_infra)
if "start_time" in kwargs or "end_time" in kwargs:
start_time = kwargs.get("start_time")
end_time = kwargs.get("end_time")
self.trade_calendar = TradeCalendarManager(step_bar=self.step_bar, start_time=start_time, end_time=end_time)
self.calendar = TradeCalendarManager(freq=self.time_per_step, start_time=start_time, end_time=end_time)
if common_infra is not None:
self.reset_common_infra(common_infra)
def get_level_infra(self):
return {"trade_calendar": self.trade_calendar}
return {"calendar": self.calendar}
def finished(self):
return self.trade_calendar.finished()
return self.calendar.finished()
def execute(self, trade_decision):
"""execute the trade decision and return the executed result
@@ -99,8 +100,8 @@ class BaseExecutor:
Returns
----------
executed state : List[Tuple[Order, float, float, float]]
- Each element in the list represents (order, trade value, trade cost, trade price)
execute_result : List[object]
the executed result for trade decison
"""
raise NotImplementedError("execute is not implemented!")
@@ -109,9 +110,6 @@ class BaseExecutor:
yield trade_decision
return self.execute(trade_decision)
def get_init_state(self):
raise NotImplementedError("get_init_state in not implemeted!")
def get_trade_account(self):
raise NotImplementedError("get_trade_account is not implemented!")
@@ -124,9 +122,9 @@ class SplitExecutor(BaseExecutor):
def __init__(
self,
step_bar: str,
sub_executor: Union[BaseExecutor, dict],
sub_strategy: Union[BaseStrategy, dict],
time_per_step: str,
inner_executor: Union[BaseExecutor, dict],
inner_strategy: Union[BaseStrategy, dict],
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
trade_exchange: Exchange = None,
@@ -139,22 +137,24 @@ class SplitExecutor(BaseExecutor):
"""
Parameters
----------
sub_executor : BaseExecutor
inner_executor : BaseExecutor
trading env in each trading bar.
sub_strategy : BaseStrategy
inner_strategy : BaseStrategy
trading strategy in each trading bar
trade_exchange : Exchange
exchange that provides market info, used to generate report
- If generate_report is None, trade_exchange will be ignored
- Else If `trade_exchange` is None, self.trade_exchange will be set with common_infra
"""
self.sub_executor = init_instance_by_config(sub_executor, common_infra=common_infra, accept_types=BaseExecutor)
self.sub_strategy = init_instance_by_config(
sub_strategy, common_infra=common_infra, accept_types=self.BaseStrategy
self.inner_executor = init_instance_by_config(
inner_executor, common_infra=common_infra, accept_types=BaseExecutor
)
self.inner_strategy = init_instance_by_config(
inner_strategy, common_infra=common_infra, accept_types=self.BaseStrategy
)
super(SplitExecutor, self).__init__(
step_bar=step_bar,
time_per_step=time_per_step,
start_time=start_time,
end_time=end_time,
generate_report=generate_report,
@@ -171,29 +171,26 @@ class SplitExecutor(BaseExecutor):
"""
reset infrastructure for trading
- reset trade_exchange
- reset substrategy and subexecutor common infra
- reset inner_strategyand inner_executor common infra
"""
super(SplitExecutor, self).reset_common_infra(common_infra)
if self.generate_report and "trade_exchange" in common_infra:
self.trade_exchange = common_infra.get("trade_exchange")
self.sub_executor.reset_common_infra(common_infra)
self.sub_strategy.reset_common_infra(common_infra)
def get_init_state(self):
return []
self.inner_executor.reset_common_infra(common_infra)
self.inner_strategy.reset_common_infra(common_infra)
def _init_sub_trading(self, trade_decision):
trade_index = self.trade_calendar.get_trade_index()
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
self.sub_executor.reset(start_time=trade_start_time, end_time=trade_end_time)
sub_level_infra = self.sub_executor.get_level_infra()
self.sub_strategy.reset(level_infra=sub_level_infra, rely_trade_decision=trade_decision)
trade_index = self.calendar.get_trade_index()
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
self.inner_executor.reset(start_time=trade_start_time, end_time=trade_end_time)
sub_level_infra = self.inner_executor.get_level_infra()
self.inner_strategy.reset(level_infra=sub_level_infra, outer_trade_decision=trade_decision)
def _update_trade_account(self):
trade_index = self.trade_calendar.get_trade_index()
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
trade_index = self.calendar.get_trade_index()
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
self.trade_account.update_bar_count()
if self.generate_report:
self.trade_account.update_bar_report(
@@ -203,41 +200,41 @@ class SplitExecutor(BaseExecutor):
)
def execute(self, trade_decision):
self.trade_calendar.step()
self.calendar.step()
self._init_sub_trading(trade_decision)
execute_state = []
sub_execute_state = self.sub_executor.get_init_state()
while not self.sub_executor.finished():
sub_trade_decison = self.sub_strategy.generate_trade_decision(sub_execute_state)
sub_execute_state = self.sub_executor.execute(trade_decision=sub_trade_decison)
execute_state.extend(sub_execute_state)
execute_result = []
_inner_execute_result = None
while not self.inner_executor.finished():
_inner_trade_decision = self.inner_strategy.generate_trade_decision(_inner_execute_result)
_inner_execute_result = self.inner_executor.execute(trade_decision=_inner_trade_decision)
execute_result.extend(_inner_execute_result)
if hasattr(self, "trade_account"):
self._update_trade_account()
return execute_state
return execute_result
def collect_data(self, trade_decision):
if self.track_data:
yield trade_decision
self.trade_calendar.step()
self.calendar.step()
self._init_sub_trading(trade_decision)
execute_state = []
sub_execute_state = self.sub_executor.get_init_state()
while not self.sub_executor.finished():
sub_trade_decison = self.sub_strategy.generate_trade_decision(sub_execute_state)
sub_execute_state = yield from self.sub_executor.collect_data(trade_decision=sub_trade_decison)
execute_state.extend(sub_execute_state)
execute_result = []
_inner_execute_result = None
while not self.inner_executor.finished():
_inner_trade_decision = self.inner_strategy.generate_trade_decision(_inner_execute_result)
_inner_execute_result = yield from self.inner_executor.collect_data(trade_decision=_inner_trade_decision)
execute_result.extend(_inner_execute_result)
if hasattr(self, "trade_account"):
self._update_trade_account()
return execute_state
return execute_result
def get_report(self):
sub_env_report_dict = self.sub_executor.get_report()
sub_env_report_dict = self.inner_executor.get_report()
if self.generate_report:
_report = self.trade_account.report.generate_report_dataframe()
_positions = self.trade_account.get_positions()
_count, _freq = parse_freq(self.step_bar)
_count, _freq = parse_freq(self.time_per_step)
sub_env_report_dict.update({f"{_count}{_freq}": (_report, _positions)})
return sub_env_report_dict
@@ -245,7 +242,7 @@ class SplitExecutor(BaseExecutor):
class SimulatorExecutor(BaseExecutor):
def __init__(
self,
step_bar: str,
time_per_step: str,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
trade_exchange: Exchange = None,
@@ -263,7 +260,7 @@ class SimulatorExecutor(BaseExecutor):
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
"""
super(SimulatorExecutor, self).__init__(
step_bar=step_bar,
time_per_step=time_per_step,
start_time=start_time,
end_time=end_time,
generate_report=generate_report,
@@ -284,21 +281,18 @@ class SimulatorExecutor(BaseExecutor):
if "trade_exchange" in common_infra:
self.trade_exchange = common_infra.get("trade_exchange")
def get_init_state(self):
return []
def execute(self, trade_decision):
self.trade_calendar.step()
trade_index = self.trade_calendar.get_trade_index()
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
execute_state = []
self.calendar.step()
trade_index = self.calendar.get_trade_index()
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
execute_result = []
for order in trade_decision:
if self.trade_exchange.check_order(order) is True:
# execute the order
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(
order, trade_account=self.trade_account
)
execute_state.append((order, trade_val, trade_cost, trade_price))
execute_result.append((order, trade_val, trade_cost, trade_price))
if self.verbose:
if order.direction == Order.SELL: # sell
print(
@@ -340,13 +334,13 @@ class SimulatorExecutor(BaseExecutor):
trade_exchange=self.trade_exchange,
)
return execute_state
return execute_result
def get_report(self):
if self.generate_report:
_report = self.trade_account.report.generate_report_dataframe()
_positions = self.trade_account.get_positions()
_count, _freq = parse_freq(self.step_bar)
_count, _freq = parse_freq(self.time_per_step)
return {f"{_count}{_freq}": (_report, _positions)}
else:
return {}

View File

@@ -15,13 +15,13 @@ class TradeCalendarManager:
"""
def __init__(
self, step_bar: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None
self, freq: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None
):
"""
Parameters
----------
step_bar : str
frequency of each trading calendar
freq : str
frequency of trading calendar, also trade time per trading step
start_time : Union[str, pd.Timestamp], optional
closed start of the trading calendar, by default None
If `start_time` is None, it must be reset before trading.
@@ -29,14 +29,14 @@ class TradeCalendarManager:
closed end of the trade time range, by default None
If `end_time` is None, it must be reset before trading.
"""
self.step_bar = step_bar
self.freq = freq
self.start_time = pd.Timestamp(start_time) if start_time else None
self.end_time = pd.Timestamp(start_time) if start_time else None
self._init_trade_calendar(step_bar=step_bar, start_time=start_time, end_time=end_time)
self._init_trade_calendar(freq=freq, start_time=start_time, end_time=end_time)
def _init_trade_calendar(self, step_bar, start_time, end_time):
def _init_trade_calendar(self, freq, start_time, end_time):
"""reset trade calendar"""
_calendar, freq, freq_sam = get_resam_calendar(freq=step_bar)
_calendar, freq, freq_sam = get_resam_calendar(freq=freq)
self.calendar = _calendar
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam)
self.start_index = _start_index
@@ -52,8 +52,8 @@ class TradeCalendarManager:
raise RuntimeError(f"The calendar is finished, please reset it if you want to call it!")
self.trade_index = self.trade_index + 1
def get_step_bar(self):
return self.step_bar
def get_freq(self):
return self.freq
def get_trade_len(self):
return self.trade_len

View File

@@ -81,10 +81,10 @@ class TopkDropoutStrategy(ModelStrategy):
# It will use 95% amoutn of your total value by default
return self.risk_degree
def generate_trade_decision(self, execute_state):
trade_index = self.trade_calendar.get_trade_index()
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
pred_start_time, pred_end_time = self.trade_calendar.get_calendar_time(trade_index, shift=1)
def generate_trade_decision(self, execute_result=None):
trade_index = self.calendar.get_trade_index()
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
pred_start_time, pred_end_time = self.calendar.get_calendar_time(trade_index, shift=1)
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
if pred_score is None:
return []
@@ -179,8 +179,8 @@ class TopkDropoutStrategy(ModelStrategy):
continue
if code in sell:
# check hold limit
step_bar = self.trade_calendar.get_step_bar()
if current_temp.get_stock_count(code, bar=step_bar) < self.hold_thresh:
time_per_step = self.calendar.get_freq()
if current_temp.get_stock_count(code, bar=time_per_step) < self.hold_thresh:
continue
# sell order
sell_amount = current_temp.get_stock_amount(code=code)
@@ -292,7 +292,7 @@ class WeightStrategyBase(ModelStrategy):
"""
raise NotImplementedError()
def generate_trade_decision(self, execute_state):
def generate_trade_decision(self, execute_result=None):
"""
Parameters
-----------
@@ -307,9 +307,9 @@ class WeightStrategyBase(ModelStrategy):
"""
# generate_trade_decision
# generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list
trade_index = self.trade_calendar.get_trade_index()
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
pred_start_time, pred_end_time = self.trade_calendar.get_calendar_time(trade_index, shift=1)
trade_index = self.calendar.get_trade_index()
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
pred_start_time, pred_end_time = self.calendar.get_calendar_time(trade_index, shift=1)
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
if pred_score is None:
return []

View File

@@ -24,31 +24,31 @@ class TWAPStrategy(RuleStrategy):
if "trade_exchange" in common_infra:
self.trade_exchange = common_infra.get("trade_exchange")
def reset(self, rely_trade_decision: object = None, **kwargs):
def reset(self, outer_trade_decision: object = None, **kwargs):
"""
Parameters
----------
rely_trade_decision : object, optional
outer_trade_decision : object, optional
"""
super(TWAPStrategy, self).reset(rely_trade_decision=rely_trade_decision, common_infra=common_infra, **kwargs)
if rely_trade_decision is not None:
super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, common_infra=common_infra, **kwargs)
if outer_trade_decision is not None:
self.trade_amount = {}
for order in rely_trade_decision:
for order in outer_trade_decision:
self.trade_amount[(order.stock_id, order.direction)] = order.amount
def generate_trade_decision(self, execute_state):
def generate_trade_decision(self, execute_result=None):
# update the order amount
trade_info = execute_state
for order, _, _, _ in trade_info:
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
if execute_result is not None:
for order, _, _, _ in execute_result:
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
trade_index = self.trade_calendar.get_trade_index()
trade_len = self.trade_calendar.get_trade_len()
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
trade_index = self.calendar.get_trade_index()
trade_len = self.calendar.get_trade_len()
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
order_list = []
for order in self.rely_trade_decision:
for order in self.outer_trade_decision:
if not self.trade_exchange.is_stock_tradable(
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
):
@@ -104,41 +104,41 @@ class SBBStrategyBase(RuleStrategy):
if "trade_exchange" in common_infra:
self.trade_exchange = common_infra.get("trade_exchange")
def reset(self, rely_trade_decision=None, **kwargs):
def reset(self, outer_trade_decision=None, **kwargs):
"""
Parameters
----------
rely_trade_decision : object, optional
outer_trade_decision : object, optional
common_infra : None, optional
common infrastructure for backtesting, by default None
- It should include `trade_account`, used to get position
- It should include `trade_exchange`, used to provide market info
"""
super(SBBStrategyBase, self).reset(rely_trade_decision=rely_trade_decision, **kwargs)
if rely_trade_decision is not None:
super(SBBStrategyBase, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
if outer_trade_decision is not None:
self.trade_trend = {}
self.trade_amount = {}
# init the trade amount of order and predicted trade trend
for order in rely_trade_decision:
for order in outer_trade_decision:
self.trade_trend[(order.stock_id, order.direction)] = self.TREND_MID
self.trade_amount[(order.stock_id, order.direction)] = order.amount
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
raise NotImplementedError("pred_price_trend method is not implemented!")
def generate_trade_decision(self, execute_state):
def generate_trade_decision(self, execute_result=None):
# update the order amount
trade_info = execute_state
for order, _, _, _ in trade_info:
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
trade_index = self.trade_calendar.get_trade_index()
trade_len = self.trade_calendar.get_trade_len()
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
pred_start_time, pred_end_time = self.trade_calendar.get_calendar_time(trade_index, shift=1)
if execute_result is not None:
for order, _, _, _ in execute_result:
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
trade_index = self.calendar.get_trade_index()
trade_len = self.calendar.get_trade_len()
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
pred_start_time, pred_end_time = self.calendar.get_calendar_time(trade_index, shift=1)
order_list = []
# for each order in in self.rely_trade_decision
for order in self.rely_trade_decision:
# for each order in in self.outer_trade_decision
for order in self.outer_trade_decision:
# predict the price trend
if trade_index % 2 == 1:
_pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time)
@@ -266,7 +266,7 @@ class SBBStrategyEMA(SBBStrategyBase):
def __init__(
self,
rely_trade_decision=[],
outer_trade_decision=[],
instruments="csi300",
freq="day",
level_infra={},
@@ -288,13 +288,13 @@ class SBBStrategyEMA(SBBStrategyBase):
if isinstance(instruments, str):
self.instruments = D.instruments(instruments)
self.freq = freq
super(SBBStrategyEMA, self).__init__(rely_trade_decision, level_infra, common_infra, **kwargs)
super(SBBStrategyEMA, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
def _reset_signal(self):
trade_len = self.trade_calendar.get_trade_len()
trade_len = self.calendar.get_trade_len()
fields = ["EMA($close, 10)-EMA($close, 20)"]
signal_start_time, _ = self.trade_calendar.get_calendar_time(trade_index=1, shift=1)
_, signal_end_time = self.trade_calendar.get_calendar_time(trade_index=trade_len, shift=1)
signal_start_time, _ = self.calendar.get_calendar_time(trade_index=1, shift=1)
_, signal_end_time = self.calendar.get_calendar_time(trade_index=trade_len, shift=1)
signal_df = D.features(
self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
)
@@ -307,15 +307,15 @@ class SBBStrategyEMA(SBBStrategyBase):
def reset_level_infra(self, level_infra):
"""
reset level-shared infra
- After reset the trade_calendar, the signal will be changed
- After reset the trade calendar, the signal will be changed
"""
if not hasattr(self, "level_infra"):
self.level_infra = level_infra
else:
self.level_infra.update(level_infra)
if "trade_calendar" in level_infra:
self.trade_calendar = level_infra.get("trade_calendar")
if "calendar" in level_infra:
self.calendar = level_infra.get("calendar")
self._reset_signal()
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):

View File

@@ -6,6 +6,7 @@ from typing import Union
from .interpreter import StateInterpreter, ActionInterpreter
from ..contrib.backtest.executor import BaseExecutor
from ..utils import init_instance_by_config
from .interpreter import BaseInterpreter
class BaseRLEnv:
@@ -68,8 +69,8 @@ class QlibIntRLEnv(QlibRLEnv):
interpretor that interprets the rl agent action into qlib order list
"""
super(QlibIntRLEnv, self).__init__(executor=executor)
self.state_interpreter = init_instance_by_config(state_interpreter)
self.action_interpreter = init_instance_by_config(action_interpreter)
self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter)
self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter)
def step(self, action):
"""
@@ -87,7 +88,7 @@ class QlibIntRLEnv(QlibRLEnv):
-------
env state to rl policy
"""
_interpret_action = self.action_interpreter.interpret(action=action)
_execute_result = self.executor.execute(_interpret_action)
_interpret_decision = self.action_interpreter.interpret(action=action)
_execute_result = self.executor.execute(trade_decision=_interpret_decision)
_interpret_state = self.state_interpreter.interpret(execute_result=_execute_result)
return _interpret_state

View File

@@ -19,24 +19,24 @@ class BaseStrategy:
def __init__(
self,
rely_trade_decision: object = None,
outer_trade_decision: object = None,
level_infra: dict = {},
common_infra: dict = {},
):
"""
Parameters
----------
rely_trade_decision : object, optional
the high-level trade decison on which the startegy rely, and it will be traded in [start_time , end_time] , by default None
outer_trade_decision : object, optional
the trade decison of outer strategy which this startegy relies, and it will be traded in [start_time, end_time], by default None
- If the strategy is used to split trade decison, it will be used
- If the strategy is used for portfolio management, it can be ignored
level_infra : dict, optional
level shared infrastructure for backtesting, including trade_calendar
level shared infrastructure for backtesting, including trade calendar
common_infra : dict, optional
common infrastructure for backtesting, including trade_account, trade_exchange, .etc
"""
self.reset(level_infra=level_infra, common_infra=common_infra, rely_trade_decision=rely_trade_decision)
self.reset(level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision)
def reset_level_infra(self, level_infra):
if not hasattr(self, "level_infra"):
@@ -44,8 +44,8 @@ class BaseStrategy:
else:
self.level_infra.update(level_infra)
if "trade_calendar" in level_infra:
self.trade_calendar = level_infra.get("trade_calendar")
if "calendar" in level_infra:
self.calendar = level_infra.get("calendar")
def reset_common_infra(self, common_infra):
if not hasattr(self, "common_infra"):
@@ -56,11 +56,11 @@ class BaseStrategy:
if "trade_account" in common_infra:
self.trade_position = common_infra.get("trade_account").current
def reset(self, level_infra: dict = None, common_infra: dict = None, rely_trade_decision=None, **kwargs):
def reset(self, level_infra: dict = None, common_infra: dict = None, outer_trade_decision=None, **kwargs):
"""
- reset `level_infra`, used to reset trade_calendar, .etc
- reset `level_infra`, used to reset trade calendar, .etc
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
- reset `rely_trade_decision`, used to make split decison
- reset `outer_trade_decision`, used to make split decison
"""
if level_infra is not None:
self.reset_level_infra(level_infra)
@@ -68,11 +68,18 @@ class BaseStrategy:
if common_infra is not None:
self.reset_common_infra(common_infra)
if rely_trade_decision is not None:
self.rely_trade_decision = rely_trade_decision
if outer_trade_decision is not None:
self.outer_trade_decision = outer_trade_decision
def generate_trade_decision(self, execute_state):
"""Generate trade decision in each trading bar"""
def generate_trade_decision(self, execute_result=None):
"""Generate trade decision in each trading bar
Parameters
----------
execute_result : List[object], optional
the executed result for trade decison, by default None
- When call the generate_trade_decision firstly, `execute_result` could be None
"""
raise NotImplementedError("generate_trade_decision is not implemented!")
@@ -89,7 +96,7 @@ class ModelStrategy(BaseStrategy):
self,
model: BaseModel,
dataset: DatasetH,
rely_trade_decision: object = None,
outer_trade_decision: object = None,
level_infra: dict = {},
common_infra: dict = {},
**kwargs,
@@ -104,7 +111,7 @@ class ModelStrategy(BaseStrategy):
kwargs : dict
arguments that will be passed into `reset` method
"""
super(ModelStrategy, self).__init__(rely_trade_decision, level_infra, common_infra, **kwargs)
super(ModelStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
self.model = model
self.dataset = dataset
self.pred_scores = convert_index_format(self.model.predict(dataset), level="datetime")
@@ -125,7 +132,7 @@ class RLStrategy(BaseStrategy):
def __init__(
self,
policy,
rely_trade_decision: object = None,
outer_trade_decision: object = None,
level_infra: dict = {},
common_infra: dict = {},
**kwargs,
@@ -136,7 +143,7 @@ class RLStrategy(BaseStrategy):
policy :
RL policy for generate action
"""
super(RLStrategy, self).__init__(rely_trade_decision, level_infra, common_infra, **kwargs)
super(RLStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
self.policy = policy
@@ -148,7 +155,7 @@ class RLIntStrategy(RLStrategy):
policy,
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
rely_trade_decision: object = None,
outer_trade_decision: object = None,
level_infra: dict = {},
common_infra: dict = {},
**kwargs,
@@ -165,15 +172,14 @@ class RLIntStrategy(RLStrategy):
end_time : Union[str, pd.Timestamp], optional
end time of trading, by default None
"""
super(RLIntStrategy, self).__init__(policy, rely_trade_decision, level_infra, common_infra, **kwargs)
super(RLIntStrategy, self).__init__(policy, outer_trade_decision, level_infra, common_infra, **kwargs)
self.policy = policy
self.state_interpreter = init_instance_by_config(state_interpreter)
self.action_interpreter = init_instance_by_config(action_interpreter)
def generate_trade_decision(self, execute_state):
super(RLStrategy, self).step()
_interpret_state = self.state_interpretor.interpret(execute_result=execute_state)
_policy_action = self.policy.step(_interpret_state)
_order_list = self.action_interpreter.interpret(action=_policy_action)
return _order_list
def generate_trade_decision(self, execute_result=None):
_interpret_state = self.state_interpretor.interpret(execute_result=execute_result)
_action = self.policy.step(_interpret_state)
_trade_decision = self.action_interpreter.interpret(action=_action)
return _trade_decision

View File

@@ -317,7 +317,7 @@ class PortAnaRecord(RecordTemp):
def _get_report_freq(self, executor_config):
ret_freq = []
if executor_config["kwargs"].get("generate_report", False):
_count, _freq = parse_freq(executor_config["kwargs"]["step_bar"])
_count, _freq = parse_freq(executor_config["kwargs"]["time_per_step"])
ret_freq.append(f"{_count}{_freq}")
if "sub_env" in executor_config["kwargs"]:
ret_freq.extend(self._get_report_freq(executor_config["kwargs"]["sub_env"]))