1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 14:01:28 +08:00
Files
qlib/tests/rl/test_trainer.py
Yuge Zhang 25ecb1135f Qlib RL framework (stage 2) - trainer (#1125)
* checkpoint

(cherry picked from commit 1a8e0bd4671ee6d624a7d09bb198a273282cd050)

* Not a workable version

(cherry picked from commit 3498e185684cd5590d3ab97e0ab69eab8c1e0e3a)

* vessel

* ckpt

* .

* vessel

* .

* .

* checkpoint callback

* .

* cleanup

* logger

* .

* test

* .

* add test

* .

* .

* .

* .

* New reward

* Add train API

* fix mypy

* fix lint

* More comment

* 3.7 compat

* fix test

* fix test

* .

* Resolve comments

* fix typehint
2022-06-28 19:53:05 +08:00

203 lines
6.1 KiB
Python

import os
import random
import sys
from pathlib import Path
import pytest
import torch
import torch.nn as nn
from gym import spaces
from tianshou.policy import PPOPolicy
from qlib.config import C
from qlib.log import set_log_with_config
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.simulator import Simulator
from qlib.rl.reward import Reward
from qlib.rl.trainer import Trainer, TrainingVessel, EarlyStopping, Checkpoint
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
class ZeroSimulator(Simulator):
def __init__(self, *args, **kwargs):
self.action = self.correct = 0
def step(self, action):
self.action = action
self.correct = action == 0
self._done = random.choice([False, True])
if self._done:
self.env.logger.add_scalar("acc", self.correct * 100)
def get_state(self):
return {
"acc": self.correct * 100,
"action": self.action,
}
def done(self) -> bool:
return self._done
class NoopStateInterpreter(StateInterpreter):
observation_space = spaces.Dict(
{
"acc": spaces.Discrete(200),
"action": spaces.Discrete(2),
}
)
def interpret(self, simulator_state):
return simulator_state
class NoopActionInterpreter(ActionInterpreter):
action_space = spaces.Discrete(2)
def interpret(self, simulator_state, action):
return action
class AccReward(Reward):
def reward(self, simulator_state):
if self.env.status["done"]:
return simulator_state["acc"] / 100
return 0.0
class PolicyNet(nn.Module):
def __init__(self, out_features=1, return_state=False):
super().__init__()
self.fc = nn.Linear(32, out_features)
self.return_state = return_state
def forward(self, obs, state=None, **kwargs):
res = self.fc(torch.randn(obs["acc"].shape[0], 32))
if self.return_state:
return nn.functional.softmax(res, dim=-1), state
else:
return res
def _ppo_policy():
actor = PolicyNet(2, True)
critic = PolicyNet()
policy = PPOPolicy(
actor,
critic,
torch.optim.Adam(tuple(actor.parameters()) + tuple(critic.parameters())),
torch.distributions.Categorical,
action_space=NoopActionInterpreter().action_space,
)
return policy
def test_trainer():
set_log_with_config(C.logging_config)
trainer = Trainer(max_iters=10, finite_env_type="subproc")
policy = _ppo_policy()
vessel = TrainingVessel(
simulator_fn=lambda init: ZeroSimulator(init),
state_interpreter=NoopStateInterpreter(),
action_interpreter=NoopActionInterpreter(),
policy=policy,
train_initial_states=list(range(100)),
val_initial_states=list(range(10)),
test_initial_states=list(range(10)),
reward=AccReward(),
episode_per_iter=500,
update_kwargs=dict(repeat=10, batch_size=64),
)
trainer.fit(vessel)
assert trainer.current_iter == 10
assert trainer.current_episode == 5000
assert abs(trainer.metrics["acc"] - trainer.metrics["reward"] * 100) < 1e-4
assert trainer.metrics["acc"] > 80
trainer.test(vessel)
assert trainer.metrics["acc"] > 60
def test_trainer_fast_dev_run():
set_log_with_config(C.logging_config)
trainer = Trainer(max_iters=2, fast_dev_run=2, finite_env_type="shmem")
policy = _ppo_policy()
vessel = TrainingVessel(
simulator_fn=lambda init: ZeroSimulator(init),
state_interpreter=NoopStateInterpreter(),
action_interpreter=NoopActionInterpreter(),
policy=policy,
train_initial_states=list(range(100)),
val_initial_states=list(range(10)),
test_initial_states=list(range(10)),
reward=AccReward(),
episode_per_iter=500,
update_kwargs=dict(repeat=10, batch_size=64),
)
trainer.fit(vessel)
assert trainer.current_episode == 4
def test_trainer_earlystop():
# TODO this is just sanity check.
# need to see the logs to check whether it works.
set_log_with_config(C.logging_config)
trainer = Trainer(
max_iters=10,
val_every_n_iters=1,
finite_env_type="dummy",
callbacks=[EarlyStopping("val/reward", restore_best_weights=True)],
)
policy = _ppo_policy()
vessel = TrainingVessel(
simulator_fn=lambda init: ZeroSimulator(init),
state_interpreter=NoopStateInterpreter(),
action_interpreter=NoopActionInterpreter(),
policy=policy,
train_initial_states=list(range(100)),
val_initial_states=list(range(10)),
test_initial_states=list(range(10)),
reward=AccReward(),
episode_per_iter=500,
update_kwargs=dict(repeat=10, batch_size=64),
)
trainer.fit(vessel)
assert trainer.metrics["val/acc"] > 30
assert trainer.current_iter == 2 # second iteration
def test_trainer_checkpoint():
set_log_with_config(C.logging_config)
output_dir = Path(__file__).parent / ".output"
trainer = Trainer(max_iters=2, finite_env_type="dummy", callbacks=[Checkpoint(output_dir, every_n_iters=1)])
policy = _ppo_policy()
vessel = TrainingVessel(
simulator_fn=lambda init: ZeroSimulator(init),
state_interpreter=NoopStateInterpreter(),
action_interpreter=NoopActionInterpreter(),
policy=policy,
train_initial_states=list(range(100)),
val_initial_states=list(range(10)),
test_initial_states=list(range(10)),
reward=AccReward(),
episode_per_iter=100,
update_kwargs=dict(repeat=10, batch_size=64),
)
trainer.fit(vessel)
assert (output_dir / "001.pth").exists()
assert (output_dir / "002.pth").exists()
assert os.readlink(output_dir / "latest.pth") == str(output_dir / "002.pth")
trainer.load_state_dict(torch.load(output_dir / "001.pth"))
assert trainer.current_iter == 1
assert trainer.current_episode == 100
# Reload the checkpoint at first iteration
trainer.fit(vessel, ckpt_path=output_dir / "001.pth")