mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
RL Training pipeline on 5-min data (#1415)
* Workflow runnable * CI * Slight changes to make the workflow runnable. The changes of handler/provider should be reverted before merging. * Train experiment successful * Refine handler & provider * CI issues * Resolve PR comments * Resolve PR comments * CI issues * Fix test issue * Black
This commit is contained in:
@@ -113,8 +113,11 @@ class HighFreqGeneralHandler(DataHandlerLP):
|
||||
fit_end_time=None,
|
||||
drop_raw=True,
|
||||
day_length=240,
|
||||
freq="1min",
|
||||
columns=["$open", "$high", "$low", "$close", "$vwap"],
|
||||
):
|
||||
self.day_length = day_length
|
||||
self.columns = columns
|
||||
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
@@ -124,7 +127,7 @@ class HighFreqGeneralHandler(DataHandlerLP):
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
"freq": freq,
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
@@ -160,19 +163,13 @@ class HighFreqGeneralHandler(DataHandlerLP):
|
||||
)
|
||||
return feature_ops
|
||||
|
||||
fields += [get_normalized_price_feature("$open", 0)]
|
||||
fields += [get_normalized_price_feature("$high", 0)]
|
||||
fields += [get_normalized_price_feature("$low", 0)]
|
||||
fields += [get_normalized_price_feature("$close", 0)]
|
||||
fields += [get_normalized_price_feature("$vwap", 0)]
|
||||
names += ["$open", "$high", "$low", "$close", "$vwap"]
|
||||
for column_name in self.columns:
|
||||
fields.append(get_normalized_price_feature(column_name, 0))
|
||||
names.append(column_name)
|
||||
|
||||
fields += [get_normalized_price_feature("$open", self.day_length)]
|
||||
fields += [get_normalized_price_feature("$high", self.day_length)]
|
||||
fields += [get_normalized_price_feature("$low", self.day_length)]
|
||||
fields += [get_normalized_price_feature("$close", self.day_length)]
|
||||
fields += [get_normalized_price_feature("$vwap", self.day_length)]
|
||||
names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"]
|
||||
for column_name in self.columns:
|
||||
fields.append(get_normalized_price_feature(column_name, self.day_length))
|
||||
names.append(column_name + "_1")
|
||||
|
||||
# calculate and fill nan with 0
|
||||
fields += [
|
||||
@@ -258,14 +255,17 @@ class HighFreqGeneralBacktestHandler(DataHandler):
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
day_length=240,
|
||||
freq="1min",
|
||||
columns=["$close", "$vwap", "$volume"],
|
||||
):
|
||||
self.day_length = day_length
|
||||
self.columns = set(columns)
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
"freq": freq,
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
@@ -279,21 +279,24 @@ class HighFreqGeneralBacktestHandler(DataHandler):
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
template_paused = f"Cut({{0}}, {self.day_length * 2}, None)"
|
||||
template_fillnan = "FFillNan({0})"
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
fields += [
|
||||
template_paused.format(template_fillnan.format("$close")),
|
||||
]
|
||||
names += ["$close0"]
|
||||
if "$close" in self.columns:
|
||||
template_paused = f"Cut({{0}}, {self.day_length * 2}, None)"
|
||||
template_fillnan = "FFillNan({0})"
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
fields += [
|
||||
template_paused.format(template_fillnan.format("$close")),
|
||||
]
|
||||
names += ["$close0"]
|
||||
|
||||
fields += [
|
||||
template_paused.format(template_if.format(template_fillnan.format("$close"), "$vwap")),
|
||||
]
|
||||
names += ["$vwap0"]
|
||||
if "$vwap" in self.columns:
|
||||
fields += [
|
||||
template_paused.format(template_if.format(template_fillnan.format("$close"), "$vwap")),
|
||||
]
|
||||
names += ["$vwap0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$volume"))]
|
||||
names += ["$volume0"]
|
||||
if "$volume" in self.columns:
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$volume"))]
|
||||
names += ["$volume0"]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ class HighFreqProvider:
|
||||
feature_conf: dict,
|
||||
label_conf: Optional[dict] = None,
|
||||
backtest_conf: dict = None,
|
||||
freq: str = "1min",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.start_time = start_time
|
||||
@@ -42,6 +43,7 @@ class HighFreqProvider:
|
||||
self.backtest_conf = backtest_conf
|
||||
self.qlib_conf = qlib_conf
|
||||
self.logger = get_module_logger("HighFreqProvider")
|
||||
self.freq = freq
|
||||
|
||||
def get_pre_datasets(self):
|
||||
"""Generate the training, validation and test datasets for prediction
|
||||
@@ -116,8 +118,8 @@ class HighFreqProvider:
|
||||
# This code used the copy-on-write feature of Linux
|
||||
# to avoid calculating the calendar multiple times in the subprocess.
|
||||
# This code may accelerate, but may be not useful on Windows and Mac Os
|
||||
Cal.calendar(freq="1min")
|
||||
get_calendar_day(freq="1min")
|
||||
Cal.calendar(freq=self.freq)
|
||||
get_calendar_day(freq=self.freq)
|
||||
|
||||
def _gen_dataframe(self, config, datasets=["train", "valid", "test"]):
|
||||
try:
|
||||
@@ -240,7 +242,7 @@ class HighFreqProvider:
|
||||
with open(path + "tmp_dataset.pkl", "rb") as f:
|
||||
new_dataset = pkl.load(f)
|
||||
|
||||
time_list = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="1min")[::240]
|
||||
time_list = D.calendar(start_time=self.start_time, end_time=self.end_time, freq=self.freq)[::240]
|
||||
|
||||
def generate_dataset(times):
|
||||
if os.path.isfile(path + times.strftime("%Y-%m-%d") + ".pkl"):
|
||||
@@ -283,7 +285,7 @@ class HighFreqProvider:
|
||||
|
||||
instruments = D.instruments(market="all")
|
||||
stock_list = D.list_instruments(
|
||||
instruments=instruments, start_time=self.start_time, end_time=self.end_time, freq="1min", as_list=True
|
||||
instruments=instruments, start_time=self.start_time, end_time=self.end_time, freq=self.freq, as_list=True
|
||||
)
|
||||
|
||||
def generate_dataset(stock):
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import cast, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import qlib
|
||||
import torch
|
||||
import yaml
|
||||
from qlib.backtest import Order
|
||||
@@ -17,7 +18,9 @@ from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.trainer import Checkpoint, train
|
||||
from qlib.rl.trainer import Checkpoint, backtest, train
|
||||
from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
|
||||
from qlib.rl.utils.log import CsvWriter
|
||||
from qlib.utils import init_instance_by_config
|
||||
from tianshou.policy import BasePolicy
|
||||
from torch import nn
|
||||
@@ -98,40 +101,54 @@ def train_and_test(
|
||||
action_interpreter: ActionInterpreter,
|
||||
policy: BasePolicy,
|
||||
reward: Reward,
|
||||
run_backtest: bool,
|
||||
) -> None:
|
||||
qlib.init()
|
||||
|
||||
order_root_path = Path(data_config["source"]["order_dir"])
|
||||
|
||||
data_granularity = simulator_config.get("data_granularity", 1)
|
||||
|
||||
def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
|
||||
return SingleAssetOrderExecutionSimple(
|
||||
order=order,
|
||||
data_dir=Path(data_config["source"]["data_dir"]),
|
||||
ticks_per_step=simulator_config["time_per_step"],
|
||||
data_granularity=data_granularity,
|
||||
deal_price_type=data_config["source"].get("deal_price_column", "close"),
|
||||
vol_threshold=simulator_config["vol_limit"],
|
||||
)
|
||||
|
||||
train_dataset = LazyLoadDataset(
|
||||
order_file_path=order_root_path / "train",
|
||||
data_dir=Path(data_config["source"]["data_dir"]),
|
||||
default_start_time_index=data_config["source"]["default_start_time"],
|
||||
default_end_time_index=data_config["source"]["default_end_time"],
|
||||
)
|
||||
valid_dataset = LazyLoadDataset(
|
||||
order_file_path=order_root_path / "valid",
|
||||
data_dir=Path(data_config["source"]["data_dir"]),
|
||||
default_start_time_index=data_config["source"]["default_start_time"],
|
||||
default_end_time_index=data_config["source"]["default_end_time"],
|
||||
)
|
||||
assert data_config["source"]["default_start_time_index"] % data_granularity == 0
|
||||
assert data_config["source"]["default_end_time_index"] % data_granularity == 0
|
||||
|
||||
train_dataset, valid_dataset, test_dataset = [
|
||||
LazyLoadDataset(
|
||||
order_file_path=order_root_path / tag,
|
||||
data_dir=Path(data_config["source"]["data_dir"]),
|
||||
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
|
||||
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
|
||||
)
|
||||
for tag in ("train", "valid", "test")
|
||||
]
|
||||
|
||||
callbacks = []
|
||||
if "checkpoint_path" in trainer_config:
|
||||
callbacks: List[Callback] = []
|
||||
callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"])))
|
||||
callbacks.append(
|
||||
Checkpoint(
|
||||
dirpath=Path(trainer_config["checkpoint_path"]),
|
||||
every_n_iters=trainer_config["checkpoint_every_n_iters"],
|
||||
dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints",
|
||||
every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1),
|
||||
save_latest="copy",
|
||||
),
|
||||
)
|
||||
if "earlystop_patience" in trainer_config:
|
||||
callbacks.append(
|
||||
EarlyStopping(
|
||||
patience=trainer_config["earlystop_patience"],
|
||||
monitor="val/pa",
|
||||
)
|
||||
)
|
||||
|
||||
trainer_kwargs = {
|
||||
"max_iters": trainer_config["max_epoch"],
|
||||
@@ -160,8 +177,21 @@ def train_and_test(
|
||||
vessel_kwargs=vessel_kwargs,
|
||||
)
|
||||
|
||||
if run_backtest:
|
||||
backtest(
|
||||
simulator_fn=_simulator_factory_simple,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
initial_states=test_dataset,
|
||||
policy=policy,
|
||||
logger=CsvWriter(Path(trainer_config["checkpoint_path"])),
|
||||
reward=reward,
|
||||
finite_env_type=trainer_kwargs["finite_env_type"],
|
||||
concurrency=trainer_kwargs["concurrency"],
|
||||
)
|
||||
|
||||
def main(config: dict) -> None:
|
||||
|
||||
def main(config: dict, run_backtest: bool) -> None:
|
||||
if "seed" in config["runtime"]:
|
||||
seed_everything(config["runtime"]["seed"])
|
||||
|
||||
@@ -200,6 +230,7 @@ def main(config: dict) -> None:
|
||||
state_interpreter=state_interpreter,
|
||||
policy=policy,
|
||||
reward=reward,
|
||||
run_backtest=run_backtest,
|
||||
)
|
||||
|
||||
|
||||
@@ -211,9 +242,10 @@ if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
|
||||
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow after training is finished")
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config_path, "r") as input_stream:
|
||||
config = yaml.safe_load(input_stream)
|
||||
|
||||
main(config)
|
||||
main(config, run_backtest=args.run_backtest)
|
||||
|
||||
@@ -83,7 +83,16 @@ def _find_pickle(filename_without_suffix: Path) -> Path:
|
||||
|
||||
@lru_cache(maxsize=10) # 10 * 40M = 400MB
|
||||
def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame:
|
||||
return pd.read_pickle(_find_pickle(filename_without_suffix))
|
||||
df = pd.read_pickle(_find_pickle(filename_without_suffix))
|
||||
index_cols = df.index.names
|
||||
|
||||
df = df.reset_index()
|
||||
for date_col_name in ["date", "datetime"]:
|
||||
if date_col_name in df:
|
||||
df[date_col_name] = pd.to_datetime(df[date_col_name])
|
||||
df = df.set_index(index_cols)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
class SimpleIntradayBacktestData(BaseIntradayBacktestData):
|
||||
@@ -161,6 +170,7 @@ class IntradayProcessedData(BaseIntradayProcessedData):
|
||||
time_index: pd.Index,
|
||||
) -> None:
|
||||
proc = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id)
|
||||
|
||||
# We have to infer the names here because,
|
||||
# unfortunately they are not included in the original data.
|
||||
cnames = _infer_processed_data_column_names(feature_dim)
|
||||
|
||||
@@ -21,10 +21,13 @@ class PAPenaltyReward(Reward[SAOEState]):
|
||||
----------
|
||||
penalty
|
||||
The penalty for large volume in a short time.
|
||||
scale
|
||||
The weight used to scale up or down the reward.
|
||||
"""
|
||||
|
||||
def __init__(self, penalty: float = 100.0):
|
||||
def __init__(self, penalty: float = 100.0, scale: float = 1.0) -> None:
|
||||
self.penalty = penalty
|
||||
self.scale = scale
|
||||
|
||||
def reward(self, simulator_state: SAOEState) -> float:
|
||||
whole_order = simulator_state.order.amount
|
||||
@@ -43,4 +46,4 @@ class PAPenaltyReward(Reward[SAOEState]):
|
||||
|
||||
self.log("reward/pa", pa)
|
||||
self.log("reward/penalty", penalty)
|
||||
return reward
|
||||
return reward * self.scale
|
||||
|
||||
@@ -36,6 +36,8 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
----------
|
||||
order
|
||||
The seed to start an SAOE simulator is an order.
|
||||
data_granularity
|
||||
Number of ticks between consecutive data entries.
|
||||
ticks_per_step
|
||||
How many ticks per step.
|
||||
data_dir
|
||||
@@ -71,14 +73,17 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
self,
|
||||
order: Order,
|
||||
data_dir: Path,
|
||||
data_granularity: int = 1,
|
||||
ticks_per_step: int = 30,
|
||||
deal_price_type: DealPriceType = "close",
|
||||
vol_threshold: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__(initial=order)
|
||||
|
||||
assert ticks_per_step % data_granularity == 0
|
||||
|
||||
self.order = order
|
||||
self.ticks_per_step: int = ticks_per_step
|
||||
self.ticks_per_step: int = ticks_per_step // data_granularity
|
||||
self.deal_price_type = deal_price_type
|
||||
self.vol_threshold = vol_threshold
|
||||
self.data_dir = data_dir
|
||||
@@ -132,6 +137,8 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
ticks_position = self.position - np.cumsum(exec_vol)
|
||||
|
||||
self.position -= exec_vol.sum()
|
||||
if abs(self.position) < 1e-6:
|
||||
self.position = 0.0
|
||||
if self.position < -EPS or (exec_vol < -EPS).any():
|
||||
raise ValueError(f"Execution volume is invalid: {exec_vol} (position = {self.position})")
|
||||
|
||||
|
||||
@@ -4,8 +4,17 @@
|
||||
"""Train, test, inference utilities."""
|
||||
|
||||
from .api import backtest, train
|
||||
from .callbacks import Checkpoint, EarlyStopping
|
||||
from .callbacks import Checkpoint, EarlyStopping, MetricsWriter
|
||||
from .trainer import Trainer
|
||||
from .vessel import TrainingVessel, TrainingVesselBase
|
||||
|
||||
__all__ = ["Trainer", "TrainingVessel", "TrainingVesselBase", "Checkpoint", "EarlyStopping", "train", "backtest"]
|
||||
__all__ = [
|
||||
"Trainer",
|
||||
"TrainingVessel",
|
||||
"TrainingVesselBase",
|
||||
"Checkpoint",
|
||||
"EarlyStopping",
|
||||
"MetricsWriter",
|
||||
"train",
|
||||
"backtest",
|
||||
]
|
||||
|
||||
@@ -13,9 +13,10 @@ import shutil
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any, List, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
@@ -25,7 +26,6 @@ if TYPE_CHECKING:
|
||||
from .trainer import Trainer
|
||||
from .vessel import TrainingVesselBase
|
||||
|
||||
|
||||
_logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
@@ -155,6 +155,11 @@ class EarlyStopping(Callback):
|
||||
if self.baseline is None or self._is_improvement(current, self.baseline):
|
||||
self.wait = 0
|
||||
|
||||
msg = (
|
||||
f"#{trainer.current_iter} current reward: {current:.4f}, best reward: {self.best:.4f} in #{self.best_iter}"
|
||||
)
|
||||
_logger.info(msg)
|
||||
|
||||
# Only check after the first epoch.
|
||||
if self.wait >= self.patience and trainer.current_iter > 0:
|
||||
trainer.should_stop = True
|
||||
@@ -177,6 +182,24 @@ class EarlyStopping(Callback):
|
||||
return self.monitor_op(monitor_value - self.min_delta, reference_value)
|
||||
|
||||
|
||||
class MetricsWriter(Callback):
|
||||
"""Dump training metrics to file."""
|
||||
|
||||
def __init__(self, dirpath: Path) -> None:
|
||||
self.dirpath = dirpath
|
||||
self.dirpath.mkdir(exist_ok=True, parents=True)
|
||||
self.train_records: List[dict] = []
|
||||
self.valid_records: List[dict] = []
|
||||
|
||||
def on_train_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
self.train_records.append({k: v for k, v in trainer.metrics.items() if not k.startswith("val/")})
|
||||
pd.DataFrame.from_records(self.train_records).to_csv(self.dirpath / "train_result.csv", index=True)
|
||||
|
||||
def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
self.valid_records.append({k: v for k, v in trainer.metrics.items() if k.startswith("val/")})
|
||||
pd.DataFrame.from_records(self.valid_records).to_csv(self.dirpath / "validation_result.csv", index=True)
|
||||
|
||||
|
||||
class Checkpoint(Callback):
|
||||
"""Save checkpoints periodically for persistence and recovery.
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from __future__ import annotations
|
||||
import collections
|
||||
import copy
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, OrderedDict, Sequence, TypeVar, cast
|
||||
|
||||
@@ -206,6 +207,9 @@ class Trainer:
|
||||
self._call_callback_hooks("on_fit_start")
|
||||
|
||||
while not self.should_stop:
|
||||
msg = f"\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tTrain iteration {self.current_iter + 1}/{self.max_iters}"
|
||||
_logger.info(msg)
|
||||
|
||||
self.initialize_iter()
|
||||
|
||||
self._call_callback_hooks("on_iter_start")
|
||||
@@ -218,6 +222,7 @@ class Trainer:
|
||||
with _wrap_context(vessel.train_seed_iterator()) as iterator:
|
||||
vector_env = self.venv_from_iterator(iterator)
|
||||
self.vessel.train(vector_env)
|
||||
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
|
||||
self._call_callback_hooks("on_train_end")
|
||||
|
||||
@@ -228,6 +233,7 @@ class Trainer:
|
||||
with _wrap_context(vessel.val_seed_iterator()) as iterator:
|
||||
vector_env = self.venv_from_iterator(iterator)
|
||||
self.vessel.validate(vector_env)
|
||||
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
|
||||
self._call_callback_hooks("on_validate_end")
|
||||
|
||||
@@ -262,6 +268,7 @@ class Trainer:
|
||||
with _wrap_context(vessel.test_seed_iterator()) as iterator:
|
||||
vector_env = self.venv_from_iterator(iterator)
|
||||
self.vessel.test(vector_env)
|
||||
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
self._call_callback_hooks("on_test_end")
|
||||
|
||||
def venv_from_iterator(self, iterator: Iterable[InitialStateType]) -> FiniteVectorEnv:
|
||||
|
||||
Reference in New Issue
Block a user