From b184cc4125d83b8f8e3f466dcdb2f7bf1beae812 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Thu, 16 Jun 2022 13:33:59 +0800 Subject: [PATCH] Refine previous version RL codes --- qlib/rl/aux_info.py | 4 +- qlib/rl/data/pickle_styled.py | 46 ++++++++----- qlib/rl/entries/test.py | 13 ++-- qlib/rl/interpreter.py | 12 ++-- qlib/rl/order_execution/__init__.py | 5 -- qlib/rl/order_execution/interpreter.py | 22 +++--- qlib/rl/order_execution/network.py | 11 +-- qlib/rl/order_execution/policy.py | 66 ++++++++++++------ qlib/rl/order_execution/simulator_simple.py | 47 +++++++------ qlib/rl/reward.py | 8 +-- qlib/rl/simulator.py | 4 +- qlib/rl/utils/__init__.py | 5 -- qlib/rl/utils/data_queue.py | 36 +++++----- qlib/rl/utils/env_wrapper.py | 34 +++++----- qlib/rl/utils/finite_env.py | 48 +++++++------ qlib/rl/utils/log.py | 74 +++++++++++---------- tests/rl/test_logger.py | 14 ++-- 17 files changed, 254 insertions(+), 195 deletions(-) diff --git a/qlib/rl/aux_info.py b/qlib/rl/aux_info.py index 65cd95d5d..1fd581544 100644 --- a/qlib/rl/aux_info.py +++ b/qlib/rl/aux_info.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import Generic, TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Generic, Optional, TypeVar from qlib.typehint import final @@ -21,7 +21,7 @@ AuxInfoType = TypeVar("AuxInfoType") class AuxiliaryInfoCollector(Generic[StateType, AuxInfoType]): """Override this class to collect customized auxiliary information from environment.""" - env: EnvWrapper | None = None + env: Optional[EnvWrapper] = None @final def __call__(self, simulator_state: StateType) -> AuxInfoType: diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index 6cf386801..7a58512df 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -20,18 +20,17 @@ This file shows resemblence to qlib.backtest.high_performance_ds. We might merge from __future__ import annotations from functools import lru_cache -from typing import List, Sequence, cast from pathlib import Path +from typing import List, Sequence, cast import cachetools import numpy as np import pandas as pd from cachetools.keys import hashkey -from qlib.backtest.decision import OrderDir, Order +from qlib.backtest.decision import Order, OrderDir from qlib.typehint import Literal - DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"] """Several ad-hoc deal price. ``bid_or_ask``: If sell, use column ``$bid0``; if buy, use column ``$ask0``. @@ -40,7 +39,7 @@ DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"] """ -def _infer_processed_data_column_names(shape: int) -> list[str]: +def _infer_processed_data_column_names(shape: int) -> List[str]: if shape == 16: return [ "$open", @@ -95,8 +94,8 @@ class IntradayBacktestData: stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", - order_dir: int | None = None, - ): + order_dir: int = None, + ) -> None: backtest = _read_pickle(data_dir / stock_id) backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]] @@ -105,13 +104,13 @@ class IntradayBacktestData: self.data: pd.DataFrame = backtest self.deal_price_type: DealPriceType = deal_price - self.order_dir: int | None = order_dir + self.order_dir = order_dir - def __repr__(self): + def __repr__(self) -> str: with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"): return f"{self.__class__.__name__}({self.data})" - def __len__(self): + def __len__(self) -> int: return len(self.data) def get_deal_price(self) -> pd.Series: @@ -162,7 +161,14 @@ class IntradayProcessedData: """Processed data for "yesterday". Number of records must be ``time_length``, and columns must be ``feature_dim``.""" - def __init__(self, data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index): + def __init__( + self, + data_dir: Path, + stock_id: str, + date: pd.Timestamp, + feature_dim: int, + time_index: pd.Index, + ) -> None: proc = _read_pickle(data_dir / stock_id) # We have to infer the names here because, # unfortunately they are not included in the original data. @@ -190,14 +196,18 @@ class IntradayProcessedData: assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim assert len(self.today) == len(self.yesterday) == time_length - def __repr__(self): + def __repr__(self) -> str: with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"): return f"{self.__class__.__name__}({self.today}, {self.yesterday})" @lru_cache(maxsize=100) # 100 * 50K = 5MB def load_intraday_backtest_data( - data_dir: Path, stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", order_dir: int | None = None + data_dir: Path, + stock_id: str, + date: pd.Timestamp, + deal_price: DealPriceType = "close", + order_dir: int = None, ) -> IntradayBacktestData: return IntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir) @@ -207,13 +217,19 @@ def load_intraday_backtest_data( key=lambda data_dir, stock_id, date, _, __: hashkey(data_dir, stock_id, date), ) def load_intraday_processed_data( - data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index + data_dir: Path, + stock_id: str, + date: pd.Timestamp, + feature_dim: int, + time_index: pd.Index, ) -> IntradayProcessedData: return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index) def load_orders( - order_path: Path, start_time: pd.Timestamp | None = None, end_time: pd.Timestamp | None = None + order_path: Path, + start_time: pd.Timestamp = None, + end_time: pd.Timestamp = None, ) -> Sequence[Order]: """Load orders, and set start time and end time for the orders.""" @@ -251,7 +267,7 @@ def load_orders( int(row["order_type"]), row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second), row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second), - ) + ), ) return orders diff --git a/qlib/rl/entries/test.py b/qlib/rl/entries/test.py index ca311407b..8cd891200 100644 --- a/qlib/rl/entries/test.py +++ b/qlib/rl/entries/test.py @@ -4,19 +4,18 @@ from __future__ import annotations import copy -from typing import Callable, Sequence +from typing import Callable, List, Sequence, Union from tianshou.data import Collector from tianshou.policy import BasePolicy from qlib.constant import INF from qlib.log import get_module_logger -from qlib.rl.simulator import InitialStateType, Simulator -from qlib.rl.interpreter import StateInterpreter, ActionInterpreter +from qlib.rl.interpreter import ActionInterpreter, StateInterpreter from qlib.rl.reward import Reward +from qlib.rl.simulator import InitialStateType, Simulator from qlib.rl.utils import DataQueue, EnvWrapper, FiniteEnvType, LogCollector, LogWriter, vectorize_env - _logger = get_module_logger(__name__) @@ -26,8 +25,8 @@ def backtest( action_interpreter: ActionInterpreter, initial_states: Sequence[InitialStateType], policy: BasePolicy, - logger: LogWriter | list[LogWriter], - reward: Reward | None = None, + logger: Union[LogWriter, List[LogWriter]], + reward: Reward = None, finite_env_type: FiniteEnvType = "subproc", concurrency: int = 2, ) -> None: @@ -60,7 +59,7 @@ def backtest( # To save bandwidth min_loglevel = min(lg.loglevel for lg in logger) if isinstance(logger, list) else logger.loglevel - def env_factory(): + def env_factory() -> EnvWrapper: # FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env), # and could be thread unsafe. # I'm not sure whether it's a design flaw. diff --git a/qlib/rl/interpreter.py b/qlib/rl/interpreter.py index 3835b5b92..61c9b8381 100644 --- a/qlib/rl/interpreter.py +++ b/qlib/rl/interpreter.py @@ -3,13 +3,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, TypeVar, Generic, Any +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar import numpy as np from qlib.typehint import final -from .simulator import StateType, ActType +from .simulator import ActType, StateType if TYPE_CHECKING: from .utils.env_wrapper import EnvWrapper @@ -40,7 +40,7 @@ class Interpreter: class StateInterpreter(Generic[StateType, ObsType], Interpreter): """State Interpreter that interpret execution result of qlib executor into rl env state""" - env: EnvWrapper | None = None + env: Optional[EnvWrapper] = None @property def observation_space(self) -> gym.Space: @@ -74,7 +74,7 @@ class StateInterpreter(Generic[StateType, ObsType], Interpreter): class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter): """Action Interpreter that interpret rl agent action into qlib orders""" - env: "EnvWrapper" | None = None + env: Optional[EnvWrapper] = None @property def action_space(self) -> gym.Space: @@ -141,10 +141,10 @@ def _gym_space_contains(space: gym.Space, x: Any) -> None: class GymSpaceValidationError(Exception): - def __init__(self, message: str, space: gym.Space, x: Any): + def __init__(self, message: str, space: gym.Space, x: Any) -> None: self.message = message self.space = space self.x = x - def __str__(self): + def __str__(self) -> str: return f"{self.message}\n Space: {self.space}\n Sample: {self.x}" diff --git a/qlib/rl/order_execution/__init__.py b/qlib/rl/order_execution/__init__.py index 048dfecac..ea599a9e8 100644 --- a/qlib/rl/order_execution/__init__.py +++ b/qlib/rl/order_execution/__init__.py @@ -5,8 +5,3 @@ Currently it supports single-asset order execution. Multi-asset is on the way. """ - -from .interpreter import * -from .network import * -from .policy import * -from .simulator_simple import * diff --git a/qlib/rl/order_execution/interpreter.py b/qlib/rl/order_execution/interpreter.py index 9bb5dc2cf..788d22ae0 100644 --- a/qlib/rl/order_execution/interpreter.py +++ b/qlib/rl/order_execution/interpreter.py @@ -5,15 +5,15 @@ from __future__ import annotations import math from pathlib import Path -from typing import Any, cast +from typing import Any, List, Union, cast import numpy as np import pandas as pd from gym import spaces from qlib.constant import EPS -from qlib.rl.interpreter import StateInterpreter, ActionInterpreter from qlib.rl.data import pickle_styled +from qlib.rl.interpreter import ActionInterpreter, StateInterpreter from qlib.typehint import TypedDict from .simulator_simple import SAOEState @@ -26,7 +26,7 @@ __all__ = [ ] -def canonicalize(value: int | float | np.ndarray | pd.DataFrame | dict) -> np.ndarray | dict: +def canonicalize(value: Union[int, float, np.ndarray, pd.DataFrame, dict]) -> Union[np.ndarray, dict]: """To 32-bit numeric types. Recursively.""" if isinstance(value, pd.DataFrame): return value.to_numpy() @@ -99,18 +99,18 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]): "data_processed": self._mask_future_info(processed.today, state.cur_time), "data_processed_prev": processed.yesterday, "acquiring": state.order.direction == state.order.BUY, - "cur_tick": min(np.sum(state.ticks_index < state.cur_time), self.data_ticks - 1), + "cur_tick": min(float(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1), "cur_step": min(self.env.status["cur_step"], self.max_step - 1), "num_step": self.max_step, "target": state.order.amount, "position": state.position, "position_history": position_history[: self.max_step], - } + }, ), ) @property - def observation_space(self): + def observation_space(self) -> spaces.Dict: space = { "data_processed": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)), "data_processed_prev": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)), @@ -147,11 +147,11 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]): The key list is not full. You can add more if more information is needed by your policy. """ - def __init__(self, max_step: int): + def __init__(self, max_step: int) -> None: self.max_step = max_step @property - def observation_space(self): + def observation_space(self) -> spaces.Dict: space = { "acquiring": spaces.Discrete(2), "cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32), @@ -165,7 +165,7 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]): assert self.env is not None assert self.env.status["cur_step"] <= self.max_step obs = CurrentStateObs( - { + **{ "acquiring": state.order.direction == state.order.BUY, "cur_step": self.env.status["cur_step"], "num_step": self.max_step, @@ -188,7 +188,7 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]): i.e., $[0, 1/n, 2/n, \\ldots, n/n]$. """ - def __init__(self, values: int | list[float]): + def __init__(self, values: Union[int, List[float]]) -> None: if isinstance(values, int): values = [i / values for i in range(0, values + 1)] self.action_values = values @@ -203,7 +203,7 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]): class TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]): - """Convert a continous ratio to deal amount. + """Convert a continuous ratio to deal amount. The ratio is relative to TWAP on the remainder of the day. For example, there are 5 steps left, and the left position is 300. diff --git a/qlib/rl/order_execution/network.py b/qlib/rl/order_execution/network.py index 908f96130..3d0279559 100644 --- a/qlib/rl/order_execution/network.py +++ b/qlib/rl/order_execution/network.py @@ -3,13 +3,14 @@ from __future__ import annotations -from typing import cast +from typing import List, Tuple, cast import torch import torch.nn as nn from tianshou.data import Batch from qlib.typehint import Literal + from .interpreter import FullHistoryObs __all__ = ["Recurrent"] @@ -18,7 +19,7 @@ __all__ = ["Recurrent"] class Recurrent(nn.Module): """The network architecture proposed in `OPD `_. - At every timestep the input of policy network is divided into two parts, + At every time step the input of policy network is divided into two parts, the public variables and the private variables. which are handled by ``raw_rnn`` and ``pri_rnn`` in this network, respectively. @@ -33,7 +34,7 @@ class Recurrent(nn.Module): output_dim: int = 32, rnn_type: Literal["rnn", "lstm", "gru"] = "gru", rnn_num_layers: int = 1, - ): + ) -> None: super().__init__() self.hidden_dim = hidden_dim @@ -62,10 +63,10 @@ class Recurrent(nn.Module): nn.ReLU(), ) - def _init_extra_branches(self): + def _init_extra_branches(self) -> None: pass - def _source_features(self, obs: FullHistoryObs, device: torch.device) -> tuple[list[torch.Tensor], torch.Tensor]: + def _source_features(self, obs: FullHistoryObs, device: torch.device) -> Tuple[List[torch.Tensor], torch.Tensor]: bs, _, data_dim = obs["data_processed"].size() data = torch.cat((torch.zeros(bs, 1, data_dim, device=device), obs["data_processed"]), 1) cur_step = obs["cur_step"].long() diff --git a/qlib/rl/order_execution/policy.py b/qlib/rl/order_execution/policy.py index f95a53c75..e9737ca98 100644 --- a/qlib/rl/order_execution/policy.py +++ b/qlib/rl/order_execution/policy.py @@ -1,16 +1,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - +from abc import ABCMeta from pathlib import Path -from typing import Optional, cast +from typing import Any, Dict, Generator, Iterable, Optional, Tuple, Union, cast -import numpy as np import gym +import numpy as np import torch import torch.nn as nn from gym.spaces import Discrete -from tianshou.data import Batch, to_torch -from tianshou.policy import PPOPolicy, BasePolicy +from tianshou.data import Batch, ReplayBuffer, to_torch +from tianshou.policy import BasePolicy, PPOPolicy __all__ = ["AllOne", "PPO"] @@ -18,29 +18,39 @@ __all__ = ["AllOne", "PPO"] # baselines # -class NonlearnablePolicy(BasePolicy): +class NonLearnablePolicy(BasePolicy, metaclass=ABCMeta): """Tianshou's BasePolicy with empty ``learn`` and ``process_fn``. This could be moved outside in future. """ - def __init__(self, obs_space: gym.Space, action_space: gym.Space): + def __init__(self, obs_space: gym.Space, action_space: gym.Space) -> None: super().__init__() - def learn(self, batch, batch_size, repeat): + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, Any]: pass - def process_fn(self, batch, buffer, indice): + def process_fn( + self, + batch: Batch, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> Batch: pass -class AllOne(NonlearnablePolicy): +class AllOne(NonLearnablePolicy): """Forward returns a batch full of 1. Useful when implementing some baselines (e.g., TWAP). """ - def forward(self, batch, state=None, **kwargs): + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + **kwargs: Any, + ) -> Batch: return Batch(act=np.full(len(batch), 1.0), state=state) @@ -48,24 +58,34 @@ class AllOne(NonlearnablePolicy): class PPOActor(nn.Module): - def __init__(self, extractor: nn.Module, action_dim: int): + def __init__(self, extractor: nn.Module, action_dim: int) -> None: super().__init__() self.extractor = extractor self.layer_out = nn.Sequential(nn.Linear(cast(int, extractor.output_dim), action_dim), nn.Softmax(dim=-1)) - def forward(self, obs, state=None, info={}): + def forward( + self, + obs: torch.Tensor, + state: torch.Tensor = None, + info: dict = {}, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: feature = self.extractor(to_torch(obs, device=auto_device(self))) out = self.layer_out(feature) return out, state class PPOCritic(nn.Module): - def __init__(self, extractor: nn.Module): + def __init__(self, extractor: nn.Module) -> None: super().__init__() self.extractor = extractor self.value_out = nn.Linear(cast(int, extractor.output_dim), 1) - def forward(self, obs, state=None, info={}): + def forward( + self, + obs: torch.Tensor, + state: torch.Tensor = None, + info: dict = {}, + ) -> torch.Tensor: feature = self.extractor(to_torch(obs, device=auto_device(self))) return self.value_out(feature).squeeze(dim=-1) @@ -93,18 +113,20 @@ class PPO(PPOPolicy): max_grad_norm: float = 100.0, reward_normalization: bool = True, eps_clip: float = 0.3, - value_clip: float = True, + value_clip: bool = True, vf_coef: float = 1.0, gae_lambda: float = 1.0, - max_batchsize: int = 256, + max_batch_size: int = 256, deterministic_eval: bool = True, weight_file: Optional[Path] = None, - ): + ) -> None: assert isinstance(action_space, Discrete) actor = PPOActor(network, action_space.n) critic = PPOCritic(network) optimizer = torch.optim.Adam( - chain_dedup(actor.parameters(), critic.parameters()), lr=lr, weight_decay=weight_decay + chain_dedup(actor.parameters(), critic.parameters()), + lr=lr, + weight_decay=weight_decay, ) super().__init__( actor, @@ -118,7 +140,7 @@ class PPO(PPOPolicy): value_clip=value_clip, vf_coef=vf_coef, gae_lambda=gae_lambda, - max_batchsize=max_batchsize, + max_batchsize=max_batch_size, deterministic_eval=deterministic_eval, observation_space=obs_space, action_space=action_space, @@ -136,7 +158,7 @@ def auto_device(module: nn.Module) -> torch.device: return torch.device("cpu") # fallback to cpu -def load_weight(policy, path): +def load_weight(policy: nn.Module, path: Path) -> None: assert isinstance(policy, nn.Module), "Policy has to be an nn.Module to load weight." loaded_weight = torch.load(path, map_location="cpu") try: @@ -149,7 +171,7 @@ def load_weight(policy, path): policy.load_state_dict(loaded_weight) -def chain_dedup(*iterables): +def chain_dedup(*iterables: Iterable) -> Generator[Any, None, None]: seen = set() for iterable in iterables: for i in iterable: diff --git a/qlib/rl/order_execution/simulator_simple.py b/qlib/rl/order_execution/simulator_simple.py index 8022c34ce..048898118 100644 --- a/qlib/rl/order_execution/simulator_simple.py +++ b/qlib/rl/order_execution/simulator_simple.py @@ -4,15 +4,15 @@ from __future__ import annotations from pathlib import Path -from typing import NamedTuple, Any, TypeVar, cast +from typing import Any, NamedTuple, Optional, TypeVar, Union, cast import numpy as np import pandas as pd from qlib.backtest.decision import Order, OrderDir from qlib.constant import EPS +from qlib.rl.data.pickle_styled import DealPriceType, IntradayBacktestData, load_intraday_backtest_data from qlib.rl.simulator import Simulator -from qlib.rl.data.pickle_styled import IntradayBacktestData, load_intraday_backtest_data, DealPriceType from qlib.rl.utils import LogLevel from qlib.typehint import TypedDict @@ -33,7 +33,7 @@ class SAOEMetrics(TypedDict): stock_id: str """Stock ID of this record.""" - datetime: pd.Timestamp + datetime: Union[pd.Timestamp, pd.DatetimeIndex] # TODO: check this """Datetime of this record (this is index in the dataframe).""" direction: int """Direction of the order. 0 for sell, 1 for buy.""" @@ -87,7 +87,7 @@ class SAOEState(NamedTuple): history_steps: pd.DataFrame """See :attr:`SingleAssetOrderExecution.history_steps`.""" - metrics: SAOEMetrics | None + metrics: Optional[SAOEMetrics] """Daily metric, only available when the trading is in "done" state.""" backtest_data: IntradayBacktestData @@ -114,13 +114,13 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): If such fine granularity is not needed, use ``ticks_per_step`` to lengthen the ticks for each step. - In each step, the traded amount are "equally" splitted to each tick, - then bounded by volume maximum exeuction volume (i.e., ``vol_threshold``), + In each step, the traded amount are "equally" separated to each tick, + then bounded by volume maximum execution volume (i.e., ``vol_threshold``), and if it's the last step, try to ensure all the amount to be executed. Parameters ---------- - initial + order The seed to start an SAOE simulator is an order. ticks_per_step How many ticks per step. @@ -137,7 +137,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): """Positions at each step. The position before first step is also recorded. See :class:`SAOEMetrics` for available columns.""" - metrics: SAOEMetrics | None + metrics: Optional[SAOEMetrics] """Metrics. Only available when done.""" twap_price: float @@ -156,15 +156,21 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): data_dir: Path, ticks_per_step: int = 30, deal_price_type: DealPriceType = "close", - vol_threshold: float | None = None, + vol_threshold: Optional[float] = None, ) -> None: + super(SingleAssetOrderExecution, self).__init__(initial=order) + self.order = order self.ticks_per_step: int = ticks_per_step self.deal_price_type = deal_price_type self.vol_threshold = vol_threshold self.data_dir = data_dir self.backtest_data = load_intraday_backtest_data( - self.data_dir, order.stock_id, pd.Timestamp(order.start_time.date()), self.deal_price_type, order.direction + self.data_dir, + order.stock_id, + pd.Timestamp(order.start_time.date()), + self.deal_price_type, + order.direction, ) self.ticks_index = self.backtest_data.get_time_index() @@ -185,9 +191,9 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime") self.metrics = None - self.market_price: np.ndarray | None = None - self.market_vol: np.ndarray | None = None - self.market_vol_limit: np.ndarray | None = None + self.market_price: Optional[np.ndarray] = None + self.market_vol: Optional[np.ndarray] = None + self.market_vol_limit: Optional[np.ndarray] = None def step(self, amount: float) -> None: """Execute one step or SAOE. @@ -202,7 +208,8 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): self.market_price = self.market_vol = None # avoid misuse exec_vol = self._split_exec_vol(amount) - assert self.market_price is not None and self.market_vol is not None + assert self.market_price is not None + assert self.market_vol is not None ticks_position = self.position - np.cumsum(exec_vol) @@ -224,9 +231,9 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): direction=self.order.direction, market_volume=self.market_vol, market_price=self.market_price, - amount=exec_vol, - inner_amount=exec_vol, - deal_amount=exec_vol, + amount=exec_vol.sum(), # TODO: check this logic with Yuge & Xiao + inner_amount=exec_vol.sum(), # TODO: check this logic with Yuge & Xiao + deal_amount=exec_vol.sum(), # TODO: check this logic with Yuge & Xiao trade_price=self.market_price, trade_value=self.market_price * exec_vol, position=ticks_position, @@ -360,7 +367,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): inner_amount=exec_vol.sum(), deal_amount=exec_vol.sum(), # in this simulator, there's no other restrictions trade_price=exec_avg_price, - trade_value=np.sum(market_price * exec_vol), + trade_value=float(np.sum(market_price * exec_vol)), position=self.position, ffr=float(exec_vol.sum() / self.order.amount), pa=price_advantage(exec_avg_price, self.twap_price, self.order.direction), @@ -383,7 +390,9 @@ _float_or_ndarray = TypeVar("_float_or_ndarray", float, np.ndarray) def price_advantage( - exec_price: _float_or_ndarray, baseline_price: float, direction: OrderDir | int + exec_price: _float_or_ndarray, + baseline_price: float, + direction: Union[OrderDir, int], ) -> _float_or_ndarray: if baseline_price == 0: # something is wrong with data. Should be nan here if isinstance(exec_price, float): diff --git a/qlib/rl/reward.py b/qlib/rl/reward.py index 20d985874..0df7006f2 100644 --- a/qlib/rl/reward.py +++ b/qlib/rl/reward.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import Generic, Any, TypeVar, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Tuple, TypeVar from qlib.typehint import final @@ -20,7 +20,7 @@ class Reward(Generic[SimulatorState]): Subclass should implement ``reward(simulator_state)`` to implement their own reward calculation recipe. """ - env: EnvWrapper | None = None + env: Optional[EnvWrapper] = None @final def __call__(self, simulator_state: SimulatorState) -> float: @@ -30,14 +30,14 @@ class Reward(Generic[SimulatorState]): """Implement this method for your own reward.""" raise NotImplementedError("Implement reward calculation recipe in `reward()`.") - def log(self, name, value): + def log(self, name: str, value: Any) -> None: self.env.logger.add_scalar(name, value) class RewardCombination(Reward): """Combination of multiple reward.""" - def __init__(self, rewards: dict[str, tuple[Reward, float]]): + def __init__(self, rewards: Dict[str, Tuple[Reward, float]]) -> None: self.rewards = rewards def reward(self, simulator_state: Any) -> float: diff --git a/qlib/rl/simulator.py b/qlib/rl/simulator.py index 56fc12042..72e74b64f 100644 --- a/qlib/rl/simulator.py +++ b/qlib/rl/simulator.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import TypeVar, Generic, Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar from .seed import InitialStateType @@ -49,7 +49,7 @@ class Simulator(Generic[InitialStateType, StateType, ActType]): Simulators are discouraged to use this, because it's prone to induce errors. """ - env: EnvWrapper | None = None + env: Optional[EnvWrapper] = None def __init__(self, initial: InitialStateType, **kwargs: Any) -> None: pass diff --git a/qlib/rl/utils/__init__.py b/qlib/rl/utils/__init__.py index 4a1fa9d90..59e481eb9 100644 --- a/qlib/rl/utils/__init__.py +++ b/qlib/rl/utils/__init__.py @@ -1,7 +1,2 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - -from .data_queue import * -from .env_wrapper import * -from .finite_env import * -from .log import * diff --git a/qlib/rl/utils/data_queue.py b/qlib/rl/utils/data_queue.py index 32041abef..2432e85b7 100644 --- a/qlib/rl/utils/data_queue.py +++ b/qlib/rl/utils/data_queue.py @@ -1,13 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import os +from __future__ import annotations + import multiprocessing +import os import threading import time import warnings from queue import Empty -from typing import TypeVar, Generic, Sequence, cast +from typing import Any, Generator, Generic, Sequence, TypeVar, cast from qlib.log import get_module_logger @@ -60,7 +62,7 @@ class DataQueue(Generic[T]): shuffle: bool = True, producer_num_workers: int = 0, queue_maxsize: int = 0, - ): + ) -> None: if queue_maxsize == 0: if os.cpu_count() is not None: queue_maxsize = cast(int, os.cpu_count()) @@ -78,14 +80,14 @@ class DataQueue(Generic[T]): self._queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=queue_maxsize) self._done = multiprocessing.Value("i", 0) - def __enter__(self): + def __enter__(self) -> DataQueue: self.activate() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.cleanup() - def cleanup(self): + def cleanup(self) -> None: with self._done.get_lock(): self._done.value += 1 for repeat in range(500): @@ -105,7 +107,7 @@ class DataQueue(Generic[T]): break _logger.debug(f"Remaining items in queue collection done. Empty: {self._queue.empty()}") - def get(self, block=True): + def get(self, block: bool = True) -> Any: if not hasattr(self, "_first_get"): self._first_get = True if self._first_get: @@ -120,17 +122,17 @@ class DataQueue(Generic[T]): if self._done.value: raise StopIteration # pylint: disable=raise-missing-from - def put(self, obj, block=True, timeout=None): - return self._queue.put(obj, block=block, timeout=timeout) + def put(self, obj: Any, block: bool = True, timeout: int = None) -> None: + self._queue.put(obj, block=block, timeout=timeout) - def mark_as_done(self): + def mark_as_done(self) -> None: with self._done.get_lock(): self._done.value = 1 - def done(self): + def done(self) -> int: return self._done.value - def activate(self): + def activate(self) -> DataQueue: if self._activated: raise ValueError("DataQueue can not activate twice.") thread = threading.Thread(target=self._producer, daemon=True) @@ -138,18 +140,18 @@ class DataQueue(Generic[T]): self._activated = True return self - def __del__(self): + def __del__(self) -> None: _logger.debug(f"__del__ of {__name__}.DataQueue") self.cleanup() - def __iter__(self): + def __iter__(self) -> Generator[Any, None, None]: if not self._activated: raise ValueError( - "Need to call activate() to launch a daemon worker to produce data into data queue before using it." + "Need to call activate() to launch a daemon worker to produce data into data queue before using it.", ) return self._consumer() - def _consumer(self): + def _consumer(self) -> Generator[Any, None, None]: while True: try: yield self.get() @@ -157,7 +159,7 @@ class DataQueue(Generic[T]): _logger.debug("Data consumer timed-out from get.") return - def _producer(self): + def _producer(self) -> None: # pytorch dataloader is used here only because we need its sampler and multi-processing from torch.utils.data import DataLoader, Dataset # pylint: disable=import-outside-toplevel diff --git a/qlib/rl/utils/env_wrapper.py b/qlib/rl/utils/env_wrapper.py index f343e5b9b..4d50a32d8 100644 --- a/qlib/rl/utils/env_wrapper.py +++ b/qlib/rl/utils/env_wrapper.py @@ -4,14 +4,15 @@ from __future__ import annotations import weakref -from typing import Callable, Any, Iterable, Iterator, Generic, cast +from typing import Any, Callable, Dict, Generic, Iterable, Iterator, Optional, Tuple, Union, cast import gym +from gym import Space from qlib.rl.aux_info import AuxiliaryInfoCollector -from qlib.rl.simulator import Simulator, InitialStateType, StateType, ActType -from qlib.rl.interpreter import StateInterpreter, ActionInterpreter, PolicyActType, ObsType +from qlib.rl.interpreter import ActionInterpreter, ObsType, PolicyActType, StateInterpreter from qlib.rl.reward import Reward +from qlib.rl.simulator import ActType, InitialStateType, Simulator, StateType from qlib.typehint import TypedDict from .finite_env import generate_nan_observation @@ -28,7 +29,7 @@ class InfoDict(TypedDict): aux_info: dict """Any information depends on auxiliary info collector.""" - log: dict[str, Any] + log: Dict[str, Any] """Collected by LogCollector.""" @@ -42,14 +43,15 @@ class EnvWrapperStatus(TypedDict): cur_step: int done: bool - initial_state: Any | None + initial_state: Optional[Any] obs_history: list action_history: list reward_history: list class EnvWrapper( - gym.Env[ObsType, PolicyActType], Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType] + gym.Env[ObsType, PolicyActType], + Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType], ): """Qlib-based RL environment, subclassing ``gym.Env``. A wrapper of components, including simulator, state-interpreter, action-interpreter, reward. @@ -90,18 +92,18 @@ class EnvWrapper( """ simulator: Simulator[InitialStateType, StateType, ActType] - seed_iterator: str | Iterator[InitialStateType] | None + seed_iterator: Union[str, Iterator[InitialStateType], None] def __init__( self, simulator_fn: Callable[..., Simulator[InitialStateType, StateType, ActType]], state_interpreter: StateInterpreter[StateType, ObsType], action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType], - seed_iterator: Iterable[InitialStateType] | None, - reward_fn: Reward | None = None, - aux_info_collector: AuxiliaryInfoCollector[StateType, Any] | None = None, - logger: LogCollector | None = None, - ): + seed_iterator: Optional[Iterable[InitialStateType]], + reward_fn: Reward = None, + aux_info_collector: AuxiliaryInfoCollector[StateType, Any] = None, + logger: LogCollector = None, + ) -> None: # Assign weak reference to wrapper. # # Use weak reference here, because: @@ -135,11 +137,11 @@ class EnvWrapper( self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None) @property - def action_space(self): + def action_space(self) -> Space: return self.action_interpreter.action_space @property - def observation_space(self): + def observation_space(self) -> Space: return self.state_interpreter.observation_space def reset(self, **kwargs: Any) -> ObsType: @@ -191,7 +193,7 @@ class EnvWrapper( self.seed_iterator = None return generate_nan_observation(self.observation_space) - def step(self, policy_action: PolicyActType, **kwargs: Any) -> tuple[ObsType, float, bool, InfoDict]: + def step(self, policy_action: PolicyActType, **kwargs: Any) -> Tuple[ObsType, float, bool, InfoDict]: """Environment step. See the code along with comments to get a sequence of things happening here. @@ -245,5 +247,5 @@ class EnvWrapper( info_dict = InfoDict(log=self.logger.logs(), aux_info=aux_info) return obs, rew, done, info_dict - def render(self): + def render(self, mode: str = "human") -> None: raise NotImplementedError("Render is not implemented in EnvWrapper.") diff --git a/qlib/rl/utils/finite_env.py b/qlib/rl/utils/finite_env.py index fc9c2c75e..725dbe975 100644 --- a/qlib/rl/utils/finite_env.py +++ b/qlib/rl/utils/finite_env.py @@ -11,14 +11,14 @@ from __future__ import annotations import copy import warnings from contextlib import contextmanager +from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Type, Union, cast import gym import numpy as np -from typing import Any, Set, Callable, Type - from tianshou.env import BaseVectorEnv, DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv from qlib.typehint import Literal + from .log import LogWriter __all__ = [ @@ -36,7 +36,7 @@ __all__ = [ FiniteEnvType = Literal["dummy", "subproc", "shmem"] -def fill_invalid(obj): +def fill_invalid(obj: Union[int, float, bool, np.ndarray, dict, list, tuple]) -> Union[np.ndarray, dict, list, tuple]: if isinstance(obj, (int, float, bool)): return fill_invalid(np.array(obj)) if hasattr(obj, "dtype"): @@ -55,7 +55,7 @@ def fill_invalid(obj): raise ValueError(f"Unsupported value to fill with invalid: {obj}") -def is_invalid(arr): +def is_invalid(arr: Union[int, float, bool, np.ndarray, dict, list, tuple]) -> bool: if hasattr(arr, "dtype"): if np.issubdtype(arr.dtype, np.floating): return np.isnan(arr).all() @@ -121,11 +121,11 @@ class FiniteVectorEnv(BaseVectorEnv): """ def __init__( - self, logger: LogWriter | list[LogWriter], env_fns: list[Callable[..., gym.Env]], **kwargs: Any + self, logger: Union[LogWriter, List[LogWriter]], env_fns: List[Callable[..., gym.Env]], **kwargs: Any ) -> None: super().__init__(env_fns, **kwargs) - self._logger: list[LogWriter] = logger if isinstance(logger, list) else [logger] + self._logger: List[LogWriter] = logger if isinstance(logger, list) else [logger] self._alive_env_ids: Set[int] = set() self._reset_alive_envs() self._default_obs = self._default_info = self._default_rew = None @@ -133,44 +133,44 @@ class FiniteVectorEnv(BaseVectorEnv): self._collector_guarded: bool = False - def _reset_alive_envs(self): + def _reset_alive_envs(self) -> None: if not self._alive_env_ids: # starting or running out self._alive_env_ids = set(range(self.env_num)) # to workaround with tianshou's buffer and batch - def _set_default_obs(self, obs): + def _set_default_obs(self, obs: Any) -> None: if obs is not None and self._default_obs is None: self._default_obs = copy.deepcopy(obs) - def _set_default_info(self, info): + def _set_default_info(self, info: Any) -> None: if info is not None and self._default_info is None: self._default_info = copy.deepcopy(info) - def _set_default_rew(self, rew): + def _set_default_rew(self, rew: Any) -> None: if rew is not None and self._default_rew is None: self._default_rew = copy.deepcopy(rew) - def _get_default_obs(self): + def _get_default_obs(self) -> Any: return copy.deepcopy(self._default_obs) - def _get_default_info(self): + def _get_default_info(self) -> Any: return copy.deepcopy(self._default_info) - def _get_default_rew(self): + def _get_default_rew(self) -> Any: return copy.deepcopy(self._default_rew) # END @staticmethod - def _postproc_env_obs(obs): + def _postproc_env_obs(obs: Any) -> Optional[Any]: # reserved for shmem vector env to restore empty observation if obs is None or check_nan_observation(obs): return None return obs @contextmanager - def collector_guard(self): + def collector_guard(self) -> Generator[FiniteVectorEnv, None, None]: """Guard the collector. Recommended to guard every collect. This guard is for two purposes. @@ -197,7 +197,10 @@ class FiniteVectorEnv(BaseVectorEnv): for logger in self._logger: logger.on_env_all_done() - def reset(self, id=None): + def reset( + self, + id: Optional[Union[int, List[int], np.ndarray]] = None, + ) -> np.ndarray: assert not self._zombie # Check whether it's guarded by collector_guard() @@ -245,7 +248,11 @@ class FiniteVectorEnv(BaseVectorEnv): return np.stack(obs) - def step(self, action, id=None): + def step( + self, + action: np.ndarray, + id: Optional[Union[int, List[int], np.ndarray]] = None, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: assert not self._zombie id = self._wrap_id(id) id2idx = {i: k for k, i in enumerate(id)} @@ -277,7 +284,8 @@ class FiniteVectorEnv(BaseVectorEnv): if r[3] is None: result[i][3] = self._get_default_info() - return list(map(np.stack, zip(*result))) + ret = list(map(np.stack, zip(*result))) + return cast(Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], ret) class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv): @@ -296,7 +304,7 @@ def vectorize_env( env_factory: Callable[..., gym.Env], env_type: FiniteEnvType, concurrency: int, - logger: LogWriter | list[LogWriter], + logger: Union[LogWriter, List[LogWriter]], ) -> FiniteVectorEnv: """Helper function to create a vector env. @@ -326,7 +334,7 @@ def vectorize_env( def env_factory(): ... vectorize_env(env_factory, ...) """ - env_type_cls_mapping: dict[str, Type[FiniteVectorEnv]] = { + env_type_cls_mapping: Dict[str, Type[FiniteVectorEnv]] = { "dummy": FiniteDummyVectorEnv, "subproc": FiniteSubprocVectorEnv, "shmem": FiniteShmemVectorEnv, diff --git a/qlib/rl/utils/log.py b/qlib/rl/utils/log.py index 3d495b11d..915dcab1e 100644 --- a/qlib/rl/utils/log.py +++ b/qlib/rl/utils/log.py @@ -18,7 +18,7 @@ import logging from collections import defaultdict from enum import IntEnum from pathlib import Path -from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Any, Dict, Generic, List, Sequence, Set, Tuple, TypeVar, Union import numpy as np import pandas as pd @@ -62,22 +62,22 @@ class LogCollector: ``min_loglevel`` is for optimization purposes: to avoid too much traffic on networks / in pipe. """ - _logged: dict[str, tuple[int, Any]] + _logged: Dict[str, Tuple[int, Any]] _min_loglevel: int - def __init__(self, min_loglevel: int | LogLevel = LogLevel.PERIODIC): + def __init__(self, min_loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None: self._min_loglevel = int(min_loglevel) - def reset(self): + def reset(self) -> None: """Clear all collected contents.""" self._logged = {} - def _add_metric(self, name: str, metric: Any, loglevel: int | LogLevel) -> None: + def _add_metric(self, name: str, metric: Any, loglevel: Union[int, LogLevel]) -> None: if name in self._logged: raise ValueError(f"A metric with {name} is already added. Please change a name or reset the log collector.") self._logged[name] = (int(loglevel), metric) - def add_string(self, name: str, string: str, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None: + def add_string(self, name: str, string: str, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None: """Add a string with name into logged contents.""" if loglevel < self._min_loglevel: return @@ -85,7 +85,7 @@ class LogCollector: raise TypeError(f"{string} is not a string.") self._add_metric(name, string, loglevel) - def add_scalar(self, name: str, scalar: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None: + def add_scalar(self, name: str, scalar: Any, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None: """Add a scalar with name into logged contents. Scalar will be converted into a float. """ @@ -101,7 +101,10 @@ class LogCollector: self._add_metric(name, scalar, loglevel) def add_array( - self, name: str, array: np.ndarray | pd.DataFrame | pd.Series, loglevel: int | LogLevel = LogLevel.PERIODIC + self, + name: str, + array: Union[np.ndarray, pd.DataFrame, pd.Series], + loglevel: Union[int, LogLevel] = LogLevel.PERIODIC, ) -> None: """Add an array with name into logging.""" if loglevel < self._min_loglevel: @@ -111,7 +114,7 @@ class LogCollector: raise TypeError(f"{array} is not one of ndarray, DataFrame and Series.") self._add_metric(name, array, loglevel) - def add_any(self, name: str, obj: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None: + def add_any(self, name: str, obj: Any, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None: """Log something with any type. As it's an "any" object, the only LogWriter accepting it is pickle. @@ -124,7 +127,7 @@ class LogCollector: self._add_metric(name, obj, loglevel) - def logs(self) -> dict[str, np.ndarray]: + def logs(self) -> Dict[str, np.ndarray]: return {key: np.asanyarray(value, dtype="object") for key, value in self._logged.items()} @@ -151,16 +154,16 @@ class LogWriter(Generic[ObsType, ActType]): active_env_ids: Set[int] """Active environment ids in vector env.""" - episode_lengths: dict[int, int] + episode_lengths: Dict[int, int] """Map from environment id to episode length.""" - episode_rewards: dict[int, list[float]] + episode_rewards: Dict[int, List[float]] """Map from environment id to episode total reward.""" - episode_logs: dict[int, list] + episode_logs: Dict[int, list] """Map from environment id to episode logs.""" - def __init__(self, loglevel: int | LogLevel = LogLevel.PERIODIC): + def __init__(self, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None: self.loglevel = loglevel self.global_step = 0 @@ -174,12 +177,13 @@ class LogWriter(Generic[ObsType, ActType]): self.clear() - def clear(self): + def clear(self) -> None: self.episode_count = self.step_count = 0 self.active_env_ids = set() self.logs = [] - def aggregation(self, array: Sequence[Any]) -> Any: + @staticmethod + def aggregation(array: Sequence[Any]) -> Any: """Aggregation function from step-wise to episode-wise. If it's a sequence of float, take the mean. @@ -191,7 +195,7 @@ class LogWriter(Generic[ObsType, ActType]): else: return array[0] - def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None: + def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None: """This is triggered at the end of each trajectory. Parameters @@ -204,7 +208,7 @@ class LogWriter(Generic[ObsType, ActType]): Logged contents for every steps. """ - def log_step(self, reward: float, contents: dict[str, Any]) -> None: + def log_step(self, reward: float, contents: Dict[str, Any]) -> None: """This is triggered at each step. Parameters @@ -227,7 +231,7 @@ class LogWriter(Generic[ObsType, ActType]): # TODO: reward can be a list of list for MARL self.episode_rewards[env_id].append(rew) - values: dict[str, Any] = {} + values: Dict[str, Any] = {} for key, (loglevel, value) in info["log"].items(): if loglevel >= self.loglevel: # FIXME: this is actually incorrect (see last FIXME) @@ -272,11 +276,11 @@ class ConsoleWriter(LogWriter): def __init__( self, log_every_n_episode: int = 20, - total_episodes: int | None = None, + total_episodes: int = None, float_format: str = ":.4f", counter_format: str = ":4d", - loglevel: int | LogLevel = LogLevel.PERIODIC, - ): + loglevel: Union[int, LogLevel] = LogLevel.PERIODIC, + ) -> None: super().__init__(loglevel) # TODO: support log_every_n_step self.log_every_n_episode = log_every_n_episode @@ -289,15 +293,15 @@ class ConsoleWriter(LogWriter): self.console_logger = get_module_logger(__name__, level=logging.INFO) - def clear(self): + def clear(self) -> None: super().clear() # Clear average meters - self.metric_counts: dict[str, int] = defaultdict(int) - self.metric_sums: dict[str, float] = defaultdict(float) + self.metric_counts: Dict[str, int] = defaultdict(int) + self.metric_sums: Dict[str, float] = defaultdict(float) - def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None: + def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None: # Aggregate step-wise to episode-wise - episode_wise_contents: dict[str, list] = defaultdict(list) + episode_wise_contents: Dict[str, list] = defaultdict(list) for step_contents in contents: for name, value in step_contents.items(): @@ -306,7 +310,7 @@ class ConsoleWriter(LogWriter): # Generate log contents and track them in average-meter. # This should be done at every step, regardless of periodic or not. - logs: dict[str, float] = {} + logs: Dict[str, float] = {} for name, values in episode_wise_contents.items(): logs[name] = self.aggregation(values) # type: ignore @@ -318,7 +322,7 @@ class ConsoleWriter(LogWriter): # Only log periodically or at the end self.console_logger.info(self.generate_log_message(logs)) - def generate_log_message(self, logs: dict[str, float]) -> str: + def generate_log_message(self, logs: Dict[str, float]) -> str: if self.prefix: msg_prefix = self.prefix + " " else: @@ -348,27 +352,27 @@ class CsvWriter(LogWriter): SUPPORTED_TYPES = (float, str, pd.Timestamp) - all_records: list[dict[str, Any]] + all_records: List[Dict[str, Any]] - def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC): + def __init__(self, output_dir: Path, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None: super().__init__(loglevel) self.output_dir = output_dir self.output_dir.mkdir(exist_ok=True) - def clear(self): + def clear(self) -> None: super().clear() self.all_records = [] - def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None: + def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None: # FIXME Same as ConsoleLogger, needs a refactor to eliminate code-dup - episode_wise_contents: dict[str, list] = defaultdict(list) + episode_wise_contents: Dict[str, list] = defaultdict(list) for step_contents in contents: for name, value in step_contents.items(): if isinstance(value, self.SUPPORTED_TYPES): episode_wise_contents[name].append(value) - logs: dict[str, float] = {} + logs: Dict[str, float] = {} for name, values in episode_wise_contents.items(): logs[name] = self.aggregation(values) # type: ignore diff --git a/tests/rl/test_logger.py b/tests/rl/test_logger.py index 240ffc1e1..8e4789dfd 100644 --- a/tests/rl/test_logger.py +++ b/tests/rl/test_logger.py @@ -5,6 +5,8 @@ from random import randint, choice from pathlib import Path import re +from typing import Any, Tuple + import gym import numpy as np import pandas as pd @@ -24,16 +26,16 @@ from qlib.rl.utils.finite_env import vectorize_env class SimpleEnv(gym.Env[int, int]): - def __init__(self): + def __init__(self) -> None: self.logger = LogCollector() self.observation_space = gym.spaces.Discrete(2) self.action_space = gym.spaces.Discrete(2) - def reset(self): + def reset(self, *args: Any, **kwargs: Any) -> int: self.step_count = 0 return 0 - def step(self, action: int): + def step(self, action: int) -> Tuple[int, float, bool, dict]: self.logger.reset() self.logger.add_scalar("reward", 42.0) @@ -53,6 +55,9 @@ class SimpleEnv(gym.Env[int, int]): return 1, 42.0, done, InfoDict(log=self.logger.logs(), aux_info={}) + def render(self, mode: str = "human") -> None: + pass + class AnyPolicy(BasePolicy): def forward(self, batch, state=None): @@ -86,7 +91,8 @@ def test_simple_env_logger(caplog): class SimpleSimulator(Simulator[int, float, float]): - def __init__(self, initial: int, **kwargs) -> None: + def __init__(self, initial: int, **kwargs: Any) -> None: + super(SimpleSimulator, self).__init__(initial, **kwargs) self.initial = float(initial) def step(self, action: float) -> None: