1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Qlib RL framework (stage 1) - single-asset order execution (#1076)

* rl init

* aux info

* Reward config

* update

* simple

* update saoe init

* update simulator and seed

* minor

* minor

* update sim

* checkpoint

* obs

* Update interpreter

* init qlib simulator

* checkpoint

* Refine codebase

* checkpoint

* checkpoint

* Add one test

* More tests

* Simulator checkpoint

* checkpoint

* First-step tested

* Checkpoint

* Update data_queue API

* Checkpoint

* Update test

* Move files

* Checkpoint

* Single-quote -> double-quote

* Fix finite env tests

* Tested with mypy

* pep-574

* No call for env done

* Update finite env docs

* Fix csv writer

* Refine tester

* Update logger

* Add another logger test

* Checkpoint

* Add network sanity test

* steps per episode is not correct

* Cleanup code, ready for PR

* Reformat with black

* Fix pylint for py37

* Fix lint

* Fix lint

* Fix flake

* update mypy command

* mypy

* Update exclude pattern

* Use pyproject.toml

* test

* .

* .

* Refactor pipeline

* .

* defaults run bash

* .

* Revert and skip follow_imports

* Fix toml issue

* fix mypy

* .

* .

* .

* Fix install

* Minor fix

* Fix test

* Fix test

* Remove requirements

* Revert

* fix tests

* Fix lint

* .

* .

* .

* .

* .

* update install from source command

* .

* Fix data download

* .

* .

* .

* .

* .

* .

* Fix py37

* Ignore tests on non-linux

* resolve comments

* fix tests

* resolve comments

* some typo

* style updates

* More comments

* fix dummy

* add warning

* Align precision in some system

* Added some impl notes

Co-authored-by: Young <afe.young@gmail.com>
This commit is contained in:
Yuge Zhang
2022-05-21 18:19:24 +08:00
committed by GitHub
parent c4281121e3
commit 9a40fd3cdc
36 changed files with 3680 additions and 121 deletions

View File

@@ -35,7 +35,7 @@ jobs:
pip install numpy==1.19.5 ruamel.yaml
pip install pyqlib --ignore-installed
- name: Make html with sphnix
- name: Make html with sphinx
run: |
pip install -U sphinx
pip install sphinx_rtd_theme readthedocs_sphinx_ext
@@ -97,12 +97,21 @@ jobs:
run: |
pip install --upgrade pip
pip install flake8
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 qlib
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
# https://github.com/python/mypy/issues/10600
- name: Check Qlib with mypy
run: |
pip install mypy
mypy qlib --install-types --non-interactive || true
mypy qlib
- name: Test data downloads
run: |
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data_simple --interval 1d --region cn
python -c "import os; userpath=os.path.expanduser('~'); os.rename(userpath + '/.qlib/qlib_data/cn_data_simple', userpath + '/.qlib/qlib_data/cn_data')"
azcopy copy https://qlibpublic.blob.core.windows.net/data /tmp/qlibpublic --recursive
mv /tmp/qlibpublic/data tests/.data
- name: Test workflow by config (install from pip)
run: |
@@ -113,6 +122,7 @@ jobs:
- name: Install Qlib from source
run: |
pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
pip install gym tianshou torch
pip install -e .
- name: Install test dependencies
@@ -129,4 +139,3 @@ jobs:
- name: Test workflow by config (install from source)
run: |
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml

View File

@@ -38,7 +38,7 @@ jobs:
run: |
pip install --upgrade pip
pip install flake8
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 qlib
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
- name: Install Qlib with pip
run: |
@@ -65,6 +65,8 @@ jobs:
run: |
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data_simple --interval 1d --region cn
python -c "import os; userpath=os.path.expanduser('~'); os.rename(userpath + '/.qlib/qlib_data/cn_data_simple', userpath + '/.qlib/qlib_data/cn_data')"
azcopy copy https://qlibpublic.blob.core.windows.net/data /tmp/qlibpublic --recursive
mv /tmp/qlibpublic/data tests/.data
- name: Test workflow by config (install from pip)
run: |
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
@@ -75,6 +77,7 @@ jobs:
python -m pip install --upgrade cython
python -m pip install numpy jupyter jupyter_contrib_nbextensions
python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
python -m pip install gym tianshou torch
pip install -e .
- name: Install test dependencies
run: |

5
.gitignore vendored
View File

@@ -27,6 +27,10 @@ examples/estimator/estimator_example/
*.egg-info/
# test related
test-output.xml
.output
.data
# special software
mlruns/
@@ -34,6 +38,7 @@ mlruns/
tags
.pytest_cache/
.mypy_cache/
.vscode/
*.swp

17
.mypy.ini Normal file
View File

@@ -0,0 +1,17 @@
[mypy]
exclude = (?x)(
^qlib/backtest
| ^qlib/contrib
| ^qlib/data
| ^qlib/model
| ^qlib/strategy
| ^qlib/tests
| ^qlib/utils
| ^qlib/workflow
| ^qlib/config\.py$
| ^qlib/log\.py$
| ^qlib/__init__\.py$
)
ignore_missing_imports = true
disallow_incomplete_defs = true
follow_imports = skip

View File

@@ -8,3 +8,6 @@ REG_TW = "tw"
# Epsilon for avoiding division by zero.
EPS = 1e-12
# Infinity in integer
INF = 10**18

View File

@@ -61,7 +61,11 @@ def get_module_logger(module_name, level: Optional[int] = None) -> QlibLogger:
if level is None:
level = C.logging_level
if not module_name.startswith("qlib."):
# Add a prefix of qlib. when the requested ``module_name`` doesn't start with ``qlib.``.
# If the module_name is already qlib.xxx, we do not format here. Otherwise, it will become qlib.qlib.xxx.
module_name = "qlib.{}".format(module_name)
# Get logger.
module_logger = QlibLogger(module_name)
module_logger.setLevel(level)

43
qlib/rl/aux_info.py Normal file
View File

@@ -0,0 +1,43 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import Generic, TYPE_CHECKING, TypeVar
from qlib.typehint import final
from .simulator import StateType
if TYPE_CHECKING:
from .utils.env_wrapper import EnvWrapper
__all__ = ["AuxiliaryInfoCollector"]
AuxInfoType = TypeVar("AuxInfoType")
class AuxiliaryInfoCollector(Generic[StateType, AuxInfoType]):
"""Override this class to collect customized auxiliary information from environment."""
env: EnvWrapper | None = None
@final
def __call__(self, simulator_state: StateType) -> AuxInfoType:
return self.collect(simulator_state)
def collect(self, simulator_state: StateType) -> AuxInfoType:
"""Override this for customized auxiliary info.
Usually useful in Multi-agent RL.
Parameters
----------
simulator_state
Retrieved with ``simulator.get_state()``.
Returns
-------
Auxiliary information.
"""
raise NotImplementedError("collect is not implemented!")

8
qlib/rl/data/__init__.py Normal file
View File

@@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Common utilities to handle ad-hoc-styled data.
Most of these snippets comes from research project (paper code).
Please take caution when using them in production.
"""

View File

@@ -0,0 +1,257 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""This module contains utilities to read financial data from pickle-styled files.
This is the format used in `OPD paper <https://seqml.github.io/opd/>`__. NOT the standard data format in qlib.
The data here are all wrapped with ``@lru_cache``, which saves the expensive IO cost to repetitively read the data.
We also encourage users to use ``get_xxx_yyy`` rather than ``XxxYyy`` (although they are the same thing),
because ``get_xxx_yyy`` is cache-optimized.
Note that these pickle files are dumped with Python 3.8. Python lower than 3.7 might not be able to load them.
See `PEP 574 <https://peps.python.org/pep-0574/>`__ for details.
This file shows resemblence to qlib.backtest.high_performance_ds. We might merge those two in future.
"""
# TODO: merge with qlib/backtest/high_performance_ds.py
from __future__ import annotations
from functools import lru_cache
from typing import List, Sequence, cast
from pathlib import Path
import cachetools
import numpy as np
import pandas as pd
from cachetools.keys import hashkey
from qlib.backtest.decision import OrderDir, Order
from qlib.typehint import Literal
DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"]
"""Several ad-hoc deal price.
``bid_or_ask``: If sell, use column ``$bid0``; if buy, use column ``$ask0``.
``bid_or_ask_fill``: Based on ``bid_or_ask``. If price is 0, use another price (``$ask0`` / ``$bid0``) instead.
``close``: Use close price (``$close0``) as deal price.
"""
def _infer_processed_data_column_names(shape: int) -> list[str]:
if shape == 16:
return [
"$open",
"$high",
"$low",
"$close",
"$vwap",
"$bid",
"$ask",
"$volume",
"$bidV",
"$bidV1",
"$bidV3",
"$bidV5",
"$askV",
"$askV1",
"$askV3",
"$askV5",
]
if shape == 6:
return ["$high", "$low", "$open", "$close", "$vwap", "$volume"]
elif shape == 5:
return ["$high", "$low", "$open", "$close", "$volume"]
raise ValueError(f"Unrecognized data shape: {shape}")
def _find_pickle(filename_without_suffix: Path) -> Path:
suffix_list = [".pkl", ".pkl.backtest"]
paths: List[Path] = []
for suffix in suffix_list:
path = filename_without_suffix.parent / (filename_without_suffix.name + suffix)
if path.exists():
paths.append(path)
if not paths:
raise FileNotFoundError(f"No file starting with '{filename_without_suffix}' found")
if len(paths) > 1:
raise ValueError(f"Multiple paths are found with prefix '{filename_without_suffix}': {paths}")
return paths[0]
@lru_cache(maxsize=10) # 10 * 40M = 400MB
def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame:
return pd.read_pickle(_find_pickle(filename_without_suffix))
class IntradayBacktestData:
"""Raw market data that is often used in backtesting (thus called BacktestData)."""
def __init__(
self,
data_dir: Path,
stock_id: str,
date: pd.Timestamp,
deal_price: DealPriceType = "close",
order_dir: int | None = None,
):
backtest = _read_pickle(data_dir / stock_id)
backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]]
# No longer need for pandas >= 1.4
# backtest = backtest.droplevel([0, 2])
self.data: pd.DataFrame = backtest
self.deal_price_type: DealPriceType = deal_price
self.order_dir: int | None = order_dir
def __repr__(self):
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
return f"{self.__class__.__name__}({self.data})"
def __len__(self):
return len(self.data)
def get_deal_price(self) -> pd.Series:
"""Return a pandas series that can be indexed with time.
See :attribute:`DealPriceType` for details."""
if self.deal_price_type in ("bid_or_ask", "bid_or_ask_fill"):
if self.order_dir is None:
raise ValueError("Order direction cannot be none when deal_price_type is not close.")
if self.order_dir == OrderDir.SELL:
col = "$bid0"
else: # BUY
col = "$ask0"
elif self.deal_price_type == "close":
col = "$close0"
else:
raise ValueError(f"Unsupported deal_price_type: {self.deal_price_type}")
price = self.data[col]
if self.deal_price_type == "bid_or_ask_fill":
if self.order_dir == OrderDir.SELL:
fill_col = "$ask0"
else:
fill_col = "$bid0"
price = price.replace(0, np.nan).fillna(self.data[fill_col])
return price
def get_volume(self) -> pd.Series:
"""Return a volume series that can be indexed with time."""
return self.data["$volume0"]
def get_time_index(self) -> pd.DatetimeIndex:
return cast(pd.DatetimeIndex, self.data.index)
class IntradayProcessedData:
"""Processed market data after data cleanup and feature engineering.
It contains both processed data for "today" and "yesterday", as some algorithms
might use the market information of the previous day to assist decision making.
"""
today: pd.DataFrame
"""Processed data for "today".
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
yesterday: pd.DataFrame
"""Processed data for "yesterday".
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
def __init__(self, data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index):
proc = _read_pickle(data_dir / stock_id)
# We have to infer the names here because,
# unfortunately they are not included in the original data.
cnames = _infer_processed_data_column_names(feature_dim)
time_length: int = len(time_index)
try:
# new data format
proc = proc.loc[pd.IndexSlice[stock_id, :, date]]
assert len(proc) == time_length and len(proc.columns) == feature_dim * 2
proc_today = proc[cnames]
proc_yesterday = proc[[f"{c}_1" for c in cnames]].rename(columns=lambda c: c[:-2])
except (IndexError, KeyError):
# legacy data
proc = proc.loc[pd.IndexSlice[stock_id, date]]
assert time_length * feature_dim * 2 == len(proc)
proc_today = proc.to_numpy()[: time_length * feature_dim].reshape((time_length, feature_dim))
proc_yesterday = proc.to_numpy()[time_length * feature_dim :].reshape((time_length, feature_dim))
proc_today = pd.DataFrame(proc_today, index=time_index, columns=cnames)
proc_yesterday = pd.DataFrame(proc_yesterday, index=time_index, columns=cnames)
self.today: pd.DataFrame = proc_today
self.yesterday: pd.DataFrame = proc_yesterday
assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim
assert len(self.today) == len(self.yesterday) == time_length
def __repr__(self):
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
return f"{self.__class__.__name__}({self.today}, {self.yesterday})"
@lru_cache(maxsize=100) # 100 * 50K = 5MB
def load_intraday_backtest_data(
data_dir: Path, stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", order_dir: int | None = None
) -> IntradayBacktestData:
return IntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
key=lambda data_dir, stock_id, date, _, __: hashkey(data_dir, stock_id, date),
)
def load_intraday_processed_data(
data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index
) -> IntradayProcessedData:
return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
def load_orders(
order_path: Path, start_time: pd.Timestamp | None = None, end_time: pd.Timestamp | None = None
) -> Sequence[Order]:
"""Load orders, and set start time and end time for the orders."""
start_time = start_time or pd.Timestamp("0:00:00")
end_time = end_time or pd.Timestamp("23:59:59")
if order_path.is_file():
order_df = pd.read_pickle(order_path)
else:
order_df = []
for file in order_path.iterdir():
order_data = pd.read_pickle(file)
order_df.append(order_data)
order_df = pd.concat(order_df)
order_df = order_df.reset_index()
# Legacy-style orders have "date" instead of "datetime"
if "date" in order_df.columns:
order_df = order_df.rename(columns={"date": "datetime"})
# Sometimes "date" are str rather than Timestamp
order_df["datetime"] = pd.to_datetime(order_df["datetime"])
orders: List[Order] = []
for _, row in order_df.iterrows():
# filter out orders with amount == 0
if row["amount"] <= 0:
continue
orders.append(
Order(
row["instrument"],
row["amount"],
int(row["order_type"]),
row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second),
row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second),
)
)
return orders

