1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00
Files
qlib/qlib/rl/trainer/trainer.py
2025-03-12 14:11:42 +08:00

356 lines
13 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import collections
import copy
from contextlib import AbstractContextManager, contextmanager
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Iterable, List, OrderedDict, Sequence, TypeVar, cast
import torch
from qlib.log import get_module_logger
from qlib.rl.simulator import InitialStateType
from qlib.rl.utils import EnvWrapper, FiniteEnvType, LogBuffer, LogCollector, LogLevel, LogWriter, vectorize_env
from qlib.rl.utils.finite_env import FiniteVectorEnv
from qlib.typehint import Literal
from .callbacks import Callback
from .vessel import TrainingVesselBase
_logger = get_module_logger(__name__)
T = TypeVar("T")
class Trainer:
"""
Utility to train a policy on a particular task.
Different from traditional DL trainer, the iteration of this trainer is "collect",
rather than "epoch", or "mini-batch".
In each collect, :class:`Collector` collects a number of policy-env interactions, and accumulates
them into a replay buffer. This buffer is used as the "data" to train the policy.
At the end of each collect, the policy is *updated* several times.
The API has some resemblence with `PyTorch Lightning <https://pytorch-lightning.readthedocs.io/>`__,
but it's essentially different because this trainer is built for RL applications, and thus
most configurations are under RL context.
We are still looking for ways to incorporate existing trainer libraries, because it looks like
big efforts to build a trainer as powerful as those libraries, and also, that's not our primary goal.
It's essentially different
`tianshou's built-in trainers <https://tianshou.readthedocs.io/en/master/api/tianshou.trainer.html>`__,
as it's far much more complicated than that.
Parameters
----------
max_iters
Maximum iterations before stopping.
val_every_n_iters
Perform validation every n iterations (i.e., training collects).
logger
Logger to record the backtest results. Logger must be present because
without logger, all information will be lost.
finite_env_type
Type of finite env implementation.
concurrency
Parallel workers.
fast_dev_run
Create a subset for debugging.
How this is implemented depends on the implementation of training vessel.
For :class:`~qlib.rl.vessel.TrainingVessel`, if greater than zero,
a random subset sized ``fast_dev_run`` will be used
instead of ``train_initial_states`` and ``val_initial_states``.
"""
should_stop: bool
"""Set to stop the training."""
metrics: dict
"""Numeric metrics of produced in train/val/test.
In the middle of training / validation, metrics will be of the latest episode.
When each iteration of training / validation finishes, metrics will be the aggregation
of all episodes encountered in this iteration.
Cleared on every new iteration of training.
In fit, validation metrics will be prefixed with ``val/``.
"""
current_iter: int
"""Current iteration (collect) of training."""
loggers: List[LogWriter]
"""A list of log writers."""
def __init__(
self,
*,
max_iters: int | None = None,
val_every_n_iters: int | None = None,
loggers: LogWriter | List[LogWriter] | None = None,
callbacks: List[Callback] | None = None,
finite_env_type: FiniteEnvType = "subproc",
concurrency: int = 2,
fast_dev_run: int | None = None,
):
self.max_iters = max_iters
self.val_every_n_iters = val_every_n_iters
if isinstance(loggers, list):
self.loggers = loggers
elif isinstance(loggers, LogWriter):
self.loggers = [loggers]
else:
self.loggers = []
self.loggers.append(LogBuffer(self._metrics_callback, loglevel=self._min_loglevel()))
self.callbacks: List[Callback] = callbacks if callbacks is not None else []
self.finite_env_type = finite_env_type
self.concurrency = concurrency
self.fast_dev_run = fast_dev_run
self.current_stage: Literal["train", "val", "test"] = "train"
self.vessel: TrainingVesselBase = cast(TrainingVesselBase, None)
def initialize(self):
"""Initialize the whole training process.
The states here should be synchronized with state_dict.
"""
self.should_stop = False
self.current_iter = 0
self.current_episode = 0
self.current_stage = "train"
def initialize_iter(self):
"""Initialize one iteration / collect."""
self.metrics = {}
def state_dict(self) -> dict:
"""Putting every states of current training into a dict, at best effort.
It doesn't try to handle all the possible kinds of states in the middle of one training collect.
For most cases at the end of each iteration, things should be usually correct.
Note that it's also intended behavior that replay buffer data in the collector will be lost.
"""
return {
"vessel": self.vessel.state_dict(),
"callbacks": {name: callback.state_dict() for name, callback in self.named_callbacks().items()},
"loggers": {name: logger.state_dict() for name, logger in self.named_loggers().items()},
"should_stop": self.should_stop,
"current_iter": self.current_iter,
"current_episode": self.current_episode,
"current_stage": self.current_stage,
"metrics": self.metrics,
}
@staticmethod
def get_policy_state_dict(ckpt_path: Path) -> OrderedDict:
state_dict = torch.load(ckpt_path, map_location="cpu")
if "vessel" in state_dict:
state_dict = state_dict["vessel"]["policy"]
return state_dict
def load_state_dict(self, state_dict: dict) -> None:
"""Load all states into current trainer."""
self.vessel.load_state_dict(state_dict["vessel"])
for name, callback in self.named_callbacks().items():
callback.load_state_dict(state_dict["callbacks"][name])
for name, logger in self.named_loggers().items():
logger.load_state_dict(state_dict["loggers"][name])
self.should_stop = state_dict["should_stop"]
self.current_iter = state_dict["current_iter"]
self.current_episode = state_dict["current_episode"]
self.current_stage = state_dict["current_stage"]
self.metrics = state_dict["metrics"]
def named_callbacks(self) -> Dict[str, Callback]:
"""Retrieve a collection of callbacks where each one has a name.
Useful when saving checkpoints.
"""
return _named_collection(self.callbacks)
def named_loggers(self) -> Dict[str, LogWriter]:
"""Retrieve a collection of loggers where each one has a name.
Useful when saving checkpoints.
"""
return _named_collection(self.loggers)
def fit(self, vessel: TrainingVesselBase, ckpt_path: Path | None = None) -> None:
"""Train the RL policy upon the defined simulator.
Parameters
----------
vessel
A bundle of all elements used in training.
ckpt_path
Load a pre-trained / paused training checkpoint.
"""
self.vessel = vessel
vessel.assign_trainer(self)
if ckpt_path is not None:
_logger.info("Resuming states from %s", str(ckpt_path))
self.load_state_dict(torch.load(ckpt_path, weights_only=False))
else:
self.initialize()
self._call_callback_hooks("on_fit_start")
while not self.should_stop:
msg = f"\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tTrain iteration {self.current_iter + 1}/{self.max_iters}"
_logger.info(msg)
self.initialize_iter()
self._call_callback_hooks("on_iter_start")
self.current_stage = "train"
self._call_callback_hooks("on_train_start")
# TODO
# Add a feature that supports reloading the training environment every few iterations.
with _wrap_context(vessel.train_seed_iterator()) as iterator:
vector_env = self.venv_from_iterator(iterator)
self.vessel.train(vector_env)
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
self._call_callback_hooks("on_train_end")
if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0:
# Implementation of validation loop
self.current_stage = "val"
self._call_callback_hooks("on_validate_start")
with _wrap_context(vessel.val_seed_iterator()) as iterator:
vector_env = self.venv_from_iterator(iterator)
self.vessel.validate(vector_env)
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
self._call_callback_hooks("on_validate_end")
# This iteration is considered complete.
# Bumping the current iteration counter.
self.current_iter += 1
if self.max_iters is not None and self.current_iter >= self.max_iters:
self.should_stop = True
self._call_callback_hooks("on_iter_end")
self._call_callback_hooks("on_fit_end")
def test(self, vessel: TrainingVesselBase) -> None:
"""Test the RL policy against the simulator.
The simulator will be fed with data generated in ``test_seed_iterator``.
Parameters
----------
vessel
A bundle of all related elements.
"""
self.vessel = vessel
vessel.assign_trainer(self)
self.initialize_iter()
self.current_stage = "test"
self._call_callback_hooks("on_test_start")
with _wrap_context(vessel.test_seed_iterator()) as iterator:
vector_env = self.venv_from_iterator(iterator)
self.vessel.test(vector_env)
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
self._call_callback_hooks("on_test_end")
def venv_from_iterator(self, iterator: Iterable[InitialStateType]) -> FiniteVectorEnv:
"""Create a vectorized environment from iterator and the training vessel."""
def env_factory():
# FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
# and could be thread unsafe.
# I'm not sure whether it's a design flaw.
# I'll rethink about this when designing the trainer.
if self.finite_env_type == "dummy":
# We could only experience the "threading-unsafe" problem in dummy.
state = copy.deepcopy(self.vessel.state_interpreter)
action = copy.deepcopy(self.vessel.action_interpreter)
rew = copy.deepcopy(self.vessel.reward)
else:
state = self.vessel.state_interpreter
action = self.vessel.action_interpreter
rew = self.vessel.reward
return EnvWrapper(
self.vessel.simulator_fn,
state,
action,
iterator,
rew,
logger=LogCollector(min_loglevel=self._min_loglevel()),
)
return vectorize_env(
env_factory,
self.finite_env_type,
self.concurrency,
self.loggers,
)
def _metrics_callback(self, on_episode: bool, on_collect: bool, log_buffer: LogBuffer) -> None:
if on_episode:
# Update the global counter.
self.current_episode = log_buffer.global_episode
metrics = log_buffer.episode_metrics()
elif on_collect:
# Update the latest metrics.
metrics = log_buffer.collect_metrics()
if self.current_stage == "val":
metrics = {"val/" + name: value for name, value in metrics.items()}
self.metrics.update(metrics)
def _call_callback_hooks(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
for callback in self.callbacks:
fn = getattr(callback, hook_name)
fn(self, self.vessel, *args, **kwargs)
def _min_loglevel(self):
if not self.loggers:
return LogLevel.PERIODIC
else:
# To save bandwidth
return min(lg.loglevel for lg in self.loggers)
@contextmanager
def _wrap_context(obj):
"""Make any object a (possibly dummy) context manager."""
if isinstance(obj, AbstractContextManager):
# obj has __enter__ and __exit__
with obj as ctx:
yield ctx
else:
yield obj
def _named_collection(seq: Sequence[T]) -> Dict[str, T]:
"""Convert a list into a dict, where each item is named with its type."""
res = {}
retry_cnt: collections.Counter = collections.Counter()
for item in seq:
typename = type(item).__name__.lower()
key = typename if retry_cnt[typename] == 0 else f"{typename}{retry_cnt[typename]}"
retry_cnt[typename] += 1
res[key] = item
return res