mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Qlib RL framework (stage 2) - trainer (#1125)
* checkpoint (cherry picked from commit 1a8e0bd4671ee6d624a7d09bb198a273282cd050) * Not a workable version (cherry picked from commit 3498e185684cd5590d3ab97e0ab69eab8c1e0e3a) * vessel * ckpt * . * vessel * . * . * checkpoint callback * . * cleanup * logger * . * test * . * add test * . * . * . * . * New reward * Add train API * fix mypy * fix lint * More comment * 3.7 compat * fix test * fix test * . * Resolve comments * fix typehint
This commit is contained in:
@@ -1,7 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Train, test, inference utilities.
|
||||
|
||||
The APIs in this directory are NOT considered final and are subject to change!
|
||||
"""
|
||||
@@ -1,99 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Callable, Sequence
|
||||
|
||||
from tianshou.data import Collector
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
from qlib.constant import INF
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.rl.simulator import InitialStateType, Simulator
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.utils import DataQueue, EnvWrapper, FiniteEnvType, LogCollector, LogWriter, vectorize_env
|
||||
|
||||
|
||||
_logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
def backtest(
|
||||
simulator_fn: Callable[[InitialStateType], Simulator],
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
initial_states: Sequence[InitialStateType],
|
||||
policy: BasePolicy,
|
||||
logger: LogWriter | list[LogWriter],
|
||||
reward: Reward | None = None,
|
||||
finite_env_type: FiniteEnvType = "subproc",
|
||||
concurrency: int = 2,
|
||||
) -> None:
|
||||
"""Backtest with the parallelism provided by RL framework.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
simulator_fn
|
||||
Callable receiving initial seed, returning a simulator.
|
||||
state_interpreter
|
||||
Interprets the state of simulators.
|
||||
action_interpreter
|
||||
Interprets the policy actions.
|
||||
initial_states
|
||||
Initial states to iterate over. Every state will be run exactly once.
|
||||
policy
|
||||
Policy to test against.
|
||||
logger
|
||||
Logger to record the backtest results. Logger must be present because
|
||||
without logger, all information will be lost.
|
||||
reward
|
||||
Optional reward function. For backtest, this is for testing the rewards
|
||||
and logging them only.
|
||||
finite_env_type
|
||||
Type of finite env implementation.
|
||||
concurrency
|
||||
Parallel workers.
|
||||
"""
|
||||
|
||||
# To save bandwidth
|
||||
min_loglevel = min(lg.loglevel for lg in logger) if isinstance(logger, list) else logger.loglevel
|
||||
|
||||
def env_factory():
|
||||
# FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
|
||||
# and could be thread unsafe.
|
||||
# I'm not sure whether it's a design flaw.
|
||||
# I'll rethink about this when designing the trainer.
|
||||
|
||||
if finite_env_type == "dummy":
|
||||
# We could only experience the "threading-unsafe" problem in dummy.
|
||||
state = copy.deepcopy(state_interpreter)
|
||||
action = copy.deepcopy(action_interpreter)
|
||||
rew = copy.deepcopy(reward)
|
||||
else:
|
||||
state, action, rew = state_interpreter, action_interpreter, reward
|
||||
|
||||
return EnvWrapper(
|
||||
simulator_fn,
|
||||
state,
|
||||
action,
|
||||
seed_iterator,
|
||||
rew,
|
||||
logger=LogCollector(min_loglevel=min_loglevel),
|
||||
)
|
||||
|
||||
with DataQueue(initial_states) as seed_iterator:
|
||||
vector_env = vectorize_env(
|
||||
env_factory,
|
||||
finite_env_type,
|
||||
concurrency,
|
||||
logger,
|
||||
)
|
||||
|
||||
policy.eval()
|
||||
|
||||
with vector_env.collector_guard():
|
||||
test_collector = Collector(policy, vector_env)
|
||||
_logger.info("All ready. Start backtest.")
|
||||
test_collector.collect(n_step=INF * len(vector_env))
|
||||
@@ -1,4 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# TBD
|
||||
@@ -9,4 +9,5 @@ Multi-asset is on the way.
|
||||
from .interpreter import *
|
||||
from .network import *
|
||||
from .policy import *
|
||||
from .reward import *
|
||||
from .simulator_simple import *
|
||||
|
||||
46
qlib/rl/order_execution/reward.py
Normal file
46
qlib/rl/order_execution/reward.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
from qlib.rl.reward import Reward
|
||||
|
||||
from .simulator_simple import SAOEState, SAOEMetrics
|
||||
|
||||
__all__ = ["PAPenaltyReward"]
|
||||
|
||||
|
||||
class PAPenaltyReward(Reward[SAOEState]):
|
||||
"""Encourage higher PAs, but penalize stacking all the amounts within a very short time.
|
||||
Formally, for each time step, the reward is :math:`(PA_t * vol_t / target - vol_t^2 * penalty)`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
penalty
|
||||
The penalty for large volume in a short time.
|
||||
"""
|
||||
|
||||
def __init__(self, penalty: float = 100.0):
|
||||
self.penalty = penalty
|
||||
|
||||
def reward(self, simulator_state: SAOEState) -> float:
|
||||
whole_order = simulator_state.order.amount
|
||||
assert whole_order > 0
|
||||
last_step = cast(SAOEMetrics, simulator_state.history_steps.reset_index().iloc[-1].to_dict())
|
||||
pa = last_step["pa"] * last_step["amount"] / whole_order
|
||||
|
||||
# Inspect the "break-down" of the latest step: trading amount at every tick
|
||||
last_step_breakdown = simulator_state.history_exec.loc[last_step["datetime"] :]
|
||||
penalty = -self.penalty * ((last_step_breakdown["amount"] / whole_order) ** 2).sum()
|
||||
|
||||
reward = pa + penalty
|
||||
|
||||
# Throw error in case of NaN
|
||||
assert not (np.isnan(reward) or np.isinf(reward)), f"Invalid reward for simulator state: {simulator_state}"
|
||||
|
||||
self.log("reward/pa", pa)
|
||||
self.log("reward/penalty", penalty)
|
||||
return reward
|
||||
@@ -131,11 +131,14 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
"""
|
||||
|
||||
history_exec: pd.DataFrame
|
||||
"""All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns."""
|
||||
"""All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns.
|
||||
Index is ``datetime``.
|
||||
"""
|
||||
|
||||
history_steps: pd.DataFrame
|
||||
"""Positions at each step. The position before first step is also recorded.
|
||||
See :class:`SAOEMetrics` for available columns."""
|
||||
See :class:`SAOEMetrics` for available columns.
|
||||
Index is ``datetime``, which is the **starting** time of each step."""
|
||||
|
||||
metrics: SAOEMetrics | None
|
||||
"""Metrics. Only available when done."""
|
||||
|
||||
9
qlib/rl/trainer/__init__.py
Normal file
9
qlib/rl/trainer/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Train, test, inference utilities."""
|
||||
|
||||
from .api import backtest, train
|
||||
from .callbacks import EarlyStopping, Checkpoint
|
||||
from .trainer import Trainer
|
||||
from .vessel import TrainingVessel, TrainingVesselBase
|
||||
118
qlib/rl/trainer/api.py
Normal file
118
qlib/rl/trainer/api.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, Sequence, cast, Any
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
from qlib.rl.simulator import InitialStateType, Simulator
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.utils import FiniteEnvType, LogWriter
|
||||
|
||||
from .vessel import TrainingVessel
|
||||
from .trainer import Trainer
|
||||
|
||||
|
||||
def train(
|
||||
simulator_fn: Callable[[InitialStateType], Simulator],
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
initial_states: Sequence[InitialStateType],
|
||||
policy: BasePolicy,
|
||||
reward: Reward,
|
||||
vessel_kwargs: dict[str, Any],
|
||||
trainer_kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Train a policy with the parallelism provided by RL framework.
|
||||
|
||||
Experimental API. Parameters might change shortly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
simulator_fn
|
||||
Callable receiving initial seed, returning a simulator.
|
||||
state_interpreter
|
||||
Interprets the state of simulators.
|
||||
action_interpreter
|
||||
Interprets the policy actions.
|
||||
initial_states
|
||||
Initial states to iterate over. Every state will be run exactly once.
|
||||
policy
|
||||
Policy to train against.
|
||||
reward
|
||||
Reward function.
|
||||
vessel_kwargs
|
||||
Keyword arguments passed to :class:`TrainingVessel`, like ``episode_per_iter``.
|
||||
trainer_kwargs
|
||||
Keyword arguments passed to :class:`Trainer`, like ``finite_env_type``, ``concurrency``.
|
||||
"""
|
||||
|
||||
vessel = TrainingVessel(
|
||||
simulator_fn=simulator_fn,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
policy=policy,
|
||||
train_initial_states=initial_states,
|
||||
reward=reward, # ignore none
|
||||
**vessel_kwargs,
|
||||
)
|
||||
trainer = Trainer(**trainer_kwargs)
|
||||
trainer.fit(vessel)
|
||||
|
||||
|
||||
def backtest(
|
||||
simulator_fn: Callable[[InitialStateType], Simulator],
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
initial_states: Sequence[InitialStateType],
|
||||
policy: BasePolicy,
|
||||
logger: LogWriter | list[LogWriter],
|
||||
reward: Reward | None = None,
|
||||
finite_env_type: FiniteEnvType = "subproc",
|
||||
concurrency: int = 2,
|
||||
) -> None:
|
||||
"""Backtest with the parallelism provided by RL framework.
|
||||
|
||||
Experimental API. Parameters might change shortly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
simulator_fn
|
||||
Callable receiving initial seed, returning a simulator.
|
||||
state_interpreter
|
||||
Interprets the state of simulators.
|
||||
action_interpreter
|
||||
Interprets the policy actions.
|
||||
initial_states
|
||||
Initial states to iterate over. Every state will be run exactly once.
|
||||
policy
|
||||
Policy to test against.
|
||||
logger
|
||||
Logger to record the backtest results. Logger must be present because
|
||||
without logger, all information will be lost.
|
||||
reward
|
||||
Optional reward function. For backtest, this is for testing the rewards
|
||||
and logging them only.
|
||||
finite_env_type
|
||||
Type of finite env implementation.
|
||||
concurrency
|
||||
Parallel workers.
|
||||
"""
|
||||
|
||||
vessel = TrainingVessel(
|
||||
simulator_fn=simulator_fn,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
policy=policy,
|
||||
test_initial_states=initial_states,
|
||||
reward=cast(Reward, reward), # ignore none
|
||||
)
|
||||
trainer = Trainer(
|
||||
finite_env_type=finite_env_type,
|
||||
concurrency=concurrency,
|
||||
loggers=logger,
|
||||
)
|
||||
trainer.test(vessel)
|
||||
267
qlib/rl/trainer/callbacks.py
Normal file
267
qlib/rl/trainer/callbacks.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Callbacks to insert customized recipes during the training.
|
||||
Mimicks the hooks of Keras / PyTorch-Lightning, but tailored for the context of RL.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import shutil
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.typehint import Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .trainer import Trainer
|
||||
from .vessel import TrainingVesselBase
|
||||
|
||||
|
||||
_logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
class Callback:
|
||||
"""Base class of all callbacks."""
|
||||
|
||||
def on_fit_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
"""Called before the whole fit process begins."""
|
||||
|
||||
def on_fit_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
"""Called after the whole fit process ends."""
|
||||
|
||||
def on_train_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
"""Called when each collect for training begins."""
|
||||
|
||||
def on_train_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
"""Called when the training ends.
|
||||
To access all outputs produced during training, cache the data in either trainer and vessel,
|
||||
and post-process them in this hook.
|
||||
"""
|
||||
|
||||
def on_validate_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
"""Called when every run for validation begins."""
|
||||
|
||||
def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
"""Called when the validation ends."""
|
||||
|
||||
def on_test_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
"""Called when every run of testing begins."""
|
||||
|
||||
def on_test_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
"""Called when the testing ends."""
|
||||
|
||||
def on_iter_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
"""Called when every iteration (i.e., collect) starts."""
|
||||
|
||||
def on_iter_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
"""Called upon every end of iteration.
|
||||
This is called **after** the bump of ``current_iter``,
|
||||
when the previous iteration is considered complete.
|
||||
"""
|
||||
|
||||
def state_dict(self) -> Any:
|
||||
"""Get a state dict of the callback for pause and resume."""
|
||||
|
||||
def load_state_dict(self, state_dict: Any) -> None:
|
||||
"""Resume the callback from a saved state dict."""
|
||||
|
||||
|
||||
class EarlyStopping(Callback):
|
||||
"""Stop training when a monitored metric has stopped improving.
|
||||
|
||||
The earlystopping callback will be triggered each time validation ends.
|
||||
It will examine the metrics produced in validation,
|
||||
and get the metric with name ``monitor` (``monitor`` is ``reward`` by default),
|
||||
to check whether it's no longer increasing / decreasing.
|
||||
It takes ``min_delta`` and ``patience`` if applicable.
|
||||
If it's found to be not increasing / decreasing any more.
|
||||
``trainer.should_stop`` will be set to true,
|
||||
and the training terminates.
|
||||
|
||||
Implementation reference: https://github.com/keras-team/keras/blob/v2.9.0/keras/callbacks.py#L1744-L1893
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
monitor: str = "reward",
|
||||
min_delta: float = 0.0,
|
||||
patience: int = 0,
|
||||
mode: Literal["min", "max"] = "max",
|
||||
baseline: float | None = None,
|
||||
restore_best_weights: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.monitor = monitor
|
||||
self.patience = patience
|
||||
self.baseline = baseline
|
||||
self.min_delta = abs(min_delta)
|
||||
self.restore_best_weights = restore_best_weights
|
||||
self.best_weights: Any | None = None
|
||||
|
||||
if mode not in ["min", "max"]:
|
||||
raise ValueError("Unsupported earlystopping mode: " + mode)
|
||||
|
||||
if mode == "min":
|
||||
self.monitor_op = np.less
|
||||
elif mode == "max":
|
||||
self.monitor_op = np.greater
|
||||
|
||||
if self.monitor_op == np.greater:
|
||||
self.min_delta *= 1
|
||||
else:
|
||||
self.min_delta *= -1
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
return {"wait": self.wait, "best": self.best, "best_weights": self.best_weights, "best_iter": self.best_iter}
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
self.wait = state_dict["wait"]
|
||||
self.best = state_dict["best"]
|
||||
self.best_weights = state_dict["best_weights"]
|
||||
self.best_iter = state_dict["best_iter"]
|
||||
|
||||
def on_fit_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
# Allow instances to be re-used
|
||||
self.wait = 0
|
||||
self.best = np.inf if self.monitor_op == np.less else -np.inf
|
||||
self.best_weights = None
|
||||
self.best_iter = 0
|
||||
|
||||
def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
current = self.get_monitor_value(trainer)
|
||||
if current is None:
|
||||
return
|
||||
if self.restore_best_weights and self.best_weights is None:
|
||||
# Restore the weights after first iteration if no progress is ever made.
|
||||
self.best_weights = copy.deepcopy(vessel.state_dict())
|
||||
|
||||
self.wait += 1
|
||||
if self._is_improvement(current, self.best):
|
||||
self.best = current
|
||||
self.best_iter = trainer.current_iter
|
||||
if self.restore_best_weights:
|
||||
self.best_weights = copy.deepcopy(vessel.state_dict())
|
||||
# Only restart wait if we beat both the baseline and our previous best.
|
||||
if self.baseline is None or self._is_improvement(current, self.baseline):
|
||||
self.wait = 0
|
||||
|
||||
# Only check after the first epoch.
|
||||
if self.wait >= self.patience and trainer.current_iter > 0:
|
||||
trainer.should_stop = True
|
||||
_logger.info(f"On iteration %d: early stopping", trainer.current_iter + 1)
|
||||
if self.restore_best_weights and self.best_weights is not None:
|
||||
_logger.info("Restoring model weights from the end of the best iteration: %d", self.best_iter + 1)
|
||||
vessel.load_state_dict(self.best_weights)
|
||||
|
||||
def get_monitor_value(self, trainer: Trainer) -> Any:
|
||||
monitor_value = trainer.metrics.get(self.monitor)
|
||||
if monitor_value is None:
|
||||
_logger.warning(
|
||||
"Early stopping conditioned on metric `%s` which is not available. Available metrics are: %s",
|
||||
self.monitor,
|
||||
",".join(list(trainer.metrics.keys())),
|
||||
)
|
||||
return monitor_value
|
||||
|
||||
def _is_improvement(self, monitor_value, reference_value):
|
||||
return self.monitor_op(monitor_value - self.min_delta, reference_value)
|
||||
|
||||
|
||||
class Checkpoint(Callback):
|
||||
"""Save checkpoints periodically for persistence and recovery.
|
||||
|
||||
Reference: https://github.com/PyTorchLightning/pytorch-lightning/blob/bfa8b7be/pytorch_lightning/callbacks/model_checkpoint.py
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dirpath
|
||||
Directory to save the checkpoint file.
|
||||
filename
|
||||
Checkpoint filename. Can contain named formatting options to be auto-filled.
|
||||
For example: ``{iter:03d}-{reward:.2f}.pth``.
|
||||
Supported argument names are:
|
||||
|
||||
- iter (int)
|
||||
- metrics in ``trainer.metrics``
|
||||
- time string, in the format of ``%Y%m%d%H%M%S``
|
||||
save_latest
|
||||
Save the latest checkpoint in ``latest.pth``.
|
||||
If ``link``, ``latest.pth`` will be created as a softlink.
|
||||
If ``copy``, ``latest.pth`` will be stored as an individual copy.
|
||||
Set to none to disable this.
|
||||
every_n_iters
|
||||
Checkpoints are saved at the end of every n iterations of training,
|
||||
after validation if applicable.
|
||||
time_interval
|
||||
Maximum time (seconds) before checkpoints save again.
|
||||
save_on_fit_end
|
||||
Save one last checkpoint at the end to fit.
|
||||
Do nothing if a checkpoint is already saved there.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dirpath: Path,
|
||||
filename: str = "{iter:03d}.pth",
|
||||
save_latest: Literal["link", "copy"] | None = "link",
|
||||
every_n_iters: int | None = None,
|
||||
time_interval: int | None = None,
|
||||
save_on_fit_end: bool = True,
|
||||
):
|
||||
self.dirpath = Path(dirpath)
|
||||
self.filename = filename
|
||||
self.save_latest = save_latest
|
||||
self.every_n_iters = every_n_iters
|
||||
self.time_interval = time_interval
|
||||
self.save_on_fit_end = save_on_fit_end
|
||||
|
||||
self._last_checkpoint_name: str | None = None
|
||||
self._last_checkpoint_iter: int | None = None
|
||||
self._last_checkpoint_time: float | None = None
|
||||
|
||||
def on_fit_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
if self.save_on_fit_end and (trainer.current_iter != self._last_checkpoint_iter):
|
||||
self._save_checkpoint(trainer)
|
||||
|
||||
def on_iter_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
should_save_ckpt = False
|
||||
if self.every_n_iters is not None and (trainer.current_iter + 1) % self.every_n_iters == 0:
|
||||
should_save_ckpt = True
|
||||
if self.time_interval is not None and (
|
||||
self._last_checkpoint_time is None or (time.time() - self._last_checkpoint_time) >= self.time_interval
|
||||
):
|
||||
should_save_ckpt = True
|
||||
if should_save_ckpt:
|
||||
self._save_checkpoint(trainer)
|
||||
|
||||
def _save_checkpoint(self, trainer: Trainer) -> None:
|
||||
self.dirpath.mkdir(exist_ok=True, parents=True)
|
||||
self._last_checkpoint_name = self._new_checkpoint_name(trainer)
|
||||
self._last_checkpoint_iter = trainer.current_iter
|
||||
self._last_checkpoint_time = time.time()
|
||||
torch.save(trainer.state_dict(), self.dirpath / self._last_checkpoint_name)
|
||||
|
||||
latest_pth = self.dirpath / "latest.pth"
|
||||
|
||||
# Remove first before saving
|
||||
if self.save_latest and latest_pth.exists():
|
||||
latest_pth.unlink()
|
||||
|
||||
if self.save_latest == "link":
|
||||
latest_pth.symlink_to(self.dirpath / self._last_checkpoint_name)
|
||||
elif self.save_latest == "copy":
|
||||
shutil.copyfile(self.dirpath / self._last_checkpoint_name, latest_pth)
|
||||
|
||||
def _new_checkpoint_name(self, trainer: Trainer) -> str:
|
||||
return self.filename.format(
|
||||
iter=trainer.current_iter, time=datetime.now().strftime("%Y%m%d%H%M%S"), **trainer.metrics
|
||||
)
|
||||
343
qlib/rl/trainer/trainer.py
Normal file
343
qlib/rl/trainer/trainer.py
Normal file
@@ -0,0 +1,343 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, TypeVar, Sequence, cast
|
||||
|
||||
import torch
|
||||
|
||||
from qlib.rl.simulator import InitialStateType
|
||||
from qlib.rl.utils import EnvWrapper, FiniteEnvType, LogCollector, LogWriter, LogBuffer, vectorize_env, LogLevel
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.rl.utils.finite_env import FiniteVectorEnv
|
||||
from qlib.typehint import Literal
|
||||
|
||||
from .callbacks import Callback
|
||||
from .vessel import TrainingVesselBase
|
||||
|
||||
_logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
Utility to train a policy on a particular task.
|
||||
|
||||
Different from traditional DL trainer, the iteration of this trainer is "collect",
|
||||
rather than "epoch", or "mini-batch".
|
||||
In each collect, :class:`Collector` collects a number of policy-env interactions, and accumulates
|
||||
them into a replay buffer. This buffer is used as the "data" to train the policy.
|
||||
At the end of each collect, the policy is *updated* several times.
|
||||
|
||||
The API has some resemblence with `PyTorch Lightning <https://pytorch-lightning.readthedocs.io/>`__,
|
||||
but it's essentially different because this trainer is built for RL applications, and thus
|
||||
most configurations are under RL context.
|
||||
We are still looking for ways to incorporate existing trainer libraries, because it looks like
|
||||
big efforts to build a trainer as powerful as those libraries, and also, that's not our primary goal.
|
||||
|
||||
It's essentially different
|
||||
`tianshou's built-in trainers <https://tianshou.readthedocs.io/en/master/api/tianshou.trainer.html>`__,
|
||||
as it's far much more complicated than that.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_iters
|
||||
Maximum iterations before stopping.
|
||||
val_every_n_iters
|
||||
Perform validation every n iterations (i.e., training collects).
|
||||
logger
|
||||
Logger to record the backtest results. Logger must be present because
|
||||
without logger, all information will be lost.
|
||||
finite_env_type
|
||||
Type of finite env implementation.
|
||||
concurrency
|
||||
Parallel workers.
|
||||
fast_dev_run
|
||||
Create a subset for debugging.
|
||||
How this is implemented depends on the implementation of training vessel.
|
||||
For :class:`~qlib.rl.vessel.TrainingVessel`, if greater than zero,
|
||||
a random subset sized ``fast_dev_run`` will be used
|
||||
instead of ``train_initial_states`` and ``val_initial_states``.
|
||||
"""
|
||||
|
||||
should_stop: bool
|
||||
"""Set to stop the training."""
|
||||
|
||||
metrics: dict
|
||||
"""Numeric metrics of produced in train/val/test.
|
||||
In the middle of training / validation, metrics will be of the latest episode.
|
||||
When each iteration of training / validation finishes, metrics will be the aggregation
|
||||
of all episodes encountered in this iteration.
|
||||
|
||||
Cleared on every new iteration of training.
|
||||
|
||||
In fit, validation metrics will be prefixed with ``val/``.
|
||||
"""
|
||||
|
||||
current_iter: int
|
||||
"""Current iteration (collect) of training."""
|
||||
|
||||
loggers: list[LogWriter]
|
||||
"""A list of log writers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_iters: int | None = None,
|
||||
val_every_n_iters: int | None = None,
|
||||
loggers: LogWriter | list[LogWriter] | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
finite_env_type: FiniteEnvType = "subproc",
|
||||
concurrency: int = 2,
|
||||
fast_dev_run: int | None = None,
|
||||
):
|
||||
self.max_iters = max_iters
|
||||
self.val_every_n_iters = val_every_n_iters
|
||||
|
||||
if isinstance(loggers, list):
|
||||
self.loggers = loggers
|
||||
elif isinstance(loggers, LogWriter):
|
||||
self.loggers = [loggers]
|
||||
else:
|
||||
self.loggers = []
|
||||
|
||||
self.loggers.append(LogBuffer(self._metrics_callback, loglevel=self._min_loglevel()))
|
||||
|
||||
self.callbacks: list[Callback] = callbacks if callbacks is not None else []
|
||||
self.finite_env_type = finite_env_type
|
||||
self.concurrency = concurrency
|
||||
self.fast_dev_run = fast_dev_run
|
||||
|
||||
self.current_stage: Literal["train", "val", "test"] = "train"
|
||||
|
||||
self.vessel: TrainingVesselBase = cast(TrainingVesselBase, None)
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize the whole training process.
|
||||
|
||||
The states here should be synchronized with state_dict.
|
||||
"""
|
||||
self.should_stop = False
|
||||
self.current_iter = 0
|
||||
self.current_episode = 0
|
||||
self.current_stage = "train"
|
||||
|
||||
def initialize_iter(self):
|
||||
"""Initialize one iteration / collect."""
|
||||
self.metrics = {}
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
"""Putting every states of current training into a dict, at best effort.
|
||||
|
||||
It doesn't try to handle all the possible kinds of states in the middle of one training collect.
|
||||
For most cases at the end of each iteration, things should be usually correct.
|
||||
|
||||
Note that it's also intended behavior that replay buffer data in the collector will be lost.
|
||||
"""
|
||||
return {
|
||||
"vessel": self.vessel.state_dict(),
|
||||
"callbacks": {name: callback.state_dict() for name, callback in self.named_callbacks().items()},
|
||||
"loggers": {name: logger.state_dict() for name, logger in self.named_loggers().items()},
|
||||
"should_stop": self.should_stop,
|
||||
"current_iter": self.current_iter,
|
||||
"current_episode": self.current_episode,
|
||||
"current_stage": self.current_stage,
|
||||
"metrics": self.metrics,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
"""Load all states into current trainer."""
|
||||
self.vessel.load_state_dict(state_dict["vessel"])
|
||||
for name, callback in self.named_callbacks().items():
|
||||
callback.load_state_dict(state_dict["callbacks"][name])
|
||||
for name, logger in self.named_loggers().items():
|
||||
logger.load_state_dict(state_dict["loggers"][name])
|
||||
self.should_stop = state_dict["should_stop"]
|
||||
self.current_iter = state_dict["current_iter"]
|
||||
self.current_episode = state_dict["current_episode"]
|
||||
self.current_stage = state_dict["current_stage"]
|
||||
self.metrics = state_dict["metrics"]
|
||||
|
||||
def named_callbacks(self) -> dict[str, Callback]:
|
||||
"""Retrieve a collection of callbacks where each one has a name.
|
||||
Useful when saving checkpoints.
|
||||
"""
|
||||
return _named_collection(self.callbacks)
|
||||
|
||||
def named_loggers(self) -> dict[str, LogWriter]:
|
||||
"""Retrieve a collection of loggers where each one has a name.
|
||||
Useful when saving checkpoints.
|
||||
"""
|
||||
return _named_collection(self.loggers)
|
||||
|
||||
def fit(self, vessel: TrainingVesselBase, ckpt_path: Path | None = None) -> None:
|
||||
"""Train the RL policy upon the defined simulator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vessel
|
||||
A bundle of all elements used in training.
|
||||
ckpt_path
|
||||
Load a pre-trained / paused training checkpoint.
|
||||
"""
|
||||
self.vessel = vessel
|
||||
vessel.assign_trainer(self)
|
||||
|
||||
if ckpt_path is not None:
|
||||
_logger.info("Resuming states from %s", str(ckpt_path))
|
||||
self.load_state_dict(torch.load(ckpt_path))
|
||||
else:
|
||||
self.initialize()
|
||||
|
||||
self._call_callback_hooks("on_fit_start")
|
||||
|
||||
while not self.should_stop:
|
||||
self.initialize_iter()
|
||||
|
||||
self._call_callback_hooks("on_iter_start")
|
||||
|
||||
self.current_stage = "train"
|
||||
self._call_callback_hooks("on_train_start")
|
||||
|
||||
# TODO
|
||||
# Add a feature that supports reloading the training environment every few iterations.
|
||||
with _wrap_context(vessel.train_seed_iterator()) as iterator:
|
||||
vector_env = self.venv_from_iterator(iterator)
|
||||
self.vessel.train(vector_env)
|
||||
|
||||
self._call_callback_hooks("on_train_end")
|
||||
|
||||
if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0:
|
||||
# Implementation of validation loop
|
||||
self.current_stage = "val"
|
||||
self._call_callback_hooks("on_validate_start")
|
||||
with _wrap_context(vessel.val_seed_iterator()) as iterator:
|
||||
vector_env = self.venv_from_iterator(iterator)
|
||||
self.vessel.validate(vector_env)
|
||||
|
||||
self._call_callback_hooks("on_validate_end")
|
||||
|
||||
# This iteration is considered complete.
|
||||
# Bumping the current iteration counter.
|
||||
self.current_iter += 1
|
||||
|
||||
if self.max_iters is not None and self.current_iter >= self.max_iters:
|
||||
self.should_stop = True
|
||||
|
||||
self._call_callback_hooks("on_iter_end")
|
||||
|
||||
self._call_callback_hooks("on_fit_end")
|
||||
|
||||
def test(self, vessel: TrainingVesselBase) -> None:
|
||||
"""Test the RL policy against the simulator.
|
||||
|
||||
The simulator will be fed with data generated in ``test_seed_iterator``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vessel
|
||||
A bundle of all related elements.
|
||||
"""
|
||||
self.vessel = vessel
|
||||
vessel.assign_trainer(self)
|
||||
|
||||
self.initialize_iter()
|
||||
|
||||
self.current_stage = "test"
|
||||
self._call_callback_hooks("on_test_start")
|
||||
with _wrap_context(vessel.test_seed_iterator()) as iterator:
|
||||
vector_env = self.venv_from_iterator(iterator)
|
||||
self.vessel.test(vector_env)
|
||||
self._call_callback_hooks("on_test_end")
|
||||
|
||||
def venv_from_iterator(self, iterator: Iterable[InitialStateType]) -> FiniteVectorEnv:
|
||||
"""Create a vectorized environment from iterator and the training vessel."""
|
||||
|
||||
def env_factory():
|
||||
# FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
|
||||
# and could be thread unsafe.
|
||||
# I'm not sure whether it's a design flaw.
|
||||
# I'll rethink about this when designing the trainer.
|
||||
|
||||
if self.finite_env_type == "dummy":
|
||||
# We could only experience the "threading-unsafe" problem in dummy.
|
||||
state = copy.deepcopy(self.vessel.state_interpreter)
|
||||
action = copy.deepcopy(self.vessel.action_interpreter)
|
||||
rew = copy.deepcopy(self.vessel.reward)
|
||||
else:
|
||||
state = self.vessel.state_interpreter
|
||||
action = self.vessel.action_interpreter
|
||||
rew = self.vessel.reward
|
||||
|
||||
return EnvWrapper(
|
||||
self.vessel.simulator_fn,
|
||||
state,
|
||||
action,
|
||||
iterator,
|
||||
rew,
|
||||
logger=LogCollector(min_loglevel=self._min_loglevel()),
|
||||
)
|
||||
|
||||
return vectorize_env(
|
||||
env_factory,
|
||||
self.finite_env_type,
|
||||
self.concurrency,
|
||||
self.loggers,
|
||||
)
|
||||
|
||||
def _metrics_callback(self, on_episode: bool, on_collect: bool, log_buffer: LogBuffer) -> None:
|
||||
if on_episode:
|
||||
# Update the global counter.
|
||||
self.current_episode = log_buffer.global_episode
|
||||
metrics = log_buffer.episode_metrics()
|
||||
elif on_collect:
|
||||
# Update the latest metrics.
|
||||
metrics = log_buffer.collect_metrics()
|
||||
if self.current_stage == "val":
|
||||
metrics = {"val/" + name: value for name, value in metrics.items()}
|
||||
self.metrics.update(metrics)
|
||||
|
||||
def _call_callback_hooks(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
|
||||
for callback in self.callbacks:
|
||||
fn = getattr(callback, hook_name)
|
||||
fn(self, self.vessel, *args, **kwargs)
|
||||
|
||||
def _min_loglevel(self):
|
||||
if not self.loggers:
|
||||
return LogLevel.PERIODIC
|
||||
else:
|
||||
# To save bandwidth
|
||||
return min(lg.loglevel for lg in self.loggers)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _wrap_context(obj):
|
||||
"""Make any object a (possibly dummy) context manager."""
|
||||
|
||||
if isinstance(obj, AbstractContextManager):
|
||||
# obj has __enter__ and __exit__
|
||||
with obj as ctx:
|
||||
yield ctx
|
||||
else:
|
||||
yield obj
|
||||
|
||||
|
||||
def _named_collection(seq: Sequence[T]) -> dict[str, T]:
|
||||
"""Convert a list into a dict, where each item is named with its type."""
|
||||
res = {}
|
||||
for item in seq:
|
||||
typename = type(item).__name__.lower()
|
||||
if typename not in res:
|
||||
res[typename] = item
|
||||
else:
|
||||
# names are auto-labelled as earlystop1, earlystop2, ...
|
||||
for retry in range(1, 1000):
|
||||
if f"{typename}{retry}" not in res:
|
||||
res[f"{typename}{retry}"] = item
|
||||
return res
|
||||
214
qlib/rl/trainer/vessel.py
Normal file
214
qlib/rl/trainer/vessel.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import weakref
|
||||
from typing import Callable, ContextManager, Generic, Iterable, TYPE_CHECKING, Sequence, Any, TypeVar, cast, Dict
|
||||
|
||||
import numpy as np
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import BaseVectorEnv
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
from qlib.constant import INF
|
||||
from qlib.rl.interpreter import StateType, ActType, ObsType, PolicyActType
|
||||
from qlib.rl.simulator import InitialStateType, Simulator
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.utils import DataQueue
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.rl.utils.finite_env import FiniteVectorEnv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .trainer import Trainer
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
_logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
class SeedIteratorNotAvailable(BaseException):
|
||||
pass
|
||||
|
||||
|
||||
class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType]):
|
||||
"""A ship that contains simulator, interpreter, and policy, will be sent to trainer.
|
||||
This class controls algorithm-related parts of training, while trainer is responsible for runtime part.
|
||||
|
||||
The ship also defines the most important logic of the core training part,
|
||||
and (optionally) some callbacks to insert customized logics at specific events.
|
||||
"""
|
||||
|
||||
simulator_fn: Callable[[InitialStateType], Simulator[InitialStateType, StateType, ActType]]
|
||||
state_interpreter: StateInterpreter[StateType, ObsType]
|
||||
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType]
|
||||
policy: BasePolicy
|
||||
reward: Reward
|
||||
trainer: Trainer
|
||||
|
||||
def assign_trainer(self, trainer: Trainer) -> None:
|
||||
self.trainer = weakref.proxy(trainer) # type: ignore
|
||||
|
||||
def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
"""Override this to create a seed iterator for training.
|
||||
If the iterable is a context manager, the whole training will be invoked in the with-block,
|
||||
and the iterator will be automatically closed after the training is done."""
|
||||
raise SeedIteratorNotAvailable("Seed iterator for training is not available.")
|
||||
|
||||
def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
"""Override this to create a seed iterator for validation."""
|
||||
raise SeedIteratorNotAvailable("Seed iterator for validation is not available.")
|
||||
|
||||
def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
"""Override this to create a seed iterator for testing."""
|
||||
raise SeedIteratorNotAvailable("Seed iterator for testing is not available.")
|
||||
|
||||
def train(self, vector_env: BaseVectorEnv) -> dict[str, Any]:
|
||||
"""Implement this to train one iteration. In RL, one iteration usually refers to one collect."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def validate(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
|
||||
"""Implement this to validate the policy once."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def test(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
|
||||
"""Implement this to evaluate the policy on test environment once."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def log(self, name: str, value: Any) -> None:
|
||||
# FIXME: this is a workaround to make the log at least show somewhere.
|
||||
# Need a refactor in logger to formalize this.
|
||||
if isinstance(value, (np.ndarray, list)):
|
||||
value = np.mean(value)
|
||||
_logger.info(f"[Iter {self.trainer.current_iter + 1}] {name} = {value}")
|
||||
|
||||
def log_dict(self, data: dict[str, Any]) -> None:
|
||||
for name, value in data.items():
|
||||
self.log(name, value)
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
"""Return a checkpoint of current vessel state."""
|
||||
return {"policy": self.policy.state_dict()}
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
"""Restore a checkpoint from a previously saved state dict."""
|
||||
self.policy.load_state_dict(state_dict["policy"])
|
||||
|
||||
|
||||
class TrainingVessel(TrainingVesselBase):
|
||||
"""The default implementation of training vessel.
|
||||
|
||||
``__init__`` accepts a sequence of initial states so that iterator can be created.
|
||||
``train``, ``validate``, ``test`` each do one collect (and also update in train).
|
||||
By default, the train initial states will be repeated infinitely during training,
|
||||
and collector will control the number of episodes for each iteration.
|
||||
In validation and testing, the val / test initial states will be used exactly once.
|
||||
|
||||
Extra hyper-parameters (only used in train) include:
|
||||
|
||||
- ``buffer_size``: Size of replay buffer.
|
||||
- ``episode_per_iter``: Episodes per collect at training. Can be overridden by fast dev run.
|
||||
- ``update_kwargs``: Keyword arguments appearing in ``policy.update``.
|
||||
For example, ``dict(repeat=10, batch_size=64)``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
simulator_fn: Callable[[InitialStateType], Simulator[InitialStateType, StateType, ActType]],
|
||||
state_interpreter: StateInterpreter[StateType, ObsType],
|
||||
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
|
||||
policy: BasePolicy,
|
||||
reward: Reward,
|
||||
train_initial_states: Sequence[InitialStateType] | None = None,
|
||||
val_initial_states: Sequence[InitialStateType] | None = None,
|
||||
test_initial_states: Sequence[InitialStateType] | None = None,
|
||||
buffer_size: int = 20000,
|
||||
episode_per_iter: int = 1000,
|
||||
update_kwargs: dict[str, Any] = cast(Dict[str, Any], None),
|
||||
):
|
||||
self.simulator_fn = simulator_fn # type: ignore
|
||||
self.state_interpreter = state_interpreter
|
||||
self.action_interpreter = action_interpreter
|
||||
self.policy = policy
|
||||
self.reward = reward
|
||||
self.train_initial_states = train_initial_states
|
||||
self.val_initial_states = val_initial_states
|
||||
self.test_initial_states = test_initial_states
|
||||
self.buffer_size = buffer_size
|
||||
self.episode_per_iter = episode_per_iter
|
||||
self.update_kwargs = update_kwargs or {}
|
||||
|
||||
def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
if self.train_initial_states is not None:
|
||||
_logger.info("Training initial states collection size: %d", len(self.train_initial_states))
|
||||
# Implement fast_dev_run here.
|
||||
train_initial_states = self._random_subset("train", self.train_initial_states, self.trainer.fast_dev_run)
|
||||
return DataQueue(train_initial_states, repeat=-1, shuffle=True)
|
||||
return super().train_seed_iterator()
|
||||
|
||||
def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
if self.val_initial_states is not None:
|
||||
_logger.info("Validation initial states collection size: %d", len(self.val_initial_states))
|
||||
val_initial_states = self._random_subset("val", self.val_initial_states, self.trainer.fast_dev_run)
|
||||
return DataQueue(val_initial_states, repeat=1)
|
||||
return super().val_seed_iterator()
|
||||
|
||||
def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
if self.test_initial_states is not None:
|
||||
_logger.info("Testing initial states collection size: %d", len(self.test_initial_states))
|
||||
test_initial_states = self._random_subset("test", self.test_initial_states, self.trainer.fast_dev_run)
|
||||
return DataQueue(test_initial_states, repeat=1)
|
||||
return super().test_seed_iterator()
|
||||
|
||||
def train(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
|
||||
"""Create a collector and collects ``episode_per_iter`` episodes.
|
||||
Update the policy on the collected replay buffer.
|
||||
"""
|
||||
self.policy.train()
|
||||
|
||||
with vector_env.collector_guard():
|
||||
collector = Collector(self.policy, vector_env, VectorReplayBuffer(self.buffer_size, len(vector_env)))
|
||||
|
||||
# Number of episodes collected in each training iteration can be overridden by fast dev run.
|
||||
if self.trainer.fast_dev_run is not None:
|
||||
episodes = self.trainer.fast_dev_run
|
||||
else:
|
||||
episodes = self.episode_per_iter
|
||||
|
||||
col_result = collector.collect(n_episode=episodes)
|
||||
update_result = self.policy.update(sample_size=0, buffer=collector.buffer, **self.update_kwargs)
|
||||
res = {**col_result, **update_result}
|
||||
self.log_dict(res)
|
||||
return res
|
||||
|
||||
def validate(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
|
||||
self.policy.eval()
|
||||
|
||||
with vector_env.collector_guard():
|
||||
test_collector = Collector(self.policy, vector_env)
|
||||
res = test_collector.collect(n_step=INF * len(vector_env))
|
||||
self.log_dict(res)
|
||||
return res
|
||||
|
||||
def test(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
|
||||
self.policy.eval()
|
||||
|
||||
with vector_env.collector_guard():
|
||||
test_collector = Collector(self.policy, vector_env)
|
||||
res = test_collector.collect(n_step=INF * len(vector_env))
|
||||
self.log_dict(res)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def _random_subset(name: str, collection: Sequence[T], size: int | None) -> Sequence[T]:
|
||||
if size is None:
|
||||
# Size = None -> original collection
|
||||
return collection
|
||||
order = np.random.permutation(len(collection))
|
||||
res = [collection[o] for o in order[:size]]
|
||||
_logger.info(
|
||||
"Fast running in development mode. Cut %s initial states from %d to %d.", name, len(collection), len(res)
|
||||
)
|
||||
return res
|
||||
@@ -145,7 +145,9 @@ class DataQueue(Generic[T]):
|
||||
def __iter__(self):
|
||||
if not self._activated:
|
||||
raise ValueError(
|
||||
"Need to call activate() to launch a daemon worker to produce data into data queue before using it."
|
||||
"Need to call activate() to launch a daemon worker "
|
||||
"to produce data into data queue before using it. "
|
||||
"You probably have forgotten to use the DataQueue in a with block."
|
||||
)
|
||||
return self._consumer()
|
||||
|
||||
@@ -161,19 +163,21 @@ class DataQueue(Generic[T]):
|
||||
# pytorch dataloader is used here only because we need its sampler and multi-processing
|
||||
from torch.utils.data import DataLoader, Dataset # pylint: disable=import-outside-toplevel
|
||||
|
||||
dataloader = DataLoader(
|
||||
cast(Dataset[T], self.dataset),
|
||||
batch_size=None,
|
||||
num_workers=self.producer_num_workers,
|
||||
shuffle=self.shuffle,
|
||||
collate_fn=lambda t: t, # identity collate fn
|
||||
)
|
||||
repeat = 10**18 if self.repeat == -1 else self.repeat
|
||||
for _rep in range(repeat):
|
||||
for data in dataloader:
|
||||
if self._done.value:
|
||||
# Already done.
|
||||
return
|
||||
self._queue.put(data)
|
||||
_logger.debug(f"Dataloader loop done. Repeat {_rep}.")
|
||||
self.mark_as_done()
|
||||
try:
|
||||
dataloader = DataLoader(
|
||||
cast(Dataset[T], self.dataset),
|
||||
batch_size=None,
|
||||
num_workers=self.producer_num_workers,
|
||||
shuffle=self.shuffle,
|
||||
collate_fn=lambda t: t, # identity collate fn
|
||||
)
|
||||
repeat = 10**18 if self.repeat == -1 else self.repeat
|
||||
for _rep in range(repeat):
|
||||
for data in dataloader:
|
||||
if self._done.value:
|
||||
# Already done.
|
||||
return
|
||||
self._queue.put(data)
|
||||
_logger.debug(f"Dataloader loop done. Repeat {_rep}.")
|
||||
finally:
|
||||
self.mark_as_done()
|
||||
|
||||
@@ -120,12 +120,19 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
from child workers. See :class:`qlib.rl.utils.LogWriter`.
|
||||
"""
|
||||
|
||||
_logger: list[LogWriter]
|
||||
|
||||
def __init__(
|
||||
self, logger: LogWriter | list[LogWriter], env_fns: list[Callable[..., gym.Env]], **kwargs: Any
|
||||
self, logger: LogWriter | list[LogWriter] | None, env_fns: list[Callable[..., gym.Env]], **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(env_fns, **kwargs)
|
||||
|
||||
self._logger: list[LogWriter] = logger if isinstance(logger, list) else [logger]
|
||||
if isinstance(logger, list):
|
||||
self._logger = logger
|
||||
elif isinstance(logger, LogWriter):
|
||||
self._logger = [logger]
|
||||
else:
|
||||
self._logger = []
|
||||
self._alive_env_ids: Set[int] = set()
|
||||
self._reset_alive_envs()
|
||||
self._default_obs = self._default_info = self._default_rew = None
|
||||
@@ -177,7 +184,7 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
|
||||
1. Catch and ignore the StopIteration exception, which is the stopping signal
|
||||
thrown by FiniteEnv to let tianshou know that ``collector.collect()`` should exit.
|
||||
2. Notify the loggers that the collect is done what it's done.
|
||||
2. Notify the loggers that the collect is ready / done what it's ready / done.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -186,6 +193,9 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
"""
|
||||
self._collector_guarded = True
|
||||
|
||||
for logger in self._logger:
|
||||
logger.on_env_all_ready()
|
||||
|
||||
try:
|
||||
yield self
|
||||
except StopIteration:
|
||||
@@ -298,7 +308,21 @@ def vectorize_env(
|
||||
concurrency: int,
|
||||
logger: LogWriter | list[LogWriter],
|
||||
) -> FiniteVectorEnv:
|
||||
"""Helper function to create a vector env.
|
||||
"""Helper function to create a vector env. Can be used to replace usual VectorEnv.
|
||||
|
||||
For example, once you wrote: ::
|
||||
|
||||
DummyVectorEnv([lambda: gym.make(task) for _ in range(env_num)])
|
||||
|
||||
Now you can replace it with: ::
|
||||
|
||||
finite_env_factory(lambda: gym.make(task), "dummy", env_num, my_logger)
|
||||
|
||||
By doing such replacement, you have two additional features enabled (compared to normal VectorEnv):
|
||||
|
||||
1. The vector env will check for NaN observation and kill the worker when its found.
|
||||
See :class:`FiniteVectorEnv` for why we need this.
|
||||
2. A logger to explicit collect logs from environment workers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
@@ -12,13 +12,16 @@ in each worker, and writes them to console, log files, or tensorboard...
|
||||
The two modules communicate by the "log" field in "info" returned by ``env.step()``.
|
||||
"""
|
||||
|
||||
# NOTE: This file contains many hardcoded / ad-hoc rules.
|
||||
# Refactoring it will be one of the future tasks.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence
|
||||
from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence, Callable
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -29,7 +32,7 @@ if TYPE_CHECKING:
|
||||
from .env_wrapper import InfoDict
|
||||
|
||||
|
||||
__all__ = ["LogCollector", "LogWriter", "LogLevel", "ConsoleWriter", "CsvWriter"]
|
||||
__all__ = ["LogCollector", "LogWriter", "LogLevel", "LogBuffer", "ConsoleWriter", "CsvWriter"]
|
||||
|
||||
ObsType = TypeVar("ObsType")
|
||||
ActType = TypeVar("ActType")
|
||||
@@ -175,18 +178,53 @@ class LogWriter(Generic[ObsType, ActType]):
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
"""Clear all the metrics for a fresh start.
|
||||
To make the logger instance reusable.
|
||||
"""
|
||||
self.episode_count = self.step_count = 0
|
||||
self.active_env_ids = set()
|
||||
self.logs = []
|
||||
|
||||
def aggregation(self, array: Sequence[Any]) -> Any:
|
||||
def state_dict(self) -> dict:
|
||||
"""Save the states of the logger to a dict."""
|
||||
return {
|
||||
"episode_count": self.episode_count,
|
||||
"step_count": self.step_count,
|
||||
"global_step": self.global_step,
|
||||
"global_episode": self.global_episode,
|
||||
"active_env_ids": self.active_env_ids,
|
||||
"episode_lengths": self.episode_lengths,
|
||||
"episode_rewards": self.episode_rewards,
|
||||
"episode_logs": self.episode_logs,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
"""Load the states of current logger from a dict."""
|
||||
self.episode_count = state_dict["episode_count"]
|
||||
self.step_count = state_dict["step_count"]
|
||||
self.global_step = state_dict["global_step"]
|
||||
self.global_episode = state_dict["global_episode"]
|
||||
|
||||
# These are runtime infos.
|
||||
# Though they are loaded, I don't think it really helps.
|
||||
self.active_env_ids = state_dict["active_env_ids"]
|
||||
self.episode_lenghts = state_dict["episode_lengths"]
|
||||
self.episode_rewards = state_dict["episode_rewards"]
|
||||
self.episode_logs = state_dict["episode_logs"]
|
||||
|
||||
def aggregation(self, array: Sequence[Any], name: str | None = None) -> Any:
|
||||
"""Aggregation function from step-wise to episode-wise.
|
||||
|
||||
If it's a sequence of float, take the mean.
|
||||
Otherwise, take the first element.
|
||||
|
||||
If a name is specified and,
|
||||
|
||||
- if it's ``reward``, the reduction will be sum.
|
||||
"""
|
||||
assert len(array) > 0, "The aggregated array must be not empty."
|
||||
if all(isinstance(v, float) for v in array):
|
||||
if name == "reward":
|
||||
return np.sum(array)
|
||||
return np.mean(array)
|
||||
else:
|
||||
return array[0]
|
||||
@@ -253,10 +291,93 @@ class LogWriter(Generic[ObsType, ActType]):
|
||||
self.episode_rewards[env_id] = []
|
||||
self.episode_logs[env_id] = []
|
||||
|
||||
def on_env_all_ready(self) -> None:
|
||||
"""When all environments are ready to run.
|
||||
Usually, loggers should be reset here.
|
||||
"""
|
||||
self.clear()
|
||||
|
||||
def on_env_all_done(self) -> None:
|
||||
"""All done. Time for cleanup."""
|
||||
|
||||
|
||||
class LogBuffer(LogWriter):
|
||||
"""Keep all numbers in memory.
|
||||
|
||||
Objects that can't be aggregated like strings, tensors, images can't be stored in the buffer.
|
||||
To persist them, please use :class:`PickleWriter`.
|
||||
|
||||
Every time, Log buffer receives a new metric, the callback is triggered,
|
||||
which is useful when tracking metrics inside a trainer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
callback
|
||||
A callback receiving three arguments:
|
||||
|
||||
- on_episode: Whether it's called at the end of an episode
|
||||
- on_collect: Whether it's called at the end of a collect
|
||||
- log_buffer: the :class:`LogBbuffer`object
|
||||
|
||||
No return value is expected.
|
||||
"""
|
||||
|
||||
# FIXME: needs a metric count
|
||||
|
||||
def __init__(self, callback: Callable[[bool, bool, LogBuffer], None], loglevel: int | LogLevel = LogLevel.PERIODIC):
|
||||
super().__init__(loglevel)
|
||||
self.callback = callback
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
return {
|
||||
**super().state_dict(),
|
||||
"latest_metrics": self._latest_metrics,
|
||||
"aggregated_metrics": self._aggregated_metrics,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
self._latest_metrics = state_dict["latest_metrics"]
|
||||
self._aggregated_metrics = state_dict["aggregated_metrics"]
|
||||
return super().load_state_dict(state_dict)
|
||||
|
||||
def clear(self):
|
||||
super().clear()
|
||||
self._latest_metrics: dict[str, float] | None = None
|
||||
self._aggregated_metrics: dict[str, float] = defaultdict(float)
|
||||
|
||||
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
|
||||
# FIXME Dup of ConsoleWriter
|
||||
episode_wise_contents: dict[str, list] = defaultdict(list)
|
||||
for step_contents in contents:
|
||||
for name, value in step_contents.items():
|
||||
# FIXME This could be false-negative for some numpy types
|
||||
if isinstance(value, float):
|
||||
episode_wise_contents[name].append(value)
|
||||
|
||||
logs: dict[str, float] = {}
|
||||
for name, values in episode_wise_contents.items():
|
||||
logs[name] = self.aggregation(values, name) # type: ignore
|
||||
self._aggregated_metrics[name] += logs[name]
|
||||
|
||||
self._latest_metrics = logs
|
||||
|
||||
self.callback(True, False, self)
|
||||
|
||||
def on_env_all_done(self) -> None:
|
||||
# This happens when collect exits
|
||||
self.callback(False, True, self)
|
||||
|
||||
def episode_metrics(self) -> dict[str, float]:
|
||||
"""Retrieve the numeric metrics of the latest episode."""
|
||||
if self._latest_metrics is None:
|
||||
raise ValueError("No episode metrics available yet.")
|
||||
return self._latest_metrics
|
||||
|
||||
def collect_metrics(self) -> dict[str, float]:
|
||||
"""Retrieve the aggregated metrics of the latest collect."""
|
||||
return {name: value / self.episode_count for name, value in self._aggregated_metrics.items()}
|
||||
|
||||
|
||||
class ConsoleWriter(LogWriter):
|
||||
"""Write log messages to console periodically.
|
||||
|
||||
@@ -289,6 +410,8 @@ class ConsoleWriter(LogWriter):
|
||||
|
||||
self.console_logger = get_module_logger(__name__, level=logging.INFO)
|
||||
|
||||
# FIXME: save & reload
|
||||
|
||||
def clear(self):
|
||||
super().clear()
|
||||
# Clear average meters
|
||||
@@ -308,7 +431,7 @@ class ConsoleWriter(LogWriter):
|
||||
# This should be done at every step, regardless of periodic or not.
|
||||
logs: dict[str, float] = {}
|
||||
for name, values in episode_wise_contents.items():
|
||||
logs[name] = self.aggregation(values) # type: ignore
|
||||
logs[name] = self.aggregation(values, name) # type: ignore
|
||||
|
||||
for name, value in logs.items():
|
||||
self.metric_counts[name] += 1
|
||||
@@ -350,6 +473,8 @@ class CsvWriter(LogWriter):
|
||||
|
||||
all_records: list[dict[str, Any]]
|
||||
|
||||
# FIXME: save & reload
|
||||
|
||||
def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC):
|
||||
super().__init__(loglevel)
|
||||
self.output_dir = output_dir
|
||||
@@ -370,7 +495,7 @@ class CsvWriter(LogWriter):
|
||||
|
||||
logs: dict[str, float] = {}
|
||||
for name, values in episode_wise_contents.items():
|
||||
logs[name] = self.aggregation(values) # type: ignore
|
||||
logs[name] = self.aggregation(values, name) # type: ignore
|
||||
|
||||
self.all_records.append(logs)
|
||||
|
||||
@@ -392,7 +517,3 @@ class TensorboardWriter(LogWriter):
|
||||
|
||||
class MlflowWriter(LogWriter):
|
||||
"""Add logs to mlflow."""
|
||||
|
||||
|
||||
class LogBuffer(LogWriter):
|
||||
"""Keep everything in memory."""
|
||||
|
||||
@@ -81,7 +81,7 @@ def test_simple_env_logger(caplog):
|
||||
line = line.strip()
|
||||
if line:
|
||||
line_counter += 1
|
||||
assert re.match(r".*reward 42\.0000 \(42.0000\) a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
|
||||
assert re.match(r".*reward .* a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
|
||||
assert line_counter >= 3
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from qlib.backtest import Order
|
||||
from qlib.config import C
|
||||
from qlib.log import set_log_with_config
|
||||
from qlib.rl.data import pickle_styled
|
||||
from qlib.rl.entries.test import backtest
|
||||
from qlib.rl.trainer import backtest, train
|
||||
from qlib.rl.order_execution import *
|
||||
from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
|
||||
|
||||
@@ -306,3 +306,26 @@ def test_cn_ppo_strategy():
|
||||
assert np.isclose(metrics["pa"].mean(), -16.21578303474833)
|
||||
assert np.isclose(metrics["market_price"].mean(), 58.68277690875527)
|
||||
assert np.isclose(metrics["trade_price"].mean(), 58.76063985000002)
|
||||
|
||||
|
||||
def test_ppo_train():
|
||||
set_log_with_config(C.logging_config)
|
||||
# The data starts with 9:31 and ends with 15:00
|
||||
orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58"))
|
||||
assert len(orders) == 40
|
||||
|
||||
state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6)
|
||||
action_interp = CategoricalActionInterpreter(4)
|
||||
network = Recurrent(state_interp.observation_space)
|
||||
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
|
||||
|
||||
train(
|
||||
partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
|
||||
state_interp,
|
||||
action_interp,
|
||||
orders,
|
||||
policy,
|
||||
PAPenaltyReward(),
|
||||
vessel_kwargs={"episode_per_iter": 100, "update_kwargs": {"batch_size": 64, "repeat": 5}},
|
||||
trainer_kwargs={"max_iters": 2, "loggers": ConsoleWriter(total_episodes=100)},
|
||||
)
|
||||
|
||||
202
tests/rl/test_trainer.py
Normal file
202
tests/rl/test_trainer.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from gym import spaces
|
||||
from tianshou.policy import PPOPolicy
|
||||
|
||||
from qlib.config import C
|
||||
from qlib.log import set_log_with_config
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
||||
from qlib.rl.simulator import Simulator
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.trainer import Trainer, TrainingVessel, EarlyStopping, Checkpoint
|
||||
|
||||
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
|
||||
|
||||
|
||||
class ZeroSimulator(Simulator):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.action = self.correct = 0
|
||||
|
||||
def step(self, action):
|
||||
self.action = action
|
||||
self.correct = action == 0
|
||||
self._done = random.choice([False, True])
|
||||
if self._done:
|
||||
self.env.logger.add_scalar("acc", self.correct * 100)
|
||||
|
||||
def get_state(self):
|
||||
return {
|
||||
"acc": self.correct * 100,
|
||||
"action": self.action,
|
||||
}
|
||||
|
||||
def done(self) -> bool:
|
||||
return self._done
|
||||
|
||||
|
||||
class NoopStateInterpreter(StateInterpreter):
|
||||
observation_space = spaces.Dict(
|
||||
{
|
||||
"acc": spaces.Discrete(200),
|
||||
"action": spaces.Discrete(2),
|
||||
}
|
||||
)
|
||||
|
||||
def interpret(self, simulator_state):
|
||||
return simulator_state
|
||||
|
||||
|
||||
class NoopActionInterpreter(ActionInterpreter):
|
||||
action_space = spaces.Discrete(2)
|
||||
|
||||
def interpret(self, simulator_state, action):
|
||||
return action
|
||||
|
||||
|
||||
class AccReward(Reward):
|
||||
def reward(self, simulator_state):
|
||||
if self.env.status["done"]:
|
||||
return simulator_state["acc"] / 100
|
||||
return 0.0
|
||||
|
||||
|
||||
class PolicyNet(nn.Module):
|
||||
def __init__(self, out_features=1, return_state=False):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(32, out_features)
|
||||
self.return_state = return_state
|
||||
|
||||
def forward(self, obs, state=None, **kwargs):
|
||||
res = self.fc(torch.randn(obs["acc"].shape[0], 32))
|
||||
if self.return_state:
|
||||
return nn.functional.softmax(res, dim=-1), state
|
||||
else:
|
||||
return res
|
||||
|
||||
|
||||
def _ppo_policy():
|
||||
actor = PolicyNet(2, True)
|
||||
critic = PolicyNet()
|
||||
policy = PPOPolicy(
|
||||
actor,
|
||||
critic,
|
||||
torch.optim.Adam(tuple(actor.parameters()) + tuple(critic.parameters())),
|
||||
torch.distributions.Categorical,
|
||||
action_space=NoopActionInterpreter().action_space,
|
||||
)
|
||||
return policy
|
||||
|
||||
|
||||
def test_trainer():
|
||||
set_log_with_config(C.logging_config)
|
||||
trainer = Trainer(max_iters=10, finite_env_type="subproc")
|
||||
policy = _ppo_policy()
|
||||
|
||||
vessel = TrainingVessel(
|
||||
simulator_fn=lambda init: ZeroSimulator(init),
|
||||
state_interpreter=NoopStateInterpreter(),
|
||||
action_interpreter=NoopActionInterpreter(),
|
||||
policy=policy,
|
||||
train_initial_states=list(range(100)),
|
||||
val_initial_states=list(range(10)),
|
||||
test_initial_states=list(range(10)),
|
||||
reward=AccReward(),
|
||||
episode_per_iter=500,
|
||||
update_kwargs=dict(repeat=10, batch_size=64),
|
||||
)
|
||||
trainer.fit(vessel)
|
||||
assert trainer.current_iter == 10
|
||||
assert trainer.current_episode == 5000
|
||||
assert abs(trainer.metrics["acc"] - trainer.metrics["reward"] * 100) < 1e-4
|
||||
assert trainer.metrics["acc"] > 80
|
||||
trainer.test(vessel)
|
||||
assert trainer.metrics["acc"] > 60
|
||||
|
||||
|
||||
def test_trainer_fast_dev_run():
|
||||
set_log_with_config(C.logging_config)
|
||||
trainer = Trainer(max_iters=2, fast_dev_run=2, finite_env_type="shmem")
|
||||
policy = _ppo_policy()
|
||||
|
||||
vessel = TrainingVessel(
|
||||
simulator_fn=lambda init: ZeroSimulator(init),
|
||||
state_interpreter=NoopStateInterpreter(),
|
||||
action_interpreter=NoopActionInterpreter(),
|
||||
policy=policy,
|
||||
train_initial_states=list(range(100)),
|
||||
val_initial_states=list(range(10)),
|
||||
test_initial_states=list(range(10)),
|
||||
reward=AccReward(),
|
||||
episode_per_iter=500,
|
||||
update_kwargs=dict(repeat=10, batch_size=64),
|
||||
)
|
||||
trainer.fit(vessel)
|
||||
assert trainer.current_episode == 4
|
||||
|
||||
|
||||
def test_trainer_earlystop():
|
||||
# TODO this is just sanity check.
|
||||
# need to see the logs to check whether it works.
|
||||
set_log_with_config(C.logging_config)
|
||||
trainer = Trainer(
|
||||
max_iters=10,
|
||||
val_every_n_iters=1,
|
||||
finite_env_type="dummy",
|
||||
callbacks=[EarlyStopping("val/reward", restore_best_weights=True)],
|
||||
)
|
||||
policy = _ppo_policy()
|
||||
|
||||
vessel = TrainingVessel(
|
||||
simulator_fn=lambda init: ZeroSimulator(init),
|
||||
state_interpreter=NoopStateInterpreter(),
|
||||
action_interpreter=NoopActionInterpreter(),
|
||||
policy=policy,
|
||||
train_initial_states=list(range(100)),
|
||||
val_initial_states=list(range(10)),
|
||||
test_initial_states=list(range(10)),
|
||||
reward=AccReward(),
|
||||
episode_per_iter=500,
|
||||
update_kwargs=dict(repeat=10, batch_size=64),
|
||||
)
|
||||
trainer.fit(vessel)
|
||||
assert trainer.metrics["val/acc"] > 30
|
||||
assert trainer.current_iter == 2 # second iteration
|
||||
|
||||
|
||||
def test_trainer_checkpoint():
|
||||
set_log_with_config(C.logging_config)
|
||||
output_dir = Path(__file__).parent / ".output"
|
||||
trainer = Trainer(max_iters=2, finite_env_type="dummy", callbacks=[Checkpoint(output_dir, every_n_iters=1)])
|
||||
policy = _ppo_policy()
|
||||
|
||||
vessel = TrainingVessel(
|
||||
simulator_fn=lambda init: ZeroSimulator(init),
|
||||
state_interpreter=NoopStateInterpreter(),
|
||||
action_interpreter=NoopActionInterpreter(),
|
||||
policy=policy,
|
||||
train_initial_states=list(range(100)),
|
||||
val_initial_states=list(range(10)),
|
||||
test_initial_states=list(range(10)),
|
||||
reward=AccReward(),
|
||||
episode_per_iter=100,
|
||||
update_kwargs=dict(repeat=10, batch_size=64),
|
||||
)
|
||||
trainer.fit(vessel)
|
||||
|
||||
assert (output_dir / "001.pth").exists()
|
||||
assert (output_dir / "002.pth").exists()
|
||||
assert os.readlink(output_dir / "latest.pth") == str(output_dir / "002.pth")
|
||||
|
||||
trainer.load_state_dict(torch.load(output_dir / "001.pth"))
|
||||
assert trainer.current_iter == 1
|
||||
assert trainer.current_episode == 100
|
||||
|
||||
# Reload the checkpoint at first iteration
|
||||
trainer.fit(vessel, ckpt_path=output_dir / "001.pth")
|
||||
Reference in New Issue
Block a user