View File

@@ -0,0 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Train, test, inference utilities.
The APIs in this directory are NOT considered final and are subject to change!
"""

99
qlib/rl/entries/test.py Normal file
View File

@@ -0,0 +1,99 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import copy
from typing import Callable, Sequence
from tianshou.data import Collector
from tianshou.policy import BasePolicy
from qlib.constant import INF
from qlib.log import get_module_logger
from qlib.rl.simulator import InitialStateType, Simulator
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.reward import Reward
from qlib.rl.utils import DataQueue, EnvWrapper, FiniteEnvType, LogCollector, LogWriter, vectorize_env
_logger = get_module_logger(__name__)
def backtest(
simulator_fn: Callable[[InitialStateType], Simulator],
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
initial_states: Sequence[InitialStateType],
policy: BasePolicy,
logger: LogWriter | list[LogWriter],
reward: Reward | None = None,
finite_env_type: FiniteEnvType = "subproc",
concurrency: int = 2,
) -> None:
"""Backtest with the parallelism provided by RL framework.
Parameters
----------
simulator_fn
Callable receiving initial seed, returning a simulator.
state_interpreter
Interprets the state of simulators.
action_interpreter
Interprets the policy actions.
initial_states
Initial states to iterate over. Every state will be run exactly once.
policy
Policy to test against.
logger
Logger to record the backtest results. Logger must be present because
without logger, all information will be lost.
reward
Optional reward function. For backtest, this is for testing the rewards
and logging them only.
finite_env_type
Type of finite env implementation.
concurrency
Parallel workers.
"""
# To save bandwidth
min_loglevel = min(lg.loglevel for lg in logger) if isinstance(logger, list) else logger.loglevel
def env_factory():
# FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
# and could be thread unsafe.
# I'm not sure whether it's a design flaw.
# I'll rethink about this when designing the trainer.
if finite_env_type == "dummy":
# We could only experience the "threading-unsafe" problem in dummy.
state = copy.deepcopy(state_interpreter)
action = copy.deepcopy(action_interpreter)
rew = copy.deepcopy(reward)
else:
state, action, rew = state_interpreter, action_interpreter, reward
return EnvWrapper(
simulator_fn,
state,
action,
seed_iterator,
rew,
logger=LogCollector(min_loglevel=min_loglevel),
)
with DataQueue(initial_states) as seed_iterator:
vector_env = vectorize_env(
env_factory,
finite_env_type,
concurrency,
logger,
)
policy.eval()
with vector_env.collector_guard():
test_collector = Collector(policy, vector_env)
_logger.info("All ready. Start backtest.")
test_collector.collect(n_step=INF * len(vector_env))

4
qlib/rl/entries/train.py Normal file
View File

@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# TBD

View File

@@ -1,94 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Union
from ..backtest.executor import BaseExecutor
from .interpreter import StateInterpreter, ActionInterpreter
from ..utils import init_instance_by_config
class BaseRLEnv:
"""Base environment for reinforcement learning"""
def reset(self, **kwargs):
raise NotImplementedError("reset is not implemented!")
def step(self, action):
"""
step method of rl env
Parameters
----------
action :
action from rl policy
Returns
-------
env state to rl policy
"""
raise NotImplementedError("step is not implemented!")
class QlibRLEnv:
"""qlib-based RL env"""
def __init__(
self,
executor: BaseExecutor,
):
"""
Parameters
----------
executor : BaseExecutor
qlib multi-level/single-level executor, which can be regarded as gamecore in RL
"""
self.executor = executor
def reset(self, **kwargs):
self.executor.reset(**kwargs)
class QlibIntRLEnv(QlibRLEnv):
"""(Qlib)-based RL (Env) with (Interpreter)"""
def __init__(
self,
executor: BaseExecutor,
state_interpreter: Union[dict, StateInterpreter],
action_interpreter: Union[dict, ActionInterpreter],
):
"""
Parameters
----------
state_interpreter : Union[dict, StateInterpreter]
interpreter that interprets the qlib execute result into rl env state.
action_interpreter : Union[dict, ActionInterpreter]
interpreter that interprets the rl agent action into qlib order list
"""
super(QlibIntRLEnv, self).__init__(executor=executor)
self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter)
self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter)
def step(self, action):
"""
step method of rl env, it run as following step:
- Use `action_interpreter.interpret` method to interpret the agent action into order list
- Execute the order list with qlib executor, and get the executed result
- Use `state_interpreter.interpret` method to interpret the executed result into env state
Parameters
----------
action :
action from rl policy
Returns
-------
env state to rl policy
"""
_interpret_decision = self.action_interpreter.interpret(action=action)
_execute_result = self.executor.execute(trade_decision=_interpret_decision)
_interpret_state = self.state_interpreter.interpret(execute_result=_execute_result)
return _interpret_state

View File

@@ -1,47 +1,150 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
class BaseInterpreter:
"""Base Interpreter"""
from typing import TYPE_CHECKING, TypeVar, Generic, Any
def interpret(self, **kwargs):
raise NotImplementedError("interpret is not implemented!")
import numpy as np
from qlib.typehint import final
from .simulator import StateType, ActType
if TYPE_CHECKING:
from .utils.env_wrapper import EnvWrapper
import gym
from gym import spaces
ObsType = TypeVar("ObsType")
PolicyActType = TypeVar("PolicyActType")
class ActionInterpreter(BaseInterpreter):
"""Action Interpreter that interpret rl agent action into qlib orders"""
class Interpreter:
"""Interpreter is a media between states produced by simulators and states needed by RL policies.
Interpreters are two-way:
def interpret(self, action, **kwargs):
"""interpret method
1. From simulator state to policy state (aka observation), see :class:`StateInterpreter`.
2. From policy action to action accepted by simulator, see :class:`ActionInterpreter`.
Inherit one of the two sub-classes to define your own interpreter.
This super-class is only used for isinstance check.
Interpreters are recommended to be stateless, meaning that storing temporary information with ``self.xxx``
in interpreter is anti-pattern. In future, we might support register some interpreter-related
states by calling ``self.env.register_state()``, but it's not planned for first iteration.
"""
class StateInterpreter(Generic[StateType, ObsType], Interpreter):
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
env: EnvWrapper | None = None
@property
def observation_space(self) -> gym.Space:
raise NotImplementedError()
@final # no overridden
def __call__(self, simulator_state: StateType) -> ObsType:
obs = self.interpret(simulator_state)
self.validate(obs)
return obs
def validate(self, obs: ObsType) -> None:
"""Validate whether an observation belongs to the pre-defined observation space."""
_gym_space_contains(self.observation_space, obs)
def interpret(self, simulator_state: StateType) -> ObsType:
"""Interpret the state of simulator.
Parameters
----------
action :
rl agent action
simulator_state
Retrieved with ``simulator.get_state()``.
Returns
-------
qlib orders
State needed by policy. Should conform with the state space defined in ``observation_space``.
"""
raise NotImplementedError("interpret is not implemented!")
class StateInterpreter(BaseInterpreter):
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter):
"""Action Interpreter that interpret rl agent action into qlib orders"""
def interpret(self, execute_result, **kwargs):
"""interpret method
env: "EnvWrapper" | None = None
@property
def action_space(self) -> gym.Space:
raise NotImplementedError()
@final # no overridden
def __call__(self, simulator_state: StateType, action: PolicyActType) -> ActType:
self.validate(action)
obs = self.interpret(simulator_state, action)
return obs
def validate(self, action: PolicyActType) -> None:
"""Validate whether an action belongs to the pre-defined action space."""
_gym_space_contains(self.action_space, action)
def interpret(self, simulator_state: StateType, action: PolicyActType) -> ActType:
"""Convert the policy action to simulator action.
Parameters
----------
execute_result :
qlib execution result
simulator_state
Retrieved with ``simulator.get_state()``.
action
Raw action given by policy.
Returns
----------
rl env state
-------
The action needed by simulator,
"""
raise NotImplementedError("interpret is not implemented!")
def _gym_space_contains(space: gym.Space, x: Any) -> None:
"""Strengthened version of gym.Space.contains.
Giving more diagnostic information on why validation fails.
Throw exception rather than returning true or false.
"""
if isinstance(space, spaces.Dict):
if not isinstance(x, dict) or len(x) != len(space):
raise GymSpaceValidationError("Sample must be a dict with same length as space.", space, x)
for k, subspace in space.spaces.items():
if k not in x:
raise GymSpaceValidationError(f"Key {k} not found in sample.", space, x)
try:
_gym_space_contains(subspace, x[k])
except GymSpaceValidationError as e:
raise GymSpaceValidationError(f"Subspace of key {k} validation error.", space, x) from e
elif isinstance(space, spaces.Tuple):
if isinstance(x, (list, np.ndarray)):
x = tuple(x) # Promote list and ndarray to tuple for contains check
if not isinstance(x, tuple) or len(x) != len(space):
raise GymSpaceValidationError("Sample must be a tuple with same length as space.", space, x)
for i, (subspace, part) in enumerate(zip(space, x)):
try:
_gym_space_contains(subspace, part)
except GymSpaceValidationError as e:
raise GymSpaceValidationError(f"Subspace of index {i} validation error.", space, x) from e
else:
if not space.contains(x):
raise GymSpaceValidationError("Validation error reported by gym.", space, x)
class GymSpaceValidationError(Exception):
def __init__(self, message: str, space: gym.Space, x: Any):
self.message = message
self.space = space
self.x = x
def __str__(self):
return f"{self.message}\n Space: {self.space}\n Sample: {self.x}"

View File

@@ -0,0 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Currently it supports single-asset order execution.
Multi-asset is on the way.
"""
from .interpreter import *
from .network import *
from .policy import *
from .simulator_simple import *

View File

@@ -0,0 +1,222 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import math
from pathlib import Path
from typing import Any, cast
import numpy as np
import pandas as pd
from gym import spaces
from qlib.constant import EPS
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.data import pickle_styled
from qlib.typehint import TypedDict
from .simulator_simple import SAOEState
__all__ = [
"FullHistoryStateInterpreter",
"CurrentStepStateInterpreter",
"CategoricalActionInterpreter",
"TwapRelativeActionInterpreter",
]
def canonicalize(value: int | float | np.ndarray | pd.DataFrame | dict) -> np.ndarray | dict:
"""To 32-bit numeric types. Recursively."""
if isinstance(value, pd.DataFrame):
return value.to_numpy()
if isinstance(value, (float, np.floating)) or (isinstance(value, np.ndarray) and value.dtype.kind == "f"):
return np.array(value, dtype=np.float32)
elif isinstance(value, (int, bool, np.integer)) or (isinstance(value, np.ndarray) and value.dtype.kind == "i"):
return np.array(value, dtype=np.int32)
elif isinstance(value, dict):
return {k: canonicalize(v) for k, v in value.items()}
else:
return value
class FullHistoryObs(TypedDict):
data_processed: Any
data_processed_prev: Any
acquiring: Any
cur_tick: Any
cur_step: Any
num_step: Any
target: Any
position: Any
position_history: Any
class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
"""The observation of all the history, including today (until this moment), and yesterday.
Parameters
----------
data_dir
Path to load data after feature engineering.
max_step
Total number of steps (an upper-bound estimation). For example, 390min / 30min-per-step = 13 steps.
data_ticks
Equal to the total number of records. For example, in SAOE per minute,
the total ticks is the length of day in minutes.
data_dim
Number of dimensions in data.
"""
def __init__(self, data_dir: Path, max_step: int, data_ticks: int, data_dim: int) -> None:
self.data_dir = data_dir
self.max_step = max_step
self.data_ticks = data_ticks
self.data_dim = data_dim
def interpret(self, state: SAOEState) -> FullHistoryObs:
processed = pickle_styled.load_intraday_processed_data(
self.data_dir,
state.order.stock_id,
pd.Timestamp(state.order.start_time.date()),
self.data_dim,
state.ticks_index,
)
position_history = np.full(self.max_step + 1, 0.0, dtype=np.float32)
position_history[0] = state.order.amount
position_history[1 : len(state.history_steps) + 1] = state.history_steps["position"].to_numpy()
assert self.env is not None
# The min, slice here are to make sure that indices fit into the range,
# even after the final step of the simulator (in the done step),
# to make network in policy happy.
return cast(
FullHistoryObs,
canonicalize(
{
"data_processed": self._mask_future_info(processed.today, state.cur_time),
"data_processed_prev": processed.yesterday,
"acquiring": state.order.direction == state.order.BUY,
"cur_tick": min(np.sum(state.ticks_index < state.cur_time), self.data_ticks - 1),
"cur_step": min(self.env.status["cur_step"], self.max_step - 1),
"num_step": self.max_step,
"target": state.order.amount,
"position": state.position,
"position_history": position_history[: self.max_step],
}
),
)
@property
def observation_space(self):
space = {
"data_processed": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),
"data_processed_prev": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),
"acquiring": spaces.Discrete(2),
"cur_tick": spaces.Box(0, self.data_ticks - 1, shape=(), dtype=np.int32),
"cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32),
# TODO: support arbitrary length index
"num_step": spaces.Box(self.max_step, self.max_step, shape=(), dtype=np.int32),
"target": spaces.Box(-EPS, np.inf, shape=()),
"position": spaces.Box(-EPS, np.inf, shape=()),
"position_history": spaces.Box(-EPS, np.inf, shape=(self.max_step,)),
}
return spaces.Dict(space)
@staticmethod
def _mask_future_info(arr: pd.DataFrame, current: pd.Timestamp) -> pd.DataFrame:
arr = arr.copy(deep=True)
arr.loc[current:] = 0.0 # mask out data after this moment (inclusive)
return arr
class CurrentStateObs(TypedDict):
acquiring: bool
cur_step: int
num_step: int
target: float
position: float
class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
"""The observation of current step.
Used when policy only depends on the latest state, but not history.
The key list is not full. You can add more if more information is needed by your policy.
"""
def __init__(self, max_step: int):
self.max_step = max_step
@property
def observation_space(self):
space = {
"acquiring": spaces.Discrete(2),
"cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32),
"num_step": spaces.Box(self.max_step, self.max_step, shape=(), dtype=np.int32),
"target": spaces.Box(-EPS, np.inf, shape=()),
"position": spaces.Box(-EPS, np.inf, shape=()),
}
return spaces.Dict(space)
def interpret(self, state: SAOEState) -> CurrentStateObs:
assert self.env is not None
assert self.env.status["cur_step"] <= self.max_step
obs = CurrentStateObs(
{
"acquiring": state.order.direction == state.order.BUY,
"cur_step": self.env.status["cur_step"],
"num_step": self.max_step,
"target": state.order.amount,
"position": state.position,
}
)
return obs
class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
"""Convert a discrete policy action to a continuous action, then multiplied by ``order.amount``.
Parameters
----------
values
It can be a list of length $L$: $[a_1, a_2, \\ldots, a_L]$.
Then when policy givens decision $x$, $a_x$ times order amount is the output.
It can also be an integer $n$, in which case the list of length $n+1$ is auto-generated,
i.e., $[0, 1/n, 2/n, \\ldots, n/n]$.
"""
def __init__(self, values: int | list[float]):
if isinstance(values, int):
values = [i / values for i in range(0, values + 1)]
self.action_values = values
@property
def action_space(self) -> spaces.Discrete:
return spaces.Discrete(len(self.action_values))
def interpret(self, state: SAOEState, action: int) -> float:
assert 0 <= action < len(self.action_values)
return min(state.position, state.order.amount * self.action_values[action])
class TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]):
"""Convert a continous ratio to deal amount.
The ratio is relative to TWAP on the remainder of the day.
For example, there are 5 steps left, and the left position is 300.
With TWAP strategy, in each position, 60 should be traded.
When this interpreter receives action $a$, its output is $60 \\cdot a$.
"""
@property
def action_space(self) -> spaces.Box:
return spaces.Box(0, np.inf, shape=(), dtype=np.float32)
def interpret(self, state: SAOEState, action: float) -> float:
assert self.env is not None
estimated_total_steps = math.ceil(len(state.ticks_for_order) / state.ticks_per_step)
twap_volume = state.position / (estimated_total_steps - self.env.status["cur_step"])
return min(state.position, twap_volume * action)

