1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

RL Training pipeline on 5-min data (#1415)

* Workflow runnable

* CI

* Slight changes to make the workflow runnable. The changes of handler/provider should be reverted before merging.

* Train experiment successful

* Refine handler & provider

* CI issues

* Resolve PR comments

* Resolve PR comments

* CI issues

* Fix test issue

* Black
This commit is contained in:
Huoran Li
2023-01-18 16:17:06 +08:00
committed by GitHub
parent d8764660dc
commit d8fc9aea6b
9 changed files with 153 additions and 57 deletions

View File

@@ -113,8 +113,11 @@ class HighFreqGeneralHandler(DataHandlerLP):
fit_end_time=None,
drop_raw=True,
day_length=240,
freq="1min",
columns=["$open", "$high", "$low", "$close", "$vwap"],
):
self.day_length = day_length
self.columns = columns
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
@@ -124,7 +127,7 @@ class HighFreqGeneralHandler(DataHandlerLP):
"kwargs": {
"config": self.get_feature_config(),
"swap_level": False,
"freq": "1min",
"freq": freq,
},
}
super().__init__(
@@ -160,19 +163,13 @@ class HighFreqGeneralHandler(DataHandlerLP):
)
return feature_ops
fields += [get_normalized_price_feature("$open", 0)]
fields += [get_normalized_price_feature("$high", 0)]
fields += [get_normalized_price_feature("$low", 0)]
fields += [get_normalized_price_feature("$close", 0)]
fields += [get_normalized_price_feature("$vwap", 0)]
names += ["$open", "$high", "$low", "$close", "$vwap"]
for column_name in self.columns:
fields.append(get_normalized_price_feature(column_name, 0))
names.append(column_name)
fields += [get_normalized_price_feature("$open", self.day_length)]
fields += [get_normalized_price_feature("$high", self.day_length)]
fields += [get_normalized_price_feature("$low", self.day_length)]
fields += [get_normalized_price_feature("$close", self.day_length)]
fields += [get_normalized_price_feature("$vwap", self.day_length)]
names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"]
for column_name in self.columns:
fields.append(get_normalized_price_feature(column_name, self.day_length))
names.append(column_name + "_1")
# calculate and fill nan with 0
fields += [
@@ -258,14 +255,17 @@ class HighFreqGeneralBacktestHandler(DataHandler):
start_time=None,
end_time=None,
day_length=240,
freq="1min",
columns=["$close", "$vwap", "$volume"],
):
self.day_length = day_length
self.columns = set(columns)
data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"config": self.get_feature_config(),
"swap_level": False,
"freq": "1min",
"freq": freq,
},
}
super().__init__(
@@ -279,21 +279,24 @@ class HighFreqGeneralBacktestHandler(DataHandler):
fields = []
names = []
template_paused = f"Cut({{0}}, {self.day_length * 2}, None)"
template_fillnan = "FFillNan({0})"
template_if = "If(IsNull({1}), {0}, {1})"
fields += [
template_paused.format(template_fillnan.format("$close")),
]
names += ["$close0"]
if "$close" in self.columns:
template_paused = f"Cut({{0}}, {self.day_length * 2}, None)"
template_fillnan = "FFillNan({0})"
template_if = "If(IsNull({1}), {0}, {1})"
fields += [
template_paused.format(template_fillnan.format("$close")),
]
names += ["$close0"]
fields += [
template_paused.format(template_if.format(template_fillnan.format("$close"), "$vwap")),
]
names += ["$vwap0"]
if "$vwap" in self.columns:
fields += [
template_paused.format(template_if.format(template_fillnan.format("$close"), "$vwap")),
]
names += ["$vwap0"]
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$volume"))]
names += ["$volume0"]
if "$volume" in self.columns:
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$volume"))]
names += ["$volume0"]
return fields, names

View File

@@ -28,6 +28,7 @@ class HighFreqProvider:
feature_conf: dict,
label_conf: Optional[dict] = None,
backtest_conf: dict = None,
freq: str = "1min",
**kwargs,
) -> None:
self.start_time = start_time
@@ -42,6 +43,7 @@ class HighFreqProvider:
self.backtest_conf = backtest_conf
self.qlib_conf = qlib_conf
self.logger = get_module_logger("HighFreqProvider")
self.freq = freq
def get_pre_datasets(self):
"""Generate the training, validation and test datasets for prediction
@@ -116,8 +118,8 @@ class HighFreqProvider:
# This code used the copy-on-write feature of Linux
# to avoid calculating the calendar multiple times in the subprocess.
# This code may accelerate, but may be not useful on Windows and Mac Os
Cal.calendar(freq="1min")
get_calendar_day(freq="1min")
Cal.calendar(freq=self.freq)
get_calendar_day(freq=self.freq)
def _gen_dataframe(self, config, datasets=["train", "valid", "test"]):
try:
@@ -240,7 +242,7 @@ class HighFreqProvider:
with open(path + "tmp_dataset.pkl", "rb") as f:
new_dataset = pkl.load(f)
time_list = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="1min")[::240]
time_list = D.calendar(start_time=self.start_time, end_time=self.end_time, freq=self.freq)[::240]
def generate_dataset(times):
if os.path.isfile(path + times.strftime("%Y-%m-%d") + ".pkl"):
@@ -283,7 +285,7 @@ class HighFreqProvider:
instruments = D.instruments(market="all")
stock_list = D.list_instruments(
instruments=instruments, start_time=self.start_time, end_time=self.end_time, freq="1min", as_list=True
instruments=instruments, start_time=self.start_time, end_time=self.end_time, freq=self.freq, as_list=True
)
def generate_dataset(stock):

View File

@@ -8,6 +8,7 @@ from typing import cast, List, Optional
import numpy as np
import pandas as pd
import qlib
import torch
import yaml
from qlib.backtest import Order
@@ -17,7 +18,9 @@ from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
from qlib.rl.reward import Reward
from qlib.rl.trainer import Checkpoint, train
from qlib.rl.trainer import Checkpoint, backtest, train
from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
from qlib.rl.utils.log import CsvWriter
from qlib.utils import init_instance_by_config
from tianshou.policy import BasePolicy
from torch import nn
@@ -98,40 +101,54 @@ def train_and_test(
action_interpreter: ActionInterpreter,
policy: BasePolicy,
reward: Reward,
run_backtest: bool,
) -> None:
qlib.init()
order_root_path = Path(data_config["source"]["order_dir"])
data_granularity = simulator_config.get("data_granularity", 1)
def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
return SingleAssetOrderExecutionSimple(
order=order,
data_dir=Path(data_config["source"]["data_dir"]),
ticks_per_step=simulator_config["time_per_step"],
data_granularity=data_granularity,
deal_price_type=data_config["source"].get("deal_price_column", "close"),
vol_threshold=simulator_config["vol_limit"],
)
train_dataset = LazyLoadDataset(
order_file_path=order_root_path / "train",
data_dir=Path(data_config["source"]["data_dir"]),
default_start_time_index=data_config["source"]["default_start_time"],
default_end_time_index=data_config["source"]["default_end_time"],
)
valid_dataset = LazyLoadDataset(
order_file_path=order_root_path / "valid",
data_dir=Path(data_config["source"]["data_dir"]),
default_start_time_index=data_config["source"]["default_start_time"],
default_end_time_index=data_config["source"]["default_end_time"],
)
assert data_config["source"]["default_start_time_index"] % data_granularity == 0
assert data_config["source"]["default_end_time_index"] % data_granularity == 0
train_dataset, valid_dataset, test_dataset = [
LazyLoadDataset(
order_file_path=order_root_path / tag,
data_dir=Path(data_config["source"]["data_dir"]),
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
)
for tag in ("train", "valid", "test")
]
callbacks = []
if "checkpoint_path" in trainer_config:
callbacks: List[Callback] = []
callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"])))
callbacks.append(
Checkpoint(
dirpath=Path(trainer_config["checkpoint_path"]),
every_n_iters=trainer_config["checkpoint_every_n_iters"],
dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints",
every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1),
save_latest="copy",
),
)
if "earlystop_patience" in trainer_config:
callbacks.append(
EarlyStopping(
patience=trainer_config["earlystop_patience"],
monitor="val/pa",
)
)
trainer_kwargs = {
"max_iters": trainer_config["max_epoch"],
@@ -160,8 +177,21 @@ def train_and_test(
vessel_kwargs=vessel_kwargs,
)
if run_backtest:
backtest(
simulator_fn=_simulator_factory_simple,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
initial_states=test_dataset,
policy=policy,
logger=CsvWriter(Path(trainer_config["checkpoint_path"])),
reward=reward,
finite_env_type=trainer_kwargs["finite_env_type"],
concurrency=trainer_kwargs["concurrency"],
)
def main(config: dict) -> None:
def main(config: dict, run_backtest: bool) -> None:
if "seed" in config["runtime"]:
seed_everything(config["runtime"]["seed"])
@@ -200,6 +230,7 @@ def main(config: dict) -> None:
state_interpreter=state_interpreter,
policy=policy,
reward=reward,
run_backtest=run_backtest,
)
@@ -211,9 +242,10 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow after training is finished")
args = parser.parse_args()
with open(args.config_path, "r") as input_stream:
config = yaml.safe_load(input_stream)
main(config)
main(config, run_backtest=args.run_backtest)

View File

@@ -83,7 +83,16 @@ def _find_pickle(filename_without_suffix: Path) -> Path:
@lru_cache(maxsize=10) # 10 * 40M = 400MB
def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame:
return pd.read_pickle(_find_pickle(filename_without_suffix))
df = pd.read_pickle(_find_pickle(filename_without_suffix))
index_cols = df.index.names
df = df.reset_index()
for date_col_name in ["date", "datetime"]:
if date_col_name in df:
df[date_col_name] = pd.to_datetime(df[date_col_name])
df = df.set_index(index_cols)
return df
class SimpleIntradayBacktestData(BaseIntradayBacktestData):
@@ -161,6 +170,7 @@ class IntradayProcessedData(BaseIntradayProcessedData):
time_index: pd.Index,
) -> None:
proc = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id)
# We have to infer the names here because,
# unfortunately they are not included in the original data.
cnames = _infer_processed_data_column_names(feature_dim)

View File

@@ -21,10 +21,13 @@ class PAPenaltyReward(Reward[SAOEState]):
----------
penalty
The penalty for large volume in a short time.
scale
The weight used to scale up or down the reward.
"""
def __init__(self, penalty: float = 100.0):
def __init__(self, penalty: float = 100.0, scale: float = 1.0) -> None:
self.penalty = penalty
self.scale = scale
def reward(self, simulator_state: SAOEState) -> float:
whole_order = simulator_state.order.amount
@@ -43,4 +46,4 @@ class PAPenaltyReward(Reward[SAOEState]):
self.log("reward/pa", pa)
self.log("reward/penalty", penalty)
return reward
return reward * self.scale

