1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-13 01:11:00 +08:00

Compare commits

...

18 Commits

Author SHA1 Message Date
you-n-g
a02ac95538 add gym (#1104) 2022-05-21 23:50:18 +08:00
you-n-g
cc94c32db6 init_instance_by_config enhancement (#1103)
* fix SepDataFrame when we del it to empty

* init_instance_by_config enhancement

* Update test_sepdf.py
2022-05-21 20:16:22 +08:00
Yuge Zhang
9a40fd3cdc Qlib RL framework (stage 1) - single-asset order execution (#1076)
* rl init

* aux info

* Reward config

* update

* simple

* update saoe init

* update simulator and seed

* minor

* minor

* update sim

* checkpoint

* obs

* Update interpreter

* init qlib simulator

* checkpoint

* Refine codebase

* checkpoint

* checkpoint

* Add one test

* More tests

* Simulator checkpoint

* checkpoint

* First-step tested

* Checkpoint

* Update data_queue API

* Checkpoint

* Update test

* Move files

* Checkpoint

* Single-quote -> double-quote

* Fix finite env tests

* Tested with mypy

* pep-574

* No call for env done

* Update finite env docs

* Fix csv writer

* Refine tester

* Update logger

* Add another logger test

* Checkpoint

* Add network sanity test

* steps per episode is not correct

* Cleanup code, ready for PR

* Reformat with black

* Fix pylint for py37

* Fix lint

* Fix lint

* Fix flake

* update mypy command

* mypy

* Update exclude pattern

* Use pyproject.toml

* test

* .

* .

* Refactor pipeline

* .

* defaults run bash

* .

* Revert and skip follow_imports

* Fix toml issue

* fix mypy

* .

* .

* .

* Fix install

* Minor fix

* Fix test

* Fix test

* Remove requirements

* Revert

* fix tests

* Fix lint

* .

* .

* .

* .

* .

* update install from source command

* .

* Fix data download

* .

* .

* .

* .

* .

* .

* Fix py37

* Ignore tests on non-linux

* resolve comments

* fix tests

* resolve comments

* some typo

* style updates

* More comments

* fix dummy

* add warning

* Align precision in some system

* Added some impl notes

Co-authored-by: Young <afe.young@gmail.com>
2022-05-21 18:19:24 +08:00
you-n-g
c4281121e3 Update README.md (#1091)
* Update README.md

* Fix typo
2022-05-08 20:19:19 +08:00
Linlang
2de9903200 fix_issue_1060 (#1092)
* fix_issue_1060

* fix_import_error
2022-05-07 20:59:06 +08:00
Linlang
2cf842bcfe add_test_pit (#1089)
* add_test_pit

* add_test_pit_to_tests

* add_baostock_to_setup

* add_pip_to_CI

Co-authored-by: Linlang Lv (iSoftStone) <v-linlanglv@microsoft.com>
2022-05-06 16:47:20 +08:00
you-n-g
9e381493c2 Add instructions to add models (#1088) 2022-05-05 21:27:24 +08:00
Chia-hung Tai
a73b60d05a Update detailed_workflow.ipynb (#1084)
time_per_step bug.
2022-05-03 15:11:27 +08:00
you-n-g
64979ad769 Yahoo data Docs (#1077) 2022-04-29 17:24:53 +08:00
you-n-g
c5cf8fb9cc fix est_sepdf.py with black 2022-04-29 17:21:20 +08:00
Linlang
5d579d1a20 fix_macos_CI (#1081)
Co-authored-by: Linlang Lv (iSoftStone) <v-linlanglv@microsoft.com>
2022-04-29 17:04:28 +08:00
you-n-g
3c9c76b384 fix SepDataFrame when we del it to empty (#1082) 2022-04-29 14:29:17 +08:00
you-n-g
9d0a8f61d1 Make sepdf more like DataFrame (#1080) 2022-04-28 19:13:45 +08:00
Linlang
701b18af1b fix_issue_715 (#1070)
* fix_issue_715

* fix_issue_1065

Co-authored-by: Linlang Lv (iSoftStone) <v-linlanglv@microsoft.com>
2022-04-28 16:09:31 +08:00
Hubedge
84ff662a26 Fixed pandas FutureWarning (#1073)
* Fixed pandas FutureWarning

`FutureWarning: Passing a set as an indexer is deprecated and will raise in a future version. Use a list instead.`

* fixed another pandas FutureWarning

```
scripts/data_collector/index.py:228: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.
  new_df = new_df.append(_tmp_df, sort=False)
```

* fixed more pandas futurewarnings
2022-04-27 18:43:26 +08:00
金戈
00e40e775b Fixed typos in workflow.rst (#1068)
* Update workflow.rst

Fixed a typo. `please refer to Qlib Model` should be `please refer to Qlib Data` in Dataset section.

* Fix typo. `preprossing` should be `preprocessing`

* Update data.rst

Remove extra `of`.
2022-04-27 18:36:47 +08:00
code-review-doctor
45fe5e6974 Fix issue probably-meant-fstring found at https://codereview.doctor (#1072) 2022-04-25 16:12:40 +08:00
you-n-g
366a9c33f3 Bump to Dev Version 2022-04-25 16:11:47 +08:00
61 changed files with 3960 additions and 245 deletions

View File

@@ -35,7 +35,7 @@ jobs:
pip install numpy==1.19.5 ruamel.yaml
pip install pyqlib --ignore-installed
- name: Make html with sphnix
- name: Make html with sphinx
run: |
pip install -U sphinx
pip install sphinx_rtd_theme readthedocs_sphinx_ext
@@ -97,12 +97,21 @@ jobs:
run: |
pip install --upgrade pip
pip install flake8
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 qlib
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
# https://github.com/python/mypy/issues/10600
- name: Check Qlib with mypy
run: |
pip install mypy
mypy qlib --install-types --non-interactive || true
mypy qlib
- name: Test data downloads
run: |
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data_simple --interval 1d --region cn
python -c "import os; userpath=os.path.expanduser('~'); os.rename(userpath + '/.qlib/qlib_data/cn_data_simple', userpath + '/.qlib/qlib_data/cn_data')"
azcopy copy https://qlibpublic.blob.core.windows.net/data /tmp/qlibpublic --recursive
mv /tmp/qlibpublic/data tests/.data
- name: Test workflow by config (install from pip)
run: |
@@ -113,6 +122,7 @@ jobs:
- name: Install Qlib from source
run: |
pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
pip install gym tianshou torch
pip install -e .
- name: Install test dependencies
@@ -122,10 +132,10 @@ jobs:
- name: Unit tests with Pytest
run: |
pip install -r scripts/data_collector/pit/requirements.txt
cd tests
python -m pytest . --durations=10
- name: Test workflow by config (install from source)
run: |
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml

View File

@@ -35,11 +35,10 @@ jobs:
# Test Qlib installed with pip
- name: Check Qlib with flake8
run: |
pip install --upgrade pip
pip install flake8
cd ..
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 qlib
run: |
pip install --upgrade pip
pip install flake8
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
- name: Install Qlib with pip
run: |
@@ -66,6 +65,8 @@ jobs:
run: |
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data_simple --interval 1d --region cn
python -c "import os; userpath=os.path.expanduser('~'); os.rename(userpath + '/.qlib/qlib_data/cn_data_simple', userpath + '/.qlib/qlib_data/cn_data')"
azcopy copy https://qlibpublic.blob.core.windows.net/data /tmp/qlibpublic --recursive
mv /tmp/qlibpublic/data tests/.data
- name: Test workflow by config (install from pip)
run: |
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
@@ -76,6 +77,7 @@ jobs:
python -m pip install --upgrade cython
python -m pip install numpy jupyter jupyter_contrib_nbextensions
python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
python -m pip install gym tianshou torch
pip install -e .
- name: Install test dependencies
run: |
@@ -84,6 +86,7 @@ jobs:
python -m pip install black pytest
- name: Unit tests with Pytest
run: |
pip install -r scripts/data_collector/pit/requirements.txt
cd tests
python -m pytest . --durations=0
- name: Test workflow by config (install from source)

5
.gitignore vendored
View File

@@ -27,6 +27,10 @@ examples/estimator/estimator_example/
*.egg-info/
# test related
test-output.xml
.output
.data
# special software
mlruns/
@@ -34,6 +38,7 @@ mlruns/
tags
.pytest_cache/
.mypy_cache/
.vscode/
*.swp

17
.mypy.ini Normal file
View File

@@ -0,0 +1,17 @@
[mypy]
exclude = (?x)(
^qlib/backtest
| ^qlib/contrib
| ^qlib/data
| ^qlib/model
| ^qlib/strategy
| ^qlib/tests
| ^qlib/utils
| ^qlib/workflow
| ^qlib/config\.py$
| ^qlib/log\.py$
| ^qlib/__init__\.py$
)
ignore_missing_imports = true
disallow_incomplete_defs = true
follow_imports = skip

View File

@@ -32,7 +32,7 @@ Recent released features
| High-frequency data processing example | :hammer: [Released](https://github.com/microsoft/qlib/pull/257) on Feb 5, 2021 |
| High-frequency trading example | :chart_with_upwards_trend: [Part of code released](https://github.com/microsoft/qlib/pull/227) on Jan 28, 2021 |
| High-frequency data(1min) | :rice: [Released](https://github.com/microsoft/qlib/pull/221) on Jan 27, 2021 |
| Tabnet Model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/205) on Jan 22, 2021 |
| Tabnet Model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/205) on Jan 22, 2021 |
Features released before 2021 are not listed here.
@@ -474,7 +474,7 @@ If you don't know how to start to contribute, you can refer to the following exa
| Docs | [Improve docs quality](https://github.com/microsoft/qlib/pull/797/files) ; [Fix a typo](https://github.com/microsoft/qlib/pull/774) |
| Feature | Implement a [requested feature](https://github.com/microsoft/qlib/projects) like [this](https://github.com/microsoft/qlib/pull/754); [Refactor interfaces](https://github.com/microsoft/qlib/pull/539/files) |
| Dataset | [Add a dataset](https://github.com/microsoft/qlib/pull/733) |
| Models | [Implement a new model](https://github.com/microsoft/qlib/pull/689) |
| Models | [Implement a new model](https://github.com/microsoft/qlib/pull/689), [some instructions to contribute models](https://github.com/microsoft/qlib/tree/main/examples/benchmarks#contributing) |
[Good first issues](https://github.com/microsoft/qlib/labels/good%20first%20issue) are labelled to indicate that they are easy to start your contributions.

View File

@@ -437,7 +437,7 @@ Dataset
The ``Dataset`` module in ``Qlib`` aims to prepare data for model training and inferencing.
The motivation of this module is that we want to maximize the flexibility of of different models to handle data that are suitable for themselves. This module gives the model the flexibility to process their data in an unique way. For instance, models such as ``GBDT`` may work well on data that contains `nan` or `None` value, while neural networks such as ``MLP`` will break down on such data.
The motivation of this module is that we want to maximize the flexibility of different models to handle data that are suitable for themselves. This module gives the model the flexibility to process their data in an unique way. For instance, models such as ``GBDT`` may work well on data that contains `nan` or `None` value, while neural networks such as ``MLP`` will break down on such data.
If user's model need process its data in a different way, user could implement his own ``Dataset`` class. If the model's
data processing is not special, ``DatasetH`` can be used directly.

View File

@@ -104,7 +104,7 @@ Graphical Result
- Axis Y:
- `ic`
The `Pearson correlation coefficient` series between `label` and `prediction score`.
In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Feature <data.html#feature>`_ for more details.
In the above example, the `label` is formulated as `Ref($close, -2)/Ref($close, -1)-1`. Please refer to `Data Feature <data.html#feature>`_ for more details.
- `rank_ic`
The `Spearman's rank correlation coefficient` series between `label` and `prediction score`.

View File

@@ -233,7 +233,7 @@ The meaning of each field is as follows:
Dataset Section
~~~~~~~~~~~~~~~~~~~~
The `dataset` field describes the parameters for the ``Dataset`` module in ``Qlib`` as well those for the module ``DataHandler``. For more information about the ``Dataset`` module, please refer to `Qlib Model <../component/data.html#dataset>`_.
The `dataset` field describes the parameters for the ``Dataset`` module in ``Qlib`` as well those for the module ``DataHandler``. For more information about the ``Dataset`` module, please refer to `Qlib Data <../component/data.html#dataset>`_.
The keywords arguments configuration of the ``DataHandler`` is as follows:
@@ -248,7 +248,7 @@ The keywords arguments configuration of the ``DataHandler`` is as follows:
Users can refer to the document of `DataHandler <../component/data.html#datahandler>`_ for more information about the meaning of each field in the configuration.
Here is the configuration for the ``Dataset`` module which will take care of data preprossing and slicing during the training and testing phase.
Here is the configuration for the ``Dataset`` module which will take care of data preprocessing and slicing during the training and testing phase.
.. code-block:: YAML

View File

@@ -6,3 +6,4 @@
[https://www.ijcai.org/Proceedings/2017/0366.pdf](https://www.ijcai.org/Proceedings/2017/0366.pdf)
- NOTE: Current version of implementation is just a simplified version of ALSTM. It is an LSTM with attention.

View File

@@ -78,3 +78,20 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
- The metrics can be categorized into two
- Signal-based evaluation: IC, ICIR, Rank IC, Rank ICIR
- Portfolio-based metrics: Annualized Return, Information Ratio, Max Drawdown
# Contributing
Your contributions to new models are highly welcome!
If you want to contribute your new models, you can follow the steps below.
1. Create a folder for your model
2. The folder contains following items(you can refer to [this example](https://github.com/microsoft/qlib/tree/main/examples/benchmarks/TCTS)).
- `requirements.txt`: required dependencies.
- `README.md`: a brief introduction to your models
- `workflow_config_<model name>_<dataset>.yaml`: a configuration which can read by `qrun`. You are encouraged to run your model in all datasets.
3. You can integrate your model as a module [in this folder](https://github.com/microsoft/qlib/tree/main/qlib/contrib/model).
4. Please updated your results in the benchmark tables, e.g. [Alpha360](#alpha158-dataset), [Alpha158](#alpha158-dataset)(the values of each metric are the mean and std calculated based on 20 runs with different random seeds, if you don't have enough computational resource, you can ask for help in the PR).
5. Update the info in the index page in the [news list](https://github.com/microsoft/qlib#newspaper-whats-new----sparkling_heart) and [model list](https://github.com/microsoft/qlib#quant-model-paper-zoo).
Finally, you can send PR for review. ([here is an example](https://github.com/microsoft/qlib/pull/1040))

View File

@@ -967,10 +967,10 @@
"###################################\n",
"port_analysis_config = {\n",
" \"executor\": {\n",
" \"time_per_step\"\n",
" \"class\": \"SimulatorExecutor\",\n",
" \"module_path\": \"qlib.backtest.executor\",\n",
" \"kwargs\": {: \"day\",\n",
" \"kwargs\": {\n",
" \"time_per_step\": \"day\",\n",
" \"generate_portfolio_metrics\": True,\n",
" },\n",
" },\n",

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
from pathlib import Path
__version__ = "0.8.5"
__version__ = "0.8.5.99"
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
import os
from typing import Union

View File

@@ -8,3 +8,6 @@ REG_TW = "tw"
# Epsilon for avoiding division by zero.
EPS = 1e-12
# Infinity in integer
INF = 10**18

View File

@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
from typing import Dict, Iterable
from typing import Dict, Iterable, Union
def align_index(df_dict, join):
@@ -24,6 +24,10 @@ class SepDataFrame:
SepDataFrame tries to act like a DataFrame whose column with multiindex
"""
# TODO:
# SepDataFrame try to behave like pandas dataframe, but it is still not them same
# Contributions are welcome to make it more complete.
def __init__(self, df_dict: Dict[str, pd.DataFrame], join: str, skip_align=False):
"""
initialize the data based on the dataframe dictionary
@@ -77,14 +81,37 @@ class SepDataFrame:
def _update_join(self):
if self.join not in self:
self.join = next(iter(self._df_dict.keys()))
if len(self._df_dict) > 0:
self.join = next(iter(self._df_dict.keys()))
else:
# NOTE: this will change the behavior of previous reindex when all the keys are empty
self.join = None
def __getitem__(self, item):
# TODO: behave more like pandas when multiindex
return self._df_dict[item]
def __setitem__(self, item: str, df: pd.DataFrame):
def __setitem__(self, item: str, df: Union[pd.DataFrame, pd.Series]):
# TODO: consider the join behavior
self._df_dict[item] = df
if not isinstance(item, tuple):
self._df_dict[item] = df
else:
# NOTE: corner case of MultiIndex
_df_dict_key, *col_name = item
col_name = tuple(col_name)
if _df_dict_key in self._df_dict:
if len(col_name) == 1:
col_name = col_name[0]
self._df_dict[_df_dict_key][col_name] = df
else:
if isinstance(df, pd.Series):
if len(col_name) == 1:
col_name = col_name[0]
self._df_dict[_df_dict_key] = df.to_frame(col_name)
else:
df_copy = df.copy() # avoid changing df
df_copy.columns = pd.MultiIndex.from_tuples([(*col_name, *idx) for idx in df.columns.to_list()])
self._df_dict[_df_dict_key] = df_copy
def __delitem__(self, item: str):
del self._df_dict[item]

View File

@@ -123,7 +123,7 @@ def pred_autocorr(pred: pd.Series, lag=1, inst_col="instrument", date_col="datet
"""
if isinstance(pred, pd.DataFrame):
pred = pred.iloc[:, 0]
get_module_logger("pred_autocorr").warning("Only the first column in {pred.columns} of `pred` is kept")
get_module_logger("pred_autocorr").warning(f"Only the first column in {pred.columns} of `pred` is kept")
pred_ustk = pred.sort_index().unstack(inst_col)
corr_s = {}
for (idx, cur), (_, prev) in zip(pred_ustk.iterrows(), pred_ustk.shift(lag).iterrows()):

View File

@@ -144,7 +144,7 @@ class ADARNN(Model):
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self.fitted = False
self.model.cuda()
self.model.to(self.device)
@property
def use_gpu(self):
@@ -153,7 +153,7 @@ class ADARNN(Model):
def train_AdaRNN(self, train_loader_list, epoch, dist_old=None, weight_mat=None):
self.model.train()
criterion = nn.MSELoss()
dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda()
dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)
len_loader = np.inf
for loader in train_loader_list:
if len(loader) < len_loader:
@@ -165,7 +165,7 @@ class ADARNN(Model):
list_label = []
for data in data_all:
# feature :[36, 24, 6]
feature, label_reg = data[0].cuda().float(), data[1].cuda().float()
feature, label_reg = data[0].to(self.device).float(), data[1].to(self.device).float()
list_feat.append(feature)
list_label.append(label_reg)
flag = False
@@ -179,7 +179,7 @@ class ADARNN(Model):
if flag:
continue
total_loss = torch.zeros(1).cuda()
total_loss = torch.zeros(1).to(self.device)
for i, n in enumerate(index):
feature_s = list_feat[n[0]]
feature_t = list_feat[n[1]]
@@ -325,7 +325,7 @@ class ADARNN(Model):
else:
end = begin + self.batch_size
x_batch = torch.from_numpy(x_values[begin:end]).float().cuda()
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
with torch.no_grad():
pred = self.model.predict(x_batch).detach().cpu().numpy()
@@ -335,7 +335,7 @@ class ADARNN(Model):
return pd.Series(np.concatenate(preds), index=index)
def transform_type(self, init_weight):
weight = torch.ones(self.num_layers, self.len_seq).cuda()
weight = torch.ones(self.num_layers, self.len_seq).to(self.device)
for i in range(self.num_layers):
for j in range(self.len_seq):
weight[i, j] = init_weight[i][j].item()
@@ -389,6 +389,7 @@ class AdaRNN(nn.Module):
len_seq=9,
model_type="AdaRNN",
trans_loss="mmd",
GPU=0,
):
super(AdaRNN, self).__init__()
self.use_bottleneck = use_bottleneck
@@ -399,6 +400,7 @@ class AdaRNN(nn.Module):
self.model_type = model_type
self.trans_loss = trans_loss
self.len_seq = len_seq
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
in_size = self.n_input
features = nn.ModuleList()
@@ -455,7 +457,7 @@ class AdaRNN(nn.Module):
out_list_all, out_weight_list = out[1], out[2]
out_list_s, out_list_t = self.get_features(out_list_all)
loss_transfer = torch.zeros((1,)).cuda()
loss_transfer = torch.zeros((1,)).to(self.device)
for i, n in enumerate(out_list_s):
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])
h_start = 0
@@ -516,12 +518,12 @@ class AdaRNN(nn.Module):
out_list_all = out[1]
out_list_s, out_list_t = self.get_features(out_list_all)
loss_transfer = torch.zeros((1,)).cuda()
loss_transfer = torch.zeros((1,)).to(self.device)
if weight_mat is None:
weight = (1.0 / self.len_seq * torch.ones(self.num_layers, self.len_seq)).cuda()
weight = (1.0 / self.len_seq * torch.ones(self.num_layers, self.len_seq)).to(self.device)
else:
weight = weight_mat
dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda()
dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)
for i, n in enumerate(out_list_s):
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])
for j in range(self.len_seq):
@@ -553,12 +555,13 @@ class AdaRNN(nn.Module):
class TransferLoss:
def __init__(self, loss_type="cosine", input_dim=512):
def __init__(self, loss_type="cosine", input_dim=512, GPU=0):
"""
Supported loss_type: mmd(mmd_lin), mmd_rbf, coral, cosine, kl, js, mine, adv
"""
self.loss_type = loss_type
self.input_dim = input_dim
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
def compute(self, X, Y):
"""Compute adaptation loss
@@ -574,7 +577,7 @@ class TransferLoss:
mmdloss = MMD_loss(kernel_type="linear")
loss = mmdloss(X, Y)
elif self.loss_type == "coral":
loss = CORAL(X, Y)
loss = CORAL(X, Y, self.device)
elif self.loss_type in ("cosine", "cos"):
loss = 1 - cosine(X, Y)
elif self.loss_type == "kl":
@@ -582,10 +585,10 @@ class TransferLoss:
elif self.loss_type == "js":
loss = js(X, Y)
elif self.loss_type == "mine":
mine_model = Mine_estimator(input_dim=self.input_dim, hidden_dim=60).cuda()
mine_model = Mine_estimator(input_dim=self.input_dim, hidden_dim=60).to(self.device)
loss = mine_model(X, Y)
elif self.loss_type == "adv":
loss = adv(X, Y, input_dim=self.input_dim, hidden_dim=32)
loss = adv(X, Y, self.device, input_dim=self.input_dim, hidden_dim=32)
elif self.loss_type == "mmd_rbf":
mmdloss = MMD_loss(kernel_type="rbf")
loss = mmdloss(X, Y)
@@ -630,12 +633,12 @@ class Discriminator(nn.Module):
return x
def adv(source, target, input_dim=256, hidden_dim=512):
def adv(source, target, device, input_dim=256, hidden_dim=512):
domain_loss = nn.BCELoss()
# !!! Pay attention to .cuda !!!
adv_net = Discriminator(input_dim, hidden_dim).cuda()
domain_src = torch.ones(len(source)).cuda()
domain_tar = torch.zeros(len(target)).cuda()
adv_net = Discriminator(input_dim, hidden_dim).to(device)
domain_src = torch.ones(len(source)).to(device)
domain_tar = torch.zeros(len(target)).to(device)
domain_src, domain_tar = domain_src.view(domain_src.shape[0], 1), domain_tar.view(domain_tar.shape[0], 1)
reverse_src = ReverseLayerF.apply(source, 1)
reverse_tar = ReverseLayerF.apply(target, 1)
@@ -646,16 +649,16 @@ def adv(source, target, input_dim=256, hidden_dim=512):
return loss
def CORAL(source, target):
def CORAL(source, target, device):
d = source.size(1)
ns, nt = source.size(0), target.size(0)
# source covariance
tmp_s = torch.ones((1, ns)).cuda() @ source
tmp_s = torch.ones((1, ns)).to(device) @ source
cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)
# target covariance
tmp_t = torch.ones((1, nt)).cuda() @ target
tmp_t = torch.ones((1, nt)).to(device) @ target
ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)
# frobenius norm

View File

@@ -68,9 +68,9 @@ def parse_position(position: dict = None) -> pd.DataFrame:
if not _trading_day_sell_df.empty:
_trading_day_sell_df["status"] = -1
_trading_day_sell_df["date"] = _trading_date
_trading_day_df = _trading_day_df.append(_trading_day_sell_df, sort=False)
_trading_day_df = pd.concat([_trading_day_df, _trading_day_sell_df], sort=False)
result_df = result_df.append(_trading_day_df, sort=True)
result_df = pd.concat([result_df, _trading_day_df], sort=True)
previous_data = dict(
date=_trading_date,

View File

@@ -85,7 +85,7 @@ def _get_monthly_risk_analysis_with_report(report_normal_df: pd.DataFrame) -> pd
# _m_report_long_short,
pd.Timestamp(year=gp_m[0], month=gp_m[1], day=month_days),
)
_monthly_df = _monthly_df.append(_temp_df, sort=False)
_monthly_df = pd.concat([_monthly_df, _temp_df], sort=False)
return _monthly_df

View File

@@ -61,7 +61,11 @@ def get_module_logger(module_name, level: Optional[int] = None) -> QlibLogger:
if level is None:
level = C.logging_level
module_name = "qlib.{}".format(module_name)
if not module_name.startswith("qlib."):
# Add a prefix of qlib. when the requested ``module_name`` doesn't start with ``qlib.``.
# If the module_name is already qlib.xxx, we do not format here. Otherwise, it will become qlib.qlib.xxx.
module_name = "qlib.{}".format(module_name)
# Get logger.
module_logger = QlibLogger(module_name)
module_logger.setLevel(level)

43
qlib/rl/aux_info.py Normal file
View File

@@ -0,0 +1,43 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import Generic, TYPE_CHECKING, TypeVar
from qlib.typehint import final
from .simulator import StateType
if TYPE_CHECKING:
from .utils.env_wrapper import EnvWrapper
__all__ = ["AuxiliaryInfoCollector"]
AuxInfoType = TypeVar("AuxInfoType")
class AuxiliaryInfoCollector(Generic[StateType, AuxInfoType]):
"""Override this class to collect customized auxiliary information from environment."""
env: EnvWrapper | None = None
@final
def __call__(self, simulator_state: StateType) -> AuxInfoType:
return self.collect(simulator_state)
def collect(self, simulator_state: StateType) -> AuxInfoType:
"""Override this for customized auxiliary info.
Usually useful in Multi-agent RL.
Parameters
----------
simulator_state
Retrieved with ``simulator.get_state()``.
Returns
-------
Auxiliary information.
"""
raise NotImplementedError("collect is not implemented!")

8
qlib/rl/data/__init__.py Normal file
View File

@@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Common utilities to handle ad-hoc-styled data.
Most of these snippets comes from research project (paper code).
Please take caution when using them in production.
"""

View File

@@ -0,0 +1,257 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""This module contains utilities to read financial data from pickle-styled files.
This is the format used in `OPD paper <https://seqml.github.io/opd/>`__. NOT the standard data format in qlib.
The data here are all wrapped with ``@lru_cache``, which saves the expensive IO cost to repetitively read the data.
We also encourage users to use ``get_xxx_yyy`` rather than ``XxxYyy`` (although they are the same thing),
because ``get_xxx_yyy`` is cache-optimized.
Note that these pickle files are dumped with Python 3.8. Python lower than 3.7 might not be able to load them.
See `PEP 574 <https://peps.python.org/pep-0574/>`__ for details.
This file shows resemblence to qlib.backtest.high_performance_ds. We might merge those two in future.
"""
# TODO: merge with qlib/backtest/high_performance_ds.py
from __future__ import annotations
from functools import lru_cache
from typing import List, Sequence, cast
from pathlib import Path
import cachetools
import numpy as np
import pandas as pd
from cachetools.keys import hashkey
from qlib.backtest.decision import OrderDir, Order
from qlib.typehint import Literal
DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"]
"""Several ad-hoc deal price.
``bid_or_ask``: If sell, use column ``$bid0``; if buy, use column ``$ask0``.
``bid_or_ask_fill``: Based on ``bid_or_ask``. If price is 0, use another price (``$ask0`` / ``$bid0``) instead.
``close``: Use close price (``$close0``) as deal price.
"""
def _infer_processed_data_column_names(shape: int) -> list[str]:
if shape == 16:
return [
"$open",
"$high",
"$low",
"$close",
"$vwap",
"$bid",
"$ask",
"$volume",
"$bidV",
"$bidV1",
"$bidV3",
"$bidV5",
"$askV",
"$askV1",
"$askV3",
"$askV5",
]
if shape == 6:
return ["$high", "$low", "$open", "$close", "$vwap", "$volume"]
elif shape == 5:
return ["$high", "$low", "$open", "$close", "$volume"]
raise ValueError(f"Unrecognized data shape: {shape}")
def _find_pickle(filename_without_suffix: Path) -> Path:
suffix_list = [".pkl", ".pkl.backtest"]
paths: List[Path] = []
for suffix in suffix_list:
path = filename_without_suffix.parent / (filename_without_suffix.name + suffix)
if path.exists():
paths.append(path)
if not paths:
raise FileNotFoundError(f"No file starting with '{filename_without_suffix}' found")
if len(paths) > 1:
raise ValueError(f"Multiple paths are found with prefix '{filename_without_suffix}': {paths}")
return paths[0]
@lru_cache(maxsize=10) # 10 * 40M = 400MB
def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame:
return pd.read_pickle(_find_pickle(filename_without_suffix))
class IntradayBacktestData:
"""Raw market data that is often used in backtesting (thus called BacktestData)."""
def __init__(
self,
data_dir: Path,
stock_id: str,
date: pd.Timestamp,
deal_price: DealPriceType = "close",
order_dir: int | None = None,
):
backtest = _read_pickle(data_dir / stock_id)
backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]]
# No longer need for pandas >= 1.4
# backtest = backtest.droplevel([0, 2])
self.data: pd.DataFrame = backtest
self.deal_price_type: DealPriceType = deal_price
self.order_dir: int | None = order_dir
def __repr__(self):
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
return f"{self.__class__.__name__}({self.data})"
def __len__(self):
return len(self.data)
def get_deal_price(self) -> pd.Series:
"""Return a pandas series that can be indexed with time.
See :attribute:`DealPriceType` for details."""
if self.deal_price_type in ("bid_or_ask", "bid_or_ask_fill"):
if self.order_dir is None:
raise ValueError("Order direction cannot be none when deal_price_type is not close.")
if self.order_dir == OrderDir.SELL:
col = "$bid0"
else: # BUY
col = "$ask0"
elif self.deal_price_type == "close":
col = "$close0"
else:
raise ValueError(f"Unsupported deal_price_type: {self.deal_price_type}")
price = self.data[col]
if self.deal_price_type == "bid_or_ask_fill":
if self.order_dir == OrderDir.SELL:
fill_col = "$ask0"
else:
fill_col = "$bid0"
price = price.replace(0, np.nan).fillna(self.data[fill_col])
return price
def get_volume(self) -> pd.Series:
"""Return a volume series that can be indexed with time."""
return self.data["$volume0"]
def get_time_index(self) -> pd.DatetimeIndex:
return cast(pd.DatetimeIndex, self.data.index)
class IntradayProcessedData:
"""Processed market data after data cleanup and feature engineering.
It contains both processed data for "today" and "yesterday", as some algorithms
might use the market information of the previous day to assist decision making.
"""
today: pd.DataFrame
"""Processed data for "today".
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
yesterday: pd.DataFrame
"""Processed data for "yesterday".
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
def __init__(self, data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index):
proc = _read_pickle(data_dir / stock_id)
# We have to infer the names here because,
# unfortunately they are not included in the original data.
cnames = _infer_processed_data_column_names(feature_dim)
time_length: int = len(time_index)
try:
# new data format
proc = proc.loc[pd.IndexSlice[stock_id, :, date]]
assert len(proc) == time_length and len(proc.columns) == feature_dim * 2
proc_today = proc[cnames]
proc_yesterday = proc[[f"{c}_1" for c in cnames]].rename(columns=lambda c: c[:-2])
except (IndexError, KeyError):
# legacy data
proc = proc.loc[pd.IndexSlice[stock_id, date]]
assert time_length * feature_dim * 2 == len(proc)
proc_today = proc.to_numpy()[: time_length * feature_dim].reshape((time_length, feature_dim))
proc_yesterday = proc.to_numpy()[time_length * feature_dim :].reshape((time_length, feature_dim))
proc_today = pd.DataFrame(proc_today, index=time_index, columns=cnames)
proc_yesterday = pd.DataFrame(proc_yesterday, index=time_index, columns=cnames)
self.today: pd.DataFrame = proc_today
self.yesterday: pd.DataFrame = proc_yesterday
assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim
assert len(self.today) == len(self.yesterday) == time_length
def __repr__(self):
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
return f"{self.__class__.__name__}({self.today}, {self.yesterday})"
@lru_cache(maxsize=100) # 100 * 50K = 5MB
def load_intraday_backtest_data(
data_dir: Path, stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", order_dir: int | None = None
) -> IntradayBacktestData:
return IntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
key=lambda data_dir, stock_id, date, _, __: hashkey(data_dir, stock_id, date),
)
def load_intraday_processed_data(
data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index
) -> IntradayProcessedData:
return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
def load_orders(
order_path: Path, start_time: pd.Timestamp | None = None, end_time: pd.Timestamp | None = None
) -> Sequence[Order]:
"""Load orders, and set start time and end time for the orders."""
start_time = start_time or pd.Timestamp("0:00:00")
end_time = end_time or pd.Timestamp("23:59:59")
if order_path.is_file():
order_df = pd.read_pickle(order_path)
else:
order_df = []
for file in order_path.iterdir():
order_data = pd.read_pickle(file)
order_df.append(order_data)
order_df = pd.concat(order_df)
order_df = order_df.reset_index()
# Legacy-style orders have "date" instead of "datetime"
if "date" in order_df.columns:
order_df = order_df.rename(columns={"date": "datetime"})
# Sometimes "date" are str rather than Timestamp
order_df["datetime"] = pd.to_datetime(order_df["datetime"])
orders: List[Order] = []
for _, row in order_df.iterrows():
# filter out orders with amount == 0
if row["amount"] <= 0:
continue
orders.append(
Order(
row["instrument"],
row["amount"],
int(row["order_type"]),
row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second),
row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second),
)
)
return orders

View File

@@ -0,0 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Train, test, inference utilities.
The APIs in this directory are NOT considered final and are subject to change!
"""

99
qlib/rl/entries/test.py Normal file
View File

@@ -0,0 +1,99 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import copy
from typing import Callable, Sequence
from tianshou.data import Collector
from tianshou.policy import BasePolicy
from qlib.constant import INF
from qlib.log import get_module_logger
from qlib.rl.simulator import InitialStateType, Simulator
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.reward import Reward
from qlib.rl.utils import DataQueue, EnvWrapper, FiniteEnvType, LogCollector, LogWriter, vectorize_env
_logger = get_module_logger(__name__)
def backtest(
simulator_fn: Callable[[InitialStateType], Simulator],
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
initial_states: Sequence[InitialStateType],
policy: BasePolicy,
logger: LogWriter | list[LogWriter],
reward: Reward | None = None,
finite_env_type: FiniteEnvType = "subproc",
concurrency: int = 2,
) -> None:
"""Backtest with the parallelism provided by RL framework.
Parameters
----------
simulator_fn
Callable receiving initial seed, returning a simulator.
state_interpreter
Interprets the state of simulators.
action_interpreter
Interprets the policy actions.
initial_states
Initial states to iterate over. Every state will be run exactly once.
policy
Policy to test against.
logger
Logger to record the backtest results. Logger must be present because
without logger, all information will be lost.
reward
Optional reward function. For backtest, this is for testing the rewards
and logging them only.
finite_env_type
Type of finite env implementation.
concurrency
Parallel workers.
"""
# To save bandwidth
min_loglevel = min(lg.loglevel for lg in logger) if isinstance(logger, list) else logger.loglevel
def env_factory():
# FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
# and could be thread unsafe.
# I'm not sure whether it's a design flaw.
# I'll rethink about this when designing the trainer.
if finite_env_type == "dummy":
# We could only experience the "threading-unsafe" problem in dummy.
state = copy.deepcopy(state_interpreter)
action = copy.deepcopy(action_interpreter)
rew = copy.deepcopy(reward)
else:
state, action, rew = state_interpreter, action_interpreter, reward
return EnvWrapper(
simulator_fn,
state,
action,
seed_iterator,
rew,
logger=LogCollector(min_loglevel=min_loglevel),
)
with DataQueue(initial_states) as seed_iterator:
vector_env = vectorize_env(
env_factory,
finite_env_type,
concurrency,
logger,
)
policy.eval()
with vector_env.collector_guard():
test_collector = Collector(policy, vector_env)
_logger.info("All ready. Start backtest.")
test_collector.collect(n_step=INF * len(vector_env))

4
qlib/rl/entries/train.py Normal file
View File

@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# TBD

View File

@@ -1,94 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Union
from ..backtest.executor import BaseExecutor
from .interpreter import StateInterpreter, ActionInterpreter
from ..utils import init_instance_by_config
class BaseRLEnv:
"""Base environment for reinforcement learning"""
def reset(self, **kwargs):
raise NotImplementedError("reset is not implemented!")
def step(self, action):
"""
step method of rl env
Parameters
----------
action :
action from rl policy
Returns
-------
env state to rl policy
"""
raise NotImplementedError("step is not implemented!")
class QlibRLEnv:
"""qlib-based RL env"""
def __init__(
self,
executor: BaseExecutor,
):
"""
Parameters
----------
executor : BaseExecutor
qlib multi-level/single-level executor, which can be regarded as gamecore in RL
"""
self.executor = executor
def reset(self, **kwargs):
self.executor.reset(**kwargs)
class QlibIntRLEnv(QlibRLEnv):
"""(Qlib)-based RL (Env) with (Interpreter)"""
def __init__(
self,
executor: BaseExecutor,
state_interpreter: Union[dict, StateInterpreter],
action_interpreter: Union[dict, ActionInterpreter],
):
"""
Parameters
----------
state_interpreter : Union[dict, StateInterpreter]
interpreter that interprets the qlib execute result into rl env state.
action_interpreter : Union[dict, ActionInterpreter]
interpreter that interprets the rl agent action into qlib order list
"""
super(QlibIntRLEnv, self).__init__(executor=executor)
self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter)
self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter)
def step(self, action):
"""
step method of rl env, it run as following step:
- Use `action_interpreter.interpret` method to interpret the agent action into order list
- Execute the order list with qlib executor, and get the executed result
- Use `state_interpreter.interpret` method to interpret the executed result into env state
Parameters
----------
action :
action from rl policy
Returns
-------
env state to rl policy
"""
_interpret_decision = self.action_interpreter.interpret(action=action)
_execute_result = self.executor.execute(trade_decision=_interpret_decision)
_interpret_state = self.state_interpreter.interpret(execute_result=_execute_result)
return _interpret_state

View File

@@ -1,47 +1,150 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
class BaseInterpreter:
"""Base Interpreter"""
from typing import TYPE_CHECKING, TypeVar, Generic, Any
def interpret(self, **kwargs):
raise NotImplementedError("interpret is not implemented!")
import numpy as np
from qlib.typehint import final
from .simulator import StateType, ActType
if TYPE_CHECKING:
from .utils.env_wrapper import EnvWrapper
import gym
from gym import spaces
ObsType = TypeVar("ObsType")
PolicyActType = TypeVar("PolicyActType")
class ActionInterpreter(BaseInterpreter):
"""Action Interpreter that interpret rl agent action into qlib orders"""
class Interpreter:
"""Interpreter is a media between states produced by simulators and states needed by RL policies.
Interpreters are two-way:
def interpret(self, action, **kwargs):
"""interpret method
1. From simulator state to policy state (aka observation), see :class:`StateInterpreter`.
2. From policy action to action accepted by simulator, see :class:`ActionInterpreter`.
Inherit one of the two sub-classes to define your own interpreter.
This super-class is only used for isinstance check.
Interpreters are recommended to be stateless, meaning that storing temporary information with ``self.xxx``
in interpreter is anti-pattern. In future, we might support register some interpreter-related
states by calling ``self.env.register_state()``, but it's not planned for first iteration.
"""
class StateInterpreter(Generic[StateType, ObsType], Interpreter):
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
env: EnvWrapper | None = None
@property
def observation_space(self) -> gym.Space:
raise NotImplementedError()
@final # no overridden
def __call__(self, simulator_state: StateType) -> ObsType:
obs = self.interpret(simulator_state)
self.validate(obs)
return obs
def validate(self, obs: ObsType) -> None:
"""Validate whether an observation belongs to the pre-defined observation space."""
_gym_space_contains(self.observation_space, obs)
def interpret(self, simulator_state: StateType) -> ObsType:
"""Interpret the state of simulator.
Parameters
----------
action :
rl agent action
simulator_state
Retrieved with ``simulator.get_state()``.
Returns
-------
qlib orders
State needed by policy. Should conform with the state space defined in ``observation_space``.
"""
raise NotImplementedError("interpret is not implemented!")
class StateInterpreter(BaseInterpreter):
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter):
"""Action Interpreter that interpret rl agent action into qlib orders"""
def interpret(self, execute_result, **kwargs):
"""interpret method
env: "EnvWrapper" | None = None
@property
def action_space(self) -> gym.Space:
raise NotImplementedError()
@final # no overridden
def __call__(self, simulator_state: StateType, action: PolicyActType) -> ActType:
self.validate(action)
obs = self.interpret(simulator_state, action)
return obs
def validate(self, action: PolicyActType) -> None:
"""Validate whether an action belongs to the pre-defined action space."""
_gym_space_contains(self.action_space, action)
def interpret(self, simulator_state: StateType, action: PolicyActType) -> ActType:
"""Convert the policy action to simulator action.
Parameters
----------
execute_result :
qlib execution result
simulator_state
Retrieved with ``simulator.get_state()``.
action
Raw action given by policy.
Returns
----------
rl env state
-------
The action needed by simulator,
"""
raise NotImplementedError("interpret is not implemented!")
def _gym_space_contains(space: gym.Space, x: Any) -> None:
"""Strengthened version of gym.Space.contains.
Giving more diagnostic information on why validation fails.
Throw exception rather than returning true or false.
"""
if isinstance(space, spaces.Dict):
if not isinstance(x, dict) or len(x) != len(space):
raise GymSpaceValidationError("Sample must be a dict with same length as space.", space, x)
for k, subspace in space.spaces.items():
if k not in x:
raise GymSpaceValidationError(f"Key {k} not found in sample.", space, x)
try:
_gym_space_contains(subspace, x[k])
except GymSpaceValidationError as e:
raise GymSpaceValidationError(f"Subspace of key {k} validation error.", space, x) from e
elif isinstance(space, spaces.Tuple):
if isinstance(x, (list, np.ndarray)):
x = tuple(x) # Promote list and ndarray to tuple for contains check
if not isinstance(x, tuple) or len(x) != len(space):
raise GymSpaceValidationError("Sample must be a tuple with same length as space.", space, x)
for i, (subspace, part) in enumerate(zip(space, x)):
try:
_gym_space_contains(subspace, part)
except GymSpaceValidationError as e:
raise GymSpaceValidationError(f"Subspace of index {i} validation error.", space, x) from e
else:
if not space.contains(x):
raise GymSpaceValidationError("Validation error reported by gym.", space, x)
class GymSpaceValidationError(Exception):
def __init__(self, message: str, space: gym.Space, x: Any):
self.message = message
self.space = space
self.x = x
def __str__(self):
return f"{self.message}\n Space: {self.space}\n Sample: {self.x}"

View File

@@ -0,0 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Currently it supports single-asset order execution.
Multi-asset is on the way.
"""
from .interpreter import *
from .network import *
from .policy import *
from .simulator_simple import *

View File

@@ -0,0 +1,222 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import math
from pathlib import Path
from typing import Any, cast
import numpy as np
import pandas as pd
from gym import spaces
from qlib.constant import EPS
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.data import pickle_styled
from qlib.typehint import TypedDict
from .simulator_simple import SAOEState
__all__ = [
"FullHistoryStateInterpreter",
"CurrentStepStateInterpreter",
"CategoricalActionInterpreter",
"TwapRelativeActionInterpreter",
]
def canonicalize(value: int | float | np.ndarray | pd.DataFrame | dict) -> np.ndarray | dict:
"""To 32-bit numeric types. Recursively."""
if isinstance(value, pd.DataFrame):
return value.to_numpy()
if isinstance(value, (float, np.floating)) or (isinstance(value, np.ndarray) and value.dtype.kind == "f"):
return np.array(value, dtype=np.float32)
elif isinstance(value, (int, bool, np.integer)) or (isinstance(value, np.ndarray) and value.dtype.kind == "i"):
return np.array(value, dtype=np.int32)
elif isinstance(value, dict):
return {k: canonicalize(v) for k, v in value.items()}
else:
return value
class FullHistoryObs(TypedDict):
data_processed: Any
data_processed_prev: Any
acquiring: Any
cur_tick: Any
cur_step: Any
num_step: Any
target: Any
position: Any
position_history: Any
class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
"""The observation of all the history, including today (until this moment), and yesterday.
Parameters
----------
data_dir
Path to load data after feature engineering.
max_step
Total number of steps (an upper-bound estimation). For example, 390min / 30min-per-step = 13 steps.
data_ticks
Equal to the total number of records. For example, in SAOE per minute,
the total ticks is the length of day in minutes.
data_dim
Number of dimensions in data.
"""
def __init__(self, data_dir: Path, max_step: int, data_ticks: int, data_dim: int) -> None:
self.data_dir = data_dir
self.max_step = max_step
self.data_ticks = data_ticks
self.data_dim = data_dim
def interpret(self, state: SAOEState) -> FullHistoryObs:
processed = pickle_styled.load_intraday_processed_data(
self.data_dir,
state.order.stock_id,
pd.Timestamp(state.order.start_time.date()),
self.data_dim,
state.ticks_index,
)
position_history = np.full(self.max_step + 1, 0.0, dtype=np.float32)
position_history[0] = state.order.amount
position_history[1 : len(state.history_steps) + 1] = state.history_steps["position"].to_numpy()
assert self.env is not None
# The min, slice here are to make sure that indices fit into the range,
# even after the final step of the simulator (in the done step),
# to make network in policy happy.
return cast(
FullHistoryObs,
canonicalize(
{
"data_processed": self._mask_future_info(processed.today, state.cur_time),
"data_processed_prev": processed.yesterday,
"acquiring": state.order.direction == state.order.BUY,
"cur_tick": min(np.sum(state.ticks_index < state.cur_time), self.data_ticks - 1),
"cur_step": min(self.env.status["cur_step"], self.max_step - 1),
"num_step": self.max_step,
"target": state.order.amount,
"position": state.position,
"position_history": position_history[: self.max_step],
}
),
)
@property
def observation_space(self):
space = {
"data_processed": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),
"data_processed_prev": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),
"acquiring": spaces.Discrete(2),
"cur_tick": spaces.Box(0, self.data_ticks - 1, shape=(), dtype=np.int32),
"cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32),
# TODO: support arbitrary length index
"num_step": spaces.Box(self.max_step, self.max_step, shape=(), dtype=np.int32),
"target": spaces.Box(-EPS, np.inf, shape=()),
"position": spaces.Box(-EPS, np.inf, shape=()),
"position_history": spaces.Box(-EPS, np.inf, shape=(self.max_step,)),
}
return spaces.Dict(space)
@staticmethod
def _mask_future_info(arr: pd.DataFrame, current: pd.Timestamp) -> pd.DataFrame:
arr = arr.copy(deep=True)
arr.loc[current:] = 0.0 # mask out data after this moment (inclusive)
return arr
class CurrentStateObs(TypedDict):
acquiring: bool
cur_step: int
num_step: int
target: float
position: float
class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
"""The observation of current step.
Used when policy only depends on the latest state, but not history.
The key list is not full. You can add more if more information is needed by your policy.
"""
def __init__(self, max_step: int):
self.max_step = max_step
@property
def observation_space(self):
space = {
"acquiring": spaces.Discrete(2),
"cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32),
"num_step": spaces.Box(self.max_step, self.max_step, shape=(), dtype=np.int32),
"target": spaces.Box(-EPS, np.inf, shape=()),
"position": spaces.Box(-EPS, np.inf, shape=()),
}
return spaces.Dict(space)
def interpret(self, state: SAOEState) -> CurrentStateObs:
assert self.env is not None
assert self.env.status["cur_step"] <= self.max_step
obs = CurrentStateObs(
{
"acquiring": state.order.direction == state.order.BUY,
"cur_step": self.env.status["cur_step"],
"num_step": self.max_step,
"target": state.order.amount,
"position": state.position,
}
)
return obs
class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
"""Convert a discrete policy action to a continuous action, then multiplied by ``order.amount``.
Parameters
----------
values
It can be a list of length $L$: $[a_1, a_2, \\ldots, a_L]$.
Then when policy givens decision $x$, $a_x$ times order amount is the output.
It can also be an integer $n$, in which case the list of length $n+1$ is auto-generated,
i.e., $[0, 1/n, 2/n, \\ldots, n/n]$.
"""
def __init__(self, values: int | list[float]):
if isinstance(values, int):
values = [i / values for i in range(0, values + 1)]
self.action_values = values
@property
def action_space(self) -> spaces.Discrete:
return spaces.Discrete(len(self.action_values))
def interpret(self, state: SAOEState, action: int) -> float:
assert 0 <= action < len(self.action_values)
return min(state.position, state.order.amount * self.action_values[action])
class TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]):
"""Convert a continous ratio to deal amount.
The ratio is relative to TWAP on the remainder of the day.
For example, there are 5 steps left, and the left position is 300.
With TWAP strategy, in each position, 60 should be traded.
When this interpreter receives action $a$, its output is $60 \\cdot a$.
"""
@property
def action_space(self) -> spaces.Box:
return spaces.Box(0, np.inf, shape=(), dtype=np.float32)
def interpret(self, state: SAOEState, action: float) -> float:
assert self.env is not None
estimated_total_steps = math.ceil(len(state.ticks_for_order) / state.ticks_per_step)
twap_volume = state.position / (estimated_total_steps - self.env.status["cur_step"])
return min(state.position, twap_volume * action)

View File

@@ -0,0 +1,118 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import cast
import torch
import torch.nn as nn
from tianshou.data import Batch
from qlib.typehint import Literal
from .interpreter import FullHistoryObs
__all__ = ["Recurrent"]
class Recurrent(nn.Module):
"""The network architecture proposed in `OPD <https://seqml.github.io/opd/opd_aaai21_supplement.pdf>`_.
At every timestep the input of policy network is divided into two parts,
the public variables and the private variables. which are handled by ``raw_rnn``
and ``pri_rnn`` in this network, respectively.
One minor difference is that, in this implementation, we don't assume the direction to be fixed.
Thus, another ``dire_fc`` is added to produce an extra direction-related feature.
"""
def __init__(
self,
obs_space: FullHistoryObs,
hidden_dim: int = 64,
output_dim: int = 32,
rnn_type: Literal["rnn", "lstm", "gru"] = "gru",
rnn_num_layers: int = 1,
):
super().__init__()
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.num_sources = 3
rnn_classes = {"rnn": nn.RNN, "lstm": nn.LSTM, "gru": nn.GRU}
self.rnn_class = rnn_classes[rnn_type]
self.rnn_layers = rnn_num_layers
self.raw_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers)
self.prev_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers)
self.pri_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers)
self.raw_fc = nn.Sequential(nn.Linear(obs_space["data_processed"].shape[-1], hidden_dim), nn.ReLU())
self.pri_fc = nn.Sequential(nn.Linear(2, hidden_dim), nn.ReLU())
self.dire_fc = nn.Sequential(nn.Linear(2, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU())
self._init_extra_branches()
self.fc = nn.Sequential(
nn.Linear(hidden_dim * self.num_sources, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
nn.ReLU(),
)
def _init_extra_branches(self):
pass
def _source_features(self, obs: FullHistoryObs, device: torch.device) -> tuple[list[torch.Tensor], torch.Tensor]:
bs, _, data_dim = obs["data_processed"].size()
data = torch.cat((torch.zeros(bs, 1, data_dim, device=device), obs["data_processed"]), 1)
cur_step = obs["cur_step"].long()
cur_tick = obs["cur_tick"].long()
bs_indices = torch.arange(bs, device=device)
position = obs["position_history"] / obs["target"].unsqueeze(-1) # [bs, num_step]
steps = (
torch.arange(position.size(-1), device=device).unsqueeze(0).repeat(bs, 1).float()
/ obs["num_step"].unsqueeze(-1).float()
) # [bs, num_step]
priv = torch.stack((position.float(), steps), -1)
data_in = self.raw_fc(data)
data_out, _ = self.raw_rnn(data_in)
# as it is padded with zero in front, this should be last minute
data_out_slice = data_out[bs_indices, cur_tick]
priv_in = self.pri_fc(priv)
priv_out = self.pri_rnn(priv_in)[0]
priv_out = priv_out[bs_indices, cur_step]
sources = [data_out_slice, priv_out]
dir_out = self.dire_fc(torch.stack((obs["acquiring"], 1 - obs["acquiring"]), -1).float())
sources.append(dir_out)
return sources, data_out
def forward(self, batch: Batch) -> torch.Tensor:
"""
Input should be a dict (at least) containing:
- data_processed: [N, T, C]
- cur_step: [N] (int)
- cur_time: [N] (int)
- position_history: [N, S] (S is number of steps)
- target: [N]
- num_step: [N] (int)
- acquiring: [N] (0 or 1)
"""
inp = cast(FullHistoryObs, batch)
device = inp["data_processed"].device
sources, _ = self._source_features(inp, device)
assert len(sources) == self.num_sources
out = torch.cat(sources, -1)
return self.fc(out)

View File

@@ -0,0 +1,158 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
from typing import Optional, cast
import numpy as np
import gym
import torch
import torch.nn as nn
from gym.spaces import Discrete
from tianshou.data import Batch, to_torch
from tianshou.policy import PPOPolicy, BasePolicy
__all__ = ["AllOne", "PPO"]
# baselines #
class NonlearnablePolicy(BasePolicy):
"""Tianshou's BasePolicy with empty ``learn`` and ``process_fn``.
This could be moved outside in future.
"""
def __init__(self, obs_space: gym.Space, action_space: gym.Space):
super().__init__()
def learn(self, batch, batch_size, repeat):
pass
def process_fn(self, batch, buffer, indice):
pass
class AllOne(NonlearnablePolicy):
"""Forward returns a batch full of 1.
Useful when implementing some baselines (e.g., TWAP).
"""
def forward(self, batch, state=None, **kwargs):
return Batch(act=np.full(len(batch), 1.0), state=state)
# ppo #
class PPOActor(nn.Module):
def __init__(self, extractor: nn.Module, action_dim: int):
super().__init__()
self.extractor = extractor
self.layer_out = nn.Sequential(nn.Linear(cast(int, extractor.output_dim), action_dim), nn.Softmax(dim=-1))
def forward(self, obs, state=None, info={}):
feature = self.extractor(to_torch(obs, device=auto_device(self)))
out = self.layer_out(feature)
return out, state
class PPOCritic(nn.Module):
def __init__(self, extractor: nn.Module):
super().__init__()
self.extractor = extractor
self.value_out = nn.Linear(cast(int, extractor.output_dim), 1)
def forward(self, obs, state=None, info={}):
feature = self.extractor(to_torch(obs, device=auto_device(self)))
return self.value_out(feature).squeeze(dim=-1)
class PPO(PPOPolicy):
"""A wrapper of tianshou PPOPolicy.
Differences:
- Auto-create actor and critic network. Supports discrete action space only.
- Dedup common parameters between actor network and critic network
(not sure whether this is included in latest tianshou or not).
- Support a ``weight_file`` that supports loading checkpoint.
- Some parameters' default values are different from original.
"""
def __init__(
self,
network: nn.Module,
obs_space: gym.Space,
action_space: gym.Space,
lr: float,
weight_decay: float = 0.0,
discount_factor: float = 1.0,
max_grad_norm: float = 100.0,
reward_normalization: bool = True,
eps_clip: float = 0.3,
value_clip: float = True,
vf_coef: float = 1.0,
gae_lambda: float = 1.0,
max_batchsize: int = 256,
deterministic_eval: bool = True,
weight_file: Optional[Path] = None,
):
assert isinstance(action_space, Discrete)
actor = PPOActor(network, action_space.n)
critic = PPOCritic(network)
optimizer = torch.optim.Adam(
chain_dedup(actor.parameters(), critic.parameters()), lr=lr, weight_decay=weight_decay
)
super().__init__(
actor,
critic,
optimizer,
torch.distributions.Categorical,
discount_factor=discount_factor,
max_grad_norm=max_grad_norm,
reward_normalization=reward_normalization,
eps_clip=eps_clip,
value_clip=value_clip,
vf_coef=vf_coef,
gae_lambda=gae_lambda,
max_batchsize=max_batchsize,
deterministic_eval=deterministic_eval,
observation_space=obs_space,
action_space=action_space,
)
if weight_file is not None:
load_weight(self, weight_file)
# utilities: these should be put in a separate (common) file. #
def auto_device(module: nn.Module) -> torch.device:
for param in module.parameters():
return param.device
return torch.device("cpu") # fallback to cpu
def load_weight(policy, path):
assert isinstance(policy, nn.Module), "Policy has to be an nn.Module to load weight."
loaded_weight = torch.load(path, map_location="cpu")
try:
policy.load_state_dict(loaded_weight)
except RuntimeError:
# try again by loading the converted weight
# https://github.com/thu-ml/tianshou/issues/468
for k in list(loaded_weight):
loaded_weight["_actor_critic." + k] = loaded_weight[k]
policy.load_state_dict(loaded_weight)
def chain_dedup(*iterables):
seen = set()
for iterable in iterables:
for i in iterable:
if i not in seen:
seen.add(i)
yield i

View File

@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Placeholder for qlib-based simulator."""

View File

@@ -0,0 +1,403 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from pathlib import Path
from typing import NamedTuple, Any, TypeVar, cast
import numpy as np
import pandas as pd
from qlib.backtest.decision import Order, OrderDir
from qlib.constant import EPS
from qlib.rl.simulator import Simulator
from qlib.rl.data.pickle_styled import IntradayBacktestData, load_intraday_backtest_data, DealPriceType
from qlib.rl.utils import LogLevel
from qlib.typehint import TypedDict
__all__ = ["SAOEMetrics", "SAOEState", "SingleAssetOrderExecution"]
ONE_SEC = pd.Timedelta("1s") # use 1 second to exclude the right interval point
class SAOEMetrics(TypedDict):
"""Metrics for SAOE accumulated for a "period".
It could be accumulated for a day, or a period of time (e.g., 30min), or calculated separately for every minute.
Warnings
--------
The type hints are for single elements. In lots of times, they can be vectorized.
For example, ``market_volume`` could be a list of float (or ndarray) rather tahn a single float.
"""
stock_id: str
"""Stock ID of this record."""
datetime: pd.Timestamp
"""Datetime of this record (this is index in the dataframe)."""
direction: int
"""Direction of the order. 0 for sell, 1 for buy."""
# Market information.
market_volume: float
"""(total) market volume traded in the period."""
market_price: float
"""Deal price. If it's a period of time, this is the average market deal price."""
# Strategy records.
amount: float
"""Total amount (volume) strategy intends to trade."""
inner_amount: float
"""Total amount that the lower-level strategy intends to trade
(might be larger than amount, e.g., to ensure ffr)."""
deal_amount: float
"""Amount that successfully takes effect (must be less than inner_amount)."""
trade_price: float
"""The average deal price for this strategy."""
trade_value: float
"""Total worth of trading. In the simple simulaton, trade_value = deal_amount * price."""
position: float
"""Position left after this "period"."""
# Accumulated metrics
ffr: float
"""Completed how much percent of the daily order."""
pa: float
"""Price advantage compared to baseline (i.e., trade with baseline market price).
The baseline is trade price when using TWAP strategy to execute this order.
Please note that there could be data leak here).
Unit is BP (basis point, 1/10000)."""
class SAOEState(NamedTuple):
"""Data structure holding a state for SAOE simulator."""
order: Order
"""The order we are dealing with."""
cur_time: pd.Timestamp
"""Current time, e.g., 9:30."""
position: float
"""Current remaining volume to execute."""
history_exec: pd.DataFrame
"""See :attr:`SingleAssetOrderExecution.history_exec`."""
history_steps: pd.DataFrame
"""See :attr:`SingleAssetOrderExecution.history_steps`."""
metrics: SAOEMetrics | None
"""Daily metric, only available when the trading is in "done" state."""
backtest_data: IntradayBacktestData
"""Backtest data is included in the state.
Actually, only the time index of this data is needed, at this moment.
I include the full data so that algorithms (e.g., VWAP) that relies on the raw data can be implemented.
Interpreter can use this as they wish, but they should be careful not to leak future data.
"""
ticks_per_step: int
"""How many ticks for each step."""
ticks_index: pd.DatetimeIndex
"""Trading ticks in all day, NOT sliced by order (defined in data). e.g., [9:30, 9:31, ..., 14:59]."""
ticks_for_order: pd.DatetimeIndex
"""Trading ticks sliced by order, e.g., [9:45, 9:46, ..., 14:44]."""
class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
"""Single-asset order execution (SAOE) simulator.
As there's no "calendar" in the simple simulator, ticks are used to trade.
A tick is a record (a line) in the pickle-styled data file.
Each tick is considered as a individual trading opportunity.
If such fine granularity is not needed, use ``ticks_per_step`` to
lengthen the ticks for each step.
In each step, the traded amount are "equally" splitted to each tick,
then bounded by volume maximum exeuction volume (i.e., ``vol_threshold``),
and if it's the last step, try to ensure all the amount to be executed.
Parameters
----------
initial
The seed to start an SAOE simulator is an order.
ticks_per_step
How many ticks per step.
data_dir
Path to load backtest data
vol_threshold
Maximum execution volume (divided by market execution volume).
"""
history_exec: pd.DataFrame
"""All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns."""
history_steps: pd.DataFrame
"""Positions at each step. The position before first step is also recorded.
See :class:`SAOEMetrics` for available columns."""
metrics: SAOEMetrics | None
"""Metrics. Only available when done."""
twap_price: float
"""This price is used to compute price advantage.
It"s defined as the average price in the period from order"s start time to end time."""
ticks_index: pd.DatetimeIndex
"""All available ticks for the day (not restricted to order)."""
ticks_for_order: pd.DatetimeIndex
"""Ticks that is available for trading (sliced by order)."""
def __init__(
self,
order: Order,
data_dir: Path,
ticks_per_step: int = 30,
deal_price_type: DealPriceType = "close",
vol_threshold: float | None = None,
) -> None:
self.order = order
self.ticks_per_step: int = ticks_per_step
self.deal_price_type = deal_price_type
self.vol_threshold = vol_threshold
self.data_dir = data_dir
self.backtest_data = load_intraday_backtest_data(
self.data_dir, order.stock_id, pd.Timestamp(order.start_time.date()), self.deal_price_type, order.direction
)
self.ticks_index = self.backtest_data.get_time_index()
# Get time index available for trading
self.ticks_for_order = self._get_ticks_slice(self.order.start_time, self.order.end_time)
self.cur_time = self.ticks_for_order[0]
# NOTE: astype(float) is necessary in some systems.
# this will align the precision with `.to_numpy()` in `_split_exec_vol`
self.twap_price = float(self.backtest_data.get_deal_price().loc[self.ticks_for_order].astype(float).mean())
self.position = order.amount
metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member
# NOTE: can empty dataframe contain index?
self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime")
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
self.metrics = None
self.market_price: np.ndarray | None = None
self.market_vol: np.ndarray | None = None
self.market_vol_limit: np.ndarray | None = None
def step(self, amount: float) -> None:
"""Execute one step or SAOE.
Parameters
----------
amount
The amount you wish to deal. The simulator doesn't guarantee all the amount to be successfully dealt.
"""
assert not self.done()
self.market_price = self.market_vol = None # avoid misuse
exec_vol = self._split_exec_vol(amount)
assert self.market_price is not None and self.market_vol is not None
ticks_position = self.position - np.cumsum(exec_vol)
self.position -= exec_vol.sum()
if self.position < -EPS or (exec_vol < -EPS).any():
raise ValueError(f"Execution volume is invalid: {exec_vol} (position = {self.position})")
# Get time index available for this step
time_index = self._get_ticks_slice(self.cur_time, self._next_time())
self.history_exec = self._dataframe_append(
self.history_exec,
SAOEMetrics(
# It should have the same keys with SAOEMetrics,
# but the values do not necessarily have the annotated type.
# Some values could be vectorized (e.g., exec_vol).
stock_id=self.order.stock_id,
datetime=time_index,
direction=self.order.direction,
market_volume=self.market_vol,
market_price=self.market_price,
amount=exec_vol,
inner_amount=exec_vol,
deal_amount=exec_vol,
trade_price=self.market_price,
trade_value=self.market_price * exec_vol,
position=ticks_position,
ffr=exec_vol / self.order.amount,
pa=price_advantage(self.market_price, self.twap_price, self.order.direction),
),
)
self.history_steps = self._dataframe_append(
self.history_steps,
[self._metrics_collect(self.cur_time, self.market_vol, self.market_price, amount, exec_vol)],
)
if self.done():
if self.env is not None:
self.env.logger.add_any("history_steps", self.history_steps, loglevel=LogLevel.DEBUG)
self.env.logger.add_any("history_exec", self.history_exec, loglevel=LogLevel.DEBUG)
self.metrics = self._metrics_collect(
self.ticks_index[0], # start time
self.history_exec["market_volume"],
self.history_exec["market_price"],
self.history_steps["amount"].sum(),
self.history_exec["deal_amount"],
)
# NOTE (yuge): It looks to me that it's the "correct" decision to
# put all the logs here, because only components like simulators themselves
# have the knowledge about what could appear in the logs, and what's the format.
# But I admit it's not necessarily the most convenient way.
# I'll rethink about it when we have the second environment
# Maybe some APIs like self.logger.enable_auto_log() ?
if self.env is not None:
for key, value in self.metrics.items():
if isinstance(value, float):
self.env.logger.add_scalar(key, value)
else:
self.env.logger.add_any(key, value)
self.cur_time = self._next_time()
def get_state(self) -> SAOEState:
return SAOEState(
order=self.order,
cur_time=self.cur_time,
position=self.position,
history_exec=self.history_exec,
history_steps=self.history_steps,
metrics=self.metrics,
backtest_data=self.backtest_data,
ticks_per_step=self.ticks_per_step,
ticks_index=self.ticks_index,
ticks_for_order=self.ticks_for_order,
)
def done(self) -> bool:
return self.position < EPS or self.cur_time >= self.order.end_time
def _next_time(self) -> pd.Timestamp:
"""The "current time" (``cur_time``) for next step."""
# Look for next time on time index
current_loc = self.ticks_index.get_loc(self.cur_time)
next_loc = current_loc + self.ticks_per_step
# Calibrate the next location to multiple of ticks_per_step.
# This is to make sure that:
# as long as ticks_per_step is a multiple of something, each step won't cross morning and afternoon.
next_loc = next_loc - next_loc % self.ticks_per_step
if next_loc < len(self.ticks_index) and self.ticks_index[next_loc] < self.order.end_time:
return self.ticks_index[next_loc]
else:
return self.order.end_time
def _cur_duration(self) -> pd.Timedelta:
"""The "duration" of this step (step that is about to happen)."""
return self._next_time() - self.cur_time
def _split_exec_vol(self, exec_vol_sum: float) -> np.ndarray:
"""
Split the volume in each step into minutes, considering possible constraints.
This follows TWAP strategy.
"""
next_time = self._next_time()
# get the backtest data for next interval
self.market_vol = self.backtest_data.get_volume().loc[self.cur_time : next_time - ONE_SEC].to_numpy()
self.market_price = self.backtest_data.get_deal_price().loc[self.cur_time : next_time - ONE_SEC].to_numpy()
assert self.market_vol is not None and self.market_price is not None
# split the volume equally into each minute
exec_vol = np.repeat(exec_vol_sum / len(self.market_price), len(self.market_price))
# apply the volume threshold
market_vol_limit = self.vol_threshold * self.market_vol if self.vol_threshold is not None else np.inf
exec_vol = np.minimum(exec_vol, market_vol_limit) # type: ignore
# Complete all the order amount at the last moment.
if next_time >= self.order.end_time:
exec_vol[-1] += self.position - exec_vol.sum()
exec_vol = np.minimum(exec_vol, market_vol_limit) # type: ignore
return exec_vol
def _metrics_collect(
self,
datetime: pd.Timestamp,
market_vol: np.ndarray,
market_price: np.ndarray,
amount: float, # intended to trade such amount
exec_vol: np.ndarray,
) -> SAOEMetrics:
assert len(market_vol) == len(market_price) == len(exec_vol)
if np.abs(np.sum(exec_vol)) < EPS:
exec_avg_price = 0.0
else:
exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan
if hasattr(exec_avg_price, "item"): # could be numpy scalar
exec_avg_price = exec_avg_price.item() # type: ignore
return SAOEMetrics(
stock_id=self.order.stock_id,
datetime=datetime,
direction=self.order.direction,
market_volume=market_vol.sum(),
market_price=market_price.mean(),
amount=amount,
inner_amount=exec_vol.sum(),
deal_amount=exec_vol.sum(), # in this simulator, there's no other restrictions
trade_price=exec_avg_price,
trade_value=np.sum(market_price * exec_vol),
position=self.position,
ffr=float(exec_vol.sum() / self.order.amount),
pa=price_advantage(exec_avg_price, self.twap_price, self.order.direction),
)
def _get_ticks_slice(self, start: pd.Timestamp, end: pd.Timestamp, include_end: bool = False) -> pd.DatetimeIndex:
if not include_end:
end = end - ONE_SEC
return self.ticks_index[self.ticks_index.slice_indexer(start, end)]
@staticmethod
def _dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame:
# dataframe.append is deprecated
other_df = pd.DataFrame(other).set_index("datetime")
other_df.index.name = "datetime"
return pd.concat([df, other_df], axis=0)
_float_or_ndarray = TypeVar("_float_or_ndarray", float, np.ndarray)
def price_advantage(
exec_price: _float_or_ndarray, baseline_price: float, direction: OrderDir | int
) -> _float_or_ndarray:
if baseline_price == 0: # something is wrong with data. Should be nan here
if isinstance(exec_price, float):
return 0.0
else:
return np.zeros_like(exec_price)
if direction == OrderDir.BUY:
res = (1 - exec_price / baseline_price) * 10000
elif direction == OrderDir.SELL:
res = (exec_price / baseline_price - 1) * 10000
else:
raise ValueError(f"Unexpected order direction: {direction}")
res_wo_nan: np.ndarray = np.nan_to_num(res, nan=0.0)
if res_wo_nan.size == 1:
return res_wo_nan.item()
else:
return cast(_float_or_ndarray, res_wo_nan)

84
qlib/rl/reward.py Normal file
View File

@@ -0,0 +1,84 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import Generic, Any, TypeVar, TYPE_CHECKING
from qlib.typehint import final
if TYPE_CHECKING:
from .utils.env_wrapper import EnvWrapper
SimulatorState = TypeVar("SimulatorState")
class Reward(Generic[SimulatorState]):
"""
Reward calculation component that takes a single argument: state of simulator. Returns a real number: reward.
Subclass should implement ``reward(simulator_state)`` to implement their own reward calculation recipe.
"""
env: EnvWrapper | None = None
@final
def __call__(self, simulator_state: SimulatorState) -> float:
return self.reward(simulator_state)
def reward(self, simulator_state: SimulatorState) -> float:
"""Implement this method for your own reward."""
raise NotImplementedError("Implement reward calculation recipe in `reward()`.")
def log(self, name, value):
self.env.logger.add_scalar(name, value)
class RewardCombination(Reward):
"""Combination of multiple reward."""
def __init__(self, rewards: dict[str, tuple[Reward, float]]):
self.rewards = rewards
def reward(self, simulator_state: Any) -> float:
total_reward = 0.0
for name, (reward_fn, weight) in self.rewards.items():
rew = reward_fn(simulator_state) * weight
total_reward += rew
self.log(name, rew)
return total_reward
# TODO:
# reward_factory is disabled for now
# _RegistryConfigReward = RegistryConfig[REWARDS]
# @configclass
# class _WeightedRewardConfig:
# weight: float
# reward: _RegistryConfigReward
# RewardConfig = Union[_RegistryConfigReward, Dict[str, Union[_RegistryConfigReward, _WeightedRewardConfig]]]
# def reward_factory(reward_config: RewardConfig) -> Reward:
# """
# Use this factory to instantiate the reward from config.
# Simply using ``reward_config.build()`` might not work because reward can have complex combinations.
# """
# if isinstance(reward_config, dict):
# # as reward combination
# rewards = {}
# for name, rew in reward_config.items():
# if not isinstance(rew, _WeightedRewardConfig):
# # default weight is 1.
# rew = _WeightedRewardConfig(weight=1., rew=rew)
# # no recursive build in this step
# rewards[name] = (rew.reward.build(), rew.weight)
# return RewardCombination(rewards)
# else:
# # single reward
# return reward_config.build()

12
qlib/rl/seed.py Normal file
View File

@@ -0,0 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Defines a set of initial state definitions and state-set definitions.
With single-asset order execution only, the only seed is order.
"""
from typing import TypeVar
InitialStateType = TypeVar("InitialStateType")
"""Type of data that creates the simulator."""

75
qlib/rl/simulator.py Normal file
View File

@@ -0,0 +1,75 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import TypeVar, Generic, Any, TYPE_CHECKING
from .seed import InitialStateType
if TYPE_CHECKING:
from .utils.env_wrapper import EnvWrapper
StateType = TypeVar("StateType")
"""StateType stores all the useful data in the simulation process
(as well as utilities to generate/retrieve data when needed)."""
ActType = TypeVar("ActType")
"""This ActType is the type of action at the simulator end."""
class Simulator(Generic[InitialStateType, StateType, ActType]):
"""
Simulator that resets with ``__init__``, and transits with ``step(action)``.
To make the data-flow clear, we make the following restrictions to Simulator:
1. The only way to modify the inner status of a simulator is by using ``step(action)``.
2. External modules can *read* the status of a simulator by using ``simulator.get_state()``,
and check whether the simulator is in the ending state by calling ``simulator.done()``.
A simulator is defined to be bounded with three types:
- *InitialStateType* that is the type of the data used to create the simulator.
- *StateType* that is the type of the **status** (state) of the simulator.
- *ActType* that is the type of the **action**, which is the input received in each step.
Different simulators might share the same StateType. For example, when they are dealing with the same task,
but with different simulation implementation. With the same type, they can safely share other components in the MDP.
Simulators are ephemeral. The lifecycle of a simulator starts with an initial state, and ends with the trajectory.
In another word, when the trajectory ends, simulator is recycled.
If simulators want to share context between (e.g., for speed-up purposes),
this could be done by accessing the weak reference of environment wrapper.
Attributes
----------
env
A reference of env-wrapper, which could be useful in some corner cases.
Simulators are discouraged to use this, because it's prone to induce errors.
"""
env: EnvWrapper | None = None
def __init__(self, initial: InitialStateType, **kwargs: Any) -> None:
pass
def step(self, action: ActType) -> None:
"""Receives an action of ActType.
Simulator should update its internal state, and return None.
The updated state can be retrieved with ``simulator.get_state()``.
"""
raise NotImplementedError()
def get_state(self) -> StateType:
raise NotImplementedError()
def done(self) -> bool:
"""Check whether the simulator is in a "done" state.
When simulator is in a "done" state,
it should no longer receives any ``step`` request.
As simulators are ephemeral, to reset the simulator,
the old one should be destroyed and a new simulator can be created.
"""
raise NotImplementedError()

View File

@@ -0,0 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .data_queue import *
from .env_wrapper import *
from .finite_env import *
from .log import *

179
qlib/rl/utils/data_queue.py Normal file
View File

@@ -0,0 +1,179 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import multiprocessing
import threading
import time
import warnings
from queue import Empty
from typing import TypeVar, Generic, Sequence, cast
from qlib.log import get_module_logger
_logger = get_module_logger(__name__)
T = TypeVar("T")
__all__ = ["DataQueue"]
class DataQueue(Generic[T]):
"""Main process (producer) produces data and stores them in a queue.
Sub-processes (consumers) can retrieve the data-points from the queue.
Data-points are generated via reading items from ``dataset``.
:class:`DataQueue` is ephemeral. You must create a new DataQueue
when the ``repeat`` is exhausted.
See the documents of :class:`qlib.rl.utils.FiniteVectorEnv` for more background.
Parameters
----------
dataset
The dataset to read data from. Must implement ``__len__`` and ``__getitem__``.
repeat
Iterate over the data-points for how many times. Use ``-1`` to iterate forever.
shuffle
If ``shuffle`` is true, the items will be read in random order.
producer_num_workers
Concurrent workers for data-loading.
queue_maxsize
Maximum items to put into queue before it jams.
Examples
--------
>>> data_queue = DataQueue(my_dataset)
>>> with data_queue:
... ...
In worker:
>>> for data in data_queue:
... print(data)
"""
def __init__(
self,
dataset: Sequence[T],
repeat: int = 1,
shuffle: bool = True,
producer_num_workers: int = 0,
queue_maxsize: int = 0,
):
if queue_maxsize == 0:
if os.cpu_count() is not None:
queue_maxsize = cast(int, os.cpu_count())
_logger.info(f"Automatically set data queue maxsize to {queue_maxsize} to avoid overwhelming.")
else:
queue_maxsize = 1
_logger.warning(f"CPU count not available. Setting queue maxsize to 1.")
self.dataset: Sequence[T] = dataset
self.repeat: int = repeat
self.shuffle: bool = shuffle
self.producer_num_workers: int = producer_num_workers
self._activated: bool = False
self._queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=queue_maxsize)
self._done = multiprocessing.Value("i", 0)
def __enter__(self):
self.activate()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.cleanup()
def cleanup(self):
with self._done.get_lock():
self._done.value += 1
for repeat in range(500):
if repeat >= 1:
warnings.warn(f"After {repeat} cleanup, the queue is still not empty.", category=RuntimeWarning)
while not self._queue.empty():
try:
self._queue.get(block=False)
except Empty:
pass
# Sometimes when the queue gets emptied, more data have already been sent,
# and they are on the way into the queue.
# If these data didn't get consumed, it will jam the queue and make the process hang.
# We wait a second here for potential data arriving, and check again (for ``repeat`` times).
time.sleep(1.0)
if self._queue.empty():
break
_logger.debug(f"Remaining items in queue collection done. Empty: {self._queue.empty()}")
def get(self, block=True):
if not hasattr(self, "_first_get"):
self._first_get = True
if self._first_get:
timeout = 5.0
self._first_get = False
else:
timeout = 0.5
while True:
try:
return self._queue.get(block=block, timeout=timeout)
except Empty:
if self._done.value:
raise StopIteration # pylint: disable=raise-missing-from
def put(self, obj, block=True, timeout=None):
return self._queue.put(obj, block=block, timeout=timeout)
def mark_as_done(self):
with self._done.get_lock():
self._done.value = 1
def done(self):
return self._done.value
def activate(self):
if self._activated:
raise ValueError("DataQueue can not activate twice.")
thread = threading.Thread(target=self._producer, daemon=True)
thread.start()
self._activated = True
return self
def __del__(self):
_logger.debug(f"__del__ of {__name__}.DataQueue")
self.cleanup()
def __iter__(self):
if not self._activated:
raise ValueError(
"Need to call activate() to launch a daemon worker " "to produce data into data queue before using it."
)
return self._consumer()
def _consumer(self):
while True:
try:
yield self.get()
except StopIteration:
_logger.debug("Data consumer timed-out from get.")
return
def _producer(self):
# pytorch dataloader is used here only because we need its sampler and multi-processing
from torch.utils.data import DataLoader, Dataset # pylint: disable=import-outside-toplevel
dataloader = DataLoader(
cast(Dataset[T], self.dataset),
batch_size=None,
num_workers=self.producer_num_workers,
shuffle=self.shuffle,
collate_fn=lambda t: t, # identity collate fn
)
repeat = 10**18 if self.repeat == -1 else self.repeat
for _rep in range(repeat):
for data in dataloader:
if self._done.value:
# Already done.
return
self._queue.put(data)
_logger.debug(f"Dataloader loop done. Repeat {_rep}.")
self.mark_as_done()

View File

@@ -0,0 +1,249 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import weakref
from typing import Callable, Any, Iterable, Iterator, Generic, cast
import gym
from qlib.rl.aux_info import AuxiliaryInfoCollector
from qlib.rl.simulator import Simulator, InitialStateType, StateType, ActType
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter, PolicyActType, ObsType
from qlib.rl.reward import Reward
from qlib.typehint import TypedDict
from .finite_env import generate_nan_observation
from .log import LogCollector, LogLevel
__all__ = ["InfoDict", "EnvWrapperStatus", "EnvWrapper"]
# in this case, there won't be any seed for simulator
SEED_INTERATOR_MISSING = "_missing_"
class InfoDict(TypedDict):
"""The type of dict that is used in the 4th return value of ``env.step()``."""
aux_info: dict
"""Any information depends on auxiliary info collector."""
log: dict[str, Any]
"""Collected by LogCollector."""
class EnvWrapperStatus(TypedDict):
"""
This is the status data structure used in EnvWrapper.
The fields here are in the semantics of RL.
For example, ``obs`` means the observation fed into policy.
``action`` means the raw action returned by policy.
"""
cur_step: int
done: bool
initial_state: Any | None
obs_history: list
action_history: list
reward_history: list
class EnvWrapper(
gym.Env[ObsType, PolicyActType], Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType]
):
"""Qlib-based RL environment, subclassing ``gym.Env``.
A wrapper of components, including simulator, state-interpreter, action-interpreter, reward.
This is what the framework of simulator - interpreter - policy looks like in RL training.
All the components other than policy needs to be assembled into a single object called "environment".
The "environment" are replicated into multiple workers, and (at least in tianshou's implementation),
one single policy (agent) plays against a batch of environments.
Parameters
----------
simulator_fn
A callable that is the simulator factory.
When ``seed_iterator`` is present, the factory should take one argument,
that is the seed (aka initial state).
Otherwise, it should take zero argument.
state_interpreter
State-observation converter.
action_interpreter
Policy-simulator action converter.
seed_iterator
An iterable of seed. With the help of :class:`qlib.rl.utils.DataQueue`,
environment workers in different processes can share one ``seed_iterator``.
reward_fn
A callable that accepts the StateType and returns a float (at least in single-agent case).
aux_info_collector
Collect auxiliary information. Could be useful in MARL.
logger
Log collector that collects the logs. The collected logs are sent back to main process,
via the return value of ``env.step()``.
Attributes
----------
status : EnvWrapperStatus
Status indicator. All terms are in *RL language*.
It can be used if users care about data on the RL side.
Can be none when no trajectory is available.
"""
simulator: Simulator[InitialStateType, StateType, ActType]
seed_iterator: str | Iterator[InitialStateType] | None
def __init__(
self,
simulator_fn: Callable[..., Simulator[InitialStateType, StateType, ActType]],
state_interpreter: StateInterpreter[StateType, ObsType],
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
seed_iterator: Iterable[InitialStateType] | None,
reward_fn: Reward | None = None,
aux_info_collector: AuxiliaryInfoCollector[StateType, Any] | None = None,
logger: LogCollector | None = None,
):
# Assign weak reference to wrapper.
#
# Use weak reference here, because:
# 1. Logically, the other components should be able to live without an env_wrapper.
# For example, they might live in a strategy_wrapper in future.
# Therefore injecting a "hard" attribute called "env" is not appropripate.
# 2. When the environment gets destroyed, it gets destoryed.
# We don't want it to silently live inside some interpreters.
# 3. Avoid circular reference.
# 4. When the components get serialized, we can throw away the env without any burden.
# (though this part is not implemented yet)
for obj in [state_interpreter, action_interpreter, reward_fn, aux_info_collector]:
if obj is not None:
obj.env = weakref.proxy(self) # type: ignore
self.simulator_fn = simulator_fn
self.state_interpreter = state_interpreter
self.action_interpreter = action_interpreter
if seed_iterator is None:
# In this case, there won't be any seed for simulator
# We can't set it to None because None actually means something else.
# If `seed_iterator` is None, it means that it's exhausted.
self.seed_iterator = SEED_INTERATOR_MISSING
else:
self.seed_iterator = iter(seed_iterator)
self.reward_fn = reward_fn
self.aux_info_collector = aux_info_collector
self.logger: LogCollector = logger or LogCollector()
self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None)
@property
def action_space(self):
return self.action_interpreter.action_space
@property
def observation_space(self):
return self.state_interpreter.observation_space
def reset(self, **kwargs: Any) -> ObsType:
"""
Try to get a state from state queue, and init the simulator with this state.
If the queue is exhausted, generate an invalid (nan) observation.
"""
try:
if self.seed_iterator is None:
raise RuntimeError("You can trying to get a state from a dead environment wrapper.")
# TODO: simulator/observation might need seed to prefetch something
# as only seed has the ability to do the work beforehands
# NOTE: though logger is reset here, logs in this function won't work,
# because we can't send them outside.
# See https://github.com/thu-ml/tianshou/issues/605
self.logger.reset()
if self.seed_iterator is SEED_INTERATOR_MISSING:
# no initial state
initial_state = None
self.simulator = cast(Callable[[], Simulator], self.simulator_fn)()
else:
initial_state = next(cast(Iterator[InitialStateType], self.seed_iterator))
self.simulator = self.simulator_fn(initial_state)
self.status = EnvWrapperStatus(
cur_step=0,
done=False,
initial_state=initial_state,
obs_history=[],
action_history=[],
reward_history=[],
)
self.simulator.env = cast(EnvWrapper, weakref.proxy(self))
sim_state = self.simulator.get_state()
obs = self.state_interpreter(sim_state)
self.status["obs_history"].append(obs)
return obs
except StopIteration:
# The environment should be recycled because it's in a dead state.
self.seed_iterator = None
return generate_nan_observation(self.observation_space)
def step(self, policy_action: PolicyActType, **kwargs: Any) -> tuple[ObsType, float, bool, InfoDict]:
"""Environment step.
See the code along with comments to get a sequence of things happening here.
"""
if self.seed_iterator is None:
raise RuntimeError("State queue is already exhausted, but the environment is still receiving action.")
# Clear the logged information from last step
self.logger.reset()
# Action is what we have got from policy
self.status["action_history"].append(policy_action)
action = self.action_interpreter(self.simulator.get_state(), policy_action)
# This update must be after action interpreter and before simulator.
self.status["cur_step"] += 1
# Use the converted action of update the simulator
self.simulator.step(action)
# Update "done" first, as this status might be used by reward_fn later
done = self.simulator.done()
self.status["done"] = done
# Get state and calculate observation
sim_state = self.simulator.get_state()
obs = self.state_interpreter(sim_state)
self.status["obs_history"].append(obs)
# Reward and extra info
if self.reward_fn is not None:
rew = self.reward_fn(sim_state)
else:
# No reward. Treated as 0.
rew = 0.0
self.status["reward_history"].append(rew)
if self.aux_info_collector is not None:
aux_info = self.aux_info_collector(sim_state)
else:
aux_info = {}
# Final logging stuff: RL-specific logs
if done:
self.logger.add_scalar("steps_per_episode", self.status["cur_step"])
self.logger.add_scalar("reward", rew)
self.logger.add_any("obs", obs, loglevel=LogLevel.DEBUG)
self.logger.add_any("policy_act", policy_action, loglevel=LogLevel.DEBUG)
info_dict = InfoDict(log=self.logger.logs(), aux_info=aux_info)
return obs, rew, done, info_dict
def render(self):
raise NotImplementedError("Render is not implemented in EnvWrapper.")

337
qlib/rl/utils/finite_env.py Normal file
View File

@@ -0,0 +1,337 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This is to support finite env in vector env.
See https://github.com/thu-ml/tianshou/issues/322 for details.
"""
from __future__ import annotations
import copy
import warnings
from contextlib import contextmanager
import gym
import numpy as np
from typing import Any, Set, Callable, Type
from tianshou.env import BaseVectorEnv, DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv
from qlib.typehint import Literal
from .log import LogWriter
__all__ = [
"generate_nan_observation",
"check_nan_observation",
"FiniteVectorEnv",
"FiniteDummyVectorEnv",
"FiniteSubprocVectorEnv",
"FiniteShmemVectorEnv",
"FiniteEnvType",
"vectorize_env",
]
FiniteEnvType = Literal["dummy", "subproc", "shmem"]
def fill_invalid(obj):
if isinstance(obj, (int, float, bool)):
return fill_invalid(np.array(obj))
if hasattr(obj, "dtype"):
if isinstance(obj, np.ndarray):
if np.issubdtype(obj.dtype, np.floating):
return np.full_like(obj, np.nan)
return np.full_like(obj, np.iinfo(obj.dtype).max)
# dealing with corner cases that numpy number is not supported by tianshou's sharray
return fill_invalid(np.array(obj))
elif isinstance(obj, dict):
return {k: fill_invalid(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [fill_invalid(v) for v in obj]
elif isinstance(obj, tuple):
return tuple(fill_invalid(v) for v in obj)
raise ValueError(f"Unsupported value to fill with invalid: {obj}")
def is_invalid(arr):
if hasattr(arr, "dtype"):
if np.issubdtype(arr.dtype, np.floating):
return np.isnan(arr).all()
return (np.iinfo(arr.dtype).max == arr).all()
if isinstance(arr, dict):
return all(is_invalid(o) for o in arr.values())
if isinstance(arr, (list, tuple)):
return all(is_invalid(o) for o in arr)
if isinstance(arr, (int, float, bool, np.number)):
return is_invalid(np.array(arr))
return True
def generate_nan_observation(obs_space: gym.Space) -> Any:
"""The NaN observation that indicates the environment receives no seed.
We assume that obs is complex and there must be something like float.
Otherwise this logic doesn't work.
"""
sample = obs_space.sample()
sample = fill_invalid(sample)
return sample
def check_nan_observation(obs: Any) -> bool:
"""Check whether obs is generated by :func:`generate_nan_observation`."""
return is_invalid(obs)
class FiniteVectorEnv(BaseVectorEnv):
"""To allow the paralleled env workers consume a single DataQueue until it's exhausted.
See `tianshou issue #322 <https://github.com/thu-ml/tianshou/issues/322>`_.
The requirement is to make every possible seed (stored in :class:`qlib.rl.utils.DataQueue` in our case)
consumed by exactly one environment. This is not possible by tianshou's native VectorEnv and Collector,
because tianshou is unaware of this "exactly one" constraint, and might launch extra workers.
Consider a corner case, where concurrency is 2, but there is only one seed in DataQueue.
The reset of two workers must be both called according to the logic in collect.
The returned results of two workers are collected, regardless of what they are.
The problem is, one of the reset result must be invalid, or repeated,
because there's only one need in queue, and collector isn't aware of such situation.
Luckily, we can hack the vector env, and make a protocol between single env and vector env.
The single environment (should be :class:`qlib.rl.utils.EnvWrapper` in our case) is responsible for
reading from queue, and generate a special observation when the queue is exhausted. The special obs
is called "nan observation", because simply using none causes problems in shared-memory vector env.
:class:`FiniteVectorEnv` then read the observations from all workers, and select those non-nan
observation. It also maintains an ``_alive_env_ids`` to track which workers should never be
called again. When also the environments are exhausted, it will raise StopIteration exception.
The usage of this vector env in collector are two parts:
1. If the data queue is finite (usually when inference), collector should collect "infinity" number of
episodes, until the vector env exhausts by itself.
2. If the data queue is infinite (usually in training), collector can set number of episodes / steps.
In this case, data would be randomly ordered, and some repetitions wouldn't matter.
One extra function of this vector env is that it has a logger that explicitly collects logs
from child workers. See :class:`qlib.rl.utils.LogWriter`.
"""
def __init__(
self, logger: LogWriter | list[LogWriter], env_fns: list[Callable[..., gym.Env]], **kwargs: Any
) -> None:
super().__init__(env_fns, **kwargs)
self._logger: list[LogWriter] = logger if isinstance(logger, list) else [logger]
self._alive_env_ids: Set[int] = set()
self._reset_alive_envs()
self._default_obs = self._default_info = self._default_rew = None
self._zombie = False
self._collector_guarded: bool = False
def _reset_alive_envs(self):
if not self._alive_env_ids:
# starting or running out
self._alive_env_ids = set(range(self.env_num))
# to workaround with tianshou's buffer and batch
def _set_default_obs(self, obs):
if obs is not None and self._default_obs is None:
self._default_obs = copy.deepcopy(obs)
def _set_default_info(self, info):
if info is not None and self._default_info is None:
self._default_info = copy.deepcopy(info)
def _set_default_rew(self, rew):
if rew is not None and self._default_rew is None:
self._default_rew = copy.deepcopy(rew)
def _get_default_obs(self):
return copy.deepcopy(self._default_obs)
def _get_default_info(self):
return copy.deepcopy(self._default_info)
def _get_default_rew(self):
return copy.deepcopy(self._default_rew)
# END
@staticmethod
def _postproc_env_obs(obs):
# reserved for shmem vector env to restore empty observation
if obs is None or check_nan_observation(obs):
return None
return obs
@contextmanager
def collector_guard(self):
"""Guard the collector. Recommended to guard every collect.
This guard is for two purposes.
1. Catch and ignore the StopIteration exception, which is the stopping signal
thrown by FiniteEnv to let tianshou know that ``collector.collect()`` should exit.
2. Notify the loggers that the collect is done what it's done.
Examples
--------
>>> with finite_env.collector_guard():
... collector.collect(n_episode=INF)
"""
self._collector_guarded = True
try:
yield self
except StopIteration:
pass
finally:
self._collector_guarded = False
# At last trigger the loggers
for logger in self._logger:
logger.on_env_all_done()
def reset(self, id=None):
assert not self._zombie
# Check whether it's guarded by collector_guard()
if not self._collector_guarded:
warnings.warn(
"Collector is not guarded by FiniteEnv. "
"This may cause unexpected problems, like unexpected StopIteration exception, "
"or missing logs.",
RuntimeWarning,
)
id = self._wrap_id(id)
self._reset_alive_envs()
# ask super to reset alive envs and remap to current index
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
obs = [None] * len(id)
id2idx = {i: k for k, i in enumerate(id)}
if request_id:
for i, o in zip(request_id, super().reset(request_id)):
obs[id2idx[i]] = self._postproc_env_obs(o)
for i, o in zip(id, obs):
if o is None and i in self._alive_env_ids:
self._alive_env_ids.remove(i)
# logging
for i, o in zip(id, obs):
if i in self._alive_env_ids:
for logger in self._logger:
logger.on_env_reset(i, obs)
# fill empty observation with default(fake) observation
for o in obs:
self._set_default_obs(o)
for i, o in enumerate(obs):
if o is None:
obs[i] = self._get_default_obs()
if not self._alive_env_ids:
# comment this line so that the env becomes indisposable
# self.reset()
self._zombie = True
raise StopIteration
return np.stack(obs)
def step(self, action, id=None):
assert not self._zombie
id = self._wrap_id(id)
id2idx = {i: k for k, i in enumerate(id)}
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
result = [[None, None, False, None] for _ in range(len(id))]
# ask super to step alive envs and remap to current index
if request_id:
valid_act = np.stack([action[id2idx[i]] for i in request_id])
for i, r in zip(request_id, zip(*super().step(valid_act, request_id))):
result[id2idx[i]] = list(r)
result[id2idx[i]][0] = self._postproc_env_obs(result[id2idx[i]][0])
# logging
for i, r in zip(id, result):
if i in self._alive_env_ids:
for logger in self._logger:
logger.on_env_step(i, *r)
# fill empty observation/info with default(fake)
for _, r, ___, i in result:
self._set_default_info(i)
self._set_default_rew(r)
for i, r in enumerate(result):
if r[0] is None:
result[i][0] = self._get_default_obs()
if r[1] is None:
result[i][1] = self._get_default_rew()
if r[3] is None:
result[i][3] = self._get_default_info()
return list(map(np.stack, zip(*result)))
class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv):
pass
class FiniteSubprocVectorEnv(FiniteVectorEnv, SubprocVectorEnv):
pass
class FiniteShmemVectorEnv(FiniteVectorEnv, ShmemVectorEnv):
pass
def vectorize_env(
env_factory: Callable[..., gym.Env],
env_type: FiniteEnvType,
concurrency: int,
logger: LogWriter | list[LogWriter],
) -> FiniteVectorEnv:
"""Helper function to create a vector env.
Parameters
----------
env_factory
Callable to instantiate one single ``gym.Env``.
All concurrent workers will have the same ``env_factory``.
env_type
dummy or subproc or shmem. Corresponding to
`parallelism in tianshou <https://tianshou.readthedocs.io/en/master/api/tianshou.env.html#vectorenv>`_.
concurrency
Concurrent environment workers.
logger
Log writers.
Warnings
--------
Please do not use lambda expression here for ``env_factory`` as it may create incorrectly-shared instances.
Don't do: ::
vectorize_env(lambda: EnvWrapper(...), ...)
Please do: ::
def env_factory(): ...
vectorize_env(env_factory, ...)
"""
env_type_cls_mapping: dict[str, Type[FiniteVectorEnv]] = {
"dummy": FiniteDummyVectorEnv,
"subproc": FiniteSubprocVectorEnv,
"shmem": FiniteShmemVectorEnv,
}
finite_env_cls = env_type_cls_mapping[env_type]
return finite_env_cls(logger, [env_factory for _ in range(concurrency)])

398
qlib/rl/utils/log.py Normal file
View File

@@ -0,0 +1,398 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Distributed logger for RL.
:class:`LogCollector` runs in every environment workers. It collects log info from simulator states,
and add them (as a dict) to auxiliary info returned for each step.
:class:`LogWriter` runs in the central worker. It decodes the dict collected by :class:`LogCollector`
in each worker, and writes them to console, log files, or tensorboard...
The two modules communicate by the "log" field in "info" returned by ``env.step()``.
"""
from __future__ import annotations
import logging
from collections import defaultdict
from enum import IntEnum
from pathlib import Path
from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence
import numpy as np
import pandas as pd
from qlib.log import get_module_logger
if TYPE_CHECKING:
from .env_wrapper import InfoDict
__all__ = ["LogCollector", "LogWriter", "LogLevel", "ConsoleWriter", "CsvWriter"]
ObsType = TypeVar("ObsType")
ActType = TypeVar("ActType")
class LogLevel(IntEnum):
"""Log-levels for RL training.
The behavior of handling each log level depends on the implementation of :class:`LogWriter`.
"""
DEBUG = 10
"""If you only want to see the metric in debug mode."""
PERIODIC = 20
"""If you want to see the metric periodically."""
# FIXME: I haven't given much thought about this. Let's hold it for one iteration.
INFO = 30
"""Important log messages."""
CRITICAL = 40
"""LogWriter should always handle CRITICAL messages"""
class LogCollector:
"""Logs are first collected in each environment worker,
and then aggregated to stream at the central thread in vector env.
In :class:`LogCollector`, every metric is added to a dict, which needs to be ``reset()`` at each step.
The dict is sent via the ``info`` in ``env.step()``, and decoded by the :class:`LogWriter` at vector env.
``min_loglevel`` is for optimization purposes: to avoid too much traffic on networks / in pipe.
"""
_logged: dict[str, tuple[int, Any]]
_min_loglevel: int
def __init__(self, min_loglevel: int | LogLevel = LogLevel.PERIODIC):
self._min_loglevel = int(min_loglevel)
def reset(self):
"""Clear all collected contents."""
self._logged = {}
def _add_metric(self, name: str, metric: Any, loglevel: int | LogLevel) -> None:
if name in self._logged:
raise ValueError(f"A metric with {name} is already added. Please change a name or reset the log collector.")
self._logged[name] = (int(loglevel), metric)
def add_string(self, name: str, string: str, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
"""Add a string with name into logged contents."""
if loglevel < self._min_loglevel:
return
if not isinstance(string, str):
raise TypeError(f"{string} is not a string.")
self._add_metric(name, string, loglevel)
def add_scalar(self, name: str, scalar: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
"""Add a scalar with name into logged contents.
Scalar will be converted into a float.
"""
if loglevel < self._min_loglevel:
return
if hasattr(scalar, "item"):
# could be single-item number
scalar = scalar.item()
if not isinstance(scalar, (float, int)):
raise TypeError(f"{scalar} is not and can not be converted into float or integer.")
scalar = float(scalar)
self._add_metric(name, scalar, loglevel)
def add_array(
self, name: str, array: np.ndarray | pd.DataFrame | pd.Series, loglevel: int | LogLevel = LogLevel.PERIODIC
) -> None:
"""Add an array with name into logging."""
if loglevel < self._min_loglevel:
return
if not isinstance(array, (np.ndarray, pd.DataFrame, pd.Series)):
raise TypeError(f"{array} is not one of ndarray, DataFrame and Series.")
self._add_metric(name, array, loglevel)
def add_any(self, name: str, obj: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
"""Log something with any type.
As it's an "any" object, the only LogWriter accepting it is pickle.
Therefore pickle must be able to serialize it.
"""
if loglevel < self._min_loglevel:
return
# FIXME: detect and rescue object that could be scalar or array
self._add_metric(name, obj, loglevel)
def logs(self) -> dict[str, np.ndarray]:
return {key: np.asanyarray(value, dtype="object") for key, value in self._logged.items()}
class LogWriter(Generic[ObsType, ActType]):
"""Base class for log writers, triggered at every reset and step by finite env.
What to do with a specific log depends on the implementation of subclassing :class:`LogWriter`.
The general principle is that, it should handle logs above its loglevel (inclusive),
and discard logs that are not acceptable. For instance, console loggers obviously can't handle an image.
"""
episode_count: int
"""Counter of episodes."""
step_count: int
"""Counter of steps."""
global_step: int
"""Counter of steps. Won"t be cleared in ``clear``."""
global_episode: int
"""Counter of episodes. Won"t be cleared in ``clear``."""
active_env_ids: Set[int]
"""Active environment ids in vector env."""
episode_lengths: dict[int, int]
"""Map from environment id to episode length."""
episode_rewards: dict[int, list[float]]
"""Map from environment id to episode total reward."""
episode_logs: dict[int, list]
"""Map from environment id to episode logs."""
def __init__(self, loglevel: int | LogLevel = LogLevel.PERIODIC):
self.loglevel = loglevel
self.global_step = 0
self.global_episode = 0
# Information, logs of one episode is stored here.
# This assumes that episode is not too long to fit into the memory.
self.episode_lengths = dict()
self.episode_rewards = dict()
self.episode_logs = dict()
self.clear()
def clear(self):
self.episode_count = self.step_count = 0
self.active_env_ids = set()
self.logs = []
def aggregation(self, array: Sequence[Any]) -> Any:
"""Aggregation function from step-wise to episode-wise.
If it's a sequence of float, take the mean.
Otherwise, take the first element.
"""
assert len(array) > 0, "The aggregated array must be not empty."
if all(isinstance(v, float) for v in array):
return np.mean(array)
else:
return array[0]
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
"""This is triggered at the end of each trajectory.
Parameters
----------
length
Length of this trajectory.
rewards
A list of rewards at each step of this episode.
contents
Logged contents for every steps.
"""
def log_step(self, reward: float, contents: dict[str, Any]) -> None:
"""This is triggered at each step.
Parameters
----------
reward
Reward for this step.
contents
Logged contents for this step.
"""
def on_env_step(self, env_id: int, obs: ObsType, rew: float, done: bool, info: InfoDict) -> None:
"""Callback for finite env, on each step."""
# Update counter
self.global_step += 1
self.step_count += 1
self.active_env_ids.add(env_id)
self.episode_lengths[env_id] += 1
# TODO: reward can be a list of list for MARL
self.episode_rewards[env_id].append(rew)
values: dict[str, Any] = {}
for key, (loglevel, value) in info["log"].items():
if loglevel >= self.loglevel: # FIXME: this is actually incorrect (see last FIXME)
values[key] = value
self.episode_logs[env_id].append(values)
self.log_step(rew, values)
if done:
# Update counter
self.global_episode += 1
self.episode_count += 1
self.log_episode(self.episode_lengths[env_id], self.episode_rewards[env_id], self.episode_logs[env_id])
def on_env_reset(self, env_id: int, obs: ObsType) -> None:
"""Callback for finite env.
Reset episode statistics. Nothing task-specific is logged here because of
`a limitation of tianshou <https://github.com/thu-ml/tianshou/issues/605>`__.
"""
self.episode_lengths[env_id] = 0
self.episode_rewards[env_id] = []
self.episode_logs[env_id] = []
def on_env_all_done(self) -> None:
"""All done. Time for cleanup."""
class ConsoleWriter(LogWriter):
"""Write log messages to console periodically.
It tracks an average meter for each metric, which is the average value since last ``clear()`` till now.
The display format for each metric is ``<name> <latest_value> (<average_value>)``.
Non-single-number metrics are auto skipped.
"""
prefix: str
"""Prefix can be set via ``writer.prefix``."""
def __init__(
self,
log_every_n_episode: int = 20,
total_episodes: int | None = None,
float_format: str = ":.4f",
counter_format: str = ":4d",
loglevel: int | LogLevel = LogLevel.PERIODIC,
):
super().__init__(loglevel)
# TODO: support log_every_n_step
self.log_every_n_episode = log_every_n_episode
self.total_episodes = total_episodes
self.counter_format = counter_format
self.float_format = float_format
self.prefix = ""
self.console_logger = get_module_logger(__name__, level=logging.INFO)
def clear(self):
super().clear()
# Clear average meters
self.metric_counts: dict[str, int] = defaultdict(int)
self.metric_sums: dict[str, float] = defaultdict(float)
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
# Aggregate step-wise to episode-wise
episode_wise_contents: dict[str, list] = defaultdict(list)
for step_contents in contents:
for name, value in step_contents.items():
if isinstance(value, float):
episode_wise_contents[name].append(value)
# Generate log contents and track them in average-meter.
# This should be done at every step, regardless of periodic or not.
logs: dict[str, float] = {}
for name, values in episode_wise_contents.items():
logs[name] = self.aggregation(values) # type: ignore
for name, value in logs.items():
self.metric_counts[name] += 1
self.metric_sums[name] += value
if self.episode_count % self.log_every_n_episode == 0 or self.episode_count == self.total_episodes:
# Only log periodically or at the end
self.console_logger.info(self.generate_log_message(logs))
def generate_log_message(self, logs: dict[str, float]) -> str:
if self.prefix:
msg_prefix = self.prefix + " "
else:
msg_prefix = ""
if self.total_episodes is None:
msg_prefix += "[Step {" + self.counter_format + "}]"
else:
msg_prefix += "[{" + self.counter_format + "}/" + str(self.total_episodes) + "]"
msg_prefix = msg_prefix.format(self.episode_count)
msg = ""
for name, value in logs.items():
# Double-space as delimiter
format_template = r" {} {" + self.float_format + "} ({" + self.float_format + "})"
msg += format_template.format(name, value, self.metric_sums[name] / self.metric_counts[name])
msg = msg_prefix + " " + msg
return msg
class CsvWriter(LogWriter):
"""Dump all episode metrics to a ``result.csv``.
This is not the correct implementation. It's only used for first iteration.
"""
SUPPORTED_TYPES = (float, str, pd.Timestamp)
all_records: list[dict[str, Any]]
def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC):
super().__init__(loglevel)
self.output_dir = output_dir
self.output_dir.mkdir(exist_ok=True)
def clear(self):
super().clear()
self.all_records = []
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
# FIXME Same as ConsoleLogger, needs a refactor to eliminate code-dup
episode_wise_contents: dict[str, list] = defaultdict(list)
for step_contents in contents:
for name, value in step_contents.items():
if isinstance(value, self.SUPPORTED_TYPES):
episode_wise_contents[name].append(value)
logs: dict[str, float] = {}
for name, values in episode_wise_contents.items():
logs[name] = self.aggregation(values) # type: ignore
self.all_records.append(logs)
def on_env_all_done(self) -> None:
# FIXME: this is temporary
pd.DataFrame.from_records(self.all_records).to_csv(self.output_dir / "result.csv", index=False)
# The following are not implemented yet.
class PickleWriter(LogWriter):
"""Dump logs to pickle files."""
class TensorboardWriter(LogWriter):
"""Write logs to event files that can be visualized with tensorboard."""
class MlflowWriter(LogWriter):
"""Add logs to mlflow."""
class LogBuffer(LogWriter):
"""Keep everything in memory."""

13
qlib/typehint.py Normal file
View File

@@ -0,0 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Commonly used types."""
import sys
__all__ = ["Literal", "TypedDict", "final"]
if sys.version_info >= (3, 8):
from typing import Literal, TypedDict, final # type: ignore # pylint: disable=no-name-in-module
else:
from typing_extensions import Literal, TypedDict, final

View File

@@ -376,7 +376,7 @@ get_cls_kwargs = get_callable_kwargs # NOTE: this is for compatibility for the
def init_instance_by_config(
config: Union[str, dict, object],
config: Union[str, dict, object, Path],
default_module=None,
accept_types: Union[type, Tuple[type]] = (),
try_kwargs: Dict = {},
@@ -409,6 +409,9 @@ def init_instance_by_config(
- "a.b.c.ClassName" getattr(<a.b.c.module>, "ClassName")() will be used.
object example:
instance of accept_types
Path example:
specify a pickle object
- it will be treated like 'file:///<path to pickle file>/obj.pkl'
default_module : Python module
Optional. It should be a python module.
NOTE: the "module_path" will be override by `module` arguments
@@ -432,11 +435,15 @@ def init_instance_by_config(
if isinstance(config, accept_types):
return config
if isinstance(config, str):
# path like 'file:///<path to pickle file>/obj.pkl'
pr = urlparse(config)
if pr.scheme == "file":
with open(os.path.join(pr.netloc, pr.path), "rb") as f:
if isinstance(config, (str, Path)):
if isinstance(config, str):
# path like 'file:///<path to pickle file>/obj.pkl'
pr = urlparse(config)
if pr.scheme == "file":
with open(os.path.join(pr.netloc, pr.path), "rb") as f:
return pickle.load(f)
else:
with config.open("rb") as f:
return pickle.load(f)
klass, cls_kwargs = get_callable_kwargs(config, default_module=default_module)

View File

@@ -170,7 +170,7 @@ class BaseCollector(abc.ABC):
df["symbol"] = symbol
if instrument_path.exists():
_old_df = pd.read_csv(instrument_path)
df = _old_df.append(df, sort=False)
df = pd.concat([_old_df, df], sort=False)
df.to_csv(instrument_path, index=False)
def cache_small_data(self, symbol, df):

View File

@@ -55,13 +55,13 @@ class IBOVIndex(IndexBase):
def get_current_4_month_period(self, current_month: int):
"""
This function is used to calculated what is the current
four month period for the current month. For example,
This function is used to calculated what is the current
four month period for the current month. For example,
If the current month is August 8, its four month period
is 2Q.
OBS: In english Q is used to represent *quarter*
which means a three month period. However, in
which means a three month period. However, in
portuguese we use Q to represent a four month period.
In other words,
@@ -90,8 +90,8 @@ class IBOVIndex(IndexBase):
def get_four_month_period(self):
"""
The ibovespa index is updated every four months.
Therefore, we will represent each time period as 2003_1Q
The ibovespa index is updated every four months.
Therefore, we will represent each time period as 2003_1Q
which means 2003 first four mount period (Jan, Feb, Mar, Apr)
"""
four_months_period = ["1Q", "2Q", "3Q"]
@@ -101,14 +101,13 @@ class IBOVIndex(IndexBase):
current_month = now.month
for year in [item for item in range(init_year, current_year)]:
for el in four_months_period:
self.years_4_month_periods.append(str(year)+"_"+el)
self.years_4_month_periods.append(str(year) + "_" + el)
# For current year the logic must be a little different
current_4_month_period = self.get_current_4_month_period(current_month)
for i in range(int(current_4_month_period[0])):
self.years_4_month_periods.append(str(current_year) + "_" + str(i+1) + "Q")
self.years_4_month_periods.append(str(current_year) + "_" + str(i + 1) + "Q")
return self.years_4_month_periods
def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame:
"""formatting the datetime in an instrument
@@ -189,11 +188,19 @@ class IBOVIndex(IndexBase):
try:
df_changes_list = []
for i in tqdm(range(len(self.years_4_month_periods) - 1)):
df = pd.read_csv(self.ibov_index_composition.format(self.years_4_month_periods[i]), on_bad_lines="skip")["symbol"]
df_ = pd.read_csv(self.ibov_index_composition.format(self.years_4_month_periods[i + 1]), on_bad_lines="skip")["symbol"]
df = pd.read_csv(
self.ibov_index_composition.format(self.years_4_month_periods[i]), on_bad_lines="skip"
)["symbol"]
df_ = pd.read_csv(
self.ibov_index_composition.format(self.years_4_month_periods[i + 1]), on_bad_lines="skip"
)["symbol"]
## Remove Dataframe
remove_date = self.years_4_month_periods[i].split("_")[0] + "-" + quarter_dict[self.years_4_month_periods[i].split("_")[1]]
remove_date = (
self.years_4_month_periods[i].split("_")[0]
+ "-"
+ quarter_dict[self.years_4_month_periods[i].split("_")[1]]
)
list_remove = list(df[~df.isin(df_)])
df_removed = pd.DataFrame(
{
@@ -204,7 +211,11 @@ class IBOVIndex(IndexBase):
)
## Add Dataframe
add_date = self.years_4_month_periods[i + 1].split("_")[0] + "-" + quarter_dict[self.years_4_month_periods[i + 1].split("_")[1]]
add_date = (
self.years_4_month_periods[i + 1].split("_")[0]
+ "-"
+ quarter_dict[self.years_4_month_periods[i + 1].split("_")[1]]
)
list_add = list(df_[~df_.isin(df)])
df_added = pd.DataFrame(
{"date": len(list_add) * [add_date], "type": len(list_add) * ["add"], "symbol": list_add}
@@ -272,6 +283,5 @@ class IBOVIndex(IndexBase):
return df.loc[:, ["Código"]].copy()
if __name__ == "__main__":
fire.Fire(partial(get_instruments, market_index="br_index" ))
fire.Fire(partial(get_instruments, market_index="br_index"))

View File

@@ -90,7 +90,6 @@ class CSIIndex(IndexBase):
raise NotImplementedError("rewrite index_code")
@property
@abc.abstractmethod
def html_table_index(self) -> int:
"""Which table of changes in html
@@ -98,7 +97,7 @@ class CSIIndex(IndexBase):
CSI100: 1
:return:
"""
raise NotImplementedError()
raise NotImplementedError("rewrite html_table_index")
def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame:
"""formatting the datetime in an instrument
@@ -184,12 +183,7 @@ class CSIIndex(IndexBase):
df = pd.DataFrame()
_tmp_count = 0
for _df in pd.read_html(content):
if (
_df.shape[-1] != 4
or _df.iloc[2:,][0].str.contains(
"."
)[2]
):
if _df.shape[-1] != 4 or _df.isnull().loc(0)[0][0]:
continue
_tmp_count += 1
if self.html_table_index + 1 > _tmp_count:
@@ -341,8 +335,8 @@ class CSI300Index(CSIIndex):
return pd.Timestamp("2005-01-01")
@property
def html_table_index(self):
return 1
def html_table_index(self) -> int:
return 0
class CSI100Index(CSIIndex):
@@ -355,8 +349,8 @@ class CSI100Index(CSIIndex):
return pd.Timestamp("2006-05-29")
@property
def html_table_index(self):
return 2
def html_table_index(self) -> int:
return 1
class CSI500Index(CSIIndex):
@@ -368,10 +362,6 @@ class CSI500Index(CSIIndex):
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("2007-01-15")
@property
def html_table_index(self) -> int:
return 0
def get_changes(self) -> pd.DataFrame:
"""get companies changes
@@ -475,5 +465,4 @@ class CSI500Index(CSIIndex):
if __name__ == "__main__":
get_instruments(index_name="CSI300", qlib_dir="~/.qlib/qlib_data/cn_data", method="parse_instruments")
# fire.Fire(get_instruments)
fire.Fire(get_instruments)

View File

@@ -225,7 +225,7 @@ class IndexBase:
] = _row.date
else:
_tmp_df = pd.DataFrame([[_row.symbol, self.bench_start_date, _row.date]], columns=instruments_columns)
new_df = new_df.append(_tmp_df, sort=False)
new_df = pd.concat([new_df, _tmp_df], sort=False)
inst_df = new_df.loc[:, instruments_columns]
_inst_prefix = self.INST_PREFIX.strip()

View File

@@ -2,6 +2,7 @@
# Licensed under the MIT License.
import re
import sys
from datetime import datetime
from pathlib import Path
from typing import List, Iterable, Optional, Union
@@ -11,10 +12,11 @@ import pandas as pd
import baostock as bs
from loguru import logger
from scripts.data_collector.base import BaseCollector, BaseRun, BaseNormalize
from scripts.data_collector.utils import get_hs_stock_symbols, get_calendar_list
BASE_DIR = Path(__file__).resolve().parent
sys.path.append(str(BASE_DIR.parent.parent))
BASE_DIR = Path(__file__).resolve().parent.parent
from data_collector.base import BaseCollector, BaseRun, BaseNormalize
from data_collector.utils import get_hs_stock_symbols, get_calendar_list
class PitCollector(BaseCollector):

View File

@@ -271,6 +271,5 @@ class SP400Index(WIKIIndex):
logger.warning(f"No suitable data source has been found!")
if __name__ == "__main__":
fire.Fire(partial(get_instruments, market_index="us_index"))

View File

@@ -559,6 +559,7 @@ def generate_minutes_calendar_from_daily(
return pd.Index(sorted(set(np.hstack(res))))
def get_instruments(
qlib_dir: str,
index_name: str,
@@ -566,7 +567,7 @@ def get_instruments(
freq: str = "day",
request_retry: int = 5,
retry_sleep: int = 3,
market_index: str = "cn_index"
market_index: str = "cn_index",
):
"""
@@ -585,7 +586,7 @@ def get_instruments(
retry_sleep: int
request sleep, by default 3
market_index: str
Where the files to obtain the index are located,
Where the files to obtain the index are located,
for example data_collector.cn_index.collector
Examples
@@ -605,4 +606,4 @@ def get_instruments(
if __name__ == "__main__":
assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM
assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM

View File

@@ -10,6 +10,7 @@
> *Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*
**NOTE**: Yahoo! Finance has blocked the access from China. Please change your network if you want to use the Yahoo data crawler.
> **Examples of abnormal data**

View File

@@ -245,7 +245,7 @@ class YahooCollectorCN1d(YahooCollectorCN):
_path = self.save_dir.joinpath(f"sh{_index_code}.csv")
if _path.exists():
_old_df = pd.read_csv(_path)
df = _old_df.append(df, sort=False)
df = pd.concat([_old_df, df], sort=False)
df.to_csv(_path, index=False)
time.sleep(5)
@@ -317,24 +317,24 @@ class YahooCollectorIN1min(YahooCollectorIN):
class YahooCollectorBR(YahooCollector, ABC):
def retry(cls):
""""
The reason to use retry=2 is due to the fact that
Yahoo Finance unfortunately does not keep track of some
Brazilian stocks.
Therefore, the decorator deco_retry with retry argument
set to 5 will keep trying to get the stock data up to 5 times,
which makes the code to download Brazilians stocks very slow.
In future, this may change, but for now
I suggest to leave retry argument to 1 or 2 in
order to improve download speed.
"""
The reason to use retry=2 is due to the fact that
Yahoo Finance unfortunately does not keep track of some
Brazilian stocks.
To achieve this goal an abstract attribute (retry)
was added into YahooCollectorBR base class
Therefore, the decorator deco_retry with retry argument
set to 5 will keep trying to get the stock data up to 5 times,
which makes the code to download Brazilians stocks very slow.
In future, this may change, but for now
I suggest to leave retry argument to 1 or 2 in
order to improve download speed.
To achieve this goal an abstract attribute (retry)
was added into YahooCollectorBR base class
"""
raise NotImplementedError
def get_instrument_list(self):
logger.info("get BR stock symbols......")
symbols = get_br_stock_symbols() + [
@@ -404,7 +404,7 @@ class YahooNormalize(BaseNormalize):
.index
)
df.sort_index(inplace=True)
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {symbol_field_name}] = np.nan
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), list(set(df.columns) - {symbol_field_name})] = np.nan
change_series = YahooNormalize.calc_change(df, last_close)
# NOTE: The data obtained by Yahoo finance sometimes has exceptions

View File

@@ -78,7 +78,8 @@ REQUIRED = [
"dill",
"dataclasses;python_version<'3.7'",
"filelock",
"jinja2<3.1.0" # for passing the readthedocs workflow.
"jinja2<3.1.0", # for passing the readthedocs workflow.
"gym",
]
# Numpy include
@@ -134,7 +135,12 @@ setup(
"sphinx",
"sphinx_rtd_theme",
"pre-commit",
]
],
"rl": [
"tianshou",
"gym",
"torch",
],
},
include_package_data=True,
classifiers=[

10
tests/conftest.py Normal file
View File

@@ -0,0 +1,10 @@
import os
import sys
"""Ignore RL tests on non-linux platform."""
collect_ignore = []
if sys.platform != "linux":
for root, dirs, files in os.walk("rl"):
for file in files:
collect_ignore.append(os.path.join(root, file))

61
tests/misc/test_sepdf.py Normal file
View File

@@ -0,0 +1,61 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest
import numpy as np
import pandas as pd
from qlib.contrib.data.utils.sepdf import SepDataFrame
class SepDF(unittest.TestCase):
def to_str(self, obj):
return "".join(str(obj).split())
def test_index_data(self):
np.random.seed(42)
index = [
np.array(["bar", "bar", "baz", "baz", "foo", "foo", "qux", "qux"]),
np.array(["one", "two", "one", "two", "one", "two", "one", "two"]),
]
cols = [
np.repeat(np.array(["g1", "g2"]), 2),
np.arange(4),
]
df = pd.DataFrame(np.random.randn(8, 4), index=index, columns=cols)
sdf = SepDataFrame(df_dict={"g2": df["g2"]}, join=None)
sdf[("g2", 4)] = 3
sdf["g1"] = df["g1"]
exp = """
{'g2': 2 3 4
bar one 0.647689 1.523030 3
two 1.579213 0.767435 3
baz one -0.463418 -0.465730 3
two -1.724918 -0.562288 3
foo one -0.908024 -1.412304 3
two 0.067528 -1.424748 3
qux one -1.150994 0.375698 3
two -0.601707 1.852278 3, 'g1': 0 1
bar one 0.496714 -0.138264
two -0.234153 -0.234137
baz one -0.469474 0.542560
two 0.241962 -1.913280
foo one -1.012831 0.314247
two 1.465649 -0.225776
qux one -0.544383 0.110923
two -0.600639 -0.291694}
"""
self.assertEqual(self.to_str(sdf._df_dict), self.to_str(exp))
del df["g1"]
del df["g2"]
# it will not raise error, and df will be an empty dataframe
del sdf["g1"]
del sdf["g2"]
# sdf should support deleting all the columns
if __name__ == "__main__":
unittest.main()

4
tests/pytest.ini Normal file
View File

@@ -0,0 +1,4 @@
[pytest]
filterwarnings =
ignore:.*rng.randint:DeprecationWarning
ignore:.*Casting input x to numpy array:UserWarning

View File

@@ -0,0 +1,88 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import multiprocessing
import time
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from qlib.rl.utils.data_queue import DataQueue
class DummyDataset(Dataset):
def __init__(self, length):
self.length = length
def __getitem__(self, index):
assert 0 <= index < self.length
return pd.DataFrame(np.random.randint(0, 100, size=(index + 1, 4)), columns=list("ABCD"))
def __len__(self):
return self.length
def _worker(dataloader, collector):
# for i in range(3):
for i, data in enumerate(dataloader):
collector.put(len(data))
def _queue_to_list(queue):
result = []
while not queue.empty():
result.append(queue.get())
return result
def test_pytorch_dataloader():
dataset = DummyDataset(100)
dataloader = DataLoader(dataset, batch_size=None, num_workers=1)
queue = multiprocessing.Queue()
_worker(dataloader, queue)
assert len(set(_queue_to_list(queue))) == 100
def test_multiprocess_shared_dataloader():
dataset = DummyDataset(100)
with DataQueue(dataset, producer_num_workers=1) as data_queue:
queue = multiprocessing.Queue()
processes = []
for _ in range(3):
processes.append(multiprocessing.Process(target=_worker, args=(data_queue, queue)))
processes[-1].start()
for p in processes:
p.join()
assert len(set(_queue_to_list(queue))) == 100
def test_exit_on_crash_finite():
def _exit_finite():
dataset = DummyDataset(100)
with DataQueue(dataset, producer_num_workers=4) as data_queue:
time.sleep(3)
raise ValueError
# https://stackoverflow.com/questions/34506638/how-to-register-atexit-function-in-pythons-multiprocessing-subprocess
process = multiprocessing.Process(target=_exit_finite)
process.start()
process.join()
def test_exit_on_crash_infinite():
def _exit_infinite():
dataset = DummyDataset(100)
with DataQueue(dataset, repeat=-1, queue_maxsize=100) as data_queue:
time.sleep(3)
raise ValueError
process = multiprocessing.Process(target=_exit_infinite)
process.start()
process.join()
if __name__ == "__main__":
test_multiprocess_shared_dataloader()

249
tests/rl/test_finite_env.py Normal file
View File

@@ -0,0 +1,249 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from collections import Counter
import gym
import numpy as np
from tianshou.data import Batch, Collector
from tianshou.policy import BasePolicy
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from qlib.rl.utils.finite_env import (
LogWriter,
FiniteDummyVectorEnv,
FiniteShmemVectorEnv,
FiniteSubprocVectorEnv,
check_nan_observation,
generate_nan_observation,
)
_test_space = gym.spaces.Dict(
{
"sensors": gym.spaces.Dict(
{
"position": gym.spaces.Box(low=-100, high=100, shape=(3,)),
"velocity": gym.spaces.Box(low=-1, high=1, shape=(3,)),
"front_cam": gym.spaces.Tuple(
(gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)), gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)))
),
"rear_cam": gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)),
}
),
"ext_controller": gym.spaces.MultiDiscrete((5, 2, 2)),
"inner_state": gym.spaces.Dict(
{
"charge": gym.spaces.Discrete(100),
"system_checks": gym.spaces.MultiBinary(10),
"job_status": gym.spaces.Dict(
{
"task": gym.spaces.Discrete(5),
"progress": gym.spaces.Box(low=0, high=100, shape=()),
}
),
}
),
}
)
class FiniteEnv(gym.Env):
def __init__(self, dataset, num_replicas, rank):
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.loader = DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None)
self.iterator = None
self.observation_space = gym.spaces.Discrete(255)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
if self.iterator is None:
self.iterator = iter(self.loader)
try:
self.current_sample, self.step_count = next(self.iterator)
self.current_step = 0
return self.current_sample
except StopIteration:
self.iterator = None
return generate_nan_observation(self.observation_space)
def step(self, action):
self.current_step += 1
assert self.current_step <= self.step_count
return (
0,
1.0,
self.current_step >= self.step_count,
{"sample": self.current_sample, "action": action, "metric": 2.0},
)
class FiniteEnvWithComplexObs(FiniteEnv):
def __init__(self, dataset, num_replicas, rank):
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.loader = DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None)
self.iterator = None
self.observation_space = gym.spaces.Discrete(255)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
if self.iterator is None:
self.iterator = iter(self.loader)
try:
self.current_sample, self.step_count = next(self.iterator)
self.current_step = 0
return _test_space.sample()
except StopIteration:
self.iterator = None
return generate_nan_observation(self.observation_space)
def step(self, action):
self.current_step += 1
assert self.current_step <= self.step_count
return (
_test_space.sample(),
1.0,
self.current_step >= self.step_count,
{"sample": _test_space.sample(), "action": action, "metric": 2.0},
)
class DummyDataset(Dataset):
def __init__(self, length):
self.length = length
self.episodes = [3 * i % 5 + 1 for i in range(self.length)]
def __getitem__(self, index):
assert 0 <= index < self.length
return index, self.episodes[index]
def __len__(self):
return self.length
class AnyPolicy(BasePolicy):
def forward(self, batch, state=None):
return Batch(act=np.stack([1] * len(batch)))
def learn(self, batch):
pass
def _finite_env_factory(dataset, num_replicas, rank, complex=False):
if complex:
return lambda: FiniteEnvWithComplexObs(dataset, num_replicas, rank)
return lambda: FiniteEnv(dataset, num_replicas, rank)
class MetricTracker(LogWriter):
def __init__(self, length):
super().__init__()
self.counter = Counter()
self.finished = set()
self.length = length
def on_env_step(self, env_id, obs, rew, done, info):
assert rew == 1.0
index = info["sample"]
if done:
# assert index not in self.finished
self.finished.add(index)
self.counter[index] += 1
def validate(self):
assert len(self.finished) == self.length
for k, v in self.counter.items():
assert v == k * 3 % 5 + 1
class DoNothingTracker(LogWriter):
def on_env_step(self, *args, **kwargs):
pass
def test_finite_dummy_vector_env():
length = 100
dataset = DummyDataset(length)
envs = FiniteDummyVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)])
envs._collector_guarded = True
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
for _ in range(1):
envs._logger = [MetricTracker(length)]
try:
test_collector.collect(n_step=10**18)
except StopIteration:
envs._logger[0].validate()
def test_finite_shmem_vector_env():
length = 100
dataset = DummyDataset(length)
envs = FiniteShmemVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)])
envs._collector_guarded = True
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
for _ in range(1):
envs._logger = [MetricTracker(length)]
try:
test_collector.collect(n_step=10**18)
except StopIteration:
envs._logger[0].validate()
def test_finite_subproc_vector_env():
length = 100
dataset = DummyDataset(length)
envs = FiniteSubprocVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)])
envs._collector_guarded = True
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
for _ in range(1):
envs._logger = [MetricTracker(length)]
try:
test_collector.collect(n_step=10**18)
except StopIteration:
envs._logger[0].validate()
def test_nan():
assert check_nan_observation(generate_nan_observation(_test_space))
assert not check_nan_observation(_test_space.sample())
def test_finite_dummy_vector_env_complex():
length = 100
dataset = DummyDataset(length)
envs = FiniteDummyVectorEnv(
DoNothingTracker(), [_finite_env_factory(dataset, 5, i, complex=True) for i in range(5)]
)
envs._collector_guarded = True
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
try:
test_collector.collect(n_step=10**18)
except StopIteration:
pass
def test_finite_shmem_vector_env_complex():
length = 100
dataset = DummyDataset(length)
envs = FiniteShmemVectorEnv(
DoNothingTracker(), [_finite_env_factory(dataset, 5, i, complex=True) for i in range(5)]
)
envs._collector_guarded = True
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
try:
test_collector.collect(n_step=10**18)
except StopIteration:
pass

156
tests/rl/test_logger.py Normal file
View File

@@ -0,0 +1,156 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from random import randint, choice
from pathlib import Path
import re
import gym
import numpy as np
import pandas as pd
from gym import spaces
from tianshou.data import Collector, Batch
from tianshou.policy import BasePolicy
from qlib.log import set_log_with_config
from qlib.config import C
from qlib.constant import INF
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.simulator import Simulator
from qlib.rl.utils.data_queue import DataQueue
from qlib.rl.utils.env_wrapper import InfoDict, EnvWrapper
from qlib.rl.utils.log import LogLevel, LogCollector, CsvWriter, ConsoleWriter
from qlib.rl.utils.finite_env import vectorize_env
class SimpleEnv(gym.Env[int, int]):
def __init__(self):
self.logger = LogCollector()
self.observation_space = gym.spaces.Discrete(2)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
self.step_count = 0
return 0
def step(self, action: int):
self.logger.reset()
self.logger.add_scalar("reward", 42.0)
self.logger.add_scalar("a", randint(1, 10))
self.logger.add_array("b", pd.DataFrame({"a": [1, 2], "b": [3, 4]}))
if self.step_count >= 3:
done = choice([False, True])
else:
done = False
if 2 <= self.step_count <= 3:
self.logger.add_scalar("c", randint(11, 20))
self.step_count += 1
return 1, 42.0, done, InfoDict(log=self.logger.logs(), aux_info={})
class AnyPolicy(BasePolicy):
def forward(self, batch, state=None):
return Batch(act=np.stack([1] * len(batch)))
def learn(self, batch):
pass
def test_simple_env_logger(caplog):
set_log_with_config(C.logging_config)
for venv_cls_name in ["dummy", "shmem", "subproc"]:
writer = ConsoleWriter()
csv_writer = CsvWriter(Path(__file__).parent / ".output")
venv = vectorize_env(lambda: SimpleEnv(), venv_cls_name, 4, [writer, csv_writer])
with venv.collector_guard():
collector = Collector(AnyPolicy(), venv)
collector.collect(n_episode=30)
output_file = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
assert output_file.columns.tolist() == ["reward", "a", "c"]
assert len(output_file) >= 30
line_counter = 0
for line in caplog.text.splitlines():
line = line.strip()
if line:
line_counter += 1
assert re.match(r".*reward 42\.0000 \(42.0000\) a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
assert line_counter >= 3
class SimpleSimulator(Simulator[int, float, float]):
def __init__(self, initial: int, **kwargs) -> None:
self.initial = float(initial)
def step(self, action: float) -> None:
import torch
self.initial += action
self.env.logger.add_scalar("test_a", torch.tensor(233.0))
self.env.logger.add_scalar("test_b", np.array(200))
def get_state(self) -> float:
return self.initial
def done(self) -> bool:
return self.initial % 1 > 0.5
class DummyStateInterpreter(StateInterpreter[float, float]):
def interpret(self, state: float) -> float:
return state
@property
def observation_space(self) -> spaces.Box:
return spaces.Box(0, np.inf, shape=(), dtype=np.float32)
class DummyActionInterpreter(ActionInterpreter[float, int, float]):
def interpret(self, state: float, action: int) -> float:
return action / 100
@property
def action_space(self) -> spaces.Box:
return spaces.Discrete(5)
class RandomFivePolicy(BasePolicy):
def forward(self, batch, state=None):
return Batch(act=np.random.randint(5, size=len(batch)))
def learn(self, batch):
pass
def test_logger_with_env_wrapper():
with DataQueue(list(range(20)), shuffle=False) as data_iterator:
env_wrapper_factory = lambda: EnvWrapper(
SimpleSimulator,
DummyStateInterpreter(),
DummyActionInterpreter(),
data_iterator,
logger=LogCollector(LogLevel.DEBUG),
)
# loglevel can be debug here because metrics can all dump into csv
# otherwise, csv writer might crash
csv_writer = CsvWriter(Path(__file__).parent / ".output", loglevel=LogLevel.DEBUG)
venv = vectorize_env(env_wrapper_factory, "shmem", 4, csv_writer)
with venv.collector_guard():
collector = Collector(RandomFivePolicy(), venv)
collector.collect(n_episode=INF * len(venv))
output_df = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
assert len(output_df) == 20
# obs has a increasing trend
assert output_df["obs"].to_numpy()[:10].sum() < output_df["obs"].to_numpy()[10:].sum()
assert (output_df["test_a"] == 233).all()
assert (output_df["test_b"] == 200).all()
assert "steps_per_episode" in output_df and "reward" in output_df

View File

@@ -0,0 +1,308 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
from functools import partial
from pathlib import Path
from typing import NamedTuple
import numpy as np
import pandas as pd
import pytest
import torch
from tianshou.data import Batch
from qlib.backtest import Order
from qlib.config import C
from qlib.log import set_log_with_config
from qlib.rl.data import pickle_styled
from qlib.rl.entries.test import backtest
from qlib.rl.order_execution import *
from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
DATA_ROOT_DIR = Path(__file__).parent.parent / ".data" / "rl" / "intraday_saoe"
DATA_DIR = DATA_ROOT_DIR / "us"
BACKTEST_DATA_DIR = DATA_DIR / "backtest"
FEATURE_DATA_DIR = DATA_DIR / "processed"
ORDER_DIR = DATA_DIR / "order" / "valid_bidir"
CN_DATA_DIR = DATA_ROOT_DIR / "cn"
CN_BACKTEST_DATA_DIR = CN_DATA_DIR / "backtest"
CN_FEATURE_DATA_DIR = CN_DATA_DIR / "processed"
CN_ORDER_DIR = CN_DATA_DIR / "order" / "test"
CN_POLICY_WEIGHTS_DIR = CN_DATA_DIR / "weights"
def test_pickle_data_inspect():
data = pickle_styled.load_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0)
assert len(data) == 390
data = pickle_styled.load_intraday_processed_data(
DATA_DIR / "processed", "AAL", "2013-12-11", 5, data.get_time_index()
)
assert len(data.today) == len(data.yesterday) == 390
def test_simulator_first_step():
order = Order("AAL", 30.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
state = simulator.get_state()
assert state.cur_time == pd.Timestamp("2013-12-11 09:30:00")
assert state.position == 30.0
simulator.step(15.0)
state = simulator.get_state()
assert len(state.history_exec) == 30
assert state.history_exec.index[0] == pd.Timestamp("2013-12-11 09:30:00")
assert state.history_exec["market_volume"].iloc[0] == 450072.0
assert abs(state.history_exec["market_price"].iloc[0] - 25.370001) < 1e-4
assert (state.history_exec["amount"] == 0.5).all()
assert (state.history_exec["deal_amount"] == 0.5).all()
assert abs(state.history_exec["trade_price"].iloc[0] - 25.370001) < 1e-4
assert abs(state.history_exec["trade_value"].iloc[0] - 12.68500) < 1e-4
assert state.history_exec["position"].iloc[0] == 29.5
assert state.history_exec["ffr"].iloc[0] == 1 / 60
assert state.history_steps["market_volume"].iloc[0] == 5041147.0
assert state.history_steps["amount"].iloc[0] == 15.0
assert state.history_steps["deal_amount"].iloc[0] == 15.0
assert state.history_steps["ffr"].iloc[0] == 0.5
assert (
state.history_steps["pa"].iloc[0]
== (state.history_steps["trade_price"].iloc[0] / simulator.twap_price - 1) * 10000
)
assert state.position == 15.0
assert state.cur_time == pd.Timestamp("2013-12-11 10:00:00")
def test_simulator_stop_twap():
order = Order("AAL", 13.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
for _ in range(13):
simulator.step(1.0)
state = simulator.get_state()
assert len(state.history_exec) == 390
assert (state.history_exec["deal_amount"] == 13 / 390).all()
assert state.history_steps["position"].iloc[0] == 12 and state.history_steps["position"].iloc[-1] == 0
assert (state.metrics["ffr"] - 1) < 1e-3
assert abs(state.metrics["market_price"] - state.backtest_data.get_deal_price().mean()) < 1e-4
assert np.isclose(state.metrics["market_volume"], state.backtest_data.get_volume().sum())
assert state.position == 0.0
assert abs(state.metrics["trade_price"] - state.metrics["market_price"]) < 1e-4
assert abs(state.metrics["pa"]) < 1e-2
assert simulator.done()
def test_simulator_stop_early():
order = Order("AAL", 1.0, 1, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
with pytest.raises(ValueError):
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
simulator.step(2.0)
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
simulator.step(1.0)
with pytest.raises(AssertionError):
simulator.step(1.0)
def test_simulator_start_middle():
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
assert len(simulator.ticks_for_order) == 330
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
simulator.step(2.0)
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:30:00")
for _ in range(10):
simulator.step(1.0)
simulator.step(2.0)
assert len(simulator.history_exec) == 330
assert simulator.done()
assert abs(simulator.history_exec["amount"].iloc[-1] - (1 + 2 / 15)) < 1e-4
assert abs(simulator.metrics["ffr"] - 1) < 1e-4
def test_interpreter():
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
assert len(simulator.ticks_for_order) == 330
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
# emulate a env status
class EmulateEnvWrapper(NamedTuple):
status: EnvWrapperStatus
interpreter = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5)
interpreter_step = CurrentStepStateInterpreter(13)
interpreter_action = CategoricalActionInterpreter(20)
interpreter_action_twap = TwapRelativeActionInterpreter()
wrapper_status_kwargs = dict(initial_state=order, obs_history=[], action_history=[], reward_history=[])
# first step
interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=0, done=False, **wrapper_status_kwargs))
obs = interpreter(simulator.get_state())
assert obs["cur_tick"] == 45
assert obs["cur_step"] == 0
assert obs["position"] == 15.0
assert obs["position_history"][0] == 15.0
assert all(np.sum(obs["data_processed"][i]) != 0 for i in range(45))
assert np.sum(obs["data_processed"][45:]) == 0
assert obs["data_processed_prev"].shape == (390, 5)
# first step: second interpreter
interpreter_step.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=0, done=False, **wrapper_status_kwargs))
obs = interpreter_step(simulator.get_state())
assert obs["acquiring"] == 1
assert obs["position"] == 15.0
# second step
simulator.step(5.0)
interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=1, done=False, **wrapper_status_kwargs))
obs = interpreter(simulator.get_state())
assert obs["cur_tick"] == 60
assert obs["cur_step"] == 1
assert obs["position"] == 10.0
assert obs["position_history"][:2].tolist() == [15.0, 10.0]
assert all(np.sum(obs["data_processed"][i]) != 0 for i in range(60))
assert np.sum(obs["data_processed"][60:]) == 0
# second step: action
action = interpreter_action(simulator.get_state(), 1)
assert action == 15 / 20
interpreter_action_twap.env = EmulateEnvWrapper(
status=EnvWrapperStatus(cur_step=1, done=False, **wrapper_status_kwargs)
)
action = interpreter_action_twap(simulator.get_state(), 1.5)
assert action == 1.5
# fast-forward
for _ in range(10):
simulator.step(0.0)
# last step
simulator.step(5.0)
interpreter.env = EmulateEnvWrapper(
status=EnvWrapperStatus(cur_step=12, done=simulator.done(), **wrapper_status_kwargs)
)
assert interpreter.env.status["done"]
obs = interpreter(simulator.get_state())
assert obs["cur_tick"] == 375
assert obs["cur_step"] == 12
assert obs["position"] == 0.0
assert obs["position_history"][1:11].tolist() == [10.0] * 10
assert all(np.sum(obs["data_processed"][i]) != 0 for i in range(375))
assert np.sum(obs["data_processed"][375:]) == 0
def test_network_sanity():
# we won't check the correctness of networks here
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 9:30:00"), pd.Timestamp("2013-12-11 15:59:59"))
simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR)
assert len(simulator.ticks_for_order) == 390
class EmulateEnvWrapper(NamedTuple):
status: EnvWrapperStatus
interpreter = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5)
action_interp = CategoricalActionInterpreter(13)
wrapper_status_kwargs = dict(initial_state=order, obs_history=[], action_history=[], reward_history=[])
network = Recurrent(interpreter.observation_space)
policy = PPO(network, interpreter.observation_space, action_interp.action_space, 1e-3)
for i in range(14):
interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=i, done=False, **wrapper_status_kwargs))
obs = interpreter(simulator.get_state())
batch = Batch(obs=[obs])
output = policy(batch)
assert 0 <= output["act"].item() <= 13
if i < 13:
simulator.step(1.0)
else:
assert obs["cur_tick"] == 389
assert obs["cur_step"] == 12
assert obs["position_history"][-1] == 3
@pytest.mark.parametrize("finite_env_type", ["dummy", "subproc", "shmem"])
def test_twap_strategy(finite_env_type):
set_log_with_config(C.logging_config)
orders = pickle_styled.load_orders(ORDER_DIR)
assert len(orders) == 248
state_interp = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5)
action_interp = TwapRelativeActionInterpreter()
policy = AllOne(state_interp.observation_space, action_interp.action_space)
csv_writer = CsvWriter(Path(__file__).parent / ".output")
backtest(
partial(SingleAssetOrderExecution, data_dir=BACKTEST_DATA_DIR, ticks_per_step=30),
state_interp,
action_interp,
orders,
policy,
[ConsoleWriter(total_episodes=len(orders)), csv_writer],
concurrency=4,
finite_env_type=finite_env_type,
)
metrics = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
assert len(metrics) == 248
assert np.isclose(metrics["ffr"].mean(), 1.0)
assert np.isclose(metrics["pa"].mean(), 0.0)
assert np.allclose(metrics["pa"], 0.0, atol=2e-3)
def test_cn_ppo_strategy():
set_log_with_config(C.logging_config)
# The data starts with 9:31 and ends with 15:00
orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58"))
assert len(orders) == 40
state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6)
action_interp = CategoricalActionInterpreter(4)
network = Recurrent(state_interp.observation_space)
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
policy.load_state_dict(torch.load(CN_POLICY_WEIGHTS_DIR / "ppo_recurrent_30min.pth", map_location="cpu"))
csv_writer = CsvWriter(Path(__file__).parent / ".output")
backtest(
partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
state_interp,
action_interp,
orders,
policy,
[ConsoleWriter(total_episodes=len(orders)), csv_writer],
concurrency=4,
)
metrics = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
assert len(metrics) == len(orders)
assert np.isclose(metrics["ffr"].mean(), 1.0)
assert np.isclose(metrics["pa"].mean(), -16.21578303474833)
assert np.isclose(metrics["market_price"].mean(), 58.68277690875527)
assert np.isclose(metrics["trade_price"].mean(), 58.76063985000002)

View File

@@ -1,28 +1,64 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
import sys
import qlib
from qlib.data import D
import shutil
import unittest
import pandas as pd
import baostock as bs
from pathlib import Path
from qlib.data import D
from scripts.get_data import GetData
from scripts.dump_pit import DumpPitData
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts/data_collector/pit")))
from collector import Run
pd.set_option("display.width", 1000)
pd.set_option("display.max_columns", None)
DATA_DIR = Path(__file__).parent.joinpath("test_pit_data")
SOURCE_DIR = DATA_DIR.joinpath("stock_data/source")
SOURCE_DIR.mkdir(exist_ok=True, parents=True)
QLIB_DIR = DATA_DIR.joinpath("qlib_data")
QLIB_DIR.mkdir(exist_ok=True, parents=True)
class TestPIT(unittest.TestCase):
"""
NOTE!!!!!!
The assert of this test assumes that users follows the cmd below and only download 2 stock.
1. `python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn`
2. `python scripts/data_collector/pit/collector.py download_data --source_dir ~/.qlib/stock_data/source/pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_regex "^(600519|000725).*"`
3. `python scripts/data_collector/pit/collector.py normalize_data --interval quarterly --source_dir ~/.qlib/stock_data/source/pit --normalize_dir ~/.qlib/stock_data/source/pit_normalized`
4. `python scripts/dump_pit.py dump --csv_path ~/.qlib/stock_data/source/pit_normalized --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly`
"""
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree(str(DATA_DIR.resolve()))
@classmethod
def setUpClass(cls) -> None:
cn_data_dir = str(QLIB_DIR.joinpath("cn_data").resolve())
pit_dir = str(SOURCE_DIR.joinpath("pit").resolve())
pit_normalized_dir = str(SOURCE_DIR.joinpath("pit_normalized").resolve())
GetData().qlib_data(name="qlib_data_simple", target_dir=cn_data_dir, region="cn")
bs.login()
Run(
source_dir=pit_dir,
interval="quarterly",
).download_data(start="2000-01-01", end="2020-01-01", symbol_regex="^(600519|000725).*")
Run(
source_dir=pit_dir,
normalize_dir=pit_normalized_dir,
interval="quarterly",
).normalize_data()
bs.logout()
DumpPitData(
csv_path=pit_normalized_dir,
qlib_dir=cn_data_dir,
).dump(interval="quarterly")
def setUp(self):
# qlib.init(kernels=1) # NOTE: set kernel to 1 to make it debug easier
qlib.init()
provider_uri = str(QLIB_DIR.joinpath("cn_data").resolve())
qlib.init(provider_uri=provider_uri)
def to_str(self, obj):
return "".join(str(obj).split())
@@ -66,7 +102,7 @@ class TestPIT(unittest.TestCase):
data["$close"] = 1 # in case of different dataset gives different values
expect = """
P($$roewa_q) P($$yoyni_q) $close
instrument datetime
instrument datetime
sh600519 2019-01-02 0.25522 0.243892 1
2019-01-03 0.25522 0.243892 1
2019-01-04 0.25522 0.243892 1
@@ -78,7 +114,7 @@ class TestPIT(unittest.TestCase):
2019-07-17 NaN NaN 1
2019-07-18 NaN NaN 1
2019-07-19 NaN NaN 1
[266 rows x 3 columns]
"""
self.check_same(data, expect)
@@ -191,7 +227,7 @@ class TestPIT(unittest.TestCase):
data = D.features(instruments, fields, start_time="2019-01-01", end_time="2020-01-01", freq="day")
except_data = """
P($$roewa_q) P($$yoyni_q) P(($$roewa_q / $$yoyni_q) / Ref($$roewa_q / $$yoyni_q, 1) - 1) P(Sum($$yoyni_q, 4)) $close P($$roewa_q) * $close
instrument datetime
instrument datetime
sh600519 2019-01-02 0.255220 0.243892 1.484224 1.661578 63.595333 16.230801
2019-01-03 0.255220 0.243892 1.484224 1.661578 62.641907 15.987467
2019-01-04 0.255220 0.243892 1.484224 1.661578 63.915985 16.312637
@@ -203,7 +239,7 @@ class TestPIT(unittest.TestCase):
2019-12-27 0.255819 0.219821 0.677052 1.081693 125.307404 32.056015
2019-12-30 0.255819 0.219821 0.677052 1.081693 127.763992 32.684456
2019-12-31 0.255819 0.219821 0.677052 1.081693 127.462303 32.607277
[244 rows x 6 columns]
"""
self.check_same(data, except_data)
@@ -219,7 +255,7 @@ class TestPIT(unittest.TestCase):
data = D.features(instruments, fields, start_time="2018-04-28", end_time="2019-07-19", freq="day")
except_data = """
PRef($$roewa_q, 201902) PRef($$yoyni_q, 201801) P($$roewa_q) P($$roewa_q) / PRef($$roewa_q, 201801)
instrument datetime
instrument datetime
sh600519 2018-05-02 NaN 0.395075 0.088887 1.000000
2018-05-03 NaN 0.395075 0.088887 1.000000
2018-05-04 NaN 0.395075 0.088887 1.000000
@@ -231,7 +267,7 @@ class TestPIT(unittest.TestCase):
2019-07-17 0.000000 0.395075 0.000000 0.000000
2019-07-18 0.175322 0.395075 0.175322 1.972414
2019-07-19 0.175322 0.395075 0.175322 1.972414
[299 rows x 4 columns]
"""
self.check_same(data, except_data)