mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 03:50:57 +08:00
Refactor for strategy
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import pickle
|
||||
from dataclasses import dataclass, asdict
|
||||
from pprint import pprint
|
||||
from typing import Iterable, Any, Optional, Tuple, Dict
|
||||
from typing import Iterable, Any, Optional, Tuple, Dict, List
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
@@ -128,6 +128,48 @@ class EpisodicState:
|
||||
}
|
||||
return logs
|
||||
|
||||
@classmethod
|
||||
def from_order_and_executor(cls, order: Order, executor: BaseExecutor, frequency: str) -> "EpisodicState":
|
||||
# Synchronous state for executor to EpisodicState
|
||||
executor.reset(start_time=order.start_time, end_time=order.end_time)
|
||||
state = cls(
|
||||
stock_id=order.stock_id,
|
||||
start_time=order.start_time,
|
||||
end_time=order.end_time,
|
||||
direction=order.direction,
|
||||
target=order.amount,
|
||||
num_step=executor.trade_calendar.get_trade_len(),
|
||||
market_price=_retrieve_backtest_data(order, '$close', frequency),
|
||||
market_vol=_retrieve_backtest_data(order, '$volume', frequency),
|
||||
)
|
||||
state.cur_step = executor.trade_calendar.get_trade_step()
|
||||
assert state.cur_step == 0
|
||||
state.cur_time, _ = executor.trade_calendar.get_step_time(state.cur_step)
|
||||
return state
|
||||
|
||||
def update(self, execute_result: List[Order], executor: BaseExecutor) -> "StepState":
|
||||
exec_vol = np.array([order.deal_amount for order, _, __, ___ in execute_result])
|
||||
# Synchronous exec_vol to executor and synchronous back to EpisodicState
|
||||
calendar = executor.trade_calendar
|
||||
cur_tick = self.cur_tick
|
||||
ticks_this_step = len(exec_vol)
|
||||
self.cur_step = trade_step = calendar.get_trade_step()
|
||||
self.cur_tick += ticks_this_step
|
||||
self.position -= np.sum(exec_vol)
|
||||
self.position_history[trade_step] = self.position
|
||||
self.done = executor.finished()
|
||||
self.exec_vol = exec_vol if self.exec_vol is None else \
|
||||
np.concatenate((self.exec_vol, exec_vol))
|
||||
|
||||
if self.done:
|
||||
self.update_stats()
|
||||
else:
|
||||
self.cur_time, _ = calendar.get_step_time(trade_step)
|
||||
|
||||
l, r = cur_tick, cur_tick + ticks_this_step
|
||||
assert 0 <= l < r
|
||||
return StepState(exec_vol, self.market_vol[l:r], self.market_price[l:r], self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepState:
|
||||
@@ -158,6 +200,28 @@ class StepState:
|
||||
self.episode_state.direction)
|
||||
|
||||
|
||||
def _retrieve_backtest_data(order: Order, field: str, frequency: str) -> np.ndarray:
|
||||
# Retrieve backtest data for RL-specific use (including reward calculation)
|
||||
return D.features(
|
||||
[order.stock_id],
|
||||
['$open', '$close', '$high', '$low', '$volume'],
|
||||
start_time=order.start_time,
|
||||
end_time=order.end_time,
|
||||
freq=frequency
|
||||
)[field].to_numpy()
|
||||
|
||||
|
||||
def create_sub_order(exec_vol: float, executor: BaseExecutor, original_order: Order) -> Order:
|
||||
# Convert a real number to an order
|
||||
calendar = executor.trade_calendar
|
||||
trade_step = calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = calendar.get_step_time(trade_step)
|
||||
order_kwargs = asdict(original_order)
|
||||
order_kwargs.update(start_time=trade_start_time, end_time=trade_end_time, amount=exec_vol)
|
||||
trade_decision = Order(**order_kwargs)
|
||||
return trade_decision
|
||||
|
||||
|
||||
class SingleOrderEnv(gym.Env):
|
||||
def __init__(self,
|
||||
observation: StateInterpreter,
|
||||
@@ -181,66 +245,6 @@ class SingleOrderEnv(gym.Env):
|
||||
def observation_space(self):
|
||||
return self.observation.observation_space
|
||||
|
||||
def retrieve_backtest_data(self, field: str):
|
||||
# Retrieve backtest data for RL-specific use (including reward calculation)
|
||||
return D.features(
|
||||
[self.cur_order.stock_id],
|
||||
['$open', '$close', '$high', '$low', '$volume'],
|
||||
start_time=self.cur_order.start_time,
|
||||
end_time=self.cur_order.end_time,
|
||||
freq=self.inner_frequency
|
||||
)[field].to_numpy()
|
||||
|
||||
def initialize_state(self):
|
||||
# Synchronous state for executor to EpisodicState
|
||||
self.executor.reset(start_time=self.cur_order.start_time, end_time=self.cur_order.end_time)
|
||||
state = EpisodicState(
|
||||
stock_id=self.cur_order.stock_id,
|
||||
start_time=self.cur_order.start_time,
|
||||
end_time=self.cur_order.end_time,
|
||||
direction=self.cur_order.direction,
|
||||
target=self.cur_order.amount,
|
||||
num_step=self.executor.trade_calendar.get_trade_len(),
|
||||
market_price=self.retrieve_backtest_data('$close'),
|
||||
market_vol=self.retrieve_backtest_data('$volume'),
|
||||
)
|
||||
state.cur_step = self.executor.trade_calendar.get_trade_step()
|
||||
assert state.cur_step == 0
|
||||
state.cur_time, _ = self.executor.trade_calendar.get_step_time(state.cur_step)
|
||||
return state
|
||||
|
||||
def update_state(self, exec_vol):
|
||||
# Synchronous exec_vol to executor and synchronous back to EpisodicState
|
||||
calendar = self.executor.trade_calendar
|
||||
state = self.ep_state
|
||||
|
||||
trade_step = calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = calendar.get_step_time(trade_step)
|
||||
order_kwargs = asdict(self.cur_order)
|
||||
order_kwargs.update(start_time=trade_start_time, end_time=trade_end_time, amount=exec_vol)
|
||||
trade_decision = Order(**order_kwargs)
|
||||
execute_result = self.executor.execute([trade_decision])
|
||||
cur_tick = state.cur_tick
|
||||
|
||||
inner_exec_vol = np.array([order.deal_amount for order, _, __, ___ in execute_result])
|
||||
ticks_this_step = len(inner_exec_vol)
|
||||
state.cur_step = trade_step = calendar.get_trade_step()
|
||||
state.cur_tick += ticks_this_step
|
||||
state.position -= np.sum(inner_exec_vol)
|
||||
state.position_history[trade_step] = state.position
|
||||
state.done = self.executor.finished()
|
||||
state.exec_vol = inner_exec_vol if state.exec_vol is None else \
|
||||
np.concatenate((state.exec_vol, inner_exec_vol))
|
||||
|
||||
if state.done:
|
||||
state.update_stats()
|
||||
else:
|
||||
state.cur_time, _ = calendar.get_step_time(trade_step)
|
||||
|
||||
l, r = cur_tick, cur_tick + ticks_this_step
|
||||
assert 0 <= l < r
|
||||
return StepState(inner_exec_vol, state.market_vol[l:r], state.market_price[l:r], state)
|
||||
|
||||
def reset(self):
|
||||
try:
|
||||
self.cur_order = next(self.dataloader)
|
||||
@@ -249,7 +253,9 @@ class SingleOrderEnv(gym.Env):
|
||||
return None
|
||||
|
||||
self.execute_result = []
|
||||
self.ep_state = self.initialize_state()
|
||||
self.ep_state = EpisodicState.from_order_and_executor(
|
||||
self.cur_order, self.executor, self.inner_frequency
|
||||
)
|
||||
|
||||
self.action_history = np.full(self.ep_state.num_step, np.nan)
|
||||
return self.observation(self.ep_state)
|
||||
@@ -260,7 +266,9 @@ class SingleOrderEnv(gym.Env):
|
||||
self.action_history[self.ep_state.cur_step] = action
|
||||
|
||||
exec_vol = self.action(action, self.ep_state)
|
||||
step_state = self.update_state(exec_vol)
|
||||
trade_decision = create_sub_order(exec_vol, self.executor, self.cur_order)
|
||||
execute_result = self.executor.execute([trade_decision])
|
||||
step_state = self.ep_state.update(execute_result, self.executor)
|
||||
if self.executor.finished():
|
||||
assert self.ep_state.done
|
||||
|
||||
|
||||
Reference in New Issue
Block a user