diff --git a/qlib/rl/entries/test.py b/qlib/rl/entries/test.py index 8cd891200..6054c866a 100644 --- a/qlib/rl/entries/test.py +++ b/qlib/rl/entries/test.py @@ -4,7 +4,7 @@ from __future__ import annotations import copy -from typing import Callable, List, Sequence, Union +from typing import Callable, List, Sequence from tianshou.data import Collector from tianshou.policy import BasePolicy @@ -25,7 +25,7 @@ def backtest( action_interpreter: ActionInterpreter, initial_states: Sequence[InitialStateType], policy: BasePolicy, - logger: Union[LogWriter, List[LogWriter]], + logger: LogWriter | List[LogWriter], reward: Reward = None, finite_env_type: FiniteEnvType = "subproc", concurrency: int = 2, diff --git a/qlib/rl/order_execution/interpreter.py b/qlib/rl/order_execution/interpreter.py index 788d22ae0..14af3ed36 100644 --- a/qlib/rl/order_execution/interpreter.py +++ b/qlib/rl/order_execution/interpreter.py @@ -5,7 +5,7 @@ from __future__ import annotations import math from pathlib import Path -from typing import Any, List, Union, cast +from typing import Any, List, cast import numpy as np import pandas as pd @@ -26,7 +26,7 @@ __all__ = [ ] -def canonicalize(value: Union[int, float, np.ndarray, pd.DataFrame, dict]) -> Union[np.ndarray, dict]: +def canonicalize(value: int | float | np.ndarray | pd.DataFrame | dict) -> np.ndarray | dict: """To 32-bit numeric types. Recursively.""" if isinstance(value, pd.DataFrame): return value.to_numpy() @@ -188,7 +188,7 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]): i.e., $[0, 1/n, 2/n, \\ldots, n/n]$. """ - def __init__(self, values: Union[int, List[float]]) -> None: + def __init__(self, values: int | List[float]) -> None: if isinstance(values, int): values = [i / values for i in range(0, values + 1)] self.action_values = values diff --git a/qlib/rl/order_execution/policy.py b/qlib/rl/order_execution/policy.py index e9737ca98..df6dc0454 100644 --- a/qlib/rl/order_execution/policy.py +++ b/qlib/rl/order_execution/policy.py @@ -1,8 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from __future__ import annotations + from abc import ABCMeta from pathlib import Path -from typing import Any, Dict, Generator, Iterable, Optional, Tuple, Union, cast +from typing import Any, Dict, Generator, Iterable, Optional, Tuple, cast import gym import numpy as np @@ -48,7 +50,7 @@ class AllOne(NonLearnablePolicy): def forward( self, batch: Batch, - state: Optional[Union[dict, Batch, np.ndarray]] = None, + state: dict | Batch | np.ndarray = None, **kwargs: Any, ) -> Batch: return Batch(act=np.full(len(batch), 1.0), state=state) diff --git a/qlib/rl/order_execution/simulator_simple.py b/qlib/rl/order_execution/simulator_simple.py index c0d364433..eb1cc10a4 100644 --- a/qlib/rl/order_execution/simulator_simple.py +++ b/qlib/rl/order_execution/simulator_simple.py @@ -4,7 +4,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any, NamedTuple, Optional, TypeVar, Union, cast +from typing import Any, NamedTuple, Optional, TypeVar, cast import numpy as np import pandas as pd @@ -33,7 +33,7 @@ class SAOEMetrics(TypedDict): stock_id: str """Stock ID of this record.""" - datetime: Union[pd.Timestamp, pd.DatetimeIndex] # TODO: check this + datetime: pd.Timestamp | pd.DatetimeIndex # TODO: check this """Datetime of this record (this is index in the dataframe).""" direction: int """Direction of the order. 0 for sell, 1 for buy.""" @@ -392,7 +392,7 @@ _float_or_ndarray = TypeVar("_float_or_ndarray", float, np.ndarray) def price_advantage( exec_price: _float_or_ndarray, baseline_price: float, - direction: Union[OrderDir, int], + direction: OrderDir | int, ) -> _float_or_ndarray: if baseline_price == 0: # something is wrong with data. Should be nan here if isinstance(exec_price, float): diff --git a/qlib/rl/utils/env_wrapper.py b/qlib/rl/utils/env_wrapper.py index 4d50a32d8..529bfe597 100644 --- a/qlib/rl/utils/env_wrapper.py +++ b/qlib/rl/utils/env_wrapper.py @@ -4,7 +4,7 @@ from __future__ import annotations import weakref -from typing import Any, Callable, Dict, Generic, Iterable, Iterator, Optional, Tuple, Union, cast +from typing import Any, Callable, Dict, Generic, Iterable, Iterator, Optional, Tuple, cast import gym from gym import Space @@ -92,7 +92,7 @@ class EnvWrapper( """ simulator: Simulator[InitialStateType, StateType, ActType] - seed_iterator: Union[str, Iterator[InitialStateType], None] + seed_iterator: str | Iterator[InitialStateType] | None def __init__( self, diff --git a/qlib/rl/utils/finite_env.py b/qlib/rl/utils/finite_env.py index 725dbe975..27b498414 100644 --- a/qlib/rl/utils/finite_env.py +++ b/qlib/rl/utils/finite_env.py @@ -11,7 +11,7 @@ from __future__ import annotations import copy import warnings from contextlib import contextmanager -from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Type, Union, cast +from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Type, cast import gym import numpy as np @@ -36,7 +36,7 @@ __all__ = [ FiniteEnvType = Literal["dummy", "subproc", "shmem"] -def fill_invalid(obj: Union[int, float, bool, np.ndarray, dict, list, tuple]) -> Union[np.ndarray, dict, list, tuple]: +def fill_invalid(obj: int | float | bool | np.ndarray | dict | list | tuple) -> np.ndarray | dict | list | tuple: if isinstance(obj, (int, float, bool)): return fill_invalid(np.array(obj)) if hasattr(obj, "dtype"): @@ -55,7 +55,7 @@ def fill_invalid(obj: Union[int, float, bool, np.ndarray, dict, list, tuple]) -> raise ValueError(f"Unsupported value to fill with invalid: {obj}") -def is_invalid(arr: Union[int, float, bool, np.ndarray, dict, list, tuple]) -> bool: +def is_invalid(arr: int | float | bool | np.ndarray | dict | list | tuple) -> bool: if hasattr(arr, "dtype"): if np.issubdtype(arr.dtype, np.floating): return np.isnan(arr).all() @@ -121,7 +121,7 @@ class FiniteVectorEnv(BaseVectorEnv): """ def __init__( - self, logger: Union[LogWriter, List[LogWriter]], env_fns: List[Callable[..., gym.Env]], **kwargs: Any + self, logger: LogWriter | List[LogWriter], env_fns: List[Callable[..., gym.Env]], **kwargs: Any ) -> None: super().__init__(env_fns, **kwargs) @@ -199,7 +199,7 @@ class FiniteVectorEnv(BaseVectorEnv): def reset( self, - id: Optional[Union[int, List[int], np.ndarray]] = None, + id: int | List[int] | np.ndarray = None, ) -> np.ndarray: assert not self._zombie @@ -251,7 +251,7 @@ class FiniteVectorEnv(BaseVectorEnv): def step( self, action: np.ndarray, - id: Optional[Union[int, List[int], np.ndarray]] = None, + id: int | List[int] | np.ndarray = None, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: assert not self._zombie id = self._wrap_id(id) @@ -304,7 +304,7 @@ def vectorize_env( env_factory: Callable[..., gym.Env], env_type: FiniteEnvType, concurrency: int, - logger: Union[LogWriter, List[LogWriter]], + logger: LogWriter | List[LogWriter], ) -> FiniteVectorEnv: """Helper function to create a vector env. diff --git a/qlib/rl/utils/log.py b/qlib/rl/utils/log.py index 915dcab1e..4e172343e 100644 --- a/qlib/rl/utils/log.py +++ b/qlib/rl/utils/log.py @@ -18,7 +18,7 @@ import logging from collections import defaultdict from enum import IntEnum from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Generic, List, Sequence, Set, Tuple, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, Generic, List, Sequence, Set, Tuple, TypeVar import numpy as np import pandas as pd @@ -65,19 +65,19 @@ class LogCollector: _logged: Dict[str, Tuple[int, Any]] _min_loglevel: int - def __init__(self, min_loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None: + def __init__(self, min_loglevel: int | LogLevel = LogLevel.PERIODIC) -> None: self._min_loglevel = int(min_loglevel) def reset(self) -> None: """Clear all collected contents.""" self._logged = {} - def _add_metric(self, name: str, metric: Any, loglevel: Union[int, LogLevel]) -> None: + def _add_metric(self, name: str, metric: Any, loglevel: int | LogLevel) -> None: if name in self._logged: raise ValueError(f"A metric with {name} is already added. Please change a name or reset the log collector.") self._logged[name] = (int(loglevel), metric) - def add_string(self, name: str, string: str, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None: + def add_string(self, name: str, string: str, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None: """Add a string with name into logged contents.""" if loglevel < self._min_loglevel: return @@ -85,7 +85,7 @@ class LogCollector: raise TypeError(f"{string} is not a string.") self._add_metric(name, string, loglevel) - def add_scalar(self, name: str, scalar: Any, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None: + def add_scalar(self, name: str, scalar: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None: """Add a scalar with name into logged contents. Scalar will be converted into a float. """ @@ -103,8 +103,8 @@ class LogCollector: def add_array( self, name: str, - array: Union[np.ndarray, pd.DataFrame, pd.Series], - loglevel: Union[int, LogLevel] = LogLevel.PERIODIC, + array: np.ndarray | pd.DataFrame | pd.Series, + loglevel: int | LogLevel = LogLevel.PERIODIC, ) -> None: """Add an array with name into logging.""" if loglevel < self._min_loglevel: @@ -114,7 +114,7 @@ class LogCollector: raise TypeError(f"{array} is not one of ndarray, DataFrame and Series.") self._add_metric(name, array, loglevel) - def add_any(self, name: str, obj: Any, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None: + def add_any(self, name: str, obj: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None: """Log something with any type. As it's an "any" object, the only LogWriter accepting it is pickle. @@ -163,7 +163,7 @@ class LogWriter(Generic[ObsType, ActType]): episode_logs: Dict[int, list] """Map from environment id to episode logs.""" - def __init__(self, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None: + def __init__(self, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None: self.loglevel = loglevel self.global_step = 0 @@ -279,7 +279,7 @@ class ConsoleWriter(LogWriter): total_episodes: int = None, float_format: str = ":.4f", counter_format: str = ":4d", - loglevel: Union[int, LogLevel] = LogLevel.PERIODIC, + loglevel: int | LogLevel = LogLevel.PERIODIC, ) -> None: super().__init__(loglevel) # TODO: support log_every_n_step @@ -354,7 +354,7 @@ class CsvWriter(LogWriter): all_records: List[Dict[str, Any]] - def __init__(self, output_dir: Path, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None: + def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None: super().__init__(loglevel) self.output_dir = output_dir self.output_dir.mkdir(exist_ok=True) diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index ff6291c44..bf8c9a55f 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from qlib.backtest.exchange import Exchange from qlib.backtest.position import BasePosition -from typing import Tuple, Union +from typing import Tuple from ..backtest.decision import BaseTradeDecision from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager @@ -232,8 +232,8 @@ class RLIntStrategy(RLStrategy, metaclass=ABCMeta): def __init__( self, policy, - state_interpreter: Union[dict, StateInterpreter], - action_interpreter: Union[dict, ActionInterpreter], + state_interpreter: dict | StateInterpreter, + action_interpreter: dict | ActionInterpreter, outer_trade_decision: BaseTradeDecision = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None,