diff --git a/qlib/backtest/backtest.py b/qlib/backtest/backtest.py index 5e5edacaf..8f65606a6 100644 --- a/qlib/backtest/backtest.py +++ b/qlib/backtest/backtest.py @@ -56,6 +56,7 @@ def collect_data_loop( trade_strategy: BaseStrategy, trade_executor: BaseExecutor, return_value: dict | None = None, + show_progress: bool = True, ) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], None]: """Generator for collecting the trade decision data for rl training @@ -74,6 +75,8 @@ def collect_data_loop( the outermost executor return_value : dict used for backtest_loop + show_progress: bool + whether to show execution progress Yields ------- @@ -83,7 +86,8 @@ def collect_data_loop( trade_executor.reset(start_time=start_time, end_time=end_time) trade_strategy.reset(level_infra=trade_executor.get_level_infra()) - with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop") as bar: + disable = not show_progress + with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop", disable=disable) as bar: _execute_result = None while not trade_executor.finished(): _trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result) diff --git a/qlib/rl/contrib/backtest.py b/qlib/rl/contrib/backtest.py index 60602c10d..fec7ba35d 100644 --- a/qlib/rl/contrib/backtest.py +++ b/qlib/rl/contrib/backtest.py @@ -15,7 +15,7 @@ import pandas as pd import torch from joblib import Parallel, delayed -from qlib.backtest import INDICATOR_METRIC, collect_data_loop, get_strategy_executor +from qlib.backtest import INDICATOR_METRIC, collect_data_loop, get_exchange, get_strategy_executor from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime from qlib.backtest.executor import SimulatorExecutor from qlib.backtest.high_performance_ds import BaseOrderIndicator @@ -250,8 +250,6 @@ def single_with_collect_data_loop( If generate_report is True, return execution records and the generated report. Otherwise, return only records. """ - init_qlib(backtest_config["qlib"]) - trade_start_time = orders["datetime"].min() trade_end_time = orders["datetime"].max() stocks = orders.instrument.unique().tolist() @@ -275,13 +273,13 @@ def single_with_collect_data_loop( data_granularity=backtest_config["data_granularity"], ) - exchange_config = copy.deepcopy(backtest_config["exchange"]) - exchange_config.update( - { + exchange_config = { + **backtest_config["exchange"], + **{ "codes": stocks, "freq": backtest_config["data_granularity"], } - ) + } strategy, executor = get_strategy_executor( start_time=pd.Timestamp(trade_start_time), @@ -326,6 +324,8 @@ def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFram single = single_with_simulator if with_simulator else single_with_collect_data_loop mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"} torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199 + + init_qlib(backtest_config["qlib"]) res = Parallel(**mp_config)( delayed(single)( backtest_config=backtest_config, diff --git a/qlib/rl/contrib/naive_config_parser.py b/qlib/rl/contrib/naive_config_parser.py index 2255c7414..6a4e9d410 100644 --- a/qlib/rl/contrib/naive_config_parser.py +++ b/qlib/rl/contrib/naive_config_parser.py @@ -105,3 +105,100 @@ def get_backtest_config_fromfile(path: str) -> dict: backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default) return backtest_config + + +class TrainingConfigParser: + def __init__(self, path: str) -> None: + self.raw_config = parse_backtest_config(path) + + def parse(self) -> dict: + return { + "general": self._parse_general(), + "policy": self.raw_config["policy"], + "interpreter": self.raw_config["interpreter"], + "runtime": self._parse_runtime(), + "training": self._parse_training(), + "simulator": self._parse_simulator(), + } + + def _parse_general(self) -> dict: + default = { + "freq": "1min", + "extra_module_paths": [], + } + return {**default, **self.raw_config["general"]} + + def _parse_runtime(self) -> dict: + default = { + "seed": None, + "use_cuda": False, + "concurrency": 1, + "parallel_mode": "dummy", + } + return {**default, **self.raw_config["runtime"]} + + def _parse_training(self) -> dict: + default = { + "max_epoch": 100, + "repeat_per_collect": 2, + "earlystop_patience": float("inf"), + "episode_per_collect": 10000, + "batch_size": 256, + "val_every_n_epoch": None, + "checkpoint_path": "./outputs", + "checkpoint_every_n_iters": 10, + } + + config = self.raw_config["training"] + assert "order_dir" in config + + return {**default, **config} + + def _parse_simulator(self) -> dict: + config = self.raw_config["simulator"] + sim_type = config["type"] + assert sim_type in ("simple", "full") + + if sim_type == "simple": + return { + "type": sim_type, + "data": { + "feature_root_dir": config["data"]["feature_root_dir"], + "feature_columns_today": config["data"]["feature_columns_today"], + "default_start_time_index": config["data"].get("default_start_time_index", 0), + "default_end_time_index": config["data"].get("default_end_time_index", 240), + }, + "time_per_step": config["time_per_step"], + "vol_limit": config["vol_limit"], + } + else: + exchange_config_default = { + "open_cost": 0.0005, + "close_cost": 0.0015, + "min_cost": 5.0, + "trade_unit": 100.0, + # "cash_limit": None, + } + exchange_config = {**exchange_config_default, **_convert_all_list_to_tuple(config["exchange"])} + exchange_config["freq"] = self.raw_config["general"].get("freq", "1min") + + ret_config = { + "type": sim_type, + "data": { + "feature_root_dir": config["data"]["feature_root_dir"], + "default_start_time_index": config["data"].get("default_start_time_index", 0), + "default_end_time_index": config["data"].get("default_end_time_index", 240), + }, + "qlib": { + "provider_uri_1min": config["qlib"]["provider_uri_1min"], + }, + "exchange": exchange_config + } + + return ret_config + +if __name__ == "__main__": + parser = TrainingConfigParser("/home/huoran/exp_configs/amc4th_training_refined.yml") + + from pprint import pprint + pprint(parser.parse()) diff --git a/qlib/rl/contrib/train.py b/qlib/rl/contrib/train.py new file mode 100644 index 000000000..dbfdf8507 --- /dev/null +++ b/qlib/rl/contrib/train.py @@ -0,0 +1,333 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import argparse +import os +import random +import sys +import warnings +from pathlib import Path +from typing import Any, cast, List, Optional + +import numpy as np +import pandas as pd +import torch +from qlib.backtest import Order +from qlib.backtest.decision import OrderDir +from qlib.constant import ONE_MIN +from qlib.rl.contrib.naive_config_parser import TrainingConfigParser +from qlib.rl.data.integration import init_qlib +from qlib.rl.data.native import _load_handler_pickle, load_handler_intraday_processed_data +from qlib.rl.interpreter import ActionInterpreter, StateInterpreter +from qlib.rl.order_execution import SingleAssetOrderExecutionSimple +from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution +from qlib.rl.reward import Reward +from qlib.rl.trainer import Checkpoint, backtest, train +from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter +from qlib.rl.utils.log import CsvWriter +from qlib.utils import init_instance_by_config +from tianshou.policy import BasePolicy +from torch.utils.data import Dataset + + +def get_executor_config(freq: int) -> dict: + return { + "class": "NestedExecutor", + "module_path": "qlib.backtest.executor", + "kwargs": { + "inner_executor": { + "class": "NestedExecutor", + "module_path": "qlib.backtest.executor", + "kwargs": { + "inner_executor": { + "class": "SimulatorExecutor", + "module_path": "qlib.backtest.executor", + "kwargs": { + "generate_report": False, + "time_per_step": f"{freq}min", + "track_data": True, + "trade_type": "serial", + "verbose": False, + } + }, + "inner_strategy": { + "class": "TWAPStrategy", + "kwargs": {}, + "module_path": "qlib.contrib.strategy.rule_strategy", + }, + "time_per_step": "30min", + "track_data": True, + } + }, + "inner_strategy": { + "class": "ProxySAOEStrategy", + "module_path": "qlib.rl.order_execution.strategy", + "kwargs": {}, + }, + "time_per_step": "1day", + "track_data": True, + } + } + + +def seed_everything(seed: int) -> None: + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +def _read_orders(order_dir: Path) -> pd.DataFrame: + if os.path.isfile(order_dir): + return pd.read_pickle(order_dir) + else: + orders = [] + for file in order_dir.iterdir(): + order_data = pd.read_pickle(file) + orders.append(order_data) + return pd.concat(orders) + + +def _freq_str_to_int(freq: str) -> int: + if freq.endswith("min"): + return int(freq.replace("min", "")) + elif freq.endswith("hour"): + return int(freq.replace("hour", "") * 60) + else: + raise ValueError(f"Unrecognized freq string: {freq}") + + +class LazyLoadDataset(Dataset): + def __init__( + self, + data_dir: str, + order_file_path: Path, + default_start_time_index: int, + default_end_time_index: int, + ) -> None: + self._default_start_time_index = default_start_time_index + self._default_end_time_index = default_end_time_index + + self._order_df = _read_orders(order_file_path).reset_index() + self._ticks_index: Optional[pd.DatetimeIndex] = None + self._data_dir = Path(data_dir) + + def __len__(self) -> int: + return len(self._order_df) + + def __getitem__(self, index: int) -> Order: + row = self._order_df.iloc[index] + date = pd.Timestamp(str(row["date"])) + + if self._ticks_index is None: + # TODO: We only load ticks index once based on the assumption that ticks index of different dates + # TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index + # TODO: of all dates. + + data = load_handler_intraday_processed_data( + data_dir=self._data_dir, + stock_id=row["instrument"], + date=date, + feature_columns_today=[], + feature_columns_yesterday=[], + backtest=True, + ) + self._ticks_index = [t - date for t in data.today.index] + + order = Order( + stock_id=row["instrument"], + amount=row["amount"], + direction=OrderDir(int(row["order_type"])), + start_time=date + self._ticks_index[self._default_start_time_index], + end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN, + ) + + return order + + +def train_and_test( + freq: str, + concurrency: int, + parallel_mode: str, + training_config: dict, + simulator_config: dict, + policy: BasePolicy, + state_interpreter: StateInterpreter, + action_interpreter: ActionInterpreter, + reward: Reward, + run_training: bool, + run_backtest: bool, +) -> None: + freq = _freq_str_to_int(freq) + order_root_path = Path(training_config["order_dir"]) + feature_root_dir = simulator_config["data"]["feature_root_dir"] + assert simulator_config["data"]["default_start_time_index"] % freq == 0 + assert simulator_config["data"]["default_end_time_index"] % freq == 0 + + sim_type = simulator_config["type"] + if sim_type == "simple": + def _simulator_factory(order: Order) -> SingleAssetOrderExecutionSimple: + simulator = SingleAssetOrderExecutionSimple( + order=order, + data_dir=feature_root_dir, + feature_columns_today=simulator_config["data"]["feature_columns_today"], + data_granularity=freq, + ticks_per_step=simulator_config["time_per_step"], + vol_threshold=simulator_config["vol_limit"], + ) + return simulator + elif sim_type == "full": + init_qlib(simulator_config["qlib"]) + executor_config = get_executor_config(freq) + exchange_config = simulator_config["exchange"] + + def _simulator_factory(order: Order) -> SingleAssetOrderExecution: + simulator = SingleAssetOrderExecution( + order=order, + executor_config=executor_config, + exchange_config={**exchange_config, **{"codes": [order.stock_id]}}, + qlib_config=None, + cash_limit=None, + ) + return simulator + + if run_training: + train_dataset, valid_dataset = [ + LazyLoadDataset( + data_dir=feature_root_dir, + order_file_path=order_root_path / tag, + default_start_time_index=simulator_config["data"]["default_start_time_index"] // freq, + default_end_time_index=simulator_config["data"]["default_end_time_index"] // freq, + ) + for tag in ("train", "valid") + ] + + callbacks: List[Callback] = [ + MetricsWriter(dirpath=Path(training_config["checkpoint_path"])), + Checkpoint( + dirpath=Path(training_config["checkpoint_path"]) / "checkpoints", + every_n_iters=training_config["checkpoint_every_n_iters"], + save_latest="copy", + ), + EarlyStopping( + patience=training_config["earlystop_patience"], + monitor="val/pa", + ), + ] + + train( + simulator_fn=_simulator_factory, + state_interpreter=state_interpreter, + action_interpreter=action_interpreter, + policy=policy, + reward=reward, + initial_states=cast(List[Order], train_dataset), + trainer_kwargs={ + "max_iters": training_config["max_epoch"], + "finite_env_type": parallel_mode, + "concurrency": concurrency, + "val_every_n_iters": training_config["val_every_n_epoch"], + "callbacks": callbacks, + }, + vessel_kwargs={ + "episode_per_iter": training_config["episode_per_collect"], + "update_kwargs": { + "batch_size": training_config["batch_size"], + "repeat": training_config["repeat_per_collect"], + }, + "val_initial_states": valid_dataset, + }, + ) + + if run_backtest: + test_dataset = LazyLoadDataset( + data_dir=feature_root_dir, + order_file_path=order_root_path / "test", + default_start_time_index=simulator_config["data"]["default_start_time_index"] // freq, + default_end_time_index=simulator_config["data"]["default_end_time_index"] // freq, + ) + + backtest( + simulator_fn=_simulator_factory, + state_interpreter=state_interpreter, + action_interpreter=action_interpreter, + initial_states=test_dataset, + policy=policy, + logger=CsvWriter(Path(training_config["checkpoint_path"])), + reward=reward, + finite_env_type=parallel_mode, + concurrency=concurrency, + ) + + +def main(config: dict, run_training: bool, run_backtest: bool) -> None: + if not run_training and not run_backtest: + warnings.warn("Skip the entire job since training and backtest are both skipped.") + return + + seed = config["runtime"]["seed"] + if seed is not None: + seed_everything(seed) + + for extra_module_path in config["general"]["extra_module_paths"]: + sys.path.append(extra_module_path) + + state_interpreter: StateInterpreter = init_instance_by_config(config["interpreter"]["state"]) + action_interpreter: ActionInterpreter = init_instance_by_config(config["interpreter"]["action"]) + reward: Reward = init_instance_by_config(config["interpreter"]["reward"]) + + additional_policy_kwargs = { + "obs_space": state_interpreter.observation_space, + "action_space": action_interpreter.action_space, + } + # Create torch network + if "network" in config["policy"]: + network_config = config["policy"]["network"] + network_config["kwargs"] = { + **network_config.get("kwargs", {}), + **{"obs_space": state_interpreter.observation_space} + } + additional_policy_kwargs["network"] = init_instance_by_config(network_config) + + # Create policy + policy_config = config["policy"]["policy"] + policy_config["kwargs"] = { + **policy_config.get("kwargs", {}), + **additional_policy_kwargs + } + policy: BasePolicy = init_instance_by_config(policy_config) + + use_cuda = config["runtime"]["use_cuda"] + if use_cuda: + policy.cuda() + + train_and_test( + freq=config["general"]["freq"], + concurrency=config["runtime"]["concurrency"], + parallel_mode=config["runtime"]["parallel_mode"], + training_config=config["training"], + simulator_config=config["simulator"], + policy=policy, + state_interpreter=state_interpreter, + action_interpreter=action_interpreter, + reward=reward, + run_training=run_training, + run_backtest=run_backtest, + ) + + +if __name__ == "__main__": + warnings.filterwarnings("ignore", category=DeprecationWarning) + warnings.filterwarnings("ignore", category=RuntimeWarning) + + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", type=str, required=True, help="Path to the config file") + parser.add_argument("--no_training", action="store_true", help="Skip training workflow.") + parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.") + args = parser.parse_args() + + config_parser = TrainingConfigParser(args.config_path) + config = config_parser.parse() + main(config, run_training=not args.no_training, run_backtest=args.run_backtest) diff --git a/qlib/rl/contrib/train_onpolicy_full_simulation.py b/qlib/rl/contrib/train_onpolicy_full_simulation.py new file mode 100644 index 000000000..af2abb3e3 --- /dev/null +++ b/qlib/rl/contrib/train_onpolicy_full_simulation.py @@ -0,0 +1,331 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import argparse +import os +import random +import sys +import warnings +from pathlib import Path +from typing import Any, cast, List, Optional + +import numpy as np +import pandas as pd +import torch +import yaml +from qlib.backtest import Order +from qlib.backtest.decision import OrderDir +from qlib.constant import ONE_MIN +from qlib.rl.contrib.naive_config_parser import parse_backtest_config +from qlib.rl.data.integration import init_qlib +from qlib.rl.data.native import load_handler_intraday_processed_data +from qlib.rl.interpreter import ActionInterpreter, StateInterpreter +from qlib.rl.order_execution import SingleAssetOrderExecutionSimple +from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution +from qlib.rl.reward import Reward +from qlib.rl.trainer import Checkpoint, backtest, train +from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter +from qlib.rl.utils.log import CsvWriter +from qlib.utils import init_instance_by_config +from tianshou.policy import BasePolicy +from torch.utils.data import Dataset + + +def get_executor_config(data_granularity: int = 1) -> dict: + return { + "class": "NestedExecutor", + "module_path": "qlib.backtest.executor", + "kwargs": { + "inner_executor": { + "class": "NestedExecutor", + "module_path": "qlib.backtest.executor", + "kwargs": { + "inner_executor": { + "class": "SimulatorExecutor", + "module_path": "qlib.backtest.executor", + "kwargs": { + "generate_report": False, + "time_per_step": f"{data_granularity}min", + "track_data": True, + "trade_type": "serial", + "verbose": False, + } + }, + "inner_strategy": { + "class": "TWAPStrategy", + "kwargs": {}, + "module_path": "qlib.contrib.strategy.rule_strategy", + }, + "time_per_step": "30min", + "track_data": True, + } + }, + "inner_strategy": { + "class": "ProxySAOEStrategy", + "module_path": "qlib.rl.order_execution.strategy", + "kwargs": {}, + }, + "time_per_step": "1day", + "track_data": True, + } + } + + +def _convert_list_to_tuple(obj: Any) -> Any: + if isinstance(obj, dict): + return {k: _convert_list_to_tuple(v) for k, v in obj.items()} + elif isinstance(obj, list): + return tuple(obj) + else: + return obj + + +def seed_everything(seed: int) -> None: + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +def _read_orders(order_dir: Path) -> pd.DataFrame: + if os.path.isfile(order_dir): + return pd.read_pickle(order_dir) + else: + orders = [] + for file in order_dir.iterdir(): + order_data = pd.read_pickle(file) + orders.append(order_data) + return pd.concat(orders) + + +class LazyLoadDataset(Dataset): + def __init__( + self, + data_dir: str, + order_file_path: Path, + default_start_time_index: int, + default_end_time_index: int, + ) -> None: + self._default_start_time_index = default_start_time_index + self._default_end_time_index = default_end_time_index + + self._order_df = _read_orders(order_file_path).reset_index() + self._ticks_index: Optional[pd.DatetimeIndex] = None + self._data_dir = Path(data_dir) + + def __len__(self) -> int: + return len(self._order_df) + + def __getitem__(self, index: int) -> Order: + row = self._order_df.iloc[index] + date = pd.Timestamp(str(row["date"])) + + if self._ticks_index is None: + # TODO: We only load ticks index once based on the assumption that ticks index of different dates + # TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index + # TODO: of all dates. + + data = load_handler_intraday_processed_data( + data_dir=self._data_dir, + stock_id=row["instrument"], + date=date, + feature_columns_today=[], + feature_columns_yesterday=[], + backtest=True, + index_only=True, + ) + self._ticks_index = [t - date for t in data.today.index] + + order = Order( + stock_id=row["instrument"], + amount=row["amount"], + direction=OrderDir(int(row["order_type"])), + start_time=date + self._ticks_index[self._default_start_time_index], + end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN, + ) + + return order + + +def train_and_test( + env_config: dict, + trainer_config: dict, + data_config: dict, + exchange_config: dict, + qlib_config: dict, + state_interpreter: StateInterpreter, + action_interpreter: ActionInterpreter, + policy: BasePolicy, + reward: Reward, + run_training: bool, + run_backtest: bool, +) -> None: + init_qlib(qlib_config) + + order_root_path = Path(data_config["source"]["order_dir"]) + + data_granularity = 1 # simulator_config.get("data_granularity", 1) + + exchange_config_default = { + "open_cost": 0.0005, + "close_cost": 0.0015, + "min_cost": 5.0, + "trade_unit": 100.0, + # "cash_limit": None, + } + exchange_config = {**exchange_config_default, **exchange_config} + exchange_config = _convert_list_to_tuple(exchange_config) + + def _simulator_factory(order: Order) -> SingleAssetOrderExecution: + simulator = SingleAssetOrderExecution( + order=order, + executor_config=get_executor_config(data_granularity), + exchange_config={**exchange_config, **{"codes": [order.stock_id]}}, + qlib_config=None, + cash_limit=None, + ) + return simulator + + assert data_config["source"]["default_start_time_index"] % data_granularity == 0 + assert data_config["source"]["default_end_time_index"] % data_granularity == 0 + + if run_training: + train_dataset, valid_dataset = [ + LazyLoadDataset( + data_dir=data_config["source"]["feature_root_dir"], + order_file_path=order_root_path / tag, + default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity, + default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity, + ) + for tag in ("train", "valid") + ] + + callbacks: List[Callback] = [] + if "checkpoint_path" in trainer_config: + callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"]))) + callbacks.append( + Checkpoint( + dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints", + every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1), + save_latest="copy", + ), + ) + if "earlystop_patience" in trainer_config: + callbacks.append( + EarlyStopping( + patience=trainer_config["earlystop_patience"], + monitor="val/pa", + ) + ) + + train( + simulator_fn=_simulator_factory, + state_interpreter=state_interpreter, + action_interpreter=action_interpreter, + policy=policy, + reward=reward, + initial_states=cast(List[Order], train_dataset), + trainer_kwargs={ + "max_iters": trainer_config["max_epoch"], + "finite_env_type": env_config["parallel_mode"], + "concurrency": env_config["concurrency"], + "val_every_n_iters": trainer_config.get("val_every_n_epoch", None), + "callbacks": callbacks, + }, + vessel_kwargs={ + "episode_per_iter": trainer_config["episode_per_collect"], + "update_kwargs": { + "batch_size": trainer_config["batch_size"], + "repeat": trainer_config["repeat_per_collect"], + }, + "val_initial_states": valid_dataset, + }, + ) + + if run_backtest: + test_dataset = LazyLoadDataset( + data_dir=data_config["source"]["feature_root_dir"], + order_file_path=order_root_path / "test", + default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity, + default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity, + ) + + backtest( + simulator_fn=_simulator_factory, + state_interpreter=state_interpreter, + action_interpreter=action_interpreter, + initial_states=test_dataset, + policy=policy, + logger=CsvWriter(Path(trainer_config["checkpoint_path"])), + reward=reward, + finite_env_type=env_config["parallel_mode"], + concurrency=env_config["concurrency"], + ) + + +def main(config: dict, run_training: bool, run_backtest: bool) -> None: + if not run_training and not run_backtest: + warnings.warn("Skip the entire job since training and backtest are both skipped.") + return + + if "seed" in config["runtime"]: + seed_everything(config["runtime"]["seed"]) + + for extra_module_path in config["env"].get("extra_module_paths", []): + sys.path.append(extra_module_path) + + state_interpreter: StateInterpreter = init_instance_by_config(config["state_interpreter"]) + action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"]) + reward: Reward = init_instance_by_config(config["reward"]) + + additional_policy_kwargs = { + "obs_space": state_interpreter.observation_space, + "action_space": action_interpreter.action_space, + } + + # Create torch network + if "network" in config: + if "kwargs" not in config["network"]: + config["network"]["kwargs"] = {} + config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space}) + additional_policy_kwargs["network"] = init_instance_by_config(config["network"]) + + # Create policy + if "kwargs" not in config["policy"]: + config["policy"]["kwargs"] = {} + config["policy"]["kwargs"].update(additional_policy_kwargs) + policy: BasePolicy = init_instance_by_config(config["policy"]) + + use_cuda = config["runtime"].get("use_cuda", False) + if use_cuda: + policy.cuda() + + train_and_test( + env_config=config["env"], + data_config=config["data"], + exchange_config=config["exchange"], + qlib_config=config["qlib"], + trainer_config=config["trainer"], + action_interpreter=action_interpreter, + state_interpreter=state_interpreter, + policy=policy, + reward=reward, + run_training=run_training, + run_backtest=run_backtest, + ) + + +if __name__ == "__main__": + warnings.filterwarnings("ignore", category=DeprecationWarning) + warnings.filterwarnings("ignore", category=RuntimeWarning) + + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", type=str, required=True, help="Path to the config file") + parser.add_argument("--no_training", action="store_true", help="Skip training workflow.") + parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.") + args = parser.parse_args() + + config = parse_backtest_config(args.config_path) + main(config, run_training=not args.no_training, run_backtest=args.run_backtest) diff --git a/qlib/rl/data/native.py b/qlib/rl/data/native.py index ceb540882..eb9d44d58 100644 --- a/qlib/rl/data/native.py +++ b/qlib/rl/data/native.py @@ -140,6 +140,15 @@ def load_backtest_data( return backtest_data +@cachetools.cached( # type: ignore + cache=cachetools.LRUCache(1000), + key=lambda path: path, +) +def _load_handler_pickle(path: str) -> object: + with open(path, "rb") as fstream: + obj = pickle.load(fstream) + return obj + class HandlerIntradayProcessedData(BaseIntradayProcessedData): """Subclass of IntradayProcessedData. Used to handle handler (bin format) style data.""" @@ -151,7 +160,6 @@ class HandlerIntradayProcessedData(BaseIntradayProcessedData): feature_columns_today: List[str], feature_columns_yesterday: List[str], backtest: bool = False, - index_only: bool = False, ) -> None: def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame: df = df.reset_index() @@ -161,31 +169,17 @@ class HandlerIntradayProcessedData(BaseIntradayProcessedData): path = os.path.join(data_dir, "backtest" if backtest else "feature", f"{stock_id}.pkl") start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59) - with open(path, "rb") as fstream: - dataset = pickle.load(fstream) + dataset = _load_handler_pickle(path) data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None) - if index_only: - self.today = _drop_stock_id(data[[]]) - self.yesterday = _drop_stock_id(data[[]]) - else: - self.today = _drop_stock_id(data[feature_columns_today]) - self.yesterday = _drop_stock_id(data[feature_columns_yesterday]) + self.today = _drop_stock_id(data[feature_columns_today]) + self.yesterday = _drop_stock_id(data[feature_columns_yesterday]) def __repr__(self) -> str: with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"): return f"{self.__class__.__name__}({self.today}, {self.yesterday})" -@cachetools.cached( # type: ignore - cache=cachetools.LRUCache(100), # 100 * 50K = 5MB - key=lambda data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only: ( - stock_id, - date, - backtest, - index_only, - ), -) def load_handler_intraday_processed_data( data_dir: Path, stock_id: str, @@ -193,10 +187,9 @@ def load_handler_intraday_processed_data( feature_columns_today: List[str], feature_columns_yesterday: List[str], backtest: bool = False, - index_only: bool = False, ) -> HandlerIntradayProcessedData: return HandlerIntradayProcessedData( - data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only + data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, ) @@ -229,5 +222,4 @@ class HandlerProcessedDataProvider(ProcessedDataProvider): self.feature_columns_today, self.feature_columns_yesterday, backtest=self.backtest, - index_only=False, ) diff --git a/qlib/rl/order_execution/simulator_qlib.py b/qlib/rl/order_execution/simulator_qlib.py index 1417e2ab4..e2d7defcc 100644 --- a/qlib/rl/order_execution/simulator_qlib.py +++ b/qlib/rl/order_execution/simulator_qlib.py @@ -90,6 +90,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): trade_strategy=strategy, trade_executor=self._executor, return_value=self.report_dict, + show_progress=False, ) assert isinstance(self._collect_data_loop, Generator) diff --git a/qlib/rl/order_execution/simulator_simple.py b/qlib/rl/order_execution/simulator_simple.py index 48aa03a17..cdfbd2098 100644 --- a/qlib/rl/order_execution/simulator_simple.py +++ b/qlib/rl/order_execution/simulator_simple.py @@ -42,8 +42,6 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]): Path to load backtest data. feature_columns_today Columns of today's feature. - feature_columns_yesterday - Columns of yesterday's feature. data_granularity Number of ticks between consecutive data entries. ticks_per_step @@ -80,7 +78,6 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]): order: Order, data_dir: Path, feature_columns_today: List[str] = [], - feature_columns_yesterday: List[str] = [], data_granularity: int = 1, ticks_per_step: int = 30, vol_threshold: Optional[float] = None, @@ -92,7 +89,6 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]): self.order = order self.data_dir = data_dir self.feature_columns_today = feature_columns_today - self.feature_columns_yesterday = feature_columns_yesterday self.ticks_per_step: int = ticks_per_step // data_granularity self.vol_threshold = vol_threshold @@ -127,9 +123,8 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]): stock_id=self.order.stock_id, date=pd.Timestamp(self.order.start_time.date()), feature_columns_today=self.feature_columns_today, - feature_columns_yesterday=self.feature_columns_yesterday, + feature_columns_yesterday=[], backtest=True, - index_only=False, ) return DataframeIntradayBacktestData(data.today) except (AttributeError, FileNotFoundError): diff --git a/qlib/rl/trainer/trainer.py b/qlib/rl/trainer/trainer.py index fb73dd549..9d457f82d 100644 --- a/qlib/rl/trainer/trainer.py +++ b/qlib/rl/trainer/trainer.py @@ -208,6 +208,7 @@ class Trainer: while not self.should_stop: msg = f"\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tTrain iteration {self.current_iter + 1}/{self.max_iters}" + print(msg) _logger.info(msg) self.initialize_iter() diff --git a/qlib/rl/utils/profiling.py b/qlib/rl/utils/profiling.py new file mode 100644 index 000000000..15d350e9c --- /dev/null +++ b/qlib/rl/utils/profiling.py @@ -0,0 +1,20 @@ +import time +from contextlib import contextmanager +from line_profiler import LineProfiler + +@contextmanager +def simple_perf(desc: str = ""): + s = time.perf_counter() + yield + e = time.perf_counter() + print(f"{desc}: {(e - s) * 1000.0} ms") + + +def lprofile(func): + def wrapper(*args, **kwargs): + lp = LineProfiler() + lpw = lp(func) + res = lpw(*args, **kwargs) + lp.print_stats() + return res + return wrapper