1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 19:10:58 +08:00

Refine previous version RL codes

This commit is contained in:
Huoran Li
2022-06-16 13:33:59 +08:00
parent 13d904d9a9
commit b184cc4125
17 changed files with 254 additions and 195 deletions

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
from typing import Generic, TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING, Generic, Optional, TypeVar
from qlib.typehint import final
@@ -21,7 +21,7 @@ AuxInfoType = TypeVar("AuxInfoType")
class AuxiliaryInfoCollector(Generic[StateType, AuxInfoType]):
"""Override this class to collect customized auxiliary information from environment."""
env: EnvWrapper | None = None
env: Optional[EnvWrapper] = None
@final
def __call__(self, simulator_state: StateType) -> AuxInfoType:

View File

@@ -20,18 +20,17 @@ This file shows resemblence to qlib.backtest.high_performance_ds. We might merge
from __future__ import annotations
from functools import lru_cache
from typing import List, Sequence, cast
from pathlib import Path
from typing import List, Sequence, cast
import cachetools
import numpy as np
import pandas as pd
from cachetools.keys import hashkey
from qlib.backtest.decision import OrderDir, Order
from qlib.backtest.decision import Order, OrderDir
from qlib.typehint import Literal
DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"]
"""Several ad-hoc deal price.
``bid_or_ask``: If sell, use column ``$bid0``; if buy, use column ``$ask0``.
@@ -40,7 +39,7 @@ DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"]
"""
def _infer_processed_data_column_names(shape: int) -> list[str]:
def _infer_processed_data_column_names(shape: int) -> List[str]:
if shape == 16:
return [
"$open",
@@ -95,8 +94,8 @@ class IntradayBacktestData:
stock_id: str,
date: pd.Timestamp,
deal_price: DealPriceType = "close",
order_dir: int | None = None,
):
order_dir: int = None,
) -> None:
backtest = _read_pickle(data_dir / stock_id)
backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]]
@@ -105,13 +104,13 @@ class IntradayBacktestData:
self.data: pd.DataFrame = backtest
self.deal_price_type: DealPriceType = deal_price
self.order_dir: int | None = order_dir
self.order_dir = order_dir
def __repr__(self):
def __repr__(self) -> str:
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
return f"{self.__class__.__name__}({self.data})"
def __len__(self):
def __len__(self) -> int:
return len(self.data)
def get_deal_price(self) -> pd.Series:
@@ -162,7 +161,14 @@ class IntradayProcessedData:
"""Processed data for "yesterday".
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
def __init__(self, data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index):
def __init__(
self,
data_dir: Path,
stock_id: str,
date: pd.Timestamp,
feature_dim: int,
time_index: pd.Index,
) -> None:
proc = _read_pickle(data_dir / stock_id)
# We have to infer the names here because,
# unfortunately they are not included in the original data.
@@ -190,14 +196,18 @@ class IntradayProcessedData:
assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim
assert len(self.today) == len(self.yesterday) == time_length
def __repr__(self):
def __repr__(self) -> str:
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
return f"{self.__class__.__name__}({self.today}, {self.yesterday})"
@lru_cache(maxsize=100) # 100 * 50K = 5MB
def load_intraday_backtest_data(
data_dir: Path, stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", order_dir: int | None = None
data_dir: Path,
stock_id: str,
date: pd.Timestamp,
deal_price: DealPriceType = "close",
order_dir: int = None,
) -> IntradayBacktestData:
return IntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
@@ -207,13 +217,19 @@ def load_intraday_backtest_data(
key=lambda data_dir, stock_id, date, _, __: hashkey(data_dir, stock_id, date),
)
def load_intraday_processed_data(
data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index
data_dir: Path,
stock_id: str,
date: pd.Timestamp,
feature_dim: int,
time_index: pd.Index,
) -> IntradayProcessedData:
return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
def load_orders(
order_path: Path, start_time: pd.Timestamp | None = None, end_time: pd.Timestamp | None = None
order_path: Path,
start_time: pd.Timestamp = None,
end_time: pd.Timestamp = None,
) -> Sequence[Order]:
"""Load orders, and set start time and end time for the orders."""
@@ -251,7 +267,7 @@ def load_orders(
int(row["order_type"]),
row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second),
row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second),
)
),
)
return orders

View File

@@ -4,19 +4,18 @@
from __future__ import annotations
import copy
from typing import Callable, Sequence
from typing import Callable, List, Sequence, Union
from tianshou.data import Collector
from tianshou.policy import BasePolicy
from qlib.constant import INF
from qlib.log import get_module_logger
from qlib.rl.simulator import InitialStateType, Simulator
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
from qlib.rl.reward import Reward
from qlib.rl.simulator import InitialStateType, Simulator
from qlib.rl.utils import DataQueue, EnvWrapper, FiniteEnvType, LogCollector, LogWriter, vectorize_env
_logger = get_module_logger(__name__)
@@ -26,8 +25,8 @@ def backtest(
action_interpreter: ActionInterpreter,
initial_states: Sequence[InitialStateType],
policy: BasePolicy,
logger: LogWriter | list[LogWriter],
reward: Reward | None = None,
logger: Union[LogWriter, List[LogWriter]],
reward: Reward = None,
finite_env_type: FiniteEnvType = "subproc",
concurrency: int = 2,
) -> None:
@@ -60,7 +59,7 @@ def backtest(
# To save bandwidth
min_loglevel = min(lg.loglevel for lg in logger) if isinstance(logger, list) else logger.loglevel
def env_factory():
def env_factory() -> EnvWrapper:
# FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
# and could be thread unsafe.
# I'm not sure whether it's a design flaw.

View File

@@ -3,13 +3,13 @@
from __future__ import annotations
from typing import TYPE_CHECKING, TypeVar, Generic, Any
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
import numpy as np
from qlib.typehint import final
from .simulator import StateType, ActType
from .simulator import ActType, StateType
if TYPE_CHECKING:
from .utils.env_wrapper import EnvWrapper
@@ -40,7 +40,7 @@ class Interpreter:
class StateInterpreter(Generic[StateType, ObsType], Interpreter):
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
env: EnvWrapper | None = None
env: Optional[EnvWrapper] = None
@property
def observation_space(self) -> gym.Space:
@@ -74,7 +74,7 @@ class StateInterpreter(Generic[StateType, ObsType], Interpreter):
class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter):
"""Action Interpreter that interpret rl agent action into qlib orders"""
env: "EnvWrapper" | None = None
env: Optional[EnvWrapper] = None
@property
def action_space(self) -> gym.Space:
@@ -141,10 +141,10 @@ def _gym_space_contains(space: gym.Space, x: Any) -> None:
class GymSpaceValidationError(Exception):
def __init__(self, message: str, space: gym.Space, x: Any):
def __init__(self, message: str, space: gym.Space, x: Any) -> None:
self.message = message
self.space = space
self.x = x
def __str__(self):
def __str__(self) -> str:
return f"{self.message}\n Space: {self.space}\n Sample: {self.x}"

View File

@@ -5,8 +5,3 @@
Currently it supports single-asset order execution.
Multi-asset is on the way.
"""
from .interpreter import *
from .network import *
from .policy import *
from .simulator_simple import *

