From d8fc9aea6b5c795afe35601ebcd41059547cf7d8 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Wed, 18 Jan 2023 16:17:06 +0800 Subject: [PATCH] RL Training pipeline on 5-min data (#1415) * Workflow runnable * CI * Slight changes to make the workflow runnable. The changes of handler/provider should be reverted before merging. * Train experiment successful * Refine handler & provider * CI issues * Resolve PR comments * Resolve PR comments * CI issues * Fix test issue * Black --- qlib/contrib/data/highfreq_handler.py | 57 +++++++++-------- qlib/contrib/data/highfreq_provider.py | 10 +-- qlib/rl/contrib/train_onpolicy.py | 68 +++++++++++++++------ qlib/rl/data/pickle_styled.py | 12 +++- qlib/rl/order_execution/reward.py | 7 ++- qlib/rl/order_execution/simulator_simple.py | 9 ++- qlib/rl/trainer/__init__.py | 13 +++- qlib/rl/trainer/callbacks.py | 27 +++++++- qlib/rl/trainer/trainer.py | 7 +++ 9 files changed, 153 insertions(+), 57 deletions(-) diff --git a/qlib/contrib/data/highfreq_handler.py b/qlib/contrib/data/highfreq_handler.py index 373b8e669..f69f8195f 100644 --- a/qlib/contrib/data/highfreq_handler.py +++ b/qlib/contrib/data/highfreq_handler.py @@ -113,8 +113,11 @@ class HighFreqGeneralHandler(DataHandlerLP): fit_end_time=None, drop_raw=True, day_length=240, + freq="1min", + columns=["$open", "$high", "$low", "$close", "$vwap"], ): self.day_length = day_length + self.columns = columns infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) @@ -124,7 +127,7 @@ class HighFreqGeneralHandler(DataHandlerLP): "kwargs": { "config": self.get_feature_config(), "swap_level": False, - "freq": "1min", + "freq": freq, }, } super().__init__( @@ -160,19 +163,13 @@ class HighFreqGeneralHandler(DataHandlerLP): ) return feature_ops - fields += [get_normalized_price_feature("$open", 0)] - fields += [get_normalized_price_feature("$high", 0)] - fields += [get_normalized_price_feature("$low", 0)] - fields += [get_normalized_price_feature("$close", 0)] - fields += [get_normalized_price_feature("$vwap", 0)] - names += ["$open", "$high", "$low", "$close", "$vwap"] + for column_name in self.columns: + fields.append(get_normalized_price_feature(column_name, 0)) + names.append(column_name) - fields += [get_normalized_price_feature("$open", self.day_length)] - fields += [get_normalized_price_feature("$high", self.day_length)] - fields += [get_normalized_price_feature("$low", self.day_length)] - fields += [get_normalized_price_feature("$close", self.day_length)] - fields += [get_normalized_price_feature("$vwap", self.day_length)] - names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"] + for column_name in self.columns: + fields.append(get_normalized_price_feature(column_name, self.day_length)) + names.append(column_name + "_1") # calculate and fill nan with 0 fields += [ @@ -258,14 +255,17 @@ class HighFreqGeneralBacktestHandler(DataHandler): start_time=None, end_time=None, day_length=240, + freq="1min", + columns=["$close", "$vwap", "$volume"], ): self.day_length = day_length + self.columns = set(columns) data_loader = { "class": "QlibDataLoader", "kwargs": { "config": self.get_feature_config(), "swap_level": False, - "freq": "1min", + "freq": freq, }, } super().__init__( @@ -279,21 +279,24 @@ class HighFreqGeneralBacktestHandler(DataHandler): fields = [] names = [] - template_paused = f"Cut({{0}}, {self.day_length * 2}, None)" - template_fillnan = "FFillNan({0})" - template_if = "If(IsNull({1}), {0}, {1})" - fields += [ - template_paused.format(template_fillnan.format("$close")), - ] - names += ["$close0"] + if "$close" in self.columns: + template_paused = f"Cut({{0}}, {self.day_length * 2}, None)" + template_fillnan = "FFillNan({0})" + template_if = "If(IsNull({1}), {0}, {1})" + fields += [ + template_paused.format(template_fillnan.format("$close")), + ] + names += ["$close0"] - fields += [ - template_paused.format(template_if.format(template_fillnan.format("$close"), "$vwap")), - ] - names += ["$vwap0"] + if "$vwap" in self.columns: + fields += [ + template_paused.format(template_if.format(template_fillnan.format("$close"), "$vwap")), + ] + names += ["$vwap0"] - fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$volume"))] - names += ["$volume0"] + if "$volume" in self.columns: + fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$volume"))] + names += ["$volume0"] return fields, names diff --git a/qlib/contrib/data/highfreq_provider.py b/qlib/contrib/data/highfreq_provider.py index 704b37f72..b499cc68e 100644 --- a/qlib/contrib/data/highfreq_provider.py +++ b/qlib/contrib/data/highfreq_provider.py @@ -28,6 +28,7 @@ class HighFreqProvider: feature_conf: dict, label_conf: Optional[dict] = None, backtest_conf: dict = None, + freq: str = "1min", **kwargs, ) -> None: self.start_time = start_time @@ -42,6 +43,7 @@ class HighFreqProvider: self.backtest_conf = backtest_conf self.qlib_conf = qlib_conf self.logger = get_module_logger("HighFreqProvider") + self.freq = freq def get_pre_datasets(self): """Generate the training, validation and test datasets for prediction @@ -116,8 +118,8 @@ class HighFreqProvider: # This code used the copy-on-write feature of Linux # to avoid calculating the calendar multiple times in the subprocess. # This code may accelerate, but may be not useful on Windows and Mac Os - Cal.calendar(freq="1min") - get_calendar_day(freq="1min") + Cal.calendar(freq=self.freq) + get_calendar_day(freq=self.freq) def _gen_dataframe(self, config, datasets=["train", "valid", "test"]): try: @@ -240,7 +242,7 @@ class HighFreqProvider: with open(path + "tmp_dataset.pkl", "rb") as f: new_dataset = pkl.load(f) - time_list = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="1min")[::240] + time_list = D.calendar(start_time=self.start_time, end_time=self.end_time, freq=self.freq)[::240] def generate_dataset(times): if os.path.isfile(path + times.strftime("%Y-%m-%d") + ".pkl"): @@ -283,7 +285,7 @@ class HighFreqProvider: instruments = D.instruments(market="all") stock_list = D.list_instruments( - instruments=instruments, start_time=self.start_time, end_time=self.end_time, freq="1min", as_list=True + instruments=instruments, start_time=self.start_time, end_time=self.end_time, freq=self.freq, as_list=True ) def generate_dataset(stock): diff --git a/qlib/rl/contrib/train_onpolicy.py b/qlib/rl/contrib/train_onpolicy.py index f043dda64..d05994854 100644 --- a/qlib/rl/contrib/train_onpolicy.py +++ b/qlib/rl/contrib/train_onpolicy.py @@ -8,6 +8,7 @@ from typing import cast, List, Optional import numpy as np import pandas as pd +import qlib import torch import yaml from qlib.backtest import Order @@ -17,7 +18,9 @@ from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data from qlib.rl.interpreter import ActionInterpreter, StateInterpreter from qlib.rl.order_execution import SingleAssetOrderExecutionSimple from qlib.rl.reward import Reward -from qlib.rl.trainer import Checkpoint, train +from qlib.rl.trainer import Checkpoint, backtest, train +from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter +from qlib.rl.utils.log import CsvWriter from qlib.utils import init_instance_by_config from tianshou.policy import BasePolicy from torch import nn @@ -98,40 +101,54 @@ def train_and_test( action_interpreter: ActionInterpreter, policy: BasePolicy, reward: Reward, + run_backtest: bool, ) -> None: + qlib.init() + order_root_path = Path(data_config["source"]["order_dir"]) + data_granularity = simulator_config.get("data_granularity", 1) + def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple: return SingleAssetOrderExecutionSimple( order=order, data_dir=Path(data_config["source"]["data_dir"]), ticks_per_step=simulator_config["time_per_step"], + data_granularity=data_granularity, deal_price_type=data_config["source"].get("deal_price_column", "close"), vol_threshold=simulator_config["vol_limit"], ) - train_dataset = LazyLoadDataset( - order_file_path=order_root_path / "train", - data_dir=Path(data_config["source"]["data_dir"]), - default_start_time_index=data_config["source"]["default_start_time"], - default_end_time_index=data_config["source"]["default_end_time"], - ) - valid_dataset = LazyLoadDataset( - order_file_path=order_root_path / "valid", - data_dir=Path(data_config["source"]["data_dir"]), - default_start_time_index=data_config["source"]["default_start_time"], - default_end_time_index=data_config["source"]["default_end_time"], - ) + assert data_config["source"]["default_start_time_index"] % data_granularity == 0 + assert data_config["source"]["default_end_time_index"] % data_granularity == 0 + + train_dataset, valid_dataset, test_dataset = [ + LazyLoadDataset( + order_file_path=order_root_path / tag, + data_dir=Path(data_config["source"]["data_dir"]), + default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity, + default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity, + ) + for tag in ("train", "valid", "test") + ] - callbacks = [] if "checkpoint_path" in trainer_config: + callbacks: List[Callback] = [] + callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"]))) callbacks.append( Checkpoint( - dirpath=Path(trainer_config["checkpoint_path"]), - every_n_iters=trainer_config["checkpoint_every_n_iters"], + dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints", + every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1), save_latest="copy", ), ) + if "earlystop_patience" in trainer_config: + callbacks.append( + EarlyStopping( + patience=trainer_config["earlystop_patience"], + monitor="val/pa", + ) + ) trainer_kwargs = { "max_iters": trainer_config["max_epoch"], @@ -160,8 +177,21 @@ def train_and_test( vessel_kwargs=vessel_kwargs, ) + if run_backtest: + backtest( + simulator_fn=_simulator_factory_simple, + state_interpreter=state_interpreter, + action_interpreter=action_interpreter, + initial_states=test_dataset, + policy=policy, + logger=CsvWriter(Path(trainer_config["checkpoint_path"])), + reward=reward, + finite_env_type=trainer_kwargs["finite_env_type"], + concurrency=trainer_kwargs["concurrency"], + ) -def main(config: dict) -> None: + +def main(config: dict, run_backtest: bool) -> None: if "seed" in config["runtime"]: seed_everything(config["runtime"]["seed"]) @@ -200,6 +230,7 @@ def main(config: dict) -> None: state_interpreter=state_interpreter, policy=policy, reward=reward, + run_backtest=run_backtest, ) @@ -211,9 +242,10 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config_path", type=str, required=True, help="Path to the config file") + parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow after training is finished") args = parser.parse_args() with open(args.config_path, "r") as input_stream: config = yaml.safe_load(input_stream) - main(config) + main(config, run_backtest=args.run_backtest) diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index 3af1e2483..63b55d6e0 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -83,7 +83,16 @@ def _find_pickle(filename_without_suffix: Path) -> Path: @lru_cache(maxsize=10) # 10 * 40M = 400MB def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame: - return pd.read_pickle(_find_pickle(filename_without_suffix)) + df = pd.read_pickle(_find_pickle(filename_without_suffix)) + index_cols = df.index.names + + df = df.reset_index() + for date_col_name in ["date", "datetime"]: + if date_col_name in df: + df[date_col_name] = pd.to_datetime(df[date_col_name]) + df = df.set_index(index_cols) + + return df class SimpleIntradayBacktestData(BaseIntradayBacktestData): @@ -161,6 +170,7 @@ class IntradayProcessedData(BaseIntradayProcessedData): time_index: pd.Index, ) -> None: proc = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id) + # We have to infer the names here because, # unfortunately they are not included in the original data. cnames = _infer_processed_data_column_names(feature_dim) diff --git a/qlib/rl/order_execution/reward.py b/qlib/rl/order_execution/reward.py index 99a88f8e4..e83066d85 100644 --- a/qlib/rl/order_execution/reward.py +++ b/qlib/rl/order_execution/reward.py @@ -21,10 +21,13 @@ class PAPenaltyReward(Reward[SAOEState]): ---------- penalty The penalty for large volume in a short time. + scale + The weight used to scale up or down the reward. """ - def __init__(self, penalty: float = 100.0): + def __init__(self, penalty: float = 100.0, scale: float = 1.0) -> None: self.penalty = penalty + self.scale = scale def reward(self, simulator_state: SAOEState) -> float: whole_order = simulator_state.order.amount @@ -43,4 +46,4 @@ class PAPenaltyReward(Reward[SAOEState]): self.log("reward/pa", pa) self.log("reward/penalty", penalty) - return reward + return reward * self.scale diff --git a/qlib/rl/order_execution/simulator_simple.py b/qlib/rl/order_execution/simulator_simple.py index 9086e6047..f1c09c151 100644 --- a/qlib/rl/order_execution/simulator_simple.py +++ b/qlib/rl/order_execution/simulator_simple.py @@ -36,6 +36,8 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]): ---------- order The seed to start an SAOE simulator is an order. + data_granularity + Number of ticks between consecutive data entries. ticks_per_step How many ticks per step. data_dir @@ -71,14 +73,17 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]): self, order: Order, data_dir: Path, + data_granularity: int = 1, ticks_per_step: int = 30, deal_price_type: DealPriceType = "close", vol_threshold: Optional[float] = None, ) -> None: super().__init__(initial=order) + assert ticks_per_step % data_granularity == 0 + self.order = order - self.ticks_per_step: int = ticks_per_step + self.ticks_per_step: int = ticks_per_step // data_granularity self.deal_price_type = deal_price_type self.vol_threshold = vol_threshold self.data_dir = data_dir @@ -132,6 +137,8 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]): ticks_position = self.position - np.cumsum(exec_vol) self.position -= exec_vol.sum() + if abs(self.position) < 1e-6: + self.position = 0.0 if self.position < -EPS or (exec_vol < -EPS).any(): raise ValueError(f"Execution volume is invalid: {exec_vol} (position = {self.position})") diff --git a/qlib/rl/trainer/__init__.py b/qlib/rl/trainer/__init__.py index 4c5121ece..828ba7bd3 100644 --- a/qlib/rl/trainer/__init__.py +++ b/qlib/rl/trainer/__init__.py @@ -4,8 +4,17 @@ """Train, test, inference utilities.""" from .api import backtest, train -from .callbacks import Checkpoint, EarlyStopping +from .callbacks import Checkpoint, EarlyStopping, MetricsWriter from .trainer import Trainer from .vessel import TrainingVessel, TrainingVesselBase -__all__ = ["Trainer", "TrainingVessel", "TrainingVesselBase", "Checkpoint", "EarlyStopping", "train", "backtest"] +__all__ = [ + "Trainer", + "TrainingVessel", + "TrainingVesselBase", + "Checkpoint", + "EarlyStopping", + "MetricsWriter", + "train", + "backtest", +] diff --git a/qlib/rl/trainer/callbacks.py b/qlib/rl/trainer/callbacks.py index e5422075e..9d1bf4ba2 100644 --- a/qlib/rl/trainer/callbacks.py +++ b/qlib/rl/trainer/callbacks.py @@ -13,9 +13,10 @@ import shutil import time from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import Any, List, TYPE_CHECKING import numpy as np +import pandas as pd import torch from qlib.log import get_module_logger @@ -25,7 +26,6 @@ if TYPE_CHECKING: from .trainer import Trainer from .vessel import TrainingVesselBase - _logger = get_module_logger(__name__) @@ -155,6 +155,11 @@ class EarlyStopping(Callback): if self.baseline is None or self._is_improvement(current, self.baseline): self.wait = 0 + msg = ( + f"#{trainer.current_iter} current reward: {current:.4f}, best reward: {self.best:.4f} in #{self.best_iter}" + ) + _logger.info(msg) + # Only check after the first epoch. if self.wait >= self.patience and trainer.current_iter > 0: trainer.should_stop = True @@ -177,6 +182,24 @@ class EarlyStopping(Callback): return self.monitor_op(monitor_value - self.min_delta, reference_value) +class MetricsWriter(Callback): + """Dump training metrics to file.""" + + def __init__(self, dirpath: Path) -> None: + self.dirpath = dirpath + self.dirpath.mkdir(exist_ok=True, parents=True) + self.train_records: List[dict] = [] + self.valid_records: List[dict] = [] + + def on_train_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: + self.train_records.append({k: v for k, v in trainer.metrics.items() if not k.startswith("val/")}) + pd.DataFrame.from_records(self.train_records).to_csv(self.dirpath / "train_result.csv", index=True) + + def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: + self.valid_records.append({k: v for k, v in trainer.metrics.items() if k.startswith("val/")}) + pd.DataFrame.from_records(self.valid_records).to_csv(self.dirpath / "validation_result.csv", index=True) + + class Checkpoint(Callback): """Save checkpoints periodically for persistence and recovery. diff --git a/qlib/rl/trainer/trainer.py b/qlib/rl/trainer/trainer.py index 7573b3391..fb73dd549 100644 --- a/qlib/rl/trainer/trainer.py +++ b/qlib/rl/trainer/trainer.py @@ -6,6 +6,7 @@ from __future__ import annotations import collections import copy from contextlib import AbstractContextManager, contextmanager +from datetime import datetime from pathlib import Path from typing import Any, Dict, Iterable, List, OrderedDict, Sequence, TypeVar, cast @@ -206,6 +207,9 @@ class Trainer: self._call_callback_hooks("on_fit_start") while not self.should_stop: + msg = f"\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tTrain iteration {self.current_iter + 1}/{self.max_iters}" + _logger.info(msg) + self.initialize_iter() self._call_callback_hooks("on_iter_start") @@ -218,6 +222,7 @@ class Trainer: with _wrap_context(vessel.train_seed_iterator()) as iterator: vector_env = self.venv_from_iterator(iterator) self.vessel.train(vector_env) + del vector_env # FIXME: Explicitly delete this object to avoid memory leak. self._call_callback_hooks("on_train_end") @@ -228,6 +233,7 @@ class Trainer: with _wrap_context(vessel.val_seed_iterator()) as iterator: vector_env = self.venv_from_iterator(iterator) self.vessel.validate(vector_env) + del vector_env # FIXME: Explicitly delete this object to avoid memory leak. self._call_callback_hooks("on_validate_end") @@ -262,6 +268,7 @@ class Trainer: with _wrap_context(vessel.test_seed_iterator()) as iterator: vector_env = self.venv_from_iterator(iterator) self.vessel.test(vector_env) + del vector_env # FIXME: Explicitly delete this object to avoid memory leak. self._call_callback_hooks("on_test_end") def venv_from_iterator(self, iterator: Iterable[InitialStateType]) -> FiniteVectorEnv: