mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 20:11:08 +08:00
rename var in backtest
This commit is contained in:
@@ -61,24 +61,24 @@ class MultiLevelTradingWorkflow:
|
||||
}
|
||||
|
||||
trade_start_time = "2017-01-01"
|
||||
trade_end_time = "2020-08-01"
|
||||
trade_end_time = "2017-02-01"
|
||||
|
||||
port_analysis_config = {
|
||||
"executor": {
|
||||
"class": "SplitExecutor",
|
||||
"module_path": "qlib.contrib.backtest.executor",
|
||||
"kwargs": {
|
||||
"step_bar": "week",
|
||||
"sub_executor": {
|
||||
"time_per_step": "week",
|
||||
"inner_executor": {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.contrib.backtest.executor",
|
||||
"kwargs": {
|
||||
"step_bar": "day",
|
||||
"time_per_step": "day",
|
||||
"verbose": True,
|
||||
"generate_report": True,
|
||||
},
|
||||
},
|
||||
"sub_strategy": {
|
||||
"inner_strategy": {
|
||||
"class": "SBBStrategyEMA",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
"kwargs": {
|
||||
@@ -107,7 +107,6 @@ class MultiLevelTradingWorkflow:
|
||||
|
||||
def _init_qlib(self):
|
||||
"""initialize qlib"""
|
||||
# use yahoo_cn_1min data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
|
||||
@@ -23,7 +23,6 @@ class RollingDataWorkflow:
|
||||
|
||||
def _init_qlib(self):
|
||||
"""initialize qlib"""
|
||||
# use yahoo_cn_1min data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
|
||||
@@ -8,10 +8,10 @@ def backtest(start_time, end_time, trade_strategy, trade_executor):
|
||||
level_infra = trade_executor.get_level_infra()
|
||||
trade_strategy.reset(level_infra=level_infra)
|
||||
|
||||
sub_execute_state = trade_executor.get_init_state()
|
||||
_execute_result = None
|
||||
while not trade_executor.finished():
|
||||
sub_trade_decision = trade_strategy.generate_trade_decision(sub_execute_state)
|
||||
sub_execute_state = trade_executor.execute(sub_trade_decision)
|
||||
_trade_decision = trade_strategy.generate_trade_decision(_execute_result)
|
||||
_execute_result = trade_executor.execute(_trade_decision)
|
||||
|
||||
return trade_executor.get_report()
|
||||
|
||||
@@ -22,9 +22,9 @@ def collect_data(start_time, end_time, trade_strategy, trade_executor):
|
||||
level_infra = trade_executor.get_level_infra()
|
||||
trade_strategy.reset(level_infra=level_infra)
|
||||
|
||||
sub_execute_state = trade_executor.get_init_state()
|
||||
_execute_result = None
|
||||
while not trade_executor.finished():
|
||||
sub_trade_decision = trade_strategy.generate_trade_decision(sub_execute_state)
|
||||
sub_execute_state = yield from trade_executor.collect_data(sub_trade_decision)
|
||||
_trade_decision = trade_strategy.generate_trade_decision(_execute_result)
|
||||
_execute_result = yield from trade_executor.collect_data(_trade_decision)
|
||||
|
||||
return trade_executor.get_report()
|
||||
|
||||
@@ -8,7 +8,6 @@ from ...utils.resam import parse_freq
|
||||
|
||||
|
||||
from .order import Order
|
||||
from .account import Account
|
||||
from .exchange import Exchange
|
||||
from .utils import TradeCalendarManager
|
||||
|
||||
@@ -18,7 +17,7 @@ class BaseExecutor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_bar: str,
|
||||
time_per_step: str,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
generate_report: bool = False,
|
||||
@@ -30,6 +29,8 @@ class BaseExecutor:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
time_per_step : str
|
||||
trade time per trading step, used for genreate trade calendar
|
||||
generate_report : bool, optional
|
||||
whether to generate report, by default False
|
||||
verbose : bool, optional
|
||||
@@ -46,7 +47,7 @@ class BaseExecutor:
|
||||
exchange that provides market info
|
||||
|
||||
"""
|
||||
self.step_bar = step_bar
|
||||
self.time_per_step = time_per_step
|
||||
self.generate_report = generate_report
|
||||
self.verbose = verbose
|
||||
self.track_data = track_data
|
||||
@@ -64,7 +65,7 @@ class BaseExecutor:
|
||||
|
||||
if "trade_account" in common_infra:
|
||||
self.trade_account = copy.copy(common_infra.get("trade_account"))
|
||||
self.trade_account.reset(freq=self.step_bar, init_report=True)
|
||||
self.trade_account.reset(freq=self.time_per_step, init_report=True)
|
||||
|
||||
def reset(self, track_data: bool = None, common_infra: dict = None, **kwargs):
|
||||
"""
|
||||
@@ -76,19 +77,19 @@ class BaseExecutor:
|
||||
if track_data is not None:
|
||||
self.track_data = track_data
|
||||
|
||||
if common_infra is not None:
|
||||
self.reset_common_infra(common_infra)
|
||||
|
||||
if "start_time" in kwargs or "end_time" in kwargs:
|
||||
start_time = kwargs.get("start_time")
|
||||
end_time = kwargs.get("end_time")
|
||||
self.trade_calendar = TradeCalendarManager(step_bar=self.step_bar, start_time=start_time, end_time=end_time)
|
||||
self.calendar = TradeCalendarManager(freq=self.time_per_step, start_time=start_time, end_time=end_time)
|
||||
|
||||
if common_infra is not None:
|
||||
self.reset_common_infra(common_infra)
|
||||
|
||||
def get_level_infra(self):
|
||||
return {"trade_calendar": self.trade_calendar}
|
||||
return {"calendar": self.calendar}
|
||||
|
||||
def finished(self):
|
||||
return self.trade_calendar.finished()
|
||||
return self.calendar.finished()
|
||||
|
||||
def execute(self, trade_decision):
|
||||
"""execute the trade decision and return the executed result
|
||||
@@ -99,8 +100,8 @@ class BaseExecutor:
|
||||
|
||||
Returns
|
||||
----------
|
||||
executed state : List[Tuple[Order, float, float, float]]
|
||||
- Each element in the list represents (order, trade value, trade cost, trade price)
|
||||
execute_result : List[object]
|
||||
the executed result for trade decison
|
||||
"""
|
||||
raise NotImplementedError("execute is not implemented!")
|
||||
|
||||
@@ -109,9 +110,6 @@ class BaseExecutor:
|
||||
yield trade_decision
|
||||
return self.execute(trade_decision)
|
||||
|
||||
def get_init_state(self):
|
||||
raise NotImplementedError("get_init_state in not implemeted!")
|
||||
|
||||
def get_trade_account(self):
|
||||
raise NotImplementedError("get_trade_account is not implemented!")
|
||||
|
||||
@@ -124,9 +122,9 @@ class SplitExecutor(BaseExecutor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_bar: str,
|
||||
sub_executor: Union[BaseExecutor, dict],
|
||||
sub_strategy: Union[BaseStrategy, dict],
|
||||
time_per_step: str,
|
||||
inner_executor: Union[BaseExecutor, dict],
|
||||
inner_strategy: Union[BaseStrategy, dict],
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
trade_exchange: Exchange = None,
|
||||
@@ -139,22 +137,24 @@ class SplitExecutor(BaseExecutor):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
sub_executor : BaseExecutor
|
||||
inner_executor : BaseExecutor
|
||||
trading env in each trading bar.
|
||||
sub_strategy : BaseStrategy
|
||||
inner_strategy : BaseStrategy
|
||||
trading strategy in each trading bar
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to generate report
|
||||
- If generate_report is None, trade_exchange will be ignored
|
||||
- Else If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
"""
|
||||
self.sub_executor = init_instance_by_config(sub_executor, common_infra=common_infra, accept_types=BaseExecutor)
|
||||
self.sub_strategy = init_instance_by_config(
|
||||
sub_strategy, common_infra=common_infra, accept_types=self.BaseStrategy
|
||||
self.inner_executor = init_instance_by_config(
|
||||
inner_executor, common_infra=common_infra, accept_types=BaseExecutor
|
||||
)
|
||||
self.inner_strategy = init_instance_by_config(
|
||||
inner_strategy, common_infra=common_infra, accept_types=self.BaseStrategy
|
||||
)
|
||||
|
||||
super(SplitExecutor, self).__init__(
|
||||
step_bar=step_bar,
|
||||
time_per_step=time_per_step,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generate_report=generate_report,
|
||||
@@ -171,29 +171,26 @@ class SplitExecutor(BaseExecutor):
|
||||
"""
|
||||
reset infrastructure for trading
|
||||
- reset trade_exchange
|
||||
- reset substrategy and subexecutor common infra
|
||||
- reset inner_strategyand inner_executor common infra
|
||||
"""
|
||||
super(SplitExecutor, self).reset_common_infra(common_infra)
|
||||
|
||||
if self.generate_report and "trade_exchange" in common_infra:
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
self.sub_executor.reset_common_infra(common_infra)
|
||||
self.sub_strategy.reset_common_infra(common_infra)
|
||||
|
||||
def get_init_state(self):
|
||||
return []
|
||||
self.inner_executor.reset_common_infra(common_infra)
|
||||
self.inner_strategy.reset_common_infra(common_infra)
|
||||
|
||||
def _init_sub_trading(self, trade_decision):
|
||||
trade_index = self.trade_calendar.get_trade_index()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
|
||||
self.sub_executor.reset(start_time=trade_start_time, end_time=trade_end_time)
|
||||
sub_level_infra = self.sub_executor.get_level_infra()
|
||||
self.sub_strategy.reset(level_infra=sub_level_infra, rely_trade_decision=trade_decision)
|
||||
trade_index = self.calendar.get_trade_index()
|
||||
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
|
||||
self.inner_executor.reset(start_time=trade_start_time, end_time=trade_end_time)
|
||||
sub_level_infra = self.inner_executor.get_level_infra()
|
||||
self.inner_strategy.reset(level_infra=sub_level_infra, outer_trade_decision=trade_decision)
|
||||
|
||||
def _update_trade_account(self):
|
||||
trade_index = self.trade_calendar.get_trade_index()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
|
||||
trade_index = self.calendar.get_trade_index()
|
||||
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
|
||||
self.trade_account.update_bar_count()
|
||||
if self.generate_report:
|
||||
self.trade_account.update_bar_report(
|
||||
@@ -203,41 +200,41 @@ class SplitExecutor(BaseExecutor):
|
||||
)
|
||||
|
||||
def execute(self, trade_decision):
|
||||
self.trade_calendar.step()
|
||||
self.calendar.step()
|
||||
self._init_sub_trading(trade_decision)
|
||||
execute_state = []
|
||||
sub_execute_state = self.sub_executor.get_init_state()
|
||||
while not self.sub_executor.finished():
|
||||
sub_trade_decison = self.sub_strategy.generate_trade_decision(sub_execute_state)
|
||||
sub_execute_state = self.sub_executor.execute(trade_decision=sub_trade_decison)
|
||||
execute_state.extend(sub_execute_state)
|
||||
execute_result = []
|
||||
_inner_execute_result = None
|
||||
while not self.inner_executor.finished():
|
||||
_inner_trade_decision = self.inner_strategy.generate_trade_decision(_inner_execute_result)
|
||||
_inner_execute_result = self.inner_executor.execute(trade_decision=_inner_trade_decision)
|
||||
execute_result.extend(_inner_execute_result)
|
||||
if hasattr(self, "trade_account"):
|
||||
self._update_trade_account()
|
||||
|
||||
return execute_state
|
||||
return execute_result
|
||||
|
||||
def collect_data(self, trade_decision):
|
||||
if self.track_data:
|
||||
yield trade_decision
|
||||
self.trade_calendar.step()
|
||||
self.calendar.step()
|
||||
self._init_sub_trading(trade_decision)
|
||||
execute_state = []
|
||||
sub_execute_state = self.sub_executor.get_init_state()
|
||||
while not self.sub_executor.finished():
|
||||
sub_trade_decison = self.sub_strategy.generate_trade_decision(sub_execute_state)
|
||||
sub_execute_state = yield from self.sub_executor.collect_data(trade_decision=sub_trade_decison)
|
||||
execute_state.extend(sub_execute_state)
|
||||
execute_result = []
|
||||
_inner_execute_result = None
|
||||
while not self.inner_executor.finished():
|
||||
_inner_trade_decision = self.inner_strategy.generate_trade_decision(_inner_execute_result)
|
||||
_inner_execute_result = yield from self.inner_executor.collect_data(trade_decision=_inner_trade_decision)
|
||||
execute_result.extend(_inner_execute_result)
|
||||
if hasattr(self, "trade_account"):
|
||||
self._update_trade_account()
|
||||
|
||||
return execute_state
|
||||
return execute_result
|
||||
|
||||
def get_report(self):
|
||||
sub_env_report_dict = self.sub_executor.get_report()
|
||||
sub_env_report_dict = self.inner_executor.get_report()
|
||||
if self.generate_report:
|
||||
_report = self.trade_account.report.generate_report_dataframe()
|
||||
_positions = self.trade_account.get_positions()
|
||||
_count, _freq = parse_freq(self.step_bar)
|
||||
_count, _freq = parse_freq(self.time_per_step)
|
||||
sub_env_report_dict.update({f"{_count}{_freq}": (_report, _positions)})
|
||||
return sub_env_report_dict
|
||||
|
||||
@@ -245,7 +242,7 @@ class SplitExecutor(BaseExecutor):
|
||||
class SimulatorExecutor(BaseExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
step_bar: str,
|
||||
time_per_step: str,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
trade_exchange: Exchange = None,
|
||||
@@ -263,7 +260,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
"""
|
||||
super(SimulatorExecutor, self).__init__(
|
||||
step_bar=step_bar,
|
||||
time_per_step=time_per_step,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generate_report=generate_report,
|
||||
@@ -284,21 +281,18 @@ class SimulatorExecutor(BaseExecutor):
|
||||
if "trade_exchange" in common_infra:
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def get_init_state(self):
|
||||
return []
|
||||
|
||||
def execute(self, trade_decision):
|
||||
self.trade_calendar.step()
|
||||
trade_index = self.trade_calendar.get_trade_index()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
|
||||
execute_state = []
|
||||
self.calendar.step()
|
||||
trade_index = self.calendar.get_trade_index()
|
||||
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
|
||||
execute_result = []
|
||||
for order in trade_decision:
|
||||
if self.trade_exchange.check_order(order) is True:
|
||||
# execute the order
|
||||
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(
|
||||
order, trade_account=self.trade_account
|
||||
)
|
||||
execute_state.append((order, trade_val, trade_cost, trade_price))
|
||||
execute_result.append((order, trade_val, trade_cost, trade_price))
|
||||
if self.verbose:
|
||||
if order.direction == Order.SELL: # sell
|
||||
print(
|
||||
@@ -340,13 +334,13 @@ class SimulatorExecutor(BaseExecutor):
|
||||
trade_exchange=self.trade_exchange,
|
||||
)
|
||||
|
||||
return execute_state
|
||||
return execute_result
|
||||
|
||||
def get_report(self):
|
||||
if self.generate_report:
|
||||
_report = self.trade_account.report.generate_report_dataframe()
|
||||
_positions = self.trade_account.get_positions()
|
||||
_count, _freq = parse_freq(self.step_bar)
|
||||
_count, _freq = parse_freq(self.time_per_step)
|
||||
return {f"{_count}{_freq}": (_report, _positions)}
|
||||
else:
|
||||
return {}
|
||||
|
||||
@@ -15,13 +15,13 @@ class TradeCalendarManager:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, step_bar: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None
|
||||
self, freq: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
step_bar : str
|
||||
frequency of each trading calendar
|
||||
freq : str
|
||||
frequency of trading calendar, also trade time per trading step
|
||||
start_time : Union[str, pd.Timestamp], optional
|
||||
closed start of the trading calendar, by default None
|
||||
If `start_time` is None, it must be reset before trading.
|
||||
@@ -29,14 +29,14 @@ class TradeCalendarManager:
|
||||
closed end of the trade time range, by default None
|
||||
If `end_time` is None, it must be reset before trading.
|
||||
"""
|
||||
self.step_bar = step_bar
|
||||
self.freq = freq
|
||||
self.start_time = pd.Timestamp(start_time) if start_time else None
|
||||
self.end_time = pd.Timestamp(start_time) if start_time else None
|
||||
self._init_trade_calendar(step_bar=step_bar, start_time=start_time, end_time=end_time)
|
||||
self._init_trade_calendar(freq=freq, start_time=start_time, end_time=end_time)
|
||||
|
||||
def _init_trade_calendar(self, step_bar, start_time, end_time):
|
||||
def _init_trade_calendar(self, freq, start_time, end_time):
|
||||
"""reset trade calendar"""
|
||||
_calendar, freq, freq_sam = get_resam_calendar(freq=step_bar)
|
||||
_calendar, freq, freq_sam = get_resam_calendar(freq=freq)
|
||||
self.calendar = _calendar
|
||||
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam)
|
||||
self.start_index = _start_index
|
||||
@@ -52,8 +52,8 @@ class TradeCalendarManager:
|
||||
raise RuntimeError(f"The calendar is finished, please reset it if you want to call it!")
|
||||
self.trade_index = self.trade_index + 1
|
||||
|
||||
def get_step_bar(self):
|
||||
return self.step_bar
|
||||
def get_freq(self):
|
||||
return self.freq
|
||||
|
||||
def get_trade_len(self):
|
||||
return self.trade_len
|
||||
|
||||
@@ -81,10 +81,10 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
# It will use 95% amoutn of your total value by default
|
||||
return self.risk_degree
|
||||
|
||||
def generate_trade_decision(self, execute_state):
|
||||
trade_index = self.trade_calendar.get_trade_index()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_calendar_time(trade_index, shift=1)
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
trade_index = self.calendar.get_trade_index()
|
||||
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
|
||||
pred_start_time, pred_end_time = self.calendar.get_calendar_time(trade_index, shift=1)
|
||||
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
if pred_score is None:
|
||||
return []
|
||||
@@ -179,8 +179,8 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
continue
|
||||
if code in sell:
|
||||
# check hold limit
|
||||
step_bar = self.trade_calendar.get_step_bar()
|
||||
if current_temp.get_stock_count(code, bar=step_bar) < self.hold_thresh:
|
||||
time_per_step = self.calendar.get_freq()
|
||||
if current_temp.get_stock_count(code, bar=time_per_step) < self.hold_thresh:
|
||||
continue
|
||||
# sell order
|
||||
sell_amount = current_temp.get_stock_amount(code=code)
|
||||
@@ -292,7 +292,7 @@ class WeightStrategyBase(ModelStrategy):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def generate_trade_decision(self, execute_state):
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
"""
|
||||
Parameters
|
||||
-----------
|
||||
@@ -307,9 +307,9 @@ class WeightStrategyBase(ModelStrategy):
|
||||
"""
|
||||
# generate_trade_decision
|
||||
# generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list
|
||||
trade_index = self.trade_calendar.get_trade_index()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_calendar_time(trade_index, shift=1)
|
||||
trade_index = self.calendar.get_trade_index()
|
||||
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
|
||||
pred_start_time, pred_end_time = self.calendar.get_calendar_time(trade_index, shift=1)
|
||||
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
if pred_score is None:
|
||||
return []
|
||||
|
||||
@@ -24,31 +24,31 @@ class TWAPStrategy(RuleStrategy):
|
||||
if "trade_exchange" in common_infra:
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def reset(self, rely_trade_decision: object = None, **kwargs):
|
||||
def reset(self, outer_trade_decision: object = None, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
rely_trade_decision : object, optional
|
||||
outer_trade_decision : object, optional
|
||||
"""
|
||||
|
||||
super(TWAPStrategy, self).reset(rely_trade_decision=rely_trade_decision, common_infra=common_infra, **kwargs)
|
||||
if rely_trade_decision is not None:
|
||||
super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, common_infra=common_infra, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
self.trade_amount = {}
|
||||
for order in rely_trade_decision:
|
||||
for order in outer_trade_decision:
|
||||
self.trade_amount[(order.stock_id, order.direction)] = order.amount
|
||||
|
||||
def generate_trade_decision(self, execute_state):
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
|
||||
# update the order amount
|
||||
trade_info = execute_state
|
||||
for order, _, _, _ in trade_info:
|
||||
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
|
||||
if execute_result is not None:
|
||||
for order, _, _, _ in execute_result:
|
||||
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
|
||||
|
||||
trade_index = self.trade_calendar.get_trade_index()
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
|
||||
trade_index = self.calendar.get_trade_index()
|
||||
trade_len = self.calendar.get_trade_len()
|
||||
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
|
||||
order_list = []
|
||||
for order in self.rely_trade_decision:
|
||||
for order in self.outer_trade_decision:
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
|
||||
):
|
||||
@@ -104,41 +104,41 @@ class SBBStrategyBase(RuleStrategy):
|
||||
if "trade_exchange" in common_infra:
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def reset(self, rely_trade_decision=None, **kwargs):
|
||||
def reset(self, outer_trade_decision=None, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
rely_trade_decision : object, optional
|
||||
outer_trade_decision : object, optional
|
||||
common_infra : None, optional
|
||||
common infrastructure for backtesting, by default None
|
||||
- It should include `trade_account`, used to get position
|
||||
- It should include `trade_exchange`, used to provide market info
|
||||
"""
|
||||
super(SBBStrategyBase, self).reset(rely_trade_decision=rely_trade_decision, **kwargs)
|
||||
if rely_trade_decision is not None:
|
||||
super(SBBStrategyBase, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
self.trade_trend = {}
|
||||
self.trade_amount = {}
|
||||
# init the trade amount of order and predicted trade trend
|
||||
for order in rely_trade_decision:
|
||||
for order in outer_trade_decision:
|
||||
self.trade_trend[(order.stock_id, order.direction)] = self.TREND_MID
|
||||
self.trade_amount[(order.stock_id, order.direction)] = order.amount
|
||||
|
||||
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
|
||||
raise NotImplementedError("pred_price_trend method is not implemented!")
|
||||
|
||||
def generate_trade_decision(self, execute_state):
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
|
||||
# update the order amount
|
||||
trade_info = execute_state
|
||||
for order, _, _, _ in trade_info:
|
||||
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
|
||||
trade_index = self.trade_calendar.get_trade_index()
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_calendar_time(trade_index, shift=1)
|
||||
if execute_result is not None:
|
||||
for order, _, _, _ in execute_result:
|
||||
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
|
||||
trade_index = self.calendar.get_trade_index()
|
||||
trade_len = self.calendar.get_trade_len()
|
||||
trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index)
|
||||
pred_start_time, pred_end_time = self.calendar.get_calendar_time(trade_index, shift=1)
|
||||
order_list = []
|
||||
# for each order in in self.rely_trade_decision
|
||||
for order in self.rely_trade_decision:
|
||||
# for each order in in self.outer_trade_decision
|
||||
for order in self.outer_trade_decision:
|
||||
# predict the price trend
|
||||
if trade_index % 2 == 1:
|
||||
_pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time)
|
||||
@@ -266,7 +266,7 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rely_trade_decision=[],
|
||||
outer_trade_decision=[],
|
||||
instruments="csi300",
|
||||
freq="day",
|
||||
level_infra={},
|
||||
@@ -288,13 +288,13 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
if isinstance(instruments, str):
|
||||
self.instruments = D.instruments(instruments)
|
||||
self.freq = freq
|
||||
super(SBBStrategyEMA, self).__init__(rely_trade_decision, level_infra, common_infra, **kwargs)
|
||||
super(SBBStrategyEMA, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
|
||||
|
||||
def _reset_signal(self):
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
trade_len = self.calendar.get_trade_len()
|
||||
fields = ["EMA($close, 10)-EMA($close, 20)"]
|
||||
signal_start_time, _ = self.trade_calendar.get_calendar_time(trade_index=1, shift=1)
|
||||
_, signal_end_time = self.trade_calendar.get_calendar_time(trade_index=trade_len, shift=1)
|
||||
signal_start_time, _ = self.calendar.get_calendar_time(trade_index=1, shift=1)
|
||||
_, signal_end_time = self.calendar.get_calendar_time(trade_index=trade_len, shift=1)
|
||||
signal_df = D.features(
|
||||
self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
|
||||
)
|
||||
@@ -307,15 +307,15 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
def reset_level_infra(self, level_infra):
|
||||
"""
|
||||
reset level-shared infra
|
||||
- After reset the trade_calendar, the signal will be changed
|
||||
- After reset the trade calendar, the signal will be changed
|
||||
"""
|
||||
if not hasattr(self, "level_infra"):
|
||||
self.level_infra = level_infra
|
||||
else:
|
||||
self.level_infra.update(level_infra)
|
||||
|
||||
if "trade_calendar" in level_infra:
|
||||
self.trade_calendar = level_infra.get("trade_calendar")
|
||||
if "calendar" in level_infra:
|
||||
self.calendar = level_infra.get("calendar")
|
||||
self._reset_signal()
|
||||
|
||||
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Union
|
||||
from .interpreter import StateInterpreter, ActionInterpreter
|
||||
from ..contrib.backtest.executor import BaseExecutor
|
||||
from ..utils import init_instance_by_config
|
||||
from .interpreter import BaseInterpreter
|
||||
|
||||
|
||||
class BaseRLEnv:
|
||||
@@ -68,8 +69,8 @@ class QlibIntRLEnv(QlibRLEnv):
|
||||
interpretor that interprets the rl agent action into qlib order list
|
||||
"""
|
||||
super(QlibIntRLEnv, self).__init__(executor=executor)
|
||||
self.state_interpreter = init_instance_by_config(state_interpreter)
|
||||
self.action_interpreter = init_instance_by_config(action_interpreter)
|
||||
self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter)
|
||||
self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter)
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
@@ -87,7 +88,7 @@ class QlibIntRLEnv(QlibRLEnv):
|
||||
-------
|
||||
env state to rl policy
|
||||
"""
|
||||
_interpret_action = self.action_interpreter.interpret(action=action)
|
||||
_execute_result = self.executor.execute(_interpret_action)
|
||||
_interpret_decision = self.action_interpreter.interpret(action=action)
|
||||
_execute_result = self.executor.execute(trade_decision=_interpret_decision)
|
||||
_interpret_state = self.state_interpreter.interpret(execute_result=_execute_result)
|
||||
return _interpret_state
|
||||
|
||||
@@ -19,24 +19,24 @@ class BaseStrategy:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rely_trade_decision: object = None,
|
||||
outer_trade_decision: object = None,
|
||||
level_infra: dict = {},
|
||||
common_infra: dict = {},
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
rely_trade_decision : object, optional
|
||||
the high-level trade decison on which the startegy rely, and it will be traded in [start_time , end_time] , by default None
|
||||
outer_trade_decision : object, optional
|
||||
the trade decison of outer strategy which this startegy relies, and it will be traded in [start_time, end_time], by default None
|
||||
- If the strategy is used to split trade decison, it will be used
|
||||
- If the strategy is used for portfolio management, it can be ignored
|
||||
level_infra : dict, optional
|
||||
level shared infrastructure for backtesting, including trade_calendar
|
||||
level shared infrastructure for backtesting, including trade calendar
|
||||
common_infra : dict, optional
|
||||
common infrastructure for backtesting, including trade_account, trade_exchange, .etc
|
||||
"""
|
||||
|
||||
self.reset(level_infra=level_infra, common_infra=common_infra, rely_trade_decision=rely_trade_decision)
|
||||
self.reset(level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision)
|
||||
|
||||
def reset_level_infra(self, level_infra):
|
||||
if not hasattr(self, "level_infra"):
|
||||
@@ -44,8 +44,8 @@ class BaseStrategy:
|
||||
else:
|
||||
self.level_infra.update(level_infra)
|
||||
|
||||
if "trade_calendar" in level_infra:
|
||||
self.trade_calendar = level_infra.get("trade_calendar")
|
||||
if "calendar" in level_infra:
|
||||
self.calendar = level_infra.get("calendar")
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
if not hasattr(self, "common_infra"):
|
||||
@@ -56,11 +56,11 @@ class BaseStrategy:
|
||||
if "trade_account" in common_infra:
|
||||
self.trade_position = common_infra.get("trade_account").current
|
||||
|
||||
def reset(self, level_infra: dict = None, common_infra: dict = None, rely_trade_decision=None, **kwargs):
|
||||
def reset(self, level_infra: dict = None, common_infra: dict = None, outer_trade_decision=None, **kwargs):
|
||||
"""
|
||||
- reset `level_infra`, used to reset trade_calendar, .etc
|
||||
- reset `level_infra`, used to reset trade calendar, .etc
|
||||
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
|
||||
- reset `rely_trade_decision`, used to make split decison
|
||||
- reset `outer_trade_decision`, used to make split decison
|
||||
"""
|
||||
if level_infra is not None:
|
||||
self.reset_level_infra(level_infra)
|
||||
@@ -68,11 +68,18 @@ class BaseStrategy:
|
||||
if common_infra is not None:
|
||||
self.reset_common_infra(common_infra)
|
||||
|
||||
if rely_trade_decision is not None:
|
||||
self.rely_trade_decision = rely_trade_decision
|
||||
if outer_trade_decision is not None:
|
||||
self.outer_trade_decision = outer_trade_decision
|
||||
|
||||
def generate_trade_decision(self, execute_state):
|
||||
"""Generate trade decision in each trading bar"""
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
"""Generate trade decision in each trading bar
|
||||
|
||||
Parameters
|
||||
----------
|
||||
execute_result : List[object], optional
|
||||
the executed result for trade decison, by default None
|
||||
- When call the generate_trade_decision firstly, `execute_result` could be None
|
||||
"""
|
||||
raise NotImplementedError("generate_trade_decision is not implemented!")
|
||||
|
||||
|
||||
@@ -89,7 +96,7 @@ class ModelStrategy(BaseStrategy):
|
||||
self,
|
||||
model: BaseModel,
|
||||
dataset: DatasetH,
|
||||
rely_trade_decision: object = None,
|
||||
outer_trade_decision: object = None,
|
||||
level_infra: dict = {},
|
||||
common_infra: dict = {},
|
||||
**kwargs,
|
||||
@@ -104,7 +111,7 @@ class ModelStrategy(BaseStrategy):
|
||||
kwargs : dict
|
||||
arguments that will be passed into `reset` method
|
||||
"""
|
||||
super(ModelStrategy, self).__init__(rely_trade_decision, level_infra, common_infra, **kwargs)
|
||||
super(ModelStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
self.pred_scores = convert_index_format(self.model.predict(dataset), level="datetime")
|
||||
@@ -125,7 +132,7 @@ class RLStrategy(BaseStrategy):
|
||||
def __init__(
|
||||
self,
|
||||
policy,
|
||||
rely_trade_decision: object = None,
|
||||
outer_trade_decision: object = None,
|
||||
level_infra: dict = {},
|
||||
common_infra: dict = {},
|
||||
**kwargs,
|
||||
@@ -136,7 +143,7 @@ class RLStrategy(BaseStrategy):
|
||||
policy :
|
||||
RL policy for generate action
|
||||
"""
|
||||
super(RLStrategy, self).__init__(rely_trade_decision, level_infra, common_infra, **kwargs)
|
||||
super(RLStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
|
||||
self.policy = policy
|
||||
|
||||
|
||||
@@ -148,7 +155,7 @@ class RLIntStrategy(RLStrategy):
|
||||
policy,
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
rely_trade_decision: object = None,
|
||||
outer_trade_decision: object = None,
|
||||
level_infra: dict = {},
|
||||
common_infra: dict = {},
|
||||
**kwargs,
|
||||
@@ -165,15 +172,14 @@ class RLIntStrategy(RLStrategy):
|
||||
end_time : Union[str, pd.Timestamp], optional
|
||||
end time of trading, by default None
|
||||
"""
|
||||
super(RLIntStrategy, self).__init__(policy, rely_trade_decision, level_infra, common_infra, **kwargs)
|
||||
super(RLIntStrategy, self).__init__(policy, outer_trade_decision, level_infra, common_infra, **kwargs)
|
||||
|
||||
self.policy = policy
|
||||
self.state_interpreter = init_instance_by_config(state_interpreter)
|
||||
self.action_interpreter = init_instance_by_config(action_interpreter)
|
||||
|
||||
def generate_trade_decision(self, execute_state):
|
||||
super(RLStrategy, self).step()
|
||||
_interpret_state = self.state_interpretor.interpret(execute_result=execute_state)
|
||||
_policy_action = self.policy.step(_interpret_state)
|
||||
_order_list = self.action_interpreter.interpret(action=_policy_action)
|
||||
return _order_list
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
_interpret_state = self.state_interpretor.interpret(execute_result=execute_result)
|
||||
_action = self.policy.step(_interpret_state)
|
||||
_trade_decision = self.action_interpreter.interpret(action=_action)
|
||||
return _trade_decision
|
||||
|
||||
@@ -317,7 +317,7 @@ class PortAnaRecord(RecordTemp):
|
||||
def _get_report_freq(self, executor_config):
|
||||
ret_freq = []
|
||||
if executor_config["kwargs"].get("generate_report", False):
|
||||
_count, _freq = parse_freq(executor_config["kwargs"]["step_bar"])
|
||||
_count, _freq = parse_freq(executor_config["kwargs"]["time_per_step"])
|
||||
ret_freq.append(f"{_count}{_freq}")
|
||||
if "sub_env" in executor_config["kwargs"]:
|
||||
ret_freq.extend(self._get_report_freq(executor_config["kwargs"]["sub_env"]))
|
||||
|
||||
Reference in New Issue
Block a user