1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 10:01:19 +08:00

Finish RL dummy example

This commit is contained in:
Yuge Zhang
2021-06-02 16:41:18 +08:00
parent 3200bb88c8
commit d515efb46e
2 changed files with 183 additions and 166 deletions

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
import pandas as pd
from dataclasses import dataclass, field
from typing import ClassVar
from typing import ClassVar, Optional
@dataclass
@@ -26,7 +26,7 @@ class Order:
end_time: pd.Timestamp
direction: int
factor: float
deal_amount: float = field(init=False)
deal_amount: Optional[float] = None
SELL: ClassVar[int] = 0
BUY: ClassVar[int] = 1

View File

@@ -1,5 +1,6 @@
import pickle
from dataclasses import dataclass, asdict
from pprint import pprint
from typing import Iterable, Any, Optional, Tuple, Dict
import gym
@@ -22,7 +23,7 @@ from tianshou.policy import BasePolicy
MAX_STEPS = 10
def get_executor(start_time, end_time, executor, benchmark="SH000300", account=1e9, exchange_kwargs={}) -> BaseExecutor:
def get_executor(start_time, end_time, executor, exchange, benchmark="SH000300", account=1e9) -> BaseExecutor:
trade_account = Account(
init_cash=account,
benchmark_config={
@@ -31,9 +32,8 @@ def get_executor(start_time, end_time, executor, benchmark="SH000300", account=1
"end_time": end_time,
},
)
trade_exchange = get_exchange(**exchange_kwargs)
common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange)
common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=exchange)
trade_executor = init_instance_by_config(executor, accept_types=BaseExecutor, common_infra=common_infra)
return trade_executor
@@ -48,31 +48,6 @@ def price_advantage(exec_price: float, baseline_price: float, direction: int) ->
return (exec_price / baseline_price - 1) * 10000
def _to_int32(val): return np.array(int(val), dtype=np.int32)
def _to_float32(val): return np.array(val, dtype=np.float32)
class QlibOrderDataset(Dataset):
def __init__(self, order_file):
with open(order_file, 'rb') as f:
self.orders = pickle.load(f)
def __len__(self):
return len(self.orders)
def __getitem__(self, index):
return self.orders[index]
class DummyPolicy(BasePolicy):
def forward(self, batch, state=None, **kwargs):
print(batch)
return Batch(act=np.random.randint(5))
def learn(self, *args, **kwargs):
pass
@dataclass
class EpisodicState:
"""
@@ -182,103 +157,6 @@ class StepState:
self.episode_state.direction)
class Observation:
def __init__(self, time_per_step):
self.time_per_step = time_per_step
def __call__(self, ep_state: EpisodicState) -> Any:
obs = self.observe(ep_state)
if not self.validate(obs):
raise ValueError(f'Observation space does not contain obs. Space: {self.observation_space} Sample: {obs}')
return obs
def validate(self, obs: Any) -> bool:
return self.observation_space.contains(obs)
@property
def observation_space(self):
space = {
'direction': spaces.Discrete(2),
'cur_step': spaces.Box(0, MAX_STEPS - 1, shape=(), dtype=np.int32),
'num_step': spaces.Box(MAX_STEPS, MAX_STEPS, shape=(), dtype=np.int32),
'target': spaces.Box(-1e-5, np.inf, shape=()),
'position': spaces.Box(-1e-5, np.inf, shape=()),
'features': spaces.Box(-np.inf, np.inf, shape=(5, ))
}
return spaces.Dict(space)
def observe(self, ep_state: EpisodicState) -> Any:
return {
'acquiring': _to_int32(ep_state.direction),
'cur_step': _to_int32(min(ep_state.cur_step, ep_state.num_step - 1)),
'num_step': _to_int32(ep_state.num_step),
'target': _to_float32(ep_state.target),
'position': _to_float32(ep_state.position),
'features': D.features(
[ep_state.stock_id],
['$open', '$close', '$high', '$low', '$volume'],
start_time=ep_state.start_time,
end_time=ep_state.end_time,
freq=self.time_per_step
)
}
class Action:
@property
def action_space(self):
return spaces.Discrete(5)
def __call__(self, action: Any, ep_state: EpisodicState) -> Any:
if not self.validate(action):
raise ValueError(f'Action space does not contain action. Space: {self.action_space} Sample: {action}')
act_ = self.to_volume(action, ep_state)
return act_
def validate(self, action: Any) -> bool:
return self.action_space.contains(action)
def to_volume(self, action: Any, ep_state: EpisodicState):
exec_vol = ep_state.position / 5 * action
if ep_state.cur_step + 1 >= ep_state.num_step:
exec_vol = ep_state.position
# TODO: might need to check whether the stock is tradable or whether it satisfies trade unit?
return exec_vol
class Reward:
weight = 1.0
def __call__(self, ep_state: EpisodicState, st_state: StepState) -> Tuple[float, Dict[str, float]]:
rew, info = 0., {}
if ep_state.done:
ep_rew, ep_info = self._to_tuple(self.episode_end(ep_state))
rew += ep_rew
info.update({f'ep/{k}': v for k, v in ep_info.items()})
st_rew, st_info = self._to_tuple(self.step_end(ep_state, st_state))
rew += st_rew
info.update({f'st/{k}': v for k, v in st_info.items()})
return rew * self.weight, info
@staticmethod
def _to_tuple(x):
if isinstance(x, tuple):
return x
return x, {}
def episode_end(self, ep_state: EpisodicState) -> Tuple[float, Dict[str, float]]:
return 0.
def step_end(self, ep_state: EpisodicState, st_state: StepState) -> Tuple[float, Dict[str, float]]:
assert ep_state.target > 0
baseline_price = st_state.pa_twap
pa = baseline_price * st_state.exec_vol.sum() / ep_state.target
penalty = -self.penalty * ((st_state.exec_vol / ep_state.target) ** 2).sum()
reward = pa + penalty
return reward, {'pa': pa, 'penalty': penalty}
class SingleOrderEnv(gym.Env):
def __init__(self,
observation: StateInterpreter,
@@ -313,7 +191,7 @@ class SingleOrderEnv(gym.Env):
def initialize_state(self):
self.executor.reset(start_time=self.cur_order.start_time, end_time=self.cur_order.end_time)
return EpisodicState(
state = EpisodicState(
stock_id=self.cur_order.stock_id,
start_time=self.cur_order.start_time,
end_time=self.cur_order.end_time,
@@ -323,29 +201,37 @@ class SingleOrderEnv(gym.Env):
market_price=self.retrieve_backtest_data('$close'),
market_vol=self.retrieve_backtest_data('$volume'),
)
state.cur_step = self.executor.trade_calendar.get_trade_step()
assert state.cur_step == 0
state.cur_time, _ = self.executor.trade_calendar.get_step_time(state.cur_step)
return state
def update_state(self, exec_vol):
trade_step = self.trade_calendar.get_trade_step()
trade_start_time = self.executor.trade_calendar.get_step_time(trade_step)
trade_end_time = self.executor.trade_calendar.get_step_time(trade_step, shift=1)
trade_decision = Order(**asdict(self.cur_order),
start_time=trade_start_time, end_time=trade_end_time, amount=exec_vol)
calendar = self.executor.trade_calendar
state = self.ep_state
trade_step = calendar.get_trade_step()
trade_start_time, trade_end_time = calendar.get_step_time(trade_step)
order_kwargs = asdict(self.cur_order)
order_kwargs.update(start_time=trade_start_time, end_time=trade_end_time, amount=exec_vol)
trade_decision = Order(**order_kwargs)
execute_result = self.executor.execute([trade_decision])
cur_tick = self.ep_state.cur_tick
cur_tick = state.cur_tick
inner_exec_vol = np.array([order.deal_amount for order, _, __, ___ in execute_result])
ticks_this_step = len(inner_exec_vol)
state = self.ep_state
state.cur_step = trade_step = self.executor.trade_calendar.get_trade_step()
state.cur_time = self.executor.trade_calendar.get_step_time(trade_step)
state.cur_step = trade_step = calendar.get_trade_step()
state.cur_tick += ticks_this_step
state.position -= np.sum(inner_exec_vol)
state.position_history[trade_step] = state.position
state.exec_vol = inner_exec_vol if state.exec_vol is None else np.concatenate((state.exec_vol, inner_exec_vol))
state.done = self.executor.finished()
state.exec_vol = inner_exec_vol if state.exec_vol is None else \
np.concatenate((state.exec_vol, inner_exec_vol))
if state.done:
state.update_stats()
else:
state.cur_time, _ = calendar.get_step_time(trade_step)
l, r = cur_tick, cur_tick + ticks_this_step
assert 0 <= l < r
@@ -362,19 +248,23 @@ class SingleOrderEnv(gym.Env):
self.ep_state = self.initialize_state()
self.action_history = np.full(self.ep_state.num_step, np.nan)
return self.observation(self.cur_sample, self.ep_state)
return self.observation(self.ep_state)
def step(self, action):
assert self.dataloader is not None
assert not self.executor.finished()
self.action_history[self.ep_state.cur_step] = action
exec_vol = self.action(action, self.ep_state)
step_state = self.update_state(exec_vol)
if self.executor.finished():
assert self.ep_state.done
reward, rew_info = self.reward(self.ep_state, step_state)
info = {
'action_history': self.action_history,
'category': self.ep_state.flow_dir.value,
'category': self.ep_state.direction,
'reward': rew_info
}
if self.ep_state.done:
@@ -383,8 +273,9 @@ class SingleOrderEnv(gym.Env):
'ins': self.ep_state.stock_id,
'date': self.ep_state.start_time,
}
pprint(info)
return self.observation(self.cur_sample, self.ep_state), reward, self.ep_state.done, info
return self.observation(self.ep_state), reward, self.ep_state.done, info
def _init_qlib():
@@ -412,39 +303,165 @@ def _main():
"generate_report": False,
}
}
executor = get_executor(
trade_start_time,
trade_end_time,
executor_config,
benchmark,
1000000000,
exchange_kwargs={
"freq": "day",
"limit_threshold": 0.095,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
}
exchange = get_exchange(
freq="day",
limit_threshold=0.095,
deal_price="close",
open_cost=0.0005,
close_cost=0.0015,
min_cost=5
)
observation = Observation(time_per_step)
action = Action()
reward_fn = Reward()
def dummy_env(): return SingleOrderEnv(
observation, action, reward_fn,
DataLoader(QlibOrderDataset('rl.pkl'), batch_size=None, shuffle=True), executor)
def dummy_env():
executor = get_executor(
trade_start_time,
trade_end_time,
executor_config,
exchange,
benchmark,
1000000000,
)
return SingleOrderEnv(
observation, action, reward_fn,
iter(DataLoader(QlibOrderDataset('rl.pkl'), batch_size=None, shuffle=True)), executor)
policy = DummyPolicy()
env = dummy_env()
obs = env.reset()
print(obs)
envs = DummyVectorEnv([dummy_env for _ in range(4)])
test_collector = Collector(policy, envs)
policy.eval()
test_collector.collect(n_episode=10)
# envs = DummyVectorEnv([dummy_env for _ in range(4)])
# test_collector = Collector(policy, envs)
# policy.eval()
# test_collector.collect(n_episode=10)
### This is a full RL strategy ###
class QlibOrderDataset(Dataset):
def __init__(self, order_file):
with open(order_file, 'rb') as f:
self.orders = pickle.load(f)
def __len__(self):
return len(self.orders)
def __getitem__(self, index):
return self.orders[index]
class DummyPolicy(BasePolicy):
def forward(self, batch, state=None, **kwargs):
return Batch(act=np.random.randint(0, 5, size=(len(batch), )))
def learn(self, *args, **kwargs):
pass
class Observation:
def __init__(self, time_per_step):
self.time_per_step = time_per_step
def __call__(self, ep_state: EpisodicState) -> Any:
obs = self.observe(ep_state)
if not self.validate(obs):
raise ValueError(f'Observation space does not contain obs. Space: {self.observation_space} Sample: {obs}')
return obs
def validate(self, obs: Any) -> bool:
return self.observation_space.contains(obs)
@property
def observation_space(self):
space = {
'direction': spaces.Discrete(2),
'cur_step': spaces.Box(0, MAX_STEPS, shape=(), dtype=np.int32),
'num_step': spaces.Box(0, MAX_STEPS, shape=(), dtype=np.int32),
'target': spaces.Box(-1e-5, np.inf, shape=()),
'position': spaces.Box(-1e-5, np.inf, shape=()),
'features': spaces.Box(-np.inf, np.inf, shape=(5, ))
}
return spaces.Dict(space)
def observe(self, ep_state: EpisodicState) -> Any:
return {
'direction': _to_int32(ep_state.direction),
'cur_step': _to_int32(min(ep_state.cur_step, ep_state.num_step - 1)),
'num_step': _to_int32(ep_state.num_step),
'target': _to_float32(ep_state.target),
'position': _to_float32(ep_state.position),
'features': D.features(
[ep_state.stock_id],
['$open', '$close', '$high', '$low', '$volume'],
start_time=ep_state.start_time,
end_time=ep_state.end_time,
freq=self.time_per_step
).loc[(ep_state.stock_id, ep_state.cur_time)].to_numpy(),
}
class Action:
denominator = 4
@property
def action_space(self):
return spaces.Discrete(self.denominator + 1)
def __call__(self, action: Any, ep_state: EpisodicState) -> Any:
if not self.validate(action):
raise ValueError(f'Action space does not contain action. Space: {self.action_space} Sample: {action}')
act_ = self.to_volume(action, ep_state)
return act_
def validate(self, action: Any) -> bool:
return self.action_space.contains(action)
def to_volume(self, action: Any, ep_state: EpisodicState):
exec_vol = ep_state.position / self.denominator * action
if ep_state.cur_step + 1 >= ep_state.num_step:
exec_vol = ep_state.position
# TODO: might need to check whether the stock is tradable or whether it satisfies trade unit?
return exec_vol
class Reward:
weight = 1.0
def __call__(self, ep_state: EpisodicState, st_state: StepState) -> Tuple[float, Dict[str, float]]:
rew, info = 0., {}
if ep_state.done:
ep_rew, ep_info = self._to_tuple(self.episode_end(ep_state))
rew += ep_rew
info.update({f'ep/{k}': v for k, v in ep_info.items()})
st_rew, st_info = self._to_tuple(self.step_end(ep_state, st_state))
rew += st_rew
info.update({f'st/{k}': v for k, v in st_info.items()})
return rew * self.weight, info
@staticmethod
def _to_tuple(x):
if isinstance(x, tuple):
return x
return x, {}
def episode_end(self, ep_state: EpisodicState) -> Tuple[float, Dict[str, float]]:
return 0.
def step_end(self, ep_state: EpisodicState, st_state: StepState) -> Tuple[float, Dict[str, float]]:
assert ep_state.target > 0
baseline_price = st_state.pa_twap
pa = baseline_price * st_state.exec_vol.sum() / ep_state.target
penalty = -100 * ((st_state.exec_vol / ep_state.target) ** 2).sum() # penalize too much volume at one step
reward = pa + penalty
return reward, {'pa': pa, 'penalty': penalty}
def _to_int32(val): return np.array(int(val), dtype=np.int32)
def _to_float32(val): return np.array(val, dtype=np.float32)
### End of RL strategy ###
if __name__ == '__main__':