1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 20:11:08 +08:00

Update impl for robustness

This commit is contained in:
Yuge Zhang
2021-06-04 13:01:49 +08:00
parent c43805eff6
commit 1581ef12ac

View File

@@ -152,8 +152,13 @@ class EpisodicState:
state.cur_time, _ = calendar.get_step_time(state.cur_step)
return state
def update(self, execute_result: List[Order], calendar: TradeCalendarManager, done: Optional[bool] = None) -> "StepState":
exec_vol = np.array([order.deal_amount for order, _, __, ___ in execute_result])
def update(self, execute_result: List[Order], calendar: TradeCalendarManager,
done: Optional[bool] = None, length: Optional[int] = None) -> "StepState":
if length is not None:
exec_vol = np.zeros(length)
exec_vol[:len(execute_result)] = np.array([order.deal_amount for order, _, __, ___ in execute_result])
else:
exec_vol = np.array([order.deal_amount for order, _, __, ___ in execute_result])
# Synchronous exec_vol to executor and synchronous back to EpisodicState
cur_tick = self.cur_tick
ticks_this_step = len(exec_vol)
@@ -300,8 +305,6 @@ class SingleOrderEnv(gym.Env):
class RLStrategy(BaseStrategy):
"""When inference and do the backtest from end to end, use this strategy."""
# TODO This strategy is still for code demo purpose only.
# It has not been end-to-end tested.
def __init__(
self,
@@ -315,12 +318,15 @@ class RLStrategy(BaseStrategy):
self.action = action
self.policy = policy
# TODO: how to get inner frequency and trade len
self.inner_frequency = "day"
self.inner_trade_len = 1
def reset(self, outer_trade_decision: List[Order] = None, **kwargs):
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
if outer_trade_decision is not None:
self.states = OrderedDict() # explicitly make it ordered
for order in outer_trade_decision:
# TODO: how to get inner frequency
state = EpisodicState.from_order_and_executor(order, self.trade_calendar, "day")
self.states[order.stock_id, order.direction] = state
@@ -331,7 +337,7 @@ class RLStrategy(BaseStrategy):
for e in execute_result:
orders[e[0].stock_id, e[0].direction].append(e)
for (stock_id, direction), state in self.states.items():
state.update(orders[stock_id, direction], self.trade_calendar)
state.update(orders[stock_id, direction], self.trade_calendar, length=self.inner_trade_len)
if not self.states:
return []
@@ -495,19 +501,21 @@ class Observation:
return spaces.Dict(space)
def observe(self, ep_state: EpisodicState) -> Any:
features = D.features(
[ep_state.stock_id],
['$open', '$close', '$high', '$low', '$volume'],
start_time=ep_state.start_time,
end_time=ep_state.end_time,
freq=self.time_per_step
).loc[(ep_state.stock_id, ep_state.cur_time)].to_numpy()
features = np.nan_to_num(features)
return {
'direction': _to_int32(ep_state.direction),
'cur_step': _to_int32(min(ep_state.cur_step, ep_state.num_step - 1)),
'num_step': _to_int32(ep_state.num_step),
'target': _to_float32(ep_state.target),
'position': _to_float32(ep_state.position),
'features': D.features(
[ep_state.stock_id],
['$open', '$close', '$high', '$low', '$volume'],
start_time=ep_state.start_time,
end_time=ep_state.end_time,
freq=self.time_per_step
).loc[(ep_state.stock_id, ep_state.cur_time)].to_numpy(),
'features': features,
}