mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 12:00:58 +08:00
Update end-to-end example and requirements
This commit is contained in:
2
examples/nested_decision_execution/requirements.txt
Normal file
2
examples/nested_decision_execution/requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
tianshou>=0.4.1
|
||||
torch>=1.8.0
|
||||
@@ -4,12 +4,14 @@ from dataclasses import dataclass, asdict
|
||||
from pprint import pprint
|
||||
from typing import Iterable, Any, Optional, OrderedDict, Tuple, Dict, List
|
||||
|
||||
import fire
|
||||
import gym
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import qlib
|
||||
from gym import spaces
|
||||
from qlib.backtest import get_exchange, Account, BaseExecutor, CommonInfrastructure, Order, TradeCalendarManager
|
||||
from qlib.backtest import get_exchange, Account, BaseExecutor, CommonInfrastructure, Order, TradeCalendarManager, backtest_func
|
||||
from qlib.backtest.executor import NestedExecutor, SimulatorExecutor
|
||||
from qlib.config import REG_CN
|
||||
from qlib.data import D
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
||||
@@ -21,6 +23,8 @@ from tianshou.data import Batch, Collector
|
||||
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
from workflow import NestedDecisonExecutionWorkflow
|
||||
|
||||
|
||||
MAX_STEPS = 10
|
||||
|
||||
@@ -324,79 +328,122 @@ class RLStrategy(BaseStrategy):
|
||||
# apply results from the last step
|
||||
if execute_result is not None:
|
||||
orders = defaultdict(list)
|
||||
for order, _, __, in execute_result:
|
||||
orders[order.stock_id, order.direction].append(order)
|
||||
for e in execute_result:
|
||||
orders[e[0].stock_id, e[0].direction].append(e)
|
||||
for (stock_id, direction), state in self.states.items():
|
||||
state.update(orders[stock_id, direction])
|
||||
|
||||
state.update(orders[stock_id, direction], self.trade_calendar)
|
||||
|
||||
if not self.states:
|
||||
return []
|
||||
|
||||
obs_batch = Batch([{"obs": self.observation(state)} for state in self.states.values()])
|
||||
act = self.policy(obs_batch)
|
||||
exec_vols = [self.action(a) for a in act.act]
|
||||
return [create_sub_order(v, self.trade_calendar, order) for v in exec_vols]
|
||||
exec_vols = [self.action(a, s) for a, s in zip(act.act, self.states.values())]
|
||||
return [create_sub_order(v, self.trade_calendar, o) for v, o in zip(exec_vols, self.outer_trade_decision)]
|
||||
|
||||
|
||||
def _init_qlib():
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
class RlWorkflow(NestedDecisonExecutionWorkflow):
|
||||
|
||||
def tianshou(self):
|
||||
self._init_qlib()
|
||||
|
||||
def _main_tianshou():
|
||||
_init_qlib()
|
||||
|
||||
# TODO: why is there a benchmark?
|
||||
trade_start_time = "2017-01-01"
|
||||
trade_end_time = "2020-08-01"
|
||||
benchmark = "SH000300"
|
||||
time_per_step = "day"
|
||||
executor_config = {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": time_per_step,
|
||||
"verbose": True,
|
||||
"generate_report": False,
|
||||
# TODO: why is there a benchmark?
|
||||
trade_start_time = "2017-01-01"
|
||||
trade_end_time = "2020-08-01"
|
||||
benchmark = "SH000300"
|
||||
time_per_step = "day"
|
||||
executor_config = {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": time_per_step,
|
||||
"verbose": True,
|
||||
"generate_report": False,
|
||||
}
|
||||
}
|
||||
}
|
||||
exchange = get_exchange(
|
||||
freq="day",
|
||||
limit_threshold=0.095,
|
||||
deal_price="close",
|
||||
open_cost=0.0005,
|
||||
close_cost=0.0015,
|
||||
min_cost=5
|
||||
)
|
||||
|
||||
observation = Observation(time_per_step)
|
||||
action = Action()
|
||||
reward_fn = Reward()
|
||||
|
||||
def dummy_env():
|
||||
executor = get_executor(
|
||||
trade_start_time,
|
||||
trade_end_time,
|
||||
executor_config,
|
||||
exchange,
|
||||
benchmark,
|
||||
1000000000,
|
||||
exchange = get_exchange(
|
||||
freq="day",
|
||||
limit_threshold=0.095,
|
||||
deal_price="close",
|
||||
open_cost=0.0005,
|
||||
close_cost=0.0015,
|
||||
min_cost=5
|
||||
)
|
||||
return SingleOrderEnv(
|
||||
observation, action, reward_fn,
|
||||
iter(DataLoader(QlibOrderDataset('assets/orders'), batch_size=None, shuffle=True)), executor)
|
||||
|
||||
policy = DummyPolicy()
|
||||
observation = Observation(time_per_step)
|
||||
action = Action()
|
||||
reward_fn = Reward()
|
||||
|
||||
# This can not be replaced with SubprocVectorEnv
|
||||
# File "/xxx/qlib/qlib/data/data.py", line 462, in dataset_processor
|
||||
# p = Pool(processes=workers)
|
||||
# AssertionError: daemonic processes are not allowed to have children
|
||||
envs = DummyVectorEnv([dummy_env for _ in range(4)])
|
||||
test_collector = Collector(policy, envs)
|
||||
policy.eval()
|
||||
# TODO: create a queue for all orders and make it auto-complete when all the orders are processed
|
||||
test_collector.collect(n_episode=10)
|
||||
def dummy_env():
|
||||
executor = get_executor(
|
||||
trade_start_time,
|
||||
trade_end_time,
|
||||
executor_config,
|
||||
exchange,
|
||||
benchmark,
|
||||
1000000000,
|
||||
)
|
||||
return SingleOrderEnv(
|
||||
observation, action, reward_fn,
|
||||
iter(DataLoader(QlibOrderDataset('assets/orders'), batch_size=None, shuffle=True)), executor)
|
||||
|
||||
policy = DummyPolicy()
|
||||
|
||||
# This can not be replaced with SubprocVectorEnv
|
||||
# File "/xxx/qlib/qlib/data/data.py", line 462, in dataset_processor
|
||||
# p = Pool(processes=workers)
|
||||
# AssertionError: daemonic processes are not allowed to have children
|
||||
envs = DummyVectorEnv([dummy_env for _ in range(4)])
|
||||
test_collector = Collector(policy, envs)
|
||||
policy.eval()
|
||||
# TODO: create a queue for all orders and make it auto-complete when all the orders are processed
|
||||
test_collector.collect(n_episode=10)
|
||||
|
||||
def rl_day(self, load_model: Optional[str] = None):
|
||||
self._init_qlib()
|
||||
model = init_instance_by_config(self.task["model"])
|
||||
dataset = init_instance_by_config(self.task["dataset"])
|
||||
if load_model is None:
|
||||
self._train_model(model, dataset)
|
||||
else:
|
||||
model = self._load_model(load_model)
|
||||
trade_start_time = "2017-01-01"
|
||||
trade_end_time = "2020-08-01"
|
||||
trade_account = Account(
|
||||
init_cash=int(1e9),
|
||||
benchmark_config={
|
||||
"benchmark": "SH000300",
|
||||
"start_time": trade_start_time,
|
||||
"end_time": trade_end_time,
|
||||
},
|
||||
)
|
||||
exchange = get_exchange(
|
||||
freq="day",
|
||||
limit_threshold=0.095,
|
||||
deal_price="close",
|
||||
open_cost=0.0005,
|
||||
close_cost=0.0015,
|
||||
min_cost=5
|
||||
)
|
||||
common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=exchange)
|
||||
strategy = init_instance_by_config({
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"kwargs": {
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
}, common_infra=common_infra)
|
||||
executor = NestedExecutor(
|
||||
time_per_step="week",
|
||||
inner_executor=SimulatorExecutor(time_per_step="day", verbose=True),
|
||||
inner_strategy=RLStrategy(Observation("day"), Action(), DummyPolicy()),
|
||||
common_infra=common_infra
|
||||
)
|
||||
report_dict = backtest_func(trade_start_time, trade_end_time, strategy, executor)
|
||||
print(report_dict)
|
||||
|
||||
|
||||
### This is a full RL strategy ###
|
||||
@@ -527,4 +574,4 @@ def _to_float32(val): return np.array(val, dtype=np.float32)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_main_tianshou()
|
||||
fire.Fire(RlWorkflow)
|
||||
|
||||
Reference in New Issue
Block a user