diff --git a/qlib/rl/order_execution/simulator_qlib.py b/qlib/rl/order_execution/simulator_qlib.py index c48aa4c18..63f71f75b 100644 --- a/qlib/rl/order_execution/simulator_qlib.py +++ b/qlib/rl/order_execution/simulator_qlib.py @@ -21,7 +21,7 @@ from qlib.rl.order_execution.utils import ( _convert_tick_str_to_int, _dataframe_append, _get_common_infra, - _get_ticks_slice, + _get_portfolio_and_indicator, _get_ticks_slice, _price_advantage, ) from qlib.rl.simulator import Simulator @@ -103,83 +103,63 @@ class StateMaintainer: self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime") self.metrics = None - def update(self, inner_executor: BaseExecutor, inner_strategy: DecomposedStrategy, done: bool) -> None: + def update( + self, + inner_executor: BaseExecutor, + inner_strategy: DecomposedStrategy, + done: bool, + all_indicators: dict, + ) -> None: execute_order = inner_strategy.execute_order execute_result = inner_strategy.execute_result exec_vol = np.array([e[0].deal_amount for e in execute_result]) - ticks_position = self.position - np.cumsum(exec_vol) - self.position -= exec_vol.sum() + num_step = len(execute_result) - if len(execute_result) > 0: - exchange = inner_executor.trade_exchange - market_price = np.array( - [ - exchange.get_deal_price( - execute_order.stock_id, - execute_result[0][0].start_time, - execute_result[-1][0].start_time, - direction=execute_order.direction, - method=None, - ) - ] - ).reshape(-1) - market_volume = np.array( - [ - exchange.get_volume( - execute_order.stock_id, - execute_result[0][0].start_time, - execute_result[-1][0].start_time, - method=None, - ) - ] - ).reshape(-1) - - datetime_list = _get_ticks_slice( - self._tick_index, - execute_result[0][0].start_time, - execute_result[-1][0].start_time, - include_end=True, - ) - else: - market_price = np.array([]) + if num_step == 0: market_volume = np.array([]) + market_price = np.array([]) datetime_list = pd.DatetimeIndex([]) + else: + market_volume = np.array( + inner_executor.trade_exchange.get_volume( + execute_order.stock_id, + execute_result[0][0].start_time, + execute_result[-1][0].start_time, + method=None, + ) + ) + + trade_value = all_indicators["1min"].iloc[-num_step:]["value"].values + deal_amount = all_indicators["1min"].iloc[-num_step:]["deal_amount"].values + market_price = trade_value / deal_amount + + datetime_list = all_indicators["1min"].index[-num_step:] assert market_price.shape == market_volume.shape == exec_vol.shape self.history_exec = _dataframe_append( self.history_exec, - SAOEMetrics( - # It should have the same keys with SAOEMetrics, - # but the values do not necessarily have the annotated type. - # Some values could be vectorized (e.g., exec_vol). - stock_id=self._order.stock_id, + self._collect_multi_order_metric( + order=self._order, datetime=datetime_list, - direction=self._order.direction, - market_volume=market_volume, + market_vol=market_volume, market_price=market_price, - amount=exec_vol, - inner_amount=exec_vol, - deal_amount=exec_vol, - trade_price=market_price, - trade_value=market_price * exec_vol, - position=ticks_position, - ffr=exec_vol / self._order.amount, - pa=_price_advantage(market_price, self._twap_price, self._order.direction), + exec_vol=exec_vol, + pa=all_indicators["30min"].iloc[-1]["pa"], ), ) self.history_steps = _dataframe_append( self.history_steps, [ - self._metrics_collect( - execute_order, execute_order.start_time, market_volume, market_price, exec_vol.sum(), exec_vol + self._collect_single_order_metric( + execute_order, execute_order.start_time, market_volume, market_price, exec_vol.sum(), exec_vol, ) ], ) if done: - self.metrics = self._metrics_collect( + self.metrics = self._collect_single_order_metric( self._order, self._tick_index[0], # start time self.history_exec["market_volume"], @@ -188,7 +168,38 @@ class StateMaintainer: self.history_exec["deal_amount"], ) - def _metrics_collect( + # Do this at the end + self.position -= exec_vol.sum() + + def _collect_multi_order_metric( + self, + order: Order, + datetime: pd.Timestamp, + market_vol: np.ndarray, + market_price: np.ndarray, + exec_vol: np.ndarray, + pa: float, + ) -> SAOEMetrics: + return SAOEMetrics( + # It should have the same keys with SAOEMetrics, + # but the values do not necessarily have the annotated type. + # Some values could be vectorized (e.g., exec_vol). + stock_id=order.stock_id, + datetime=datetime, + direction=order.direction, + market_volume=market_vol, + market_price=market_price, + amount=exec_vol, + inner_amount=exec_vol, + deal_amount=exec_vol, + trade_price=market_price, + trade_value=market_price * exec_vol, + position=self.position - np.cumsum(exec_vol), + ffr=exec_vol / order.amount, + pa=pa, + ) + + def _collect_single_order_metric( self, order: Order, datetime: pd.Timestamp, @@ -206,6 +217,7 @@ class StateMaintainer: if hasattr(exec_avg_price, "item"): # could be numpy scalar exec_avg_price = exec_avg_price.item() # type: ignore + exec_sum = exec_vol.sum() return SAOEMetrics( stock_id=order.stock_id, datetime=datetime, @@ -213,12 +225,12 @@ class StateMaintainer: market_volume=market_vol.sum(), market_price=market_price.mean() if len(market_price) > 0 else np.nan, amount=amount, - inner_amount=exec_vol.sum(), - deal_amount=exec_vol.sum(), # in this simulator, there's no other restrictions + inner_amount=exec_sum, + deal_amount=exec_sum, # in this simulator, there's no other restrictions trade_price=exec_avg_price, trade_value=float(np.sum(market_price * exec_vol)), - position=self.position, - ffr=float(exec_vol.sum() / order.amount), + position=self.position - exec_sum, + ffr=float(exec_sum / order.amount), pa=_price_advantage(exec_avg_price, self._twap_price, order.direction), ) @@ -326,10 +338,13 @@ class SingleAssetQlibSimulator(Simulator[Order, SAOEState, float]): except StopIteration: self._done = True + _, all_indicators = _get_portfolio_and_indicator(self._executor) + self._maintainer.update( inner_executor=self._inner_executor, inner_strategy=self._inner_strategy, done=self._done, + all_indicators=all_indicators, ) def get_state(self) -> SAOEState: diff --git a/qlib/rl/order_execution/tests/test_simulator_qlib.py b/qlib/rl/order_execution/tests/test_simulator_qlib.py index 0e7c3227d..bd706a2ea 100644 --- a/qlib/rl/order_execution/tests/test_simulator_qlib.py +++ b/qlib/rl/order_execution/tests/test_simulator_qlib.py @@ -144,9 +144,10 @@ def test_simulator_stop_twap() -> None: def test_interpreter() -> None: + NUM_EXECUTION = 3 order = get_order() simulator = get_simulator(order) - interpreter_action = CategoricalActionInterpreter(values=4) + interpreter_action = CategoricalActionInterpreter(values=NUM_EXECUTION) NUM_STEPS = 7 state = simulator.get_state() @@ -156,13 +157,7 @@ def test_interpreter() -> None: state = simulator.get_state() position_history.append(state.position) - assert position_history[0] == TOTAL_POSITION - TOTAL_POSITION / 4 * 1 - assert position_history[1] == TOTAL_POSITION - TOTAL_POSITION / 4 * 2 - assert position_history[2] == TOTAL_POSITION - TOTAL_POSITION / 4 * 3 - assert position_history[3] == 0.0 - assert position_history[4] == 0.0 - assert position_history[5] == 0.0 - assert position_history[6] == 0.0 + assert position_history[-1] == max(TOTAL_POSITION - TOTAL_POSITION / NUM_EXECUTION * (i + 1), 0.0) if __name__ == "__main__": diff --git a/qlib/rl/order_execution/utils.py b/qlib/rl/order_execution/utils.py index 6e7f22c54..ef22716b8 100644 --- a/qlib/rl/order_execution/utils.py +++ b/qlib/rl/order_execution/utils.py @@ -1,14 +1,17 @@ from __future__ import annotations -from typing import Any, List, Optional, cast +from typing import Any, cast, List, Optional, Tuple import numpy as np import pandas as pd -from qlib.backtest import Account, CommonInfrastructure, get_exchange +from qlib.backtest import CommonInfrastructure, get_exchange +from qlib.backtest.account import Account from qlib.backtest.decision import OrderDir +from qlib.backtest.executor import BaseExecutor from qlib.rl.order_execution.from_neutrader.config import ExchangeConfig -from qlib.rl.order_execution.simulator_simple import ONE_SEC, _float_or_ndarray +from qlib.rl.order_execution.simulator_simple import _float_or_ndarray, ONE_SEC +from qlib.utils.time import Freq def _get_common_infra( @@ -93,3 +96,20 @@ def _price_advantage( return res_wo_nan.item() else: return cast(_float_or_ndarray, res_wo_nan) + + +def _get_portfolio_and_indicator(executor: BaseExecutor) -> Tuple[dict, dict]: + all_executors = executor.get_all_executors() + all_portfolio_metrics = { + "{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_portfolio_metrics() + for _executor in all_executors + if _executor.trade_account.is_port_metr_enabled() + } + + all_indicators = {} + for _executor in all_executors: + key = "{}{}".format(*Freq.parse(_executor.time_per_step)) + all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe() + all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator() + + return all_portfolio_metrics, all_indicators