1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 03:21:00 +08:00

Use | instead of Union

This commit is contained in:
Huoran Li
2022-06-24 15:48:55 +08:00
parent 15340ff835
commit e23504c1d7
8 changed files with 35 additions and 33 deletions

View File

@@ -4,7 +4,7 @@
from __future__ import annotations
import copy
from typing import Callable, List, Sequence, Union
from typing import Callable, List, Sequence
from tianshou.data import Collector
from tianshou.policy import BasePolicy
@@ -25,7 +25,7 @@ def backtest(
action_interpreter: ActionInterpreter,
initial_states: Sequence[InitialStateType],
policy: BasePolicy,
logger: Union[LogWriter, List[LogWriter]],
logger: LogWriter | List[LogWriter],
reward: Reward = None,
finite_env_type: FiniteEnvType = "subproc",
concurrency: int = 2,

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import math
from pathlib import Path
from typing import Any, List, Union, cast
from typing import Any, List, cast
import numpy as np
import pandas as pd
@@ -26,7 +26,7 @@ __all__ = [
]
def canonicalize(value: Union[int, float, np.ndarray, pd.DataFrame, dict]) -> Union[np.ndarray, dict]:
def canonicalize(value: int | float | np.ndarray | pd.DataFrame | dict) -> np.ndarray | dict:
"""To 32-bit numeric types. Recursively."""
if isinstance(value, pd.DataFrame):
return value.to_numpy()
@@ -188,7 +188,7 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
i.e., $[0, 1/n, 2/n, \\ldots, n/n]$.
"""
def __init__(self, values: Union[int, List[float]]) -> None:
def __init__(self, values: int | List[float]) -> None:
if isinstance(values, int):
values = [i / values for i in range(0, values + 1)]
self.action_values = values

View File

@@ -1,8 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from abc import ABCMeta
from pathlib import Path
from typing import Any, Dict, Generator, Iterable, Optional, Tuple, Union, cast
from typing import Any, Dict, Generator, Iterable, Optional, Tuple, cast
import gym
import numpy as np
@@ -48,7 +50,7 @@ class AllOne(NonLearnablePolicy):
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
state: dict | Batch | np.ndarray = None,
**kwargs: Any,
) -> Batch:
return Batch(act=np.full(len(batch), 1.0), state=state)

View File

@@ -4,7 +4,7 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, NamedTuple, Optional, TypeVar, Union, cast
from typing import Any, NamedTuple, Optional, TypeVar, cast
import numpy as np
import pandas as pd
@@ -33,7 +33,7 @@ class SAOEMetrics(TypedDict):
stock_id: str
"""Stock ID of this record."""
datetime: Union[pd.Timestamp, pd.DatetimeIndex] # TODO: check this
datetime: pd.Timestamp | pd.DatetimeIndex # TODO: check this
"""Datetime of this record (this is index in the dataframe)."""
direction: int
"""Direction of the order. 0 for sell, 1 for buy."""
@@ -392,7 +392,7 @@ _float_or_ndarray = TypeVar("_float_or_ndarray", float, np.ndarray)
def price_advantage(
exec_price: _float_or_ndarray,
baseline_price: float,
direction: Union[OrderDir, int],
direction: OrderDir | int,
) -> _float_or_ndarray:
if baseline_price == 0: # something is wrong with data. Should be nan here
if isinstance(exec_price, float):

View File

@@ -4,7 +4,7 @@
from __future__ import annotations
import weakref
from typing import Any, Callable, Dict, Generic, Iterable, Iterator, Optional, Tuple, Union, cast
from typing import Any, Callable, Dict, Generic, Iterable, Iterator, Optional, Tuple, cast
import gym
from gym import Space
@@ -92,7 +92,7 @@ class EnvWrapper(
"""
simulator: Simulator[InitialStateType, StateType, ActType]
seed_iterator: Union[str, Iterator[InitialStateType], None]
seed_iterator: str | Iterator[InitialStateType] | None
def __init__(
self,

View File

@@ -11,7 +11,7 @@ from __future__ import annotations
import copy
import warnings
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Type, Union, cast
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Type, cast
import gym
import numpy as np
@@ -36,7 +36,7 @@ __all__ = [
FiniteEnvType = Literal["dummy", "subproc", "shmem"]
def fill_invalid(obj: Union[int, float, bool, np.ndarray, dict, list, tuple]) -> Union[np.ndarray, dict, list, tuple]:
def fill_invalid(obj: int | float | bool | np.ndarray | dict | list | tuple) -> np.ndarray | dict | list | tuple:
if isinstance(obj, (int, float, bool)):
return fill_invalid(np.array(obj))
if hasattr(obj, "dtype"):
@@ -55,7 +55,7 @@ def fill_invalid(obj: Union[int, float, bool, np.ndarray, dict, list, tuple]) ->
raise ValueError(f"Unsupported value to fill with invalid: {obj}")
def is_invalid(arr: Union[int, float, bool, np.ndarray, dict, list, tuple]) -> bool:
def is_invalid(arr: int | float | bool | np.ndarray | dict | list | tuple) -> bool:
if hasattr(arr, "dtype"):
if np.issubdtype(arr.dtype, np.floating):
return np.isnan(arr).all()
@@ -121,7 +121,7 @@ class FiniteVectorEnv(BaseVectorEnv):
"""
def __init__(
self, logger: Union[LogWriter, List[LogWriter]], env_fns: List[Callable[..., gym.Env]], **kwargs: Any
self, logger: LogWriter | List[LogWriter], env_fns: List[Callable[..., gym.Env]], **kwargs: Any
) -> None:
super().__init__(env_fns, **kwargs)
@@ -199,7 +199,7 @@ class FiniteVectorEnv(BaseVectorEnv):
def reset(
self,
id: Optional[Union[int, List[int], np.ndarray]] = None,
id: int | List[int] | np.ndarray = None,
) -> np.ndarray:
assert not self._zombie
@@ -251,7 +251,7 @@ class FiniteVectorEnv(BaseVectorEnv):
def step(
self,
action: np.ndarray,
id: Optional[Union[int, List[int], np.ndarray]] = None,
id: int | List[int] | np.ndarray = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
assert not self._zombie
id = self._wrap_id(id)
@@ -304,7 +304,7 @@ def vectorize_env(
env_factory: Callable[..., gym.Env],
env_type: FiniteEnvType,
concurrency: int,
logger: Union[LogWriter, List[LogWriter]],
logger: LogWriter | List[LogWriter],
) -> FiniteVectorEnv:
"""Helper function to create a vector env.

View File

@@ -18,7 +18,7 @@ import logging
from collections import defaultdict
from enum import IntEnum
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Sequence, Set, Tuple, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Sequence, Set, Tuple, TypeVar
import numpy as np
import pandas as pd
@@ -65,19 +65,19 @@ class LogCollector:
_logged: Dict[str, Tuple[int, Any]]
_min_loglevel: int
def __init__(self, min_loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
def __init__(self, min_loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
self._min_loglevel = int(min_loglevel)
def reset(self) -> None:
"""Clear all collected contents."""
self._logged = {}
def _add_metric(self, name: str, metric: Any, loglevel: Union[int, LogLevel]) -> None:
def _add_metric(self, name: str, metric: Any, loglevel: int | LogLevel) -> None:
if name in self._logged:
raise ValueError(f"A metric with {name} is already added. Please change a name or reset the log collector.")
self._logged[name] = (int(loglevel), metric)
def add_string(self, name: str, string: str, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
def add_string(self, name: str, string: str, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
"""Add a string with name into logged contents."""
if loglevel < self._min_loglevel:
return
@@ -85,7 +85,7 @@ class LogCollector:
raise TypeError(f"{string} is not a string.")
self._add_metric(name, string, loglevel)
def add_scalar(self, name: str, scalar: Any, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
def add_scalar(self, name: str, scalar: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
"""Add a scalar with name into logged contents.
Scalar will be converted into a float.
"""
@@ -103,8 +103,8 @@ class LogCollector:
def add_array(
self,
name: str,
array: Union[np.ndarray, pd.DataFrame, pd.Series],
loglevel: Union[int, LogLevel] = LogLevel.PERIODIC,
array: np.ndarray | pd.DataFrame | pd.Series,
loglevel: int | LogLevel = LogLevel.PERIODIC,
) -> None:
"""Add an array with name into logging."""
if loglevel < self._min_loglevel:
@@ -114,7 +114,7 @@ class LogCollector:
raise TypeError(f"{array} is not one of ndarray, DataFrame and Series.")
self._add_metric(name, array, loglevel)
def add_any(self, name: str, obj: Any, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
def add_any(self, name: str, obj: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
"""Log something with any type.
As it's an "any" object, the only LogWriter accepting it is pickle.
@@ -163,7 +163,7 @@ class LogWriter(Generic[ObsType, ActType]):
episode_logs: Dict[int, list]
"""Map from environment id to episode logs."""
def __init__(self, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
def __init__(self, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
self.loglevel = loglevel
self.global_step = 0
@@ -279,7 +279,7 @@ class ConsoleWriter(LogWriter):
total_episodes: int = None,
float_format: str = ":.4f",
counter_format: str = ":4d",
loglevel: Union[int, LogLevel] = LogLevel.PERIODIC,
loglevel: int | LogLevel = LogLevel.PERIODIC,
) -> None:
super().__init__(loglevel)
# TODO: support log_every_n_step
@@ -354,7 +354,7 @@ class CsvWriter(LogWriter):
all_records: List[Dict[str, Any]]
def __init__(self, output_dir: Path, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
super().__init__(loglevel)
self.output_dir = output_dir
self.output_dir.mkdir(exist_ok=True)

View File

@@ -9,7 +9,7 @@ if TYPE_CHECKING:
from qlib.backtest.exchange import Exchange
from qlib.backtest.position import BasePosition
from typing import Tuple, Union
from typing import Tuple
from ..backtest.decision import BaseTradeDecision
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
@@ -232,8 +232,8 @@ class RLIntStrategy(RLStrategy, metaclass=ABCMeta):
def __init__(
self,
policy,
state_interpreter: Union[dict, StateInterpreter],
action_interpreter: Union[dict, ActionInterpreter],
state_interpreter: dict | StateInterpreter,
action_interpreter: dict | ActionInterpreter,
outer_trade_decision: BaseTradeDecision = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,