diff --git a/examples/nested_decision_execution/rl_dummy.py b/examples/nested_decision_execution/rl_dummy.py index 3eec91789..61f1bba59 100644 --- a/examples/nested_decision_execution/rl_dummy.py +++ b/examples/nested_decision_execution/rl_dummy.py @@ -1,17 +1,19 @@ import pickle +from collections import OrderedDict, defaultdict from dataclasses import dataclass, asdict from pprint import pprint -from typing import Iterable, Any, Optional, Tuple, Dict, List +from typing import Iterable, Any, Optional, OrderedDict, Tuple, Dict, List import gym import numpy as np import pandas as pd import qlib from gym import spaces -from qlib.backtest import get_exchange, Account, BaseExecutor, CommonInfrastructure, Order +from qlib.backtest import get_exchange, Account, BaseExecutor, CommonInfrastructure, Order, TradeCalendarManager from qlib.config import REG_CN from qlib.data import D from qlib.rl.interpreter import StateInterpreter, ActionInterpreter +from qlib.strategy import BaseStrategy from qlib.tests.data import GetData from qlib.utils import init_instance_by_config, exists_qlib_data from torch.utils.data import Dataset, DataLoader @@ -129,35 +131,36 @@ class EpisodicState: return logs @classmethod - def from_order_and_executor(cls, order: Order, executor: BaseExecutor, frequency: str) -> "EpisodicState": + def from_order_and_executor(cls, order: Order, calendar: TradeCalendarManager, frequency: str) -> "EpisodicState": # Synchronous state for executor to EpisodicState - executor.reset(start_time=order.start_time, end_time=order.end_time) state = cls( stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time, direction=order.direction, target=order.amount, - num_step=executor.trade_calendar.get_trade_len(), + num_step=calendar.get_trade_len(), market_price=_retrieve_backtest_data(order, '$close', frequency), market_vol=_retrieve_backtest_data(order, '$volume', frequency), ) - state.cur_step = executor.trade_calendar.get_trade_step() + state.cur_step = calendar.get_trade_step() assert state.cur_step == 0 - state.cur_time, _ = executor.trade_calendar.get_step_time(state.cur_step) + state.cur_time, _ = calendar.get_step_time(state.cur_step) return state - def update(self, execute_result: List[Order], executor: BaseExecutor) -> "StepState": + def update(self, execute_result: List[Order], calendar: TradeCalendarManager, done: Optional[bool] = None) -> "StepState": exec_vol = np.array([order.deal_amount for order, _, __, ___ in execute_result]) # Synchronous exec_vol to executor and synchronous back to EpisodicState - calendar = executor.trade_calendar cur_tick = self.cur_tick ticks_this_step = len(exec_vol) self.cur_step = trade_step = calendar.get_trade_step() self.cur_tick += ticks_this_step self.position -= np.sum(exec_vol) self.position_history[trade_step] = self.position - self.done = executor.finished() + if done is not None: + self.done = done + else: + self.done = self.position < 1e-5 self.exec_vol = exec_vol if self.exec_vol is None else \ np.concatenate((self.exec_vol, exec_vol)) @@ -211,9 +214,8 @@ def _retrieve_backtest_data(order: Order, field: str, frequency: str) -> np.ndar )[field].to_numpy() -def create_sub_order(exec_vol: float, executor: BaseExecutor, original_order: Order) -> Order: +def create_sub_order(exec_vol: float, calendar: TradeCalendarManager, original_order: Order) -> Order: # Convert a real number to an order - calendar = executor.trade_calendar trade_step = calendar.get_trade_step() trade_start_time, trade_end_time = calendar.get_step_time(trade_step) order_kwargs = asdict(original_order) @@ -253,8 +255,9 @@ class SingleOrderEnv(gym.Env): return None self.execute_result = [] + self.executor.reset(start_time=self.cur_order.start_time, end_time=self.cur_order.end_time) self.ep_state = EpisodicState.from_order_and_executor( - self.cur_order, self.executor, self.inner_frequency + self.cur_order, self.executor.trade_calendar, self.inner_frequency ) self.action_history = np.full(self.ep_state.num_step, np.nan) @@ -266,9 +269,9 @@ class SingleOrderEnv(gym.Env): self.action_history[self.ep_state.cur_step] = action exec_vol = self.action(action, self.ep_state) - trade_decision = create_sub_order(exec_vol, self.executor, self.cur_order) + trade_decision = create_sub_order(exec_vol, self.executor.trade_calendar, self.cur_order) execute_result = self.executor.execute([trade_decision]) - step_state = self.ep_state.update(execute_result, self.executor) + step_state = self.ep_state.update(execute_result, self.executor.trade_calendar) if self.executor.finished(): assert self.ep_state.done @@ -291,6 +294,47 @@ class SingleOrderEnv(gym.Env): return self.observation(self.ep_state), reward, self.ep_state.done, info +class RLStrategy(BaseStrategy): + """When inference and do the backtest from end to end, use this strategy.""" + # TODO This strategy is still for code demo purpose only. + # It has not been end-to-end tested. + + def __init__( + self, + observation: "Observation", + action: "Action", + policy: BasePolicy, + **kwargs + ): + super().__init__(**kwargs) + self.observation = observation + self.action = action + self.policy = policy + + def reset(self, outer_trade_decision: List[Order] = None, **kwargs): + super().reset(outer_trade_decision=outer_trade_decision, **kwargs) + if outer_trade_decision is not None: + self.states = OrderedDict() # explicitly make it ordered + for order in outer_trade_decision: + # TODO: how to get inner frequency + state = EpisodicState.from_order_and_executor(order, self.trade_calendar, "day") + self.states[order.stock_id, order.direction] = state + + def generate_trade_decision(self, execute_result=None): + # apply results from the last step + if execute_result is not None: + orders = defaultdict(list) + for order, _, __, in execute_result: + orders[order.stock_id, order.direction].append(order) + for (stock_id, direction), state in self.states.items(): + state.update(orders[stock_id, direction]) + + obs_batch = Batch([{"obs": self.observation(state)} for state in self.states.values()]) + act = self.policy(obs_batch) + exec_vols = [self.action(a) for a in act.act] + return [create_sub_order(v, self.trade_calendar, order) for v in exec_vols] + + def _init_qlib(): provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir if not exists_qlib_data(provider_uri): @@ -299,7 +343,7 @@ def _init_qlib(): qlib.init(provider_uri=provider_uri, region=REG_CN) -def _main(): +def _main_tianshou(): _init_qlib() # TODO: why is there a benchmark? @@ -483,4 +527,4 @@ def _to_float32(val): return np.array(val, dtype=np.float32) if __name__ == '__main__': - _main() + _main_tianshou() diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index f80f7ebeb..c053269ef 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -7,6 +7,7 @@ from .executor import BaseExecutor from .backtest import backtest as backtest_func from .backtest import collect_data as data_generator from .order import Order +from .utils import TradeCalendarManager from .utils import CommonInfrastructure from .order import Order