View File

@@ -5,15 +5,15 @@ from __future__ import annotations
import math
from pathlib import Path
from typing import Any, cast
from typing import Any, List, Union, cast
import numpy as np
import pandas as pd
from gym import spaces
from qlib.constant import EPS
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.data import pickle_styled
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
from qlib.typehint import TypedDict
from .simulator_simple import SAOEState
@@ -26,7 +26,7 @@ __all__ = [
]
def canonicalize(value: int | float | np.ndarray | pd.DataFrame | dict) -> np.ndarray | dict:
def canonicalize(value: Union[int, float, np.ndarray, pd.DataFrame, dict]) -> Union[np.ndarray, dict]:
"""To 32-bit numeric types. Recursively."""
if isinstance(value, pd.DataFrame):
return value.to_numpy()
@@ -99,18 +99,18 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
"data_processed": self._mask_future_info(processed.today, state.cur_time),
"data_processed_prev": processed.yesterday,
"acquiring": state.order.direction == state.order.BUY,
"cur_tick": min(np.sum(state.ticks_index < state.cur_time), self.data_ticks - 1),
"cur_tick": min(float(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1),
"cur_step": min(self.env.status["cur_step"], self.max_step - 1),
"num_step": self.max_step,
"target": state.order.amount,
"position": state.position,
"position_history": position_history[: self.max_step],
}
},
),
)
@property
def observation_space(self):
def observation_space(self) -> spaces.Dict:
space = {
"data_processed": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),
"data_processed_prev": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),
@@ -147,11 +147,11 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
The key list is not full. You can add more if more information is needed by your policy.
"""
def __init__(self, max_step: int):
def __init__(self, max_step: int) -> None:
self.max_step = max_step
@property
def observation_space(self):
def observation_space(self) -> spaces.Dict:
space = {
"acquiring": spaces.Discrete(2),
"cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32),
@@ -165,7 +165,7 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
assert self.env is not None
assert self.env.status["cur_step"] <= self.max_step
obs = CurrentStateObs(
{
**{
"acquiring": state.order.direction == state.order.BUY,
"cur_step": self.env.status["cur_step"],
"num_step": self.max_step,
@@ -188,7 +188,7 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
i.e., $[0, 1/n, 2/n, \\ldots, n/n]$.
"""
def __init__(self, values: int | list[float]):
def __init__(self, values: Union[int, List[float]]) -> None:
if isinstance(values, int):
values = [i / values for i in range(0, values + 1)]
self.action_values = values
@@ -203,7 +203,7 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
class TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]):
"""Convert a continous ratio to deal amount.
"""Convert a continuous ratio to deal amount.
The ratio is relative to TWAP on the remainder of the day.
For example, there are 5 steps left, and the left position is 300.

View File

