1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 02:50:58 +08:00
Files
qlib/tests/rl/test_saoe_simple.py
Yuge Zhang 25ecb1135f Qlib RL framework (stage 2) - trainer (#1125)
* checkpoint

(cherry picked from commit 1a8e0bd4671ee6d624a7d09bb198a273282cd050)

* Not a workable version

(cherry picked from commit 3498e185684cd5590d3ab97e0ab69eab8c1e0e3a)

* vessel

* ckpt

* .

* vessel

* .

* .

* checkpoint callback

* .

* cleanup

* logger

* .

* test

* .

* add test

* .

* .

* .

* .

* New reward

* Add train API

* fix mypy

* fix lint

* More comment

* 3.7 compat

* fix test

* fix test

* .

* Resolve comments

* fix typehint
2022-06-28 19:53:05 +08:00

332 lines
13 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
from functools import partial
from pathlib import Path
from typing import NamedTuple
import numpy as np
import pandas as pd
import pytest
import torch
from tianshou.data import Batch
from qlib.backtest import Order
from qlib.config import C
from qlib.log import set_log_with_config
from qlib.rl.data import pickle_styled
from qlib.rl.trainer import backtest, train
from qlib.rl.order_execution import *
from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
DATA_ROOT_DIR = Path(__file__).parent.parent / ".data" / "rl" / "intraday_saoe"
DATA_DIR = DATA_ROOT_DIR / "us"
BACKTEST_DATA_DIR = DATA_DIR / "backtest"
FEATURE_DATA_DIR = DATA_DIR / "processed"
ORDER_DIR = DATA_DIR / "order" / "valid_bidir"
CN_DATA_DIR = DATA_ROOT_DIR / "cn"
CN_BACKTEST_DATA_DIR = CN_DATA_DIR / "backtest"
CN_FEATURE_DATA_DIR = CN_DATA_DIR / "processed"
CN_ORDER_DIR = CN_DATA_DIR / "order" / "test"
CN_POLICY_WEIGHTS_DIR = CN_DATA_DIR / "weights"
def test_pickle_data_inspect():
data = pickle_styled.load_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0)
assert len(data) == 390
data = pickle_styled.load_intraday_processed_data(
DATA_DIR / "processed", "AAL", "2013-12-11", 5, data.get_time_index()
)
assert len(data.today) == len(data.yesterday) == 390
def test_simulator_first_step():
order = Order("AAL", 30.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
state = simulator.get_state()
assert state.cur_time == pd.Timestamp("2013-12-11 09:30:00")
assert state.position == 30.0
simulator.step(15.0)
state = simulator.get_state()
assert len(state.history_exec) == 30
assert state.history_exec.index[0] == pd.Timestamp("2013-12-11 09:30:00")
assert state.history_exec["market_volume"].iloc[0] == 450072.0
assert abs(state.history_exec["market_price"].iloc[0] - 25.370001) < 1e-4
assert (state.history_exec["amount"] == 0.5).all()
assert (state.history_exec["deal_amount"] == 0.5).all()
assert abs(state.history_exec["trade_price"].iloc[0] - 25.370001) < 1e-4
assert abs(state.history_exec["trade_value"].iloc[0] - 12.68500) < 1e-4
assert state.history_exec["position"].iloc[0] == 29.5
assert state.history_exec["ffr"].iloc[0] == 1 / 60
assert state.history_steps["market_volume"].iloc[0] == 5041147.0
assert state.history_steps["amount"].iloc[0] == 15.0
assert state.history_steps["deal_amount"].iloc[0] == 15.0
assert state.history_steps["ffr"].iloc[0] == 0.5
assert (
state.history_steps["pa"].iloc[0]
== (state.history_steps["trade_price"].iloc[0] / simulator.twap_price - 1) * 10000
)
assert state.position == 15.0
assert state.cur_time == pd.Timestamp("2013-12-11 10:00:00")
def test_simulator_stop_twap():
order = Order("AAL", 13.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
for _ in range(13):
simulator.step(1.0)
state = simulator.get_state()
assert len(state.history_exec) == 390
assert (state.history_exec["deal_amount"] == 13 / 390).all()
assert state.history_steps["position"].iloc[0] == 12 and state.history_steps["position"].iloc[-1] == 0
assert (state.metrics["ffr"] - 1) < 1e-3
assert abs(state.metrics["market_price"] - state.backtest_data.get_deal_price().mean()) < 1e-4
assert np.isclose(state.metrics["market_volume"], state.backtest_data.get_volume().sum())
assert state.position == 0.0
assert abs(state.metrics["trade_price"] - state.metrics["market_price"]) < 1e-4
assert abs(state.metrics["pa"]) < 1e-2
assert simulator.done()
def test_simulator_stop_early():
order = Order("AAL", 1.0, 1, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
with pytest.raises(ValueError):
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
simulator.step(2.0)
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
simulator.step(1.0)
with pytest.raises(AssertionError):
simulator.step(1.0)
def test_simulator_start_middle():
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
assert len(simulator.ticks_for_order) == 330
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
simulator.step(2.0)
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:30:00")
for _ in range(10):
simulator.step(1.0)
simulator.step(2.0)
assert len(simulator.history_exec) == 330
assert simulator.done()
assert abs(simulator.history_exec["amount"].iloc[-1] - (1 + 2 / 15)) < 1e-4
assert abs(simulator.metrics["ffr"] - 1) < 1e-4
def test_interpreter():
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
assert len(simulator.ticks_for_order) == 330
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
# emulate a env status
class EmulateEnvWrapper(NamedTuple):
status: EnvWrapperStatus
interpreter = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5)
interpreter_step = CurrentStepStateInterpreter(13)
interpreter_action = CategoricalActionInterpreter(20)
interpreter_action_twap = TwapRelativeActionInterpreter()
wrapper_status_kwargs = dict(initial_state=order, obs_history=[], action_history=[], reward_history=[])
# first step
interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=0, done=False, **wrapper_status_kwargs))
obs = interpreter(simulator.get_state())
assert obs["cur_tick"] == 45
assert obs["cur_step"] == 0
assert obs["position"] == 15.0
assert obs["position_history"][0] == 15.0
assert all(np.sum(obs["data_processed"][i]) != 0 for i in range(45))
assert np.sum(obs["data_processed"][45:]) == 0
assert obs["data_processed_prev"].shape == (390, 5)
# first step: second interpreter
interpreter_step.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=0, done=False, **wrapper_status_kwargs))
obs = interpreter_step(simulator.get_state())
assert obs["acquiring"] == 1
assert obs["position"] == 15.0
# second step
simulator.step(5.0)
interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=1, done=False, **wrapper_status_kwargs))
obs = interpreter(simulator.get_state())
assert obs["cur_tick"] == 60
assert obs["cur_step"] == 1
assert obs["position"] == 10.0
assert obs["position_history"][:2].tolist() == [15.0, 10.0]
assert all(np.sum(obs["data_processed"][i]) != 0 for i in range(60))
assert np.sum(obs["data_processed"][60:]) == 0
# second step: action
action = interpreter_action(simulator.get_state(), 1)
assert action == 15 / 20
interpreter_action_twap.env = EmulateEnvWrapper(
status=EnvWrapperStatus(cur_step=1, done=False, **wrapper_status_kwargs)
)
action = interpreter_action_twap(simulator.get_state(), 1.5)
assert action == 1.5
# fast-forward
for _ in range(10):
simulator.step(0.0)
# last step
simulator.step(5.0)
interpreter.env = EmulateEnvWrapper(
status=EnvWrapperStatus(cur_step=12, done=simulator.done(), **wrapper_status_kwargs)
)
assert interpreter.env.status["done"]
obs = interpreter(simulator.get_state())
assert obs["cur_tick"] == 375
assert obs["cur_step"] == 12
assert obs["position"] == 0.0
assert obs["position_history"][1:11].tolist() == [10.0] * 10
assert all(np.sum(obs["data_processed"][i]) != 0 for i in range(375))
assert np.sum(obs["data_processed"][375:]) == 0
def test_network_sanity():
# we won't check the correctness of networks here
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 9:30:00"), pd.Timestamp("2013-12-11 15:59:59"))
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
assert len(simulator.ticks_for_order) == 390
class EmulateEnvWrapper(NamedTuple):
status: EnvWrapperStatus
interpreter = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5)
action_interp = CategoricalActionInterpreter(13)
wrapper_status_kwargs = dict(initial_state=order, obs_history=[], action_history=[], reward_history=[])
network = Recurrent(interpreter.observation_space)
policy = PPO(network, interpreter.observation_space, action_interp.action_space, 1e-3)
for i in range(14):
interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=i, done=False, **wrapper_status_kwargs))
obs = interpreter(simulator.get_state())
batch = Batch(obs=[obs])
output = policy(batch)
assert 0 <= output["act"].item() <= 13
if i < 13:
simulator.step(1.0)
else:
assert obs["cur_tick"] == 389
assert obs["cur_step"] == 12
assert obs["position_history"][-1] == 3
@pytest.mark.parametrize("finite_env_type", ["dummy", "subproc", "shmem"])
def test_twap_strategy(finite_env_type):
set_log_with_config(C.logging_config)
orders = pickle_styled.load_orders(ORDER_DIR)
assert len(orders) == 248
state_interp = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5)
action_interp = TwapRelativeActionInterpreter()
policy = AllOne(state_interp.observation_space, action_interp.action_space)
csv_writer = CsvWriter(Path(__file__).parent / ".output")
backtest(
partial(SingleAssetOrderExecution, data_dir=BACKTEST_DATA_DIR, ticks_per_step=30),
state_interp,
action_interp,
orders,
policy,
[ConsoleWriter(total_episodes=len(orders)), csv_writer],
concurrency=4,
finite_env_type=finite_env_type,
)
metrics = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
assert len(metrics) == 248
assert np.isclose(metrics["ffr"].mean(), 1.0)
assert np.isclose(metrics["pa"].mean(), 0.0)
assert np.allclose(metrics["pa"], 0.0, atol=2e-3)
def test_cn_ppo_strategy():
set_log_with_config(C.logging_config)
# The data starts with 9:31 and ends with 15:00
orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58"))
assert len(orders) == 40
state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6)
action_interp = CategoricalActionInterpreter(4)
network = Recurrent(state_interp.observation_space)
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
policy.load_state_dict(torch.load(CN_POLICY_WEIGHTS_DIR / "ppo_recurrent_30min.pth", map_location="cpu"))
csv_writer = CsvWriter(Path(__file__).parent / ".output")
backtest(
partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
state_interp,
action_interp,
orders,
policy,
[ConsoleWriter(total_episodes=len(orders)), csv_writer],
concurrency=4,
)
metrics = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
assert len(metrics) == len(orders)
assert np.isclose(metrics["ffr"].mean(), 1.0)
assert np.isclose(metrics["pa"].mean(), -16.21578303474833)
assert np.isclose(metrics["market_price"].mean(), 58.68277690875527)
assert np.isclose(metrics["trade_price"].mean(), 58.76063985000002)
def test_ppo_train():
set_log_with_config(C.logging_config)
# The data starts with 9:31 and ends with 15:00
orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58"))
assert len(orders) == 40
state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6)
action_interp = CategoricalActionInterpreter(4)
network = Recurrent(state_interp.observation_space)
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
train(
partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
state_interp,
action_interp,
orders,
policy,
PAPenaltyReward(),
vessel_kwargs={"episode_per_iter": 100, "update_kwargs": {"batch_size": 64, "repeat": 5}},
trainer_kwargs={"max_iters": 2, "loggers": ConsoleWriter(total_episodes=100)},
)