From 2ad61f12b3e08b1fbf736eca6c9abed20b8341f6 Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 27 May 2021 17:03:53 +0800 Subject: [PATCH] rename var in backtest --- examples/multi_level_trading/workflow.py | 11 +- examples/rolling_process_data/workflow.py | 1 - qlib/contrib/backtest/backtest.py | 12 +- qlib/contrib/backtest/executor.py | 128 +++++++++++----------- qlib/contrib/backtest/utils.py | 18 +-- qlib/contrib/strategy/model_strategy.py | 20 ++-- qlib/contrib/strategy/rule_strategy.py | 72 ++++++------ qlib/rl/env.py | 9 +- qlib/strategy/base.py | 58 +++++----- qlib/workflow/record_temp.py | 2 +- 10 files changed, 165 insertions(+), 166 deletions(-) diff --git a/examples/multi_level_trading/workflow.py b/examples/multi_level_trading/workflow.py index 390044480..ea11d4e7f 100644 --- a/examples/multi_level_trading/workflow.py +++ b/examples/multi_level_trading/workflow.py @@ -61,24 +61,24 @@ class MultiLevelTradingWorkflow: } trade_start_time = "2017-01-01" - trade_end_time = "2020-08-01" + trade_end_time = "2017-02-01" port_analysis_config = { "executor": { "class": "SplitExecutor", "module_path": "qlib.contrib.backtest.executor", "kwargs": { - "step_bar": "week", - "sub_executor": { + "time_per_step": "week", + "inner_executor": { "class": "SimulatorExecutor", "module_path": "qlib.contrib.backtest.executor", "kwargs": { - "step_bar": "day", + "time_per_step": "day", "verbose": True, "generate_report": True, }, }, - "sub_strategy": { + "inner_strategy": { "class": "SBBStrategyEMA", "module_path": "qlib.contrib.strategy.rule_strategy", "kwargs": { @@ -107,7 +107,6 @@ class MultiLevelTradingWorkflow: def _init_qlib(self): """initialize qlib""" - # use yahoo_cn_1min data provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir if not exists_qlib_data(provider_uri): print(f"Qlib data is not found in {provider_uri}") diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index 5757aaa87..048253f0d 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -23,7 +23,6 @@ class RollingDataWorkflow: def _init_qlib(self): """initialize qlib""" - # use yahoo_cn_1min data provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir if not exists_qlib_data(provider_uri): print(f"Qlib data is not found in {provider_uri}") diff --git a/qlib/contrib/backtest/backtest.py b/qlib/contrib/backtest/backtest.py index 33c73de7a..1f0d2ac38 100644 --- a/qlib/contrib/backtest/backtest.py +++ b/qlib/contrib/backtest/backtest.py @@ -8,10 +8,10 @@ def backtest(start_time, end_time, trade_strategy, trade_executor): level_infra = trade_executor.get_level_infra() trade_strategy.reset(level_infra=level_infra) - sub_execute_state = trade_executor.get_init_state() + _execute_result = None while not trade_executor.finished(): - sub_trade_decision = trade_strategy.generate_trade_decision(sub_execute_state) - sub_execute_state = trade_executor.execute(sub_trade_decision) + _trade_decision = trade_strategy.generate_trade_decision(_execute_result) + _execute_result = trade_executor.execute(_trade_decision) return trade_executor.get_report() @@ -22,9 +22,9 @@ def collect_data(start_time, end_time, trade_strategy, trade_executor): level_infra = trade_executor.get_level_infra() trade_strategy.reset(level_infra=level_infra) - sub_execute_state = trade_executor.get_init_state() + _execute_result = None while not trade_executor.finished(): - sub_trade_decision = trade_strategy.generate_trade_decision(sub_execute_state) - sub_execute_state = yield from trade_executor.collect_data(sub_trade_decision) + _trade_decision = trade_strategy.generate_trade_decision(_execute_result) + _execute_result = yield from trade_executor.collect_data(_trade_decision) return trade_executor.get_report() diff --git a/qlib/contrib/backtest/executor.py b/qlib/contrib/backtest/executor.py index 8a57d2986..c896f802d 100644 --- a/qlib/contrib/backtest/executor.py +++ b/qlib/contrib/backtest/executor.py @@ -8,7 +8,6 @@ from ...utils.resam import parse_freq from .order import Order -from .account import Account from .exchange import Exchange from .utils import TradeCalendarManager @@ -18,7 +17,7 @@ class BaseExecutor: def __init__( self, - step_bar: str, + time_per_step: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None, generate_report: bool = False, @@ -30,6 +29,8 @@ class BaseExecutor: """ Parameters ---------- + time_per_step : str + trade time per trading step, used for genreate trade calendar generate_report : bool, optional whether to generate report, by default False verbose : bool, optional @@ -46,7 +47,7 @@ class BaseExecutor: exchange that provides market info """ - self.step_bar = step_bar + self.time_per_step = time_per_step self.generate_report = generate_report self.verbose = verbose self.track_data = track_data @@ -64,7 +65,7 @@ class BaseExecutor: if "trade_account" in common_infra: self.trade_account = copy.copy(common_infra.get("trade_account")) - self.trade_account.reset(freq=self.step_bar, init_report=True) + self.trade_account.reset(freq=self.time_per_step, init_report=True) def reset(self, track_data: bool = None, common_infra: dict = None, **kwargs): """ @@ -76,19 +77,19 @@ class BaseExecutor: if track_data is not None: self.track_data = track_data - if common_infra is not None: - self.reset_common_infra(common_infra) - if "start_time" in kwargs or "end_time" in kwargs: start_time = kwargs.get("start_time") end_time = kwargs.get("end_time") - self.trade_calendar = TradeCalendarManager(step_bar=self.step_bar, start_time=start_time, end_time=end_time) + self.calendar = TradeCalendarManager(freq=self.time_per_step, start_time=start_time, end_time=end_time) + + if common_infra is not None: + self.reset_common_infra(common_infra) def get_level_infra(self): - return {"trade_calendar": self.trade_calendar} + return {"calendar": self.calendar} def finished(self): - return self.trade_calendar.finished() + return self.calendar.finished() def execute(self, trade_decision): """execute the trade decision and return the executed result @@ -99,8 +100,8 @@ class BaseExecutor: Returns ---------- - executed state : List[Tuple[Order, float, float, float]] - - Each element in the list represents (order, trade value, trade cost, trade price) + execute_result : List[object] + the executed result for trade decison """ raise NotImplementedError("execute is not implemented!") @@ -109,9 +110,6 @@ class BaseExecutor: yield trade_decision return self.execute(trade_decision) - def get_init_state(self): - raise NotImplementedError("get_init_state in not implemeted!") - def get_trade_account(self): raise NotImplementedError("get_trade_account is not implemented!") @@ -124,9 +122,9 @@ class SplitExecutor(BaseExecutor): def __init__( self, - step_bar: str, - sub_executor: Union[BaseExecutor, dict], - sub_strategy: Union[BaseStrategy, dict], + time_per_step: str, + inner_executor: Union[BaseExecutor, dict], + inner_strategy: Union[BaseStrategy, dict], start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None, trade_exchange: Exchange = None, @@ -139,22 +137,24 @@ class SplitExecutor(BaseExecutor): """ Parameters ---------- - sub_executor : BaseExecutor + inner_executor : BaseExecutor trading env in each trading bar. - sub_strategy : BaseStrategy + inner_strategy : BaseStrategy trading strategy in each trading bar trade_exchange : Exchange exchange that provides market info, used to generate report - If generate_report is None, trade_exchange will be ignored - Else If `trade_exchange` is None, self.trade_exchange will be set with common_infra """ - self.sub_executor = init_instance_by_config(sub_executor, common_infra=common_infra, accept_types=BaseExecutor) - self.sub_strategy = init_instance_by_config( - sub_strategy, common_infra=common_infra, accept_types=self.BaseStrategy + self.inner_executor = init_instance_by_config( + inner_executor, common_infra=common_infra, accept_types=BaseExecutor + ) + self.inner_strategy = init_instance_by_config( + inner_strategy, common_infra=common_infra, accept_types=self.BaseStrategy ) super(SplitExecutor, self).__init__( - step_bar=step_bar, + time_per_step=time_per_step, start_time=start_time, end_time=end_time, generate_report=generate_report, @@ -171,29 +171,26 @@ class SplitExecutor(BaseExecutor): """ reset infrastructure for trading - reset trade_exchange - - reset substrategy and subexecutor common infra + - reset inner_strategyand inner_executor common infra """ super(SplitExecutor, self).reset_common_infra(common_infra) if self.generate_report and "trade_exchange" in common_infra: self.trade_exchange = common_infra.get("trade_exchange") - self.sub_executor.reset_common_infra(common_infra) - self.sub_strategy.reset_common_infra(common_infra) - - def get_init_state(self): - return [] + self.inner_executor.reset_common_infra(common_infra) + self.inner_strategy.reset_common_infra(common_infra) def _init_sub_trading(self, trade_decision): - trade_index = self.trade_calendar.get_trade_index() - trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index) - self.sub_executor.reset(start_time=trade_start_time, end_time=trade_end_time) - sub_level_infra = self.sub_executor.get_level_infra() - self.sub_strategy.reset(level_infra=sub_level_infra, rely_trade_decision=trade_decision) + trade_index = self.calendar.get_trade_index() + trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index) + self.inner_executor.reset(start_time=trade_start_time, end_time=trade_end_time) + sub_level_infra = self.inner_executor.get_level_infra() + self.inner_strategy.reset(level_infra=sub_level_infra, outer_trade_decision=trade_decision) def _update_trade_account(self): - trade_index = self.trade_calendar.get_trade_index() - trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index) + trade_index = self.calendar.get_trade_index() + trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index) self.trade_account.update_bar_count() if self.generate_report: self.trade_account.update_bar_report( @@ -203,41 +200,41 @@ class SplitExecutor(BaseExecutor): ) def execute(self, trade_decision): - self.trade_calendar.step() + self.calendar.step() self._init_sub_trading(trade_decision) - execute_state = [] - sub_execute_state = self.sub_executor.get_init_state() - while not self.sub_executor.finished(): - sub_trade_decison = self.sub_strategy.generate_trade_decision(sub_execute_state) - sub_execute_state = self.sub_executor.execute(trade_decision=sub_trade_decison) - execute_state.extend(sub_execute_state) + execute_result = [] + _inner_execute_result = None + while not self.inner_executor.finished(): + _inner_trade_decision = self.inner_strategy.generate_trade_decision(_inner_execute_result) + _inner_execute_result = self.inner_executor.execute(trade_decision=_inner_trade_decision) + execute_result.extend(_inner_execute_result) if hasattr(self, "trade_account"): self._update_trade_account() - return execute_state + return execute_result def collect_data(self, trade_decision): if self.track_data: yield trade_decision - self.trade_calendar.step() + self.calendar.step() self._init_sub_trading(trade_decision) - execute_state = [] - sub_execute_state = self.sub_executor.get_init_state() - while not self.sub_executor.finished(): - sub_trade_decison = self.sub_strategy.generate_trade_decision(sub_execute_state) - sub_execute_state = yield from self.sub_executor.collect_data(trade_decision=sub_trade_decison) - execute_state.extend(sub_execute_state) + execute_result = [] + _inner_execute_result = None + while not self.inner_executor.finished(): + _inner_trade_decision = self.inner_strategy.generate_trade_decision(_inner_execute_result) + _inner_execute_result = yield from self.inner_executor.collect_data(trade_decision=_inner_trade_decision) + execute_result.extend(_inner_execute_result) if hasattr(self, "trade_account"): self._update_trade_account() - return execute_state + return execute_result def get_report(self): - sub_env_report_dict = self.sub_executor.get_report() + sub_env_report_dict = self.inner_executor.get_report() if self.generate_report: _report = self.trade_account.report.generate_report_dataframe() _positions = self.trade_account.get_positions() - _count, _freq = parse_freq(self.step_bar) + _count, _freq = parse_freq(self.time_per_step) sub_env_report_dict.update({f"{_count}{_freq}": (_report, _positions)}) return sub_env_report_dict @@ -245,7 +242,7 @@ class SplitExecutor(BaseExecutor): class SimulatorExecutor(BaseExecutor): def __init__( self, - step_bar: str, + time_per_step: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None, trade_exchange: Exchange = None, @@ -263,7 +260,7 @@ class SimulatorExecutor(BaseExecutor): - If `trade_exchange` is None, self.trade_exchange will be set with common_infra """ super(SimulatorExecutor, self).__init__( - step_bar=step_bar, + time_per_step=time_per_step, start_time=start_time, end_time=end_time, generate_report=generate_report, @@ -284,21 +281,18 @@ class SimulatorExecutor(BaseExecutor): if "trade_exchange" in common_infra: self.trade_exchange = common_infra.get("trade_exchange") - def get_init_state(self): - return [] - def execute(self, trade_decision): - self.trade_calendar.step() - trade_index = self.trade_calendar.get_trade_index() - trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index) - execute_state = [] + self.calendar.step() + trade_index = self.calendar.get_trade_index() + trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index) + execute_result = [] for order in trade_decision: if self.trade_exchange.check_order(order) is True: # execute the order trade_val, trade_cost, trade_price = self.trade_exchange.deal_order( order, trade_account=self.trade_account ) - execute_state.append((order, trade_val, trade_cost, trade_price)) + execute_result.append((order, trade_val, trade_cost, trade_price)) if self.verbose: if order.direction == Order.SELL: # sell print( @@ -340,13 +334,13 @@ class SimulatorExecutor(BaseExecutor): trade_exchange=self.trade_exchange, ) - return execute_state + return execute_result def get_report(self): if self.generate_report: _report = self.trade_account.report.generate_report_dataframe() _positions = self.trade_account.get_positions() - _count, _freq = parse_freq(self.step_bar) + _count, _freq = parse_freq(self.time_per_step) return {f"{_count}{_freq}": (_report, _positions)} else: return {} diff --git a/qlib/contrib/backtest/utils.py b/qlib/contrib/backtest/utils.py index 1a4173887..622816753 100644 --- a/qlib/contrib/backtest/utils.py +++ b/qlib/contrib/backtest/utils.py @@ -15,13 +15,13 @@ class TradeCalendarManager: """ def __init__( - self, step_bar: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None + self, freq: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None ): """ Parameters ---------- - step_bar : str - frequency of each trading calendar + freq : str + frequency of trading calendar, also trade time per trading step start_time : Union[str, pd.Timestamp], optional closed start of the trading calendar, by default None If `start_time` is None, it must be reset before trading. @@ -29,14 +29,14 @@ class TradeCalendarManager: closed end of the trade time range, by default None If `end_time` is None, it must be reset before trading. """ - self.step_bar = step_bar + self.freq = freq self.start_time = pd.Timestamp(start_time) if start_time else None self.end_time = pd.Timestamp(start_time) if start_time else None - self._init_trade_calendar(step_bar=step_bar, start_time=start_time, end_time=end_time) + self._init_trade_calendar(freq=freq, start_time=start_time, end_time=end_time) - def _init_trade_calendar(self, step_bar, start_time, end_time): + def _init_trade_calendar(self, freq, start_time, end_time): """reset trade calendar""" - _calendar, freq, freq_sam = get_resam_calendar(freq=step_bar) + _calendar, freq, freq_sam = get_resam_calendar(freq=freq) self.calendar = _calendar _, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam) self.start_index = _start_index @@ -52,8 +52,8 @@ class TradeCalendarManager: raise RuntimeError(f"The calendar is finished, please reset it if you want to call it!") self.trade_index = self.trade_index + 1 - def get_step_bar(self): - return self.step_bar + def get_freq(self): + return self.freq def get_trade_len(self): return self.trade_len diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/model_strategy.py index 336cfa534..d797729be 100644 --- a/qlib/contrib/strategy/model_strategy.py +++ b/qlib/contrib/strategy/model_strategy.py @@ -81,10 +81,10 @@ class TopkDropoutStrategy(ModelStrategy): # It will use 95% amoutn of your total value by default return self.risk_degree - def generate_trade_decision(self, execute_state): - trade_index = self.trade_calendar.get_trade_index() - trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index) - pred_start_time, pred_end_time = self.trade_calendar.get_calendar_time(trade_index, shift=1) + def generate_trade_decision(self, execute_result=None): + trade_index = self.calendar.get_trade_index() + trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index) + pred_start_time, pred_end_time = self.calendar.get_calendar_time(trade_index, shift=1) pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") if pred_score is None: return [] @@ -179,8 +179,8 @@ class TopkDropoutStrategy(ModelStrategy): continue if code in sell: # check hold limit - step_bar = self.trade_calendar.get_step_bar() - if current_temp.get_stock_count(code, bar=step_bar) < self.hold_thresh: + time_per_step = self.calendar.get_freq() + if current_temp.get_stock_count(code, bar=time_per_step) < self.hold_thresh: continue # sell order sell_amount = current_temp.get_stock_amount(code=code) @@ -292,7 +292,7 @@ class WeightStrategyBase(ModelStrategy): """ raise NotImplementedError() - def generate_trade_decision(self, execute_state): + def generate_trade_decision(self, execute_result=None): """ Parameters ----------- @@ -307,9 +307,9 @@ class WeightStrategyBase(ModelStrategy): """ # generate_trade_decision # generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list - trade_index = self.trade_calendar.get_trade_index() - trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index) - pred_start_time, pred_end_time = self.trade_calendar.get_calendar_time(trade_index, shift=1) + trade_index = self.calendar.get_trade_index() + trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index) + pred_start_time, pred_end_time = self.calendar.get_calendar_time(trade_index, shift=1) pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") if pred_score is None: return [] diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 2265a9dc5..1f42c451c 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -24,31 +24,31 @@ class TWAPStrategy(RuleStrategy): if "trade_exchange" in common_infra: self.trade_exchange = common_infra.get("trade_exchange") - def reset(self, rely_trade_decision: object = None, **kwargs): + def reset(self, outer_trade_decision: object = None, **kwargs): """ Parameters ---------- - rely_trade_decision : object, optional + outer_trade_decision : object, optional """ - super(TWAPStrategy, self).reset(rely_trade_decision=rely_trade_decision, common_infra=common_infra, **kwargs) - if rely_trade_decision is not None: + super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, common_infra=common_infra, **kwargs) + if outer_trade_decision is not None: self.trade_amount = {} - for order in rely_trade_decision: + for order in outer_trade_decision: self.trade_amount[(order.stock_id, order.direction)] = order.amount - def generate_trade_decision(self, execute_state): + def generate_trade_decision(self, execute_result=None): # update the order amount - trade_info = execute_state - for order, _, _, _ in trade_info: - self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount + if execute_result is not None: + for order, _, _, _ in execute_result: + self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount - trade_index = self.trade_calendar.get_trade_index() - trade_len = self.trade_calendar.get_trade_len() - trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index) + trade_index = self.calendar.get_trade_index() + trade_len = self.calendar.get_trade_len() + trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index) order_list = [] - for order in self.rely_trade_decision: + for order in self.outer_trade_decision: if not self.trade_exchange.is_stock_tradable( stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time ): @@ -104,41 +104,41 @@ class SBBStrategyBase(RuleStrategy): if "trade_exchange" in common_infra: self.trade_exchange = common_infra.get("trade_exchange") - def reset(self, rely_trade_decision=None, **kwargs): + def reset(self, outer_trade_decision=None, **kwargs): """ Parameters ---------- - rely_trade_decision : object, optional + outer_trade_decision : object, optional common_infra : None, optional common infrastructure for backtesting, by default None - It should include `trade_account`, used to get position - It should include `trade_exchange`, used to provide market info """ - super(SBBStrategyBase, self).reset(rely_trade_decision=rely_trade_decision, **kwargs) - if rely_trade_decision is not None: + super(SBBStrategyBase, self).reset(outer_trade_decision=outer_trade_decision, **kwargs) + if outer_trade_decision is not None: self.trade_trend = {} self.trade_amount = {} # init the trade amount of order and predicted trade trend - for order in rely_trade_decision: + for order in outer_trade_decision: self.trade_trend[(order.stock_id, order.direction)] = self.TREND_MID self.trade_amount[(order.stock_id, order.direction)] = order.amount def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None): raise NotImplementedError("pred_price_trend method is not implemented!") - def generate_trade_decision(self, execute_state): + def generate_trade_decision(self, execute_result=None): # update the order amount - trade_info = execute_state - for order, _, _, _ in trade_info: - self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount - trade_index = self.trade_calendar.get_trade_index() - trade_len = self.trade_calendar.get_trade_len() - trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index) - pred_start_time, pred_end_time = self.trade_calendar.get_calendar_time(trade_index, shift=1) + if execute_result is not None: + for order, _, _, _ in execute_result: + self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount + trade_index = self.calendar.get_trade_index() + trade_len = self.calendar.get_trade_len() + trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index) + pred_start_time, pred_end_time = self.calendar.get_calendar_time(trade_index, shift=1) order_list = [] - # for each order in in self.rely_trade_decision - for order in self.rely_trade_decision: + # for each order in in self.outer_trade_decision + for order in self.outer_trade_decision: # predict the price trend if trade_index % 2 == 1: _pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time) @@ -266,7 +266,7 @@ class SBBStrategyEMA(SBBStrategyBase): def __init__( self, - rely_trade_decision=[], + outer_trade_decision=[], instruments="csi300", freq="day", level_infra={}, @@ -288,13 +288,13 @@ class SBBStrategyEMA(SBBStrategyBase): if isinstance(instruments, str): self.instruments = D.instruments(instruments) self.freq = freq - super(SBBStrategyEMA, self).__init__(rely_trade_decision, level_infra, common_infra, **kwargs) + super(SBBStrategyEMA, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs) def _reset_signal(self): - trade_len = self.trade_calendar.get_trade_len() + trade_len = self.calendar.get_trade_len() fields = ["EMA($close, 10)-EMA($close, 20)"] - signal_start_time, _ = self.trade_calendar.get_calendar_time(trade_index=1, shift=1) - _, signal_end_time = self.trade_calendar.get_calendar_time(trade_index=trade_len, shift=1) + signal_start_time, _ = self.calendar.get_calendar_time(trade_index=1, shift=1) + _, signal_end_time = self.calendar.get_calendar_time(trade_index=trade_len, shift=1) signal_df = D.features( self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq ) @@ -307,15 +307,15 @@ class SBBStrategyEMA(SBBStrategyBase): def reset_level_infra(self, level_infra): """ reset level-shared infra - - After reset the trade_calendar, the signal will be changed + - After reset the trade calendar, the signal will be changed """ if not hasattr(self, "level_infra"): self.level_infra = level_infra else: self.level_infra.update(level_infra) - if "trade_calendar" in level_infra: - self.trade_calendar = level_infra.get("trade_calendar") + if "calendar" in level_infra: + self.calendar = level_infra.get("calendar") self._reset_signal() def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None): diff --git a/qlib/rl/env.py b/qlib/rl/env.py index 2fef7a659..faf9c026e 100644 --- a/qlib/rl/env.py +++ b/qlib/rl/env.py @@ -6,6 +6,7 @@ from typing import Union from .interpreter import StateInterpreter, ActionInterpreter from ..contrib.backtest.executor import BaseExecutor from ..utils import init_instance_by_config +from .interpreter import BaseInterpreter class BaseRLEnv: @@ -68,8 +69,8 @@ class QlibIntRLEnv(QlibRLEnv): interpretor that interprets the rl agent action into qlib order list """ super(QlibIntRLEnv, self).__init__(executor=executor) - self.state_interpreter = init_instance_by_config(state_interpreter) - self.action_interpreter = init_instance_by_config(action_interpreter) + self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter) + self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter) def step(self, action): """ @@ -87,7 +88,7 @@ class QlibIntRLEnv(QlibRLEnv): ------- env state to rl policy """ - _interpret_action = self.action_interpreter.interpret(action=action) - _execute_result = self.executor.execute(_interpret_action) + _interpret_decision = self.action_interpreter.interpret(action=action) + _execute_result = self.executor.execute(trade_decision=_interpret_decision) _interpret_state = self.state_interpreter.interpret(execute_result=_execute_result) return _interpret_state diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index dad994303..59d9d72e3 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -19,24 +19,24 @@ class BaseStrategy: def __init__( self, - rely_trade_decision: object = None, + outer_trade_decision: object = None, level_infra: dict = {}, common_infra: dict = {}, ): """ Parameters ---------- - rely_trade_decision : object, optional - the high-level trade decison on which the startegy rely, and it will be traded in [start_time , end_time] , by default None + outer_trade_decision : object, optional + the trade decison of outer strategy which this startegy relies, and it will be traded in [start_time, end_time], by default None - If the strategy is used to split trade decison, it will be used - If the strategy is used for portfolio management, it can be ignored level_infra : dict, optional - level shared infrastructure for backtesting, including trade_calendar + level shared infrastructure for backtesting, including trade calendar common_infra : dict, optional common infrastructure for backtesting, including trade_account, trade_exchange, .etc """ - self.reset(level_infra=level_infra, common_infra=common_infra, rely_trade_decision=rely_trade_decision) + self.reset(level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision) def reset_level_infra(self, level_infra): if not hasattr(self, "level_infra"): @@ -44,8 +44,8 @@ class BaseStrategy: else: self.level_infra.update(level_infra) - if "trade_calendar" in level_infra: - self.trade_calendar = level_infra.get("trade_calendar") + if "calendar" in level_infra: + self.calendar = level_infra.get("calendar") def reset_common_infra(self, common_infra): if not hasattr(self, "common_infra"): @@ -56,11 +56,11 @@ class BaseStrategy: if "trade_account" in common_infra: self.trade_position = common_infra.get("trade_account").current - def reset(self, level_infra: dict = None, common_infra: dict = None, rely_trade_decision=None, **kwargs): + def reset(self, level_infra: dict = None, common_infra: dict = None, outer_trade_decision=None, **kwargs): """ - - reset `level_infra`, used to reset trade_calendar, .etc + - reset `level_infra`, used to reset trade calendar, .etc - reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc - - reset `rely_trade_decision`, used to make split decison + - reset `outer_trade_decision`, used to make split decison """ if level_infra is not None: self.reset_level_infra(level_infra) @@ -68,11 +68,18 @@ class BaseStrategy: if common_infra is not None: self.reset_common_infra(common_infra) - if rely_trade_decision is not None: - self.rely_trade_decision = rely_trade_decision + if outer_trade_decision is not None: + self.outer_trade_decision = outer_trade_decision - def generate_trade_decision(self, execute_state): - """Generate trade decision in each trading bar""" + def generate_trade_decision(self, execute_result=None): + """Generate trade decision in each trading bar + + Parameters + ---------- + execute_result : List[object], optional + the executed result for trade decison, by default None + - When call the generate_trade_decision firstly, `execute_result` could be None + """ raise NotImplementedError("generate_trade_decision is not implemented!") @@ -89,7 +96,7 @@ class ModelStrategy(BaseStrategy): self, model: BaseModel, dataset: DatasetH, - rely_trade_decision: object = None, + outer_trade_decision: object = None, level_infra: dict = {}, common_infra: dict = {}, **kwargs, @@ -104,7 +111,7 @@ class ModelStrategy(BaseStrategy): kwargs : dict arguments that will be passed into `reset` method """ - super(ModelStrategy, self).__init__(rely_trade_decision, level_infra, common_infra, **kwargs) + super(ModelStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs) self.model = model self.dataset = dataset self.pred_scores = convert_index_format(self.model.predict(dataset), level="datetime") @@ -125,7 +132,7 @@ class RLStrategy(BaseStrategy): def __init__( self, policy, - rely_trade_decision: object = None, + outer_trade_decision: object = None, level_infra: dict = {}, common_infra: dict = {}, **kwargs, @@ -136,7 +143,7 @@ class RLStrategy(BaseStrategy): policy : RL policy for generate action """ - super(RLStrategy, self).__init__(rely_trade_decision, level_infra, common_infra, **kwargs) + super(RLStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs) self.policy = policy @@ -148,7 +155,7 @@ class RLIntStrategy(RLStrategy): policy, state_interpreter: StateInterpreter, action_interpreter: ActionInterpreter, - rely_trade_decision: object = None, + outer_trade_decision: object = None, level_infra: dict = {}, common_infra: dict = {}, **kwargs, @@ -165,15 +172,14 @@ class RLIntStrategy(RLStrategy): end_time : Union[str, pd.Timestamp], optional end time of trading, by default None """ - super(RLIntStrategy, self).__init__(policy, rely_trade_decision, level_infra, common_infra, **kwargs) + super(RLIntStrategy, self).__init__(policy, outer_trade_decision, level_infra, common_infra, **kwargs) self.policy = policy self.state_interpreter = init_instance_by_config(state_interpreter) self.action_interpreter = init_instance_by_config(action_interpreter) - def generate_trade_decision(self, execute_state): - super(RLStrategy, self).step() - _interpret_state = self.state_interpretor.interpret(execute_result=execute_state) - _policy_action = self.policy.step(_interpret_state) - _order_list = self.action_interpreter.interpret(action=_policy_action) - return _order_list + def generate_trade_decision(self, execute_result=None): + _interpret_state = self.state_interpretor.interpret(execute_result=execute_result) + _action = self.policy.step(_interpret_state) + _trade_decision = self.action_interpreter.interpret(action=_action) + return _trade_decision diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 1f80bd051..a32ef9729 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -317,7 +317,7 @@ class PortAnaRecord(RecordTemp): def _get_report_freq(self, executor_config): ret_freq = [] if executor_config["kwargs"].get("generate_report", False): - _count, _freq = parse_freq(executor_config["kwargs"]["step_bar"]) + _count, _freq = parse_freq(executor_config["kwargs"]["time_per_step"]) ret_freq.append(f"{_count}{_freq}") if "sub_env" in executor_config["kwargs"]: ret_freq.extend(self._get_report_freq(executor_config["kwargs"]["sub_env"]))