1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 11:30:57 +08:00

Refactor for strategy

This commit is contained in:
Yuge Zhang
2021-06-02 22:04:54 +08:00
parent 2314405613
commit f5ac6230e1

View File

@@ -1,7 +1,7 @@
import pickle
from dataclasses import dataclass, asdict
from pprint import pprint
from typing import Iterable, Any, Optional, Tuple, Dict
from typing import Iterable, Any, Optional, Tuple, Dict, List
import gym
import numpy as np
@@ -128,6 +128,48 @@ class EpisodicState:
}
return logs
@classmethod
def from_order_and_executor(cls, order: Order, executor: BaseExecutor, frequency: str) -> "EpisodicState":
# Synchronous state for executor to EpisodicState
executor.reset(start_time=order.start_time, end_time=order.end_time)
state = cls(
stock_id=order.stock_id,
start_time=order.start_time,
end_time=order.end_time,
direction=order.direction,
target=order.amount,
num_step=executor.trade_calendar.get_trade_len(),
market_price=_retrieve_backtest_data(order, '$close', frequency),
market_vol=_retrieve_backtest_data(order, '$volume', frequency),
)
state.cur_step = executor.trade_calendar.get_trade_step()
assert state.cur_step == 0
state.cur_time, _ = executor.trade_calendar.get_step_time(state.cur_step)
return state
def update(self, execute_result: List[Order], executor: BaseExecutor) -> "StepState":
exec_vol = np.array([order.deal_amount for order, _, __, ___ in execute_result])
# Synchronous exec_vol to executor and synchronous back to EpisodicState
calendar = executor.trade_calendar
cur_tick = self.cur_tick
ticks_this_step = len(exec_vol)
self.cur_step = trade_step = calendar.get_trade_step()
self.cur_tick += ticks_this_step
self.position -= np.sum(exec_vol)
self.position_history[trade_step] = self.position
self.done = executor.finished()
self.exec_vol = exec_vol if self.exec_vol is None else \
np.concatenate((self.exec_vol, exec_vol))
if self.done:
self.update_stats()
else:
self.cur_time, _ = calendar.get_step_time(trade_step)
l, r = cur_tick, cur_tick + ticks_this_step
assert 0 <= l < r
return StepState(exec_vol, self.market_vol[l:r], self.market_price[l:r], self)
@dataclass
class StepState:
@@ -158,6 +200,28 @@ class StepState:
self.episode_state.direction)
def _retrieve_backtest_data(order: Order, field: str, frequency: str) -> np.ndarray:
# Retrieve backtest data for RL-specific use (including reward calculation)
return D.features(
[order.stock_id],
['$open', '$close', '$high', '$low', '$volume'],
start_time=order.start_time,
end_time=order.end_time,
freq=frequency
)[field].to_numpy()
def create_sub_order(exec_vol: float, executor: BaseExecutor, original_order: Order) -> Order:
# Convert a real number to an order
calendar = executor.trade_calendar
trade_step = calendar.get_trade_step()
trade_start_time, trade_end_time = calendar.get_step_time(trade_step)
order_kwargs = asdict(original_order)
order_kwargs.update(start_time=trade_start_time, end_time=trade_end_time, amount=exec_vol)
trade_decision = Order(**order_kwargs)
return trade_decision
class SingleOrderEnv(gym.Env):
def __init__(self,
observation: StateInterpreter,
@@ -181,66 +245,6 @@ class SingleOrderEnv(gym.Env):
def observation_space(self):
return self.observation.observation_space
def retrieve_backtest_data(self, field: str):
# Retrieve backtest data for RL-specific use (including reward calculation)
return D.features(
[self.cur_order.stock_id],
['$open', '$close', '$high', '$low', '$volume'],
start_time=self.cur_order.start_time,
end_time=self.cur_order.end_time,
freq=self.inner_frequency
)[field].to_numpy()
def initialize_state(self):
# Synchronous state for executor to EpisodicState
self.executor.reset(start_time=self.cur_order.start_time, end_time=self.cur_order.end_time)
state = EpisodicState(
stock_id=self.cur_order.stock_id,
start_time=self.cur_order.start_time,
end_time=self.cur_order.end_time,
direction=self.cur_order.direction,
target=self.cur_order.amount,
num_step=self.executor.trade_calendar.get_trade_len(),
market_price=self.retrieve_backtest_data('$close'),
market_vol=self.retrieve_backtest_data('$volume'),
)
state.cur_step = self.executor.trade_calendar.get_trade_step()
assert state.cur_step == 0
state.cur_time, _ = self.executor.trade_calendar.get_step_time(state.cur_step)
return state
def update_state(self, exec_vol):
# Synchronous exec_vol to executor and synchronous back to EpisodicState
calendar = self.executor.trade_calendar
state = self.ep_state
trade_step = calendar.get_trade_step()
trade_start_time, trade_end_time = calendar.get_step_time(trade_step)
order_kwargs = asdict(self.cur_order)
order_kwargs.update(start_time=trade_start_time, end_time=trade_end_time, amount=exec_vol)
trade_decision = Order(**order_kwargs)
execute_result = self.executor.execute([trade_decision])
cur_tick = state.cur_tick
inner_exec_vol = np.array([order.deal_amount for order, _, __, ___ in execute_result])
ticks_this_step = len(inner_exec_vol)
state.cur_step = trade_step = calendar.get_trade_step()
state.cur_tick += ticks_this_step
state.position -= np.sum(inner_exec_vol)
state.position_history[trade_step] = state.position
state.done = self.executor.finished()
state.exec_vol = inner_exec_vol if state.exec_vol is None else \
np.concatenate((state.exec_vol, inner_exec_vol))
if state.done:
state.update_stats()
else:
state.cur_time, _ = calendar.get_step_time(trade_step)
l, r = cur_tick, cur_tick + ticks_this_step
assert 0 <= l < r
return StepState(inner_exec_vol, state.market_vol[l:r], state.market_price[l:r], state)
def reset(self):
try:
self.cur_order = next(self.dataloader)
@@ -249,7 +253,9 @@ class SingleOrderEnv(gym.Env):
return None
self.execute_result = []
self.ep_state = self.initialize_state()
self.ep_state = EpisodicState.from_order_and_executor(
self.cur_order, self.executor, self.inner_frequency
)
self.action_history = np.full(self.ep_state.num_step, np.nan)
return self.observation(self.ep_state)
@@ -260,7 +266,9 @@ class SingleOrderEnv(gym.Env):
self.action_history[self.ep_state.cur_step] = action
exec_vol = self.action(action, self.ep_state)
step_state = self.update_state(exec_vol)
trade_decision = create_sub_order(exec_vol, self.executor, self.cur_order)
execute_result = self.executor.execute([trade_decision])
step_state = self.ep_state.update(execute_result, self.executor)
if self.executor.finished():
assert self.ep_state.done