mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 03:21:00 +08:00
Use | instead of Union
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Callable, List, Sequence, Union
|
||||
from typing import Callable, List, Sequence
|
||||
|
||||
from tianshou.data import Collector
|
||||
from tianshou.policy import BasePolicy
|
||||
@@ -25,7 +25,7 @@ def backtest(
|
||||
action_interpreter: ActionInterpreter,
|
||||
initial_states: Sequence[InitialStateType],
|
||||
policy: BasePolicy,
|
||||
logger: Union[LogWriter, List[LogWriter]],
|
||||
logger: LogWriter | List[LogWriter],
|
||||
reward: Reward = None,
|
||||
finite_env_type: FiniteEnvType = "subproc",
|
||||
concurrency: int = 2,
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Union, cast
|
||||
from typing import Any, List, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -26,7 +26,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def canonicalize(value: Union[int, float, np.ndarray, pd.DataFrame, dict]) -> Union[np.ndarray, dict]:
|
||||
def canonicalize(value: int | float | np.ndarray | pd.DataFrame | dict) -> np.ndarray | dict:
|
||||
"""To 32-bit numeric types. Recursively."""
|
||||
if isinstance(value, pd.DataFrame):
|
||||
return value.to_numpy()
|
||||
@@ -188,7 +188,7 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
|
||||
i.e., $[0, 1/n, 2/n, \\ldots, n/n]$.
|
||||
"""
|
||||
|
||||
def __init__(self, values: Union[int, List[float]]) -> None:
|
||||
def __init__(self, values: int | List[float]) -> None:
|
||||
if isinstance(values, int):
|
||||
values = [i / values for i in range(0, values + 1)]
|
||||
self.action_values = values
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABCMeta
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, Iterable, Optional, Tuple, Union, cast
|
||||
from typing import Any, Dict, Generator, Iterable, Optional, Tuple, cast
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
@@ -48,7 +50,7 @@ class AllOne(NonLearnablePolicy):
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
state: dict | Batch | np.ndarray = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
return Batch(act=np.full(len(batch), 1.0), state=state)
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, NamedTuple, Optional, TypeVar, Union, cast
|
||||
from typing import Any, NamedTuple, Optional, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -33,7 +33,7 @@ class SAOEMetrics(TypedDict):
|
||||
|
||||
stock_id: str
|
||||
"""Stock ID of this record."""
|
||||
datetime: Union[pd.Timestamp, pd.DatetimeIndex] # TODO: check this
|
||||
datetime: pd.Timestamp | pd.DatetimeIndex # TODO: check this
|
||||
"""Datetime of this record (this is index in the dataframe)."""
|
||||
direction: int
|
||||
"""Direction of the order. 0 for sell, 1 for buy."""
|
||||
@@ -392,7 +392,7 @@ _float_or_ndarray = TypeVar("_float_or_ndarray", float, np.ndarray)
|
||||
def price_advantage(
|
||||
exec_price: _float_or_ndarray,
|
||||
baseline_price: float,
|
||||
direction: Union[OrderDir, int],
|
||||
direction: OrderDir | int,
|
||||
) -> _float_or_ndarray:
|
||||
if baseline_price == 0: # something is wrong with data. Should be nan here
|
||||
if isinstance(exec_price, float):
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import weakref
|
||||
from typing import Any, Callable, Dict, Generic, Iterable, Iterator, Optional, Tuple, Union, cast
|
||||
from typing import Any, Callable, Dict, Generic, Iterable, Iterator, Optional, Tuple, cast
|
||||
|
||||
import gym
|
||||
from gym import Space
|
||||
@@ -92,7 +92,7 @@ class EnvWrapper(
|
||||
"""
|
||||
|
||||
simulator: Simulator[InitialStateType, StateType, ActType]
|
||||
seed_iterator: Union[str, Iterator[InitialStateType], None]
|
||||
seed_iterator: str | Iterator[InitialStateType] | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -11,7 +11,7 @@ from __future__ import annotations
|
||||
import copy
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Type, Union, cast
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Type, cast
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
@@ -36,7 +36,7 @@ __all__ = [
|
||||
FiniteEnvType = Literal["dummy", "subproc", "shmem"]
|
||||
|
||||
|
||||
def fill_invalid(obj: Union[int, float, bool, np.ndarray, dict, list, tuple]) -> Union[np.ndarray, dict, list, tuple]:
|
||||
def fill_invalid(obj: int | float | bool | np.ndarray | dict | list | tuple) -> np.ndarray | dict | list | tuple:
|
||||
if isinstance(obj, (int, float, bool)):
|
||||
return fill_invalid(np.array(obj))
|
||||
if hasattr(obj, "dtype"):
|
||||
@@ -55,7 +55,7 @@ def fill_invalid(obj: Union[int, float, bool, np.ndarray, dict, list, tuple]) ->
|
||||
raise ValueError(f"Unsupported value to fill with invalid: {obj}")
|
||||
|
||||
|
||||
def is_invalid(arr: Union[int, float, bool, np.ndarray, dict, list, tuple]) -> bool:
|
||||
def is_invalid(arr: int | float | bool | np.ndarray | dict | list | tuple) -> bool:
|
||||
if hasattr(arr, "dtype"):
|
||||
if np.issubdtype(arr.dtype, np.floating):
|
||||
return np.isnan(arr).all()
|
||||
@@ -121,7 +121,7 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, logger: Union[LogWriter, List[LogWriter]], env_fns: List[Callable[..., gym.Env]], **kwargs: Any
|
||||
self, logger: LogWriter | List[LogWriter], env_fns: List[Callable[..., gym.Env]], **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(env_fns, **kwargs)
|
||||
|
||||
@@ -199,7 +199,7 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
|
||||
def reset(
|
||||
self,
|
||||
id: Optional[Union[int, List[int], np.ndarray]] = None,
|
||||
id: int | List[int] | np.ndarray = None,
|
||||
) -> np.ndarray:
|
||||
assert not self._zombie
|
||||
|
||||
@@ -251,7 +251,7 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
def step(
|
||||
self,
|
||||
action: np.ndarray,
|
||||
id: Optional[Union[int, List[int], np.ndarray]] = None,
|
||||
id: int | List[int] | np.ndarray = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
assert not self._zombie
|
||||
id = self._wrap_id(id)
|
||||
@@ -304,7 +304,7 @@ def vectorize_env(
|
||||
env_factory: Callable[..., gym.Env],
|
||||
env_type: FiniteEnvType,
|
||||
concurrency: int,
|
||||
logger: Union[LogWriter, List[LogWriter]],
|
||||
logger: LogWriter | List[LogWriter],
|
||||
) -> FiniteVectorEnv:
|
||||
"""Helper function to create a vector env.
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import logging
|
||||
from collections import defaultdict
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Sequence, Set, Tuple, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Sequence, Set, Tuple, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -65,19 +65,19 @@ class LogCollector:
|
||||
_logged: Dict[str, Tuple[int, Any]]
|
||||
_min_loglevel: int
|
||||
|
||||
def __init__(self, min_loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
|
||||
def __init__(self, min_loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
|
||||
self._min_loglevel = int(min_loglevel)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear all collected contents."""
|
||||
self._logged = {}
|
||||
|
||||
def _add_metric(self, name: str, metric: Any, loglevel: Union[int, LogLevel]) -> None:
|
||||
def _add_metric(self, name: str, metric: Any, loglevel: int | LogLevel) -> None:
|
||||
if name in self._logged:
|
||||
raise ValueError(f"A metric with {name} is already added. Please change a name or reset the log collector.")
|
||||
self._logged[name] = (int(loglevel), metric)
|
||||
|
||||
def add_string(self, name: str, string: str, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
|
||||
def add_string(self, name: str, string: str, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
|
||||
"""Add a string with name into logged contents."""
|
||||
if loglevel < self._min_loglevel:
|
||||
return
|
||||
@@ -85,7 +85,7 @@ class LogCollector:
|
||||
raise TypeError(f"{string} is not a string.")
|
||||
self._add_metric(name, string, loglevel)
|
||||
|
||||
def add_scalar(self, name: str, scalar: Any, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
|
||||
def add_scalar(self, name: str, scalar: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
|
||||
"""Add a scalar with name into logged contents.
|
||||
Scalar will be converted into a float.
|
||||
"""
|
||||
@@ -103,8 +103,8 @@ class LogCollector:
|
||||
def add_array(
|
||||
self,
|
||||
name: str,
|
||||
array: Union[np.ndarray, pd.DataFrame, pd.Series],
|
||||
loglevel: Union[int, LogLevel] = LogLevel.PERIODIC,
|
||||
array: np.ndarray | pd.DataFrame | pd.Series,
|
||||
loglevel: int | LogLevel = LogLevel.PERIODIC,
|
||||
) -> None:
|
||||
"""Add an array with name into logging."""
|
||||
if loglevel < self._min_loglevel:
|
||||
@@ -114,7 +114,7 @@ class LogCollector:
|
||||
raise TypeError(f"{array} is not one of ndarray, DataFrame and Series.")
|
||||
self._add_metric(name, array, loglevel)
|
||||
|
||||
def add_any(self, name: str, obj: Any, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
|
||||
def add_any(self, name: str, obj: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
|
||||
"""Log something with any type.
|
||||
|
||||
As it's an "any" object, the only LogWriter accepting it is pickle.
|
||||
@@ -163,7 +163,7 @@ class LogWriter(Generic[ObsType, ActType]):
|
||||
episode_logs: Dict[int, list]
|
||||
"""Map from environment id to episode logs."""
|
||||
|
||||
def __init__(self, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
|
||||
def __init__(self, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
|
||||
self.loglevel = loglevel
|
||||
|
||||
self.global_step = 0
|
||||
@@ -279,7 +279,7 @@ class ConsoleWriter(LogWriter):
|
||||
total_episodes: int = None,
|
||||
float_format: str = ":.4f",
|
||||
counter_format: str = ":4d",
|
||||
loglevel: Union[int, LogLevel] = LogLevel.PERIODIC,
|
||||
loglevel: int | LogLevel = LogLevel.PERIODIC,
|
||||
) -> None:
|
||||
super().__init__(loglevel)
|
||||
# TODO: support log_every_n_step
|
||||
@@ -354,7 +354,7 @@ class CsvWriter(LogWriter):
|
||||
|
||||
all_records: List[Dict[str, Any]]
|
||||
|
||||
def __init__(self, output_dir: Path, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
|
||||
def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
|
||||
super().__init__(loglevel)
|
||||
self.output_dir = output_dir
|
||||
self.output_dir.mkdir(exist_ok=True)
|
||||
|
||||
@@ -9,7 +9,7 @@ if TYPE_CHECKING:
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.position import BasePosition
|
||||
|
||||
from typing import Tuple, Union
|
||||
from typing import Tuple
|
||||
|
||||
from ..backtest.decision import BaseTradeDecision
|
||||
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
|
||||
@@ -232,8 +232,8 @@ class RLIntStrategy(RLStrategy, metaclass=ABCMeta):
|
||||
def __init__(
|
||||
self,
|
||||
policy,
|
||||
state_interpreter: Union[dict, StateInterpreter],
|
||||
action_interpreter: Union[dict, ActionInterpreter],
|
||||
state_interpreter: dict | StateInterpreter,
|
||||
action_interpreter: dict | ActionInterpreter,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
|
||||
Reference in New Issue
Block a user