mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 03:21:00 +08:00
Train on full simulation
This commit is contained in:
@@ -56,6 +56,7 @@ def collect_data_loop(
|
||||
trade_strategy: BaseStrategy,
|
||||
trade_executor: BaseExecutor,
|
||||
return_value: dict | None = None,
|
||||
show_progress: bool = True,
|
||||
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], None]:
|
||||
"""Generator for collecting the trade decision data for rl training
|
||||
|
||||
@@ -74,6 +75,8 @@ def collect_data_loop(
|
||||
the outermost executor
|
||||
return_value : dict
|
||||
used for backtest_loop
|
||||
show_progress: bool
|
||||
whether to show execution progress
|
||||
|
||||
Yields
|
||||
-------
|
||||
@@ -83,7 +86,8 @@ def collect_data_loop(
|
||||
trade_executor.reset(start_time=start_time, end_time=end_time)
|
||||
trade_strategy.reset(level_infra=trade_executor.get_level_infra())
|
||||
|
||||
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop") as bar:
|
||||
disable = not show_progress
|
||||
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop", disable=disable) as bar:
|
||||
_execute_result = None
|
||||
while not trade_executor.finished():
|
||||
_trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result)
|
||||
|
||||
@@ -15,7 +15,7 @@ import pandas as pd
|
||||
import torch
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
from qlib.backtest import INDICATOR_METRIC, collect_data_loop, get_strategy_executor
|
||||
from qlib.backtest import INDICATOR_METRIC, collect_data_loop, get_exchange, get_strategy_executor
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime
|
||||
from qlib.backtest.executor import SimulatorExecutor
|
||||
from qlib.backtest.high_performance_ds import BaseOrderIndicator
|
||||
@@ -250,8 +250,6 @@ def single_with_collect_data_loop(
|
||||
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
|
||||
"""
|
||||
|
||||
init_qlib(backtest_config["qlib"])
|
||||
|
||||
trade_start_time = orders["datetime"].min()
|
||||
trade_end_time = orders["datetime"].max()
|
||||
stocks = orders.instrument.unique().tolist()
|
||||
@@ -275,13 +273,13 @@ def single_with_collect_data_loop(
|
||||
data_granularity=backtest_config["data_granularity"],
|
||||
)
|
||||
|
||||
exchange_config = copy.deepcopy(backtest_config["exchange"])
|
||||
exchange_config.update(
|
||||
{
|
||||
exchange_config = {
|
||||
**backtest_config["exchange"],
|
||||
**{
|
||||
"codes": stocks,
|
||||
"freq": backtest_config["data_granularity"],
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
strategy, executor = get_strategy_executor(
|
||||
start_time=pd.Timestamp(trade_start_time),
|
||||
@@ -326,6 +324,8 @@ def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFram
|
||||
single = single_with_simulator if with_simulator else single_with_collect_data_loop
|
||||
mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"}
|
||||
torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199
|
||||
|
||||
init_qlib(backtest_config["qlib"])
|
||||
res = Parallel(**mp_config)(
|
||||
delayed(single)(
|
||||
backtest_config=backtest_config,
|
||||
|
||||
@@ -105,3 +105,100 @@ def get_backtest_config_fromfile(path: str) -> dict:
|
||||
backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)
|
||||
|
||||
return backtest_config
|
||||
|
||||
|
||||
class TrainingConfigParser:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.raw_config = parse_backtest_config(path)
|
||||
|
||||
def parse(self) -> dict:
|
||||
return {
|
||||
"general": self._parse_general(),
|
||||
"policy": self.raw_config["policy"],
|
||||
"interpreter": self.raw_config["interpreter"],
|
||||
"runtime": self._parse_runtime(),
|
||||
"training": self._parse_training(),
|
||||
"simulator": self._parse_simulator(),
|
||||
}
|
||||
|
||||
def _parse_general(self) -> dict:
|
||||
default = {
|
||||
"freq": "1min",
|
||||
"extra_module_paths": [],
|
||||
}
|
||||
return {**default, **self.raw_config["general"]}
|
||||
|
||||
def _parse_runtime(self) -> dict:
|
||||
default = {
|
||||
"seed": None,
|
||||
"use_cuda": False,
|
||||
"concurrency": 1,
|
||||
"parallel_mode": "dummy",
|
||||
}
|
||||
return {**default, **self.raw_config["runtime"]}
|
||||
|
||||
def _parse_training(self) -> dict:
|
||||
default = {
|
||||
"max_epoch": 100,
|
||||
"repeat_per_collect": 2,
|
||||
"earlystop_patience": float("inf"),
|
||||
"episode_per_collect": 10000,
|
||||
"batch_size": 256,
|
||||
"val_every_n_epoch": None,
|
||||
"checkpoint_path": "./outputs",
|
||||
"checkpoint_every_n_iters": 10,
|
||||
}
|
||||
|
||||
config = self.raw_config["training"]
|
||||
assert "order_dir" in config
|
||||
|
||||
return {**default, **config}
|
||||
|
||||
def _parse_simulator(self) -> dict:
|
||||
config = self.raw_config["simulator"]
|
||||
sim_type = config["type"]
|
||||
assert sim_type in ("simple", "full")
|
||||
|
||||
if sim_type == "simple":
|
||||
return {
|
||||
"type": sim_type,
|
||||
"data": {
|
||||
"feature_root_dir": config["data"]["feature_root_dir"],
|
||||
"feature_columns_today": config["data"]["feature_columns_today"],
|
||||
"default_start_time_index": config["data"].get("default_start_time_index", 0),
|
||||
"default_end_time_index": config["data"].get("default_end_time_index", 240),
|
||||
},
|
||||
"time_per_step": config["time_per_step"],
|
||||
"vol_limit": config["vol_limit"],
|
||||
}
|
||||
else:
|
||||
exchange_config_default = {
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5.0,
|
||||
"trade_unit": 100.0,
|
||||
# "cash_limit": None,
|
||||
}
|
||||
exchange_config = {**exchange_config_default, **_convert_all_list_to_tuple(config["exchange"])}
|
||||
exchange_config["freq"] = self.raw_config["general"].get("freq", "1min")
|
||||
|
||||
ret_config = {
|
||||
"type": sim_type,
|
||||
"data": {
|
||||
"feature_root_dir": config["data"]["feature_root_dir"],
|
||||
"default_start_time_index": config["data"].get("default_start_time_index", 0),
|
||||
"default_end_time_index": config["data"].get("default_end_time_index", 240),
|
||||
},
|
||||
"qlib": {
|
||||
"provider_uri_1min": config["qlib"]["provider_uri_1min"],
|
||||
},
|
||||
"exchange": exchange_config
|
||||
}
|
||||
|
||||
return ret_config
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrainingConfigParser("/home/huoran/exp_configs/amc4th_training_refined.yml")
|
||||
|
||||
from pprint import pprint
|
||||
pprint(parser.parse())
|
||||
|
||||
333
qlib/rl/contrib/train.py
Normal file
333
qlib/rl/contrib/train.py
Normal file
@@ -0,0 +1,333 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, cast, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from qlib.backtest import Order
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.constant import ONE_MIN
|
||||
from qlib.rl.contrib.naive_config_parser import TrainingConfigParser
|
||||
from qlib.rl.data.integration import init_qlib
|
||||
from qlib.rl.data.native import _load_handler_pickle, load_handler_intraday_processed_data
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
|
||||
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.trainer import Checkpoint, backtest, train
|
||||
from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
|
||||
from qlib.rl.utils.log import CsvWriter
|
||||
from qlib.utils import init_instance_by_config
|
||||
from tianshou.policy import BasePolicy
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def get_executor_config(freq: int) -> dict:
|
||||
return {
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"inner_executor": {
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"inner_executor": {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"generate_report": False,
|
||||
"time_per_step": f"{freq}min",
|
||||
"track_data": True,
|
||||
"trade_type": "serial",
|
||||
"verbose": False,
|
||||
}
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "TWAPStrategy",
|
||||
"kwargs": {},
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
},
|
||||
"time_per_step": "30min",
|
||||
"track_data": True,
|
||||
}
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "ProxySAOEStrategy",
|
||||
"module_path": "qlib.rl.order_execution.strategy",
|
||||
"kwargs": {},
|
||||
},
|
||||
"time_per_step": "1day",
|
||||
"track_data": True,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def seed_everything(seed: int) -> None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def _read_orders(order_dir: Path) -> pd.DataFrame:
|
||||
if os.path.isfile(order_dir):
|
||||
return pd.read_pickle(order_dir)
|
||||
else:
|
||||
orders = []
|
||||
for file in order_dir.iterdir():
|
||||
order_data = pd.read_pickle(file)
|
||||
orders.append(order_data)
|
||||
return pd.concat(orders)
|
||||
|
||||
|
||||
def _freq_str_to_int(freq: str) -> int:
|
||||
if freq.endswith("min"):
|
||||
return int(freq.replace("min", ""))
|
||||
elif freq.endswith("hour"):
|
||||
return int(freq.replace("hour", "") * 60)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized freq string: {freq}")
|
||||
|
||||
|
||||
class LazyLoadDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
order_file_path: Path,
|
||||
default_start_time_index: int,
|
||||
default_end_time_index: int,
|
||||
) -> None:
|
||||
self._default_start_time_index = default_start_time_index
|
||||
self._default_end_time_index = default_end_time_index
|
||||
|
||||
self._order_df = _read_orders(order_file_path).reset_index()
|
||||
self._ticks_index: Optional[pd.DatetimeIndex] = None
|
||||
self._data_dir = Path(data_dir)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._order_df)
|
||||
|
||||
def __getitem__(self, index: int) -> Order:
|
||||
row = self._order_df.iloc[index]
|
||||
date = pd.Timestamp(str(row["date"]))
|
||||
|
||||
if self._ticks_index is None:
|
||||
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
|
||||
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
|
||||
# TODO: of all dates.
|
||||
|
||||
data = load_handler_intraday_processed_data(
|
||||
data_dir=self._data_dir,
|
||||
stock_id=row["instrument"],
|
||||
date=date,
|
||||
feature_columns_today=[],
|
||||
feature_columns_yesterday=[],
|
||||
backtest=True,
|
||||
)
|
||||
self._ticks_index = [t - date for t in data.today.index]
|
||||
|
||||
order = Order(
|
||||
stock_id=row["instrument"],
|
||||
amount=row["amount"],
|
||||
direction=OrderDir(int(row["order_type"])),
|
||||
start_time=date + self._ticks_index[self._default_start_time_index],
|
||||
end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN,
|
||||
)
|
||||
|
||||
return order
|
||||
|
||||
|
||||
def train_and_test(
|
||||
freq: str,
|
||||
concurrency: int,
|
||||
parallel_mode: str,
|
||||
training_config: dict,
|
||||
simulator_config: dict,
|
||||
policy: BasePolicy,
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
reward: Reward,
|
||||
run_training: bool,
|
||||
run_backtest: bool,
|
||||
) -> None:
|
||||
freq = _freq_str_to_int(freq)
|
||||
order_root_path = Path(training_config["order_dir"])
|
||||
feature_root_dir = simulator_config["data"]["feature_root_dir"]
|
||||
assert simulator_config["data"]["default_start_time_index"] % freq == 0
|
||||
assert simulator_config["data"]["default_end_time_index"] % freq == 0
|
||||
|
||||
sim_type = simulator_config["type"]
|
||||
if sim_type == "simple":
|
||||
def _simulator_factory(order: Order) -> SingleAssetOrderExecutionSimple:
|
||||
simulator = SingleAssetOrderExecutionSimple(
|
||||
order=order,
|
||||
data_dir=feature_root_dir,
|
||||
feature_columns_today=simulator_config["data"]["feature_columns_today"],
|
||||
data_granularity=freq,
|
||||
ticks_per_step=simulator_config["time_per_step"],
|
||||
vol_threshold=simulator_config["vol_limit"],
|
||||
)
|
||||
return simulator
|
||||
elif sim_type == "full":
|
||||
init_qlib(simulator_config["qlib"])
|
||||
executor_config = get_executor_config(freq)
|
||||
exchange_config = simulator_config["exchange"]
|
||||
|
||||
def _simulator_factory(order: Order) -> SingleAssetOrderExecution:
|
||||
simulator = SingleAssetOrderExecution(
|
||||
order=order,
|
||||
executor_config=executor_config,
|
||||
exchange_config={**exchange_config, **{"codes": [order.stock_id]}},
|
||||
qlib_config=None,
|
||||
cash_limit=None,
|
||||
)
|
||||
return simulator
|
||||
|
||||
if run_training:
|
||||
train_dataset, valid_dataset = [
|
||||
LazyLoadDataset(
|
||||
data_dir=feature_root_dir,
|
||||
order_file_path=order_root_path / tag,
|
||||
default_start_time_index=simulator_config["data"]["default_start_time_index"] // freq,
|
||||
default_end_time_index=simulator_config["data"]["default_end_time_index"] // freq,
|
||||
)
|
||||
for tag in ("train", "valid")
|
||||
]
|
||||
|
||||
callbacks: List[Callback] = [
|
||||
MetricsWriter(dirpath=Path(training_config["checkpoint_path"])),
|
||||
Checkpoint(
|
||||
dirpath=Path(training_config["checkpoint_path"]) / "checkpoints",
|
||||
every_n_iters=training_config["checkpoint_every_n_iters"],
|
||||
save_latest="copy",
|
||||
),
|
||||
EarlyStopping(
|
||||
patience=training_config["earlystop_patience"],
|
||||
monitor="val/pa",
|
||||
),
|
||||
]
|
||||
|
||||
train(
|
||||
simulator_fn=_simulator_factory,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
policy=policy,
|
||||
reward=reward,
|
||||
initial_states=cast(List[Order], train_dataset),
|
||||
trainer_kwargs={
|
||||
"max_iters": training_config["max_epoch"],
|
||||
"finite_env_type": parallel_mode,
|
||||
"concurrency": concurrency,
|
||||
"val_every_n_iters": training_config["val_every_n_epoch"],
|
||||
"callbacks": callbacks,
|
||||
},
|
||||
vessel_kwargs={
|
||||
"episode_per_iter": training_config["episode_per_collect"],
|
||||
"update_kwargs": {
|
||||
"batch_size": training_config["batch_size"],
|
||||
"repeat": training_config["repeat_per_collect"],
|
||||
},
|
||||
"val_initial_states": valid_dataset,
|
||||
},
|
||||
)
|
||||
|
||||
if run_backtest:
|
||||
test_dataset = LazyLoadDataset(
|
||||
data_dir=feature_root_dir,
|
||||
order_file_path=order_root_path / "test",
|
||||
default_start_time_index=simulator_config["data"]["default_start_time_index"] // freq,
|
||||
default_end_time_index=simulator_config["data"]["default_end_time_index"] // freq,
|
||||
)
|
||||
|
||||
backtest(
|
||||
simulator_fn=_simulator_factory,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
initial_states=test_dataset,
|
||||
policy=policy,
|
||||
logger=CsvWriter(Path(training_config["checkpoint_path"])),
|
||||
reward=reward,
|
||||
finite_env_type=parallel_mode,
|
||||
concurrency=concurrency,
|
||||
)
|
||||
|
||||
|
||||
def main(config: dict, run_training: bool, run_backtest: bool) -> None:
|
||||
if not run_training and not run_backtest:
|
||||
warnings.warn("Skip the entire job since training and backtest are both skipped.")
|
||||
return
|
||||
|
||||
seed = config["runtime"]["seed"]
|
||||
if seed is not None:
|
||||
seed_everything(seed)
|
||||
|
||||
for extra_module_path in config["general"]["extra_module_paths"]:
|
||||
sys.path.append(extra_module_path)
|
||||
|
||||
state_interpreter: StateInterpreter = init_instance_by_config(config["interpreter"]["state"])
|
||||
action_interpreter: ActionInterpreter = init_instance_by_config(config["interpreter"]["action"])
|
||||
reward: Reward = init_instance_by_config(config["interpreter"]["reward"])
|
||||
|
||||
additional_policy_kwargs = {
|
||||
"obs_space": state_interpreter.observation_space,
|
||||
"action_space": action_interpreter.action_space,
|
||||
}
|
||||
# Create torch network
|
||||
if "network" in config["policy"]:
|
||||
network_config = config["policy"]["network"]
|
||||
network_config["kwargs"] = {
|
||||
**network_config.get("kwargs", {}),
|
||||
**{"obs_space": state_interpreter.observation_space}
|
||||
}
|
||||
additional_policy_kwargs["network"] = init_instance_by_config(network_config)
|
||||
|
||||
# Create policy
|
||||
policy_config = config["policy"]["policy"]
|
||||
policy_config["kwargs"] = {
|
||||
**policy_config.get("kwargs", {}),
|
||||
**additional_policy_kwargs
|
||||
}
|
||||
policy: BasePolicy = init_instance_by_config(policy_config)
|
||||
|
||||
use_cuda = config["runtime"]["use_cuda"]
|
||||
if use_cuda:
|
||||
policy.cuda()
|
||||
|
||||
train_and_test(
|
||||
freq=config["general"]["freq"],
|
||||
concurrency=config["runtime"]["concurrency"],
|
||||
parallel_mode=config["runtime"]["parallel_mode"],
|
||||
training_config=config["training"],
|
||||
simulator_config=config["simulator"],
|
||||
policy=policy,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
reward=reward,
|
||||
run_training=run_training,
|
||||
run_backtest=run_backtest,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
|
||||
parser.add_argument("--no_training", action="store_true", help="Skip training workflow.")
|
||||
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.")
|
||||
args = parser.parse_args()
|
||||
|
||||
config_parser = TrainingConfigParser(args.config_path)
|
||||
config = config_parser.parse()
|
||||
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)
|
||||
331
qlib/rl/contrib/train_onpolicy_full_simulation.py
Normal file
331
qlib/rl/contrib/train_onpolicy_full_simulation.py
Normal file
@@ -0,0 +1,331 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, cast, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import yaml
|
||||
from qlib.backtest import Order
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.constant import ONE_MIN
|
||||
from qlib.rl.contrib.naive_config_parser import parse_backtest_config
|
||||
from qlib.rl.data.integration import init_qlib
|
||||
from qlib.rl.data.native import load_handler_intraday_processed_data
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
|
||||
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.trainer import Checkpoint, backtest, train
|
||||
from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
|
||||
from qlib.rl.utils.log import CsvWriter
|
||||
from qlib.utils import init_instance_by_config
|
||||
from tianshou.policy import BasePolicy
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def get_executor_config(data_granularity: int = 1) -> dict:
|
||||
return {
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"inner_executor": {
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"inner_executor": {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"generate_report": False,
|
||||
"time_per_step": f"{data_granularity}min",
|
||||
"track_data": True,
|
||||
"trade_type": "serial",
|
||||
"verbose": False,
|
||||
}
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "TWAPStrategy",
|
||||
"kwargs": {},
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
},
|
||||
"time_per_step": "30min",
|
||||
"track_data": True,
|
||||
}
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "ProxySAOEStrategy",
|
||||
"module_path": "qlib.rl.order_execution.strategy",
|
||||
"kwargs": {},
|
||||
},
|
||||
"time_per_step": "1day",
|
||||
"track_data": True,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _convert_list_to_tuple(obj: Any) -> Any:
|
||||
if isinstance(obj, dict):
|
||||
return {k: _convert_list_to_tuple(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return tuple(obj)
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
def seed_everything(seed: int) -> None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def _read_orders(order_dir: Path) -> pd.DataFrame:
|
||||
if os.path.isfile(order_dir):
|
||||
return pd.read_pickle(order_dir)
|
||||
else:
|
||||
orders = []
|
||||
for file in order_dir.iterdir():
|
||||
order_data = pd.read_pickle(file)
|
||||
orders.append(order_data)
|
||||
return pd.concat(orders)
|
||||
|
||||
|
||||
class LazyLoadDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
order_file_path: Path,
|
||||
default_start_time_index: int,
|
||||
default_end_time_index: int,
|
||||
) -> None:
|
||||
self._default_start_time_index = default_start_time_index
|
||||
self._default_end_time_index = default_end_time_index
|
||||
|
||||
self._order_df = _read_orders(order_file_path).reset_index()
|
||||
self._ticks_index: Optional[pd.DatetimeIndex] = None
|
||||
self._data_dir = Path(data_dir)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._order_df)
|
||||
|
||||
def __getitem__(self, index: int) -> Order:
|
||||
row = self._order_df.iloc[index]
|
||||
date = pd.Timestamp(str(row["date"]))
|
||||
|
||||
if self._ticks_index is None:
|
||||
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
|
||||
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
|
||||
# TODO: of all dates.
|
||||
|
||||
data = load_handler_intraday_processed_data(
|
||||
data_dir=self._data_dir,
|
||||
stock_id=row["instrument"],
|
||||
date=date,
|
||||
feature_columns_today=[],
|
||||
feature_columns_yesterday=[],
|
||||
backtest=True,
|
||||
index_only=True,
|
||||
)
|
||||
self._ticks_index = [t - date for t in data.today.index]
|
||||
|
||||
order = Order(
|
||||
stock_id=row["instrument"],
|
||||
amount=row["amount"],
|
||||
direction=OrderDir(int(row["order_type"])),
|
||||
start_time=date + self._ticks_index[self._default_start_time_index],
|
||||
end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN,
|
||||
)
|
||||
|
||||
return order
|
||||
|
||||
|
||||
def train_and_test(
|
||||
env_config: dict,
|
||||
trainer_config: dict,
|
||||
data_config: dict,
|
||||
exchange_config: dict,
|
||||
qlib_config: dict,
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
policy: BasePolicy,
|
||||
reward: Reward,
|
||||
run_training: bool,
|
||||
run_backtest: bool,
|
||||
) -> None:
|
||||
init_qlib(qlib_config)
|
||||
|
||||
order_root_path = Path(data_config["source"]["order_dir"])
|
||||
|
||||
data_granularity = 1 # simulator_config.get("data_granularity", 1)
|
||||
|
||||
exchange_config_default = {
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5.0,
|
||||
"trade_unit": 100.0,
|
||||
# "cash_limit": None,
|
||||
}
|
||||
exchange_config = {**exchange_config_default, **exchange_config}
|
||||
exchange_config = _convert_list_to_tuple(exchange_config)
|
||||
|
||||
def _simulator_factory(order: Order) -> SingleAssetOrderExecution:
|
||||
simulator = SingleAssetOrderExecution(
|
||||
order=order,
|
||||
executor_config=get_executor_config(data_granularity),
|
||||
exchange_config={**exchange_config, **{"codes": [order.stock_id]}},
|
||||
qlib_config=None,
|
||||
cash_limit=None,
|
||||
)
|
||||
return simulator
|
||||
|
||||
assert data_config["source"]["default_start_time_index"] % data_granularity == 0
|
||||
assert data_config["source"]["default_end_time_index"] % data_granularity == 0
|
||||
|
||||
if run_training:
|
||||
train_dataset, valid_dataset = [
|
||||
LazyLoadDataset(
|
||||
data_dir=data_config["source"]["feature_root_dir"],
|
||||
order_file_path=order_root_path / tag,
|
||||
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
|
||||
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
|
||||
)
|
||||
for tag in ("train", "valid")
|
||||
]
|
||||
|
||||
callbacks: List[Callback] = []
|
||||
if "checkpoint_path" in trainer_config:
|
||||
callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"])))
|
||||
callbacks.append(
|
||||
Checkpoint(
|
||||
dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints",
|
||||
every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1),
|
||||
save_latest="copy",
|
||||
),
|
||||
)
|
||||
if "earlystop_patience" in trainer_config:
|
||||
callbacks.append(
|
||||
EarlyStopping(
|
||||
patience=trainer_config["earlystop_patience"],
|
||||
monitor="val/pa",
|
||||
)
|
||||
)
|
||||
|
||||
train(
|
||||
simulator_fn=_simulator_factory,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
policy=policy,
|
||||
reward=reward,
|
||||
initial_states=cast(List[Order], train_dataset),
|
||||
trainer_kwargs={
|
||||
"max_iters": trainer_config["max_epoch"],
|
||||
"finite_env_type": env_config["parallel_mode"],
|
||||
"concurrency": env_config["concurrency"],
|
||||
"val_every_n_iters": trainer_config.get("val_every_n_epoch", None),
|
||||
"callbacks": callbacks,
|
||||
},
|
||||
vessel_kwargs={
|
||||
"episode_per_iter": trainer_config["episode_per_collect"],
|
||||
"update_kwargs": {
|
||||
"batch_size": trainer_config["batch_size"],
|
||||
"repeat": trainer_config["repeat_per_collect"],
|
||||
},
|
||||
"val_initial_states": valid_dataset,
|
||||
},
|
||||
)
|
||||
|
||||
if run_backtest:
|
||||
test_dataset = LazyLoadDataset(
|
||||
data_dir=data_config["source"]["feature_root_dir"],
|
||||
order_file_path=order_root_path / "test",
|
||||
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
|
||||
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
|
||||
)
|
||||
|
||||
backtest(
|
||||
simulator_fn=_simulator_factory,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
initial_states=test_dataset,
|
||||
policy=policy,
|
||||
logger=CsvWriter(Path(trainer_config["checkpoint_path"])),
|
||||
reward=reward,
|
||||
finite_env_type=env_config["parallel_mode"],
|
||||
concurrency=env_config["concurrency"],
|
||||
)
|
||||
|
||||
|
||||
def main(config: dict, run_training: bool, run_backtest: bool) -> None:
|
||||
if not run_training and not run_backtest:
|
||||
warnings.warn("Skip the entire job since training and backtest are both skipped.")
|
||||
return
|
||||
|
||||
if "seed" in config["runtime"]:
|
||||
seed_everything(config["runtime"]["seed"])
|
||||
|
||||
for extra_module_path in config["env"].get("extra_module_paths", []):
|
||||
sys.path.append(extra_module_path)
|
||||
|
||||
state_interpreter: StateInterpreter = init_instance_by_config(config["state_interpreter"])
|
||||
action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"])
|
||||
reward: Reward = init_instance_by_config(config["reward"])
|
||||
|
||||
additional_policy_kwargs = {
|
||||
"obs_space": state_interpreter.observation_space,
|
||||
"action_space": action_interpreter.action_space,
|
||||
}
|
||||
|
||||
# Create torch network
|
||||
if "network" in config:
|
||||
if "kwargs" not in config["network"]:
|
||||
config["network"]["kwargs"] = {}
|
||||
config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space})
|
||||
additional_policy_kwargs["network"] = init_instance_by_config(config["network"])
|
||||
|
||||
# Create policy
|
||||
if "kwargs" not in config["policy"]:
|
||||
config["policy"]["kwargs"] = {}
|
||||
config["policy"]["kwargs"].update(additional_policy_kwargs)
|
||||
policy: BasePolicy = init_instance_by_config(config["policy"])
|
||||
|
||||
use_cuda = config["runtime"].get("use_cuda", False)
|
||||
if use_cuda:
|
||||
policy.cuda()
|
||||
|
||||
train_and_test(
|
||||
env_config=config["env"],
|
||||
data_config=config["data"],
|
||||
exchange_config=config["exchange"],
|
||||
qlib_config=config["qlib"],
|
||||
trainer_config=config["trainer"],
|
||||
action_interpreter=action_interpreter,
|
||||
state_interpreter=state_interpreter,
|
||||
policy=policy,
|
||||
reward=reward,
|
||||
run_training=run_training,
|
||||
run_backtest=run_backtest,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
|
||||
parser.add_argument("--no_training", action="store_true", help="Skip training workflow.")
|
||||
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.")
|
||||
args = parser.parse_args()
|
||||
|
||||
config = parse_backtest_config(args.config_path)
|
||||
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)
|
||||
@@ -140,6 +140,15 @@ def load_backtest_data(
|
||||
return backtest_data
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(1000),
|
||||
key=lambda path: path,
|
||||
)
|
||||
def _load_handler_pickle(path: str) -> object:
|
||||
with open(path, "rb") as fstream:
|
||||
obj = pickle.load(fstream)
|
||||
return obj
|
||||
|
||||
class HandlerIntradayProcessedData(BaseIntradayProcessedData):
|
||||
"""Subclass of IntradayProcessedData. Used to handle handler (bin format) style data."""
|
||||
|
||||
@@ -151,7 +160,6 @@ class HandlerIntradayProcessedData(BaseIntradayProcessedData):
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
index_only: bool = False,
|
||||
) -> None:
|
||||
def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame:
|
||||
df = df.reset_index()
|
||||
@@ -161,31 +169,17 @@ class HandlerIntradayProcessedData(BaseIntradayProcessedData):
|
||||
|
||||
path = os.path.join(data_dir, "backtest" if backtest else "feature", f"{stock_id}.pkl")
|
||||
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
|
||||
with open(path, "rb") as fstream:
|
||||
dataset = pickle.load(fstream)
|
||||
dataset = _load_handler_pickle(path)
|
||||
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
|
||||
|
||||
if index_only:
|
||||
self.today = _drop_stock_id(data[[]])
|
||||
self.yesterday = _drop_stock_id(data[[]])
|
||||
else:
|
||||
self.today = _drop_stock_id(data[feature_columns_today])
|
||||
self.yesterday = _drop_stock_id(data[feature_columns_yesterday])
|
||||
self.today = _drop_stock_id(data[feature_columns_today])
|
||||
self.yesterday = _drop_stock_id(data[feature_columns_yesterday])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
return f"{self.__class__.__name__}({self.today}, {self.yesterday})"
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
|
||||
key=lambda data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only: (
|
||||
stock_id,
|
||||
date,
|
||||
backtest,
|
||||
index_only,
|
||||
),
|
||||
)
|
||||
def load_handler_intraday_processed_data(
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
@@ -193,10 +187,9 @@ def load_handler_intraday_processed_data(
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
index_only: bool = False,
|
||||
) -> HandlerIntradayProcessedData:
|
||||
return HandlerIntradayProcessedData(
|
||||
data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only
|
||||
data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest,
|
||||
)
|
||||
|
||||
|
||||
@@ -229,5 +222,4 @@ class HandlerProcessedDataProvider(ProcessedDataProvider):
|
||||
self.feature_columns_today,
|
||||
self.feature_columns_yesterday,
|
||||
backtest=self.backtest,
|
||||
index_only=False,
|
||||
)
|
||||
|
||||
@@ -90,6 +90,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
trade_strategy=strategy,
|
||||
trade_executor=self._executor,
|
||||
return_value=self.report_dict,
|
||||
show_progress=False,
|
||||
)
|
||||
assert isinstance(self._collect_data_loop, Generator)
|
||||
|
||||
|
||||
@@ -42,8 +42,6 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
Path to load backtest data.
|
||||
feature_columns_today
|
||||
Columns of today's feature.
|
||||
feature_columns_yesterday
|
||||
Columns of yesterday's feature.
|
||||
data_granularity
|
||||
Number of ticks between consecutive data entries.
|
||||
ticks_per_step
|
||||
@@ -80,7 +78,6 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
order: Order,
|
||||
data_dir: Path,
|
||||
feature_columns_today: List[str] = [],
|
||||
feature_columns_yesterday: List[str] = [],
|
||||
data_granularity: int = 1,
|
||||
ticks_per_step: int = 30,
|
||||
vol_threshold: Optional[float] = None,
|
||||
@@ -92,7 +89,6 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
self.order = order
|
||||
self.data_dir = data_dir
|
||||
self.feature_columns_today = feature_columns_today
|
||||
self.feature_columns_yesterday = feature_columns_yesterday
|
||||
self.ticks_per_step: int = ticks_per_step // data_granularity
|
||||
self.vol_threshold = vol_threshold
|
||||
|
||||
@@ -127,9 +123,8 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
stock_id=self.order.stock_id,
|
||||
date=pd.Timestamp(self.order.start_time.date()),
|
||||
feature_columns_today=self.feature_columns_today,
|
||||
feature_columns_yesterday=self.feature_columns_yesterday,
|
||||
feature_columns_yesterday=[],
|
||||
backtest=True,
|
||||
index_only=False,
|
||||
)
|
||||
return DataframeIntradayBacktestData(data.today)
|
||||
except (AttributeError, FileNotFoundError):
|
||||
|
||||
@@ -208,6 +208,7 @@ class Trainer:
|
||||
|
||||
while not self.should_stop:
|
||||
msg = f"\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tTrain iteration {self.current_iter + 1}/{self.max_iters}"
|
||||
print(msg)
|
||||
_logger.info(msg)
|
||||
|
||||
self.initialize_iter()
|
||||
|
||||
20
qlib/rl/utils/profiling.py
Normal file
20
qlib/rl/utils/profiling.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from line_profiler import LineProfiler
|
||||
|
||||
@contextmanager
|
||||
def simple_perf(desc: str = ""):
|
||||
s = time.perf_counter()
|
||||
yield
|
||||
e = time.perf_counter()
|
||||
print(f"{desc}: {(e - s) * 1000.0} ms")
|
||||
|
||||
|
||||
def lprofile(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
lp = LineProfiler()
|
||||
lpw = lp(func)
|
||||
res = lpw(*args, **kwargs)
|
||||
lp.print_stats()
|
||||
return res
|
||||
return wrapper
|
||||
Reference in New Issue
Block a user