mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-30 09:31:18 +08:00
Add RL strategy demo
This commit is contained in:
@@ -1,17 +1,19 @@
|
||||
import pickle
|
||||
from collections import OrderedDict, defaultdict
|
||||
from dataclasses import dataclass, asdict
|
||||
from pprint import pprint
|
||||
from typing import Iterable, Any, Optional, Tuple, Dict, List
|
||||
from typing import Iterable, Any, Optional, OrderedDict, Tuple, Dict, List
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import qlib
|
||||
from gym import spaces
|
||||
from qlib.backtest import get_exchange, Account, BaseExecutor, CommonInfrastructure, Order
|
||||
from qlib.backtest import get_exchange, Account, BaseExecutor, CommonInfrastructure, Order, TradeCalendarManager
|
||||
from qlib.config import REG_CN
|
||||
from qlib.data import D
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
||||
from qlib.strategy import BaseStrategy
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.utils import init_instance_by_config, exists_qlib_data
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
@@ -129,35 +131,36 @@ class EpisodicState:
|
||||
return logs
|
||||
|
||||
@classmethod
|
||||
def from_order_and_executor(cls, order: Order, executor: BaseExecutor, frequency: str) -> "EpisodicState":
|
||||
def from_order_and_executor(cls, order: Order, calendar: TradeCalendarManager, frequency: str) -> "EpisodicState":
|
||||
# Synchronous state for executor to EpisodicState
|
||||
executor.reset(start_time=order.start_time, end_time=order.end_time)
|
||||
state = cls(
|
||||
stock_id=order.stock_id,
|
||||
start_time=order.start_time,
|
||||
end_time=order.end_time,
|
||||
direction=order.direction,
|
||||
target=order.amount,
|
||||
num_step=executor.trade_calendar.get_trade_len(),
|
||||
num_step=calendar.get_trade_len(),
|
||||
market_price=_retrieve_backtest_data(order, '$close', frequency),
|
||||
market_vol=_retrieve_backtest_data(order, '$volume', frequency),
|
||||
)
|
||||
state.cur_step = executor.trade_calendar.get_trade_step()
|
||||
state.cur_step = calendar.get_trade_step()
|
||||
assert state.cur_step == 0
|
||||
state.cur_time, _ = executor.trade_calendar.get_step_time(state.cur_step)
|
||||
state.cur_time, _ = calendar.get_step_time(state.cur_step)
|
||||
return state
|
||||
|
||||
def update(self, execute_result: List[Order], executor: BaseExecutor) -> "StepState":
|
||||
def update(self, execute_result: List[Order], calendar: TradeCalendarManager, done: Optional[bool] = None) -> "StepState":
|
||||
exec_vol = np.array([order.deal_amount for order, _, __, ___ in execute_result])
|
||||
# Synchronous exec_vol to executor and synchronous back to EpisodicState
|
||||
calendar = executor.trade_calendar
|
||||
cur_tick = self.cur_tick
|
||||
ticks_this_step = len(exec_vol)
|
||||
self.cur_step = trade_step = calendar.get_trade_step()
|
||||
self.cur_tick += ticks_this_step
|
||||
self.position -= np.sum(exec_vol)
|
||||
self.position_history[trade_step] = self.position
|
||||
self.done = executor.finished()
|
||||
if done is not None:
|
||||
self.done = done
|
||||
else:
|
||||
self.done = self.position < 1e-5
|
||||
self.exec_vol = exec_vol if self.exec_vol is None else \
|
||||
np.concatenate((self.exec_vol, exec_vol))
|
||||
|
||||
@@ -211,9 +214,8 @@ def _retrieve_backtest_data(order: Order, field: str, frequency: str) -> np.ndar
|
||||
)[field].to_numpy()
|
||||
|
||||
|
||||
def create_sub_order(exec_vol: float, executor: BaseExecutor, original_order: Order) -> Order:
|
||||
def create_sub_order(exec_vol: float, calendar: TradeCalendarManager, original_order: Order) -> Order:
|
||||
# Convert a real number to an order
|
||||
calendar = executor.trade_calendar
|
||||
trade_step = calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = calendar.get_step_time(trade_step)
|
||||
order_kwargs = asdict(original_order)
|
||||
@@ -253,8 +255,9 @@ class SingleOrderEnv(gym.Env):
|
||||
return None
|
||||
|
||||
self.execute_result = []
|
||||
self.executor.reset(start_time=self.cur_order.start_time, end_time=self.cur_order.end_time)
|
||||
self.ep_state = EpisodicState.from_order_and_executor(
|
||||
self.cur_order, self.executor, self.inner_frequency
|
||||
self.cur_order, self.executor.trade_calendar, self.inner_frequency
|
||||
)
|
||||
|
||||
self.action_history = np.full(self.ep_state.num_step, np.nan)
|
||||
@@ -266,9 +269,9 @@ class SingleOrderEnv(gym.Env):
|
||||
self.action_history[self.ep_state.cur_step] = action
|
||||
|
||||
exec_vol = self.action(action, self.ep_state)
|
||||
trade_decision = create_sub_order(exec_vol, self.executor, self.cur_order)
|
||||
trade_decision = create_sub_order(exec_vol, self.executor.trade_calendar, self.cur_order)
|
||||
execute_result = self.executor.execute([trade_decision])
|
||||
step_state = self.ep_state.update(execute_result, self.executor)
|
||||
step_state = self.ep_state.update(execute_result, self.executor.trade_calendar)
|
||||
if self.executor.finished():
|
||||
assert self.ep_state.done
|
||||
|
||||
@@ -291,6 +294,47 @@ class SingleOrderEnv(gym.Env):
|
||||
return self.observation(self.ep_state), reward, self.ep_state.done, info
|
||||
|
||||
|
||||
class RLStrategy(BaseStrategy):
|
||||
"""When inference and do the backtest from end to end, use this strategy."""
|
||||
# TODO This strategy is still for code demo purpose only.
|
||||
# It has not been end-to-end tested.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation: "Observation",
|
||||
action: "Action",
|
||||
policy: BasePolicy,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.observation = observation
|
||||
self.action = action
|
||||
self.policy = policy
|
||||
|
||||
def reset(self, outer_trade_decision: List[Order] = None, **kwargs):
|
||||
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
self.states = OrderedDict() # explicitly make it ordered
|
||||
for order in outer_trade_decision:
|
||||
# TODO: how to get inner frequency
|
||||
state = EpisodicState.from_order_and_executor(order, self.trade_calendar, "day")
|
||||
self.states[order.stock_id, order.direction] = state
|
||||
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
# apply results from the last step
|
||||
if execute_result is not None:
|
||||
orders = defaultdict(list)
|
||||
for order, _, __, in execute_result:
|
||||
orders[order.stock_id, order.direction].append(order)
|
||||
for (stock_id, direction), state in self.states.items():
|
||||
state.update(orders[stock_id, direction])
|
||||
|
||||
obs_batch = Batch([{"obs": self.observation(state)} for state in self.states.values()])
|
||||
act = self.policy(obs_batch)
|
||||
exec_vols = [self.action(a) for a in act.act]
|
||||
return [create_sub_order(v, self.trade_calendar, order) for v in exec_vols]
|
||||
|
||||
|
||||
def _init_qlib():
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
@@ -299,7 +343,7 @@ def _init_qlib():
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
|
||||
def _main():
|
||||
def _main_tianshou():
|
||||
_init_qlib()
|
||||
|
||||
# TODO: why is there a benchmark?
|
||||
@@ -483,4 +527,4 @@ def _to_float32(val): return np.array(val, dtype=np.float32)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_main()
|
||||
_main_tianshou()
|
||||
|
||||
@@ -7,6 +7,7 @@ from .executor import BaseExecutor
|
||||
from .backtest import backtest as backtest_func
|
||||
from .backtest import collect_data as data_generator
|
||||
from .order import Order
|
||||
from .utils import TradeCalendarManager
|
||||
|
||||
from .utils import CommonInfrastructure
|
||||
from .order import Order
|
||||
|
||||
Reference in New Issue
Block a user