1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 11:30:57 +08:00
Files
qlib/qlib/rl/trainer/callbacks.py
Yuge Zhang 25ecb1135f Qlib RL framework (stage 2) - trainer (#1125)
* checkpoint

(cherry picked from commit 1a8e0bd4671ee6d624a7d09bb198a273282cd050)

* Not a workable version

(cherry picked from commit 3498e185684cd5590d3ab97e0ab69eab8c1e0e3a)

* vessel

* ckpt

* .

* vessel

* .

* .

* checkpoint callback

* .

* cleanup

* logger

* .

* test

* .

* add test

* .

* .

* .

* .

* New reward

* Add train API

* fix mypy

* fix lint

* More comment

* 3.7 compat

* fix test

* fix test

* .

* Resolve comments

* fix typehint
2022-06-28 19:53:05 +08:00

268 lines
10 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Callbacks to insert customized recipes during the training.
Mimicks the hooks of Keras / PyTorch-Lightning, but tailored for the context of RL.
"""
from __future__ import annotations
import copy
import shutil
import time
from datetime import datetime
from pathlib import Path
from typing import Any, TYPE_CHECKING
import numpy as np
import torch
from qlib.log import get_module_logger
from qlib.typehint import Literal
if TYPE_CHECKING:
from .trainer import Trainer
from .vessel import TrainingVesselBase
_logger = get_module_logger(__name__)
class Callback:
"""Base class of all callbacks."""
def on_fit_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called before the whole fit process begins."""
def on_fit_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called after the whole fit process ends."""
def on_train_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called when each collect for training begins."""
def on_train_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called when the training ends.
To access all outputs produced during training, cache the data in either trainer and vessel,
and post-process them in this hook.
"""
def on_validate_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called when every run for validation begins."""
def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called when the validation ends."""
def on_test_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called when every run of testing begins."""
def on_test_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called when the testing ends."""
def on_iter_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called when every iteration (i.e., collect) starts."""
def on_iter_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called upon every end of iteration.
This is called **after** the bump of ``current_iter``,
when the previous iteration is considered complete.
"""
def state_dict(self) -> Any:
"""Get a state dict of the callback for pause and resume."""
def load_state_dict(self, state_dict: Any) -> None:
"""Resume the callback from a saved state dict."""
class EarlyStopping(Callback):
"""Stop training when a monitored metric has stopped improving.
The earlystopping callback will be triggered each time validation ends.
It will examine the metrics produced in validation,
and get the metric with name ``monitor` (``monitor`` is ``reward`` by default),
to check whether it's no longer increasing / decreasing.
It takes ``min_delta`` and ``patience`` if applicable.
If it's found to be not increasing / decreasing any more.
``trainer.should_stop`` will be set to true,
and the training terminates.
Implementation reference: https://github.com/keras-team/keras/blob/v2.9.0/keras/callbacks.py#L1744-L1893
"""
def __init__(
self,
monitor: str = "reward",
min_delta: float = 0.0,
patience: int = 0,
mode: Literal["min", "max"] = "max",
baseline: float | None = None,
restore_best_weights: bool = False,
):
super().__init__()
self.monitor = monitor
self.patience = patience
self.baseline = baseline
self.min_delta = abs(min_delta)
self.restore_best_weights = restore_best_weights
self.best_weights: Any | None = None
if mode not in ["min", "max"]:
raise ValueError("Unsupported earlystopping mode: " + mode)
if mode == "min":
self.monitor_op = np.less
elif mode == "max":
self.monitor_op = np.greater
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
def state_dict(self) -> dict:
return {"wait": self.wait, "best": self.best, "best_weights": self.best_weights, "best_iter": self.best_iter}
def load_state_dict(self, state_dict: dict) -> None:
self.wait = state_dict["wait"]
self.best = state_dict["best"]
self.best_weights = state_dict["best_weights"]
self.best_iter = state_dict["best_iter"]
def on_fit_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
# Allow instances to be re-used
self.wait = 0
self.best = np.inf if self.monitor_op == np.less else -np.inf
self.best_weights = None
self.best_iter = 0
def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
current = self.get_monitor_value(trainer)
if current is None:
return
if self.restore_best_weights and self.best_weights is None:
# Restore the weights after first iteration if no progress is ever made.
self.best_weights = copy.deepcopy(vessel.state_dict())
self.wait += 1
if self._is_improvement(current, self.best):
self.best = current
self.best_iter = trainer.current_iter
if self.restore_best_weights:
self.best_weights = copy.deepcopy(vessel.state_dict())
# Only restart wait if we beat both the baseline and our previous best.
if self.baseline is None or self._is_improvement(current, self.baseline):
self.wait = 0
# Only check after the first epoch.
if self.wait >= self.patience and trainer.current_iter > 0:
trainer.should_stop = True
_logger.info(f"On iteration %d: early stopping", trainer.current_iter + 1)
if self.restore_best_weights and self.best_weights is not None:
_logger.info("Restoring model weights from the end of the best iteration: %d", self.best_iter + 1)
vessel.load_state_dict(self.best_weights)
def get_monitor_value(self, trainer: Trainer) -> Any:
monitor_value = trainer.metrics.get(self.monitor)
if monitor_value is None:
_logger.warning(
"Early stopping conditioned on metric `%s` which is not available. Available metrics are: %s",
self.monitor,
",".join(list(trainer.metrics.keys())),
)
return monitor_value
def _is_improvement(self, monitor_value, reference_value):
return self.monitor_op(monitor_value - self.min_delta, reference_value)
class Checkpoint(Callback):
"""Save checkpoints periodically for persistence and recovery.
Reference: https://github.com/PyTorchLightning/pytorch-lightning/blob/bfa8b7be/pytorch_lightning/callbacks/model_checkpoint.py
Parameters
----------
dirpath
Directory to save the checkpoint file.
filename
Checkpoint filename. Can contain named formatting options to be auto-filled.
For example: ``{iter:03d}-{reward:.2f}.pth``.
Supported argument names are:
- iter (int)
- metrics in ``trainer.metrics``
- time string, in the format of ``%Y%m%d%H%M%S``
save_latest
Save the latest checkpoint in ``latest.pth``.
If ``link``, ``latest.pth`` will be created as a softlink.
If ``copy``, ``latest.pth`` will be stored as an individual copy.
Set to none to disable this.
every_n_iters
Checkpoints are saved at the end of every n iterations of training,
after validation if applicable.
time_interval
Maximum time (seconds) before checkpoints save again.
save_on_fit_end
Save one last checkpoint at the end to fit.
Do nothing if a checkpoint is already saved there.
"""
def __init__(
self,
dirpath: Path,
filename: str = "{iter:03d}.pth",
save_latest: Literal["link", "copy"] | None = "link",
every_n_iters: int | None = None,
time_interval: int | None = None,
save_on_fit_end: bool = True,
):
self.dirpath = Path(dirpath)
self.filename = filename
self.save_latest = save_latest
self.every_n_iters = every_n_iters
self.time_interval = time_interval
self.save_on_fit_end = save_on_fit_end
self._last_checkpoint_name: str | None = None
self._last_checkpoint_iter: int | None = None
self._last_checkpoint_time: float | None = None
def on_fit_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
if self.save_on_fit_end and (trainer.current_iter != self._last_checkpoint_iter):
self._save_checkpoint(trainer)
def on_iter_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
should_save_ckpt = False
if self.every_n_iters is not None and (trainer.current_iter + 1) % self.every_n_iters == 0:
should_save_ckpt = True
if self.time_interval is not None and (
self._last_checkpoint_time is None or (time.time() - self._last_checkpoint_time) >= self.time_interval
):
should_save_ckpt = True
if should_save_ckpt:
self._save_checkpoint(trainer)
def _save_checkpoint(self, trainer: Trainer) -> None:
self.dirpath.mkdir(exist_ok=True, parents=True)
self._last_checkpoint_name = self._new_checkpoint_name(trainer)
self._last_checkpoint_iter = trainer.current_iter
self._last_checkpoint_time = time.time()
torch.save(trainer.state_dict(), self.dirpath / self._last_checkpoint_name)
latest_pth = self.dirpath / "latest.pth"
# Remove first before saving
if self.save_latest and latest_pth.exists():
latest_pth.unlink()
if self.save_latest == "link":
latest_pth.symlink_to(self.dirpath / self._last_checkpoint_name)
elif self.save_latest == "copy":
shutil.copyfile(self.dirpath / self._last_checkpoint_name, latest_pth)
def _new_checkpoint_name(self, trainer: Trainer) -> str:
return self.filename.format(
iter=trainer.current_iter, time=datetime.now().strftime("%Y%m%d%H%M%S"), **trainer.metrics
)