mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 19:10:58 +08:00
151 lines
5.3 KiB
Python
151 lines
5.3 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
|
|
|
|
import numpy as np
|
|
|
|
from qlib.typehint import final
|
|
|
|
from .simulator import ActType, StateType
|
|
|
|
if TYPE_CHECKING:
|
|
from .utils.env_wrapper import EnvWrapper
|
|
|
|
import gym
|
|
from gym import spaces
|
|
|
|
ObsType = TypeVar("ObsType")
|
|
PolicyActType = TypeVar("PolicyActType")
|
|
|
|
|
|
class Interpreter:
|
|
"""Interpreter is a media between states produced by simulators and states needed by RL policies.
|
|
Interpreters are two-way:
|
|
|
|
1. From simulator state to policy state (aka observation), see :class:`StateInterpreter`.
|
|
2. From policy action to action accepted by simulator, see :class:`ActionInterpreter`.
|
|
|
|
Inherit one of the two sub-classes to define your own interpreter.
|
|
This super-class is only used for isinstance check.
|
|
|
|
Interpreters are recommended to be stateless, meaning that storing temporary information with ``self.xxx``
|
|
in interpreter is anti-pattern. In future, we might support register some interpreter-related
|
|
states by calling ``self.env.register_state()``, but it's not planned for first iteration.
|
|
"""
|
|
|
|
|
|
class StateInterpreter(Generic[StateType, ObsType], Interpreter):
|
|
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
|
|
|
|
env: Optional[EnvWrapper] = None
|
|
|
|
@property
|
|
def observation_space(self) -> gym.Space:
|
|
raise NotImplementedError()
|
|
|
|
@final # no overridden
|
|
def __call__(self, simulator_state: StateType) -> ObsType:
|
|
obs = self.interpret(simulator_state)
|
|
self.validate(obs)
|
|
return obs
|
|
|
|
def validate(self, obs: ObsType) -> None:
|
|
"""Validate whether an observation belongs to the pre-defined observation space."""
|
|
_gym_space_contains(self.observation_space, obs)
|
|
|
|
def interpret(self, simulator_state: StateType) -> ObsType:
|
|
"""Interpret the state of simulator.
|
|
|
|
Parameters
|
|
----------
|
|
simulator_state
|
|
Retrieved with ``simulator.get_state()``.
|
|
|
|
Returns
|
|
-------
|
|
State needed by policy. Should conform with the state space defined in ``observation_space``.
|
|
"""
|
|
raise NotImplementedError("interpret is not implemented!")
|
|
|
|
|
|
class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter):
|
|
"""Action Interpreter that interpret rl agent action into qlib orders"""
|
|
|
|
env: Optional[EnvWrapper] = None
|
|
|
|
@property
|
|
def action_space(self) -> gym.Space:
|
|
raise NotImplementedError()
|
|
|
|
@final # no overridden
|
|
def __call__(self, simulator_state: StateType, action: PolicyActType) -> ActType:
|
|
self.validate(action)
|
|
obs = self.interpret(simulator_state, action)
|
|
return obs
|
|
|
|
def validate(self, action: PolicyActType) -> None:
|
|
"""Validate whether an action belongs to the pre-defined action space."""
|
|
_gym_space_contains(self.action_space, action)
|
|
|
|
def interpret(self, simulator_state: StateType, action: PolicyActType) -> ActType:
|
|
"""Convert the policy action to simulator action.
|
|
|
|
Parameters
|
|
----------
|
|
simulator_state
|
|
Retrieved with ``simulator.get_state()``.
|
|
action
|
|
Raw action given by policy.
|
|
|
|
Returns
|
|
-------
|
|
The action needed by simulator,
|
|
"""
|
|
raise NotImplementedError("interpret is not implemented!")
|
|
|
|
|
|
def _gym_space_contains(space: gym.Space, x: Any) -> None:
|
|
"""Strengthened version of gym.Space.contains.
|
|
Giving more diagnostic information on why validation fails.
|
|
|
|
Throw exception rather than returning true or false.
|
|
"""
|
|
if isinstance(space, spaces.Dict):
|
|
if not isinstance(x, dict) or len(x) != len(space):
|
|
raise GymSpaceValidationError("Sample must be a dict with same length as space.", space, x)
|
|
for k, subspace in space.spaces.items():
|
|
if k not in x:
|
|
raise GymSpaceValidationError(f"Key {k} not found in sample.", space, x)
|
|
try:
|
|
_gym_space_contains(subspace, x[k])
|
|
except GymSpaceValidationError as e:
|
|
raise GymSpaceValidationError(f"Subspace of key {k} validation error.", space, x) from e
|
|
|
|
elif isinstance(space, spaces.Tuple):
|
|
if isinstance(x, (list, np.ndarray)):
|
|
x = tuple(x) # Promote list and ndarray to tuple for contains check
|
|
if not isinstance(x, tuple) or len(x) != len(space):
|
|
raise GymSpaceValidationError("Sample must be a tuple with same length as space.", space, x)
|
|
for i, (subspace, part) in enumerate(zip(space, x)):
|
|
try:
|
|
_gym_space_contains(subspace, part)
|
|
except GymSpaceValidationError as e:
|
|
raise GymSpaceValidationError(f"Subspace of index {i} validation error.", space, x) from e
|
|
|
|
else:
|
|
if not space.contains(x):
|
|
raise GymSpaceValidationError("Validation error reported by gym.", space, x)
|
|
|
|
|
|
class GymSpaceValidationError(Exception):
|
|
def __init__(self, message: str, space: gym.Space, x: Any) -> None:
|
|
self.message = message
|
|
self.space = space
|
|
self.x = x
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.message}\n Space: {self.space}\n Sample: {self.x}"
|