1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 03:21:00 +08:00

Train on full simulation

This commit is contained in:
Huoran Li
2023-05-24 10:36:27 +08:00
parent 94268619c4
commit 3e9ccd3ad2
10 changed files with 809 additions and 35 deletions

View File

@@ -56,6 +56,7 @@ def collect_data_loop(
trade_strategy: BaseStrategy,
trade_executor: BaseExecutor,
return_value: dict | None = None,
show_progress: bool = True,
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], None]:
"""Generator for collecting the trade decision data for rl training
@@ -74,6 +75,8 @@ def collect_data_loop(
the outermost executor
return_value : dict
used for backtest_loop
show_progress: bool
whether to show execution progress
Yields
-------
@@ -83,7 +86,8 @@ def collect_data_loop(
trade_executor.reset(start_time=start_time, end_time=end_time)
trade_strategy.reset(level_infra=trade_executor.get_level_infra())
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop") as bar:
disable = not show_progress
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop", disable=disable) as bar:
_execute_result = None
while not trade_executor.finished():
_trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result)

View File

@@ -15,7 +15,7 @@ import pandas as pd
import torch
from joblib import Parallel, delayed
from qlib.backtest import INDICATOR_METRIC, collect_data_loop, get_strategy_executor
from qlib.backtest import INDICATOR_METRIC, collect_data_loop, get_exchange, get_strategy_executor
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime
from qlib.backtest.executor import SimulatorExecutor
from qlib.backtest.high_performance_ds import BaseOrderIndicator
@@ -250,8 +250,6 @@ def single_with_collect_data_loop(
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
"""
init_qlib(backtest_config["qlib"])
trade_start_time = orders["datetime"].min()
trade_end_time = orders["datetime"].max()
stocks = orders.instrument.unique().tolist()
@@ -275,13 +273,13 @@ def single_with_collect_data_loop(
data_granularity=backtest_config["data_granularity"],
)
exchange_config = copy.deepcopy(backtest_config["exchange"])
exchange_config.update(
{
exchange_config = {
**backtest_config["exchange"],
**{
"codes": stocks,
"freq": backtest_config["data_granularity"],
}
)
}
strategy, executor = get_strategy_executor(
start_time=pd.Timestamp(trade_start_time),
@@ -326,6 +324,8 @@ def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFram
single = single_with_simulator if with_simulator else single_with_collect_data_loop
mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"}
torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199
init_qlib(backtest_config["qlib"])
res = Parallel(**mp_config)(
delayed(single)(
backtest_config=backtest_config,

View File

@@ -105,3 +105,100 @@ def get_backtest_config_fromfile(path: str) -> dict:
backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)
return backtest_config
class TrainingConfigParser:
def __init__(self, path: str) -> None:
self.raw_config = parse_backtest_config(path)
def parse(self) -> dict:
return {
"general": self._parse_general(),
"policy": self.raw_config["policy"],
"interpreter": self.raw_config["interpreter"],
"runtime": self._parse_runtime(),
"training": self._parse_training(),
"simulator": self._parse_simulator(),
}
def _parse_general(self) -> dict:
default = {
"freq": "1min",
"extra_module_paths": [],
}
return {**default, **self.raw_config["general"]}
def _parse_runtime(self) -> dict:
default = {
"seed": None,
"use_cuda": False,
"concurrency": 1,
"parallel_mode": "dummy",
}
return {**default, **self.raw_config["runtime"]}
def _parse_training(self) -> dict:
default = {
"max_epoch": 100,
"repeat_per_collect": 2,
"earlystop_patience": float("inf"),
"episode_per_collect": 10000,
"batch_size": 256,
"val_every_n_epoch": None,
"checkpoint_path": "./outputs",
"checkpoint_every_n_iters": 10,
}
config = self.raw_config["training"]
assert "order_dir" in config
return {**default, **config}
def _parse_simulator(self) -> dict:
config = self.raw_config["simulator"]
sim_type = config["type"]
assert sim_type in ("simple", "full")
if sim_type == "simple":
return {
"type": sim_type,
"data": {
"feature_root_dir": config["data"]["feature_root_dir"],
"feature_columns_today": config["data"]["feature_columns_today"],
"default_start_time_index": config["data"].get("default_start_time_index", 0),
"default_end_time_index": config["data"].get("default_end_time_index", 240),
},
"time_per_step": config["time_per_step"],
"vol_limit": config["vol_limit"],
}
else:
exchange_config_default = {
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5.0,
"trade_unit": 100.0,
# "cash_limit": None,
}
exchange_config = {**exchange_config_default, **_convert_all_list_to_tuple(config["exchange"])}
exchange_config["freq"] = self.raw_config["general"].get("freq", "1min")
ret_config = {
"type": sim_type,
"data": {
"feature_root_dir": config["data"]["feature_root_dir"],
"default_start_time_index": config["data"].get("default_start_time_index", 0),
"default_end_time_index": config["data"].get("default_end_time_index", 240),
},
"qlib": {
"provider_uri_1min": config["qlib"]["provider_uri_1min"],
},
"exchange": exchange_config
}
return ret_config
if __name__ == "__main__":
parser = TrainingConfigParser("/home/huoran/exp_configs/amc4th_training_refined.yml")
from pprint import pprint
pprint(parser.parse())

333
qlib/rl/contrib/train.py Normal file
View File

@@ -0,0 +1,333 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import argparse
import os
import random
import sys
import warnings
from pathlib import Path
from typing import Any, cast, List, Optional
import numpy as np
import pandas as pd
import torch
from qlib.backtest import Order
from qlib.backtest.decision import OrderDir
from qlib.constant import ONE_MIN
from qlib.rl.contrib.naive_config_parser import TrainingConfigParser
from qlib.rl.data.integration import init_qlib
from qlib.rl.data.native import _load_handler_pickle, load_handler_intraday_processed_data
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
from qlib.rl.reward import Reward
from qlib.rl.trainer import Checkpoint, backtest, train
from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
from qlib.rl.utils.log import CsvWriter
from qlib.utils import init_instance_by_config
from tianshou.policy import BasePolicy
from torch.utils.data import Dataset
def get_executor_config(freq: int) -> dict:
return {
"class": "NestedExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"inner_executor": {
"class": "NestedExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"inner_executor": {
"class": "SimulatorExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"generate_report": False,
"time_per_step": f"{freq}min",
"track_data": True,
"trade_type": "serial",
"verbose": False,
}
},
"inner_strategy": {
"class": "TWAPStrategy",
"kwargs": {},
"module_path": "qlib.contrib.strategy.rule_strategy",
},
"time_per_step": "30min",
"track_data": True,
}
},
"inner_strategy": {
"class": "ProxySAOEStrategy",
"module_path": "qlib.rl.order_execution.strategy",
"kwargs": {},
},
"time_per_step": "1day",
"track_data": True,
}
}
def seed_everything(seed: int) -> None:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def _read_orders(order_dir: Path) -> pd.DataFrame:
if os.path.isfile(order_dir):
return pd.read_pickle(order_dir)
else:
orders = []
for file in order_dir.iterdir():
order_data = pd.read_pickle(file)
orders.append(order_data)
return pd.concat(orders)
def _freq_str_to_int(freq: str) -> int:
if freq.endswith("min"):
return int(freq.replace("min", ""))
elif freq.endswith("hour"):
return int(freq.replace("hour", "") * 60)
else:
raise ValueError(f"Unrecognized freq string: {freq}")
class LazyLoadDataset(Dataset):
def __init__(
self,
data_dir: str,
order_file_path: Path,
default_start_time_index: int,
default_end_time_index: int,
) -> None:
self._default_start_time_index = default_start_time_index
self._default_end_time_index = default_end_time_index
self._order_df = _read_orders(order_file_path).reset_index()
self._ticks_index: Optional[pd.DatetimeIndex] = None
self._data_dir = Path(data_dir)
def __len__(self) -> int:
return len(self._order_df)
def __getitem__(self, index: int) -> Order:
row = self._order_df.iloc[index]
date = pd.Timestamp(str(row["date"]))
if self._ticks_index is None:
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
# TODO: of all dates.
data = load_handler_intraday_processed_data(
data_dir=self._data_dir,
stock_id=row["instrument"],
date=date,
feature_columns_today=[],
feature_columns_yesterday=[],
backtest=True,
)
self._ticks_index = [t - date for t in data.today.index]
order = Order(
stock_id=row["instrument"],
amount=row["amount"],
direction=OrderDir(int(row["order_type"])),
start_time=date + self._ticks_index[self._default_start_time_index],
end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN,
)
return order
def train_and_test(
freq: str,
concurrency: int,
parallel_mode: str,
training_config: dict,
simulator_config: dict,
policy: BasePolicy,
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
reward: Reward,
run_training: bool,
run_backtest: bool,
) -> None:
freq = _freq_str_to_int(freq)
order_root_path = Path(training_config["order_dir"])
feature_root_dir = simulator_config["data"]["feature_root_dir"]
assert simulator_config["data"]["default_start_time_index"] % freq == 0
assert simulator_config["data"]["default_end_time_index"] % freq == 0
sim_type = simulator_config["type"]
if sim_type == "simple":
def _simulator_factory(order: Order) -> SingleAssetOrderExecutionSimple:
simulator = SingleAssetOrderExecutionSimple(
order=order,
data_dir=feature_root_dir,
feature_columns_today=simulator_config["data"]["feature_columns_today"],
data_granularity=freq,
ticks_per_step=simulator_config["time_per_step"],
vol_threshold=simulator_config["vol_limit"],
)
return simulator
elif sim_type == "full":
init_qlib(simulator_config["qlib"])
executor_config = get_executor_config(freq)
exchange_config = simulator_config["exchange"]
def _simulator_factory(order: Order) -> SingleAssetOrderExecution:
simulator = SingleAssetOrderExecution(
order=order,
executor_config=executor_config,
exchange_config={**exchange_config, **{"codes": [order.stock_id]}},
qlib_config=None,
cash_limit=None,
)
return simulator
if run_training:
train_dataset, valid_dataset = [
LazyLoadDataset(
data_dir=feature_root_dir,
order_file_path=order_root_path / tag,
default_start_time_index=simulator_config["data"]["default_start_time_index"] // freq,
default_end_time_index=simulator_config["data"]["default_end_time_index"] // freq,
)
for tag in ("train", "valid")
]
callbacks: List[Callback] = [
MetricsWriter(dirpath=Path(training_config["checkpoint_path"])),
Checkpoint(
dirpath=Path(training_config["checkpoint_path"]) / "checkpoints",
every_n_iters=training_config["checkpoint_every_n_iters"],
save_latest="copy",
),
EarlyStopping(
patience=training_config["earlystop_patience"],
monitor="val/pa",
),
]
train(
simulator_fn=_simulator_factory,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
policy=policy,
reward=reward,
initial_states=cast(List[Order], train_dataset),
trainer_kwargs={
"max_iters": training_config["max_epoch"],
"finite_env_type": parallel_mode,
"concurrency": concurrency,
"val_every_n_iters": training_config["val_every_n_epoch"],
"callbacks": callbacks,
},
vessel_kwargs={
"episode_per_iter": training_config["episode_per_collect"],
"update_kwargs": {
"batch_size": training_config["batch_size"],
"repeat": training_config["repeat_per_collect"],
},
"val_initial_states": valid_dataset,
},
)
if run_backtest:
test_dataset = LazyLoadDataset(
data_dir=feature_root_dir,
order_file_path=order_root_path / "test",
default_start_time_index=simulator_config["data"]["default_start_time_index"] // freq,
default_end_time_index=simulator_config["data"]["default_end_time_index"] // freq,
)
backtest(
simulator_fn=_simulator_factory,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
initial_states=test_dataset,
policy=policy,
logger=CsvWriter(Path(training_config["checkpoint_path"])),
reward=reward,
finite_env_type=parallel_mode,
concurrency=concurrency,
)
def main(config: dict, run_training: bool, run_backtest: bool) -> None:
if not run_training and not run_backtest:
warnings.warn("Skip the entire job since training and backtest are both skipped.")
return
seed = config["runtime"]["seed"]
if seed is not None:
seed_everything(seed)
for extra_module_path in config["general"]["extra_module_paths"]:
sys.path.append(extra_module_path)
state_interpreter: StateInterpreter = init_instance_by_config(config["interpreter"]["state"])
action_interpreter: ActionInterpreter = init_instance_by_config(config["interpreter"]["action"])
reward: Reward = init_instance_by_config(config["interpreter"]["reward"])
additional_policy_kwargs = {
"obs_space": state_interpreter.observation_space,
"action_space": action_interpreter.action_space,
}
# Create torch network
if "network" in config["policy"]:
network_config = config["policy"]["network"]
network_config["kwargs"] = {
**network_config.get("kwargs", {}),
**{"obs_space": state_interpreter.observation_space}
}
additional_policy_kwargs["network"] = init_instance_by_config(network_config)
# Create policy
policy_config = config["policy"]["policy"]
policy_config["kwargs"] = {
**policy_config.get("kwargs", {}),
**additional_policy_kwargs
}
policy: BasePolicy = init_instance_by_config(policy_config)
use_cuda = config["runtime"]["use_cuda"]
if use_cuda:
policy.cuda()
train_and_test(
freq=config["general"]["freq"],
concurrency=config["runtime"]["concurrency"],
parallel_mode=config["runtime"]["parallel_mode"],
training_config=config["training"],
simulator_config=config["simulator"],
policy=policy,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
reward=reward,
run_training=run_training,
run_backtest=run_backtest,
)
if __name__ == "__main__":
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
parser.add_argument("--no_training", action="store_true", help="Skip training workflow.")
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.")
args = parser.parse_args()
config_parser = TrainingConfigParser(args.config_path)
config = config_parser.parse()
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)

