1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 19:10:58 +08:00

Use indicator to simplify state update

This commit is contained in:
Huoran Li
2022-07-22 11:55:51 +08:00
parent 036e5931f9
commit 53dde5146c
3 changed files with 100 additions and 70 deletions

View File

@@ -21,7 +21,7 @@ from qlib.rl.order_execution.utils import (
_convert_tick_str_to_int,
_dataframe_append,
_get_common_infra,
_get_ticks_slice,
_get_portfolio_and_indicator, _get_ticks_slice,
_price_advantage,
)
from qlib.rl.simulator import Simulator
@@ -103,83 +103,63 @@ class StateMaintainer:
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
self.metrics = None
def update(self, inner_executor: BaseExecutor, inner_strategy: DecomposedStrategy, done: bool) -> None:
def update(
self,
inner_executor: BaseExecutor,
inner_strategy: DecomposedStrategy,
done: bool,
all_indicators: dict,
) -> None:
execute_order = inner_strategy.execute_order
execute_result = inner_strategy.execute_result
exec_vol = np.array([e[0].deal_amount for e in execute_result])
ticks_position = self.position - np.cumsum(exec_vol)
self.position -= exec_vol.sum()
num_step = len(execute_result)
if len(execute_result) > 0:
exchange = inner_executor.trade_exchange
market_price = np.array(
[
exchange.get_deal_price(
execute_order.stock_id,
execute_result[0][0].start_time,
execute_result[-1][0].start_time,
direction=execute_order.direction,
method=None,
)
]
).reshape(-1)
market_volume = np.array(
[
exchange.get_volume(
execute_order.stock_id,
execute_result[0][0].start_time,
execute_result[-1][0].start_time,
method=None,
)
]
).reshape(-1)
datetime_list = _get_ticks_slice(
self._tick_index,
execute_result[0][0].start_time,
execute_result[-1][0].start_time,
include_end=True,
)
else:
market_price = np.array([])
if num_step == 0:
market_volume = np.array([])
market_price = np.array([])
datetime_list = pd.DatetimeIndex([])
else:
market_volume = np.array(
inner_executor.trade_exchange.get_volume(
execute_order.stock_id,
execute_result[0][0].start_time,
execute_result[-1][0].start_time,
method=None,
)
)
trade_value = all_indicators["1min"].iloc[-num_step:]["value"].values
deal_amount = all_indicators["1min"].iloc[-num_step:]["deal_amount"].values
market_price = trade_value / deal_amount
datetime_list = all_indicators["1min"].index[-num_step:]
assert market_price.shape == market_volume.shape == exec_vol.shape
self.history_exec = _dataframe_append(
self.history_exec,
SAOEMetrics(
# It should have the same keys with SAOEMetrics,
# but the values do not necessarily have the annotated type.
# Some values could be vectorized (e.g., exec_vol).
stock_id=self._order.stock_id,
self._collect_multi_order_metric(
order=self._order,
datetime=datetime_list,
direction=self._order.direction,
market_volume=market_volume,
market_vol=market_volume,
market_price=market_price,
amount=exec_vol,
inner_amount=exec_vol,
deal_amount=exec_vol,
trade_price=market_price,
trade_value=market_price * exec_vol,
position=ticks_position,
ffr=exec_vol / self._order.amount,
pa=_price_advantage(market_price, self._twap_price, self._order.direction),
exec_vol=exec_vol,
pa=all_indicators["30min"].iloc[-1]["pa"],
),
)
self.history_steps = _dataframe_append(
self.history_steps,
[
self._metrics_collect(
execute_order, execute_order.start_time, market_volume, market_price, exec_vol.sum(), exec_vol
self._collect_single_order_metric(
execute_order, execute_order.start_time, market_volume, market_price, exec_vol.sum(), exec_vol,
)
],
)
if done:
self.metrics = self._metrics_collect(
self.metrics = self._collect_single_order_metric(
self._order,
self._tick_index[0], # start time
self.history_exec["market_volume"],
@@ -188,7 +168,38 @@ class StateMaintainer:
self.history_exec["deal_amount"],
)
def _metrics_collect(
# Do this at the end
self.position -= exec_vol.sum()
def _collect_multi_order_metric(
self,
order: Order,
datetime: pd.Timestamp,
market_vol: np.ndarray,
market_price: np.ndarray,
exec_vol: np.ndarray,
pa: float,
) -> SAOEMetrics:
return SAOEMetrics(
# It should have the same keys with SAOEMetrics,
# but the values do not necessarily have the annotated type.
# Some values could be vectorized (e.g., exec_vol).
stock_id=order.stock_id,
datetime=datetime,
direction=order.direction,
market_volume=market_vol,
market_price=market_price,
amount=exec_vol,
inner_amount=exec_vol,
deal_amount=exec_vol,
trade_price=market_price,
trade_value=market_price * exec_vol,
position=self.position - np.cumsum(exec_vol),
ffr=exec_vol / order.amount,
pa=pa,
)
def _collect_single_order_metric(
self,
order: Order,
datetime: pd.Timestamp,
@@ -206,6 +217,7 @@ class StateMaintainer:
if hasattr(exec_avg_price, "item"): # could be numpy scalar
exec_avg_price = exec_avg_price.item() # type: ignore
exec_sum = exec_vol.sum()
return SAOEMetrics(
stock_id=order.stock_id,
datetime=datetime,
@@ -213,12 +225,12 @@ class StateMaintainer:
market_volume=market_vol.sum(),
market_price=market_price.mean() if len(market_price) > 0 else np.nan,
amount=amount,
inner_amount=exec_vol.sum(),
deal_amount=exec_vol.sum(), # in this simulator, there's no other restrictions
inner_amount=exec_sum,
deal_amount=exec_sum, # in this simulator, there's no other restrictions
trade_price=exec_avg_price,
trade_value=float(np.sum(market_price * exec_vol)),
position=self.position,
ffr=float(exec_vol.sum() / order.amount),
position=self.position - exec_sum,
ffr=float(exec_sum / order.amount),
pa=_price_advantage(exec_avg_price, self._twap_price, order.direction),
)
@@ -326,10 +338,13 @@ class SingleAssetQlibSimulator(Simulator[Order, SAOEState, float]):
except StopIteration:
self._done = True
_, all_indicators = _get_portfolio_and_indicator(self._executor)
self._maintainer.update(
inner_executor=self._inner_executor,
inner_strategy=self._inner_strategy,
done=self._done,
all_indicators=all_indicators,
)
def get_state(self) -> SAOEState:

View File

@@ -144,9 +144,10 @@ def test_simulator_stop_twap() -> None:
def test_interpreter() -> None:
NUM_EXECUTION = 3
order = get_order()
simulator = get_simulator(order)
interpreter_action = CategoricalActionInterpreter(values=4)
interpreter_action = CategoricalActionInterpreter(values=NUM_EXECUTION)
NUM_STEPS = 7
state = simulator.get_state()
@@ -156,13 +157,7 @@ def test_interpreter() -> None:
state = simulator.get_state()
position_history.append(state.position)
assert position_history[0] == TOTAL_POSITION - TOTAL_POSITION / 4 * 1
assert position_history[1] == TOTAL_POSITION - TOTAL_POSITION / 4 * 2
assert position_history[2] == TOTAL_POSITION - TOTAL_POSITION / 4 * 3
assert position_history[3] == 0.0
assert position_history[4] == 0.0
assert position_history[5] == 0.0
assert position_history[6] == 0.0
assert position_history[-1] == max(TOTAL_POSITION - TOTAL_POSITION / NUM_EXECUTION * (i + 1), 0.0)
if __name__ == "__main__":

View File

@@ -1,14 +1,17 @@
from __future__ import annotations
from typing import Any, List, Optional, cast
from typing import Any, cast, List, Optional, Tuple
import numpy as np
import pandas as pd
from qlib.backtest import Account, CommonInfrastructure, get_exchange
from qlib.backtest import CommonInfrastructure, get_exchange
from qlib.backtest.account import Account
from qlib.backtest.decision import OrderDir
from qlib.backtest.executor import BaseExecutor
from qlib.rl.order_execution.from_neutrader.config import ExchangeConfig
from qlib.rl.order_execution.simulator_simple import ONE_SEC, _float_or_ndarray
from qlib.rl.order_execution.simulator_simple import _float_or_ndarray, ONE_SEC
from qlib.utils.time import Freq
def _get_common_infra(
@@ -93,3 +96,20 @@ def _price_advantage(
return res_wo_nan.item()
else:
return cast(_float_or_ndarray, res_wo_nan)
def _get_portfolio_and_indicator(executor: BaseExecutor) -> Tuple[dict, dict]:
all_executors = executor.get_all_executors()
all_portfolio_metrics = {
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_portfolio_metrics()
for _executor in all_executors
if _executor.trade_account.is_port_metr_enabled()
}
all_indicators = {}
for _executor in all_executors:
key = "{}{}".format(*Freq.parse(_executor.time_per_step))
all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator()
return all_portfolio_metrics, all_indicators