mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Qlib RL framework (stage 2) - trainer (#1125)
* checkpoint (cherry picked from commit 1a8e0bd4671ee6d624a7d09bb198a273282cd050) * Not a workable version (cherry picked from commit 3498e185684cd5590d3ab97e0ab69eab8c1e0e3a) * vessel * ckpt * . * vessel * . * . * checkpoint callback * . * cleanup * logger * . * test * . * add test * . * . * . * . * New reward * Add train API * fix mypy * fix lint * More comment * 3.7 compat * fix test * fix test * . * Resolve comments * fix typehint
This commit is contained in:
@@ -81,7 +81,7 @@ def test_simple_env_logger(caplog):
|
||||
line = line.strip()
|
||||
if line:
|
||||
line_counter += 1
|
||||
assert re.match(r".*reward 42\.0000 \(42.0000\) a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
|
||||
assert re.match(r".*reward .* a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
|
||||
assert line_counter >= 3
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from qlib.backtest import Order
|
||||
from qlib.config import C
|
||||
from qlib.log import set_log_with_config
|
||||
from qlib.rl.data import pickle_styled
|
||||
from qlib.rl.entries.test import backtest
|
||||
from qlib.rl.trainer import backtest, train
|
||||
from qlib.rl.order_execution import *
|
||||
from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
|
||||
|
||||
@@ -306,3 +306,26 @@ def test_cn_ppo_strategy():
|
||||
assert np.isclose(metrics["pa"].mean(), -16.21578303474833)
|
||||
assert np.isclose(metrics["market_price"].mean(), 58.68277690875527)
|
||||
assert np.isclose(metrics["trade_price"].mean(), 58.76063985000002)
|
||||
|
||||
|
||||
def test_ppo_train():
|
||||
set_log_with_config(C.logging_config)
|
||||
# The data starts with 9:31 and ends with 15:00
|
||||
orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58"))
|
||||
assert len(orders) == 40
|
||||
|
||||
state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6)
|
||||
action_interp = CategoricalActionInterpreter(4)
|
||||
network = Recurrent(state_interp.observation_space)
|
||||
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
|
||||
|
||||
train(
|
||||
partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
|
||||
state_interp,
|
||||
action_interp,
|
||||
orders,
|
||||
policy,
|
||||
PAPenaltyReward(),
|
||||
vessel_kwargs={"episode_per_iter": 100, "update_kwargs": {"batch_size": 64, "repeat": 5}},
|
||||
trainer_kwargs={"max_iters": 2, "loggers": ConsoleWriter(total_episodes=100)},
|
||||
)
|
||||
|
||||
202
tests/rl/test_trainer.py
Normal file
202
tests/rl/test_trainer.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from gym import spaces
|
||||
from tianshou.policy import PPOPolicy
|
||||
|
||||
from qlib.config import C
|
||||
from qlib.log import set_log_with_config
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
||||
from qlib.rl.simulator import Simulator
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.trainer import Trainer, TrainingVessel, EarlyStopping, Checkpoint
|
||||
|
||||
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
|
||||
|
||||
|
||||
class ZeroSimulator(Simulator):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.action = self.correct = 0
|
||||
|
||||
def step(self, action):
|
||||
self.action = action
|
||||
self.correct = action == 0
|
||||
self._done = random.choice([False, True])
|
||||
if self._done:
|
||||
self.env.logger.add_scalar("acc", self.correct * 100)
|
||||
|
||||
def get_state(self):
|
||||
return {
|
||||
"acc": self.correct * 100,
|
||||
"action": self.action,
|
||||
}
|
||||
|
||||
def done(self) -> bool:
|
||||
return self._done
|
||||
|
||||
|
||||
class NoopStateInterpreter(StateInterpreter):
|
||||
observation_space = spaces.Dict(
|
||||
{
|
||||
"acc": spaces.Discrete(200),
|
||||
"action": spaces.Discrete(2),
|
||||
}
|
||||
)
|
||||
|
||||
def interpret(self, simulator_state):
|
||||
return simulator_state
|
||||
|
||||
|
||||
class NoopActionInterpreter(ActionInterpreter):
|
||||
action_space = spaces.Discrete(2)
|
||||
|
||||
def interpret(self, simulator_state, action):
|
||||
return action
|
||||
|
||||
|
||||
class AccReward(Reward):
|
||||
def reward(self, simulator_state):
|
||||
if self.env.status["done"]:
|
||||
return simulator_state["acc"] / 100
|
||||
return 0.0
|
||||
|
||||
|
||||
class PolicyNet(nn.Module):
|
||||
def __init__(self, out_features=1, return_state=False):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(32, out_features)
|
||||
self.return_state = return_state
|
||||
|
||||
def forward(self, obs, state=None, **kwargs):
|
||||
res = self.fc(torch.randn(obs["acc"].shape[0], 32))
|
||||
if self.return_state:
|
||||
return nn.functional.softmax(res, dim=-1), state
|
||||
else:
|
||||
return res
|
||||
|
||||
|
||||
def _ppo_policy():
|
||||
actor = PolicyNet(2, True)
|
||||
critic = PolicyNet()
|
||||
policy = PPOPolicy(
|
||||
actor,
|
||||
critic,
|
||||
torch.optim.Adam(tuple(actor.parameters()) + tuple(critic.parameters())),
|
||||
torch.distributions.Categorical,
|
||||
action_space=NoopActionInterpreter().action_space,
|
||||
)
|
||||
return policy
|
||||
|
||||
|
||||
def test_trainer():
|
||||
set_log_with_config(C.logging_config)
|
||||
trainer = Trainer(max_iters=10, finite_env_type="subproc")
|
||||
policy = _ppo_policy()
|
||||
|
||||
vessel = TrainingVessel(
|
||||
simulator_fn=lambda init: ZeroSimulator(init),
|
||||
state_interpreter=NoopStateInterpreter(),
|
||||
action_interpreter=NoopActionInterpreter(),
|
||||
policy=policy,
|
||||
train_initial_states=list(range(100)),
|
||||
val_initial_states=list(range(10)),
|
||||
test_initial_states=list(range(10)),
|
||||
reward=AccReward(),
|
||||
episode_per_iter=500,
|
||||
update_kwargs=dict(repeat=10, batch_size=64),
|
||||
)
|
||||
trainer.fit(vessel)
|
||||
assert trainer.current_iter == 10
|
||||
assert trainer.current_episode == 5000
|
||||
assert abs(trainer.metrics["acc"] - trainer.metrics["reward"] * 100) < 1e-4
|
||||
assert trainer.metrics["acc"] > 80
|
||||
trainer.test(vessel)
|
||||
assert trainer.metrics["acc"] > 60
|
||||
|
||||
|
||||
def test_trainer_fast_dev_run():
|
||||
set_log_with_config(C.logging_config)
|
||||
trainer = Trainer(max_iters=2, fast_dev_run=2, finite_env_type="shmem")
|
||||
policy = _ppo_policy()
|
||||
|
||||
vessel = TrainingVessel(
|
||||
simulator_fn=lambda init: ZeroSimulator(init),
|
||||
state_interpreter=NoopStateInterpreter(),
|
||||
action_interpreter=NoopActionInterpreter(),
|
||||
policy=policy,
|
||||
train_initial_states=list(range(100)),
|
||||
val_initial_states=list(range(10)),
|
||||
test_initial_states=list(range(10)),
|
||||
reward=AccReward(),
|
||||
episode_per_iter=500,
|
||||
update_kwargs=dict(repeat=10, batch_size=64),
|
||||
)
|
||||
trainer.fit(vessel)
|
||||
assert trainer.current_episode == 4
|
||||
|
||||
|
||||
def test_trainer_earlystop():
|
||||
# TODO this is just sanity check.
|
||||
# need to see the logs to check whether it works.
|
||||
set_log_with_config(C.logging_config)
|
||||
trainer = Trainer(
|
||||
max_iters=10,
|
||||
val_every_n_iters=1,
|
||||
finite_env_type="dummy",
|
||||
callbacks=[EarlyStopping("val/reward", restore_best_weights=True)],
|
||||
)
|
||||
policy = _ppo_policy()
|
||||
|
||||
vessel = TrainingVessel(
|
||||
simulator_fn=lambda init: ZeroSimulator(init),
|
||||
state_interpreter=NoopStateInterpreter(),
|
||||
action_interpreter=NoopActionInterpreter(),
|
||||
policy=policy,
|
||||
train_initial_states=list(range(100)),
|
||||
val_initial_states=list(range(10)),
|
||||
test_initial_states=list(range(10)),
|
||||
reward=AccReward(),
|
||||
episode_per_iter=500,
|
||||
update_kwargs=dict(repeat=10, batch_size=64),
|
||||
)
|
||||
trainer.fit(vessel)
|
||||
assert trainer.metrics["val/acc"] > 30
|
||||
assert trainer.current_iter == 2 # second iteration
|
||||
|
||||
|
||||
def test_trainer_checkpoint():
|
||||
set_log_with_config(C.logging_config)
|
||||
output_dir = Path(__file__).parent / ".output"
|
||||
trainer = Trainer(max_iters=2, finite_env_type="dummy", callbacks=[Checkpoint(output_dir, every_n_iters=1)])
|
||||
policy = _ppo_policy()
|
||||
|
||||
vessel = TrainingVessel(
|
||||
simulator_fn=lambda init: ZeroSimulator(init),
|
||||
state_interpreter=NoopStateInterpreter(),
|
||||
action_interpreter=NoopActionInterpreter(),
|
||||
policy=policy,
|
||||
train_initial_states=list(range(100)),
|
||||
val_initial_states=list(range(10)),
|
||||
test_initial_states=list(range(10)),
|
||||
reward=AccReward(),
|
||||
episode_per_iter=100,
|
||||
update_kwargs=dict(repeat=10, batch_size=64),
|
||||
)
|
||||
trainer.fit(vessel)
|
||||
|
||||
assert (output_dir / "001.pth").exists()
|
||||
assert (output_dir / "002.pth").exists()
|
||||
assert os.readlink(output_dir / "latest.pth") == str(output_dir / "002.pth")
|
||||
|
||||
trainer.load_state_dict(torch.load(output_dir / "001.pth"))
|
||||
assert trainer.current_iter == 1
|
||||
assert trainer.current_episode == 100
|
||||
|
||||
# Reload the checkpoint at first iteration
|
||||
trainer.fit(vessel, ckpt_path=output_dir / "001.pth")
|
||||
Reference in New Issue
Block a user