@@ -3,13 +3,14 @@
from __future__ import annotations
from typing import cast
from typing import List, Tuple, cast
import torch
import torch.nn as nn
from tianshou.data import Batch
from qlib.typehint import Literal
from .interpreter import FullHistoryObs
__all__ = ["Recurrent"]
@@ -18,7 +19,7 @@ __all__ = ["Recurrent"]
class Recurrent(nn.Module):
"""The network architecture proposed in `OPD <https://seqml.github.io/opd/opd_aaai21_supplement.pdf>`_.
At every timestep the input of policy network is divided into two parts,
At every time step the input of policy network is divided into two parts,
the public variables and the private variables. which are handled by ``raw_rnn``
and ``pri_rnn`` in this network, respectively.
@@ -33,7 +34,7 @@ class Recurrent(nn.Module):
output_dim: int = 32,
rnn_type: Literal["rnn", "lstm", "gru"] = "gru",
rnn_num_layers: int = 1,
):
) -> None:
super().__init__()
self.hidden_dim = hidden_dim
@@ -62,10 +63,10 @@ class Recurrent(nn.Module):
nn.ReLU(),
)
def _init_extra_branches(self):
def _init_extra_branches(self) -> None:
pass
def _source_features(self, obs: FullHistoryObs, device: torch.device) -> tuple[list[torch.Tensor], torch.Tensor]:
def _source_features(self, obs: FullHistoryObs, device: torch.device) -> Tuple[List[torch.Tensor], torch.Tensor]:
bs, _, data_dim = obs["data_processed"].size()
data = torch.cat((torch.zeros(bs, 1, data_dim, device=device), obs["data_processed"]), 1)
cur_step = obs["cur_step"].long()

View File

@@ -1,16 +1,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from abc import ABCMeta
from pathlib import Path
from typing import Optional, cast
from typing import Any, Dict, Generator, Iterable, Optional, Tuple, Union, cast
import numpy as np
import gym
import numpy as np
import torch
import torch.nn as nn
from gym.spaces import Discrete
from tianshou.data import Batch, to_torch
from tianshou.policy import PPOPolicy, BasePolicy
from tianshou.data import Batch, ReplayBuffer, to_torch
from tianshou.policy import BasePolicy, PPOPolicy
__all__ = ["AllOne", "PPO"]
@@ -18,29 +18,39 @@ __all__ = ["AllOne", "PPO"]
# baselines #
class NonlearnablePolicy(BasePolicy):
class NonLearnablePolicy(BasePolicy, metaclass=ABCMeta):
"""Tianshou's BasePolicy with empty ``learn`` and ``process_fn``.
This could be moved outside in future.
"""
def __init__(self, obs_space: gym.Space, action_space: gym.Space):
def __init__(self, obs_space: gym.Space, action_space: gym.Space) -> None:
super().__init__()
def learn(self, batch, batch_size, repeat):
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, Any]:
pass
def process_fn(self, batch, buffer, indice):
def process_fn(
self,
batch: Batch,
buffer: ReplayBuffer,
indices: np.ndarray,
) -> Batch:
pass
class AllOne(NonlearnablePolicy):
class AllOne(NonLearnablePolicy):
"""Forward returns a batch full of 1.
Useful when implementing some baselines (e.g., TWAP).
"""
def forward(self, batch, state=None, **kwargs):
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> Batch:
return Batch(act=np.full(len(batch), 1.0), state=state)
@@ -48,24 +58,34 @@ class AllOne(NonlearnablePolicy):
class PPOActor(nn.Module):
def __init__(self, extractor: nn.Module, action_dim: int):
def __init__(self, extractor: nn.Module, action_dim: int) -> None:
super().__init__()
self.extractor = extractor
self.layer_out = nn.Sequential(nn.Linear(cast(int, extractor.output_dim), action_dim), nn.Softmax(dim=-1))
def forward(self, obs, state=None, info={}):
def forward(
self,
obs: torch.Tensor,
state: torch.Tensor = None,
info: dict = {},
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
feature = self.extractor(to_torch(obs, device=auto_device(self)))
out = self.layer_out(feature)
return out, state
class PPOCritic(nn.Module):
def __init__(self, extractor: nn.Module):
def __init__(self, extractor: nn.Module) -> None:
super().__init__()
self.extractor = extractor
self.value_out = nn.Linear(cast(int, extractor.output_dim), 1)
def forward(self, obs, state=None, info={}):
def forward(
self,
obs: torch.Tensor,
state: torch.Tensor = None,
info: dict = {},
) -> torch.Tensor:
feature = self.extractor(to_torch(obs, device=auto_device(self)))
return self.value_out(feature).squeeze(dim=-1)
@@ -93,18 +113,20 @@ class PPO(PPOPolicy):
max_grad_norm: float = 100.0,
reward_normalization: bool = True,
eps_clip: float = 0.3,
value_clip: float = True,
value_clip: bool = True,
vf_coef: float = 1.0,
gae_lambda: float = 1.0,
max_batchsize: int = 256,
max_batch_size: int = 256,
deterministic_eval: bool = True,
weight_file: Optional[Path] = None,
):
) -> None:
assert isinstance(action_space, Discrete)
actor = PPOActor(network, action_space.n)
critic = PPOCritic(network)
optimizer = torch.optim.Adam(
chain_dedup(actor.parameters(), critic.parameters()), lr=lr, weight_decay=weight_decay
chain_dedup(actor.parameters(), critic.parameters()),
lr=lr,
weight_decay=weight_decay,
)
super().__init__(
actor,
@@ -118,7 +140,7 @@ class PPO(PPOPolicy):
value_clip=value_clip,
vf_coef=vf_coef,
gae_lambda=gae_lambda,
max_batchsize=max_batchsize,
max_batchsize=max_batch_size,
deterministic_eval=deterministic_eval,
observation_space=obs_space,
action_space=action_space,
@@ -136,7 +158,7 @@ def auto_device(module: nn.Module) -> torch.device:
return torch.device("cpu") # fallback to cpu
def load_weight(policy, path):
def load_weight(policy: nn.Module, path: Path) -> None:
assert isinstance(policy, nn.Module), "Policy has to be an nn.Module to load weight."
loaded_weight = torch.load(path, map_location="cpu")
try:
@@ -149,7 +171,7 @@ def load_weight(policy, path):
policy.load_state_dict(loaded_weight)
def chain_dedup(*iterables):
def chain_dedup(*iterables: Iterable) -> Generator[Any, None, None]:
seen = set()
for iterable in iterables:
for i in iterable:

View File

@@ -4,15 +4,15 @@
from __future__ import annotations
from pathlib import Path
from typing import NamedTuple, Any, TypeVar, cast
from typing import Any, NamedTuple, Optional, TypeVar, Union, cast
import numpy as np
import pandas as pd
from qlib.backtest.decision import Order, OrderDir
from qlib.constant import EPS
from qlib.rl.data.pickle_styled import DealPriceType, IntradayBacktestData, load_intraday_backtest_data
from qlib.rl.simulator import Simulator
from qlib.rl.data.pickle_styled import IntradayBacktestData, load_intraday_backtest_data, DealPriceType
from qlib.rl.utils import LogLevel
from qlib.typehint import TypedDict
@@ -33,7 +33,7 @@ class SAOEMetrics(TypedDict):
stock_id: str
"""Stock ID of this record."""
datetime: pd.Timestamp
datetime: Union[pd.Timestamp, pd.DatetimeIndex] # TODO: check this
"""Datetime of this record (this is index in the dataframe)."""
direction: int
"""Direction of the order. 0 for sell, 1 for buy."""
@@ -87,7 +87,7 @@ class SAOEState(NamedTuple):
history_steps: pd.DataFrame
"""See :attr:`SingleAssetOrderExecution.history_steps`."""
metrics: SAOEMetrics | None
metrics: Optional[SAOEMetrics]
"""Daily metric, only available when the trading is in "done" state."""
backtest_data: IntradayBacktestData
@@ -114,13 +114,13 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
If such fine granularity is not needed, use ``ticks_per_step`` to
lengthen the ticks for each step.
In each step, the traded amount are "equally" splitted to each tick,
then bounded by volume maximum exeuction volume (i.e., ``vol_threshold``),
In each step, the traded amount are "equally" separated to each tick,
then bounded by volume maximum execution volume (i.e., ``vol_threshold``),
and if it's the last step, try to ensure all the amount to be executed.
Parameters
----------
initial
order
The seed to start an SAOE simulator is an order.
ticks_per_step
How many ticks per step.
@@ -137,7 +137,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
"""Positions at each step. The position before first step is also recorded.
See :class:`SAOEMetrics` for available columns."""
metrics: SAOEMetrics | None
metrics: Optional[SAOEMetrics]
"""Metrics. Only available when done."""
twap_price: float
@@ -156,15 +156,21 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
data_dir: Path,
ticks_per_step: int = 30,
deal_price_type: DealPriceType = "close",
vol_threshold: float | None = None,
vol_threshold: Optional[float] = None,
) -> None:
super(SingleAssetOrderExecution, self).__init__(initial=order)
self.order = order
self.ticks_per_step: int = ticks_per_step
self.deal_price_type = deal_price_type
self.vol_threshold = vol_threshold
self.data_dir = data_dir
self.backtest_data = load_intraday_backtest_data(
self.data_dir, order.stock_id, pd.Timestamp(order.start_time.date()), self.deal_price_type, order.direction
self.data_dir,
order.stock_id,
pd.Timestamp(order.start_time.date()),
self.deal_price_type,
order.direction,
)
self.ticks_index = self.backtest_data.get_time_index()
@@ -185,9 +191,9 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
self.metrics = None
self.market_price: np.ndarray | None = None
self.market_vol: np.ndarray | None = None
self.market_vol_limit: np.ndarray | None = None
self.market_price: Optional[np.ndarray] = None
self.market_vol: Optional[np.ndarray] = None
self.market_vol_limit: Optional[np.ndarray] = None
def step(self, amount: float) -> None:
"""Execute one step or SAOE.
@@ -202,7 +208,8 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
self.market_price = self.market_vol = None # avoid misuse
exec_vol = self._split_exec_vol(amount)
assert self.market_price is not None and self.market_vol is not None
assert self.market_price is not None
assert self.market_vol is not None
ticks_position = self.position - np.cumsum(exec_vol)
@@ -224,9 +231,9 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
direction=self.order.direction,
market_volume=self.market_vol,
market_price=self.market_price,
amount=exec_vol,
inner_amount=exec_vol,
deal_amount=exec_vol,
amount=exec_vol.sum(), # TODO: check this logic with Yuge & Xiao
inner_amount=exec_vol.sum(), # TODO: check this logic with Yuge & Xiao
deal_amount=exec_vol.sum(), # TODO: check this logic with Yuge & Xiao
trade_price=self.market_price,
trade_value=self.market_price * exec_vol,
position=ticks_position,
@@ -360,7 +367,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
inner_amount=exec_vol.sum(),
deal_amount=exec_vol.sum(), # in this simulator, there's no other restrictions
trade_price=exec_avg_price,
trade_value=np.sum(market_price * exec_vol),
trade_value=float(np.sum(market_price * exec_vol)),
position=self.position,
ffr=float(exec_vol.sum() / self.order.amount),
pa=price_advantage(exec_avg_price, self.twap_price, self.order.direction),
@@ -383,7 +390,9 @@ _float_or_ndarray = TypeVar("_float_or_ndarray", float, np.ndarray)
def price_advantage(
exec_price: _float_or_ndarray, baseline_price: float, direction: OrderDir | int
exec_price: _float_or_ndarray,
baseline_price: float,
direction: Union[OrderDir, int],
) -> _float_or_ndarray:
if baseline_price == 0: # something is wrong with data. Should be nan here
if isinstance(exec_price, float):

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
from typing import Generic, Any, TypeVar, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Tuple, TypeVar
from qlib.typehint import final
@@ -20,7 +20,7 @@ class Reward(Generic[SimulatorState]):
Subclass should implement ``reward(simulator_state)`` to implement their own reward calculation recipe.
"""
env: EnvWrapper | None = None
env: Optional[EnvWrapper] = None
@final
def __call__(self, simulator_state: SimulatorState) -> float:
@@ -30,14 +30,14 @@ class Reward(Generic[SimulatorState]):
"""Implement this method for your own reward."""
raise NotImplementedError("Implement reward calculation recipe in `reward()`.")
def log(self, name, value):
def log(self, name: str, value: Any) -> None:
self.env.logger.add_scalar(name, value)
class RewardCombination(Reward):
"""Combination of multiple reward."""
def __init__(self, rewards: dict[str, tuple[Reward, float]]):
def __init__(self, rewards: Dict[str, Tuple[Reward, float]]) -> None:
self.rewards = rewards
def reward(self, simulator_state: Any) -> float:

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
from typing import TypeVar, Generic, Any, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
from .seed import InitialStateType
@@ -49,7 +49,7 @@ class Simulator(Generic[InitialStateType, StateType, ActType]):
Simulators are discouraged to use this, because it's prone to induce errors.
"""
env: EnvWrapper | None = None
env: Optional[EnvWrapper] = None
def __init__(self, initial: InitialStateType, **kwargs: Any) -> None:
pass