View File

@@ -0,0 +1,118 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import cast
import torch
import torch.nn as nn
from tianshou.data import Batch
from qlib.typehint import Literal
from .interpreter import FullHistoryObs
__all__ = ["Recurrent"]
class Recurrent(nn.Module):
"""The network architecture proposed in `OPD <https://seqml.github.io/opd/opd_aaai21_supplement.pdf>`_.
At every timestep the input of policy network is divided into two parts,
the public variables and the private variables. which are handled by ``raw_rnn``
and ``pri_rnn`` in this network, respectively.
One minor difference is that, in this implementation, we don't assume the direction to be fixed.
Thus, another ``dire_fc`` is added to produce an extra direction-related feature.
"""
def __init__(
self,
obs_space: FullHistoryObs,
hidden_dim: int = 64,
output_dim: int = 32,
rnn_type: Literal["rnn", "lstm", "gru"] = "gru",
rnn_num_layers: int = 1,
):
super().__init__()
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.num_sources = 3
rnn_classes = {"rnn": nn.RNN, "lstm": nn.LSTM, "gru": nn.GRU}
self.rnn_class = rnn_classes[rnn_type]
self.rnn_layers = rnn_num_layers
self.raw_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers)
self.prev_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers)
self.pri_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers)
self.raw_fc = nn.Sequential(nn.Linear(obs_space["data_processed"].shape[-1], hidden_dim), nn.ReLU())
self.pri_fc = nn.Sequential(nn.Linear(2, hidden_dim), nn.ReLU())
self.dire_fc = nn.Sequential(nn.Linear(2, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU())
self._init_extra_branches()
self.fc = nn.Sequential(
nn.Linear(hidden_dim * self.num_sources, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
nn.ReLU(),
)
def _init_extra_branches(self):
pass
def _source_features(self, obs: FullHistoryObs, device: torch.device) -> tuple[list[torch.Tensor], torch.Tensor]:
bs, _, data_dim = obs["data_processed"].size()
data = torch.cat((torch.zeros(bs, 1, data_dim, device=device), obs["data_processed"]), 1)
cur_step = obs["cur_step"].long()
cur_tick = obs["cur_tick"].long()
bs_indices = torch.arange(bs, device=device)
position = obs["position_history"] / obs["target"].unsqueeze(-1) # [bs, num_step]
steps = (
torch.arange(position.size(-1), device=device).unsqueeze(0).repeat(bs, 1).float()
/ obs["num_step"].unsqueeze(-1).float()
) # [bs, num_step]
priv = torch.stack((position.float(), steps), -1)
data_in = self.raw_fc(data)
data_out, _ = self.raw_rnn(data_in)
# as it is padded with zero in front, this should be last minute
data_out_slice = data_out[bs_indices, cur_tick]
priv_in = self.pri_fc(priv)
priv_out = self.pri_rnn(priv_in)[0]
priv_out = priv_out[bs_indices, cur_step]
sources = [data_out_slice, priv_out]
dir_out = self.dire_fc(torch.stack((obs["acquiring"], 1 - obs["acquiring"]), -1).float())
sources.append(dir_out)
return sources, data_out
def forward(self, batch: Batch) -> torch.Tensor:
"""
Input should be a dict (at least) containing:
- data_processed: [N, T, C]
- cur_step: [N] (int)
- cur_time: [N] (int)
- position_history: [N, S] (S is number of steps)
- target: [N]
- num_step: [N] (int)
- acquiring: [N] (0 or 1)
"""
inp = cast(FullHistoryObs, batch)
device = inp["data_processed"].device
sources, _ = self._source_features(inp, device)
assert len(sources) == self.num_sources
out = torch.cat(sources, -1)
return self.fc(out)

View File

@@ -0,0 +1,158 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
from typing import Optional, cast
import numpy as np
import gym
import torch
import torch.nn as nn
from gym.spaces import Discrete
from tianshou.data import Batch, to_torch
from tianshou.policy import PPOPolicy, BasePolicy
__all__ = ["AllOne", "PPO"]
# baselines #
class NonlearnablePolicy(BasePolicy):
"""Tianshou's BasePolicy with empty ``learn`` and ``process_fn``.
This could be moved outside in future.
"""
def __init__(self, obs_space: gym.Space, action_space: gym.Space):
super().__init__()
def learn(self, batch, batch_size, repeat):
pass
def process_fn(self, batch, buffer, indice):
pass
class AllOne(NonlearnablePolicy):
"""Forward returns a batch full of 1.
Useful when implementing some baselines (e.g., TWAP).
"""
def forward(self, batch, state=None, **kwargs):
return Batch(act=np.full(len(batch), 1.0), state=state)
# ppo #
class PPOActor(nn.Module):
def __init__(self, extractor: nn.Module, action_dim: int):
super().__init__()
self.extractor = extractor
self.layer_out = nn.Sequential(nn.Linear(cast(int, extractor.output_dim), action_dim), nn.Softmax(dim=-1))
def forward(self, obs, state=None, info={}):
feature = self.extractor(to_torch(obs, device=auto_device(self)))
out = self.layer_out(feature)
return out, state
class PPOCritic(nn.Module):
def __init__(self, extractor: nn.Module):
super().__init__()
self.extractor = extractor
self.value_out = nn.Linear(cast(int, extractor.output_dim), 1)
def forward(self, obs, state=None, info={}):
feature = self.extractor(to_torch(obs, device=auto_device(self)))
return self.value_out(feature).squeeze(dim=-1)
class PPO(PPOPolicy):
"""A wrapper of tianshou PPOPolicy.
Differences:
- Auto-create actor and critic network. Supports discrete action space only.
- Dedup common parameters between actor network and critic network
(not sure whether this is included in latest tianshou or not).
- Support a ``weight_file`` that supports loading checkpoint.
- Some parameters' default values are different from original.
"""
def __init__(
self,
network: nn.Module,
obs_space: gym.Space,
action_space: gym.Space,
lr: float,
weight_decay: float = 0.0,
discount_factor: float = 1.0,
max_grad_norm: float = 100.0,
reward_normalization: bool = True,
eps_clip: float = 0.3,
value_clip: float = True,
vf_coef: float = 1.0,
gae_lambda: float = 1.0,
max_batchsize: int = 256,
deterministic_eval: bool = True,
weight_file: Optional[Path] = None,
):
assert isinstance(action_space, Discrete)
actor = PPOActor(network, action_space.n)
critic = PPOCritic(network)
optimizer = torch.optim.Adam(
chain_dedup(actor.parameters(), critic.parameters()), lr=lr, weight_decay=weight_decay
)
super().__init__(
actor,
critic,
optimizer,
torch.distributions.Categorical,
discount_factor=discount_factor,
max_grad_norm=max_grad_norm,
reward_normalization=reward_normalization,
eps_clip=eps_clip,
value_clip=value_clip,
vf_coef=vf_coef,
gae_lambda=gae_lambda,
max_batchsize=max_batchsize,
deterministic_eval=deterministic_eval,
observation_space=obs_space,
action_space=action_space,
)
if weight_file is not None:
load_weight(self, weight_file)
# utilities: these should be put in a separate (common) file. #
def auto_device(module: nn.Module) -> torch.device:
for param in module.parameters():
return param.device
return torch.device("cpu") # fallback to cpu
def load_weight(policy, path):
assert isinstance(policy, nn.Module), "Policy has to be an nn.Module to load weight."
loaded_weight = torch.load(path, map_location="cpu")
try:
policy.load_state_dict(loaded_weight)
except RuntimeError:
# try again by loading the converted weight
# https://github.com/thu-ml/tianshou/issues/468
for k in list(loaded_weight):
loaded_weight["_actor_critic." + k] = loaded_weight[k]
policy.load_state_dict(loaded_weight)
def chain_dedup(*iterables):
seen = set()
for iterable in iterables:
for i in iterable:
if i not in seen:
seen.add(i)
yield i

View File

@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Placeholder for qlib-based simulator."""

View File

@@ -0,0 +1,403 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from pathlib import Path
from typing import NamedTuple, Any, TypeVar, cast
import numpy as np
import pandas as pd
from qlib.backtest.decision import Order, OrderDir
from qlib.constant import EPS
from qlib.rl.simulator import Simulator
from qlib.rl.data.pickle_styled import IntradayBacktestData, load_intraday_backtest_data, DealPriceType
from qlib.rl.utils import LogLevel
from qlib.typehint import TypedDict
__all__ = ["SAOEMetrics", "SAOEState", "SingleAssetOrderExecution"]
ONE_SEC = pd.Timedelta("1s") # use 1 second to exclude the right interval point
class SAOEMetrics(TypedDict):
"""Metrics for SAOE accumulated for a "period".
It could be accumulated for a day, or a period of time (e.g., 30min), or calculated separately for every minute.
Warnings
--------
The type hints are for single elements. In lots of times, they can be vectorized.
For example, ``market_volume`` could be a list of float (or ndarray) rather tahn a single float.
"""
stock_id: str
"""Stock ID of this record."""
datetime: pd.Timestamp
"""Datetime of this record (this is index in the dataframe)."""
direction: int
"""Direction of the order. 0 for sell, 1 for buy."""
# Market information.
market_volume: float
"""(total) market volume traded in the period."""
market_price: float
"""Deal price. If it's a period of time, this is the average market deal price."""
# Strategy records.
amount: float
"""Total amount (volume) strategy intends to trade."""
inner_amount: float
"""Total amount that the lower-level strategy intends to trade
(might be larger than amount, e.g., to ensure ffr)."""
deal_amount: float
"""Amount that successfully takes effect (must be less than inner_amount)."""
trade_price: float
"""The average deal price for this strategy."""
trade_value: float
"""Total worth of trading. In the simple simulaton, trade_value = deal_amount * price."""
position: float
"""Position left after this "period"."""
# Accumulated metrics
ffr: float
"""Completed how much percent of the daily order."""
pa: float
"""Price advantage compared to baseline (i.e., trade with baseline market price).
The baseline is trade price when using TWAP strategy to execute this order.
Please note that there could be data leak here).
Unit is BP (basis point, 1/10000)."""
class SAOEState(NamedTuple):
"""Data structure holding a state for SAOE simulator."""
order: Order
"""The order we are dealing with."""
cur_time: pd.Timestamp
"""Current time, e.g., 9:30."""
position: float
"""Current remaining volume to execute."""
history_exec: pd.DataFrame
"""See :attr:`SingleAssetOrderExecution.history_exec`."""
history_steps: pd.DataFrame
"""See :attr:`SingleAssetOrderExecution.history_steps`."""
metrics: SAOEMetrics | None
"""Daily metric, only available when the trading is in "done" state."""
backtest_data: IntradayBacktestData
"""Backtest data is included in the state.
Actually, only the time index of this data is needed, at this moment.
I include the full data so that algorithms (e.g., VWAP) that relies on the raw data can be implemented.
Interpreter can use this as they wish, but they should be careful not to leak future data.
"""
ticks_per_step: int
"""How many ticks for each step."""
ticks_index: pd.DatetimeIndex
"""Trading ticks in all day, NOT sliced by order (defined in data). e.g., [9:30, 9:31, ..., 14:59]."""
ticks_for_order: pd.DatetimeIndex
"""Trading ticks sliced by order, e.g., [9:45, 9:46, ..., 14:44]."""
class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
"""Single-asset order execution (SAOE) simulator.
As there's no "calendar" in the simple simulator, ticks are used to trade.
A tick is a record (a line) in the pickle-styled data file.
Each tick is considered as a individual trading opportunity.
If such fine granularity is not needed, use ``ticks_per_step`` to
lengthen the ticks for each step.
In each step, the traded amount are "equally" splitted to each tick,
then bounded by volume maximum exeuction volume (i.e., ``vol_threshold``),
and if it's the last step, try to ensure all the amount to be executed.
Parameters
----------
initial
The seed to start an SAOE simulator is an order.
ticks_per_step
How many ticks per step.
data_dir
Path to load backtest data
vol_threshold
Maximum execution volume (divided by market execution volume).
"""
history_exec: pd.DataFrame
"""All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns."""
history_steps: pd.DataFrame
"""Positions at each step. The position before first step is also recorded.
See :class:`SAOEMetrics` for available columns."""
metrics: SAOEMetrics | None
"""Metrics. Only available when done."""
twap_price: float
"""This price is used to compute price advantage.
It"s defined as the average price in the period from order"s start time to end time."""
ticks_index: pd.DatetimeIndex
"""All available ticks for the day (not restricted to order)."""
ticks_for_order: pd.DatetimeIndex
"""Ticks that is available for trading (sliced by order)."""
def __init__(
self,
order: Order,
data_dir: Path,
ticks_per_step: int = 30,
deal_price_type: DealPriceType = "close",
vol_threshold: float | None = None,
) -> None:
self.order = order
self.ticks_per_step: int = ticks_per_step
self.deal_price_type = deal_price_type
self.vol_threshold = vol_threshold
self.data_dir = data_dir
self.backtest_data = load_intraday_backtest_data(
self.data_dir, order.stock_id, pd.Timestamp(order.start_time.date()), self.deal_price_type, order.direction
)
self.ticks_index = self.backtest_data.get_time_index()
# Get time index available for trading
self.ticks_for_order = self._get_ticks_slice(self.order.start_time, self.order.end_time)
self.cur_time = self.ticks_for_order[0]
# NOTE: astype(float) is necessary in some systems.
# this will align the precision with `.to_numpy()` in `_split_exec_vol`
self.twap_price = float(self.backtest_data.get_deal_price().loc[self.ticks_for_order].astype(float).mean())
self.position = order.amount
metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member
# NOTE: can empty dataframe contain index?
self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime")
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
self.metrics = None
self.market_price: np.ndarray | None = None
self.market_vol: np.ndarray | None = None
self.market_vol_limit: np.ndarray | None = None
def step(self, amount: float) -> None:
"""Execute one step or SAOE.
Parameters
----------
amount
The amount you wish to deal. The simulator doesn't guarantee all the amount to be successfully dealt.
"""
assert not self.done()
self.market_price = self.market_vol = None # avoid misuse
exec_vol = self._split_exec_vol(amount)
assert self.market_price is not None and self.market_vol is not None
ticks_position = self.position - np.cumsum(exec_vol)
self.position -= exec_vol.sum()
if self.position < -EPS or (exec_vol < -EPS).any():
raise ValueError(f"Execution volume is invalid: {exec_vol} (position = {self.position})")
# Get time index available for this step
time_index = self._get_ticks_slice(self.cur_time, self._next_time())
self.history_exec = self._dataframe_append(
self.history_exec,
SAOEMetrics(
# It should have the same keys with SAOEMetrics,
# but the values do not necessarily have the annotated type.
# Some values could be vectorized (e.g., exec_vol).
stock_id=self.order.stock_id,
datetime=time_index,
direction=self.order.direction,
market_volume=self.market_vol,
market_price=self.market_price,
amount=exec_vol,
inner_amount=exec_vol,
deal_amount=exec_vol,
trade_price=self.market_price,
trade_value=self.market_price * exec_vol,
position=ticks_position,
ffr=exec_vol / self.order.amount,
pa=price_advantage(self.market_price, self.twap_price, self.order.direction),
),
)
self.history_steps = self._dataframe_append(
self.history_steps,
[self._metrics_collect(self.cur_time, self.market_vol, self.market_price, amount, exec_vol)],
)
if self.done():
if self.env is not None:
self.env.logger.add_any("history_steps", self.history_steps, loglevel=LogLevel.DEBUG)
self.env.logger.add_any("history_exec", self.history_exec, loglevel=LogLevel.DEBUG)
self.metrics = self._metrics_collect(
self.ticks_index[0], # start time
self.history_exec["market_volume"],
self.history_exec["market_price"],
self.history_steps["amount"].sum(),
self.history_exec["deal_amount"],
)
# NOTE (yuge): It looks to me that it's the "correct" decision to
# put all the logs here, because only components like simulators themselves
# have the knowledge about what could appear in the logs, and what's the format.
# But I admit it's not necessarily the most convenient way.
# I'll rethink about it when we have the second environment
# Maybe some APIs like self.logger.enable_auto_log() ?
if self.env is not None:
for key, value in self.metrics.items():
if isinstance(value, float):
self.env.logger.add_scalar(key, value)
else:
self.env.logger.add_any(key, value)
self.cur_time = self._next_time()
def get_state(self) -> SAOEState:
return SAOEState(
order=self.order,
cur_time=self.cur_time,
position=self.position,
history_exec=self.history_exec,
history_steps=self.history_steps,
metrics=self.metrics,
backtest_data=self.backtest_data,
ticks_per_step=self.ticks_per_step,
ticks_index=self.ticks_index,
ticks_for_order=self.ticks_for_order,
)
def done(self) -> bool:
return self.position < EPS or self.cur_time >= self.order.end_time
def _next_time(self) -> pd.Timestamp:
"""The "current time" (``cur_time``) for next step."""
# Look for next time on time index
current_loc = self.ticks_index.get_loc(self.cur_time)
next_loc = current_loc + self.ticks_per_step
# Calibrate the next location to multiple of ticks_per_step.
# This is to make sure that:
# as long as ticks_per_step is a multiple of something, each step won't cross morning and afternoon.
next_loc = next_loc - next_loc % self.ticks_per_step
if next_loc < len(self.ticks_index) and self.ticks_index[next_loc] < self.order.end_time:
return self.ticks_index[next_loc]
else:
return self.order.end_time
def _cur_duration(self) -> pd.Timedelta:
"""The "duration" of this step (step that is about to happen)."""
return self._next_time() - self.cur_time
def _split_exec_vol(self, exec_vol_sum: float) -> np.ndarray:
"""
Split the volume in each step into minutes, considering possible constraints.
This follows TWAP strategy.
"""
next_time = self._next_time()
# get the backtest data for next interval
self.market_vol = self.backtest_data.get_volume().loc[self.cur_time : next_time - ONE_SEC].to_numpy()
self.market_price = self.backtest_data.get_deal_price().loc[self.cur_time : next_time - ONE_SEC].to_numpy()
assert self.market_vol is not None and self.market_price is not None
# split the volume equally into each minute
exec_vol = np.repeat(exec_vol_sum / len(self.market_price), len(self.market_price))
# apply the volume threshold
market_vol_limit = self.vol_threshold * self.market_vol if self.vol_threshold is not None else np.inf
exec_vol = np.minimum(exec_vol, market_vol_limit) # type: ignore
# Complete all the order amount at the last moment.
if next_time >= self.order.end_time:
exec_vol[-1] += self.position - exec_vol.sum()
exec_vol = np.minimum(exec_vol, market_vol_limit) # type: ignore
return exec_vol
def _metrics_collect(
self,
datetime: pd.Timestamp,
market_vol: np.ndarray,
market_price: np.ndarray,
amount: float, # intended to trade such amount
exec_vol: np.ndarray,
) -> SAOEMetrics:
assert len(market_vol) == len(market_price) == len(exec_vol)
if np.abs(np.sum(exec_vol)) < EPS:
exec_avg_price = 0.0
else:
exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan
if hasattr(exec_avg_price, "item"): # could be numpy scalar
exec_avg_price = exec_avg_price.item() # type: ignore
return SAOEMetrics(
stock_id=self.order.stock_id,
datetime=datetime,
direction=self.order.direction,
market_volume=market_vol.sum(),
market_price=market_price.mean(),
amount=amount,
inner_amount=exec_vol.sum(),
deal_amount=exec_vol.sum(), # in this simulator, there's no other restrictions
trade_price=exec_avg_price,
trade_value=np.sum(market_price * exec_vol),
position=self.position,
ffr=float(exec_vol.sum() / self.order.amount),
pa=price_advantage(exec_avg_price, self.twap_price, self.order.direction),
)
def _get_ticks_slice(self, start: pd.Timestamp, end: pd.Timestamp, include_end: bool = False) -> pd.DatetimeIndex:
if not include_end:
end = end - ONE_SEC
return self.ticks_index[self.ticks_index.slice_indexer(start, end)]
@staticmethod
def _dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame:
# dataframe.append is deprecated
other_df = pd.DataFrame(other).set_index("datetime")
other_df.index.name = "datetime"
return pd.concat([df, other_df], axis=0)
_float_or_ndarray = TypeVar("_float_or_ndarray", float, np.ndarray)
def price_advantage(
exec_price: _float_or_ndarray, baseline_price: float, direction: OrderDir | int
) -> _float_or_ndarray:
if baseline_price == 0: # something is wrong with data. Should be nan here
if isinstance(exec_price, float):
return 0.0
else:
return np.zeros_like(exec_price)
if direction == OrderDir.BUY:
res = (1 - exec_price / baseline_price) * 10000
elif direction == OrderDir.SELL:
res = (exec_price / baseline_price - 1) * 10000
else:
raise ValueError(f"Unexpected order direction: {direction}")
res_wo_nan: np.ndarray = np.nan_to_num(res, nan=0.0)
if res_wo_nan.size == 1:
return res_wo_nan.item()
else:
return cast(_float_or_ndarray, res_wo_nan)

84
qlib/rl/reward.py Normal file
View File

@@ -0,0 +1,84 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import Generic, Any, TypeVar, TYPE_CHECKING
from qlib.typehint import final
if TYPE_CHECKING:
from .utils.env_wrapper import EnvWrapper
SimulatorState = TypeVar("SimulatorState")
class Reward(Generic[SimulatorState]):
"""
Reward calculation component that takes a single argument: state of simulator. Returns a real number: reward.
Subclass should implement ``reward(simulator_state)`` to implement their own reward calculation recipe.
"""
env: EnvWrapper | None = None
@final
def __call__(self, simulator_state: SimulatorState) -> float:
return self.reward(simulator_state)
def reward(self, simulator_state: SimulatorState) -> float:
"""Implement this method for your own reward."""
raise NotImplementedError("Implement reward calculation recipe in `reward()`.")
def log(self, name, value):
self.env.logger.add_scalar(name, value)
class RewardCombination(Reward):
"""Combination of multiple reward."""
def __init__(self, rewards: dict[str, tuple[Reward, float]]):
self.rewards = rewards
def reward(self, simulator_state: Any) -> float:
total_reward = 0.0
for name, (reward_fn, weight) in self.rewards.items():
rew = reward_fn(simulator_state) * weight
total_reward += rew
self.log(name, rew)
return total_reward
# TODO:
# reward_factory is disabled for now
# _RegistryConfigReward = RegistryConfig[REWARDS]
# @configclass
# class _WeightedRewardConfig:
# weight: float
# reward: _RegistryConfigReward
# RewardConfig = Union[_RegistryConfigReward, Dict[str, Union[_RegistryConfigReward, _WeightedRewardConfig]]]
# def reward_factory(reward_config: RewardConfig) -> Reward:
# """
# Use this factory to instantiate the reward from config.
# Simply using ``reward_config.build()`` might not work because reward can have complex combinations.
# """
# if isinstance(reward_config, dict):
# # as reward combination
# rewards = {}
# for name, rew in reward_config.items():
# if not isinstance(rew, _WeightedRewardConfig):
# # default weight is 1.
# rew = _WeightedRewardConfig(weight=1., rew=rew)
# # no recursive build in this step
# rewards[name] = (rew.reward.build(), rew.weight)
# return RewardCombination(rewards)
# else:
# # single reward
# return reward_config.build()

12
qlib/rl/seed.py Normal file
View File

@@ -0,0 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Defines a set of initial state definitions and state-set definitions.
With single-asset order execution only, the only seed is order.
"""
from typing import TypeVar
InitialStateType = TypeVar("InitialStateType")
"""Type of data that creates the simulator."""

75
qlib/rl/simulator.py Normal file
View File

@@ -0,0 +1,75 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import TypeVar, Generic, Any, TYPE_CHECKING
from .seed import InitialStateType
if TYPE_CHECKING:
from .utils.env_wrapper import EnvWrapper
StateType = TypeVar("StateType")
"""StateType stores all the useful data in the simulation process
(as well as utilities to generate/retrieve data when needed)."""
ActType = TypeVar("ActType")
"""This ActType is the type of action at the simulator end."""
class Simulator(Generic[InitialStateType, StateType, ActType]):
"""
Simulator that resets with ``__init__``, and transits with ``step(action)``.
To make the data-flow clear, we make the following restrictions to Simulator:
1. The only way to modify the inner status of a simulator is by using ``step(action)``.
2. External modules can *read* the status of a simulator by using ``simulator.get_state()``,
and check whether the simulator is in the ending state by calling ``simulator.done()``.
A simulator is defined to be bounded with three types:
- *InitialStateType* that is the type of the data used to create the simulator.
- *StateType* that is the type of the **status** (state) of the simulator.
- *ActType* that is the type of the **action**, which is the input received in each step.
Different simulators might share the same StateType. For example, when they are dealing with the same task,
but with different simulation implementation. With the same type, they can safely share other components in the MDP.
Simulators are ephemeral. The lifecycle of a simulator starts with an initial state, and ends with the trajectory.
In another word, when the trajectory ends, simulator is recycled.
If simulators want to share context between (e.g., for speed-up purposes),
this could be done by accessing the weak reference of environment wrapper.
Attributes
----------
env
A reference of env-wrapper, which could be useful in some corner cases.
Simulators are discouraged to use this, because it's prone to induce errors.
"""
env: EnvWrapper | None = None
def __init__(self, initial: InitialStateType, **kwargs: Any) -> None:
pass
def step(self, action: ActType) -> None:
"""Receives an action of ActType.
Simulator should update its internal state, and return None.
The updated state can be retrieved with ``simulator.get_state()``.
"""
raise NotImplementedError()
def get_state(self) -> StateType:
raise NotImplementedError()
def done(self) -> bool:
"""Check whether the simulator is in a "done" state.
When simulator is in a "done" state,
it should no longer receives any ``step`` request.
As simulators are ephemeral, to reset the simulator,
the old one should be destroyed and a new simulator can be created.
"""
raise NotImplementedError()

View File

@@ -0,0 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .data_queue import *
from .env_wrapper import *
from .finite_env import *
from .log import *

179
qlib/rl/utils/data_queue.py Normal file
View File

@@ -0,0 +1,179 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import multiprocessing
import threading
import time
import warnings
from queue import Empty
from typing import TypeVar, Generic, Sequence, cast
from qlib.log import get_module_logger
_logger = get_module_logger(__name__)
T = TypeVar("T")
__all__ = ["DataQueue"]
class DataQueue(Generic[T]):
"""Main process (producer) produces data and stores them in a queue.
Sub-processes (consumers) can retrieve the data-points from the queue.
Data-points are generated via reading items from ``dataset``.
:class:`DataQueue` is ephemeral. You must create a new DataQueue
when the ``repeat`` is exhausted.
See the documents of :class:`qlib.rl.utils.FiniteVectorEnv` for more background.
Parameters
----------
dataset
The dataset to read data from. Must implement ``__len__`` and ``__getitem__``.
repeat
Iterate over the data-points for how many times. Use ``-1`` to iterate forever.
shuffle
If ``shuffle`` is true, the items will be read in random order.
producer_num_workers
Concurrent workers for data-loading.
queue_maxsize
Maximum items to put into queue before it jams.
Examples
--------
>>> data_queue = DataQueue(my_dataset)
>>> with data_queue:
... ...
In worker:
>>> for data in data_queue:
... print(data)
"""
def __init__(
self,
dataset: Sequence[T],
repeat: int = 1,
shuffle: bool = True,
producer_num_workers: int = 0,
queue_maxsize: int = 0,
):
if queue_maxsize == 0:
if os.cpu_count() is not None:
queue_maxsize = cast(int, os.cpu_count())
_logger.info(f"Automatically set data queue maxsize to {queue_maxsize} to avoid overwhelming.")
else:
queue_maxsize = 1
_logger.warning(f"CPU count not available. Setting queue maxsize to 1.")
self.dataset: Sequence[T] = dataset
self.repeat: int = repeat
self.shuffle: bool = shuffle
self.producer_num_workers: int = producer_num_workers
self._activated: bool = False
self._queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=queue_maxsize)
self._done = multiprocessing.Value("i", 0)
def __enter__(self):
self.activate()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.cleanup()
def cleanup(self):
with self._done.get_lock():
self._done.value += 1
for repeat in range(500):
if repeat >= 1:
warnings.warn(f"After {repeat} cleanup, the queue is still not empty.", category=RuntimeWarning)
while not self._queue.empty():
try:
self._queue.get(block=False)
except Empty:
pass
# Sometimes when the queue gets emptied, more data have already been sent,
# and they are on the way into the queue.
# If these data didn't get consumed, it will jam the queue and make the process hang.
# We wait a second here for potential data arriving, and check again (for ``repeat`` times).
time.sleep(1.0)
if self._queue.empty():
break
_logger.debug(f"Remaining items in queue collection done. Empty: {self._queue.empty()}")
def get(self, block=True):
if not hasattr(self, "_first_get"):
self._first_get = True
if self._first_get:
timeout = 5.0
self._first_get = False
else:
timeout = 0.5
while True:
try:
return self._queue.get(block=block, timeout=timeout)
except Empty:
if self._done.value:
raise StopIteration # pylint: disable=raise-missing-from
def put(self, obj, block=True, timeout=None):
return self._queue.put(obj, block=block, timeout=timeout)
def mark_as_done(self):
with self._done.get_lock():
self._done.value = 1
def done(self):
return self._done.value
def activate(self):
if self._activated:
raise ValueError("DataQueue can not activate twice.")
thread = threading.Thread(target=self._producer, daemon=True)
thread.start()
self._activated = True
return self
def __del__(self):
_logger.debug(f"__del__ of {__name__}.DataQueue")
self.cleanup()
def __iter__(self):
if not self._activated:
raise ValueError(
"Need to call activate() to launch a daemon worker " "to produce data into data queue before using it."
)
return self._consumer()
def _consumer(self):
while True:
try:
yield self.get()
except StopIteration:
_logger.debug("Data consumer timed-out from get.")
return
def _producer(self):
# pytorch dataloader is used here only because we need its sampler and multi-processing
from torch.utils.data import DataLoader, Dataset # pylint: disable=import-outside-toplevel
dataloader = DataLoader(
cast(Dataset[T], self.dataset),
batch_size=None,
num_workers=self.producer_num_workers,
shuffle=self.shuffle,
collate_fn=lambda t: t, # identity collate fn
)
repeat = 10**18 if self.repeat == -1 else self.repeat
for _rep in range(repeat):
for data in dataloader:
if self._done.value:
# Already done.
return
self._queue.put(data)
_logger.debug(f"Dataloader loop done. Repeat {_rep}.")
self.mark_as_done()

View File

@@ -0,0 +1,249 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import weakref
from typing import Callable, Any, Iterable, Iterator, Generic, cast
import gym
from qlib.rl.aux_info import AuxiliaryInfoCollector
from qlib.rl.simulator import Simulator, InitialStateType, StateType, ActType
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter, PolicyActType, ObsType
from qlib.rl.reward import Reward
from qlib.typehint import TypedDict
from .finite_env import generate_nan_observation
from .log import LogCollector, LogLevel
__all__ = ["InfoDict", "EnvWrapperStatus", "EnvWrapper"]
# in this case, there won't be any seed for simulator
SEED_INTERATOR_MISSING = "_missing_"
class InfoDict(TypedDict):
"""The type of dict that is used in the 4th return value of ``env.step()``."""
aux_info: dict
"""Any information depends on auxiliary info collector."""
log: dict[str, Any]
"""Collected by LogCollector."""
class EnvWrapperStatus(TypedDict):
"""
This is the status data structure used in EnvWrapper.
The fields here are in the semantics of RL.
For example, ``obs`` means the observation fed into policy.
``action`` means the raw action returned by policy.
"""
cur_step: int
done: bool
initial_state: Any | None
obs_history: list
action_history: list
reward_history: list
class EnvWrapper(
gym.Env[ObsType, PolicyActType], Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType]
):
"""Qlib-based RL environment, subclassing ``gym.Env``.
A wrapper of components, including simulator, state-interpreter, action-interpreter, reward.
This is what the framework of simulator - interpreter - policy looks like in RL training.
All the components other than policy needs to be assembled into a single object called "environment".
The "environment" are replicated into multiple workers, and (at least in tianshou's implementation),
one single policy (agent) plays against a batch of environments.
Parameters
----------
simulator_fn
A callable that is the simulator factory.
When ``seed_iterator`` is present, the factory should take one argument,
that is the seed (aka initial state).
Otherwise, it should take zero argument.
state_interpreter
State-observation converter.
action_interpreter
Policy-simulator action converter.
seed_iterator
An iterable of seed. With the help of :class:`qlib.rl.utils.DataQueue`,
environment workers in different processes can share one ``seed_iterator``.
reward_fn
A callable that accepts the StateType and returns a float (at least in single-agent case).
aux_info_collector
Collect auxiliary information. Could be useful in MARL.
logger
Log collector that collects the logs. The collected logs are sent back to main process,
via the return value of ``env.step()``.
Attributes
----------
status : EnvWrapperStatus
Status indicator. All terms are in *RL language*.
It can be used if users care about data on the RL side.
Can be none when no trajectory is available.
"""
simulator: Simulator[InitialStateType, StateType, ActType]
seed_iterator: str | Iterator[InitialStateType] | None
def __init__(
self,
simulator_fn: Callable[..., Simulator[InitialStateType, StateType, ActType]],
state_interpreter: StateInterpreter[StateType, ObsType],
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
seed_iterator: Iterable[InitialStateType] | None,
reward_fn: Reward | None = None,
aux_info_collector: AuxiliaryInfoCollector[StateType, Any] | None = None,
logger: LogCollector | None = None,
):
# Assign weak reference to wrapper.
#
# Use weak reference here, because:
# 1. Logically, the other components should be able to live without an env_wrapper.
# For example, they might live in a strategy_wrapper in future.
# Therefore injecting a "hard" attribute called "env" is not appropripate.
# 2. When the environment gets destroyed, it gets destoryed.
# We don't want it to silently live inside some interpreters.
# 3. Avoid circular reference.
# 4. When the components get serialized, we can throw away the env without any burden.
# (though this part is not implemented yet)
for obj in [state_interpreter, action_interpreter, reward_fn, aux_info_collector]:
if obj is not None:
obj.env = weakref.proxy(self) # type: ignore
self.simulator_fn = simulator_fn
self.state_interpreter = state_interpreter
self.action_interpreter = action_interpreter
if seed_iterator is None:
# In this case, there won't be any seed for simulator
# We can't set it to None because None actually means something else.
# If `seed_iterator` is None, it means that it's exhausted.
self.seed_iterator = SEED_INTERATOR_MISSING
else:
self.seed_iterator = iter(seed_iterator)
self.reward_fn = reward_fn
self.aux_info_collector = aux_info_collector
self.logger: LogCollector = logger or LogCollector()
self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None)
@property
def action_space(self):
return self.action_interpreter.action_space
@property
def observation_space(self):
return self.state_interpreter.observation_space
def reset(self, **kwargs: Any) -> ObsType:
"""
Try to get a state from state queue, and init the simulator with this state.
If the queue is exhausted, generate an invalid (nan) observation.
"""
try:
if self.seed_iterator is None:
raise RuntimeError("You can trying to get a state from a dead environment wrapper.")
# TODO: simulator/observation might need seed to prefetch something
# as only seed has the ability to do the work beforehands
# NOTE: though logger is reset here, logs in this function won't work,
# because we can't send them outside.
# See https://github.com/thu-ml/tianshou/issues/605
self.logger.reset()
if self.seed_iterator is SEED_INTERATOR_MISSING:
# no initial state
initial_state = None
self.simulator = cast(Callable[[], Simulator], self.simulator_fn)()
else:
initial_state = next(cast(Iterator[InitialStateType], self.seed_iterator))
self.simulator = self.simulator_fn(initial_state)
self.status = EnvWrapperStatus(
cur_step=0,
done=False,
initial_state=initial_state,
obs_history=[],
action_history=[],
reward_history=[],
)
self.simulator.env = cast(EnvWrapper, weakref.proxy(self))
sim_state = self.simulator.get_state()
obs = self.state_interpreter(sim_state)
self.status["obs_history"].append(obs)
return obs
except StopIteration:
# The environment should be recycled because it's in a dead state.
self.seed_iterator = None
return generate_nan_observation(self.observation_space)
def step(self, policy_action: PolicyActType, **kwargs: Any) -> tuple[ObsType, float, bool, InfoDict]:
"""Environment step.
See the code along with comments to get a sequence of things happening here.
"""
if self.seed_iterator is None:
raise RuntimeError("State queue is already exhausted, but the environment is still receiving action.")
# Clear the logged information from last step
self.logger.reset()
# Action is what we have got from policy
self.status["action_history"].append(policy_action)
action = self.action_interpreter(self.simulator.get_state(), policy_action)
# This update must be after action interpreter and before simulator.
self.status["cur_step"] += 1
# Use the converted action of update the simulator
self.simulator.step(action)
# Update "done" first, as this status might be used by reward_fn later
done = self.simulator.done()
self.status["done"] = done
# Get state and calculate observation
sim_state = self.simulator.get_state()
obs = self.state_interpreter(sim_state)
self.status["obs_history"].append(obs)
# Reward and extra info
if self.reward_fn is not None:
rew = self.reward_fn(sim_state)
else:
# No reward. Treated as 0.
rew = 0.0
self.status["reward_history"].append(rew)
if self.aux_info_collector is not None:
aux_info = self.aux_info_collector(sim_state)
else:
aux_info = {}
# Final logging stuff: RL-specific logs
if done:
self.logger.add_scalar("steps_per_episode", self.status["cur_step"])
self.logger.add_scalar("reward", rew)
self.logger.add_any("obs", obs, loglevel=LogLevel.DEBUG)
self.logger.add_any("policy_act", policy_action, loglevel=LogLevel.DEBUG)
info_dict = InfoDict(log=self.logger.logs(), aux_info=aux_info)
return obs, rew, done, info_dict
def render(self):
raise NotImplementedError("Render is not implemented in EnvWrapper.")

337
qlib/rl/utils/finite_env.py Normal file
View File

@@ -0,0 +1,337 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This is to support finite env in vector env.
See https://github.com/thu-ml/tianshou/issues/322 for details.
"""
from __future__ import annotations
import copy
import warnings
from contextlib import contextmanager
import gym
import numpy as np
from typing import Any, Set, Callable, Type
from tianshou.env import BaseVectorEnv, DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv
from qlib.typehint import Literal
from .log import LogWriter
__all__ = [
"generate_nan_observation",
"check_nan_observation",
"FiniteVectorEnv",
"FiniteDummyVectorEnv",
"FiniteSubprocVectorEnv",
"FiniteShmemVectorEnv",
"FiniteEnvType",
"vectorize_env",
]
FiniteEnvType = Literal["dummy", "subproc", "shmem"]
def fill_invalid(obj):
if isinstance(obj, (int, float, bool)):
return fill_invalid(np.array(obj))
if hasattr(obj, "dtype"):
if isinstance(obj, np.ndarray):
if np.issubdtype(obj.dtype, np.floating):
return np.full_like(obj, np.nan)
return np.full_like(obj, np.iinfo(obj.dtype).max)
# dealing with corner cases that numpy number is not supported by tianshou's sharray
return fill_invalid(np.array(obj))
elif isinstance(obj, dict):
return {k: fill_invalid(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [fill_invalid(v) for v in obj]
elif isinstance(obj, tuple):
return tuple(fill_invalid(v) for v in obj)
raise ValueError(f"Unsupported value to fill with invalid: {obj}")
def is_invalid(arr):
if hasattr(arr, "dtype"):
if np.issubdtype(arr.dtype, np.floating):
return np.isnan(arr).all()
return (np.iinfo(arr.dtype).max == arr).all()
if isinstance(arr, dict):
return all(is_invalid(o) for o in arr.values())
if isinstance(arr, (list, tuple)):
return all(is_invalid(o) for o in arr)
if isinstance(arr, (int, float, bool, np.number)):
return is_invalid(np.array(arr))
return True
def generate_nan_observation(obs_space: gym.Space) -> Any:
"""The NaN observation that indicates the environment receives no seed.
We assume that obs is complex and there must be something like float.
Otherwise this logic doesn't work.
"""
sample = obs_space.sample()
sample = fill_invalid(sample)
return sample
def check_nan_observation(obs: Any) -> bool:
"""Check whether obs is generated by :func:`generate_nan_observation`."""
return is_invalid(obs)
class FiniteVectorEnv(BaseVectorEnv):
"""To allow the paralleled env workers consume a single DataQueue until it's exhausted.
See `tianshou issue #322 <https://github.com/thu-ml/tianshou/issues/322>`_.
The requirement is to make every possible seed (stored in :class:`qlib.rl.utils.DataQueue` in our case)
consumed by exactly one environment. This is not possible by tianshou's native VectorEnv and Collector,
because tianshou is unaware of this "exactly one" constraint, and might launch extra workers.
Consider a corner case, where concurrency is 2, but there is only one seed in DataQueue.
The reset of two workers must be both called according to the logic in collect.
The returned results of two workers are collected, regardless of what they are.
The problem is, one of the reset result must be invalid, or repeated,
because there's only one need in queue, and collector isn't aware of such situation.
Luckily, we can hack the vector env, and make a protocol between single env and vector env.
The single environment (should be :class:`qlib.rl.utils.EnvWrapper` in our case) is responsible for
reading from queue, and generate a special observation when the queue is exhausted. The special obs
is called "nan observation", because simply using none causes problems in shared-memory vector env.
:class:`FiniteVectorEnv` then read the observations from all workers, and select those non-nan
observation. It also maintains an ``_alive_env_ids`` to track which workers should never be
called again. When also the environments are exhausted, it will raise StopIteration exception.
The usage of this vector env in collector are two parts:
1. If the data queue is finite (usually when inference), collector should collect "infinity" number of
episodes, until the vector env exhausts by itself.
2. If the data queue is infinite (usually in training), collector can set number of episodes / steps.
In this case, data would be randomly ordered, and some repetitions wouldn't matter.
One extra function of this vector env is that it has a logger that explicitly collects logs
from child workers. See :class:`qlib.rl.utils.LogWriter`.
"""
def __init__(
self, logger: LogWriter | list[LogWriter], env_fns: list[Callable[..., gym.Env]], **kwargs: Any
) -> None:
super().__init__(env_fns, **kwargs)
self._logger: list[LogWriter] = logger if isinstance(logger, list) else [logger]
self._alive_env_ids: Set[int] = set()
self._reset_alive_envs()
self._default_obs = self._default_info = self._default_rew = None
self._zombie = False
self._collector_guarded: bool = False
def _reset_alive_envs(self):
if not self._alive_env_ids:
# starting or running out
self._alive_env_ids = set(range(self.env_num))
# to workaround with tianshou's buffer and batch
def _set_default_obs(self, obs):
if obs is not None and self._default_obs is None:
self._default_obs = copy.deepcopy(obs)
def _set_default_info(self, info):
if info is not None and self._default_info is None:
self._default_info = copy.deepcopy(info)
def _set_default_rew(self, rew):
if rew is not None and self._default_rew is None:
self._default_rew = copy.deepcopy(rew)
def _get_default_obs(self):
return copy.deepcopy(self._default_obs)
def _get_default_info(self):
return copy.deepcopy(self._default_info)
def _get_default_rew(self):
return copy.deepcopy(self._default_rew)
# END
@staticmethod
def _postproc_env_obs(obs):
# reserved for shmem vector env to restore empty observation
if obs is None or check_nan_observation(obs):
return None
return obs
@contextmanager
def collector_guard(self):
"""Guard the collector. Recommended to guard every collect.
This guard is for two purposes.
1. Catch and ignore the StopIteration exception, which is the stopping signal
thrown by FiniteEnv to let tianshou know that ``collector.collect()`` should exit.
2. Notify the loggers that the collect is done what it's done.
Examples
--------
>>> with finite_env.collector_guard():
... collector.collect(n_episode=INF)
"""
self._collector_guarded = True
try:
yield self
except StopIteration:
pass
finally:
self._collector_guarded = False
# At last trigger the loggers
for logger in self._logger:
logger.on_env_all_done()
def reset(self, id=None):
assert not self._zombie
# Check whether it's guarded by collector_guard()
if not self._collector_guarded:
warnings.warn(
"Collector is not guarded by FiniteEnv. "
"This may cause unexpected problems, like unexpected StopIteration exception, "
"or missing logs.",
RuntimeWarning,
)
id = self._wrap_id(id)
self._reset_alive_envs()
# ask super to reset alive envs and remap to current index
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
obs = [None] * len(id)
id2idx = {i: k for k, i in enumerate(id)}
if request_id:
for i, o in zip(request_id, super().reset(request_id)):
obs[id2idx[i]] = self._postproc_env_obs(o)
for i, o in zip(id, obs):
if o is None and i in self._alive_env_ids:
self._alive_env_ids.remove(i)
# logging
for i, o in zip(id, obs):
if i in self._alive_env_ids:
for logger in self._logger:
logger.on_env_reset(i, obs)
# fill empty observation with default(fake) observation
for o in obs:
self._set_default_obs(o)
for i, o in enumerate(obs):
if o is None:
obs[i] = self._get_default_obs()
if not self._alive_env_ids:
# comment this line so that the env becomes indisposable
# self.reset()
self._zombie = True
raise StopIteration
return np.stack(obs)
def step(self, action, id=None):
assert not self._zombie
id = self._wrap_id(id)
id2idx = {i: k for k, i in enumerate(id)}
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
result = [[None, None, False, None] for _ in range(len(id))]
# ask super to step alive envs and remap to current index
if request_id:
valid_act = np.stack([action[id2idx[i]] for i in request_id])
for i, r in zip(request_id, zip(*super().step(valid_act, request_id))):
result[id2idx[i]] = list(r)
result[id2idx[i]][0] = self._postproc_env_obs(result[id2idx[i]][0])
# logging
for i, r in zip(id, result):
if i in self._alive_env_ids:
for logger in self._logger:
logger.on_env_step(i, *r)
# fill empty observation/info with default(fake)
for _, r, ___, i in result:
self._set_default_info(i)
self._set_default_rew(r)
for i, r in enumerate(result):
if r[0] is None:
result[i][0] = self._get_default_obs()
if r[1] is None:
result[i][1] = self._get_default_rew()
if r[3] is None:
result[i][3] = self._get_default_info()
return list(map(np.stack, zip(*result)))
class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv):
pass
class FiniteSubprocVectorEnv(FiniteVectorEnv, SubprocVectorEnv):
pass
class FiniteShmemVectorEnv(FiniteVectorEnv, ShmemVectorEnv):
pass
def vectorize_env(
env_factory: Callable[..., gym.Env],
env_type: FiniteEnvType,
concurrency: int,
logger: LogWriter | list[LogWriter],
) -> FiniteVectorEnv:
"""Helper function to create a vector env.
Parameters
----------
env_factory
Callable to instantiate one single ``gym.Env``.
All concurrent workers will have the same ``env_factory``.
env_type
dummy or subproc or shmem. Corresponding to
`parallelism in tianshou <https://tianshou.readthedocs.io/en/master/api/tianshou.env.html#vectorenv>`_.
concurrency
Concurrent environment workers.
logger
Log writers.
Warnings
--------
Please do not use lambda expression here for ``env_factory`` as it may create incorrectly-shared instances.
Don't do: ::
vectorize_env(lambda: EnvWrapper(...), ...)
Please do: ::
def env_factory(): ...
vectorize_env(env_factory, ...)
"""
env_type_cls_mapping: dict[str, Type[FiniteVectorEnv]] = {
"dummy": FiniteDummyVectorEnv,
"subproc": FiniteSubprocVectorEnv,
"shmem": FiniteShmemVectorEnv,
}
finite_env_cls = env_type_cls_mapping[env_type]
return finite_env_cls(logger, [env_factory for _ in range(concurrency)])

398
qlib/rl/utils/log.py Normal file
View File

@@ -0,0 +1,398 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Distributed logger for RL.
:class:`LogCollector` runs in every environment workers. It collects log info from simulator states,
and add them (as a dict) to auxiliary info returned for each step.
:class:`LogWriter` runs in the central worker. It decodes the dict collected by :class:`LogCollector`
in each worker, and writes them to console, log files, or tensorboard...
The two modules communicate by the "log" field in "info" returned by ``env.step()``.
"""
from __future__ import annotations
import logging
from collections import defaultdict
from enum import IntEnum
from pathlib import Path
from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence
import numpy as np
import pandas as pd
from qlib.log import get_module_logger
if TYPE_CHECKING:
from .env_wrapper import InfoDict
__all__ = ["LogCollector", "LogWriter", "LogLevel", "ConsoleWriter", "CsvWriter"]
ObsType = TypeVar("ObsType")
ActType = TypeVar("ActType")
class LogLevel(IntEnum):
"""Log-levels for RL training.
The behavior of handling each log level depends on the implementation of :class:`LogWriter`.
"""
DEBUG = 10
"""If you only want to see the metric in debug mode."""
PERIODIC = 20
"""If you want to see the metric periodically."""
# FIXME: I haven't given much thought about this. Let's hold it for one iteration.
INFO = 30
"""Important log messages."""
CRITICAL = 40
"""LogWriter should always handle CRITICAL messages"""
class LogCollector:
"""Logs are first collected in each environment worker,
and then aggregated to stream at the central thread in vector env.
In :class:`LogCollector`, every metric is added to a dict, which needs to be ``reset()`` at each step.
The dict is sent via the ``info`` in ``env.step()``, and decoded by the :class:`LogWriter` at vector env.
``min_loglevel`` is for optimization purposes: to avoid too much traffic on networks / in pipe.
"""
_logged: dict[str, tuple[int, Any]]
_min_loglevel: int
def __init__(self, min_loglevel: int | LogLevel = LogLevel.PERIODIC):
self._min_loglevel = int(min_loglevel)
def reset(self):
"""Clear all collected contents."""
self._logged = {}
def _add_metric(self, name: str, metric: Any, loglevel: int | LogLevel) -> None:
if name in self._logged:
raise ValueError(f"A metric with {name} is already added. Please change a name or reset the log collector.")
self._logged[name] = (int(loglevel), metric)
def add_string(self, name: str, string: str, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
"""Add a string with name into logged contents."""
if loglevel < self._min_loglevel:
return
if not isinstance(string, str):
raise TypeError(f"{string} is not a string.")
self._add_metric(name, string, loglevel)
def add_scalar(self, name: str, scalar: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
"""Add a scalar with name into logged contents.
Scalar will be converted into a float.
"""
if loglevel < self._min_loglevel:
return
if hasattr(scalar, "item"):
# could be single-item number
scalar = scalar.item()
if not isinstance(scalar, (float, int)):
raise TypeError(f"{scalar} is not and can not be converted into float or integer.")
scalar = float(scalar)
self._add_metric(name, scalar, loglevel)
def add_array(
self, name: str, array: np.ndarray | pd.DataFrame | pd.Series, loglevel: int | LogLevel = LogLevel.PERIODIC
) -> None:
"""Add an array with name into logging."""
if loglevel < self._min_loglevel:
return
if not isinstance(array, (np.ndarray, pd.DataFrame, pd.Series)):
raise TypeError(f"{array} is not one of ndarray, DataFrame and Series.")
self._add_metric(name, array, loglevel)
def add_any(self, name: str, obj: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
"""Log something with any type.
As it's an "any" object, the only LogWriter accepting it is pickle.
Therefore pickle must be able to serialize it.
"""
if loglevel < self._min_loglevel:
return
# FIXME: detect and rescue object that could be scalar or array
self._add_metric(name, obj, loglevel)
def logs(self) -> dict[str, np.ndarray]:
return {key: np.asanyarray(value, dtype="object") for key, value in self._logged.items()}
class LogWriter(Generic[ObsType, ActType]):
"""Base class for log writers, triggered at every reset and step by finite env.
What to do with a specific log depends on the implementation of subclassing :class:`LogWriter`.
The general principle is that, it should handle logs above its loglevel (inclusive),
and discard logs that are not acceptable. For instance, console loggers obviously can't handle an image.
"""
episode_count: int
"""Counter of episodes."""
step_count: int
"""Counter of steps."""
global_step: int
"""Counter of steps. Won"t be cleared in ``clear``."""
global_episode: int
"""Counter of episodes. Won"t be cleared in ``clear``."""
active_env_ids: Set[int]
"""Active environment ids in vector env."""
episode_lengths: dict[int, int]
"""Map from environment id to episode length."""
episode_rewards: dict[int, list[float]]
"""Map from environment id to episode total reward."""
episode_logs: dict[int, list]
"""Map from environment id to episode logs."""
def __init__(self, loglevel: int | LogLevel = LogLevel.PERIODIC):
self.loglevel = loglevel
self.global_step = 0
self.global_episode = 0
# Information, logs of one episode is stored here.
# This assumes that episode is not too long to fit into the memory.
self.episode_lengths = dict()
self.episode_rewards = dict()
self.episode_logs = dict()
self.clear()
def clear(self):
self.episode_count = self.step_count = 0
self.active_env_ids = set()
self.logs = []
def aggregation(self, array: Sequence[Any]) -> Any:
"""Aggregation function from step-wise to episode-wise.
If it's a sequence of float, take the mean.
Otherwise, take the first element.
"""
assert len(array) > 0, "The aggregated array must be not empty."
if all(isinstance(v, float) for v in array):
return np.mean(array)
else:
return array[0]
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
"""This is triggered at the end of each trajectory.
Parameters
----------
length
Length of this trajectory.
rewards
A list of rewards at each step of this episode.
contents
Logged contents for every steps.
"""
def log_step(self, reward: float, contents: dict[str, Any]) -> None:
"""This is triggered at each step.
Parameters
----------
reward
Reward for this step.
contents
Logged contents for this step.
"""
def on_env_step(self, env_id: int, obs: ObsType, rew: float, done: bool, info: InfoDict) -> None:
"""Callback for finite env, on each step."""
# Update counter
self.global_step += 1
self.step_count += 1
self.active_env_ids.add(env_id)
self.episode_lengths[env_id] += 1
# TODO: reward can be a list of list for MARL
self.episode_rewards[env_id].append(rew)
values: dict[str, Any] = {}
for key, (loglevel, value) in info["log"].items():
if loglevel >= self.loglevel: # FIXME: this is actually incorrect (see last FIXME)
values[key] = value
self.episode_logs[env_id].append(values)
self.log_step(rew, values)
if done:
# Update counter
self.global_episode += 1
self.episode_count += 1
self.log_episode(self.episode_lengths[env_id], self.episode_rewards[env_id], self.episode_logs[env_id])
def on_env_reset(self, env_id: int, obs: ObsType) -> None:
"""Callback for finite env.
Reset episode statistics. Nothing task-specific is logged here because of
`a limitation of tianshou <https://github.com/thu-ml/tianshou/issues/605>`__.
"""
self.episode_lengths[env_id] = 0
self.episode_rewards[env_id] = []
self.episode_logs[env_id] = []
def on_env_all_done(self) -> None:
"""All done. Time for cleanup."""
class ConsoleWriter(LogWriter):
"""Write log messages to console periodically.
It tracks an average meter for each metric, which is the average value since last ``clear()`` till now.
The display format for each metric is ``<name> <latest_value> (<average_value>)``.
Non-single-number metrics are auto skipped.
"""
prefix: str
"""Prefix can be set via ``writer.prefix``."""
def __init__(
self,
log_every_n_episode: int = 20,
total_episodes: int | None = None,
float_format: str = ":.4f",
counter_format: str = ":4d",
loglevel: int | LogLevel = LogLevel.PERIODIC,
):
super().__init__(loglevel)
# TODO: support log_every_n_step
self.log_every_n_episode = log_every_n_episode
self.total_episodes = total_episodes
self.counter_format = counter_format
self.float_format = float_format
self.prefix = ""
self.console_logger = get_module_logger(__name__, level=logging.INFO)
def clear(self):
super().clear()
# Clear average meters
self.metric_counts: dict[str, int] = defaultdict(int)
self.metric_sums: dict[str, float] = defaultdict(float)
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
# Aggregate step-wise to episode-wise
episode_wise_contents: dict[str, list] = defaultdict(list)
for step_contents in contents:
for name, value in step_contents.items():
if isinstance(value, float):
episode_wise_contents[name].append(value)
# Generate log contents and track them in average-meter.
# This should be done at every step, regardless of periodic or not.
logs: dict[str, float] = {}
for name, values in episode_wise_contents.items():
logs[name] = self.aggregation(values) # type: ignore
for name, value in logs.items():
self.metric_counts[name] += 1
self.metric_sums[name] += value
if self.episode_count % self.log_every_n_episode == 0 or self.episode_count == self.total_episodes:
# Only log periodically or at the end
self.console_logger.info(self.generate_log_message(logs))
def generate_log_message(self, logs: dict[str, float]) -> str:
if self.prefix:
msg_prefix = self.prefix + " "
else:
msg_prefix = ""
if self.total_episodes is None:
msg_prefix += "[Step {" + self.counter_format + "}]"
else:
msg_prefix += "[{" + self.counter_format + "}/" + str(self.total_episodes) + "]"
msg_prefix = msg_prefix.format(self.episode_count)
msg = ""
for name, value in logs.items():
# Double-space as delimiter
format_template = r" {} {" + self.float_format + "} ({" + self.float_format + "})"
msg += format_template.format(name, value, self.metric_sums[name] / self.metric_counts[name])
msg = msg_prefix + " " + msg
return msg
class CsvWriter(LogWriter):
"""Dump all episode metrics to a ``result.csv``.
This is not the correct implementation. It's only used for first iteration.
"""
SUPPORTED_TYPES = (float, str, pd.Timestamp)
all_records: list[dict[str, Any]]
def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC):
super().__init__(loglevel)
self.output_dir = output_dir
self.output_dir.mkdir(exist_ok=True)
def clear(self):
super().clear()
self.all_records = []
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
# FIXME Same as ConsoleLogger, needs a refactor to eliminate code-dup
episode_wise_contents: dict[str, list] = defaultdict(list)
for step_contents in contents:
for name, value in step_contents.items():
if isinstance(value, self.SUPPORTED_TYPES):
episode_wise_contents[name].append(value)
logs: dict[str, float] = {}
for name, values in episode_wise_contents.items():
logs[name] = self.aggregation(values) # type: ignore
self.all_records.append(logs)
def on_env_all_done(self) -> None:
# FIXME: this is temporary
pd.DataFrame.from_records(self.all_records).to_csv(self.output_dir / "result.csv", index=False)
# The following are not implemented yet.
class PickleWriter(LogWriter):
"""Dump logs to pickle files."""
class TensorboardWriter(LogWriter):
"""Write logs to event files that can be visualized with tensorboard."""
class MlflowWriter(LogWriter):
"""Add logs to mlflow."""
class LogBuffer(LogWriter):
"""Keep everything in memory."""

13
qlib/typehint.py Normal file
View File

@@ -0,0 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Commonly used types."""
import sys
__all__ = ["Literal", "TypedDict", "final"]
if sys.version_info >= (3, 8):
from typing import Literal, TypedDict, final # type: ignore # pylint: disable=no-name-in-module
else:
from typing_extensions import Literal, TypedDict, final

View File

@@ -134,7 +134,12 @@ setup(
"sphinx",
"sphinx_rtd_theme",
"pre-commit",
]
],
"rl": [
"tianshou",
"gym",
"torch",
],
},
include_package_data=True,
classifiers=[

10
tests/conftest.py Normal file
View File

@@ -0,0 +1,10 @@
import os
import sys
"""Ignore RL tests on non-linux platform."""
collect_ignore = []
if sys.platform != "linux":
for root, dirs, files in os.walk("rl"):
for file in files:
collect_ignore.append(os.path.join(root, file))

4
tests/pytest.ini Normal file
View File

@@ -0,0 +1,4 @@
[pytest]
filterwarnings =
ignore:.*rng.randint:DeprecationWarning
ignore:.*Casting input x to numpy array:UserWarning

View File

@@ -0,0 +1,88 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import multiprocessing
import time
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from qlib.rl.utils.data_queue import DataQueue
class DummyDataset(Dataset):
def __init__(self, length):
self.length = length
def __getitem__(self, index):
assert 0 <= index < self.length
return pd.DataFrame(np.random.randint(0, 100, size=(index + 1, 4)), columns=list("ABCD"))
def __len__(self):
return self.length
def _worker(dataloader, collector):
# for i in range(3):
for i, data in enumerate(dataloader):
collector.put(len(data))
def _queue_to_list(queue):
result = []
while not queue.empty():
result.append(queue.get())
return result
def test_pytorch_dataloader():
dataset = DummyDataset(100)
dataloader = DataLoader(dataset, batch_size=None, num_workers=1)
queue = multiprocessing.Queue()
_worker(dataloader, queue)
assert len(set(_queue_to_list(queue))) == 100
def test_multiprocess_shared_dataloader():
dataset = DummyDataset(100)
with DataQueue(dataset, producer_num_workers=1) as data_queue:
queue = multiprocessing.Queue()
processes = []
for _ in range(3):
processes.append(multiprocessing.Process(target=_worker, args=(data_queue, queue)))
processes[-1].start()
for p in processes:
p.join()
assert len(set(_queue_to_list(queue))) == 100
def test_exit_on_crash_finite():
def _exit_finite():
dataset = DummyDataset(100)
with DataQueue(dataset, producer_num_workers=4) as data_queue:
time.sleep(3)
raise ValueError
# https://stackoverflow.com/questions/34506638/how-to-register-atexit-function-in-pythons-multiprocessing-subprocess
process = multiprocessing.Process(target=_exit_finite)
process.start()
process.join()
def test_exit_on_crash_infinite():
def _exit_infinite():
dataset = DummyDataset(100)
with DataQueue(dataset, repeat=-1, queue_maxsize=100) as data_queue:
time.sleep(3)
raise ValueError
process = multiprocessing.Process(target=_exit_infinite)
process.start()
process.join()
if __name__ == "__main__":
test_multiprocess_shared_dataloader()

249
tests/rl/test_finite_env.py Normal file
View File

@@ -0,0 +1,249 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from collections import Counter
import gym
import numpy as np
from tianshou.data import Batch, Collector
from tianshou.policy import BasePolicy
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from qlib.rl.utils.finite_env import (
LogWriter,
FiniteDummyVectorEnv,
FiniteShmemVectorEnv,
FiniteSubprocVectorEnv,
check_nan_observation,
generate_nan_observation,
)
_test_space = gym.spaces.Dict(
{
"sensors": gym.spaces.Dict(
{
"position": gym.spaces.Box(low=-100, high=100, shape=(3,)),
"velocity": gym.spaces.Box(low=-1, high=1, shape=(3,)),
"front_cam": gym.spaces.Tuple(
(gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)), gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)))
),
"rear_cam": gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)),
}
),
"ext_controller": gym.spaces.MultiDiscrete((5, 2, 2)),
"inner_state": gym.spaces.Dict(
{
"charge": gym.spaces.Discrete(100),
"system_checks": gym.spaces.MultiBinary(10),
"job_status": gym.spaces.Dict(
{
"task": gym.spaces.Discrete(5),
"progress": gym.spaces.Box(low=0, high=100, shape=()),
}
),
}
),
}
)
class FiniteEnv(gym.Env):
def __init__(self, dataset, num_replicas, rank):
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.loader = DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None)
self.iterator = None
self.observation_space = gym.spaces.Discrete(255)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
if self.iterator is None:
self.iterator = iter(self.loader)
try:
self.current_sample, self.step_count = next(self.iterator)
self.current_step = 0
return self.current_sample
except StopIteration:
self.iterator = None
return generate_nan_observation(self.observation_space)
def step(self, action):
self.current_step += 1
assert self.current_step <= self.step_count
return (
0,
1.0,
self.current_step >= self.step_count,
{"sample": self.current_sample, "action": action, "metric": 2.0},
)
class FiniteEnvWithComplexObs(FiniteEnv):
def __init__(self, dataset, num_replicas, rank):
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.loader = DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None)
self.iterator = None
self.observation_space = gym.spaces.Discrete(255)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
if self.iterator is None:
self.iterator = iter(self.loader)
try:
self.current_sample, self.step_count = next(self.iterator)
self.current_step = 0
return _test_space.sample()
except StopIteration:
self.iterator = None
return generate_nan_observation(self.observation_space)
def step(self, action):
self.current_step += 1
assert self.current_step <= self.step_count
return (
_test_space.sample(),
1.0,
self.current_step >= self.step_count,
{"sample": _test_space.sample(), "action": action, "metric": 2.0},
)
class DummyDataset(Dataset):
def __init__(self, length):
self.length = length
self.episodes = [3 * i % 5 + 1 for i in range(self.length)]
def __getitem__(self, index):
assert 0 <= index < self.length
return index, self.episodes[index]
def __len__(self):
return self.length
class AnyPolicy(BasePolicy):
def forward(self, batch, state=None):
return Batch(act=np.stack([1] * len(batch)))
def learn(self, batch):
pass
def _finite_env_factory(dataset, num_replicas, rank, complex=False):
if complex:
return lambda: FiniteEnvWithComplexObs(dataset, num_replicas, rank)
return lambda: FiniteEnv(dataset, num_replicas, rank)
class MetricTracker(LogWriter):
def __init__(self, length):
super().__init__()
self.counter = Counter()
self.finished = set()
self.length = length
def on_env_step(self, env_id, obs, rew, done, info):
assert rew == 1.0
index = info["sample"]
if done:
# assert index not in self.finished
self.finished.add(index)
self.counter[index] += 1
def validate(self):
assert len(self.finished) == self.length
for k, v in self.counter.items():
assert v == k * 3 % 5 + 1
class DoNothingTracker(LogWriter):
def on_env_step(self, *args, **kwargs):
pass
def test_finite_dummy_vector_env():
length = 100
dataset = DummyDataset(length)
envs = FiniteDummyVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)])
envs._collector_guarded = True
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
for _ in range(1):
envs._logger = [MetricTracker(length)]
try:
test_collector.collect(n_step=10**18)
except StopIteration:
envs._logger[0].validate()
def test_finite_shmem_vector_env():
length = 100
dataset = DummyDataset(length)
envs = FiniteShmemVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)])
envs._collector_guarded = True
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
for _ in range(1):
envs._logger = [MetricTracker(length)]
try:
test_collector.collect(n_step=10**18)
except StopIteration:
envs._logger[0].validate()
def test_finite_subproc_vector_env():
length = 100
dataset = DummyDataset(length)
envs = FiniteSubprocVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)])
envs._collector_guarded = True
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
for _ in range(1):
envs._logger = [MetricTracker(length)]
try:
test_collector.collect(n_step=10**18)
except StopIteration:
envs._logger[0].validate()
def test_nan():
assert check_nan_observation(generate_nan_observation(_test_space))
assert not check_nan_observation(_test_space.sample())
def test_finite_dummy_vector_env_complex():
length = 100
dataset = DummyDataset(length)
envs = FiniteDummyVectorEnv(
DoNothingTracker(), [_finite_env_factory(dataset, 5, i, complex=True) for i in range(5)]
)
envs._collector_guarded = True
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
try:
test_collector.collect(n_step=10**18)
except StopIteration:
pass
def test_finite_shmem_vector_env_complex():
length = 100
dataset = DummyDataset(length)
envs = FiniteShmemVectorEnv(
DoNothingTracker(), [_finite_env_factory(dataset, 5, i, complex=True) for i in range(5)]
)
envs._collector_guarded = True
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
try:
test_collector.collect(n_step=10**18)
except StopIteration:
pass

156
tests/rl/test_logger.py Normal file
View File

@@ -0,0 +1,156 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from random import randint, choice
from pathlib import Path
import re
import gym
import numpy as np
import pandas as pd
from gym import spaces
from tianshou.data import Collector, Batch
from tianshou.policy import BasePolicy
from qlib.log import set_log_with_config
from qlib.config import C
from qlib.constant import INF
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.simulator import Simulator
from qlib.rl.utils.data_queue import DataQueue
from qlib.rl.utils.env_wrapper import InfoDict, EnvWrapper
from qlib.rl.utils.log import LogLevel, LogCollector, CsvWriter, ConsoleWriter
from qlib.rl.utils.finite_env import vectorize_env
class SimpleEnv(gym.Env[int, int]):
def __init__(self):
self.logger = LogCollector()
self.observation_space = gym.spaces.Discrete(2)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
self.step_count = 0
return 0
def step(self, action: int):
self.logger.reset()
self.logger.add_scalar("reward", 42.0)
self.logger.add_scalar("a", randint(1, 10))
self.logger.add_array("b", pd.DataFrame({"a": [1, 2], "b": [3, 4]}))
if self.step_count >= 3:
done = choice([False, True])
else:
done = False
if 2 <= self.step_count <= 3:
self.logger.add_scalar("c", randint(11, 20))
self.step_count += 1
return 1, 42.0, done, InfoDict(log=self.logger.logs(), aux_info={})
class AnyPolicy(BasePolicy):
def forward(self, batch, state=None):
return Batch(act=np.stack([1] * len(batch)))
def learn(self, batch):
pass
def test_simple_env_logger(caplog):
set_log_with_config(C.logging_config)
for venv_cls_name in ["dummy", "shmem", "subproc"]:
writer = ConsoleWriter()
csv_writer = CsvWriter(Path(__file__).parent / ".output")
venv = vectorize_env(lambda: SimpleEnv(), venv_cls_name, 4, [writer, csv_writer])
with venv.collector_guard():
collector = Collector(AnyPolicy(), venv)
collector.collect(n_episode=30)
output_file = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
assert output_file.columns.tolist() == ["reward", "a", "c"]
assert len(output_file) >= 30
line_counter = 0
for line in caplog.text.splitlines():
line = line.strip()
if line:
line_counter += 1
assert re.match(r".*reward 42\.0000 \(42.0000\) a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
assert line_counter >= 3
class SimpleSimulator(Simulator[int, float, float]):
def __init__(self, initial: int, **kwargs) -> None:
self.initial = float(initial)
def step(self, action: float) -> None:
import torch
self.initial += action
self.env.logger.add_scalar("test_a", torch.tensor(233.0))
self.env.logger.add_scalar("test_b", np.array(200))
def get_state(self) -> float:
return self.initial
def done(self) -> bool:
return self.initial % 1 > 0.5
class DummyStateInterpreter(StateInterpreter[float, float]):
def interpret(self, state: float) -> float:
return state
@property
def observation_space(self) -> spaces.Box:
return spaces.Box(0, np.inf, shape=(), dtype=np.float32)
class DummyActionInterpreter(ActionInterpreter[float, int, float]):
def interpret(self, state: float, action: int) -> float:
return action / 100
@property
def action_space(self) -> spaces.Box:
return spaces.Discrete(5)
class RandomFivePolicy(BasePolicy):
def forward(self, batch, state=None):
return Batch(act=np.random.randint(5, size=len(batch)))
def learn(self, batch):
pass
def test_logger_with_env_wrapper():
with DataQueue(list(range(20)), shuffle=False) as data_iterator:
env_wrapper_factory = lambda: EnvWrapper(
SimpleSimulator,
DummyStateInterpreter(),
DummyActionInterpreter(),
data_iterator,
logger=LogCollector(LogLevel.DEBUG),
)
# loglevel can be debug here because metrics can all dump into csv
# otherwise, csv writer might crash
csv_writer = CsvWriter(Path(__file__).parent / ".output", loglevel=LogLevel.DEBUG)
venv = vectorize_env(env_wrapper_factory, "shmem", 4, csv_writer)
with venv.collector_guard():
collector = Collector(RandomFivePolicy(), venv)
collector.collect(n_episode=INF * len(venv))
output_df = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
assert len(output_df) == 20
# obs has a increasing trend
assert output_df["obs"].to_numpy()[:10].sum() < output_df["obs"].to_numpy()[10:].sum()
assert (output_df["test_a"] == 233).all()
assert (output_df["test_b"] == 200).all()
assert "steps_per_episode" in output_df and "reward" in output_df

View File

@@ -0,0 +1,308 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
from functools import partial
from pathlib import Path
from typing import NamedTuple
import numpy as np
import pandas as pd
import pytest
import torch
from tianshou.data import Batch
from qlib.backtest import Order
from qlib.config import C
from qlib.log import set_log_with_config
from qlib.rl.data import pickle_styled
from qlib.rl.entries.test import backtest
from qlib.rl.order_execution import *
from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
DATA_ROOT_DIR = Path(__file__).parent.parent / ".data" / "rl" / "intraday_saoe"
DATA_DIR = DATA_ROOT_DIR / "us"
BACKTEST_DATA_DIR = DATA_DIR / "backtest"
FEATURE_DATA_DIR = DATA_DIR / "processed"
ORDER_DIR = DATA_DIR / "order" / "valid_bidir"
CN_DATA_DIR = DATA_ROOT_DIR / "cn"
CN_BACKTEST_DATA_DIR = CN_DATA_DIR / "backtest"
CN_FEATURE_DATA_DIR = CN_DATA_DIR / "processed"
CN_ORDER_DIR = CN_DATA_DIR / "order" / "test"
CN_POLICY_WEIGHTS_DIR = CN_DATA_DIR / "weights"
def test_pickle_data_inspect():
data = pickle_styled.load_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0)
assert len(data) == 390
data = pickle_styled.load_intraday_processed_data(
DATA_DIR / "processed", "AAL", "2013-12-11", 5, data.get_time_index()
)
assert len(data.today) == len(data.yesterday) == 390
def test_simulator_first_step():
order = Order("AAL", 30.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
state = simulator.get_state()
assert state.cur_time == pd.Timestamp("2013-12-11 09:30:00")
assert state.position == 30.0
simulator.step(15.0)
state = simulator.get_state()
assert len(state.history_exec) == 30
assert state.history_exec.index[0] == pd.Timestamp("2013-12-11 09:30:00")
assert state.history_exec["market_volume"].iloc[0] == 450072.0
assert abs(state.history_exec["market_price"].iloc[0] - 25.370001) < 1e-4
assert (state.history_exec["amount"] == 0.5).all()
assert (state.history_exec["deal_amount"] == 0.5).all()
assert abs(state.history_exec["trade_price"].iloc[0] - 25.370001) < 1e-4
assert abs(state.history_exec["trade_value"].iloc[0] - 12.68500) < 1e-4
assert state.history_exec["position"].iloc[0] == 29.5
assert state.history_exec["ffr"].iloc[0] == 1 / 60
assert state.history_steps["market_volume"].iloc[0] == 5041147.0
assert state.history_steps["amount"].iloc[0] == 15.0
assert state.history_steps["deal_amount"].iloc[0] == 15.0
assert state.history_steps["ffr"].iloc[0] == 0.5
assert (
state.history_steps["pa"].iloc[0]
== (state.history_steps["trade_price"].iloc[0] / simulator.twap_price - 1) * 10000
)
assert state.position == 15.0
assert state.cur_time == pd.Timestamp("2013-12-11 10:00:00")
def test_simulator_stop_twap():
order = Order("AAL", 13.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
for _ in range(13):
simulator.step(1.0)
state = simulator.get_state()
assert len(state.history_exec) == 390
assert (state.history_exec["deal_amount"] == 13 / 390).all()
assert state.history_steps["position"].iloc[0] == 12 and state.history_steps["position"].iloc[-1] == 0
assert (state.metrics["ffr"] - 1) < 1e-3
assert abs(state.metrics["market_price"] - state.backtest_data.get_deal_price().mean()) < 1e-4
assert np.isclose(state.metrics["market_volume"], state.backtest_data.get_volume().sum())
assert state.position == 0.0
assert abs(state.metrics["trade_price"] - state.metrics["market_price"]) < 1e-4
assert abs(state.metrics["pa"]) < 1e-2
assert simulator.done()
def test_simulator_stop_early():
order = Order("AAL", 1.0, 1, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
with pytest.raises(ValueError):
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
simulator.step(2.0)
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
simulator.step(1.0)
with pytest.raises(AssertionError):
simulator.step(1.0)
def test_simulator_start_middle():
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
assert len(simulator.ticks_for_order) == 330
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
simulator.step(2.0)
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:30:00")
for _ in range(10):
simulator.step(1.0)
simulator.step(2.0)
assert len(simulator.history_exec) == 330
assert simulator.done()
assert abs(simulator.history_exec["amount"].iloc[-1] - (1 + 2 / 15)) < 1e-4
assert abs(simulator.metrics["ffr"] - 1) < 1e-4
def test_interpreter():
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
assert len(simulator.ticks_for_order) == 330
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
# emulate a env status
class EmulateEnvWrapper(NamedTuple):
status: EnvWrapperStatus
interpreter = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5)
interpreter_step = CurrentStepStateInterpreter(13)
interpreter_action = CategoricalActionInterpreter(20)
interpreter_action_twap = TwapRelativeActionInterpreter()
wrapper_status_kwargs = dict(initial_state=order, obs_history=[], action_history=[], reward_history=[])
# first step
interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=0, done=False, **wrapper_status_kwargs))
obs = interpreter(simulator.get_state())
assert obs["cur_tick"] == 45
assert obs["cur_step"] == 0
assert obs["position"] == 15.0
assert obs["position_history"][0] == 15.0
assert all(np.sum(obs["data_processed"][i]) != 0 for i in range(45))
assert np.sum(obs["data_processed"][45:]) == 0
assert obs["data_processed_prev"].shape == (390, 5)
# first step: second interpreter
interpreter_step.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=0, done=False, **wrapper_status_kwargs))
obs = interpreter_step(simulator.get_state())
assert obs["acquiring"] == 1
assert obs["position"] == 15.0
# second step
simulator.step(5.0)
interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=1, done=False, **wrapper_status_kwargs))
obs = interpreter(simulator.get_state())
assert obs["cur_tick"] == 60
assert obs["cur_step"] == 1
assert obs["position"] == 10.0
assert obs["position_history"][:2].tolist() == [15.0, 10.0]
assert all(np.sum(obs["data_processed"][i]) != 0 for i in range(60))
assert np.sum(obs["data_processed"][60:]) == 0
# second step: action
action = interpreter_action(simulator.get_state(), 1)
assert action == 15 / 20
interpreter_action_twap.env = EmulateEnvWrapper(
status=EnvWrapperStatus(cur_step=1, done=False, **wrapper_status_kwargs)
)
action = interpreter_action_twap(simulator.get_state(), 1.5)
assert action == 1.5
# fast-forward
for _ in range(10):
simulator.step(0.0)
# last step
simulator.step(5.0)
interpreter.env = EmulateEnvWrapper(
status=EnvWrapperStatus(cur_step=12, done=simulator.done(), **wrapper_status_kwargs)
)
assert interpreter.env.status["done"]
obs = interpreter(simulator.get_state())
assert obs["cur_tick"] == 375
assert obs["cur_step"] == 12
assert obs["position"] == 0.0
assert obs["position_history"][1:11].tolist() == [10.0] * 10
assert all(np.sum(obs["data_processed"][i]) != 0 for i in range(375))
assert np.sum(obs["data_processed"][375:]) == 0
def test_network_sanity():
# we won't check the correctness of networks here
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 9:30:00"), pd.Timestamp("2013-12-11 15:59:59"))
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
assert len(simulator.ticks_for_order) == 390
class EmulateEnvWrapper(NamedTuple):
status: EnvWrapperStatus
interpreter = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5)
action_interp = CategoricalActionInterpreter(13)
wrapper_status_kwargs = dict(initial_state=order, obs_history=[], action_history=[], reward_history=[])
network = Recurrent(interpreter.observation_space)
policy = PPO(network, interpreter.observation_space, action_interp.action_space, 1e-3)
for i in range(14):
interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=i, done=False, **wrapper_status_kwargs))
obs = interpreter(simulator.get_state())
batch = Batch(obs=[obs])
output = policy(batch)
assert 0 <= output["act"].item() <= 13
if i < 13:
simulator.step(1.0)
else:
assert obs["cur_tick"] == 389
assert obs["cur_step"] == 12
assert obs["position_history"][-1] == 3
@pytest.mark.parametrize("finite_env_type", ["dummy", "subproc", "shmem"])
def test_twap_strategy(finite_env_type):
set_log_with_config(C.logging_config)
orders = pickle_styled.load_orders(ORDER_DIR)
assert len(orders) == 248
state_interp = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5)
action_interp = TwapRelativeActionInterpreter()
policy = AllOne(state_interp.observation_space, action_interp.action_space)
csv_writer = CsvWriter(Path(__file__).parent / ".output")
backtest(
partial(SingleAssetOrderExecution, data_dir=BACKTEST_DATA_DIR, ticks_per_step=30),
state_interp,
action_interp,
orders,
policy,
[ConsoleWriter(total_episodes=len(orders)), csv_writer],
concurrency=4,
finite_env_type=finite_env_type,
)
metrics = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
assert len(metrics) == 248
assert np.isclose(metrics["ffr"].mean(), 1.0)
assert np.isclose(metrics["pa"].mean(), 0.0)
assert np.allclose(metrics["pa"], 0.0, atol=2e-3)
def test_cn_ppo_strategy():
set_log_with_config(C.logging_config)
# The data starts with 9:31 and ends with 15:00
orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58"))
assert len(orders) == 40
state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6)
action_interp = CategoricalActionInterpreter(4)
network = Recurrent(state_interp.observation_space)
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
policy.load_state_dict(torch.load(CN_POLICY_WEIGHTS_DIR / "ppo_recurrent_30min.pth", map_location="cpu"))
csv_writer = CsvWriter(Path(__file__).parent / ".output")
backtest(
partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
state_interp,
action_interp,
orders,
policy,
[ConsoleWriter(total_episodes=len(orders)), csv_writer],
concurrency=4,
)
metrics = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
assert len(metrics) == len(orders)
assert np.isclose(metrics["ffr"].mean(), 1.0)
assert np.isclose(metrics["pa"].mean(), -16.21578303474833)
assert np.isclose(metrics["market_price"].mean(), 58.68277690875527)
assert np.isclose(metrics["trade_price"].mean(), 58.76063985000002)