diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index ba1dd2c0b..6660f9ef6 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -444,7 +444,7 @@ class Exchange: stock_id: str, start_time: pd.Timestamp, end_time: pd.Timestamp, - method: str = "sum", + method: Optional[str] = "sum", ) -> float: """get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)""" return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method) @@ -455,7 +455,7 @@ class Exchange: start_time: pd.Timestamp, end_time: pd.Timestamp, direction: OrderDir, - method: str = "ts_data_last", + method: Optional[str] = "ts_data_last", ) -> float: if direction == OrderDir.SELL: pstr = self.sell_price diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index 7a58512df..687f3dbe0 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -19,15 +19,17 @@ This file shows resemblence to qlib.backtest.high_performance_ds. We might merge from __future__ import annotations +from abc import abstractmethod from functools import lru_cache from pathlib import Path -from typing import List, Sequence, cast +from typing import List, Optional, Sequence, cast import cachetools import numpy as np import pandas as pd from cachetools.keys import hashkey +from qlib.backtest import Exchange from qlib.backtest.decision import Order, OrderDir from qlib.typehint import Literal @@ -86,6 +88,31 @@ def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame: class IntradayBacktestData: + def __init__(self) -> None: + super(IntradayBacktestData, self).__init__() + + @abstractmethod + def __repr__(self) -> str: + raise NotImplementedError + + @abstractmethod + def __len__(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_deal_price(self) -> pd.Series: + raise NotImplementedError + + @abstractmethod + def get_volume(self) -> pd.Series: + raise NotImplementedError + + @abstractmethod + def get_time_index(self) -> pd.DatetimeIndex: + raise NotImplementedError + + +class SimpleIntradayBacktestData(IntradayBacktestData): """Raw market data that is often used in backtesting (thus called BacktestData).""" def __init__( @@ -96,6 +123,8 @@ class IntradayBacktestData: deal_price: DealPriceType = "close", order_dir: int = None, ) -> None: + super(SimpleIntradayBacktestData, self).__init__() + backtest = _read_pickle(data_dir / stock_id) backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]] @@ -146,6 +175,41 @@ class IntradayBacktestData: return cast(pd.DatetimeIndex, self.data.index) +class QlibIntradayBacktestData(IntradayBacktestData): + def __init__(self, order: Order, exchange: Exchange, start_time: pd.Timestamp, end_time: pd.Timestamp) -> None: + super(QlibIntradayBacktestData, self).__init__() + self._order = order + self._exchange = exchange + self._start_time = start_time + self._end_time = end_time + + def __repr__(self) -> str: + raise NotImplementedError + + def __len__(self) -> int: + raise NotImplementedError + + def get_deal_price(self) -> pd.Series: + return self._exchange.get_deal_price( + self._order.stock_id, + self._start_time, + self._end_time, + direction=self._order.direction, + method=None, + ) + + def get_volume(self) -> pd.Series: + return self._exchange.get_volume( + self._order.stock_id, + self._start_time, + self._end_time, + method=None, + ) + + def get_time_index(self) -> pd.DatetimeIndex: + return pd.DatetimeIndex([e[1] for e in list(self._exchange.quote_df.index)]) + + class IntradayProcessedData: """Processed market data after data cleanup and feature engineering. @@ -202,14 +266,14 @@ class IntradayProcessedData: @lru_cache(maxsize=100) # 100 * 50K = 5MB -def load_intraday_backtest_data( +def load_simple_intraday_backtest_data( data_dir: Path, stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", order_dir: int = None, -) -> IntradayBacktestData: - return IntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir) +) -> SimpleIntradayBacktestData: + return SimpleIntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir) @cachetools.cached( # type: ignore diff --git a/qlib/rl/order_execution/simulator_qlib.py b/qlib/rl/order_execution/simulator_qlib.py index eacb67b70..b27fe55ff 100644 --- a/qlib/rl/order_execution/simulator_qlib.py +++ b/qlib/rl/order_execution/simulator_qlib.py @@ -14,10 +14,11 @@ from qlib.backtest.executor import BaseExecutor, NestedExecutor from qlib.backtest.utils import CommonInfrastructure from qlib.config import QlibConfig from qlib.constant import EPS +from qlib.rl.data.pickle_styled import QlibIntradayBacktestData from qlib.rl.order_execution.from_neutrader.config import ExchangeConfig from qlib.rl.order_execution.from_neutrader.feature import init_qlib from qlib.rl.order_execution.simulator_simple import SAOEMetrics, SAOEState -from qlib.rl.order_execution.utils import (_convert_tick_str_to_int, _dataframe_append, _get_common_infra, _get_minutes, +from qlib.rl.order_execution.utils import (_convert_tick_str_to_int, _dataframe_append, _get_common_infra, _get_ticks_slice, _price_advantage) from qlib.rl.simulator import Simulator from qlib.strategy.base import BaseStrategy @@ -107,14 +108,19 @@ class StateMaintainer: if len(execute_result) > 0: exchange = inner_executor.trade_exchange - minutes = _get_minutes(execute_result[0][0].start_time, execute_result[-1][0].start_time) - market_price = np.array( - [ - exchange.get_deal_price(execute_order.stock_id, t, t, direction=execute_order.direction) - for t in minutes - ] - ) - market_volume = np.array([exchange.get_volume(execute_order.stock_id, t, t) for t in minutes]) + market_price = np.array([exchange.get_deal_price( + execute_order.stock_id, + execute_result[0][0].start_time, + execute_result[-1][0].start_time, + direction=execute_order.direction, + method=None, + )]).reshape(-1) + market_volume = np.array([exchange.get_volume( + execute_order.stock_id, + execute_result[0][0].start_time, + execute_result[-1][0].start_time, + method=None, + )]).reshape(-1) datetime_list = _get_ticks_slice( self._tick_index, execute_result[0][0].start_time, execute_result[-1][0].start_time, include_end=True, @@ -265,14 +271,15 @@ class QlibSimulator(Simulator[Order, SAOEState, float]): include_end=True, ) - self.twap_price = exchange.get_deal_price( - order.stock_id, - pd.Timestamp(self._ticks_for_order[0]), - pd.Timestamp(self._ticks_for_order[-1]), - direction=order.direction, - method="mean", + self._backtest_data = QlibIntradayBacktestData( + order=self._order, + exchange=exchange, + start_time=self._ticks_for_order[0], + end_time=self._ticks_for_order[-1], ) + self.twap_price = self._backtest_data.get_deal_price().mean() + top_strategy = SingleOrderStrategy(common_infra, order, self._trade_range, instrument) self._executor.reset(start_time=pd.Timestamp(self._order_date), end_time=pd.Timestamp(self._order_date)) top_strategy.reset(level_infra=self._executor.get_level_infra()) @@ -318,7 +325,7 @@ class QlibSimulator(Simulator[Order, SAOEState, float]): history_exec=self._maintainer.history_exec, history_steps=self._maintainer.history_steps, metrics=self._maintainer.metrics, - backtest_data=None, + backtest_data=self._backtest_data, ticks_per_step=self._ticks_per_step, ticks_index=self._ticks_index, ticks_for_order=self._ticks_for_order, diff --git a/qlib/rl/order_execution/simulator_simple.py b/qlib/rl/order_execution/simulator_simple.py index 031c7ab14..ab4b88031 100644 --- a/qlib/rl/order_execution/simulator_simple.py +++ b/qlib/rl/order_execution/simulator_simple.py @@ -11,7 +11,7 @@ import pandas as pd from qlib.backtest.decision import Order, OrderDir from qlib.constant import EPS -from qlib.rl.data.pickle_styled import DealPriceType, IntradayBacktestData, load_intraday_backtest_data +from qlib.rl.data.pickle_styled import DealPriceType, IntradayBacktestData, load_simple_intraday_backtest_data from qlib.rl.simulator import Simulator from qlib.rl.utils import LogLevel from qlib.typehint import TypedDict @@ -165,7 +165,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): self.deal_price_type = deal_price_type self.vol_threshold = vol_threshold self.data_dir = data_dir - self.backtest_data = load_intraday_backtest_data( + self.backtest_data = load_simple_intraday_backtest_data( self.data_dir, order.stock_id, pd.Timestamp(order.start_time.date()), diff --git a/qlib/rl/order_execution/tests/test_simulator_qlib.py b/qlib/rl/order_execution/tests/test_simulator_qlib.py index 2904c8f65..839524ec1 100644 --- a/qlib/rl/order_execution/tests/test_simulator_qlib.py +++ b/qlib/rl/order_execution/tests/test_simulator_qlib.py @@ -78,8 +78,8 @@ def test_simulator_stop_twap() -> None: assert is_close(state.position, 0.0) assert is_close(state.metrics["ffr"], 1.0) - # assert abs(state.metrics["market_price"] - state.backtest_data.get_deal_price().mean()) < 1e-4 - # assert np.isclose(state.metrics["market_volume"], state.backtest_data.get_volume().sum()) + assert is_close(state.metrics["market_price"], state.backtest_data.get_deal_price().mean()) + assert is_close(state.metrics["market_volume"], state.backtest_data.get_volume().sum()) assert is_close(state.metrics["trade_price"], state.metrics["market_price"]) assert is_close(state.metrics["pa"], 0.0) diff --git a/qlib/rl/order_execution/utils.py b/qlib/rl/order_execution/utils.py index dcfc03b31..6e7f22c54 100644 --- a/qlib/rl/order_execution/utils.py +++ b/qlib/rl/order_execution/utils.py @@ -63,15 +63,6 @@ def _get_ticks_slice( return ticks_index[ticks_index.slice_indexer(start, end)] -def _get_minutes(start_time: pd.Timestamp, end_time: pd.Timestamp) -> List[pd.Timestamp]: - minutes = [] - t = start_time - while t <= end_time: - minutes.append(t) - t += pd.Timedelta("1min") - return minutes - - def _dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame: # dataframe.append is deprecated other_df = pd.DataFrame(other).set_index("datetime") @@ -101,4 +92,4 @@ def _price_advantage( if res_wo_nan.size == 1: return res_wo_nan.item() else: - return cast(_float_or_ndarray, res_wo_nan) \ No newline at end of file + return cast(_float_or_ndarray, res_wo_nan) diff --git a/tests/rl/test_saoe_simple.py b/tests/rl/test_saoe_simple.py index 7eb9377d5..51ba13b34 100644 --- a/tests/rl/test_saoe_simple.py +++ b/tests/rl/test_saoe_simple.py @@ -38,7 +38,7 @@ CN_POLICY_WEIGHTS_DIR = CN_DATA_DIR / "weights" def test_pickle_data_inspect(): - data = pickle_styled.load_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0) + data = pickle_styled.load_simple_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0) assert len(data) == 390 data = pickle_styled.load_intraday_processed_data(