1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00
Files
qlib/tests/rl/test_trainer.py
Linlang 4d621bff99 fix pkl file not loading in StaticDataLoader (#1896)
* fix pkl file not loading in StaticDataLoader

* resolve hard code

* resolve hard code
2025-03-18 16:05:24 +08:00

203 lines
6.1 KiB
Python

import os
import random
import sys
from pathlib import Path
import pytest
import torch
import torch.nn as nn
from gym import spaces
from tianshou.policy import PPOPolicy
from qlib.config import C
from qlib.log import set_log_with_config
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.simulator import Simulator
from qlib.rl.reward import Reward
from qlib.rl.trainer import Trainer, TrainingVessel, EarlyStopping, Checkpoint
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
class ZeroSimulator(Simulator):
def __init__(self, *args, **kwargs):
self.action = self.correct = 0
def step(self, action):
self.action = action
self.correct = action == 0
self._done = random.choice([False, True])
if self._done:
self.env.logger.add_scalar("acc", self.correct * 100)
def get_state(self):
return {
"acc": self.correct * 100,
"action": self.action,
}
def done(self) -> bool:
return self._done
class NoopStateInterpreter(StateInterpreter):
observation_space = spaces.Dict(
{
"acc": spaces.Discrete(200),
"action": spaces.Discrete(2),
}
)
def interpret(self, simulator_state):
return simulator_state
class NoopActionInterpreter(ActionInterpreter):
action_space = spaces.Discrete(2)
def interpret(self, simulator_state, action):
return action
class AccReward(Reward):
def reward(self, simulator_state):
if self.env.status["done"]:
return simulator_state["acc"] / 100
return 0.0
class PolicyNet(nn.Module):
def __init__(self, out_features=1, return_state=False):
super().__init__()
self.fc = nn.Linear(32, out_features)
self.return_state = return_state
def forward(self, obs, state=None, **kwargs):
res = self.fc(torch.randn(obs["acc"].shape[0], 32))
if self.return_state:
return nn.functional.softmax(res, dim=-1), state
else:
return res
def _ppo_policy():
actor = PolicyNet(2, True)
critic = PolicyNet()
policy = PPOPolicy(
actor,
critic,
torch.optim.Adam(tuple(actor.parameters()) + tuple(critic.parameters())),
torch.distributions.Categorical,
action_space=NoopActionInterpreter().action_space,
)
return policy
def test_trainer():
set_log_with_config(C.logging_config)
trainer = Trainer(max_iters=10, finite_env_type="subproc")
policy = _ppo_policy()
vessel = TrainingVessel(
simulator_fn=lambda init: ZeroSimulator(init),
state_interpreter=NoopStateInterpreter(),
action_interpreter=NoopActionInterpreter(),
policy=policy,
train_initial_states=list(range(100)),
val_initial_states=list(range(10)),
test_initial_states=list(range(10)),
reward=AccReward(),
episode_per_iter=500,
update_kwargs=dict(repeat=10, batch_size=64),
)
trainer.fit(vessel)
assert trainer.current_iter == 10
assert trainer.current_episode == 5000
assert abs(trainer.metrics["acc"] - trainer.metrics["reward"] * 100) < 1e-4
assert trainer.metrics["acc"] > 80
trainer.test(vessel)
assert trainer.metrics["acc"] > 60
def test_trainer_fast_dev_run():
set_log_with_config(C.logging_config)
trainer = Trainer(max_iters=2, fast_dev_run=2, finite_env_type="shmem")
policy = _ppo_policy()
vessel = TrainingVessel(
simulator_fn=lambda init: ZeroSimulator(init),
state_interpreter=NoopStateInterpreter(),
action_interpreter=NoopActionInterpreter(),
policy=policy,
train_initial_states=list(range(100)),
val_initial_states=list(range(10)),
test_initial_states=list(range(10)),
reward=AccReward(),
episode_per_iter=500,
update_kwargs=dict(repeat=10, batch_size=64),
)
trainer.fit(vessel)
assert trainer.current_episode == 4
def test_trainer_earlystop():
# TODO this is just sanity check.
# need to see the logs to check whether it works.
set_log_with_config(C.logging_config)
trainer = Trainer(
max_iters=10,
val_every_n_iters=1,
finite_env_type="dummy",
callbacks=[EarlyStopping("val/reward", restore_best_weights=True)],
)
policy = _ppo_policy()
vessel = TrainingVessel(
simulator_fn=lambda init: ZeroSimulator(init),
state_interpreter=NoopStateInterpreter(),
action_interpreter=NoopActionInterpreter(),
policy=policy,
train_initial_states=list(range(100)),
val_initial_states=list(range(10)),
test_initial_states=list(range(10)),
reward=AccReward(),
episode_per_iter=500,
update_kwargs=dict(repeat=10, batch_size=64),
)
trainer.fit(vessel)
assert trainer.metrics["val/acc"] > 30
assert trainer.current_iter == 2 # second iteration
def test_trainer_checkpoint():
set_log_with_config(C.logging_config)
output_dir = Path(__file__).parent / ".output"
trainer = Trainer(max_iters=2, finite_env_type="dummy", callbacks=[Checkpoint(output_dir, every_n_iters=1)])
policy = _ppo_policy()
vessel = TrainingVessel(
simulator_fn=lambda init: ZeroSimulator(init),
state_interpreter=NoopStateInterpreter(),
action_interpreter=NoopActionInterpreter(),
policy=policy,
train_initial_states=list(range(100)),
val_initial_states=list(range(10)),
test_initial_states=list(range(10)),
reward=AccReward(),
episode_per_iter=100,
update_kwargs=dict(repeat=10, batch_size=64),
)
trainer.fit(vessel)
assert (output_dir / "001.pth").exists()
assert (output_dir / "002.pth").exists()
assert os.readlink(output_dir / "latest.pth") == str(output_dir / "002.pth")
trainer.load_state_dict(torch.load(output_dir / "001.pth", weights_only=False))
assert trainer.current_iter == 1
assert trainer.current_episode == 100
# Reload the checkpoint at first iteration
trainer.fit(vessel, ckpt_path=output_dir / "001.pth")