1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 12:00:58 +08:00

Update end-to-end example and requirements

This commit is contained in:
Yuge Zhang
2021-06-04 12:20:27 +08:00
parent bf02fc23f8
commit c43805eff6
2 changed files with 113 additions and 64 deletions

View File

@@ -0,0 +1,2 @@
tianshou>=0.4.1
torch>=1.8.0

View File

@@ -4,12 +4,14 @@ from dataclasses import dataclass, asdict
from pprint import pprint
from typing import Iterable, Any, Optional, OrderedDict, Tuple, Dict, List
import fire
import gym
import numpy as np
import pandas as pd
import qlib
from gym import spaces
from qlib.backtest import get_exchange, Account, BaseExecutor, CommonInfrastructure, Order, TradeCalendarManager
from qlib.backtest import get_exchange, Account, BaseExecutor, CommonInfrastructure, Order, TradeCalendarManager, backtest_func
from qlib.backtest.executor import NestedExecutor, SimulatorExecutor
from qlib.config import REG_CN
from qlib.data import D
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
@@ -21,6 +23,8 @@ from tianshou.data import Batch, Collector
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
from tianshou.policy import BasePolicy
from workflow import NestedDecisonExecutionWorkflow
MAX_STEPS = 10
@@ -324,79 +328,122 @@ class RLStrategy(BaseStrategy):
# apply results from the last step
if execute_result is not None:
orders = defaultdict(list)
for order, _, __, in execute_result:
orders[order.stock_id, order.direction].append(order)
for e in execute_result:
orders[e[0].stock_id, e[0].direction].append(e)
for (stock_id, direction), state in self.states.items():
state.update(orders[stock_id, direction])
state.update(orders[stock_id, direction], self.trade_calendar)
if not self.states:
return []
obs_batch = Batch([{"obs": self.observation(state)} for state in self.states.values()])
act = self.policy(obs_batch)
exec_vols = [self.action(a) for a in act.act]
return [create_sub_order(v, self.trade_calendar, order) for v in exec_vols]
exec_vols = [self.action(a, s) for a, s in zip(act.act, self.states.values())]
return [create_sub_order(v, self.trade_calendar, o) for v, o in zip(exec_vols, self.outer_trade_decision)]
def _init_qlib():
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)
class RlWorkflow(NestedDecisonExecutionWorkflow):
def tianshou(self):
self._init_qlib()
def _main_tianshou():
_init_qlib()
# TODO: why is there a benchmark?
trade_start_time = "2017-01-01"
trade_end_time = "2020-08-01"
benchmark = "SH000300"
time_per_step = "day"
executor_config = {
"class": "SimulatorExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": time_per_step,
"verbose": True,
"generate_report": False,
# TODO: why is there a benchmark?
trade_start_time = "2017-01-01"
trade_end_time = "2020-08-01"
benchmark = "SH000300"
time_per_step = "day"
executor_config = {
"class": "SimulatorExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": time_per_step,
"verbose": True,
"generate_report": False,
}
}
}
exchange = get_exchange(
freq="day",
limit_threshold=0.095,
deal_price="close",
open_cost=0.0005,
close_cost=0.0015,
min_cost=5
)
observation = Observation(time_per_step)
action = Action()
reward_fn = Reward()
def dummy_env():
executor = get_executor(
trade_start_time,
trade_end_time,
executor_config,
exchange,
benchmark,
1000000000,
exchange = get_exchange(
freq="day",
limit_threshold=0.095,
deal_price="close",
open_cost=0.0005,
close_cost=0.0015,
min_cost=5
)
return SingleOrderEnv(
observation, action, reward_fn,
iter(DataLoader(QlibOrderDataset('assets/orders'), batch_size=None, shuffle=True)), executor)
policy = DummyPolicy()
observation = Observation(time_per_step)
action = Action()
reward_fn = Reward()
# This can not be replaced with SubprocVectorEnv
# File "/xxx/qlib/qlib/data/data.py", line 462, in dataset_processor
# p = Pool(processes=workers)
# AssertionError: daemonic processes are not allowed to have children
envs = DummyVectorEnv([dummy_env for _ in range(4)])
test_collector = Collector(policy, envs)
policy.eval()
# TODO: create a queue for all orders and make it auto-complete when all the orders are processed
test_collector.collect(n_episode=10)
def dummy_env():
executor = get_executor(
trade_start_time,
trade_end_time,
executor_config,
exchange,
benchmark,
1000000000,
)
return SingleOrderEnv(
observation, action, reward_fn,
iter(DataLoader(QlibOrderDataset('assets/orders'), batch_size=None, shuffle=True)), executor)
policy = DummyPolicy()
# This can not be replaced with SubprocVectorEnv
# File "/xxx/qlib/qlib/data/data.py", line 462, in dataset_processor
# p = Pool(processes=workers)
# AssertionError: daemonic processes are not allowed to have children
envs = DummyVectorEnv([dummy_env for _ in range(4)])
test_collector = Collector(policy, envs)
policy.eval()
# TODO: create a queue for all orders and make it auto-complete when all the orders are processed
test_collector.collect(n_episode=10)
def rl_day(self, load_model: Optional[str] = None):
self._init_qlib()
model = init_instance_by_config(self.task["model"])
dataset = init_instance_by_config(self.task["dataset"])
if load_model is None:
self._train_model(model, dataset)
else:
model = self._load_model(load_model)
trade_start_time = "2017-01-01"
trade_end_time = "2020-08-01"
trade_account = Account(
init_cash=int(1e9),
benchmark_config={
"benchmark": "SH000300",
"start_time": trade_start_time,
"end_time": trade_end_time,
},
)
exchange = get_exchange(
freq="day",
limit_threshold=0.095,
deal_price="close",
open_cost=0.0005,
close_cost=0.0015,
min_cost=5
)
common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=exchange)
strategy = init_instance_by_config({
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"kwargs": {
"model": model,
"dataset": dataset,
"topk": 50,
"n_drop": 5,
},
}, common_infra=common_infra)
executor = NestedExecutor(
time_per_step="week",
inner_executor=SimulatorExecutor(time_per_step="day", verbose=True),
inner_strategy=RLStrategy(Observation("day"), Action(), DummyPolicy()),
common_infra=common_infra
)
report_dict = backtest_func(trade_start_time, trade_end_time, strategy, executor)
print(report_dict)
### This is a full RL strategy ###
@@ -527,4 +574,4 @@ def _to_float32(val): return np.array(val, dtype=np.float32)
if __name__ == '__main__':
_main_tianshou()
fire.Fire(RlWorkflow)