1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-30 09:31:18 +08:00

Add RL strategy demo

This commit is contained in:
Yuge Zhang
2021-06-02 23:20:27 +08:00
parent f5ac6230e1
commit bf02fc23f8
2 changed files with 62 additions and 17 deletions

View File

@@ -1,17 +1,19 @@
import pickle
from collections import OrderedDict, defaultdict
from dataclasses import dataclass, asdict
from pprint import pprint
from typing import Iterable, Any, Optional, Tuple, Dict, List
from typing import Iterable, Any, Optional, OrderedDict, Tuple, Dict, List
import gym
import numpy as np
import pandas as pd
import qlib
from gym import spaces
from qlib.backtest import get_exchange, Account, BaseExecutor, CommonInfrastructure, Order
from qlib.backtest import get_exchange, Account, BaseExecutor, CommonInfrastructure, Order, TradeCalendarManager
from qlib.config import REG_CN
from qlib.data import D
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.strategy import BaseStrategy
from qlib.tests.data import GetData
from qlib.utils import init_instance_by_config, exists_qlib_data
from torch.utils.data import Dataset, DataLoader
@@ -129,35 +131,36 @@ class EpisodicState:
return logs
@classmethod
def from_order_and_executor(cls, order: Order, executor: BaseExecutor, frequency: str) -> "EpisodicState":
def from_order_and_executor(cls, order: Order, calendar: TradeCalendarManager, frequency: str) -> "EpisodicState":
# Synchronous state for executor to EpisodicState
executor.reset(start_time=order.start_time, end_time=order.end_time)
state = cls(
stock_id=order.stock_id,
start_time=order.start_time,
end_time=order.end_time,
direction=order.direction,
target=order.amount,
num_step=executor.trade_calendar.get_trade_len(),
num_step=calendar.get_trade_len(),
market_price=_retrieve_backtest_data(order, '$close', frequency),
market_vol=_retrieve_backtest_data(order, '$volume', frequency),
)
state.cur_step = executor.trade_calendar.get_trade_step()
state.cur_step = calendar.get_trade_step()
assert state.cur_step == 0
state.cur_time, _ = executor.trade_calendar.get_step_time(state.cur_step)
state.cur_time, _ = calendar.get_step_time(state.cur_step)
return state
def update(self, execute_result: List[Order], executor: BaseExecutor) -> "StepState":
def update(self, execute_result: List[Order], calendar: TradeCalendarManager, done: Optional[bool] = None) -> "StepState":
exec_vol = np.array([order.deal_amount for order, _, __, ___ in execute_result])
# Synchronous exec_vol to executor and synchronous back to EpisodicState
calendar = executor.trade_calendar
cur_tick = self.cur_tick
ticks_this_step = len(exec_vol)
self.cur_step = trade_step = calendar.get_trade_step()
self.cur_tick += ticks_this_step
self.position -= np.sum(exec_vol)
self.position_history[trade_step] = self.position
self.done = executor.finished()
if done is not None:
self.done = done
else:
self.done = self.position < 1e-5
self.exec_vol = exec_vol if self.exec_vol is None else \
np.concatenate((self.exec_vol, exec_vol))
@@ -211,9 +214,8 @@ def _retrieve_backtest_data(order: Order, field: str, frequency: str) -> np.ndar
)[field].to_numpy()
def create_sub_order(exec_vol: float, executor: BaseExecutor, original_order: Order) -> Order:
def create_sub_order(exec_vol: float, calendar: TradeCalendarManager, original_order: Order) -> Order:
# Convert a real number to an order
calendar = executor.trade_calendar
trade_step = calendar.get_trade_step()
trade_start_time, trade_end_time = calendar.get_step_time(trade_step)
order_kwargs = asdict(original_order)
@@ -253,8 +255,9 @@ class SingleOrderEnv(gym.Env):
return None
self.execute_result = []
self.executor.reset(start_time=self.cur_order.start_time, end_time=self.cur_order.end_time)
self.ep_state = EpisodicState.from_order_and_executor(
self.cur_order, self.executor, self.inner_frequency
self.cur_order, self.executor.trade_calendar, self.inner_frequency
)
self.action_history = np.full(self.ep_state.num_step, np.nan)
@@ -266,9 +269,9 @@ class SingleOrderEnv(gym.Env):
self.action_history[self.ep_state.cur_step] = action
exec_vol = self.action(action, self.ep_state)
trade_decision = create_sub_order(exec_vol, self.executor, self.cur_order)
trade_decision = create_sub_order(exec_vol, self.executor.trade_calendar, self.cur_order)
execute_result = self.executor.execute([trade_decision])
step_state = self.ep_state.update(execute_result, self.executor)
step_state = self.ep_state.update(execute_result, self.executor.trade_calendar)
if self.executor.finished():
assert self.ep_state.done
@@ -291,6 +294,47 @@ class SingleOrderEnv(gym.Env):
return self.observation(self.ep_state), reward, self.ep_state.done, info
class RLStrategy(BaseStrategy):
"""When inference and do the backtest from end to end, use this strategy."""
# TODO This strategy is still for code demo purpose only.
# It has not been end-to-end tested.
def __init__(
self,
observation: "Observation",
action: "Action",
policy: BasePolicy,
**kwargs
):
super().__init__(**kwargs)
self.observation = observation
self.action = action
self.policy = policy
def reset(self, outer_trade_decision: List[Order] = None, **kwargs):
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
if outer_trade_decision is not None:
self.states = OrderedDict() # explicitly make it ordered
for order in outer_trade_decision:
# TODO: how to get inner frequency
state = EpisodicState.from_order_and_executor(order, self.trade_calendar, "day")
self.states[order.stock_id, order.direction] = state
def generate_trade_decision(self, execute_result=None):
# apply results from the last step
if execute_result is not None:
orders = defaultdict(list)
for order, _, __, in execute_result:
orders[order.stock_id, order.direction].append(order)
for (stock_id, direction), state in self.states.items():
state.update(orders[stock_id, direction])
obs_batch = Batch([{"obs": self.observation(state)} for state in self.states.values()])
act = self.policy(obs_batch)
exec_vols = [self.action(a) for a in act.act]
return [create_sub_order(v, self.trade_calendar, order) for v in exec_vols]
def _init_qlib():
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
@@ -299,7 +343,7 @@ def _init_qlib():
qlib.init(provider_uri=provider_uri, region=REG_CN)
def _main():
def _main_tianshou():
_init_qlib()
# TODO: why is there a benchmark?
@@ -483,4 +527,4 @@ def _to_float32(val): return np.array(val, dtype=np.float32)
if __name__ == '__main__':
_main()
_main_tianshou()

View File

@@ -7,6 +7,7 @@ from .executor import BaseExecutor
from .backtest import backtest as backtest_func
from .backtest import collect_data as data_generator
from .order import Order
from .utils import TradeCalendarManager
from .utils import CommonInfrastructure
from .order import Order