View File

@@ -0,0 +1,331 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import argparse
import os
import random
import sys
import warnings
from pathlib import Path
from typing import Any, cast, List, Optional
import numpy as np
import pandas as pd
import torch
import yaml
from qlib.backtest import Order
from qlib.backtest.decision import OrderDir
from qlib.constant import ONE_MIN
from qlib.rl.contrib.naive_config_parser import parse_backtest_config
from qlib.rl.data.integration import init_qlib
from qlib.rl.data.native import load_handler_intraday_processed_data
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
from qlib.rl.reward import Reward
from qlib.rl.trainer import Checkpoint, backtest, train
from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
from qlib.rl.utils.log import CsvWriter
from qlib.utils import init_instance_by_config
from tianshou.policy import BasePolicy
from torch.utils.data import Dataset
def get_executor_config(data_granularity: int = 1) -> dict:
return {
"class": "NestedExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"inner_executor": {
"class": "NestedExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"inner_executor": {
"class": "SimulatorExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"generate_report": False,
"time_per_step": f"{data_granularity}min",
"track_data": True,
"trade_type": "serial",
"verbose": False,
}
},
"inner_strategy": {
"class": "TWAPStrategy",
"kwargs": {},
"module_path": "qlib.contrib.strategy.rule_strategy",
},
"time_per_step": "30min",
"track_data": True,
}
},
"inner_strategy": {
"class": "ProxySAOEStrategy",
"module_path": "qlib.rl.order_execution.strategy",
"kwargs": {},
},
"time_per_step": "1day",
"track_data": True,
}
}
def _convert_list_to_tuple(obj: Any) -> Any:
if isinstance(obj, dict):
return {k: _convert_list_to_tuple(v) for k, v in obj.items()}
elif isinstance(obj, list):
return tuple(obj)
else:
return obj
def seed_everything(seed: int) -> None:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def _read_orders(order_dir: Path) -> pd.DataFrame:
if os.path.isfile(order_dir):
return pd.read_pickle(order_dir)
else:
orders = []
for file in order_dir.iterdir():
order_data = pd.read_pickle(file)
orders.append(order_data)
return pd.concat(orders)
class LazyLoadDataset(Dataset):
def __init__(
self,
data_dir: str,
order_file_path: Path,
default_start_time_index: int,
default_end_time_index: int,
) -> None:
self._default_start_time_index = default_start_time_index
self._default_end_time_index = default_end_time_index
self._order_df = _read_orders(order_file_path).reset_index()
self._ticks_index: Optional[pd.DatetimeIndex] = None
self._data_dir = Path(data_dir)
def __len__(self) -> int:
return len(self._order_df)
def __getitem__(self, index: int) -> Order:
row = self._order_df.iloc[index]
date = pd.Timestamp(str(row["date"]))
if self._ticks_index is None:
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
# TODO: of all dates.
data = load_handler_intraday_processed_data(
data_dir=self._data_dir,
stock_id=row["instrument"],
date=date,
feature_columns_today=[],
feature_columns_yesterday=[],
backtest=True,
index_only=True,
)
self._ticks_index = [t - date for t in data.today.index]
order = Order(
stock_id=row["instrument"],
amount=row["amount"],
direction=OrderDir(int(row["order_type"])),
start_time=date + self._ticks_index[self._default_start_time_index],
end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN,
)
return order
def train_and_test(
env_config: dict,
trainer_config: dict,
data_config: dict,
exchange_config: dict,
qlib_config: dict,
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
policy: BasePolicy,
reward: Reward,
run_training: bool,
run_backtest: bool,
) -> None:
init_qlib(qlib_config)
order_root_path = Path(data_config["source"]["order_dir"])
data_granularity = 1 # simulator_config.get("data_granularity", 1)
exchange_config_default = {
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5.0,
"trade_unit": 100.0,
# "cash_limit": None,
}
exchange_config = {**exchange_config_default, **exchange_config}
exchange_config = _convert_list_to_tuple(exchange_config)
def _simulator_factory(order: Order) -> SingleAssetOrderExecution:
simulator = SingleAssetOrderExecution(
order=order,
executor_config=get_executor_config(data_granularity),
exchange_config={**exchange_config, **{"codes": [order.stock_id]}},
qlib_config=None,
cash_limit=None,
)
return simulator
assert data_config["source"]["default_start_time_index"] % data_granularity == 0
assert data_config["source"]["default_end_time_index"] % data_granularity == 0
if run_training:
train_dataset, valid_dataset = [
LazyLoadDataset(
data_dir=data_config["source"]["feature_root_dir"],
order_file_path=order_root_path / tag,
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
)
for tag in ("train", "valid")
]
callbacks: List[Callback] = []
if "checkpoint_path" in trainer_config:
callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"])))
callbacks.append(
Checkpoint(
dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints",
every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1),
save_latest="copy",
),
)
if "earlystop_patience" in trainer_config:
callbacks.append(
EarlyStopping(
patience=trainer_config["earlystop_patience"],
monitor="val/pa",
)
)
train(
simulator_fn=_simulator_factory,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
policy=policy,
reward=reward,
initial_states=cast(List[Order], train_dataset),
trainer_kwargs={
"max_iters": trainer_config["max_epoch"],
"finite_env_type": env_config["parallel_mode"],
"concurrency": env_config["concurrency"],
"val_every_n_iters": trainer_config.get("val_every_n_epoch", None),
"callbacks": callbacks,
},
vessel_kwargs={
"episode_per_iter": trainer_config["episode_per_collect"],
"update_kwargs": {
"batch_size": trainer_config["batch_size"],
"repeat": trainer_config["repeat_per_collect"],
},
"val_initial_states": valid_dataset,
},
)
if run_backtest:
test_dataset = LazyLoadDataset(
data_dir=data_config["source"]["feature_root_dir"],
order_file_path=order_root_path / "test",
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
)
backtest(
simulator_fn=_simulator_factory,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
initial_states=test_dataset,
policy=policy,
logger=CsvWriter(Path(trainer_config["checkpoint_path"])),
reward=reward,
finite_env_type=env_config["parallel_mode"],
concurrency=env_config["concurrency"],
)
def main(config: dict, run_training: bool, run_backtest: bool) -> None:
if not run_training and not run_backtest:
warnings.warn("Skip the entire job since training and backtest are both skipped.")
return
if "seed" in config["runtime"]:
seed_everything(config["runtime"]["seed"])
for extra_module_path in config["env"].get("extra_module_paths", []):
sys.path.append(extra_module_path)
state_interpreter: StateInterpreter = init_instance_by_config(config["state_interpreter"])
action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"])
reward: Reward = init_instance_by_config(config["reward"])
additional_policy_kwargs = {
"obs_space": state_interpreter.observation_space,
"action_space": action_interpreter.action_space,
}
# Create torch network
if "network" in config:
if "kwargs" not in config["network"]:
config["network"]["kwargs"] = {}
config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space})
additional_policy_kwargs["network"] = init_instance_by_config(config["network"])
# Create policy
if "kwargs" not in config["policy"]:
config["policy"]["kwargs"] = {}
config["policy"]["kwargs"].update(additional_policy_kwargs)
policy: BasePolicy = init_instance_by_config(config["policy"])
use_cuda = config["runtime"].get("use_cuda", False)
if use_cuda:
policy.cuda()
train_and_test(
env_config=config["env"],
data_config=config["data"],
exchange_config=config["exchange"],
qlib_config=config["qlib"],
trainer_config=config["trainer"],
action_interpreter=action_interpreter,
state_interpreter=state_interpreter,
policy=policy,
reward=reward,
run_training=run_training,
run_backtest=run_backtest,
)
if __name__ == "__main__":
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
parser.add_argument("--no_training", action="store_true", help="Skip training workflow.")
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.")
args = parser.parse_args()
config = parse_backtest_config(args.config_path)
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)

