diff --git a/qlib/rl/entries/__init__.py b/qlib/rl/entries/__init__.py deleted file mode 100644 index 169fa985c..000000000 --- a/qlib/rl/entries/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -"""Train, test, inference utilities. - -The APIs in this directory are NOT considered final and are subject to change! -""" diff --git a/qlib/rl/entries/test.py b/qlib/rl/entries/test.py deleted file mode 100644 index ca311407b..000000000 --- a/qlib/rl/entries/test.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from __future__ import annotations - -import copy -from typing import Callable, Sequence - -from tianshou.data import Collector -from tianshou.policy import BasePolicy - -from qlib.constant import INF -from qlib.log import get_module_logger -from qlib.rl.simulator import InitialStateType, Simulator -from qlib.rl.interpreter import StateInterpreter, ActionInterpreter -from qlib.rl.reward import Reward -from qlib.rl.utils import DataQueue, EnvWrapper, FiniteEnvType, LogCollector, LogWriter, vectorize_env - - -_logger = get_module_logger(__name__) - - -def backtest( - simulator_fn: Callable[[InitialStateType], Simulator], - state_interpreter: StateInterpreter, - action_interpreter: ActionInterpreter, - initial_states: Sequence[InitialStateType], - policy: BasePolicy, - logger: LogWriter | list[LogWriter], - reward: Reward | None = None, - finite_env_type: FiniteEnvType = "subproc", - concurrency: int = 2, -) -> None: - """Backtest with the parallelism provided by RL framework. - - Parameters - ---------- - simulator_fn - Callable receiving initial seed, returning a simulator. - state_interpreter - Interprets the state of simulators. - action_interpreter - Interprets the policy actions. - initial_states - Initial states to iterate over. Every state will be run exactly once. - policy - Policy to test against. - logger - Logger to record the backtest results. Logger must be present because - without logger, all information will be lost. - reward - Optional reward function. For backtest, this is for testing the rewards - and logging them only. - finite_env_type - Type of finite env implementation. - concurrency - Parallel workers. - """ - - # To save bandwidth - min_loglevel = min(lg.loglevel for lg in logger) if isinstance(logger, list) else logger.loglevel - - def env_factory(): - # FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env), - # and could be thread unsafe. - # I'm not sure whether it's a design flaw. - # I'll rethink about this when designing the trainer. - - if finite_env_type == "dummy": - # We could only experience the "threading-unsafe" problem in dummy. - state = copy.deepcopy(state_interpreter) - action = copy.deepcopy(action_interpreter) - rew = copy.deepcopy(reward) - else: - state, action, rew = state_interpreter, action_interpreter, reward - - return EnvWrapper( - simulator_fn, - state, - action, - seed_iterator, - rew, - logger=LogCollector(min_loglevel=min_loglevel), - ) - - with DataQueue(initial_states) as seed_iterator: - vector_env = vectorize_env( - env_factory, - finite_env_type, - concurrency, - logger, - ) - - policy.eval() - - with vector_env.collector_guard(): - test_collector = Collector(policy, vector_env) - _logger.info("All ready. Start backtest.") - test_collector.collect(n_step=INF * len(vector_env)) diff --git a/qlib/rl/entries/train.py b/qlib/rl/entries/train.py deleted file mode 100644 index c852e6235..000000000 --- a/qlib/rl/entries/train.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -# TBD diff --git a/qlib/rl/order_execution/__init__.py b/qlib/rl/order_execution/__init__.py index 048dfecac..b7b47c3d1 100644 --- a/qlib/rl/order_execution/__init__.py +++ b/qlib/rl/order_execution/__init__.py @@ -9,4 +9,5 @@ Multi-asset is on the way. from .interpreter import * from .network import * from .policy import * +from .reward import * from .simulator_simple import * diff --git a/qlib/rl/order_execution/reward.py b/qlib/rl/order_execution/reward.py new file mode 100644 index 000000000..43015407d --- /dev/null +++ b/qlib/rl/order_execution/reward.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import cast + +import numpy as np +from qlib.rl.reward import Reward + +from .simulator_simple import SAOEState, SAOEMetrics + +__all__ = ["PAPenaltyReward"] + + +class PAPenaltyReward(Reward[SAOEState]): + """Encourage higher PAs, but penalize stacking all the amounts within a very short time. + Formally, for each time step, the reward is :math:`(PA_t * vol_t / target - vol_t^2 * penalty)`. + + Parameters + ---------- + penalty + The penalty for large volume in a short time. + """ + + def __init__(self, penalty: float = 100.0): + self.penalty = penalty + + def reward(self, simulator_state: SAOEState) -> float: + whole_order = simulator_state.order.amount + assert whole_order > 0 + last_step = cast(SAOEMetrics, simulator_state.history_steps.reset_index().iloc[-1].to_dict()) + pa = last_step["pa"] * last_step["amount"] / whole_order + + # Inspect the "break-down" of the latest step: trading amount at every tick + last_step_breakdown = simulator_state.history_exec.loc[last_step["datetime"] :] + penalty = -self.penalty * ((last_step_breakdown["amount"] / whole_order) ** 2).sum() + + reward = pa + penalty + + # Throw error in case of NaN + assert not (np.isnan(reward) or np.isinf(reward)), f"Invalid reward for simulator state: {simulator_state}" + + self.log("reward/pa", pa) + self.log("reward/penalty", penalty) + return reward diff --git a/qlib/rl/order_execution/simulator_simple.py b/qlib/rl/order_execution/simulator_simple.py index 8022c34ce..51357dfdf 100644 --- a/qlib/rl/order_execution/simulator_simple.py +++ b/qlib/rl/order_execution/simulator_simple.py @@ -131,11 +131,14 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): """ history_exec: pd.DataFrame - """All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns.""" + """All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns. + Index is ``datetime``. + """ history_steps: pd.DataFrame """Positions at each step. The position before first step is also recorded. - See :class:`SAOEMetrics` for available columns.""" + See :class:`SAOEMetrics` for available columns. + Index is ``datetime``, which is the **starting** time of each step.""" metrics: SAOEMetrics | None """Metrics. Only available when done.""" diff --git a/qlib/rl/trainer/__init__.py b/qlib/rl/trainer/__init__.py new file mode 100644 index 000000000..efce804c4 --- /dev/null +++ b/qlib/rl/trainer/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Train, test, inference utilities.""" + +from .api import backtest, train +from .callbacks import EarlyStopping, Checkpoint +from .trainer import Trainer +from .vessel import TrainingVessel, TrainingVesselBase diff --git a/qlib/rl/trainer/api.py b/qlib/rl/trainer/api.py new file mode 100644 index 000000000..65abbd88d --- /dev/null +++ b/qlib/rl/trainer/api.py @@ -0,0 +1,118 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Callable, Sequence, cast, Any + +from tianshou.policy import BasePolicy + +from qlib.rl.simulator import InitialStateType, Simulator +from qlib.rl.interpreter import StateInterpreter, ActionInterpreter +from qlib.rl.reward import Reward +from qlib.rl.utils import FiniteEnvType, LogWriter + +from .vessel import TrainingVessel +from .trainer import Trainer + + +def train( + simulator_fn: Callable[[InitialStateType], Simulator], + state_interpreter: StateInterpreter, + action_interpreter: ActionInterpreter, + initial_states: Sequence[InitialStateType], + policy: BasePolicy, + reward: Reward, + vessel_kwargs: dict[str, Any], + trainer_kwargs: dict[str, Any], +) -> None: + """Train a policy with the parallelism provided by RL framework. + + Experimental API. Parameters might change shortly. + + Parameters + ---------- + simulator_fn + Callable receiving initial seed, returning a simulator. + state_interpreter + Interprets the state of simulators. + action_interpreter + Interprets the policy actions. + initial_states + Initial states to iterate over. Every state will be run exactly once. + policy + Policy to train against. + reward + Reward function. + vessel_kwargs + Keyword arguments passed to :class:`TrainingVessel`, like ``episode_per_iter``. + trainer_kwargs + Keyword arguments passed to :class:`Trainer`, like ``finite_env_type``, ``concurrency``. + """ + + vessel = TrainingVessel( + simulator_fn=simulator_fn, + state_interpreter=state_interpreter, + action_interpreter=action_interpreter, + policy=policy, + train_initial_states=initial_states, + reward=reward, # ignore none + **vessel_kwargs, + ) + trainer = Trainer(**trainer_kwargs) + trainer.fit(vessel) + + +def backtest( + simulator_fn: Callable[[InitialStateType], Simulator], + state_interpreter: StateInterpreter, + action_interpreter: ActionInterpreter, + initial_states: Sequence[InitialStateType], + policy: BasePolicy, + logger: LogWriter | list[LogWriter], + reward: Reward | None = None, + finite_env_type: FiniteEnvType = "subproc", + concurrency: int = 2, +) -> None: + """Backtest with the parallelism provided by RL framework. + + Experimental API. Parameters might change shortly. + + Parameters + ---------- + simulator_fn + Callable receiving initial seed, returning a simulator. + state_interpreter + Interprets the state of simulators. + action_interpreter + Interprets the policy actions. + initial_states + Initial states to iterate over. Every state will be run exactly once. + policy + Policy to test against. + logger + Logger to record the backtest results. Logger must be present because + without logger, all information will be lost. + reward + Optional reward function. For backtest, this is for testing the rewards + and logging them only. + finite_env_type + Type of finite env implementation. + concurrency + Parallel workers. + """ + + vessel = TrainingVessel( + simulator_fn=simulator_fn, + state_interpreter=state_interpreter, + action_interpreter=action_interpreter, + policy=policy, + test_initial_states=initial_states, + reward=cast(Reward, reward), # ignore none + ) + trainer = Trainer( + finite_env_type=finite_env_type, + concurrency=concurrency, + loggers=logger, + ) + trainer.test(vessel) diff --git a/qlib/rl/trainer/callbacks.py b/qlib/rl/trainer/callbacks.py new file mode 100644 index 000000000..72e2df99a --- /dev/null +++ b/qlib/rl/trainer/callbacks.py @@ -0,0 +1,267 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Callbacks to insert customized recipes during the training. +Mimicks the hooks of Keras / PyTorch-Lightning, but tailored for the context of RL. +""" + +from __future__ import annotations + +import copy +import shutil +import time +from datetime import datetime +from pathlib import Path +from typing import Any, TYPE_CHECKING + +import numpy as np +import torch + +from qlib.log import get_module_logger +from qlib.typehint import Literal + +if TYPE_CHECKING: + from .trainer import Trainer + from .vessel import TrainingVesselBase + + +_logger = get_module_logger(__name__) + + +class Callback: + """Base class of all callbacks.""" + + def on_fit_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: + """Called before the whole fit process begins.""" + + def on_fit_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: + """Called after the whole fit process ends.""" + + def on_train_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: + """Called when each collect for training begins.""" + + def on_train_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: + """Called when the training ends. + To access all outputs produced during training, cache the data in either trainer and vessel, + and post-process them in this hook. + """ + + def on_validate_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: + """Called when every run for validation begins.""" + + def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: + """Called when the validation ends.""" + + def on_test_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: + """Called when every run of testing begins.""" + + def on_test_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: + """Called when the testing ends.""" + + def on_iter_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: + """Called when every iteration (i.e., collect) starts.""" + + def on_iter_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: + """Called upon every end of iteration. + This is called **after** the bump of ``current_iter``, + when the previous iteration is considered complete. + """ + + def state_dict(self) -> Any: + """Get a state dict of the callback for pause and resume.""" + + def load_state_dict(self, state_dict: Any) -> None: + """Resume the callback from a saved state dict.""" + + +class EarlyStopping(Callback): + """Stop training when a monitored metric has stopped improving. + + The earlystopping callback will be triggered each time validation ends. + It will examine the metrics produced in validation, + and get the metric with name ``monitor` (``monitor`` is ``reward`` by default), + to check whether it's no longer increasing / decreasing. + It takes ``min_delta`` and ``patience`` if applicable. + If it's found to be not increasing / decreasing any more. + ``trainer.should_stop`` will be set to true, + and the training terminates. + + Implementation reference: https://github.com/keras-team/keras/blob/v2.9.0/keras/callbacks.py#L1744-L1893 + """ + + def __init__( + self, + monitor: str = "reward", + min_delta: float = 0.0, + patience: int = 0, + mode: Literal["min", "max"] = "max", + baseline: float | None = None, + restore_best_weights: bool = False, + ): + super().__init__() + + self.monitor = monitor + self.patience = patience + self.baseline = baseline + self.min_delta = abs(min_delta) + self.restore_best_weights = restore_best_weights + self.best_weights: Any | None = None + + if mode not in ["min", "max"]: + raise ValueError("Unsupported earlystopping mode: " + mode) + + if mode == "min": + self.monitor_op = np.less + elif mode == "max": + self.monitor_op = np.greater + + if self.monitor_op == np.greater: + self.min_delta *= 1 + else: + self.min_delta *= -1 + + def state_dict(self) -> dict: + return {"wait": self.wait, "best": self.best, "best_weights": self.best_weights, "best_iter": self.best_iter} + + def load_state_dict(self, state_dict: dict) -> None: + self.wait = state_dict["wait"] + self.best = state_dict["best"] + self.best_weights = state_dict["best_weights"] + self.best_iter = state_dict["best_iter"] + + def on_fit_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: + # Allow instances to be re-used + self.wait = 0 + self.best = np.inf if self.monitor_op == np.less else -np.inf + self.best_weights = None + self.best_iter = 0 + + def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: + current = self.get_monitor_value(trainer) + if current is None: + return + if self.restore_best_weights and self.best_weights is None: + # Restore the weights after first iteration if no progress is ever made. + self.best_weights = copy.deepcopy(vessel.state_dict()) + + self.wait += 1 + if self._is_improvement(current, self.best): + self.best = current + self.best_iter = trainer.current_iter + if self.restore_best_weights: + self.best_weights = copy.deepcopy(vessel.state_dict()) + # Only restart wait if we beat both the baseline and our previous best. + if self.baseline is None or self._is_improvement(current, self.baseline): + self.wait = 0 + + # Only check after the first epoch. + if self.wait >= self.patience and trainer.current_iter > 0: + trainer.should_stop = True + _logger.info(f"On iteration %d: early stopping", trainer.current_iter + 1) + if self.restore_best_weights and self.best_weights is not None: + _logger.info("Restoring model weights from the end of the best iteration: %d", self.best_iter + 1) + vessel.load_state_dict(self.best_weights) + + def get_monitor_value(self, trainer: Trainer) -> Any: + monitor_value = trainer.metrics.get(self.monitor) + if monitor_value is None: + _logger.warning( + "Early stopping conditioned on metric `%s` which is not available. Available metrics are: %s", + self.monitor, + ",".join(list(trainer.metrics.keys())), + ) + return monitor_value + + def _is_improvement(self, monitor_value, reference_value): + return self.monitor_op(monitor_value - self.min_delta, reference_value) + + +class Checkpoint(Callback): + """Save checkpoints periodically for persistence and recovery. + + Reference: https://github.com/PyTorchLightning/pytorch-lightning/blob/bfa8b7be/pytorch_lightning/callbacks/model_checkpoint.py + + Parameters + ---------- + dirpath + Directory to save the checkpoint file. + filename + Checkpoint filename. Can contain named formatting options to be auto-filled. + For example: ``{iter:03d}-{reward:.2f}.pth``. + Supported argument names are: + + - iter (int) + - metrics in ``trainer.metrics`` + - time string, in the format of ``%Y%m%d%H%M%S`` + save_latest + Save the latest checkpoint in ``latest.pth``. + If ``link``, ``latest.pth`` will be created as a softlink. + If ``copy``, ``latest.pth`` will be stored as an individual copy. + Set to none to disable this. + every_n_iters + Checkpoints are saved at the end of every n iterations of training, + after validation if applicable. + time_interval + Maximum time (seconds) before checkpoints save again. + save_on_fit_end + Save one last checkpoint at the end to fit. + Do nothing if a checkpoint is already saved there. + """ + + def __init__( + self, + dirpath: Path, + filename: str = "{iter:03d}.pth", + save_latest: Literal["link", "copy"] | None = "link", + every_n_iters: int | None = None, + time_interval: int | None = None, + save_on_fit_end: bool = True, + ): + self.dirpath = Path(dirpath) + self.filename = filename + self.save_latest = save_latest + self.every_n_iters = every_n_iters + self.time_interval = time_interval + self.save_on_fit_end = save_on_fit_end + + self._last_checkpoint_name: str | None = None + self._last_checkpoint_iter: int | None = None + self._last_checkpoint_time: float | None = None + + def on_fit_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: + if self.save_on_fit_end and (trainer.current_iter != self._last_checkpoint_iter): + self._save_checkpoint(trainer) + + def on_iter_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: + should_save_ckpt = False + if self.every_n_iters is not None and (trainer.current_iter + 1) % self.every_n_iters == 0: + should_save_ckpt = True + if self.time_interval is not None and ( + self._last_checkpoint_time is None or (time.time() - self._last_checkpoint_time) >= self.time_interval + ): + should_save_ckpt = True + if should_save_ckpt: + self._save_checkpoint(trainer) + + def _save_checkpoint(self, trainer: Trainer) -> None: + self.dirpath.mkdir(exist_ok=True, parents=True) + self._last_checkpoint_name = self._new_checkpoint_name(trainer) + self._last_checkpoint_iter = trainer.current_iter + self._last_checkpoint_time = time.time() + torch.save(trainer.state_dict(), self.dirpath / self._last_checkpoint_name) + + latest_pth = self.dirpath / "latest.pth" + + # Remove first before saving + if self.save_latest and latest_pth.exists(): + latest_pth.unlink() + + if self.save_latest == "link": + latest_pth.symlink_to(self.dirpath / self._last_checkpoint_name) + elif self.save_latest == "copy": + shutil.copyfile(self.dirpath / self._last_checkpoint_name, latest_pth) + + def _new_checkpoint_name(self, trainer: Trainer) -> str: + return self.filename.format( + iter=trainer.current_iter, time=datetime.now().strftime("%Y%m%d%H%M%S"), **trainer.metrics + ) diff --git a/qlib/rl/trainer/trainer.py b/qlib/rl/trainer/trainer.py new file mode 100644 index 000000000..c44419e05 --- /dev/null +++ b/qlib/rl/trainer/trainer.py @@ -0,0 +1,343 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import copy +from contextlib import AbstractContextManager, contextmanager +from pathlib import Path +from typing import Any, Iterable, TypeVar, Sequence, cast + +import torch + +from qlib.rl.simulator import InitialStateType +from qlib.rl.utils import EnvWrapper, FiniteEnvType, LogCollector, LogWriter, LogBuffer, vectorize_env, LogLevel +from qlib.log import get_module_logger +from qlib.rl.utils.finite_env import FiniteVectorEnv +from qlib.typehint import Literal + +from .callbacks import Callback +from .vessel import TrainingVesselBase + +_logger = get_module_logger(__name__) + + +T = TypeVar("T") + + +class Trainer: + """ + Utility to train a policy on a particular task. + + Different from traditional DL trainer, the iteration of this trainer is "collect", + rather than "epoch", or "mini-batch". + In each collect, :class:`Collector` collects a number of policy-env interactions, and accumulates + them into a replay buffer. This buffer is used as the "data" to train the policy. + At the end of each collect, the policy is *updated* several times. + + The API has some resemblence with `PyTorch Lightning `__, + but it's essentially different because this trainer is built for RL applications, and thus + most configurations are under RL context. + We are still looking for ways to incorporate existing trainer libraries, because it looks like + big efforts to build a trainer as powerful as those libraries, and also, that's not our primary goal. + + It's essentially different + `tianshou's built-in trainers `__, + as it's far much more complicated than that. + + Parameters + ---------- + max_iters + Maximum iterations before stopping. + val_every_n_iters + Perform validation every n iterations (i.e., training collects). + logger + Logger to record the backtest results. Logger must be present because + without logger, all information will be lost. + finite_env_type + Type of finite env implementation. + concurrency + Parallel workers. + fast_dev_run + Create a subset for debugging. + How this is implemented depends on the implementation of training vessel. + For :class:`~qlib.rl.vessel.TrainingVessel`, if greater than zero, + a random subset sized ``fast_dev_run`` will be used + instead of ``train_initial_states`` and ``val_initial_states``. + """ + + should_stop: bool + """Set to stop the training.""" + + metrics: dict + """Numeric metrics of produced in train/val/test. + In the middle of training / validation, metrics will be of the latest episode. + When each iteration of training / validation finishes, metrics will be the aggregation + of all episodes encountered in this iteration. + + Cleared on every new iteration of training. + + In fit, validation metrics will be prefixed with ``val/``. + """ + + current_iter: int + """Current iteration (collect) of training.""" + + loggers: list[LogWriter] + """A list of log writers.""" + + def __init__( + self, + *, + max_iters: int | None = None, + val_every_n_iters: int | None = None, + loggers: LogWriter | list[LogWriter] | None = None, + callbacks: list[Callback] | None = None, + finite_env_type: FiniteEnvType = "subproc", + concurrency: int = 2, + fast_dev_run: int | None = None, + ): + self.max_iters = max_iters + self.val_every_n_iters = val_every_n_iters + + if isinstance(loggers, list): + self.loggers = loggers + elif isinstance(loggers, LogWriter): + self.loggers = [loggers] + else: + self.loggers = [] + + self.loggers.append(LogBuffer(self._metrics_callback, loglevel=self._min_loglevel())) + + self.callbacks: list[Callback] = callbacks if callbacks is not None else [] + self.finite_env_type = finite_env_type + self.concurrency = concurrency + self.fast_dev_run = fast_dev_run + + self.current_stage: Literal["train", "val", "test"] = "train" + + self.vessel: TrainingVesselBase = cast(TrainingVesselBase, None) + + def initialize(self): + """Initialize the whole training process. + + The states here should be synchronized with state_dict. + """ + self.should_stop = False + self.current_iter = 0 + self.current_episode = 0 + self.current_stage = "train" + + def initialize_iter(self): + """Initialize one iteration / collect.""" + self.metrics = {} + + def state_dict(self) -> dict: + """Putting every states of current training into a dict, at best effort. + + It doesn't try to handle all the possible kinds of states in the middle of one training collect. + For most cases at the end of each iteration, things should be usually correct. + + Note that it's also intended behavior that replay buffer data in the collector will be lost. + """ + return { + "vessel": self.vessel.state_dict(), + "callbacks": {name: callback.state_dict() for name, callback in self.named_callbacks().items()}, + "loggers": {name: logger.state_dict() for name, logger in self.named_loggers().items()}, + "should_stop": self.should_stop, + "current_iter": self.current_iter, + "current_episode": self.current_episode, + "current_stage": self.current_stage, + "metrics": self.metrics, + } + + def load_state_dict(self, state_dict: dict) -> None: + """Load all states into current trainer.""" + self.vessel.load_state_dict(state_dict["vessel"]) + for name, callback in self.named_callbacks().items(): + callback.load_state_dict(state_dict["callbacks"][name]) + for name, logger in self.named_loggers().items(): + logger.load_state_dict(state_dict["loggers"][name]) + self.should_stop = state_dict["should_stop"] + self.current_iter = state_dict["current_iter"] + self.current_episode = state_dict["current_episode"] + self.current_stage = state_dict["current_stage"] + self.metrics = state_dict["metrics"] + + def named_callbacks(self) -> dict[str, Callback]: + """Retrieve a collection of callbacks where each one has a name. + Useful when saving checkpoints. + """ + return _named_collection(self.callbacks) + + def named_loggers(self) -> dict[str, LogWriter]: + """Retrieve a collection of loggers where each one has a name. + Useful when saving checkpoints. + """ + return _named_collection(self.loggers) + + def fit(self, vessel: TrainingVesselBase, ckpt_path: Path | None = None) -> None: + """Train the RL policy upon the defined simulator. + + Parameters + ---------- + vessel + A bundle of all elements used in training. + ckpt_path + Load a pre-trained / paused training checkpoint. + """ + self.vessel = vessel + vessel.assign_trainer(self) + + if ckpt_path is not None: + _logger.info("Resuming states from %s", str(ckpt_path)) + self.load_state_dict(torch.load(ckpt_path)) + else: + self.initialize() + + self._call_callback_hooks("on_fit_start") + + while not self.should_stop: + self.initialize_iter() + + self._call_callback_hooks("on_iter_start") + + self.current_stage = "train" + self._call_callback_hooks("on_train_start") + + # TODO + # Add a feature that supports reloading the training environment every few iterations. + with _wrap_context(vessel.train_seed_iterator()) as iterator: + vector_env = self.venv_from_iterator(iterator) + self.vessel.train(vector_env) + + self._call_callback_hooks("on_train_end") + + if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0: + # Implementation of validation loop + self.current_stage = "val" + self._call_callback_hooks("on_validate_start") + with _wrap_context(vessel.val_seed_iterator()) as iterator: + vector_env = self.venv_from_iterator(iterator) + self.vessel.validate(vector_env) + + self._call_callback_hooks("on_validate_end") + + # This iteration is considered complete. + # Bumping the current iteration counter. + self.current_iter += 1 + + if self.max_iters is not None and self.current_iter >= self.max_iters: + self.should_stop = True + + self._call_callback_hooks("on_iter_end") + + self._call_callback_hooks("on_fit_end") + + def test(self, vessel: TrainingVesselBase) -> None: + """Test the RL policy against the simulator. + + The simulator will be fed with data generated in ``test_seed_iterator``. + + Parameters + ---------- + vessel + A bundle of all related elements. + """ + self.vessel = vessel + vessel.assign_trainer(self) + + self.initialize_iter() + + self.current_stage = "test" + self._call_callback_hooks("on_test_start") + with _wrap_context(vessel.test_seed_iterator()) as iterator: + vector_env = self.venv_from_iterator(iterator) + self.vessel.test(vector_env) + self._call_callback_hooks("on_test_end") + + def venv_from_iterator(self, iterator: Iterable[InitialStateType]) -> FiniteVectorEnv: + """Create a vectorized environment from iterator and the training vessel.""" + + def env_factory(): + # FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env), + # and could be thread unsafe. + # I'm not sure whether it's a design flaw. + # I'll rethink about this when designing the trainer. + + if self.finite_env_type == "dummy": + # We could only experience the "threading-unsafe" problem in dummy. + state = copy.deepcopy(self.vessel.state_interpreter) + action = copy.deepcopy(self.vessel.action_interpreter) + rew = copy.deepcopy(self.vessel.reward) + else: + state = self.vessel.state_interpreter + action = self.vessel.action_interpreter + rew = self.vessel.reward + + return EnvWrapper( + self.vessel.simulator_fn, + state, + action, + iterator, + rew, + logger=LogCollector(min_loglevel=self._min_loglevel()), + ) + + return vectorize_env( + env_factory, + self.finite_env_type, + self.concurrency, + self.loggers, + ) + + def _metrics_callback(self, on_episode: bool, on_collect: bool, log_buffer: LogBuffer) -> None: + if on_episode: + # Update the global counter. + self.current_episode = log_buffer.global_episode + metrics = log_buffer.episode_metrics() + elif on_collect: + # Update the latest metrics. + metrics = log_buffer.collect_metrics() + if self.current_stage == "val": + metrics = {"val/" + name: value for name, value in metrics.items()} + self.metrics.update(metrics) + + def _call_callback_hooks(self, hook_name: str, *args: Any, **kwargs: Any) -> None: + for callback in self.callbacks: + fn = getattr(callback, hook_name) + fn(self, self.vessel, *args, **kwargs) + + def _min_loglevel(self): + if not self.loggers: + return LogLevel.PERIODIC + else: + # To save bandwidth + return min(lg.loglevel for lg in self.loggers) + + +@contextmanager +def _wrap_context(obj): + """Make any object a (possibly dummy) context manager.""" + + if isinstance(obj, AbstractContextManager): + # obj has __enter__ and __exit__ + with obj as ctx: + yield ctx + else: + yield obj + + +def _named_collection(seq: Sequence[T]) -> dict[str, T]: + """Convert a list into a dict, where each item is named with its type.""" + res = {} + for item in seq: + typename = type(item).__name__.lower() + if typename not in res: + res[typename] = item + else: + # names are auto-labelled as earlystop1, earlystop2, ... + for retry in range(1, 1000): + if f"{typename}{retry}" not in res: + res[f"{typename}{retry}"] = item + return res diff --git a/qlib/rl/trainer/vessel.py b/qlib/rl/trainer/vessel.py new file mode 100644 index 000000000..9c0879ce0 --- /dev/null +++ b/qlib/rl/trainer/vessel.py @@ -0,0 +1,214 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import weakref +from typing import Callable, ContextManager, Generic, Iterable, TYPE_CHECKING, Sequence, Any, TypeVar, cast, Dict + +import numpy as np +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import BaseVectorEnv +from tianshou.policy import BasePolicy + +from qlib.constant import INF +from qlib.rl.interpreter import StateType, ActType, ObsType, PolicyActType +from qlib.rl.simulator import InitialStateType, Simulator +from qlib.rl.interpreter import StateInterpreter, ActionInterpreter +from qlib.rl.reward import Reward +from qlib.rl.utils import DataQueue +from qlib.log import get_module_logger +from qlib.rl.utils.finite_env import FiniteVectorEnv + +if TYPE_CHECKING: + from .trainer import Trainer + + +T = TypeVar("T") +_logger = get_module_logger(__name__) + + +class SeedIteratorNotAvailable(BaseException): + pass + + +class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType]): + """A ship that contains simulator, interpreter, and policy, will be sent to trainer. + This class controls algorithm-related parts of training, while trainer is responsible for runtime part. + + The ship also defines the most important logic of the core training part, + and (optionally) some callbacks to insert customized logics at specific events. + """ + + simulator_fn: Callable[[InitialStateType], Simulator[InitialStateType, StateType, ActType]] + state_interpreter: StateInterpreter[StateType, ObsType] + action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType] + policy: BasePolicy + reward: Reward + trainer: Trainer + + def assign_trainer(self, trainer: Trainer) -> None: + self.trainer = weakref.proxy(trainer) # type: ignore + + def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: + """Override this to create a seed iterator for training. + If the iterable is a context manager, the whole training will be invoked in the with-block, + and the iterator will be automatically closed after the training is done.""" + raise SeedIteratorNotAvailable("Seed iterator for training is not available.") + + def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: + """Override this to create a seed iterator for validation.""" + raise SeedIteratorNotAvailable("Seed iterator for validation is not available.") + + def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: + """Override this to create a seed iterator for testing.""" + raise SeedIteratorNotAvailable("Seed iterator for testing is not available.") + + def train(self, vector_env: BaseVectorEnv) -> dict[str, Any]: + """Implement this to train one iteration. In RL, one iteration usually refers to one collect.""" + raise NotImplementedError() + + def validate(self, vector_env: FiniteVectorEnv) -> dict[str, Any]: + """Implement this to validate the policy once.""" + raise NotImplementedError() + + def test(self, vector_env: FiniteVectorEnv) -> dict[str, Any]: + """Implement this to evaluate the policy on test environment once.""" + raise NotImplementedError() + + def log(self, name: str, value: Any) -> None: + # FIXME: this is a workaround to make the log at least show somewhere. + # Need a refactor in logger to formalize this. + if isinstance(value, (np.ndarray, list)): + value = np.mean(value) + _logger.info(f"[Iter {self.trainer.current_iter + 1}] {name} = {value}") + + def log_dict(self, data: dict[str, Any]) -> None: + for name, value in data.items(): + self.log(name, value) + + def state_dict(self) -> dict: + """Return a checkpoint of current vessel state.""" + return {"policy": self.policy.state_dict()} + + def load_state_dict(self, state_dict: dict) -> None: + """Restore a checkpoint from a previously saved state dict.""" + self.policy.load_state_dict(state_dict["policy"]) + + +class TrainingVessel(TrainingVesselBase): + """The default implementation of training vessel. + + ``__init__`` accepts a sequence of initial states so that iterator can be created. + ``train``, ``validate``, ``test`` each do one collect (and also update in train). + By default, the train initial states will be repeated infinitely during training, + and collector will control the number of episodes for each iteration. + In validation and testing, the val / test initial states will be used exactly once. + + Extra hyper-parameters (only used in train) include: + + - ``buffer_size``: Size of replay buffer. + - ``episode_per_iter``: Episodes per collect at training. Can be overridden by fast dev run. + - ``update_kwargs``: Keyword arguments appearing in ``policy.update``. + For example, ``dict(repeat=10, batch_size=64)``. + """ + + def __init__( + self, + *, + simulator_fn: Callable[[InitialStateType], Simulator[InitialStateType, StateType, ActType]], + state_interpreter: StateInterpreter[StateType, ObsType], + action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType], + policy: BasePolicy, + reward: Reward, + train_initial_states: Sequence[InitialStateType] | None = None, + val_initial_states: Sequence[InitialStateType] | None = None, + test_initial_states: Sequence[InitialStateType] | None = None, + buffer_size: int = 20000, + episode_per_iter: int = 1000, + update_kwargs: dict[str, Any] = cast(Dict[str, Any], None), + ): + self.simulator_fn = simulator_fn # type: ignore + self.state_interpreter = state_interpreter + self.action_interpreter = action_interpreter + self.policy = policy + self.reward = reward + self.train_initial_states = train_initial_states + self.val_initial_states = val_initial_states + self.test_initial_states = test_initial_states + self.buffer_size = buffer_size + self.episode_per_iter = episode_per_iter + self.update_kwargs = update_kwargs or {} + + def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: + if self.train_initial_states is not None: + _logger.info("Training initial states collection size: %d", len(self.train_initial_states)) + # Implement fast_dev_run here. + train_initial_states = self._random_subset("train", self.train_initial_states, self.trainer.fast_dev_run) + return DataQueue(train_initial_states, repeat=-1, shuffle=True) + return super().train_seed_iterator() + + def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: + if self.val_initial_states is not None: + _logger.info("Validation initial states collection size: %d", len(self.val_initial_states)) + val_initial_states = self._random_subset("val", self.val_initial_states, self.trainer.fast_dev_run) + return DataQueue(val_initial_states, repeat=1) + return super().val_seed_iterator() + + def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: + if self.test_initial_states is not None: + _logger.info("Testing initial states collection size: %d", len(self.test_initial_states)) + test_initial_states = self._random_subset("test", self.test_initial_states, self.trainer.fast_dev_run) + return DataQueue(test_initial_states, repeat=1) + return super().test_seed_iterator() + + def train(self, vector_env: FiniteVectorEnv) -> dict[str, Any]: + """Create a collector and collects ``episode_per_iter`` episodes. + Update the policy on the collected replay buffer. + """ + self.policy.train() + + with vector_env.collector_guard(): + collector = Collector(self.policy, vector_env, VectorReplayBuffer(self.buffer_size, len(vector_env))) + + # Number of episodes collected in each training iteration can be overridden by fast dev run. + if self.trainer.fast_dev_run is not None: + episodes = self.trainer.fast_dev_run + else: + episodes = self.episode_per_iter + + col_result = collector.collect(n_episode=episodes) + update_result = self.policy.update(sample_size=0, buffer=collector.buffer, **self.update_kwargs) + res = {**col_result, **update_result} + self.log_dict(res) + return res + + def validate(self, vector_env: FiniteVectorEnv) -> dict[str, Any]: + self.policy.eval() + + with vector_env.collector_guard(): + test_collector = Collector(self.policy, vector_env) + res = test_collector.collect(n_step=INF * len(vector_env)) + self.log_dict(res) + return res + + def test(self, vector_env: FiniteVectorEnv) -> dict[str, Any]: + self.policy.eval() + + with vector_env.collector_guard(): + test_collector = Collector(self.policy, vector_env) + res = test_collector.collect(n_step=INF * len(vector_env)) + self.log_dict(res) + return res + + @staticmethod + def _random_subset(name: str, collection: Sequence[T], size: int | None) -> Sequence[T]: + if size is None: + # Size = None -> original collection + return collection + order = np.random.permutation(len(collection)) + res = [collection[o] for o in order[:size]] + _logger.info( + "Fast running in development mode. Cut %s initial states from %d to %d.", name, len(collection), len(res) + ) + return res diff --git a/qlib/rl/utils/data_queue.py b/qlib/rl/utils/data_queue.py index 32041abef..c1f7f3ab0 100644 --- a/qlib/rl/utils/data_queue.py +++ b/qlib/rl/utils/data_queue.py @@ -145,7 +145,9 @@ class DataQueue(Generic[T]): def __iter__(self): if not self._activated: raise ValueError( - "Need to call activate() to launch a daemon worker to produce data into data queue before using it." + "Need to call activate() to launch a daemon worker " + "to produce data into data queue before using it. " + "You probably have forgotten to use the DataQueue in a with block." ) return self._consumer() @@ -161,19 +163,21 @@ class DataQueue(Generic[T]): # pytorch dataloader is used here only because we need its sampler and multi-processing from torch.utils.data import DataLoader, Dataset # pylint: disable=import-outside-toplevel - dataloader = DataLoader( - cast(Dataset[T], self.dataset), - batch_size=None, - num_workers=self.producer_num_workers, - shuffle=self.shuffle, - collate_fn=lambda t: t, # identity collate fn - ) - repeat = 10**18 if self.repeat == -1 else self.repeat - for _rep in range(repeat): - for data in dataloader: - if self._done.value: - # Already done. - return - self._queue.put(data) - _logger.debug(f"Dataloader loop done. Repeat {_rep}.") - self.mark_as_done() + try: + dataloader = DataLoader( + cast(Dataset[T], self.dataset), + batch_size=None, + num_workers=self.producer_num_workers, + shuffle=self.shuffle, + collate_fn=lambda t: t, # identity collate fn + ) + repeat = 10**18 if self.repeat == -1 else self.repeat + for _rep in range(repeat): + for data in dataloader: + if self._done.value: + # Already done. + return + self._queue.put(data) + _logger.debug(f"Dataloader loop done. Repeat {_rep}.") + finally: + self.mark_as_done() diff --git a/qlib/rl/utils/finite_env.py b/qlib/rl/utils/finite_env.py index fc9c2c75e..6d7b0e209 100644 --- a/qlib/rl/utils/finite_env.py +++ b/qlib/rl/utils/finite_env.py @@ -120,12 +120,19 @@ class FiniteVectorEnv(BaseVectorEnv): from child workers. See :class:`qlib.rl.utils.LogWriter`. """ + _logger: list[LogWriter] + def __init__( - self, logger: LogWriter | list[LogWriter], env_fns: list[Callable[..., gym.Env]], **kwargs: Any + self, logger: LogWriter | list[LogWriter] | None, env_fns: list[Callable[..., gym.Env]], **kwargs: Any ) -> None: super().__init__(env_fns, **kwargs) - self._logger: list[LogWriter] = logger if isinstance(logger, list) else [logger] + if isinstance(logger, list): + self._logger = logger + elif isinstance(logger, LogWriter): + self._logger = [logger] + else: + self._logger = [] self._alive_env_ids: Set[int] = set() self._reset_alive_envs() self._default_obs = self._default_info = self._default_rew = None @@ -177,7 +184,7 @@ class FiniteVectorEnv(BaseVectorEnv): 1. Catch and ignore the StopIteration exception, which is the stopping signal thrown by FiniteEnv to let tianshou know that ``collector.collect()`` should exit. - 2. Notify the loggers that the collect is done what it's done. + 2. Notify the loggers that the collect is ready / done what it's ready / done. Examples -------- @@ -186,6 +193,9 @@ class FiniteVectorEnv(BaseVectorEnv): """ self._collector_guarded = True + for logger in self._logger: + logger.on_env_all_ready() + try: yield self except StopIteration: @@ -298,7 +308,21 @@ def vectorize_env( concurrency: int, logger: LogWriter | list[LogWriter], ) -> FiniteVectorEnv: - """Helper function to create a vector env. + """Helper function to create a vector env. Can be used to replace usual VectorEnv. + + For example, once you wrote: :: + + DummyVectorEnv([lambda: gym.make(task) for _ in range(env_num)]) + + Now you can replace it with: :: + + finite_env_factory(lambda: gym.make(task), "dummy", env_num, my_logger) + + By doing such replacement, you have two additional features enabled (compared to normal VectorEnv): + + 1. The vector env will check for NaN observation and kill the worker when its found. + See :class:`FiniteVectorEnv` for why we need this. + 2. A logger to explicit collect logs from environment workers. Parameters ---------- diff --git a/qlib/rl/utils/log.py b/qlib/rl/utils/log.py index 3d495b11d..409a48a76 100644 --- a/qlib/rl/utils/log.py +++ b/qlib/rl/utils/log.py @@ -12,13 +12,16 @@ in each worker, and writes them to console, log files, or tensorboard... The two modules communicate by the "log" field in "info" returned by ``env.step()``. """ +# NOTE: This file contains many hardcoded / ad-hoc rules. +# Refactoring it will be one of the future tasks. + from __future__ import annotations import logging from collections import defaultdict from enum import IntEnum from pathlib import Path -from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence +from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence, Callable import numpy as np import pandas as pd @@ -29,7 +32,7 @@ if TYPE_CHECKING: from .env_wrapper import InfoDict -__all__ = ["LogCollector", "LogWriter", "LogLevel", "ConsoleWriter", "CsvWriter"] +__all__ = ["LogCollector", "LogWriter", "LogLevel", "LogBuffer", "ConsoleWriter", "CsvWriter"] ObsType = TypeVar("ObsType") ActType = TypeVar("ActType") @@ -175,18 +178,53 @@ class LogWriter(Generic[ObsType, ActType]): self.clear() def clear(self): + """Clear all the metrics for a fresh start. + To make the logger instance reusable. + """ self.episode_count = self.step_count = 0 self.active_env_ids = set() - self.logs = [] - def aggregation(self, array: Sequence[Any]) -> Any: + def state_dict(self) -> dict: + """Save the states of the logger to a dict.""" + return { + "episode_count": self.episode_count, + "step_count": self.step_count, + "global_step": self.global_step, + "global_episode": self.global_episode, + "active_env_ids": self.active_env_ids, + "episode_lengths": self.episode_lengths, + "episode_rewards": self.episode_rewards, + "episode_logs": self.episode_logs, + } + + def load_state_dict(self, state_dict: dict) -> None: + """Load the states of current logger from a dict.""" + self.episode_count = state_dict["episode_count"] + self.step_count = state_dict["step_count"] + self.global_step = state_dict["global_step"] + self.global_episode = state_dict["global_episode"] + + # These are runtime infos. + # Though they are loaded, I don't think it really helps. + self.active_env_ids = state_dict["active_env_ids"] + self.episode_lenghts = state_dict["episode_lengths"] + self.episode_rewards = state_dict["episode_rewards"] + self.episode_logs = state_dict["episode_logs"] + + def aggregation(self, array: Sequence[Any], name: str | None = None) -> Any: """Aggregation function from step-wise to episode-wise. If it's a sequence of float, take the mean. Otherwise, take the first element. + + If a name is specified and, + + - if it's ``reward``, the reduction will be sum. """ assert len(array) > 0, "The aggregated array must be not empty." if all(isinstance(v, float) for v in array): + if name == "reward": + return np.sum(array) return np.mean(array) else: return array[0] @@ -253,10 +291,93 @@ class LogWriter(Generic[ObsType, ActType]): self.episode_rewards[env_id] = [] self.episode_logs[env_id] = [] + def on_env_all_ready(self) -> None: + """When all environments are ready to run. + Usually, loggers should be reset here. + """ + self.clear() + def on_env_all_done(self) -> None: """All done. Time for cleanup.""" +class LogBuffer(LogWriter): + """Keep all numbers in memory. + + Objects that can't be aggregated like strings, tensors, images can't be stored in the buffer. + To persist them, please use :class:`PickleWriter`. + + Every time, Log buffer receives a new metric, the callback is triggered, + which is useful when tracking metrics inside a trainer. + + Parameters + ---------- + callback + A callback receiving three arguments: + + - on_episode: Whether it's called at the end of an episode + - on_collect: Whether it's called at the end of a collect + - log_buffer: the :class:`LogBbuffer`object + + No return value is expected. + """ + + # FIXME: needs a metric count + + def __init__(self, callback: Callable[[bool, bool, LogBuffer], None], loglevel: int | LogLevel = LogLevel.PERIODIC): + super().__init__(loglevel) + self.callback = callback + + def state_dict(self) -> dict: + return { + **super().state_dict(), + "latest_metrics": self._latest_metrics, + "aggregated_metrics": self._aggregated_metrics, + } + + def load_state_dict(self, state_dict: dict) -> None: + self._latest_metrics = state_dict["latest_metrics"] + self._aggregated_metrics = state_dict["aggregated_metrics"] + return super().load_state_dict(state_dict) + + def clear(self): + super().clear() + self._latest_metrics: dict[str, float] | None = None + self._aggregated_metrics: dict[str, float] = defaultdict(float) + + def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None: + # FIXME Dup of ConsoleWriter + episode_wise_contents: dict[str, list] = defaultdict(list) + for step_contents in contents: + for name, value in step_contents.items(): + # FIXME This could be false-negative for some numpy types + if isinstance(value, float): + episode_wise_contents[name].append(value) + + logs: dict[str, float] = {} + for name, values in episode_wise_contents.items(): + logs[name] = self.aggregation(values, name) # type: ignore + self._aggregated_metrics[name] += logs[name] + + self._latest_metrics = logs + + self.callback(True, False, self) + + def on_env_all_done(self) -> None: + # This happens when collect exits + self.callback(False, True, self) + + def episode_metrics(self) -> dict[str, float]: + """Retrieve the numeric metrics of the latest episode.""" + if self._latest_metrics is None: + raise ValueError("No episode metrics available yet.") + return self._latest_metrics + + def collect_metrics(self) -> dict[str, float]: + """Retrieve the aggregated metrics of the latest collect.""" + return {name: value / self.episode_count for name, value in self._aggregated_metrics.items()} + + class ConsoleWriter(LogWriter): """Write log messages to console periodically. @@ -289,6 +410,8 @@ class ConsoleWriter(LogWriter): self.console_logger = get_module_logger(__name__, level=logging.INFO) + # FIXME: save & reload + def clear(self): super().clear() # Clear average meters @@ -308,7 +431,7 @@ class ConsoleWriter(LogWriter): # This should be done at every step, regardless of periodic or not. logs: dict[str, float] = {} for name, values in episode_wise_contents.items(): - logs[name] = self.aggregation(values) # type: ignore + logs[name] = self.aggregation(values, name) # type: ignore for name, value in logs.items(): self.metric_counts[name] += 1 @@ -350,6 +473,8 @@ class CsvWriter(LogWriter): all_records: list[dict[str, Any]] + # FIXME: save & reload + def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC): super().__init__(loglevel) self.output_dir = output_dir @@ -370,7 +495,7 @@ class CsvWriter(LogWriter): logs: dict[str, float] = {} for name, values in episode_wise_contents.items(): - logs[name] = self.aggregation(values) # type: ignore + logs[name] = self.aggregation(values, name) # type: ignore self.all_records.append(logs) @@ -392,7 +517,3 @@ class TensorboardWriter(LogWriter): class MlflowWriter(LogWriter): """Add logs to mlflow.""" - - -class LogBuffer(LogWriter): - """Keep everything in memory.""" diff --git a/tests/rl/test_logger.py b/tests/rl/test_logger.py index 240ffc1e1..2cf149a75 100644 --- a/tests/rl/test_logger.py +++ b/tests/rl/test_logger.py @@ -81,7 +81,7 @@ def test_simple_env_logger(caplog): line = line.strip() if line: line_counter += 1 - assert re.match(r".*reward 42\.0000 \(42.0000\) a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line) + assert re.match(r".*reward .* a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line) assert line_counter >= 3 diff --git a/tests/rl/test_saoe_simple.py b/tests/rl/test_saoe_simple.py index 2ac0d9cbd..98e5dd981 100644 --- a/tests/rl/test_saoe_simple.py +++ b/tests/rl/test_saoe_simple.py @@ -17,7 +17,7 @@ from qlib.backtest import Order from qlib.config import C from qlib.log import set_log_with_config from qlib.rl.data import pickle_styled -from qlib.rl.entries.test import backtest +from qlib.rl.trainer import backtest, train from qlib.rl.order_execution import * from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus @@ -306,3 +306,26 @@ def test_cn_ppo_strategy(): assert np.isclose(metrics["pa"].mean(), -16.21578303474833) assert np.isclose(metrics["market_price"].mean(), 58.68277690875527) assert np.isclose(metrics["trade_price"].mean(), 58.76063985000002) + + +def test_ppo_train(): + set_log_with_config(C.logging_config) + # The data starts with 9:31 and ends with 15:00 + orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58")) + assert len(orders) == 40 + + state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6) + action_interp = CategoricalActionInterpreter(4) + network = Recurrent(state_interp.observation_space) + policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4) + + train( + partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30), + state_interp, + action_interp, + orders, + policy, + PAPenaltyReward(), + vessel_kwargs={"episode_per_iter": 100, "update_kwargs": {"batch_size": 64, "repeat": 5}}, + trainer_kwargs={"max_iters": 2, "loggers": ConsoleWriter(total_episodes=100)}, + ) diff --git a/tests/rl/test_trainer.py b/tests/rl/test_trainer.py new file mode 100644 index 000000000..751fbd387 --- /dev/null +++ b/tests/rl/test_trainer.py @@ -0,0 +1,202 @@ +import os +import random +import sys +from pathlib import Path + +import pytest + +import torch +import torch.nn as nn +from gym import spaces +from tianshou.policy import PPOPolicy + +from qlib.config import C +from qlib.log import set_log_with_config +from qlib.rl.interpreter import StateInterpreter, ActionInterpreter +from qlib.rl.simulator import Simulator +from qlib.rl.reward import Reward +from qlib.rl.trainer import Trainer, TrainingVessel, EarlyStopping, Checkpoint + +pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8") + + +class ZeroSimulator(Simulator): + def __init__(self, *args, **kwargs): + self.action = self.correct = 0 + + def step(self, action): + self.action = action + self.correct = action == 0 + self._done = random.choice([False, True]) + if self._done: + self.env.logger.add_scalar("acc", self.correct * 100) + + def get_state(self): + return { + "acc": self.correct * 100, + "action": self.action, + } + + def done(self) -> bool: + return self._done + + +class NoopStateInterpreter(StateInterpreter): + observation_space = spaces.Dict( + { + "acc": spaces.Discrete(200), + "action": spaces.Discrete(2), + } + ) + + def interpret(self, simulator_state): + return simulator_state + + +class NoopActionInterpreter(ActionInterpreter): + action_space = spaces.Discrete(2) + + def interpret(self, simulator_state, action): + return action + + +class AccReward(Reward): + def reward(self, simulator_state): + if self.env.status["done"]: + return simulator_state["acc"] / 100 + return 0.0 + + +class PolicyNet(nn.Module): + def __init__(self, out_features=1, return_state=False): + super().__init__() + self.fc = nn.Linear(32, out_features) + self.return_state = return_state + + def forward(self, obs, state=None, **kwargs): + res = self.fc(torch.randn(obs["acc"].shape[0], 32)) + if self.return_state: + return nn.functional.softmax(res, dim=-1), state + else: + return res + + +def _ppo_policy(): + actor = PolicyNet(2, True) + critic = PolicyNet() + policy = PPOPolicy( + actor, + critic, + torch.optim.Adam(tuple(actor.parameters()) + tuple(critic.parameters())), + torch.distributions.Categorical, + action_space=NoopActionInterpreter().action_space, + ) + return policy + + +def test_trainer(): + set_log_with_config(C.logging_config) + trainer = Trainer(max_iters=10, finite_env_type="subproc") + policy = _ppo_policy() + + vessel = TrainingVessel( + simulator_fn=lambda init: ZeroSimulator(init), + state_interpreter=NoopStateInterpreter(), + action_interpreter=NoopActionInterpreter(), + policy=policy, + train_initial_states=list(range(100)), + val_initial_states=list(range(10)), + test_initial_states=list(range(10)), + reward=AccReward(), + episode_per_iter=500, + update_kwargs=dict(repeat=10, batch_size=64), + ) + trainer.fit(vessel) + assert trainer.current_iter == 10 + assert trainer.current_episode == 5000 + assert abs(trainer.metrics["acc"] - trainer.metrics["reward"] * 100) < 1e-4 + assert trainer.metrics["acc"] > 80 + trainer.test(vessel) + assert trainer.metrics["acc"] > 60 + + +def test_trainer_fast_dev_run(): + set_log_with_config(C.logging_config) + trainer = Trainer(max_iters=2, fast_dev_run=2, finite_env_type="shmem") + policy = _ppo_policy() + + vessel = TrainingVessel( + simulator_fn=lambda init: ZeroSimulator(init), + state_interpreter=NoopStateInterpreter(), + action_interpreter=NoopActionInterpreter(), + policy=policy, + train_initial_states=list(range(100)), + val_initial_states=list(range(10)), + test_initial_states=list(range(10)), + reward=AccReward(), + episode_per_iter=500, + update_kwargs=dict(repeat=10, batch_size=64), + ) + trainer.fit(vessel) + assert trainer.current_episode == 4 + + +def test_trainer_earlystop(): + # TODO this is just sanity check. + # need to see the logs to check whether it works. + set_log_with_config(C.logging_config) + trainer = Trainer( + max_iters=10, + val_every_n_iters=1, + finite_env_type="dummy", + callbacks=[EarlyStopping("val/reward", restore_best_weights=True)], + ) + policy = _ppo_policy() + + vessel = TrainingVessel( + simulator_fn=lambda init: ZeroSimulator(init), + state_interpreter=NoopStateInterpreter(), + action_interpreter=NoopActionInterpreter(), + policy=policy, + train_initial_states=list(range(100)), + val_initial_states=list(range(10)), + test_initial_states=list(range(10)), + reward=AccReward(), + episode_per_iter=500, + update_kwargs=dict(repeat=10, batch_size=64), + ) + trainer.fit(vessel) + assert trainer.metrics["val/acc"] > 30 + assert trainer.current_iter == 2 # second iteration + + +def test_trainer_checkpoint(): + set_log_with_config(C.logging_config) + output_dir = Path(__file__).parent / ".output" + trainer = Trainer(max_iters=2, finite_env_type="dummy", callbacks=[Checkpoint(output_dir, every_n_iters=1)]) + policy = _ppo_policy() + + vessel = TrainingVessel( + simulator_fn=lambda init: ZeroSimulator(init), + state_interpreter=NoopStateInterpreter(), + action_interpreter=NoopActionInterpreter(), + policy=policy, + train_initial_states=list(range(100)), + val_initial_states=list(range(10)), + test_initial_states=list(range(10)), + reward=AccReward(), + episode_per_iter=100, + update_kwargs=dict(repeat=10, batch_size=64), + ) + trainer.fit(vessel) + + assert (output_dir / "001.pth").exists() + assert (output_dir / "002.pth").exists() + assert os.readlink(output_dir / "latest.pth") == str(output_dir / "002.pth") + + trainer.load_state_dict(torch.load(output_dir / "001.pth")) + assert trainer.current_iter == 1 + assert trainer.current_episode == 100 + + # Reload the checkpoint at first iteration + trainer.fit(vessel, ckpt_path=output_dir / "001.pth")