1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 02:21:18 +08:00

Update simple playground

This commit is contained in:
Yuge Zhang
2021-06-01 11:33:44 +08:00
parent c26bee126b
commit d3dac068df
3 changed files with 141 additions and 0 deletions

View File

@@ -1,2 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .base import *

View File

@@ -7,6 +7,8 @@ from ..data.dataset.utils import convert_index_format
from ..rl.interpreter import ActionInterpreter, StateInterpreter
from ..utils import init_instance_by_config
__all__ = ['BaseStrategy', 'ModelStrategy', 'RLStrategy', 'RLIntStrategy']
class BaseStrategy:
"""Base strategy for trading"""

137
rl_playground.py Normal file
View File

@@ -0,0 +1,137 @@
import logging
import pickle
from enum import Enum
from typing import Iterable, Optional, Any
import gym
import numpy as np
import torch
from torch.utils.data import Dataset
from qlib.backtest import get_exchange, Account, BaseExecutor
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.utils import init_instance_by_config
def get_executor(start_time, end_time, executor, benchmark="SH000300", account=1e9, exchange_kwargs={}):
trade_account = Account(
init_cash=account,
benchmark_config={
"benchmark": benchmark,
"start_time": start_time,
"end_time": end_time,
},
)
trade_exchange = get_exchange(**exchange_kwargs)
common_infra = {
"trade_account": trade_account,
"trade_exchange": trade_exchange,
}
trade_executor = init_instance_by_config(executor, accept_types=BaseExecutor, common_infra=common_infra)
return common_infra, trade_executor
class QlibOrderDataset(Dataset):
def __init__(self, order_file):
with open(order_file, 'rb') as f:
self.orders = pickle.load(f)
def __len__(self):
return len(self.orders)
def __getitem__(self, index):
return self.orders[index]
class OrderEnv(gym.Env):
def __init__(self,
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
reward: Any,
dataloader: Iterable,
executor: BaseExecutor):
self.action_interpreter = action_interpreter
self.state_interpreter = state_interpreter
self.reward = reward
self.dataloader = dataloader
self.executor = executor
@property
def action_space(self):
return self.action.action_space
@property
def observation_space(self):
return self.observation.observation_space
def reset(self):
try:
self.cur_order = next(self.dataloader)
except StopIteration:
self.dataloader = None
return None
self.executor.reset(start_time=self.cur_order.start_time, end_time=self.cur_order.end_time)
self.level_infra = self.executor.get_level_infra()
self.execute_result = []
# TODO: how to fetch data after feature engineering?
# TODO: can be rewritten as dataclasses.asdict(self.cur_order) is Order is written to be a dataclass
return self.state_interpreter(self.cur_order, self.level_infra)
def step(self, action):
assert self.dataloader is not None
assert not self.executor.finished()
trade_decision = self.action_interpreter(action)
self.execute_result.extend(self.executor.execute(trade_decision))
reward, rew_info = self.reward()
done = self.executor.finished()
info = {
'action_history': self.action_history,
'category': self.ep_state.flow_dir.value,
'reward': rew_info
}
if self.ep_state.done:
info['logs'] = self.ep_state.logs()
info['index'] = {
'ins': self._sample.ins,
'date': self._sample.date
}
# TODO: how to collect metrics
return self.state_interpreter(self.cur_order, self.level_infra), reward, done, info
def _main():
executor_config = {
"class": "SimulatorExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": "day",
"verbose": True,
"generate_report": True,
}
}
# TODO: why is there a benchmark?
trade_start_time = "2017-01-01"
trade_end_time = "2020-08-01"
benchmark = "SH000300"
executor = get_executor(
trade_start_time, trade_end_time, executor_config,
benchmark, 1000000000, exchange_kwargs={
"freq": "day",
"limit_threshold": 0.095,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
}
)