View File

@@ -1,7 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .data_queue import *
from .env_wrapper import *
from .finite_env import *
from .log import *

View File

@@ -1,13 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from __future__ import annotations
import multiprocessing
import os
import threading
import time
import warnings
from queue import Empty
from typing import TypeVar, Generic, Sequence, cast
from typing import Any, Generator, Generic, Sequence, TypeVar, cast
from qlib.log import get_module_logger
@@ -60,7 +62,7 @@ class DataQueue(Generic[T]):
shuffle: bool = True,
producer_num_workers: int = 0,
queue_maxsize: int = 0,
):
) -> None:
if queue_maxsize == 0:
if os.cpu_count() is not None:
queue_maxsize = cast(int, os.cpu_count())
@@ -78,14 +80,14 @@ class DataQueue(Generic[T]):
self._queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=queue_maxsize)
self._done = multiprocessing.Value("i", 0)
def __enter__(self):
def __enter__(self) -> DataQueue:
self.activate()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.cleanup()
def cleanup(self):
def cleanup(self) -> None:
with self._done.get_lock():
self._done.value += 1
for repeat in range(500):
@@ -105,7 +107,7 @@ class DataQueue(Generic[T]):
break
_logger.debug(f"Remaining items in queue collection done. Empty: {self._queue.empty()}")
def get(self, block=True):
def get(self, block: bool = True) -> Any:
if not hasattr(self, "_first_get"):
self._first_get = True
if self._first_get:
@@ -120,17 +122,17 @@ class DataQueue(Generic[T]):
if self._done.value:
raise StopIteration # pylint: disable=raise-missing-from
def put(self, obj, block=True, timeout=None):
return self._queue.put(obj, block=block, timeout=timeout)
def put(self, obj: Any, block: bool = True, timeout: int = None) -> None:
self._queue.put(obj, block=block, timeout=timeout)
def mark_as_done(self):
def mark_as_done(self) -> None:
with self._done.get_lock():
self._done.value = 1
def done(self):
def done(self) -> int:
return self._done.value
def activate(self):
def activate(self) -> DataQueue:
if self._activated:
raise ValueError("DataQueue can not activate twice.")
thread = threading.Thread(target=self._producer, daemon=True)
@@ -138,18 +140,18 @@ class DataQueue(Generic[T]):
self._activated = True
return self
def __del__(self):
def __del__(self) -> None:
_logger.debug(f"__del__ of {__name__}.DataQueue")
self.cleanup()
def __iter__(self):
def __iter__(self) -> Generator[Any, None, None]:
if not self._activated:
raise ValueError(
"Need to call activate() to launch a daemon worker to produce data into data queue before using it."
"Need to call activate() to launch a daemon worker to produce data into data queue before using it.",
)
return self._consumer()
def _consumer(self):
def _consumer(self) -> Generator[Any, None, None]:
while True:
try:
yield self.get()
@@ -157,7 +159,7 @@ class DataQueue(Generic[T]):
_logger.debug("Data consumer timed-out from get.")
return
def _producer(self):
def _producer(self) -> None:
# pytorch dataloader is used here only because we need its sampler and multi-processing
from torch.utils.data import DataLoader, Dataset # pylint: disable=import-outside-toplevel

