mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 19:10:58 +08:00
Refine previous version RL codes
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generic, TYPE_CHECKING, TypeVar
|
||||
from typing import TYPE_CHECKING, Generic, Optional, TypeVar
|
||||
|
||||
from qlib.typehint import final
|
||||
|
||||
@@ -21,7 +21,7 @@ AuxInfoType = TypeVar("AuxInfoType")
|
||||
class AuxiliaryInfoCollector(Generic[StateType, AuxInfoType]):
|
||||
"""Override this class to collect customized auxiliary information from environment."""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
env: Optional[EnvWrapper] = None
|
||||
|
||||
@final
|
||||
def __call__(self, simulator_state: StateType) -> AuxInfoType:
|
||||
|
||||
@@ -20,18 +20,17 @@ This file shows resemblence to qlib.backtest.high_performance_ds. We might merge
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import List, Sequence, cast
|
||||
from pathlib import Path
|
||||
from typing import List, Sequence, cast
|
||||
|
||||
import cachetools
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from cachetools.keys import hashkey
|
||||
|
||||
from qlib.backtest.decision import OrderDir, Order
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.typehint import Literal
|
||||
|
||||
|
||||
DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"]
|
||||
"""Several ad-hoc deal price.
|
||||
``bid_or_ask``: If sell, use column ``$bid0``; if buy, use column ``$ask0``.
|
||||
@@ -40,7 +39,7 @@ DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"]
|
||||
"""
|
||||
|
||||
|
||||
def _infer_processed_data_column_names(shape: int) -> list[str]:
|
||||
def _infer_processed_data_column_names(shape: int) -> List[str]:
|
||||
if shape == 16:
|
||||
return [
|
||||
"$open",
|
||||
@@ -95,8 +94,8 @@ class IntradayBacktestData:
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
deal_price: DealPriceType = "close",
|
||||
order_dir: int | None = None,
|
||||
):
|
||||
order_dir: int = None,
|
||||
) -> None:
|
||||
backtest = _read_pickle(data_dir / stock_id)
|
||||
backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]]
|
||||
|
||||
@@ -105,13 +104,13 @@ class IntradayBacktestData:
|
||||
|
||||
self.data: pd.DataFrame = backtest
|
||||
self.deal_price_type: DealPriceType = deal_price
|
||||
self.order_dir: int | None = order_dir
|
||||
self.order_dir = order_dir
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
return f"{self.__class__.__name__}({self.data})"
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def get_deal_price(self) -> pd.Series:
|
||||
@@ -162,7 +161,14 @@ class IntradayProcessedData:
|
||||
"""Processed data for "yesterday".
|
||||
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
|
||||
|
||||
def __init__(self, data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
) -> None:
|
||||
proc = _read_pickle(data_dir / stock_id)
|
||||
# We have to infer the names here because,
|
||||
# unfortunately they are not included in the original data.
|
||||
@@ -190,14 +196,18 @@ class IntradayProcessedData:
|
||||
assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim
|
||||
assert len(self.today) == len(self.yesterday) == time_length
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
return f"{self.__class__.__name__}({self.today}, {self.yesterday})"
|
||||
|
||||
|
||||
@lru_cache(maxsize=100) # 100 * 50K = 5MB
|
||||
def load_intraday_backtest_data(
|
||||
data_dir: Path, stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", order_dir: int | None = None
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
deal_price: DealPriceType = "close",
|
||||
order_dir: int = None,
|
||||
) -> IntradayBacktestData:
|
||||
return IntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
|
||||
|
||||
@@ -207,13 +217,19 @@ def load_intraday_backtest_data(
|
||||
key=lambda data_dir, stock_id, date, _, __: hashkey(data_dir, stock_id, date),
|
||||
)
|
||||
def load_intraday_processed_data(
|
||||
data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
) -> IntradayProcessedData:
|
||||
return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
|
||||
|
||||
|
||||
def load_orders(
|
||||
order_path: Path, start_time: pd.Timestamp | None = None, end_time: pd.Timestamp | None = None
|
||||
order_path: Path,
|
||||
start_time: pd.Timestamp = None,
|
||||
end_time: pd.Timestamp = None,
|
||||
) -> Sequence[Order]:
|
||||
"""Load orders, and set start time and end time for the orders."""
|
||||
|
||||
@@ -251,7 +267,7 @@ def load_orders(
|
||||
int(row["order_type"]),
|
||||
row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second),
|
||||
row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return orders
|
||||
|
||||
@@ -4,19 +4,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Callable, Sequence
|
||||
from typing import Callable, List, Sequence, Union
|
||||
|
||||
from tianshou.data import Collector
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
from qlib.constant import INF
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.rl.simulator import InitialStateType, Simulator
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.simulator import InitialStateType, Simulator
|
||||
from qlib.rl.utils import DataQueue, EnvWrapper, FiniteEnvType, LogCollector, LogWriter, vectorize_env
|
||||
|
||||
|
||||
_logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
@@ -26,8 +25,8 @@ def backtest(
|
||||
action_interpreter: ActionInterpreter,
|
||||
initial_states: Sequence[InitialStateType],
|
||||
policy: BasePolicy,
|
||||
logger: LogWriter | list[LogWriter],
|
||||
reward: Reward | None = None,
|
||||
logger: Union[LogWriter, List[LogWriter]],
|
||||
reward: Reward = None,
|
||||
finite_env_type: FiniteEnvType = "subproc",
|
||||
concurrency: int = 2,
|
||||
) -> None:
|
||||
@@ -60,7 +59,7 @@ def backtest(
|
||||
# To save bandwidth
|
||||
min_loglevel = min(lg.loglevel for lg in logger) if isinstance(logger, list) else logger.loglevel
|
||||
|
||||
def env_factory():
|
||||
def env_factory() -> EnvWrapper:
|
||||
# FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
|
||||
# and could be thread unsafe.
|
||||
# I'm not sure whether it's a design flaw.
|
||||
|
||||
@@ -3,13 +3,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, TypeVar, Generic, Any
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qlib.typehint import final
|
||||
|
||||
from .simulator import StateType, ActType
|
||||
from .simulator import ActType, StateType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .utils.env_wrapper import EnvWrapper
|
||||
@@ -40,7 +40,7 @@ class Interpreter:
|
||||
class StateInterpreter(Generic[StateType, ObsType], Interpreter):
|
||||
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
env: Optional[EnvWrapper] = None
|
||||
|
||||
@property
|
||||
def observation_space(self) -> gym.Space:
|
||||
@@ -74,7 +74,7 @@ class StateInterpreter(Generic[StateType, ObsType], Interpreter):
|
||||
class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter):
|
||||
"""Action Interpreter that interpret rl agent action into qlib orders"""
|
||||
|
||||
env: "EnvWrapper" | None = None
|
||||
env: Optional[EnvWrapper] = None
|
||||
|
||||
@property
|
||||
def action_space(self) -> gym.Space:
|
||||
@@ -141,10 +141,10 @@ def _gym_space_contains(space: gym.Space, x: Any) -> None:
|
||||
|
||||
|
||||
class GymSpaceValidationError(Exception):
|
||||
def __init__(self, message: str, space: gym.Space, x: Any):
|
||||
def __init__(self, message: str, space: gym.Space, x: Any) -> None:
|
||||
self.message = message
|
||||
self.space = space
|
||||
self.x = x
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"{self.message}\n Space: {self.space}\n Sample: {self.x}"
|
||||
|
||||
@@ -5,8 +5,3 @@
|
||||
Currently it supports single-asset order execution.
|
||||
Multi-asset is on the way.
|
||||
"""
|
||||
|
||||
from .interpreter import *
|
||||
from .network import *
|
||||
from .policy import *
|
||||
from .simulator_simple import *
|
||||
|
||||
@@ -5,15 +5,15 @@ from __future__ import annotations
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from typing import Any, List, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from gym import spaces
|
||||
|
||||
from qlib.constant import EPS
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
||||
from qlib.rl.data import pickle_styled
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.typehint import TypedDict
|
||||
|
||||
from .simulator_simple import SAOEState
|
||||
@@ -26,7 +26,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def canonicalize(value: int | float | np.ndarray | pd.DataFrame | dict) -> np.ndarray | dict:
|
||||
def canonicalize(value: Union[int, float, np.ndarray, pd.DataFrame, dict]) -> Union[np.ndarray, dict]:
|
||||
"""To 32-bit numeric types. Recursively."""
|
||||
if isinstance(value, pd.DataFrame):
|
||||
return value.to_numpy()
|
||||
@@ -99,18 +99,18 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
|
||||
"data_processed": self._mask_future_info(processed.today, state.cur_time),
|
||||
"data_processed_prev": processed.yesterday,
|
||||
"acquiring": state.order.direction == state.order.BUY,
|
||||
"cur_tick": min(np.sum(state.ticks_index < state.cur_time), self.data_ticks - 1),
|
||||
"cur_tick": min(float(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1),
|
||||
"cur_step": min(self.env.status["cur_step"], self.max_step - 1),
|
||||
"num_step": self.max_step,
|
||||
"target": state.order.amount,
|
||||
"position": state.position,
|
||||
"position_history": position_history[: self.max_step],
|
||||
}
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
def observation_space(self) -> spaces.Dict:
|
||||
space = {
|
||||
"data_processed": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),
|
||||
"data_processed_prev": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),
|
||||
@@ -147,11 +147,11 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
|
||||
The key list is not full. You can add more if more information is needed by your policy.
|
||||
"""
|
||||
|
||||
def __init__(self, max_step: int):
|
||||
def __init__(self, max_step: int) -> None:
|
||||
self.max_step = max_step
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
def observation_space(self) -> spaces.Dict:
|
||||
space = {
|
||||
"acquiring": spaces.Discrete(2),
|
||||
"cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32),
|
||||
@@ -165,7 +165,7 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
|
||||
assert self.env is not None
|
||||
assert self.env.status["cur_step"] <= self.max_step
|
||||
obs = CurrentStateObs(
|
||||
{
|
||||
**{
|
||||
"acquiring": state.order.direction == state.order.BUY,
|
||||
"cur_step": self.env.status["cur_step"],
|
||||
"num_step": self.max_step,
|
||||
@@ -188,7 +188,7 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
|
||||
i.e., $[0, 1/n, 2/n, \\ldots, n/n]$.
|
||||
"""
|
||||
|
||||
def __init__(self, values: int | list[float]):
|
||||
def __init__(self, values: Union[int, List[float]]) -> None:
|
||||
if isinstance(values, int):
|
||||
values = [i / values for i in range(0, values + 1)]
|
||||
self.action_values = values
|
||||
@@ -203,7 +203,7 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
|
||||
|
||||
|
||||
class TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]):
|
||||
"""Convert a continous ratio to deal amount.
|
||||
"""Convert a continuous ratio to deal amount.
|
||||
|
||||
The ratio is relative to TWAP on the remainder of the day.
|
||||
For example, there are 5 steps left, and the left position is 300.
|
||||
|
||||
@@ -3,13 +3,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
from typing import List, Tuple, cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tianshou.data import Batch
|
||||
|
||||
from qlib.typehint import Literal
|
||||
|
||||
from .interpreter import FullHistoryObs
|
||||
|
||||
__all__ = ["Recurrent"]
|
||||
@@ -18,7 +19,7 @@ __all__ = ["Recurrent"]
|
||||
class Recurrent(nn.Module):
|
||||
"""The network architecture proposed in `OPD <https://seqml.github.io/opd/opd_aaai21_supplement.pdf>`_.
|
||||
|
||||
At every timestep the input of policy network is divided into two parts,
|
||||
At every time step the input of policy network is divided into two parts,
|
||||
the public variables and the private variables. which are handled by ``raw_rnn``
|
||||
and ``pri_rnn`` in this network, respectively.
|
||||
|
||||
@@ -33,7 +34,7 @@ class Recurrent(nn.Module):
|
||||
output_dim: int = 32,
|
||||
rnn_type: Literal["rnn", "lstm", "gru"] = "gru",
|
||||
rnn_num_layers: int = 1,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
@@ -62,10 +63,10 @@ class Recurrent(nn.Module):
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
def _init_extra_branches(self):
|
||||
def _init_extra_branches(self) -> None:
|
||||
pass
|
||||
|
||||
def _source_features(self, obs: FullHistoryObs, device: torch.device) -> tuple[list[torch.Tensor], torch.Tensor]:
|
||||
def _source_features(self, obs: FullHistoryObs, device: torch.device) -> Tuple[List[torch.Tensor], torch.Tensor]:
|
||||
bs, _, data_dim = obs["data_processed"].size()
|
||||
data = torch.cat((torch.zeros(bs, 1, data_dim, device=device), obs["data_processed"]), 1)
|
||||
cur_step = obs["cur_step"].long()
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from abc import ABCMeta
|
||||
from pathlib import Path
|
||||
from typing import Optional, cast
|
||||
from typing import Any, Dict, Generator, Iterable, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from gym.spaces import Discrete
|
||||
from tianshou.data import Batch, to_torch
|
||||
from tianshou.policy import PPOPolicy, BasePolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch
|
||||
from tianshou.policy import BasePolicy, PPOPolicy
|
||||
|
||||
__all__ = ["AllOne", "PPO"]
|
||||
|
||||
@@ -18,29 +18,39 @@ __all__ = ["AllOne", "PPO"]
|
||||
# baselines #
|
||||
|
||||
|
||||
class NonlearnablePolicy(BasePolicy):
|
||||
class NonLearnablePolicy(BasePolicy, metaclass=ABCMeta):
|
||||
"""Tianshou's BasePolicy with empty ``learn`` and ``process_fn``.
|
||||
|
||||
This could be moved outside in future.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space: gym.Space, action_space: gym.Space):
|
||||
def __init__(self, obs_space: gym.Space, action_space: gym.Space) -> None:
|
||||
super().__init__()
|
||||
|
||||
def learn(self, batch, batch_size, repeat):
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
def process_fn(
|
||||
self,
|
||||
batch: Batch,
|
||||
buffer: ReplayBuffer,
|
||||
indices: np.ndarray,
|
||||
) -> Batch:
|
||||
pass
|
||||
|
||||
|
||||
class AllOne(NonlearnablePolicy):
|
||||
class AllOne(NonLearnablePolicy):
|
||||
"""Forward returns a batch full of 1.
|
||||
|
||||
Useful when implementing some baselines (e.g., TWAP).
|
||||
"""
|
||||
|
||||
def forward(self, batch, state=None, **kwargs):
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
return Batch(act=np.full(len(batch), 1.0), state=state)
|
||||
|
||||
|
||||
@@ -48,24 +58,34 @@ class AllOne(NonlearnablePolicy):
|
||||
|
||||
|
||||
class PPOActor(nn.Module):
|
||||
def __init__(self, extractor: nn.Module, action_dim: int):
|
||||
def __init__(self, extractor: nn.Module, action_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.layer_out = nn.Sequential(nn.Linear(cast(int, extractor.output_dim), action_dim), nn.Softmax(dim=-1))
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
def forward(
|
||||
self,
|
||||
obs: torch.Tensor,
|
||||
state: torch.Tensor = None,
|
||||
info: dict = {},
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
feature = self.extractor(to_torch(obs, device=auto_device(self)))
|
||||
out = self.layer_out(feature)
|
||||
return out, state
|
||||
|
||||
|
||||
class PPOCritic(nn.Module):
|
||||
def __init__(self, extractor: nn.Module):
|
||||
def __init__(self, extractor: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.value_out = nn.Linear(cast(int, extractor.output_dim), 1)
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
def forward(
|
||||
self,
|
||||
obs: torch.Tensor,
|
||||
state: torch.Tensor = None,
|
||||
info: dict = {},
|
||||
) -> torch.Tensor:
|
||||
feature = self.extractor(to_torch(obs, device=auto_device(self)))
|
||||
return self.value_out(feature).squeeze(dim=-1)
|
||||
|
||||
@@ -93,18 +113,20 @@ class PPO(PPOPolicy):
|
||||
max_grad_norm: float = 100.0,
|
||||
reward_normalization: bool = True,
|
||||
eps_clip: float = 0.3,
|
||||
value_clip: float = True,
|
||||
value_clip: bool = True,
|
||||
vf_coef: float = 1.0,
|
||||
gae_lambda: float = 1.0,
|
||||
max_batchsize: int = 256,
|
||||
max_batch_size: int = 256,
|
||||
deterministic_eval: bool = True,
|
||||
weight_file: Optional[Path] = None,
|
||||
):
|
||||
) -> None:
|
||||
assert isinstance(action_space, Discrete)
|
||||
actor = PPOActor(network, action_space.n)
|
||||
critic = PPOCritic(network)
|
||||
optimizer = torch.optim.Adam(
|
||||
chain_dedup(actor.parameters(), critic.parameters()), lr=lr, weight_decay=weight_decay
|
||||
chain_dedup(actor.parameters(), critic.parameters()),
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
super().__init__(
|
||||
actor,
|
||||
@@ -118,7 +140,7 @@ class PPO(PPOPolicy):
|
||||
value_clip=value_clip,
|
||||
vf_coef=vf_coef,
|
||||
gae_lambda=gae_lambda,
|
||||
max_batchsize=max_batchsize,
|
||||
max_batchsize=max_batch_size,
|
||||
deterministic_eval=deterministic_eval,
|
||||
observation_space=obs_space,
|
||||
action_space=action_space,
|
||||
@@ -136,7 +158,7 @@ def auto_device(module: nn.Module) -> torch.device:
|
||||
return torch.device("cpu") # fallback to cpu
|
||||
|
||||
|
||||
def load_weight(policy, path):
|
||||
def load_weight(policy: nn.Module, path: Path) -> None:
|
||||
assert isinstance(policy, nn.Module), "Policy has to be an nn.Module to load weight."
|
||||
loaded_weight = torch.load(path, map_location="cpu")
|
||||
try:
|
||||
@@ -149,7 +171,7 @@ def load_weight(policy, path):
|
||||
policy.load_state_dict(loaded_weight)
|
||||
|
||||
|
||||
def chain_dedup(*iterables):
|
||||
def chain_dedup(*iterables: Iterable) -> Generator[Any, None, None]:
|
||||
seen = set()
|
||||
for iterable in iterables:
|
||||
for i in iterable:
|
||||
|
||||
@@ -4,15 +4,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple, Any, TypeVar, cast
|
||||
from typing import Any, NamedTuple, Optional, TypeVar, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.constant import EPS
|
||||
from qlib.rl.data.pickle_styled import DealPriceType, IntradayBacktestData, load_intraday_backtest_data
|
||||
from qlib.rl.simulator import Simulator
|
||||
from qlib.rl.data.pickle_styled import IntradayBacktestData, load_intraday_backtest_data, DealPriceType
|
||||
from qlib.rl.utils import LogLevel
|
||||
from qlib.typehint import TypedDict
|
||||
|
||||
@@ -33,7 +33,7 @@ class SAOEMetrics(TypedDict):
|
||||
|
||||
stock_id: str
|
||||
"""Stock ID of this record."""
|
||||
datetime: pd.Timestamp
|
||||
datetime: Union[pd.Timestamp, pd.DatetimeIndex] # TODO: check this
|
||||
"""Datetime of this record (this is index in the dataframe)."""
|
||||
direction: int
|
||||
"""Direction of the order. 0 for sell, 1 for buy."""
|
||||
@@ -87,7 +87,7 @@ class SAOEState(NamedTuple):
|
||||
history_steps: pd.DataFrame
|
||||
"""See :attr:`SingleAssetOrderExecution.history_steps`."""
|
||||
|
||||
metrics: SAOEMetrics | None
|
||||
metrics: Optional[SAOEMetrics]
|
||||
"""Daily metric, only available when the trading is in "done" state."""
|
||||
|
||||
backtest_data: IntradayBacktestData
|
||||
@@ -114,13 +114,13 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
If such fine granularity is not needed, use ``ticks_per_step`` to
|
||||
lengthen the ticks for each step.
|
||||
|
||||
In each step, the traded amount are "equally" splitted to each tick,
|
||||
then bounded by volume maximum exeuction volume (i.e., ``vol_threshold``),
|
||||
In each step, the traded amount are "equally" separated to each tick,
|
||||
then bounded by volume maximum execution volume (i.e., ``vol_threshold``),
|
||||
and if it's the last step, try to ensure all the amount to be executed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
initial
|
||||
order
|
||||
The seed to start an SAOE simulator is an order.
|
||||
ticks_per_step
|
||||
How many ticks per step.
|
||||
@@ -137,7 +137,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
"""Positions at each step. The position before first step is also recorded.
|
||||
See :class:`SAOEMetrics` for available columns."""
|
||||
|
||||
metrics: SAOEMetrics | None
|
||||
metrics: Optional[SAOEMetrics]
|
||||
"""Metrics. Only available when done."""
|
||||
|
||||
twap_price: float
|
||||
@@ -156,15 +156,21 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
data_dir: Path,
|
||||
ticks_per_step: int = 30,
|
||||
deal_price_type: DealPriceType = "close",
|
||||
vol_threshold: float | None = None,
|
||||
vol_threshold: Optional[float] = None,
|
||||
) -> None:
|
||||
super(SingleAssetOrderExecution, self).__init__(initial=order)
|
||||
|
||||
self.order = order
|
||||
self.ticks_per_step: int = ticks_per_step
|
||||
self.deal_price_type = deal_price_type
|
||||
self.vol_threshold = vol_threshold
|
||||
self.data_dir = data_dir
|
||||
self.backtest_data = load_intraday_backtest_data(
|
||||
self.data_dir, order.stock_id, pd.Timestamp(order.start_time.date()), self.deal_price_type, order.direction
|
||||
self.data_dir,
|
||||
order.stock_id,
|
||||
pd.Timestamp(order.start_time.date()),
|
||||
self.deal_price_type,
|
||||
order.direction,
|
||||
)
|
||||
|
||||
self.ticks_index = self.backtest_data.get_time_index()
|
||||
@@ -185,9 +191,9 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.metrics = None
|
||||
|
||||
self.market_price: np.ndarray | None = None
|
||||
self.market_vol: np.ndarray | None = None
|
||||
self.market_vol_limit: np.ndarray | None = None
|
||||
self.market_price: Optional[np.ndarray] = None
|
||||
self.market_vol: Optional[np.ndarray] = None
|
||||
self.market_vol_limit: Optional[np.ndarray] = None
|
||||
|
||||
def step(self, amount: float) -> None:
|
||||
"""Execute one step or SAOE.
|
||||
@@ -202,7 +208,8 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
|
||||
self.market_price = self.market_vol = None # avoid misuse
|
||||
exec_vol = self._split_exec_vol(amount)
|
||||
assert self.market_price is not None and self.market_vol is not None
|
||||
assert self.market_price is not None
|
||||
assert self.market_vol is not None
|
||||
|
||||
ticks_position = self.position - np.cumsum(exec_vol)
|
||||
|
||||
@@ -224,9 +231,9 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
direction=self.order.direction,
|
||||
market_volume=self.market_vol,
|
||||
market_price=self.market_price,
|
||||
amount=exec_vol,
|
||||
inner_amount=exec_vol,
|
||||
deal_amount=exec_vol,
|
||||
amount=exec_vol.sum(), # TODO: check this logic with Yuge & Xiao
|
||||
inner_amount=exec_vol.sum(), # TODO: check this logic with Yuge & Xiao
|
||||
deal_amount=exec_vol.sum(), # TODO: check this logic with Yuge & Xiao
|
||||
trade_price=self.market_price,
|
||||
trade_value=self.market_price * exec_vol,
|
||||
position=ticks_position,
|
||||
@@ -360,7 +367,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
inner_amount=exec_vol.sum(),
|
||||
deal_amount=exec_vol.sum(), # in this simulator, there's no other restrictions
|
||||
trade_price=exec_avg_price,
|
||||
trade_value=np.sum(market_price * exec_vol),
|
||||
trade_value=float(np.sum(market_price * exec_vol)),
|
||||
position=self.position,
|
||||
ffr=float(exec_vol.sum() / self.order.amount),
|
||||
pa=price_advantage(exec_avg_price, self.twap_price, self.order.direction),
|
||||
@@ -383,7 +390,9 @@ _float_or_ndarray = TypeVar("_float_or_ndarray", float, np.ndarray)
|
||||
|
||||
|
||||
def price_advantage(
|
||||
exec_price: _float_or_ndarray, baseline_price: float, direction: OrderDir | int
|
||||
exec_price: _float_or_ndarray,
|
||||
baseline_price: float,
|
||||
direction: Union[OrderDir, int],
|
||||
) -> _float_or_ndarray:
|
||||
if baseline_price == 0: # something is wrong with data. Should be nan here
|
||||
if isinstance(exec_price, float):
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generic, Any, TypeVar, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Tuple, TypeVar
|
||||
|
||||
from qlib.typehint import final
|
||||
|
||||
@@ -20,7 +20,7 @@ class Reward(Generic[SimulatorState]):
|
||||
Subclass should implement ``reward(simulator_state)`` to implement their own reward calculation recipe.
|
||||
"""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
env: Optional[EnvWrapper] = None
|
||||
|
||||
@final
|
||||
def __call__(self, simulator_state: SimulatorState) -> float:
|
||||
@@ -30,14 +30,14 @@ class Reward(Generic[SimulatorState]):
|
||||
"""Implement this method for your own reward."""
|
||||
raise NotImplementedError("Implement reward calculation recipe in `reward()`.")
|
||||
|
||||
def log(self, name, value):
|
||||
def log(self, name: str, value: Any) -> None:
|
||||
self.env.logger.add_scalar(name, value)
|
||||
|
||||
|
||||
class RewardCombination(Reward):
|
||||
"""Combination of multiple reward."""
|
||||
|
||||
def __init__(self, rewards: dict[str, tuple[Reward, float]]):
|
||||
def __init__(self, rewards: Dict[str, Tuple[Reward, float]]) -> None:
|
||||
self.rewards = rewards
|
||||
|
||||
def reward(self, simulator_state: Any) -> float:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypeVar, Generic, Any, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
|
||||
|
||||
from .seed import InitialStateType
|
||||
|
||||
@@ -49,7 +49,7 @@ class Simulator(Generic[InitialStateType, StateType, ActType]):
|
||||
Simulators are discouraged to use this, because it's prone to induce errors.
|
||||
"""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
env: Optional[EnvWrapper] = None
|
||||
|
||||
def __init__(self, initial: InitialStateType, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
@@ -1,7 +1,2 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .data_queue import *
|
||||
from .env_wrapper import *
|
||||
from .finite_env import *
|
||||
from .log import *
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from __future__ import annotations
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from queue import Empty
|
||||
from typing import TypeVar, Generic, Sequence, cast
|
||||
from typing import Any, Generator, Generic, Sequence, TypeVar, cast
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
@@ -60,7 +62,7 @@ class DataQueue(Generic[T]):
|
||||
shuffle: bool = True,
|
||||
producer_num_workers: int = 0,
|
||||
queue_maxsize: int = 0,
|
||||
):
|
||||
) -> None:
|
||||
if queue_maxsize == 0:
|
||||
if os.cpu_count() is not None:
|
||||
queue_maxsize = cast(int, os.cpu_count())
|
||||
@@ -78,14 +80,14 @@ class DataQueue(Generic[T]):
|
||||
self._queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=queue_maxsize)
|
||||
self._done = multiprocessing.Value("i", 0)
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> DataQueue:
|
||||
self.activate()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
self.cleanup()
|
||||
|
||||
def cleanup(self):
|
||||
def cleanup(self) -> None:
|
||||
with self._done.get_lock():
|
||||
self._done.value += 1
|
||||
for repeat in range(500):
|
||||
@@ -105,7 +107,7 @@ class DataQueue(Generic[T]):
|
||||
break
|
||||
_logger.debug(f"Remaining items in queue collection done. Empty: {self._queue.empty()}")
|
||||
|
||||
def get(self, block=True):
|
||||
def get(self, block: bool = True) -> Any:
|
||||
if not hasattr(self, "_first_get"):
|
||||
self._first_get = True
|
||||
if self._first_get:
|
||||
@@ -120,17 +122,17 @@ class DataQueue(Generic[T]):
|
||||
if self._done.value:
|
||||
raise StopIteration # pylint: disable=raise-missing-from
|
||||
|
||||
def put(self, obj, block=True, timeout=None):
|
||||
return self._queue.put(obj, block=block, timeout=timeout)
|
||||
def put(self, obj: Any, block: bool = True, timeout: int = None) -> None:
|
||||
self._queue.put(obj, block=block, timeout=timeout)
|
||||
|
||||
def mark_as_done(self):
|
||||
def mark_as_done(self) -> None:
|
||||
with self._done.get_lock():
|
||||
self._done.value = 1
|
||||
|
||||
def done(self):
|
||||
def done(self) -> int:
|
||||
return self._done.value
|
||||
|
||||
def activate(self):
|
||||
def activate(self) -> DataQueue:
|
||||
if self._activated:
|
||||
raise ValueError("DataQueue can not activate twice.")
|
||||
thread = threading.Thread(target=self._producer, daemon=True)
|
||||
@@ -138,18 +140,18 @@ class DataQueue(Generic[T]):
|
||||
self._activated = True
|
||||
return self
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
_logger.debug(f"__del__ of {__name__}.DataQueue")
|
||||
self.cleanup()
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Generator[Any, None, None]:
|
||||
if not self._activated:
|
||||
raise ValueError(
|
||||
"Need to call activate() to launch a daemon worker to produce data into data queue before using it."
|
||||
"Need to call activate() to launch a daemon worker to produce data into data queue before using it.",
|
||||
)
|
||||
return self._consumer()
|
||||
|
||||
def _consumer(self):
|
||||
def _consumer(self) -> Generator[Any, None, None]:
|
||||
while True:
|
||||
try:
|
||||
yield self.get()
|
||||
@@ -157,7 +159,7 @@ class DataQueue(Generic[T]):
|
||||
_logger.debug("Data consumer timed-out from get.")
|
||||
return
|
||||
|
||||
def _producer(self):
|
||||
def _producer(self) -> None:
|
||||
# pytorch dataloader is used here only because we need its sampler and multi-processing
|
||||
from torch.utils.data import DataLoader, Dataset # pylint: disable=import-outside-toplevel
|
||||
|
||||
|
||||
@@ -4,14 +4,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import weakref
|
||||
from typing import Callable, Any, Iterable, Iterator, Generic, cast
|
||||
from typing import Any, Callable, Dict, Generic, Iterable, Iterator, Optional, Tuple, Union, cast
|
||||
|
||||
import gym
|
||||
from gym import Space
|
||||
|
||||
from qlib.rl.aux_info import AuxiliaryInfoCollector
|
||||
from qlib.rl.simulator import Simulator, InitialStateType, StateType, ActType
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter, PolicyActType, ObsType
|
||||
from qlib.rl.interpreter import ActionInterpreter, ObsType, PolicyActType, StateInterpreter
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.simulator import ActType, InitialStateType, Simulator, StateType
|
||||
from qlib.typehint import TypedDict
|
||||
|
||||
from .finite_env import generate_nan_observation
|
||||
@@ -28,7 +29,7 @@ class InfoDict(TypedDict):
|
||||
|
||||
aux_info: dict
|
||||
"""Any information depends on auxiliary info collector."""
|
||||
log: dict[str, Any]
|
||||
log: Dict[str, Any]
|
||||
"""Collected by LogCollector."""
|
||||
|
||||
|
||||
@@ -42,14 +43,15 @@ class EnvWrapperStatus(TypedDict):
|
||||
|
||||
cur_step: int
|
||||
done: bool
|
||||
initial_state: Any | None
|
||||
initial_state: Optional[Any]
|
||||
obs_history: list
|
||||
action_history: list
|
||||
reward_history: list
|
||||
|
||||
|
||||
class EnvWrapper(
|
||||
gym.Env[ObsType, PolicyActType], Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType]
|
||||
gym.Env[ObsType, PolicyActType],
|
||||
Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType],
|
||||
):
|
||||
"""Qlib-based RL environment, subclassing ``gym.Env``.
|
||||
A wrapper of components, including simulator, state-interpreter, action-interpreter, reward.
|
||||
@@ -90,18 +92,18 @@ class EnvWrapper(
|
||||
"""
|
||||
|
||||
simulator: Simulator[InitialStateType, StateType, ActType]
|
||||
seed_iterator: str | Iterator[InitialStateType] | None
|
||||
seed_iterator: Union[str, Iterator[InitialStateType], None]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
simulator_fn: Callable[..., Simulator[InitialStateType, StateType, ActType]],
|
||||
state_interpreter: StateInterpreter[StateType, ObsType],
|
||||
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
|
||||
seed_iterator: Iterable[InitialStateType] | None,
|
||||
reward_fn: Reward | None = None,
|
||||
aux_info_collector: AuxiliaryInfoCollector[StateType, Any] | None = None,
|
||||
logger: LogCollector | None = None,
|
||||
):
|
||||
seed_iterator: Optional[Iterable[InitialStateType]],
|
||||
reward_fn: Reward = None,
|
||||
aux_info_collector: AuxiliaryInfoCollector[StateType, Any] = None,
|
||||
logger: LogCollector = None,
|
||||
) -> None:
|
||||
# Assign weak reference to wrapper.
|
||||
#
|
||||
# Use weak reference here, because:
|
||||
@@ -135,11 +137,11 @@ class EnvWrapper(
|
||||
self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
def action_space(self) -> Space:
|
||||
return self.action_interpreter.action_space
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
def observation_space(self) -> Space:
|
||||
return self.state_interpreter.observation_space
|
||||
|
||||
def reset(self, **kwargs: Any) -> ObsType:
|
||||
@@ -191,7 +193,7 @@ class EnvWrapper(
|
||||
self.seed_iterator = None
|
||||
return generate_nan_observation(self.observation_space)
|
||||
|
||||
def step(self, policy_action: PolicyActType, **kwargs: Any) -> tuple[ObsType, float, bool, InfoDict]:
|
||||
def step(self, policy_action: PolicyActType, **kwargs: Any) -> Tuple[ObsType, float, bool, InfoDict]:
|
||||
"""Environment step.
|
||||
|
||||
See the code along with comments to get a sequence of things happening here.
|
||||
@@ -245,5 +247,5 @@ class EnvWrapper(
|
||||
info_dict = InfoDict(log=self.logger.logs(), aux_info=aux_info)
|
||||
return obs, rew, done, info_dict
|
||||
|
||||
def render(self):
|
||||
def render(self, mode: str = "human") -> None:
|
||||
raise NotImplementedError("Render is not implemented in EnvWrapper.")
|
||||
|
||||
@@ -11,14 +11,14 @@ from __future__ import annotations
|
||||
import copy
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Type, Union, cast
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from typing import Any, Set, Callable, Type
|
||||
|
||||
from tianshou.env import BaseVectorEnv, DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv
|
||||
|
||||
from qlib.typehint import Literal
|
||||
|
||||
from .log import LogWriter
|
||||
|
||||
__all__ = [
|
||||
@@ -36,7 +36,7 @@ __all__ = [
|
||||
FiniteEnvType = Literal["dummy", "subproc", "shmem"]
|
||||
|
||||
|
||||
def fill_invalid(obj):
|
||||
def fill_invalid(obj: Union[int, float, bool, np.ndarray, dict, list, tuple]) -> Union[np.ndarray, dict, list, tuple]:
|
||||
if isinstance(obj, (int, float, bool)):
|
||||
return fill_invalid(np.array(obj))
|
||||
if hasattr(obj, "dtype"):
|
||||
@@ -55,7 +55,7 @@ def fill_invalid(obj):
|
||||
raise ValueError(f"Unsupported value to fill with invalid: {obj}")
|
||||
|
||||
|
||||
def is_invalid(arr):
|
||||
def is_invalid(arr: Union[int, float, bool, np.ndarray, dict, list, tuple]) -> bool:
|
||||
if hasattr(arr, "dtype"):
|
||||
if np.issubdtype(arr.dtype, np.floating):
|
||||
return np.isnan(arr).all()
|
||||
@@ -121,11 +121,11 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, logger: LogWriter | list[LogWriter], env_fns: list[Callable[..., gym.Env]], **kwargs: Any
|
||||
self, logger: Union[LogWriter, List[LogWriter]], env_fns: List[Callable[..., gym.Env]], **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(env_fns, **kwargs)
|
||||
|
||||
self._logger: list[LogWriter] = logger if isinstance(logger, list) else [logger]
|
||||
self._logger: List[LogWriter] = logger if isinstance(logger, list) else [logger]
|
||||
self._alive_env_ids: Set[int] = set()
|
||||
self._reset_alive_envs()
|
||||
self._default_obs = self._default_info = self._default_rew = None
|
||||
@@ -133,44 +133,44 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
|
||||
self._collector_guarded: bool = False
|
||||
|
||||
def _reset_alive_envs(self):
|
||||
def _reset_alive_envs(self) -> None:
|
||||
if not self._alive_env_ids:
|
||||
# starting or running out
|
||||
self._alive_env_ids = set(range(self.env_num))
|
||||
|
||||
# to workaround with tianshou's buffer and batch
|
||||
def _set_default_obs(self, obs):
|
||||
def _set_default_obs(self, obs: Any) -> None:
|
||||
if obs is not None and self._default_obs is None:
|
||||
self._default_obs = copy.deepcopy(obs)
|
||||
|
||||
def _set_default_info(self, info):
|
||||
def _set_default_info(self, info: Any) -> None:
|
||||
if info is not None and self._default_info is None:
|
||||
self._default_info = copy.deepcopy(info)
|
||||
|
||||
def _set_default_rew(self, rew):
|
||||
def _set_default_rew(self, rew: Any) -> None:
|
||||
if rew is not None and self._default_rew is None:
|
||||
self._default_rew = copy.deepcopy(rew)
|
||||
|
||||
def _get_default_obs(self):
|
||||
def _get_default_obs(self) -> Any:
|
||||
return copy.deepcopy(self._default_obs)
|
||||
|
||||
def _get_default_info(self):
|
||||
def _get_default_info(self) -> Any:
|
||||
return copy.deepcopy(self._default_info)
|
||||
|
||||
def _get_default_rew(self):
|
||||
def _get_default_rew(self) -> Any:
|
||||
return copy.deepcopy(self._default_rew)
|
||||
|
||||
# END
|
||||
|
||||
@staticmethod
|
||||
def _postproc_env_obs(obs):
|
||||
def _postproc_env_obs(obs: Any) -> Optional[Any]:
|
||||
# reserved for shmem vector env to restore empty observation
|
||||
if obs is None or check_nan_observation(obs):
|
||||
return None
|
||||
return obs
|
||||
|
||||
@contextmanager
|
||||
def collector_guard(self):
|
||||
def collector_guard(self) -> Generator[FiniteVectorEnv, None, None]:
|
||||
"""Guard the collector. Recommended to guard every collect.
|
||||
|
||||
This guard is for two purposes.
|
||||
@@ -197,7 +197,10 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
for logger in self._logger:
|
||||
logger.on_env_all_done()
|
||||
|
||||
def reset(self, id=None):
|
||||
def reset(
|
||||
self,
|
||||
id: Optional[Union[int, List[int], np.ndarray]] = None,
|
||||
) -> np.ndarray:
|
||||
assert not self._zombie
|
||||
|
||||
# Check whether it's guarded by collector_guard()
|
||||
@@ -245,7 +248,11 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
|
||||
return np.stack(obs)
|
||||
|
||||
def step(self, action, id=None):
|
||||
def step(
|
||||
self,
|
||||
action: np.ndarray,
|
||||
id: Optional[Union[int, List[int], np.ndarray]] = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
assert not self._zombie
|
||||
id = self._wrap_id(id)
|
||||
id2idx = {i: k for k, i in enumerate(id)}
|
||||
@@ -277,7 +284,8 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
if r[3] is None:
|
||||
result[i][3] = self._get_default_info()
|
||||
|
||||
return list(map(np.stack, zip(*result)))
|
||||
ret = list(map(np.stack, zip(*result)))
|
||||
return cast(Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], ret)
|
||||
|
||||
|
||||
class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv):
|
||||
@@ -296,7 +304,7 @@ def vectorize_env(
|
||||
env_factory: Callable[..., gym.Env],
|
||||
env_type: FiniteEnvType,
|
||||
concurrency: int,
|
||||
logger: LogWriter | list[LogWriter],
|
||||
logger: Union[LogWriter, List[LogWriter]],
|
||||
) -> FiniteVectorEnv:
|
||||
"""Helper function to create a vector env.
|
||||
|
||||
@@ -326,7 +334,7 @@ def vectorize_env(
|
||||
def env_factory(): ...
|
||||
vectorize_env(env_factory, ...)
|
||||
"""
|
||||
env_type_cls_mapping: dict[str, Type[FiniteVectorEnv]] = {
|
||||
env_type_cls_mapping: Dict[str, Type[FiniteVectorEnv]] = {
|
||||
"dummy": FiniteDummyVectorEnv,
|
||||
"subproc": FiniteSubprocVectorEnv,
|
||||
"shmem": FiniteShmemVectorEnv,
|
||||
|
||||
@@ -18,7 +18,7 @@ import logging
|
||||
from collections import defaultdict
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Sequence, Set, Tuple, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -62,22 +62,22 @@ class LogCollector:
|
||||
``min_loglevel`` is for optimization purposes: to avoid too much traffic on networks / in pipe.
|
||||
"""
|
||||
|
||||
_logged: dict[str, tuple[int, Any]]
|
||||
_logged: Dict[str, Tuple[int, Any]]
|
||||
_min_loglevel: int
|
||||
|
||||
def __init__(self, min_loglevel: int | LogLevel = LogLevel.PERIODIC):
|
||||
def __init__(self, min_loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
|
||||
self._min_loglevel = int(min_loglevel)
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""Clear all collected contents."""
|
||||
self._logged = {}
|
||||
|
||||
def _add_metric(self, name: str, metric: Any, loglevel: int | LogLevel) -> None:
|
||||
def _add_metric(self, name: str, metric: Any, loglevel: Union[int, LogLevel]) -> None:
|
||||
if name in self._logged:
|
||||
raise ValueError(f"A metric with {name} is already added. Please change a name or reset the log collector.")
|
||||
self._logged[name] = (int(loglevel), metric)
|
||||
|
||||
def add_string(self, name: str, string: str, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
|
||||
def add_string(self, name: str, string: str, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
|
||||
"""Add a string with name into logged contents."""
|
||||
if loglevel < self._min_loglevel:
|
||||
return
|
||||
@@ -85,7 +85,7 @@ class LogCollector:
|
||||
raise TypeError(f"{string} is not a string.")
|
||||
self._add_metric(name, string, loglevel)
|
||||
|
||||
def add_scalar(self, name: str, scalar: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
|
||||
def add_scalar(self, name: str, scalar: Any, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
|
||||
"""Add a scalar with name into logged contents.
|
||||
Scalar will be converted into a float.
|
||||
"""
|
||||
@@ -101,7 +101,10 @@ class LogCollector:
|
||||
self._add_metric(name, scalar, loglevel)
|
||||
|
||||
def add_array(
|
||||
self, name: str, array: np.ndarray | pd.DataFrame | pd.Series, loglevel: int | LogLevel = LogLevel.PERIODIC
|
||||
self,
|
||||
name: str,
|
||||
array: Union[np.ndarray, pd.DataFrame, pd.Series],
|
||||
loglevel: Union[int, LogLevel] = LogLevel.PERIODIC,
|
||||
) -> None:
|
||||
"""Add an array with name into logging."""
|
||||
if loglevel < self._min_loglevel:
|
||||
@@ -111,7 +114,7 @@ class LogCollector:
|
||||
raise TypeError(f"{array} is not one of ndarray, DataFrame and Series.")
|
||||
self._add_metric(name, array, loglevel)
|
||||
|
||||
def add_any(self, name: str, obj: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
|
||||
def add_any(self, name: str, obj: Any, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
|
||||
"""Log something with any type.
|
||||
|
||||
As it's an "any" object, the only LogWriter accepting it is pickle.
|
||||
@@ -124,7 +127,7 @@ class LogCollector:
|
||||
|
||||
self._add_metric(name, obj, loglevel)
|
||||
|
||||
def logs(self) -> dict[str, np.ndarray]:
|
||||
def logs(self) -> Dict[str, np.ndarray]:
|
||||
return {key: np.asanyarray(value, dtype="object") for key, value in self._logged.items()}
|
||||
|
||||
|
||||
@@ -151,16 +154,16 @@ class LogWriter(Generic[ObsType, ActType]):
|
||||
active_env_ids: Set[int]
|
||||
"""Active environment ids in vector env."""
|
||||
|
||||
episode_lengths: dict[int, int]
|
||||
episode_lengths: Dict[int, int]
|
||||
"""Map from environment id to episode length."""
|
||||
|
||||
episode_rewards: dict[int, list[float]]
|
||||
episode_rewards: Dict[int, List[float]]
|
||||
"""Map from environment id to episode total reward."""
|
||||
|
||||
episode_logs: dict[int, list]
|
||||
episode_logs: Dict[int, list]
|
||||
"""Map from environment id to episode logs."""
|
||||
|
||||
def __init__(self, loglevel: int | LogLevel = LogLevel.PERIODIC):
|
||||
def __init__(self, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
|
||||
self.loglevel = loglevel
|
||||
|
||||
self.global_step = 0
|
||||
@@ -174,12 +177,13 @@ class LogWriter(Generic[ObsType, ActType]):
|
||||
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
def clear(self) -> None:
|
||||
self.episode_count = self.step_count = 0
|
||||
self.active_env_ids = set()
|
||||
self.logs = []
|
||||
|
||||
def aggregation(self, array: Sequence[Any]) -> Any:
|
||||
@staticmethod
|
||||
def aggregation(array: Sequence[Any]) -> Any:
|
||||
"""Aggregation function from step-wise to episode-wise.
|
||||
|
||||
If it's a sequence of float, take the mean.
|
||||
@@ -191,7 +195,7 @@ class LogWriter(Generic[ObsType, ActType]):
|
||||
else:
|
||||
return array[0]
|
||||
|
||||
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
|
||||
def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None:
|
||||
"""This is triggered at the end of each trajectory.
|
||||
|
||||
Parameters
|
||||
@@ -204,7 +208,7 @@ class LogWriter(Generic[ObsType, ActType]):
|
||||
Logged contents for every steps.
|
||||
"""
|
||||
|
||||
def log_step(self, reward: float, contents: dict[str, Any]) -> None:
|
||||
def log_step(self, reward: float, contents: Dict[str, Any]) -> None:
|
||||
"""This is triggered at each step.
|
||||
|
||||
Parameters
|
||||
@@ -227,7 +231,7 @@ class LogWriter(Generic[ObsType, ActType]):
|
||||
# TODO: reward can be a list of list for MARL
|
||||
self.episode_rewards[env_id].append(rew)
|
||||
|
||||
values: dict[str, Any] = {}
|
||||
values: Dict[str, Any] = {}
|
||||
|
||||
for key, (loglevel, value) in info["log"].items():
|
||||
if loglevel >= self.loglevel: # FIXME: this is actually incorrect (see last FIXME)
|
||||
@@ -272,11 +276,11 @@ class ConsoleWriter(LogWriter):
|
||||
def __init__(
|
||||
self,
|
||||
log_every_n_episode: int = 20,
|
||||
total_episodes: int | None = None,
|
||||
total_episodes: int = None,
|
||||
float_format: str = ":.4f",
|
||||
counter_format: str = ":4d",
|
||||
loglevel: int | LogLevel = LogLevel.PERIODIC,
|
||||
):
|
||||
loglevel: Union[int, LogLevel] = LogLevel.PERIODIC,
|
||||
) -> None:
|
||||
super().__init__(loglevel)
|
||||
# TODO: support log_every_n_step
|
||||
self.log_every_n_episode = log_every_n_episode
|
||||
@@ -289,15 +293,15 @@ class ConsoleWriter(LogWriter):
|
||||
|
||||
self.console_logger = get_module_logger(__name__, level=logging.INFO)
|
||||
|
||||
def clear(self):
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
# Clear average meters
|
||||
self.metric_counts: dict[str, int] = defaultdict(int)
|
||||
self.metric_sums: dict[str, float] = defaultdict(float)
|
||||
self.metric_counts: Dict[str, int] = defaultdict(int)
|
||||
self.metric_sums: Dict[str, float] = defaultdict(float)
|
||||
|
||||
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
|
||||
def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None:
|
||||
# Aggregate step-wise to episode-wise
|
||||
episode_wise_contents: dict[str, list] = defaultdict(list)
|
||||
episode_wise_contents: Dict[str, list] = defaultdict(list)
|
||||
|
||||
for step_contents in contents:
|
||||
for name, value in step_contents.items():
|
||||
@@ -306,7 +310,7 @@ class ConsoleWriter(LogWriter):
|
||||
|
||||
# Generate log contents and track them in average-meter.
|
||||
# This should be done at every step, regardless of periodic or not.
|
||||
logs: dict[str, float] = {}
|
||||
logs: Dict[str, float] = {}
|
||||
for name, values in episode_wise_contents.items():
|
||||
logs[name] = self.aggregation(values) # type: ignore
|
||||
|
||||
@@ -318,7 +322,7 @@ class ConsoleWriter(LogWriter):
|
||||
# Only log periodically or at the end
|
||||
self.console_logger.info(self.generate_log_message(logs))
|
||||
|
||||
def generate_log_message(self, logs: dict[str, float]) -> str:
|
||||
def generate_log_message(self, logs: Dict[str, float]) -> str:
|
||||
if self.prefix:
|
||||
msg_prefix = self.prefix + " "
|
||||
else:
|
||||
@@ -348,27 +352,27 @@ class CsvWriter(LogWriter):
|
||||
|
||||
SUPPORTED_TYPES = (float, str, pd.Timestamp)
|
||||
|
||||
all_records: list[dict[str, Any]]
|
||||
all_records: List[Dict[str, Any]]
|
||||
|
||||
def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC):
|
||||
def __init__(self, output_dir: Path, loglevel: Union[int, LogLevel] = LogLevel.PERIODIC) -> None:
|
||||
super().__init__(loglevel)
|
||||
self.output_dir = output_dir
|
||||
self.output_dir.mkdir(exist_ok=True)
|
||||
|
||||
def clear(self):
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
self.all_records = []
|
||||
|
||||
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
|
||||
def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None:
|
||||
# FIXME Same as ConsoleLogger, needs a refactor to eliminate code-dup
|
||||
episode_wise_contents: dict[str, list] = defaultdict(list)
|
||||
episode_wise_contents: Dict[str, list] = defaultdict(list)
|
||||
|
||||
for step_contents in contents:
|
||||
for name, value in step_contents.items():
|
||||
if isinstance(value, self.SUPPORTED_TYPES):
|
||||
episode_wise_contents[name].append(value)
|
||||
|
||||
logs: dict[str, float] = {}
|
||||
logs: Dict[str, float] = {}
|
||||
for name, values in episode_wise_contents.items():
|
||||
logs[name] = self.aggregation(values) # type: ignore
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@ from random import randint, choice
|
||||
from pathlib import Path
|
||||
|
||||
import re
|
||||
from typing import Any, Tuple
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -24,16 +26,16 @@ from qlib.rl.utils.finite_env import vectorize_env
|
||||
|
||||
|
||||
class SimpleEnv(gym.Env[int, int]):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.logger = LogCollector()
|
||||
self.observation_space = gym.spaces.Discrete(2)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
def reset(self, *args: Any, **kwargs: Any) -> int:
|
||||
self.step_count = 0
|
||||
return 0
|
||||
|
||||
def step(self, action: int):
|
||||
def step(self, action: int) -> Tuple[int, float, bool, dict]:
|
||||
self.logger.reset()
|
||||
|
||||
self.logger.add_scalar("reward", 42.0)
|
||||
@@ -53,6 +55,9 @@ class SimpleEnv(gym.Env[int, int]):
|
||||
|
||||
return 1, 42.0, done, InfoDict(log=self.logger.logs(), aux_info={})
|
||||
|
||||
def render(self, mode: str = "human") -> None:
|
||||
pass
|
||||
|
||||
|
||||
class AnyPolicy(BasePolicy):
|
||||
def forward(self, batch, state=None):
|
||||
@@ -86,7 +91,8 @@ def test_simple_env_logger(caplog):
|
||||
|
||||
|
||||
class SimpleSimulator(Simulator[int, float, float]):
|
||||
def __init__(self, initial: int, **kwargs) -> None:
|
||||
def __init__(self, initial: int, **kwargs: Any) -> None:
|
||||
super(SimpleSimulator, self).__init__(initial, **kwargs)
|
||||
self.initial = float(initial)
|
||||
|
||||
def step(self, action: float) -> None:
|
||||
|
||||
Reference in New Issue
Block a user