View File

@@ -36,6 +36,8 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
----------
order
The seed to start an SAOE simulator is an order.
data_granularity
Number of ticks between consecutive data entries.
ticks_per_step
How many ticks per step.
data_dir
@@ -71,14 +73,17 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
self,
order: Order,
data_dir: Path,
data_granularity: int = 1,
ticks_per_step: int = 30,
deal_price_type: DealPriceType = "close",
vol_threshold: Optional[float] = None,
) -> None:
super().__init__(initial=order)
assert ticks_per_step % data_granularity == 0
self.order = order
self.ticks_per_step: int = ticks_per_step
self.ticks_per_step: int = ticks_per_step // data_granularity
self.deal_price_type = deal_price_type
self.vol_threshold = vol_threshold
self.data_dir = data_dir
@@ -132,6 +137,8 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
ticks_position = self.position - np.cumsum(exec_vol)
self.position -= exec_vol.sum()
if abs(self.position) < 1e-6:
self.position = 0.0
if self.position < -EPS or (exec_vol < -EPS).any():
raise ValueError(f"Execution volume is invalid: {exec_vol} (position = {self.position})")

View File

@@ -4,8 +4,17 @@
"""Train, test, inference utilities."""
from .api import backtest, train
from .callbacks import Checkpoint, EarlyStopping
from .callbacks import Checkpoint, EarlyStopping, MetricsWriter
from .trainer import Trainer
from .vessel import TrainingVessel, TrainingVesselBase
__all__ = ["Trainer", "TrainingVessel", "TrainingVesselBase", "Checkpoint", "EarlyStopping", "train", "backtest"]
__all__ = [
"Trainer",
"TrainingVessel",
"TrainingVesselBase",
"Checkpoint",
"EarlyStopping",
"MetricsWriter",
"train",
"backtest",
]

