From c43805eff60475eddc5f3f17ce39936cc81de335 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Fri, 4 Jun 2021 12:20:27 +0800 Subject: [PATCH] Update end-to-end example and requirements --- .../requirements.txt | 2 + .../nested_decision_execution/rl_dummy.py | 175 +++++++++++------- 2 files changed, 113 insertions(+), 64 deletions(-) create mode 100644 examples/nested_decision_execution/requirements.txt diff --git a/examples/nested_decision_execution/requirements.txt b/examples/nested_decision_execution/requirements.txt new file mode 100644 index 000000000..2ad0a826f --- /dev/null +++ b/examples/nested_decision_execution/requirements.txt @@ -0,0 +1,2 @@ +tianshou>=0.4.1 +torch>=1.8.0 diff --git a/examples/nested_decision_execution/rl_dummy.py b/examples/nested_decision_execution/rl_dummy.py index 61f1bba59..4a8f50ad0 100644 --- a/examples/nested_decision_execution/rl_dummy.py +++ b/examples/nested_decision_execution/rl_dummy.py @@ -4,12 +4,14 @@ from dataclasses import dataclass, asdict from pprint import pprint from typing import Iterable, Any, Optional, OrderedDict, Tuple, Dict, List +import fire import gym import numpy as np import pandas as pd import qlib from gym import spaces -from qlib.backtest import get_exchange, Account, BaseExecutor, CommonInfrastructure, Order, TradeCalendarManager +from qlib.backtest import get_exchange, Account, BaseExecutor, CommonInfrastructure, Order, TradeCalendarManager, backtest_func +from qlib.backtest.executor import NestedExecutor, SimulatorExecutor from qlib.config import REG_CN from qlib.data import D from qlib.rl.interpreter import StateInterpreter, ActionInterpreter @@ -21,6 +23,8 @@ from tianshou.data import Batch, Collector from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import BasePolicy +from workflow import NestedDecisonExecutionWorkflow + MAX_STEPS = 10 @@ -324,79 +328,122 @@ class RLStrategy(BaseStrategy): # apply results from the last step if execute_result is not None: orders = defaultdict(list) - for order, _, __, in execute_result: - orders[order.stock_id, order.direction].append(order) + for e in execute_result: + orders[e[0].stock_id, e[0].direction].append(e) for (stock_id, direction), state in self.states.items(): - state.update(orders[stock_id, direction]) - + state.update(orders[stock_id, direction], self.trade_calendar) + + if not self.states: + return [] + obs_batch = Batch([{"obs": self.observation(state)} for state in self.states.values()]) act = self.policy(obs_batch) - exec_vols = [self.action(a) for a in act.act] - return [create_sub_order(v, self.trade_calendar, order) for v in exec_vols] + exec_vols = [self.action(a, s) for a, s in zip(act.act, self.states.values())] + return [create_sub_order(v, self.trade_calendar, o) for v, o in zip(exec_vols, self.outer_trade_decision)] -def _init_qlib(): - provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) - qlib.init(provider_uri=provider_uri, region=REG_CN) +class RlWorkflow(NestedDecisonExecutionWorkflow): + def tianshou(self): + self._init_qlib() -def _main_tianshou(): - _init_qlib() - - # TODO: why is there a benchmark? - trade_start_time = "2017-01-01" - trade_end_time = "2020-08-01" - benchmark = "SH000300" - time_per_step = "day" - executor_config = { - "class": "SimulatorExecutor", - "module_path": "qlib.backtest.executor", - "kwargs": { - "time_per_step": time_per_step, - "verbose": True, - "generate_report": False, + # TODO: why is there a benchmark? + trade_start_time = "2017-01-01" + trade_end_time = "2020-08-01" + benchmark = "SH000300" + time_per_step = "day" + executor_config = { + "class": "SimulatorExecutor", + "module_path": "qlib.backtest.executor", + "kwargs": { + "time_per_step": time_per_step, + "verbose": True, + "generate_report": False, + } } - } - exchange = get_exchange( - freq="day", - limit_threshold=0.095, - deal_price="close", - open_cost=0.0005, - close_cost=0.0015, - min_cost=5 - ) - - observation = Observation(time_per_step) - action = Action() - reward_fn = Reward() - - def dummy_env(): - executor = get_executor( - trade_start_time, - trade_end_time, - executor_config, - exchange, - benchmark, - 1000000000, + exchange = get_exchange( + freq="day", + limit_threshold=0.095, + deal_price="close", + open_cost=0.0005, + close_cost=0.0015, + min_cost=5 ) - return SingleOrderEnv( - observation, action, reward_fn, - iter(DataLoader(QlibOrderDataset('assets/orders'), batch_size=None, shuffle=True)), executor) - policy = DummyPolicy() + observation = Observation(time_per_step) + action = Action() + reward_fn = Reward() - # This can not be replaced with SubprocVectorEnv - # File "/xxx/qlib/qlib/data/data.py", line 462, in dataset_processor - # p = Pool(processes=workers) - # AssertionError: daemonic processes are not allowed to have children - envs = DummyVectorEnv([dummy_env for _ in range(4)]) - test_collector = Collector(policy, envs) - policy.eval() - # TODO: create a queue for all orders and make it auto-complete when all the orders are processed - test_collector.collect(n_episode=10) + def dummy_env(): + executor = get_executor( + trade_start_time, + trade_end_time, + executor_config, + exchange, + benchmark, + 1000000000, + ) + return SingleOrderEnv( + observation, action, reward_fn, + iter(DataLoader(QlibOrderDataset('assets/orders'), batch_size=None, shuffle=True)), executor) + + policy = DummyPolicy() + + # This can not be replaced with SubprocVectorEnv + # File "/xxx/qlib/qlib/data/data.py", line 462, in dataset_processor + # p = Pool(processes=workers) + # AssertionError: daemonic processes are not allowed to have children + envs = DummyVectorEnv([dummy_env for _ in range(4)]) + test_collector = Collector(policy, envs) + policy.eval() + # TODO: create a queue for all orders and make it auto-complete when all the orders are processed + test_collector.collect(n_episode=10) + + def rl_day(self, load_model: Optional[str] = None): + self._init_qlib() + model = init_instance_by_config(self.task["model"]) + dataset = init_instance_by_config(self.task["dataset"]) + if load_model is None: + self._train_model(model, dataset) + else: + model = self._load_model(load_model) + trade_start_time = "2017-01-01" + trade_end_time = "2020-08-01" + trade_account = Account( + init_cash=int(1e9), + benchmark_config={ + "benchmark": "SH000300", + "start_time": trade_start_time, + "end_time": trade_end_time, + }, + ) + exchange = get_exchange( + freq="day", + limit_threshold=0.095, + deal_price="close", + open_cost=0.0005, + close_cost=0.0015, + min_cost=5 + ) + common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=exchange) + strategy = init_instance_by_config({ + "class": "TopkDropoutStrategy", + "module_path": "qlib.contrib.strategy.model_strategy", + "kwargs": { + "model": model, + "dataset": dataset, + "topk": 50, + "n_drop": 5, + }, + }, common_infra=common_infra) + executor = NestedExecutor( + time_per_step="week", + inner_executor=SimulatorExecutor(time_per_step="day", verbose=True), + inner_strategy=RLStrategy(Observation("day"), Action(), DummyPolicy()), + common_infra=common_infra + ) + report_dict = backtest_func(trade_start_time, trade_end_time, strategy, executor) + print(report_dict) ### This is a full RL strategy ### @@ -527,4 +574,4 @@ def _to_float32(val): return np.array(val, dtype=np.float32) if __name__ == '__main__': - _main_tianshou() + fire.Fire(RlWorkflow)