1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Qlib RL framework (stage 2) - trainer (#1125)

* checkpoint

(cherry picked from commit 1a8e0bd4671ee6d624a7d09bb198a273282cd050)

* Not a workable version

(cherry picked from commit 3498e185684cd5590d3ab97e0ab69eab8c1e0e3a)

* vessel

* ckpt

* .

* vessel

* .

* .

* checkpoint callback

* .

* cleanup

* logger

* .

* test

* .

* add test

* .

* .

* .

* .

* New reward

* Add train API

* fix mypy

* fix lint

* More comment

* 3.7 compat

* fix test

* fix test

* .

* Resolve comments

* fix typehint
This commit is contained in:
Yuge Zhang
2022-06-28 19:53:05 +08:00
committed by GitHub
parent 2ca0d88d2d
commit 25ecb1135f
17 changed files with 1410 additions and 145 deletions

View File

@@ -81,7 +81,7 @@ def test_simple_env_logger(caplog):
line = line.strip()
if line:
line_counter += 1
assert re.match(r".*reward 42\.0000 \(42.0000\) a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
assert re.match(r".*reward .* a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
assert line_counter >= 3

View File

@@ -17,7 +17,7 @@ from qlib.backtest import Order
from qlib.config import C
from qlib.log import set_log_with_config
from qlib.rl.data import pickle_styled
from qlib.rl.entries.test import backtest
from qlib.rl.trainer import backtest, train
from qlib.rl.order_execution import *
from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
@@ -306,3 +306,26 @@ def test_cn_ppo_strategy():
assert np.isclose(metrics["pa"].mean(), -16.21578303474833)
assert np.isclose(metrics["market_price"].mean(), 58.68277690875527)
assert np.isclose(metrics["trade_price"].mean(), 58.76063985000002)
def test_ppo_train():
set_log_with_config(C.logging_config)
# The data starts with 9:31 and ends with 15:00
orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58"))
assert len(orders) == 40
state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6)
action_interp = CategoricalActionInterpreter(4)
network = Recurrent(state_interp.observation_space)
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
train(
partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
state_interp,
action_interp,
orders,
policy,
PAPenaltyReward(),
vessel_kwargs={"episode_per_iter": 100, "update_kwargs": {"batch_size": 64, "repeat": 5}},
trainer_kwargs={"max_iters": 2, "loggers": ConsoleWriter(total_episodes=100)},
)

202
tests/rl/test_trainer.py Normal file
View File

@@ -0,0 +1,202 @@
import os
import random
import sys
from pathlib import Path
import pytest
import torch
import torch.nn as nn
from gym import spaces
from tianshou.policy import PPOPolicy
from qlib.config import C
from qlib.log import set_log_with_config
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.simulator import Simulator
from qlib.rl.reward import Reward
from qlib.rl.trainer import Trainer, TrainingVessel, EarlyStopping, Checkpoint
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
class ZeroSimulator(Simulator):
def __init__(self, *args, **kwargs):
self.action = self.correct = 0
def step(self, action):
self.action = action
self.correct = action == 0
self._done = random.choice([False, True])
if self._done:
self.env.logger.add_scalar("acc", self.correct * 100)
def get_state(self):
return {
"acc": self.correct * 100,
"action": self.action,
}
def done(self) -> bool:
return self._done
class NoopStateInterpreter(StateInterpreter):
observation_space = spaces.Dict(
{
"acc": spaces.Discrete(200),
"action": spaces.Discrete(2),
}
)
def interpret(self, simulator_state):
return simulator_state
class NoopActionInterpreter(ActionInterpreter):
action_space = spaces.Discrete(2)
def interpret(self, simulator_state, action):
return action
class AccReward(Reward):
def reward(self, simulator_state):
if self.env.status["done"]:
return simulator_state["acc"] / 100
return 0.0
class PolicyNet(nn.Module):
def __init__(self, out_features=1, return_state=False):
super().__init__()
self.fc = nn.Linear(32, out_features)
self.return_state = return_state
def forward(self, obs, state=None, **kwargs):
res = self.fc(torch.randn(obs["acc"].shape[0], 32))
if self.return_state:
return nn.functional.softmax(res, dim=-1), state
else:
return res
def _ppo_policy():
actor = PolicyNet(2, True)
critic = PolicyNet()
policy = PPOPolicy(
actor,
critic,
torch.optim.Adam(tuple(actor.parameters()) + tuple(critic.parameters())),
torch.distributions.Categorical,
action_space=NoopActionInterpreter().action_space,
)
return policy
def test_trainer():
set_log_with_config(C.logging_config)
trainer = Trainer(max_iters=10, finite_env_type="subproc")
policy = _ppo_policy()
vessel = TrainingVessel(
simulator_fn=lambda init: ZeroSimulator(init),
state_interpreter=NoopStateInterpreter(),
action_interpreter=NoopActionInterpreter(),
policy=policy,
train_initial_states=list(range(100)),
val_initial_states=list(range(10)),
test_initial_states=list(range(10)),
reward=AccReward(),
episode_per_iter=500,
update_kwargs=dict(repeat=10, batch_size=64),
)
trainer.fit(vessel)
assert trainer.current_iter == 10
assert trainer.current_episode == 5000
assert abs(trainer.metrics["acc"] - trainer.metrics["reward"] * 100) < 1e-4
assert trainer.metrics["acc"] > 80
trainer.test(vessel)
assert trainer.metrics["acc"] > 60
def test_trainer_fast_dev_run():
set_log_with_config(C.logging_config)
trainer = Trainer(max_iters=2, fast_dev_run=2, finite_env_type="shmem")
policy = _ppo_policy()
vessel = TrainingVessel(
simulator_fn=lambda init: ZeroSimulator(init),
state_interpreter=NoopStateInterpreter(),
action_interpreter=NoopActionInterpreter(),
policy=policy,
train_initial_states=list(range(100)),
val_initial_states=list(range(10)),
test_initial_states=list(range(10)),
reward=AccReward(),
episode_per_iter=500,
update_kwargs=dict(repeat=10, batch_size=64),
)
trainer.fit(vessel)
assert trainer.current_episode == 4
def test_trainer_earlystop():
# TODO this is just sanity check.
# need to see the logs to check whether it works.
set_log_with_config(C.logging_config)
trainer = Trainer(
max_iters=10,
val_every_n_iters=1,
finite_env_type="dummy",
callbacks=[EarlyStopping("val/reward", restore_best_weights=True)],
)
policy = _ppo_policy()
vessel = TrainingVessel(
simulator_fn=lambda init: ZeroSimulator(init),
state_interpreter=NoopStateInterpreter(),
action_interpreter=NoopActionInterpreter(),
policy=policy,
train_initial_states=list(range(100)),
val_initial_states=list(range(10)),
test_initial_states=list(range(10)),
reward=AccReward(),
episode_per_iter=500,
update_kwargs=dict(repeat=10, batch_size=64),
)
trainer.fit(vessel)
assert trainer.metrics["val/acc"] > 30
assert trainer.current_iter == 2 # second iteration
def test_trainer_checkpoint():
set_log_with_config(C.logging_config)
output_dir = Path(__file__).parent / ".output"
trainer = Trainer(max_iters=2, finite_env_type="dummy", callbacks=[Checkpoint(output_dir, every_n_iters=1)])
policy = _ppo_policy()
vessel = TrainingVessel(
simulator_fn=lambda init: ZeroSimulator(init),
state_interpreter=NoopStateInterpreter(),
action_interpreter=NoopActionInterpreter(),
policy=policy,
train_initial_states=list(range(100)),
val_initial_states=list(range(10)),
test_initial_states=list(range(10)),
reward=AccReward(),
episode_per_iter=100,
update_kwargs=dict(repeat=10, batch_size=64),
)
trainer.fit(vessel)
assert (output_dir / "001.pth").exists()
assert (output_dir / "002.pth").exists()
assert os.readlink(output_dir / "latest.pth") == str(output_dir / "002.pth")
trainer.load_state_dict(torch.load(output_dir / "001.pth"))
assert trainer.current_iter == 1
assert trainer.current_episode == 100
# Reload the checkpoint at first iteration
trainer.fit(vessel, ckpt_path=output_dir / "001.pth")