View File

@@ -4,14 +4,15 @@
from __future__ import annotations
import weakref
from typing import Callable, Any, Iterable, Iterator, Generic, cast
from typing import Any, Callable, Dict, Generic, Iterable, Iterator, Optional, Tuple, Union, cast
import gym
from gym import Space
from qlib.rl.aux_info import AuxiliaryInfoCollector
from qlib.rl.simulator import Simulator, InitialStateType, StateType, ActType
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter, PolicyActType, ObsType
from qlib.rl.interpreter import ActionInterpreter, ObsType, PolicyActType, StateInterpreter
from qlib.rl.reward import Reward
from qlib.rl.simulator import ActType, InitialStateType, Simulator, StateType
from qlib.typehint import TypedDict
from .finite_env import generate_nan_observation
@@ -28,7 +29,7 @@ class InfoDict(TypedDict):
aux_info: dict
"""Any information depends on auxiliary info collector."""
log: dict[str, Any]
log: Dict[str, Any]
"""Collected by LogCollector."""
@@ -42,14 +43,15 @@ class EnvWrapperStatus(TypedDict):
cur_step: int
done: bool
initial_state: Any | None
initial_state: Optional[Any]
obs_history: list
action_history: list
reward_history: list
class EnvWrapper(
gym.Env[ObsType, PolicyActType], Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType]
gym.Env[ObsType, PolicyActType],
Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType],
):
"""Qlib-based RL environment, subclassing ``gym.Env``.
A wrapper of components, including simulator, state-interpreter, action-interpreter, reward.
@@ -90,18 +92,18 @@ class EnvWrapper(
"""
simulator: Simulator[InitialStateType, StateType, ActType]
seed_iterator: str | Iterator[InitialStateType] | None
seed_iterator: Union[str, Iterator[InitialStateType], None]
def __init__(
self,
simulator_fn: Callable[..., Simulator[InitialStateType, StateType, ActType]],
state_interpreter: StateInterpreter[StateType, ObsType],
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
seed_iterator: Iterable[InitialStateType] | None,
reward_fn: Reward | None = None,
aux_info_collector: AuxiliaryInfoCollector[StateType, Any] | None = None,
logger: LogCollector | None = None,
):
seed_iterator: Optional[Iterable[InitialStateType]],
reward_fn: Reward = None,
aux_info_collector: AuxiliaryInfoCollector[StateType, Any] = None,
logger: LogCollector = None,
) -> None:
# Assign weak reference to wrapper.
#
# Use weak reference here, because:
@@ -135,11 +137,11 @@ class EnvWrapper(
self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None)
@property
def action_space(self):
def action_space(self) -> Space:
return self.action_interpreter.action_space
@property
def observation_space(self):
def observation_space(self) -> Space:
return self.state_interpreter.observation_space
def reset(self, **kwargs: Any) -> ObsType:
@@ -191,7 +193,7 @@ class EnvWrapper(
self.seed_iterator = None
return generate_nan_observation(self.observation_space)
def step(self, policy_action: PolicyActType, **kwargs: Any) -> tuple[ObsType, float, bool, InfoDict]:
def step(self, policy_action: PolicyActType, **kwargs: Any) -> Tuple[ObsType, float, bool, InfoDict]:
"""Environment step.
See the code along with comments to get a sequence of things happening here.
@@ -245,5 +247,5 @@ class EnvWrapper(
info_dict = InfoDict(log=self.logger.logs(), aux_info=aux_info)
return obs, rew, done, info_dict
def render(self):
def render(self, mode: str = "human") -> None:
raise NotImplementedError("Render is not implemented in EnvWrapper.")

View File

@@ -11,14 +11,14 @@ from __future__ import annotations
import copy
import warnings
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Type, Union, cast
import gym
import numpy as np
from typing import Any, Set, Callable, Type
from tianshou.env import BaseVectorEnv, DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv
from qlib.typehint import Literal
from .log import LogWriter
__all__ = [
@@ -36,7 +36,7 @@ __all__ = [
FiniteEnvType = Literal["dummy", "subproc", "shmem"]
def fill_invalid(obj):
def fill_invalid(obj: Union[int, float, bool, np.ndarray, dict, list, tuple]) -> Union[np.ndarray, dict, list, tuple]:
if isinstance(obj, (int, float, bool)):
return fill_invalid(np.array(obj))
if hasattr(obj, "dtype"):
@@ -55,7 +55,7 @@ def fill_invalid(obj):
raise ValueError(f"Unsupported value to fill with invalid: {obj}")
def is_invalid(arr):
def is_invalid(arr: Union[int, float, bool, np.ndarray, dict, list, tuple]) -> bool:
if hasattr(arr, "dtype"):
if np.issubdtype(arr.dtype, np.floating):
return np.isnan(arr).all()
@@ -121,11 +121,11 @@ class FiniteVectorEnv(BaseVectorEnv):
"""
def __init__(
self, logger: LogWriter | list[LogWriter], env_fns: list[Callable[..., gym.Env]], **kwargs: Any
self, logger: Union[LogWriter, List[LogWriter]], env_fns: List[Callable[..., gym.Env]], **kwargs: Any
) -> None:
super().__init__(env_fns, **kwargs)
self._logger: list[LogWriter] = logger if isinstance(logger, list) else [logger]
self._logger: List[LogWriter] = logger if isinstance(logger, list) else [logger]
self._alive_env_ids: Set[int] = set()
self._reset_alive_envs()
self._default_obs = self._default_info = self._default_rew = None
@@ -133,44 +133,44 @@ class FiniteVectorEnv(BaseVectorEnv):
self._collector_guarded: bool = False
def _reset_alive_envs(self):
def _reset_alive_envs(self) -> None:
if not self._alive_env_ids:
# starting or running out
self._alive_env_ids = set(range(self.env_num))
# to workaround with tianshou's buffer and batch
def _set_default_obs(self, obs):
def _set_default_obs(self, obs: Any) -> None:
if obs is not None and self._default_obs is None:
self._default_obs = copy.deepcopy(obs)
def _set_default_info(self, info):
def _set_default_info(self, info: Any) -> None:
if info is not None and self._default_info is None:
self._default_info = copy.deepcopy(info)
def _set_default_rew(self, rew):
def _set_default_rew(self, rew: Any) -> None:
if rew is not None and self._default_rew is None:
self._default_rew = copy.deepcopy(rew)
def _get_default_obs(self):
def _get_default_obs(self) -> Any:
return copy.deepcopy(self._default_obs)
def _get_default_info(self):
def _get_default_info(self) -> Any:
return copy.deepcopy(self._default_info)
def _get_default_rew(self):
def _get_default_rew(self) -> Any:
return copy.deepcopy(self._default_rew)
# END
@staticmethod
def _postproc_env_obs(obs):
def _postproc_env_obs(obs: Any) -> Optional[Any]:
# reserved for shmem vector env to restore empty observation
if obs is None or check_nan_observation(obs):
return None
return obs
@contextmanager
def collector_guard(self):
def collector_guard(self) -> Generator[FiniteVectorEnv, None, None]:
"""Guard the collector. Recommended to guard every collect.
This guard is for two purposes.
@@ -197,7 +197,10 @@ class FiniteVectorEnv(BaseVectorEnv):
for logger in self._logger:
logger.on_env_all_done()
def reset(self, id=None):
def reset(
self,
id: Optional[Union[int, List[int], np.ndarray]] = None,
) -> np.ndarray:
assert not self._zombie
# Check whether it's guarded by collector_guard()
@@ -245,7 +248,11 @@ class FiniteVectorEnv(BaseVectorEnv):
return np.stack(obs)
def step(self, action, id=None):
def step(
self,
action: np.ndarray,
id: Optional[Union[int, List[int], np.ndarray]] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
assert not self._zombie
id = self._wrap_id(id)
id2idx = {i: k for k, i in enumerate(id)}
@@ -277,7 +284,8 @@ class FiniteVectorEnv(BaseVectorEnv):
if r[3] is None:
result[i][3] = self._get_default_info()
return list(map(np.stack, zip(*result)))
ret = list(map(np.stack, zip(*result)))
return cast(Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], ret)
class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv):
@@ -296,7 +304,7 @@ def vectorize_env(
env_factory: Callable[..., gym.Env],
env_type: FiniteEnvType,
concurrency: int,
logger: LogWriter | list[LogWriter],
logger: Union[LogWriter, List[LogWriter]],
) -> FiniteVectorEnv:
"""Helper function to create a vector env.
@@ -326,7 +334,7 @@ def vectorize_env(
def env_factory(): ...
vectorize_env(env_factory, ...)
"""
env_type_cls_mapping: dict[str, Type[FiniteVectorEnv]] = {
env_type_cls_mapping: Dict[str, Type[FiniteVectorEnv]] = {
"dummy": FiniteDummyVectorEnv,
"subproc": FiniteSubprocVectorEnv,
"shmem": FiniteShmemVectorEnv,

View File

@@ -18,7 +18,7 @@ import logging
from collections import defaultdict
from enum import IntEnum
from pathlib import Path
from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Sequence, Set, Tuple, TypeVar, Union
import numpy as np
import pandas as pd
@@ -62,22 +62,22 @@ class LogCollector:
``min_loglevel`` is for optimization purposes: to avoid too much traffic on networks / in pipe.
"""
_logged: dict[str, tuple[int, Any]]
_logged: Dict[str, Tuple[int, Any]]
_min_loglevel: int
def __init__(self, min_loglevel: int | LogLevel = LogLevel.PERIODIC):
def __init__(self, min_loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
self._min_loglevel = int(min_loglevel)
def reset(self):
def reset(self) -> None:
"""Clear all collected contents."""
self._logged = {}
def _add_metric(self, name: str, metric: Any, loglevel: int | LogLevel) -> None:
def _add_metric(self, name: str, metric: Any, loglevel: Union[int, LogLevel]) -> None:
if name in self._logged:
raise ValueError(f"A metric with {name} is already added. Please change a name or reset the log collector.")
self._logged[name] = (int(loglevel), metric)
def add_string(self, name: str, string: str, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
def add_string(self, name: str, string: str, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
"""Add a string with name into logged contents."""
if loglevel < self._min_loglevel:
return
@@ -85,7 +85,7 @@ class LogCollector:
raise TypeError(f"{string} is not a string.")
self._add_metric(name, string, loglevel)
def add_scalar(self, name: str, scalar: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
def add_scalar(self, name: str, scalar: Any, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
"""Add a scalar with name into logged contents.
Scalar will be converted into a float.
"""
@@ -101,7 +101,10 @@ class LogCollector:
self._add_metric(name, scalar, loglevel)
def add_array(
self, name: str, array: np.ndarray | pd.DataFrame | pd.Series, loglevel: int | LogLevel = LogLevel.PERIODIC
self,
name: str,
array: Union[np.ndarray, pd.DataFrame, pd.Series],
loglevel: Union[int, LogLevel] = LogLevel.PERIODIC,
) -> None:
"""Add an array with name into logging."""
if loglevel < self._min_loglevel:
@@ -111,7 +114,7 @@ class LogCollector:
raise TypeError(f"{array} is not one of ndarray, DataFrame and Series.")
self._add_metric(name, array, loglevel)
def add_any(self, name: str, obj: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
def add_any(self, name: str, obj: Any, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
"""Log something with any type.
As it's an "any" object, the only LogWriter accepting it is pickle.
@@ -124,7 +127,7 @@ class LogCollector:
self._add_metric(name, obj, loglevel)
def logs(self) -> dict[str, np.ndarray]:
def logs(self) -> Dict[str, np.ndarray]:
return {key: np.asanyarray(value, dtype="object") for key, value in self._logged.items()}
@@ -151,16 +154,16 @@ class LogWriter(Generic[ObsType, ActType]):
active_env_ids: Set[int]
"""Active environment ids in vector env."""
episode_lengths: dict[int, int]
episode_lengths: Dict[int, int]
"""Map from environment id to episode length."""
episode_rewards: dict[int, list[float]]
episode_rewards: Dict[int, List[float]]
"""Map from environment id to episode total reward."""
episode_logs: dict[int, list]
episode_logs: Dict[int, list]
"""Map from environment id to episode logs."""
def __init__(self, loglevel: int | LogLevel = LogLevel.PERIODIC):
def __init__(self, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
self.loglevel = loglevel
self.global_step = 0
@@ -174,12 +177,13 @@ class LogWriter(Generic[ObsType, ActType]):
self.clear()
def clear(self):
def clear(self) -> None:
self.episode_count = self.step_count = 0
self.active_env_ids = set()
self.logs = []
def aggregation(self, array: Sequence[Any]) -> Any:
@staticmethod
def aggregation(array: Sequence[Any]) -> Any:
"""Aggregation function from step-wise to episode-wise.
If it's a sequence of float, take the mean.
@@ -191,7 +195,7 @@ class LogWriter(Generic[ObsType, ActType]):
else:
return array[0]
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None:
"""This is triggered at the end of each trajectory.
Parameters
@@ -204,7 +208,7 @@ class LogWriter(Generic[ObsType, ActType]):
Logged contents for every steps.
"""
def log_step(self, reward: float, contents: dict[str, Any]) -> None:
def log_step(self, reward: float, contents: Dict[str, Any]) -> None:
"""This is triggered at each step.
Parameters
@@ -227,7 +231,7 @@ class LogWriter(Generic[ObsType, ActType]):
# TODO: reward can be a list of list for MARL
self.episode_rewards[env_id].append(rew)
values: dict[str, Any] = {}
values: Dict[str, Any] = {}
for key, (loglevel, value) in info["log"].items():
if loglevel >= self.loglevel: # FIXME: this is actually incorrect (see last FIXME)
@@ -272,11 +276,11 @@ class ConsoleWriter(LogWriter):
def __init__(
self,
log_every_n_episode: int = 20,
total_episodes: int | None = None,
total_episodes: int = None,
float_format: str = ":.4f",
counter_format: str = ":4d",
loglevel: int | LogLevel = LogLevel.PERIODIC,
):
loglevel: Union[int, LogLevel] = LogLevel.PERIODIC,
) -> None:
super().__init__(loglevel)
# TODO: support log_every_n_step
self.log_every_n_episode = log_every_n_episode
@@ -289,15 +293,15 @@ class ConsoleWriter(LogWriter):
self.console_logger = get_module_logger(__name__, level=logging.INFO)
def clear(self):
def clear(self) -> None:
super().clear()
# Clear average meters
self.metric_counts: dict[str, int] = defaultdict(int)
self.metric_sums: dict[str, float] = defaultdict(float)
self.metric_counts: Dict[str, int] = defaultdict(int)
self.metric_sums: Dict[str, float] = defaultdict(float)
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None:
# Aggregate step-wise to episode-wise
episode_wise_contents: dict[str, list] = defaultdict(list)
episode_wise_contents: Dict[str, list] = defaultdict(list)
for step_contents in contents:
for name, value in step_contents.items():
@@ -306,7 +310,7 @@ class ConsoleWriter(LogWriter):
# Generate log contents and track them in average-meter.
# This should be done at every step, regardless of periodic or not.
logs: dict[str, float] = {}
logs: Dict[str, float] = {}
for name, values in episode_wise_contents.items():
logs[name] = self.aggregation(values) # type: ignore
@@ -318,7 +322,7 @@ class ConsoleWriter(LogWriter):
# Only log periodically or at the end
self.console_logger.info(self.generate_log_message(logs))
def generate_log_message(self, logs: dict[str, float]) -> str:
def generate_log_message(self, logs: Dict[str, float]) -> str:
if self.prefix:
msg_prefix = self.prefix + " "
else:
@@ -348,27 +352,27 @@ class CsvWriter(LogWriter):
SUPPORTED_TYPES = (float, str, pd.Timestamp)
all_records: list[dict[str, Any]]
all_records: List[Dict[str, Any]]
def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC):
def __init__(self, output_dir: Path, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
super().__init__(loglevel)
self.output_dir = output_dir
self.output_dir.mkdir(exist_ok=True)
def clear(self):
def clear(self) -> None:
super().clear()
self.all_records = []
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None:
# FIXME Same as ConsoleLogger, needs a refactor to eliminate code-dup
episode_wise_contents: dict[str, list] = defaultdict(list)
episode_wise_contents: Dict[str, list] = defaultdict(list)
for step_contents in contents:
for name, value in step_contents.items():
if isinstance(value, self.SUPPORTED_TYPES):
episode_wise_contents[name].append(value)
logs: dict[str, float] = {}
logs: Dict[str, float] = {}
for name, values in episode_wise_contents.items():
logs[name] = self.aggregation(values) # type: ignore

View File

@@ -5,6 +5,8 @@ from random import randint, choice
from pathlib import Path
import re
from typing import Any, Tuple
import gym
import numpy as np
import pandas as pd
@@ -24,16 +26,16 @@ from qlib.rl.utils.finite_env import vectorize_env
class SimpleEnv(gym.Env[int, int]):
def __init__(self):
def __init__(self) -> None:
self.logger = LogCollector()
self.observation_space = gym.spaces.Discrete(2)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
def reset(self, *args: Any, **kwargs: Any) -> int:
self.step_count = 0
return 0
def step(self, action: int):
def step(self, action: int) -> Tuple[int, float, bool, dict]:
self.logger.reset()
self.logger.add_scalar("reward", 42.0)
@@ -53,6 +55,9 @@ class SimpleEnv(gym.Env[int, int]):
return 1, 42.0, done, InfoDict(log=self.logger.logs(), aux_info={})
def render(self, mode: str = "human") -> None:
pass
class AnyPolicy(BasePolicy):
def forward(self, batch, state=None):
@@ -86,7 +91,8 @@ def test_simple_env_logger(caplog):
class SimpleSimulator(Simulator[int, float, float]):
def __init__(self, initial: int, **kwargs) -> None:
def __init__(self, initial: int, **kwargs: Any) -> None:
super(SimpleSimulator, self).__init__(initial, **kwargs)
self.initial = float(initial)
def step(self, action: float) -> None: