diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index bb8ca731b..0cb37a61c 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -179,7 +179,7 @@ def get_strategy_executor( executor: Union[str, dict, object, Path], benchmark: Optional[str] = "SH000300", account: Union[float, int, dict] = 1e9, - exchange_kwargs: dict = {}, + exchange_kwargs: Union[dict, Exchange] = {}, # TODO: rename parameter pos_type: str = "Position", ) -> Tuple[BaseStrategy, BaseExecutor]: @@ -197,12 +197,15 @@ def get_strategy_executor( pos_type=pos_type, ) - exchange_kwargs = copy.copy(exchange_kwargs) - if "start_time" not in exchange_kwargs: - exchange_kwargs["start_time"] = start_time - if "end_time" not in exchange_kwargs: - exchange_kwargs["end_time"] = end_time - trade_exchange = get_exchange(**exchange_kwargs) + if isinstance(exchange_kwargs, Exchange): + trade_exchange = exchange_kwargs + else: + exchange_kwargs = copy.copy(exchange_kwargs) + if "start_time" not in exchange_kwargs: + exchange_kwargs["start_time"] = start_time + if "end_time" not in exchange_kwargs: + exchange_kwargs["end_time"] = end_time + trade_exchange = get_exchange(**exchange_kwargs) common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange) trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy) diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index a752a9f8c..bcfd8610f 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -177,7 +177,7 @@ class Exchange: necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"} if self.limit_type == self.LT_TP_EXP: - assert isinstance(limit_threshold, tuple) + assert isinstance(limit_threshold, tuple) or (isinstance(limit_threshold, list) and len(limit_threshold) == 2) for exp in limit_threshold: necessary_fields.add(exp) all_fields = list(necessary_fields | set(vol_lt_fields) | set(subscribe_fields)) @@ -263,6 +263,9 @@ class Exchange: """get limit type""" if isinstance(limit_threshold, tuple): return self.LT_TP_EXP + if isinstance(limit_threshold, list): + assert len(limit_threshold) == 2 + return self.LT_TP_EXP elif isinstance(limit_threshold, float): return self.LT_FLT elif limit_threshold is None: @@ -325,7 +328,7 @@ class Exchange: assert isinstance(volume_threshold, dict) for key, vol_limit in volume_threshold.items(): - assert isinstance(vol_limit, tuple) + assert isinstance(vol_limit, tuple) or (isinstance(vol_limit, list) and len(vol_limit) == 2) fields.add(vol_limit[1]) if key in ("buy", "all"): @@ -803,7 +806,7 @@ class Exchange: vol_limit_num: List[float] = [] for limit in vol_limit: - assert isinstance(limit, tuple) + assert isinstance(limit, tuple) or (isinstance(limit, list) and len(limit) == 2) if limit[0] == "current": limit_value = self.quote.get_data( order.stock_id, diff --git a/qlib/rl/contrib/backtest.py b/qlib/rl/contrib/backtest.py index 1786c07b1..a0af23175 100644 --- a/qlib/rl/contrib/backtest.py +++ b/qlib/rl/contrib/backtest.py @@ -15,14 +15,13 @@ import pandas as pd import torch from joblib import Parallel, delayed -from qlib.backtest import INDICATOR_METRIC, collect_data_loop, get_exchange, get_strategy_executor -from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime +from qlib.backtest import INDICATOR_METRIC, collect_data_loop, get_strategy_executor +from qlib.backtest.decision import BaseTradeDecision, TradeRangeByTime from qlib.backtest.executor import SimulatorExecutor from qlib.backtest.high_performance_ds import BaseOrderIndicator -from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile +from qlib.rl.contrib.naive_config_parser import BacktestConfigParser from qlib.rl.contrib.utils import read_order_file from qlib.rl.data.integration import init_qlib -from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution from qlib.typehint import Literal @@ -124,105 +123,13 @@ def _generate_report( return report -def single_with_simulator( - backtest_config: dict, - orders: pd.DataFrame, - split: Literal["stock", "day"] = "stock", - cash_limit: float | None = None, - generate_report: bool = False, -) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]: - """Run backtest in a single thread with SingleAssetOrderExecution simulator. The orders will be executed day by day. - A new simulator will be created and used for every single-day order. - - Parameters - ---------- - backtest_config: - Backtest config - orders: - Orders to be executed. Example format: - datetime instrument amount direction - 0 2020-06-01 INST 600.0 0 - 1 2020-06-02 INST 700.0 1 - ... - split - Method to split orders. If it is "stock", split orders by stock. If it is "day", split orders by date. - cash_limit - Limitation of cash. - generate_report - Whether to generate reports. - - Returns - ------- - If generate_report is True, return execution records and the generated report. Otherwise, return only records. - """ - init_qlib(backtest_config["qlib"]) - - stocks = orders.instrument.unique().tolist() - - reports = [] - decisions = [] - for _, row in orders.iterrows(): - date = pd.Timestamp(row["datetime"]) - start_time = pd.Timestamp(backtest_config["start_time"]).replace(year=date.year, month=date.month, day=date.day) - end_time = pd.Timestamp(backtest_config["end_time"]).replace(year=date.year, month=date.month, day=date.day) - order = Order( - stock_id=row["instrument"], - amount=row["amount"], - direction=OrderDir(row["direction"]), - start_time=start_time, - end_time=end_time, - ) - - executor_config = _get_multi_level_executor_config( - strategy_config=backtest_config["strategies"], - cash_limit=cash_limit, - generate_report=generate_report, - data_granularity=backtest_config["data_granularity"], - ) - - exchange_config = copy.deepcopy(backtest_config["exchange"]) - exchange_config.update( - { - "codes": stocks, - "freq": backtest_config["data_granularity"], - } - ) - - simulator = SingleAssetOrderExecution( - order=order, - executor_config=executor_config, - exchange_config=exchange_config, - qlib_config=None, - cash_limit=None, - ) - - reports.append(simulator.report_dict) - decisions += simulator.decisions - - indicator_1day_objs = [report["indicator_dict"]["1day"][1] for report in reports] - indicator_info = {k: v for obj in indicator_1day_objs for k, v in obj.order_indicator_his.items()} - records = _convert_indicator_to_dataframe(indicator_info) - assert records is None or not np.isnan(records["ffr"]).any() - - if generate_report: - _report = _generate_report(decisions, [report["indicator"] for report in reports]) - - if split == "stock": - stock_id = orders.iloc[0].instrument - report = {stock_id: _report} - else: - day = orders.iloc[0].datetime - report = {day: _report} - - return records, report - else: - return records - - def single_with_collect_data_loop( - backtest_config: dict, orders: pd.DataFrame, + time_range: Tuple[str, str], + exchange_config: dict, + strategy_config: dict, split: Literal["stock", "day"] = "stock", + data_granularity: str = "1min", cash_limit: float | None = None, generate_report: bool = False, ) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]: @@ -254,38 +161,38 @@ def single_with_collect_data_loop( trade_end_time = orders["datetime"].max() stocks = orders.instrument.unique().tolist() - strategy_config = { + top_strategy_config = { "class": "FileOrderStrategy", "module_path": "qlib.contrib.strategy.rule_strategy", "kwargs": { "file": orders, "trade_range": TradeRangeByTime( - pd.Timestamp(backtest_config["start_time"]).time(), - pd.Timestamp(backtest_config["end_time"]).time(), + pd.Timestamp(time_range[0]).time(), + pd.Timestamp(time_range[1]).time(), ), }, } - executor_config = _get_multi_level_executor_config( - strategy_config=backtest_config["strategies"], + top_executor_config = _get_multi_level_executor_config( + strategy_config=strategy_config, cash_limit=cash_limit, generate_report=generate_report, - data_granularity=backtest_config["data_granularity"], + data_granularity=data_granularity, ) exchange_config = { - **backtest_config["exchange"], + **exchange_config, **{ "codes": stocks, - "freq": backtest_config["data_granularity"], + "freq": data_granularity, }, } strategy, executor = get_strategy_executor( start_time=pd.Timestamp(trade_start_time), end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1), - strategy=strategy_config, - executor=executor_config, + strategy=top_strategy_config, + executor=top_executor_config, benchmark=None, account=cash_limit if cash_limit is not None else int(1e12), exchange_kwargs=exchange_config, @@ -293,7 +200,7 @@ def single_with_collect_data_loop( ) report_dict: dict = {} - decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict)) + decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict, show_progress=False)) indicator_dict = cast(INDICATOR_METRIC, report_dict.get("indicator_dict")) records = _convert_indicator_to_dataframe(indicator_dict["1day"][1].order_indicator_his) @@ -313,48 +220,54 @@ def single_with_collect_data_loop( def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFrame: - order_df = read_order_file(backtest_config["order_file"]) - - cash_limit = backtest_config["exchange"].pop("cash_limit") - generate_report = backtest_config.pop("generate_report") - - stock_pool = order_df["instrument"].unique().tolist() - stock_pool.sort() - - single = single_with_simulator if with_simulator else single_with_collect_data_loop - mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"} + init_qlib(backtest_config["simulator"]["qlib"]) torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199 - - init_qlib(backtest_config["qlib"]) - res = Parallel(**mp_config)( - delayed(single)( - backtest_config=backtest_config, - orders=order_df[order_df["instrument"] == stock].copy(), - split="stock", - cash_limit=cash_limit, - generate_report=generate_report, + + single = single_with_collect_data_loop + mp_config = {"n_jobs": backtest_config["runtime"]["concurrency"], "verbose": 10, "backend": "multiprocessing"} + + for task_config in backtest_config["tasks"]: + order_df = read_order_file(task_config["order_file"]) + exchange_config = task_config["exchange"] + cash_limit = exchange_config.pop("cash_limit") + generate_report = backtest_config["runtime"]["generate_report"] + + stock_pool = order_df["instrument"].unique().tolist() + stock_pool.sort() + + # + res = Parallel(**mp_config)( + delayed(single)( + orders=order_df[order_df["instrument"] == stock].copy(), + time_range=task_config["time_range"], + exchange_config=task_config["exchange"], + strategy_config=backtest_config["strategies"], + split="stock", + data_granularity=task_config["data_granularity"], + cash_limit=cash_limit, + generate_report=generate_report, + ) + for stock in stock_pool ) - for stock in stock_pool - ) - - output_path = Path(backtest_config["output_dir"]) - if generate_report: - with (output_path / "report.pkl").open("wb") as f: - report = {} - for r in res: - report.update(r[1]) - pickle.dump(report, f) - res = pd.concat([r[0] for r in res], 0) - else: - res = pd.concat(res) - - if not output_path.exists(): - os.makedirs(output_path) - - if "pa" in res.columns: - res["pa"] = res["pa"] * 10000.0 # align with training metrics - res.to_csv(output_path / "backtest_result.csv") - return res + + # + output_path = Path(task_config["output_dir"]) + os.makedirs(output_path, exist_ok=True) + + if generate_report: + with (output_path / "report.pkl").open("wb") as f: + report = {} + for r in res: + report.update(r[1]) + pickle.dump(report, f) + res = pd.concat([r[0] for r in res], 0) + else: + res = pd.concat(res) + + if "pa" in res.columns: + res["pa"] = res["pa"] * 10000.0 # align with training metrics + res.to_csv(output_path / "backtest_result.csv") + # return res # TODO if __name__ == "__main__": @@ -362,6 +275,7 @@ if __name__ == "__main__": warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=RuntimeWarning) + warnings.filterwarnings("ignore", category=FutureWarning) parser = argparse.ArgumentParser() parser.add_argument("--config_path", type=str, required=True, help="Path to the config file") @@ -374,9 +288,11 @@ if __name__ == "__main__": ) args = parser.parse_args() - config = get_backtest_config_fromfile(args.config_path) - if args.n_jobs is not None: - config["concurrency"] = args.n_jobs + + config_parser = BacktestConfigParser(args.config_path) + config = config_parser.parse() + if args.n_jobs is not None: # Overwrite concurrency + config["runtime"]["concurrency"] = args.n_jobs backtest( backtest_config=config, diff --git a/qlib/rl/contrib/naive_config_parser.py b/qlib/rl/contrib/naive_config_parser.py index 5de307362..bc0792343 100644 --- a/qlib/rl/contrib/naive_config_parser.py +++ b/qlib/rl/contrib/naive_config_parser.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import copy import os import platform import shutil @@ -30,7 +31,7 @@ def check_file_exist(filename: str, msg_tmpl: str = 'file "{}" does not exist') raise FileNotFoundError(msg_tmpl.format(filename)) -def parse_backtest_config(path: str) -> dict: +def load_config(path: str) -> dict: abs_path = os.path.abspath(path) check_file_exist(abs_path) @@ -65,51 +66,63 @@ def parse_backtest_config(path: str) -> dict: base_file_name = [base_file_name] for f in base_file_name: - base_config = parse_backtest_config(os.path.join(os.path.dirname(abs_path), f)) + base_config = load_config(os.path.join(os.path.dirname(abs_path), f)) config = merge_a_into_b(a=config, b=base_config) return config -def _convert_all_list_to_tuple(config: dict) -> dict: - for k, v in config.items(): - if isinstance(v, list): - config[k] = tuple(v) - elif isinstance(v, dict): - config[k] = _convert_all_list_to_tuple(v) - return config +class BacktestConfigParser: + def __init__(self, path: str) -> None: + self.raw_config = load_config(path) + + def parse(self) -> dict: + self._simulator_config = self._parse_simulator() + self._exchange_config = self._simulator_config.pop("exchange") + config = { + "strategies": self.raw_config["strategies"], + "runtime": self.raw_config["runtime"], + "tasks": self._parse_tasks(), + "simulator": self._simulator_config, + } + return config + + def _parse_tasks(self) -> dict: + task_config = [] + for task in self.raw_config["tasks"]: + if "output_dir" not in task: + task["output_dir"] = os.path.join("outputs_backtest", task["name"]) + if "exchange" not in task: + task["exchange"] = copy.deepcopy(self._exchange_config) + else: + task["exchange"] = self._complete_exchange_config(task["exchange"]) + task_config.append(task) + + return task_config + + def _complete_exchange_config(self, exchange_config: dict) -> dict: + exchange_config_default = { + "open_cost": 0.0005, + "close_cost": 0.0015, + "min_cost": 5.0, + "trade_unit": 100.0, + "cash_limit": None, + } + exchange_config = merge_a_into_b(a=exchange_config, b=exchange_config_default) + return exchange_config + + def _parse_simulator(self) -> dict: + config = self.raw_config["simulator"] - -def get_backtest_config_fromfile(path: str) -> dict: - backtest_config = parse_backtest_config(path) - - exchange_config_default = { - "open_cost": 0.0005, - "close_cost": 0.0015, - "min_cost": 5.0, - "trade_unit": 100.0, - "cash_limit": None, - } - backtest_config["exchange"] = merge_a_into_b(a=backtest_config["exchange"], b=exchange_config_default) - backtest_config["exchange"] = _convert_all_list_to_tuple(backtest_config["exchange"]) - - backtest_config_default = { - "debug_single_stock": None, - "debug_single_day": None, - "concurrency": -1, - "multiplier": 1.0, - "output_dir": "outputs_backtest/", - "generate_report": False, - "data_granularity": "1min", - } - backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default) - - return backtest_config + return { + "qlib": config["qlib"], + "exchange": self._complete_exchange_config(config["exchange"]), + } class TrainingConfigParser: def __init__(self, path: str) -> None: - self.raw_config = parse_backtest_config(path) + self.raw_config = load_config(path) def parse(self) -> dict: return { @@ -179,7 +192,7 @@ class TrainingConfigParser: "trade_unit": 100.0, # "cash_limit": None, } - exchange_config = {**exchange_config_default, **_convert_all_list_to_tuple(config["exchange"])} + exchange_config = {**exchange_config_default, **config["exchange"]} exchange_config["freq"] = self.raw_config["general"].get("freq", "1min") ret_config = { diff --git a/qlib/rl/contrib/train.py b/qlib/rl/contrib/train.py index ca5f54def..ac7b9156e 100644 --- a/qlib/rl/contrib/train.py +++ b/qlib/rl/contrib/train.py @@ -8,7 +8,7 @@ import random import sys import warnings from pathlib import Path -from typing import Any, cast, List, Optional +from typing import Callable, cast, List, Optional, Sequence import numpy as np import pandas as pd @@ -16,9 +16,10 @@ import torch from qlib.backtest import Order from qlib.backtest.decision import OrderDir from qlib.constant import ONE_MIN +from qlib.rl import Simulator from qlib.rl.contrib.naive_config_parser import TrainingConfigParser from qlib.rl.data.integration import init_qlib -from qlib.rl.data.native import _load_handler_pickle, load_handler_intraday_processed_data +from qlib.rl.data.pickle_styled import load_pickle_intraday_processed_data from qlib.rl.interpreter import ActionInterpreter, StateInterpreter from qlib.rl.order_execution import SingleAssetOrderExecutionSimple from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution @@ -103,14 +104,14 @@ class LazyLoadDataset(Dataset): def __init__( self, data_dir: str, - order_file_path: Path, + order_df: pd.DataFrame, default_start_time_index: int, default_end_time_index: int, ) -> None: self._default_start_time_index = default_start_time_index self._default_end_time_index = default_end_time_index - self._order_df = _read_orders(order_file_path).reset_index() + self._order_df = order_df self._ticks_index: Optional[pd.DatetimeIndex] = None self._data_dir = Path(data_dir) @@ -126,7 +127,7 @@ class LazyLoadDataset(Dataset): # TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index # TODO: of all dates. - data = load_handler_intraday_processed_data( + data = load_pickle_intraday_processed_data( data_dir=self._data_dir, stock_id=row["instrument"], date=date, @@ -147,6 +148,53 @@ class LazyLoadDataset(Dataset): return order +def _split_order_df_by_instrument(df: pd.DataFrame, k: int) -> List[pd.DataFrame]: + df = df.copy() + df["group"] = df["instrument"].apply(lambda s: hash(s) % k) + dfs = [df[df["group"] == i].drop(columns=["group"]) for i in range(k)] + return dfs + + +def _get_simulator_factory( + sim_type: str, + data_dir: Path, + freq_min: int, + simulator_config: dict, +) -> Callable[[Order], Simulator]: + if sim_type == "simple": + + def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple: + simulator = SingleAssetOrderExecutionSimple( + order=order, + data_dir=data_dir, + feature_columns_today=simulator_config["data"]["feature_columns_today"], + data_granularity=freq_min, + ticks_per_step=simulator_config["time_per_step"], + vol_threshold=simulator_config["vol_limit"], + ) + return simulator + + return _simulator_factory_simple + elif sim_type == "full": + init_qlib(simulator_config["qlib"]) + executor_config = get_executor_config(freq_min) + exchange_config = simulator_config["exchange"] + + def _simulator_factory_full(order: Order) -> SingleAssetOrderExecution: + simulator = SingleAssetOrderExecution( + order=order, + executor_config=executor_config, + exchange_config=exchange_config, # `codes` will be set in SingleAssetOrderExecution.__init__() + qlib_config=None, + cash_limit=None, + ) + return simulator + + return _simulator_factory_full + else: + raise ValueError(f"Unknown simulator type: {sim_type}") + + def train_and_test( freq: str, concurrency: int, @@ -160,52 +208,41 @@ def train_and_test( run_training: bool, run_backtest: bool, ) -> None: - freq = _freq_str_to_int(freq) + freq_min: int = _freq_str_to_int(freq) order_root_path = Path(training_config["order_dir"]) feature_root_dir = simulator_config["data"]["feature_root_dir"] - assert simulator_config["data"]["default_start_time_index"] % freq == 0 - assert simulator_config["data"]["default_end_time_index"] % freq == 0 + assert simulator_config["data"]["default_start_time_index"] % freq_min == 0 + assert simulator_config["data"]["default_end_time_index"] % freq_min == 0 - sim_type = simulator_config["type"] - if sim_type == "simple": - - def _simulator_factory(order: Order) -> SingleAssetOrderExecutionSimple: - simulator = SingleAssetOrderExecutionSimple( - order=order, - data_dir=feature_root_dir, - feature_columns_today=simulator_config["data"]["feature_columns_today"], - data_granularity=freq, - ticks_per_step=simulator_config["time_per_step"], - vol_threshold=simulator_config["vol_limit"], - ) - return simulator - - elif sim_type == "full": - init_qlib(simulator_config["qlib"]) - executor_config = get_executor_config(freq) - exchange_config = simulator_config["exchange"] - - def _simulator_factory(order: Order) -> SingleAssetOrderExecution: - simulator = SingleAssetOrderExecution( - order=order, - executor_config=executor_config, - exchange_config={**exchange_config, **{"codes": [order.stock_id]}}, - qlib_config=None, - cash_limit=None, - ) - return simulator + _simulator_factory = _get_simulator_factory( + sim_type=simulator_config["type"], + data_dir=feature_root_dir, + freq_min=freq_min, + simulator_config=simulator_config, + ) + # Load orders + load_data_tags = [] + orders_by_tag = {} if run_training: - train_dataset, valid_dataset = [ + load_data_tags += ["train", "valid"] + if run_backtest: + load_data_tags += ["test"] + for tag in load_data_tags: + order_df = _read_orders(order_root_path / tag).reset_index() + dfs = _split_order_df_by_instrument(order_df, concurrency) + datasets = [ LazyLoadDataset( data_dir=feature_root_dir, - order_file_path=order_root_path / tag, - default_start_time_index=simulator_config["data"]["default_start_time_index"] // freq, - default_end_time_index=simulator_config["data"]["default_end_time_index"] // freq, + order_df=df, + default_start_time_index=simulator_config["data"]["default_start_time_index"] // freq_min, + default_end_time_index=simulator_config["data"]["default_end_time_index"] // freq_min, ) - for tag in ("train", "valid") + for df in dfs ] + orders_by_tag[tag] = datasets + if run_training: callbacks: List[Callback] = [ MetricsWriter(dirpath=Path(training_config["checkpoint_path"])), Checkpoint( @@ -225,7 +262,7 @@ def train_and_test( action_interpreter=action_interpreter, policy=policy, reward=reward, - initial_states=cast(List[Order], train_dataset), + initial_states=cast(List[Sequence[Order]], orders_by_tag["train"]), trainer_kwargs={ "max_iters": training_config["max_epoch"], "finite_env_type": parallel_mode, @@ -239,27 +276,20 @@ def train_and_test( "batch_size": training_config["batch_size"], "repeat": training_config["repeat_per_collect"], }, - "val_initial_states": valid_dataset, + "val_initial_states": cast(List[Sequence[Order]], orders_by_tag["valid"]), }, ) if run_backtest: - test_dataset = LazyLoadDataset( - data_dir=feature_root_dir, - order_file_path=order_root_path / "test", - default_start_time_index=simulator_config["data"]["default_start_time_index"] // freq, - default_end_time_index=simulator_config["data"]["default_end_time_index"] // freq, - ) - backtest( simulator_fn=_simulator_factory, state_interpreter=state_interpreter, action_interpreter=action_interpreter, - initial_states=test_dataset, + initial_states=cast(List[Sequence[Order]], orders_by_tag["test"]), policy=policy, logger=CsvWriter(Path(training_config["checkpoint_path"])), reward=reward, - finite_env_type=parallel_mode, + finite_env_type=parallel_mode, # type: ignore[arg-type] concurrency=concurrency, ) diff --git a/qlib/rl/contrib/train_onpolicy.py b/qlib/rl/contrib/train_onpolicy.py deleted file mode 100644 index cd5d0e55e..000000000 --- a/qlib/rl/contrib/train_onpolicy.py +++ /dev/null @@ -1,268 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import argparse -import os -import random -import sys -import warnings -from pathlib import Path -from typing import cast, List, Optional - -import numpy as np -import pandas as pd -import torch -import yaml -from qlib.backtest import Order -from qlib.backtest.decision import OrderDir -from qlib.constant import ONE_MIN -from qlib.rl.data.native import load_handler_intraday_processed_data -from qlib.rl.interpreter import ActionInterpreter, StateInterpreter -from qlib.rl.order_execution import SingleAssetOrderExecutionSimple -from qlib.rl.reward import Reward -from qlib.rl.trainer import Checkpoint, backtest, train -from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter -from qlib.rl.utils.log import CsvWriter -from qlib.utils import init_instance_by_config -from tianshou.policy import BasePolicy -from torch.utils.data import Dataset - - -def seed_everything(seed: int) -> None: - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - np.random.seed(seed) - random.seed(seed) - torch.backends.cudnn.deterministic = True - - -def _read_orders(order_dir: Path) -> pd.DataFrame: - if os.path.isfile(order_dir): - return pd.read_pickle(order_dir) - else: - orders = [] - for file in order_dir.iterdir(): - order_data = pd.read_pickle(file) - orders.append(order_data) - return pd.concat(orders) - - -class LazyLoadDataset(Dataset): - def __init__( - self, - data_dir: str, - order_file_path: Path, - default_start_time_index: int, - default_end_time_index: int, - ) -> None: - self._default_start_time_index = default_start_time_index - self._default_end_time_index = default_end_time_index - - self._order_df = _read_orders(order_file_path).reset_index() - self._ticks_index: Optional[pd.DatetimeIndex] = None - self._data_dir = Path(data_dir) - - def __len__(self) -> int: - return len(self._order_df) - - def __getitem__(self, index: int) -> Order: - row = self._order_df.iloc[index] - date = pd.Timestamp(str(row["date"])) - - if self._ticks_index is None: - # TODO: We only load ticks index once based on the assumption that ticks index of different dates - # TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index - # TODO: of all dates. - - data = load_handler_intraday_processed_data( - data_dir=self._data_dir, - stock_id=row["instrument"], - date=date, - feature_columns_today=[], - feature_columns_yesterday=[], - backtest=True, - index_only=True, - ) - self._ticks_index = [t - date for t in data.today.index] - - order = Order( - stock_id=row["instrument"], - amount=row["amount"], - direction=OrderDir(int(row["order_type"])), - start_time=date + self._ticks_index[self._default_start_time_index], - end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN, - ) - - return order - - -def train_and_test( - env_config: dict, - simulator_config: dict, - trainer_config: dict, - data_config: dict, - state_interpreter: StateInterpreter, - action_interpreter: ActionInterpreter, - policy: BasePolicy, - reward: Reward, - run_training: bool, - run_backtest: bool, -) -> None: - order_root_path = Path(data_config["source"]["order_dir"]) - - data_granularity = simulator_config.get("data_granularity", 1) - - def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple: - return SingleAssetOrderExecutionSimple( - order=order, - data_dir=data_config["source"]["feature_root_dir"], - feature_columns_today=data_config["source"]["feature_columns_today"], - feature_columns_yesterday=data_config["source"]["feature_columns_yesterday"], - data_granularity=data_granularity, - ticks_per_step=simulator_config["time_per_step"], - vol_threshold=simulator_config["vol_limit"], - ) - - assert data_config["source"]["default_start_time_index"] % data_granularity == 0 - assert data_config["source"]["default_end_time_index"] % data_granularity == 0 - - if run_training: - train_dataset, valid_dataset = [ - LazyLoadDataset( - data_dir=data_config["source"]["feature_root_dir"], - order_file_path=order_root_path / tag, - default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity, - default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity, - ) - for tag in ("train", "valid") - ] - - callbacks: List[Callback] = [] - if "checkpoint_path" in trainer_config: - callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"]))) - callbacks.append( - Checkpoint( - dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints", - every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1), - save_latest="copy", - ), - ) - if "earlystop_patience" in trainer_config: - callbacks.append( - EarlyStopping( - patience=trainer_config["earlystop_patience"], - monitor="val/pa", - ) - ) - - train( - simulator_fn=_simulator_factory_simple, - state_interpreter=state_interpreter, - action_interpreter=action_interpreter, - policy=policy, - reward=reward, - initial_states=cast(List[Order], train_dataset), - trainer_kwargs={ - "max_iters": trainer_config["max_epoch"], - "finite_env_type": env_config["parallel_mode"], - "concurrency": env_config["concurrency"], - "val_every_n_iters": trainer_config.get("val_every_n_epoch", None), - "callbacks": callbacks, - }, - vessel_kwargs={ - "episode_per_iter": trainer_config["episode_per_collect"], - "update_kwargs": { - "batch_size": trainer_config["batch_size"], - "repeat": trainer_config["repeat_per_collect"], - }, - "val_initial_states": valid_dataset, - }, - ) - - if run_backtest: - test_dataset = LazyLoadDataset( - data_dir=data_config["source"]["feature_root_dir"], - order_file_path=order_root_path / "test", - default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity, - default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity, - ) - - backtest( - simulator_fn=_simulator_factory_simple, - state_interpreter=state_interpreter, - action_interpreter=action_interpreter, - initial_states=test_dataset, - policy=policy, - logger=CsvWriter(Path(trainer_config["checkpoint_path"])), - reward=reward, - finite_env_type=env_config["parallel_mode"], - concurrency=env_config["concurrency"], - ) - - -def main(config: dict, run_training: bool, run_backtest: bool) -> None: - if not run_training and not run_backtest: - warnings.warn("Skip the entire job since training and backtest are both skipped.") - return - - if "seed" in config["runtime"]: - seed_everything(config["runtime"]["seed"]) - - for extra_module_path in config["env"].get("extra_module_paths", []): - sys.path.append(extra_module_path) - - state_interpreter: StateInterpreter = init_instance_by_config(config["state_interpreter"]) - action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"]) - reward: Reward = init_instance_by_config(config["reward"]) - - additional_policy_kwargs = { - "obs_space": state_interpreter.observation_space, - "action_space": action_interpreter.action_space, - } - - # Create torch network - if "network" in config: - if "kwargs" not in config["network"]: - config["network"]["kwargs"] = {} - config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space}) - additional_policy_kwargs["network"] = init_instance_by_config(config["network"]) - - # Create policy - if "kwargs" not in config["policy"]: - config["policy"]["kwargs"] = {} - config["policy"]["kwargs"].update(additional_policy_kwargs) - policy: BasePolicy = init_instance_by_config(config["policy"]) - - use_cuda = config["runtime"].get("use_cuda", False) - if use_cuda: - policy.cuda() - - train_and_test( - env_config=config["env"], - simulator_config=config["simulator"], - data_config=config["data"], - trainer_config=config["trainer"], - action_interpreter=action_interpreter, - state_interpreter=state_interpreter, - policy=policy, - reward=reward, - run_training=run_training, - run_backtest=run_backtest, - ) - - -if __name__ == "__main__": - warnings.filterwarnings("ignore", category=DeprecationWarning) - warnings.filterwarnings("ignore", category=RuntimeWarning) - - parser = argparse.ArgumentParser() - parser.add_argument("--config_path", type=str, required=True, help="Path to the config file") - parser.add_argument("--no_training", action="store_true", help="Skip training workflow.") - parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.") - args = parser.parse_args() - - with open(args.config_path, "r") as input_stream: - config = yaml.safe_load(input_stream) - - main(config, run_training=not args.no_training, run_backtest=args.run_backtest) diff --git a/qlib/rl/contrib/train_onpolicy_full_simulation.py b/qlib/rl/contrib/train_onpolicy_full_simulation.py deleted file mode 100644 index bd7732786..000000000 --- a/qlib/rl/contrib/train_onpolicy_full_simulation.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import argparse -import os -import random -import sys -import warnings -from pathlib import Path -from typing import Any, cast, List, Optional - -import numpy as np -import pandas as pd -import torch -import yaml -from qlib.backtest import Order -from qlib.backtest.decision import OrderDir -from qlib.constant import ONE_MIN -from qlib.rl.contrib.naive_config_parser import parse_backtest_config -from qlib.rl.data.integration import init_qlib -from qlib.rl.data.native import load_handler_intraday_processed_data -from qlib.rl.interpreter import ActionInterpreter, StateInterpreter -from qlib.rl.order_execution import SingleAssetOrderExecutionSimple -from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution -from qlib.rl.reward import Reward -from qlib.rl.trainer import Checkpoint, backtest, train -from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter -from qlib.rl.utils.log import CsvWriter -from qlib.utils import init_instance_by_config -from tianshou.policy import BasePolicy -from torch.utils.data import Dataset - - -def get_executor_config(data_granularity: int = 1) -> dict: - return { - "class": "NestedExecutor", - "module_path": "qlib.backtest.executor", - "kwargs": { - "inner_executor": { - "class": "NestedExecutor", - "module_path": "qlib.backtest.executor", - "kwargs": { - "inner_executor": { - "class": "SimulatorExecutor", - "module_path": "qlib.backtest.executor", - "kwargs": { - "generate_report": False, - "time_per_step": f"{data_granularity}min", - "track_data": True, - "trade_type": "serial", - "verbose": False, - }, - }, - "inner_strategy": { - "class": "TWAPStrategy", - "kwargs": {}, - "module_path": "qlib.contrib.strategy.rule_strategy", - }, - "time_per_step": "30min", - "track_data": True, - }, - }, - "inner_strategy": { - "class": "ProxySAOEStrategy", - "module_path": "qlib.rl.order_execution.strategy", - "kwargs": {}, - }, - "time_per_step": "1day", - "track_data": True, - }, - } - - -def _convert_list_to_tuple(obj: Any) -> Any: - if isinstance(obj, dict): - return {k: _convert_list_to_tuple(v) for k, v in obj.items()} - elif isinstance(obj, list): - return tuple(obj) - else: - return obj - - -def seed_everything(seed: int) -> None: - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - np.random.seed(seed) - random.seed(seed) - torch.backends.cudnn.deterministic = True - - -def _read_orders(order_dir: Path) -> pd.DataFrame: - if os.path.isfile(order_dir): - return pd.read_pickle(order_dir) - else: - orders = [] - for file in order_dir.iterdir(): - order_data = pd.read_pickle(file) - orders.append(order_data) - return pd.concat(orders) - - -class LazyLoadDataset(Dataset): - def __init__( - self, - data_dir: str, - order_file_path: Path, - default_start_time_index: int, - default_end_time_index: int, - ) -> None: - self._default_start_time_index = default_start_time_index - self._default_end_time_index = default_end_time_index - - self._order_df = _read_orders(order_file_path).reset_index() - self._ticks_index: Optional[pd.DatetimeIndex] = None - self._data_dir = Path(data_dir) - - def __len__(self) -> int: - return len(self._order_df) - - def __getitem__(self, index: int) -> Order: - row = self._order_df.iloc[index] - date = pd.Timestamp(str(row["date"])) - - if self._ticks_index is None: - # TODO: We only load ticks index once based on the assumption that ticks index of different dates - # TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index - # TODO: of all dates. - - data = load_handler_intraday_processed_data( - data_dir=self._data_dir, - stock_id=row["instrument"], - date=date, - feature_columns_today=[], - feature_columns_yesterday=[], - backtest=True, - index_only=True, - ) - self._ticks_index = [t - date for t in data.today.index] - - order = Order( - stock_id=row["instrument"], - amount=row["amount"], - direction=OrderDir(int(row["order_type"])), - start_time=date + self._ticks_index[self._default_start_time_index], - end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN, - ) - - return order - - -def train_and_test( - env_config: dict, - trainer_config: dict, - data_config: dict, - exchange_config: dict, - qlib_config: dict, - state_interpreter: StateInterpreter, - action_interpreter: ActionInterpreter, - policy: BasePolicy, - reward: Reward, - run_training: bool, - run_backtest: bool, -) -> None: - init_qlib(qlib_config) - - order_root_path = Path(data_config["source"]["order_dir"]) - - data_granularity = 1 # simulator_config.get("data_granularity", 1) - - exchange_config_default = { - "open_cost": 0.0005, - "close_cost": 0.0015, - "min_cost": 5.0, - "trade_unit": 100.0, - # "cash_limit": None, - } - exchange_config = {**exchange_config_default, **exchange_config} - exchange_config = _convert_list_to_tuple(exchange_config) - - def _simulator_factory(order: Order) -> SingleAssetOrderExecution: - simulator = SingleAssetOrderExecution( - order=order, - executor_config=get_executor_config(data_granularity), - exchange_config={**exchange_config, **{"codes": [order.stock_id]}}, - qlib_config=None, - cash_limit=None, - ) - return simulator - - assert data_config["source"]["default_start_time_index"] % data_granularity == 0 - assert data_config["source"]["default_end_time_index"] % data_granularity == 0 - - if run_training: - train_dataset, valid_dataset = [ - LazyLoadDataset( - data_dir=data_config["source"]["feature_root_dir"], - order_file_path=order_root_path / tag, - default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity, - default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity, - ) - for tag in ("train", "valid") - ] - - callbacks: List[Callback] = [] - if "checkpoint_path" in trainer_config: - callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"]))) - callbacks.append( - Checkpoint( - dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints", - every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1), - save_latest="copy", - ), - ) - if "earlystop_patience" in trainer_config: - callbacks.append( - EarlyStopping( - patience=trainer_config["earlystop_patience"], - monitor="val/pa", - ) - ) - - train( - simulator_fn=_simulator_factory, - state_interpreter=state_interpreter, - action_interpreter=action_interpreter, - policy=policy, - reward=reward, - initial_states=cast(List[Order], train_dataset), - trainer_kwargs={ - "max_iters": trainer_config["max_epoch"], - "finite_env_type": env_config["parallel_mode"], - "concurrency": env_config["concurrency"], - "val_every_n_iters": trainer_config.get("val_every_n_epoch", None), - "callbacks": callbacks, - }, - vessel_kwargs={ - "episode_per_iter": trainer_config["episode_per_collect"], - "update_kwargs": { - "batch_size": trainer_config["batch_size"], - "repeat": trainer_config["repeat_per_collect"], - }, - "val_initial_states": valid_dataset, - }, - ) - - if run_backtest: - test_dataset = LazyLoadDataset( - data_dir=data_config["source"]["feature_root_dir"], - order_file_path=order_root_path / "test", - default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity, - default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity, - ) - - backtest( - simulator_fn=_simulator_factory, - state_interpreter=state_interpreter, - action_interpreter=action_interpreter, - initial_states=test_dataset, - policy=policy, - logger=CsvWriter(Path(trainer_config["checkpoint_path"])), - reward=reward, - finite_env_type=env_config["parallel_mode"], - concurrency=env_config["concurrency"], - ) - - -def main(config: dict, run_training: bool, run_backtest: bool) -> None: - if not run_training and not run_backtest: - warnings.warn("Skip the entire job since training and backtest are both skipped.") - return - - if "seed" in config["runtime"]: - seed_everything(config["runtime"]["seed"]) - - for extra_module_path in config["env"].get("extra_module_paths", []): - sys.path.append(extra_module_path) - - state_interpreter: StateInterpreter = init_instance_by_config(config["state_interpreter"]) - action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"]) - reward: Reward = init_instance_by_config(config["reward"]) - - additional_policy_kwargs = { - "obs_space": state_interpreter.observation_space, - "action_space": action_interpreter.action_space, - } - - # Create torch network - if "network" in config: - if "kwargs" not in config["network"]: - config["network"]["kwargs"] = {} - config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space}) - additional_policy_kwargs["network"] = init_instance_by_config(config["network"]) - - # Create policy - if "kwargs" not in config["policy"]: - config["policy"]["kwargs"] = {} - config["policy"]["kwargs"].update(additional_policy_kwargs) - policy: BasePolicy = init_instance_by_config(config["policy"]) - - use_cuda = config["runtime"].get("use_cuda", False) - if use_cuda: - policy.cuda() - - train_and_test( - env_config=config["env"], - data_config=config["data"], - exchange_config=config["exchange"], - qlib_config=config["qlib"], - trainer_config=config["trainer"], - action_interpreter=action_interpreter, - state_interpreter=state_interpreter, - policy=policy, - reward=reward, - run_training=run_training, - run_backtest=run_backtest, - ) - - -if __name__ == "__main__": - warnings.filterwarnings("ignore", category=DeprecationWarning) - warnings.filterwarnings("ignore", category=RuntimeWarning) - - parser = argparse.ArgumentParser() - parser.add_argument("--config_path", type=str, required=True, help="Path to the config file") - parser.add_argument("--no_training", action="store_true", help="Skip training workflow.") - parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.") - args = parser.parse_args() - - config = parse_backtest_config(args.config_path) - main(config, run_training=not args.no_training, run_backtest=args.run_backtest) diff --git a/qlib/rl/data/native.py b/qlib/rl/data/native.py index 04dad778b..5b701ec9c 100644 --- a/qlib/rl/data/native.py +++ b/qlib/rl/data/native.py @@ -13,6 +13,7 @@ import os from qlib.backtest import Exchange, Order from qlib.backtest.decision import TradeRange, TradeRangeByTime from qlib.constant import EPS_T +from qlib.data.dataset import DatasetH from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider @@ -144,7 +145,7 @@ def load_backtest_data( cache=cachetools.LRUCache(1000), key=lambda path: path, ) -def _load_handler_pickle(path: str) -> object: +def _load_handler_pickle(path: str) -> DatasetH: with open(path, "rb") as fstream: obj = pickle.load(fstream) return obj diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index 4905b026a..63e11db2e 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -26,7 +26,6 @@ from typing import List, Sequence, cast import cachetools import numpy as np import pandas as pd -from cachetools.keys import hashkey from qlib.backtest.decision import Order, OrderDir from qlib.rl.data.base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider @@ -158,6 +157,15 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData): return cast(pd.DatetimeIndex, self.data.index) +@cachetools.cached( # type: ignore + cache=cachetools.LRUCache(1000), + key=lambda path: path, +) +def _load_df_pickle(path: str) -> pd.DataFrame: + df = pd.read_pickle(path) + return df + + class PickleIntradayProcessedData(BaseIntradayProcessedData): """Subclass of IntradayProcessedData. Used to handle pickle-styled data.""" @@ -166,36 +174,18 @@ class PickleIntradayProcessedData(BaseIntradayProcessedData): data_dir: Path | str, stock_id: str, date: pd.Timestamp, - feature_dim: int, - time_index: pd.Index, + feature_columns_today: List[str], + feature_columns_yesterday: List[str], + backtest: bool, ) -> None: - proc = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id) + if isinstance(data_dir, str): + data_dir = Path(data_dir) + path = data_dir / ("backtest" if backtest else "feature") / f"{stock_id}.pkl" + df = _load_df_pickle(str(path)) + df = df.loc[pd.IndexSlice[stock_id, :, date]] - # We have to infer the names here because, - # unfortunately they are not included in the original data. - cnames = _infer_processed_data_column_names(feature_dim) - - time_length: int = len(time_index) - - try: - # new data format - proc = proc.loc[pd.IndexSlice[stock_id, :, date]] - assert len(proc) == time_length and len(proc.columns) == feature_dim * 2 - proc_today = proc[cnames] - proc_yesterday = proc[[f"{c}_1" for c in cnames]].rename(columns=lambda c: c[:-2]) - except (IndexError, KeyError): - # legacy data - proc = proc.loc[pd.IndexSlice[stock_id, date]] - assert time_length * feature_dim * 2 == len(proc) - proc_today = proc.to_numpy()[: time_length * feature_dim].reshape((time_length, feature_dim)) - proc_yesterday = proc.to_numpy()[time_length * feature_dim :].reshape((time_length, feature_dim)) - proc_today = pd.DataFrame(proc_today, index=time_index, columns=cnames) - proc_yesterday = pd.DataFrame(proc_yesterday, index=time_index, columns=cnames) - - self.today: pd.DataFrame = proc_today - self.yesterday: pd.DataFrame = proc_yesterday - assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim - assert len(self.today) == len(self.yesterday) == time_length + self.today = df[feature_columns_today] + self.yesterday = df[feature_columns_yesterday] def __repr__(self) -> str: with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"): @@ -213,25 +203,38 @@ def load_simple_intraday_backtest_data( return SimpleIntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir) -@cachetools.cached( # type: ignore - cache=cachetools.LRUCache(100), # 100 * 50K = 5MB - key=lambda data_dir, stock_id, date, feature_dim, time_index: hashkey(data_dir, stock_id, date), -) def load_pickle_intraday_processed_data( data_dir: Path, stock_id: str, date: pd.Timestamp, - feature_dim: int, - time_index: pd.Index, + feature_columns_today: List[str], + feature_columns_yesterday: List[str], + backtest: bool = False, ) -> BaseIntradayProcessedData: - return PickleIntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index) + return PickleIntradayProcessedData( + data_dir, + stock_id, + date, + feature_columns_today, + feature_columns_yesterday, + backtest, + ) class PickleProcessedDataProvider(ProcessedDataProvider): - def __init__(self, data_dir: Path) -> None: + def __init__( + self, + data_dir: Path, + feature_columns_today: List[str], + feature_columns_yesterday: List[str], + backtest: bool = False, + ) -> None: super().__init__() self._data_dir = data_dir + self._backtest = backtest + self._feature_columns_today = feature_columns_today + self._feature_columns_yesterday = feature_columns_yesterday def get_data( self, @@ -244,8 +247,9 @@ class PickleProcessedDataProvider(ProcessedDataProvider): data_dir=self._data_dir, stock_id=stock_id, date=date, - feature_dim=feature_dim, - time_index=time_index, + feature_columns_today=self._feature_columns_today, + feature_columns_yesterday=self._feature_columns_yesterday, + backtest=self._backtest, ) diff --git a/qlib/rl/order_execution/simulator_qlib.py b/qlib/rl/order_execution/simulator_qlib.py index e2d7defcc..bbcf72d7f 100644 --- a/qlib/rl/order_execution/simulator_qlib.py +++ b/qlib/rl/order_execution/simulator_qlib.py @@ -4,10 +4,11 @@ from __future__ import annotations from typing import Generator, List, Optional +import cachetools import pandas as pd -from qlib.backtest import collect_data_loop, get_strategy_executor +from qlib.backtest import collect_data_loop, Exchange, get_exchange, get_strategy_executor from qlib.backtest.decision import BaseTradeDecision, Order, TradeRangeByTime from qlib.backtest.executor import NestedExecutor from qlib.rl.data.integration import init_qlib @@ -16,6 +17,18 @@ from .state import SAOEState from .strategy import SAOEStateAdapter, SAOEStrategy +@cachetools.cached( # type: ignore + cache=cachetools.LRUCache(1000), + key=lambda order, _: order.stock_id, +) +def _create_exchange(order: Order, exchange_config: dict) -> Exchange: + exchange_kwargs = { + **exchange_config, + "codes": [order.stock_id], + } + return get_exchange(**exchange_kwargs) + + class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): """Single-asset order execution (SAOE) simulator which is implemented based on Qlib backtest tools. @@ -76,7 +89,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): executor=executor_config, benchmark=order.stock_id, account=cash_limit if cash_limit is not None else int(1e12), - exchange_kwargs=exchange_config, + exchange_kwargs=_create_exchange(order, exchange_config), pos_type="Position" if cash_limit is not None else "InfPosition", ) diff --git a/qlib/rl/order_execution/simulator_simple.py b/qlib/rl/order_execution/simulator_simple.py index cdfbd2098..808d26bf9 100644 --- a/qlib/rl/order_execution/simulator_simple.py +++ b/qlib/rl/order_execution/simulator_simple.py @@ -12,7 +12,8 @@ from pathlib import Path from qlib.backtest.decision import Order, OrderDir from qlib.constant import EPS, EPS_T, float_or_ndarray from qlib.rl.data.base import BaseIntradayBacktestData -from qlib.rl.data.native import DataframeIntradayBacktestData, load_handler_intraday_processed_data +from qlib.rl.data.native import DataframeIntradayBacktestData +from qlib.rl.data.pickle_styled import load_pickle_intraday_processed_data from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data from qlib.rl.simulator import Simulator from qlib.rl.utils import LogLevel @@ -118,7 +119,7 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]): def get_backtest_data(self) -> BaseIntradayBacktestData: try: - data = load_handler_intraday_processed_data( + data = load_pickle_intraday_processed_data( data_dir=self.data_dir, stock_id=self.order.stock_id, date=pd.Timestamp(self.order.start_time.date()), diff --git a/qlib/rl/order_execution/strategy.py b/qlib/rl/order_execution/strategy.py index 7e66a1f08..1ad6e5fb4 100644 --- a/qlib/rl/order_execution/strategy.py +++ b/qlib/rl/order_execution/strategy.py @@ -451,6 +451,7 @@ class SAOEIntStrategy(SAOEStrategy): state_interpreter: dict | StateInterpreter, action_interpreter: dict | ActionInterpreter, network: dict | torch.nn.Module | None = None, + immediate_addition: bool = False, outer_trade_decision: BaseTradeDecision | None = None, level_infra: LevelInfrastructure | None = None, common_infra: CommonInfrastructure | None = None, @@ -501,9 +502,12 @@ class SAOEIntStrategy(SAOEStrategy): if self._policy is not None: self._policy.eval() + + self.immediate_addition = immediate_addition def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None: super().reset(outer_trade_decision=outer_trade_decision, **kwargs) + self.trade_amount_planned = collections.defaultdict(float) def _generate_trade_details(self, act: np.ndarray, exec_vols: List[float]) -> pd.DataFrame: assert hasattr(self.outer_trade_decision, "order_list") @@ -539,9 +543,15 @@ class SAOEIntStrategy(SAOEStrategy): oh = self.trade_exchange.get_order_helper() order_list = [] - for decision, exec_vol in zip(self.outer_trade_decision.get_decision(), exec_vols): + for decision, exec_vol, state in zip(self.outer_trade_decision.get_decision(), exec_vols, states): + order = cast(Order, decision) + if self.immediate_addition: + self.trade_amount_planned[order.stock_id] += exec_vol + amount_planned = self.trade_amount_planned[order.stock_id] + amount_finished = order.amount - state.position + exec_vol = min(state.position, amount_planned - amount_finished) + if exec_vol != 0: - order = cast(Order, decision) order_list.append(oh.create(order.stock_id, exec_vol, order.direction)) return TradeDecisionWithDetails( diff --git a/qlib/rl/trainer/api.py b/qlib/rl/trainer/api.py index aea99dc3d..238186dd8 100644 --- a/qlib/rl/trainer/api.py +++ b/qlib/rl/trainer/api.py @@ -20,7 +20,7 @@ def train( simulator_fn: Callable[[InitialStateType], Simulator], state_interpreter: StateInterpreter, action_interpreter: ActionInterpreter, - initial_states: Sequence[InitialStateType], + initial_states: List[Sequence[InitialStateType]], policy: BasePolicy, reward: Reward, vessel_kwargs: Dict[str, Any], @@ -39,7 +39,9 @@ def train( action_interpreter Interprets the policy actions. initial_states - Initial states to iterate over. Every state will be run exactly once. + List of Initial state iterators to iterate over. There should be 1 or `concurrency` initial state iterators in + the list. If there is only 1 initial state iterator, this iterator will be shared by all workers and every + state will be run exactly once. Otherwise, every worker will have its own iterator. policy Policy to train against. reward @@ -67,7 +69,7 @@ def backtest( simulator_fn: Callable[[InitialStateType], Simulator], state_interpreter: StateInterpreter, action_interpreter: ActionInterpreter, - initial_states: Sequence[InitialStateType], + initial_states: List[Sequence[InitialStateType]], policy: BasePolicy, logger: LogWriter | List[LogWriter], reward: Reward | None = None, @@ -87,7 +89,9 @@ def backtest( action_interpreter Interprets the policy actions. initial_states - Initial states to iterate over. Every state will be run exactly once. + List of Initial state iterators to iterate over. There should be 1 or `concurrency` initial state iterators in + the list. If there is only 1 initial state iterator, this iterator will be shared by all workers and every + state will be run exactly once. Otherwise, every worker will have its own iterator. policy Policy to test against. logger diff --git a/qlib/rl/trainer/trainer.py b/qlib/rl/trainer/trainer.py index 9d457f82d..c7d7def39 100644 --- a/qlib/rl/trainer/trainer.py +++ b/qlib/rl/trainer/trainer.py @@ -5,8 +5,9 @@ from __future__ import annotations import collections import copy -from contextlib import AbstractContextManager, contextmanager +from contextlib import AbstractContextManager, ExitStack, contextmanager from datetime import datetime +from functools import partial from pathlib import Path from typing import Any, Dict, Iterable, List, OrderedDict, Sequence, TypeVar, cast @@ -206,46 +207,50 @@ class Trainer: self._call_callback_hooks("on_fit_start") - while not self.should_stop: - msg = f"\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tTrain iteration {self.current_iter + 1}/{self.max_iters}" - print(msg) - _logger.info(msg) + with _wrap_context(vessel.train_seed_iterators()) as train_iterators, _wrap_context( + vessel.val_seed_iterators() + ) as valid_iterators: + train_vector_env = self.venv_from_iterator(train_iterators) + valid_vector_env = self.venv_from_iterator(valid_iterators) - self.initialize_iter() + while not self.should_stop: + msg = f"\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tTrain iteration {self.current_iter + 1}/{self.max_iters}" + print(msg) + _logger.info(msg) - self._call_callback_hooks("on_iter_start") + self.initialize_iter() - self.current_stage = "train" - self._call_callback_hooks("on_train_start") + self._call_callback_hooks("on_iter_start") - # TODO - # Add a feature that supports reloading the training environment every few iterations. - with _wrap_context(vessel.train_seed_iterator()) as iterator: - vector_env = self.venv_from_iterator(iterator) - self.vessel.train(vector_env) - del vector_env # FIXME: Explicitly delete this object to avoid memory leak. + self.current_stage = "train" + self._call_callback_hooks("on_train_start") - self._call_callback_hooks("on_train_end") + # TODO + # Add a feature that supports reloading the training environment every few iterations. + self.vessel.train(train_vector_env) - if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0: - # Implementation of validation loop - self.current_stage = "val" - self._call_callback_hooks("on_validate_start") - with _wrap_context(vessel.val_seed_iterator()) as iterator: - vector_env = self.venv_from_iterator(iterator) - self.vessel.validate(vector_env) - del vector_env # FIXME: Explicitly delete this object to avoid memory leak. + self._call_callback_hooks("on_train_end") - self._call_callback_hooks("on_validate_end") + if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0: + # Implementation of validation loop + self.current_stage = "val" + self._call_callback_hooks("on_validate_start") - # This iteration is considered complete. - # Bumping the current iteration counter. - self.current_iter += 1 + self.vessel.validate(valid_vector_env) - if self.max_iters is not None and self.current_iter >= self.max_iters: - self.should_stop = True + self._call_callback_hooks("on_validate_end") - self._call_callback_hooks("on_iter_end") + # This iteration is considered complete. + # Bumping the current iteration counter. + self.current_iter += 1 + + if self.max_iters is not None and self.current_iter >= self.max_iters: + self.should_stop = True + + self._call_callback_hooks("on_iter_end") + + del train_vector_env # FIXME: Explicitly delete this object to avoid memory leak. + del valid_vector_env # FIXME: Explicitly delete this object to avoid memory leak. self._call_callback_hooks("on_fit_end") @@ -266,16 +271,16 @@ class Trainer: self.current_stage = "test" self._call_callback_hooks("on_test_start") - with _wrap_context(vessel.test_seed_iterator()) as iterator: - vector_env = self.venv_from_iterator(iterator) + with _wrap_context(vessel.test_seed_iterators()) as iterators: + vector_env = self.venv_from_iterator(iterators) self.vessel.test(vector_env) del vector_env # FIXME: Explicitly delete this object to avoid memory leak. self._call_callback_hooks("on_test_end") - def venv_from_iterator(self, iterator: Iterable[InitialStateType]) -> FiniteVectorEnv: + def venv_from_iterator(self, iterators: List[Iterable[InitialStateType]]) -> FiniteVectorEnv: """Create a vectorized environment from iterator and the training vessel.""" - def env_factory(): + def env_factory(iterator): # FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env), # and could be thread unsafe. # I'm not sure whether it's a design flaw. @@ -301,7 +306,7 @@ class Trainer: ) return vectorize_env( - env_factory, + [partial(env_factory, iterator=it) for it in iterators], self.finite_env_type, self.concurrency, self.loggers, @@ -335,8 +340,11 @@ class Trainer: @contextmanager def _wrap_context(obj): """Make any object a (possibly dummy) context manager.""" - - if isinstance(obj, AbstractContextManager): + if isinstance(obj, list) and isinstance(obj[0], AbstractContextManager): + with ExitStack() as stack: + yield [stack.enter_context(e) for e in obj] + stack.pop_all().close() + elif isinstance(obj, AbstractContextManager): # obj has __enter__ and __exit__ with obj as ctx: yield ctx diff --git a/qlib/rl/trainer/vessel.py b/qlib/rl/trainer/vessel.py index 6cd2eb3e9..027f57a18 100644 --- a/qlib/rl/trainer/vessel.py +++ b/qlib/rl/trainer/vessel.py @@ -4,7 +4,7 @@ from __future__ import annotations import weakref -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generic, Iterable, Sequence, TypeVar, cast +from typing import List, TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generic, Iterable, Sequence, TypeVar, cast import numpy as np from tianshou.data import Collector, VectorReplayBuffer @@ -49,19 +49,23 @@ class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType, def assign_trainer(self, trainer: Trainer) -> None: self.trainer = weakref.proxy(trainer) # type: ignore - def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: - """Override this to create a seed iterator for training. + def train_seed_iterators( + self, + ) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]: + """Override this to create a seed iterators for training. If the iterable is a context manager, the whole training will be invoked in the with-block, and the iterator will be automatically closed after the training is done.""" - raise SeedIteratorNotAvailable("Seed iterator for training is not available.") + raise SeedIteratorNotAvailable("Seed iterators for training is not available.") - def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: - """Override this to create a seed iterator for validation.""" - raise SeedIteratorNotAvailable("Seed iterator for validation is not available.") + def val_seed_iterators(self) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]: + """Override this to create a seed iterators for validation.""" + raise SeedIteratorNotAvailable("Seed iterators for validation is not available.") - def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: - """Override this to create a seed iterator for testing.""" - raise SeedIteratorNotAvailable("Seed iterator for testing is not available.") + def test_seed_iterators( + self, + ) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]: + """Override this to create a seed iterators for testing.""" + raise SeedIteratorNotAvailable("Seed iterators for testing is not available.") def train(self, vector_env: BaseVectorEnv) -> Dict[str, Any]: """Implement this to train one iteration. In RL, one iteration usually refers to one collect.""" @@ -120,9 +124,9 @@ class TrainingVessel(TrainingVesselBase): action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType], policy: BasePolicy, reward: Reward, - train_initial_states: Sequence[InitialStateType] | None = None, - val_initial_states: Sequence[InitialStateType] | None = None, - test_initial_states: Sequence[InitialStateType] | None = None, + train_initial_states: List[Sequence[InitialStateType]] | None = None, + val_initial_states: List[Sequence[InitialStateType]] | None = None, + test_initial_states: List[Sequence[InitialStateType]] | None = None, buffer_size: int = 20000, episode_per_iter: int = 1000, update_kwargs: Dict[str, Any] = cast(Dict[str, Any], None), @@ -132,34 +136,49 @@ class TrainingVessel(TrainingVesselBase): self.action_interpreter = action_interpreter self.policy = policy self.reward = reward - self.train_initial_states = train_initial_states - self.val_initial_states = val_initial_states - self.test_initial_states = test_initial_states + self.train_initial_states = None if train_initial_states is None else train_initial_states + self.val_initial_states = None if val_initial_states is None else val_initial_states + self.test_initial_states = None if test_initial_states is None else test_initial_states self.buffer_size = buffer_size self.episode_per_iter = episode_per_iter self.update_kwargs = update_kwargs or {} - def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: + def train_seed_iterators( + self, + ) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]: if self.train_initial_states is not None: - _logger.info("Training initial states collection size: %d", len(self.train_initial_states)) - # Implement fast_dev_run here. - train_initial_states = self._random_subset("train", self.train_initial_states, self.trainer.fast_dev_run) - return DataQueue(train_initial_states, repeat=-1, shuffle=True) - return super().train_seed_iterator() + _logger.info(f"Training initial states collection sizes: {[len(e) for e in self.train_initial_states]}") + train_initial_states = [ + self._random_subset("train", e, self.trainer.fast_dev_run) for e in self.train_initial_states + ] + iterators = [DataQueue(e, repeat=-1, shuffle=True) for e in train_initial_states] + return cast(List[Iterable[InitialStateType]], iterators) + else: + return super().train_seed_iterators() - def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: + def val_seed_iterators(self) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]: if self.val_initial_states is not None: - _logger.info("Validation initial states collection size: %d", len(self.val_initial_states)) - val_initial_states = self._random_subset("val", self.val_initial_states, self.trainer.fast_dev_run) - return DataQueue(val_initial_states, repeat=1) - return super().val_seed_iterator() + _logger.info(f"Validation initial states collection sizes: {[len(e) for e in self.val_initial_states]}") + val_initial_states = [ + self._random_subset("val", e, self.trainer.fast_dev_run) for e in self.val_initial_states + ] + iterators = [DataQueue(e, repeat=1) for e in val_initial_states] + return cast(List[Iterable[InitialStateType]], iterators) + else: + return super().val_seed_iterators() - def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: + def test_seed_iterators( + self, + ) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]: if self.test_initial_states is not None: - _logger.info("Testing initial states collection size: %d", len(self.test_initial_states)) - test_initial_states = self._random_subset("test", self.test_initial_states, self.trainer.fast_dev_run) - return DataQueue(test_initial_states, repeat=1) - return super().test_seed_iterator() + _logger.info(f"Testing initial states collection sizes: {[len(e) for e in self.test_initial_states]}") + test_initial_states = [ + self._random_subset("test", e, self.trainer.fast_dev_run) for e in self.test_initial_states + ] + iterators = [DataQueue(e, repeat=1) for e in test_initial_states] + return cast(List[Iterable[InitialStateType]], iterators) + else: + return super().test_seed_iterators() def train(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]: """Create a collector and collects ``episode_per_iter`` episodes. diff --git a/qlib/rl/utils/finite_env.py b/qlib/rl/utils/finite_env.py index 87f0900e1..50a83194e 100644 --- a/qlib/rl/utils/finite_env.py +++ b/qlib/rl/utils/finite_env.py @@ -258,6 +258,46 @@ class FiniteVectorEnv(BaseVectorEnv): return np.stack(obs) + def step2( + self, + action: np.ndarray, + id: int | List[int] | np.ndarray | None = None, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + assert not self._zombie + wrapped_id = self._wrap_id(id) + id2idx = {i: k for k, i in enumerate(wrapped_id)} + request_id = list(filter(lambda i: i in self._alive_env_ids, wrapped_id)) + result = {} + + # ask super to step alive envs and remap to current index + if request_id: + valid_act = np.stack([action[id2idx[i]] for i in request_id]) + tmp = super().step(valid_act, request_id) + + for obs_next, rew, done, info in zip(*tmp): + obs_next = self._postproc_env_obs(obs_next) + result[info["env_id"]] = [obs_next, rew, done, info] + + # logging + for i, r in result.items(): + if i in self._alive_env_ids and r[0] is not None: + for logger in self._logger: + logger.on_env_step(i, *r) + + for _, reward, __, info in result.values(): + self._set_default_info(info) + self._set_default_rew(reward) + for r in result.values(): + if r[0] is None: + r[0] = self._get_default_obs() + if r[1] is None: + r[1] = self._get_default_rew() + if r[3] is None: + r[3] = self._get_default_info() + + ret = list(map(np.stack, zip(*result.values()))) + return cast(Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], ret) + def step( self, action: np.ndarray, @@ -311,7 +351,7 @@ class FiniteShmemVectorEnv(FiniteVectorEnv, ShmemVectorEnv): def vectorize_env( - env_factory: Callable[..., gym.Env], + env_factories: List[Callable[..., gym.Env]], env_type: FiniteEnvType, concurrency: int, logger: LogWriter | List[LogWriter], @@ -334,9 +374,10 @@ def vectorize_env( Parameters ---------- - env_factory - Callable to instantiate one single ``gym.Env``. - All concurrent workers will have the same ``env_factory``. + env_factories + Callables to instantiate one single ``gym.Env``. + There should be 1 or `concurrency` env_factories. If there is 1 env_factory, all concurrent workers will have + the same env_factory. Otherwise, each worker will have its own env_factory. env_type dummy or subproc or shmem. Corresponding to `parallelism in tianshou `_. @@ -358,6 +399,8 @@ def vectorize_env( def env_factory(): ... vectorize_env(env_factory, ...) """ + assert len(env_factories) in (1, concurrency) + env_type_cls_mapping: Dict[str, Type[FiniteVectorEnv]] = { "dummy": FiniteDummyVectorEnv, "subproc": FiniteSubprocVectorEnv, @@ -366,4 +409,7 @@ def vectorize_env( finite_env_cls = env_type_cls_mapping[env_type] - return finite_env_cls(logger, [env_factory for _ in range(concurrency)]) + if len(env_factories) == 1: + return finite_env_cls(logger, [env_factories[0] for _ in range(concurrency)]) + else: + return finite_env_cls(logger, env_factories) diff --git a/qlib/rl/utils/profiling.py b/qlib/rl/utils/profiling.py index 05792b8ec..24bb9e5da 100644 --- a/qlib/rl/utils/profiling.py +++ b/qlib/rl/utils/profiling.py @@ -1,17 +1,25 @@ import time from contextlib import contextmanager +from typing import Callable, Generator + from line_profiler import LineProfiler @contextmanager -def simple_perf(desc: str = ""): +def simple_perf(desc: str = "", out_path: str = None) -> Generator[None, None, None]: s = time.perf_counter() yield e = time.perf_counter() - print(f"{desc}: {(e - s) * 1000.0} ms") + msg = f"{desc}: {(e - s) * 1000.0:.4f} ms" + + if out_path is not None: + with open(out_path, "a") as fstream: + fstream.write(msg + "\n") + else: + print(msg) -def lprofile(func): +def lprofile(func: Callable) -> Callable: def wrapper(*args, **kwargs): lp = LineProfiler() lpw = lp(func)