1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Qlib RL framework (stage 2) - trainer (#1125)

* checkpoint

(cherry picked from commit 1a8e0bd4671ee6d624a7d09bb198a273282cd050)

* Not a workable version

(cherry picked from commit 3498e185684cd5590d3ab97e0ab69eab8c1e0e3a)

* vessel

* ckpt

* .

* vessel

* .

* .

* checkpoint callback

* .

* cleanup

* logger

* .

* test

* .

* add test

* .

* .

* .

* .

* New reward

* Add train API

* fix mypy

* fix lint

* More comment

* 3.7 compat

* fix test

* fix test

* .

* Resolve comments

* fix typehint
This commit is contained in:
Yuge Zhang
2022-06-28 19:53:05 +08:00
committed by GitHub
parent 2ca0d88d2d
commit 25ecb1135f
17 changed files with 1410 additions and 145 deletions

View File

@@ -1,7 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Train, test, inference utilities.
The APIs in this directory are NOT considered final and are subject to change!
"""

View File

@@ -1,99 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import copy
from typing import Callable, Sequence
from tianshou.data import Collector
from tianshou.policy import BasePolicy
from qlib.constant import INF
from qlib.log import get_module_logger
from qlib.rl.simulator import InitialStateType, Simulator
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.reward import Reward
from qlib.rl.utils import DataQueue, EnvWrapper, FiniteEnvType, LogCollector, LogWriter, vectorize_env
_logger = get_module_logger(__name__)
def backtest(
simulator_fn: Callable[[InitialStateType], Simulator],
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
initial_states: Sequence[InitialStateType],
policy: BasePolicy,
logger: LogWriter | list[LogWriter],
reward: Reward | None = None,
finite_env_type: FiniteEnvType = "subproc",
concurrency: int = 2,
) -> None:
"""Backtest with the parallelism provided by RL framework.
Parameters
----------
simulator_fn
Callable receiving initial seed, returning a simulator.
state_interpreter
Interprets the state of simulators.
action_interpreter
Interprets the policy actions.
initial_states
Initial states to iterate over. Every state will be run exactly once.
policy
Policy to test against.
logger
Logger to record the backtest results. Logger must be present because
without logger, all information will be lost.
reward
Optional reward function. For backtest, this is for testing the rewards
and logging them only.
finite_env_type
Type of finite env implementation.
concurrency
Parallel workers.
"""
# To save bandwidth
min_loglevel = min(lg.loglevel for lg in logger) if isinstance(logger, list) else logger.loglevel
def env_factory():
# FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
# and could be thread unsafe.
# I'm not sure whether it's a design flaw.
# I'll rethink about this when designing the trainer.
if finite_env_type == "dummy":
# We could only experience the "threading-unsafe" problem in dummy.
state = copy.deepcopy(state_interpreter)
action = copy.deepcopy(action_interpreter)
rew = copy.deepcopy(reward)
else:
state, action, rew = state_interpreter, action_interpreter, reward
return EnvWrapper(
simulator_fn,
state,
action,
seed_iterator,
rew,
logger=LogCollector(min_loglevel=min_loglevel),
)
with DataQueue(initial_states) as seed_iterator:
vector_env = vectorize_env(
env_factory,
finite_env_type,
concurrency,
logger,
)
policy.eval()
with vector_env.collector_guard():
test_collector = Collector(policy, vector_env)
_logger.info("All ready. Start backtest.")
test_collector.collect(n_step=INF * len(vector_env))

View File

@@ -1,4 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# TBD

View File

@@ -9,4 +9,5 @@ Multi-asset is on the way.
from .interpreter import *
from .network import *
from .policy import *
from .reward import *
from .simulator_simple import *

View File

@@ -0,0 +1,46 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import cast
import numpy as np
from qlib.rl.reward import Reward
from .simulator_simple import SAOEState, SAOEMetrics
__all__ = ["PAPenaltyReward"]
class PAPenaltyReward(Reward[SAOEState]):
"""Encourage higher PAs, but penalize stacking all the amounts within a very short time.
Formally, for each time step, the reward is :math:`(PA_t * vol_t / target - vol_t^2 * penalty)`.
Parameters
----------
penalty
The penalty for large volume in a short time.
"""
def __init__(self, penalty: float = 100.0):
self.penalty = penalty
def reward(self, simulator_state: SAOEState) -> float:
whole_order = simulator_state.order.amount
assert whole_order > 0
last_step = cast(SAOEMetrics, simulator_state.history_steps.reset_index().iloc[-1].to_dict())
pa = last_step["pa"] * last_step["amount"] / whole_order
# Inspect the "break-down" of the latest step: trading amount at every tick
last_step_breakdown = simulator_state.history_exec.loc[last_step["datetime"] :]
penalty = -self.penalty * ((last_step_breakdown["amount"] / whole_order) ** 2).sum()
reward = pa + penalty
# Throw error in case of NaN
assert not (np.isnan(reward) or np.isinf(reward)), f"Invalid reward for simulator state: {simulator_state}"
self.log("reward/pa", pa)
self.log("reward/penalty", penalty)
return reward

View File

@@ -131,11 +131,14 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
"""
history_exec: pd.DataFrame
"""All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns."""
"""All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns.
Index is ``datetime``.
"""
history_steps: pd.DataFrame
"""Positions at each step. The position before first step is also recorded.
See :class:`SAOEMetrics` for available columns."""
See :class:`SAOEMetrics` for available columns.
Index is ``datetime``, which is the **starting** time of each step."""
metrics: SAOEMetrics | None
"""Metrics. Only available when done."""

View File

@@ -0,0 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Train, test, inference utilities."""
from .api import backtest, train
from .callbacks import EarlyStopping, Checkpoint
from .trainer import Trainer
from .vessel import TrainingVessel, TrainingVesselBase

118
qlib/rl/trainer/api.py Normal file
View File

@@ -0,0 +1,118 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import Callable, Sequence, cast, Any
from tianshou.policy import BasePolicy
from qlib.rl.simulator import InitialStateType, Simulator
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.reward import Reward
from qlib.rl.utils import FiniteEnvType, LogWriter
from .vessel import TrainingVessel
from .trainer import Trainer
def train(
simulator_fn: Callable[[InitialStateType], Simulator],
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
initial_states: Sequence[InitialStateType],
policy: BasePolicy,
reward: Reward,
vessel_kwargs: dict[str, Any],
trainer_kwargs: dict[str, Any],
) -> None:
"""Train a policy with the parallelism provided by RL framework.
Experimental API. Parameters might change shortly.
Parameters
----------
simulator_fn
Callable receiving initial seed, returning a simulator.
state_interpreter
Interprets the state of simulators.
action_interpreter
Interprets the policy actions.
initial_states
Initial states to iterate over. Every state will be run exactly once.
policy
Policy to train against.
reward
Reward function.
vessel_kwargs
Keyword arguments passed to :class:`TrainingVessel`, like ``episode_per_iter``.
trainer_kwargs
Keyword arguments passed to :class:`Trainer`, like ``finite_env_type``, ``concurrency``.
"""
vessel = TrainingVessel(
simulator_fn=simulator_fn,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
policy=policy,
train_initial_states=initial_states,
reward=reward, # ignore none
**vessel_kwargs,
)
trainer = Trainer(**trainer_kwargs)
trainer.fit(vessel)
def backtest(
simulator_fn: Callable[[InitialStateType], Simulator],
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
initial_states: Sequence[InitialStateType],
policy: BasePolicy,
logger: LogWriter | list[LogWriter],
reward: Reward | None = None,
finite_env_type: FiniteEnvType = "subproc",
concurrency: int = 2,
) -> None:
"""Backtest with the parallelism provided by RL framework.
Experimental API. Parameters might change shortly.
Parameters
----------
simulator_fn
Callable receiving initial seed, returning a simulator.
state_interpreter
Interprets the state of simulators.
action_interpreter
Interprets the policy actions.
initial_states
Initial states to iterate over. Every state will be run exactly once.
policy
Policy to test against.
logger
Logger to record the backtest results. Logger must be present because
without logger, all information will be lost.
reward
Optional reward function. For backtest, this is for testing the rewards
and logging them only.
finite_env_type
Type of finite env implementation.
concurrency
Parallel workers.
"""
vessel = TrainingVessel(
simulator_fn=simulator_fn,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
policy=policy,
test_initial_states=initial_states,
reward=cast(Reward, reward), # ignore none
)
trainer = Trainer(
finite_env_type=finite_env_type,
concurrency=concurrency,
loggers=logger,
)
trainer.test(vessel)

View File

@@ -0,0 +1,267 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Callbacks to insert customized recipes during the training.
Mimicks the hooks of Keras / PyTorch-Lightning, but tailored for the context of RL.
"""
from __future__ import annotations
import copy
import shutil
import time
from datetime import datetime
from pathlib import Path
from typing import Any, TYPE_CHECKING
import numpy as np
import torch
from qlib.log import get_module_logger
from qlib.typehint import Literal
if TYPE_CHECKING:
from .trainer import Trainer
from .vessel import TrainingVesselBase
_logger = get_module_logger(__name__)
class Callback:
"""Base class of all callbacks."""
def on_fit_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called before the whole fit process begins."""
def on_fit_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called after the whole fit process ends."""
def on_train_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called when each collect for training begins."""
def on_train_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called when the training ends.
To access all outputs produced during training, cache the data in either trainer and vessel,
and post-process them in this hook.
"""
def on_validate_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called when every run for validation begins."""
def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called when the validation ends."""
def on_test_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called when every run of testing begins."""
def on_test_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called when the testing ends."""
def on_iter_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called when every iteration (i.e., collect) starts."""
def on_iter_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called upon every end of iteration.
This is called **after** the bump of ``current_iter``,
when the previous iteration is considered complete.
"""
def state_dict(self) -> Any:
"""Get a state dict of the callback for pause and resume."""
def load_state_dict(self, state_dict: Any) -> None:
"""Resume the callback from a saved state dict."""
class EarlyStopping(Callback):
"""Stop training when a monitored metric has stopped improving.
The earlystopping callback will be triggered each time validation ends.
It will examine the metrics produced in validation,
and get the metric with name ``monitor` (``monitor`` is ``reward`` by default),
to check whether it's no longer increasing / decreasing.
It takes ``min_delta`` and ``patience`` if applicable.
If it's found to be not increasing / decreasing any more.
``trainer.should_stop`` will be set to true,
and the training terminates.
Implementation reference: https://github.com/keras-team/keras/blob/v2.9.0/keras/callbacks.py#L1744-L1893
"""
def __init__(
self,
monitor: str = "reward",
min_delta: float = 0.0,
patience: int = 0,
mode: Literal["min", "max"] = "max",
baseline: float | None = None,
restore_best_weights: bool = False,
):
super().__init__()
self.monitor = monitor
self.patience = patience
self.baseline = baseline
self.min_delta = abs(min_delta)
self.restore_best_weights = restore_best_weights
self.best_weights: Any | None = None
if mode not in ["min", "max"]:
raise ValueError("Unsupported earlystopping mode: " + mode)
if mode == "min":
self.monitor_op = np.less
elif mode == "max":
self.monitor_op = np.greater
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
def state_dict(self) -> dict:
return {"wait": self.wait, "best": self.best, "best_weights": self.best_weights, "best_iter": self.best_iter}
def load_state_dict(self, state_dict: dict) -> None:
self.wait = state_dict["wait"]
self.best = state_dict["best"]
self.best_weights = state_dict["best_weights"]
self.best_iter = state_dict["best_iter"]
def on_fit_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
# Allow instances to be re-used
self.wait = 0
self.best = np.inf if self.monitor_op == np.less else -np.inf
self.best_weights = None
self.best_iter = 0
def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
current = self.get_monitor_value(trainer)
if current is None:
return
if self.restore_best_weights and self.best_weights is None:
# Restore the weights after first iteration if no progress is ever made.
self.best_weights = copy.deepcopy(vessel.state_dict())
self.wait += 1
if self._is_improvement(current, self.best):
self.best = current
self.best_iter = trainer.current_iter
if self.restore_best_weights:
self.best_weights = copy.deepcopy(vessel.state_dict())
# Only restart wait if we beat both the baseline and our previous best.
if self.baseline is None or self._is_improvement(current, self.baseline):
self.wait = 0
# Only check after the first epoch.
if self.wait >= self.patience and trainer.current_iter > 0:
trainer.should_stop = True
_logger.info(f"On iteration %d: early stopping", trainer.current_iter + 1)
if self.restore_best_weights and self.best_weights is not None:
_logger.info("Restoring model weights from the end of the best iteration: %d", self.best_iter + 1)
vessel.load_state_dict(self.best_weights)
def get_monitor_value(self, trainer: Trainer) -> Any:
monitor_value = trainer.metrics.get(self.monitor)
if monitor_value is None:
_logger.warning(
"Early stopping conditioned on metric `%s` which is not available. Available metrics are: %s",
self.monitor,
",".join(list(trainer.metrics.keys())),
)
return monitor_value
def _is_improvement(self, monitor_value, reference_value):
return self.monitor_op(monitor_value - self.min_delta, reference_value)
class Checkpoint(Callback):
"""Save checkpoints periodically for persistence and recovery.
Reference: https://github.com/PyTorchLightning/pytorch-lightning/blob/bfa8b7be/pytorch_lightning/callbacks/model_checkpoint.py
Parameters
----------
dirpath
Directory to save the checkpoint file.
filename
Checkpoint filename. Can contain named formatting options to be auto-filled.
For example: ``{iter:03d}-{reward:.2f}.pth``.
Supported argument names are:
- iter (int)
- metrics in ``trainer.metrics``
- time string, in the format of ``%Y%m%d%H%M%S``
save_latest
Save the latest checkpoint in ``latest.pth``.
If ``link``, ``latest.pth`` will be created as a softlink.
If ``copy``, ``latest.pth`` will be stored as an individual copy.
Set to none to disable this.
every_n_iters
Checkpoints are saved at the end of every n iterations of training,
after validation if applicable.
time_interval
Maximum time (seconds) before checkpoints save again.
save_on_fit_end
Save one last checkpoint at the end to fit.
Do nothing if a checkpoint is already saved there.
"""
def __init__(
self,
dirpath: Path,
filename: str = "{iter:03d}.pth",
save_latest: Literal["link", "copy"] | None = "link",
every_n_iters: int | None = None,
time_interval: int | None = None,
save_on_fit_end: bool = True,
):
self.dirpath = Path(dirpath)
self.filename = filename
self.save_latest = save_latest
self.every_n_iters = every_n_iters
self.time_interval = time_interval
self.save_on_fit_end = save_on_fit_end
self._last_checkpoint_name: str | None = None
self._last_checkpoint_iter: int | None = None
self._last_checkpoint_time: float | None = None
def on_fit_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
if self.save_on_fit_end and (trainer.current_iter != self._last_checkpoint_iter):
self._save_checkpoint(trainer)
def on_iter_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
should_save_ckpt = False
if self.every_n_iters is not None and (trainer.current_iter + 1) % self.every_n_iters == 0:
should_save_ckpt = True
if self.time_interval is not None and (
self._last_checkpoint_time is None or (time.time() - self._last_checkpoint_time) >= self.time_interval
):
should_save_ckpt = True
if should_save_ckpt:
self._save_checkpoint(trainer)
def _save_checkpoint(self, trainer: Trainer) -> None:
self.dirpath.mkdir(exist_ok=True, parents=True)
self._last_checkpoint_name = self._new_checkpoint_name(trainer)
self._last_checkpoint_iter = trainer.current_iter
self._last_checkpoint_time = time.time()
torch.save(trainer.state_dict(), self.dirpath / self._last_checkpoint_name)
latest_pth = self.dirpath / "latest.pth"
# Remove first before saving
if self.save_latest and latest_pth.exists():
latest_pth.unlink()
if self.save_latest == "link":
latest_pth.symlink_to(self.dirpath / self._last_checkpoint_name)
elif self.save_latest == "copy":
shutil.copyfile(self.dirpath / self._last_checkpoint_name, latest_pth)
def _new_checkpoint_name(self, trainer: Trainer) -> str:
return self.filename.format(
iter=trainer.current_iter, time=datetime.now().strftime("%Y%m%d%H%M%S"), **trainer.metrics
)

343
qlib/rl/trainer/trainer.py Normal file
View File

@@ -0,0 +1,343 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import copy
from contextlib import AbstractContextManager, contextmanager
from pathlib import Path
from typing import Any, Iterable, TypeVar, Sequence, cast
import torch
from qlib.rl.simulator import InitialStateType
from qlib.rl.utils import EnvWrapper, FiniteEnvType, LogCollector, LogWriter, LogBuffer, vectorize_env, LogLevel
from qlib.log import get_module_logger
from qlib.rl.utils.finite_env import FiniteVectorEnv
from qlib.typehint import Literal
from .callbacks import Callback
from .vessel import TrainingVesselBase
_logger = get_module_logger(__name__)
T = TypeVar("T")
class Trainer:
"""
Utility to train a policy on a particular task.
Different from traditional DL trainer, the iteration of this trainer is "collect",
rather than "epoch", or "mini-batch".
In each collect, :class:`Collector` collects a number of policy-env interactions, and accumulates
them into a replay buffer. This buffer is used as the "data" to train the policy.
At the end of each collect, the policy is *updated* several times.
The API has some resemblence with `PyTorch Lightning <https://pytorch-lightning.readthedocs.io/>`__,
but it's essentially different because this trainer is built for RL applications, and thus
most configurations are under RL context.
We are still looking for ways to incorporate existing trainer libraries, because it looks like
big efforts to build a trainer as powerful as those libraries, and also, that's not our primary goal.
It's essentially different
`tianshou's built-in trainers <https://tianshou.readthedocs.io/en/master/api/tianshou.trainer.html>`__,
as it's far much more complicated than that.
Parameters
----------
max_iters
Maximum iterations before stopping.
val_every_n_iters
Perform validation every n iterations (i.e., training collects).
logger
Logger to record the backtest results. Logger must be present because
without logger, all information will be lost.
finite_env_type
Type of finite env implementation.
concurrency
Parallel workers.
fast_dev_run
Create a subset for debugging.
How this is implemented depends on the implementation of training vessel.
For :class:`~qlib.rl.vessel.TrainingVessel`, if greater than zero,
a random subset sized ``fast_dev_run`` will be used
instead of ``train_initial_states`` and ``val_initial_states``.
"""
should_stop: bool
"""Set to stop the training."""
metrics: dict
"""Numeric metrics of produced in train/val/test.
In the middle of training / validation, metrics will be of the latest episode.
When each iteration of training / validation finishes, metrics will be the aggregation
of all episodes encountered in this iteration.
Cleared on every new iteration of training.
In fit, validation metrics will be prefixed with ``val/``.
"""
current_iter: int
"""Current iteration (collect) of training."""
loggers: list[LogWriter]
"""A list of log writers."""
def __init__(
self,
*,
max_iters: int | None = None,
val_every_n_iters: int | None = None,
loggers: LogWriter | list[LogWriter] | None = None,
callbacks: list[Callback] | None = None,
finite_env_type: FiniteEnvType = "subproc",
concurrency: int = 2,
fast_dev_run: int | None = None,
):
self.max_iters = max_iters
self.val_every_n_iters = val_every_n_iters
if isinstance(loggers, list):
self.loggers = loggers
elif isinstance(loggers, LogWriter):
self.loggers = [loggers]
else:
self.loggers = []
self.loggers.append(LogBuffer(self._metrics_callback, loglevel=self._min_loglevel()))
self.callbacks: list[Callback] = callbacks if callbacks is not None else []
self.finite_env_type = finite_env_type
self.concurrency = concurrency
self.fast_dev_run = fast_dev_run
self.current_stage: Literal["train", "val", "test"] = "train"
self.vessel: TrainingVesselBase = cast(TrainingVesselBase, None)
def initialize(self):
"""Initialize the whole training process.
The states here should be synchronized with state_dict.
"""
self.should_stop = False
self.current_iter = 0
self.current_episode = 0
self.current_stage = "train"
def initialize_iter(self):
"""Initialize one iteration / collect."""
self.metrics = {}
def state_dict(self) -> dict:
"""Putting every states of current training into a dict, at best effort.
It doesn't try to handle all the possible kinds of states in the middle of one training collect.
For most cases at the end of each iteration, things should be usually correct.
Note that it's also intended behavior that replay buffer data in the collector will be lost.
"""
return {
"vessel": self.vessel.state_dict(),
"callbacks": {name: callback.state_dict() for name, callback in self.named_callbacks().items()},
"loggers": {name: logger.state_dict() for name, logger in self.named_loggers().items()},
"should_stop": self.should_stop,
"current_iter": self.current_iter,
"current_episode": self.current_episode,
"current_stage": self.current_stage,
"metrics": self.metrics,
}
def load_state_dict(self, state_dict: dict) -> None:
"""Load all states into current trainer."""
self.vessel.load_state_dict(state_dict["vessel"])
for name, callback in self.named_callbacks().items():
callback.load_state_dict(state_dict["callbacks"][name])
for name, logger in self.named_loggers().items():
logger.load_state_dict(state_dict["loggers"][name])
self.should_stop = state_dict["should_stop"]
self.current_iter = state_dict["current_iter"]
self.current_episode = state_dict["current_episode"]
self.current_stage = state_dict["current_stage"]
self.metrics = state_dict["metrics"]
def named_callbacks(self) -> dict[str, Callback]:
"""Retrieve a collection of callbacks where each one has a name.
Useful when saving checkpoints.
"""
return _named_collection(self.callbacks)
def named_loggers(self) -> dict[str, LogWriter]:
"""Retrieve a collection of loggers where each one has a name.
Useful when saving checkpoints.
"""
return _named_collection(self.loggers)
def fit(self, vessel: TrainingVesselBase, ckpt_path: Path | None = None) -> None:
"""Train the RL policy upon the defined simulator.
Parameters
----------
vessel
A bundle of all elements used in training.
ckpt_path
Load a pre-trained / paused training checkpoint.
"""
self.vessel = vessel
vessel.assign_trainer(self)
if ckpt_path is not None:
_logger.info("Resuming states from %s", str(ckpt_path))
self.load_state_dict(torch.load(ckpt_path))
else:
self.initialize()
self._call_callback_hooks("on_fit_start")
while not self.should_stop:
self.initialize_iter()
self._call_callback_hooks("on_iter_start")
self.current_stage = "train"
self._call_callback_hooks("on_train_start")
# TODO
# Add a feature that supports reloading the training environment every few iterations.
with _wrap_context(vessel.train_seed_iterator()) as iterator:
vector_env = self.venv_from_iterator(iterator)
self.vessel.train(vector_env)
self._call_callback_hooks("on_train_end")
if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0:
# Implementation of validation loop
self.current_stage = "val"
self._call_callback_hooks("on_validate_start")
with _wrap_context(vessel.val_seed_iterator()) as iterator:
vector_env = self.venv_from_iterator(iterator)
self.vessel.validate(vector_env)
self._call_callback_hooks("on_validate_end")
# This iteration is considered complete.
# Bumping the current iteration counter.
self.current_iter += 1
if self.max_iters is not None and self.current_iter >= self.max_iters:
self.should_stop = True
self._call_callback_hooks("on_iter_end")
self._call_callback_hooks("on_fit_end")
def test(self, vessel: TrainingVesselBase) -> None:
"""Test the RL policy against the simulator.
The simulator will be fed with data generated in ``test_seed_iterator``.
Parameters
----------
vessel
A bundle of all related elements.
"""
self.vessel = vessel
vessel.assign_trainer(self)
self.initialize_iter()
self.current_stage = "test"
self._call_callback_hooks("on_test_start")
with _wrap_context(vessel.test_seed_iterator()) as iterator:
vector_env = self.venv_from_iterator(iterator)
self.vessel.test(vector_env)
self._call_callback_hooks("on_test_end")
def venv_from_iterator(self, iterator: Iterable[InitialStateType]) -> FiniteVectorEnv:
"""Create a vectorized environment from iterator and the training vessel."""
def env_factory():
# FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
# and could be thread unsafe.
# I'm not sure whether it's a design flaw.
# I'll rethink about this when designing the trainer.
if self.finite_env_type == "dummy":
# We could only experience the "threading-unsafe" problem in dummy.
state = copy.deepcopy(self.vessel.state_interpreter)
action = copy.deepcopy(self.vessel.action_interpreter)
rew = copy.deepcopy(self.vessel.reward)
else:
state = self.vessel.state_interpreter
action = self.vessel.action_interpreter
rew = self.vessel.reward
return EnvWrapper(
self.vessel.simulator_fn,
state,
action,
iterator,
rew,
logger=LogCollector(min_loglevel=self._min_loglevel()),
)
return vectorize_env(
env_factory,
self.finite_env_type,
self.concurrency,
self.loggers,
)
def _metrics_callback(self, on_episode: bool, on_collect: bool, log_buffer: LogBuffer) -> None:
if on_episode:
# Update the global counter.
self.current_episode = log_buffer.global_episode
metrics = log_buffer.episode_metrics()
elif on_collect:
# Update the latest metrics.
metrics = log_buffer.collect_metrics()
if self.current_stage == "val":
metrics = {"val/" + name: value for name, value in metrics.items()}
self.metrics.update(metrics)
def _call_callback_hooks(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
for callback in self.callbacks:
fn = getattr(callback, hook_name)
fn(self, self.vessel, *args, **kwargs)
def _min_loglevel(self):
if not self.loggers:
return LogLevel.PERIODIC
else:
# To save bandwidth
return min(lg.loglevel for lg in self.loggers)
@contextmanager
def _wrap_context(obj):
"""Make any object a (possibly dummy) context manager."""
if isinstance(obj, AbstractContextManager):
# obj has __enter__ and __exit__
with obj as ctx:
yield ctx
else:
yield obj
def _named_collection(seq: Sequence[T]) -> dict[str, T]:
"""Convert a list into a dict, where each item is named with its type."""
res = {}
for item in seq:
typename = type(item).__name__.lower()
if typename not in res:
res[typename] = item
else:
# names are auto-labelled as earlystop1, earlystop2, ...
for retry in range(1, 1000):
if f"{typename}{retry}" not in res:
res[f"{typename}{retry}"] = item
return res

214
qlib/rl/trainer/vessel.py Normal file
View File

@@ -0,0 +1,214 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import weakref
from typing import Callable, ContextManager, Generic, Iterable, TYPE_CHECKING, Sequence, Any, TypeVar, cast, Dict
import numpy as np
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import BaseVectorEnv
from tianshou.policy import BasePolicy
from qlib.constant import INF
from qlib.rl.interpreter import StateType, ActType, ObsType, PolicyActType
from qlib.rl.simulator import InitialStateType, Simulator
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.reward import Reward
from qlib.rl.utils import DataQueue
from qlib.log import get_module_logger
from qlib.rl.utils.finite_env import FiniteVectorEnv
if TYPE_CHECKING:
from .trainer import Trainer
T = TypeVar("T")
_logger = get_module_logger(__name__)
class SeedIteratorNotAvailable(BaseException):
pass
class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType]):
"""A ship that contains simulator, interpreter, and policy, will be sent to trainer.
This class controls algorithm-related parts of training, while trainer is responsible for runtime part.
The ship also defines the most important logic of the core training part,
and (optionally) some callbacks to insert customized logics at specific events.
"""
simulator_fn: Callable[[InitialStateType], Simulator[InitialStateType, StateType, ActType]]
state_interpreter: StateInterpreter[StateType, ObsType]
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType]
policy: BasePolicy
reward: Reward
trainer: Trainer
def assign_trainer(self, trainer: Trainer) -> None:
self.trainer = weakref.proxy(trainer) # type: ignore
def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
"""Override this to create a seed iterator for training.
If the iterable is a context manager, the whole training will be invoked in the with-block,
and the iterator will be automatically closed after the training is done."""
raise SeedIteratorNotAvailable("Seed iterator for training is not available.")
def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
"""Override this to create a seed iterator for validation."""
raise SeedIteratorNotAvailable("Seed iterator for validation is not available.")
def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
"""Override this to create a seed iterator for testing."""
raise SeedIteratorNotAvailable("Seed iterator for testing is not available.")
def train(self, vector_env: BaseVectorEnv) -> dict[str, Any]:
"""Implement this to train one iteration. In RL, one iteration usually refers to one collect."""
raise NotImplementedError()
def validate(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
"""Implement this to validate the policy once."""
raise NotImplementedError()
def test(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
"""Implement this to evaluate the policy on test environment once."""
raise NotImplementedError()
def log(self, name: str, value: Any) -> None:
# FIXME: this is a workaround to make the log at least show somewhere.
# Need a refactor in logger to formalize this.
if isinstance(value, (np.ndarray, list)):
value = np.mean(value)
_logger.info(f"[Iter {self.trainer.current_iter + 1}] {name} = {value}")
def log_dict(self, data: dict[str, Any]) -> None:
for name, value in data.items():
self.log(name, value)
def state_dict(self) -> dict:
"""Return a checkpoint of current vessel state."""
return {"policy": self.policy.state_dict()}
def load_state_dict(self, state_dict: dict) -> None:
"""Restore a checkpoint from a previously saved state dict."""
self.policy.load_state_dict(state_dict["policy"])
class TrainingVessel(TrainingVesselBase):
"""The default implementation of training vessel.
``__init__`` accepts a sequence of initial states so that iterator can be created.
``train``, ``validate``, ``test`` each do one collect (and also update in train).
By default, the train initial states will be repeated infinitely during training,
and collector will control the number of episodes for each iteration.
In validation and testing, the val / test initial states will be used exactly once.
Extra hyper-parameters (only used in train) include:
- ``buffer_size``: Size of replay buffer.
- ``episode_per_iter``: Episodes per collect at training. Can be overridden by fast dev run.
- ``update_kwargs``: Keyword arguments appearing in ``policy.update``.
For example, ``dict(repeat=10, batch_size=64)``.
"""
def __init__(
self,
*,
simulator_fn: Callable[[InitialStateType], Simulator[InitialStateType, StateType, ActType]],
state_interpreter: StateInterpreter[StateType, ObsType],
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
policy: BasePolicy,
reward: Reward,
train_initial_states: Sequence[InitialStateType] | None = None,
val_initial_states: Sequence[InitialStateType] | None = None,
test_initial_states: Sequence[InitialStateType] | None = None,
buffer_size: int = 20000,
episode_per_iter: int = 1000,
update_kwargs: dict[str, Any] = cast(Dict[str, Any], None),
):
self.simulator_fn = simulator_fn # type: ignore
self.state_interpreter = state_interpreter
self.action_interpreter = action_interpreter
self.policy = policy
self.reward = reward
self.train_initial_states = train_initial_states
self.val_initial_states = val_initial_states
self.test_initial_states = test_initial_states
self.buffer_size = buffer_size
self.episode_per_iter = episode_per_iter
self.update_kwargs = update_kwargs or {}
def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
if self.train_initial_states is not None:
_logger.info("Training initial states collection size: %d", len(self.train_initial_states))
# Implement fast_dev_run here.
train_initial_states = self._random_subset("train", self.train_initial_states, self.trainer.fast_dev_run)
return DataQueue(train_initial_states, repeat=-1, shuffle=True)
return super().train_seed_iterator()
def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
if self.val_initial_states is not None:
_logger.info("Validation initial states collection size: %d", len(self.val_initial_states))
val_initial_states = self._random_subset("val", self.val_initial_states, self.trainer.fast_dev_run)
return DataQueue(val_initial_states, repeat=1)
return super().val_seed_iterator()
def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
if self.test_initial_states is not None:
_logger.info("Testing initial states collection size: %d", len(self.test_initial_states))
test_initial_states = self._random_subset("test", self.test_initial_states, self.trainer.fast_dev_run)
return DataQueue(test_initial_states, repeat=1)
return super().test_seed_iterator()
def train(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
"""Create a collector and collects ``episode_per_iter`` episodes.
Update the policy on the collected replay buffer.
"""
self.policy.train()
with vector_env.collector_guard():
collector = Collector(self.policy, vector_env, VectorReplayBuffer(self.buffer_size, len(vector_env)))
# Number of episodes collected in each training iteration can be overridden by fast dev run.
if self.trainer.fast_dev_run is not None:
episodes = self.trainer.fast_dev_run
else:
episodes = self.episode_per_iter
col_result = collector.collect(n_episode=episodes)
update_result = self.policy.update(sample_size=0, buffer=collector.buffer, **self.update_kwargs)
res = {**col_result, **update_result}
self.log_dict(res)
return res
def validate(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
self.policy.eval()
with vector_env.collector_guard():
test_collector = Collector(self.policy, vector_env)
res = test_collector.collect(n_step=INF * len(vector_env))
self.log_dict(res)
return res
def test(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
self.policy.eval()
with vector_env.collector_guard():
test_collector = Collector(self.policy, vector_env)
res = test_collector.collect(n_step=INF * len(vector_env))
self.log_dict(res)
return res
@staticmethod
def _random_subset(name: str, collection: Sequence[T], size: int | None) -> Sequence[T]:
if size is None:
# Size = None -> original collection
return collection
order = np.random.permutation(len(collection))
res = [collection[o] for o in order[:size]]
_logger.info(
"Fast running in development mode. Cut %s initial states from %d to %d.", name, len(collection), len(res)
)
return res

View File

@@ -145,7 +145,9 @@ class DataQueue(Generic[T]):
def __iter__(self):
if not self._activated:
raise ValueError(
"Need to call activate() to launch a daemon worker to produce data into data queue before using it."
"Need to call activate() to launch a daemon worker "
"to produce data into data queue before using it. "
"You probably have forgotten to use the DataQueue in a with block."
)
return self._consumer()
@@ -161,19 +163,21 @@ class DataQueue(Generic[T]):
# pytorch dataloader is used here only because we need its sampler and multi-processing
from torch.utils.data import DataLoader, Dataset # pylint: disable=import-outside-toplevel
dataloader = DataLoader(
cast(Dataset[T], self.dataset),
batch_size=None,
num_workers=self.producer_num_workers,
shuffle=self.shuffle,
collate_fn=lambda t: t, # identity collate fn
)
repeat = 10**18 if self.repeat == -1 else self.repeat
for _rep in range(repeat):
for data in dataloader:
if self._done.value:
# Already done.
return
self._queue.put(data)
_logger.debug(f"Dataloader loop done. Repeat {_rep}.")
self.mark_as_done()
try:
dataloader = DataLoader(
cast(Dataset[T], self.dataset),
batch_size=None,
num_workers=self.producer_num_workers,
shuffle=self.shuffle,
collate_fn=lambda t: t, # identity collate fn
)
repeat = 10**18 if self.repeat == -1 else self.repeat
for _rep in range(repeat):
for data in dataloader:
if self._done.value:
# Already done.
return
self._queue.put(data)
_logger.debug(f"Dataloader loop done. Repeat {_rep}.")
finally:
self.mark_as_done()

View File

@@ -120,12 +120,19 @@ class FiniteVectorEnv(BaseVectorEnv):
from child workers. See :class:`qlib.rl.utils.LogWriter`.
"""
_logger: list[LogWriter]
def __init__(
self, logger: LogWriter | list[LogWriter], env_fns: list[Callable[..., gym.Env]], **kwargs: Any
self, logger: LogWriter | list[LogWriter] | None, env_fns: list[Callable[..., gym.Env]], **kwargs: Any
) -> None:
super().__init__(env_fns, **kwargs)
self._logger: list[LogWriter] = logger if isinstance(logger, list) else [logger]
if isinstance(logger, list):
self._logger = logger
elif isinstance(logger, LogWriter):
self._logger = [logger]
else:
self._logger = []
self._alive_env_ids: Set[int] = set()
self._reset_alive_envs()
self._default_obs = self._default_info = self._default_rew = None
@@ -177,7 +184,7 @@ class FiniteVectorEnv(BaseVectorEnv):
1. Catch and ignore the StopIteration exception, which is the stopping signal
thrown by FiniteEnv to let tianshou know that ``collector.collect()`` should exit.
2. Notify the loggers that the collect is done what it's done.
2. Notify the loggers that the collect is ready / done what it's ready / done.
Examples
--------
@@ -186,6 +193,9 @@ class FiniteVectorEnv(BaseVectorEnv):
"""
self._collector_guarded = True
for logger in self._logger:
logger.on_env_all_ready()
try:
yield self
except StopIteration:
@@ -298,7 +308,21 @@ def vectorize_env(
concurrency: int,
logger: LogWriter | list[LogWriter],
) -> FiniteVectorEnv:
"""Helper function to create a vector env.
"""Helper function to create a vector env. Can be used to replace usual VectorEnv.
For example, once you wrote: ::
DummyVectorEnv([lambda: gym.make(task) for _ in range(env_num)])
Now you can replace it with: ::
finite_env_factory(lambda: gym.make(task), "dummy", env_num, my_logger)
By doing such replacement, you have two additional features enabled (compared to normal VectorEnv):
1. The vector env will check for NaN observation and kill the worker when its found.
See :class:`FiniteVectorEnv` for why we need this.
2. A logger to explicit collect logs from environment workers.
Parameters
----------

View File

@@ -12,13 +12,16 @@ in each worker, and writes them to console, log files, or tensorboard...
The two modules communicate by the "log" field in "info" returned by ``env.step()``.
"""
# NOTE: This file contains many hardcoded / ad-hoc rules.
# Refactoring it will be one of the future tasks.
from __future__ import annotations
import logging
from collections import defaultdict
from enum import IntEnum
from pathlib import Path
from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence
from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence, Callable
import numpy as np
import pandas as pd
@@ -29,7 +32,7 @@ if TYPE_CHECKING:
from .env_wrapper import InfoDict
__all__ = ["LogCollector", "LogWriter", "LogLevel", "ConsoleWriter", "CsvWriter"]
__all__ = ["LogCollector", "LogWriter", "LogLevel", "LogBuffer", "ConsoleWriter", "CsvWriter"]
ObsType = TypeVar("ObsType")
ActType = TypeVar("ActType")
@@ -175,18 +178,53 @@ class LogWriter(Generic[ObsType, ActType]):
self.clear()
def clear(self):
"""Clear all the metrics for a fresh start.
To make the logger instance reusable.
"""
self.episode_count = self.step_count = 0
self.active_env_ids = set()
self.logs = []
def aggregation(self, array: Sequence[Any]) -> Any:
def state_dict(self) -> dict:
"""Save the states of the logger to a dict."""
return {
"episode_count": self.episode_count,
"step_count": self.step_count,
"global_step": self.global_step,
"global_episode": self.global_episode,
"active_env_ids": self.active_env_ids,
"episode_lengths": self.episode_lengths,
"episode_rewards": self.episode_rewards,
"episode_logs": self.episode_logs,
}
def load_state_dict(self, state_dict: dict) -> None:
"""Load the states of current logger from a dict."""
self.episode_count = state_dict["episode_count"]
self.step_count = state_dict["step_count"]
self.global_step = state_dict["global_step"]
self.global_episode = state_dict["global_episode"]
# These are runtime infos.
# Though they are loaded, I don't think it really helps.
self.active_env_ids = state_dict["active_env_ids"]
self.episode_lenghts = state_dict["episode_lengths"]
self.episode_rewards = state_dict["episode_rewards"]
self.episode_logs = state_dict["episode_logs"]
def aggregation(self, array: Sequence[Any], name: str | None = None) -> Any:
"""Aggregation function from step-wise to episode-wise.
If it's a sequence of float, take the mean.
Otherwise, take the first element.
If a name is specified and,
- if it's ``reward``, the reduction will be sum.
"""
assert len(array) > 0, "The aggregated array must be not empty."
if all(isinstance(v, float) for v in array):
if name == "reward":
return np.sum(array)
return np.mean(array)
else:
return array[0]
@@ -253,10 +291,93 @@ class LogWriter(Generic[ObsType, ActType]):
self.episode_rewards[env_id] = []
self.episode_logs[env_id] = []
def on_env_all_ready(self) -> None:
"""When all environments are ready to run.
Usually, loggers should be reset here.
"""
self.clear()
def on_env_all_done(self) -> None:
"""All done. Time for cleanup."""
class LogBuffer(LogWriter):
"""Keep all numbers in memory.
Objects that can't be aggregated like strings, tensors, images can't be stored in the buffer.
To persist them, please use :class:`PickleWriter`.
Every time, Log buffer receives a new metric, the callback is triggered,
which is useful when tracking metrics inside a trainer.
Parameters
----------
callback
A callback receiving three arguments:
- on_episode: Whether it's called at the end of an episode
- on_collect: Whether it's called at the end of a collect
- log_buffer: the :class:`LogBbuffer`object
No return value is expected.
"""
# FIXME: needs a metric count
def __init__(self, callback: Callable[[bool, bool, LogBuffer], None], loglevel: int | LogLevel = LogLevel.PERIODIC):
super().__init__(loglevel)
self.callback = callback
def state_dict(self) -> dict:
return {
**super().state_dict(),
"latest_metrics": self._latest_metrics,
"aggregated_metrics": self._aggregated_metrics,
}
def load_state_dict(self, state_dict: dict) -> None:
self._latest_metrics = state_dict["latest_metrics"]
self._aggregated_metrics = state_dict["aggregated_metrics"]
return super().load_state_dict(state_dict)
def clear(self):
super().clear()
self._latest_metrics: dict[str, float] | None = None
self._aggregated_metrics: dict[str, float] = defaultdict(float)
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
# FIXME Dup of ConsoleWriter
episode_wise_contents: dict[str, list] = defaultdict(list)
for step_contents in contents:
for name, value in step_contents.items():
# FIXME This could be false-negative for some numpy types
if isinstance(value, float):
episode_wise_contents[name].append(value)
logs: dict[str, float] = {}
for name, values in episode_wise_contents.items():
logs[name] = self.aggregation(values, name) # type: ignore
self._aggregated_metrics[name] += logs[name]
self._latest_metrics = logs
self.callback(True, False, self)
def on_env_all_done(self) -> None:
# This happens when collect exits
self.callback(False, True, self)
def episode_metrics(self) -> dict[str, float]:
"""Retrieve the numeric metrics of the latest episode."""
if self._latest_metrics is None:
raise ValueError("No episode metrics available yet.")
return self._latest_metrics
def collect_metrics(self) -> dict[str, float]:
"""Retrieve the aggregated metrics of the latest collect."""
return {name: value / self.episode_count for name, value in self._aggregated_metrics.items()}
class ConsoleWriter(LogWriter):
"""Write log messages to console periodically.
@@ -289,6 +410,8 @@ class ConsoleWriter(LogWriter):
self.console_logger = get_module_logger(__name__, level=logging.INFO)
# FIXME: save & reload
def clear(self):
super().clear()
# Clear average meters
@@ -308,7 +431,7 @@ class ConsoleWriter(LogWriter):
# This should be done at every step, regardless of periodic or not.
logs: dict[str, float] = {}
for name, values in episode_wise_contents.items():
logs[name] = self.aggregation(values) # type: ignore
logs[name] = self.aggregation(values, name) # type: ignore
for name, value in logs.items():
self.metric_counts[name] += 1
@@ -350,6 +473,8 @@ class CsvWriter(LogWriter):
all_records: list[dict[str, Any]]
# FIXME: save & reload
def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC):
super().__init__(loglevel)
self.output_dir = output_dir
@@ -370,7 +495,7 @@ class CsvWriter(LogWriter):
logs: dict[str, float] = {}
for name, values in episode_wise_contents.items():
logs[name] = self.aggregation(values) # type: ignore
logs[name] = self.aggregation(values, name) # type: ignore
self.all_records.append(logs)
@@ -392,7 +517,3 @@ class TensorboardWriter(LogWriter):
class MlflowWriter(LogWriter):
"""Add logs to mlflow."""
class LogBuffer(LogWriter):
"""Keep everything in memory."""

View File

@@ -81,7 +81,7 @@ def test_simple_env_logger(caplog):
line = line.strip()
if line:
line_counter += 1
assert re.match(r".*reward 42\.0000 \(42.0000\) a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
assert re.match(r".*reward .* a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
assert line_counter >= 3

View File

@@ -17,7 +17,7 @@ from qlib.backtest import Order
from qlib.config import C
from qlib.log import set_log_with_config
from qlib.rl.data import pickle_styled
from qlib.rl.entries.test import backtest
from qlib.rl.trainer import backtest, train
from qlib.rl.order_execution import *
from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
@@ -306,3 +306,26 @@ def test_cn_ppo_strategy():
assert np.isclose(metrics["pa"].mean(), -16.21578303474833)
assert np.isclose(metrics["market_price"].mean(), 58.68277690875527)
assert np.isclose(metrics["trade_price"].mean(), 58.76063985000002)
def test_ppo_train():
set_log_with_config(C.logging_config)
# The data starts with 9:31 and ends with 15:00
orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58"))
assert len(orders) == 40
state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6)
action_interp = CategoricalActionInterpreter(4)
network = Recurrent(state_interp.observation_space)
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
train(
partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
state_interp,
action_interp,
orders,
policy,
PAPenaltyReward(),
vessel_kwargs={"episode_per_iter": 100, "update_kwargs": {"batch_size": 64, "repeat": 5}},
trainer_kwargs={"max_iters": 2, "loggers": ConsoleWriter(total_episodes=100)},
)

202
tests/rl/test_trainer.py Normal file
View File

@@ -0,0 +1,202 @@
import os
import random
import sys
from pathlib import Path
import pytest
import torch
import torch.nn as nn
from gym import spaces
from tianshou.policy import PPOPolicy
from qlib.config import C
from qlib.log import set_log_with_config
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.simulator import Simulator
from qlib.rl.reward import Reward
from qlib.rl.trainer import Trainer, TrainingVessel, EarlyStopping, Checkpoint
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
class ZeroSimulator(Simulator):
def __init__(self, *args, **kwargs):
self.action = self.correct = 0
def step(self, action):
self.action = action
self.correct = action == 0
self._done = random.choice([False, True])
if self._done:
self.env.logger.add_scalar("acc", self.correct * 100)
def get_state(self):
return {
"acc": self.correct * 100,
"action": self.action,
}
def done(self) -> bool:
return self._done
class NoopStateInterpreter(StateInterpreter):
observation_space = spaces.Dict(
{
"acc": spaces.Discrete(200),
"action": spaces.Discrete(2),
}
)
def interpret(self, simulator_state):
return simulator_state
class NoopActionInterpreter(ActionInterpreter):
action_space = spaces.Discrete(2)
def interpret(self, simulator_state, action):
return action
class AccReward(Reward):
def reward(self, simulator_state):
if self.env.status["done"]:
return simulator_state["acc"] / 100
return 0.0
class PolicyNet(nn.Module):
def __init__(self, out_features=1, return_state=False):
super().__init__()
self.fc = nn.Linear(32, out_features)
self.return_state = return_state
def forward(self, obs, state=None, **kwargs):
res = self.fc(torch.randn(obs["acc"].shape[0], 32))
if self.return_state:
return nn.functional.softmax(res, dim=-1), state
else:
return res
def _ppo_policy():
actor = PolicyNet(2, True)
critic = PolicyNet()
policy = PPOPolicy(
actor,
critic,
torch.optim.Adam(tuple(actor.parameters()) + tuple(critic.parameters())),
torch.distributions.Categorical,
action_space=NoopActionInterpreter().action_space,
)
return policy
def test_trainer():
set_log_with_config(C.logging_config)
trainer = Trainer(max_iters=10, finite_env_type="subproc")
policy = _ppo_policy()
vessel = TrainingVessel(
simulator_fn=lambda init: ZeroSimulator(init),
state_interpreter=NoopStateInterpreter(),
action_interpreter=NoopActionInterpreter(),
policy=policy,
train_initial_states=list(range(100)),
val_initial_states=list(range(10)),
test_initial_states=list(range(10)),
reward=AccReward(),
episode_per_iter=500,
update_kwargs=dict(repeat=10, batch_size=64),
)
trainer.fit(vessel)
assert trainer.current_iter == 10
assert trainer.current_episode == 5000
assert abs(trainer.metrics["acc"] - trainer.metrics["reward"] * 100) < 1e-4
assert trainer.metrics["acc"] > 80
trainer.test(vessel)
assert trainer.metrics["acc"] > 60
def test_trainer_fast_dev_run():
set_log_with_config(C.logging_config)
trainer = Trainer(max_iters=2, fast_dev_run=2, finite_env_type="shmem")
policy = _ppo_policy()
vessel = TrainingVessel(
simulator_fn=lambda init: ZeroSimulator(init),
state_interpreter=NoopStateInterpreter(),
action_interpreter=NoopActionInterpreter(),
policy=policy,
train_initial_states=list(range(100)),
val_initial_states=list(range(10)),
test_initial_states=list(range(10)),
reward=AccReward(),
episode_per_iter=500,
update_kwargs=dict(repeat=10, batch_size=64),
)
trainer.fit(vessel)
assert trainer.current_episode == 4
def test_trainer_earlystop():
# TODO this is just sanity check.
# need to see the logs to check whether it works.
set_log_with_config(C.logging_config)
trainer = Trainer(
max_iters=10,
val_every_n_iters=1,
finite_env_type="dummy",
callbacks=[EarlyStopping("val/reward", restore_best_weights=True)],
)
policy = _ppo_policy()
vessel = TrainingVessel(
simulator_fn=lambda init: ZeroSimulator(init),
state_interpreter=NoopStateInterpreter(),
action_interpreter=NoopActionInterpreter(),
policy=policy,
train_initial_states=list(range(100)),
val_initial_states=list(range(10)),
test_initial_states=list(range(10)),
reward=AccReward(),
episode_per_iter=500,
update_kwargs=dict(repeat=10, batch_size=64),
)
trainer.fit(vessel)
assert trainer.metrics["val/acc"] > 30
assert trainer.current_iter == 2 # second iteration
def test_trainer_checkpoint():
set_log_with_config(C.logging_config)
output_dir = Path(__file__).parent / ".output"
trainer = Trainer(max_iters=2, finite_env_type="dummy", callbacks=[Checkpoint(output_dir, every_n_iters=1)])
policy = _ppo_policy()
vessel = TrainingVessel(
simulator_fn=lambda init: ZeroSimulator(init),
state_interpreter=NoopStateInterpreter(),
action_interpreter=NoopActionInterpreter(),
policy=policy,
train_initial_states=list(range(100)),
val_initial_states=list(range(10)),
test_initial_states=list(range(10)),
reward=AccReward(),
episode_per_iter=100,
update_kwargs=dict(repeat=10, batch_size=64),
)
trainer.fit(vessel)
assert (output_dir / "001.pth").exists()
assert (output_dir / "002.pth").exists()
assert os.readlink(output_dir / "latest.pth") == str(output_dir / "002.pth")
trainer.load_state_dict(torch.load(output_dir / "001.pth"))
assert trainer.current_iter == 1
assert trainer.current_episode == 100
# Reload the checkpoint at first iteration
trainer.fit(vessel, ckpt_path=output_dir / "001.pth")