mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 19:10:58 +08:00
Use indicator to simplify state update
This commit is contained in:
@@ -21,7 +21,7 @@ from qlib.rl.order_execution.utils import (
|
||||
_convert_tick_str_to_int,
|
||||
_dataframe_append,
|
||||
_get_common_infra,
|
||||
_get_ticks_slice,
|
||||
_get_portfolio_and_indicator, _get_ticks_slice,
|
||||
_price_advantage,
|
||||
)
|
||||
from qlib.rl.simulator import Simulator
|
||||
@@ -103,83 +103,63 @@ class StateMaintainer:
|
||||
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.metrics = None
|
||||
|
||||
def update(self, inner_executor: BaseExecutor, inner_strategy: DecomposedStrategy, done: bool) -> None:
|
||||
def update(
|
||||
self,
|
||||
inner_executor: BaseExecutor,
|
||||
inner_strategy: DecomposedStrategy,
|
||||
done: bool,
|
||||
all_indicators: dict,
|
||||
) -> None:
|
||||
execute_order = inner_strategy.execute_order
|
||||
execute_result = inner_strategy.execute_result
|
||||
exec_vol = np.array([e[0].deal_amount for e in execute_result])
|
||||
ticks_position = self.position - np.cumsum(exec_vol)
|
||||
self.position -= exec_vol.sum()
|
||||
num_step = len(execute_result)
|
||||
|
||||
if len(execute_result) > 0:
|
||||
exchange = inner_executor.trade_exchange
|
||||
market_price = np.array(
|
||||
[
|
||||
exchange.get_deal_price(
|
||||
execute_order.stock_id,
|
||||
execute_result[0][0].start_time,
|
||||
execute_result[-1][0].start_time,
|
||||
direction=execute_order.direction,
|
||||
method=None,
|
||||
)
|
||||
]
|
||||
).reshape(-1)
|
||||
market_volume = np.array(
|
||||
[
|
||||
exchange.get_volume(
|
||||
execute_order.stock_id,
|
||||
execute_result[0][0].start_time,
|
||||
execute_result[-1][0].start_time,
|
||||
method=None,
|
||||
)
|
||||
]
|
||||
).reshape(-1)
|
||||
|
||||
datetime_list = _get_ticks_slice(
|
||||
self._tick_index,
|
||||
execute_result[0][0].start_time,
|
||||
execute_result[-1][0].start_time,
|
||||
include_end=True,
|
||||
)
|
||||
else:
|
||||
market_price = np.array([])
|
||||
if num_step == 0:
|
||||
market_volume = np.array([])
|
||||
market_price = np.array([])
|
||||
datetime_list = pd.DatetimeIndex([])
|
||||
else:
|
||||
market_volume = np.array(
|
||||
inner_executor.trade_exchange.get_volume(
|
||||
execute_order.stock_id,
|
||||
execute_result[0][0].start_time,
|
||||
execute_result[-1][0].start_time,
|
||||
method=None,
|
||||
)
|
||||
)
|
||||
|
||||
trade_value = all_indicators["1min"].iloc[-num_step:]["value"].values
|
||||
deal_amount = all_indicators["1min"].iloc[-num_step:]["deal_amount"].values
|
||||
market_price = trade_value / deal_amount
|
||||
|
||||
datetime_list = all_indicators["1min"].index[-num_step:]
|
||||
|
||||
assert market_price.shape == market_volume.shape == exec_vol.shape
|
||||
|
||||
self.history_exec = _dataframe_append(
|
||||
self.history_exec,
|
||||
SAOEMetrics(
|
||||
# It should have the same keys with SAOEMetrics,
|
||||
# but the values do not necessarily have the annotated type.
|
||||
# Some values could be vectorized (e.g., exec_vol).
|
||||
stock_id=self._order.stock_id,
|
||||
self._collect_multi_order_metric(
|
||||
order=self._order,
|
||||
datetime=datetime_list,
|
||||
direction=self._order.direction,
|
||||
market_volume=market_volume,
|
||||
market_vol=market_volume,
|
||||
market_price=market_price,
|
||||
amount=exec_vol,
|
||||
inner_amount=exec_vol,
|
||||
deal_amount=exec_vol,
|
||||
trade_price=market_price,
|
||||
trade_value=market_price * exec_vol,
|
||||
position=ticks_position,
|
||||
ffr=exec_vol / self._order.amount,
|
||||
pa=_price_advantage(market_price, self._twap_price, self._order.direction),
|
||||
exec_vol=exec_vol,
|
||||
pa=all_indicators["30min"].iloc[-1]["pa"],
|
||||
),
|
||||
)
|
||||
|
||||
self.history_steps = _dataframe_append(
|
||||
self.history_steps,
|
||||
[
|
||||
self._metrics_collect(
|
||||
execute_order, execute_order.start_time, market_volume, market_price, exec_vol.sum(), exec_vol
|
||||
self._collect_single_order_metric(
|
||||
execute_order, execute_order.start_time, market_volume, market_price, exec_vol.sum(), exec_vol,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
if done:
|
||||
self.metrics = self._metrics_collect(
|
||||
self.metrics = self._collect_single_order_metric(
|
||||
self._order,
|
||||
self._tick_index[0], # start time
|
||||
self.history_exec["market_volume"],
|
||||
@@ -188,7 +168,38 @@ class StateMaintainer:
|
||||
self.history_exec["deal_amount"],
|
||||
)
|
||||
|
||||
def _metrics_collect(
|
||||
# Do this at the end
|
||||
self.position -= exec_vol.sum()
|
||||
|
||||
def _collect_multi_order_metric(
|
||||
self,
|
||||
order: Order,
|
||||
datetime: pd.Timestamp,
|
||||
market_vol: np.ndarray,
|
||||
market_price: np.ndarray,
|
||||
exec_vol: np.ndarray,
|
||||
pa: float,
|
||||
) -> SAOEMetrics:
|
||||
return SAOEMetrics(
|
||||
# It should have the same keys with SAOEMetrics,
|
||||
# but the values do not necessarily have the annotated type.
|
||||
# Some values could be vectorized (e.g., exec_vol).
|
||||
stock_id=order.stock_id,
|
||||
datetime=datetime,
|
||||
direction=order.direction,
|
||||
market_volume=market_vol,
|
||||
market_price=market_price,
|
||||
amount=exec_vol,
|
||||
inner_amount=exec_vol,
|
||||
deal_amount=exec_vol,
|
||||
trade_price=market_price,
|
||||
trade_value=market_price * exec_vol,
|
||||
position=self.position - np.cumsum(exec_vol),
|
||||
ffr=exec_vol / order.amount,
|
||||
pa=pa,
|
||||
)
|
||||
|
||||
def _collect_single_order_metric(
|
||||
self,
|
||||
order: Order,
|
||||
datetime: pd.Timestamp,
|
||||
@@ -206,6 +217,7 @@ class StateMaintainer:
|
||||
if hasattr(exec_avg_price, "item"): # could be numpy scalar
|
||||
exec_avg_price = exec_avg_price.item() # type: ignore
|
||||
|
||||
exec_sum = exec_vol.sum()
|
||||
return SAOEMetrics(
|
||||
stock_id=order.stock_id,
|
||||
datetime=datetime,
|
||||
@@ -213,12 +225,12 @@ class StateMaintainer:
|
||||
market_volume=market_vol.sum(),
|
||||
market_price=market_price.mean() if len(market_price) > 0 else np.nan,
|
||||
amount=amount,
|
||||
inner_amount=exec_vol.sum(),
|
||||
deal_amount=exec_vol.sum(), # in this simulator, there's no other restrictions
|
||||
inner_amount=exec_sum,
|
||||
deal_amount=exec_sum, # in this simulator, there's no other restrictions
|
||||
trade_price=exec_avg_price,
|
||||
trade_value=float(np.sum(market_price * exec_vol)),
|
||||
position=self.position,
|
||||
ffr=float(exec_vol.sum() / order.amount),
|
||||
position=self.position - exec_sum,
|
||||
ffr=float(exec_sum / order.amount),
|
||||
pa=_price_advantage(exec_avg_price, self._twap_price, order.direction),
|
||||
)
|
||||
|
||||
@@ -326,10 +338,13 @@ class SingleAssetQlibSimulator(Simulator[Order, SAOEState, float]):
|
||||
except StopIteration:
|
||||
self._done = True
|
||||
|
||||
_, all_indicators = _get_portfolio_and_indicator(self._executor)
|
||||
|
||||
self._maintainer.update(
|
||||
inner_executor=self._inner_executor,
|
||||
inner_strategy=self._inner_strategy,
|
||||
done=self._done,
|
||||
all_indicators=all_indicators,
|
||||
)
|
||||
|
||||
def get_state(self) -> SAOEState:
|
||||
|
||||
@@ -144,9 +144,10 @@ def test_simulator_stop_twap() -> None:
|
||||
|
||||
|
||||
def test_interpreter() -> None:
|
||||
NUM_EXECUTION = 3
|
||||
order = get_order()
|
||||
simulator = get_simulator(order)
|
||||
interpreter_action = CategoricalActionInterpreter(values=4)
|
||||
interpreter_action = CategoricalActionInterpreter(values=NUM_EXECUTION)
|
||||
|
||||
NUM_STEPS = 7
|
||||
state = simulator.get_state()
|
||||
@@ -156,13 +157,7 @@ def test_interpreter() -> None:
|
||||
state = simulator.get_state()
|
||||
position_history.append(state.position)
|
||||
|
||||
assert position_history[0] == TOTAL_POSITION - TOTAL_POSITION / 4 * 1
|
||||
assert position_history[1] == TOTAL_POSITION - TOTAL_POSITION / 4 * 2
|
||||
assert position_history[2] == TOTAL_POSITION - TOTAL_POSITION / 4 * 3
|
||||
assert position_history[3] == 0.0
|
||||
assert position_history[4] == 0.0
|
||||
assert position_history[5] == 0.0
|
||||
assert position_history[6] == 0.0
|
||||
assert position_history[-1] == max(TOTAL_POSITION - TOTAL_POSITION / NUM_EXECUTION * (i + 1), 0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Optional, cast
|
||||
from typing import Any, cast, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest import Account, CommonInfrastructure, get_exchange
|
||||
from qlib.backtest import CommonInfrastructure, get_exchange
|
||||
from qlib.backtest.account import Account
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.backtest.executor import BaseExecutor
|
||||
from qlib.rl.order_execution.from_neutrader.config import ExchangeConfig
|
||||
from qlib.rl.order_execution.simulator_simple import ONE_SEC, _float_or_ndarray
|
||||
from qlib.rl.order_execution.simulator_simple import _float_or_ndarray, ONE_SEC
|
||||
from qlib.utils.time import Freq
|
||||
|
||||
|
||||
def _get_common_infra(
|
||||
@@ -93,3 +96,20 @@ def _price_advantage(
|
||||
return res_wo_nan.item()
|
||||
else:
|
||||
return cast(_float_or_ndarray, res_wo_nan)
|
||||
|
||||
|
||||
def _get_portfolio_and_indicator(executor: BaseExecutor) -> Tuple[dict, dict]:
|
||||
all_executors = executor.get_all_executors()
|
||||
all_portfolio_metrics = {
|
||||
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_portfolio_metrics()
|
||||
for _executor in all_executors
|
||||
if _executor.trade_account.is_port_metr_enabled()
|
||||
}
|
||||
|
||||
all_indicators = {}
|
||||
for _executor in all_executors:
|
||||
key = "{}{}".format(*Freq.parse(_executor.time_per_step))
|
||||
all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
|
||||
all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator()
|
||||
|
||||
return all_portfolio_metrics, all_indicators
|
||||
|
||||
Reference in New Issue
Block a user