mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 14:01:28 +08:00
* checkpoint (cherry picked from commit 1a8e0bd4671ee6d624a7d09bb198a273282cd050) * Not a workable version (cherry picked from commit 3498e185684cd5590d3ab97e0ab69eab8c1e0e3a) * vessel * ckpt * . * vessel * . * . * checkpoint callback * . * cleanup * logger * . * test * . * add test * . * . * . * . * New reward * Add train API * fix mypy * fix lint * More comment * 3.7 compat * fix test * fix test * . * Resolve comments * fix typehint
203 lines
6.1 KiB
Python
203 lines
6.1 KiB
Python
import os
|
|
import random
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from gym import spaces
|
|
from tianshou.policy import PPOPolicy
|
|
|
|
from qlib.config import C
|
|
from qlib.log import set_log_with_config
|
|
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
|
from qlib.rl.simulator import Simulator
|
|
from qlib.rl.reward import Reward
|
|
from qlib.rl.trainer import Trainer, TrainingVessel, EarlyStopping, Checkpoint
|
|
|
|
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
|
|
|
|
|
|
class ZeroSimulator(Simulator):
|
|
def __init__(self, *args, **kwargs):
|
|
self.action = self.correct = 0
|
|
|
|
def step(self, action):
|
|
self.action = action
|
|
self.correct = action == 0
|
|
self._done = random.choice([False, True])
|
|
if self._done:
|
|
self.env.logger.add_scalar("acc", self.correct * 100)
|
|
|
|
def get_state(self):
|
|
return {
|
|
"acc": self.correct * 100,
|
|
"action": self.action,
|
|
}
|
|
|
|
def done(self) -> bool:
|
|
return self._done
|
|
|
|
|
|
class NoopStateInterpreter(StateInterpreter):
|
|
observation_space = spaces.Dict(
|
|
{
|
|
"acc": spaces.Discrete(200),
|
|
"action": spaces.Discrete(2),
|
|
}
|
|
)
|
|
|
|
def interpret(self, simulator_state):
|
|
return simulator_state
|
|
|
|
|
|
class NoopActionInterpreter(ActionInterpreter):
|
|
action_space = spaces.Discrete(2)
|
|
|
|
def interpret(self, simulator_state, action):
|
|
return action
|
|
|
|
|
|
class AccReward(Reward):
|
|
def reward(self, simulator_state):
|
|
if self.env.status["done"]:
|
|
return simulator_state["acc"] / 100
|
|
return 0.0
|
|
|
|
|
|
class PolicyNet(nn.Module):
|
|
def __init__(self, out_features=1, return_state=False):
|
|
super().__init__()
|
|
self.fc = nn.Linear(32, out_features)
|
|
self.return_state = return_state
|
|
|
|
def forward(self, obs, state=None, **kwargs):
|
|
res = self.fc(torch.randn(obs["acc"].shape[0], 32))
|
|
if self.return_state:
|
|
return nn.functional.softmax(res, dim=-1), state
|
|
else:
|
|
return res
|
|
|
|
|
|
def _ppo_policy():
|
|
actor = PolicyNet(2, True)
|
|
critic = PolicyNet()
|
|
policy = PPOPolicy(
|
|
actor,
|
|
critic,
|
|
torch.optim.Adam(tuple(actor.parameters()) + tuple(critic.parameters())),
|
|
torch.distributions.Categorical,
|
|
action_space=NoopActionInterpreter().action_space,
|
|
)
|
|
return policy
|
|
|
|
|
|
def test_trainer():
|
|
set_log_with_config(C.logging_config)
|
|
trainer = Trainer(max_iters=10, finite_env_type="subproc")
|
|
policy = _ppo_policy()
|
|
|
|
vessel = TrainingVessel(
|
|
simulator_fn=lambda init: ZeroSimulator(init),
|
|
state_interpreter=NoopStateInterpreter(),
|
|
action_interpreter=NoopActionInterpreter(),
|
|
policy=policy,
|
|
train_initial_states=list(range(100)),
|
|
val_initial_states=list(range(10)),
|
|
test_initial_states=list(range(10)),
|
|
reward=AccReward(),
|
|
episode_per_iter=500,
|
|
update_kwargs=dict(repeat=10, batch_size=64),
|
|
)
|
|
trainer.fit(vessel)
|
|
assert trainer.current_iter == 10
|
|
assert trainer.current_episode == 5000
|
|
assert abs(trainer.metrics["acc"] - trainer.metrics["reward"] * 100) < 1e-4
|
|
assert trainer.metrics["acc"] > 80
|
|
trainer.test(vessel)
|
|
assert trainer.metrics["acc"] > 60
|
|
|
|
|
|
def test_trainer_fast_dev_run():
|
|
set_log_with_config(C.logging_config)
|
|
trainer = Trainer(max_iters=2, fast_dev_run=2, finite_env_type="shmem")
|
|
policy = _ppo_policy()
|
|
|
|
vessel = TrainingVessel(
|
|
simulator_fn=lambda init: ZeroSimulator(init),
|
|
state_interpreter=NoopStateInterpreter(),
|
|
action_interpreter=NoopActionInterpreter(),
|
|
policy=policy,
|
|
train_initial_states=list(range(100)),
|
|
val_initial_states=list(range(10)),
|
|
test_initial_states=list(range(10)),
|
|
reward=AccReward(),
|
|
episode_per_iter=500,
|
|
update_kwargs=dict(repeat=10, batch_size=64),
|
|
)
|
|
trainer.fit(vessel)
|
|
assert trainer.current_episode == 4
|
|
|
|
|
|
def test_trainer_earlystop():
|
|
# TODO this is just sanity check.
|
|
# need to see the logs to check whether it works.
|
|
set_log_with_config(C.logging_config)
|
|
trainer = Trainer(
|
|
max_iters=10,
|
|
val_every_n_iters=1,
|
|
finite_env_type="dummy",
|
|
callbacks=[EarlyStopping("val/reward", restore_best_weights=True)],
|
|
)
|
|
policy = _ppo_policy()
|
|
|
|
vessel = TrainingVessel(
|
|
simulator_fn=lambda init: ZeroSimulator(init),
|
|
state_interpreter=NoopStateInterpreter(),
|
|
action_interpreter=NoopActionInterpreter(),
|
|
policy=policy,
|
|
train_initial_states=list(range(100)),
|
|
val_initial_states=list(range(10)),
|
|
test_initial_states=list(range(10)),
|
|
reward=AccReward(),
|
|
episode_per_iter=500,
|
|
update_kwargs=dict(repeat=10, batch_size=64),
|
|
)
|
|
trainer.fit(vessel)
|
|
assert trainer.metrics["val/acc"] > 30
|
|
assert trainer.current_iter == 2 # second iteration
|
|
|
|
|
|
def test_trainer_checkpoint():
|
|
set_log_with_config(C.logging_config)
|
|
output_dir = Path(__file__).parent / ".output"
|
|
trainer = Trainer(max_iters=2, finite_env_type="dummy", callbacks=[Checkpoint(output_dir, every_n_iters=1)])
|
|
policy = _ppo_policy()
|
|
|
|
vessel = TrainingVessel(
|
|
simulator_fn=lambda init: ZeroSimulator(init),
|
|
state_interpreter=NoopStateInterpreter(),
|
|
action_interpreter=NoopActionInterpreter(),
|
|
policy=policy,
|
|
train_initial_states=list(range(100)),
|
|
val_initial_states=list(range(10)),
|
|
test_initial_states=list(range(10)),
|
|
reward=AccReward(),
|
|
episode_per_iter=100,
|
|
update_kwargs=dict(repeat=10, batch_size=64),
|
|
)
|
|
trainer.fit(vessel)
|
|
|
|
assert (output_dir / "001.pth").exists()
|
|
assert (output_dir / "002.pth").exists()
|
|
assert os.readlink(output_dir / "latest.pth") == str(output_dir / "002.pth")
|
|
|
|
trainer.load_state_dict(torch.load(output_dir / "001.pth"))
|
|
assert trainer.current_iter == 1
|
|
assert trainer.current_episode == 100
|
|
|
|
# Reload the checkpoint at first iteration
|
|
trainer.fit(vessel, ckpt_path=output_dir / "001.pth")
|