mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Qlib RL framework (stage 1) - single-asset order execution (#1076)
* rl init * aux info * Reward config * update * simple * update saoe init * update simulator and seed * minor * minor * update sim * checkpoint * obs * Update interpreter * init qlib simulator * checkpoint * Refine codebase * checkpoint * checkpoint * Add one test * More tests * Simulator checkpoint * checkpoint * First-step tested * Checkpoint * Update data_queue API * Checkpoint * Update test * Move files * Checkpoint * Single-quote -> double-quote * Fix finite env tests * Tested with mypy * pep-574 * No call for env done * Update finite env docs * Fix csv writer * Refine tester * Update logger * Add another logger test * Checkpoint * Add network sanity test * steps per episode is not correct * Cleanup code, ready for PR * Reformat with black * Fix pylint for py37 * Fix lint * Fix lint * Fix flake * update mypy command * mypy * Update exclude pattern * Use pyproject.toml * test * . * . * Refactor pipeline * . * defaults run bash * . * Revert and skip follow_imports * Fix toml issue * fix mypy * . * . * . * Fix install * Minor fix * Fix test * Fix test * Remove requirements * Revert * fix tests * Fix lint * . * . * . * . * . * update install from source command * . * Fix data download * . * . * . * . * . * . * Fix py37 * Ignore tests on non-linux * resolve comments * fix tests * resolve comments * some typo * style updates * More comments * fix dummy * add warning * Align precision in some system * Added some impl notes Co-authored-by: Young <afe.young@gmail.com>
This commit is contained in:
15
.github/workflows/test.yml
vendored
15
.github/workflows/test.yml
vendored
@@ -35,7 +35,7 @@ jobs:
|
||||
pip install numpy==1.19.5 ruamel.yaml
|
||||
pip install pyqlib --ignore-installed
|
||||
|
||||
- name: Make html with sphnix
|
||||
- name: Make html with sphinx
|
||||
run: |
|
||||
pip install -U sphinx
|
||||
pip install sphinx_rtd_theme readthedocs_sphinx_ext
|
||||
@@ -97,12 +97,21 @@ jobs:
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install flake8
|
||||
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 qlib
|
||||
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
|
||||
|
||||
# https://github.com/python/mypy/issues/10600
|
||||
- name: Check Qlib with mypy
|
||||
run: |
|
||||
pip install mypy
|
||||
mypy qlib --install-types --non-interactive || true
|
||||
mypy qlib
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data_simple --interval 1d --region cn
|
||||
python -c "import os; userpath=os.path.expanduser('~'); os.rename(userpath + '/.qlib/qlib_data/cn_data_simple', userpath + '/.qlib/qlib_data/cn_data')"
|
||||
azcopy copy https://qlibpublic.blob.core.windows.net/data /tmp/qlibpublic --recursive
|
||||
mv /tmp/qlibpublic/data tests/.data
|
||||
|
||||
- name: Test workflow by config (install from pip)
|
||||
run: |
|
||||
@@ -113,6 +122,7 @@ jobs:
|
||||
- name: Install Qlib from source
|
||||
run: |
|
||||
pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
pip install gym tianshou torch
|
||||
pip install -e .
|
||||
|
||||
- name: Install test dependencies
|
||||
@@ -129,4 +139,3 @@ jobs:
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
|
||||
5
.github/workflows/test_macos.yml
vendored
5
.github/workflows/test_macos.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install flake8
|
||||
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 qlib
|
||||
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
|
||||
|
||||
- name: Install Qlib with pip
|
||||
run: |
|
||||
@@ -65,6 +65,8 @@ jobs:
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data_simple --interval 1d --region cn
|
||||
python -c "import os; userpath=os.path.expanduser('~'); os.rename(userpath + '/.qlib/qlib_data/cn_data_simple', userpath + '/.qlib/qlib_data/cn_data')"
|
||||
azcopy copy https://qlibpublic.blob.core.windows.net/data /tmp/qlibpublic --recursive
|
||||
mv /tmp/qlibpublic/data tests/.data
|
||||
- name: Test workflow by config (install from pip)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
@@ -75,6 +77,7 @@ jobs:
|
||||
python -m pip install --upgrade cython
|
||||
python -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
python -m pip install gym tianshou torch
|
||||
pip install -e .
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -27,6 +27,10 @@ examples/estimator/estimator_example/
|
||||
|
||||
*.egg-info/
|
||||
|
||||
# test related
|
||||
test-output.xml
|
||||
.output
|
||||
.data
|
||||
|
||||
# special software
|
||||
mlruns/
|
||||
@@ -34,6 +38,7 @@ mlruns/
|
||||
tags
|
||||
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
.vscode/
|
||||
|
||||
*.swp
|
||||
|
||||
17
.mypy.ini
Normal file
17
.mypy.ini
Normal file
@@ -0,0 +1,17 @@
|
||||
[mypy]
|
||||
exclude = (?x)(
|
||||
^qlib/backtest
|
||||
| ^qlib/contrib
|
||||
| ^qlib/data
|
||||
| ^qlib/model
|
||||
| ^qlib/strategy
|
||||
| ^qlib/tests
|
||||
| ^qlib/utils
|
||||
| ^qlib/workflow
|
||||
| ^qlib/config\.py$
|
||||
| ^qlib/log\.py$
|
||||
| ^qlib/__init__\.py$
|
||||
)
|
||||
ignore_missing_imports = true
|
||||
disallow_incomplete_defs = true
|
||||
follow_imports = skip
|
||||
@@ -8,3 +8,6 @@ REG_TW = "tw"
|
||||
|
||||
# Epsilon for avoiding division by zero.
|
||||
EPS = 1e-12
|
||||
|
||||
# Infinity in integer
|
||||
INF = 10**18
|
||||
|
||||
@@ -61,7 +61,11 @@ def get_module_logger(module_name, level: Optional[int] = None) -> QlibLogger:
|
||||
if level is None:
|
||||
level = C.logging_level
|
||||
|
||||
module_name = "qlib.{}".format(module_name)
|
||||
if not module_name.startswith("qlib."):
|
||||
# Add a prefix of qlib. when the requested ``module_name`` doesn't start with ``qlib.``.
|
||||
# If the module_name is already qlib.xxx, we do not format here. Otherwise, it will become qlib.qlib.xxx.
|
||||
module_name = "qlib.{}".format(module_name)
|
||||
|
||||
# Get logger.
|
||||
module_logger = QlibLogger(module_name)
|
||||
module_logger.setLevel(level)
|
||||
|
||||
43
qlib/rl/aux_info.py
Normal file
43
qlib/rl/aux_info.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generic, TYPE_CHECKING, TypeVar
|
||||
|
||||
from qlib.typehint import final
|
||||
|
||||
from .simulator import StateType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .utils.env_wrapper import EnvWrapper
|
||||
|
||||
|
||||
__all__ = ["AuxiliaryInfoCollector"]
|
||||
|
||||
AuxInfoType = TypeVar("AuxInfoType")
|
||||
|
||||
|
||||
class AuxiliaryInfoCollector(Generic[StateType, AuxInfoType]):
|
||||
"""Override this class to collect customized auxiliary information from environment."""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
|
||||
@final
|
||||
def __call__(self, simulator_state: StateType) -> AuxInfoType:
|
||||
return self.collect(simulator_state)
|
||||
|
||||
def collect(self, simulator_state: StateType) -> AuxInfoType:
|
||||
"""Override this for customized auxiliary info.
|
||||
Usually useful in Multi-agent RL.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
simulator_state
|
||||
Retrieved with ``simulator.get_state()``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Auxiliary information.
|
||||
"""
|
||||
raise NotImplementedError("collect is not implemented!")
|
||||
8
qlib/rl/data/__init__.py
Normal file
8
qlib/rl/data/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Common utilities to handle ad-hoc-styled data.
|
||||
|
||||
Most of these snippets comes from research project (paper code).
|
||||
Please take caution when using them in production.
|
||||
"""
|
||||
257
qlib/rl/data/pickle_styled.py
Normal file
257
qlib/rl/data/pickle_styled.py
Normal file
@@ -0,0 +1,257 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""This module contains utilities to read financial data from pickle-styled files.
|
||||
|
||||
This is the format used in `OPD paper <https://seqml.github.io/opd/>`__. NOT the standard data format in qlib.
|
||||
|
||||
The data here are all wrapped with ``@lru_cache``, which saves the expensive IO cost to repetitively read the data.
|
||||
We also encourage users to use ``get_xxx_yyy`` rather than ``XxxYyy`` (although they are the same thing),
|
||||
because ``get_xxx_yyy`` is cache-optimized.
|
||||
|
||||
Note that these pickle files are dumped with Python 3.8. Python lower than 3.7 might not be able to load them.
|
||||
See `PEP 574 <https://peps.python.org/pep-0574/>`__ for details.
|
||||
|
||||
This file shows resemblence to qlib.backtest.high_performance_ds. We might merge those two in future.
|
||||
"""
|
||||
|
||||
# TODO: merge with qlib/backtest/high_performance_ds.py
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import List, Sequence, cast
|
||||
from pathlib import Path
|
||||
|
||||
import cachetools
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from cachetools.keys import hashkey
|
||||
|
||||
from qlib.backtest.decision import OrderDir, Order
|
||||
from qlib.typehint import Literal
|
||||
|
||||
|
||||
DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"]
|
||||
"""Several ad-hoc deal price.
|
||||
``bid_or_ask``: If sell, use column ``$bid0``; if buy, use column ``$ask0``.
|
||||
``bid_or_ask_fill``: Based on ``bid_or_ask``. If price is 0, use another price (``$ask0`` / ``$bid0``) instead.
|
||||
``close``: Use close price (``$close0``) as deal price.
|
||||
"""
|
||||
|
||||
|
||||
def _infer_processed_data_column_names(shape: int) -> list[str]:
|
||||
if shape == 16:
|
||||
return [
|
||||
"$open",
|
||||
"$high",
|
||||
"$low",
|
||||
"$close",
|
||||
"$vwap",
|
||||
"$bid",
|
||||
"$ask",
|
||||
"$volume",
|
||||
"$bidV",
|
||||
"$bidV1",
|
||||
"$bidV3",
|
||||
"$bidV5",
|
||||
"$askV",
|
||||
"$askV1",
|
||||
"$askV3",
|
||||
"$askV5",
|
||||
]
|
||||
if shape == 6:
|
||||
return ["$high", "$low", "$open", "$close", "$vwap", "$volume"]
|
||||
elif shape == 5:
|
||||
return ["$high", "$low", "$open", "$close", "$volume"]
|
||||
raise ValueError(f"Unrecognized data shape: {shape}")
|
||||
|
||||
|
||||
def _find_pickle(filename_without_suffix: Path) -> Path:
|
||||
suffix_list = [".pkl", ".pkl.backtest"]
|
||||
paths: List[Path] = []
|
||||
for suffix in suffix_list:
|
||||
path = filename_without_suffix.parent / (filename_without_suffix.name + suffix)
|
||||
if path.exists():
|
||||
paths.append(path)
|
||||
if not paths:
|
||||
raise FileNotFoundError(f"No file starting with '{filename_without_suffix}' found")
|
||||
if len(paths) > 1:
|
||||
raise ValueError(f"Multiple paths are found with prefix '{filename_without_suffix}': {paths}")
|
||||
return paths[0]
|
||||
|
||||
|
||||
@lru_cache(maxsize=10) # 10 * 40M = 400MB
|
||||
def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame:
|
||||
return pd.read_pickle(_find_pickle(filename_without_suffix))
|
||||
|
||||
|
||||
class IntradayBacktestData:
|
||||
"""Raw market data that is often used in backtesting (thus called BacktestData)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
deal_price: DealPriceType = "close",
|
||||
order_dir: int | None = None,
|
||||
):
|
||||
backtest = _read_pickle(data_dir / stock_id)
|
||||
backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]]
|
||||
|
||||
# No longer need for pandas >= 1.4
|
||||
# backtest = backtest.droplevel([0, 2])
|
||||
|
||||
self.data: pd.DataFrame = backtest
|
||||
self.deal_price_type: DealPriceType = deal_price
|
||||
self.order_dir: int | None = order_dir
|
||||
|
||||
def __repr__(self):
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
return f"{self.__class__.__name__}({self.data})"
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def get_deal_price(self) -> pd.Series:
|
||||
"""Return a pandas series that can be indexed with time.
|
||||
See :attribute:`DealPriceType` for details."""
|
||||
if self.deal_price_type in ("bid_or_ask", "bid_or_ask_fill"):
|
||||
if self.order_dir is None:
|
||||
raise ValueError("Order direction cannot be none when deal_price_type is not close.")
|
||||
if self.order_dir == OrderDir.SELL:
|
||||
col = "$bid0"
|
||||
else: # BUY
|
||||
col = "$ask0"
|
||||
elif self.deal_price_type == "close":
|
||||
col = "$close0"
|
||||
else:
|
||||
raise ValueError(f"Unsupported deal_price_type: {self.deal_price_type}")
|
||||
price = self.data[col]
|
||||
|
||||
if self.deal_price_type == "bid_or_ask_fill":
|
||||
if self.order_dir == OrderDir.SELL:
|
||||
fill_col = "$ask0"
|
||||
else:
|
||||
fill_col = "$bid0"
|
||||
price = price.replace(0, np.nan).fillna(self.data[fill_col])
|
||||
|
||||
return price
|
||||
|
||||
def get_volume(self) -> pd.Series:
|
||||
"""Return a volume series that can be indexed with time."""
|
||||
return self.data["$volume0"]
|
||||
|
||||
def get_time_index(self) -> pd.DatetimeIndex:
|
||||
return cast(pd.DatetimeIndex, self.data.index)
|
||||
|
||||
|
||||
class IntradayProcessedData:
|
||||
"""Processed market data after data cleanup and feature engineering.
|
||||
|
||||
It contains both processed data for "today" and "yesterday", as some algorithms
|
||||
might use the market information of the previous day to assist decision making.
|
||||
"""
|
||||
|
||||
today: pd.DataFrame
|
||||
"""Processed data for "today".
|
||||
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
|
||||
|
||||
yesterday: pd.DataFrame
|
||||
"""Processed data for "yesterday".
|
||||
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
|
||||
|
||||
def __init__(self, data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index):
|
||||
proc = _read_pickle(data_dir / stock_id)
|
||||
# We have to infer the names here because,
|
||||
# unfortunately they are not included in the original data.
|
||||
cnames = _infer_processed_data_column_names(feature_dim)
|
||||
|
||||
time_length: int = len(time_index)
|
||||
|
||||
try:
|
||||
# new data format
|
||||
proc = proc.loc[pd.IndexSlice[stock_id, :, date]]
|
||||
assert len(proc) == time_length and len(proc.columns) == feature_dim * 2
|
||||
proc_today = proc[cnames]
|
||||
proc_yesterday = proc[[f"{c}_1" for c in cnames]].rename(columns=lambda c: c[:-2])
|
||||
except (IndexError, KeyError):
|
||||
# legacy data
|
||||
proc = proc.loc[pd.IndexSlice[stock_id, date]]
|
||||
assert time_length * feature_dim * 2 == len(proc)
|
||||
proc_today = proc.to_numpy()[: time_length * feature_dim].reshape((time_length, feature_dim))
|
||||
proc_yesterday = proc.to_numpy()[time_length * feature_dim :].reshape((time_length, feature_dim))
|
||||
proc_today = pd.DataFrame(proc_today, index=time_index, columns=cnames)
|
||||
proc_yesterday = pd.DataFrame(proc_yesterday, index=time_index, columns=cnames)
|
||||
|
||||
self.today: pd.DataFrame = proc_today
|
||||
self.yesterday: pd.DataFrame = proc_yesterday
|
||||
assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim
|
||||
assert len(self.today) == len(self.yesterday) == time_length
|
||||
|
||||
def __repr__(self):
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
return f"{self.__class__.__name__}({self.today}, {self.yesterday})"
|
||||
|
||||
|
||||
@lru_cache(maxsize=100) # 100 * 50K = 5MB
|
||||
def load_intraday_backtest_data(
|
||||
data_dir: Path, stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", order_dir: int | None = None
|
||||
) -> IntradayBacktestData:
|
||||
return IntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
|
||||
key=lambda data_dir, stock_id, date, _, __: hashkey(data_dir, stock_id, date),
|
||||
)
|
||||
def load_intraday_processed_data(
|
||||
data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index
|
||||
) -> IntradayProcessedData:
|
||||
return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
|
||||
|
||||
|
||||
def load_orders(
|
||||
order_path: Path, start_time: pd.Timestamp | None = None, end_time: pd.Timestamp | None = None
|
||||
) -> Sequence[Order]:
|
||||
"""Load orders, and set start time and end time for the orders."""
|
||||
|
||||
start_time = start_time or pd.Timestamp("0:00:00")
|
||||
end_time = end_time or pd.Timestamp("23:59:59")
|
||||
|
||||
if order_path.is_file():
|
||||
order_df = pd.read_pickle(order_path)
|
||||
else:
|
||||
order_df = []
|
||||
for file in order_path.iterdir():
|
||||
order_data = pd.read_pickle(file)
|
||||
order_df.append(order_data)
|
||||
order_df = pd.concat(order_df)
|
||||
|
||||
order_df = order_df.reset_index()
|
||||
|
||||
# Legacy-style orders have "date" instead of "datetime"
|
||||
if "date" in order_df.columns:
|
||||
order_df = order_df.rename(columns={"date": "datetime"})
|
||||
|
||||
# Sometimes "date" are str rather than Timestamp
|
||||
order_df["datetime"] = pd.to_datetime(order_df["datetime"])
|
||||
|
||||
orders: List[Order] = []
|
||||
|
||||
for _, row in order_df.iterrows():
|
||||
# filter out orders with amount == 0
|
||||
if row["amount"] <= 0:
|
||||
continue
|
||||
orders.append(
|
||||
Order(
|
||||
row["instrument"],
|
||||
row["amount"],
|
||||
int(row["order_type"]),
|
||||
row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second),
|
||||
row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second),
|
||||
)
|
||||
)
|
||||
|
||||
return orders
|
||||
7
qlib/rl/entries/__init__.py
Normal file
7
qlib/rl/entries/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Train, test, inference utilities.
|
||||
|
||||
The APIs in this directory are NOT considered final and are subject to change!
|
||||
"""
|
||||
99
qlib/rl/entries/test.py
Normal file
99
qlib/rl/entries/test.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Callable, Sequence
|
||||
|
||||
from tianshou.data import Collector
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
from qlib.constant import INF
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.rl.simulator import InitialStateType, Simulator
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.utils import DataQueue, EnvWrapper, FiniteEnvType, LogCollector, LogWriter, vectorize_env
|
||||
|
||||
|
||||
_logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
def backtest(
|
||||
simulator_fn: Callable[[InitialStateType], Simulator],
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
initial_states: Sequence[InitialStateType],
|
||||
policy: BasePolicy,
|
||||
logger: LogWriter | list[LogWriter],
|
||||
reward: Reward | None = None,
|
||||
finite_env_type: FiniteEnvType = "subproc",
|
||||
concurrency: int = 2,
|
||||
) -> None:
|
||||
"""Backtest with the parallelism provided by RL framework.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
simulator_fn
|
||||
Callable receiving initial seed, returning a simulator.
|
||||
state_interpreter
|
||||
Interprets the state of simulators.
|
||||
action_interpreter
|
||||
Interprets the policy actions.
|
||||
initial_states
|
||||
Initial states to iterate over. Every state will be run exactly once.
|
||||
policy
|
||||
Policy to test against.
|
||||
logger
|
||||
Logger to record the backtest results. Logger must be present because
|
||||
without logger, all information will be lost.
|
||||
reward
|
||||
Optional reward function. For backtest, this is for testing the rewards
|
||||
and logging them only.
|
||||
finite_env_type
|
||||
Type of finite env implementation.
|
||||
concurrency
|
||||
Parallel workers.
|
||||
"""
|
||||
|
||||
# To save bandwidth
|
||||
min_loglevel = min(lg.loglevel for lg in logger) if isinstance(logger, list) else logger.loglevel
|
||||
|
||||
def env_factory():
|
||||
# FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
|
||||
# and could be thread unsafe.
|
||||
# I'm not sure whether it's a design flaw.
|
||||
# I'll rethink about this when designing the trainer.
|
||||
|
||||
if finite_env_type == "dummy":
|
||||
# We could only experience the "threading-unsafe" problem in dummy.
|
||||
state = copy.deepcopy(state_interpreter)
|
||||
action = copy.deepcopy(action_interpreter)
|
||||
rew = copy.deepcopy(reward)
|
||||
else:
|
||||
state, action, rew = state_interpreter, action_interpreter, reward
|
||||
|
||||
return EnvWrapper(
|
||||
simulator_fn,
|
||||
state,
|
||||
action,
|
||||
seed_iterator,
|
||||
rew,
|
||||
logger=LogCollector(min_loglevel=min_loglevel),
|
||||
)
|
||||
|
||||
with DataQueue(initial_states) as seed_iterator:
|
||||
vector_env = vectorize_env(
|
||||
env_factory,
|
||||
finite_env_type,
|
||||
concurrency,
|
||||
logger,
|
||||
)
|
||||
|
||||
policy.eval()
|
||||
|
||||
with vector_env.collector_guard():
|
||||
test_collector = Collector(policy, vector_env)
|
||||
_logger.info("All ready. Start backtest.")
|
||||
test_collector.collect(n_step=INF * len(vector_env))
|
||||
4
qlib/rl/entries/train.py
Normal file
4
qlib/rl/entries/train.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# TBD
|
||||
@@ -1,94 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Union
|
||||
|
||||
|
||||
from ..backtest.executor import BaseExecutor
|
||||
from .interpreter import StateInterpreter, ActionInterpreter
|
||||
from ..utils import init_instance_by_config
|
||||
|
||||
|
||||
class BaseRLEnv:
|
||||
"""Base environment for reinforcement learning"""
|
||||
|
||||
def reset(self, **kwargs):
|
||||
raise NotImplementedError("reset is not implemented!")
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
step method of rl env
|
||||
Parameters
|
||||
----------
|
||||
action :
|
||||
action from rl policy
|
||||
|
||||
Returns
|
||||
-------
|
||||
env state to rl policy
|
||||
"""
|
||||
raise NotImplementedError("step is not implemented!")
|
||||
|
||||
|
||||
class QlibRLEnv:
|
||||
"""qlib-based RL env"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
executor: BaseExecutor,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
executor : BaseExecutor
|
||||
qlib multi-level/single-level executor, which can be regarded as gamecore in RL
|
||||
"""
|
||||
self.executor = executor
|
||||
|
||||
def reset(self, **kwargs):
|
||||
self.executor.reset(**kwargs)
|
||||
|
||||
|
||||
class QlibIntRLEnv(QlibRLEnv):
|
||||
"""(Qlib)-based RL (Env) with (Interpreter)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
executor: BaseExecutor,
|
||||
state_interpreter: Union[dict, StateInterpreter],
|
||||
action_interpreter: Union[dict, ActionInterpreter],
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state_interpreter : Union[dict, StateInterpreter]
|
||||
interpreter that interprets the qlib execute result into rl env state.
|
||||
|
||||
action_interpreter : Union[dict, ActionInterpreter]
|
||||
interpreter that interprets the rl agent action into qlib order list
|
||||
"""
|
||||
super(QlibIntRLEnv, self).__init__(executor=executor)
|
||||
self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter)
|
||||
self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter)
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
step method of rl env, it run as following step:
|
||||
- Use `action_interpreter.interpret` method to interpret the agent action into order list
|
||||
- Execute the order list with qlib executor, and get the executed result
|
||||
- Use `state_interpreter.interpret` method to interpret the executed result into env state
|
||||
|
||||
Parameters
|
||||
----------
|
||||
action :
|
||||
action from rl policy
|
||||
|
||||
Returns
|
||||
-------
|
||||
env state to rl policy
|
||||
"""
|
||||
_interpret_decision = self.action_interpreter.interpret(action=action)
|
||||
_execute_result = self.executor.execute(trade_decision=_interpret_decision)
|
||||
_interpret_state = self.state_interpreter.interpret(execute_result=_execute_result)
|
||||
return _interpret_state
|
||||
@@ -1,47 +1,150 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
class BaseInterpreter:
|
||||
"""Base Interpreter"""
|
||||
from typing import TYPE_CHECKING, TypeVar, Generic, Any
|
||||
|
||||
def interpret(self, **kwargs):
|
||||
raise NotImplementedError("interpret is not implemented!")
|
||||
import numpy as np
|
||||
|
||||
from qlib.typehint import final
|
||||
|
||||
from .simulator import StateType, ActType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .utils.env_wrapper import EnvWrapper
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
|
||||
ObsType = TypeVar("ObsType")
|
||||
PolicyActType = TypeVar("PolicyActType")
|
||||
|
||||
|
||||
class ActionInterpreter(BaseInterpreter):
|
||||
"""Action Interpreter that interpret rl agent action into qlib orders"""
|
||||
class Interpreter:
|
||||
"""Interpreter is a media between states produced by simulators and states needed by RL policies.
|
||||
Interpreters are two-way:
|
||||
|
||||
def interpret(self, action, **kwargs):
|
||||
"""interpret method
|
||||
1. From simulator state to policy state (aka observation), see :class:`StateInterpreter`.
|
||||
2. From policy action to action accepted by simulator, see :class:`ActionInterpreter`.
|
||||
|
||||
Inherit one of the two sub-classes to define your own interpreter.
|
||||
This super-class is only used for isinstance check.
|
||||
|
||||
Interpreters are recommended to be stateless, meaning that storing temporary information with ``self.xxx``
|
||||
in interpreter is anti-pattern. In future, we might support register some interpreter-related
|
||||
states by calling ``self.env.register_state()``, but it's not planned for first iteration.
|
||||
"""
|
||||
|
||||
|
||||
class StateInterpreter(Generic[StateType, ObsType], Interpreter):
|
||||
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
|
||||
@property
|
||||
def observation_space(self) -> gym.Space:
|
||||
raise NotImplementedError()
|
||||
|
||||
@final # no overridden
|
||||
def __call__(self, simulator_state: StateType) -> ObsType:
|
||||
obs = self.interpret(simulator_state)
|
||||
self.validate(obs)
|
||||
return obs
|
||||
|
||||
def validate(self, obs: ObsType) -> None:
|
||||
"""Validate whether an observation belongs to the pre-defined observation space."""
|
||||
_gym_space_contains(self.observation_space, obs)
|
||||
|
||||
def interpret(self, simulator_state: StateType) -> ObsType:
|
||||
"""Interpret the state of simulator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
action :
|
||||
rl agent action
|
||||
simulator_state
|
||||
Retrieved with ``simulator.get_state()``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
qlib orders
|
||||
|
||||
State needed by policy. Should conform with the state space defined in ``observation_space``.
|
||||
"""
|
||||
|
||||
raise NotImplementedError("interpret is not implemented!")
|
||||
|
||||
|
||||
class StateInterpreter(BaseInterpreter):
|
||||
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
|
||||
class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter):
|
||||
"""Action Interpreter that interpret rl agent action into qlib orders"""
|
||||
|
||||
def interpret(self, execute_result, **kwargs):
|
||||
"""interpret method
|
||||
env: "EnvWrapper" | None = None
|
||||
|
||||
@property
|
||||
def action_space(self) -> gym.Space:
|
||||
raise NotImplementedError()
|
||||
|
||||
@final # no overridden
|
||||
def __call__(self, simulator_state: StateType, action: PolicyActType) -> ActType:
|
||||
self.validate(action)
|
||||
obs = self.interpret(simulator_state, action)
|
||||
return obs
|
||||
|
||||
def validate(self, action: PolicyActType) -> None:
|
||||
"""Validate whether an action belongs to the pre-defined action space."""
|
||||
_gym_space_contains(self.action_space, action)
|
||||
|
||||
def interpret(self, simulator_state: StateType, action: PolicyActType) -> ActType:
|
||||
"""Convert the policy action to simulator action.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
execute_result :
|
||||
qlib execution result
|
||||
simulator_state
|
||||
Retrieved with ``simulator.get_state()``.
|
||||
action
|
||||
Raw action given by policy.
|
||||
|
||||
Returns
|
||||
----------
|
||||
rl env state
|
||||
-------
|
||||
The action needed by simulator,
|
||||
"""
|
||||
raise NotImplementedError("interpret is not implemented!")
|
||||
|
||||
|
||||
def _gym_space_contains(space: gym.Space, x: Any) -> None:
|
||||
"""Strengthened version of gym.Space.contains.
|
||||
Giving more diagnostic information on why validation fails.
|
||||
|
||||
Throw exception rather than returning true or false.
|
||||
"""
|
||||
if isinstance(space, spaces.Dict):
|
||||
if not isinstance(x, dict) or len(x) != len(space):
|
||||
raise GymSpaceValidationError("Sample must be a dict with same length as space.", space, x)
|
||||
for k, subspace in space.spaces.items():
|
||||
if k not in x:
|
||||
raise GymSpaceValidationError(f"Key {k} not found in sample.", space, x)
|
||||
try:
|
||||
_gym_space_contains(subspace, x[k])
|
||||
except GymSpaceValidationError as e:
|
||||
raise GymSpaceValidationError(f"Subspace of key {k} validation error.", space, x) from e
|
||||
|
||||
elif isinstance(space, spaces.Tuple):
|
||||
if isinstance(x, (list, np.ndarray)):
|
||||
x = tuple(x) # Promote list and ndarray to tuple for contains check
|
||||
if not isinstance(x, tuple) or len(x) != len(space):
|
||||
raise GymSpaceValidationError("Sample must be a tuple with same length as space.", space, x)
|
||||
for i, (subspace, part) in enumerate(zip(space, x)):
|
||||
try:
|
||||
_gym_space_contains(subspace, part)
|
||||
except GymSpaceValidationError as e:
|
||||
raise GymSpaceValidationError(f"Subspace of index {i} validation error.", space, x) from e
|
||||
|
||||
else:
|
||||
if not space.contains(x):
|
||||
raise GymSpaceValidationError("Validation error reported by gym.", space, x)
|
||||
|
||||
|
||||
class GymSpaceValidationError(Exception):
|
||||
def __init__(self, message: str, space: gym.Space, x: Any):
|
||||
self.message = message
|
||||
self.space = space
|
||||
self.x = x
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.message}\n Space: {self.space}\n Sample: {self.x}"
|
||||
|
||||
12
qlib/rl/order_execution/__init__.py
Normal file
12
qlib/rl/order_execution/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Currently it supports single-asset order execution.
|
||||
Multi-asset is on the way.
|
||||
"""
|
||||
|
||||
from .interpreter import *
|
||||
from .network import *
|
||||
from .policy import *
|
||||
from .simulator_simple import *
|
||||
222
qlib/rl/order_execution/interpreter.py
Normal file
222
qlib/rl/order_execution/interpreter.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from gym import spaces
|
||||
|
||||
from qlib.constant import EPS
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
||||
from qlib.rl.data import pickle_styled
|
||||
from qlib.typehint import TypedDict
|
||||
|
||||
from .simulator_simple import SAOEState
|
||||
|
||||
__all__ = [
|
||||
"FullHistoryStateInterpreter",
|
||||
"CurrentStepStateInterpreter",
|
||||
"CategoricalActionInterpreter",
|
||||
"TwapRelativeActionInterpreter",
|
||||
]
|
||||
|
||||
|
||||
def canonicalize(value: int | float | np.ndarray | pd.DataFrame | dict) -> np.ndarray | dict:
|
||||
"""To 32-bit numeric types. Recursively."""
|
||||
if isinstance(value, pd.DataFrame):
|
||||
return value.to_numpy()
|
||||
if isinstance(value, (float, np.floating)) or (isinstance(value, np.ndarray) and value.dtype.kind == "f"):
|
||||
return np.array(value, dtype=np.float32)
|
||||
elif isinstance(value, (int, bool, np.integer)) or (isinstance(value, np.ndarray) and value.dtype.kind == "i"):
|
||||
return np.array(value, dtype=np.int32)
|
||||
elif isinstance(value, dict):
|
||||
return {k: canonicalize(v) for k, v in value.items()}
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
class FullHistoryObs(TypedDict):
|
||||
data_processed: Any
|
||||
data_processed_prev: Any
|
||||
acquiring: Any
|
||||
cur_tick: Any
|
||||
cur_step: Any
|
||||
num_step: Any
|
||||
target: Any
|
||||
position: Any
|
||||
position_history: Any
|
||||
|
||||
|
||||
class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
|
||||
"""The observation of all the history, including today (until this moment), and yesterday.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_dir
|
||||
Path to load data after feature engineering.
|
||||
max_step
|
||||
Total number of steps (an upper-bound estimation). For example, 390min / 30min-per-step = 13 steps.
|
||||
data_ticks
|
||||
Equal to the total number of records. For example, in SAOE per minute,
|
||||
the total ticks is the length of day in minutes.
|
||||
data_dim
|
||||
Number of dimensions in data.
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: Path, max_step: int, data_ticks: int, data_dim: int) -> None:
|
||||
self.data_dir = data_dir
|
||||
self.max_step = max_step
|
||||
self.data_ticks = data_ticks
|
||||
self.data_dim = data_dim
|
||||
|
||||
def interpret(self, state: SAOEState) -> FullHistoryObs:
|
||||
processed = pickle_styled.load_intraday_processed_data(
|
||||
self.data_dir,
|
||||
state.order.stock_id,
|
||||
pd.Timestamp(state.order.start_time.date()),
|
||||
self.data_dim,
|
||||
state.ticks_index,
|
||||
)
|
||||
|
||||
position_history = np.full(self.max_step + 1, 0.0, dtype=np.float32)
|
||||
position_history[0] = state.order.amount
|
||||
position_history[1 : len(state.history_steps) + 1] = state.history_steps["position"].to_numpy()
|
||||
|
||||
assert self.env is not None
|
||||
|
||||
# The min, slice here are to make sure that indices fit into the range,
|
||||
# even after the final step of the simulator (in the done step),
|
||||
# to make network in policy happy.
|
||||
return cast(
|
||||
FullHistoryObs,
|
||||
canonicalize(
|
||||
{
|
||||
"data_processed": self._mask_future_info(processed.today, state.cur_time),
|
||||
"data_processed_prev": processed.yesterday,
|
||||
"acquiring": state.order.direction == state.order.BUY,
|
||||
"cur_tick": min(np.sum(state.ticks_index < state.cur_time), self.data_ticks - 1),
|
||||
"cur_step": min(self.env.status["cur_step"], self.max_step - 1),
|
||||
"num_step": self.max_step,
|
||||
"target": state.order.amount,
|
||||
"position": state.position,
|
||||
"position_history": position_history[: self.max_step],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
space = {
|
||||
"data_processed": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),
|
||||
"data_processed_prev": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),
|
||||
"acquiring": spaces.Discrete(2),
|
||||
"cur_tick": spaces.Box(0, self.data_ticks - 1, shape=(), dtype=np.int32),
|
||||
"cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32),
|
||||
# TODO: support arbitrary length index
|
||||
"num_step": spaces.Box(self.max_step, self.max_step, shape=(), dtype=np.int32),
|
||||
"target": spaces.Box(-EPS, np.inf, shape=()),
|
||||
"position": spaces.Box(-EPS, np.inf, shape=()),
|
||||
"position_history": spaces.Box(-EPS, np.inf, shape=(self.max_step,)),
|
||||
}
|
||||
return spaces.Dict(space)
|
||||
|
||||
@staticmethod
|
||||
def _mask_future_info(arr: pd.DataFrame, current: pd.Timestamp) -> pd.DataFrame:
|
||||
arr = arr.copy(deep=True)
|
||||
arr.loc[current:] = 0.0 # mask out data after this moment (inclusive)
|
||||
return arr
|
||||
|
||||
|
||||
class CurrentStateObs(TypedDict):
|
||||
acquiring: bool
|
||||
cur_step: int
|
||||
num_step: int
|
||||
target: float
|
||||
position: float
|
||||
|
||||
|
||||
class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
|
||||
"""The observation of current step.
|
||||
|
||||
Used when policy only depends on the latest state, but not history.
|
||||
The key list is not full. You can add more if more information is needed by your policy.
|
||||
"""
|
||||
|
||||
def __init__(self, max_step: int):
|
||||
self.max_step = max_step
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
space = {
|
||||
"acquiring": spaces.Discrete(2),
|
||||
"cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32),
|
||||
"num_step": spaces.Box(self.max_step, self.max_step, shape=(), dtype=np.int32),
|
||||
"target": spaces.Box(-EPS, np.inf, shape=()),
|
||||
"position": spaces.Box(-EPS, np.inf, shape=()),
|
||||
}
|
||||
return spaces.Dict(space)
|
||||
|
||||
def interpret(self, state: SAOEState) -> CurrentStateObs:
|
||||
assert self.env is not None
|
||||
assert self.env.status["cur_step"] <= self.max_step
|
||||
obs = CurrentStateObs(
|
||||
{
|
||||
"acquiring": state.order.direction == state.order.BUY,
|
||||
"cur_step": self.env.status["cur_step"],
|
||||
"num_step": self.max_step,
|
||||
"target": state.order.amount,
|
||||
"position": state.position,
|
||||
}
|
||||
)
|
||||
return obs
|
||||
|
||||
|
||||
class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
|
||||
"""Convert a discrete policy action to a continuous action, then multiplied by ``order.amount``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
values
|
||||
It can be a list of length $L$: $[a_1, a_2, \\ldots, a_L]$.
|
||||
Then when policy givens decision $x$, $a_x$ times order amount is the output.
|
||||
It can also be an integer $n$, in which case the list of length $n+1$ is auto-generated,
|
||||
i.e., $[0, 1/n, 2/n, \\ldots, n/n]$.
|
||||
"""
|
||||
|
||||
def __init__(self, values: int | list[float]):
|
||||
if isinstance(values, int):
|
||||
values = [i / values for i in range(0, values + 1)]
|
||||
self.action_values = values
|
||||
|
||||
@property
|
||||
def action_space(self) -> spaces.Discrete:
|
||||
return spaces.Discrete(len(self.action_values))
|
||||
|
||||
def interpret(self, state: SAOEState, action: int) -> float:
|
||||
assert 0 <= action < len(self.action_values)
|
||||
return min(state.position, state.order.amount * self.action_values[action])
|
||||
|
||||
|
||||
class TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]):
|
||||
"""Convert a continous ratio to deal amount.
|
||||
|
||||
The ratio is relative to TWAP on the remainder of the day.
|
||||
For example, there are 5 steps left, and the left position is 300.
|
||||
With TWAP strategy, in each position, 60 should be traded.
|
||||
When this interpreter receives action $a$, its output is $60 \\cdot a$.
|
||||
"""
|
||||
|
||||
@property
|
||||
def action_space(self) -> spaces.Box:
|
||||
return spaces.Box(0, np.inf, shape=(), dtype=np.float32)
|
||||
|
||||
def interpret(self, state: SAOEState, action: float) -> float:
|
||||
assert self.env is not None
|
||||
estimated_total_steps = math.ceil(len(state.ticks_for_order) / state.ticks_per_step)
|
||||
twap_volume = state.position / (estimated_total_steps - self.env.status["cur_step"])
|
||||
return min(state.position, twap_volume * action)
|
||||
118
qlib/rl/order_execution/network.py
Normal file
118
qlib/rl/order_execution/network.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tianshou.data import Batch
|
||||
|
||||
from qlib.typehint import Literal
|
||||
from .interpreter import FullHistoryObs
|
||||
|
||||
__all__ = ["Recurrent"]
|
||||
|
||||
|
||||
class Recurrent(nn.Module):
|
||||
"""The network architecture proposed in `OPD <https://seqml.github.io/opd/opd_aaai21_supplement.pdf>`_.
|
||||
|
||||
At every timestep the input of policy network is divided into two parts,
|
||||
the public variables and the private variables. which are handled by ``raw_rnn``
|
||||
and ``pri_rnn`` in this network, respectively.
|
||||
|
||||
One minor difference is that, in this implementation, we don't assume the direction to be fixed.
|
||||
Thus, another ``dire_fc`` is added to produce an extra direction-related feature.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
obs_space: FullHistoryObs,
|
||||
hidden_dim: int = 64,
|
||||
output_dim: int = 32,
|
||||
rnn_type: Literal["rnn", "lstm", "gru"] = "gru",
|
||||
rnn_num_layers: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
self.output_dim = output_dim
|
||||
self.num_sources = 3
|
||||
|
||||
rnn_classes = {"rnn": nn.RNN, "lstm": nn.LSTM, "gru": nn.GRU}
|
||||
|
||||
self.rnn_class = rnn_classes[rnn_type]
|
||||
self.rnn_layers = rnn_num_layers
|
||||
|
||||
self.raw_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers)
|
||||
self.prev_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers)
|
||||
self.pri_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers)
|
||||
|
||||
self.raw_fc = nn.Sequential(nn.Linear(obs_space["data_processed"].shape[-1], hidden_dim), nn.ReLU())
|
||||
self.pri_fc = nn.Sequential(nn.Linear(2, hidden_dim), nn.ReLU())
|
||||
self.dire_fc = nn.Sequential(nn.Linear(2, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU())
|
||||
|
||||
self._init_extra_branches()
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(hidden_dim * self.num_sources, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, output_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
def _init_extra_branches(self):
|
||||
pass
|
||||
|
||||
def _source_features(self, obs: FullHistoryObs, device: torch.device) -> tuple[list[torch.Tensor], torch.Tensor]:
|
||||
bs, _, data_dim = obs["data_processed"].size()
|
||||
data = torch.cat((torch.zeros(bs, 1, data_dim, device=device), obs["data_processed"]), 1)
|
||||
cur_step = obs["cur_step"].long()
|
||||
cur_tick = obs["cur_tick"].long()
|
||||
bs_indices = torch.arange(bs, device=device)
|
||||
|
||||
position = obs["position_history"] / obs["target"].unsqueeze(-1) # [bs, num_step]
|
||||
steps = (
|
||||
torch.arange(position.size(-1), device=device).unsqueeze(0).repeat(bs, 1).float()
|
||||
/ obs["num_step"].unsqueeze(-1).float()
|
||||
) # [bs, num_step]
|
||||
priv = torch.stack((position.float(), steps), -1)
|
||||
|
||||
data_in = self.raw_fc(data)
|
||||
data_out, _ = self.raw_rnn(data_in)
|
||||
# as it is padded with zero in front, this should be last minute
|
||||
data_out_slice = data_out[bs_indices, cur_tick]
|
||||
|
||||
priv_in = self.pri_fc(priv)
|
||||
priv_out = self.pri_rnn(priv_in)[0]
|
||||
priv_out = priv_out[bs_indices, cur_step]
|
||||
|
||||
sources = [data_out_slice, priv_out]
|
||||
|
||||
dir_out = self.dire_fc(torch.stack((obs["acquiring"], 1 - obs["acquiring"]), -1).float())
|
||||
sources.append(dir_out)
|
||||
|
||||
return sources, data_out
|
||||
|
||||
def forward(self, batch: Batch) -> torch.Tensor:
|
||||
"""
|
||||
Input should be a dict (at least) containing:
|
||||
|
||||
- data_processed: [N, T, C]
|
||||
- cur_step: [N] (int)
|
||||
- cur_time: [N] (int)
|
||||
- position_history: [N, S] (S is number of steps)
|
||||
- target: [N]
|
||||
- num_step: [N] (int)
|
||||
- acquiring: [N] (0 or 1)
|
||||
"""
|
||||
|
||||
inp = cast(FullHistoryObs, batch)
|
||||
device = inp["data_processed"].device
|
||||
|
||||
sources, _ = self._source_features(inp, device)
|
||||
assert len(sources) == self.num_sources
|
||||
|
||||
out = torch.cat(sources, -1)
|
||||
return self.fc(out)
|
||||
158
qlib/rl/order_execution/policy.py
Normal file
158
qlib/rl/order_execution/policy.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import gym
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from gym.spaces import Discrete
|
||||
from tianshou.data import Batch, to_torch
|
||||
from tianshou.policy import PPOPolicy, BasePolicy
|
||||
|
||||
__all__ = ["AllOne", "PPO"]
|
||||
|
||||
|
||||
# baselines #
|
||||
|
||||
|
||||
class NonlearnablePolicy(BasePolicy):
|
||||
"""Tianshou's BasePolicy with empty ``learn`` and ``process_fn``.
|
||||
|
||||
This could be moved outside in future.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space: gym.Space, action_space: gym.Space):
|
||||
super().__init__()
|
||||
|
||||
def learn(self, batch, batch_size, repeat):
|
||||
pass
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
pass
|
||||
|
||||
|
||||
class AllOne(NonlearnablePolicy):
|
||||
"""Forward returns a batch full of 1.
|
||||
|
||||
Useful when implementing some baselines (e.g., TWAP).
|
||||
"""
|
||||
|
||||
def forward(self, batch, state=None, **kwargs):
|
||||
return Batch(act=np.full(len(batch), 1.0), state=state)
|
||||
|
||||
|
||||
# ppo #
|
||||
|
||||
|
||||
class PPOActor(nn.Module):
|
||||
def __init__(self, extractor: nn.Module, action_dim: int):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.layer_out = nn.Sequential(nn.Linear(cast(int, extractor.output_dim), action_dim), nn.Softmax(dim=-1))
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
feature = self.extractor(to_torch(obs, device=auto_device(self)))
|
||||
out = self.layer_out(feature)
|
||||
return out, state
|
||||
|
||||
|
||||
class PPOCritic(nn.Module):
|
||||
def __init__(self, extractor: nn.Module):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.value_out = nn.Linear(cast(int, extractor.output_dim), 1)
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
feature = self.extractor(to_torch(obs, device=auto_device(self)))
|
||||
return self.value_out(feature).squeeze(dim=-1)
|
||||
|
||||
|
||||
class PPO(PPOPolicy):
|
||||
"""A wrapper of tianshou PPOPolicy.
|
||||
|
||||
Differences:
|
||||
|
||||
- Auto-create actor and critic network. Supports discrete action space only.
|
||||
- Dedup common parameters between actor network and critic network
|
||||
(not sure whether this is included in latest tianshou or not).
|
||||
- Support a ``weight_file`` that supports loading checkpoint.
|
||||
- Some parameters' default values are different from original.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
network: nn.Module,
|
||||
obs_space: gym.Space,
|
||||
action_space: gym.Space,
|
||||
lr: float,
|
||||
weight_decay: float = 0.0,
|
||||
discount_factor: float = 1.0,
|
||||
max_grad_norm: float = 100.0,
|
||||
reward_normalization: bool = True,
|
||||
eps_clip: float = 0.3,
|
||||
value_clip: float = True,
|
||||
vf_coef: float = 1.0,
|
||||
gae_lambda: float = 1.0,
|
||||
max_batchsize: int = 256,
|
||||
deterministic_eval: bool = True,
|
||||
weight_file: Optional[Path] = None,
|
||||
):
|
||||
assert isinstance(action_space, Discrete)
|
||||
actor = PPOActor(network, action_space.n)
|
||||
critic = PPOCritic(network)
|
||||
optimizer = torch.optim.Adam(
|
||||
chain_dedup(actor.parameters(), critic.parameters()), lr=lr, weight_decay=weight_decay
|
||||
)
|
||||
super().__init__(
|
||||
actor,
|
||||
critic,
|
||||
optimizer,
|
||||
torch.distributions.Categorical,
|
||||
discount_factor=discount_factor,
|
||||
max_grad_norm=max_grad_norm,
|
||||
reward_normalization=reward_normalization,
|
||||
eps_clip=eps_clip,
|
||||
value_clip=value_clip,
|
||||
vf_coef=vf_coef,
|
||||
gae_lambda=gae_lambda,
|
||||
max_batchsize=max_batchsize,
|
||||
deterministic_eval=deterministic_eval,
|
||||
observation_space=obs_space,
|
||||
action_space=action_space,
|
||||
)
|
||||
if weight_file is not None:
|
||||
load_weight(self, weight_file)
|
||||
|
||||
|
||||
# utilities: these should be put in a separate (common) file. #
|
||||
|
||||
|
||||
def auto_device(module: nn.Module) -> torch.device:
|
||||
for param in module.parameters():
|
||||
return param.device
|
||||
return torch.device("cpu") # fallback to cpu
|
||||
|
||||
|
||||
def load_weight(policy, path):
|
||||
assert isinstance(policy, nn.Module), "Policy has to be an nn.Module to load weight."
|
||||
loaded_weight = torch.load(path, map_location="cpu")
|
||||
try:
|
||||
policy.load_state_dict(loaded_weight)
|
||||
except RuntimeError:
|
||||
# try again by loading the converted weight
|
||||
# https://github.com/thu-ml/tianshou/issues/468
|
||||
for k in list(loaded_weight):
|
||||
loaded_weight["_actor_critic." + k] = loaded_weight[k]
|
||||
policy.load_state_dict(loaded_weight)
|
||||
|
||||
|
||||
def chain_dedup(*iterables):
|
||||
seen = set()
|
||||
for iterable in iterables:
|
||||
for i in iterable:
|
||||
if i not in seen:
|
||||
seen.add(i)
|
||||
yield i
|
||||
4
qlib/rl/order_execution/simulator_qlib.py
Normal file
4
qlib/rl/order_execution/simulator_qlib.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Placeholder for qlib-based simulator."""
|
||||
403
qlib/rl/order_execution/simulator_simple.py
Normal file
403
qlib/rl/order_execution/simulator_simple.py
Normal file
@@ -0,0 +1,403 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple, Any, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.constant import EPS
|
||||
from qlib.rl.simulator import Simulator
|
||||
from qlib.rl.data.pickle_styled import IntradayBacktestData, load_intraday_backtest_data, DealPriceType
|
||||
from qlib.rl.utils import LogLevel
|
||||
from qlib.typehint import TypedDict
|
||||
|
||||
__all__ = ["SAOEMetrics", "SAOEState", "SingleAssetOrderExecution"]
|
||||
|
||||
ONE_SEC = pd.Timedelta("1s") # use 1 second to exclude the right interval point
|
||||
|
||||
|
||||
class SAOEMetrics(TypedDict):
|
||||
"""Metrics for SAOE accumulated for a "period".
|
||||
It could be accumulated for a day, or a period of time (e.g., 30min), or calculated separately for every minute.
|
||||
|
||||
Warnings
|
||||
--------
|
||||
The type hints are for single elements. In lots of times, they can be vectorized.
|
||||
For example, ``market_volume`` could be a list of float (or ndarray) rather tahn a single float.
|
||||
"""
|
||||
|
||||
stock_id: str
|
||||
"""Stock ID of this record."""
|
||||
datetime: pd.Timestamp
|
||||
"""Datetime of this record (this is index in the dataframe)."""
|
||||
direction: int
|
||||
"""Direction of the order. 0 for sell, 1 for buy."""
|
||||
|
||||
# Market information.
|
||||
market_volume: float
|
||||
"""(total) market volume traded in the period."""
|
||||
market_price: float
|
||||
"""Deal price. If it's a period of time, this is the average market deal price."""
|
||||
|
||||
# Strategy records.
|
||||
|
||||
amount: float
|
||||
"""Total amount (volume) strategy intends to trade."""
|
||||
inner_amount: float
|
||||
"""Total amount that the lower-level strategy intends to trade
|
||||
(might be larger than amount, e.g., to ensure ffr)."""
|
||||
|
||||
deal_amount: float
|
||||
"""Amount that successfully takes effect (must be less than inner_amount)."""
|
||||
trade_price: float
|
||||
"""The average deal price for this strategy."""
|
||||
trade_value: float
|
||||
"""Total worth of trading. In the simple simulaton, trade_value = deal_amount * price."""
|
||||
position: float
|
||||
"""Position left after this "period"."""
|
||||
|
||||
# Accumulated metrics
|
||||
|
||||
ffr: float
|
||||
"""Completed how much percent of the daily order."""
|
||||
|
||||
pa: float
|
||||
"""Price advantage compared to baseline (i.e., trade with baseline market price).
|
||||
The baseline is trade price when using TWAP strategy to execute this order.
|
||||
Please note that there could be data leak here).
|
||||
Unit is BP (basis point, 1/10000)."""
|
||||
|
||||
|
||||
class SAOEState(NamedTuple):
|
||||
"""Data structure holding a state for SAOE simulator."""
|
||||
|
||||
order: Order
|
||||
"""The order we are dealing with."""
|
||||
cur_time: pd.Timestamp
|
||||
"""Current time, e.g., 9:30."""
|
||||
position: float
|
||||
"""Current remaining volume to execute."""
|
||||
history_exec: pd.DataFrame
|
||||
"""See :attr:`SingleAssetOrderExecution.history_exec`."""
|
||||
history_steps: pd.DataFrame
|
||||
"""See :attr:`SingleAssetOrderExecution.history_steps`."""
|
||||
|
||||
metrics: SAOEMetrics | None
|
||||
"""Daily metric, only available when the trading is in "done" state."""
|
||||
|
||||
backtest_data: IntradayBacktestData
|
||||
"""Backtest data is included in the state.
|
||||
Actually, only the time index of this data is needed, at this moment.
|
||||
I include the full data so that algorithms (e.g., VWAP) that relies on the raw data can be implemented.
|
||||
Interpreter can use this as they wish, but they should be careful not to leak future data.
|
||||
"""
|
||||
|
||||
ticks_per_step: int
|
||||
"""How many ticks for each step."""
|
||||
ticks_index: pd.DatetimeIndex
|
||||
"""Trading ticks in all day, NOT sliced by order (defined in data). e.g., [9:30, 9:31, ..., 14:59]."""
|
||||
ticks_for_order: pd.DatetimeIndex
|
||||
"""Trading ticks sliced by order, e.g., [9:45, 9:46, ..., 14:44]."""
|
||||
|
||||
|
||||
class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
"""Single-asset order execution (SAOE) simulator.
|
||||
|
||||
As there's no "calendar" in the simple simulator, ticks are used to trade.
|
||||
A tick is a record (a line) in the pickle-styled data file.
|
||||
Each tick is considered as a individual trading opportunity.
|
||||
If such fine granularity is not needed, use ``ticks_per_step`` to
|
||||
lengthen the ticks for each step.
|
||||
|
||||
In each step, the traded amount are "equally" splitted to each tick,
|
||||
then bounded by volume maximum exeuction volume (i.e., ``vol_threshold``),
|
||||
and if it's the last step, try to ensure all the amount to be executed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
initial
|
||||
The seed to start an SAOE simulator is an order.
|
||||
ticks_per_step
|
||||
How many ticks per step.
|
||||
data_dir
|
||||
Path to load backtest data
|
||||
vol_threshold
|
||||
Maximum execution volume (divided by market execution volume).
|
||||
"""
|
||||
|
||||
history_exec: pd.DataFrame
|
||||
"""All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns."""
|
||||
|
||||
history_steps: pd.DataFrame
|
||||
"""Positions at each step. The position before first step is also recorded.
|
||||
See :class:`SAOEMetrics` for available columns."""
|
||||
|
||||
metrics: SAOEMetrics | None
|
||||
"""Metrics. Only available when done."""
|
||||
|
||||
twap_price: float
|
||||
"""This price is used to compute price advantage.
|
||||
It"s defined as the average price in the period from order"s start time to end time."""
|
||||
|
||||
ticks_index: pd.DatetimeIndex
|
||||
"""All available ticks for the day (not restricted to order)."""
|
||||
|
||||
ticks_for_order: pd.DatetimeIndex
|
||||
"""Ticks that is available for trading (sliced by order)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
order: Order,
|
||||
data_dir: Path,
|
||||
ticks_per_step: int = 30,
|
||||
deal_price_type: DealPriceType = "close",
|
||||
vol_threshold: float | None = None,
|
||||
) -> None:
|
||||
self.order = order
|
||||
self.ticks_per_step: int = ticks_per_step
|
||||
self.deal_price_type = deal_price_type
|
||||
self.vol_threshold = vol_threshold
|
||||
self.data_dir = data_dir
|
||||
self.backtest_data = load_intraday_backtest_data(
|
||||
self.data_dir, order.stock_id, pd.Timestamp(order.start_time.date()), self.deal_price_type, order.direction
|
||||
)
|
||||
|
||||
self.ticks_index = self.backtest_data.get_time_index()
|
||||
|
||||
# Get time index available for trading
|
||||
self.ticks_for_order = self._get_ticks_slice(self.order.start_time, self.order.end_time)
|
||||
|
||||
self.cur_time = self.ticks_for_order[0]
|
||||
# NOTE: astype(float) is necessary in some systems.
|
||||
# this will align the precision with `.to_numpy()` in `_split_exec_vol`
|
||||
self.twap_price = float(self.backtest_data.get_deal_price().loc[self.ticks_for_order].astype(float).mean())
|
||||
|
||||
self.position = order.amount
|
||||
|
||||
metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member
|
||||
# NOTE: can empty dataframe contain index?
|
||||
self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.metrics = None
|
||||
|
||||
self.market_price: np.ndarray | None = None
|
||||
self.market_vol: np.ndarray | None = None
|
||||
self.market_vol_limit: np.ndarray | None = None
|
||||
|
||||
def step(self, amount: float) -> None:
|
||||
"""Execute one step or SAOE.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
amount
|
||||
The amount you wish to deal. The simulator doesn't guarantee all the amount to be successfully dealt.
|
||||
"""
|
||||
|
||||
assert not self.done()
|
||||
|
||||
self.market_price = self.market_vol = None # avoid misuse
|
||||
exec_vol = self._split_exec_vol(amount)
|
||||
assert self.market_price is not None and self.market_vol is not None
|
||||
|
||||
ticks_position = self.position - np.cumsum(exec_vol)
|
||||
|
||||
self.position -= exec_vol.sum()
|
||||
if self.position < -EPS or (exec_vol < -EPS).any():
|
||||
raise ValueError(f"Execution volume is invalid: {exec_vol} (position = {self.position})")
|
||||
|
||||
# Get time index available for this step
|
||||
time_index = self._get_ticks_slice(self.cur_time, self._next_time())
|
||||
|
||||
self.history_exec = self._dataframe_append(
|
||||
self.history_exec,
|
||||
SAOEMetrics(
|
||||
# It should have the same keys with SAOEMetrics,
|
||||
# but the values do not necessarily have the annotated type.
|
||||
# Some values could be vectorized (e.g., exec_vol).
|
||||
stock_id=self.order.stock_id,
|
||||
datetime=time_index,
|
||||
direction=self.order.direction,
|
||||
market_volume=self.market_vol,
|
||||
market_price=self.market_price,
|
||||
amount=exec_vol,
|
||||
inner_amount=exec_vol,
|
||||
deal_amount=exec_vol,
|
||||
trade_price=self.market_price,
|
||||
trade_value=self.market_price * exec_vol,
|
||||
position=ticks_position,
|
||||
ffr=exec_vol / self.order.amount,
|
||||
pa=price_advantage(self.market_price, self.twap_price, self.order.direction),
|
||||
),
|
||||
)
|
||||
|
||||
self.history_steps = self._dataframe_append(
|
||||
self.history_steps,
|
||||
[self._metrics_collect(self.cur_time, self.market_vol, self.market_price, amount, exec_vol)],
|
||||
)
|
||||
|
||||
if self.done():
|
||||
if self.env is not None:
|
||||
self.env.logger.add_any("history_steps", self.history_steps, loglevel=LogLevel.DEBUG)
|
||||
self.env.logger.add_any("history_exec", self.history_exec, loglevel=LogLevel.DEBUG)
|
||||
|
||||
self.metrics = self._metrics_collect(
|
||||
self.ticks_index[0], # start time
|
||||
self.history_exec["market_volume"],
|
||||
self.history_exec["market_price"],
|
||||
self.history_steps["amount"].sum(),
|
||||
self.history_exec["deal_amount"],
|
||||
)
|
||||
|
||||
# NOTE (yuge): It looks to me that it's the "correct" decision to
|
||||
# put all the logs here, because only components like simulators themselves
|
||||
# have the knowledge about what could appear in the logs, and what's the format.
|
||||
# But I admit it's not necessarily the most convenient way.
|
||||
# I'll rethink about it when we have the second environment
|
||||
# Maybe some APIs like self.logger.enable_auto_log() ?
|
||||
|
||||
if self.env is not None:
|
||||
for key, value in self.metrics.items():
|
||||
if isinstance(value, float):
|
||||
self.env.logger.add_scalar(key, value)
|
||||
else:
|
||||
self.env.logger.add_any(key, value)
|
||||
|
||||
self.cur_time = self._next_time()
|
||||
|
||||
def get_state(self) -> SAOEState:
|
||||
return SAOEState(
|
||||
order=self.order,
|
||||
cur_time=self.cur_time,
|
||||
position=self.position,
|
||||
history_exec=self.history_exec,
|
||||
history_steps=self.history_steps,
|
||||
metrics=self.metrics,
|
||||
backtest_data=self.backtest_data,
|
||||
ticks_per_step=self.ticks_per_step,
|
||||
ticks_index=self.ticks_index,
|
||||
ticks_for_order=self.ticks_for_order,
|
||||
)
|
||||
|
||||
def done(self) -> bool:
|
||||
return self.position < EPS or self.cur_time >= self.order.end_time
|
||||
|
||||
def _next_time(self) -> pd.Timestamp:
|
||||
"""The "current time" (``cur_time``) for next step."""
|
||||
# Look for next time on time index
|
||||
current_loc = self.ticks_index.get_loc(self.cur_time)
|
||||
next_loc = current_loc + self.ticks_per_step
|
||||
|
||||
# Calibrate the next location to multiple of ticks_per_step.
|
||||
# This is to make sure that:
|
||||
# as long as ticks_per_step is a multiple of something, each step won't cross morning and afternoon.
|
||||
next_loc = next_loc - next_loc % self.ticks_per_step
|
||||
|
||||
if next_loc < len(self.ticks_index) and self.ticks_index[next_loc] < self.order.end_time:
|
||||
return self.ticks_index[next_loc]
|
||||
else:
|
||||
return self.order.end_time
|
||||
|
||||
def _cur_duration(self) -> pd.Timedelta:
|
||||
"""The "duration" of this step (step that is about to happen)."""
|
||||
return self._next_time() - self.cur_time
|
||||
|
||||
def _split_exec_vol(self, exec_vol_sum: float) -> np.ndarray:
|
||||
"""
|
||||
Split the volume in each step into minutes, considering possible constraints.
|
||||
This follows TWAP strategy.
|
||||
"""
|
||||
next_time = self._next_time()
|
||||
|
||||
# get the backtest data for next interval
|
||||
self.market_vol = self.backtest_data.get_volume().loc[self.cur_time : next_time - ONE_SEC].to_numpy()
|
||||
self.market_price = self.backtest_data.get_deal_price().loc[self.cur_time : next_time - ONE_SEC].to_numpy()
|
||||
|
||||
assert self.market_vol is not None and self.market_price is not None
|
||||
|
||||
# split the volume equally into each minute
|
||||
exec_vol = np.repeat(exec_vol_sum / len(self.market_price), len(self.market_price))
|
||||
|
||||
# apply the volume threshold
|
||||
market_vol_limit = self.vol_threshold * self.market_vol if self.vol_threshold is not None else np.inf
|
||||
exec_vol = np.minimum(exec_vol, market_vol_limit) # type: ignore
|
||||
|
||||
# Complete all the order amount at the last moment.
|
||||
if next_time >= self.order.end_time:
|
||||
exec_vol[-1] += self.position - exec_vol.sum()
|
||||
exec_vol = np.minimum(exec_vol, market_vol_limit) # type: ignore
|
||||
|
||||
return exec_vol
|
||||
|
||||
def _metrics_collect(
|
||||
self,
|
||||
datetime: pd.Timestamp,
|
||||
market_vol: np.ndarray,
|
||||
market_price: np.ndarray,
|
||||
amount: float, # intended to trade such amount
|
||||
exec_vol: np.ndarray,
|
||||
) -> SAOEMetrics:
|
||||
assert len(market_vol) == len(market_price) == len(exec_vol)
|
||||
|
||||
if np.abs(np.sum(exec_vol)) < EPS:
|
||||
exec_avg_price = 0.0
|
||||
else:
|
||||
exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan
|
||||
if hasattr(exec_avg_price, "item"): # could be numpy scalar
|
||||
exec_avg_price = exec_avg_price.item() # type: ignore
|
||||
|
||||
return SAOEMetrics(
|
||||
stock_id=self.order.stock_id,
|
||||
datetime=datetime,
|
||||
direction=self.order.direction,
|
||||
market_volume=market_vol.sum(),
|
||||
market_price=market_price.mean(),
|
||||
amount=amount,
|
||||
inner_amount=exec_vol.sum(),
|
||||
deal_amount=exec_vol.sum(), # in this simulator, there's no other restrictions
|
||||
trade_price=exec_avg_price,
|
||||
trade_value=np.sum(market_price * exec_vol),
|
||||
position=self.position,
|
||||
ffr=float(exec_vol.sum() / self.order.amount),
|
||||
pa=price_advantage(exec_avg_price, self.twap_price, self.order.direction),
|
||||
)
|
||||
|
||||
def _get_ticks_slice(self, start: pd.Timestamp, end: pd.Timestamp, include_end: bool = False) -> pd.DatetimeIndex:
|
||||
if not include_end:
|
||||
end = end - ONE_SEC
|
||||
return self.ticks_index[self.ticks_index.slice_indexer(start, end)]
|
||||
|
||||
@staticmethod
|
||||
def _dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame:
|
||||
# dataframe.append is deprecated
|
||||
other_df = pd.DataFrame(other).set_index("datetime")
|
||||
other_df.index.name = "datetime"
|
||||
return pd.concat([df, other_df], axis=0)
|
||||
|
||||
|
||||
_float_or_ndarray = TypeVar("_float_or_ndarray", float, np.ndarray)
|
||||
|
||||
|
||||
def price_advantage(
|
||||
exec_price: _float_or_ndarray, baseline_price: float, direction: OrderDir | int
|
||||
) -> _float_or_ndarray:
|
||||
if baseline_price == 0: # something is wrong with data. Should be nan here
|
||||
if isinstance(exec_price, float):
|
||||
return 0.0
|
||||
else:
|
||||
return np.zeros_like(exec_price)
|
||||
if direction == OrderDir.BUY:
|
||||
res = (1 - exec_price / baseline_price) * 10000
|
||||
elif direction == OrderDir.SELL:
|
||||
res = (exec_price / baseline_price - 1) * 10000
|
||||
else:
|
||||
raise ValueError(f"Unexpected order direction: {direction}")
|
||||
res_wo_nan: np.ndarray = np.nan_to_num(res, nan=0.0)
|
||||
if res_wo_nan.size == 1:
|
||||
return res_wo_nan.item()
|
||||
else:
|
||||
return cast(_float_or_ndarray, res_wo_nan)
|
||||
84
qlib/rl/reward.py
Normal file
84
qlib/rl/reward.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generic, Any, TypeVar, TYPE_CHECKING
|
||||
|
||||
from qlib.typehint import final
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .utils.env_wrapper import EnvWrapper
|
||||
|
||||
SimulatorState = TypeVar("SimulatorState")
|
||||
|
||||
|
||||
class Reward(Generic[SimulatorState]):
|
||||
"""
|
||||
Reward calculation component that takes a single argument: state of simulator. Returns a real number: reward.
|
||||
|
||||
Subclass should implement ``reward(simulator_state)`` to implement their own reward calculation recipe.
|
||||
"""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
|
||||
@final
|
||||
def __call__(self, simulator_state: SimulatorState) -> float:
|
||||
return self.reward(simulator_state)
|
||||
|
||||
def reward(self, simulator_state: SimulatorState) -> float:
|
||||
"""Implement this method for your own reward."""
|
||||
raise NotImplementedError("Implement reward calculation recipe in `reward()`.")
|
||||
|
||||
def log(self, name, value):
|
||||
self.env.logger.add_scalar(name, value)
|
||||
|
||||
|
||||
class RewardCombination(Reward):
|
||||
"""Combination of multiple reward."""
|
||||
|
||||
def __init__(self, rewards: dict[str, tuple[Reward, float]]):
|
||||
self.rewards = rewards
|
||||
|
||||
def reward(self, simulator_state: Any) -> float:
|
||||
total_reward = 0.0
|
||||
for name, (reward_fn, weight) in self.rewards.items():
|
||||
rew = reward_fn(simulator_state) * weight
|
||||
total_reward += rew
|
||||
self.log(name, rew)
|
||||
return total_reward
|
||||
|
||||
|
||||
# TODO:
|
||||
# reward_factory is disabled for now
|
||||
|
||||
# _RegistryConfigReward = RegistryConfig[REWARDS]
|
||||
|
||||
|
||||
# @configclass
|
||||
# class _WeightedRewardConfig:
|
||||
# weight: float
|
||||
# reward: _RegistryConfigReward
|
||||
|
||||
|
||||
# RewardConfig = Union[_RegistryConfigReward, Dict[str, Union[_RegistryConfigReward, _WeightedRewardConfig]]]
|
||||
|
||||
|
||||
# def reward_factory(reward_config: RewardConfig) -> Reward:
|
||||
# """
|
||||
# Use this factory to instantiate the reward from config.
|
||||
# Simply using ``reward_config.build()`` might not work because reward can have complex combinations.
|
||||
# """
|
||||
# if isinstance(reward_config, dict):
|
||||
# # as reward combination
|
||||
# rewards = {}
|
||||
# for name, rew in reward_config.items():
|
||||
# if not isinstance(rew, _WeightedRewardConfig):
|
||||
# # default weight is 1.
|
||||
# rew = _WeightedRewardConfig(weight=1., rew=rew)
|
||||
# # no recursive build in this step
|
||||
# rewards[name] = (rew.reward.build(), rew.weight)
|
||||
# return RewardCombination(rewards)
|
||||
# else:
|
||||
# # single reward
|
||||
# return reward_config.build()
|
||||
12
qlib/rl/seed.py
Normal file
12
qlib/rl/seed.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Defines a set of initial state definitions and state-set definitions.
|
||||
|
||||
With single-asset order execution only, the only seed is order.
|
||||
"""
|
||||
|
||||
from typing import TypeVar
|
||||
|
||||
InitialStateType = TypeVar("InitialStateType")
|
||||
"""Type of data that creates the simulator."""
|
||||
75
qlib/rl/simulator.py
Normal file
75
qlib/rl/simulator.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypeVar, Generic, Any, TYPE_CHECKING
|
||||
|
||||
from .seed import InitialStateType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .utils.env_wrapper import EnvWrapper
|
||||
|
||||
StateType = TypeVar("StateType")
|
||||
"""StateType stores all the useful data in the simulation process
|
||||
(as well as utilities to generate/retrieve data when needed)."""
|
||||
|
||||
ActType = TypeVar("ActType")
|
||||
"""This ActType is the type of action at the simulator end."""
|
||||
|
||||
|
||||
class Simulator(Generic[InitialStateType, StateType, ActType]):
|
||||
"""
|
||||
Simulator that resets with ``__init__``, and transits with ``step(action)``.
|
||||
|
||||
To make the data-flow clear, we make the following restrictions to Simulator:
|
||||
|
||||
1. The only way to modify the inner status of a simulator is by using ``step(action)``.
|
||||
2. External modules can *read* the status of a simulator by using ``simulator.get_state()``,
|
||||
and check whether the simulator is in the ending state by calling ``simulator.done()``.
|
||||
|
||||
A simulator is defined to be bounded with three types:
|
||||
|
||||
- *InitialStateType* that is the type of the data used to create the simulator.
|
||||
- *StateType* that is the type of the **status** (state) of the simulator.
|
||||
- *ActType* that is the type of the **action**, which is the input received in each step.
|
||||
|
||||
Different simulators might share the same StateType. For example, when they are dealing with the same task,
|
||||
but with different simulation implementation. With the same type, they can safely share other components in the MDP.
|
||||
|
||||
Simulators are ephemeral. The lifecycle of a simulator starts with an initial state, and ends with the trajectory.
|
||||
In another word, when the trajectory ends, simulator is recycled.
|
||||
If simulators want to share context between (e.g., for speed-up purposes),
|
||||
this could be done by accessing the weak reference of environment wrapper.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
env
|
||||
A reference of env-wrapper, which could be useful in some corner cases.
|
||||
Simulators are discouraged to use this, because it's prone to induce errors.
|
||||
"""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
|
||||
def __init__(self, initial: InitialStateType, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def step(self, action: ActType) -> None:
|
||||
"""Receives an action of ActType.
|
||||
|
||||
Simulator should update its internal state, and return None.
|
||||
The updated state can be retrieved with ``simulator.get_state()``.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_state(self) -> StateType:
|
||||
raise NotImplementedError()
|
||||
|
||||
def done(self) -> bool:
|
||||
"""Check whether the simulator is in a "done" state.
|
||||
When simulator is in a "done" state,
|
||||
it should no longer receives any ``step`` request.
|
||||
As simulators are ephemeral, to reset the simulator,
|
||||
the old one should be destroyed and a new simulator can be created.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
7
qlib/rl/utils/__init__.py
Normal file
7
qlib/rl/utils/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .data_queue import *
|
||||
from .env_wrapper import *
|
||||
from .finite_env import *
|
||||
from .log import *
|
||||
179
qlib/rl/utils/data_queue.py
Normal file
179
qlib/rl/utils/data_queue.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import multiprocessing
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from queue import Empty
|
||||
from typing import TypeVar, Generic, Sequence, cast
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
_logger = get_module_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
__all__ = ["DataQueue"]
|
||||
|
||||
|
||||
class DataQueue(Generic[T]):
|
||||
"""Main process (producer) produces data and stores them in a queue.
|
||||
Sub-processes (consumers) can retrieve the data-points from the queue.
|
||||
Data-points are generated via reading items from ``dataset``.
|
||||
|
||||
:class:`DataQueue` is ephemeral. You must create a new DataQueue
|
||||
when the ``repeat`` is exhausted.
|
||||
|
||||
See the documents of :class:`qlib.rl.utils.FiniteVectorEnv` for more background.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset
|
||||
The dataset to read data from. Must implement ``__len__`` and ``__getitem__``.
|
||||
repeat
|
||||
Iterate over the data-points for how many times. Use ``-1`` to iterate forever.
|
||||
shuffle
|
||||
If ``shuffle`` is true, the items will be read in random order.
|
||||
producer_num_workers
|
||||
Concurrent workers for data-loading.
|
||||
queue_maxsize
|
||||
Maximum items to put into queue before it jams.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> data_queue = DataQueue(my_dataset)
|
||||
>>> with data_queue:
|
||||
... ...
|
||||
|
||||
In worker:
|
||||
|
||||
>>> for data in data_queue:
|
||||
... print(data)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Sequence[T],
|
||||
repeat: int = 1,
|
||||
shuffle: bool = True,
|
||||
producer_num_workers: int = 0,
|
||||
queue_maxsize: int = 0,
|
||||
):
|
||||
if queue_maxsize == 0:
|
||||
if os.cpu_count() is not None:
|
||||
queue_maxsize = cast(int, os.cpu_count())
|
||||
_logger.info(f"Automatically set data queue maxsize to {queue_maxsize} to avoid overwhelming.")
|
||||
else:
|
||||
queue_maxsize = 1
|
||||
_logger.warning(f"CPU count not available. Setting queue maxsize to 1.")
|
||||
|
||||
self.dataset: Sequence[T] = dataset
|
||||
self.repeat: int = repeat
|
||||
self.shuffle: bool = shuffle
|
||||
self.producer_num_workers: int = producer_num_workers
|
||||
|
||||
self._activated: bool = False
|
||||
self._queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=queue_maxsize)
|
||||
self._done = multiprocessing.Value("i", 0)
|
||||
|
||||
def __enter__(self):
|
||||
self.activate()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.cleanup()
|
||||
|
||||
def cleanup(self):
|
||||
with self._done.get_lock():
|
||||
self._done.value += 1
|
||||
for repeat in range(500):
|
||||
if repeat >= 1:
|
||||
warnings.warn(f"After {repeat} cleanup, the queue is still not empty.", category=RuntimeWarning)
|
||||
while not self._queue.empty():
|
||||
try:
|
||||
self._queue.get(block=False)
|
||||
except Empty:
|
||||
pass
|
||||
# Sometimes when the queue gets emptied, more data have already been sent,
|
||||
# and they are on the way into the queue.
|
||||
# If these data didn't get consumed, it will jam the queue and make the process hang.
|
||||
# We wait a second here for potential data arriving, and check again (for ``repeat`` times).
|
||||
time.sleep(1.0)
|
||||
if self._queue.empty():
|
||||
break
|
||||
_logger.debug(f"Remaining items in queue collection done. Empty: {self._queue.empty()}")
|
||||
|
||||
def get(self, block=True):
|
||||
if not hasattr(self, "_first_get"):
|
||||
self._first_get = True
|
||||
if self._first_get:
|
||||
timeout = 5.0
|
||||
self._first_get = False
|
||||
else:
|
||||
timeout = 0.5
|
||||
while True:
|
||||
try:
|
||||
return self._queue.get(block=block, timeout=timeout)
|
||||
except Empty:
|
||||
if self._done.value:
|
||||
raise StopIteration # pylint: disable=raise-missing-from
|
||||
|
||||
def put(self, obj, block=True, timeout=None):
|
||||
return self._queue.put(obj, block=block, timeout=timeout)
|
||||
|
||||
def mark_as_done(self):
|
||||
with self._done.get_lock():
|
||||
self._done.value = 1
|
||||
|
||||
def done(self):
|
||||
return self._done.value
|
||||
|
||||
def activate(self):
|
||||
if self._activated:
|
||||
raise ValueError("DataQueue can not activate twice.")
|
||||
thread = threading.Thread(target=self._producer, daemon=True)
|
||||
thread.start()
|
||||
self._activated = True
|
||||
return self
|
||||
|
||||
def __del__(self):
|
||||
_logger.debug(f"__del__ of {__name__}.DataQueue")
|
||||
self.cleanup()
|
||||
|
||||
def __iter__(self):
|
||||
if not self._activated:
|
||||
raise ValueError(
|
||||
"Need to call activate() to launch a daemon worker " "to produce data into data queue before using it."
|
||||
)
|
||||
return self._consumer()
|
||||
|
||||
def _consumer(self):
|
||||
while True:
|
||||
try:
|
||||
yield self.get()
|
||||
except StopIteration:
|
||||
_logger.debug("Data consumer timed-out from get.")
|
||||
return
|
||||
|
||||
def _producer(self):
|
||||
# pytorch dataloader is used here only because we need its sampler and multi-processing
|
||||
from torch.utils.data import DataLoader, Dataset # pylint: disable=import-outside-toplevel
|
||||
|
||||
dataloader = DataLoader(
|
||||
cast(Dataset[T], self.dataset),
|
||||
batch_size=None,
|
||||
num_workers=self.producer_num_workers,
|
||||
shuffle=self.shuffle,
|
||||
collate_fn=lambda t: t, # identity collate fn
|
||||
)
|
||||
repeat = 10**18 if self.repeat == -1 else self.repeat
|
||||
for _rep in range(repeat):
|
||||
for data in dataloader:
|
||||
if self._done.value:
|
||||
# Already done.
|
||||
return
|
||||
self._queue.put(data)
|
||||
_logger.debug(f"Dataloader loop done. Repeat {_rep}.")
|
||||
self.mark_as_done()
|
||||
249
qlib/rl/utils/env_wrapper.py
Normal file
249
qlib/rl/utils/env_wrapper.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import weakref
|
||||
from typing import Callable, Any, Iterable, Iterator, Generic, cast
|
||||
|
||||
import gym
|
||||
|
||||
from qlib.rl.aux_info import AuxiliaryInfoCollector
|
||||
from qlib.rl.simulator import Simulator, InitialStateType, StateType, ActType
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter, PolicyActType, ObsType
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.typehint import TypedDict
|
||||
|
||||
from .finite_env import generate_nan_observation
|
||||
from .log import LogCollector, LogLevel
|
||||
|
||||
__all__ = ["InfoDict", "EnvWrapperStatus", "EnvWrapper"]
|
||||
|
||||
# in this case, there won't be any seed for simulator
|
||||
SEED_INTERATOR_MISSING = "_missing_"
|
||||
|
||||
|
||||
class InfoDict(TypedDict):
|
||||
"""The type of dict that is used in the 4th return value of ``env.step()``."""
|
||||
|
||||
aux_info: dict
|
||||
"""Any information depends on auxiliary info collector."""
|
||||
log: dict[str, Any]
|
||||
"""Collected by LogCollector."""
|
||||
|
||||
|
||||
class EnvWrapperStatus(TypedDict):
|
||||
"""
|
||||
This is the status data structure used in EnvWrapper.
|
||||
The fields here are in the semantics of RL.
|
||||
For example, ``obs`` means the observation fed into policy.
|
||||
``action`` means the raw action returned by policy.
|
||||
"""
|
||||
|
||||
cur_step: int
|
||||
done: bool
|
||||
initial_state: Any | None
|
||||
obs_history: list
|
||||
action_history: list
|
||||
reward_history: list
|
||||
|
||||
|
||||
class EnvWrapper(
|
||||
gym.Env[ObsType, PolicyActType], Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType]
|
||||
):
|
||||
"""Qlib-based RL environment, subclassing ``gym.Env``.
|
||||
A wrapper of components, including simulator, state-interpreter, action-interpreter, reward.
|
||||
|
||||
This is what the framework of simulator - interpreter - policy looks like in RL training.
|
||||
All the components other than policy needs to be assembled into a single object called "environment".
|
||||
The "environment" are replicated into multiple workers, and (at least in tianshou's implementation),
|
||||
one single policy (agent) plays against a batch of environments.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
simulator_fn
|
||||
A callable that is the simulator factory.
|
||||
When ``seed_iterator`` is present, the factory should take one argument,
|
||||
that is the seed (aka initial state).
|
||||
Otherwise, it should take zero argument.
|
||||
state_interpreter
|
||||
State-observation converter.
|
||||
action_interpreter
|
||||
Policy-simulator action converter.
|
||||
seed_iterator
|
||||
An iterable of seed. With the help of :class:`qlib.rl.utils.DataQueue`,
|
||||
environment workers in different processes can share one ``seed_iterator``.
|
||||
reward_fn
|
||||
A callable that accepts the StateType and returns a float (at least in single-agent case).
|
||||
aux_info_collector
|
||||
Collect auxiliary information. Could be useful in MARL.
|
||||
logger
|
||||
Log collector that collects the logs. The collected logs are sent back to main process,
|
||||
via the return value of ``env.step()``.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
status : EnvWrapperStatus
|
||||
Status indicator. All terms are in *RL language*.
|
||||
It can be used if users care about data on the RL side.
|
||||
Can be none when no trajectory is available.
|
||||
"""
|
||||
|
||||
simulator: Simulator[InitialStateType, StateType, ActType]
|
||||
seed_iterator: str | Iterator[InitialStateType] | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
simulator_fn: Callable[..., Simulator[InitialStateType, StateType, ActType]],
|
||||
state_interpreter: StateInterpreter[StateType, ObsType],
|
||||
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
|
||||
seed_iterator: Iterable[InitialStateType] | None,
|
||||
reward_fn: Reward | None = None,
|
||||
aux_info_collector: AuxiliaryInfoCollector[StateType, Any] | None = None,
|
||||
logger: LogCollector | None = None,
|
||||
):
|
||||
# Assign weak reference to wrapper.
|
||||
#
|
||||
# Use weak reference here, because:
|
||||
# 1. Logically, the other components should be able to live without an env_wrapper.
|
||||
# For example, they might live in a strategy_wrapper in future.
|
||||
# Therefore injecting a "hard" attribute called "env" is not appropripate.
|
||||
# 2. When the environment gets destroyed, it gets destoryed.
|
||||
# We don't want it to silently live inside some interpreters.
|
||||
# 3. Avoid circular reference.
|
||||
# 4. When the components get serialized, we can throw away the env without any burden.
|
||||
# (though this part is not implemented yet)
|
||||
for obj in [state_interpreter, action_interpreter, reward_fn, aux_info_collector]:
|
||||
if obj is not None:
|
||||
obj.env = weakref.proxy(self) # type: ignore
|
||||
|
||||
self.simulator_fn = simulator_fn
|
||||
self.state_interpreter = state_interpreter
|
||||
self.action_interpreter = action_interpreter
|
||||
|
||||
if seed_iterator is None:
|
||||
# In this case, there won't be any seed for simulator
|
||||
# We can't set it to None because None actually means something else.
|
||||
# If `seed_iterator` is None, it means that it's exhausted.
|
||||
self.seed_iterator = SEED_INTERATOR_MISSING
|
||||
else:
|
||||
self.seed_iterator = iter(seed_iterator)
|
||||
self.reward_fn = reward_fn
|
||||
|
||||
self.aux_info_collector = aux_info_collector
|
||||
self.logger: LogCollector = logger or LogCollector()
|
||||
self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
return self.action_interpreter.action_space
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
return self.state_interpreter.observation_space
|
||||
|
||||
def reset(self, **kwargs: Any) -> ObsType:
|
||||
"""
|
||||
Try to get a state from state queue, and init the simulator with this state.
|
||||
If the queue is exhausted, generate an invalid (nan) observation.
|
||||
"""
|
||||
|
||||
try:
|
||||
if self.seed_iterator is None:
|
||||
raise RuntimeError("You can trying to get a state from a dead environment wrapper.")
|
||||
|
||||
# TODO: simulator/observation might need seed to prefetch something
|
||||
# as only seed has the ability to do the work beforehands
|
||||
|
||||
# NOTE: though logger is reset here, logs in this function won't work,
|
||||
# because we can't send them outside.
|
||||
# See https://github.com/thu-ml/tianshou/issues/605
|
||||
self.logger.reset()
|
||||
|
||||
if self.seed_iterator is SEED_INTERATOR_MISSING:
|
||||
# no initial state
|
||||
initial_state = None
|
||||
self.simulator = cast(Callable[[], Simulator], self.simulator_fn)()
|
||||
else:
|
||||
initial_state = next(cast(Iterator[InitialStateType], self.seed_iterator))
|
||||
self.simulator = self.simulator_fn(initial_state)
|
||||
|
||||
self.status = EnvWrapperStatus(
|
||||
cur_step=0,
|
||||
done=False,
|
||||
initial_state=initial_state,
|
||||
obs_history=[],
|
||||
action_history=[],
|
||||
reward_history=[],
|
||||
)
|
||||
|
||||
self.simulator.env = cast(EnvWrapper, weakref.proxy(self))
|
||||
|
||||
sim_state = self.simulator.get_state()
|
||||
obs = self.state_interpreter(sim_state)
|
||||
|
||||
self.status["obs_history"].append(obs)
|
||||
|
||||
return obs
|
||||
|
||||
except StopIteration:
|
||||
# The environment should be recycled because it's in a dead state.
|
||||
self.seed_iterator = None
|
||||
return generate_nan_observation(self.observation_space)
|
||||
|
||||
def step(self, policy_action: PolicyActType, **kwargs: Any) -> tuple[ObsType, float, bool, InfoDict]:
|
||||
"""Environment step.
|
||||
|
||||
See the code along with comments to get a sequence of things happening here.
|
||||
"""
|
||||
|
||||
if self.seed_iterator is None:
|
||||
raise RuntimeError("State queue is already exhausted, but the environment is still receiving action.")
|
||||
|
||||
# Clear the logged information from last step
|
||||
self.logger.reset()
|
||||
|
||||
# Action is what we have got from policy
|
||||
self.status["action_history"].append(policy_action)
|
||||
action = self.action_interpreter(self.simulator.get_state(), policy_action)
|
||||
|
||||
# This update must be after action interpreter and before simulator.
|
||||
self.status["cur_step"] += 1
|
||||
|
||||
# Use the converted action of update the simulator
|
||||
self.simulator.step(action)
|
||||
|
||||
# Update "done" first, as this status might be used by reward_fn later
|
||||
done = self.simulator.done()
|
||||
self.status["done"] = done
|
||||
|
||||
# Get state and calculate observation
|
||||
sim_state = self.simulator.get_state()
|
||||
obs = self.state_interpreter(sim_state)
|
||||
self.status["obs_history"].append(obs)
|
||||
|
||||
# Reward and extra info
|
||||
if self.reward_fn is not None:
|
||||
rew = self.reward_fn(sim_state)
|
||||
else:
|
||||
# No reward. Treated as 0.
|
||||
rew = 0.0
|
||||
self.status["reward_history"].append(rew)
|
||||
|
||||
if self.aux_info_collector is not None:
|
||||
aux_info = self.aux_info_collector(sim_state)
|
||||
else:
|
||||
aux_info = {}
|
||||
|
||||
# Final logging stuff: RL-specific logs
|
||||
if done:
|
||||
self.logger.add_scalar("steps_per_episode", self.status["cur_step"])
|
||||
self.logger.add_scalar("reward", rew)
|
||||
self.logger.add_any("obs", obs, loglevel=LogLevel.DEBUG)
|
||||
self.logger.add_any("policy_act", policy_action, loglevel=LogLevel.DEBUG)
|
||||
|
||||
info_dict = InfoDict(log=self.logger.logs(), aux_info=aux_info)
|
||||
return obs, rew, done, info_dict
|
||||
|
||||
def render(self):
|
||||
raise NotImplementedError("Render is not implemented in EnvWrapper.")
|
||||
337
qlib/rl/utils/finite_env.py
Normal file
337
qlib/rl/utils/finite_env.py
Normal file
@@ -0,0 +1,337 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This is to support finite env in vector env.
|
||||
See https://github.com/thu-ml/tianshou/issues/322 for details.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from typing import Any, Set, Callable, Type
|
||||
|
||||
from tianshou.env import BaseVectorEnv, DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv
|
||||
|
||||
from qlib.typehint import Literal
|
||||
from .log import LogWriter
|
||||
|
||||
__all__ = [
|
||||
"generate_nan_observation",
|
||||
"check_nan_observation",
|
||||
"FiniteVectorEnv",
|
||||
"FiniteDummyVectorEnv",
|
||||
"FiniteSubprocVectorEnv",
|
||||
"FiniteShmemVectorEnv",
|
||||
"FiniteEnvType",
|
||||
"vectorize_env",
|
||||
]
|
||||
|
||||
|
||||
FiniteEnvType = Literal["dummy", "subproc", "shmem"]
|
||||
|
||||
|
||||
def fill_invalid(obj):
|
||||
if isinstance(obj, (int, float, bool)):
|
||||
return fill_invalid(np.array(obj))
|
||||
if hasattr(obj, "dtype"):
|
||||
if isinstance(obj, np.ndarray):
|
||||
if np.issubdtype(obj.dtype, np.floating):
|
||||
return np.full_like(obj, np.nan)
|
||||
return np.full_like(obj, np.iinfo(obj.dtype).max)
|
||||
# dealing with corner cases that numpy number is not supported by tianshou's sharray
|
||||
return fill_invalid(np.array(obj))
|
||||
elif isinstance(obj, dict):
|
||||
return {k: fill_invalid(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [fill_invalid(v) for v in obj]
|
||||
elif isinstance(obj, tuple):
|
||||
return tuple(fill_invalid(v) for v in obj)
|
||||
raise ValueError(f"Unsupported value to fill with invalid: {obj}")
|
||||
|
||||
|
||||
def is_invalid(arr):
|
||||
if hasattr(arr, "dtype"):
|
||||
if np.issubdtype(arr.dtype, np.floating):
|
||||
return np.isnan(arr).all()
|
||||
return (np.iinfo(arr.dtype).max == arr).all()
|
||||
if isinstance(arr, dict):
|
||||
return all(is_invalid(o) for o in arr.values())
|
||||
if isinstance(arr, (list, tuple)):
|
||||
return all(is_invalid(o) for o in arr)
|
||||
if isinstance(arr, (int, float, bool, np.number)):
|
||||
return is_invalid(np.array(arr))
|
||||
return True
|
||||
|
||||
|
||||
def generate_nan_observation(obs_space: gym.Space) -> Any:
|
||||
"""The NaN observation that indicates the environment receives no seed.
|
||||
|
||||
We assume that obs is complex and there must be something like float.
|
||||
Otherwise this logic doesn't work.
|
||||
"""
|
||||
|
||||
sample = obs_space.sample()
|
||||
sample = fill_invalid(sample)
|
||||
return sample
|
||||
|
||||
|
||||
def check_nan_observation(obs: Any) -> bool:
|
||||
"""Check whether obs is generated by :func:`generate_nan_observation`."""
|
||||
return is_invalid(obs)
|
||||
|
||||
|
||||
class FiniteVectorEnv(BaseVectorEnv):
|
||||
"""To allow the paralleled env workers consume a single DataQueue until it's exhausted.
|
||||
|
||||
See `tianshou issue #322 <https://github.com/thu-ml/tianshou/issues/322>`_.
|
||||
|
||||
The requirement is to make every possible seed (stored in :class:`qlib.rl.utils.DataQueue` in our case)
|
||||
consumed by exactly one environment. This is not possible by tianshou's native VectorEnv and Collector,
|
||||
because tianshou is unaware of this "exactly one" constraint, and might launch extra workers.
|
||||
|
||||
Consider a corner case, where concurrency is 2, but there is only one seed in DataQueue.
|
||||
The reset of two workers must be both called according to the logic in collect.
|
||||
The returned results of two workers are collected, regardless of what they are.
|
||||
The problem is, one of the reset result must be invalid, or repeated,
|
||||
because there's only one need in queue, and collector isn't aware of such situation.
|
||||
|
||||
Luckily, we can hack the vector env, and make a protocol between single env and vector env.
|
||||
The single environment (should be :class:`qlib.rl.utils.EnvWrapper` in our case) is responsible for
|
||||
reading from queue, and generate a special observation when the queue is exhausted. The special obs
|
||||
is called "nan observation", because simply using none causes problems in shared-memory vector env.
|
||||
:class:`FiniteVectorEnv` then read the observations from all workers, and select those non-nan
|
||||
observation. It also maintains an ``_alive_env_ids`` to track which workers should never be
|
||||
called again. When also the environments are exhausted, it will raise StopIteration exception.
|
||||
|
||||
The usage of this vector env in collector are two parts:
|
||||
|
||||
1. If the data queue is finite (usually when inference), collector should collect "infinity" number of
|
||||
episodes, until the vector env exhausts by itself.
|
||||
2. If the data queue is infinite (usually in training), collector can set number of episodes / steps.
|
||||
In this case, data would be randomly ordered, and some repetitions wouldn't matter.
|
||||
|
||||
One extra function of this vector env is that it has a logger that explicitly collects logs
|
||||
from child workers. See :class:`qlib.rl.utils.LogWriter`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, logger: LogWriter | list[LogWriter], env_fns: list[Callable[..., gym.Env]], **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(env_fns, **kwargs)
|
||||
|
||||
self._logger: list[LogWriter] = logger if isinstance(logger, list) else [logger]
|
||||
self._alive_env_ids: Set[int] = set()
|
||||
self._reset_alive_envs()
|
||||
self._default_obs = self._default_info = self._default_rew = None
|
||||
self._zombie = False
|
||||
|
||||
self._collector_guarded: bool = False
|
||||
|
||||
def _reset_alive_envs(self):
|
||||
if not self._alive_env_ids:
|
||||
# starting or running out
|
||||
self._alive_env_ids = set(range(self.env_num))
|
||||
|
||||
# to workaround with tianshou's buffer and batch
|
||||
def _set_default_obs(self, obs):
|
||||
if obs is not None and self._default_obs is None:
|
||||
self._default_obs = copy.deepcopy(obs)
|
||||
|
||||
def _set_default_info(self, info):
|
||||
if info is not None and self._default_info is None:
|
||||
self._default_info = copy.deepcopy(info)
|
||||
|
||||
def _set_default_rew(self, rew):
|
||||
if rew is not None and self._default_rew is None:
|
||||
self._default_rew = copy.deepcopy(rew)
|
||||
|
||||
def _get_default_obs(self):
|
||||
return copy.deepcopy(self._default_obs)
|
||||
|
||||
def _get_default_info(self):
|
||||
return copy.deepcopy(self._default_info)
|
||||
|
||||
def _get_default_rew(self):
|
||||
return copy.deepcopy(self._default_rew)
|
||||
|
||||
# END
|
||||
|
||||
@staticmethod
|
||||
def _postproc_env_obs(obs):
|
||||
# reserved for shmem vector env to restore empty observation
|
||||
if obs is None or check_nan_observation(obs):
|
||||
return None
|
||||
return obs
|
||||
|
||||
@contextmanager
|
||||
def collector_guard(self):
|
||||
"""Guard the collector. Recommended to guard every collect.
|
||||
|
||||
This guard is for two purposes.
|
||||
|
||||
1. Catch and ignore the StopIteration exception, which is the stopping signal
|
||||
thrown by FiniteEnv to let tianshou know that ``collector.collect()`` should exit.
|
||||
2. Notify the loggers that the collect is done what it's done.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> with finite_env.collector_guard():
|
||||
... collector.collect(n_episode=INF)
|
||||
"""
|
||||
self._collector_guarded = True
|
||||
|
||||
try:
|
||||
yield self
|
||||
except StopIteration:
|
||||
pass
|
||||
finally:
|
||||
self._collector_guarded = False
|
||||
|
||||
# At last trigger the loggers
|
||||
for logger in self._logger:
|
||||
logger.on_env_all_done()
|
||||
|
||||
def reset(self, id=None):
|
||||
assert not self._zombie
|
||||
|
||||
# Check whether it's guarded by collector_guard()
|
||||
if not self._collector_guarded:
|
||||
warnings.warn(
|
||||
"Collector is not guarded by FiniteEnv. "
|
||||
"This may cause unexpected problems, like unexpected StopIteration exception, "
|
||||
"or missing logs.",
|
||||
RuntimeWarning,
|
||||
)
|
||||
|
||||
id = self._wrap_id(id)
|
||||
self._reset_alive_envs()
|
||||
|
||||
# ask super to reset alive envs and remap to current index
|
||||
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
|
||||
obs = [None] * len(id)
|
||||
id2idx = {i: k for k, i in enumerate(id)}
|
||||
if request_id:
|
||||
for i, o in zip(request_id, super().reset(request_id)):
|
||||
obs[id2idx[i]] = self._postproc_env_obs(o)
|
||||
|
||||
for i, o in zip(id, obs):
|
||||
if o is None and i in self._alive_env_ids:
|
||||
self._alive_env_ids.remove(i)
|
||||
|
||||
# logging
|
||||
for i, o in zip(id, obs):
|
||||
if i in self._alive_env_ids:
|
||||
for logger in self._logger:
|
||||
logger.on_env_reset(i, obs)
|
||||
|
||||
# fill empty observation with default(fake) observation
|
||||
for o in obs:
|
||||
self._set_default_obs(o)
|
||||
for i, o in enumerate(obs):
|
||||
if o is None:
|
||||
obs[i] = self._get_default_obs()
|
||||
|
||||
if not self._alive_env_ids:
|
||||
# comment this line so that the env becomes indisposable
|
||||
# self.reset()
|
||||
self._zombie = True
|
||||
raise StopIteration
|
||||
|
||||
return np.stack(obs)
|
||||
|
||||
def step(self, action, id=None):
|
||||
assert not self._zombie
|
||||
id = self._wrap_id(id)
|
||||
id2idx = {i: k for k, i in enumerate(id)}
|
||||
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
|
||||
result = [[None, None, False, None] for _ in range(len(id))]
|
||||
|
||||
# ask super to step alive envs and remap to current index
|
||||
if request_id:
|
||||
valid_act = np.stack([action[id2idx[i]] for i in request_id])
|
||||
for i, r in zip(request_id, zip(*super().step(valid_act, request_id))):
|
||||
result[id2idx[i]] = list(r)
|
||||
result[id2idx[i]][0] = self._postproc_env_obs(result[id2idx[i]][0])
|
||||
|
||||
# logging
|
||||
for i, r in zip(id, result):
|
||||
if i in self._alive_env_ids:
|
||||
for logger in self._logger:
|
||||
logger.on_env_step(i, *r)
|
||||
|
||||
# fill empty observation/info with default(fake)
|
||||
for _, r, ___, i in result:
|
||||
self._set_default_info(i)
|
||||
self._set_default_rew(r)
|
||||
for i, r in enumerate(result):
|
||||
if r[0] is None:
|
||||
result[i][0] = self._get_default_obs()
|
||||
if r[1] is None:
|
||||
result[i][1] = self._get_default_rew()
|
||||
if r[3] is None:
|
||||
result[i][3] = self._get_default_info()
|
||||
|
||||
return list(map(np.stack, zip(*result)))
|
||||
|
||||
|
||||
class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv):
|
||||
pass
|
||||
|
||||
|
||||
class FiniteSubprocVectorEnv(FiniteVectorEnv, SubprocVectorEnv):
|
||||
pass
|
||||
|
||||
|
||||
class FiniteShmemVectorEnv(FiniteVectorEnv, ShmemVectorEnv):
|
||||
pass
|
||||
|
||||
|
||||
def vectorize_env(
|
||||
env_factory: Callable[..., gym.Env],
|
||||
env_type: FiniteEnvType,
|
||||
concurrency: int,
|
||||
logger: LogWriter | list[LogWriter],
|
||||
) -> FiniteVectorEnv:
|
||||
"""Helper function to create a vector env.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
env_factory
|
||||
Callable to instantiate one single ``gym.Env``.
|
||||
All concurrent workers will have the same ``env_factory``.
|
||||
env_type
|
||||
dummy or subproc or shmem. Corresponding to
|
||||
`parallelism in tianshou <https://tianshou.readthedocs.io/en/master/api/tianshou.env.html#vectorenv>`_.
|
||||
concurrency
|
||||
Concurrent environment workers.
|
||||
logger
|
||||
Log writers.
|
||||
|
||||
Warnings
|
||||
--------
|
||||
Please do not use lambda expression here for ``env_factory`` as it may create incorrectly-shared instances.
|
||||
|
||||
Don't do: ::
|
||||
|
||||
vectorize_env(lambda: EnvWrapper(...), ...)
|
||||
|
||||
Please do: ::
|
||||
|
||||
def env_factory(): ...
|
||||
vectorize_env(env_factory, ...)
|
||||
"""
|
||||
env_type_cls_mapping: dict[str, Type[FiniteVectorEnv]] = {
|
||||
"dummy": FiniteDummyVectorEnv,
|
||||
"subproc": FiniteSubprocVectorEnv,
|
||||
"shmem": FiniteShmemVectorEnv,
|
||||
}
|
||||
|
||||
finite_env_cls = env_type_cls_mapping[env_type]
|
||||
|
||||
return finite_env_cls(logger, [env_factory for _ in range(concurrency)])
|
||||
398
qlib/rl/utils/log.py
Normal file
398
qlib/rl/utils/log.py
Normal file
@@ -0,0 +1,398 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Distributed logger for RL.
|
||||
|
||||
:class:`LogCollector` runs in every environment workers. It collects log info from simulator states,
|
||||
and add them (as a dict) to auxiliary info returned for each step.
|
||||
|
||||
:class:`LogWriter` runs in the central worker. It decodes the dict collected by :class:`LogCollector`
|
||||
in each worker, and writes them to console, log files, or tensorboard...
|
||||
|
||||
The two modules communicate by the "log" field in "info" returned by ``env.step()``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .env_wrapper import InfoDict
|
||||
|
||||
|
||||
__all__ = ["LogCollector", "LogWriter", "LogLevel", "ConsoleWriter", "CsvWriter"]
|
||||
|
||||
ObsType = TypeVar("ObsType")
|
||||
ActType = TypeVar("ActType")
|
||||
|
||||
|
||||
class LogLevel(IntEnum):
|
||||
"""Log-levels for RL training.
|
||||
The behavior of handling each log level depends on the implementation of :class:`LogWriter`.
|
||||
"""
|
||||
|
||||
DEBUG = 10
|
||||
"""If you only want to see the metric in debug mode."""
|
||||
PERIODIC = 20
|
||||
"""If you want to see the metric periodically."""
|
||||
# FIXME: I haven't given much thought about this. Let's hold it for one iteration.
|
||||
|
||||
INFO = 30
|
||||
"""Important log messages."""
|
||||
CRITICAL = 40
|
||||
"""LogWriter should always handle CRITICAL messages"""
|
||||
|
||||
|
||||
class LogCollector:
|
||||
"""Logs are first collected in each environment worker,
|
||||
and then aggregated to stream at the central thread in vector env.
|
||||
|
||||
In :class:`LogCollector`, every metric is added to a dict, which needs to be ``reset()`` at each step.
|
||||
The dict is sent via the ``info`` in ``env.step()``, and decoded by the :class:`LogWriter` at vector env.
|
||||
|
||||
``min_loglevel`` is for optimization purposes: to avoid too much traffic on networks / in pipe.
|
||||
"""
|
||||
|
||||
_logged: dict[str, tuple[int, Any]]
|
||||
_min_loglevel: int
|
||||
|
||||
def __init__(self, min_loglevel: int | LogLevel = LogLevel.PERIODIC):
|
||||
self._min_loglevel = int(min_loglevel)
|
||||
|
||||
def reset(self):
|
||||
"""Clear all collected contents."""
|
||||
self._logged = {}
|
||||
|
||||
def _add_metric(self, name: str, metric: Any, loglevel: int | LogLevel) -> None:
|
||||
if name in self._logged:
|
||||
raise ValueError(f"A metric with {name} is already added. Please change a name or reset the log collector.")
|
||||
self._logged[name] = (int(loglevel), metric)
|
||||
|
||||
def add_string(self, name: str, string: str, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
|
||||
"""Add a string with name into logged contents."""
|
||||
if loglevel < self._min_loglevel:
|
||||
return
|
||||
if not isinstance(string, str):
|
||||
raise TypeError(f"{string} is not a string.")
|
||||
self._add_metric(name, string, loglevel)
|
||||
|
||||
def add_scalar(self, name: str, scalar: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
|
||||
"""Add a scalar with name into logged contents.
|
||||
Scalar will be converted into a float.
|
||||
"""
|
||||
if loglevel < self._min_loglevel:
|
||||
return
|
||||
|
||||
if hasattr(scalar, "item"):
|
||||
# could be single-item number
|
||||
scalar = scalar.item()
|
||||
if not isinstance(scalar, (float, int)):
|
||||
raise TypeError(f"{scalar} is not and can not be converted into float or integer.")
|
||||
scalar = float(scalar)
|
||||
self._add_metric(name, scalar, loglevel)
|
||||
|
||||
def add_array(
|
||||
self, name: str, array: np.ndarray | pd.DataFrame | pd.Series, loglevel: int | LogLevel = LogLevel.PERIODIC
|
||||
) -> None:
|
||||
"""Add an array with name into logging."""
|
||||
if loglevel < self._min_loglevel:
|
||||
return
|
||||
|
||||
if not isinstance(array, (np.ndarray, pd.DataFrame, pd.Series)):
|
||||
raise TypeError(f"{array} is not one of ndarray, DataFrame and Series.")
|
||||
self._add_metric(name, array, loglevel)
|
||||
|
||||
def add_any(self, name: str, obj: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
|
||||
"""Log something with any type.
|
||||
|
||||
As it's an "any" object, the only LogWriter accepting it is pickle.
|
||||
Therefore pickle must be able to serialize it.
|
||||
"""
|
||||
if loglevel < self._min_loglevel:
|
||||
return
|
||||
|
||||
# FIXME: detect and rescue object that could be scalar or array
|
||||
|
||||
self._add_metric(name, obj, loglevel)
|
||||
|
||||
def logs(self) -> dict[str, np.ndarray]:
|
||||
return {key: np.asanyarray(value, dtype="object") for key, value in self._logged.items()}
|
||||
|
||||
|
||||
class LogWriter(Generic[ObsType, ActType]):
|
||||
"""Base class for log writers, triggered at every reset and step by finite env.
|
||||
|
||||
What to do with a specific log depends on the implementation of subclassing :class:`LogWriter`.
|
||||
The general principle is that, it should handle logs above its loglevel (inclusive),
|
||||
and discard logs that are not acceptable. For instance, console loggers obviously can't handle an image.
|
||||
"""
|
||||
|
||||
episode_count: int
|
||||
"""Counter of episodes."""
|
||||
|
||||
step_count: int
|
||||
"""Counter of steps."""
|
||||
|
||||
global_step: int
|
||||
"""Counter of steps. Won"t be cleared in ``clear``."""
|
||||
|
||||
global_episode: int
|
||||
"""Counter of episodes. Won"t be cleared in ``clear``."""
|
||||
|
||||
active_env_ids: Set[int]
|
||||
"""Active environment ids in vector env."""
|
||||
|
||||
episode_lengths: dict[int, int]
|
||||
"""Map from environment id to episode length."""
|
||||
|
||||
episode_rewards: dict[int, list[float]]
|
||||
"""Map from environment id to episode total reward."""
|
||||
|
||||
episode_logs: dict[int, list]
|
||||
"""Map from environment id to episode logs."""
|
||||
|
||||
def __init__(self, loglevel: int | LogLevel = LogLevel.PERIODIC):
|
||||
self.loglevel = loglevel
|
||||
|
||||
self.global_step = 0
|
||||
self.global_episode = 0
|
||||
|
||||
# Information, logs of one episode is stored here.
|
||||
# This assumes that episode is not too long to fit into the memory.
|
||||
self.episode_lengths = dict()
|
||||
self.episode_rewards = dict()
|
||||
self.episode_logs = dict()
|
||||
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
self.episode_count = self.step_count = 0
|
||||
self.active_env_ids = set()
|
||||
self.logs = []
|
||||
|
||||
def aggregation(self, array: Sequence[Any]) -> Any:
|
||||
"""Aggregation function from step-wise to episode-wise.
|
||||
|
||||
If it's a sequence of float, take the mean.
|
||||
Otherwise, take the first element.
|
||||
"""
|
||||
assert len(array) > 0, "The aggregated array must be not empty."
|
||||
if all(isinstance(v, float) for v in array):
|
||||
return np.mean(array)
|
||||
else:
|
||||
return array[0]
|
||||
|
||||
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
|
||||
"""This is triggered at the end of each trajectory.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
length
|
||||
Length of this trajectory.
|
||||
rewards
|
||||
A list of rewards at each step of this episode.
|
||||
contents
|
||||
Logged contents for every steps.
|
||||
"""
|
||||
|
||||
def log_step(self, reward: float, contents: dict[str, Any]) -> None:
|
||||
"""This is triggered at each step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
reward
|
||||
Reward for this step.
|
||||
contents
|
||||
Logged contents for this step.
|
||||
"""
|
||||
|
||||
def on_env_step(self, env_id: int, obs: ObsType, rew: float, done: bool, info: InfoDict) -> None:
|
||||
"""Callback for finite env, on each step."""
|
||||
|
||||
# Update counter
|
||||
self.global_step += 1
|
||||
self.step_count += 1
|
||||
|
||||
self.active_env_ids.add(env_id)
|
||||
self.episode_lengths[env_id] += 1
|
||||
# TODO: reward can be a list of list for MARL
|
||||
self.episode_rewards[env_id].append(rew)
|
||||
|
||||
values: dict[str, Any] = {}
|
||||
|
||||
for key, (loglevel, value) in info["log"].items():
|
||||
if loglevel >= self.loglevel: # FIXME: this is actually incorrect (see last FIXME)
|
||||
values[key] = value
|
||||
self.episode_logs[env_id].append(values)
|
||||
|
||||
self.log_step(rew, values)
|
||||
|
||||
if done:
|
||||
# Update counter
|
||||
self.global_episode += 1
|
||||
self.episode_count += 1
|
||||
|
||||
self.log_episode(self.episode_lengths[env_id], self.episode_rewards[env_id], self.episode_logs[env_id])
|
||||
|
||||
def on_env_reset(self, env_id: int, obs: ObsType) -> None:
|
||||
"""Callback for finite env.
|
||||
|
||||
Reset episode statistics. Nothing task-specific is logged here because of
|
||||
`a limitation of tianshou <https://github.com/thu-ml/tianshou/issues/605>`__.
|
||||
"""
|
||||
self.episode_lengths[env_id] = 0
|
||||
self.episode_rewards[env_id] = []
|
||||
self.episode_logs[env_id] = []
|
||||
|
||||
def on_env_all_done(self) -> None:
|
||||
"""All done. Time for cleanup."""
|
||||
|
||||
|
||||
class ConsoleWriter(LogWriter):
|
||||
"""Write log messages to console periodically.
|
||||
|
||||
It tracks an average meter for each metric, which is the average value since last ``clear()`` till now.
|
||||
The display format for each metric is ``<name> <latest_value> (<average_value>)``.
|
||||
|
||||
Non-single-number metrics are auto skipped.
|
||||
"""
|
||||
|
||||
prefix: str
|
||||
"""Prefix can be set via ``writer.prefix``."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_every_n_episode: int = 20,
|
||||
total_episodes: int | None = None,
|
||||
float_format: str = ":.4f",
|
||||
counter_format: str = ":4d",
|
||||
loglevel: int | LogLevel = LogLevel.PERIODIC,
|
||||
):
|
||||
super().__init__(loglevel)
|
||||
# TODO: support log_every_n_step
|
||||
self.log_every_n_episode = log_every_n_episode
|
||||
self.total_episodes = total_episodes
|
||||
|
||||
self.counter_format = counter_format
|
||||
self.float_format = float_format
|
||||
|
||||
self.prefix = ""
|
||||
|
||||
self.console_logger = get_module_logger(__name__, level=logging.INFO)
|
||||
|
||||
def clear(self):
|
||||
super().clear()
|
||||
# Clear average meters
|
||||
self.metric_counts: dict[str, int] = defaultdict(int)
|
||||
self.metric_sums: dict[str, float] = defaultdict(float)
|
||||
|
||||
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
|
||||
# Aggregate step-wise to episode-wise
|
||||
episode_wise_contents: dict[str, list] = defaultdict(list)
|
||||
|
||||
for step_contents in contents:
|
||||
for name, value in step_contents.items():
|
||||
if isinstance(value, float):
|
||||
episode_wise_contents[name].append(value)
|
||||
|
||||
# Generate log contents and track them in average-meter.
|
||||
# This should be done at every step, regardless of periodic or not.
|
||||
logs: dict[str, float] = {}
|
||||
for name, values in episode_wise_contents.items():
|
||||
logs[name] = self.aggregation(values) # type: ignore
|
||||
|
||||
for name, value in logs.items():
|
||||
self.metric_counts[name] += 1
|
||||
self.metric_sums[name] += value
|
||||
|
||||
if self.episode_count % self.log_every_n_episode == 0 or self.episode_count == self.total_episodes:
|
||||
# Only log periodically or at the end
|
||||
self.console_logger.info(self.generate_log_message(logs))
|
||||
|
||||
def generate_log_message(self, logs: dict[str, float]) -> str:
|
||||
if self.prefix:
|
||||
msg_prefix = self.prefix + " "
|
||||
else:
|
||||
msg_prefix = ""
|
||||
if self.total_episodes is None:
|
||||
msg_prefix += "[Step {" + self.counter_format + "}]"
|
||||
else:
|
||||
msg_prefix += "[{" + self.counter_format + "}/" + str(self.total_episodes) + "]"
|
||||
msg_prefix = msg_prefix.format(self.episode_count)
|
||||
|
||||
msg = ""
|
||||
for name, value in logs.items():
|
||||
# Double-space as delimiter
|
||||
format_template = r" {} {" + self.float_format + "} ({" + self.float_format + "})"
|
||||
msg += format_template.format(name, value, self.metric_sums[name] / self.metric_counts[name])
|
||||
|
||||
msg = msg_prefix + " " + msg
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
class CsvWriter(LogWriter):
|
||||
"""Dump all episode metrics to a ``result.csv``.
|
||||
|
||||
This is not the correct implementation. It's only used for first iteration.
|
||||
"""
|
||||
|
||||
SUPPORTED_TYPES = (float, str, pd.Timestamp)
|
||||
|
||||
all_records: list[dict[str, Any]]
|
||||
|
||||
def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC):
|
||||
super().__init__(loglevel)
|
||||
self.output_dir = output_dir
|
||||
self.output_dir.mkdir(exist_ok=True)
|
||||
|
||||
def clear(self):
|
||||
super().clear()
|
||||
self.all_records = []
|
||||
|
||||
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
|
||||
# FIXME Same as ConsoleLogger, needs a refactor to eliminate code-dup
|
||||
episode_wise_contents: dict[str, list] = defaultdict(list)
|
||||
|
||||
for step_contents in contents:
|
||||
for name, value in step_contents.items():
|
||||
if isinstance(value, self.SUPPORTED_TYPES):
|
||||
episode_wise_contents[name].append(value)
|
||||
|
||||
logs: dict[str, float] = {}
|
||||
for name, values in episode_wise_contents.items():
|
||||
logs[name] = self.aggregation(values) # type: ignore
|
||||
|
||||
self.all_records.append(logs)
|
||||
|
||||
def on_env_all_done(self) -> None:
|
||||
# FIXME: this is temporary
|
||||
pd.DataFrame.from_records(self.all_records).to_csv(self.output_dir / "result.csv", index=False)
|
||||
|
||||
|
||||
# The following are not implemented yet.
|
||||
|
||||
|
||||
class PickleWriter(LogWriter):
|
||||
"""Dump logs to pickle files."""
|
||||
|
||||
|
||||
class TensorboardWriter(LogWriter):
|
||||
"""Write logs to event files that can be visualized with tensorboard."""
|
||||
|
||||
|
||||
class MlflowWriter(LogWriter):
|
||||
"""Add logs to mlflow."""
|
||||
|
||||
|
||||
class LogBuffer(LogWriter):
|
||||
"""Keep everything in memory."""
|
||||
13
qlib/typehint.py
Normal file
13
qlib/typehint.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Commonly used types."""
|
||||
|
||||
import sys
|
||||
|
||||
__all__ = ["Literal", "TypedDict", "final"]
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import Literal, TypedDict, final # type: ignore # pylint: disable=no-name-in-module
|
||||
else:
|
||||
from typing_extensions import Literal, TypedDict, final
|
||||
7
setup.py
7
setup.py
@@ -134,7 +134,12 @@ setup(
|
||||
"sphinx",
|
||||
"sphinx_rtd_theme",
|
||||
"pre-commit",
|
||||
]
|
||||
],
|
||||
"rl": [
|
||||
"tianshou",
|
||||
"gym",
|
||||
"torch",
|
||||
],
|
||||
},
|
||||
include_package_data=True,
|
||||
classifiers=[
|
||||
|
||||
10
tests/conftest.py
Normal file
10
tests/conftest.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
"""Ignore RL tests on non-linux platform."""
|
||||
collect_ignore = []
|
||||
|
||||
if sys.platform != "linux":
|
||||
for root, dirs, files in os.walk("rl"):
|
||||
for file in files:
|
||||
collect_ignore.append(os.path.join(root, file))
|
||||
4
tests/pytest.ini
Normal file
4
tests/pytest.ini
Normal file
@@ -0,0 +1,4 @@
|
||||
[pytest]
|
||||
filterwarnings =
|
||||
ignore:.*rng.randint:DeprecationWarning
|
||||
ignore:.*Casting input x to numpy array:UserWarning
|
||||
88
tests/rl/test_data_queue.py
Normal file
88
tests/rl/test_data_queue.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from qlib.rl.utils.data_queue import DataQueue
|
||||
|
||||
|
||||
class DummyDataset(Dataset):
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
|
||||
def __getitem__(self, index):
|
||||
assert 0 <= index < self.length
|
||||
return pd.DataFrame(np.random.randint(0, 100, size=(index + 1, 4)), columns=list("ABCD"))
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
def _worker(dataloader, collector):
|
||||
# for i in range(3):
|
||||
for i, data in enumerate(dataloader):
|
||||
collector.put(len(data))
|
||||
|
||||
|
||||
def _queue_to_list(queue):
|
||||
result = []
|
||||
while not queue.empty():
|
||||
result.append(queue.get())
|
||||
return result
|
||||
|
||||
|
||||
def test_pytorch_dataloader():
|
||||
dataset = DummyDataset(100)
|
||||
dataloader = DataLoader(dataset, batch_size=None, num_workers=1)
|
||||
queue = multiprocessing.Queue()
|
||||
_worker(dataloader, queue)
|
||||
assert len(set(_queue_to_list(queue))) == 100
|
||||
|
||||
|
||||
def test_multiprocess_shared_dataloader():
|
||||
dataset = DummyDataset(100)
|
||||
with DataQueue(dataset, producer_num_workers=1) as data_queue:
|
||||
queue = multiprocessing.Queue()
|
||||
processes = []
|
||||
for _ in range(3):
|
||||
processes.append(multiprocessing.Process(target=_worker, args=(data_queue, queue)))
|
||||
processes[-1].start()
|
||||
for p in processes:
|
||||
p.join()
|
||||
assert len(set(_queue_to_list(queue))) == 100
|
||||
|
||||
|
||||
def test_exit_on_crash_finite():
|
||||
def _exit_finite():
|
||||
dataset = DummyDataset(100)
|
||||
|
||||
with DataQueue(dataset, producer_num_workers=4) as data_queue:
|
||||
time.sleep(3)
|
||||
raise ValueError
|
||||
|
||||
# https://stackoverflow.com/questions/34506638/how-to-register-atexit-function-in-pythons-multiprocessing-subprocess
|
||||
|
||||
process = multiprocessing.Process(target=_exit_finite)
|
||||
process.start()
|
||||
process.join()
|
||||
|
||||
|
||||
def test_exit_on_crash_infinite():
|
||||
def _exit_infinite():
|
||||
dataset = DummyDataset(100)
|
||||
with DataQueue(dataset, repeat=-1, queue_maxsize=100) as data_queue:
|
||||
time.sleep(3)
|
||||
raise ValueError
|
||||
|
||||
process = multiprocessing.Process(target=_exit_infinite)
|
||||
process.start()
|
||||
process.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_multiprocess_shared_dataloader()
|
||||
249
tests/rl/test_finite_env.py
Normal file
249
tests/rl/test_finite_env.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from collections import Counter
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from tianshou.data import Batch, Collector
|
||||
from tianshou.policy import BasePolicy
|
||||
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
||||
from qlib.rl.utils.finite_env import (
|
||||
LogWriter,
|
||||
FiniteDummyVectorEnv,
|
||||
FiniteShmemVectorEnv,
|
||||
FiniteSubprocVectorEnv,
|
||||
check_nan_observation,
|
||||
generate_nan_observation,
|
||||
)
|
||||
|
||||
|
||||
_test_space = gym.spaces.Dict(
|
||||
{
|
||||
"sensors": gym.spaces.Dict(
|
||||
{
|
||||
"position": gym.spaces.Box(low=-100, high=100, shape=(3,)),
|
||||
"velocity": gym.spaces.Box(low=-1, high=1, shape=(3,)),
|
||||
"front_cam": gym.spaces.Tuple(
|
||||
(gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)), gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)))
|
||||
),
|
||||
"rear_cam": gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)),
|
||||
}
|
||||
),
|
||||
"ext_controller": gym.spaces.MultiDiscrete((5, 2, 2)),
|
||||
"inner_state": gym.spaces.Dict(
|
||||
{
|
||||
"charge": gym.spaces.Discrete(100),
|
||||
"system_checks": gym.spaces.MultiBinary(10),
|
||||
"job_status": gym.spaces.Dict(
|
||||
{
|
||||
"task": gym.spaces.Discrete(5),
|
||||
"progress": gym.spaces.Box(low=0, high=100, shape=()),
|
||||
}
|
||||
),
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class FiniteEnv(gym.Env):
|
||||
def __init__(self, dataset, num_replicas, rank):
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.loader = DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None)
|
||||
self.iterator = None
|
||||
self.observation_space = gym.spaces.Discrete(255)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
if self.iterator is None:
|
||||
self.iterator = iter(self.loader)
|
||||
try:
|
||||
self.current_sample, self.step_count = next(self.iterator)
|
||||
self.current_step = 0
|
||||
return self.current_sample
|
||||
except StopIteration:
|
||||
self.iterator = None
|
||||
return generate_nan_observation(self.observation_space)
|
||||
|
||||
def step(self, action):
|
||||
self.current_step += 1
|
||||
assert self.current_step <= self.step_count
|
||||
return (
|
||||
0,
|
||||
1.0,
|
||||
self.current_step >= self.step_count,
|
||||
{"sample": self.current_sample, "action": action, "metric": 2.0},
|
||||
)
|
||||
|
||||
|
||||
class FiniteEnvWithComplexObs(FiniteEnv):
|
||||
def __init__(self, dataset, num_replicas, rank):
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.loader = DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None)
|
||||
self.iterator = None
|
||||
self.observation_space = gym.spaces.Discrete(255)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
if self.iterator is None:
|
||||
self.iterator = iter(self.loader)
|
||||
try:
|
||||
self.current_sample, self.step_count = next(self.iterator)
|
||||
self.current_step = 0
|
||||
return _test_space.sample()
|
||||
except StopIteration:
|
||||
self.iterator = None
|
||||
return generate_nan_observation(self.observation_space)
|
||||
|
||||
def step(self, action):
|
||||
self.current_step += 1
|
||||
assert self.current_step <= self.step_count
|
||||
return (
|
||||
_test_space.sample(),
|
||||
1.0,
|
||||
self.current_step >= self.step_count,
|
||||
{"sample": _test_space.sample(), "action": action, "metric": 2.0},
|
||||
)
|
||||
|
||||
|
||||
class DummyDataset(Dataset):
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
self.episodes = [3 * i % 5 + 1 for i in range(self.length)]
|
||||
|
||||
def __getitem__(self, index):
|
||||
assert 0 <= index < self.length
|
||||
return index, self.episodes[index]
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
class AnyPolicy(BasePolicy):
|
||||
def forward(self, batch, state=None):
|
||||
return Batch(act=np.stack([1] * len(batch)))
|
||||
|
||||
def learn(self, batch):
|
||||
pass
|
||||
|
||||
|
||||
def _finite_env_factory(dataset, num_replicas, rank, complex=False):
|
||||
if complex:
|
||||
return lambda: FiniteEnvWithComplexObs(dataset, num_replicas, rank)
|
||||
return lambda: FiniteEnv(dataset, num_replicas, rank)
|
||||
|
||||
|
||||
class MetricTracker(LogWriter):
|
||||
def __init__(self, length):
|
||||
super().__init__()
|
||||
self.counter = Counter()
|
||||
self.finished = set()
|
||||
self.length = length
|
||||
|
||||
def on_env_step(self, env_id, obs, rew, done, info):
|
||||
assert rew == 1.0
|
||||
index = info["sample"]
|
||||
if done:
|
||||
# assert index not in self.finished
|
||||
self.finished.add(index)
|
||||
self.counter[index] += 1
|
||||
|
||||
def validate(self):
|
||||
assert len(self.finished) == self.length
|
||||
for k, v in self.counter.items():
|
||||
assert v == k * 3 % 5 + 1
|
||||
|
||||
|
||||
class DoNothingTracker(LogWriter):
|
||||
def on_env_step(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def test_finite_dummy_vector_env():
|
||||
length = 100
|
||||
dataset = DummyDataset(length)
|
||||
envs = FiniteDummyVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)])
|
||||
envs._collector_guarded = True
|
||||
policy = AnyPolicy()
|
||||
test_collector = Collector(policy, envs, exploration_noise=True)
|
||||
|
||||
for _ in range(1):
|
||||
envs._logger = [MetricTracker(length)]
|
||||
try:
|
||||
test_collector.collect(n_step=10**18)
|
||||
except StopIteration:
|
||||
envs._logger[0].validate()
|
||||
|
||||
|
||||
def test_finite_shmem_vector_env():
|
||||
length = 100
|
||||
dataset = DummyDataset(length)
|
||||
envs = FiniteShmemVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)])
|
||||
envs._collector_guarded = True
|
||||
policy = AnyPolicy()
|
||||
test_collector = Collector(policy, envs, exploration_noise=True)
|
||||
|
||||
for _ in range(1):
|
||||
envs._logger = [MetricTracker(length)]
|
||||
try:
|
||||
test_collector.collect(n_step=10**18)
|
||||
except StopIteration:
|
||||
envs._logger[0].validate()
|
||||
|
||||
|
||||
def test_finite_subproc_vector_env():
|
||||
length = 100
|
||||
dataset = DummyDataset(length)
|
||||
envs = FiniteSubprocVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)])
|
||||
envs._collector_guarded = True
|
||||
policy = AnyPolicy()
|
||||
test_collector = Collector(policy, envs, exploration_noise=True)
|
||||
|
||||
for _ in range(1):
|
||||
envs._logger = [MetricTracker(length)]
|
||||
try:
|
||||
test_collector.collect(n_step=10**18)
|
||||
except StopIteration:
|
||||
envs._logger[0].validate()
|
||||
|
||||
|
||||
def test_nan():
|
||||
assert check_nan_observation(generate_nan_observation(_test_space))
|
||||
assert not check_nan_observation(_test_space.sample())
|
||||
|
||||
|
||||
def test_finite_dummy_vector_env_complex():
|
||||
length = 100
|
||||
dataset = DummyDataset(length)
|
||||
envs = FiniteDummyVectorEnv(
|
||||
DoNothingTracker(), [_finite_env_factory(dataset, 5, i, complex=True) for i in range(5)]
|
||||
)
|
||||
envs._collector_guarded = True
|
||||
policy = AnyPolicy()
|
||||
test_collector = Collector(policy, envs, exploration_noise=True)
|
||||
|
||||
try:
|
||||
test_collector.collect(n_step=10**18)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
|
||||
def test_finite_shmem_vector_env_complex():
|
||||
length = 100
|
||||
dataset = DummyDataset(length)
|
||||
envs = FiniteShmemVectorEnv(
|
||||
DoNothingTracker(), [_finite_env_factory(dataset, 5, i, complex=True) for i in range(5)]
|
||||
)
|
||||
envs._collector_guarded = True
|
||||
policy = AnyPolicy()
|
||||
test_collector = Collector(policy, envs, exploration_noise=True)
|
||||
|
||||
try:
|
||||
test_collector.collect(n_step=10**18)
|
||||
except StopIteration:
|
||||
pass
|
||||
156
tests/rl/test_logger.py
Normal file
156
tests/rl/test_logger.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from random import randint, choice
|
||||
from pathlib import Path
|
||||
|
||||
import re
|
||||
import gym
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from gym import spaces
|
||||
from tianshou.data import Collector, Batch
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
from qlib.log import set_log_with_config
|
||||
from qlib.config import C
|
||||
from qlib.constant import INF
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
||||
from qlib.rl.simulator import Simulator
|
||||
from qlib.rl.utils.data_queue import DataQueue
|
||||
from qlib.rl.utils.env_wrapper import InfoDict, EnvWrapper
|
||||
from qlib.rl.utils.log import LogLevel, LogCollector, CsvWriter, ConsoleWriter
|
||||
from qlib.rl.utils.finite_env import vectorize_env
|
||||
|
||||
|
||||
class SimpleEnv(gym.Env[int, int]):
|
||||
def __init__(self):
|
||||
self.logger = LogCollector()
|
||||
self.observation_space = gym.spaces.Discrete(2)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
self.step_count = 0
|
||||
return 0
|
||||
|
||||
def step(self, action: int):
|
||||
self.logger.reset()
|
||||
|
||||
self.logger.add_scalar("reward", 42.0)
|
||||
|
||||
self.logger.add_scalar("a", randint(1, 10))
|
||||
self.logger.add_array("b", pd.DataFrame({"a": [1, 2], "b": [3, 4]}))
|
||||
|
||||
if self.step_count >= 3:
|
||||
done = choice([False, True])
|
||||
else:
|
||||
done = False
|
||||
|
||||
if 2 <= self.step_count <= 3:
|
||||
self.logger.add_scalar("c", randint(11, 20))
|
||||
|
||||
self.step_count += 1
|
||||
|
||||
return 1, 42.0, done, InfoDict(log=self.logger.logs(), aux_info={})
|
||||
|
||||
|
||||
class AnyPolicy(BasePolicy):
|
||||
def forward(self, batch, state=None):
|
||||
return Batch(act=np.stack([1] * len(batch)))
|
||||
|
||||
def learn(self, batch):
|
||||
pass
|
||||
|
||||
|
||||
def test_simple_env_logger(caplog):
|
||||
set_log_with_config(C.logging_config)
|
||||
for venv_cls_name in ["dummy", "shmem", "subproc"]:
|
||||
writer = ConsoleWriter()
|
||||
csv_writer = CsvWriter(Path(__file__).parent / ".output")
|
||||
venv = vectorize_env(lambda: SimpleEnv(), venv_cls_name, 4, [writer, csv_writer])
|
||||
with venv.collector_guard():
|
||||
collector = Collector(AnyPolicy(), venv)
|
||||
collector.collect(n_episode=30)
|
||||
|
||||
output_file = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
|
||||
assert output_file.columns.tolist() == ["reward", "a", "c"]
|
||||
assert len(output_file) >= 30
|
||||
|
||||
line_counter = 0
|
||||
for line in caplog.text.splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
line_counter += 1
|
||||
assert re.match(r".*reward 42\.0000 \(42.0000\) a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
|
||||
assert line_counter >= 3
|
||||
|
||||
|
||||
class SimpleSimulator(Simulator[int, float, float]):
|
||||
def __init__(self, initial: int, **kwargs) -> None:
|
||||
self.initial = float(initial)
|
||||
|
||||
def step(self, action: float) -> None:
|
||||
import torch
|
||||
|
||||
self.initial += action
|
||||
self.env.logger.add_scalar("test_a", torch.tensor(233.0))
|
||||
self.env.logger.add_scalar("test_b", np.array(200))
|
||||
|
||||
def get_state(self) -> float:
|
||||
return self.initial
|
||||
|
||||
def done(self) -> bool:
|
||||
return self.initial % 1 > 0.5
|
||||
|
||||
|
||||
class DummyStateInterpreter(StateInterpreter[float, float]):
|
||||
def interpret(self, state: float) -> float:
|
||||
return state
|
||||
|
||||
@property
|
||||
def observation_space(self) -> spaces.Box:
|
||||
return spaces.Box(0, np.inf, shape=(), dtype=np.float32)
|
||||
|
||||
|
||||
class DummyActionInterpreter(ActionInterpreter[float, int, float]):
|
||||
def interpret(self, state: float, action: int) -> float:
|
||||
return action / 100
|
||||
|
||||
@property
|
||||
def action_space(self) -> spaces.Box:
|
||||
return spaces.Discrete(5)
|
||||
|
||||
|
||||
class RandomFivePolicy(BasePolicy):
|
||||
def forward(self, batch, state=None):
|
||||
return Batch(act=np.random.randint(5, size=len(batch)))
|
||||
|
||||
def learn(self, batch):
|
||||
pass
|
||||
|
||||
|
||||
def test_logger_with_env_wrapper():
|
||||
with DataQueue(list(range(20)), shuffle=False) as data_iterator:
|
||||
env_wrapper_factory = lambda: EnvWrapper(
|
||||
SimpleSimulator,
|
||||
DummyStateInterpreter(),
|
||||
DummyActionInterpreter(),
|
||||
data_iterator,
|
||||
logger=LogCollector(LogLevel.DEBUG),
|
||||
)
|
||||
|
||||
# loglevel can be debug here because metrics can all dump into csv
|
||||
# otherwise, csv writer might crash
|
||||
csv_writer = CsvWriter(Path(__file__).parent / ".output", loglevel=LogLevel.DEBUG)
|
||||
venv = vectorize_env(env_wrapper_factory, "shmem", 4, csv_writer)
|
||||
with venv.collector_guard():
|
||||
collector = Collector(RandomFivePolicy(), venv)
|
||||
collector.collect(n_episode=INF * len(venv))
|
||||
|
||||
output_df = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
|
||||
assert len(output_df) == 20
|
||||
# obs has a increasing trend
|
||||
assert output_df["obs"].to_numpy()[:10].sum() < output_df["obs"].to_numpy()[10:].sum()
|
||||
assert (output_df["test_a"] == 233).all()
|
||||
assert (output_df["test_b"] == 200).all()
|
||||
assert "steps_per_episode" in output_df and "reward" in output_df
|
||||
308
tests/rl/test_saoe_simple.py
Normal file
308
tests/rl/test_saoe_simple.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
from tianshou.data import Batch
|
||||
|
||||
from qlib.backtest import Order
|
||||
from qlib.config import C
|
||||
from qlib.log import set_log_with_config
|
||||
from qlib.rl.data import pickle_styled
|
||||
from qlib.rl.entries.test import backtest
|
||||
from qlib.rl.order_execution import *
|
||||
from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
|
||||
|
||||
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
|
||||
|
||||
|
||||
DATA_ROOT_DIR = Path(__file__).parent.parent / ".data" / "rl" / "intraday_saoe"
|
||||
DATA_DIR = DATA_ROOT_DIR / "us"
|
||||
BACKTEST_DATA_DIR = DATA_DIR / "backtest"
|
||||
FEATURE_DATA_DIR = DATA_DIR / "processed"
|
||||
ORDER_DIR = DATA_DIR / "order" / "valid_bidir"
|
||||
|
||||
CN_DATA_DIR = DATA_ROOT_DIR / "cn"
|
||||
CN_BACKTEST_DATA_DIR = CN_DATA_DIR / "backtest"
|
||||
CN_FEATURE_DATA_DIR = CN_DATA_DIR / "processed"
|
||||
CN_ORDER_DIR = CN_DATA_DIR / "order" / "test"
|
||||
CN_POLICY_WEIGHTS_DIR = CN_DATA_DIR / "weights"
|
||||
|
||||
|
||||
def test_pickle_data_inspect():
|
||||
data = pickle_styled.load_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0)
|
||||
assert len(data) == 390
|
||||
|
||||
data = pickle_styled.load_intraday_processed_data(
|
||||
DATA_DIR / "processed", "AAL", "2013-12-11", 5, data.get_time_index()
|
||||
)
|
||||
assert len(data.today) == len(data.yesterday) == 390
|
||||
|
||||
|
||||
def test_simulator_first_step():
|
||||
order = Order("AAL", 30.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
|
||||
|
||||
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
|
||||
state = simulator.get_state()
|
||||
assert state.cur_time == pd.Timestamp("2013-12-11 09:30:00")
|
||||
assert state.position == 30.0
|
||||
|
||||
simulator.step(15.0)
|
||||
state = simulator.get_state()
|
||||
assert len(state.history_exec) == 30
|
||||
assert state.history_exec.index[0] == pd.Timestamp("2013-12-11 09:30:00")
|
||||
assert state.history_exec["market_volume"].iloc[0] == 450072.0
|
||||
assert abs(state.history_exec["market_price"].iloc[0] - 25.370001) < 1e-4
|
||||
assert (state.history_exec["amount"] == 0.5).all()
|
||||
assert (state.history_exec["deal_amount"] == 0.5).all()
|
||||
assert abs(state.history_exec["trade_price"].iloc[0] - 25.370001) < 1e-4
|
||||
assert abs(state.history_exec["trade_value"].iloc[0] - 12.68500) < 1e-4
|
||||
assert state.history_exec["position"].iloc[0] == 29.5
|
||||
assert state.history_exec["ffr"].iloc[0] == 1 / 60
|
||||
|
||||
assert state.history_steps["market_volume"].iloc[0] == 5041147.0
|
||||
assert state.history_steps["amount"].iloc[0] == 15.0
|
||||
assert state.history_steps["deal_amount"].iloc[0] == 15.0
|
||||
assert state.history_steps["ffr"].iloc[0] == 0.5
|
||||
assert (
|
||||
state.history_steps["pa"].iloc[0]
|
||||
== (state.history_steps["trade_price"].iloc[0] / simulator.twap_price - 1) * 10000
|
||||
)
|
||||
|
||||
assert state.position == 15.0
|
||||
assert state.cur_time == pd.Timestamp("2013-12-11 10:00:00")
|
||||
|
||||
|
||||
def test_simulator_stop_twap():
|
||||
order = Order("AAL", 13.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
|
||||
|
||||
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
|
||||
for _ in range(13):
|
||||
simulator.step(1.0)
|
||||
|
||||
state = simulator.get_state()
|
||||
assert len(state.history_exec) == 390
|
||||
assert (state.history_exec["deal_amount"] == 13 / 390).all()
|
||||
assert state.history_steps["position"].iloc[0] == 12 and state.history_steps["position"].iloc[-1] == 0
|
||||
|
||||
assert (state.metrics["ffr"] - 1) < 1e-3
|
||||
assert abs(state.metrics["market_price"] - state.backtest_data.get_deal_price().mean()) < 1e-4
|
||||
assert np.isclose(state.metrics["market_volume"], state.backtest_data.get_volume().sum())
|
||||
assert state.position == 0.0
|
||||
assert abs(state.metrics["trade_price"] - state.metrics["market_price"]) < 1e-4
|
||||
assert abs(state.metrics["pa"]) < 1e-2
|
||||
|
||||
assert simulator.done()
|
||||
|
||||
|
||||
def test_simulator_stop_early():
|
||||
order = Order("AAL", 1.0, 1, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
|
||||
simulator.step(2.0)
|
||||
|
||||
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
|
||||
simulator.step(1.0)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
simulator.step(1.0)
|
||||
|
||||
|
||||
def test_simulator_start_middle():
|
||||
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
|
||||
|
||||
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
|
||||
assert len(simulator.ticks_for_order) == 330
|
||||
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
|
||||
simulator.step(2.0)
|
||||
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:30:00")
|
||||
|
||||
for _ in range(10):
|
||||
simulator.step(1.0)
|
||||
|
||||
simulator.step(2.0)
|
||||
assert len(simulator.history_exec) == 330
|
||||
assert simulator.done()
|
||||
assert abs(simulator.history_exec["amount"].iloc[-1] - (1 + 2 / 15)) < 1e-4
|
||||
assert abs(simulator.metrics["ffr"] - 1) < 1e-4
|
||||
|
||||
|
||||
def test_interpreter():
|
||||
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
|
||||
|
||||
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
|
||||
assert len(simulator.ticks_for_order) == 330
|
||||
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
|
||||
|
||||
# emulate a env status
|
||||
class EmulateEnvWrapper(NamedTuple):
|
||||
status: EnvWrapperStatus
|
||||
|
||||
interpreter = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5)
|
||||
interpreter_step = CurrentStepStateInterpreter(13)
|
||||
interpreter_action = CategoricalActionInterpreter(20)
|
||||
interpreter_action_twap = TwapRelativeActionInterpreter()
|
||||
|
||||
wrapper_status_kwargs = dict(initial_state=order, obs_history=[], action_history=[], reward_history=[])
|
||||
|
||||
# first step
|
||||
interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=0, done=False, **wrapper_status_kwargs))
|
||||
|
||||
obs = interpreter(simulator.get_state())
|
||||
assert obs["cur_tick"] == 45
|
||||
assert obs["cur_step"] == 0
|
||||
assert obs["position"] == 15.0
|
||||
assert obs["position_history"][0] == 15.0
|
||||
assert all(np.sum(obs["data_processed"][i]) != 0 for i in range(45))
|
||||
assert np.sum(obs["data_processed"][45:]) == 0
|
||||
assert obs["data_processed_prev"].shape == (390, 5)
|
||||
|
||||
# first step: second interpreter
|
||||
interpreter_step.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=0, done=False, **wrapper_status_kwargs))
|
||||
|
||||
obs = interpreter_step(simulator.get_state())
|
||||
assert obs["acquiring"] == 1
|
||||
assert obs["position"] == 15.0
|
||||
|
||||
# second step
|
||||
simulator.step(5.0)
|
||||
interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=1, done=False, **wrapper_status_kwargs))
|
||||
|
||||
obs = interpreter(simulator.get_state())
|
||||
assert obs["cur_tick"] == 60
|
||||
assert obs["cur_step"] == 1
|
||||
assert obs["position"] == 10.0
|
||||
assert obs["position_history"][:2].tolist() == [15.0, 10.0]
|
||||
assert all(np.sum(obs["data_processed"][i]) != 0 for i in range(60))
|
||||
assert np.sum(obs["data_processed"][60:]) == 0
|
||||
|
||||
# second step: action
|
||||
action = interpreter_action(simulator.get_state(), 1)
|
||||
assert action == 15 / 20
|
||||
|
||||
interpreter_action_twap.env = EmulateEnvWrapper(
|
||||
status=EnvWrapperStatus(cur_step=1, done=False, **wrapper_status_kwargs)
|
||||
)
|
||||
action = interpreter_action_twap(simulator.get_state(), 1.5)
|
||||
assert action == 1.5
|
||||
|
||||
# fast-forward
|
||||
for _ in range(10):
|
||||
simulator.step(0.0)
|
||||
|
||||
# last step
|
||||
simulator.step(5.0)
|
||||
interpreter.env = EmulateEnvWrapper(
|
||||
status=EnvWrapperStatus(cur_step=12, done=simulator.done(), **wrapper_status_kwargs)
|
||||
)
|
||||
|
||||
assert interpreter.env.status["done"]
|
||||
|
||||
obs = interpreter(simulator.get_state())
|
||||
assert obs["cur_tick"] == 375
|
||||
assert obs["cur_step"] == 12
|
||||
assert obs["position"] == 0.0
|
||||
assert obs["position_history"][1:11].tolist() == [10.0] * 10
|
||||
assert all(np.sum(obs["data_processed"][i]) != 0 for i in range(375))
|
||||
assert np.sum(obs["data_processed"][375:]) == 0
|
||||
|
||||
|
||||
def test_network_sanity():
|
||||
# we won't check the correctness of networks here
|
||||
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 9:30:00"), pd.Timestamp("2013-12-11 15:59:59"))
|
||||
|
||||
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
|
||||
assert len(simulator.ticks_for_order) == 390
|
||||
|
||||
class EmulateEnvWrapper(NamedTuple):
|
||||
status: EnvWrapperStatus
|
||||
|
||||
interpreter = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5)
|
||||
action_interp = CategoricalActionInterpreter(13)
|
||||
|
||||
wrapper_status_kwargs = dict(initial_state=order, obs_history=[], action_history=[], reward_history=[])
|
||||
|
||||
network = Recurrent(interpreter.observation_space)
|
||||
policy = PPO(network, interpreter.observation_space, action_interp.action_space, 1e-3)
|
||||
|
||||
for i in range(14):
|
||||
interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=i, done=False, **wrapper_status_kwargs))
|
||||
obs = interpreter(simulator.get_state())
|
||||
batch = Batch(obs=[obs])
|
||||
output = policy(batch)
|
||||
assert 0 <= output["act"].item() <= 13
|
||||
if i < 13:
|
||||
simulator.step(1.0)
|
||||
else:
|
||||
assert obs["cur_tick"] == 389
|
||||
assert obs["cur_step"] == 12
|
||||
assert obs["position_history"][-1] == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize("finite_env_type", ["dummy", "subproc", "shmem"])
|
||||
def test_twap_strategy(finite_env_type):
|
||||
set_log_with_config(C.logging_config)
|
||||
orders = pickle_styled.load_orders(ORDER_DIR)
|
||||
assert len(orders) == 248
|
||||
|
||||
state_interp = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5)
|
||||
action_interp = TwapRelativeActionInterpreter()
|
||||
policy = AllOne(state_interp.observation_space, action_interp.action_space)
|
||||
csv_writer = CsvWriter(Path(__file__).parent / ".output")
|
||||
|
||||
backtest(
|
||||
partial(SingleAssetOrderExecution, data_dir=BACKTEST_DATA_DIR, ticks_per_step=30),
|
||||
state_interp,
|
||||
action_interp,
|
||||
orders,
|
||||
policy,
|
||||
[ConsoleWriter(total_episodes=len(orders)), csv_writer],
|
||||
concurrency=4,
|
||||
finite_env_type=finite_env_type,
|
||||
)
|
||||
|
||||
metrics = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
|
||||
assert len(metrics) == 248
|
||||
assert np.isclose(metrics["ffr"].mean(), 1.0)
|
||||
assert np.isclose(metrics["pa"].mean(), 0.0)
|
||||
assert np.allclose(metrics["pa"], 0.0, atol=2e-3)
|
||||
|
||||
|
||||
def test_cn_ppo_strategy():
|
||||
set_log_with_config(C.logging_config)
|
||||
# The data starts with 9:31 and ends with 15:00
|
||||
orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58"))
|
||||
assert len(orders) == 40
|
||||
|
||||
state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6)
|
||||
action_interp = CategoricalActionInterpreter(4)
|
||||
network = Recurrent(state_interp.observation_space)
|
||||
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
|
||||
policy.load_state_dict(torch.load(CN_POLICY_WEIGHTS_DIR / "ppo_recurrent_30min.pth", map_location="cpu"))
|
||||
csv_writer = CsvWriter(Path(__file__).parent / ".output")
|
||||
|
||||
backtest(
|
||||
partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
|
||||
state_interp,
|
||||
action_interp,
|
||||
orders,
|
||||
policy,
|
||||
[ConsoleWriter(total_episodes=len(orders)), csv_writer],
|
||||
concurrency=4,
|
||||
)
|
||||
|
||||
metrics = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
|
||||
assert len(metrics) == len(orders)
|
||||
assert np.isclose(metrics["ffr"].mean(), 1.0)
|
||||
assert np.isclose(metrics["pa"].mean(), -16.21578303474833)
|
||||
assert np.isclose(metrics["market_price"].mean(), 58.68277690875527)
|
||||
assert np.isclose(metrics["trade_price"].mean(), 58.76063985000002)
|
||||
Reference in New Issue
Block a user