diff --git a/examples/nested_decision_execution/rl_dummy.py b/examples/nested_decision_execution/rl_dummy.py index 1ea444cdf..3eec91789 100644 --- a/examples/nested_decision_execution/rl_dummy.py +++ b/examples/nested_decision_execution/rl_dummy.py @@ -1,7 +1,7 @@ import pickle from dataclasses import dataclass, asdict from pprint import pprint -from typing import Iterable, Any, Optional, Tuple, Dict +from typing import Iterable, Any, Optional, Tuple, Dict, List import gym import numpy as np @@ -128,6 +128,48 @@ class EpisodicState: } return logs + @classmethod + def from_order_and_executor(cls, order: Order, executor: BaseExecutor, frequency: str) -> "EpisodicState": + # Synchronous state for executor to EpisodicState + executor.reset(start_time=order.start_time, end_time=order.end_time) + state = cls( + stock_id=order.stock_id, + start_time=order.start_time, + end_time=order.end_time, + direction=order.direction, + target=order.amount, + num_step=executor.trade_calendar.get_trade_len(), + market_price=_retrieve_backtest_data(order, '$close', frequency), + market_vol=_retrieve_backtest_data(order, '$volume', frequency), + ) + state.cur_step = executor.trade_calendar.get_trade_step() + assert state.cur_step == 0 + state.cur_time, _ = executor.trade_calendar.get_step_time(state.cur_step) + return state + + def update(self, execute_result: List[Order], executor: BaseExecutor) -> "StepState": + exec_vol = np.array([order.deal_amount for order, _, __, ___ in execute_result]) + # Synchronous exec_vol to executor and synchronous back to EpisodicState + calendar = executor.trade_calendar + cur_tick = self.cur_tick + ticks_this_step = len(exec_vol) + self.cur_step = trade_step = calendar.get_trade_step() + self.cur_tick += ticks_this_step + self.position -= np.sum(exec_vol) + self.position_history[trade_step] = self.position + self.done = executor.finished() + self.exec_vol = exec_vol if self.exec_vol is None else \ + np.concatenate((self.exec_vol, exec_vol)) + + if self.done: + self.update_stats() + else: + self.cur_time, _ = calendar.get_step_time(trade_step) + + l, r = cur_tick, cur_tick + ticks_this_step + assert 0 <= l < r + return StepState(exec_vol, self.market_vol[l:r], self.market_price[l:r], self) + @dataclass class StepState: @@ -158,6 +200,28 @@ class StepState: self.episode_state.direction) +def _retrieve_backtest_data(order: Order, field: str, frequency: str) -> np.ndarray: + # Retrieve backtest data for RL-specific use (including reward calculation) + return D.features( + [order.stock_id], + ['$open', '$close', '$high', '$low', '$volume'], + start_time=order.start_time, + end_time=order.end_time, + freq=frequency + )[field].to_numpy() + + +def create_sub_order(exec_vol: float, executor: BaseExecutor, original_order: Order) -> Order: + # Convert a real number to an order + calendar = executor.trade_calendar + trade_step = calendar.get_trade_step() + trade_start_time, trade_end_time = calendar.get_step_time(trade_step) + order_kwargs = asdict(original_order) + order_kwargs.update(start_time=trade_start_time, end_time=trade_end_time, amount=exec_vol) + trade_decision = Order(**order_kwargs) + return trade_decision + + class SingleOrderEnv(gym.Env): def __init__(self, observation: StateInterpreter, @@ -181,66 +245,6 @@ class SingleOrderEnv(gym.Env): def observation_space(self): return self.observation.observation_space - def retrieve_backtest_data(self, field: str): - # Retrieve backtest data for RL-specific use (including reward calculation) - return D.features( - [self.cur_order.stock_id], - ['$open', '$close', '$high', '$low', '$volume'], - start_time=self.cur_order.start_time, - end_time=self.cur_order.end_time, - freq=self.inner_frequency - )[field].to_numpy() - - def initialize_state(self): - # Synchronous state for executor to EpisodicState - self.executor.reset(start_time=self.cur_order.start_time, end_time=self.cur_order.end_time) - state = EpisodicState( - stock_id=self.cur_order.stock_id, - start_time=self.cur_order.start_time, - end_time=self.cur_order.end_time, - direction=self.cur_order.direction, - target=self.cur_order.amount, - num_step=self.executor.trade_calendar.get_trade_len(), - market_price=self.retrieve_backtest_data('$close'), - market_vol=self.retrieve_backtest_data('$volume'), - ) - state.cur_step = self.executor.trade_calendar.get_trade_step() - assert state.cur_step == 0 - state.cur_time, _ = self.executor.trade_calendar.get_step_time(state.cur_step) - return state - - def update_state(self, exec_vol): - # Synchronous exec_vol to executor and synchronous back to EpisodicState - calendar = self.executor.trade_calendar - state = self.ep_state - - trade_step = calendar.get_trade_step() - trade_start_time, trade_end_time = calendar.get_step_time(trade_step) - order_kwargs = asdict(self.cur_order) - order_kwargs.update(start_time=trade_start_time, end_time=trade_end_time, amount=exec_vol) - trade_decision = Order(**order_kwargs) - execute_result = self.executor.execute([trade_decision]) - cur_tick = state.cur_tick - - inner_exec_vol = np.array([order.deal_amount for order, _, __, ___ in execute_result]) - ticks_this_step = len(inner_exec_vol) - state.cur_step = trade_step = calendar.get_trade_step() - state.cur_tick += ticks_this_step - state.position -= np.sum(inner_exec_vol) - state.position_history[trade_step] = state.position - state.done = self.executor.finished() - state.exec_vol = inner_exec_vol if state.exec_vol is None else \ - np.concatenate((state.exec_vol, inner_exec_vol)) - - if state.done: - state.update_stats() - else: - state.cur_time, _ = calendar.get_step_time(trade_step) - - l, r = cur_tick, cur_tick + ticks_this_step - assert 0 <= l < r - return StepState(inner_exec_vol, state.market_vol[l:r], state.market_price[l:r], state) - def reset(self): try: self.cur_order = next(self.dataloader) @@ -249,7 +253,9 @@ class SingleOrderEnv(gym.Env): return None self.execute_result = [] - self.ep_state = self.initialize_state() + self.ep_state = EpisodicState.from_order_and_executor( + self.cur_order, self.executor, self.inner_frequency + ) self.action_history = np.full(self.ep_state.num_step, np.nan) return self.observation(self.ep_state) @@ -260,7 +266,9 @@ class SingleOrderEnv(gym.Env): self.action_history[self.ep_state.cur_step] = action exec_vol = self.action(action, self.ep_state) - step_state = self.update_state(exec_vol) + trade_decision = create_sub_order(exec_vol, self.executor, self.cur_order) + execute_result = self.executor.execute([trade_decision]) + step_state = self.ep_state.update(execute_result, self.executor) if self.executor.finished(): assert self.ep_state.done