View File

@@ -13,9 +13,10 @@ import shutil
import time
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import Any, List, TYPE_CHECKING
import numpy as np
import pandas as pd
import torch
from qlib.log import get_module_logger
@@ -25,7 +26,6 @@ if TYPE_CHECKING:
from .trainer import Trainer
from .vessel import TrainingVesselBase
_logger = get_module_logger(__name__)
@@ -155,6 +155,11 @@ class EarlyStopping(Callback):
if self.baseline is None or self._is_improvement(current, self.baseline):
self.wait = 0
msg = (
f"#{trainer.current_iter} current reward: {current:.4f}, best reward: {self.best:.4f} in #{self.best_iter}"
)
_logger.info(msg)
# Only check after the first epoch.
if self.wait >= self.patience and trainer.current_iter > 0:
trainer.should_stop = True
@@ -177,6 +182,24 @@ class EarlyStopping(Callback):
return self.monitor_op(monitor_value - self.min_delta, reference_value)
class MetricsWriter(Callback):
"""Dump training metrics to file."""
def __init__(self, dirpath: Path) -> None:
self.dirpath = dirpath
self.dirpath.mkdir(exist_ok=True, parents=True)
self.train_records: List[dict] = []
self.valid_records: List[dict] = []
def on_train_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
self.train_records.append({k: v for k, v in trainer.metrics.items() if not k.startswith("val/")})
pd.DataFrame.from_records(self.train_records).to_csv(self.dirpath / "train_result.csv", index=True)
def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
self.valid_records.append({k: v for k, v in trainer.metrics.items() if k.startswith("val/")})
pd.DataFrame.from_records(self.valid_records).to_csv(self.dirpath / "validation_result.csv", index=True)
class Checkpoint(Callback):
"""Save checkpoints periodically for persistence and recovery.

View File

@@ -6,6 +6,7 @@ from __future__ import annotations
import collections
import copy
from contextlib import AbstractContextManager, contextmanager
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Iterable, List, OrderedDict, Sequence, TypeVar, cast
@@ -206,6 +207,9 @@ class Trainer:
self._call_callback_hooks("on_fit_start")
while not self.should_stop:
msg = f"\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tTrain iteration {self.current_iter + 1}/{self.max_iters}"
_logger.info(msg)
self.initialize_iter()
self._call_callback_hooks("on_iter_start")
@@ -218,6 +222,7 @@ class Trainer:
with _wrap_context(vessel.train_seed_iterator()) as iterator:
vector_env = self.venv_from_iterator(iterator)
self.vessel.train(vector_env)
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
self._call_callback_hooks("on_train_end")
@@ -228,6 +233,7 @@ class Trainer:
with _wrap_context(vessel.val_seed_iterator()) as iterator:
vector_env = self.venv_from_iterator(iterator)
self.vessel.validate(vector_env)
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
self._call_callback_hooks("on_validate_end")
@@ -262,6 +268,7 @@ class Trainer:
with _wrap_context(vessel.test_seed_iterator()) as iterator:
vector_env = self.venv_from_iterator(iterator)
self.vessel.test(vector_env)
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
self._call_callback_hooks("on_test_end")
def venv_from_iterator(self, iterator: Iterable[InitialStateType]) -> FiniteVectorEnv: