diff --git a/qlib/rl/entries/__init__.py b/qlib/rl/entries/__init__.py
deleted file mode 100644
index 169fa985c..000000000
--- a/qlib/rl/entries/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-"""Train, test, inference utilities.
-
-The APIs in this directory are NOT considered final and are subject to change!
-"""
diff --git a/qlib/rl/entries/test.py b/qlib/rl/entries/test.py
deleted file mode 100644
index ca311407b..000000000
--- a/qlib/rl/entries/test.py
+++ /dev/null
@@ -1,99 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-from __future__ import annotations
-
-import copy
-from typing import Callable, Sequence
-
-from tianshou.data import Collector
-from tianshou.policy import BasePolicy
-
-from qlib.constant import INF
-from qlib.log import get_module_logger
-from qlib.rl.simulator import InitialStateType, Simulator
-from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
-from qlib.rl.reward import Reward
-from qlib.rl.utils import DataQueue, EnvWrapper, FiniteEnvType, LogCollector, LogWriter, vectorize_env
-
-
-_logger = get_module_logger(__name__)
-
-
-def backtest(
- simulator_fn: Callable[[InitialStateType], Simulator],
- state_interpreter: StateInterpreter,
- action_interpreter: ActionInterpreter,
- initial_states: Sequence[InitialStateType],
- policy: BasePolicy,
- logger: LogWriter | list[LogWriter],
- reward: Reward | None = None,
- finite_env_type: FiniteEnvType = "subproc",
- concurrency: int = 2,
-) -> None:
- """Backtest with the parallelism provided by RL framework.
-
- Parameters
- ----------
- simulator_fn
- Callable receiving initial seed, returning a simulator.
- state_interpreter
- Interprets the state of simulators.
- action_interpreter
- Interprets the policy actions.
- initial_states
- Initial states to iterate over. Every state will be run exactly once.
- policy
- Policy to test against.
- logger
- Logger to record the backtest results. Logger must be present because
- without logger, all information will be lost.
- reward
- Optional reward function. For backtest, this is for testing the rewards
- and logging them only.
- finite_env_type
- Type of finite env implementation.
- concurrency
- Parallel workers.
- """
-
- # To save bandwidth
- min_loglevel = min(lg.loglevel for lg in logger) if isinstance(logger, list) else logger.loglevel
-
- def env_factory():
- # FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
- # and could be thread unsafe.
- # I'm not sure whether it's a design flaw.
- # I'll rethink about this when designing the trainer.
-
- if finite_env_type == "dummy":
- # We could only experience the "threading-unsafe" problem in dummy.
- state = copy.deepcopy(state_interpreter)
- action = copy.deepcopy(action_interpreter)
- rew = copy.deepcopy(reward)
- else:
- state, action, rew = state_interpreter, action_interpreter, reward
-
- return EnvWrapper(
- simulator_fn,
- state,
- action,
- seed_iterator,
- rew,
- logger=LogCollector(min_loglevel=min_loglevel),
- )
-
- with DataQueue(initial_states) as seed_iterator:
- vector_env = vectorize_env(
- env_factory,
- finite_env_type,
- concurrency,
- logger,
- )
-
- policy.eval()
-
- with vector_env.collector_guard():
- test_collector = Collector(policy, vector_env)
- _logger.info("All ready. Start backtest.")
- test_collector.collect(n_step=INF * len(vector_env))
diff --git a/qlib/rl/entries/train.py b/qlib/rl/entries/train.py
deleted file mode 100644
index c852e6235..000000000
--- a/qlib/rl/entries/train.py
+++ /dev/null
@@ -1,4 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-# TBD
diff --git a/qlib/rl/order_execution/__init__.py b/qlib/rl/order_execution/__init__.py
index 048dfecac..b7b47c3d1 100644
--- a/qlib/rl/order_execution/__init__.py
+++ b/qlib/rl/order_execution/__init__.py
@@ -9,4 +9,5 @@ Multi-asset is on the way.
from .interpreter import *
from .network import *
from .policy import *
+from .reward import *
from .simulator_simple import *
diff --git a/qlib/rl/order_execution/reward.py b/qlib/rl/order_execution/reward.py
new file mode 100644
index 000000000..43015407d
--- /dev/null
+++ b/qlib/rl/order_execution/reward.py
@@ -0,0 +1,46 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from __future__ import annotations
+
+from typing import cast
+
+import numpy as np
+from qlib.rl.reward import Reward
+
+from .simulator_simple import SAOEState, SAOEMetrics
+
+__all__ = ["PAPenaltyReward"]
+
+
+class PAPenaltyReward(Reward[SAOEState]):
+ """Encourage higher PAs, but penalize stacking all the amounts within a very short time.
+ Formally, for each time step, the reward is :math:`(PA_t * vol_t / target - vol_t^2 * penalty)`.
+
+ Parameters
+ ----------
+ penalty
+ The penalty for large volume in a short time.
+ """
+
+ def __init__(self, penalty: float = 100.0):
+ self.penalty = penalty
+
+ def reward(self, simulator_state: SAOEState) -> float:
+ whole_order = simulator_state.order.amount
+ assert whole_order > 0
+ last_step = cast(SAOEMetrics, simulator_state.history_steps.reset_index().iloc[-1].to_dict())
+ pa = last_step["pa"] * last_step["amount"] / whole_order
+
+ # Inspect the "break-down" of the latest step: trading amount at every tick
+ last_step_breakdown = simulator_state.history_exec.loc[last_step["datetime"] :]
+ penalty = -self.penalty * ((last_step_breakdown["amount"] / whole_order) ** 2).sum()
+
+ reward = pa + penalty
+
+ # Throw error in case of NaN
+ assert not (np.isnan(reward) or np.isinf(reward)), f"Invalid reward for simulator state: {simulator_state}"
+
+ self.log("reward/pa", pa)
+ self.log("reward/penalty", penalty)
+ return reward
diff --git a/qlib/rl/order_execution/simulator_simple.py b/qlib/rl/order_execution/simulator_simple.py
index 8022c34ce..51357dfdf 100644
--- a/qlib/rl/order_execution/simulator_simple.py
+++ b/qlib/rl/order_execution/simulator_simple.py
@@ -131,11 +131,14 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
"""
history_exec: pd.DataFrame
- """All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns."""
+ """All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns.
+ Index is ``datetime``.
+ """
history_steps: pd.DataFrame
"""Positions at each step. The position before first step is also recorded.
- See :class:`SAOEMetrics` for available columns."""
+ See :class:`SAOEMetrics` for available columns.
+ Index is ``datetime``, which is the **starting** time of each step."""
metrics: SAOEMetrics | None
"""Metrics. Only available when done."""
diff --git a/qlib/rl/trainer/__init__.py b/qlib/rl/trainer/__init__.py
new file mode 100644
index 000000000..efce804c4
--- /dev/null
+++ b/qlib/rl/trainer/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""Train, test, inference utilities."""
+
+from .api import backtest, train
+from .callbacks import EarlyStopping, Checkpoint
+from .trainer import Trainer
+from .vessel import TrainingVessel, TrainingVesselBase
diff --git a/qlib/rl/trainer/api.py b/qlib/rl/trainer/api.py
new file mode 100644
index 000000000..65abbd88d
--- /dev/null
+++ b/qlib/rl/trainer/api.py
@@ -0,0 +1,118 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from __future__ import annotations
+
+from typing import Callable, Sequence, cast, Any
+
+from tianshou.policy import BasePolicy
+
+from qlib.rl.simulator import InitialStateType, Simulator
+from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
+from qlib.rl.reward import Reward
+from qlib.rl.utils import FiniteEnvType, LogWriter
+
+from .vessel import TrainingVessel
+from .trainer import Trainer
+
+
+def train(
+ simulator_fn: Callable[[InitialStateType], Simulator],
+ state_interpreter: StateInterpreter,
+ action_interpreter: ActionInterpreter,
+ initial_states: Sequence[InitialStateType],
+ policy: BasePolicy,
+ reward: Reward,
+ vessel_kwargs: dict[str, Any],
+ trainer_kwargs: dict[str, Any],
+) -> None:
+ """Train a policy with the parallelism provided by RL framework.
+
+ Experimental API. Parameters might change shortly.
+
+ Parameters
+ ----------
+ simulator_fn
+ Callable receiving initial seed, returning a simulator.
+ state_interpreter
+ Interprets the state of simulators.
+ action_interpreter
+ Interprets the policy actions.
+ initial_states
+ Initial states to iterate over. Every state will be run exactly once.
+ policy
+ Policy to train against.
+ reward
+ Reward function.
+ vessel_kwargs
+ Keyword arguments passed to :class:`TrainingVessel`, like ``episode_per_iter``.
+ trainer_kwargs
+ Keyword arguments passed to :class:`Trainer`, like ``finite_env_type``, ``concurrency``.
+ """
+
+ vessel = TrainingVessel(
+ simulator_fn=simulator_fn,
+ state_interpreter=state_interpreter,
+ action_interpreter=action_interpreter,
+ policy=policy,
+ train_initial_states=initial_states,
+ reward=reward, # ignore none
+ **vessel_kwargs,
+ )
+ trainer = Trainer(**trainer_kwargs)
+ trainer.fit(vessel)
+
+
+def backtest(
+ simulator_fn: Callable[[InitialStateType], Simulator],
+ state_interpreter: StateInterpreter,
+ action_interpreter: ActionInterpreter,
+ initial_states: Sequence[InitialStateType],
+ policy: BasePolicy,
+ logger: LogWriter | list[LogWriter],
+ reward: Reward | None = None,
+ finite_env_type: FiniteEnvType = "subproc",
+ concurrency: int = 2,
+) -> None:
+ """Backtest with the parallelism provided by RL framework.
+
+ Experimental API. Parameters might change shortly.
+
+ Parameters
+ ----------
+ simulator_fn
+ Callable receiving initial seed, returning a simulator.
+ state_interpreter
+ Interprets the state of simulators.
+ action_interpreter
+ Interprets the policy actions.
+ initial_states
+ Initial states to iterate over. Every state will be run exactly once.
+ policy
+ Policy to test against.
+ logger
+ Logger to record the backtest results. Logger must be present because
+ without logger, all information will be lost.
+ reward
+ Optional reward function. For backtest, this is for testing the rewards
+ and logging them only.
+ finite_env_type
+ Type of finite env implementation.
+ concurrency
+ Parallel workers.
+ """
+
+ vessel = TrainingVessel(
+ simulator_fn=simulator_fn,
+ state_interpreter=state_interpreter,
+ action_interpreter=action_interpreter,
+ policy=policy,
+ test_initial_states=initial_states,
+ reward=cast(Reward, reward), # ignore none
+ )
+ trainer = Trainer(
+ finite_env_type=finite_env_type,
+ concurrency=concurrency,
+ loggers=logger,
+ )
+ trainer.test(vessel)
diff --git a/qlib/rl/trainer/callbacks.py b/qlib/rl/trainer/callbacks.py
new file mode 100644
index 000000000..72e2df99a
--- /dev/null
+++ b/qlib/rl/trainer/callbacks.py
@@ -0,0 +1,267 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""Callbacks to insert customized recipes during the training.
+Mimicks the hooks of Keras / PyTorch-Lightning, but tailored for the context of RL.
+"""
+
+from __future__ import annotations
+
+import copy
+import shutil
+import time
+from datetime import datetime
+from pathlib import Path
+from typing import Any, TYPE_CHECKING
+
+import numpy as np
+import torch
+
+from qlib.log import get_module_logger
+from qlib.typehint import Literal
+
+if TYPE_CHECKING:
+ from .trainer import Trainer
+ from .vessel import TrainingVesselBase
+
+
+_logger = get_module_logger(__name__)
+
+
+class Callback:
+ """Base class of all callbacks."""
+
+ def on_fit_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
+ """Called before the whole fit process begins."""
+
+ def on_fit_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
+ """Called after the whole fit process ends."""
+
+ def on_train_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
+ """Called when each collect for training begins."""
+
+ def on_train_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
+ """Called when the training ends.
+ To access all outputs produced during training, cache the data in either trainer and vessel,
+ and post-process them in this hook.
+ """
+
+ def on_validate_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
+ """Called when every run for validation begins."""
+
+ def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
+ """Called when the validation ends."""
+
+ def on_test_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
+ """Called when every run of testing begins."""
+
+ def on_test_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
+ """Called when the testing ends."""
+
+ def on_iter_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
+ """Called when every iteration (i.e., collect) starts."""
+
+ def on_iter_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
+ """Called upon every end of iteration.
+ This is called **after** the bump of ``current_iter``,
+ when the previous iteration is considered complete.
+ """
+
+ def state_dict(self) -> Any:
+ """Get a state dict of the callback for pause and resume."""
+
+ def load_state_dict(self, state_dict: Any) -> None:
+ """Resume the callback from a saved state dict."""
+
+
+class EarlyStopping(Callback):
+ """Stop training when a monitored metric has stopped improving.
+
+ The earlystopping callback will be triggered each time validation ends.
+ It will examine the metrics produced in validation,
+ and get the metric with name ``monitor` (``monitor`` is ``reward`` by default),
+ to check whether it's no longer increasing / decreasing.
+ It takes ``min_delta`` and ``patience`` if applicable.
+ If it's found to be not increasing / decreasing any more.
+ ``trainer.should_stop`` will be set to true,
+ and the training terminates.
+
+ Implementation reference: https://github.com/keras-team/keras/blob/v2.9.0/keras/callbacks.py#L1744-L1893
+ """
+
+ def __init__(
+ self,
+ monitor: str = "reward",
+ min_delta: float = 0.0,
+ patience: int = 0,
+ mode: Literal["min", "max"] = "max",
+ baseline: float | None = None,
+ restore_best_weights: bool = False,
+ ):
+ super().__init__()
+
+ self.monitor = monitor
+ self.patience = patience
+ self.baseline = baseline
+ self.min_delta = abs(min_delta)
+ self.restore_best_weights = restore_best_weights
+ self.best_weights: Any | None = None
+
+ if mode not in ["min", "max"]:
+ raise ValueError("Unsupported earlystopping mode: " + mode)
+
+ if mode == "min":
+ self.monitor_op = np.less
+ elif mode == "max":
+ self.monitor_op = np.greater
+
+ if self.monitor_op == np.greater:
+ self.min_delta *= 1
+ else:
+ self.min_delta *= -1
+
+ def state_dict(self) -> dict:
+ return {"wait": self.wait, "best": self.best, "best_weights": self.best_weights, "best_iter": self.best_iter}
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ self.wait = state_dict["wait"]
+ self.best = state_dict["best"]
+ self.best_weights = state_dict["best_weights"]
+ self.best_iter = state_dict["best_iter"]
+
+ def on_fit_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
+ # Allow instances to be re-used
+ self.wait = 0
+ self.best = np.inf if self.monitor_op == np.less else -np.inf
+ self.best_weights = None
+ self.best_iter = 0
+
+ def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
+ current = self.get_monitor_value(trainer)
+ if current is None:
+ return
+ if self.restore_best_weights and self.best_weights is None:
+ # Restore the weights after first iteration if no progress is ever made.
+ self.best_weights = copy.deepcopy(vessel.state_dict())
+
+ self.wait += 1
+ if self._is_improvement(current, self.best):
+ self.best = current
+ self.best_iter = trainer.current_iter
+ if self.restore_best_weights:
+ self.best_weights = copy.deepcopy(vessel.state_dict())
+ # Only restart wait if we beat both the baseline and our previous best.
+ if self.baseline is None or self._is_improvement(current, self.baseline):
+ self.wait = 0
+
+ # Only check after the first epoch.
+ if self.wait >= self.patience and trainer.current_iter > 0:
+ trainer.should_stop = True
+ _logger.info(f"On iteration %d: early stopping", trainer.current_iter + 1)
+ if self.restore_best_weights and self.best_weights is not None:
+ _logger.info("Restoring model weights from the end of the best iteration: %d", self.best_iter + 1)
+ vessel.load_state_dict(self.best_weights)
+
+ def get_monitor_value(self, trainer: Trainer) -> Any:
+ monitor_value = trainer.metrics.get(self.monitor)
+ if monitor_value is None:
+ _logger.warning(
+ "Early stopping conditioned on metric `%s` which is not available. Available metrics are: %s",
+ self.monitor,
+ ",".join(list(trainer.metrics.keys())),
+ )
+ return monitor_value
+
+ def _is_improvement(self, monitor_value, reference_value):
+ return self.monitor_op(monitor_value - self.min_delta, reference_value)
+
+
+class Checkpoint(Callback):
+ """Save checkpoints periodically for persistence and recovery.
+
+ Reference: https://github.com/PyTorchLightning/pytorch-lightning/blob/bfa8b7be/pytorch_lightning/callbacks/model_checkpoint.py
+
+ Parameters
+ ----------
+ dirpath
+ Directory to save the checkpoint file.
+ filename
+ Checkpoint filename. Can contain named formatting options to be auto-filled.
+ For example: ``{iter:03d}-{reward:.2f}.pth``.
+ Supported argument names are:
+
+ - iter (int)
+ - metrics in ``trainer.metrics``
+ - time string, in the format of ``%Y%m%d%H%M%S``
+ save_latest
+ Save the latest checkpoint in ``latest.pth``.
+ If ``link``, ``latest.pth`` will be created as a softlink.
+ If ``copy``, ``latest.pth`` will be stored as an individual copy.
+ Set to none to disable this.
+ every_n_iters
+ Checkpoints are saved at the end of every n iterations of training,
+ after validation if applicable.
+ time_interval
+ Maximum time (seconds) before checkpoints save again.
+ save_on_fit_end
+ Save one last checkpoint at the end to fit.
+ Do nothing if a checkpoint is already saved there.
+ """
+
+ def __init__(
+ self,
+ dirpath: Path,
+ filename: str = "{iter:03d}.pth",
+ save_latest: Literal["link", "copy"] | None = "link",
+ every_n_iters: int | None = None,
+ time_interval: int | None = None,
+ save_on_fit_end: bool = True,
+ ):
+ self.dirpath = Path(dirpath)
+ self.filename = filename
+ self.save_latest = save_latest
+ self.every_n_iters = every_n_iters
+ self.time_interval = time_interval
+ self.save_on_fit_end = save_on_fit_end
+
+ self._last_checkpoint_name: str | None = None
+ self._last_checkpoint_iter: int | None = None
+ self._last_checkpoint_time: float | None = None
+
+ def on_fit_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
+ if self.save_on_fit_end and (trainer.current_iter != self._last_checkpoint_iter):
+ self._save_checkpoint(trainer)
+
+ def on_iter_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
+ should_save_ckpt = False
+ if self.every_n_iters is not None and (trainer.current_iter + 1) % self.every_n_iters == 0:
+ should_save_ckpt = True
+ if self.time_interval is not None and (
+ self._last_checkpoint_time is None or (time.time() - self._last_checkpoint_time) >= self.time_interval
+ ):
+ should_save_ckpt = True
+ if should_save_ckpt:
+ self._save_checkpoint(trainer)
+
+ def _save_checkpoint(self, trainer: Trainer) -> None:
+ self.dirpath.mkdir(exist_ok=True, parents=True)
+ self._last_checkpoint_name = self._new_checkpoint_name(trainer)
+ self._last_checkpoint_iter = trainer.current_iter
+ self._last_checkpoint_time = time.time()
+ torch.save(trainer.state_dict(), self.dirpath / self._last_checkpoint_name)
+
+ latest_pth = self.dirpath / "latest.pth"
+
+ # Remove first before saving
+ if self.save_latest and latest_pth.exists():
+ latest_pth.unlink()
+
+ if self.save_latest == "link":
+ latest_pth.symlink_to(self.dirpath / self._last_checkpoint_name)
+ elif self.save_latest == "copy":
+ shutil.copyfile(self.dirpath / self._last_checkpoint_name, latest_pth)
+
+ def _new_checkpoint_name(self, trainer: Trainer) -> str:
+ return self.filename.format(
+ iter=trainer.current_iter, time=datetime.now().strftime("%Y%m%d%H%M%S"), **trainer.metrics
+ )
diff --git a/qlib/rl/trainer/trainer.py b/qlib/rl/trainer/trainer.py
new file mode 100644
index 000000000..c44419e05
--- /dev/null
+++ b/qlib/rl/trainer/trainer.py
@@ -0,0 +1,343 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from __future__ import annotations
+
+import copy
+from contextlib import AbstractContextManager, contextmanager
+from pathlib import Path
+from typing import Any, Iterable, TypeVar, Sequence, cast
+
+import torch
+
+from qlib.rl.simulator import InitialStateType
+from qlib.rl.utils import EnvWrapper, FiniteEnvType, LogCollector, LogWriter, LogBuffer, vectorize_env, LogLevel
+from qlib.log import get_module_logger
+from qlib.rl.utils.finite_env import FiniteVectorEnv
+from qlib.typehint import Literal
+
+from .callbacks import Callback
+from .vessel import TrainingVesselBase
+
+_logger = get_module_logger(__name__)
+
+
+T = TypeVar("T")
+
+
+class Trainer:
+ """
+ Utility to train a policy on a particular task.
+
+ Different from traditional DL trainer, the iteration of this trainer is "collect",
+ rather than "epoch", or "mini-batch".
+ In each collect, :class:`Collector` collects a number of policy-env interactions, and accumulates
+ them into a replay buffer. This buffer is used as the "data" to train the policy.
+ At the end of each collect, the policy is *updated* several times.
+
+ The API has some resemblence with `PyTorch Lightning `__,
+ but it's essentially different because this trainer is built for RL applications, and thus
+ most configurations are under RL context.
+ We are still looking for ways to incorporate existing trainer libraries, because it looks like
+ big efforts to build a trainer as powerful as those libraries, and also, that's not our primary goal.
+
+ It's essentially different
+ `tianshou's built-in trainers `__,
+ as it's far much more complicated than that.
+
+ Parameters
+ ----------
+ max_iters
+ Maximum iterations before stopping.
+ val_every_n_iters
+ Perform validation every n iterations (i.e., training collects).
+ logger
+ Logger to record the backtest results. Logger must be present because
+ without logger, all information will be lost.
+ finite_env_type
+ Type of finite env implementation.
+ concurrency
+ Parallel workers.
+ fast_dev_run
+ Create a subset for debugging.
+ How this is implemented depends on the implementation of training vessel.
+ For :class:`~qlib.rl.vessel.TrainingVessel`, if greater than zero,
+ a random subset sized ``fast_dev_run`` will be used
+ instead of ``train_initial_states`` and ``val_initial_states``.
+ """
+
+ should_stop: bool
+ """Set to stop the training."""
+
+ metrics: dict
+ """Numeric metrics of produced in train/val/test.
+ In the middle of training / validation, metrics will be of the latest episode.
+ When each iteration of training / validation finishes, metrics will be the aggregation
+ of all episodes encountered in this iteration.
+
+ Cleared on every new iteration of training.
+
+ In fit, validation metrics will be prefixed with ``val/``.
+ """
+
+ current_iter: int
+ """Current iteration (collect) of training."""
+
+ loggers: list[LogWriter]
+ """A list of log writers."""
+
+ def __init__(
+ self,
+ *,
+ max_iters: int | None = None,
+ val_every_n_iters: int | None = None,
+ loggers: LogWriter | list[LogWriter] | None = None,
+ callbacks: list[Callback] | None = None,
+ finite_env_type: FiniteEnvType = "subproc",
+ concurrency: int = 2,
+ fast_dev_run: int | None = None,
+ ):
+ self.max_iters = max_iters
+ self.val_every_n_iters = val_every_n_iters
+
+ if isinstance(loggers, list):
+ self.loggers = loggers
+ elif isinstance(loggers, LogWriter):
+ self.loggers = [loggers]
+ else:
+ self.loggers = []
+
+ self.loggers.append(LogBuffer(self._metrics_callback, loglevel=self._min_loglevel()))
+
+ self.callbacks: list[Callback] = callbacks if callbacks is not None else []
+ self.finite_env_type = finite_env_type
+ self.concurrency = concurrency
+ self.fast_dev_run = fast_dev_run
+
+ self.current_stage: Literal["train", "val", "test"] = "train"
+
+ self.vessel: TrainingVesselBase = cast(TrainingVesselBase, None)
+
+ def initialize(self):
+ """Initialize the whole training process.
+
+ The states here should be synchronized with state_dict.
+ """
+ self.should_stop = False
+ self.current_iter = 0
+ self.current_episode = 0
+ self.current_stage = "train"
+
+ def initialize_iter(self):
+ """Initialize one iteration / collect."""
+ self.metrics = {}
+
+ def state_dict(self) -> dict:
+ """Putting every states of current training into a dict, at best effort.
+
+ It doesn't try to handle all the possible kinds of states in the middle of one training collect.
+ For most cases at the end of each iteration, things should be usually correct.
+
+ Note that it's also intended behavior that replay buffer data in the collector will be lost.
+ """
+ return {
+ "vessel": self.vessel.state_dict(),
+ "callbacks": {name: callback.state_dict() for name, callback in self.named_callbacks().items()},
+ "loggers": {name: logger.state_dict() for name, logger in self.named_loggers().items()},
+ "should_stop": self.should_stop,
+ "current_iter": self.current_iter,
+ "current_episode": self.current_episode,
+ "current_stage": self.current_stage,
+ "metrics": self.metrics,
+ }
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ """Load all states into current trainer."""
+ self.vessel.load_state_dict(state_dict["vessel"])
+ for name, callback in self.named_callbacks().items():
+ callback.load_state_dict(state_dict["callbacks"][name])
+ for name, logger in self.named_loggers().items():
+ logger.load_state_dict(state_dict["loggers"][name])
+ self.should_stop = state_dict["should_stop"]
+ self.current_iter = state_dict["current_iter"]
+ self.current_episode = state_dict["current_episode"]
+ self.current_stage = state_dict["current_stage"]
+ self.metrics = state_dict["metrics"]
+
+ def named_callbacks(self) -> dict[str, Callback]:
+ """Retrieve a collection of callbacks where each one has a name.
+ Useful when saving checkpoints.
+ """
+ return _named_collection(self.callbacks)
+
+ def named_loggers(self) -> dict[str, LogWriter]:
+ """Retrieve a collection of loggers where each one has a name.
+ Useful when saving checkpoints.
+ """
+ return _named_collection(self.loggers)
+
+ def fit(self, vessel: TrainingVesselBase, ckpt_path: Path | None = None) -> None:
+ """Train the RL policy upon the defined simulator.
+
+ Parameters
+ ----------
+ vessel
+ A bundle of all elements used in training.
+ ckpt_path
+ Load a pre-trained / paused training checkpoint.
+ """
+ self.vessel = vessel
+ vessel.assign_trainer(self)
+
+ if ckpt_path is not None:
+ _logger.info("Resuming states from %s", str(ckpt_path))
+ self.load_state_dict(torch.load(ckpt_path))
+ else:
+ self.initialize()
+
+ self._call_callback_hooks("on_fit_start")
+
+ while not self.should_stop:
+ self.initialize_iter()
+
+ self._call_callback_hooks("on_iter_start")
+
+ self.current_stage = "train"
+ self._call_callback_hooks("on_train_start")
+
+ # TODO
+ # Add a feature that supports reloading the training environment every few iterations.
+ with _wrap_context(vessel.train_seed_iterator()) as iterator:
+ vector_env = self.venv_from_iterator(iterator)
+ self.vessel.train(vector_env)
+
+ self._call_callback_hooks("on_train_end")
+
+ if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0:
+ # Implementation of validation loop
+ self.current_stage = "val"
+ self._call_callback_hooks("on_validate_start")
+ with _wrap_context(vessel.val_seed_iterator()) as iterator:
+ vector_env = self.venv_from_iterator(iterator)
+ self.vessel.validate(vector_env)
+
+ self._call_callback_hooks("on_validate_end")
+
+ # This iteration is considered complete.
+ # Bumping the current iteration counter.
+ self.current_iter += 1
+
+ if self.max_iters is not None and self.current_iter >= self.max_iters:
+ self.should_stop = True
+
+ self._call_callback_hooks("on_iter_end")
+
+ self._call_callback_hooks("on_fit_end")
+
+ def test(self, vessel: TrainingVesselBase) -> None:
+ """Test the RL policy against the simulator.
+
+ The simulator will be fed with data generated in ``test_seed_iterator``.
+
+ Parameters
+ ----------
+ vessel
+ A bundle of all related elements.
+ """
+ self.vessel = vessel
+ vessel.assign_trainer(self)
+
+ self.initialize_iter()
+
+ self.current_stage = "test"
+ self._call_callback_hooks("on_test_start")
+ with _wrap_context(vessel.test_seed_iterator()) as iterator:
+ vector_env = self.venv_from_iterator(iterator)
+ self.vessel.test(vector_env)
+ self._call_callback_hooks("on_test_end")
+
+ def venv_from_iterator(self, iterator: Iterable[InitialStateType]) -> FiniteVectorEnv:
+ """Create a vectorized environment from iterator and the training vessel."""
+
+ def env_factory():
+ # FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
+ # and could be thread unsafe.
+ # I'm not sure whether it's a design flaw.
+ # I'll rethink about this when designing the trainer.
+
+ if self.finite_env_type == "dummy":
+ # We could only experience the "threading-unsafe" problem in dummy.
+ state = copy.deepcopy(self.vessel.state_interpreter)
+ action = copy.deepcopy(self.vessel.action_interpreter)
+ rew = copy.deepcopy(self.vessel.reward)
+ else:
+ state = self.vessel.state_interpreter
+ action = self.vessel.action_interpreter
+ rew = self.vessel.reward
+
+ return EnvWrapper(
+ self.vessel.simulator_fn,
+ state,
+ action,
+ iterator,
+ rew,
+ logger=LogCollector(min_loglevel=self._min_loglevel()),
+ )
+
+ return vectorize_env(
+ env_factory,
+ self.finite_env_type,
+ self.concurrency,
+ self.loggers,
+ )
+
+ def _metrics_callback(self, on_episode: bool, on_collect: bool, log_buffer: LogBuffer) -> None:
+ if on_episode:
+ # Update the global counter.
+ self.current_episode = log_buffer.global_episode
+ metrics = log_buffer.episode_metrics()
+ elif on_collect:
+ # Update the latest metrics.
+ metrics = log_buffer.collect_metrics()
+ if self.current_stage == "val":
+ metrics = {"val/" + name: value for name, value in metrics.items()}
+ self.metrics.update(metrics)
+
+ def _call_callback_hooks(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
+ for callback in self.callbacks:
+ fn = getattr(callback, hook_name)
+ fn(self, self.vessel, *args, **kwargs)
+
+ def _min_loglevel(self):
+ if not self.loggers:
+ return LogLevel.PERIODIC
+ else:
+ # To save bandwidth
+ return min(lg.loglevel for lg in self.loggers)
+
+
+@contextmanager
+def _wrap_context(obj):
+ """Make any object a (possibly dummy) context manager."""
+
+ if isinstance(obj, AbstractContextManager):
+ # obj has __enter__ and __exit__
+ with obj as ctx:
+ yield ctx
+ else:
+ yield obj
+
+
+def _named_collection(seq: Sequence[T]) -> dict[str, T]:
+ """Convert a list into a dict, where each item is named with its type."""
+ res = {}
+ for item in seq:
+ typename = type(item).__name__.lower()
+ if typename not in res:
+ res[typename] = item
+ else:
+ # names are auto-labelled as earlystop1, earlystop2, ...
+ for retry in range(1, 1000):
+ if f"{typename}{retry}" not in res:
+ res[f"{typename}{retry}"] = item
+ return res
diff --git a/qlib/rl/trainer/vessel.py b/qlib/rl/trainer/vessel.py
new file mode 100644
index 000000000..9c0879ce0
--- /dev/null
+++ b/qlib/rl/trainer/vessel.py
@@ -0,0 +1,214 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from __future__ import annotations
+
+import weakref
+from typing import Callable, ContextManager, Generic, Iterable, TYPE_CHECKING, Sequence, Any, TypeVar, cast, Dict
+
+import numpy as np
+from tianshou.data import Collector, VectorReplayBuffer
+from tianshou.env import BaseVectorEnv
+from tianshou.policy import BasePolicy
+
+from qlib.constant import INF
+from qlib.rl.interpreter import StateType, ActType, ObsType, PolicyActType
+from qlib.rl.simulator import InitialStateType, Simulator
+from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
+from qlib.rl.reward import Reward
+from qlib.rl.utils import DataQueue
+from qlib.log import get_module_logger
+from qlib.rl.utils.finite_env import FiniteVectorEnv
+
+if TYPE_CHECKING:
+ from .trainer import Trainer
+
+
+T = TypeVar("T")
+_logger = get_module_logger(__name__)
+
+
+class SeedIteratorNotAvailable(BaseException):
+ pass
+
+
+class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType]):
+ """A ship that contains simulator, interpreter, and policy, will be sent to trainer.
+ This class controls algorithm-related parts of training, while trainer is responsible for runtime part.
+
+ The ship also defines the most important logic of the core training part,
+ and (optionally) some callbacks to insert customized logics at specific events.
+ """
+
+ simulator_fn: Callable[[InitialStateType], Simulator[InitialStateType, StateType, ActType]]
+ state_interpreter: StateInterpreter[StateType, ObsType]
+ action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType]
+ policy: BasePolicy
+ reward: Reward
+ trainer: Trainer
+
+ def assign_trainer(self, trainer: Trainer) -> None:
+ self.trainer = weakref.proxy(trainer) # type: ignore
+
+ def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
+ """Override this to create a seed iterator for training.
+ If the iterable is a context manager, the whole training will be invoked in the with-block,
+ and the iterator will be automatically closed after the training is done."""
+ raise SeedIteratorNotAvailable("Seed iterator for training is not available.")
+
+ def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
+ """Override this to create a seed iterator for validation."""
+ raise SeedIteratorNotAvailable("Seed iterator for validation is not available.")
+
+ def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
+ """Override this to create a seed iterator for testing."""
+ raise SeedIteratorNotAvailable("Seed iterator for testing is not available.")
+
+ def train(self, vector_env: BaseVectorEnv) -> dict[str, Any]:
+ """Implement this to train one iteration. In RL, one iteration usually refers to one collect."""
+ raise NotImplementedError()
+
+ def validate(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
+ """Implement this to validate the policy once."""
+ raise NotImplementedError()
+
+ def test(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
+ """Implement this to evaluate the policy on test environment once."""
+ raise NotImplementedError()
+
+ def log(self, name: str, value: Any) -> None:
+ # FIXME: this is a workaround to make the log at least show somewhere.
+ # Need a refactor in logger to formalize this.
+ if isinstance(value, (np.ndarray, list)):
+ value = np.mean(value)
+ _logger.info(f"[Iter {self.trainer.current_iter + 1}] {name} = {value}")
+
+ def log_dict(self, data: dict[str, Any]) -> None:
+ for name, value in data.items():
+ self.log(name, value)
+
+ def state_dict(self) -> dict:
+ """Return a checkpoint of current vessel state."""
+ return {"policy": self.policy.state_dict()}
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ """Restore a checkpoint from a previously saved state dict."""
+ self.policy.load_state_dict(state_dict["policy"])
+
+
+class TrainingVessel(TrainingVesselBase):
+ """The default implementation of training vessel.
+
+ ``__init__`` accepts a sequence of initial states so that iterator can be created.
+ ``train``, ``validate``, ``test`` each do one collect (and also update in train).
+ By default, the train initial states will be repeated infinitely during training,
+ and collector will control the number of episodes for each iteration.
+ In validation and testing, the val / test initial states will be used exactly once.
+
+ Extra hyper-parameters (only used in train) include:
+
+ - ``buffer_size``: Size of replay buffer.
+ - ``episode_per_iter``: Episodes per collect at training. Can be overridden by fast dev run.
+ - ``update_kwargs``: Keyword arguments appearing in ``policy.update``.
+ For example, ``dict(repeat=10, batch_size=64)``.
+ """
+
+ def __init__(
+ self,
+ *,
+ simulator_fn: Callable[[InitialStateType], Simulator[InitialStateType, StateType, ActType]],
+ state_interpreter: StateInterpreter[StateType, ObsType],
+ action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
+ policy: BasePolicy,
+ reward: Reward,
+ train_initial_states: Sequence[InitialStateType] | None = None,
+ val_initial_states: Sequence[InitialStateType] | None = None,
+ test_initial_states: Sequence[InitialStateType] | None = None,
+ buffer_size: int = 20000,
+ episode_per_iter: int = 1000,
+ update_kwargs: dict[str, Any] = cast(Dict[str, Any], None),
+ ):
+ self.simulator_fn = simulator_fn # type: ignore
+ self.state_interpreter = state_interpreter
+ self.action_interpreter = action_interpreter
+ self.policy = policy
+ self.reward = reward
+ self.train_initial_states = train_initial_states
+ self.val_initial_states = val_initial_states
+ self.test_initial_states = test_initial_states
+ self.buffer_size = buffer_size
+ self.episode_per_iter = episode_per_iter
+ self.update_kwargs = update_kwargs or {}
+
+ def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
+ if self.train_initial_states is not None:
+ _logger.info("Training initial states collection size: %d", len(self.train_initial_states))
+ # Implement fast_dev_run here.
+ train_initial_states = self._random_subset("train", self.train_initial_states, self.trainer.fast_dev_run)
+ return DataQueue(train_initial_states, repeat=-1, shuffle=True)
+ return super().train_seed_iterator()
+
+ def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
+ if self.val_initial_states is not None:
+ _logger.info("Validation initial states collection size: %d", len(self.val_initial_states))
+ val_initial_states = self._random_subset("val", self.val_initial_states, self.trainer.fast_dev_run)
+ return DataQueue(val_initial_states, repeat=1)
+ return super().val_seed_iterator()
+
+ def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
+ if self.test_initial_states is not None:
+ _logger.info("Testing initial states collection size: %d", len(self.test_initial_states))
+ test_initial_states = self._random_subset("test", self.test_initial_states, self.trainer.fast_dev_run)
+ return DataQueue(test_initial_states, repeat=1)
+ return super().test_seed_iterator()
+
+ def train(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
+ """Create a collector and collects ``episode_per_iter`` episodes.
+ Update the policy on the collected replay buffer.
+ """
+ self.policy.train()
+
+ with vector_env.collector_guard():
+ collector = Collector(self.policy, vector_env, VectorReplayBuffer(self.buffer_size, len(vector_env)))
+
+ # Number of episodes collected in each training iteration can be overridden by fast dev run.
+ if self.trainer.fast_dev_run is not None:
+ episodes = self.trainer.fast_dev_run
+ else:
+ episodes = self.episode_per_iter
+
+ col_result = collector.collect(n_episode=episodes)
+ update_result = self.policy.update(sample_size=0, buffer=collector.buffer, **self.update_kwargs)
+ res = {**col_result, **update_result}
+ self.log_dict(res)
+ return res
+
+ def validate(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
+ self.policy.eval()
+
+ with vector_env.collector_guard():
+ test_collector = Collector(self.policy, vector_env)
+ res = test_collector.collect(n_step=INF * len(vector_env))
+ self.log_dict(res)
+ return res
+
+ def test(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
+ self.policy.eval()
+
+ with vector_env.collector_guard():
+ test_collector = Collector(self.policy, vector_env)
+ res = test_collector.collect(n_step=INF * len(vector_env))
+ self.log_dict(res)
+ return res
+
+ @staticmethod
+ def _random_subset(name: str, collection: Sequence[T], size: int | None) -> Sequence[T]:
+ if size is None:
+ # Size = None -> original collection
+ return collection
+ order = np.random.permutation(len(collection))
+ res = [collection[o] for o in order[:size]]
+ _logger.info(
+ "Fast running in development mode. Cut %s initial states from %d to %d.", name, len(collection), len(res)
+ )
+ return res
diff --git a/qlib/rl/utils/data_queue.py b/qlib/rl/utils/data_queue.py
index 32041abef..c1f7f3ab0 100644
--- a/qlib/rl/utils/data_queue.py
+++ b/qlib/rl/utils/data_queue.py
@@ -145,7 +145,9 @@ class DataQueue(Generic[T]):
def __iter__(self):
if not self._activated:
raise ValueError(
- "Need to call activate() to launch a daemon worker to produce data into data queue before using it."
+ "Need to call activate() to launch a daemon worker "
+ "to produce data into data queue before using it. "
+ "You probably have forgotten to use the DataQueue in a with block."
)
return self._consumer()
@@ -161,19 +163,21 @@ class DataQueue(Generic[T]):
# pytorch dataloader is used here only because we need its sampler and multi-processing
from torch.utils.data import DataLoader, Dataset # pylint: disable=import-outside-toplevel
- dataloader = DataLoader(
- cast(Dataset[T], self.dataset),
- batch_size=None,
- num_workers=self.producer_num_workers,
- shuffle=self.shuffle,
- collate_fn=lambda t: t, # identity collate fn
- )
- repeat = 10**18 if self.repeat == -1 else self.repeat
- for _rep in range(repeat):
- for data in dataloader:
- if self._done.value:
- # Already done.
- return
- self._queue.put(data)
- _logger.debug(f"Dataloader loop done. Repeat {_rep}.")
- self.mark_as_done()
+ try:
+ dataloader = DataLoader(
+ cast(Dataset[T], self.dataset),
+ batch_size=None,
+ num_workers=self.producer_num_workers,
+ shuffle=self.shuffle,
+ collate_fn=lambda t: t, # identity collate fn
+ )
+ repeat = 10**18 if self.repeat == -1 else self.repeat
+ for _rep in range(repeat):
+ for data in dataloader:
+ if self._done.value:
+ # Already done.
+ return
+ self._queue.put(data)
+ _logger.debug(f"Dataloader loop done. Repeat {_rep}.")
+ finally:
+ self.mark_as_done()
diff --git a/qlib/rl/utils/finite_env.py b/qlib/rl/utils/finite_env.py
index fc9c2c75e..6d7b0e209 100644
--- a/qlib/rl/utils/finite_env.py
+++ b/qlib/rl/utils/finite_env.py
@@ -120,12 +120,19 @@ class FiniteVectorEnv(BaseVectorEnv):
from child workers. See :class:`qlib.rl.utils.LogWriter`.
"""
+ _logger: list[LogWriter]
+
def __init__(
- self, logger: LogWriter | list[LogWriter], env_fns: list[Callable[..., gym.Env]], **kwargs: Any
+ self, logger: LogWriter | list[LogWriter] | None, env_fns: list[Callable[..., gym.Env]], **kwargs: Any
) -> None:
super().__init__(env_fns, **kwargs)
- self._logger: list[LogWriter] = logger if isinstance(logger, list) else [logger]
+ if isinstance(logger, list):
+ self._logger = logger
+ elif isinstance(logger, LogWriter):
+ self._logger = [logger]
+ else:
+ self._logger = []
self._alive_env_ids: Set[int] = set()
self._reset_alive_envs()
self._default_obs = self._default_info = self._default_rew = None
@@ -177,7 +184,7 @@ class FiniteVectorEnv(BaseVectorEnv):
1. Catch and ignore the StopIteration exception, which is the stopping signal
thrown by FiniteEnv to let tianshou know that ``collector.collect()`` should exit.
- 2. Notify the loggers that the collect is done what it's done.
+ 2. Notify the loggers that the collect is ready / done what it's ready / done.
Examples
--------
@@ -186,6 +193,9 @@ class FiniteVectorEnv(BaseVectorEnv):
"""
self._collector_guarded = True
+ for logger in self._logger:
+ logger.on_env_all_ready()
+
try:
yield self
except StopIteration:
@@ -298,7 +308,21 @@ def vectorize_env(
concurrency: int,
logger: LogWriter | list[LogWriter],
) -> FiniteVectorEnv:
- """Helper function to create a vector env.
+ """Helper function to create a vector env. Can be used to replace usual VectorEnv.
+
+ For example, once you wrote: ::
+
+ DummyVectorEnv([lambda: gym.make(task) for _ in range(env_num)])
+
+ Now you can replace it with: ::
+
+ finite_env_factory(lambda: gym.make(task), "dummy", env_num, my_logger)
+
+ By doing such replacement, you have two additional features enabled (compared to normal VectorEnv):
+
+ 1. The vector env will check for NaN observation and kill the worker when its found.
+ See :class:`FiniteVectorEnv` for why we need this.
+ 2. A logger to explicit collect logs from environment workers.
Parameters
----------
diff --git a/qlib/rl/utils/log.py b/qlib/rl/utils/log.py
index 3d495b11d..409a48a76 100644
--- a/qlib/rl/utils/log.py
+++ b/qlib/rl/utils/log.py
@@ -12,13 +12,16 @@ in each worker, and writes them to console, log files, or tensorboard...
The two modules communicate by the "log" field in "info" returned by ``env.step()``.
"""
+# NOTE: This file contains many hardcoded / ad-hoc rules.
+# Refactoring it will be one of the future tasks.
+
from __future__ import annotations
import logging
from collections import defaultdict
from enum import IntEnum
from pathlib import Path
-from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence
+from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence, Callable
import numpy as np
import pandas as pd
@@ -29,7 +32,7 @@ if TYPE_CHECKING:
from .env_wrapper import InfoDict
-__all__ = ["LogCollector", "LogWriter", "LogLevel", "ConsoleWriter", "CsvWriter"]
+__all__ = ["LogCollector", "LogWriter", "LogLevel", "LogBuffer", "ConsoleWriter", "CsvWriter"]
ObsType = TypeVar("ObsType")
ActType = TypeVar("ActType")
@@ -175,18 +178,53 @@ class LogWriter(Generic[ObsType, ActType]):
self.clear()
def clear(self):
+ """Clear all the metrics for a fresh start.
+ To make the logger instance reusable.
+ """
self.episode_count = self.step_count = 0
self.active_env_ids = set()
- self.logs = []
- def aggregation(self, array: Sequence[Any]) -> Any:
+ def state_dict(self) -> dict:
+ """Save the states of the logger to a dict."""
+ return {
+ "episode_count": self.episode_count,
+ "step_count": self.step_count,
+ "global_step": self.global_step,
+ "global_episode": self.global_episode,
+ "active_env_ids": self.active_env_ids,
+ "episode_lengths": self.episode_lengths,
+ "episode_rewards": self.episode_rewards,
+ "episode_logs": self.episode_logs,
+ }
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ """Load the states of current logger from a dict."""
+ self.episode_count = state_dict["episode_count"]
+ self.step_count = state_dict["step_count"]
+ self.global_step = state_dict["global_step"]
+ self.global_episode = state_dict["global_episode"]
+
+ # These are runtime infos.
+ # Though they are loaded, I don't think it really helps.
+ self.active_env_ids = state_dict["active_env_ids"]
+ self.episode_lenghts = state_dict["episode_lengths"]
+ self.episode_rewards = state_dict["episode_rewards"]
+ self.episode_logs = state_dict["episode_logs"]
+
+ def aggregation(self, array: Sequence[Any], name: str | None = None) -> Any:
"""Aggregation function from step-wise to episode-wise.
If it's a sequence of float, take the mean.
Otherwise, take the first element.
+
+ If a name is specified and,
+
+ - if it's ``reward``, the reduction will be sum.
"""
assert len(array) > 0, "The aggregated array must be not empty."
if all(isinstance(v, float) for v in array):
+ if name == "reward":
+ return np.sum(array)
return np.mean(array)
else:
return array[0]
@@ -253,10 +291,93 @@ class LogWriter(Generic[ObsType, ActType]):
self.episode_rewards[env_id] = []
self.episode_logs[env_id] = []
+ def on_env_all_ready(self) -> None:
+ """When all environments are ready to run.
+ Usually, loggers should be reset here.
+ """
+ self.clear()
+
def on_env_all_done(self) -> None:
"""All done. Time for cleanup."""
+class LogBuffer(LogWriter):
+ """Keep all numbers in memory.
+
+ Objects that can't be aggregated like strings, tensors, images can't be stored in the buffer.
+ To persist them, please use :class:`PickleWriter`.
+
+ Every time, Log buffer receives a new metric, the callback is triggered,
+ which is useful when tracking metrics inside a trainer.
+
+ Parameters
+ ----------
+ callback
+ A callback receiving three arguments:
+
+ - on_episode: Whether it's called at the end of an episode
+ - on_collect: Whether it's called at the end of a collect
+ - log_buffer: the :class:`LogBbuffer`object
+
+ No return value is expected.
+ """
+
+ # FIXME: needs a metric count
+
+ def __init__(self, callback: Callable[[bool, bool, LogBuffer], None], loglevel: int | LogLevel = LogLevel.PERIODIC):
+ super().__init__(loglevel)
+ self.callback = callback
+
+ def state_dict(self) -> dict:
+ return {
+ **super().state_dict(),
+ "latest_metrics": self._latest_metrics,
+ "aggregated_metrics": self._aggregated_metrics,
+ }
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ self._latest_metrics = state_dict["latest_metrics"]
+ self._aggregated_metrics = state_dict["aggregated_metrics"]
+ return super().load_state_dict(state_dict)
+
+ def clear(self):
+ super().clear()
+ self._latest_metrics: dict[str, float] | None = None
+ self._aggregated_metrics: dict[str, float] = defaultdict(float)
+
+ def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
+ # FIXME Dup of ConsoleWriter
+ episode_wise_contents: dict[str, list] = defaultdict(list)
+ for step_contents in contents:
+ for name, value in step_contents.items():
+ # FIXME This could be false-negative for some numpy types
+ if isinstance(value, float):
+ episode_wise_contents[name].append(value)
+
+ logs: dict[str, float] = {}
+ for name, values in episode_wise_contents.items():
+ logs[name] = self.aggregation(values, name) # type: ignore
+ self._aggregated_metrics[name] += logs[name]
+
+ self._latest_metrics = logs
+
+ self.callback(True, False, self)
+
+ def on_env_all_done(self) -> None:
+ # This happens when collect exits
+ self.callback(False, True, self)
+
+ def episode_metrics(self) -> dict[str, float]:
+ """Retrieve the numeric metrics of the latest episode."""
+ if self._latest_metrics is None:
+ raise ValueError("No episode metrics available yet.")
+ return self._latest_metrics
+
+ def collect_metrics(self) -> dict[str, float]:
+ """Retrieve the aggregated metrics of the latest collect."""
+ return {name: value / self.episode_count for name, value in self._aggregated_metrics.items()}
+
+
class ConsoleWriter(LogWriter):
"""Write log messages to console periodically.
@@ -289,6 +410,8 @@ class ConsoleWriter(LogWriter):
self.console_logger = get_module_logger(__name__, level=logging.INFO)
+ # FIXME: save & reload
+
def clear(self):
super().clear()
# Clear average meters
@@ -308,7 +431,7 @@ class ConsoleWriter(LogWriter):
# This should be done at every step, regardless of periodic or not.
logs: dict[str, float] = {}
for name, values in episode_wise_contents.items():
- logs[name] = self.aggregation(values) # type: ignore
+ logs[name] = self.aggregation(values, name) # type: ignore
for name, value in logs.items():
self.metric_counts[name] += 1
@@ -350,6 +473,8 @@ class CsvWriter(LogWriter):
all_records: list[dict[str, Any]]
+ # FIXME: save & reload
+
def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC):
super().__init__(loglevel)
self.output_dir = output_dir
@@ -370,7 +495,7 @@ class CsvWriter(LogWriter):
logs: dict[str, float] = {}
for name, values in episode_wise_contents.items():
- logs[name] = self.aggregation(values) # type: ignore
+ logs[name] = self.aggregation(values, name) # type: ignore
self.all_records.append(logs)
@@ -392,7 +517,3 @@ class TensorboardWriter(LogWriter):
class MlflowWriter(LogWriter):
"""Add logs to mlflow."""
-
-
-class LogBuffer(LogWriter):
- """Keep everything in memory."""
diff --git a/tests/rl/test_logger.py b/tests/rl/test_logger.py
index 240ffc1e1..2cf149a75 100644
--- a/tests/rl/test_logger.py
+++ b/tests/rl/test_logger.py
@@ -81,7 +81,7 @@ def test_simple_env_logger(caplog):
line = line.strip()
if line:
line_counter += 1
- assert re.match(r".*reward 42\.0000 \(42.0000\) a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
+ assert re.match(r".*reward .* a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
assert line_counter >= 3
diff --git a/tests/rl/test_saoe_simple.py b/tests/rl/test_saoe_simple.py
index 2ac0d9cbd..98e5dd981 100644
--- a/tests/rl/test_saoe_simple.py
+++ b/tests/rl/test_saoe_simple.py
@@ -17,7 +17,7 @@ from qlib.backtest import Order
from qlib.config import C
from qlib.log import set_log_with_config
from qlib.rl.data import pickle_styled
-from qlib.rl.entries.test import backtest
+from qlib.rl.trainer import backtest, train
from qlib.rl.order_execution import *
from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
@@ -306,3 +306,26 @@ def test_cn_ppo_strategy():
assert np.isclose(metrics["pa"].mean(), -16.21578303474833)
assert np.isclose(metrics["market_price"].mean(), 58.68277690875527)
assert np.isclose(metrics["trade_price"].mean(), 58.76063985000002)
+
+
+def test_ppo_train():
+ set_log_with_config(C.logging_config)
+ # The data starts with 9:31 and ends with 15:00
+ orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58"))
+ assert len(orders) == 40
+
+ state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6)
+ action_interp = CategoricalActionInterpreter(4)
+ network = Recurrent(state_interp.observation_space)
+ policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
+
+ train(
+ partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
+ state_interp,
+ action_interp,
+ orders,
+ policy,
+ PAPenaltyReward(),
+ vessel_kwargs={"episode_per_iter": 100, "update_kwargs": {"batch_size": 64, "repeat": 5}},
+ trainer_kwargs={"max_iters": 2, "loggers": ConsoleWriter(total_episodes=100)},
+ )
diff --git a/tests/rl/test_trainer.py b/tests/rl/test_trainer.py
new file mode 100644
index 000000000..751fbd387
--- /dev/null
+++ b/tests/rl/test_trainer.py
@@ -0,0 +1,202 @@
+import os
+import random
+import sys
+from pathlib import Path
+
+import pytest
+
+import torch
+import torch.nn as nn
+from gym import spaces
+from tianshou.policy import PPOPolicy
+
+from qlib.config import C
+from qlib.log import set_log_with_config
+from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
+from qlib.rl.simulator import Simulator
+from qlib.rl.reward import Reward
+from qlib.rl.trainer import Trainer, TrainingVessel, EarlyStopping, Checkpoint
+
+pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
+
+
+class ZeroSimulator(Simulator):
+ def __init__(self, *args, **kwargs):
+ self.action = self.correct = 0
+
+ def step(self, action):
+ self.action = action
+ self.correct = action == 0
+ self._done = random.choice([False, True])
+ if self._done:
+ self.env.logger.add_scalar("acc", self.correct * 100)
+
+ def get_state(self):
+ return {
+ "acc": self.correct * 100,
+ "action": self.action,
+ }
+
+ def done(self) -> bool:
+ return self._done
+
+
+class NoopStateInterpreter(StateInterpreter):
+ observation_space = spaces.Dict(
+ {
+ "acc": spaces.Discrete(200),
+ "action": spaces.Discrete(2),
+ }
+ )
+
+ def interpret(self, simulator_state):
+ return simulator_state
+
+
+class NoopActionInterpreter(ActionInterpreter):
+ action_space = spaces.Discrete(2)
+
+ def interpret(self, simulator_state, action):
+ return action
+
+
+class AccReward(Reward):
+ def reward(self, simulator_state):
+ if self.env.status["done"]:
+ return simulator_state["acc"] / 100
+ return 0.0
+
+
+class PolicyNet(nn.Module):
+ def __init__(self, out_features=1, return_state=False):
+ super().__init__()
+ self.fc = nn.Linear(32, out_features)
+ self.return_state = return_state
+
+ def forward(self, obs, state=None, **kwargs):
+ res = self.fc(torch.randn(obs["acc"].shape[0], 32))
+ if self.return_state:
+ return nn.functional.softmax(res, dim=-1), state
+ else:
+ return res
+
+
+def _ppo_policy():
+ actor = PolicyNet(2, True)
+ critic = PolicyNet()
+ policy = PPOPolicy(
+ actor,
+ critic,
+ torch.optim.Adam(tuple(actor.parameters()) + tuple(critic.parameters())),
+ torch.distributions.Categorical,
+ action_space=NoopActionInterpreter().action_space,
+ )
+ return policy
+
+
+def test_trainer():
+ set_log_with_config(C.logging_config)
+ trainer = Trainer(max_iters=10, finite_env_type="subproc")
+ policy = _ppo_policy()
+
+ vessel = TrainingVessel(
+ simulator_fn=lambda init: ZeroSimulator(init),
+ state_interpreter=NoopStateInterpreter(),
+ action_interpreter=NoopActionInterpreter(),
+ policy=policy,
+ train_initial_states=list(range(100)),
+ val_initial_states=list(range(10)),
+ test_initial_states=list(range(10)),
+ reward=AccReward(),
+ episode_per_iter=500,
+ update_kwargs=dict(repeat=10, batch_size=64),
+ )
+ trainer.fit(vessel)
+ assert trainer.current_iter == 10
+ assert trainer.current_episode == 5000
+ assert abs(trainer.metrics["acc"] - trainer.metrics["reward"] * 100) < 1e-4
+ assert trainer.metrics["acc"] > 80
+ trainer.test(vessel)
+ assert trainer.metrics["acc"] > 60
+
+
+def test_trainer_fast_dev_run():
+ set_log_with_config(C.logging_config)
+ trainer = Trainer(max_iters=2, fast_dev_run=2, finite_env_type="shmem")
+ policy = _ppo_policy()
+
+ vessel = TrainingVessel(
+ simulator_fn=lambda init: ZeroSimulator(init),
+ state_interpreter=NoopStateInterpreter(),
+ action_interpreter=NoopActionInterpreter(),
+ policy=policy,
+ train_initial_states=list(range(100)),
+ val_initial_states=list(range(10)),
+ test_initial_states=list(range(10)),
+ reward=AccReward(),
+ episode_per_iter=500,
+ update_kwargs=dict(repeat=10, batch_size=64),
+ )
+ trainer.fit(vessel)
+ assert trainer.current_episode == 4
+
+
+def test_trainer_earlystop():
+ # TODO this is just sanity check.
+ # need to see the logs to check whether it works.
+ set_log_with_config(C.logging_config)
+ trainer = Trainer(
+ max_iters=10,
+ val_every_n_iters=1,
+ finite_env_type="dummy",
+ callbacks=[EarlyStopping("val/reward", restore_best_weights=True)],
+ )
+ policy = _ppo_policy()
+
+ vessel = TrainingVessel(
+ simulator_fn=lambda init: ZeroSimulator(init),
+ state_interpreter=NoopStateInterpreter(),
+ action_interpreter=NoopActionInterpreter(),
+ policy=policy,
+ train_initial_states=list(range(100)),
+ val_initial_states=list(range(10)),
+ test_initial_states=list(range(10)),
+ reward=AccReward(),
+ episode_per_iter=500,
+ update_kwargs=dict(repeat=10, batch_size=64),
+ )
+ trainer.fit(vessel)
+ assert trainer.metrics["val/acc"] > 30
+ assert trainer.current_iter == 2 # second iteration
+
+
+def test_trainer_checkpoint():
+ set_log_with_config(C.logging_config)
+ output_dir = Path(__file__).parent / ".output"
+ trainer = Trainer(max_iters=2, finite_env_type="dummy", callbacks=[Checkpoint(output_dir, every_n_iters=1)])
+ policy = _ppo_policy()
+
+ vessel = TrainingVessel(
+ simulator_fn=lambda init: ZeroSimulator(init),
+ state_interpreter=NoopStateInterpreter(),
+ action_interpreter=NoopActionInterpreter(),
+ policy=policy,
+ train_initial_states=list(range(100)),
+ val_initial_states=list(range(10)),
+ test_initial_states=list(range(10)),
+ reward=AccReward(),
+ episode_per_iter=100,
+ update_kwargs=dict(repeat=10, batch_size=64),
+ )
+ trainer.fit(vessel)
+
+ assert (output_dir / "001.pth").exists()
+ assert (output_dir / "002.pth").exists()
+ assert os.readlink(output_dir / "latest.pth") == str(output_dir / "002.pth")
+
+ trainer.load_state_dict(torch.load(output_dir / "001.pth"))
+ assert trainer.current_iter == 1
+ assert trainer.current_episode == 100
+
+ # Reload the checkpoint at first iteration
+ trainer.fit(vessel, ckpt_path=output_dir / "001.pth")