mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 20:11:08 +08:00
Update impl for robustness
This commit is contained in:
@@ -152,8 +152,13 @@ class EpisodicState:
|
||||
state.cur_time, _ = calendar.get_step_time(state.cur_step)
|
||||
return state
|
||||
|
||||
def update(self, execute_result: List[Order], calendar: TradeCalendarManager, done: Optional[bool] = None) -> "StepState":
|
||||
exec_vol = np.array([order.deal_amount for order, _, __, ___ in execute_result])
|
||||
def update(self, execute_result: List[Order], calendar: TradeCalendarManager,
|
||||
done: Optional[bool] = None, length: Optional[int] = None) -> "StepState":
|
||||
if length is not None:
|
||||
exec_vol = np.zeros(length)
|
||||
exec_vol[:len(execute_result)] = np.array([order.deal_amount for order, _, __, ___ in execute_result])
|
||||
else:
|
||||
exec_vol = np.array([order.deal_amount for order, _, __, ___ in execute_result])
|
||||
# Synchronous exec_vol to executor and synchronous back to EpisodicState
|
||||
cur_tick = self.cur_tick
|
||||
ticks_this_step = len(exec_vol)
|
||||
@@ -300,8 +305,6 @@ class SingleOrderEnv(gym.Env):
|
||||
|
||||
class RLStrategy(BaseStrategy):
|
||||
"""When inference and do the backtest from end to end, use this strategy."""
|
||||
# TODO This strategy is still for code demo purpose only.
|
||||
# It has not been end-to-end tested.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -315,12 +318,15 @@ class RLStrategy(BaseStrategy):
|
||||
self.action = action
|
||||
self.policy = policy
|
||||
|
||||
# TODO: how to get inner frequency and trade len
|
||||
self.inner_frequency = "day"
|
||||
self.inner_trade_len = 1
|
||||
|
||||
def reset(self, outer_trade_decision: List[Order] = None, **kwargs):
|
||||
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
self.states = OrderedDict() # explicitly make it ordered
|
||||
for order in outer_trade_decision:
|
||||
# TODO: how to get inner frequency
|
||||
state = EpisodicState.from_order_and_executor(order, self.trade_calendar, "day")
|
||||
self.states[order.stock_id, order.direction] = state
|
||||
|
||||
@@ -331,7 +337,7 @@ class RLStrategy(BaseStrategy):
|
||||
for e in execute_result:
|
||||
orders[e[0].stock_id, e[0].direction].append(e)
|
||||
for (stock_id, direction), state in self.states.items():
|
||||
state.update(orders[stock_id, direction], self.trade_calendar)
|
||||
state.update(orders[stock_id, direction], self.trade_calendar, length=self.inner_trade_len)
|
||||
|
||||
if not self.states:
|
||||
return []
|
||||
@@ -495,19 +501,21 @@ class Observation:
|
||||
return spaces.Dict(space)
|
||||
|
||||
def observe(self, ep_state: EpisodicState) -> Any:
|
||||
features = D.features(
|
||||
[ep_state.stock_id],
|
||||
['$open', '$close', '$high', '$low', '$volume'],
|
||||
start_time=ep_state.start_time,
|
||||
end_time=ep_state.end_time,
|
||||
freq=self.time_per_step
|
||||
).loc[(ep_state.stock_id, ep_state.cur_time)].to_numpy()
|
||||
features = np.nan_to_num(features)
|
||||
return {
|
||||
'direction': _to_int32(ep_state.direction),
|
||||
'cur_step': _to_int32(min(ep_state.cur_step, ep_state.num_step - 1)),
|
||||
'num_step': _to_int32(ep_state.num_step),
|
||||
'target': _to_float32(ep_state.target),
|
||||
'position': _to_float32(ep_state.position),
|
||||
'features': D.features(
|
||||
[ep_state.stock_id],
|
||||
['$open', '$close', '$high', '$low', '$volume'],
|
||||
start_time=ep_state.start_time,
|
||||
end_time=ep_state.end_time,
|
||||
freq=self.time_per_step
|
||||
).loc[(ep_state.stock_id, ep_state.cur_time)].to_numpy(),
|
||||
'features': features,
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user