View File

@@ -140,6 +140,15 @@ def load_backtest_data(
return backtest_data
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(1000),
key=lambda path: path,
)
def _load_handler_pickle(path: str) -> object:
with open(path, "rb") as fstream:
obj = pickle.load(fstream)
return obj
class HandlerIntradayProcessedData(BaseIntradayProcessedData):
"""Subclass of IntradayProcessedData. Used to handle handler (bin format) style data."""
@@ -151,7 +160,6 @@ class HandlerIntradayProcessedData(BaseIntradayProcessedData):
feature_columns_today: List[str],
feature_columns_yesterday: List[str],
backtest: bool = False,
index_only: bool = False,
) -> None:
def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame:
df = df.reset_index()
@@ -161,31 +169,17 @@ class HandlerIntradayProcessedData(BaseIntradayProcessedData):
path = os.path.join(data_dir, "backtest" if backtest else "feature", f"{stock_id}.pkl")
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
with open(path, "rb") as fstream:
dataset = pickle.load(fstream)
dataset = _load_handler_pickle(path)
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
if index_only:
self.today = _drop_stock_id(data[[]])
self.yesterday = _drop_stock_id(data[[]])
else:
self.today = _drop_stock_id(data[feature_columns_today])
self.yesterday = _drop_stock_id(data[feature_columns_yesterday])
self.today = _drop_stock_id(data[feature_columns_today])
self.yesterday = _drop_stock_id(data[feature_columns_yesterday])
def __repr__(self) -> str:
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
return f"{self.__class__.__name__}({self.today}, {self.yesterday})"
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
key=lambda data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only: (
stock_id,
date,
backtest,
index_only,
),
)
def load_handler_intraday_processed_data(
data_dir: Path,
stock_id: str,
@@ -193,10 +187,9 @@ def load_handler_intraday_processed_data(
feature_columns_today: List[str],
feature_columns_yesterday: List[str],
backtest: bool = False,
index_only: bool = False,
) -> HandlerIntradayProcessedData:
return HandlerIntradayProcessedData(
data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only
data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest,
)
@@ -229,5 +222,4 @@ class HandlerProcessedDataProvider(ProcessedDataProvider):
self.feature_columns_today,
self.feature_columns_yesterday,
backtest=self.backtest,
index_only=False,
)

