diff --git a/.github/workflows/test_qlib_from_source.yml b/.github/workflows/test_qlib_from_source.yml index d3894f230..d4a4b075e 100644 --- a/.github/workflows/test_qlib_from_source.yml +++ b/.github/workflows/test_qlib_from_source.yml @@ -86,11 +86,12 @@ jobs: # W1309: f-string-without-interpolation # E1102: not-callable # E1136: unsubscriptable-object + # FIXME: Due to the version change of Pylint, some code will cause W0719 error after PR 1417. W0719 is temporarily disabled in PR 1417 and should be fixed. # References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962 # We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000). - name: Check Qlib with pylint run: | - pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)" + pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)" # The following flake8 error codes were ignored: # E501 line too long diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index ec0725230..bb8ca731b 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -40,8 +40,8 @@ def get_exchange( open_cost: float = 0.0015, close_cost: float = 0.0025, min_cost: float = 5.0, - limit_threshold: Union[Tuple[str, str], float, None] = None, - deal_price: Union[str, Tuple[str, str], List[str]] = None, + limit_threshold: Union[Tuple[str, str], float, None] | None = None, + deal_price: Union[str, Tuple[str, str], List[str]] | None = None, **kwargs: Any, ) -> Exchange: """get_exchange @@ -284,7 +284,7 @@ def collect_data( account: Union[float, int, dict] = 1e9, exchange_kwargs: dict = {}, pos_type: str = "Position", - return_value: dict = None, + return_value: dict | None = None, ) -> Generator[object, None, None]: """initialize the strategy and executor, then collect the trade decision data for rl training diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py index 9d60ff092..b0e416f8f 100644 --- a/qlib/backtest/account.py +++ b/qlib/backtest/account.py @@ -152,7 +152,9 @@ class Account: # trading related metrics(e.g. high-frequency trading) self.indicator = Indicator() - def reset(self, freq: str = None, benchmark_config: dict = None, port_metr_enabled: bool = None) -> None: + def reset( + self, freq: str | None = None, benchmark_config: dict | None = None, port_metr_enabled: bool | None = None + ) -> None: """reset freq and report of account Parameters diff --git a/qlib/backtest/backtest.py b/qlib/backtest/backtest.py index cf0a3a578..5e5edacaf 100644 --- a/qlib/backtest/backtest.py +++ b/qlib/backtest/backtest.py @@ -55,7 +55,7 @@ def collect_data_loop( end_time: Union[pd.Timestamp, str], trade_strategy: BaseStrategy, trade_executor: BaseExecutor, - return_value: dict = None, + return_value: dict | None = None, ) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], None]: """Generator for collecting the trade decision data for rl training diff --git a/qlib/backtest/decision.py b/qlib/backtest/decision.py index 27026b25e..7188bec7a 100644 --- a/qlib/backtest/decision.py +++ b/qlib/backtest/decision.py @@ -254,7 +254,7 @@ class IdxTradeRange(TradeRange): self._start_idx = start_idx self._end_idx = end_idx - def __call__(self, trade_calendar: TradeCalendarManager = None) -> Tuple[int, int]: + def __call__(self, trade_calendar: TradeCalendarManager | None = None) -> Tuple[int, int]: return self._start_idx, self._end_idx def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]: @@ -315,7 +315,7 @@ class BaseTradeDecision(Generic[DecisionType]): 2. Same as `case 1.3` """ - def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None) -> None: + def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange, None] = None) -> None: """ Parameters ---------- @@ -554,7 +554,7 @@ class TradeDecisionWO(BaseTradeDecision[Order]): self, order_list: List[Order], strategy: BaseStrategy, - trade_range: Union[Tuple[int, int], TradeRange] = None, + trade_range: Union[Tuple[int, int], TradeRange, None] = None, ) -> None: super().__init__(strategy, trade_range=trade_range) self.order_list = cast(List[Order], order_list) diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 3a238156e..a752a9f8c 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -41,10 +41,10 @@ class Exchange: start_time: Union[pd.Timestamp, str] = None, end_time: Union[pd.Timestamp, str] = None, codes: Union[list, str] = "all", - deal_price: Union[str, Tuple[str, str], List[str]] = None, + deal_price: Union[str, Tuple[str, str], List[str], None] = None, subscribe_fields: list = [], limit_threshold: Union[Tuple[str, str], float, None] = None, - volume_threshold: Union[tuple, dict] = None, + volume_threshold: Union[tuple, dict, None] = None, open_cost: float = 0.0015, close_cost: float = 0.0025, min_cost: float = 5.0, @@ -340,7 +340,7 @@ class Exchange: stock_id: str, start_time: pd.Timestamp, end_time: pd.Timestamp, - direction: int = None, + direction: int | None = None, ) -> bool: """ Parameters @@ -406,7 +406,7 @@ class Exchange: stock_id: str, start_time: pd.Timestamp, end_time: pd.Timestamp, - direction: int = None, + direction: int | None = None, ) -> bool: # check if stock can be traded return not ( @@ -421,8 +421,8 @@ class Exchange: def deal_order( self, order: Order, - trade_account: Account = None, - position: BasePosition = None, + trade_account: Account | None = None, + position: BasePosition | None = None, dealt_order_amount: Dict[str, float] = defaultdict(float), ) -> Tuple[float, float, float]: """ @@ -586,7 +586,7 @@ class Exchange: ) return amount_dict - def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float = None) -> float: + def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float | None = None) -> float: """ Calculate the real adjust deal amount when considering the trading unit :param current_amount: @@ -712,8 +712,8 @@ class Exchange: def _get_factor_or_raise_error( self, - factor: float = None, - stock_id: str = None, + factor: float | None = None, + stock_id: str | None = None, start_time: pd.Timestamp = None, end_time: pd.Timestamp = None, ) -> float: @@ -728,8 +728,8 @@ class Exchange: def get_amount_of_trade_unit( self, - factor: float = None, - stock_id: str = None, + factor: float | None = None, + stock_id: str | None = None, start_time: pd.Timestamp = None, end_time: pd.Timestamp = None, ) -> Optional[float]: @@ -762,8 +762,8 @@ class Exchange: def round_amount_by_trade_unit( self, deal_amount: float, - factor: float = None, - stock_id: str = None, + factor: float | None = None, + stock_id: str | None = None, start_time: pd.Timestamp = None, end_time: pd.Timestamp = None, ) -> float: diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index afed973ba..b5d4326a7 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -31,8 +31,8 @@ class BaseExecutor: generate_portfolio_metrics: bool = False, verbose: bool = False, track_data: bool = False, - trade_exchange: Exchange = None, - common_infra: CommonInfrastructure = None, + trade_exchange: Exchange | None = None, + common_infra: CommonInfrastructure | None = None, settle_type: str = BasePosition.ST_NO, **kwargs: Any, ) -> None: @@ -161,7 +161,7 @@ class BaseExecutor: """ return self.level_infra.get("trade_calendar") - def reset(self, common_infra: CommonInfrastructure = None, **kwargs: Any) -> None: + def reset(self, common_infra: CommonInfrastructure | None = None, **kwargs: Any) -> None: """ - reset `start_time` and `end_time`, used in trade calendar - reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc @@ -227,7 +227,7 @@ class BaseExecutor: def collect_data( self, trade_decision: BaseTradeDecision, - return_value: dict = None, + return_value: dict | None = None, level: int = 0, ) -> Generator[Any, Any, List[object]]: """Generator for collecting the trade decision data for rl training @@ -327,7 +327,7 @@ class NestedExecutor(BaseExecutor): track_data: bool = False, skip_empty_decision: bool = True, align_range_limit: bool = True, - common_infra: CommonInfrastructure = None, + common_infra: CommonInfrastructure | None = None, **kwargs: Any, ) -> None: """ @@ -534,7 +534,7 @@ class SimulatorExecutor(BaseExecutor): generate_portfolio_metrics: bool = False, verbose: bool = False, track_data: bool = False, - common_infra: CommonInfrastructure = None, + common_infra: CommonInfrastructure | None = None, trade_type: str = TT_SERIAL, **kwargs: Any, ) -> None: diff --git a/qlib/backtest/position.py b/qlib/backtest/position.py index ea6b7c57b..18b084fb6 100644 --- a/qlib/backtest/position.py +++ b/qlib/backtest/position.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from __future__ import annotations from datetime import timedelta from typing import Any, Dict, List, Union @@ -320,7 +321,7 @@ class Position(BasePosition): self.position[stock]["price"] = price_dict[stock] self.position["now_account_value"] = self.calculate_value() - def _init_stock(self, stock_id: str, amount: float, price: float = None) -> None: + def _init_stock(self, stock_id: str, amount: float, price: float | None = None) -> None: """ initialization the stock in current position diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index b8aa8273c..8e7440ba9 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from __future__ import annotations import pathlib from collections import OrderedDict @@ -86,7 +87,7 @@ class PortfolioMetrics: self.benches: dict = OrderedDict() self.latest_pm_time: Optional[pd.TimeStamp] = None - def init_bench(self, freq: str = None, benchmark_config: dict = None) -> None: + def init_bench(self, freq: str | None = None, benchmark_config: dict | None = None) -> None: if freq is not None: self.freq = freq self.benchmark_config = benchmark_config @@ -149,15 +150,15 @@ class PortfolioMetrics: self, trade_start_time: Union[str, pd.Timestamp] = None, trade_end_time: Union[str, pd.Timestamp] = None, - account_value: float = None, - cash: float = None, - return_rate: float = None, - total_turnover: float = None, - turnover_rate: float = None, - total_cost: float = None, - cost_rate: float = None, - stock_value: float = None, - bench_value: float = None, + account_value: float | None = None, + cash: float | None = None, + return_rate: float | None = None, + total_turnover: float | None = None, + turnover_rate: float | None = None, + total_cost: float | None = None, + cost_rate: float | None = None, + stock_value: float | None = None, + bench_value: float | None = None, ) -> None: # check data if None in [ diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py index 595b2accc..4210c9548 100644 --- a/qlib/backtest/utils.py +++ b/qlib/backtest/utils.py @@ -31,7 +31,7 @@ class TradeCalendarManager: freq: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None, - level_infra: LevelInfrastructure = None, + level_infra: LevelInfrastructure | None = None, ) -> None: """ Parameters @@ -99,7 +99,7 @@ class TradeCalendarManager: def get_trade_step(self) -> int: return self.trade_step - def get_step_time(self, trade_step: int = None, shift: int = 0) -> Tuple[pd.Timestamp, pd.Timestamp]: + def get_step_time(self, trade_step: int | None = None, shift: int = 0) -> Tuple[pd.Timestamp, pd.Timestamp]: """ Get the left and right endpoints of the trade_step'th trading interval diff --git a/qlib/contrib/ops/high_freq.py b/qlib/contrib/ops/high_freq.py index ee2825fbf..65b84fed6 100644 --- a/qlib/contrib/ops/high_freq.py +++ b/qlib/contrib/ops/high_freq.py @@ -70,7 +70,7 @@ class DayCumsum(ElemOperator): Otherwise, the value is zero. """ - def __init__(self, feature, start: str = "9:30", end: str = "14:59"): + def __init__(self, feature, start: str = "9:30", end: str = "14:59", data_granularity: int = 1): self.feature = feature self.start = datetime.strptime(start, "%H:%M") self.end = datetime.strptime(end, "%H:%M") @@ -80,15 +80,17 @@ class DayCumsum(ElemOperator): self.noon_open = datetime.strptime("13:00", "%H:%M") self.noon_close = datetime.strptime("15:00", "%H:%M") - self.start_id = time_to_day_index(self.start) - self.end_id = time_to_day_index(self.end) + self.data_granularity = data_granularity + self.start_id = time_to_day_index(self.start) // self.data_granularity + self.end_id = time_to_day_index(self.end) // self.data_granularity + assert 240 % self.data_granularity == 0 def period_cusum(self, df): df = df.copy() - assert len(df) == 240 + assert len(df) == 240 // self.data_granularity df.iloc[0 : self.start_id] = 0 df = df.cumsum() - df.iloc[self.end_id + 1 : 240] = 0 + df.iloc[self.end_id + 1 : 240 // self.data_granularity] = 0 return df def _load_internal(self, instrument, start_index, end_index, freq): diff --git a/qlib/rl/contrib/backtest.py b/qlib/rl/contrib/backtest.py index 4d1eae46d..2818f788c 100644 --- a/qlib/rl/contrib/backtest.py +++ b/qlib/rl/contrib/backtest.py @@ -28,14 +28,14 @@ from qlib.typehint import Literal def _get_multi_level_executor_config( strategy_config: dict, - cash_limit: float = None, + cash_limit: float | None = None, generate_report: bool = False, ) -> dict: executor_config = { "class": "SimulatorExecutor", "module_path": "qlib.backtest.executor", "kwargs": { - "time_per_step": "1min", + "time_per_step": "5min", # FIXME: move this into config "verbose": False, "trade_type": SimulatorExecutor.TT_PARAL if cash_limit is not None else SimulatorExecutor.TT_SERIAL, "generate_report": generate_report, @@ -127,7 +127,7 @@ def single_with_simulator( backtest_config: dict, orders: pd.DataFrame, split: Literal["stock", "day"] = "stock", - cash_limit: float = None, + cash_limit: float | None = None, generate_report: bool = False, ) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]: """Run backtest in a single thread with SingleAssetOrderExecution simulator. The orders will be executed day by day. @@ -187,7 +187,7 @@ def single_with_simulator( exchange_config.update( { "codes": stocks, - "freq": "1min", + "freq": "5min", # FIXME: move this into config } ) @@ -226,7 +226,7 @@ def single_with_collect_data_loop( backtest_config: dict, orders: pd.DataFrame, split: Literal["stock", "day"] = "stock", - cash_limit: float = None, + cash_limit: float | None = None, generate_report: bool = False, ) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]: """Run backtest in a single thread with collect_data_loop. @@ -286,7 +286,7 @@ def single_with_collect_data_loop( exchange_config.update( { "codes": stocks, - "freq": "1min", + "freq": "5min", # FIXME: move this into config } ) diff --git a/qlib/rl/contrib/naive_config_parser.py b/qlib/rl/contrib/naive_config_parser.py index ab5e95359..a6409f828 100644 --- a/qlib/rl/contrib/naive_config_parser.py +++ b/qlib/rl/contrib/naive_config_parser.py @@ -98,7 +98,7 @@ def get_backtest_config_fromfile(path: str) -> dict: "debug_single_day": None, "concurrency": -1, "multiplier": 1.0, - "output_dir": "outputs/", + "output_dir": "outputs_backtest/", "generate_report": False, } backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default) diff --git a/qlib/rl/contrib/train_onpolicy.py b/qlib/rl/contrib/train_onpolicy.py index d05994854..d131ff244 100644 --- a/qlib/rl/contrib/train_onpolicy.py +++ b/qlib/rl/contrib/train_onpolicy.py @@ -3,6 +3,7 @@ import argparse import os import random +import warnings from pathlib import Path from typing import cast, List, Optional @@ -23,7 +24,6 @@ from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter from qlib.rl.utils.log import CsvWriter from qlib.utils import init_instance_by_config from tianshou.policy import BasePolicy -from torch import nn from torch.utils.data import Dataset @@ -101,6 +101,7 @@ def train_and_test( action_interpreter: ActionInterpreter, policy: BasePolicy, reward: Reward, + run_training: bool, run_backtest: bool, ) -> None: qlib.init() @@ -122,62 +123,67 @@ def train_and_test( assert data_config["source"]["default_start_time_index"] % data_granularity == 0 assert data_config["source"]["default_end_time_index"] % data_granularity == 0 - train_dataset, valid_dataset, test_dataset = [ - LazyLoadDataset( - order_file_path=order_root_path / tag, + if run_training: + train_dataset, valid_dataset = [ + LazyLoadDataset( + order_file_path=order_root_path / tag, + data_dir=Path(data_config["source"]["data_dir"]), + default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity, + default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity, + ) + for tag in ("train", "valid") + ] + + callbacks: List[Callback] = [] + if "checkpoint_path" in trainer_config: + callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"]))) + callbacks.append( + Checkpoint( + dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints", + every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1), + save_latest="copy", + ), + ) + if "earlystop_patience" in trainer_config: + callbacks.append( + EarlyStopping( + patience=trainer_config["earlystop_patience"], + monitor="val/pa", + ) + ) + + train( + simulator_fn=_simulator_factory_simple, + state_interpreter=state_interpreter, + action_interpreter=action_interpreter, + policy=policy, + reward=reward, + initial_states=cast(List[Order], train_dataset), + trainer_kwargs={ + "max_iters": trainer_config["max_epoch"], + "finite_env_type": env_config["parallel_mode"], + "concurrency": env_config["concurrency"], + "val_every_n_iters": trainer_config.get("val_every_n_epoch", None), + "callbacks": callbacks, + }, + vessel_kwargs={ + "episode_per_iter": trainer_config["episode_per_collect"], + "update_kwargs": { + "batch_size": trainer_config["batch_size"], + "repeat": trainer_config["repeat_per_collect"], + }, + "val_initial_states": valid_dataset, + }, + ) + + if run_backtest: + test_dataset = LazyLoadDataset( + order_file_path=order_root_path / "test", data_dir=Path(data_config["source"]["data_dir"]), default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity, default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity, ) - for tag in ("train", "valid", "test") - ] - if "checkpoint_path" in trainer_config: - callbacks: List[Callback] = [] - callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"]))) - callbacks.append( - Checkpoint( - dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints", - every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1), - save_latest="copy", - ), - ) - if "earlystop_patience" in trainer_config: - callbacks.append( - EarlyStopping( - patience=trainer_config["earlystop_patience"], - monitor="val/pa", - ) - ) - - trainer_kwargs = { - "max_iters": trainer_config["max_epoch"], - "finite_env_type": env_config["parallel_mode"], - "concurrency": env_config["concurrency"], - "val_every_n_iters": trainer_config.get("val_every_n_epoch", None), - "callbacks": callbacks, - } - vessel_kwargs = { - "episode_per_iter": trainer_config["episode_per_collect"], - "update_kwargs": { - "batch_size": trainer_config["batch_size"], - "repeat": trainer_config["repeat_per_collect"], - }, - "val_initial_states": valid_dataset, - } - - train( - simulator_fn=_simulator_factory_simple, - state_interpreter=state_interpreter, - action_interpreter=action_interpreter, - policy=policy, - reward=reward, - initial_states=cast(List[Order], train_dataset), - trainer_kwargs=trainer_kwargs, - vessel_kwargs=vessel_kwargs, - ) - - if run_backtest: backtest( simulator_fn=_simulator_factory_simple, state_interpreter=state_interpreter, @@ -186,35 +192,39 @@ def train_and_test( policy=policy, logger=CsvWriter(Path(trainer_config["checkpoint_path"])), reward=reward, - finite_env_type=trainer_kwargs["finite_env_type"], - concurrency=trainer_kwargs["concurrency"], + finite_env_type=env_config["parallel_mode"], + concurrency=env_config["concurrency"], ) -def main(config: dict, run_backtest: bool) -> None: +def main(config: dict, run_training: bool, run_backtest: bool) -> None: + if not run_training and not run_backtest: + warnings.warn("Skip the entire job since training and backtest are both skipped.") + return + if "seed" in config["runtime"]: seed_everything(config["runtime"]["seed"]) - state_config = config["state_interpreter"] - state_interpreter: StateInterpreter = init_instance_by_config(state_config) - + state_interpreter: StateInterpreter = init_instance_by_config(config["state_interpreter"]) action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"]) reward: Reward = init_instance_by_config(config["reward"]) + additional_policy_kwargs = { + "obs_space": state_interpreter.observation_space, + "action_space": action_interpreter.action_space, + } + # Create torch network - if "kwargs" not in config["network"]: - config["network"]["kwargs"] = {} - config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space}) - network: nn.Module = init_instance_by_config(config["network"]) + if "network" in config: + if "kwargs" not in config["network"]: + config["network"]["kwargs"] = {} + config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space}) + additional_policy_kwargs["network"] = init_instance_by_config(config["network"]) # Create policy - config["policy"]["kwargs"].update( - { - "network": network, - "obs_space": state_interpreter.observation_space, - "action_space": action_interpreter.action_space, - } - ) + if "kwargs" not in config["policy"]: + config["policy"]["kwargs"] = {} + config["policy"]["kwargs"].update(additional_policy_kwargs) policy: BasePolicy = init_instance_by_config(config["policy"]) use_cuda = config["runtime"].get("use_cuda", False) @@ -230,22 +240,22 @@ def main(config: dict, run_backtest: bool) -> None: state_interpreter=state_interpreter, policy=policy, reward=reward, + run_training=run_training, run_backtest=run_backtest, ) if __name__ == "__main__": - import warnings - warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=RuntimeWarning) parser = argparse.ArgumentParser() parser.add_argument("--config_path", type=str, required=True, help="Path to the config file") - parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow after training is finished") + parser.add_argument("--no_training", action="store_true", help="Skip training workflow.") + parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.") args = parser.parse_args() with open(args.config_path, "r") as input_stream: config = yaml.safe_load(input_stream) - main(config, run_backtest=args.run_backtest) + main(config, run_training=not args.no_training, run_backtest=args.run_backtest) diff --git a/qlib/rl/data/integration.py b/qlib/rl/data/integration.py index af5025c84..58311367f 100644 --- a/qlib/rl/data/integration.py +++ b/qlib/rl/data/integration.py @@ -49,7 +49,7 @@ class DataWrapper: return dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None) -def init_qlib(qlib_config: dict, part: str = None) -> None: +def init_qlib(qlib_config: dict, part: str | None = None) -> None: """Initialize necessary resource to launch the workflow, including data direction, feature columns, etc.. Parameters @@ -82,10 +82,9 @@ def init_qlib(qlib_config: dict, part: str = None) -> None: return path if isinstance(path, Path) else Path(path) provider_uri_map = {} - if "provider_uri_day" in qlib_config: - provider_uri_map["day"] = _convert_to_path(qlib_config["provider_uri_day"]).as_posix() - if "provider_uri_1min" in qlib_config: - provider_uri_map["1min"] = _convert_to_path(qlib_config["provider_uri_1min"]).as_posix() + for granularity in ["1min", "5min", "day"]: + if f"provider_uri_{granularity}" in qlib_config: + provider_uri_map[f"{granularity}"] = _convert_to_path(qlib_config[f"provider_uri_{granularity}"]).as_posix() qlib.init( region=REG_CN, diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index 63b55d6e0..3f21c0855 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -104,7 +104,7 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData): stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", - order_dir: int = None, + order_dir: int | None = None, ) -> None: super(SimpleIntradayBacktestData, self).__init__() @@ -208,7 +208,7 @@ def load_simple_intraday_backtest_data( stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", - order_dir: int = None, + order_dir: int | None = None, ) -> SimpleIntradayBacktestData: return SimpleIntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir) diff --git a/qlib/rl/order_execution/interpreter.py b/qlib/rl/order_execution/interpreter.py index 0d45624bd..01b081153 100644 --- a/qlib/rl/order_execution/interpreter.py +++ b/qlib/rl/order_execution/interpreter.py @@ -53,6 +53,18 @@ class FullHistoryObs(TypedDict): position_history: Any +class DummyStateInterpreter(StateInterpreter[SAOEState, dict]): + """Dummy interpreter for policies that do not need inputs (for example, AllOne).""" + + def interpret(self, state: SAOEState) -> dict: + # TODO: A fake state, used to pass `check_nan_observation`. Find a better way in the future. + return {"DUMMY": _to_int32(1)} + + @property + def observation_space(self) -> spaces.Dict: + return spaces.Dict({"DUMMY": spaces.Box(-np.inf, np.inf, shape=(), dtype=np.int32)}) + + class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]): """The observation of all the history, including today (until this moment), and yesterday. diff --git a/qlib/rl/order_execution/policy.py b/qlib/rl/order_execution/policy.py index 598e6b589..2102ff6ab 100644 --- a/qlib/rl/order_execution/policy.py +++ b/qlib/rl/order_execution/policy.py @@ -32,7 +32,7 @@ class NonLearnablePolicy(BasePolicy): super().__init__() def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, Any]: - pass + return {} def process_fn( self, @@ -40,7 +40,7 @@ class NonLearnablePolicy(BasePolicy): buffer: ReplayBuffer, indices: np.ndarray, ) -> Batch: - pass + return Batch({}) class AllOne(NonLearnablePolicy): @@ -49,13 +49,18 @@ class AllOne(NonLearnablePolicy): Useful when implementing some baselines (e.g., TWAP). """ + def __init__(self, obs_space: gym.Space, action_space: gym.Space, fill_value: float | int = 1.0) -> None: + super().__init__(obs_space, action_space) + + self.fill_value = fill_value + def forward( self, batch: Batch, state: dict | Batch | np.ndarray = None, **kwargs: Any, ) -> Batch: - return Batch(act=np.full(len(batch), 1.0), state=state) + return Batch(act=np.full(len(batch), self.fill_value), state=state) # ppo # diff --git a/qlib/rl/order_execution/reward.py b/qlib/rl/order_execution/reward.py index e83066d85..c6acc4394 100644 --- a/qlib/rl/order_execution/reward.py +++ b/qlib/rl/order_execution/reward.py @@ -7,6 +7,7 @@ from typing import cast import numpy as np +from qlib.backtest.decision import OrderDir from qlib.rl.order_execution.state import SAOEMetrics, SAOEState from qlib.rl.reward import Reward @@ -47,3 +48,40 @@ class PAPenaltyReward(Reward[SAOEState]): self.log("reward/pa", pa) self.log("reward/penalty", penalty) return reward * self.scale + + +class PPOReward(Reward[SAOEState]): + """Reward proposed by paper "An End-to-End Optimal Trade Execution Framework based on Proximal Policy Optimization". + + Parameters + ---------- + max_step + Maximum number of steps. + start_time_index + First time index that allowed to trade. + end_time_index + Last time index that allowed to trade. + """ + + def __init__(self, max_step: int, start_time_index: int = 0, end_time_index: int = 239) -> None: + self.max_step = max_step + self.start_time_index = start_time_index + self.end_time_index = end_time_index + + def reward(self, simulator_state: SAOEState) -> float: + if simulator_state.cur_step == self.max_step - 1 or simulator_state.position < 1e-6: + vwap_price = cast(dict, simulator_state.metrics)["trade_price"] + twap_price = simulator_state.backtest_data.get_deal_price().mean() + + if simulator_state.order.direction == OrderDir.SELL: + ratio = vwap_price / twap_price if twap_price != 0 else 1.0 + else: + ratio = twap_price / vwap_price if vwap_price != 0 else 1.0 + if ratio < 1.0: + return -1.0 + elif ratio < 1.1: + return 0.0 + else: + return 1.0 + else: + return 0.0 diff --git a/qlib/rl/order_execution/simulator_qlib.py b/qlib/rl/order_execution/simulator_qlib.py index 610a0c0bd..ab6b46376 100644 --- a/qlib/rl/order_execution/simulator_qlib.py +++ b/qlib/rl/order_execution/simulator_qlib.py @@ -38,8 +38,8 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): order: Order, executor_config: dict, exchange_config: dict, - qlib_config: dict = None, - cash_limit: Optional[float] = None, + qlib_config: dict | None = None, + cash_limit: float | None = None, ) -> None: super().__init__(initial=order) @@ -63,7 +63,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): strategy_config: dict, executor_config: dict, exchange_config: dict, - qlib_config: dict = None, + qlib_config: dict | None = None, cash_limit: Optional[float] = None, ) -> None: if qlib_config is not None: diff --git a/qlib/rl/order_execution/strategy.py b/qlib/rl/order_execution/strategy.py index 0102b9e57..b6f5e12b2 100644 --- a/qlib/rl/order_execution/strategy.py +++ b/qlib/rl/order_execution/strategy.py @@ -89,6 +89,7 @@ class SAOEStateAdapter: exchange: Exchange, ticks_per_step: int, backtest_data: IntradayBacktestData, + data_granularity: int = 1, ) -> None: self.position = order.amount self.order = order @@ -106,11 +107,13 @@ class SAOEStateAdapter: self.cur_time = max(backtest_data.ticks_for_order[0], order.start_time) self.ticks_per_step = ticks_per_step + self.data_granularity = data_granularity + assert self.ticks_per_step % self.data_granularity == 0 def _next_time(self) -> pd.Timestamp: current_loc = self.backtest_data.ticks_index.get_loc(self.cur_time) - next_loc = current_loc + self.ticks_per_step - next_loc = next_loc - next_loc % self.ticks_per_step + next_loc = current_loc + (self.ticks_per_step // self.data_granularity) + next_loc = next_loc - next_loc % (self.ticks_per_step // self.data_granularity) if ( next_loc < len(self.backtest_data.ticks_index) and self.backtest_data.ticks_index[next_loc] < self.order.end_time @@ -130,7 +133,7 @@ class SAOEStateAdapter: exec_vol = np.zeros(last_step_size) for order, _, __, ___ in execute_result: - idx, _ = get_day_min_idx_range(order.start_time, order.end_time, "1min", REG_CN) + idx, _ = get_day_min_idx_range(order.start_time, order.end_time, f"{self.data_granularity}min", REG_CN) exec_vol[idx - last_step_range[0]] = order.deal_amount if exec_vol.sum() > self.position and exec_vol.sum() > 0.0: @@ -168,7 +171,9 @@ class SAOEStateAdapter: self.history_exec, self._collect_multi_order_metric( order=self.order, - datetime=_get_all_timestamps(start_time, end_time, include_end=True), + datetime=_get_all_timestamps( + start_time, end_time, include_end=True, granularity=ONE_MIN * self.data_granularity + ), market_vol=market_volume, market_price=market_price, exec_vol=exec_vol, @@ -293,9 +298,10 @@ class SAOEStrategy(RLStrategy): def __init__( self, policy: BasePolicy, - outer_trade_decision: BaseTradeDecision = None, - level_infra: LevelInfrastructure = None, - common_infra: CommonInfrastructure = None, + outer_trade_decision: BaseTradeDecision | None = None, + level_infra: LevelInfrastructure | None = None, + common_infra: CommonInfrastructure | None = None, + data_granularity: int = 1, **kwargs: Any, ) -> None: super(SAOEStrategy, self).__init__( @@ -306,6 +312,7 @@ class SAOEStrategy(RLStrategy): **kwargs, ) + self._data_granularity = data_granularity self.adapter_dict: Dict[tuple, SAOEStateAdapter] = {} self._last_step_range = (0, 0) @@ -324,9 +331,10 @@ class SAOEStrategy(RLStrategy): exchange=self.trade_exchange, ticks_per_step=int(pd.Timedelta(self.trade_calendar.get_freq()) / ONE_MIN), backtest_data=backtest_data, + data_granularity=self._data_granularity, ) - def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None: + def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None: super(SAOEStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs) self.adapter_dict = {} @@ -366,7 +374,7 @@ class SAOEStrategy(RLStrategy): def generate_trade_decision( self, - execute_result: list = None, + execute_result: list | None = None, ) -> Union[BaseTradeDecision, Generator[Any, Any, BaseTradeDecision]]: """ For SAOEStrategy, we need to update the `self._last_step_range` every time a decision is generated. @@ -385,7 +393,7 @@ class SAOEStrategy(RLStrategy): def _generate_trade_decision( self, - execute_result: list = None, + execute_result: list | None = None, ) -> Union[BaseTradeDecision, Generator[Any, Any, BaseTradeDecision]]: raise NotImplementedError @@ -399,14 +407,14 @@ class ProxySAOEStrategy(SAOEStrategy): def __init__( self, - outer_trade_decision: BaseTradeDecision = None, - level_infra: LevelInfrastructure = None, - common_infra: CommonInfrastructure = None, + outer_trade_decision: BaseTradeDecision | None = None, + level_infra: LevelInfrastructure | None = None, + common_infra: CommonInfrastructure | None = None, **kwargs: Any, ) -> None: super().__init__(None, outer_trade_decision, level_infra, common_infra, **kwargs) - def _generate_trade_decision(self, execute_result: list = None) -> Generator[Any, Any, BaseTradeDecision]: + def _generate_trade_decision(self, execute_result: list | None = None) -> Generator[Any, Any, BaseTradeDecision]: # Once the following line is executed, this ProxySAOEStrategy (self) will be yielded to the outside # of the entire executor, and the execution will be suspended. When the execution is resumed by `send()`, # the item will be captured by `exec_vol`. The outside policy could communicate with the inner @@ -418,7 +426,7 @@ class ProxySAOEStrategy(SAOEStrategy): return TradeDecisionWO([order], self) - def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None: + def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None: super().reset(outer_trade_decision=outer_trade_decision, **kwargs) assert isinstance(outer_trade_decision, TradeDecisionWO) @@ -437,9 +445,9 @@ class SAOEIntStrategy(SAOEStrategy): state_interpreter: dict | StateInterpreter, action_interpreter: dict | ActionInterpreter, network: dict | torch.nn.Module | None = None, - outer_trade_decision: BaseTradeDecision = None, - level_infra: LevelInfrastructure = None, - common_infra: CommonInfrastructure = None, + outer_trade_decision: BaseTradeDecision | None = None, + level_infra: LevelInfrastructure | None = None, + common_infra: CommonInfrastructure | None = None, **kwargs: Any, ) -> None: super(SAOEIntStrategy, self).__init__( @@ -488,7 +496,7 @@ class SAOEIntStrategy(SAOEStrategy): if self._policy is not None: self._policy.eval() - def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None: + def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None: super().reset(outer_trade_decision=outer_trade_decision, **kwargs) def _generate_trade_details(self, act: np.ndarray, exec_vols: List[float]) -> pd.DataFrame: @@ -508,7 +516,7 @@ class SAOEIntStrategy(SAOEStrategy): trade_details[-1]["rl_action"] = a return pd.DataFrame.from_records(trade_details) - def _generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision: + def _generate_trade_decision(self, execute_result: list | None = None) -> BaseTradeDecision: states = [] obs_batch = [] for decision in self.outer_trade_decision.get_decision(): diff --git a/qlib/rl/strategy/single_order.py b/qlib/rl/strategy/single_order.py index 9d8e396ce..45db0d9c8 100644 --- a/qlib/rl/strategy/single_order.py +++ b/qlib/rl/strategy/single_order.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from __future__ import annotations + from qlib.backtest import Order from qlib.backtest.decision import OrderHelper, TradeDecisionWO, TradeRange from qlib.strategy.base import BaseStrategy @@ -12,14 +14,14 @@ class SingleOrderStrategy(BaseStrategy): def __init__( self, order: Order, - trade_range: TradeRange = None, + trade_range: TradeRange | None = None, ) -> None: super().__init__() self._order = order self._trade_range = trade_range - def generate_trade_decision(self, execute_result: list = None) -> TradeDecisionWO: + def generate_trade_decision(self, execute_result: list | None = None) -> TradeDecisionWO: oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper() order_list = [ oh.create( diff --git a/qlib/rl/utils/data_queue.py b/qlib/rl/utils/data_queue.py index 828288871..71c2dff65 100644 --- a/qlib/rl/utils/data_queue.py +++ b/qlib/rl/utils/data_queue.py @@ -4,6 +4,7 @@ from __future__ import annotations import multiprocessing +from multiprocessing.sharedctypes import Synchronized import os import threading import time @@ -78,7 +79,9 @@ class DataQueue(Generic[T]): self._activated: bool = False self._queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=queue_maxsize) - self._done = multiprocessing.Value("i", 0) + # Mypy 0.981 brought '"SynchronizedBase[Any]" has no attribute "value" [attr-defined]' bug. + # Therefore, add this type casting to pass Mypy checking. + self._done = cast(Synchronized, multiprocessing.Value("i", 0)) def __enter__(self) -> DataQueue: self.activate() @@ -122,7 +125,7 @@ class DataQueue(Generic[T]): if self._done.value: raise StopIteration # pylint: disable=raise-missing-from - def put(self, obj: Any, block: bool = True, timeout: int = None) -> None: + def put(self, obj: Any, block: bool = True, timeout: int | None = None) -> None: self._queue.put(obj, block=block, timeout=timeout) def mark_as_done(self) -> None: diff --git a/qlib/rl/utils/env_wrapper.py b/qlib/rl/utils/env_wrapper.py index e0c009b7b..e863b709a 100644 --- a/qlib/rl/utils/env_wrapper.py +++ b/qlib/rl/utils/env_wrapper.py @@ -99,9 +99,9 @@ class EnvWrapper( state_interpreter: StateInterpreter[StateType, ObsType], action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType], seed_iterator: Optional[Iterable[InitialStateType]], - reward_fn: Reward = None, - aux_info_collector: AuxiliaryInfoCollector[StateType, Any] = None, - logger: LogCollector = None, + reward_fn: Reward | None = None, + aux_info_collector: AuxiliaryInfoCollector[StateType, Any] | None = None, + logger: LogCollector | None = None, ) -> None: # Assign weak reference to wrapper. # diff --git a/qlib/rl/utils/log.py b/qlib/rl/utils/log.py index 4b0e68c68..75aab2068 100644 --- a/qlib/rl/utils/log.py +++ b/qlib/rl/utils/log.py @@ -397,7 +397,7 @@ class ConsoleWriter(LogWriter): def __init__( self, log_every_n_episode: int = 20, - total_episodes: int = None, + total_episodes: int | None = None, float_format: str = ":.4f", counter_format: str = ":4d", loglevel: int | LogLevel = LogLevel.PERIODIC,