View File

@@ -90,6 +90,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
trade_strategy=strategy,
trade_executor=self._executor,
return_value=self.report_dict,
show_progress=False,
)
assert isinstance(self._collect_data_loop, Generator)

View File

@@ -42,8 +42,6 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
Path to load backtest data.
feature_columns_today
Columns of today's feature.
feature_columns_yesterday
Columns of yesterday's feature.
data_granularity
Number of ticks between consecutive data entries.
ticks_per_step
@@ -80,7 +78,6 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
order: Order,
data_dir: Path,
feature_columns_today: List[str] = [],
feature_columns_yesterday: List[str] = [],
data_granularity: int = 1,
ticks_per_step: int = 30,
vol_threshold: Optional[float] = None,
@@ -92,7 +89,6 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
self.order = order
self.data_dir = data_dir
self.feature_columns_today = feature_columns_today
self.feature_columns_yesterday = feature_columns_yesterday
self.ticks_per_step: int = ticks_per_step // data_granularity
self.vol_threshold = vol_threshold
@@ -127,9 +123,8 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
stock_id=self.order.stock_id,
date=pd.Timestamp(self.order.start_time.date()),
feature_columns_today=self.feature_columns_today,
feature_columns_yesterday=self.feature_columns_yesterday,
feature_columns_yesterday=[],
backtest=True,
index_only=False,
)
return DataframeIntradayBacktestData(data.today)
except (AttributeError, FileNotFoundError):

View File

@@ -208,6 +208,7 @@ class Trainer:
while not self.should_stop:
msg = f"\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tTrain iteration {self.current_iter + 1}/{self.max_iters}"
print(msg)
_logger.info(msg)
self.initialize_iter()

View File

@@ -0,0 +1,20 @@
import time
from contextlib import contextmanager
from line_profiler import LineProfiler
@contextmanager
def simple_perf(desc: str = ""):
s = time.perf_counter()
yield
e = time.perf_counter()
print(f"{desc}: {(e - s) * 1000.0} ms")
def lprofile(func):
def wrapper(*args, **kwargs):
lp = LineProfiler()
lpw = lp(func)
res = lpw(*args, **kwargs)
lp.print_stats()
return res
return wrapper