1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

init commit

This commit is contained in:
Young
2020-09-22 01:43:21 +00:00
parent aa51e5aad3
commit 99ebd87cba
131 changed files with 20218 additions and 0 deletions

33
.gitignore vendored Normal file
View File

@@ -0,0 +1,33 @@
# https://github.com/github/gitignore/blob/master/Python.gitignore
__pycache__/
*.pyc
*.so
*.ipynb
.ipynb_checkpoints
_build
build/
dist/
*.pkl
*.hd5
*.csv
.env
.vim
.nvimrc
.vscode
qlib/data/_libs/expanding.cpp
qlib/data/_libs/rolling.cpp
examples/estimator/estimator_example/
*.egg-info/
# special software
mlruns/
tags

152
CHANGES.rst Normal file
View File

@@ -0,0 +1,152 @@
Changelog
====================
Here you can see the full list of changes between each QLib release.
Version 0.1.0
--------------------
This is the initial release of QLib library.
Version 0.1.1
--------------------
Performance optimize. Add more features and operators.
Version 0.1.2
--------------------
- Support operator syntax. Now ``High() - Low()`` is equivalent to ``Sub(High(), Low())``.
- Add more technical indicators.
Version 0.1.3
--------------------
Bug fix and add instruments filtering mechanism.
Version 0.2.0
--------------------
- Redesign ``LocalProvider`` database format for performance improvement.
- Support load features as string fields.
- Add scripts for database construction.
- More operators and technical indicators.
Version 0.2.1
--------------------
- Support registering user-defined ``Provider``.
- Support use operators in string format, e.g. ``['Ref($close, 1)']`` is valid field format.
- Support dynamic fields in ``$some_field`` format. And exising fields like ``Close()`` may be deprecated in the future.
Version 0.2.2
--------------------
- Add ``disk_cache`` for reusing features (enabled by default).
- Add ``qlib.contrib`` for experimental model construction and evaluation.
Version 0.2.3
--------------------
- Add ``backtest`` module
- Decoupling the Strategy, Account, Position, Exchange from the backtest module
Version 0.2.4
--------------------
- Add ``profit attribution`` module
- Add ``rick_control`` and ``cost_control`` strategies
Version 0.3.0
--------------------
- Add ``estimator`` module
Version 0.3.1
--------------------
- Add ``filter`` module
Version 0.3.2
--------------------
- Add real price trading, if the ``factor`` field in the data set is incomplete, use ``adj_price`` trading
- Refactor ``handler`` ``launcher`` ``trainer`` code
- Support ``backtest`` configuration parameters in the configuration file
- Fix bug in position ``amount`` is 0
- Fix bug of ``filter`` module
Version 0.3.3
-------------------
- Fix bug of ``filter`` module
Version 0.3.4
--------------------
- Support for ``finetune model``
- Refactor ``fetcher`` code
Version 0.3.5
--------------------
- Support multi-label training, you can provide multiple label in ``handler``. (But LightGBM doesn't support due to the algorithm itself)
- Refactor ``handler`` code, dataset.py is no longer used, and you can deploy your own labels and features in ``feature_label_config``
- Handler only offer DataFrame. Also, ``trainer`` and model.py only receive DataFrame
- Change ``split_rolling_data``, we roll the data on market calender now, not on normal date
- Move some date config from ``handler`` to ``trainer``
Version 0.4.0
--------------------
- Add `data` package that holds all data-related codes
- Reform the data provider structure
- Create a server for data centralized management `qlib-server<https://amc-msra.visualstudio.com/trading-algo/_git/qlib-server>`_
- Add a `ClientProvider` to work with server
- Add a pluggable cache mechanism
- Add a recursive backtracking algorithm to inspect the furthest reference date for an expression
.. note::
The ``D.instruments`` function does not support ``start_time``, ``end_time``, and ``as_list`` parameters, if you want to get the results of previous versions of ``D.instruments``, you can do this:
>>> from qlib.data import D
>>> instruments = D.instruments(market='csi500')
>>> D.list_instruments(instruments=instruments, start_time='2015-01-01', end_time='2016-02-15', as_list=True)
Version 0.4.1
--------------------
- Add support Windows
- Fix ``instruments`` type bug
- Fix ``features`` is empty bug(It will cause failure in updating)
- Fix ``cache`` lock and update bug
- Fix use the same cache for the same field (the original space will add a new cache)
- Change "logger handler" from config
- Change model load support 0.4.0 later
- The default value of the ``method`` parameter of ``risk_analysis`` function is changed from **ci** to **si**
Version 0.4.2
--------------------
- Refactor DataHandler
- Add ``ALPHA360`` DataHandler
Version 0.4.3
--------------------
- Implementing Online Inference and Trading Framework
- Refactoring The interfaces of backtest and strategy module.
Version 0.4.4
--------------------
- Optimize cache generation performance
- Add report module
- Fix bug when using ``ServerDatasetCache`` offline.
- In the previous version of ``long_short_backtest``, there is a case of ``np.nan`` in long_short. The current version ``0.4.4`` has been fixed, so ``long_short_backtest`` will be different from the previous version.
- In the ``0.4.2`` version of ``risk_analysis`` function, ``N`` is ``250``, and ``N`` is ``252`` from ``0.4.3``, so ``0.4.2`` is ``0.002122`` smaller than the ``0.4.3`` the backtest result is slightly different between ``0.4.2`` and ``0.4.3``.
- refactor the argument of backtest function.
- **NOTE**:
- The default arguments of topk margin strategy is changed. Please pass the arguments explicitly if you want to get the same backtest result as previous version.
- The TopkWeightStrategy is changed slightly. It will try to sell the stocks more than ``topk``. (The backtest result of TopkAmountStrategy remains the same)
- The margin ratio mechanism is supported in the Topk Margin strategies.
Version 0.4.5
--------------------
- Add multi-kernel implementation for both client and server.
- Support a new way to load data from client which skips dataset cache.
- Change the default dataset method from single kernel implementation to multi kernel implementation.
- Accelerate the high frequency data reading by optimizing the relative modules.
- Support a new method to write config file by using dict.
Version 0.4.6
--------------------
- Some bugs are fixed
- The default config in `Version 0.4.5` is not friendly to daily frequency data.
- Backtest error in TopkWeightStrategy when `WithInteract=True`.

196
README.md
View File

@@ -1,3 +1,199 @@
Qlib is an AI-oriented quantitative investment platform, which aims to realize the potential, empower the research, and create the value of AI technologies in quantitative investment.
With Qlib, you can easily apply your favorite model to create a better Quant investment strategy.
- [Framework of Qlib](#framework-of-qlib)
- [Quick start](#quick-start)
- [Installation](#installation)
- [Get Data](#get-data)
- [Auto Quant research workflow with _estimator_](#auto-quant-research-workflow-with-estimator)
- [Customized Quant research workflow by code](#customized-quant-research-workflow-by-code)
- [More About Qlib](#more-about-qlib)
- [Offline mode and online mode](#offline-mode-and-online-mode)
- [Performance of Qlib Data Server](#performance-of-qlib-data-server)
- [Contributing](#contributing)
# Framework of Qlib
![framework](docs/_static/img/framework.png)
At the module level, Qlib is a platform that consists of the above components. Each component is loose-coupling and can be used stand-alone.
| Name | Description |
| ------ | ----- |
| _Data layer_ | _DataServer_ focus on providing high performance infrastructure for user to retrieve and get raw data. _DataEnhancement_ will preprocess the data and provide the best dataset to be fed in to the models |
| _Interday Model_ | _Interday model_ focus on producing forecasting signals(aka. _alpha_). Models are trained by _Model Creator_ and managed by _Model Manager_. User could choose one or multiple models for forecasting. Multiple models could be combined with _Ensemble_ module |
| _Interday Strategy_ | _Portfolio Generator_ will take forecasting signals as input and output the orders based on current position to achieve target portfolio |
| _Intraday Trading_ | _Order Executor_ is responsible for executing orders produced by _Interday Strategy_ and returning the executed results. |
| _Analysis_ | User could get detailed analysis report of forecasting signal and portfolio in this part. |
* The modules with hand-drawn style is under development and will be released in the future.
* The modules with dashed border is highly user-customizable and extendible.
# Quick start
## Installation
To install Qlib from source you need _Cython_ in addition to the normal dependencies above:
```bash
pip install numpy
pip install --upgrade cython
```
Clone the repository and then run:
```bash
python setup.py install
```
## Get Data
- Load and prepare the Data: execute the following command to load the stock data:
```bash
python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
```
<!--
- Run the initialization code and get stock data:
```python
import qlib
from qlib.data import D
from qlib.config import REG_CN
# Initialization
mount_path = "~/.qlib/qlib_data/cn_data" # target_dir
qlib.init(mount_path=mount_path, region=REG_CN)
# Get stock data by Qlib
# Load trading calendar with the given time range and frequency
print(D.calendar(start_time='2010-01-01', end_time='2017-12-31', freq='day')[:2])
# Parse a given market name into a stockpool config
instruments = D.instruments('csi500')
print(D.list_instruments(instruments=instruments, start_time='2010-01-01', end_time='2017-12-31', as_list=True)[:6])
# Load features of certain instruments in given time range
instruments = ['SH600000']
fields = ['$close', '$volume', 'Ref($close, 1)', 'Mean($close, 3)', '$high-$low']
print(D.features(instruments, fields, start_time='2010-01-01', end_time='2017-12-31', freq='day').head())
```
-->
## Auto Quant research workflow with _estimator_
Qlib provides a tool named `estimator` to run whole workflow automatically(including building dataset, train models, backtest, analysis)
1. Run _estimator_ (_config.yaml_ for: [estimator_config.yaml](examples/estimator/estimator_config.yaml)):
```bash
cd examples # Avoid running program under the directory contains `qlib`
estimator -c estimator/estimator_config.yaml
```
Estimator result:
```bash
risk
sub_bench mean 0.000662
std 0.004487
annual 0.166720
sharpe 2.340526
mdd -0.080516
sub_cost mean 0.000577
std 0.004482
annual 0.145392
sharpe 2.043494
mdd -0.083584
```
See the full documents for [Use _Estimator_ to Start An Experiment](TODO:URL).
2. Analysis
Run `examples/estimator/analyze_from_estimator.ipynb` in `jupyter notebook`
1. forecasting signal analysis
- Cumulative Return
![Cumulative Return](docs/_static/img/analysis/analysis_model_cumulative_return.png)
![long_short](docs/_static/img/analysis/analysis_model_long_short.png)
- Information Coefficient(IC)
![Information Coefficient](docs/_static/img/analysis/analysis_model_IC.png)
![Monthly IC](docs/_static/img/analysis/analysis_model_monthly_IC.png)
![IC](docs/_static/img/analysis/analysis_model_NDQ.png)
- Auto Correlation
![Auto Correlation](docs/_static/img/analysis/analysis_model_auto_correlation.png)
2. portfolio analysis
- Report
![Report](docs/_static/img/analysis/report.png)
<!--
- Score IC
![Score IC](docs/_static/img/score_ic.png)
- Cumulative Return
![Cumulative Return](docs/_static/img/cumulative_return.png)
- Risk Analysis
![Risk Analysis](docs/_static/img/risk_analysis.png)
- Rank Label
![Rank Label](docs/_static/img/rank_label.png)
-->
## Customized Quant research workflow by code
Automatic workflow may not suite the research workflow of all Quant researchers. To support flexible Quant research workflow, Qlib also provide modularized interface to allow researchers to build their own workflow. [Here](TODO_URL) is a demo for customized Quant research workflow by code
# More About Qlib
The detailed documents are organized in [docs](docs).
[Sphinx](http://www.sphinx-doc.org) and the readthedocs theme is required to build the documentation in html formats.
```bash
cd docs/
conda install sphinx sphinx_rtd_theme -y
# Otherwise, you can install them with pip
# pip install sphinx sphinx_rtd_theme
make html
```
You can also view the [latest document](TODO_URL) online directly.
The roadmap is managed as a [github project](https://github.com/microsoft/qlib/projects/1).
## Offline mode and online mode
The data server of Qlib can both deployed as offline mode and online mode. The default mode is offline mode.
Under offline mode, the data will be deployed locally.
Under online mode, the data will be deployed as a shared data service. The data and their cache will be shared by clients. The data retrieving performance is expected to be improved due to a higher rate of cache hits. It will use less disk space, too. The documents of the online mode can be found in [Qlib-Server](TODO_link). The online mode can be deployed automatically with [Azure CLI based scripts](TODO_link)
## Performance of Qlib Data Server
The performance of data processing is important to data-driven methods like AI technologies. As an AI-oriented platform, Qlib provides a solution for data storage and data processing. To demonstrate the performance of Qlib, We
compare Qlib with several other solutions.
We evaluate the performance of several solutions by completing the same task,
which creates a dataset(14 features/factors) from the basic OHLCV daily data of a stock market(800 stocks each day from 2007 to 2020). The task involves data queries and processing.
| | HDF5 | MySQL | MongoDB | InfluxDB | Qlib -E -D | Qlib +E -D | Qlib +E +D |
| -- | ------ | ------ | -------- | --------- | ----------- | ------------ | ----------- |
| Total (1CPU) (seconds) | 184.4±3.7 | 365.3±7.5 | 253.6±6.7 | 368.2±3.6 | 147.0±8.8 | 47.6±1.0 | **7.4±0.3** |
| Total (64CPU) (seconds) | | | | | 8.8±0.6 | **4.2±0.2** | |
* `+(-)E` indicates with(out) `ExpressionCache`
* `+(-)D` indicates with(out) `DatasetCache`
Most general-purpose databases take too much time on loading data. After looking into the underlying implementation, we find that data go through too many layers of interfaces and unnecessary format transformations in general-purpose database solutions.
Such overheads greatly slow down the data loading process.
Qlib data are stored in a compact format, which is efficient to be combined into arrays for scientific computation.
# Contributing

20
docs/Makefile Normal file
View File

@@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = python3 -msphinx
SPHINXPROJ = Quantlab
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 64 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 92 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 100 KiB

BIN
docs/_static/img/analysis/report.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 148 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

BIN
docs/_static/img/analysis/score_ic.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 99 KiB

BIN
docs/_static/img/framework.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 205 KiB

BIN
docs/_static/img/topk_drop.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

104
docs/advanced/alpha.rst Normal file
View File

@@ -0,0 +1,104 @@
.. _alpha:
===========================
Building Formulaic Alphas
===========================
.. currentmodule:: qlib
Introduction
===================
In quantitative trading practice, designing novel factors that can explain and predict future asset returns are of vital importance to the profitability of a strategy. Such factors are usually called alpha factors, or alphas in short.
A formulaic alpha, as the name suggests, is a kind of alpha that can be presented as a formula or a mathematical expression.
Building Formulaic Alphas in ``Qlib``
======================================
In ``Qlib``, users can easily build formulaic alphas.
Example
-----------------
`MACD`, short for moving average convergence/divergence, is a formulaic alpha used in technical analysis of stock prices. It is designed to reveal changes in the strength, direction, momentum, and duration of a trend in a stock's price.
`MACD` can be presented as the following formula:
.. math::
MACD = 2\times (DIF-DEA)
.. note::
`DIF` means Differential value, which is 12-period EMA minus 26-period EMA.
.. math::
DIF = \frac{EMA(CLOSE, 12) - EMA(CLOSE, 26)}{CLOSE}
`DEA`means a 9-period EMA of the DIF.
.. math::
DEA = \frac{EMA(DIF, 9)}{CLOSE}
Users can use ``Data Handler`` to build formulaic alphas `MACD` in qlib:
.. note:: Users need to initialize ``Qlib`` with `qlib.init` first. Please refer to `initialization <initialization.rst>`_.
.. code-block:: python
>>> from qlib.contrib.estimator.handler import QLibDataHandler
>>> fields = ['(EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9)/$close'] # MACD
>>> names = ['MACD']
>>> labels = ['Ref($vwap, -2)/Ref($vwap, -1) - 1'] # label
>>> label_names = ['LABEL']
>>> data_handler = QLibDataHandler(start_date='2010-01-01', end_date='2017-12-31', fields=fields, names=names, labels=labels, label_names=label_names)
>>> TRAINER_CONFIG = {
... "train_start_date": "2007-01-01",
... "train_end_date": "2014-12-31",
... "validate_start_date": "2015-01-01",
... "validate_end_date": "2016-12-31",
... "test_start_date": "2017-01-01",
... "test_end_date": "2020-08-01",
... }
>>> feature_train, label_train, feature_validate, label_validate, feature_test, label_test = data_handler.get_split_data(**TRAINER_CONFIG)
>>> print(feature_train, label_train)
MACD
instrument datetime
SH600004 2012-01-04 -0.030853
2012-01-05 -0.030452
2012-01-06 -0.028252
2012-01-09 -0.024507
2012-01-10 -0.019744
... ...
SZ300273 2014-12-25 0.031339
2014-12-26 0.029695
2014-12-29 0.025577
2014-12-30 0.020493
2014-12-31 0.017089
[605882 rows x 1 columns]
label
instrument datetime
SH600004 2012-01-04 0.003021
2012-01-05 0.017434
2012-01-06 0.015490
2012-01-09 0.002324
2012-01-10 -0.002542
... ...
SZ300273 2014-12-25 -0.032454
2014-12-26 -0.016638
2014-12-29 0.008263
2014-12-30 -0.011985
2014-12-31 0.047797
[605882 rows x 1 columns]
Reference
===========
To kown more about ``Data Handler``, please refer to `Data Handler <../component/data.html>`_
To kown more about ``Data Api``, please refer to `Data Api <../component/data.html>`_

View File

@@ -0,0 +1,2 @@
.. include:: ../../CHANGES.rst

106
docs/component/backtest.rst Normal file
View File

@@ -0,0 +1,106 @@
.. _backtest:
============================================
Intraday Trading: Model&Strategy Testing
============================================
.. currentmodule:: qlib
Introduction
===================
``Intraday Trading`` is designed to test models and strategies, which help users to check the performance of custom model/strategy.
.. note::
``Intraday Trading`` uses ``Order Executor`` to trade and execute orders output by ``Interday Strategy``. ``Order Executor`` is a component in `Qlib Framework <../introduction/introduction.html#framework>`_, which can execute orders. ``Vwap Executor`` and ``Close Executor`` is supported by ``Qlib`` now. In the future, ``Qlib`` will support ``HighFreq Executor`` also.
Example
===========================
Users need to generate a prediction score(a pandas DataFrame) with MultiIndex<instrument, datetime> and a `score` column. And users need to assign a strategy used in backtest, if strategy is not assigned,
a `TopkDropoutStrategy` strategy with `(topk=50, n_drop=5, risk_degree=0.95, limit_threshold=0.0095)` will be used.
If ``Strategy`` module is not user's interested part, `TopkDropoutStrategy` is enough.
The simple example with default strategy is as follows.
.. code-block:: python
from qlib.contrib.evaluate import backtest
# pred_score is the prediction score
report, positions = backtest(pred_score, topk=50, n_drop=0.5, verbose=False, limit_threshold=0.0095)
To know more about backtesting with specific strategy, please refer to `Strategy <strategy.html>`_.
To know more about the prediction score `pred_score` output by ``Model``, please refer to `Interday Model: Model Training & Prediction <model.html>`_.
Prediction Score
-----------------
The prediction score is a pandas DataFrame. Its index is <instrument(str), datetime(pd.Timestamp)> and it must
contains a `score` column.
A prediction sample is shown as follows.
.. code-block:: python
instrument datetime score
SH600000 2019-01-04 -0.505488
SZ002531 2019-01-04 -0.320391
SZ000999 2019-01-04 0.583808
SZ300569 2019-01-04 0.819628
SZ001696 2019-01-04 -0.137140
... ...
SZ000996 2019-04-30 -1.027618
SH603127 2019-04-30 0.225677
SH603126 2019-04-30 0.462443
SH603133 2019-04-30 -0.302460
SZ300760 2019-04-30 -0.126383
``Model`` module can make predictions, please refer to `Model <model.html>`_.
Backtest Result
------------------
The backtest results are in the following form:
.. code-block:: python
sub_bench mean 0.000662
std 0.004487
annual 0.166720
sharpe 2.340526
mdd -0.080516
sub_cost mean 0.000577
std 0.004482
annual 0.145392
sharpe 2.043494
mdd -0.083584
- `sub_bench`
Returns of the portfolio without deduction of fees
- `sub_cost`
Returns of the portfolio with deduction of fees
- `mean`
Mean value of the returns sequence(difference sequence of assets).
- `std`
Standard deviation of the returns sequence(difference sequence of assets).
- `annual`
Average annualized returns of the portfolio.
- `ir`
Information Ratio, please refer to `Information Ratio IR <https://www.investopedia.com/terms/i/informationratio.asp>`_.
- `mdd`
Maximum Drawdown, please refer to `Maximum Drawdown (MDD) <https://www.investopedia.com/terms/m/maximum-drawdown-mdd.asp>`_.
Reference
==============
To know more about ``Intraday Trading``, please refer to `Backtest API <../reference/api.html>`_.

333
docs/component/data.rst Normal file
View File

@@ -0,0 +1,333 @@
.. _data:
================================
Data Layer: Data Framework&Usage
================================
Introduction
============================
``Data Layer`` is designed to download raw data, retrieve data, construct datasets and get frequently-used data.
Also, users can building formulaic alphas with ``Data Layer`` easliy. If users are interesting formulaic alphas, please refer to `Building Formulaic Alphas <../advanced/alpha.html>`_.
The ``Data Layer`` framework includes four components as follows.
- Raw Data
- Data API
- Data Handler
- Cache
Raw Data
============================
``Qlib`` provides the script ``scripts/get_data.py`` to download the raw data that will be used to initialize the qlib package, please refer to `Initialization <../start/initialization.rst>`_.
When ``Qlib`` is initialized, users can choose china-stock mode or US-stock mode, please refer to `Initialization <../start/initialization.rst>`_.
China-Stock Market Mode
--------------------------------
If users use ``Qlib`` in china-stock mode, china-stock data is required. The script ``scripts/get_data.py`` can be used to download china-stock data. If users want to use ``Qlib`` in china-stock mode, they need to do as follows.
- Download data in qlib format
Run the following command to download china-stock data in csv format.
.. code-block:: bash
python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
Users can find china-stock data in qlib format in the'~/.qlib/csv_data/cn_data' directory.
- Initialize ``Qlib`` in china-stock mode
Users only need to initialize ``Qlib`` as follows.
.. code-block:: python
from qlib.config import REG_CN
qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=REG_CN)
US-Stock Market Mode
-------------------------
If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` does not provide script to download US-stock data. If users want to use ``Qlib`` in US-stock market mode, they need to do as follows.
- Prepare data in csv format
Users need to prepare US-stock data in csv format by themselves, which is in the same format as the china-stock data in csv format. Please download the china-stock data in csv format as follows for reference of format.
.. code-block:: bash
python scripts/get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
- Convert data from csv format to ``Qlib`` format
``Qlib`` provides the script ``scripts/dump_bin.py`` to convert data from csv format to qlib format.
Assuming that the users store the US-stock data in csv format in path '~/.qlib/csv_data/us_data', they need to execute the following command to convert the data from csv format to ``Qlib`` format:
.. code-block:: bash
python scripts/dump_bin.py dump --csv_path ~/.qlib/csv_data/us_data --qlib_dir ~/.qlib/qlib_data/us_data --include_fields open,close,high,low,volume,factor
- Initialize ``Qlib`` in US-stock mode
Users only need to initialize ``Qlib`` as follows.
.. code-block:: python
from qlib.config import REG_US
qlib.init(provider_uri='~/.qlib/qlib_data/us_data', region=REG_US)
Please refer to `Script API <../reference/api.html>`_ for more details.
Data API
========================
Data Retrieval
---------------
Users can use APIs in ``qlib.data`` to retrieve data, please refer to `Data Retrieval <../start/getdata.html>`_.
Feature
------------------
``Qlib`` provides `Feature` and `ExpressionOps` to fetch the features according to users' need.
- `Feature`
Load data from data provider.
- `ExpressionOps`
`ExpressionOps` will use operator for feature construction.
To know more about ``Operator``, please refer to `Operator API <../reference/api.html>`_.
To know more about ``Feature``, please refer to `Feature API <../reference/api.html>`_.
Filter
-------------------
``Qlib`` provides `NameDFilter` and `ExpressionDFilter` to filter the instruments according to users' need.
- `NameDFilter`
Name dynamic instrument filter. Filter the instruments based on a regulated name format. A name rule regular expression is required.
- `ExpressionDFilter`
Expression dynamic instrument filter. Filter the instruments based on a certain expression. An expression rule indicating a certain feature field is required.
- `basic features filter`: rule_expression = '$close/$open>5'
- `cross-sectional features filter` : rule_expression = '$rank($close)<10'
- `time-sequence features filter`: rule_expression = '$Ref($close, 3)>100'
To know more about ``Filter``, please refer to `Filter API <../reference/api.html>`_.
API
-------------
To know more about ``Data Api``, please refer to `Data Api <../reference/api.html>`_.
Data Handler
=================
``Data Handler`` is a part of ``estimator`` and can also be used as a single module.
``Data Handler`` can be used to load raw data, prepare features and label columns, preprocess data(standardization, remove NaN, etc.), split training, validation, and test sets. It is a subclass of ``qlib.contrib.estimator.handler.BaseDataHandler``, which provides some interfaces, for example:
Base Class & Interface
----------------------
Qlib provides a base class `qlib.contrib.estimator.BaseDataHandler <../reference/api.html#class-qlib.contrib.estimator.BaseDataHandler>`_, which provides the following interfaces:
- `setup_feature`
Implement the interface to load the data features.
- `setup_label`
Implement the interface to load the data labels and calculate user's labels.
- `setup_processed_data`
Implement the interface for data preprocessing, such as preparing feature columns, discarding blank lines, and so on.
Qlib also provides two functions to help user init the data handler, user can override them for user's need.
- `_init_kwargs`
User can init the kwargs of the data handler in this function, some kwargs may be used when init the raw df.
Kwargs are the other attributes in data.args, like dropna_label, dropna_feature
- `_init_raw_df`
User can init the raw df, feature names and label names of data handler in this function.
If the index of feature df and label df are not same, user need to override this method to merge them (e.g. inner, left, right merge).
If users want to load features and labels by config, users can inherit ``qlib.contrib.estimator.handler.ConfigDataHandler``, ``Qlib`` also have provided some preprocess method in this subclass.
If users want to use qlib data, `QLibDataHandler` is recommended. Users can inherit their custom class from `QLibDataHandler`, which is also a subclass of `ConfigDataHandler`.
Usage
--------------
'Data Handler' can be used as a single module, which provides the following mehtod:
- `get_split_data`
- According to the start and end dates, return features and labels of the pandas DataFrame type used for the 'Model'
- `get_rolling_data`
- According to the start and end dates, and `rolling_period`, an iterator is returned, which can be used to traverse the features and labels used for rolling.
Example
--------------
``Data Handler`` can be run with ``estimator`` by modifying the configuration file, and can also be used as a single module.
Know more about how to run ``Data Handler`` with ``estimator``, please refer to `Estimator <estimator.html#about-data>`_.
Qlib provides implemented data handler `QLibDataHandlerV1`. The following example shows how to run 'QLibDataHandlerV1' as a single module.
.. note:: User needs to initialize ``Qlib`` with `qlib.init` first, please refer to `initialization <initialization.rst>`_.
.. code-block:: Python
from qlib.contrib.estimator.handler import QLibDataHandlerV1
from qlib.contrib.model.gbdt import LGBModel
DATA_HANDLER_CONFIG = {
"dropna_label": True,
"start_date": "2007-01-01",
"end_date": "2020-08-01",
"market": "csi500",
}
TRAINER_CONFIG = {
"train_start_date": "2007-01-01",
"train_end_date": "2014-12-31",
"validate_start_date": "2015-01-01",
"validate_end_date": "2016-12-31",
"test_start_date": "2017-01-01",
"test_end_date": "2020-08-01",
}
exampleDataHandler = QLibDataHandlerV1(**DATA_HANDLER_CONFIG)
# example of 'get_split_data'
x_train, y_train, x_validate, y_validate, x_test, y_test = exampleDataHandler.get_split_data(**TRAINER_CONFIG)
# example of 'get_rolling_data'
for (x_train, y_train, x_validate, y_validate, x_test, y_test) in exampleDataHandler.get_rolling_data(**TRAINER_CONFIG):
print(x_train, y_train, x_validate, y_validate, x_test, y_test)
.. note:: (x_train, y_train, x_validate, y_validate, x_test, y_test) can be used as arguments for the ``fit``, ``predict``, and ``score`` methods of the 'Model' , please refer to `Model <model.html#Interface>`_.
Also, the above example has been given in ``examples.estimator.train_backtest_analyze.ipynb``.
API
---------
To know more abot ``Data Handler``, please refer to `Data Handler API <../reference/api.html#handler>`_.
Cache
==========
``Cache`` is an optional module that helps accelerate providing data by saving some frequently-used data as cache file.
Memory Cache
--------------
Base Class & Interface
~~~~~~~~~~~~~~~~~~~~~~~
``Qlib`` provides a `Memcache` class to cache the most-frequently-used data in memory, an inheritable `ExpressionCache` class, and an inheritable `DatasetCache` class.
`Memcache` is a memory cache mechanism that composes of three `MemCacheUnit` instances to cache **Calendar**, **Instruments**, and **Features**. The MemCache is defined globally in `cache.py` as `H`. User can use `H['c'], H['i'], H['f']` to get/set memcache.
.. autoclass:: qlib.data.cache.MemCacheUnit
:members:
.. autoclass:: qlib.data.cache.MemCache
:members:
Disk Cache
--------------
Base Class & Interface
~~~~~~~~~~~~~~~~~~~~~~~
`ExpressionCache` is a disk cache mechanism that saves expressions such as **Mean($close, 5)**. Users can inherit this base class to define their own cache mechanism. Users need to override `self._uri` method to define how their cache file path is generated, `self._expression` method to define what data they want to cache and how to cache it.
`DatasetCache` is a disk cache mechanism that saves datasets. A certain dataset is regulated by a stockpool configuration (or a series of instruments, though not recommended), a list of expressions or static feature fields, the start time and end time for the collected features and the frequency. Users need to override `self._uri` method to define how their cache file path is generated, `self._expression` method to define what data they want to cache and how to cache it.
`ExpressionCache` and `DatasetCache` actually provides the same interfaces with `ExpressionProvider` and `DatasetProvider` so that the disk cache layer is transparent to users and will only be used if they want to define their own cache mechanism. The users can plug the cache mechanism into the server system by assigning the cache class they want to use in `config.py`:
.. code-block:: python
'ExpressionCache': 'ServerExpressionCache',
'DatasetCache': 'ServerDatasetCache',
Users can find the cache interface here.
ExpressionCache
^^^^^^^^^^^^^^^^^^^^
.. autoclass:: qlib.data.cache.ExpressionCache
:members:
DatasetCache
^^^^^^^^^^^^^^^^^^^^
.. autoclass:: qlib.data.cache.DatasetCache
:members:
Implemented Disk Cache
~~~~~~~~~~~~~~~~~~~~~~~
.. note::
If the user does not use QlibServer, please ignore the content of this section
Qlib has currently provided `ServerExpressionCache` class and `ServerDatasetCache` class as the cache mechanisms used for QlibServer. The class interface and file structure designed for server cache mechanism is listed below.
DiskExpressionCache
^^^^^^^^^^^^^^^^^^^^
.. autoclass:: qlib.data.cache.ServerExpressionCache
DiskDatasetCache
^^^^^^^^^^^^^^^^^^^^
.. autoclass:: qlib.data.cache.ServerDatasetCache
Data and Cache File Structure
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: json
- data/
[raw data] updated by data providers
- calendars/
- day.txt
- instruments/
- all.txt
- csi500.txt
- ...
- features/
- sh600000/
- open.day.bin
- close.day.bin
- ...
- ...
[cached data] updated by server when raw data is updated
- calculated features/
- sh600000/
- [hash(instrtument, field_expression, freq)]
- all-time expression -cache data file
- .meta : an assorted meta file recording the instrument name, field name, freq, and visit times
- ...
- cache/
- [hash(stockpool_config, field_expression_list, freq)]
- all-time Dataset-cache data file
- .meta : an assorted meta file recording the stockpool config, field names and visit times
- .index : an assorted index file recording the line index of all calendars
- ...

View File

@@ -0,0 +1,674 @@
.. _estimator:
=================================
Estimator: Workflow Management
=================================
.. currentmodule:: qlib
Introduction
===================
The components in `Qlib Framework <../introduction/introduction.html#framework>`_ is designed in a loosely-coupled way. Users could build their own quant research workflow with these components like `Example <http://TODO_URL>`_
Besides, ``Qlib`` provides more user-friendly interfaces named ``Estimator`` to automatically run the whole workflow defined by a config. A concrete execution of the whole workflow is called an `experiment`.
With ``Estimator``, user can easily run an `experiment`, which includes the following steps:
- Data
- Loading
- Processing
- Slicing
- Model
- Training and inference(static or rolling)
- Saving & loading
- Evaluation(Back-testing)
For each `experiment`, ``Qlib`` will capture the details of model training, performance evalution results and basic infomation(e.g. names, ids). The captured data will be stored in backend-storge(disk or database).
Example
===================
The following is an example:
.. note:: Make sure install the latest version of `qlib`, please refer to `Qlib installation <../start/installation.html>`_.
If users want to use the models and data provided by `Qlib`, they only need to do as follows.
First, Write a simple configuration file as following,
.. code-block:: YAML
experiment:
name: estimator_example
observer_type: file_storage
mode: train
model:
class: LGBModel
module_path: qlib.contrib.model.gbdt
args:
loss: mse
colsample_bytree: 0.8879
learning_rate: 0.0421
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
data:
class: QLibDataHandlerClose
args:
dropna_label: True
filter:
market: csi500
trainer:
class: StaticTrainer
args:
rolling_period: 360
train_start_date: 2007-01-01
train_end_date: 2014-12-31
validate_start_date: 2015-01-01
validate_end_date: 2016-12-31
test_start_date: 2017-01-01
test_end_date: 2020-08-01
strategy:
class: TopkDropoutStrategy
args:
topk: 50
n_drop: 5
backtest:
normal_backtest_args:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: SH000905
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
qlib_data:
# when testing, please modify the following parameters according to the specific environment
provider_uri: "~/.qlib/qlib_data/cn_data"
region: "cn"
Then run the following command:
.. code-block:: bash
estimator -c configuration.yaml
.. note:: 'estimator' is a built-in command of our program.
Configuration File
===================
Before using ``estimator``, users need to prepare a configuration file. The following shows how to prepare each part of the configuration file.
Experiment Field
--------------------
First, the configuration file needs to have a field about the experiment, whose key is `experiment`. This field and its contents determine how `estimator` tracks and persists this `experiment`. ``Qlib`` used `sacred`, a lightweight open-source tool designed to configure, organize, generate logs, and manage experiment results. The field `experiment` will determine the partial behavior of `sacred`.
Usually, in the running process of `estimator`, those following will be managed by `sacred`:
- `model.bin`, model binary file
- `pred.pkl`, model prediction result file
- `analysis.pkl`, backtest performance analysis file
- `positions.pkl`, backtest position record file
- `run`, the experiment information object, usually contains some meta information such as the experiment name, experiment date, etc.
Usually, it should contain the following:
.. code-block:: YAML
experiment:
name: test_experiment
observer_type: mongo
mongo_url: mongodb://MONGO_URL
db_name: public
finetune: false
exp_info_path: /home/test_user/exp_info.json
mode: test
loader:
id: 677
The meaning of each field is as follows:
- `name`
The experiment name, str type, `sacred` will use this experiment name as an identifier for some important internal processes. Usually, users can see this field in `sacred` by `run` object. The default value is `test_experiment`.
- `observer_type`
Observer type, str type, there are two values which are `file_storage` and `mongo` respectively. If it is `file_storage`, all the above-mentioned managed contents will be stored in the `dir` directory, separated by the number of times of experiments as a subfolder. If it is `mongo`, the content will be stored in the database. The default is `file_storage`.
- For `file_storage` observer.
- `dir`
Directory url, str type, directory for `file_storage` observer type, files captures and managed by sacred with observer type of `file_storage` will be save to this directory, default is the directory of `config.json`.
- For `mongo` observer.
- `mongo_url`
Database URL, str type, required if the observer type is `mongo`.
- `db_name`
Database name, str type, required if the observer type is `mongo`.
- `finetune`
Estimator will produce a model based on this flag
The following table is the processing logic for different situations.
========== =========================================== ==================================== =========================================== ==========================================
. Static Rolling
. Finetune=True Finetune=False Finetune=True Finetune=False
========== =========================================== ==================================== =========================================== ==========================================
Train - Need to provide model(Static or Rolling) - No need to provide model - Need to provide model(Static or Rolling) - Need to provide model(Static or Rolling)
- The args in model section will be - The args in model section will be - The args in model section will be - The args in model section will be
used for finetuning used for training used for finetuning used for finetuning
- Update based on the provided model - Train model from scratch - Update based on the provided model - Based on the provided model update
and parameters and parameters - Train model from scratch
- **Each rolling time slice is based on** - **Train each rolling time slice**
**a model updated from the previous** **separately**
**time**
Test - Model must exist, otherwise an exception will be raised.
- For `StaticTrainer`, users need to train a model and record 'exp_info' for 'Test'.
- For `RollingTrainer`, users need to train a set of models until the latest time, and record 'exp_info' for 'Test'.
========== =============================================================================================================================================================================
.. note::
1. finetune parameters: share model.args parameters.
2. provide model: from `loader.model_index`, load the index of the model(starting from 0).
3. If `loader.model_index` is None:
- In 'Static Finetune=True', if provide 'Rolling', use the last model to update.
- For RollingTrainer with Finetune=Ture.
- If StaticTrainer is used in loader, the model will be used for initialization for finetuning.
- If RollingTrainer is used in loader, the existing models will be used without any modification and the new models will be initialized with the model in the last period and finetune one by one.
- `exp_info_path`
experiment info save path, str type, save the experiment info and model prediction score after the experiment is finished. Optional parameter, the default value is `config_file_dir/ex_name/exp_info.json`
- `mode`
`train` or `test`, str type, if `mode` is test, it will load the model according to the parameters of `loader`. The default value is `train`.
Also note that when the load model failed, it will `fit` model.
.. note::
if users choose `mode` test, they need to make sure:
- The loader of `test_start_date` must be less than or equal to the current `test_start_date`.
- If other parameters of the `loader` model args are different, a warning will appear.
- `loader`
If the `mode` is `test` or `finetune` is `true`, it will be used.
- `model_index`
Model index, int type. The index of the loaded model in loader_models (starting at 0) for the first `finetune`. The default value is None.
- `exp_info_path`
Loader model experiment info path, str type. If the field exists, the following parameters will be parsed from `exp_info_path`, and the following parameters will not work. This field and `id` must exist one.
- `id`
The experiment id of the model that needs to be loaded, int type. If the `mode` is `test`, this value is required. This field and `exp_info_path` must exist one.
- `name`
The experiment name of the model that needs to be loaded, str type. The default value is the current experiment `name`.
- `observer_type`
The experiment observer type of the model that needs to be loaded, str type. The default value is the current experiment `observer_type`.
.. note:: The observer type is a concept of the `sacred` module, which determines how files, standard input and output which are managed by sacred are stored.
- `file_storage`
If `observer_type` is `file_storage`, the config may be as follows.
.. code-block:: YAML
experiment:
name: test_experiment
dir: <path to a directory> # default is dir of `config.yml`
observer_type: file_storage
- `mongo`
If `observer_type` is `mongo`, the config may be as follows.
.. code-block:: YAML
experiment:
name: test_experiment
observer_type: mongo
mongo_url: mongodb://MONGO_URL
db_name: public
Users need to indicate `mongo_url` and `db_name` for a mongo observer.
.. note::
If users choose mongo observer, they need to make sure:
- have an environment with the mongodb installed and a mongo database dedicated for storing the experiments results.
- The python environment(the version of python and package) to run the experiments and the one to fetch the results are consistent.
Model Field
-----------------
Users can use a specified model by configuration with hyper-parameters.
Custom Models
~~~~~~~~~~~~~~~~~
Qlib support custom models, but it must be a subclass of the `qlib.contrib.model.Model`, the config for custom model may be as following.
.. code-block:: YAML
model:
class: SomeModel
module_path: /tmp/my_experment/custom_model.py
args:
loss: binary
The class `SomeModel` should be in the module `custom_model`, and ``Qlib`` could parse the `module_path` to load the class.
To Know more about ``Model``, please refer to `Model <model.html>`_.
Data Field
-----------------
``Data Handler`` can be used to load raw data, prepare features and label columns, preprocess data(standardization, remove NaN, etc.), split training, validation, and test sets. It is a subclass of `qlib.contrib.estimator.handler.BaseDataHandler`.
Users can use the specified data handler by config as follows.
.. code-block:: YAML
data:
class: QLibDataHandlerClose
args:
start_date: 2005-01-01
end_date: 2018-04-30
dropna_label: True
filter:
market: csi500
filter_pipeline:
-
class: NameDFilter
module_path: qlib.filter
args:
name_rule_re: S(?!Z3)
fstart_time: 2018-01-01
fend_time: 2018-12-11
-
class: ExpressionDFilter
module_path: qlib.filter
args:
rule_expression: $open/$factor<=45
fstart_time: 2018-01-01
fend_time: 2018-12-11
- `class`
Data handler class, str type, which should be a subclass of `qlib.contrib.estimator.handler.BaseDataHandler`, and implements 5 important interfaces for loading features, loading raw data, preprocessing raw data, slicing train, validation, and test data. The default value is `ALPHA360`. If users want to write a data handler to retrieve the data in qlib, `QlibDataHandler` is suggested.
- `module_path`
The module path, str type, absolute url is also supported, indicates the path of the `class` implementation of data processor class. The default value is `qlib.contrib.estimator.handler`.
- `args`
Parameters used for ``Data Handler`` initialization.
- `train_start_date`
Training start time, str type, default value is `2005-01-01`.
- `start_date`
Data start date, str type.
- `end_date`
Data end date, str type. the data from start_date to end_date decides which part of data will be loaded in datahandler, users can only use these data in the following parts.
- `dropna_feature` (Optional in args)
Drop Nan feature, bool type, default value is False.
- `dropna_label` (Optional in args)
Drop Nan label, bool type, default value is True. Some multi-label tasks will use this.
- `normalize_method` (Optional in args)
Normalzie data by given method. str type. ``Qlib`` give two normalize method, `MinMax` and `Std`.
If users wants to build their own method, please override `_process_normalize_feature`.
- `filter`
Dynamically filtering the stocks based on the filter pipeline.
- `market`
index name, str type, the default value is `csi500`.
- `filter_pipeline`
Filter rule list, list type, the default value is []. Can be customized according to users' needs.
- `class`
Filter class name, str type.
- `module_path`
The module path, str type.
- `args`
The filter class parameters, this parameters are set according to the `class`, and all the parameters as kwargs to `class`.
Custom Data Handler
~~~~~~~~~~~~~~~~~~~~~~
Qlib support custom data handler, but it must be a subclass of the ``qlib.contrib.estimator.handler.BaseDataHandler``, the config for custom data handler may be as follows.
.. code-block:: YAML
data:
class: SomeDataHandler
module_path: /tmp/my_experment/custom_data_handler.py
args:
start_date: 2005-01-01
end_date: 2018-04-30
The class `SomeDataHandler` should be in the module `custom_data_handler`, and ``Qlib`` could parse the `module_path` to load the class.
If users want to load features and labels by config, they can inherit ``qlib.contrib.estimator.handler.ConfigDataHandler``, ``Qlib`` also has provided some preprocess method in this subclass.
If users want to use qlib data, `QLibDataHandler` is recommended, from which users can inherit custom class. `QLibDataHandler` is also a subclass of `ConfigDataHandler`.
To Know more about ``Data Handler``, please refer to `Data Framework&Usage <data.html>`_.
Trainer Field
-----------------
Users can specify the trainer ``Trainer`` by the config file, which is subclass of ``qlib.contrib.estimator.trainer.BaseTrainer`` and implement three important interfaces for training the model, restoring the model, and getting model predictions as follows.
- `train`
Implement this interface to train the model.
- `load`
Implement this interface to recover the model from disk.
- `get_pred`
Implement this interface to get model prediction results.
Qlib have provided two implemented trainer,
- `StaticTrainer`
The static trainer will be trained using the training, validation, and test data of the data processor static slicing.
- `RollingTrainer`
The rolling trainer will use the rolling iterator of the data processor to split data for rolling training.
Users can specify `trainer` with the configuration file:
.. code-block:: YAML
trainer:
class: StaticTrainer # or RollingTrainer
args:
rolling_period: 360
train_start_date: 2005-01-01
train_end_date: 2014-12-31
validate_start_date: 2015-01-01
validate_end_date: 2016-06-30
test_start_date: 2016-07-01
test_end_date: 2017-07-31
- `class`
Trainer class, which should be a subclass of `qlib.contrib.estimator.trainer.BaseTrainer`, and needs to implement three important interfaces, the default value is `StaticTrainer`.
- `module_path`
The module path, str type, absolute url is also supported, indicates the path of the trainer class implementation.
- `args`
Parameters used for ``Trainer`` initialization.
- `rolling_period`
The rolling period, integer type, indicates how many time steps need rolling when rolling the data. The default value is `60`. Only used in `RollingTrainer`.
- `train_start_date`
Training start time, str type.
- `train_end_date`
Training end time, str type.
- `validate_start_date`
Validation start time, str type.
- `validate_end_date`
Validation end time, str type.
- `test_start_date`
Test start time, str type.
- `test_end_date`
Test end time, str type. If `test_end_date` is `-1` or greater than the last date of the data, the last date of the data will be used as `test_end_date`.
Custom Trainer
~~~~~~~~~~~~~~~~~~
Qlib support custom trainer, but it must be a subclass of the `qlib.contrib.estimator.trainer.BaseTrainer`, the config for custom trainer may be as following,
.. code-block:: YAML
trainer:
class: SomeTrainer
module_path: /tmp/my_experment/custom_trainer.py
args:
train_start_date: 2005-01-01
train_end_date: 2014-12-31
validate_start_date: 2015-01-01
validate_end_date: 2016-06-30
test_start_date: 2016-07-01
test_end_date: 2017-07-31
The class `SomeTrainer` should be in the module `custom_trainer`, and ``Qlib`` could parse the `module_path` to load the class.
Strategy Field
-----------------
Users can specify strategy through a config file, for example:
.. code-block:: YAML
strategy :
class: TopkDropoutStrategy
args:
topk: 50
n_drop: 5
- `class`
The strategy class, str type, should be a subclass of `qlib.contrib.strategy.strategy.BaseStrategy`. The default value is `TopkDropoutStrategy`.
- `module_path`
The module location, str type, absolute url is also supported, and absolute path is also supported, indicates the location of the policy class implementation.
- `args`
Parameters used for ``Trainer`` initialization.
- `topk`
The number of stocks in the portfolio
- `n_drop`
Number of stocks to be replaced in each trading date
Custom Strategy
^^^^^^^^^^^^^^^^^^^
Qlib support custom strategy, but it must be a subclass of the ``qlib.contrib.strategy.strategy.BaseStrategy``, the config for custom strategy may be as following,
.. code-block:: YAML
strategy :
class: SomeStrategy
module_path: /tmp/my_experment/custom_strategy.py
The class `SomeStrategy` should be in the module `custom_strategy`, and ``Qlib`` could parse the `module_path` to load the class.
To Know more about ``Strategy``, please refer to `Strategy <strategy.html>`_.
Backtest Field
-----------------
Users can specify `backtest` through a config file, for example:
.. code-block:: YAML
backtest :
normal_backtest_args:
topk: 50
benchmark: SH000905
account: 500000
deal_price: close
min_cost: 5
subscribe_fields:
- $close
- $change
- $factor
- `normal_backtest_args`
Normal backtest parameters. All the parameters in this section will be passed to the ``qlib.contrib.evaluate.backtest`` function in the form of `**kwargs`.
- `benchmark`
Stock index symbol, str or list type, the default value is `None`.
.. note::
* If `benchmark` is None, it will use the average change of the day of all stocks in 'pred' as the 'bench'.
* If `benchmark` is list, it will use the daily average change of the stock pool in the list as the 'bench'.
* If `benchmark` is str, it will use the daily change as the 'bench'.
- `account`
Backtest initial cash, integer type. The `account` in `strategy` section is deprecated. It only works when `account` is not set in `backtest` section. It will be overridden by `account` in the `backtest` section. The default value is 1e9.
- `deal_price`
Order transaction price field, str type, the default value is vwap.
- `min_cost`
Min transaction cost, float type, the default value is 5.
- `subscribe_fields`
Subscribe quote fields, array type, the default value is [`deal_price`, $close, $change, $factor].
Qlib Data Field
--------------------
The `qlib_data` field describes the parameters of qlib initialization.
.. code-block:: YAML
qlib_data:
# when testing, please modify the following parameters according to the specific environment
provider_uri: "~/.qlib/qlib_data/cn_data"
region: "cn"
- `provider_uri`
The local directory where the data loaded by 'get_data.py' is stored.
- `region`
- If region == ``qlib.config.REG_CN``, 'qlib' will be initialized in US-stock mode.
- If region == ``qlib.config.REG_US``, 'qlib' will be initialized in china-stock mode.
Please refer to `Initialization <../start/initialization.rst>`_.
Experiment Result
===================
Form of Experimental Result
----------------------------
The result of the experiment is the result of the backtest, please refer to `Backtest <backtest.html>`_.
Get Experiment Result
----------------------------
Users can check the experiment results from file storage directly, or check the experiment results from database, or get the experiment results through two API of a module `fetcher` provided by ``Qlib``.
- `get_experiments()`
The API takes two parameters. The first parameter is the experiment name. The default is all experiments. The second parameter is the observer type. Users can get the experiment name dictionary with a list of ids and test end date by the API as follows.
.. code-block:: JSON
{
"ex_a": [
{
"id": 1,
"test_end_date": "2017-01-01"
}
],
"ex_b": [
...
]
}
- `get_experiment(exp_name, exp_id, fields=None)`
The API takes three parameters, the first parameter is the experiment name, the second parameter is the experiment id, and the third parameter is field list.
If fields is None, will get all fields.
.. note::
Currently supported fields:
['model', 'analysis', 'positions', 'report_normal', 'pred', 'task_config', 'label']
.. code-block:: JSON
{
'analysis': analysis_df,
'pred': pred_df,
'positions': positions_dic,
'report_normal': report_normal_df,
}
Here is a simple example of `FileFetcher`, which could fetch files from `file_storage` observer.
.. code-block:: python
>>> from qlib.contrib.estimator.fetcher import FileFetcher
>>> f = FileFetcher(experiments_dir=r'./')
>>> print(f.get_experiments())
{
'test_experiment': [
{
'id': '1',
'config': ...
},
{
'id': '2',
'config': ...
},
{
'id': '3',
'config': ...
}
]
}
>>> print(f.get_experiment('test_experiment', '1'))
risk
sub_bench mean 0.000662
std 0.004487
annual 0.166720
sharpe 2.340526
mdd -0.080516
sub_cost mean 0.000577
std 0.004482
annual 0.145392
sharpe 2.043494
mdd -0.083584
If users use mongo observer when training, they should initialize their fether with mongo_url
.. code-block:: python
>>> from qlib.contrib.estimator.fetcher import MongoFetcher
>>> f = MongoFetcher(mongo_url=..., db_name=...)

179
docs/component/model.rst Normal file
View File

@@ -0,0 +1,179 @@
.. _model:
============================================
Interday Model: Model Training & Prediction
============================================
Introduction
===================
``Interday Model`` is designed to make the prediction score about stocks. Users can use the ``Interday Model`` in an automatic workflow by ``Estimator``, please refer to `Estimator <estimator.html>`_.
Because the components in ``Qlib`` are designed in a loosely-coupled way, ``Interday Model`` can be used as a independent module also.
Base Class & Interface
======================
``Qlib`` provides a base class `qlib.contrib.model.base.Model <../reference/api.html#module-qlib.contrib.model.base>`_, which all models should inherit from.
The base class provides the following interfaces:
- `__init__(**kwargs)`
- Initialization.
- If users use ``Estimator`` to start an `experiment`, the parameter of `__init__` method shoule be consistent with the hyperparameters in the configuration file.
- `fit(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs)`
- Train model.
- Parameter:
- `x_train`, pd.DataFrame type, train feature
The following example explains the value of `x_train`:
.. code-block:: YAML
KMID KLEN KMID2 KUP KUP2
instrument datetime
SH600004 2012-01-04 0.000000 0.017685 0.000000 0.012862 0.727275
2012-01-05 -0.006473 0.025890 -0.250001 0.012945 0.499998
2012-01-06 0.008117 0.019481 0.416666 0.008117 0.416666
2012-01-09 0.016051 0.025682 0.624998 0.006421 0.250001
2012-01-10 0.017323 0.026772 0.647057 0.003150 0.117648
... ... ... ... ... ...
SZ300273 2014-12-25 -0.005295 0.038697 -0.136843 0.016293 0.421052
2014-12-26 -0.022486 0.041701 -0.539215 0.002453 0.058824
2014-12-29 -0.031526 0.039092 -0.806451 0.000000 0.000000
2014-12-30 -0.010000 0.032174 -0.310811 0.013913 0.432433
2014-12-31 0.010917 0.020087 0.543479 0.001310 0.065216
`x_train` is a pandas DataFrame, whose index is MultiIndex <instrument(str), datetime(pd.Timestamp)>. Each column of `x_train` corresponds to a feature, and the column name is the feature name.
.. note::
The number and names of the columns is determined by the data handler, please refer to `Data Handler <data.html#data-handler>`_ and `Estimator Data <estimator.html#about-data>`_.
- `y_train`, pd.DataFrame type, train label
The following example explains the value of `y_train`:
.. code-block:: YAML
LABEL
instrument datetime
SH600004 2012-01-04 -0.798456
2012-01-05 -1.366716
2012-01-06 -0.491026
2012-01-09 0.296900
2012-01-10 0.501426
... ...
SZ300273 2014-12-25 -0.465540
2014-12-26 0.233864
2014-12-29 0.471368
2014-12-30 0.411914
2014-12-31 1.342723
`y_train` is a pandas DataFrame, whose index is MultiIndex <instrument(str), datetime(pd.Timestamp)>. The `LABEL` column represents the value of train label.
.. note::
The number and names of the columns is determined by the ``Data Handler``, please refer to `Data Handler <data.html#data-handler>`_.
- `x_valid`, pd.DataFrame type, validation feature
The format of `x_valid` is same as `x_train`
- `y_valid`, pd.DataFrame type, validation label
The format of `y_valid` is same as `y_train`
- `w_train`(Optional args, default is None), pd.DataFrame type, train weight
`w_train` is a pandas DataFrame, whose shape and index is same as `x_train`. The float value in `w_train` represents the weight of the feature at the same position in `x_train`.
- `w_train`(Optional args, default is None), pd.DataFrame type, validation weight
`w_train` is a pandas DataFrame, whose shape and index is same as `x_valid`. The float value in `w_train` represents the weight of the feature at the same position in `x_train`.
- `predict(self, x_test, **kwargs)`
- Predict test data 'x_test'
- Parameter:
- `x_test`, pd.DataFrame type, test features
The form of `x_test` is same as `x_train` in 'fit' method.
- Return:
- `label`, np.ndarray type, test label
The label of `x_test` that predicted by model.
- `score(self, x_test, y_test, w_test=None, **kwargs)`
- Evaluate model with test feature/label
- Parameter:
- `x_test`, pd.DataFrame type, test feature
The format of `x_test` is same as `x_train` in `fit` method.
- `x_test`, pd.DataFrame type, test label
The format of `y_test` is same as `y_train` in `fit` method.
- `w_test`, pd.DataFrame type, test weight
The format of `w_test` is same as `w_train` in `fit` method.
- Return: float type, evaluation score
For other interfaces such as `save`, `load`, `finetune`, please refer to `Model API <../reference/api.html#module-qlib.contrib.model.base>`_.
Example
==================
``Qlib`` provides ``LightGBM`` and ``DNN`` models as the baseline, the following steps shows how to run`` LightGBM`` as an independent module.
- Initialize ``Qlib`` with `qlib.init` first, please refer to `initialization <initialization.rst>`_.
- Run the following code to get the prediction score `pred_score`
.. code-block:: Python
from qlib.contrib.estimator.handler import QLibDataHandlerClose
from qlib.contrib.model.gbdt import LGBModel
DATA_HANDLER_CONFIG = {
"dropna_label": True,
"start_date": "2007-01-01",
"end_date": "2020-08-01",
"market": MARKET,
}
TRAINER_CONFIG = {
"train_start_date": "2007-01-01",
"train_end_date": "2014-12-31",
"validate_start_date": "2015-01-01",
"validate_end_date": "2016-12-31",
"test_start_date": "2017-01-01",
"test_end_date": "2020-08-01",
}
x_train, y_train, x_validate, y_validate, x_test, y_test = QLibDataHandlerClose(
**DATA_HANDLER_CONFIG
).get_split_data(**TRAINER_CONFIG)
MODEL_CONFIG = {
"loss": "mse",
"colsample_bytree": 0.8879,
"learning_rate": 0.0421,
"subsample": 0.8789,
"lambda_l1": 205.6999,
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
}
# use default model
# custom Model, refer to: TODO: Model API url
model = LGBModel(**MODEL_CONFIG)
model.fit(x_train, y_train, x_validate, y_validate)
_pred = model.predict(x_test)
pred_score = pd.DataFrame(index=_pred.index)
pred_score["score"] = _pred.iloc(axis=1)[0]
.. note:: `QLibDataHandlerClose` is the data handler provided by ``Qlib``, please refer to `Data Handler <data.html#data-handler>`_.
Also, the above example has been given in ``examples/train_backtest_analyze.ipynb``.
Custom Model
===================
Qlib supports custom models. If users are interested in customizing their own models and integrating the models into ``Qlib``, please refer to `Custom Model Integration <../start/integration.html>`_.
API
===================
Please refer to `Model API <../reference/api.html#module-qlib.contrib.model.base>`_.

197
docs/component/report.rst Normal file
View File

@@ -0,0 +1,197 @@
.. _report:
==========================================
Aanalysis: Evaluation & Results Analysis
==========================================
Introduction
===================
``Aanalysis`` is designed to show the graphical reports of ``Intraday Trading`` , which helps users to evaluate and analyse investment portfolios visually. There are the following graphics to view:
- analysis_position
- report_graph
- score_ic_graph
- cumulative_return_graph
- risk_analysis_graph
- rank_label_graph
- analysis_model
- model_performance_graph
Graphical Reports
===================
Users can run the following code to get all supported reports.
.. code-block:: python
>>> import qlib.contrib.report as qcr
>>> print(qcr.GRAPH_NAME_LISt)
['analysis_position.report_graph', 'analysis_position.score_ic_graph', 'analysis_position.cumulative_return_graph', 'analysis_position.risk_analysis_graph', 'analysis_position.rank_label_graph', 'analysis_model.model_performance_graph']
.. note::
For more details, please refer to the function document: similar to ``help(qcr.analysis_position.report_graph)``
Usage&Example
===================
Usage of `analysis_position.report`
-----------------------------------
API
~~~~~~~~~~~~~~~~
.. automodule:: qlib.contrib.report.analysis_position.report
:members:
Graphical Result
~~~~~~~~~~~~~~~~
.. note::
- Axis X: Trading day
- Axis Y: Accumulated value
- The shaded part above: Maximum drawdown corresponding to `cum return`
- The shaded part below: Maximum drawdown corresponding to `cum ex return wo cost` %
.. image:: ../_static/img/analysis/report.png
Usage of `analysis_position.score_ic`
-------------------------------------
API
~~~~~~~~~~~~~~~~
.. automodule:: qlib.contrib.report.analysis_position.score_ic
:members:
Graphical Result
~~~~~~~~~~~~~~~~~
.. note::
- Axis X: Trading day
- Axis Y: `Ref($close, -1)/$close - 1` and `score` IC%
.. image:: ../_static/img/analysis/score_ic.png
Usage of `analysis_position.cumulative_return`
----------------------------------------------
API
~~~~~~~~~~~~~~~~
.. automodule:: qlib.contrib.report.analysis_position.cumulative_return
:members:
Graphical Result
~~~~~~~~~~~~~~~~~
.. note::
- Cumulative return graphics.
- Axis X: Trading day
- Axis Y:
- Above axis Y: `(((Ref($close, -1)/$close - 1) * weight).sum() / weight.sum()).cumsum()`
- Below axis Y: Daily weight sum
- In the **sell** graph, `y < 0` stands for profit; in other cases, `y > 0` stands for profit.
- In the **buy_minus_sell** graph, the **y** value of the **weight** graph at the bottom is `buy_weight + sell_weight`.
- In each graph, the **red line** in the histogram on the right represents the average.%
.. image:: ../_static/img/analysis/cumulative_return_buy.png
.. image:: ../_static/img/analysis/cumulative_return_sell.png
.. image:: ../_static/img/analysis/cumulative_return_buy_minus_sell.png
.. image:: ../_static/img/analysis/cumulative_return_hold.png
Usage of `analysis_position.risk_analysis`
----------------------------------------------
API
~~~~~~~~~~~~~~~~
.. automodule:: qlib.contrib.report.analysis_position.risk_analysis
:members:
.. note::
- annual/mdd/sharpe/std graphics
- Axis X: Trading days are grouped by month
- Axis Y: monthly(trading date) value
Graphical Result
~~~~~~~~~~~~~~~~~
.. image:: ../_static/img/analysis/risk_analysis_bar.png
.. image:: ../_static/img/analysis/risk_analysis_annual.png
.. image:: ../_static/img/analysis/risk_analysis_mdd.png
.. image:: ../_static/img/analysis/risk_analysis_sharpe.png
.. image:: ../_static/img/analysis/risk_analysis_std.png
Usage of `analysis_position.rank_label`
----------------------------------------------
API
~~~~~
.. automodule:: qlib.contrib.report.analysis_position.rank_label
:members:
Graphical Result
~~~~~~~~~~~~~~~~~
.. note::
- hold/sell/buy graphics:
- Axis X: Trading day
- Axis Y: Percentage of `'Ref($close, -1)/$close - 1'.rank(ascending=False) / (number of lines on the day) * 100` every trading day. (`ascending=False`: The higher the value, the higher the ranking)%
.. image:: ../_static/img/analysis/rank_label_hold.png
.. image:: ../_static/img/analysis/rank_label_buy.png
.. image:: ../_static/img/analysis/rank_label_sell.png
Usage of `analysis_model.analysis_model_performance`
-----------------------------------------------------
API
~~~~~
.. automodule:: qlib.contrib.report.analysis_model.analysis_model_performance
:members:
Graphical Result
~~~~~~~~~~~~~~~~~
.. image:: ../_static/img/analysis/analysis_model_cumulative_return.png
.. image:: ../_static/img/analysis/analysis_model_long_short.png
.. image:: ../_static/img/analysis/analysis_model_IC.png
.. image:: ../_static/img/analysis/analysis_model_monthly_IC.png
.. image:: ../_static/img/analysis/analysis_model_NDQ.png
.. image:: ../_static/img/analysis/analysis_model_auto_correlation.png

119
docs/component/strategy.rst Normal file
View File

@@ -0,0 +1,119 @@
.. _strategy:
========================================
Interday Strategy: Portfolio Management
========================================
.. currentmodule:: qlib
Introduction
===================
``Interday Strategy`` is designed to adopt different trading strategies, which means that users can adopt different algorithms to generate investment portfolios based on the prediction scores of the ``Interday Model``. Users can use the ``Interday Strategy`` in an automatic workflow by ``Estimator``, please refer to `Estimator <estimator.html>`_.
Because the componets in ``Qlib`` are designed in a loosely-coupled way, ``Interday Strategy`` can be used as a independent module also.
``Qlib`` provides several implemented trading strategy. Also, ``Qlib`` supports costom strategy, users can customize strategies according to their own needs.
Base Class & Interface
======================
BaseStrategy
------------------
Qlib provides a base class ``qlib.contrib.strategy.BaseStrategy``. All strategy classes need to inherit the base class and implement its interface.
- `get_risk_degree`
Return the proportion of your total value you will use in investment. Dynamically risk_degree will result in Market timing.
- `generate_order_list`
Rerturn the order list.
User can inherit `BaseStrategy` to costomize their strategy class.
WeightStrategyBase
--------------------
Qlib alse provides a class ``qlib.contrib.strategy.WeightStrategyBase`` that is a subclass of `BaseStrategy`.
`WeightStrategyBase` only focuses on the target positions, and automatically generates an order list based on positions. It provides the `generate_target_weight_position` interface.
- `generate_target_weight_position`
- According to the current position and trading date to generate the target position. The cash is not considered.
- Return the target position.
.. note::
Here the `target position` means the target percentage of total assets.
`WeightStrategyBase` implements the interface `generate_order_list`, whose processions is as follows.
- Call `generate_target_weight_position` method to generate the target position.
- Generate the target amount of stocks from the target position.
- Generate the order list from the target amount
Users can inherit `WeightStrategyBase` and implement the inteface `generate_target_weight_position` to costomize their strategy class, which only focuses on the target positions.
Implemented Strategy
====================
Qlib provides several implemented strategy classes `TopkDropoutStrategy`.
TopkDropoutStrategy
------------------
`TopkDropoutStrategy` is a subclass of `BaseStrategy` and implement the interface `generate_order_list` whose process is as follows.
- Adopt the the ``Topk-Drop`` algorithm to calculate the target amount of each stock
.. note::
``Topk-Drop`` algorithm
- `Topk`: The number of stocks held
- `Drop`: The number of stocks sold on each trading day
Currently, the number of held stocks is `Topk`.
On each trading day, the `Drop` number of held stocks with worst prediction score will be sold, and the same number of unheld stocks with best prediction score will be bought.
.. image:: ../_static/img/topk_drop.png
:alt: Topk-Drop
``TopkDrop`` algorithm sells `Drop` stocks every trading day, which guarantees a fixed turnover rate.
- Generate the order list from the target amount
Usage & Example
====================
``Interday Strategy`` can be specified in the ``Intraday Trading(Backtest)``, the example is as follows.
.. code-block:: python
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import backtest
STRATEGY_CONFIG = {
"topk": 50,
"n_drop": 5,
}
BACKTEST_CONFIG = {
"verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": BENCHMARK,
"deal_price": "vwap",
}
# use default strategy
# custom Strategy, refer to: TODO: Strategy API url
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
# pred_score is the prediction score output by Model
report_normal, positions_normal = backtest(
pred_score, strategy=strategy, **BACKTEST_CONFIG
)
Also, the above example has been given in ``examples\train_backtest_analyze.ipynb``.
To know more about the prediction score `pred_score` output by ``Interday Model``, please refer to `Interday Model: Model Training & Prediction <model.html>`_.
To know more about ``Intraday Trading``, please refer to `Intraday Trading: Model&Strategy Testing <backtest.html>`_.
Reference
===================
TO konw more about ``Interday Strategy``, please refer to `Strategy API <../reference/api.html>`_.

224
docs/conf.py Normal file
View File

@@ -0,0 +1,224 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# QLib documentation build configuration file, created by
# sphinx-quickstart on Wed Sep 27 15:16:05 2017.
#
# This file is execfile()d with the current directory set to its
# containing dir.
#
# Note that not all possible configuration values are present in this
# autogenerated file.
#
# All configuration values have a default; values that are commented out
# serve to show the default.
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
import pkg_resources
# -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.todo',
'sphinx.ext.mathjax',
'sphinx.ext.napoleon',
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
# The master toctree document.
master_doc = 'index'
# General information about the project.
project = u"QLib"
copyright = u"Microsoft"
author = u"Microsoft"
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The short X.Y version.
version = pkg_resources.get_distribution("qlib").version
# The full version, including alpha/beta/rc tags.
release = pkg_resources.get_distribution("qlib").version
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = 'en_US'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False
# If true, '()' will be appended to :func: etc. cross-reference text.
add_function_parentheses = False
# If true, the current module name will be prepended to all description
# unit titles (such as .. function::).
add_module_names = True
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
# -- Options for HTML output ----------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = "sphinx_rtd_theme"
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
# html_context = {
# "display_github": False,
# "last_updated": True,
# "commit": True,
# "github_user": "Microsoft",
# "github_repo": "QLib",
# 'github_version': 'master',
# 'conf_py_path': '/docs/',
# }
#
html_theme_options = {
'collapse_navigation': False,
'display_version': False,
'navigation_depth': 3,
}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
#html_static_path = ['_static']
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
# This is required for the alabaster theme
# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
html_sidebars = {
'**': [
'about.html',
'navigation.html',
'relations.html', # needs 'show_related': True theme option to display
'searchbox.html',
]
}
# -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'qlibdoc'
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, "qlib.tex", u"QLib Documentation", u"Microsoft", "manual"),
]
# -- Options for manual page output ---------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'qlib', u'QLib Documentation',
[author], 1)
]
# -- Options for Texinfo output -------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'QLib', u'QLib Documentation',
author, 'QLib', 'One line description of project.',
'Miscellaneous'),
]
# -- Options for Epub output ----------------------------------------------
# Bibliographic Dublin Core info.
epub_title = project
epub_author = author
epub_publisher = author
epub_copyright = copyright
# The unique identifier of the text. This can be a ISBN number
# or the project homepage.
#
# epub_identifier = ''
# A unique identification for the text.
#
# epub_uid = ''
# A list of files that should not be packed into the epub file.
epub_exclude_files = ['search.html']
autodoc_member_order = 'bysource'
autodoc_default_flags = ['members']

171
docs/hidden/client.rst Normal file
View File

@@ -0,0 +1,171 @@
.. _client:
Qlib Client-Server Framework
===================
.. currentmodule:: qlib
Introduction
-----------
Client-Server is designed to solve following problems
- Manage the data in a centralized way. Users don't have to manage data of different versions.
- Reduce the amount of cache to be generated.
- Make the data can be accessed in a remote way.
Therefore, we designed the client-server framework to solve these problems.
We will maintain a server and provide the data.
You have to initialize you qlib with specific config for using the client-server framework.
Here is a typical initialization process.
qlib ``init`` commonly used parameters; ``nfs-common`` must be installed on the server where the client is located, execute: ``sudo apt install nfs-common``:
- ``provider_uri``: nfs-server path; the format is ``host: data_dir``, for example: ``172.23.233.89:/data2/gaochao/sync_qlib/qlib``. If using offline, it can be a local data directory
- ``mount_path``: local data directory, ``provider_uri`` will be mounted to this directory
- ``auto_mount``: whether to automatically mount ``provider_uri`` to ``mount_path`` during qlib ``init``; You can also mount it manually: sudo mount.nfs ``provider_uri`` ``mount_path``. If on PAI, it is recommended to set ``auto_mount=True``
- ``flask_server``: data service host; if you are on the intranet, you can use the default host: 172.23.233.89
- ``flask_port``: data service port
If running on 10.150.144.153 or 10.150.144.154 server, it's recommended to use the following code to ``init`` qlib:
.. code-block:: python
>>> import qlib
>>> qlib.init(auto_mount=False, mount_path='/data/csdesign/qlib')
>>> from qlib.data import D
>>> D.features(['SH600000'], ['$close'], start_time='20080101', end_time='20090101').head()
[39336:MainThread](2019-05-28 21:35:42,800) INFO - Initialization - [__init__.py:16] - default_conf: client.
[39336:MainThread](2019-05-28 21:35:42,801) INFO - Initialization - [__init__.py:54] - qlib successfully initialized based on client settings.
[39336:MainThread](2019-05-28 21:35:42,801) INFO - Initialization - [__init__.py:56] - provider_uri=172.23.233.89:/data2/gaochao/sync_qlib/qlib
[39336:Thread-68](2019-05-28 21:35:42,809) INFO - Client - [client.py:28] - Connect to server ws://172.23.233.89:9710
[39336:Thread-72](2019-05-28 21:35:43,489) INFO - Client - [client.py:31] - Disconnect from server!
Opening /data/csdesign/qlib/cache/d239a3b191daa9a5b1b19a59beb47b33 in read-only mode
Out[5]:
$close
instrument datetime
SH600000 2008-01-02 119.079704
2008-01-03 113.120125
2008-01-04 117.878860
2008-01-07 124.505539
2008-01-08 125.395004
If running on PAI, it's recommended to use the following code to ``init`` qlib:
.. code-block:: python
>>> import qlib
>>> qlib.init(auto_mount=True, mount_path='/data/csdesign/qlib', provider_uri='172.23.233.89:/data2/gaochao/sync_qlib/qlib')
>>> from qlib.data import D
>>> D.features(['SH600000'], ['$close'], start_time='20080101', end_time='20090101').head()
[39336:MainThread](2019-05-28 21:35:42,800) INFO - Initialization - [__init__.py:16] - default_conf: client.
[39336:MainThread](2019-05-28 21:35:42,801) INFO - Initialization - [__init__.py:54] - qlib successfully initialized based on client settings.
[39336:MainThread](2019-05-28 21:35:42,801) INFO - Initialization - [__init__.py:56] - provider_uri=172.23.233.89:/data2/gaochao/sync_qlib/qlib
[39336:Thread-68](2019-05-28 21:35:42,809) INFO - Client - [client.py:28] - Connect to server ws://172.23.233.89:9710
[39336:Thread-72](2019-05-28 21:35:43,489) INFO - Client - [client.py:31] - Disconnect from server!
Opening /data/csdesign/qlib/cache/d239a3b191daa9a5b1b19a59beb47b33 in read-only mode
Out[5]:
$close
instrument datetime
SH600000 2008-01-02 119.079704
2008-01-03 113.120125
2008-01-04 117.878860
2008-01-07 124.505539
2008-01-08 125.395004
If running on Windows, open **NFS** features and write correct **mount_path**, it's recommended to use the following code to ``init`` qlib:
1.windows System open NFS Features
* Open ``Programs and Features``.
* Click ``Turn Windows features on or off``.
* Scroll down and check the option ``Services for NFS``, then click OK
Reference address: https://graspingtech.com/mount-nfs-share-windows-10/
2.config correct mount_path
* In windows, mount path must be not exist path and root path,
* correct format path eg: `H`, `i`...
* error format path eg: `C`, `C:/user/name`, `qlib_data`...
.. code-block:: python
>>> import qlib
>>> qlib.init(auto_mount=True, mount_path='H', provider_uri='172.23.233.89:/data2/gaochao/sync_qlib/qlib')
>>> from qlib.data import D
>>> D.features(['SH600000'], ['$close'], start_time='20080101', end_time='20090101').head()
[39336:MainThread](2019-05-28 21:35:42,800) INFO - Initialization - [__init__.py:16] - default_conf: client.
[39336:MainThread](2019-05-28 21:35:42,801) INFO - Initialization - [__init__.py:54] - qlib successfully initialized based on client settings.
[39336:MainThread](2019-05-28 21:35:42,801) INFO - Initialization - [__init__.py:56] - provider_uri=172.23.233.89:/data2/gaochao/sync_qlib/qlib
[39336:Thread-68](2019-05-28 21:35:42,809) INFO - Client - [client.py:28] - Connect to server ws://172.23.233.89:9710
[39336:Thread-72](2019-05-28 21:35:43,489) INFO - Client - [client.py:31] - Disconnect from server!
Opening /data/csdesign/qlib/cache/d239a3b191daa9a5b1b19a59beb47b33 in read-only mode
Out[5]:
$close
instrument datetime
SH600000 2008-01-02 119.079704
2008-01-03 113.120125
2008-01-04 117.878860
2008-01-07 124.505539
2008-01-08 125.395004
The client will mount the data in `provider_uri` on `mount_path`. Then the server and client will communicate with flask and transporting data with this NFS.
If you have a local qlib data files and want to use the qlib data offline instead of online with client server framework.
It is also possible with specific config.
You can created such a config. `client_config_local.yml`
.. code-block:: YAML
provider_uri: /data/csdesign/qlib
calendar_provider: 'LocalCalendarProvider'
instrument_provider: 'LocalInstrumentProvider'
feature_provider: 'LocalFeatureProvider'
expression_provider: 'LocalExpressionProvider'
dataset_provider: 'LocalDatasetProvider'
provider: 'LocalProvider'
dataset_cache: 'SimpleDatasetCache'
local_cache_path: '~/.cache/qlib/'
`provider_uri` is the directory of your local data.
.. code-block:: python
>>> import qlib
>>> qlib.init_from_yaml_conf('client_config_local.yml')
>>> from qlib.data import D
>>> D.features(['SH600001'], ['$close'], start_time='20180101', end_time='20190101').head()
21232:MainThread](2019-05-29 10:16:05,066) INFO - Initialization - [__init__.py:16] - default_conf: client.
[21232:MainThread](2019-05-29 10:16:05,066) INFO - Initialization - [__init__.py:54] - qlib successfully initialized based on client settings.
[21232:MainThread](2019-05-29 10:16:05,067) INFO - Initialization - [__init__.py:56] - provider_uri=/data/csdesign/qlib
Out[9]:
$close
instrument datetime
SH600001 2008-01-02 21.082111
2008-01-03 23.195362
2008-01-04 23.874615
2008-01-07 24.880930
2008-01-08 24.277143
Limitations
-----------
1. The following API under the client-server module may not be as fast as the older off-line API.
- Cal.calendar
- Inst.list_instruments
2. The rolling operation expression with parameter `0` can not be updated rightly under mechanism of the client-server framework.
API
********************
The client is based on `python-socketio<https://python-socketio.readthedocs.io>`_ which is a framework that supports WebSocket client for Python language. The client can only propose requests and receive results, which do not include any calculating procedure.
Class
--------------------
.. automodule:: qlib.data.client

285
docs/hidden/online.rst Normal file
View File

@@ -0,0 +1,285 @@
.. _online:
Online
===================
.. currentmodule:: qlib
Introduction
-------------------
Welcome to use Online, this module simulates what will be like if we do the real trading use our model and strategy.
Just like Estimator and other modules in Qlib, you need to determine parameters through the configuration file,
and in this module, you need to add an account in a folder to do the simulation. Then in each coming day,
this module will use the newest information to do the trade for your account,
the performance can be viewed at any time using the API we defined.
Each account will experience the following processes, the pred_date represents the date you predict the target
positions after trading, also, the trade_date is the date you do the trading.
- Generate the order list (pre_date)
- Execute the order list (trade_date)
- Update account (trade_date)
In the meantime, you can just create an account and use this module to test its performance in a period.
- Simulate (start_date, end_date)
This module need to save your account in a folder, the model and strategy will be saved as pickle files,
and the position and report will be saved as excel.
The file structure can be viewed at fileStruct_.
Example
-------------------
Let's take an example,
.. note:: Make sure you have the latest version of `qlib` installed.
If you want to use the models and data provided by `qlib`, you only need to do as follows.
Firstly, write a simple configuration file as following,
.. code-block:: YAML
strategy:
class: TopkAmountStrategy
module_path: qlib.contrib.strategy
args:
market: csi500
trade_freq: 5
model:
class: ScoreFileModel
module_path: qlib.contrib.online.online_model
args:
loss: mse
model_path: ./model.bin
init_cash: 1000000000
We then can use this command to create a folder and do trading from 2017-01-01 to 2018-08-01.
.. code-block:: bash
online simulate -id v-test -config ./config/config.yaml -exchange_config ./config/exchange.yaml -start 2017-01-01 -end 2018-08-01 -path ./user_data/
The start date (2017-01-01) is the add date of the user, which also is the first predict date,
and the end date (2018-08-01) is the last trade date. You can use "`online generate -date 2018-08-02...`"
command to continue generate the order_list at next trading date.
If Your account was saved in "./user_data/", you can see the performance of your account compared to a benchmark by
.. code-block:: bash
>> online show -id v-test -path ./user_data/ -bench SH000905
...
Result of porfolio:
sub_bench:
risk
mean 0.001157
std 0.003039
annual 0.289131
sharpe 6.017635
mdd -0.013185
sub_cost:
risk
mean 0.000800
std 0.003043
annual 0.199944
sharpe 4.155963
mdd -0.015517
Here 'SH000905' represents csi500 and 'SH000300' represents csi300
Manage your account
--------------------
Any account processed by `online` should be saved in a folder. you can use commands
defined to manage your accounts.
- add an new account
This will add an new account with user_id='v-test', add_date='2019-10-15' in ./user_data.
.. code-block:: bash
>> online add_user -id {user_id} -config {config_file} -path {folder_path} -date {add_date}
>> online add_user -id v-test -config config.yaml -path ./user_data/ -date 2019-10-15
- remove an account
.. code-block:: bash
>> online remove_user -id {user_id} -path {folder_path}
>> online remove_user -id v-test -path ./user_data/
- show the performance
Here benchmark indicates the baseline is to be compared with yours.
.. code-block:: bash
>> online show -id {user_id} -path {folder_path} -bench {benchmark}
>> online show -id v-test -path ./user_data/ -bench SH000905
The default value of all the parameter 'date' below is trade date
(will be today if today is trading date and information has been updated in `qlib`).
The 'generate' and 'update' will check whether input date is valid, the following 3 processes should
be called at each trading date.
- generate the order list
generate the order list at trade date, and save them in {folder_path}/{user_id}/temp/ as a json file.
.. code-block:: bash
>> online generate -date {date} -path {folder_path}
>> online generate -date 2019-10-16 -path ./user_data/
- execute the order list
execute the order list and generate the transactions result in {folder_path}/{user_id}/temp/ at trade date
.. code-block:: bash
>> online execute -date {date} -exchange_config {exchange_config_path} -path {folder_path}
>> online execute -date 2019-10-16 -exchange_config ./config/exchange.yaml -path ./user_data/
A simple exchange config file can be as
.. code-block:: yaml
open_cost: 0.003
close_cost: 0.003
limit_threshold: 0.095
deal_price: vwap
- update accounts
update accounts in "{folder_path}/" at trade date
.. code-block:: bash
>> online update -date {date} -path {folder_path}
>> online update -date 2019-10-16 -path ./user_data/
API
------------------
All those operations are based on defined in `qlib.contrib.online.operator`
.. automodule:: qlib.contrib.online.operator
.. _fileStruct:
File structure
------------------
'user_data' indicates the root of folder.
Name that bold indicates its a folder, otherwise its a document.
.. code-block:: yaml
{user_folder}
│ users.csv: (Init date for each users)
└───{user_id1}: (users' sub-folder to save their data)
│ │ position.xlsx
│ │ report.csv
│ │ model_{user_id1}.pickle
│ │ strategy_{user_id1}.pickle
│ │
│ └───score
│ │ └───{YYYY}
│ │ └───{MM}
│ │ │ score_{YYYY-MM-DD}.csv
│ │
│ └───trade
│ └───{YYYY}
│ └───{MM}
│ │ orderlist_{YYYY-MM-DD}.json
│ │ transaction_{YYYY-MM-DD}.csv
└───{user_id2}
│ │ position.xlsx
│ │ report.csv
│ │ model_{user_id2}.pickle
│ │ strategy_{user_id2}.pickle
│ │
│ └───score
│ └───trade
....
Configuration file
------------------
The configure file used in `online` should contain the model and strategy information.
About the model
~~~~~~~~~~~~~~~~~~~~
First, your configuration file needs to have a field about the model,
this field and its contents determine the model we used when generating score at predict date.
Followings are two examples for ScoreFileModel and a model that read a score file and return score at trade date.
.. code-block:: YAML
model:
class: ScoreFileModel
module_path: qlib.contrib.online.OnlineModel
args:
loss: mse
.. code-block:: YAML
model:
class: ScoreFileModel
module_path: qlib.contrib.online.OnlineModel
args:
score_path: <your score path>
If your model doesn't belong to above models, you need to coding your model manually.
Your model should be a subclass of models defined in 'qlib.contfib.model'. And it must
contains 2 methods used in `online` module.
About the strategy
~~~~~~~~~~~~~~~~~~~~
Your need define the strategy used to generate the order list at predict date.
Followings are two examples for a TopkAmountStrategy
.. code-block:: YAML
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
args:
topk: 100
n_drop: 10
Generated files
------------------
The 'online_generate' command will create the order list at {folder_path}/{user_id}/temp/,
the name of that is orderlist_{YYYY-MM-DD}.json, YYYY-MM-DD is the date that those orders to be executed.
The format of json file is like
.. code-block:: python
{
'sell': {
{'$stock_id1': '$amount1'},
{'$stock_id2': '$amount2'}, ...
},
'buy': {
{'$stock_id1': '$amount1'},
{'$stock_id2': '$amount2'}, ...
}
}
Then after executing the order list (either by 'online_execute' or other executors), a transaction file
will be created also at {folder_path}/{user_id}/temp/.

327
docs/hidden/tuner.rst Normal file
View File

@@ -0,0 +1,327 @@
.. _tuner:
Tuner
===================
.. currentmodule:: qlib
Introduction
-------------------
Welcome to use Tuner, this document is based on that you can use Estimator proficiently and correctly.
You can find the optimal hyper-parameters and combinations of models, trainers, strategies and data labels.
The usage of program `tuner` is similar with `estimator`, you need provide the URL of the configuration file.
The `tuner` will do the following things:
- Construct tuner pipeline
- Search and save best hyper-parameters of one tuner
- Search next tuner in pipeline
- Save the global best hyper-parameters and combination
Each tuner is consisted with a kind of combination of modules, and its goal is searching the optimal hyper-parameters of this combination.
The pipeline is consisted with different tuners, it is aim at finding the optimal combination of modules.
The result will be printed on screen and saved in file, you can check the result in your experiment saving files.
Example
~~~~~~~
Let's see an example,
First make sure you have the latest version of `qlib` installed.
Then, you need to privide a configuration to setup the experiment.
We write a simple configuration example as following,
.. code-block:: YAML
experiment:
name: tuner_experiment
tuner_class: QLibTuner
qlib_client:
auto_mount: False
logging_level: INFO
optimization_criteria:
report_type: model
report_factor: model_score
optim_type: max
tuner_pipeline:
-
model:
class: SomeModel
space: SomeModelSpace
trainer:
class: RollingTrainer
strategy:
class: TopkAmountStrategy
space: TopkAmountStrategySpace
max_evals: 2
time_period:
rolling_period: 360
train_start_date: 2005-01-01
train_end_date: 2014-12-31
validate_start_date: 2015-01-01
validate_end_date: 2016-06-30
test_start_date: 2016-07-01
test_end_date: 2018-04-30
data:
class: ALPHA360
provider_uri: /data/qlib
args:
start_date: 2005-01-01
end_date: 2018-04-30
dropna_label: True
dropna_feature: True
filter:
market: csi500
filter_pipeline:
-
class: NameDFilter
module_path: qlib.data.filter
args:
name_rule_re: S(?!Z3)
fstart_time: 2018-01-01
fend_time: 2018-12-11
-
class: ExpressionDFilter
module_path: qlib.data.filter
args:
rule_expression: $open/$factor<=45
fstart_time: 2018-01-01
fend_time: 2018-12-11
backtest:
normal_backtest_args:
verbose: False
limit_threshold: 0.095
account: 500000
benchmark: SH000905
deal_price: vwap
long_short_backtest_args:
topk: 50
Next, we run the following command, and you can see:
.. code-block:: bash
~/v-yindzh/Qlib/cfg$ tuner -c tuner_config.yaml
Searching params: {'model_space': {'colsample_bytree': 0.8870905643607678, 'lambda_l1': 472.3188735122233, 'lambda_l2': 92.75390994877243, 'learning_rate': 0.09741751430635413, 'loss': 'mse', 'max_depth': 8, 'num_leaves': 160, 'num_threads': 20, 'subsample': 0.7536051584789751}, 'strategy_space': {'buffer_margin': 250, 'topk': 40}}
...
(Estimator experiment screen log)
...
Searching params: {'model_space': {'colsample_bytree': 0.6667379039007301, 'lambda_l1': 382.10698024977904, 'lambda_l2': 117.02506488151757, 'learning_rate': 0.18514539615228137, 'loss': 'mse', 'max_depth': 6, 'num_leaves': 200, 'num_threads': 12, 'subsample': 0.9449255686969292}, 'strategy_space': {'buffer_margin': 200, 'topk': 30}}
...
(Estimator experiment screen log)
...
Local best params: {'model_space': {'colsample_bytree': 0.6667379039007301, 'lambda_l1': 382.10698024977904, 'lambda_l2': 117.02506488151757, 'learning_rate': 0.18514539615228137, 'loss': 'mse', 'max_depth': 6, 'num_leaves': 200, 'num_threads': 12, 'subsample': 0.9449255686969292}, 'strategy_space': {'buffer_margin': 200, 'topk': 30}}
Time cost: 489.87220 | Finished searching best parameters in Tuner 0.
Time cost: 0.00069 | Finished saving local best tuner parameters to: tuner_experiment/estimator_experiment/estimator_experiment_0/local_best_params.json .
Searching params: {'data_label_space': {'labels': ('Ref($vwap, -2)/Ref($vwap, -1) - 2',)}, 'model_space': {'input_dim': 158, 'lr': 0.001, 'lr_decay': 0.9100529502185579, 'lr_decay_steps': 162.48901403763966, 'optimizer': 'gd', 'output_dim': 1}, 'strategy_space': {'buffer_margin': 300, 'topk': 35}}
...
(Estimator experiment screen log)
...
Searching params: {'data_label_space': {'labels': ('Ref($vwap, -2)/Ref($vwap, -1) - 1',)}, 'model_space': {'input_dim': 158, 'lr': 0.1, 'lr_decay': 0.9882802970847494, 'lr_decay_steps': 164.76742865207729, 'optimizer': 'adam', 'output_dim': 1}, 'strategy_space': {'buffer_margin': 250, 'topk': 35}}
...
(Estimator experiment screen log)
...
Local best params: {'data_label_space': {'labels': ('Ref($vwap, -2)/Ref($vwap, -1) - 1',)}, 'model_space': {'input_dim': 158, 'lr': 0.1, 'lr_decay': 0.9882802970847494, 'lr_decay_steps': 164.76742865207729, 'optimizer': 'adam', 'output_dim': 1}, 'strategy_space': {'buffer_margin': 250, 'topk': 35}}
Time cost: 550.74039 | Finished searching best parameters in Tuner 1.
Time cost: 0.00023 | Finished saving local best tuner parameters to: tuner_experiment/estimator_experiment/estimator_experiment_1/local_best_params.json .
Time cost: 1784.14691 | Finished tuner pipeline.
Time cost: 0.00014 | Finished save global best tuner parameters.
Best Tuner id: 0.
You can check the best parameters at tuner_experiment/global_best_params.json.
Finally, you can check the results of your experiment in the given path.
Configuration file
------------------
Before using `tuner`, you need to prepare a configuration file. Next we will show you how to prepare each part of the configuration file.
About the experiment
~~~~~~~~~~~~~~~~~~~~
First, your configuration file needs to have a field about the experiment, whose key is `experiment`, this field and its contents determine the saving path and tuner class.
Usually it should contain the following content:
.. code-block:: YAML
experiment:
name: tuner_experiment
tuner_class: QLibTuner
Also, there are some optional fields. The meaning of each field is as follows:
- `name`
The experiment name, str type, the program will use this experiment name to construct a directory to save the process of the whole experiment and the results. The default value is `tuner_experiment`.
- `dir`
The saving path, str type, the program will construct the experiment directory in this path. The default value is the path where configuration locate.
- `tuner_class`
The class of tuner, str type, must be an already implemented model, such as `QLibTuner` in `qlib`, or a custom tuner, but it must be a subclass of `qlib.contrib.tuner.Tuner`, the default value is `QLibTuner`.
- `tuner_module_path`
The module path, str type, absolute url is also supported, indicates the path of the implementation of tuner. The default value is `qlib.contrib.tuner.tuner`
About the optimization criteria
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
You need to designate a factor to optimize, for tuner need a factor to decide which case is better than other cases.
Usually, we use the result of `estimator`, such as backtest results and the score of model.
This part needs contain these fields:
.. code-block:: YAML
optimization_criteria:
report_type: model
report_factor: model_pearsonr
optim_type: max
- `report_type`
The type of the report, str type, determines which kind of report you want to use. If you want to use the backtest result type, you can choose `pred_long`, `pred_long_short`, `pred_short`, `sub_bench` and `sub_cost`. If you want to use the model result type, you can only choose `model`.
- `report_factor`
The factor you want to use in the report, str type, determines which factor you want to optimize. If your `report_type` is backtest result type, you can choose `annual`, `sharpe`, `mdd`, `mean` and `std`. If your `report_type` is model result type, you can choose `model_score` and `model_pearsonr`.
- `optim_type`
The optimization type, str type, determines what kind of optimization you want to do. you can minimize the factor or maximize the factor, so you can choose `max`, `min` or `correlation` at this field.
Note: `correlation` means the factor's best value is 1, such as `model_pearsonr` (a corraltion coefficient).
If you want to process the factor or you want fetch other kinds of factor, you can override the `objective` method in your own tuner.
About the tuner pipeline
~~~~~~~~~~~~~~~~~~~~~~~~
The tuner pipeline contains different tuners, and the `tuner` program will process each tuner in pipeline. Each tuner will get an optimal hyper-parameters of its specific combination of modules. The pipeline will contrast the results of each tuner, and get the best combination and its optimal hyper-parameters. So, you need to configurate the pipeline and each tuner, here is an example:
.. code-block:: YAML
tuner_pipeline:
-
model:
class: SomeModel
space: SomeModelSpace
trainer:
class: RollingTrainer
strategy:
class: TopkAmountStrategy
space: TopkAmountStrategySpace
max_evals: 2
Each part represents a tuner, and its modules which are to be tuned. Space in each part is the hyper-parameters' space of a certain module, you need to create your searching space and modify it in `/qlib/contrib/tuner/space.py`. We use `hyperopt` package to help us to construct the space, you can see the detail of how to use it in https://github.com/hyperopt/hyperopt/wiki/FMin .
- model
You need to provide the `class` and the `space` of the model. If the model is user's own implementation, you need to privide the `module_path`.
- trainer
You need to proveide the `class` of the trainer. If the trainer is user's own implementation, you need to privide the `module_path`.
- strategy
You need to provide the `class` and the `space` of the strategy. If the strategy is user's own implementation, you need to privide the `module_path`.
- data_label
The label of the data, you can search which kinds of labels will lead to a better result. This part is optional, and you only need to provide `space`.
- max_evals
Allow up to this many function evaluations in this tuner. The default value is 10.
If you don't want to search some modules, you can fix their spaces in `space.py`. We will not give the default module.
About the time period
~~~~~~~~~~~~~~~~~~~~~
You need to use the same dataset to evaluate your different `estimator` experiments in `tuner` experiment. Two experiments using different dataset are uncomparable. You can specify `time_period` through the configuration file:
.. code-block:: YAML
time_period:
rolling_period: 360
train_start_date: 2005-01-01
train_end_date: 2014-12-31
validate_start_date: 2015-01-01
validate_end_date: 2016-06-30
test_start_date: 2016-07-01
test_end_date: 2018-04-30
- `rolling_period`
The rolling period, integer type, indicates how many time steps need rolling when rolling the data. The default value is `60`. If you use `RollingTrainer`, this config will be used, or it will be ignored.
- `train_start_date`
Training start time, str type.
- `train_end_date`
Training end time, str type.
- `validate_start_date`
Validation start time, str type.
- `validate_end_date`
Validation end time, str type.
- `test_start_date`
Test start time, str type.
- `test_end_date`
Test end time, str type. If `test_end_date` is `-1` or greater than the last date of the data, the last date of the data will be used as `test_end_date`.
About the data and backtest
~~~~~~~~~~~~~~~~~~~~~~~~~~~
`data` and `backtest` are all same in the whole `tuner` experiment. Different `estimator` experiments must use the same data and backtest method. So, these two parts of config are same with that in `estimator` configuration. You can see the precise defination of these parts in `estimator` introduction. We only provide an example here.
.. code-block:: YAML
data:
class: ALPHA360
provider_uri: /data/qlib
args:
start_date: 2005-01-01
end_date: 2018-04-30
dropna_label: True
dropna_feature: True
feature_label_config: /home/v-yindzh/v-yindzh/QLib/cfg/feature_config.yaml
filter:
market: csi500
filter_pipeline:
-
class: NameDFilter
module_path: qlib.filter
args:
name_rule_re: S(?!Z3)
fstart_time: 2018-01-01
fend_time: 2018-12-11
-
class: ExpressionDFilter
module_path: qlib.filter
args:
rule_expression: $open/$factor<=45
fstart_time: 2018-01-01
fend_time: 2018-12-11
backtest:
normal_backtest_args:
verbose: False
limit_threshold: 0.095
account: 500000
benchmark: SH000905
deal_price: vwap
long_short_backtest_args:
topk: 50
Experiment Result
-----------------
All the results are stored in experiment file directly, you can check them directly in the corresponding files.
What we save are as following:
- Global optimal parameters
- Local optimal parameters of each tuner
- Config file of this `tuner` experiment
- Every `estimator` experiments result in the process

60
docs/index.rst Normal file
View File

@@ -0,0 +1,60 @@
============================================================
``Qlib`` Documentation
============================================================
``Qlib`` is an AI-oriented quantitative investment platform, which aims to realize the potential, empower the research, and create the value of AI technologies in quantitative investment.
.. _user_guide:
Document Structure
====================
.. toctree::
:hidden:
Home <self>
.. toctree::
:maxdepth: 3
:caption: INTRODUCTION:
Qlib <introduction/introduction.rst>
.. toctree::
:maxdepth: 3
:caption: GETTING STARTED:
Installation <start/installation.rst>
Initialization <start/initialization.rst>
Data Retrieval <start/getdata.rst>
Custom Model Integration <start/integration.rst>
.. toctree::
:maxdepth: 3
:caption: COMPONENTS:
Estimator: Workflow Management <component/estimator.rst>
Data Layer: Data Framework&Usage <component/data.rst>
Interday Model: Model Training & Prediction <component/model.rst>
Interday Strategy: Portfolio Management <component/strategy.rst>
Intraday Trading: Model&Strategy Testing <component/backtest.rst>
Aanalysis: Evaluation & Results Analysis <component/report.rst>
.. toctree::
:maxdepth: 3
:caption: ADVANCED TOPICS:
Building Formulaic Alphas <advanced/alpha.rst>
.. toctree::
:maxdepth: 3
:caption: REFERENCE:
API <reference/api.rst>
.. toctree::
:maxdepth: 3
:caption: Change Log:
Change Log <changelog/changelog.rst>

View File

@@ -0,0 +1,45 @@
===============================
``Qlib``: Quantitative Library
===============================
Introduction
===================
``Qlib`` is an AI-oriented quantitative investment platform, which aims to realize the potential, empower the research, and create the value of AI technologies in quantitative investment.
With ``Qlib``, users can easily apply their favorite model to create better Quant investment strategy.
Framework
==================
.. image:: ../_static/img/framework.png
:alt: Framework
At module level, ``Qlib`` is a platform that consists of the above components. Each components is loose-coupling and can be used stand-alone.
====================== ========================================================================
Name Description
====================== ========================================================================
`Data layer` `DataServer` focus on providing high performance infrastructure for user
to retrieve and get raw data. `DataEnhancement` will preprocess the data
and provide the best dataset to be fed in to the models.
`Interday Model` `Interday model` focus on producing forecasting signals(aka. `alpha`).
Models are trained by `Model Creator` and managed by `Model Manager`.
User could choose one or multiple models for forecasting. Multiple models
could be combined with `Ensemble` module.
`Interday Strategy` `Portfolio Generator` will take forecasting signals as input and output
the orders based on current position to achieve target portfolio.
`Intraday Trading` `Order Executor` is responsible for executing orders output by
`Interday Strategy` and returning the executed results.
`Analysis` User could get detailed analysis report of forecasting signal and portfolio
in this part.
====================== ========================================================================
- The modules with hand-drawn style is under development and will be released in the future.
- The modules with dashed border is highly user-customizable and extendible.

117
docs/reference/api.rst Normal file
View File

@@ -0,0 +1,117 @@
================================
API Reference
================================
Here you can find all ``QLib`` interfaces.
Data
====================
Provider
--------------------
.. automodule:: qlib.data.data
:members:
Filter
--------------------
.. automodule:: qlib.data.filter
:members:
Feature
--------------------
Class
~~~~~~~~~~~~~~~~~~~~
.. automodule:: qlib.data.base
:members:
Operator
~~~~~~~~~~~~~~~~~~~~
.. automodule:: qlib.data.ops
:members:
Cache
----------------
.. autoclass:: qlib.data.cache.MemCacheUnit
:members:
.. autoclass:: qlib.data.cache.MemCache
:members:
.. autoclass:: qlib.data.cache.ExpressionCache
:members:
.. autoclass:: qlib.data.cache.DatasetCache
:members:
.. autoclass:: qlib.data.cache.ServerExpressionCache
:members:
.. autoclass:: qlib.data.cache.ServerDatasetCache
:members:
Contrib
====================
Data Handler
---------------
.. automodule:: qlib.contrib.estimator.handler
:members:
Model
--------------------
.. automodule:: qlib.contrib.model.base
:members:
Strategy
-------------------
.. automodule:: qlib.contrib.strategy.strategy
:members:
Evaluate
-----------------
.. automodule:: qlib.contrib.evaluate
:members:
Report
-----------------
.. automodule:: qlib.contrib.report.analysis_position.report
:members:
.. automodule:: qlib.contrib.report.analysis_position.score_ic
:members:
.. automodule:: qlib.contrib.report.analysis_position.cumulative_return
:members:
.. automodule:: qlib.contrib.report.analysis_position.risk_analysis
:members:
.. automodule:: qlib.contrib.report.analysis_position.rank_label
:members:
.. automodule:: qlib.contrib.report.analysis_model.analysis_model_performance
:members:

1
docs/requirements.txt Normal file
View File

@@ -0,0 +1 @@
Cython==0.29.21

137
docs/start/getdata.rst Normal file
View File

@@ -0,0 +1,137 @@
.. _getdata:
=============================
Data Retrieval
=============================
.. currentmodule:: qlib
Introduction
====================
Users can get stock data by ``Qlib``. Following examples will demonstrate the basic user interface.
Examples
====================
``QLib`` Initialization:
.. note:: In order to get the data, users need to initialize ``Qlib`` with `qlib.init` first. Please refer to `initialization <initialization.rst>`_.
It is recommended to use the following code to initialize qlib:
.. code-block:: python
>>> import qlib
>>> qlib.init(provider_uri='~/.qlib/qlib_data/cn_data')
Load trading calendar with the given time range and frequency:
.. code-block:: python
>>> from qlib.data import D
>>> D.calendar(start_time='2010-01-01', end_time='2017-12-31', freq='day')[:2]
[Timestamp('2010-01-04 00:00:00'), Timestamp('2010-01-05 00:00:00')]
Parse a given market name into a stockpool config:
.. code-block:: python
>>> from qlib.data import D
>>> D.instruments(market='all')
{'market': 'all', 'filter_pipe': []}
Load instruments of certain stockpool in the given time range:
.. code-block:: python
>>> from qlib.data import D
>>> instruments = D.instruments(market='csi300')
>>> D.list_instruments(instruments=instruments, start_time='2010-01-01', end_time='2017-12-31', as_list=True)[:6]
Load dynamic instruments from a base market according to a name filter
.. code-block:: python
>>> from qlib.data import D
>>> from qlib.data.filter import NameDFilter
>>> nameDFilter = NameDFilter(name_rule_re='SH[0-9]{4}55')
>>> instruments = D.instruments(market='csi300', filter_pipe=[nameDFilter])
>>> D.list_instruments(instruments=instruments, start_time='2015-01-01', end_time='2016-02-15', as_list=True)
Load dynamic instruments from a base market according to an expression filter
.. code-block:: python
>>> from qlib.data import D
>>> from qlib.data.filter import ExpressionDFilter
>>> expressionDFilter = ExpressionDFilter(rule_expression='$close>100')
>>> instruments = D.instruments(market='csi300', filter_pipe=[expressionDFilter])
>>> D.list_instruments(instruments=instruments, start_time='2015-01-01', end_time='2016-02-15', as_list=True)
To know more about how to use the filter or how to build one's own filter, go to API Reference: `filter API <../reference/api.html#filter>`_
Load features of certain instruments in given time range:
.. note:: This is not a recommended way to get features.
.. code-block:: python
>>> from qlib.data import D
>>> instruments = ['SH600000']
>>> fields = ['$close', '$volume', 'Ref($close, 1)', 'Mean($close, 3)', '$high-$low']
>>> D.features(instruments, fields, start_time='2010-01-01', end_time='2017-12-31', freq='day').head()
$close $volume Ref($close,1) Mean($close,3) \
instrument datetime
SH600000 2010-01-04 81.809998 17144536.0 NaN 81.809998
2010-01-05 82.419998 29827816.0 81.809998 82.114998
2010-01-06 80.800003 25070040.0 82.419998 81.676666
2010-01-07 78.989998 22077858.0 80.800003 80.736666
2010-01-08 79.879997 17019168.0 78.989998 79.889999
Sub($high,$low)
instrument datetime
SH600000 2010-01-04 2.741158
2010-01-05 3.049736
2010-01-06 1.621399
2010-01-07 2.856926
2010-01-08 1.930397
2010-01-08 1.930397
Load features of certain stockpool in given time range:
.. note:: Since the server need to cache all-time data for your request stockpool and fields, it may take longer to process your request than before. But in the second time, your request will be processed and responded in a flash even if you change the timespan.
.. code-block:: python
>>> from qlib.data import D
>>> from qlib.data.filter import NameDFilter, ExpressionDFilter
>>> nameDFilter = NameDFilter(name_rule_re='SH[0-9]{4}55')
>>> expressionDFilter = ExpressionDFilter(rule_expression='($close/$factor)>100')
>>> instruments = D.instruments(market='csi300', filter_pipe=[nameDFilter, expressionDFilter])
>>> fields = ['$close', '$volume', 'Ref($close, 1)', 'Mean($close, 3)', '$high-$low']
>>> D.features(instruments, fields, start_time='2010-01-01', end_time='2017-12-31', freq='day').head()
$close $volume Ref($close, 1) \
instrument datetime
SH600655 2015-06-15 4342.160156 258706.359375 4530.459961
2015-06-16 4409.270020 257349.718750 4342.160156
2015-06-17 4312.330078 235214.890625 4409.270020
2015-06-18 4086.729980 196772.859375 4312.330078
2015-06-19 3678.250000 182916.453125 4086.729980
Mean($close, 3) high low
instrument datetime
SH600655 2015-06-15 4480.743327 285.251465
2015-06-16 4427.296712 298.301270
2015-06-16 4354.586751 356.098145
2015-06-16 4269.443359 363.554932
2015-06-16 4025.770020 368.954346
.. note:: When calling D.features() at client, use parameter 'disk_cache=0' to skip dataset cache, use 'disk_cache=1' to generate and use dataset cache. In addition, when calling at server, you can use 'disk_cache=2' to update the dataset cache.
API
====================
To know more about how to use the Data, go to API Reference: `Data API <../reference/api.html#Data>`_

View File

@@ -0,0 +1,60 @@
.. _initialization:
====================
Qlib Initialization
====================
.. currentmodule:: qlib
Initialization
=========================
Please execute the following process to initialize ``Qlib``.
- Download and prepare the Data: execute the following command to download the stock data.
.. code-block:: bash
python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
Know more about how to use ``get_data.py``, refer to `Raw Data <../advanced/data.html#raw-data>`_.
- Run the initialization code: run the following code in python:
.. code-block:: Python
import qlib
# region in [REG_CN, REG_US]
from qlib.config import REG_CN
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
qlib.init(provider_uri=provider_uri, region=REG_CN)
Parameters
-------------------
In fact, in addition to `provider_uri` and `region`, `qlib.init` has other parameters. The following are all the parameters of `qlib.init`:
- `provider_uri`
Type: str. The local directory where the data loaded by ``get_data.py`` is stored.
- `region`
Type: str, optional parameter(default: ``qlib.config.REG_CN``).
Currently: ``qlib.config.REG_US``('us') and ``qlib.config.REG_CN``('cn') is supported. Different value of ``region`` will
result in different stock market mode.
- ``qlib.config.REG_US``: US stock market.
- ``qlib.config.REG_CN``: China stock market.
- `redis_host`
Type: str, optional parameter(default: "127.0.0.1"), host of `redis`
The lock and cache mechanism relies on redis.
- `redis_port`
Type: int, optional parameter(default: 6379), port of `redis`
.. note::
The value of `region` should be aligned with the data stored in `provider_uri`. Currently, ``scripts/get_data.py`` only provides China stock market data. If users want to use the US stock market data, they should prepare their own US-stock data in `provider_uri` and switch to US-stock mode.
.. note::
If redis connection failed with `redis_host` and `redis_port`, cache will not be used! Please refer to `Cache <../advanced/cache.rst>`_.

View File

@@ -0,0 +1,43 @@
.. _installation:
====================
Installation
====================
.. currentmodule:: qlib
How to Install ``Qlib``
====================
``Qlib`` only supports Python3, and supports up to Python3.8.
Please execute the following process to install ``Qlib``:
- Change the directory to ``Qlib``, in which the file ``setup.py`` exists.
- Then, please execute the following command:
.. code-block:: bash
$ pip install numpy
$ pip install --upgrade cython
$ python setup.py install
.. note::
It's recommended to use anaconda/miniconda to setup environment.
``Qlib`` needs lightgbm and tensorflow packages, use pip to install them.
.. note::
Do not import qlib in the repository folder which contains ``qlib``, otherwise errors may occur.
Use the following code to confirm installation successful:
.. code-block:: python
>>> import qlib
>>> qlib.__version__
<LATEST VERSION>

146
docs/start/integration.rst Normal file
View File

@@ -0,0 +1,146 @@
=========================================
Custom Model Integration
=========================================
Introduction
===================
``Qlib`` provides ``lightGBM`` and ``Dnn`` model as the baseline of ``Interday Model``. In addition to the default model, users can integrate their own custom models into ``Qlib``.
Users can integrate their own custom models according to the following steps.
- Define a custom model class, which should be a subclass of the `qlib.contrib.model.base.Model <../reference/api.html#module-qlib.contrib.model.base>`_
- Write a configuration file that describes the path and parameters of the custom model
- Test the custom model
Custom Model Class
===========================
The Custom models need to inherit `qlib.contrib.model.base.Model <../reference/api.html#module-qlib.contrib.model.base>`_ and override the methods in it.
- Override the `__init__` method
- ``Qlib`` passes the initialized parameters to the \_\_init\_\_ method
- The parameter must be consistent with the hyperparameters in the configuration file.
- Code Example: In the following example, the hyperparameter filed of the configuration file should contain parameters such as loss:mse.
.. code-block:: Python
def __init__(self, loss='mse', **kwargs):
if loss not in {'mse', 'binary'}:
raise NotImplementedError
self._scorer = mean_squared_error if loss == 'mse' else roc_auc_score
self._params.update(objective=loss, **kwargs)
self._model = None
- Override the `fit` method
- ``Qlib`` calls the fit method to train the model
- The parameters must include training feature 'x_train', training label 'y_train', test feature 'x_valid', test label 'y_valid'at least.
- The parameters could include some optional parameters with default values, such as train weight 'w_train', test weight 'w_valid' and 'num_boost_round = 1000'.
- Code Example: In the following example, 'num_boost_round = 1000' is an optional parameter.
.. code-block:: Python
def fit(self, x_train:pd.DataFrame, y_train:pd.DataFrame, x_valid:pd.DataFrame, y_valid:pd.DataFrame,
w_train:pd.DataFrame = None, w_valid:pd.DataFrame = None, num_boost_round = 1000, **kwargs):
# Lightgbm need 1D array as its label
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze(y_valid.values)
else:
raise ValueError('LightGBM doesn\'t support multi-label training')
w_train_weight = None if w_train is None else w_train.values
w_valid_weight = None if w_valid is None else w_valid.values
dtrain = lgb.Dataset(x_train.values, label=y_train_1d, weight=w_train_weight)
dvalid = lgb.Dataset(x_valid.values, label=y_valid_1d, weight=w_valid_weight)
self._model = lgb.train(
self._params,
dtrain,
num_boost_round=num_boost_round,
valid_sets=[dtrain, dvalid],
valid_names=['train', 'valid'],
**kwargs
)
- Override the `predict` method
- The parameters include the test features
- Return the prediction score
- Please refer to `qlib.contrib.model.base.Model <../reference/api.html#module-qlib.contrib.model.base>`_ for the parameter types of the fit method
- Code Example:In the following example, user need to user dnn to predict the label(such as 'preds') of test data 'x_test' and return it.
.. code-block:: Python
def predict(self, x_test:pd.DataFrame, **kwargs)-> numpy.ndarray:
if self._model is None:
raise ValueError('model is not fitted yet!')
return self._model.predict(x_test.values)
- Override the `score` method
- The parameters include the test features and test labels
- Return the evaluation score of model. It's recommended to adopt the loss between labels and prediction score.
- Code Example:In the following example, user need to calculate the weighted loss with test data 'x_test', test label 'y_test' and the weight 'w_test'.
.. code-block:: Python
def score(self, x_test:pd.Dataframe, y_test:pd.Dataframe, w_test:pd.DataFrame = None) -> float:
# Remove rows from x, y and w, which contain Nan in any columns in y_test.
x_test, y_test, w_test = drop_nan_by_y_index(x_test, y_test, w_test)
preds = self.predict(x_test)
w_test_weight = None if w_test is None else w_test.values
scorer = mean_squared_error if self.loss_type == 'mse' else roc_auc_score
return scorer(y_test.values, preds, sample_weight=w_test_weight)
- Override the `save` method & `load` method
- The `save` method parameter include the a `filename` that represents an absolute path, user need to save model into the path.
- The `load` method parameter include the a `buffer` read from the `filename` passed in `save` method , user need to load model from the `buffer`.
- Code Example:
.. code-block:: Python
def save(self, filename):
if self._model is None:
raise ValueError('model is not fitted yet!')
self._model.save_model(filename)
def load(self, buffer):
self._model = lgb.Booster(params={'model_str': buffer.decode('utf-8')})
Configuration File
=======================
The configuration file is described in detail in the `estimator <../advanced/estimator.html#Example>`_ document. In order to integrate the custom model into ``Qlib``, you need to modify the "model" field in the configuration file.
- Example: The following example describes the model field of configuration file about the custom lightgbm model mentioned above , where module_path is the module path, class is the class name, and args is the hyperparameter passed into the __init__ method. All parameters in the field is passed to 'self._params' by '\*\*kwargs' in `__init__` except 'loss = mse'.
.. code-block:: YAML
model:
class: LGBModel
module_path: qlib.contrib.model.gbdt
args:
loss: mse
colsample_bytree: 0.8879
learning_rate: 0.0421
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
Users could find configuration file of the baseline of the ``Model`` in ``qlib/examples/estimator/estimator_config.yaml`` and ``qlib/examples/estimator/estimator_config_dnn.yaml``
Model Testing
=====================
Assuming that the configuration file is ``examples/estimator/estimator_config.yaml``, user can run the following command to test the custom model:
.. code-block:: bash
cd examples # Avoid running program under the directory contains `qlib`
estimator -c estimator/estimator_config.yaml
.. note:: ``estimator`` is a built-in command of ``Qlib``.
Also, ``Model`` can also be tested as a single module. An example has been given in ``examples.estimator.train_backtest_analyze.ipynb``.
Reference
=====================
To know more about ``Model``, please refer to `Interday Model: Model Training & Prediction <../advanced/model.rst>`_ and `Model API <../reference/api.html#module-qlib.contrib.model.base>`_.

View File

@@ -0,0 +1,257 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import json\n",
"import yaml\n",
"import pickle\n",
"from pathlib import Path\n",
"\n",
"import qlib\n",
"import pandas as pd\n",
"from qlib.config import REG_CN\n",
"from qlib.utils import exists_qlib_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"CUR_DIR = Path.cwd()\n",
"MARKET = \"csi300\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# use default data\n",
"# NOTE: need to download data from remote: python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data\n",
"provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
"if not exists_qlib_data(provider_uri):\n",
" print(f\"Qlib data is not found in {provider_uri}\")\n",
" sys.path.append(str(CUR_DIR.parent.parent.joinpath(\"scripts\")))\n",
" from get_data import GetData\n",
" GetData().qlib_data_cn(provider_uri)\n",
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with CUR_DIR.joinpath('estimator_config.yaml').open() as fp:\n",
" estimator_name = yaml.load(fp, Loader=yaml.FullLoader)['experiment']['name']\n",
"with CUR_DIR.joinpath(estimator_name, 'exp_info.json').open() as fp:\n",
" latest_id = json.load(fp)['id']\n",
" \n",
"estimator_dir = CUR_DIR.joinpath(estimator_name, 'sacred', latest_id)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# read estimator result"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pred_df = pd.read_pickle(estimator_dir.joinpath('pred.pkl'))\n",
"report_normal_df = pd.read_pickle(estimator_dir.joinpath('report_normal.pkl'))\n",
"report_normal_df.index.names = ['index']\n",
"\n",
"analysis_df = pd.read_pickle(estimator_dir.joinpath('analysis.pkl'))\n",
"positions = pickle.load(estimator_dir.joinpath('positions.pkl').open('rb'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# get label data from qlib"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from qlib.data import D\n",
"pred_df_dates = pred_df.index.get_level_values(level='datetime')\n",
"features_df = D.features(D.instruments(MARKET), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max())\n",
"features_df.columns = ['label']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# analyze graphs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from qlib.contrib.report import analysis_model, analysis_position"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## analysis position"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### report"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis_position.report_graph(report_normal_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### score IC"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pred_label = pd.concat([features_df, pred_df], axis=1, sort=True).reindex(features_df.index)\n",
"analysis_position.score_ic_graph(pred_label)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### cumulative return"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"analysis_position.cumulative_return_graph(positions, report_normal_df, features_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### risk analysis"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"analysis_position.risk_analysis_graph(analysis_df, report_normal_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### rank label"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis_position.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## analysis model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### model performance"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"analysis_model.model_performance_graph(pred_label)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@@ -0,0 +1,55 @@
experiment:
name: estimator_example
observer_type: file_storage
mode: train
model:
class: LGBModel
module_path: qlib.contrib.model.gbdt
args:
loss: mse
colsample_bytree: 0.8879
learning_rate: 0.0421
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 64
num_threads: 20
min_data_in_leaf: 10
data:
class: QLibDataHandlerClose
args:
dropna_label: True
filter:
market: csi300
trainer:
class: StaticTrainer
args:
train_start_date: 2008-01-01
train_end_date: 2014-12-31
validate_start_date: 2015-01-01
validate_end_date: 2016-12-31
test_start_date: 2017-01-01
test_end_date: 2020-08-01
strategy:
class: TopkDropoutStrategy
args:
topk: 50
n_drop: 5
backtest:
normal_backtest_args:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: SH000300
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
qlib_data:
# when testing, please modify the following parameters according to the specific environment
provider_uri: "~/.qlib/qlib_data/cn_data"
region: "cn"
redis_port: 4312

View File

@@ -0,0 +1,57 @@
experiment:
name: estimator_example
observer_type: file_storage
mode: train
model:
module_path: qlib.contrib.model.pytorch_nn
class: DNNModelPytorch
args:
loss: mse
input_dim: 158
output_dim: 1
lr: 0.002
lr_decay: 0.96
lr_decay_steps: 100
optimizer: 'adam'
max_steps: 8000
batch_size: 4096
GPU: '0'
data:
class: QLibDataHandlerClose
args:
dropna_label: True
dropna_feature: True
filter:
market: csi300
trainer:
class: StaticTrainer
args:
train_start_date: 2007-01-01
train_end_date: 2014-12-31
validate_start_date: 2015-01-01
validate_end_date: 2016-12-31
test_start_date: 2017-01-01
test_end_date: 2020-08-01
strategy:
class: TopkDropoutStrategy
args:
topk: 50
n_drop: 5
backtest:
normal_backtest_args:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: SH000300
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
long_short_backtest_args:
topk: 50
qlib_data:
# when testing, please modify the following parameters according to the specific environment
provider_uri: "~/.qlib/qlib_data/cn_data"
region: "cn"

View File

@@ -0,0 +1,119 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
from pathlib import Path
import qlib
import pandas as pd
from qlib.config import REG_CN
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.estimator.handler import QLibDataHandlerClose
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from qlib.utils import exists_qlib_data
if __name__ == "__main__":
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data_cn(provider_uri)
qlib.init(provider_uri=provider_uri, region=REG_CN)
MARKET = "CSI300"
BENCHMARK = "SH000300"
###################################
# train model
###################################
DATA_HANDLER_CONFIG = {
"dropna_label": True,
"start_date": "2008-01-01",
"end_date": "2020-08-01",
"market": MARKET,
}
TRAINER_CONFIG = {
"train_start_date": "2008-01-01",
"train_end_date": "2014-12-31",
"validate_start_date": "2015-01-01",
"validate_end_date": "2016-12-31",
"test_start_date": "2017-01-01",
"test_end_date": "2020-08-01",
}
# use default DataHandler
# custom DataHandler, refer to: TODO: DataHandler API url
x_train, y_train, x_validate, y_validate, x_test, y_test = QLibDataHandlerClose(
**DATA_HANDLER_CONFIG
).get_split_data(**TRAINER_CONFIG)
MODEL_CONFIG = {
"loss": "mse",
"colsample_bytree": 0.8879,
"learning_rate": 0.0421,
"subsample": 0.8789,
"lambda_l1": 205.6999,
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
}
# use default model
# custom Model, refer to: TODO: Model API url
model = LGBModel(**MODEL_CONFIG)
model.fit(x_train, y_train, x_validate, y_validate)
_pred = model.predict(x_test)
_pred = pd.DataFrame(_pred, index=x_test.index, columns=y_test.columns)
# backtest requires pred_score
pred_score = pd.DataFrame(index=_pred.index)
pred_score["score"] = _pred.iloc(axis=1)[0]
# save pred_score to file
pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
pred_score_path.parent.mkdir(exist_ok=True, parents=True)
pred_score.to_pickle(pred_score_path)
###################################
# backtest
###################################
STRATEGY_CONFIG = {
"topk": 50,
"n_drop": 5,
}
BACKTEST_CONFIG = {
"verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": BENCHMARK,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
}
# use default strategy
# custom Strategy, refer to: TODO: Strategy API url
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
###################################
# analyze
# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
###################################
analysis = dict()
analysis["sub_bench"] = risk_analysis(report_normal["return"] - report_normal["bench"])
analysis["sub_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"])
analysis_df = pd.concat(analysis) # type: pd.DataFrame
print(analysis_df)

View File

@@ -0,0 +1,355 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"from pathlib import Path\n",
"\n",
"import qlib\n",
"import pandas as pd\n",
"from qlib.config import REG_CN\n",
"from qlib.contrib.model.gbdt import LGBModel\n",
"from qlib.contrib.estimator.handler import QLibDataHandlerClose\n",
"from qlib.contrib.strategy.strategy import TopkDropoutStrategy\n",
"from qlib.contrib.evaluate import (\n",
" backtest as normal_backtest,\n",
" risk_analysis,\n",
")\n",
"from qlib.utils import exists_qlib_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# use default data\n",
"# NOTE: need to download data from remote: python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data\n",
"provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
"if not exists_qlib_data(provider_uri):\n",
" print(f\"Qlib data is not found in {provider_uri}\")\n",
" sys.path.append(str(Path.cwd().parent.joinpath(\"scripts\")))\n",
" from get_data import GetData\n",
" GetData().qlib_data_cn(provider_uri)\n",
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"MARKET = \"csi300\"\n",
"BENCHMARK = \"SH000300\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# train model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"###################################\n",
"# train model\n",
"###################################\n",
"DATA_HANDLER_CONFIG = {\n",
" \"dropna_label\": True,\n",
" \"start_date\": \"2008-01-01\",\n",
" \"end_date\": \"2020-08-01\",\n",
" \"market\": MARKET,\n",
"}\n",
"\n",
"TRAINER_CONFIG = {\n",
" \"train_start_date\": \"2008-01-01\",\n",
" \"train_end_date\": \"2014-12-31\",\n",
" \"validate_start_date\": \"2015-01-01\",\n",
" \"validate_end_date\": \"2016-12-31\",\n",
" \"test_start_date\": \"2017-01-01\",\n",
" \"test_end_date\": \"2020-08-01\",\n",
"}\n",
"\n",
"# use default DataHandler\n",
"# custom DataHandler, refer to: TODO: DataHandler api url\n",
"x_train, y_train, x_validate, y_validate, x_test, y_test = QLibDataHandlerClose(**DATA_HANDLER_CONFIG).get_split_data(**TRAINER_CONFIG)\n",
"\n",
"\n",
"MODEL_CONFIG = {\n",
" \"loss\": \"mse\",\n",
" \"colsample_bytree\": 0.8879,\n",
" \"learning_rate\": 0.0421,\n",
" \"subsample\": 0.8789,\n",
" \"lambda_l1\": 205.6999,\n",
" \"lambda_l2\": 580.9768,\n",
" \"max_depth\": 8,\n",
" \"num_leaves\": 210,\n",
" \"num_threads\": 20,\n",
"}\n",
"# use default model\n",
"# custom Model, refer to: TODO: Model api url\n",
"model = LGBModel(**MODEL_CONFIG)\n",
"model.fit(x_train, y_train, x_validate, y_validate)\n",
"_pred = model.predict(x_test)\n",
"_pred = pd.DataFrame(_pred, index=x_test.index, columns=y_test.columns)\n",
"\n",
"# backtest requires pred_score\n",
"pred_score = pd.DataFrame(index=_pred.index)\n",
"pred_score[\"score\"] = _pred.iloc(axis=1)[0]\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# backtest"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"###################################\n",
"# backtest\n",
"###################################\n",
"STRATEGY_CONFIG = {\n",
" \"topk\": 50,\n",
" \"n_drop\": 5",
"}\n",
"BACKTEST_CONFIG = {\n",
" \"verbose\": False,\n",
" \"limit_threshold\": 0.095,\n",
" \"account\": 100000000,\n",
" \"benchmark\": BENCHMARK,\n",
" \"deal_price\": \"close\",\n",
" \"open_cost\": 0.0005,\n",
" \"close_cost\": 0.0015,\n",
" \"min_cost\": 5,\n",
" \n",
"}\n",
"\n",
"# use default strategy\n",
"# custom Strategy, refer to: TODO: Strategy api url\n",
"strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)\n",
"report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# analyze"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"###################################\n",
"# analyze\n",
"# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb\n",
"###################################\n",
"analysis = dict()\n",
"analysis[\"sub_bench\"] = risk_analysis(report_normal[\"return\"] - report_normal[\"bench\"])\n",
"analysis[\"sub_cost\"] = risk_analysis(\n",
" report_normal[\"return\"] - report_normal[\"bench\"] - report_normal[\"cost\"]\n",
")\n",
"analysis_df = pd.concat(analysis) # type: pd.DataFrame\n",
"print(analysis_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# analyze graphs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from qlib.contrib.report import analysis_model, analysis_position"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get label data\n",
"from qlib.data import D\n",
"pred_df_dates = pred_score.index.get_level_values(level='datetime')\n",
"features_df = D.features(D.instruments(MARKET), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max())\n",
"features_df.columns = ['label']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## analysis position"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### report"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis_position.report_graph(report_normal)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### score IC"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pred_label = pd.concat([features_df, pred_score], axis=1, sort=True).reindex(features_df.index)\n",
"analysis_position.score_ic_graph(pred_label)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### cumulative return"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"analysis_position.cumulative_return_graph(positions_normal, report_normal, features_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### risk analysis"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"analysis_position.risk_analysis_graph(analysis_df, report_normal)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### rank label"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis_position.rank_label_graph(positions_normal, features_df, pred_df_dates.min(), pred_df_dates.max())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## analysis model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### model performance"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"analysis_model.model_performance_graph(pred_label)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}

197
qlib/__init__.py Normal file
View File

@@ -0,0 +1,197 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
__version__ = "0.4.6.dev"
import os
import copy
import logging
import re
import subprocess
import platform
from pathlib import Path
from .utils import can_use_cache
# init qlib
def init(default_conf="client", **kwargs):
from .config import (
C,
_default_client_config,
_default_server_config,
_default_region_config,
REG_CN,
)
from .data.data import register_all_wrappers
from .log import get_module_logger, set_log_with_config
_logging_config = C.logging_config
if "logging_config" in kwargs:
_logging_config = kwargs["logging_config"]
# set global config
if _logging_config:
set_log_with_config(_logging_config)
LOG = get_module_logger("Initialization", level=logging.INFO)
LOG.info(f"default_conf: {default_conf}.")
if default_conf == "server":
base_config = copy.deepcopy(_default_server_config)
elif default_conf == "client":
base_config = copy.deepcopy(_default_client_config)
else:
raise ValueError("Unknown system type")
if base_config:
base_config.update(_default_region_config[kwargs.get("region", REG_CN)])
for k, v in base_config.items():
C[k] = v
for k, v in kwargs.items():
C[k] = v
if k not in C:
LOG.warning("Unrecognized config %s" % k)
if default_conf == "client":
C["mount_path"] = str(Path(C["mount_path"]).expanduser().resolve())
if not (C["expression_cache"] is None and C["dataset_cache"] is None):
# check redis
if not can_use_cache():
LOG.warning(
f"redis connection failed(host={C['redis_host']} port={C['redis_port']}), cache will not be used!"
)
C["expression_cache"] = None
C["dataset_cache"] = None
# check path if server/local
if re.match("^[^/ ]+:.+", C["provider_uri"]) is None:
C["provider_uri"] = str(Path(C["provider_uri"]).expanduser().resolve())
if not os.path.exists(C["provider_uri"]):
if C["auto_mount"]:
LOG.error(
"Invalid provider uri: {}, please check if a valid provider uri has been set. This path does not exist.".format(
C["provider_uri"]
)
)
else:
LOG.warning("auto_path is False, please make sure {} is mounted".format(C["mount_path"]))
else:
mount_command = "sudo mount.nfs %s %s" % (C["provider_uri"], C["mount_path"])
# If the provider uri looks like this 172.23.233.89//data/csdesign'
# It will be a nfs path. The client provider will be used
if not C["auto_mount"]:
if not os.path.exists(C["mount_path"]):
raise FileNotFoundError(
"Invalid mount path: {}! Please mount manually: {} or Set init parameter `auto_mount=True`".format(
C["mount_path"], mount_command
)
)
else:
# Judging system type
sys_type = platform.system()
if "win" in sys_type.lower():
# system: window
exec_result = os.popen("mount -o anon %s %s" % (C["provider_uri"], C["mount_path"] + ":"))
result = exec_result.read()
if "85" in result:
LOG.warning("already mounted or window mount path already exists")
elif "53" in result:
raise OSError("not find network path")
elif "error" in result or "错误" in result:
raise OSError("Invalid mount path")
elif C["provider_uri"] in result:
LOG.info("window success mount..")
else:
raise OSError(f"unknown error: {result}")
# config mount path
C["mount_path"] = C["mount_path"] + ":\\"
else:
# system: linux/Unix/Mac
# check mount
_remote_uri = C["provider_uri"]
_remote_uri = _remote_uri[:-1] if _remote_uri.endswith("/") else _remote_uri
_mount_path = C["mount_path"]
_mount_path = _mount_path[:-1] if _mount_path.endswith("/") else _mount_path
_check_level_num = 2
_is_mount = False
while _check_level_num:
with subprocess.Popen(
'mount | grep "{}"'.format(_remote_uri),
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
) as shell_r:
_command_log = shell_r.stdout.readlines()
if len(_command_log) > 0:
for _c in _command_log:
_temp_mount = _c.decode("utf-8").split(" ")[2]
_temp_mount = _temp_mount[:-1] if _temp_mount.endswith("/") else _temp_mount
if _temp_mount == _mount_path:
_is_mount = True
break
if _is_mount:
break
_remote_uri = "/".join(_remote_uri.split("/")[:-1])
_mount_path = "/".join(_mount_path.split("/")[:-1])
_check_level_num -= 1
if not _is_mount:
try:
os.makedirs(C["mount_path"], exist_ok=True)
except Exception:
raise OSError(
"Failed to create directory {}, please create {} manually!".format(
C["mount_path"], C["mount_path"]
)
)
# check nfs-common
command_res = os.popen("dpkg -l | grep nfs-common")
command_res = command_res.readlines()
if not command_res:
raise OSError(
"nfs-common is not found, please install it by execute: sudo apt install nfs-common"
)
# manually mount
command_status = os.system(mount_command)
if command_status == 256:
raise OSError(
"mount {} on {} error! Needs SUDO! Please mount manually: {}".format(
C["provider_uri"], C["mount_path"], mount_command
)
)
elif command_status == 32512:
# LOG.error("Command error")
raise OSError("mount {} on {} error! Command error".format(C["provider_uri"], C["mount_path"]))
elif command_status == 0:
LOG.info("Mount finished")
else:
LOG.warning("{} on {} is already mounted".format(_remote_uri, _mount_path))
LOG.info("qlib successfully initialized based on %s settings." % default_conf)
register_all_wrappers()
try:
if C["auto_mount"]:
LOG.info(f"provider_uri={C['provider_uri']}")
else:
LOG.info(f"mount_path={C['mount_path']}")
except KeyError:
LOG.info(f"provider_uri={C['provider_uri']}")
if "flask_server" in C:
LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
def init_from_yaml_conf(conf_path):
"""init_from_yaml_conf
:param conf_path: A path to the qlib config in yml format
"""
import yaml
with open(conf_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
default_conf = config.pop("default_conf", "client")
init(default_conf, **config)

167
qlib/config.py Normal file
View File

@@ -0,0 +1,167 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# REGION CONST
REG_CN = "cn"
REG_US = "US"
_default_config = {
# data provider config
"calendar_provider": "LocalCalendarProvider",
"instrument_provider": "LocalInstrumentProvider",
"feature_provider": "LocalFeatureProvider",
"expression_provider": "LocalExpressionProvider",
"dataset_provider": "LocalDatasetProvider",
"provider": "LocalProvider",
# config it in qlib.init()
"provider_uri": "",
# cache
"expression_cache": None,
"dataset_cache": None,
"calendar_cache": None,
# for simple dataset cache
"local_cache_path": None,
"kernels": 16,
# How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data.
"maxtasksperchild": None,
"default_disk_cache": 1, # 0:skip/1:use
"disable_disk_cache": False, # disable disk cache; if High-frequency data generally disable_disk_cache=True
"mem_cache_size_limit": 500,
# memory cache expire second, only in used 'ClientDatasetCache' and 'client D.calendar'
# default 1 hour
"mem_cache_expire": 60 * 60,
# memory cache space limit, default 5GB, only in used client
"mem_cache_space_limit": 1024 * 1024 * 1024 * 5,
# cache dir name
"dataset_cache_dir_name": "dataset_cache",
"features_cache_dir_name": "features_cache",
# redis
# in order to use cache
"redis_host": "127.0.0.1",
"redis_port": 6379,
"redis_task_db": 1,
# This value can be reset via qlib.init
"logging_level": "INFO",
# Global configuration of qlib log
# logging_level can control the logging level more finely
"logging_config": {
"version": 1,
"formatters": {
"logger_format": {
"format": "[%(process)s:%(threadName)s](%(asctime)s) %(levelname)s - %(name)s - [%(filename)s:%(lineno)d] - %(message)s"
}
},
"filters": {
"field_not_found": {
"()": "qlib.log.LogFilter",
"param": [".*?WARN: data not found for.*?"],
}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": "DEBUG",
"formatter": "logger_format",
"filters": ["field_not_found"],
}
},
"loggers": {"qlib": {"level": "DEBUG", "handlers": ["console"]}},
},
}
_default_server_config = {
# data provider config
"calendar_provider": "LocalCalendarProvider",
"instrument_provider": "LocalInstrumentProvider",
"feature_provider": "LocalFeatureProvider",
"expression_provider": "LocalExpressionProvider",
"dataset_provider": "LocalDatasetProvider",
"provider": "LocalProvider",
# config it in qlib.init()
"provider_uri": "",
# redis
"redis_host": "127.0.0.1",
"redis_port": 6379,
"redis_task_db": 1,
"kernels": 64,
# cache
"expression_cache": "ServerExpressionCache",
"dataset_cache": "ServerDatasetCache",
}
_default_client_config = {
# data provider config
"calendar_provider": "LocalCalendarProvider",
"instrument_provider": "LocalInstrumentProvider",
"feature_provider": "LocalFeatureProvider",
"expression_provider": "LocalExpressionProvider",
"dataset_provider": "LocalDatasetProvider",
"provider": "LocalProvider",
# config it in user's own code
"provider_uri": "~/.qlib/qlib_data/cn_data",
# cache
# Using parameter 'remote' to announce the client is using server_cache, and the writing access will be disabled.
"expression_cache": "ServerExpressionCache",
"dataset_cache": "ServerDatasetCache",
"calendar_cache": None,
# client config
"kernels": 16,
"mount_path": "~/.qlib/qlib_data/cn_data",
"auto_mount": False, # The nfs is already mounted on our server[auto_mount: False].
# The nfs should be auto-mounted by qlib on other
# serversS(such as PAI) [auto_mount:True]
"timeout": 100,
"logging_level": "INFO",
"region": REG_CN,
}
_default_region_config = {
REG_CN: {
"trade_unit": 100,
"limit_threshold": 0.1,
"deal_price": "vwap",
},
REG_US: {
"trade_unit": 1,
"limit_threshold": None,
"deal_price": "close",
},
}
class Config:
def __getitem__(self, key):
return _default_config[key]
def __getattr__(self, attr):
try:
return _default_config[attr]
except KeyError:
return AttributeError(f"No such {attr} in _default_config")
def __setitem__(self, key, value):
_default_config[key] = value
def __setattr__(self, attr, value):
_default_config[attr] = value
def __contains__(self, item):
return item in _default_config
def __getstate__(self):
return _default_config
def __setstate__(self, state):
_default_config.update(state)
def __str__(self):
return str(_default_config)
def __repr__(self):
return str(_default_config)
# global config
C = Config()

0
qlib/contrib/__init__.py Normal file
View File

View File

@@ -0,0 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# -*- coding: utf-8 -*-
from .order import Order
from .account import Account
from .position import Position
from .exchange import Exchange
from .report import Report

View File

@@ -0,0 +1,174 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import copy
from .position import Position
from .report import Report
from .order import Order
"""
rtn & earning in the Account
rtn:
from order's view
1.change if any order is executed, sell order or buy order
2.change at the end of today, (today_clse - stock_price) * amount
earning
from value of current position
earning will be updated at the end of trade date
earning = today_value - pre_value
**is consider cost**
while earning is the difference of two position value, so it considers cost, it is the true return rate
in the specific accomplishment for rtn, it does not consider cost, in other words, rtn - cost = earning
"""
class Account:
def __init__(self, init_cash, last_trade_date=None):
self.init_vars(init_cash, last_trade_date)
def init_vars(self, init_cash, last_trade_date=None):
# init cash
self.init_cash = init_cash
self.current = Position(cash=init_cash)
self.positions = {}
self.rtn = 0
self.ct = 0
self.to = 0
self.val = 0
self.report = Report()
self.earning = 0
self.last_trade_date = last_trade_date
def get_positions(self):
return self.positions
def get_cash(self):
return self.current.position["cash"]
def update_state_from_order(self, order, trade_val, cost, trade_price):
# update cash
if order.direction == Order.SELL: # 0 for sell
self.current.position["cash"] += trade_val - cost
elif order.direction == Order.BUY: # 1 for buy
self.current.position["cash"] -= trade_val + cost
else:
raise NotImplementedError("{} ".format(order.direction))
# update turnover
self.to += trade_val
# update cost
self.ct += cost
# update return
# update self.rtn from order
if order.direction == Order.SELL: # 0 for sell
# when sell stock, get profit from price change
profit = trade_val - self.current.get_stock_price(order.stock_id) * order.deal_amount
self.rtn += profit # note here do not consider cost
elif order.direction == Order.BUY: # 1 for buy
# when buy stock, we get return for the rtn computing method
# profit in buy order is to make self.rtn is consistent with self.earning at the end of date
profit = self.current.get_stock_price(order.stock_id) * order.deal_amount - trade_val
self.rtn += profit
def update_order(self, order, trade_val, cost, trade_price):
# if stock is sold out, no stock price information in Position, then we should update account first, then update current position
# if stock is bought, there is no stock in current position, update current, then update account
if order.direction == Order.SELL:
# sell stock
self.update_state_from_order(order, trade_val, cost, trade_price)
# update current position
# for may sell all of stock_id
self.current.update_order(order, trade_price)
else:
# buy stock
# deal order, then update state
self.current.update_order(order, trade_price)
self.update_state_from_order(order, trade_val, cost, trade_price)
def update_daily_end(self, today, trader):
"""
today: pd.TimeStamp
quote: pd.DataFrame (code, date), collumns
when the end of trade date
- update rtn
- update price for each asset
- update value for this account
- update earning (2nd view of return )
- update holding day, count of stock
- update position hitory
- update report
:return: None
"""
# update price for stock in the position and the profit from changed_price
stock_list = self.current.get_stock_list()
profit = 0
for code in stock_list:
# if suspend, no new price to be updated, profit is 0
if trader.check_stock_suspended(code, today):
continue
else:
today_close = trader.get_close(code, today)
profit += (today_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
self.current.update_stock_price(stock_id=code, price=today_close)
self.rtn += profit
# update holding day count
self.current.add_count_all()
# update value
self.val = self.current.calculate_value()
# update earning (2nd view of return)
# account_value - last_account_value
# for the first trade date, account_value - init_cash
# self.report.is_empty() to judge is_first_trade_date
# get last_account_value, today_account_value, today_stock_value
if self.report.is_empty():
last_account_value = self.init_cash
else:
last_account_value = self.report.get_latest_account_value()
today_account_value = self.current.calculate_value()
today_stock_value = self.current.calculate_stock_value()
self.earning = today_account_value - last_account_value
# update report for today
# judge whether the the trading is begin.
# and don't add init account state into report, due to we don't have excess return in those days.
self.report.update_report_record(
trade_date=today,
account_value=today_account_value,
cash=self.current.position["cash"],
return_rate=(self.earning + self.ct) / last_account_value,
# here use earning to calculate return, position's view, earning consider cost, true return
# in order to make same definition with original backtest in evaluate.py
turnover_rate=self.to / last_account_value,
cost_rate=self.ct / last_account_value,
stock_value=today_stock_value,
)
# set today_account_value to position
self.current.position["today_account_value"] = today_account_value
self.current.update_weight_all()
# update positions
# note use deepcopy
self.positions[today] = copy.deepcopy(self.current)
# finish today's updation
# reset the daily variables
self.rtn = 0
self.ct = 0
self.to = 0
self.last_trade_date = today
def load_account(self, account_path):
report = Report()
position = Position()
last_trade_date = position.load_position(account_path / "position.xlsx")
report.load_report(account_path / "report.csv")
# assign values
self.init_vars(position.init_cash)
self.current = position
self.report = report
self.last_trade_date = last_trade_date if last_trade_date else None
def save_account(self, account_path):
self.current.save_position(account_path / "position.xlsx", self.last_trade_date)
self.report.save_report(account_path / "report.csv")

View File

@@ -0,0 +1,128 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import pandas as pd
from ...utils import get_date_by_shift, get_date_range
from ..online.executor import SimulatorExecutor
from ...data import D
from .account import Account
from ...config import C
from ...log import get_module_logger
LOG = get_module_logger("backtest")
def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark):
"""Parameters
----------
pred : pandas.DataFrame
predict should has <instrument, datetime> index and one `score` column
strategy : Strategy()
strategy part for backtest
trade_exchange : Exchange()
exchage for backtest
shift : int
whether to shift prediction by one day
verbose : bool
whether to print log
account : float
init account value
benchmark : str/list/pd.Series
`benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
example:
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
2017-01-04 0.011693
2017-01-05 0.000721
2017-01-06 -0.004322
2017-01-09 0.006874
2017-01-10 -0.003350
`benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
`benchmark` is str, will use the daily change as the 'bench'.
benchmark code, default is SH000905 CSI500
"""
trade_account = Account(init_cash=account)
_pred_dates = pred.index.get_level_values(level="datetime")
predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max())
if isinstance(benchmark, pd.Series):
bench = benchmark
else:
_codes = benchmark if isinstance(benchmark, list) else [benchmark]
_temp_result = D.features(
_codes,
["$close/Ref($close,1)-1"],
predict_dates[0],
get_date_by_shift(predict_dates[-1], shift=shift),
disk_cache=1,
)
bench = _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean()
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], shift=shift))
executor = SimulatorExecutor(trade_exchange, verbose=verbose)
# trading apart
for pred_date, trade_date in zip(predict_dates, trade_dates):
# for loop predict date and trading date
# print
if verbose:
LOG.info("[I {:%Y-%m-%d}]: trade begin.".format(trade_date))
# 1. Load the score_series at pred_date
try:
score = pred.loc(axis=0)[:, pred_date] # (stock_id, trade_date) multi_index, score in pdate
score_series = score.reset_index(level="datetime", drop=True)[
"score"
] # pd.Series(index:stock_id, data: score)
except KeyError:
LOG.warning("No score found on predict date[{:%Y-%m-%d}]".format(trade_date))
score_series = None
if score_series is not None and score_series.count() > 0: # in case of the scores are all None
# 2. Update your strategy (and model)
strategy.update(score_series, pred_date, trade_date)
# 3. Generate order list
order_list = strategy.generate_order_list(
score_series=score_series,
current=trade_account.current,
trade_exchange=trade_exchange,
pred_date=pred_date,
trade_date=trade_date,
)
else:
order_list = []
# 4. Get result after executing order list
# NOTE: The following operation will modify order.amount.
# NOTE: If it is buy and the cash is insufficient, the tradable amount will be recalculated
trade_info = executor.execute(trade_account, order_list, trade_date)
# 5. Update account information according to transaction
update_account(trade_account, trade_info, trade_exchange, trade_date)
# generate backtest report
report_df = trade_account.report.generate_report_dataframe()
report_df["bench"] = bench
positions = trade_account.get_positions()
return report_df, positions
def update_account(trade_account, trade_info, trade_exchange, trade_date):
"""Update the account and strategy
Parameters
----------
trade_account : Account()
trade_info : list of [Order(), float, float, float]
(order, trade_val, trade_cost, trade_price), trade_info with out factor
trade_exchange : Exchange()
used to get the $close_price at trade_date to update account
trade_date : pd.Timestamp
"""
# update account
for [order, trade_val, trade_cost, trade_price] in trade_info:
if order.deal_amount == 0:
continue
trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)
# at the end of trade date, update the account based the $close_price of stocks.
trade_account.update_daily_end(today=trade_date, trader=trade_exchange)

View File

@@ -0,0 +1,430 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import random
import logging
import numpy as np
import pandas as pd
from ...data import D
from .order import Order
from ...config import C, REG_CN
from ...log import get_module_logger
class Exchange:
def __init__(
self,
trade_dates=None,
codes="all",
deal_price=None,
subscribe_fields=[],
limit_threshold=None,
open_cost=0.0015,
close_cost=0.0025,
trade_unit=None,
min_cost=5,
extra_quote=None,
):
"""__init__
:param trade_dates: list of pd.Timestamp
:param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50)
:param deal_price: str, 'close', 'open', 'vwap'
:param subscribe_fields: list, subscribe fields
:param limit_threshold: float, 0.1 for example, default None
:param open_cost: cost rate for open, default 0.0015
:param close_cost: cost rate for close, default 0.0025
:param trade_unit: trade unit, 100 for China A market
:param min_cost: min cost, default 5
:param extra_quote: pandas, dataframe consists of
columns: like ['$vwap', '$close', '$factor', 'limit'].
The limit indicates that the etf is tradable on a specific day.
Necessary fields:
$close is for calculating the total value at end of each day.
Optional fields:
$vwap is only necessary when we use the $vwap price as the deal price
$factor is for rounding to the trading unit
limit will be set to False by default(False indicates we can buy this
target on this day).
index: MultipleIndex(instrument, pd.Datetime)
"""
if trade_unit is None:
trade_unit = C.trade_unit
if limit_threshold is None:
limit_threshold = C.limit_threshold
if deal_price is None:
deal_price = C.deal_price
self.logger = get_module_logger("online operator", level=logging.INFO)
self.trade_unit = trade_unit
# TODO: the quote, trade_dates, codes are not necessray.
# It is just for performance consideration.
if limit_threshold is None:
if C.region == REG_CN:
self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold")
elif abs(limit_threshold) > 0.1:
if C.region == REG_CN:
self.logger.warning(f"limit_threshold may not be set to a reasonable value")
if deal_price[0] != "$":
self.deal_price = "$" + deal_price
else:
self.deal_price = deal_price
if isinstance(codes, str):
codes = D.instruments(codes)
self.codes = codes
# Necessary fields
# $close is for calculating the total value at end of each day.
# $factor is for rounding to the trading unit
# $change is for calculating the limit of the stock
necessary_fields = {self.deal_price, "$close", "$change", "$factor"}
subscribe_fields = list(necessary_fields | set(subscribe_fields))
all_fields = list(necessary_fields | set(subscribe_fields))
self.all_fields = all_fields
self.open_cost = open_cost
self.close_cost = close_cost
self.min_cost = min_cost
self.limit_threshold = limit_threshold
# TODO: the quote, trade_dates, codes are not necessray.
# It is just for performance consideration.
if trade_dates is not None and len(trade_dates):
start_date, end_date = trade_dates[0], trade_dates[-1]
else:
self.logger.warning("trade_dates have not been assigned, all dates will be loaded")
start_date, end_date = None, None
self.extra_quote = extra_quote
self.set_quote(codes, start_date, end_date)
def set_quote(self, codes, start_date, end_date):
if len(codes) == 0:
codes = D.instruments()
self.quote = D.features(codes, self.all_fields, start_date, end_date, disk_cache=True).dropna(subset=["$close"])
self.quote.columns = self.all_fields
if self.quote[self.deal_price].isna().any():
self.logger.warning("{} field data contains nan.".format(self.deal_price))
if self.quote["$factor"].isna().any():
# The 'factor.day.bin' file not exists, and `factor` field contains `nan`
# Use adjusted price
self.trade_w_adj_price = True
self.logger.warning("factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price.")
else:
# The `factor.day.bin` file exists and all data `close` and `factor` are not `nan`
# Use normal price
self.trade_w_adj_price = False
# update limit
# check limit_threshold
if self.limit_threshold is None:
self.quote["limit"] = False
else:
# set limit
self._update_limit(buy_limit=self.limit_threshold, sell_limit=self.limit_threshold)
quote_df = self.quote
if self.extra_quote is not None:
# process extra_quote
if "$close" not in self.extra_quote:
raise ValueError("$close is necessray in extra_quote")
if self.deal_price not in self.extra_quote.columns:
self.extra_quote[self.deal_price] = self.extra_quote["$close"]
self.logger.warning("No deal_price set for extra_quote. Use $close as deal_price.")
if "$factor" not in self.extra_quote.columns:
self.extra_quote["$factor"] = 1.0
self.logger.warning("No $factor set for extra_quote. Use 1.0 as $factor.")
if "limit" not in self.extra_quote.columns:
self.extra_quote["limit"] = False
self.logger.warning("No limit set for extra_quote. All stock will be tradable.")
assert set(self.extra_quote.columns) == set(quote_df.columns) - {"$change"}
quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0)
# update quote: pd.DataFrame to dict, for search use
self.quote = quote_df.to_dict("index")
def _update_limit(self, buy_limit, sell_limit):
self.quote["limit"] = ~self.quote["$change"].between(-sell_limit, buy_limit)
def check_stock_limit(self, stock_id, trade_date):
"""Parameter
stock_id
trade_date
is limtited
"""
return self.quote[(stock_id, trade_date)]["limit"]
def check_stock_suspended(self, stock_id, trade_date):
# is suspended
return (stock_id, trade_date) not in self.quote
def is_stock_tradable(self, stock_id, trade_date):
# check if stock can be traded
# same as check in check_order
if self.check_stock_suspended(stock_id, trade_date) or self.check_stock_limit(stock_id, trade_date):
return False
else:
return True
def check_order(self, order):
# check limit and suspended
if self.check_stock_suspended(order.stock_id, order.trade_date) or self.check_stock_limit(
order.stock_id, order.trade_date
):
return False
else:
return True
def deal_order(self, order, trade_account=None, position=None):
"""
Deal order when the actual transaction
:param order: Deal the order.
:param trade_account: Trade account to be updated after dealing the order.
:param position: position to be updated after dealing the order.
:return: trade_val, trade_cost, trade_price
"""
# need to check order first
# TODO: check the order unit limit in the exchange!!!!
# The order limit is related to the adj factor and the cur_amount.
# factor = self.quote[(order.stock_id, order.trade_date)]['$factor']
# cur_amount = trade_account.current.get_stock_amount(order.stock_id)
if self.check_order(order) is False:
raise AttributeError("need to check order first")
if trade_account is not None and position is not None:
raise ValueError("trade_account and position can only choose one")
trade_price = self.get_deal_price(order.stock_id, order.trade_date)
trade_val, trade_cost = self._calc_trade_info_by_order(
order, trade_account.current if trade_account else position
)
# update account
if trade_val > 0:
# If the order can only be deal 0 trade_val. Nothing to be updated
# Otherwise, it will result some stock with 0 amount in the position
if trade_account:
trade_account.update_order(
order=order,
trade_val=trade_val,
cost=trade_cost,
trade_price=trade_price,
)
elif position:
position.update_order(order, trade_price)
return trade_val, trade_cost, trade_price
def get_quote_info(self, stock_id, trade_date):
return self.quote[(stock_id, trade_date)]
def get_close(self, stock_id, trade_date):
return self.quote[(stock_id, trade_date)]["$close"]
def get_deal_price(self, stock_id, trade_date):
deal_price = self.quote[(stock_id, trade_date)][self.deal_price]
if np.isclose(deal_price, 0.0) or np.isnan(deal_price):
self.logger.warning(f"(stock_id:{stock_id}, trade_date:{trade_date}, {self.deal_price}): {deal_price}!!!")
self.logger.warning(f"setting deal_price to close price")
deal_price = self.get_close(stock_id, trade_date)
return deal_price
def get_factor(self, stock_id, trade_date):
return self.quote[(stock_id, trade_date)]["$factor"]
def generate_amount_position_from_weight_position(self, weight_position, cash, trade_date):
"""
The generate the target position according to the weight and the cash.
NOTE: All the cash will assigned to the tadable stock.
Parameter:
weight_position : dict {stock_id : weight}; allocate cash by weight_position
among then, weight must be in this range: 0 < weight < 1
cash : cash
trade_date : trade date
"""
# calculate the total weight of tradable value
tradable_weight = 0.0
for stock_id in weight_position:
if self.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):
# weight_position must be greater than 0 and less than 1
if weight_position[stock_id] < 0 or weight_position[stock_id] > 1:
raise ValueError(
"weight_position is {}, "
"weight_position is not in the range of (0, 1).".format(weight_position[stock_id])
)
tradable_weight += weight_position[stock_id]
if tradable_weight - 1.0 >= 1e-5:
raise ValueError("tradable_weight is {}, can not greater than 1.".format(tradable_weight))
amount_dict = {}
for stock_id in weight_position:
if weight_position[stock_id] > 0.0 and self.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):
amount_dict[stock_id] = (
cash
* weight_position[stock_id]
/ tradable_weight
// self.get_deal_price(stock_id=stock_id, trade_date=trade_date)
)
return amount_dict
def get_real_deal_amount(self, current_amount, target_amount, factor):
"""
Calculate the real adjust deal amount when considering the trading unit
:param current_amount:
:param target_amount:
:param factor:
:return real_deal_amount; Positive deal_amount indicates buying more stock.
"""
if current_amount == target_amount:
return 0
elif current_amount < target_amount:
deal_amount = target_amount - current_amount
deal_amount = self.round_amount_by_trade_unit(deal_amount, factor)
return deal_amount
else:
if target_amount == 0:
return -current_amount
else:
deal_amount = current_amount - target_amount
deal_amount = self.round_amount_by_trade_unit(deal_amount, factor)
return -deal_amount
def generate_order_for_target_amount_position(self, target_position, current_position, trade_date):
"""Parameter:
target_position : dict { stock_id : amount }
current_postion : dict { stock_id : amount}
trade_unit : trade_unit
down sample : for amount 321 and trade_unit 100, deal_amount is 300
deal order on trade_date
"""
# split buy and sell for further use
buy_order_list = []
sell_order_list = []
# three parts: kept stock_id, dropped stock_id, new stock_id
# handle kept stock_id
# because the order of the set is not fixed, the trading order of the stock is different, so that the backtest results of the same parameter are different;
# so here we sort stock_id, and then randomly shuffle the order of stock_id
# because the same random seed is used, the final stock_id order is fixed
sorted_ids = sorted(set(list(current_position.keys()) + list(target_position.keys())))
random.seed(0)
random.shuffle(sorted_ids)
for stock_id in sorted_ids:
# Do not generate order for the nontradable stocks
if not self.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):
continue
target_amount = target_position.get(stock_id, 0)
current_amount = current_position.get(stock_id, 0)
factor = self.quote[(stock_id, trade_date)]["$factor"]
deal_amount = self.get_real_deal_amount(current_amount, target_amount, factor)
if deal_amount == 0:
continue
elif deal_amount > 0:
# buy stock
buy_order_list.append(
Order(
stock_id=stock_id,
amount=deal_amount,
direction=Order.BUY,
trade_date=trade_date,
factor=factor,
)
)
else:
# sell stock
sell_order_list.append(
Order(
stock_id=stock_id,
amount=abs(deal_amount),
direction=Order.SELL,
trade_date=trade_date,
factor=factor,
)
)
# return order_list : buy + sell
return sell_order_list + buy_order_list
def calculate_amount_position_value(self, amount_dict, trade_date, only_tradable=False):
"""Parameter
position : Position()
amount_dict : {stock_id : amount}
"""
value = 0
for stock_id in amount_dict:
if (
self.check_stock_suspended(stock_id=stock_id, trade_date=trade_date) is False
and self.check_stock_limit(stock_id=stock_id, trade_date=trade_date) is False
):
value += self.get_deal_price(stock_id=stock_id, trade_date=trade_date) * amount_dict[stock_id]
return value
def round_amount_by_trade_unit(self, deal_amount, factor):
"""Parameter
deal_amount : float, adjusted amount
factor : float, adjusted factor
return : float, real amount
"""
if not self.trade_w_adj_price:
# the minimal amount is 1. Add 0.1 for solving precision problem.
return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor
return deal_amount
def _calc_trade_info_by_order(self, order, position):
"""
Calculation of trade info
:param order:
:param position: Position
:return: trade_val, trade_cost
"""
trade_price = self.get_deal_price(order.stock_id, order.trade_date)
if order.direction == Order.SELL:
# sell
if position is not None:
if np.isclose(order.amount, position.get_stock_amount(order.stock_id)):
# when selling last stock. The amount don't need rounding
order.deal_amount = order.amount
else:
order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
else:
# TODO: We don't know current position.
# We choose to sell all
order.deal_amount = order.amount
trade_val = order.deal_amount * trade_price
trade_cost = max(trade_val * self.close_cost, self.min_cost)
elif order.direction == Order.BUY:
# buy
if position is not None:
cash = position.get_cash()
trade_val = order.amount * trade_price
if cash < trade_val * (1 + self.open_cost):
# The money is not enough
order.deal_amount = self.round_amount_by_trade_unit(
cash / (1 + self.open_cost) / trade_price, order.factor
)
else:
# THe money is enough
order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
else:
# Unknown amount of money. Just round the amount
order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
trade_val = order.deal_amount * trade_price
trade_cost = trade_val * self.open_cost
else:
raise NotImplementedError("order type {} error".format(order.type))
return trade_val, trade_cost

View File

@@ -0,0 +1,29 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
class Order:
SELL = 0
BUY = 1
def __init__(self, stock_id, amount, trade_date, direction, factor):
"""Parameter
direction : Order.SELL for sell; Order.BUY for buy
stock_id : str
amount : float
trade_date : pd.Timestamp
factor : float
presents the weight factor assigned in Exchange()
"""
# check direction
if direction not in {Order.SELL, Order.BUY}:
raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy")
self.stock_id = stock_id
# amount of generated orders
self.amount = amount
# amount of successfully completed orders
self.deal_amount = 0
self.trade_date = trade_date
self.direction = direction
self.factor = factor

View File

@@ -0,0 +1,207 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
import copy
import pathlib
from .order import Order
"""
Position module
"""
"""
current state of position
a typical example is :{
<instrument_id>: {
'count': <how many days the security has been hold>,
'amount': <the amount of the security>,
'price': <the close price of security in the last trading day>,
'weight': <the security weight of total position value>,
},
}
"""
class Position:
"""Position"""
def __init__(self, cash=0, position_dict={}, today_account_value=0):
# NOTE: The position dict must be copied!!!
# Otherwise the initial value
self.init_cash = cash
self.position = position_dict.copy()
self.position["cash"] = cash
self.position["today_account_value"] = today_account_value
def init_stock(self, stock_id, amount, price=None):
self.position[stock_id] = {}
self.position[stock_id]["count"] = 0 # update count in the end of this date
self.position[stock_id]["amount"] = amount
self.position[stock_id]["price"] = price
self.position[stock_id]["weight"] = 0 # update the weight in the end of the trade date
def buy_stock(self, stock_id, amount, price):
if stock_id not in self.position:
self.init_stock(stock_id=stock_id, amount=amount, price=price)
else:
# exist, add amount
self.position[stock_id]["amount"] += amount
def sell_stock(self, stock_id, amount):
if stock_id not in self.position:
raise KeyError("{} not in current position".format(stock_id))
else:
# decrease the amount of stock
self.position[stock_id]["amount"] -= amount
# check if to delete
if self.position[stock_id]["amount"] < -1e-5:
raise ValueError(
"only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, amount)
)
elif abs(self.position[stock_id]["amount"]) <= 1e-5:
self.del_stock(stock_id)
def del_stock(self, stock_id):
del self.position[stock_id]
def update_order(self, order, trade_price):
# handle order, order is a order class, defined in exchange.py
if order.direction == Order.BUY:
# BUY
self.buy_stock(stock_id=order.stock_id, amount=order.deal_amount, price=trade_price)
elif order.direction == Order.SELL:
# SELL
self.sell_stock(stock_id=order.stock_id, amount=order.deal_amount)
else:
raise NotImplementedError("do not suppotr order direction {}".format(order.direction))
def update_stock_price(self, stock_id, price):
self.position[stock_id]["price"] = price
def update_stock_count(self, stock_id, count):
self.position[stock_id]["count"] = count
def update_stock_weight(self, stock_id, weight):
self.position[stock_id]["weight"] = weight
def update_cash(self, cash):
self.position["cash"] = cash
def calculate_stock_value(self):
stock_list = self.get_stock_list()
value = 0
for stock_id in stock_list:
value += self.position[stock_id]["amount"] * self.position[stock_id]["price"]
return value
def calculate_value(self):
value = self.calculate_stock_value()
value += self.position["cash"]
return value
def get_stock_list(self):
stock_list = list(set(self.position.keys()) - {"cash", "today_account_value"})
return stock_list
def get_stock_price(self, code):
return self.position[code]["price"]
def get_stock_amount(self, code):
return self.position[code]["amount"]
def get_stock_count(self, code):
return self.position[code]["count"]
def get_stock_weight(self, code):
return self.position[code]["weight"]
def get_cash(self):
return self.position["cash"]
def get_stock_amount_dict(self):
"""generate stock amount dict {stock_id : amount of stock} """
d = {}
stock_list = self.get_stock_list()
for stock_code in stock_list:
d[stock_code] = self.get_stock_amount(code=stock_code)
return d
def get_stock_weight_dict(self, only_stock=False):
"""get_stock_weight_dict
generate stock weight fict {stock_id : value weight of stock in the position}
it is meaningful in the beginning or the end of each trade date
:param only_stock: If only_stock=True, the weight of each stock in total stock will be returned
If only_stock=False, the weight of each stock in total assets(stock + cash) will be returned
"""
if only_stock:
position_value = self.calculate_stock_value()
else:
position_value = self.calculate_value()
d = {}
stock_list = self.get_stock_list()
for stock_code in stock_list:
d[stock_code] = self.position[stock_code]["amount"] * self.position[stock_code]["price"] / position_value
return d
def add_count_all(self):
stock_list = self.get_stock_list()
for code in stock_list:
self.position[code]["count"] += 1
def update_weight_all(self):
weight_dict = self.get_stock_weight_dict()
for stock_code, weight in weight_dict.items():
self.update_stock_weight(stock_code, weight)
def save_position(self, path, last_trade_date):
path = pathlib.Path(path)
p = copy.deepcopy(self.position)
cash = pd.Series()
cash["init_cash"] = self.init_cash
cash["cash"] = p["cash"]
cash["today_account_value"] = p["today_account_value"]
cash["last_trade_date"] = str(last_trade_date.date()) if last_trade_date else None
del p["cash"]
del p["today_account_value"]
positions = pd.DataFrame.from_dict(p, orient="index")
with pd.ExcelWriter(path) as writer:
positions.to_excel(writer, sheet_name="position")
cash.to_excel(writer, sheet_name="info")
def load_position(self, path):
"""load position information from a file
should have format below
sheet "position"
columns: ['stock', 'count', 'amount', 'price', 'weight']
'count': <how many days the security has been hold>,
'amount': <the amount of the security>,
'price': <the close price of security in the last trading day>,
'weight': <the security weight of total position value>,
sheet "cash"
index: ['init_cash', 'cash', 'today_account_value']
'init_cash': <inital cash when account was created>,
'cash': <current cash in account>,
'today_account_value': <current total account value, should equal to sum(price[stock]*amount[stock])>
"""
path = pathlib.Path(path)
positions = pd.read_excel(open(path, "rb"), sheet_name="position", index_col=0)
cash_record = pd.read_excel(open(path, "rb"), sheet_name="info", index_col=0)
positions = positions.to_dict(orient="index")
init_cash = cash_record.loc["init_cash"].values[0]
cash = cash_record.loc["cash"].values[0]
today_account_value = cash_record.loc["today_account_value"].values[0]
last_trade_date = cash_record.loc["last_trade_date"].values[0]
# assign values
self.position = {}
self.init_cash = init_cash
self.position = positions
self.position["cash"] = cash
self.position["today_account_value"] = today_account_value
return None if pd.isna(last_trade_date) else pd.Timestamp(last_trade_date)

View File

@@ -0,0 +1,324 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import pandas as pd
from .position import Position
from ...data import D
from ...config import C
import datetime
from pathlib import Path
def get_benchmark_weight(
bench,
start_date=None,
end_date=None,
path=None,
):
"""get_benchmark_weight
get the stock weight distribution of the benchmark
:param bench:
:param start_date:
:param end_date:
:param path:
:return: The weight distribution of the the benchmark described by a pandas dataframe
Every row corresponds to a trading day.
Every column corresponds to a stock.
Every cell represents the strategy.
"""
if not path:
path = Path(C.mount_path).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
# TODO: the storage of weights should be implemented in a more elegent way
# TODO: The benchmark is not consistant with the filename in instruments.
bench_weight_df = pd.read_csv(path, usecols=["code", "date", "index", "weight"])
bench_weight_df = bench_weight_df[bench_weight_df["index"] == bench]
bench_weight_df["date"] = pd.to_datetime(bench_weight_df["date"])
if start_date is not None:
bench_weight_df = bench_weight_df[bench_weight_df.date >= start_date]
if end_date is not None:
bench_weight_df = bench_weight_df[bench_weight_df.date <= end_date]
bench_stock_weight = bench_weight_df.pivot_table(index="date", columns="code", values="weight") / 100.0
return bench_stock_weight
def get_stock_weight_df(positions):
"""get_stock_weight_df
:param positions: Given a positions from backtest result.
:return: A weight distribution for the position
"""
stock_weight = []
index = []
for date in sorted(positions.keys()):
pos = positions[date]
if isinstance(pos, dict):
pos = Position(position_dict=pos)
index.append(date)
stock_weight.append(pos.get_stock_weight_dict(only_stock=True))
return pd.DataFrame(stock_weight, index=index)
def decompose_portofolio_weight(stock_weight_df, stock_group_df):
"""decompose_portofolio_weight
'''
:param stock_weight_df: a pandas dataframe to describe the portofolio by weight.
every row corresponds to a day
every column corresponds to a stock.
Here is an example below.
code SH600004 SH600006 SH600017 SH600022 SH600026 SH600037 \
date
2016-01-05 0.001543 0.001570 0.002732 0.001320 0.003000 NaN
2016-01-06 0.001538 0.001569 0.002770 0.001417 0.002945 NaN
....
:param stock_group_df: a pandas dataframe to describe the stock group.
every row corresponds to a day
every column corresponds to a stock.
the value in the cell repreponds the group id.
Here is a example by for stock_group_df for industry. The value is the industry code
instrument SH600000 SH600004 SH600005 SH600006 SH600007 SH600008 \
datetime
2016-01-05 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0
2016-01-06 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0
...
:return: Two dict will be returned. The group_weight and the stock_weight_in_group.
The key is the group. The value is a Series or Dataframe to describe the weight of group or weight of stock
"""
all_group = np.unique(stock_group_df.values.flatten())
all_group = all_group[~np.isnan(all_group)]
group_weight = {}
stock_weight_in_group = {}
for group_key in all_group:
group_mask = stock_group_df == group_key
group_weight[group_key] = stock_weight_df[group_mask].sum(axis=1)
stock_weight_in_group[group_key] = stock_weight_df[group_mask].divide(group_weight[group_key], axis=0)
return group_weight, stock_weight_in_group
def decompose_portofolio(stock_weight_df, stock_group_df, stock_ret_df):
"""
:param stock_weight_df: a pandas dataframe to describe the portofolio by weight.
every row corresponds to a day
every column corresponds to a stock.
Here is an example below.
code SH600004 SH600006 SH600017 SH600022 SH600026 SH600037 \
date
2016-01-05 0.001543 0.001570 0.002732 0.001320 0.003000 NaN
2016-01-06 0.001538 0.001569 0.002770 0.001417 0.002945 NaN
2016-01-07 0.001555 0.001546 0.002772 0.001393 0.002904 NaN
2016-01-08 0.001564 0.001527 0.002791 0.001506 0.002948 NaN
2016-01-11 0.001597 0.001476 0.002738 0.001493 0.003043 NaN
....
:param stock_group_df: a pandas dataframe to describe the stock group.
every row corresponds to a day
every column corresponds to a stock.
the value in the cell repreponds the group id.
Here is a example by for stock_group_df for industry. The value is the industry code
instrument SH600000 SH600004 SH600005 SH600006 SH600007 SH600008 \
datetime
2016-01-05 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0
2016-01-06 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0
2016-01-07 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0
2016-01-08 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0
2016-01-11 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0
...
:param stock_ret_df: a pandas dataframe to describe the stock return.
every row corresponds to a day
every column corresponds to a stock.
the value in the cell repreponds the return of the group.
Here is a example by for stock_ret_df.
instrument SH600000 SH600004 SH600005 SH600006 SH600007 SH600008 \
datetime
2016-01-05 0.007795 0.022070 0.099099 0.024707 0.009473 0.016216
2016-01-06 -0.032597 -0.075205 -0.098361 -0.098985 -0.099707 -0.098936
2016-01-07 -0.001142 0.022544 0.100000 0.004225 0.000651 0.047226
2016-01-08 -0.025157 -0.047244 -0.038567 -0.098177 -0.099609 -0.074408
2016-01-11 0.023460 0.004959 -0.034384 0.018663 0.014461 0.010962
...
:return: It will decompose the portofolio to the group weight and group return.
"""
all_group = np.unique(stock_group_df.values.flatten())
all_group = all_group[~np.isnan(all_group)]
group_weight, stock_weight_in_group = decompose_portofolio_weight(stock_weight_df, stock_group_df)
group_ret = {}
for group_key in stock_weight_in_group:
stock_weight_in_group_start_date = min(stock_weight_in_group[group_key].index)
stock_weight_in_group_end_date = max(stock_weight_in_group[group_key].index)
temp_stock_ret_df = stock_ret_df[
(stock_ret_df.index >= stock_weight_in_group_start_date)
& (stock_ret_df.index <= stock_weight_in_group_end_date)
]
group_ret[group_key] = (temp_stock_ret_df * stock_weight_in_group[group_key]).sum(axis=1)
# If no weight is assigned, then the return of group will be np.nan
group_ret[group_key][group_weight[group_key] == 0.0] = np.nan
group_weight_df = pd.DataFrame(group_weight)
group_ret_df = pd.DataFrame(group_ret)
return group_weight_df, group_ret_df
def get_daily_bin_group(bench_values, stock_values, group_n):
"""get_daily_bin_group
Group the values of the stocks of benchmark into several bins in a day.
Put the stocks into these bins.
:param bench_values: A series contains the value of stocks in benchmark.
The index is the stock code.
:param stock_values: A series contains the value of stocks of your portofolio
The index is the stock code.
:param group_n: Bins will be produced
:return: A series with the same size and index as the stock_value.
The value in the series is the group id of the bins.
The No.1 bin contains the biggest values.
"""
stock_group = stock_values.copy()
# get the bin split points based on the daily proportion of benchmark
split_points = np.percentile(bench_values[~bench_values.isna()], np.linspace(0, 100, group_n + 1))
# Modify the biggest uppper bound and smallest lowerbound
split_points[0], split_points[-1] = -np.inf, np.inf
for i, (lb, up) in enumerate(zip(split_points, split_points[1:])):
stock_group.loc[stock_values[(stock_values >= lb) & (stock_values < up)].index] = group_n - i
return stock_group
def get_stock_group(stock_group_field_df, bench_stock_weight_df, group_method, group_n=None):
if group_method == "category":
# use the value of the benchmark as the category
return stock_group_field_df
elif group_method == "bins":
assert group_n is not None
# place the values into `group_n` fields.
# Each bin corresponds to a category.
new_stock_group_df = stock_group_field_df.copy().loc[
bench_stock_weight_df.index.min() : bench_stock_weight_df.index.max()
]
for idx, row in (~bench_stock_weight_df.isna()).iterrows():
bench_values = stock_group_field_df.loc[idx, row[row].index]
new_stock_group_df.loc[idx] = get_daily_bin_group(
bench_values, stock_group_field_df.loc[idx], group_n=group_n
)
return new_stock_group_df
def brinson_pa(
positions,
bench="SH000905",
group_field="industry",
group_method="category",
group_n=None,
deal_price="vwap",
):
"""brinson profit attribution
:param positions: The position produced by the backtest class
:param bench: The benchmark for comparing. TODO: if no benchmark is set, the equal-weighted is used.
:param group_field: The field used to set the group for assets allocation.
`industry` and `market_value` is often used.
:param group_method: 'category' or 'bins'. The method used to set the group for asstes allocation
`bin` will split the value into `group_n` bins and each bins represents a group
:param group_n: . Only used when group_method == 'bins'.
:return:
A dataframe with three columns: RAA(excess Return of Assets Allocation), RSS(excess Return of Stock Selectino), RTotal(Total excess Return)
Every row corresponds to a trading day, the value corresponds to the next return for this trading day
The middle info of brinson profit attribution
"""
# group_method will decide how to group the group_field.
dates = sorted(positions.keys())
start_date, end_date = min(dates), max(dates)
bench_stock_weight = get_benchmark_weight(bench, start_date, end_date)
# The attributes for allocation will not
if not group_field.startswith("$"):
group_field = "$" + group_field
if not deal_price.startswith("$"):
deal_price = "$" + deal_price
# FIXME: In current version. Some attributes(such as market_value) of some
# suspend stock is NAN. So we have to get more date to forward fill the NAN
shift_start_date = start_date - datetime.timedelta(days=250)
instruments = D.list_instruments(
D.instruments(market="all"),
start_time=shift_start_date,
end_time=end_date,
as_list=True,
)
stock_df = D.features(
instruments,
[group_field, deal_price],
start_time=shift_start_date,
end_time=end_date,
freq="day",
)
stock_df.columns = [group_field, "deal_price"]
stock_group_field = stock_df[group_field].unstack().T
# FIXME: some attributes of some suspend stock is NAN.
stock_group_field = stock_group_field.fillna(method="ffill")
stock_group_field = stock_group_field.loc[start_date:end_date]
stock_group = get_stock_group(stock_group_field, bench_stock_weight, group_method, group_n)
deal_price_df = stock_df["deal_price"].unstack().T
deal_price_df = deal_price_df.fillna(method="ffill")
# NOTE:
# The return will be slightly different from the of the return in the report.
# Here the position are adjusted at the end of the trading day with close
stock_ret = (deal_price_df - deal_price_df.shift(1)) / deal_price_df.shift(1)
stock_ret = stock_ret.shift(-1).loc[start_date:end_date]
port_stock_weight_df = get_stock_weight_df(positions)
# decomposing the portofolio
port_group_weight_df, port_group_ret_df = decompose_portofolio(port_stock_weight_df, stock_group, stock_ret)
bench_group_weight_df, bench_group_ret_df = decompose_portofolio(bench_stock_weight, stock_group, stock_ret)
# if the group return of the portofolio is NaN, replace it with the market
# value
mod_port_group_ret_df = port_group_ret_df.copy()
mod_port_group_ret_df[mod_port_group_ret_df.isna()] = bench_group_ret_df
Q1 = (bench_group_weight_df * bench_group_ret_df).sum(axis=1)
Q2 = (port_group_weight_df * bench_group_ret_df).sum(axis=1)
Q3 = (bench_group_weight_df * mod_port_group_ret_df).sum(axis=1)
Q4 = (port_group_weight_df * mod_port_group_ret_df).sum(axis=1)
return (
pd.DataFrame(
{
"RAA": Q2 - Q1, # The excess profit from the assets allocation
"RSS": Q3 - Q1, # The excess profit from the stocks selection
# The excess profit from the interaction of assets allocation and stocks selection
"RIN": Q4 - Q3 - Q2 + Q1,
"RTotal": Q4 - Q1, # The totoal excess profit
}
),
{
"port_group_ret": port_group_ret_df,
"port_group_weight": port_group_weight_df,
"bench_group_ret": bench_group_ret_df,
"bench_group_weight": bench_group_weight_df,
"stock_group": stock_group,
"bench_stock_weight": bench_stock_weight,
"port_stock_weight": port_stock_weight_df,
"stock_ret": stock_ret,
},
)

View File

@@ -0,0 +1,106 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from collections import OrderedDict
import pandas as pd
import pathlib
class Report:
# daily report of the account
# contain those followings: returns, costs turnovers, accounts, cash, bench, value
# update report
def __init__(self):
self.init_vars()
def init_vars(self):
self.accounts = OrderedDict() # account postion value for each trade date
self.returns = OrderedDict() # daily return rate for each trade date
self.turnovers = OrderedDict() # turnover for each trade date
self.costs = OrderedDict() # trade cost for each trade date
self.values = OrderedDict() # value for each trade date
self.cashes = OrderedDict()
self.latest_report_date = None # pd.TimeStamp
def is_empty(self):
return len(self.accounts) == 0
def get_latest_date(self):
return self.latest_report_date
def get_latest_account_value(self):
return self.accounts[self.latest_report_date]
def update_report_record(
self,
trade_date=None,
account_value=None,
cash=None,
return_rate=None,
turnover_rate=None,
cost_rate=None,
stock_value=None,
):
# check data
if None in [
trade_date,
account_value,
cash,
return_rate,
turnover_rate,
cost_rate,
stock_value,
]:
raise ValueError(
"None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value]"
)
# update report data
self.accounts[trade_date] = account_value
self.returns[trade_date] = return_rate
self.turnovers[trade_date] = turnover_rate
self.costs[trade_date] = cost_rate
self.values[trade_date] = stock_value
self.cashes[trade_date] = cash
# update latest_report_date
self.latest_report_date = trade_date
# finish daily report update
def generate_report_dataframe(self):
report = pd.DataFrame()
report["account"] = pd.Series(self.accounts)
report["return"] = pd.Series(self.returns)
report["turnover"] = pd.Series(self.turnovers)
report["cost"] = pd.Series(self.costs)
report["value"] = pd.Series(self.values)
report["cash"] = pd.Series(self.cashes)
report.index.name = "date"
return report
def save_report(self, path):
r = self.generate_report_dataframe()
r.to_csv(path)
def load_report(self, path):
"""load report from a file
should have format like
columns = ['account', 'return', 'turnover', 'cost', 'value', 'cash']
:param
path: str/ pathlib.Path()
"""
path = pathlib.Path(path)
r = pd.read_csv(open(path, "rb"), index_col=0)
r.index = pd.DatetimeIndex(r.index)
index = r.index
self.init_vars()
for date in index:
self.update_report_record(
trade_date=date,
account_value=r.loc[date]["account"],
cash=r.loc[date]["cash"],
return_rate=r.loc[date]["return"],
turnover_rate=r.loc[date]["turnover"],
cost_rate=r.loc[date]["cost"],
stock_value=r.loc[date]["value"],
)

View File

View File

@@ -0,0 +1,176 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import yaml
import copy
import os
import json
import tempfile
from pathlib import Path
from ...config import REG_CN
class EstimatorConfigManager(object):
def __init__(self, config_path):
if not config_path:
raise ValueError("Config path is invalid.")
self.config_path = config_path
with open(config_path) as fp:
config = yaml.load(fp, Loader=yaml.FullLoader)
self.config = copy.deepcopy(config)
self.ex_config = ExperimentConfig(config.get("experiment", dict()), self)
self.data_config = DataConfig(config.get("data", dict()), self)
self.model_config = ModelConfig(config.get("model", dict()), self)
self.trainer_config = TrainerConfig(config.get("trainer", dict()), self)
self.strategy_config = StrategyConfig(config.get("strategy", dict()), self)
self.backtest_config = BacktestConfig(config.get("backtest", dict()), self)
self.qlib_data_config = QlibDataConfig(config.get("qlib_data", dict()), self)
# If the start_date and end_date are not given in data_config, they will be referred from the trainer_config.
handler_start_date = self.data_config.handler_parameters.get("start_date", None)
handler_end_date = self.data_config.handler_parameters.get("end_date", None)
if handler_start_date is None:
self.data_config.handler_parameters["start_date"] = self.trainer_config.parameters["train_start_date"]
if handler_end_date is None:
self.data_config.handler_parameters["end_date"] = self.trainer_config.parameters["test_end_date"]
class ExperimentConfig(object):
TRAIN_MODE = "train"
TEST_MODE = "test"
OBSERVER_FILE_STORAGE = "file_storage"
OBSERVER_MONGO = "mongo"
def __init__(self, config, CONFIG_MANAGER):
"""__init__
:param config: The config dict for experiment
:param CONFIG_MANAGER: The estimator config manager
"""
self.name = config.get("name", "test_experiment")
# The dir of the result of all the experiments
self.global_dir = config.get("dir", os.path.dirname(CONFIG_MANAGER.config_path))
# The dir of the result of current experiment
self.ex_dir = os.path.join(self.global_dir, self.name)
if not os.path.exists(self.ex_dir):
os.makedirs(self.ex_dir)
self.tmp_run_dir = tempfile.mkdtemp(dir=self.ex_dir)
self.mode = config.get("mode", ExperimentConfig.TRAIN_MODE)
self.sacred_dir = os.path.join(self.ex_dir, "sacred")
self.observer_type = config.get("observer_type", ExperimentConfig.OBSERVER_FILE_STORAGE)
self.mongo_url = config.get("mongo_url", None)
self.db_name = config.get("db_name", None)
self.finetune = config.get("finetune", False)
# The path of the experiment id of the experiment
self.exp_info_path = config.get("exp_info_path", os.path.join(self.ex_dir, "exp_info.json"))
exp_info_dir = Path(self.exp_info_path).parent
exp_info_dir.mkdir(parents=True, exist_ok=True)
# Test mode config
loader_args = config.get("loader", dict())
if self.mode == ExperimentConfig.TEST_MODE or self.finetune:
loader_exp_info_path = loader_args.get("exp_info_path", None)
self.loader_model_index = loader_args.get("model_index", None)
if (loader_exp_info_path is not None) and (os.path.exists(loader_exp_info_path)):
with open(loader_exp_info_path) as fp:
loader_dict = json.load(fp)
for k, v in loader_dict.items():
setattr(self, "loader_{}".format(k), v)
# Check loader experiment id
assert hasattr(self, "loader_id"), "If mode is test or finetune is True, loader must contain id."
else:
self.loader_id = loader_args.get("id", None)
if self.loader_id is None:
raise ValueError("If mode is test or finetune is True, loader must contain id.")
self.loader_observer_type = loader_args.get("observer_type", self.observer_type)
self.loader_name = loader_args.get("name", self.name)
self.loader_dir = loader_args.get("dir", self.global_dir)
self.loader_mongo_url = loader_args.get("mongo_url", self.mongo_url)
self.loader_db_name = loader_args.get("db_name", self.db_name)
class DataConfig(object):
def __init__(self, config, CONFIG_MANAGER):
"""__init__
:param config: The config dict for data
:param CONFIG_MANAGER: The estimator config manager
"""
self.handler_module_path = config.get("module_path", "qlib.contrib.estimator.handler")
self.handler_class = config.get("class", "ALPHA360")
self.handler_parameters = config.get("args", dict())
self.handler_filter = config.get("filter", dict())
# Update provider uri.
class ModelConfig(object):
def __init__(self, config, CONFIG_MANAGER):
"""__init__
:param config: The config dict for model
:param CONFIG_MANAGER: The estimator config manager
"""
self.model_class = config.get("class", "Model")
self.model_module_path = config.get("module_path", "qlib.contrib.model")
self.save_dir = os.path.join(CONFIG_MANAGER.ex_config.tmp_run_dir, "model")
self.save_path = config.get("save_path", os.path.join(self.save_dir, "model.bin"))
self.parameters = config.get("args", dict())
# Make dir if need.
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
class TrainerConfig(object):
def __init__(self, config, CONFIG_MANAGER):
"""__init__
:param config: The config dict for trainer
:param CONFIG_MANAGER: The estimator config manager
"""
self.trainer_class = config.get("class", "StaticTrainer")
self.trainer_module_path = config.get("module_path", "qlib.contrib.estimator.trainer")
self.parameters = config.get("args", dict())
class StrategyConfig(object):
def __init__(self, config, CONFIG_MANAGER):
"""__init__
:param config: The config dict for strategy
:param CONFIG_MANAGER: The estimator config manager
"""
self.strategy_class = config.get("class", "TopkDropoutStrategy")
self.strategy_module_path = config.get("module_path", "qlib.contrib.strategy.strategy")
self.parameters = config.get("args", dict())
class BacktestConfig(object):
def __init__(self, config, CONFIG_MANAGE):
"""__init__
:param config: The config dict for strategy
:param CONFIG_MANAGE: The estimator config manager
"""
self.normal_backtest_parameters = config.get("normal_backtest_args", dict())
self.long_short_backtest_parameters = config.get("long_short_backtest_args", dict())
class QlibDataConfig(object):
def __init__(self, config, CONFIG_MANAGE):
"""__init__
:param config: The config dict for qlib_client
:param CONFIG_MANAGE: The estimator config manager
"""
self.provider_uri = config.pop("provider_uri", "~/.qlib/qlib_data/cn_data")
self.auto_mount = config.pop("auto_mount", False)
self.mount_path = config.pop("mount_path", "~/.qlib/qlib_data/cn_data")
self.region = config.pop("region", REG_CN)
self.args = config

View File

@@ -0,0 +1,323 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# coding=utf-8
import pandas as pd
import os
import copy
import json
import yaml
import pickle
import qlib
from ..evaluate import risk_analysis
from ..evaluate import backtest as normal_backtest
from ..evaluate import long_short_backtest
from .config import ExperimentConfig
from .fetcher import create_fetcher_with_config
from ...log import get_module_logger, TimeInspector
from ...utils import get_module_by_module_path, compare_dict_value
class Estimator(object):
def __init__(self, config_manager, sacred_ex):
# Set logger.
self.logger = get_module_logger("Estimator")
# 1. Set config manager.
self.config_manager = config_manager
# 2. Set configs.
self.ex_config = config_manager.ex_config
self.data_config = config_manager.data_config
self.model_config = config_manager.model_config
self.trainer_config = config_manager.trainer_config
self.strategy_config = config_manager.strategy_config
self.backtest_config = config_manager.backtest_config
# If experiment.mode is test or experiment.finetune is True, load the experimental results in the loader
if self.ex_config.mode == self.ex_config.TEST_MODE or self.ex_config.finetune:
self.compare_config_with_config_manger(self.config_manager)
# 3. Set sacred_experiment.
self.ex = sacred_ex
# 4. Init data handler.
self.data_handler = None
self._init_data_handler()
# 5. Init trainer.
self.trainer = None
self._init_trainer()
# 6. Init strategy.
self.strategy = None
self._init_strategy()
def _init_data_handler(self):
handler_module = get_module_by_module_path(self.data_config.handler_module_path)
# Set market
market = self.data_config.handler_filter.get("market", None)
if market is None:
if "market" in self.data_config.handler_parameters:
self.logger.warning(
"Warning: The market in data.args section is deprecated. "
"It only works when market is not set in data.filter section. "
"It will be overridden by market in the data.filter section."
)
market = self.data_config.handler_parameters["market"]
else:
market = "csi500"
self.data_config.handler_parameters["market"] = market
data_filter_list = []
handler_filters = self.data_config.handler_filter.get("filter_pipeline", list())
for h_filter in handler_filters:
filter_module_path = h_filter.get("module_path", "qlib.data.filter")
filter_class_name = h_filter.get("class", "")
filter_parameters = h_filter.get("args", {})
filter_module = get_module_by_module_path(filter_module_path)
filter_class = getattr(filter_module, filter_class_name)
data_filter = filter_class(**filter_parameters)
data_filter_list.append(data_filter)
self.data_config.handler_parameters["data_filter_list"] = data_filter_list
handler_class = getattr(handler_module, self.data_config.handler_class)
self.data_handler = handler_class(**self.data_config.handler_parameters)
def _init_trainer(self):
model_module = get_module_by_module_path(self.model_config.model_module_path)
trainer_module = get_module_by_module_path(self.trainer_config.trainer_module_path)
model_class = getattr(model_module, self.model_config.model_class)
trainer_class = getattr(trainer_module, self.trainer_config.trainer_class)
self.trainer = trainer_class(
model_class,
self.model_config.save_path,
self.model_config.parameters,
self.data_handler,
self.ex,
**self.trainer_config.parameters
)
def _init_strategy(self):
module = get_module_by_module_path(self.strategy_config.strategy_module_path)
strategy_class = getattr(module, self.strategy_config.strategy_class)
self.strategy = strategy_class(**self.strategy_config.parameters)
def run(self):
if self.ex_config.mode == ExperimentConfig.TRAIN_MODE:
self.trainer.train()
elif self.ex_config.mode == ExperimentConfig.TEST_MODE:
self.trainer.load()
else:
raise ValueError("unexpected mode: %s" % self.ex_config.mode)
analysis = self.backtest()
self.logger.info(analysis)
self.logger.info(
"experiment id: {}, experiment name: {}".format(self.ex.experiment.current_run._id, self.ex_config.name)
)
# Remove temp dir
# shutil.rmtree(self.ex_config.tmp_run_dir)
def backtest(self):
TimeInspector.set_time_mark()
# 1. Get pred and prediction score of model(s).
pred = self.trainer.get_test_score()
performance = self.trainer.get_test_performance()
# 2. Normal Backtest.
report_normal, positions_normal = self._normal_backtest(pred)
# 3. Long-Short Backtest.
# Deprecated
# long_short_reports = self._long_short_backtest(pred)
# 4. Analyze
analysis_df = self._analyze(report_normal)
# 5. Save.
self._save_backtest_result(
pred,
analysis_df,
positions_normal,
report_normal,
# long_short_reports,
performance,
)
return analysis_df
def _normal_backtest(self, pred):
TimeInspector.set_time_mark()
if "account" not in self.backtest_config.normal_backtest_parameters:
if "account" in self.strategy_config.parameters:
self.logger.warning(
"Warning: The account in strategy section is deprecated. "
"It only works when account is not set in backtest section. "
"It will be overridden by account in the backtest section."
)
self.backtest_config.normal_backtest_parameters["account"] = self.strategy_config.parameters["account"]
report_normal, positions_normal = normal_backtest(
pred, strategy=self.strategy, **self.backtest_config.normal_backtest_parameters
)
TimeInspector.log_cost_time("Finished normal backtest.")
return report_normal, positions_normal
def _long_short_backtest(self, pred):
TimeInspector.set_time_mark()
long_short_reports = long_short_backtest(pred, **self.backtest_config.long_short_backtest_parameters)
TimeInspector.log_cost_time("Finished long-short backtest.")
return long_short_reports
@staticmethod
def _analyze(report_normal):
TimeInspector.set_time_mark()
analysis = dict()
# analysis["pred_long"] = risk_analysis(long_short_reports["long"])
# analysis["pred_short"] = risk_analysis(long_short_reports["short"])
# analysis["pred_long_short"] = risk_analysis(long_short_reports["long_short"])
analysis["sub_bench"] = risk_analysis(report_normal["return"] - report_normal["bench"])
analysis["sub_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"])
analysis_df = pd.concat(analysis) # type: pd.DataFrame
TimeInspector.log_cost_time(
"Finished generating analysis," " average turnover is: {0:.4f}.".format(report_normal["turnover"].mean())
)
return analysis_df
def _save_backtest_result(self, pred, analysis, positions, report_normal, performance):
# 1. Result dir.
result_dir = os.path.join(self.config_manager.ex_config.tmp_run_dir, "result")
if not os.path.exists(result_dir):
os.makedirs(result_dir)
self.ex.add_info(
"task_config",
json.loads(json.dumps(self.config_manager.config, default=str)),
)
# 2. Pred.
TimeInspector.set_time_mark()
pred_pkl_path = os.path.join(result_dir, "pred.pkl")
pred.to_pickle(pred_pkl_path)
self.ex.add_artifact(pred_pkl_path)
TimeInspector.log_cost_time("Finished saving pred.pkl to: {}".format(pred_pkl_path))
# 3. Ana.
TimeInspector.set_time_mark()
analysis_pkl_path = os.path.join(result_dir, "analysis.pkl")
analysis.to_pickle(analysis_pkl_path)
self.ex.add_artifact(analysis_pkl_path)
TimeInspector.log_cost_time("Finished saving analysis.pkl to: {}".format(analysis_pkl_path))
# 4. Pos.
TimeInspector.set_time_mark()
positions_pkl_path = os.path.join(result_dir, "positions.pkl")
with open(positions_pkl_path, "wb") as fp:
pickle.dump(positions, fp)
self.ex.add_artifact(positions_pkl_path)
TimeInspector.log_cost_time("Finished saving positions.pkl to: {}".format(positions_pkl_path))
# 5. Report normal.
TimeInspector.set_time_mark()
report_normal_pkl_path = os.path.join(result_dir, "report_normal.pkl")
report_normal.to_pickle(report_normal_pkl_path)
self.ex.add_artifact(report_normal_pkl_path)
TimeInspector.log_cost_time("Finished saving report_normal.pkl to: {}".format(report_normal_pkl_path))
# 6. Report long short.
# Deprecated
# for k, name in zip(
# ["long", "short", "long_short"],
# ["report_long.pkl", "report_short.pkl", "report_long_short.pkl"],
# ):
# TimeInspector.set_time_mark()
# pkl_path = os.path.join(result_dir, name)
# long_short_reports[k].to_pickle(pkl_path)
# self.ex.add_artifact(pkl_path)
# TimeInspector.log_cost_time("Finished saving {} to: {}".format(name, pkl_path))
# 7. Origin test label.
TimeInspector.set_time_mark()
label_pkl_path = os.path.join(result_dir, "label.pkl")
self.data_handler.get_origin_test_label_with_date(
self.trainer_config.parameters["test_start_date"],
self.trainer_config.parameters["test_end_date"],
).to_pickle(label_pkl_path)
self.ex.add_artifact(label_pkl_path)
TimeInspector.log_cost_time("Finished saving label.pkl to: {}".format(label_pkl_path))
# 8. Experiment info, save the model(s) performance here.
TimeInspector.set_time_mark()
cur_ex_id = self.ex.experiment.current_run._id
exp_info = {
"id": cur_ex_id,
"name": self.ex_config.name,
"performance": performance,
"observer_type": self.ex_config.observer_type,
}
if self.ex_config.observer_type == ExperimentConfig.OBSERVER_MONGO:
exp_info.update(
{
"mongo_url": self.ex_config.mongo_url,
"db_name": self.ex_config.db_name,
}
)
else:
exp_info.update({"dir": self.ex_config.global_dir})
with open(self.ex_config.exp_info_path, "w") as fp:
json.dump(exp_info, fp, indent=4, sort_keys=True)
self.ex.add_artifact(self.ex_config.exp_info_path)
TimeInspector.log_cost_time("Finished saving ex_info to: {}".format(self.ex_config.exp_info_path))
@staticmethod
def compare_config_with_config_manger(config_manager):
"""Compare loader model args and current config with ConfigManage
:param config_manager: ConfigManager
:return:
"""
fetcher = create_fetcher_with_config(config_manager, load_form_loader=True)
loader_mode_config = fetcher.get_experiment(
exp_name=config_manager.ex_config.loader_name,
exp_id=config_manager.ex_config.loader_id,
fields=["task_config"],
)["task_config"]
with open(config_manager.config_path) as fp:
current_config = yaml.load(fp.read())
current_config = json.loads(json.dumps(current_config, default=str))
logger = get_module_logger("Estimator")
loader_mode_config = copy.deepcopy(loader_mode_config)
current_config = copy.deepcopy(current_config)
# Require test_mode_config.test_start_date <= current_config.test_start_date
loader_trainer_args = loader_mode_config.get("trainer", {}).get("args", {})
cur_trainer_args = current_config.get("trainer", {}).get("args", {})
loader_start_date = loader_trainer_args.pop("test_start_date")
cur_test_start_date = cur_trainer_args.pop("test_start_date")
assert (
loader_start_date <= cur_test_start_date
), "Require: loader_mode_config.test_start_date <= current_config.test_start_date"
# TODO: For the user's own extended `Trainer`, the support is not very good
if "RollingTrainer" == current_config.get("trainer", {}).get("class", None):
loader_period = loader_trainer_args.pop("rolling_period")
cur_period = cur_trainer_args.pop("rolling_period")
assert (
loader_period == cur_period
), "Require: loader_mode_config.rolling_period == current_config.rolling_period"
compare_section = ["trainer", "model", "data"]
for section in compare_section:
changes = compare_dict_value(loader_mode_config.get(section, {}), current_config.get(section, {}))
if changes:
logger.warning("Warning: Loader mode config and current config, `{}` are different:\n".format(section))

View File

@@ -0,0 +1,290 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# coding=utf-8
import copy
import json
import yaml
import pickle
import gridfs
import pymongo
from pathlib import Path
from abc import abstractmethod
from .config import EstimatorConfigManager, ExperimentConfig
class Fetcher(object):
"""Sacred Experiments Fetcher"""
@abstractmethod
def _get_experiment(self, exp_name, exp_id):
"""Get experiment basic info with experiment and experiment id
:param exp_name: experiment name
:param exp_id: experiment id
:return: dict
Must contain keys: _id, experiment, info, stop_time.
Here is an example below for FileFetcher.
exp = {
'_id': exp_id, # experiment id
'path': path, # experiment result path
'experiment': {'name': exp_name}, # experiment
'info': info, # experiment config info
'stop_time': run.get('stop_time', None) # The time the experiment ended
}
"""
pass
@abstractmethod
def _list_experiments(self, exp_name=None):
"""Get experiment basic info list with experiment name
:param exp_name: experiment name
:return: list
"""
pass
@abstractmethod
def _iter_artifacts(self, experiment):
"""Get information about the data in the experiment results
:param experiment: `self._get_experiment` method result
:return: iterable
Each element contains two elements.
first element : data name
second element : data uri
"""
pass
@abstractmethod
def _load_data(self, uri):
"""Load data with uri
:param uri: data uri
:return: bytes
"""
pass
@staticmethod
def model_dict_to_buffer_list(model_dict):
"""
:param model_dict:
:return:
"""
model_list = []
is_static_model = False
if len(model_dict) == 1 and list(model_dict.keys())[0] == "model.bin":
is_static_model = True
model_list.append(list(model_dict.values())[0])
else:
sep = "model.bin_"
model_ids = list(map(lambda x: int(x.split(sep)[1]), model_dict.keys()))
min_id, max_id = min(model_ids), max(model_ids)
for i in range(min_id, max_id + 1):
model_key = sep + str(i)
model = model_dict.get(model_key, None)
if model is None:
print(
"WARNING: In Fetcher, {} is missing when the get model is in the get_experiment function.".format(
model_key
)
)
break
else:
model_list.append(model)
if is_static_model:
return model_list[0]
return model_list
def get_experiments(self, exp_name=None):
"""Get experiments with name.
:param exp_name: str
If `exp_name` is set to None, then all experiments will return.
:return: dict
Experiments info dict(Including experiment id and task_config to run the
experiment). Here is an example below.
{
'a_experiment': [
{
'id': '1',
'task_config': {...}
},
...
]
...
}
"""
res = dict()
for ex in self._list_experiments(exp_name):
name = ex["experiment"]["name"]
tmp = {
"id": ex["_id"],
"task_config": ex["info"].get("task_config", {}),
"ex_run_stop_time": ex.get("stop_time", None),
}
res.setdefault(name, []).append(tmp)
return res
def get_experiment(self, exp_name, exp_id, fields=None):
"""
:param exp_name:
:param exp_id:
:param fields: list
Experiment result fields, if fields is None, will get all fields.
Currently supported fields:
['model', 'analysis', 'positions', 'report_normal', 'pred', 'task_config', 'label']
:return: dict
"""
fields = copy.copy(fields)
ex = self._get_experiment(exp_name, exp_id)
results = dict()
model_dict = dict()
for name, uri in self._iter_artifacts(ex):
# When saving, use `sacred.experiment.add_artifact(filename)` , so `name` is os.path.basename(filename)
prefix = name.split(".")[0]
if fields and prefix not in fields:
continue
data = self._load_data(uri)
if prefix == "model":
model_dict[name] = data
else:
results[prefix] = pickle.loads(data)
# Sort model
if model_dict:
results["model"] = self.model_dict_to_buffer_list(model_dict)
# Info
results["task_config"] = ex["info"].get("task_config", {})
return results
def estimator_config_to_dict(self, exp_name, exp_id):
"""Save configuration to file
:param exp_name:
:param exp_id:
:return: config dict
"""
return self.get_experiment(exp_name, exp_id, fields=["task_config"])["task_config"]
class FileFetcher(Fetcher):
"""File Fetcher"""
def __init__(self, experiments_dir):
self.experiments_dir = Path(experiments_dir)
def _get_experiment(self, exp_name, exp_id):
path = self.experiments_dir / exp_name / "sacred" / str(exp_id)
info_path = path / "info.json"
run_path = path / "run.json"
if info_path.exists():
with info_path.open("r") as f:
info = json.load(f)
else:
info = {}
if run_path.exists():
with run_path.open("r") as f:
run = json.load(f)
else:
run = {}
exp = {
"_id": exp_id,
"path": path,
"experiment": {"name": exp_name},
"info": info,
"stop_time": run.get("stop_time", None),
}
return exp
def _list_experiments(self, exp_name=None):
runs = []
for path in self.experiments_dir.glob("{}/sacred/[!_]*".format(exp_name or "*")):
exp_name, exp_id = path.parents[1].name, path.name
runs.append(self._get_experiment(exp_name, exp_id))
return runs
def _iter_artifacts(self, experiment):
if experiment is None:
return []
for fname in experiment["path"].iterdir():
if fname.suffix == ".pkl" or ".bin" in fname.suffix:
name, uri = fname.name, str(fname)
yield name, uri
def _load_data(self, uri):
with open(uri, "rb") as f:
data = f.read()
return data
class MongoFetcher(Fetcher):
"""MongoDB Fetcher"""
def __init__(self, mongo_url, db_name):
self.mongo_url = mongo_url
self.db_name = db_name
self.client = None
self.db = None
self.runs = None
self.fs = None
self._setup_mongo_client()
def _setup_mongo_client(self):
self.client = pymongo.MongoClient(self.mongo_url)
self.db = self.client[self.db_name]
self.runs = self.db.runs
self.fs = gridfs.GridFS(self.db)
def _get_experiment(self, exp_name, exp_id):
return self.runs.find_one({"_id": exp_id})
def _list_experiments(self, exp_name=None):
if exp_name is None:
return self.runs.find()
return self.runs.find({"experiment.name": exp_name})
def _iter_artifacts(self, experiment):
if experiment is None:
return []
for artifact in experiment.get("artifacts", []):
name, uri = artifact["name"], artifact["file_id"]
yield name, uri
def _load_data(self, uri):
data = self.fs.get(uri).read()
return data
def create_fetcher_with_config(config_manager: EstimatorConfigManager, load_form_loader: bool = False):
"""Create fetcher with loader config
:param config_manager:
:param load_form_loader
:return:
"""
flag = ""
if load_form_loader:
flag = "loader_"
if config_manager.ex_config.observer_type == ExperimentConfig.OBSERVER_FILE_STORAGE:
return FileFetcher(eval("config_manager.ex_config.{}_dir".format("loader" if load_form_loader else "global")))
elif config_manager.ex_config.observer_type == ExperimentConfig.OBSERVER_MONGO:
return MongoFetcher(
mongo_url=eval("config_manager.ex_config.{}mongo_url".format(flag)),
db_name=eval("config_manager.ex_config.{}db_name".format(flag)),
)
else:
return NotImplementedError("Unkown Backend")

View File

@@ -0,0 +1,585 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# coding=utf-8
import abc
import bisect
import logging
import pandas as pd
import numpy as np
from ...log import get_module_logger, TimeInspector
from ...data import D
from ...utils import parse_config, transform_end_date
from . import processor as processor_module
class BaseDataHandler(abc.ABC):
def __init__(self, processors=[], **kwargs):
"""
:param start_date:
:param end_date:
:param kwargs:
"""
# Set logger
self.logger = get_module_logger("DataHandler")
# init data using kwargs
self._init_kwargs(**kwargs)
# Setup data.
self.raw_df, self.feature_names, self.label_names = self._init_raw_df()
# Setup preprocessor
self.processors = []
for klass in processors:
if isinstance(klass, str):
try:
klass = getattr(processor_module, klass)
except:
raise ValueError("unknown Processor %s" % klass)
self.processors.append(klass(self.feature_names, self.label_names, **kwargs))
def _init_kwargs(self, **kwargs):
"""
init the kwargs of DataHandler
"""
pass
def _init_raw_df(self):
"""
init raw_df, feature_names, label_names of DataHandler
if the index of df_feature and df_label are not same, user need to overload this method to merge (e.g. inner, left, right merge).
"""
df_features = self.setup_feature()
feature_names = df_features.columns
df_labels = self.setup_label()
label_names = df_labels.columns
raw_df = df_features.merge(df_labels, left_index=True, right_index=True, how="left")
return raw_df, feature_names, label_names
def reset_label(self, df_labels):
for col in self.label_names:
del self.raw_df[col]
self.label_names = df_labels.columns
self.raw_df = self.raw_df.merge(df_labels, left_index=True, right_index=True, how="left")
def split_rolling_periods(
self,
train_start_date,
train_end_date,
validate_start_date,
validate_end_date,
test_start_date,
test_end_date,
rolling_period,
calendar_freq="day",
):
"""
Calculating the Rolling split periods, the period rolling on market calendar.
:param train_start_date:
:param train_end_date:
:param validate_start_date:
:param validate_end_date:
:param test_start_date:
:param test_end_date:
:param rolling_period: The market period of rolling
:param calendar_freq: The frequence of the market calendar
:yield: Rolling split periods
"""
def get_start_index(calendar, start_date):
start_index = bisect.bisect_left(calendar, start_date)
return start_index
def get_end_index(calendar, end_date):
end_index = bisect.bisect_right(calendar, end_date)
return end_index - 1
calendar = self.raw_df.index.get_level_values("datetime").unique()
train_start_index = get_start_index(calendar, pd.Timestamp(train_start_date))
train_end_index = get_end_index(calendar, pd.Timestamp(train_end_date))
valid_start_index = get_start_index(calendar, pd.Timestamp(validate_start_date))
valid_end_index = get_end_index(calendar, pd.Timestamp(validate_end_date))
test_start_index = get_start_index(calendar, pd.Timestamp(test_start_date))
test_end_index = test_start_index + rolling_period - 1
need_stop_split = False
bound_test_end_index = get_end_index(calendar, pd.Timestamp(test_end_date))
while not need_stop_split:
if test_end_index > bound_test_end_index:
test_end_index = bound_test_end_index
need_stop_split = True
yield (
calendar[train_start_index],
calendar[train_end_index],
calendar[valid_start_index],
calendar[valid_end_index],
calendar[test_start_index],
calendar[test_end_index],
)
train_start_index += rolling_period
train_end_index += rolling_period
valid_start_index += rolling_period
valid_end_index += rolling_period
test_start_index += rolling_period
test_end_index += rolling_period
def get_rolling_data(
self,
train_start_date,
train_end_date,
validate_start_date,
validate_end_date,
test_start_date,
test_end_date,
rolling_period,
calendar_freq="day",
):
# Set generator.
for period in self.split_rolling_periods(
train_start_date,
train_end_date,
validate_start_date,
validate_end_date,
test_start_date,
test_end_date,
rolling_period,
calendar_freq,
):
(
x_train,
y_train,
x_validate,
y_validate,
x_test,
y_test,
) = self.get_split_data(*period)
yield x_train, y_train, x_validate, y_validate, x_test, y_test
def get_split_data(
self,
train_start_date,
train_end_date,
validate_start_date,
validate_end_date,
test_start_date,
test_end_date,
):
"""
all return types are DataFrame
"""
## TODO: loc can be slow, expecially when we put it at the second level index.
if self.raw_df.index.names[0] == "instrument":
df_train = self.raw_df.loc(axis=0)[:, train_start_date:train_end_date]
df_validate = self.raw_df.loc(axis=0)[:, validate_start_date:validate_end_date]
df_test = self.raw_df.loc(axis=0)[:, test_start_date:test_end_date]
else:
df_train = self.raw_df.loc[train_start_date:train_end_date]
df_validate = self.raw_df.loc[validate_start_date:validate_end_date]
df_test = self.raw_df.loc[test_start_date:test_end_date]
TimeInspector.set_time_mark()
df_train, df_validate, df_test = self.setup_process_data(df_train, df_validate, df_test)
TimeInspector.log_cost_time("Finished setup processed data.")
x_train = df_train[self.feature_names]
y_train = df_train[self.label_names]
x_validate = df_validate[self.feature_names]
y_validate = df_validate[self.label_names]
x_test = df_test[self.feature_names]
y_test = df_test[self.label_names]
return x_train, y_train, x_validate, y_validate, x_test, y_test
def setup_process_data(self, df_train, df_valid, df_test):
"""
process the train, valid and test data
:return: the processed train, valid and test data.
"""
for processor in self.processors:
df_train, df_valid, df_test = processor(df_train, df_valid, df_test)
return df_train, df_valid, df_test
def get_origin_test_label_with_date(self, test_start_date, test_end_date, freq="day"):
"""Get origin test label
:param test_start_date: test start date
:param test_end_date: test end date
:param freq: freq
:return: pd.DataFrame
"""
test_end_date = transform_end_date(test_end_date, freq=freq)
return self.raw_df.loc[(slice(None), slice(test_start_date, test_end_date)), self.label_names]
@abc.abstractmethod
def setup_feature(self):
"""
Implement this method to load raw feature.
the format of the feature is below
return: df_features
"""
pass
@abc.abstractmethod
def setup_label(self):
"""
Implement this method to load and calculate label.
the format of the label is below
return: df_label
"""
pass
class QLibDataHandler(BaseDataHandler):
def __init__(self, start_date, end_date, *args, **kwargs):
# Dates.
self.start_date = start_date
self.end_date = end_date
super().__init__(*args, **kwargs)
def _init_kwargs(self, **kwargs):
# Instruments
instruments = kwargs.get("instruments", None)
if instruments is None:
market = kwargs.get("market", "csi500").lower()
data_filter_list = kwargs.get("data_filter_list", list())
self.instruments = D.instruments(market, filter_pipe=data_filter_list)
else:
self.instruments = instruments
# Config of features and labels
self._fields = kwargs.get("fields", [])
self._names = kwargs.get("names", [])
self._labels = kwargs.get("labels", [])
self._label_names = kwargs.get("label_names", [])
# Check arguments
assert len(self._fields) > 0, "features list is empty"
assert len(self._labels) > 0, "labels list is empty"
# Check end_date
# If test_end_date is -1 or greater than the last date, the last date is used
self.end_date = transform_end_date(self.end_date)
def setup_feature(self):
"""
Load the raw data.
return: df_features
"""
TimeInspector.set_time_mark()
if len(self._names) == 0:
names = ["F%d" % i for i in range(len(self._fields))]
else:
names = self._names
df_features = D.features(self.instruments, self._fields, self.start_date, self.end_date)
df_features.columns = names
TimeInspector.log_cost_time("Finished loading features.")
return df_features
def setup_label(self):
"""
Build up labels in df through users' method
:return: df_labels
"""
TimeInspector.set_time_mark()
if len(self._label_names) == 0:
label_names = ["LABEL%d" % i for i in range(len(self._labels))]
else:
label_names = self._label_names
df_labels = D.features(self.instruments, self._labels, self.start_date, self.end_date)
df_labels.columns = label_names
TimeInspector.log_cost_time("Finished loading labels.")
return df_labels
def parse_config_to_fields(config):
"""create factors from config
config = {
'kbar': {}, # whether to use some hard-code kbar features
'price': { # whether to use raw price features
'windows': [0, 1, 2, 3, 4], # use price at n days ago
'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
},
'volume': { # whether to use raw volume features
'windows': [0, 1, 2, 3, 4], # use volume at n days ago
},
'rolling': { # whether to use rolling operator based features
'windows': [5, 10, 20, 30, 60], # rolling windows size
'include': ['ROC', 'MA', 'STD'], # rolling operator to use
#if include is None we will use default operators
'exclude': ['RANK'], # rolling operator not to use
}
}
"""
fields = []
names = []
if "kbar" in config:
fields += [
"($close-$open)/$open",
"($high-$low)/$open",
"($close-$open)/($high-$low+1e-12)",
"($high-Greater($open, $close))/$open",
"($high-Greater($open, $close))/($high-$low+1e-12)",
"(Less($open, $close)-$low)/$open",
"(Less($open, $close)-$low)/($high-$low+1e-12)",
"(2*$close-$high-$low)/$open",
"(2*$close-$high-$low)/($high-$low+1e-12)",
]
names += [
"KMID",
"KLEN",
"KMID2",
"KUP",
"KUP2",
"KLOW",
"KLOW2",
"KSFT",
"KSFT2",
]
if "price" in config:
windows = config["price"].get("windows", range(5))
feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
for field in feature:
field = field.lower()
fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
names += [field.upper() + str(d) for d in windows]
if "volume" in config:
windows = config["volume"].get("windows", range(5))
fields += ["Ref($volume, %d)/$volume" % d if d != 0 else "$volume/$volume" for d in windows]
names += ["VOLUME" + str(d) for d in windows]
if "rolling" in config:
windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
include = config["rolling"].get("include", None)
exclude = config["rolling"].get("exclude", [])
# `exclude` in dataset config unnecessary filed
# `include` in dataset config necessary field
use = lambda x: x not in exclude and (include is None or x in include)
if use("ROC"):
fields += ["Ref($close, %d)/$close" % d for d in windows]
names += ["ROC%d" % d for d in windows]
if use("MA"):
fields += ["Mean($close, %d)/$close" % d for d in windows]
names += ["MA%d" % d for d in windows]
if use("STD"):
fields += ["Std($close, %d)/$close" % d for d in windows]
names += ["STD%d" % d for d in windows]
if use("BETA"):
fields += ["Slope($close, %d)/$close" % d for d in windows]
names += ["BETA%d" % d for d in windows]
if use("RSQR"):
fields += ["Rsquare($close, %d)" % d for d in windows]
names += ["RSQR%d" % d for d in windows]
if use("RESI"):
fields += ["Resi($close, %d)/$close" % d for d in windows]
names += ["RESI%d" % d for d in windows]
if use("MAX"):
fields += ["Max($high, %d)/$close" % d for d in windows]
names += ["MAX%d" % d for d in windows]
if use("LOW"):
fields += ["Min($low, %d)/$close" % d for d in windows]
names += ["MIN%d" % d for d in windows]
if use("QTLU"):
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
names += ["QTLU%d" % d for d in windows]
if use("QTLD"):
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
names += ["QTLD%d" % d for d in windows]
if use("RANK"):
fields += ["Rank($close, %d)" % d for d in windows]
names += ["RANK%d" % d for d in windows]
if use("RSV"):
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
names += ["RSV%d" % d for d in windows]
if use("IMAX"):
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
names += ["IMAX%d" % d for d in windows]
if use("IMIN"):
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
names += ["IMIN%d" % d for d in windows]
if use("IMXD"):
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
names += ["IMXD%d" % d for d in windows]
if use("CORR"):
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
names += ["CORR%d" % d for d in windows]
if use("CORD"):
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
names += ["CORD%d" % d for d in windows]
if use("CNTP"):
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
names += ["CNTP%d" % d for d in windows]
if use("CNTN"):
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
names += ["CNTN%d" % d for d in windows]
if use("CNTD"):
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
names += ["CNTD%d" % d for d in windows]
if use("SUMP"):
fields += [
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
for d in windows
]
names += ["SUMP%d" % d for d in windows]
if use("SUMN"):
fields += [
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
for d in windows
]
names += ["SUMN%d" % d for d in windows]
if use("SUMD"):
fields += [
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d)
for d in windows
]
names += ["SUMD%d" % d for d in windows]
if use("VMA"):
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
names += ["VMA%d" % d for d in windows]
if use("VSTD"):
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
names += ["VSTD%d" % d for d in windows]
if use("WVMA"):
fields += [
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)"
% (d, d)
for d in windows
]
names += ["WVMA%d" % d for d in windows]
if use("VSUMP"):
fields += [
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d)
for d in windows
]
names += ["VSUMP%d" % d for d in windows]
if use("VSUMN"):
fields += [
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d)
for d in windows
]
names += ["VSUMN%d" % d for d in windows]
if use("VSUMD"):
fields += [
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d)
for d in windows
]
names += ["VSUMD%d" % d for d in windows]
return fields, names
class ConfigQLibDataHandler(QLibDataHandler):
config_template = {} # template
def __init__(self, start_date, end_date, processors=None, **kwargs):
if processors is None:
processors = ["ConfigSectionProcessor"] # default processor
super().__init__(start_date, end_date, processors, **kwargs)
def _init_kwargs(self, **kwargs):
config = self.config_template.copy()
if "config_update" in kwargs:
config.update(kwargs["config_update"])
fields, names = parse_config_to_fields(config)
kwargs["fields"] = fields
kwargs["names"] = names
if "labels" not in kwargs:
kwargs["labels"] = ["Ref($vwap, -2)/Ref($vwap, -1) - 1"]
super()._init_kwargs(**kwargs)
class ALPHA360(ConfigQLibDataHandler):
config_template = {
"price": {"windows": range(60)},
"volume": {"windows": range(60)},
}
class QLibDataHandlerV1(ConfigQLibDataHandler):
config_template = {
"kbar": {},
"price": {
"windows": [0],
"feature": ["OPEN", "HIGH", "LOW", "VWAP"],
},
"rolling": {},
}
def __init__(self, start_date, end_date, processors=None, **kwargs):
if processors is None:
processors = ["PanelProcessor"] # V1 default processor
super().__init__(start_date, end_date, processors, **kwargs)
def setup_label(self):
"""
load the labels df
:return: df_labels
"""
TimeInspector.set_time_mark()
df_labels = super().setup_label()
## calculate new labels
df_labels["LABEL1"] = df_labels["LABEL0"].groupby(level="datetime").apply(lambda x: (x - x.mean()) / x.std())
df_labels = df_labels.drop(["LABEL0"], axis=1)
TimeInspector.log_cost_time("Finished loading labels.")
return df_labels
class QLibDataHandlerClose(QLibDataHandlerV1):
config_template = {
'kbar': {},
'price': {
'windows': [0],
'feature': ['OPEN', 'HIGH', 'LOW', 'CLOSE'],
},
'rolling': {}
}
def _init_kwargs(self, **kwargs):
kwargs['labels'] = ["Ref($close, -2)/Ref($close, -1) - 1"]
super(QLibDataHandlerClose, self)._init_kwargs(**kwargs)
# if __name__ == '__main__':
# import qlib
#
# qlib.init()
#
# handler = ALPHA80('2010-01-01', '2018-12-31')
# data = handler.get_split_data(
# pd.Timestamp('2010-01-01'), pd.Timestamp('2014-01-01'),
# pd.Timestamp('2015-01-01'), pd.Timestamp('2016-01-01'),
# pd.Timestamp('2017-01-01'), pd.Timestamp('2018-01-01'))
# print(data[0])
# data[0].to_pickle('alpha80.pkl')

View File

@@ -0,0 +1,116 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# coding=utf-8
import argparse
import importlib
from ... import init
from .config import EstimatorConfigManager
from ...log import get_module_logger
from sacred import Experiment
from sacred.observers import FileStorageObserver
from sacred.observers import MongoObserver
args_parser = argparse.ArgumentParser(prog="estimator")
args_parser.add_argument(
"-c",
"--config_path",
required=True,
type=str,
help="json config path indicates where to load config.",
)
args = args_parser.parse_args()
class SacredExperiment(object):
def __init__(
self,
experiment_name,
experiment_dir,
observer_type="file_storage",
mongo_url=None,
db_name=None,
):
"""__init__
:param experiment_name: The name of the experiments.
:param experiment_dir: The directory to store all the results of the experiments(This is for file_storage).
:param observer_type: The observer to record the results: the `file_storage` or `mongo`
:param mongo_url: The mongo url(for mongo observer)
:param db_name: The mongo url(for mongo observer)
"""
self.experiment_name = experiment_name
self.experiment = Experiment(self.experiment_name)
self.experiment_dir = experiment_dir
self.experiment.logger = get_module_logger("Sacred")
self.observer_type = observer_type
self.mongo_db_url = mongo_url
self.mongo_db_name = db_name
self._setup_experiment()
def _setup_experiment(self):
if self.observer_type == "file_storage":
file_storage_observer = FileStorageObserver.create(basedir=self.experiment_dir)
self.experiment.observers.append(file_storage_observer)
elif self.observer_type == "mongo":
mongo_observer = MongoObserver.create(url=self.mongo_db_url, db_name=self.mongo_db_name)
self.experiment.observers.append(mongo_observer)
else:
raise NotImplementedError("Unsupported observer type: {}".format(self.observer_type))
def add_artifact(self, filename):
self.experiment.add_artifact(filename)
def add_info(self, key, value):
self.experiment.info[key] = value
def main_wrapper(self, func):
return self.experiment.main(func)
def config_wrapper(self, func):
return self.experiment.config(func)
CONFIG_MANAGER = EstimatorConfigManager(args.config_path)
ex = SacredExperiment(
CONFIG_MANAGER.ex_config.name,
CONFIG_MANAGER.ex_config.sacred_dir,
observer_type=CONFIG_MANAGER.ex_config.observer_type,
mongo_url=CONFIG_MANAGER.ex_config.mongo_url,
db_name=CONFIG_MANAGER.ex_config.db_name,
)
# qlib init
init(
provider_uri=CONFIG_MANAGER.qlib_data_config.provider_uri,
mount_path=CONFIG_MANAGER.qlib_data_config.mount_path,
auto_mount=CONFIG_MANAGER.qlib_data_config.auto_mount,
region=CONFIG_MANAGER.qlib_data_config.region,
**CONFIG_MANAGER.qlib_data_config.args
)
@ex.main_wrapper
def _main():
# 1. Get estimator class.
estimator_class = getattr(
importlib.import_module(".estimator", package="qlib.contrib.estimator"),
"Estimator",
)
# 2. Init estimator.
estimator = estimator_class(CONFIG_MANAGER, ex)
estimator.run()
def run():
ex.experiment.run()
if __name__ == "__main__":
run()

View File

@@ -0,0 +1,249 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import abc
import numpy as np
import pandas as pd
from ...log import TimeInspector
EPS = 1e-12
class Processor(abc.ABC):
def __init__(self, feature_names, label_names, **kwargs):
self.feature_names = feature_names
self.label_names = label_names
@abc.abstractmethod
def __call__(self, df_train, df_valid, df_test):
pass
class PanelProcessor(Processor):
"""Panel Preprocessor"""
STD_NORM = "Std"
MINMAX_NORM = "MinMax"
def __init__(self, feature_names, label_names, **kwargs):
super().__init__(feature_names, label_names)
# Options.
self.dropna_label = kwargs.get("dropna_label", True)
self.dropna_feature = kwargs.get("dropna_feature", False)
self.normalize_method = kwargs.get("normalize_method", None)
self.replace_inf = kwargs.get("replace_inf_feature", False)
def __call__(self, df_train, df_valid, df_test):
"""
Preprocess the data
:param df: the dataframe to process data.
"""
# Drop null labels.
if self.dropna_label:
df_train, df_valid, df_test = self._process_drop_null_label(df_train, df_valid, df_test)
# Dropna if need.
if self.dropna_feature:
df_train, df_valid, df_test = self._process_drop_null_feature(df_train, df_valid, df_test)
# replace the 'inf' with the mean the corresponding dimension
if self.replace_inf:
df_train, df_valid, df_test = self._process_replace_inf_feature(df_train, df_valid, df_test)
# normalize data in given method.
if self.normalize_method is not None:
df_train, df_valid, df_test = self._process_normalize_feature(df_train, df_valid, df_test)
return df_train, df_valid, df_test
def _process_drop_null_label(self, df_train, df_valid, df_test):
"""
Drop null labels.
"""
TimeInspector.set_time_mark()
df_train = df_train.dropna(subset=self.label_names)
df_valid = df_valid.dropna(subset=self.label_names)
# The test data's label is Unkown. They can not be seen when preprocessing
TimeInspector.log_cost_time("Finished dropping null labels.")
return df_train, df_valid, df_test
def _process_drop_null_feature(self, df_train, df_valid, df_test):
"""
Drop data which contain null features if needed.
"""
# TODO - `Pandas.dropna` is a low performance method.
TimeInspector.set_time_mark()
df_train = df_train.dropna(subset=self.feature_names)
df_valid = df_valid.dropna(subset=self.feature_names)
df_test = df_test.dropna(subset=self.feature_names)
TimeInspector.log_cost_time("Finished dropping nan.")
return df_train, df_valid, df_test
def _process_replace_inf_feature(self, df_train, df_valid, df_test):
"""
replace the 'inf' in feature with the mean of this dimension.
"""
TimeInspector.set_time_mark()
def replace_inf(data):
def process_inf(df):
for col in df.columns:
df[col] = df[col].replace([np.inf, -np.inf], df[col][~np.isinf(df[col])].mean())
return df
data = data.groupby("datetime").apply(process_inf)
data.sort_index(inplace=True)
return data
df_train = replace_inf(df_train)
df_valid = replace_inf(df_valid)
df_test = replace_inf(df_test)
TimeInspector.log_cost_time("Finished replace inf.")
return df_train, df_valid, df_test
def _process_normalize_feature(self, df_train, df_valid, df_test):
"""
Normalize data if needed, we provide two method now: min-max normalization and standard normalization.
"""
TimeInspector.set_time_mark()
if self.normalize_method == self.MINMAX_NORM:
min_train = np.nanmin(df_train[self.feature_names].values, axis=0)
max_train = np.nanmax(df_train[self.feature_names].values, axis=0)
ignore = min_train == max_train
def normalize(x, min_train=min_train, max_train=max_train, ignore=ignore):
if (~ignore).all():
return (x - min_train) / (max_train - min_train)
for i in range(ignore.size):
if not ignore[i]:
x[i] = (x[i] - min_train) / (max_train - min_train)
return x
elif self.normalize_method == self.STD_NORM:
mean_train = np.nanmean(df_train[self.feature_names].values, axis=0)
std_train = np.nanstd(df_train[self.feature_names].values, axis=0)
ignore = std_train == 0
def normalize(x, mean_train=mean_train, std_train=std_train, ignore=ignore):
if (~ignore).all():
return (x - mean_train) / std_train
for i in range(ignore.size):
if not ignore[i]:
x[i] = (x[i] - mean_train) / std_train
return x
else:
raise ValueError("Normalize method {} is not allowed".format(self.normalize_method))
df_train.loc(axis=1)[self.feature_names] = normalize(df_train[self.feature_names].values)
df_valid.loc(axis=1)[self.feature_names] = normalize(df_valid[self.feature_names].values)
df_test.loc(axis=1)[self.feature_names] = normalize(df_test[self.feature_names].values)
TimeInspector.log_cost_time("Finished normalizing data.")
return df_train, df_valid, df_test
class ConfigSectionProcessor(Processor):
def __init__(self, feature_names, label_names, **kwargs):
super().__init__(feature_names, label_names)
# Options
self.fillna_feature = kwargs.get("fillna_feature", True)
self.fillna_label = kwargs.get("fillna_label", True)
self.clip_feature_outlier = kwargs.get("clip_feature_outlier", False)
self.shrink_feature_outlier = kwargs.get("shrink_feature_outlier", True)
self.clip_label_outlier = kwargs.get("clip_label_outlier", False)
def __call__(self, *args):
return [self._transform(x) for x in args]
def _transform(self, df):
def _label_norm(x):
x = x - x.mean() # copy
x /= x.std()
if self.clip_label_outlier:
x.clip(-3, 3, inplace=True)
if self.fillna_label:
x.fillna(0, inplace=True)
return x
def _feature_norm(x):
x = x - x.median() # copy
x /= x.abs().median() * 1.4826
if self.clip_feature_outlier:
x.clip(-3, 3, inplace=True)
if self.shrink_feature_outlier:
x.where(x <= 3, 3 + (x - 3).div(x.max() - 3) * 0.5, inplace=True)
x.where(x >= -3, -3 - (x + 3).div(x.min() + 3) * 0.5, inplace=True)
if self.fillna_feature:
x.fillna(0, inplace=True)
return x
TimeInspector.set_time_mark()
# Copy
df_new = df.copy()
# Label
cols = df.columns[df.columns.str.contains("^LABEL")]
df_new[cols] = df[cols].groupby(level="datetime").apply(_label_norm)
# Features
cols = df.columns[df.columns.str.contains("^KLEN|^KLOW|^KUP")]
df_new[cols] = df[cols].apply(lambda x: x ** 0.25).groupby(level="datetime").apply(_feature_norm)
cols = df.columns[df.columns.str.contains("^KLOW2|^KUP2")]
df_new[cols] = df[cols].apply(lambda x: x ** 0.5).groupby(level="datetime").apply(_feature_norm)
_cols = [
"KMID",
"KSFT",
"OPEN",
"HIGH",
"LOW",
"CLOSE",
"VWAP",
"ROC",
"MA",
"BETA",
"RESI",
"QTLU",
"QTLD",
"RSV",
"SUMP",
"SUMN",
"SUMD",
"VSUMP",
"VSUMN",
"VSUMD",
]
pat = "|".join(["^" + x for x in _cols])
cols = df.columns[df.columns.str.contains(pat) & (~df.columns.isin(["HIGH0", "LOW0"]))]
df_new[cols] = df[cols].groupby(level="datetime").apply(_feature_norm)
cols = df.columns[df.columns.str.contains("^STD|^VOLUME|^VMA|^VSTD")]
df_new[cols] = df[cols].apply(np.log).groupby(level="datetime").apply(_feature_norm)
cols = df.columns[df.columns.str.contains("^RSQR")]
df_new[cols] = df[cols].fillna(0).groupby(level="datetime").apply(_feature_norm)
cols = df.columns[df.columns.str.contains("^MAX|^HIGH0")]
df_new[cols] = df[cols].apply(lambda x: (x - 1) ** 0.5).groupby(level="datetime").apply(_feature_norm)
cols = df.columns[df.columns.str.contains("^MIN|^LOW0")]
df_new[cols] = df[cols].apply(lambda x: (1 - x) ** 0.5).groupby(level="datetime").apply(_feature_norm)
cols = df.columns[df.columns.str.contains("^CORR|^CORD")]
df_new[cols] = df[cols].apply(np.exp).groupby(level="datetime").apply(_feature_norm)
cols = df.columns[df.columns.str.contains("^WVMA")]
df_new[cols] = df[cols].apply(np.log1p).groupby(level="datetime").apply(_feature_norm)
TimeInspector.log_cost_time("Finished preprocessing data.")
return df_new

View File

@@ -0,0 +1,315 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# coding=utf-8
from abc import abstractmethod
import pandas as pd
import numpy as np
from scipy.stats import pearsonr
from ...log import get_module_logger, TimeInspector
from .handler import BaseDataHandler
from .launcher import CONFIG_MANAGER
from .fetcher import create_fetcher_with_config
from ...utils import drop_nan_by_y_index, transform_end_date
class BaseTrainer(object):
def __init__(self, model_class, model_save_path, model_args, data_handler: BaseDataHandler, sacred_ex, **kwargs):
# 1. Model.
self.model_class = model_class
self.model_save_path = model_save_path
self.model_args = model_args
# 2. Data handler.
self.data_handler = data_handler
# 3. Sacred ex.
self.ex = sacred_ex
# 4. Logger.
self.logger = get_module_logger("Trainer")
# 5. Data time
self.train_start_date = kwargs.get("train_start_date", None)
self.train_end_date = kwargs.get("train_end_date", None)
self.validate_start_date = kwargs.get("validate_start_date", None)
self.validate_end_date = kwargs.get("validate_end_date", None)
self.test_start_date = kwargs.get("test_start_date", None)
self.test_end_date = transform_end_date(kwargs.get("test_end_date", None))
@abstractmethod
def train(self):
"""
Implement this method indicating how to train a model.
"""
pass
@abstractmethod
def load(self):
"""
Implement this method indicating how to restore a model and the data.
"""
pass
@abstractmethod
def get_test_pred(self):
"""
Implement this method indicating how to get prediction result(s) from a model.
"""
pass
@abstractmethod
def get_test_performance(self):
"""
Implement this method indicating how to get the performance of the model.
"""
pass
def get_test_score(self):
"""
Override this method to transfer the predict result(s) into the score of the stock.
Note: If this is a multi-label training, you need to transfer predict labels into one score.
Or you can just use the result of `get_test_pred()` (you can also process the result) if this is one label training.
We use the first column of the result of `get_test_pred()` as default method (regard it as one label training).
"""
pred = self.get_test_pred()
pred_score = pd.DataFrame(index=pred.index)
pred_score["score"] = pred.iloc(axis=1)[0]
return pred_score
class StaticTrainer(BaseTrainer):
def __init__(self, model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs):
super(StaticTrainer, self).__init__(model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs)
self.model = None
split_data = self.data_handler.get_split_data(
self.train_start_date,
self.train_end_date,
self.validate_start_date,
self.validate_end_date,
self.test_start_date,
self.test_end_date,
)
(
self.x_train,
self.y_train,
self.x_validate,
self.y_validate,
self.x_test,
self.y_test,
) = split_data
def train(self):
TimeInspector.set_time_mark()
model = self.model_class(**self.model_args)
if CONFIG_MANAGER.ex_config.finetune:
fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
loader_model = fetcher.get_experiment(
exp_name=CONFIG_MANAGER.ex_config.loader_name,
exp_id=CONFIG_MANAGER.ex_config.loader_id,
fields=["model"],
)["model"]
if isinstance(loader_model, list):
model_index = (
-1
if CONFIG_MANAGER.ex_config.loader_model_index is None
else CONFIG_MANAGER.ex_config.loader_model_index
)
loader_model = loader_model[model_index]
model.load(loader_model)
model.finetune(self.x_train, self.y_train, self.x_validate, self.y_validate)
else:
model.fit(self.x_train, self.y_train, self.x_validate, self.y_validate)
model.save(self.model_save_path)
self.ex.add_artifact(self.model_save_path)
self.model = model
TimeInspector.log_cost_time("Finished training model.")
def load(self):
model = self.model_class(**self.model_args)
# Load model
fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
loader_model = fetcher.get_experiment(
exp_name=CONFIG_MANAGER.ex_config.loader_name,
exp_id=CONFIG_MANAGER.ex_config.loader_id,
fields=["model"],
)["model"]
if isinstance(loader_model, list):
model_index = (
-1
if CONFIG_MANAGER.ex_config.loader_model_index is None
else CONFIG_MANAGER.ex_config.loader_model_index
)
loader_model = loader_model[model_index]
model.load(loader_model)
# Save model, after load, if you don't save the model, the result of this experiment will be no model
model.save(self.model_save_path)
self.ex.add_artifact(self.model_save_path)
self.model = model
def get_test_pred(self):
pred = self.model.predict(self.x_test)
pred = pd.DataFrame(pred, index=self.x_test.index, columns=self.y_test.columns)
return pred
def get_test_performance(self):
model_score = self.model.score(self.x_test, self.y_test)
# Remove rows from x, y and w, which contain Nan in any columns in y_test.
x_test, y_test, __ = drop_nan_by_y_index(self.x_test, self.y_test)
pred_test = self.model.predict(x_test)
model_pearsonr = pearsonr(np.ravel(pred_test), np.ravel(y_test.values))[0]
performance = {"model_score": model_score, "model_pearsonr": model_pearsonr}
return performance
class RollingTrainer(BaseTrainer):
def __init__(self, model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs):
super(RollingTrainer, self).__init__(
model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs
)
self.rolling_period = kwargs.get("rolling_period", 60)
self.models = []
self.rolling_data = []
self.all_x_test = []
self.all_y_test = []
for data in self.data_handler.get_rolling_data(
self.train_start_date,
self.train_end_date,
self.validate_start_date,
self.validate_end_date,
self.test_start_date,
self.test_end_date,
self.rolling_period,
):
self.rolling_data.append(data)
__, __, __, __, x_test, y_test = data
self.all_x_test.append(x_test)
self.all_y_test.append(y_test)
def train(self):
# 1. Get total data parts.
# total_data_parts = self.data_handler.total_data_parts
# self.logger.warning('Total numbers of model are: {}, start training models...'.format(total_data_parts))
if CONFIG_MANAGER.ex_config.finetune:
fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
loader_model = fetcher.get_experiment(
exp_name=CONFIG_MANAGER.ex_config.loader_name,
exp_id=CONFIG_MANAGER.ex_config.loader_id,
fields=["model"],
)["model"]
loader_model_index = CONFIG_MANAGER.ex_config.loader_model_index
previous_model_path = ""
# 2. Rolling train.
for (
index,
(x_train, y_train, x_validate, y_validate, x_test, y_test),
) in enumerate(self.rolling_data):
TimeInspector.set_time_mark()
model = self.model_class(**self.model_args)
if CONFIG_MANAGER.ex_config.finetune:
# Finetune model
if loader_model_index is None and isinstance(loader_model, list):
try:
model.load(loader_model[index])
except IndexError:
# Load model by previous_model_path
with open(previous_model_path, "rb") as fp:
model.load(fp)
model.finetune(x_train, y_train, x_validate, y_validate)
else:
if index == 0:
loader_model = (
loader_model[loader_model_index] if isinstance(loader_model, list) else loader_model
)
model.load(loader_model)
else:
with open(previous_model_path, "rb") as fp:
model.load(fp)
model.finetune(x_train, y_train, x_validate, y_validate)
else:
model.fit(x_train, y_train, x_validate, y_validate)
model_save_path = "{}_{}".format(self.model_save_path, index)
model.save(model_save_path)
previous_model_path = model_save_path
self.ex.add_artifact(model_save_path)
self.models.append(model)
TimeInspector.log_cost_time("Finished training model: {}.".format(index + 1))
def load(self):
"""
Load the data and the model
"""
fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
loader_model = fetcher.get_experiment(
exp_name=CONFIG_MANAGER.ex_config.loader_name,
exp_id=CONFIG_MANAGER.ex_config.loader_id,
fields=["model"],
)["model"]
for index in range(len(self.all_x_test)):
model = self.model_class(**self.model_args)
model.load(loader_model[index])
# Save model
model_save_path = "{}_{}".format(self.model_save_path, index)
model.save(model_save_path)
self.ex.add_artifact(model_save_path)
self.models.append(model)
def get_test_pred(self):
"""
Predict the score on test data with the models.
Please ensure the models and data are loaded before call this score.
:return: the predicted scores for the pred
"""
pred_df_list = []
y_test_columns = self.all_y_test[0].columns
# Start iteration.
for model, x_test in zip(self.models, self.all_x_test):
pred = model.predict(x_test)
pred_df = pd.DataFrame(pred, index=x_test.index, columns=y_test_columns)
pred_df_list.append(pred_df)
return pd.concat(pred_df_list)
def get_test_performance(self):
"""
Get the performances of the models
:return: the performances of models
"""
pred_test_list = []
y_test_list = []
scorer = self.models[0]._scorer
for model, x_test, y_test in zip(self.models, self.all_x_test, self.all_y_test):
# Remove rows from x, y and w, which contain Nan in any columns in y_test.
x_test, y_test, __ = drop_nan_by_y_index(x_test, y_test)
pred_test_list.append(model.predict(x_test))
y_test_list.append(np.squeeze(y_test.values))
pred_test_array = np.concatenate(pred_test_list, axis=0)
y_test_array = np.concatenate(y_test_list, axis=0)
model_score = scorer(y_test_array, pred_test_array)
model_pearsonr = pearsonr(np.ravel(y_test_array), np.ravel(pred_test_array))[0]
performance = {"model_score": model_score, "model_pearsonr": model_pearsonr}
return performance

396
qlib/contrib/evaluate.py Normal file
View File

@@ -0,0 +1,396 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import numpy as np
import pandas as pd
import inspect
from ..log import get_module_logger
from . import strategy as strategy_pool
from .strategy.strategy import BaseStrategy
from .backtest.exchange import Exchange
from .backtest.backtest import backtest as backtest_func, get_date_range
from ..data import D
from ..config import C
logger = get_module_logger("Evaluate")
def risk_analysis(r, N=252):
"""Risk Analysis
Parameters
----------
r : pandas.Series
daily return series
N: int
scaler for annualizing sharpe ratio (day: 250, week: 50, month: 12)
"""
mean = r.mean()
std = r.std(ddof=1)
annual = mean * N
sharpe = mean / std * np.sqrt(N)
mdd = (r.cumsum() - r.cumsum().cummax()).min()
data = {"mean": mean, "std": std, "annual": annual, "sharpe": sharpe, "mdd": mdd}
res = pd.Series(data, index=data.keys()).to_frame("risk")
return res
def get_strategy(
strategy=None,
topk=50,
margin=0.5,
n_drop=5,
risk_degree=0.95,
str_type="amount",
adjust_dates=None,
):
"""get_strategy
Parameters
----------
strategy : Strategy()
strategy used in backtest
topk : int (Default value: 50)
top-N stocks to buy.
margin : int or float(Default value: 0.5)
if isinstance(margin, int):
sell_limit = margin
else:
sell_limit = pred_in_a_day.count() * margin
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit)
sell_limit should be no less than topk
n_drop : int
number of stocks to be replaced in each trading date
risk_degree: float
0-1, 0.95 for example, use 95% money to trade
str_type: 'amount', 'weight' or 'dropout'
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy
Returns
-------
:class: Strategy
an initialized strategy object
"""
if strategy is None:
str_cls_dict = {
"amount": "TopkAmountStrategy",
"weight": "TopkWeightStrategy",
"dropout": "TopkDropoutStrategy",
}
logger.info("Create new streategy ")
str_cls = getattr(strategy_pool, str_cls_dict.get(str_type))
strategy = str_cls(
topk=topk,
buffer_margin=margin,
n_drop=n_drop,
risk_degree=risk_degree,
adjust_dates=adjust_dates,
)
if not isinstance(strategy, BaseStrategy):
raise TypeError("Strategy not supported")
return strategy
def get_exchange(
pred,
exchange=None,
subscribe_fields=[],
open_cost=0.0015,
close_cost=0.0025,
min_cost=5.0,
trade_unit=None,
limit_threshold=None,
deal_price=None,
extract_codes=False,
shift=1,
):
"""get_exchange
Parameters
----------
# exchange related arguments
exchange: Exchange()
subscribe_fields: list
subscribe fields
open_cost : float
open transaction cost
close_cost : float
close transaction cost
min_cost : float
min transaction cost
trade_unit : int
100 for China A
deal_price: str
dealing price type: 'close', 'open', 'vwap'
limit_threshold : float
limit move 0.1 (10%) for example, long and short with same limit
extract_codes: bool
will we pass the codes extracted from the pred to the exchange.
NOTE: This will be faster with offline qlib.
Returns
-------
:class: Exchange
an initialized Exchange object
"""
if trade_unit is None:
trade_unit = C.trade_unit
if limit_threshold is None:
limit_threshold = C.limit_threshold
if deal_price is None:
deal_price = C.deal_price
if exchange is None:
logger.info("Create new exchange")
# handle exception for deal_price
if deal_price[0] != "$":
deal_price = "$" + deal_price
if extract_codes:
codes = sorted(pred.index.get_level_values(0).unique())
else:
codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks
dates = sorted(pred.index.get_level_values(1).unique())
dates = np.append(dates, get_date_range(dates[-1], shift=shift))
exchange = Exchange(
trade_dates=dates,
codes=codes,
deal_price=deal_price,
subscribe_fields=subscribe_fields,
limit_threshold=limit_threshold,
open_cost=open_cost,
close_cost=close_cost,
min_cost=min_cost,
trade_unit=trade_unit,
)
return exchange
# This is the API for compatibility for legacy code
def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **kwargs):
"""This function will help you set a reasonable Exchange and provide default value for strategy
Parameters
----------
# backtest workflow related or commmon arguments
pred : pandas.DataFrame
predict should has <instrument, datetime> index and one `score` column
account : float
init account value
shift : int
whether to shift prediction by one day
benchmark : str
benchmark code, default is SH000905 CSI 500
verbose : bool
whether to print log
# strategy related arguments
strategy : Strategy()
strategy used in backtest
topk : int (Default value: 50)
top-N stocks to buy.
margin : int or float(Default value: 0.5)
if isinstance(margin, int):
sell_limit = margin
else:
sell_limit = pred_in_a_day.count() * margin
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit)
sell_limit should be no less than topk
n_drop : int
number of stocks to be replaced in each trading date
risk_degree: float
0-1, 0.95 for example, use 95% money to trade
str_type: 'amount', 'weight' or 'dropout'
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy
# exchange related arguments
exchange: Exchange()
pass the exchange for speeding up.
subscribe_fields: list
subscribe fields
open_cost : float
open transaction cost. The default value is 0.002(0.2%).
close_cost : float
close transaction cost. The default value is 0.002(0.2%).
min_cost : float
min transaction cost
trade_unit : int
100 for China A
deal_price: str
dealing price type: 'close', 'open', 'vwap'
limit_threshold : float
limit move 0.1 (10%) for example, long and short with same limit
extract_codes: bool
will we pass the codes extracted from the pred to the exchange.
.. note:: This will be faster with offline qlib.
"""
# check strategy:
spec = inspect.getfullargspec(get_strategy)
str_args = {k: v for k, v in kwargs.items() if k in spec.args}
strategy = get_strategy(**str_args)
# init exchange:
spec = inspect.getfullargspec(get_exchange)
ex_args = {k: v for k, v in kwargs.items() if k in spec.args}
trade_exchange = get_exchange(pred, **ex_args)
# run backtest
report_df, positions = backtest_func(
pred=pred,
strategy=strategy,
trade_exchange=trade_exchange,
shift=shift,
verbose=verbose,
account=account,
benchmark=benchmark,
)
# for compatibility of the old API. return the dict positions
positions = {k: p.position for k, p in positions.items()}
return report_df, positions
def long_short_backtest(
pred,
topk=50,
deal_price=None,
shift=1,
open_cost=0,
close_cost=0,
trade_unit=None,
limit_threshold=None,
min_cost=5,
subscribe_fields=[],
extract_codes=False,
):
"""
A backtest for long-short strategy
:param pred: The trading signal produced on day `T`
:param topk: The short topk securities and long topk securities
:param deal_price: The price to deal the trading
:param shift: Whether to shift prediction by one day. The trading day will be T+1 if shift==1.
:param open_cost: open transaction cost
:param close_cost: close transaction cost
:param trade_unit: 100 for China A
:param limit_threshold: limit move 0.1 (10%) for example, long and short with same limit
:param min_cost: min transaction cost
:param subscribe_fields: subscribe fields
:param extract_codes: bool
will we pass the codes extracted from the pred to the exchange.
NOTE: This will be faster with offline qlib.
:return: The result of backtest, it is represented by a dict.
{ "long": long_returns(excess),
"short": short_returns(excess),
"long_short": long_short_returns}
"""
if trade_unit is None:
trade_unit = C.trade_unit
if limit_threshold is None:
limit_threshold = C.limit_threshold
if deal_price is None:
deal_price = C.deal_price
if deal_price[0] != "$":
deal_price = "$" + deal_price
subscribe_fields = subscribe_fields.copy()
profit_str = f"Ref({deal_price}, -1)/{deal_price} - 1"
subscribe_fields.append(profit_str)
trade_exchange = get_exchange(
pred=pred,
deal_price=deal_price,
subscribe_fields=subscribe_fields,
limit_threshold=limit_threshold,
open_cost=open_cost,
close_cost=close_cost,
min_cost=min_cost,
trade_unit=trade_unit,
extract_codes=extract_codes,
shift=shift,
)
_pred_dates = pred.index.get_level_values(level="datetime")
predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max())
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], shift=shift))
long_returns = {}
short_returns = {}
ls_returns = {}
for pdate, date in zip(predict_dates, trade_dates):
score = pred.loc(axis=0)[:, pdate]
score = score.reset_index().sort_values(by="score", ascending=False)
long_stocks = list(score.iloc[:topk]["instrument"])
short_stocks = list(score.iloc[-topk:]["instrument"])
score = score.set_index(["instrument", "datetime"]).sort_index()
long_profit = []
short_profit = []
all_profit = []
for stock in long_stocks:
if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):
continue
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
if np.isnan(profit):
long_profit.append(0)
else:
long_profit.append(profit)
for stock in short_stocks:
if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):
continue
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
if np.isnan(profit):
short_profit.append(0)
else:
short_profit.append(-profit)
for stock in list(score.loc(axis=0)[:, pdate].index.get_level_values(level=0)):
# exclude the suspend stock
if trade_exchange.check_stock_suspended(stock_id=stock, trade_date=date):
continue
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
if np.isnan(profit):
all_profit.append(0)
else:
all_profit.append(profit)
long_returns[date] = np.mean(long_profit) - np.mean(all_profit)
short_returns[date] = np.mean(short_profit) + np.mean(all_profit)
ls_returns[date] = np.mean(short_profit) + np.mean(long_profit)
return dict(
zip(
["long", "short", "long_short"],
map(pd.Series, [long_returns, short_returns, ls_returns]),
)
)
def t_run():
pred_FN = "./check_pred.csv"
pred = pd.read_csv(pred_FN)
pred["datetime"] = pd.to_datetime(pred["datetime"])
pred = pred.set_index([pred.columns[0], pred.columns[1]])
pred = pred.iloc[:9000]
report_df, positions = backtest(pred=pred)
print(report_df.head())
print(positions.keys())
print(positions[list(positions.keys())[0]])
return 0
if __name__ == "__main__":
t_run()

View File

@@ -0,0 +1,246 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import copy
import numpy as np
import pandas as pd
from scipy.stats import spearmanr, pearsonr
from ..data import D
from collections import OrderedDict
def _get_position_value_from_df(evaluate_date, position, close_data_df):
"""Get position value by existed close data df
close_data_df:
pd.DataFrame
multi-index
close_data_df['$close'][stock_id][evaluate_date]: close price for (stock_id, evaluate_date)
position:
same in get_position_value()
"""
value = 0
for stock_id, report in position.items():
if stock_id != "cash":
value += report["amount"] * close_data_df["$close"][stock_id][evaluate_date]
# value += report['amount'] * report['price']
if "cash" in position:
value += position["cash"]
return value
def get_position_value(evaluate_date, position):
"""sum of close*amount
get value of postion
use close price
postions:
{
Timestamp('2016-01-05 00:00:00'):
{
'SH600022':
{
'amount':100.00,
'price':12.00
},
'cash':100000.0
}
}
It means Hold 100.0 'SH600022' and 100000.0 RMB in '2016-01-05'
"""
# load close price for position
# position should also consider cash
instruments = list(position.keys())
instruments = list(set(instruments) - set(["cash"])) # filter 'cash'
fields = ["$close"]
close_data_df = D.features(
instruments,
fields,
start_time=evaluate_date,
end_time=evaluate_date,
freq="day",
disk_cache=0,
)
value = _get_position_value_from_df(evaluate_date, position, close_data_df)
return value
def get_position_list_value(positions):
# generate instrument list and date for whole poitions
instruments = set()
for day, position in positions.items():
instruments.update(position.keys())
instruments = list(set(instruments) - set(["cash"])) # filter 'cash'
instruments.sort()
day_list = list(positions.keys())
day_list.sort()
start_date, end_date = day_list[0], day_list[-1]
# load data
fields = ["$close"]
close_data_df = D.features(
instruments,
fields,
start_time=start_date,
end_time=end_date,
freq="day",
disk_cache=0,
)
# generate value
# return dict for time:position_value
value_dict = OrderedDict()
for day, position in positions.items():
value = _get_position_value_from_df(evaluate_date=day, position=position, close_data_df=close_data_df)
value_dict[day] = value
return value_dict
def get_daily_return_series_from_positions(positions, init_asset_value):
"""Parameters
generate daily return series from position view
positions: positions generated by strategy
init_asset_value : init asset value
return: pd.Series of daily return , return_series[date] = daily return rate
"""
value_dict = get_position_list_value(positions)
value_series = pd.Series(value_dict)
value_series = value_series.sort_index() # check date
return_series = value_series.pct_change()
return_series[value_series.index[0]] = (
value_series[value_series.index[0]] / init_asset_value - 1
) # update daily return for the first date
return return_series
def get_annual_return_from_positions(positions, init_asset_value):
"""Annualized Returns
p_r = (p_end / p_start)^{(250/n)} - 1
p_r annual return
p_end final value
p_start init value
n days of backtest
"""
date_range_list = sorted(list(positions.keys()))
end_time = date_range_list[-1]
p_end = get_position_value(end_time, positions[end_time])
p_start = init_asset_value
n_period = len(date_range_list)
annual = pow((p_end / p_start), (250 / n_period)) - 1
return annual
def get_annaul_return_from_return_series(r, method="ci"):
"""Risk Analysis from daily return series
Parameters
----------
r : pandas.Series
daily return series
method : str
interest calculation method, ci(compound interest)/si(simple interest)
"""
mean = r.mean()
annual = (1 + mean) ** 250 - 1 if method == "ci" else mean * 250
return annual
def get_sharpe_ratio_from_return_series(r, risk_free_rate=0.00, method="ci"):
"""Risk Analysis
Parameters
----------
r : pandas.Series
daily return series
method : str
interest calculation method, ci(compound interest)/si(simple interest)
risk_free_rate : float
risk_free_rate, default as 0.00, can set as 0.03 etc
"""
std = r.std(ddof=1)
annual = get_annaul_return_from_return_series(r, method=method)
sharpe = (annual - risk_free_rate) / std / np.sqrt(250)
return sharpe
def get_max_drawdown_from_series(r):
"""Risk Analysis from asset value
cumprod way
Parameters
----------
r : pandas.Series
daily return series
"""
# mdd = ((r.cumsum() - r.cumsum().cummax()) / (1 + r.cumsum().cummax())).min()
mdd = (((1 + r).cumprod() - (1 + r).cumprod().cummax()) / ((1 + r).cumprod().cummax())).min()
return mdd
def get_turnover_rate():
# in backtest
pass
def get_beta(r, b):
"""Risk Analysis beta
Parameters
----------
r : pandas.Series
daily return series of strategy
b : pandas.Series
daily return series of baseline
"""
cov_r_b = np.cov(r, b)
var_b = np.var(b)
return cov_r_b / var_b
def get_alpha(r, b, risk_free_rate=0.03):
beta = get_beta(r, b)
annaul_r = get_annaul_return_from_return_series(r)
annaul_b = get_annaul_return_from_return_series(b)
alpha = annaul_r - risk_free_rate - beta * (annaul_b - risk_free_rate)
return alpha
def get_volatility_from_series(r):
return r.std(ddof=1)
def get_rank_ic(a, b):
"""Rank IC
Parameters
----------
r : pandas.Series
daily score series of feature
b : pandas.Series
daily return series
"""
return spearmanr(a, b).correlation
def get_normal_ic(a, b):
return pearsonr(a, b).correlation

View File

@@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import warnings
from .base import Model

155
qlib/contrib/model/base.py Normal file
View File

@@ -0,0 +1,155 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import abc
import six
@six.add_metaclass(abc.ABCMeta)
class Model(object):
"""Model base class"""
@property
def name(self):
return type(self).__name__
def fit(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs):
"""fix train with cross-validation
Fit model when ex_config.finetune is False
Parameters
----------
x_train : pd.dataframe
train data
y_train : pd.dataframe
train label
x_valid : pd.dataframe
valid data
y_valid : pd.dataframe
valid label
w_train : pd.dataframe
train weight
w_valid : pd.dataframe
valid weight
Returns
----------
Model
trained model
"""
raise NotImplementedError()
def score(self, x_test, y_test, w_test=None, **kwargs):
"""evaluate model with test data/label
Parameters
----------
x_test : pd.dataframe
test data
y_test : pd.dataframe
test label
w_test : pd.dataframe
test weight
Returns
----------
float
evaluation score
"""
raise NotImplementedError()
def predict(self, x_test, **kwargs):
"""predict given test data
Parameters
----------
x_test : pd.dataframe
test data
Returns
----------
np.ndarray
test predict label
"""
raise NotImplementedError()
def save(self, fname, **kwargs):
"""save model
Parameters
----------
fname : str
model filename
"""
# TODO: Currently need to save the model as a single file, otherwise the estimator may not be compatible
raise NotImplementedError()
def load(self, buffer, **kwargs):
"""load model
Parameters
----------
buffer : bytes
binary data of model parameters
Returns
----------
Model
loaded model
"""
raise NotImplementedError()
def get_data_with_date(self, date, **kwargs):
"""
Will be called in online module
need to return the data that used to predict the label (score) of stocks at date.
:param
date: pd.Timestamp
predict date
:return:
data: the input data that used to predict the label (score) of stocks at predict date.
"""
raise NotImplementedError("get_data_with_date for this model is not implemented.")
def finetune(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs):
"""Finetune model
In `RollingTrainer`:
if loader.model_index is None:
If provide 'Static Model', based on the provided 'Static' model update.
If provide 'Rolling Model', skip the model of load, based on the last 'provided model' update.
if loader.model_index is not None:
Based on the provided model(loader.model_index) update.
In `StaticTrainer`:
If the load is 'static model':
Based on the 'static model' update
If the load is 'rolling model':
Based on the provided model(`loader.model_index`) update. If `loader.model_index` is None, use the last model.
Parameters
----------
x_train : pd.dataframe
train data
y_train : pd.dataframe
train label
x_valid : pd.dataframe
valid data
y_valid : pd.dataframe
valid label
w_train : pd.dataframe
train weight
w_valid : pd.dataframe
valid weight
Returns
----------
Model
finetune model
"""
raise NotImplementedError("Finetune for this model is not implemented.")

View File

@@ -0,0 +1,91 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import numpy as np
import lightgbm as lgb
from sklearn.metrics import roc_auc_score, mean_squared_error
from .base import Model
from ...utils import drop_nan_by_y_index
class LGBModel(Model):
"""LightGBM Model
Parameters
----------
param_update : dict
training parameters
"""
_params = dict()
def __init__(self, loss="mse", **kwargs):
if loss not in {"mse", "binary"}:
raise NotImplementedError
self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
self._params.update(objective=loss, **kwargs)
self._model = None
def fit(
self,
x_train,
y_train,
x_valid,
y_valid,
w_train=None,
w_valid=None,
num_boost_round=1000,
early_stopping_rounds=50,
verbose_eval=20,
evals_result=dict(),
**kwargs
):
# Lightgbm need 1D array as its label
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze(y_valid.values)
else:
raise ValueError("LightGBM doesn't support multi-label training")
w_train_weight = None if w_train is None else w_train.values
w_valid_weight = None if w_valid is None else w_valid.values
dtrain = lgb.Dataset(x_train.values, label=y_train_1d, weight=w_train_weight)
dvalid = lgb.Dataset(x_valid.values, label=y_valid_1d, weight=w_valid_weight)
self._model = lgb.train(
self._params,
dtrain,
num_boost_round=num_boost_round,
valid_sets=[dtrain, dvalid],
valid_names=["train", "valid"],
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
evals_result=evals_result,
**kwargs
)
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]
def predict(self, x_test):
if self._model is None:
raise ValueError("model is not fitted yet!")
return self._model.predict(x_test.values)
def score(self, x_test, y_test, w_test=None):
# Remove rows from x, y and w, which contain Nan in any columns in y_test.
x_test, y_test, w_test = drop_nan_by_y_index(x_test, y_test, w_test)
preds = self.predict(x_test)
w_test_weight = None if w_test is None else w_test.values
return self._scorer(y_test.values, preds, sample_weight=w_test_weight)
def save(self, filename):
if self._model is None:
raise ValueError("model is not fitted yet!")
self._model.save_model(filename)
def load(self, buffer):
self._model = lgb.Booster(params={"model_str": buffer.decode("utf-8")})

View File

@@ -0,0 +1,356 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score, mean_squared_error
import logging
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
from ...log import get_module_logger, TimeInspector
import torch
import torch.nn as nn
import torch.optim as optim
from .base import Model
class DNNModelPytorch(Model):
"""DNN Model
Parameters
----------
input_dim : int
input dimension
output_dim : int
output dimension
layers : tuple
layer sizes
lr : float
learning rate
lr_decay : float
learning rate decay
lr_decay_steps : int
learning rate decay steps
optimizer : str
optimizer name
GPU : str
the GPU ID(s) used for training
"""
def __init__(
self,
input_dim,
output_dim,
layers=(256, 256, 128),
lr=0.001,
max_steps=300,
batch_size=2000,
early_stop_rounds=50,
eval_steps=20,
lr_decay=0.96,
lr_decay_steps=100,
optimizer="gd",
loss="mse",
GPU="0",
**kwargs
):
# Set logger.
self.logger = get_module_logger("DNNModelPytorch")
self.logger.info("DNN pytorch version...")
# set hyper-parameters.
self.layers = layers
self.lr = lr
self.max_steps = max_steps
self.batch_size = batch_size
self.early_stop_rounds = early_stop_rounds
self.eval_steps = eval_steps
self.lr_decay = lr_decay
self.lr_decay_steps = lr_decay_steps
self.optimizer = optimizer.lower()
self.loss_type = loss
self.visible_GPU = GPU
self.logger.info(
"DNN parameters setting:"
"\nlayers : {}"
"\nlr : {}"
"\nmax_steps : {}"
"\nbatch_size : {}"
"\nearly_stop_rounds : {}"
"\neval_steps : {}"
"\nlr_decay : {}"
"\nlr_decay_steps : {}"
"\noptimizer : {}"
"\nloss_type : {}"
"\neval_steps : {}"
"\nvisible_GPU : {}".format(
layers,
lr,
max_steps,
batch_size,
early_stop_rounds,
eval_steps,
lr_decay,
lr_decay_steps,
optimizer,
loss,
eval_steps,
GPU,
)
)
if loss not in {"mse", "binary"}:
raise NotImplementedError("loss {} is not supported!".format(loss))
self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
self.dnn_model = Net(input_dim, output_dim, layers, loss=self.loss_type)
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr)
elif optimizer.lower() == "gd":
self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr)
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
# Reduce learning rate when loss has stopped decrease
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.train_optimizer,
mode="min",
factor=0.5,
patience=10,
verbose=True,
threshold=0.0001,
threshold_mode="rel",
cooldown=0,
min_lr=0.00001,
eps=1e-08,
)
self._fitted = False
self.dnn_model.cuda()
# set the visible GPU
if self.visible_GPU:
os.environ["CUDA_VISIBLE_DEVICES"] = self.visible_GPU
def fit(
self,
x_train,
y_train,
x_valid,
y_valid,
w_train=None,
w_valid=None,
evals_result=dict(),
verbose=True,
save_path=None,
):
if w_train is None:
w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)
if w_valid is None:
w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)
save_path = create_save_path(save_path)
stop_steps = 0
train_loss = 0
best_loss = np.inf
evals_result["train"] = []
evals_result["valid"] = []
# train
self.logger.info("training...")
self._fitted = True
#return
# prepare training data
x_train_values = torch.from_numpy(x_train.values).float()
y_train_values = torch.from_numpy(y_train.values).float()
w_train_values = torch.from_numpy(w_train.values).float()
train_num = y_train_values.shape[0]
# prepare validation data
x_val_cuda = torch.from_numpy(x_valid.values).float()
y_val_cuda = torch.from_numpy(y_valid.values).float()
w_val_cuda = torch.from_numpy(w_valid.values).float()
x_val_cuda = x_val_cuda.cuda()
y_val_cuda = y_val_cuda.cuda()
w_val_cuda = w_val_cuda.cuda()
for step in range(self.max_steps):
if stop_steps >= self.early_stop_rounds:
if verbose:
self.logger.info("\tearly stop")
break
loss = AverageMeter()
self.dnn_model.train()
self.train_optimizer.zero_grad()
choice = np.random.choice(train_num, self.batch_size)
x_batch = x_train_values[choice]
y_batch = y_train_values[choice]
w_batch = w_train_values[choice]
x_batch_cuda = x_batch.float().cuda()
y_batch_cuda = y_batch.float().cuda()
w_batch_cuda = w_batch.float().cuda()
# forward
preds = self.dnn_model(x_batch_cuda)
cur_loss = self.get_loss(preds, w_batch_cuda, y_batch_cuda, self.loss_type)
cur_loss.backward()
self.train_optimizer.step()
loss.update(cur_loss.item())
# validation
train_loss += loss.val
#print(loss.val)
if step and step % self.eval_steps == 0:
stop_steps += 1
train_loss /= self.eval_steps
with torch.no_grad():
self.dnn_model.eval()
loss_val = AverageMeter()
# forward
preds = self.dnn_model(x_val_cuda)
cur_loss_val = self.get_loss(preds, w_val_cuda, y_val_cuda, self.loss_type)
loss_val.update(cur_loss_val.item())
if verbose:
self.logger.info(
"[Epoch {}]: train_loss {:.6f}, valid_loss {:.6f}".format(step, train_loss, loss_val.val)
)
evals_result["train"].append(train_loss)
evals_result["valid"].append(loss_val.val)
if loss_val.val < best_loss:
if verbose:
self.logger.info(
"\tvalid loss update from {:.6f} to {:.6f}, save checkpoint.".format(
best_loss, loss_val.val
)
)
best_loss = loss_val.val
stop_steps = 0
torch.save(self.dnn_model.state_dict(), save_path)
train_loss = 0
# update learning rate
self.scheduler.step(cur_loss_val)
# restore the optimal parameters after training ??
self.dnn_model.load_state_dict(torch.load(save_path))
torch.cuda.empty_cache()
def get_loss(self, pred, w, target, loss_type):
if loss_type == "mse":
sqr_loss = torch.mul(pred - target, pred - target)
loss = torch.mul(sqr_loss, w).mean()
return loss
elif loss_type == "binary":
loss = nn.BCELoss()
return loss(pred, target)
else:
raise NotImplementedError("loss {} is not supported!".format(loss_type))
def predict(self, x_test):
if not self._fitted:
raise ValueError("model is not fitted yet!")
x_test = torch.from_numpy(x_test.values).float().cuda()
self.dnn_model.eval()
with torch.no_grad():
preds = self.dnn_model(x_test).detach().cpu().numpy()
return preds
def score(self, x_test, y_test, w_test=None):
# Remove rows from x, y and w, which contain Nan in any columns in y_test.
x_test, y_test, w_test = drop_nan_by_y_index(x_test, y_test, w_test)
preds = self.predict(x_test)
w_test_weight = None if w_test is None else w_test.values
return self._scorer(y_test.values, preds, sample_weight=w_test_weight)
def save(self, filename, **kwargs):
with save_multiple_parts_file(filename) as model_dir:
model_path = os.path.join(model_dir, os.path.split(model_dir)[-1])
# Save model
torch.save(self.dnn_model.state_dict(), model_path)
def load(self, buffer, **kwargs):
with unpack_archive_with_buffer(buffer) as model_dir:
# Get model name
_model_name = os.path.splitext(list(filter(lambda x: x.startswith("model.bin"), os.listdir(model_dir)))[0])[
0
]
_model_path = os.path.join(model_dir, _model_name)
# Load model
self.dnn_model.load_state_dict(torch.load(_model_path))
self._fitted = True
def finetune(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs):
self.fit(x_train, y_train, x_valid, y_valid, w_train=w_train, w_valid=w_valid, **kwargs)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class Net(nn.Module):
def __init__(self, input_dim, output_dim, layers=(256, 256, 256), loss="mse"):
super(Net, self).__init__()
layers = [input_dim] + list(layers)
dnn_layers = []
drop_input = nn.Dropout(0.1)
dnn_layers.append(drop_input)
for i, (input_dim, hidden_units) in enumerate(zip(layers[:-1], layers[1:])):
fc = nn.Linear(input_dim, hidden_units)
activation = nn.ReLU()
bn = nn.BatchNorm1d(hidden_units)
drop = nn.Dropout(0.1)
seq = nn.Sequential(fc, bn, activation, drop)
dnn_layers.append(seq)
if loss == "mse":
fc = nn.Linear(hidden_units, output_dim)
dnn_layers.append(fc)
elif loss == "binary":
fc = nn.Linear(hidden_units, output_dim)
sigmoid = nn.Sigmoid()
dnn_layers.append(nn.Sequential(fc, sigmoid))
else:
raise NotImplementedError("loss {} is not supported!".format(loss))
# optimizer
self.dnn_layers = nn.ModuleList(dnn_layers)
self._weight_init()
def _weight_init(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight, gain=1)
def forward(self, x):
cur_output = x
for i, now_layer in enumerate(self.dnn_layers):
cur_output = now_layer(cur_output)
return cur_output

View File

View File

@@ -0,0 +1,291 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
import json
import copy
import pathlib
import pandas as pd
from ...data import D
from ...utils import get_date_in_file_name
from ...utils import get_pre_trading_date
from ..backtest.order import Order
class BaseExecutor:
"""
# Strategy framework document
class Executor(BaseExecutor):
"""
def execute(self, trade_account, order_list, trade_date):
"""
return the executed result (trade_info) after trading at trade_date.
NOTICE: trade_account will not be modified after executing.
Parameter
---------
trade_account : Account()
order_list : list
[Order()]
trade_date : pd.Timestamp
Return
---------
trade_info : list
[Order(), float, float, float]
"""
raise NotImplementedError("get_execute_result for this model is not implemented.")
def save_executed_file_from_trade_info(self, trade_info, user_path, trade_date):
"""
Save the trade_info to the .csv transaction file in disk
the columns of result file is
['date', 'stock_id', 'direction', 'trade_val', 'trade_cost', 'trade_price', 'factor']
Parameter
---------
trade_info : list of [Order(), float, float, float]
(order, trade_val, trade_cost, trade_price), trade_info with out factor
user_path: str / pathlib.Path()
the sub folder to save user data
transaction_path : string / pathlib.Path()
"""
YYYY, MM, DD = str(trade_date.date()).split("-")
folder_path = pathlib.Path(user_path) / "trade" / YYYY / MM
if not folder_path.exists():
folder_path.mkdir(parents=True)
transaction_path = folder_path / "transaction_{}.csv".format(str(trade_date.date()))
columns = [
"date",
"stock_id",
"direction",
"amount",
"trade_val",
"trade_cost",
"trade_price",
"factor",
]
data = []
for [order, trade_val, trade_cost, trade_price] in trade_info:
data.append(
[
trade_date,
order.stock_id,
order.direction,
order.amount,
trade_val,
trade_cost,
trade_price,
order.factor,
]
)
df = pd.DataFrame(data, columns=columns)
df.to_csv(transaction_path, index=False)
def load_trade_info_from_executed_file(self, user_path, trade_date):
YYYY, MM, DD = str(trade_date.date()).split("-")
file_path = pathlib.Path(user_path) / "trade" / YYYY / MM / "transaction_{}.csv".format(str(trade_date.date()))
if not file_path.exists():
raise ValueError("File {} not exists!".format(file_path))
filedate = get_date_in_file_name(file_path)
transaction = pd.read_csv(file_path)
trade_info = []
for i in range(len(transaction)):
date = transaction.loc[i]["date"]
if not date == filedate:
continue
# raise ValueError("date in transaction file {} not equal to it's file date{}".format(date, filedate))
order = Order(
stock_id=transaction.loc[i]["stock_id"],
amount=transaction.loc[i]["amount"],
trade_date=transaction.loc[i]["date"],
direction=transaction.loc[i]["direction"],
factor=transaction.loc[i]["factor"],
)
trade_val = transaction.loc[i]["trade_val"]
trade_cost = transaction.loc[i]["trade_cost"]
trade_price = transaction.loc[i]["trade_price"]
trade_info.append([order, trade_val, trade_cost, trade_price])
return trade_info
class SimulatorExecutor(BaseExecutor):
def __init__(self, trade_exchange, verbose=False):
self.trade_exchange = trade_exchange
self.verbose = verbose
self.order_list = []
def execute(self, trade_account, order_list, trade_date):
"""
execute the order list, do the trading wil exchange at date.
Will not modify the trade_account.
Parameter
trade_account : Account()
order_list : list
list or orders
trade_date : pd.Timestamp
:return:
trade_info : list of [Order(), float, float, float]
(order, trade_val, trade_cost, trade_price), trade_info with out factor
"""
account = copy.deepcopy(trade_account)
trade_info = []
for order in order_list:
# check holding thresh is done in strategy
# if order.direction==0: # sell order
# # checking holding thresh limit for sell order
# if trade_account.current.get_stock_count(order.stock_id) < thresh:
# # can not sell this code
# continue
# is order executable
# check order
if self.trade_exchange.check_order(order) is True:
# execute the order
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(order, trade_account=account)
trade_info.append([order, trade_val, trade_cost, trade_price])
if self.verbose:
if order.direction == Order.SELL: # sell
print(
"[I {:%Y-%m-%d}]: sell {}, price {:.2f}, amount {}, value {:.2f}.".format(
trade_date,
order.stock_id,
trade_price,
order.deal_amount,
trade_val,
)
)
else:
print(
"[I {:%Y-%m-%d}]: buy {}, price {:.2f}, amount {}, value {:.2f}.".format(
trade_date,
order.stock_id,
trade_price,
order.deal_amount,
trade_val,
)
)
else:
if self.verbose:
print("[W {:%Y-%m-%d}]: {} wrong.".format(trade_date, order.stock_id))
# do nothing
pass
return trade_info
def save_score_series(score_series, user_path, trade_date):
"""Save the score_series into a .csv file.
The columns of saved file is
[stock_id, score]
Parameter
---------
order_list: [Order()]
list of Order()
date: pd.Timestamp
the date to save the order list
user_path: str / pathlib.Path()
the sub folder to save user data
"""
user_path = pathlib.Path(user_path)
YYYY, MM, DD = str(trade_date.date()).split("-")
folder_path = user_path / "score" / YYYY / MM
if not folder_path.exists():
folder_path.mkdir(parents=True)
file_path = folder_path / "score_{}.csv".format(str(trade_date.date()))
score_series.to_csv(file_path)
def load_score_series(user_path, trade_date):
"""Save the score_series into a .csv file.
The columns of saved file is
[stock_id, score]
Parameter
---------
order_list: [Order()]
list of Order()
date: pd.Timestamp
the date to save the order list
user_path: str / pathlib.Path()
the sub folder to save user data
"""
user_path = pathlib.Path(user_path)
YYYY, MM, DD = str(trade_date.date()).split("-")
folder_path = user_path / "score" / YYYY / MM
if not folder_path.exists():
folder_path.mkdir(parents=True)
file_path = folder_path / "score_{}.csv".format(str(trade_date.date()))
score_series = pd.read_csv(file_path, index_col=0, header=None, names=["instrument", "score"])
return score_series
def save_order_list(order_list, user_path, trade_date):
"""
Save the order list into a json file.
Will calculate the real amount in order according to factors at date.
The format in json file like
{"sell": {"stock_id": amount, ...}
,"buy": {"stock_id": amount, ...}}
:param
order_list: [Order()]
list of Order()
date: pd.Timestamp
the date to save the order list
user_path: str / pathlib.Path()
the sub folder to save user data
"""
user_path = pathlib.Path(user_path)
YYYY, MM, DD = str(trade_date.date()).split("-")
folder_path = user_path / "trade" / YYYY / MM
if not folder_path.exists():
folder_path.mkdir(parents=True)
sell = {}
buy = {}
for order in order_list:
if order.direction == 0: # sell
sell[order.stock_id] = [order.amount, order.factor]
else:
buy[order.stock_id] = [order.amount, order.factor]
order_dict = {"sell": sell, "buy": buy}
file_path = folder_path / "orderlist_{}.json".format(str(trade_date.date()))
with file_path.open("w") as fp:
json.dump(order_dict, fp)
def load_order_list(user_path, trade_date):
user_path = pathlib.Path(user_path)
YYYY, MM, DD = str(trade_date.date()).split("-")
path = user_path / "trade" / YYYY / MM / "orderlist_{}.json".format(str(trade_date.date()))
if not path.exists():
raise ValueError("File {} not exists!".format(path))
# get orders
with path.open("r") as fp:
order_dict = json.load(fp)
order_list = []
for stock_id in order_dict["sell"]:
amount, factor = order_dict["sell"][stock_id]
order = Order(
stock_id=stock_id,
amount=amount,
trade_date=pd.Timestamp(trade_date),
direction=Order.SELL,
factor=factor,
)
order_list.append(order)
for stock_id in order_dict["buy"]:
amount, factor = order_dict["buy"][stock_id]
order = Order(
stock_id=stock_id,
amount=amount,
trade_date=pd.Timestamp(trade_date),
direction=Order.BUY,
factor=factor,
)
order_list.append(order)
return order_list

View File

@@ -0,0 +1,147 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import pickle
import yaml
import pathlib
import pandas as pd
import shutil
from ..backtest.account import Account
from ..backtest.exchange import Exchange
from .user import User
from .utils import load_instance
from .utils import save_instance, init_instance_by_config
class UserManager:
def __init__(self, user_data_path, save_report=True):
"""
This module is designed to manager the users in online system
all users' data were assumed to be saved in user_data_path
Parameter
user_data_path : string
data path that all users' data were saved in
variables:
data_path : string
data path that all users' data were saved in
users_file : string
A path of the file record the add_date of users
save_report : bool
whether to save report after each trading process
users : dict{}
[user_id]->User()
the python dict save instances of User() for each user_id
user_record : pd.Dataframe
user_id(string), add_date(string)
indicate the add_date for each users
"""
self.data_path = pathlib.Path(user_data_path)
self.users_file = self.data_path / "users.csv"
self.save_report = save_report
self.users = {}
self.user_record = None
def load_users(self):
"""
load all users' data into manager
"""
self.users = {}
self.user_record = pd.read_csv(self.users_file, index_col=0)
for user_id in self.user_record.index:
self.users[user_id] = self.load_user(user_id)
def load_user(self, user_id):
"""
return a instance of User() represents a user to be processed
Parameter
user_id : string
:return
user : User()
"""
account_path = self.data_path / user_id
strategy_file = self.data_path / user_id / "strategy_{}.pickle".format(user_id)
model_file = self.data_path / user_id / "model_{}.pickle".format(user_id)
cur_user_list = [user_id for user_id in self.users]
if user_id in cur_user_list:
raise ValueError("User {} has been loaded".format(user_id))
else:
trade_account = Account(0)
trade_account.load_account(account_path)
strategy = load_instance(strategy_file)
model = load_instance(model_file)
user = User(account=trade_account, strategy=strategy, model=model)
return user
def save_user_data(self, user_id):
"""
save a instance of User() to user data path
Parameter
user_id : string
"""
if not user_id in self.users:
raise ValueError("Cannot find user {}".format(user_id))
self.users[user_id].account.save_account(self.data_path / user_id)
save_instance(
self.users[user_id].strategy,
self.data_path / user_id / "strategy_{}.pickle".format(user_id),
)
save_instance(
self.users[user_id].model,
self.data_path / user_id / "model_{}.pickle".format(user_id),
)
def add_user(self, user_id, config_file, add_date):
"""
add the new user {user_id} into user data
will create a new folder named "{user_id}" in user data path
Parameter
user_id : string
init_cash : int
config_file : str/pathlib.Path()
path of config file
"""
config_file = pathlib.Path(config_file)
if not config_file.exists():
raise ValueError("Cannot find config file {}".format(config_file))
user_path = self.data_path / user_id
if user_path.exists():
raise ValueError("User data for {} already exists".format(user_id))
with config_file.open("r") as fp:
config = yaml.load(fp)
# load model
model = init_instance_by_config(config["model"])
# load strategy
strategy = init_instance_by_config(config["strategy"])
init_args = strategy.get_init_args_from_model(model, add_date)
strategy.init(**init_args)
# init Account
trade_account = Account(init_cash=config["init_cash"])
# save user
user_path.mkdir()
save_instance(model, self.data_path / user_id / "model_{}.pickle".format(user_id))
save_instance(strategy, self.data_path / user_id / "strategy_{}.pickle".format(user_id))
trade_account.save_account(self.data_path / user_id)
user_record = pd.read_csv(self.users_file, index_col=0)
user_record.loc[user_id] = [add_date]
user_record.to_csv(self.users_file)
def remove_user(self, user_id):
"""
remove user {user_id} in current user dataset
will delete the folder "{user_id}" in user data path
:param
user_id : string
"""
user_path = self.data_path / user_id
if not user_path.exists():
raise ValueError("Cannot find user data {}".format(user_id))
shutil.rmtree(user_path)
user_record = pd.read_csv(self.users_file, index_col=0)
user_record.drop([user_id], inplace=True)
user_record.to_csv(self.users_file)

View File

@@ -0,0 +1,36 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import random
import pandas as pd
from ...data import D
from ..model.base import Model
class ScoreFileModel(Model):
"""
This model will load a score file, and return score at date exists in score file.
"""
def __init__(self, score_path):
pred_test = pd.read_csv(score_path, index_col=[0, 1], parse_dates=True, infer_datetime_format=True)
self.pred = pred_test
def get_data_with_date(self, date, **kwargs):
score = self.pred.loc(axis=0)[:, date] # (stock_id, trade_date) multi_index, score in pdate
score_series = score.reset_index(level="datetime", drop=True)[
"score"
] # pd.Series ; index:stock_id, data: score
return score_series
def predict(self, x_test, **kwargs):
return x_test
def score(self, x_test, **kwargs):
return
def fit(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs):
return
def save(self, fname, **kwargs):
return

View File

@@ -0,0 +1,317 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import fire
import pandas as pd
import pathlib
import qlib
import logging
from ...data import D
from ...log import get_module_logger
from ...utils import get_pre_trading_date, is_tradable_date
from ..evaluate import risk_analysis
from ..backtest.backtest import update_account
from .manager import UserManager
from .utils import prepare
from .utils import create_user_folder
from .executor import load_order_list, save_order_list
from .executor import SimulatorExecutor
from .executor import save_score_series, load_score_series
class Operator(object):
def __init__(self, client: str):
"""
Parameters
----------
client: str
The qlib client config file(.yaml)
"""
self.logger = get_module_logger("online operator", level=logging.INFO)
self.client = client
@staticmethod
def init(client, path, date=None):
"""Initial UserManager(), get predict date and trade date
Parameters
----------
client: str
The qlib client config file(.yaml)
path : str
Path to save user account.
date : str (YYYY-MM-DD)
Trade date, when the generated order list will be traded.
Return
----------
um: UserManager()
pred_date: pd.Timestamp
trade_date: pd.Timestamp
"""
qlib.init_from_yaml_conf(client)
um = UserManager(user_data_path=pathlib.Path(path))
um.load_users()
if not date:
trade_date, pred_date = None, None
else:
trade_date = pd.Timestamp(date)
if not is_tradable_date(trade_date):
raise ValueError("trade date is not tradable date".format(trade_date.date()))
pred_date = get_pre_trading_date(trade_date, future=True)
return um, pred_date, trade_date
def add_user(self, id, config, path, date):
"""Add a new user into the a folder to run 'online' module.
Parameters
----------
id : str
User id, should be unique.
config : str
The file path (yaml) of user config
path : str
Path to save user account.
date : str (YYYY-MM-DD)
The date that user account was added.
"""
create_user_folder(path)
qlib.init_from_yaml_conf(self.client)
um = UserManager(user_data_path=path)
add_date = D.calendar(end_time=date)[-1]
if not is_tradable_date(add_date):
raise ValueError("add date is not tradable date".format(add_date.date()))
um.add_user(user_id=id, config_file=config, add_date=add_date)
def remove_user(self, id, path):
"""Remove user from folder used in 'online' module.
Parameters
----------
id : str
User id, should be unique.
path : str
Path to save user account.
"""
um = UserManager(user_data_path=path)
um.remove_user(user_id=id)
def generate(self, date, path):
"""Generate order list that will be traded at 'date'.
Parameters
----------
date : str (YYYY-MM-DD)
Trade date, when the generated order list will be traded.
path : str
Path to save user account.
"""
um, pred_date, trade_date = self.init(self.client, path, date)
for user_id, user in um.users.items():
dates, trade_exchange = prepare(um, pred_date, user_id)
# get and save the score at predict date
input_data = user.model.get_data_with_date(pred_date)
score_series = user.model.predict(input_data)
save_score_series(score_series, (pathlib.Path(path) / user_id), trade_date)
# update strategy (and model)
user.strategy.update(score_series, pred_date, trade_date)
# generate and save order list
order_list = user.strategy.generate_order_list(
score_series=score_series,
current=user.account.current,
trade_exchange=trade_exchange,
trade_date=trade_date,
)
save_order_list(
order_list=order_list,
user_path=(pathlib.Path(path) / user_id),
trade_date=trade_date,
)
self.logger.info("Generate order list at {} for {}".format(trade_date, user_id))
um.save_user_data(user_id)
def execute(self, date, exchange_config, path):
"""Execute the orderlist at 'date'.
Parameters
----------
date : str (YYYY-MM-DD)
Trade date, that the generated order list will be traded.
exchange_config: str
The file path (yaml) of exchange config
path : str
Path to save user account.
"""
um, pred_date, trade_date = self.init(self.client, path, date)
for user_id, user in um.users.items():
dates, trade_exchange = prepare(um, trade_date, user_id, exchange_config)
executor = SimulatorExecutor(trade_exchange=trade_exchange)
if not str(dates[0].date()) == str(pred_date.date()):
raise ValueError(
"The account data is not newest! last trading date {}, today {}".format(
dates[0].date(), trade_date.date()
)
)
# load and execute the order list
# will not modify the trade_account after executing
order_list = load_order_list(user_path=(pathlib.Path(path) / user_id), trade_date=trade_date)
trade_info = executor.execute(order_list=order_list, trade_account=user.account, trade_date=trade_date)
executor.save_executed_file_from_trade_info(
trade_info=trade_info,
user_path=(pathlib.Path(path) / user_id),
trade_date=trade_date,
)
self.logger.info("execute order list at {} for {}".format(trade_date.date(), user_id))
def update(self, date, path, type="SIM"):
"""Update account at 'date'.
Parameters
----------
date : str (YYYY-MM-DD)
Trade date, that the generated order list will be traded.
path : str
Path to save user account.
type : str
which executor was been used to execute the order list
'SIM': SimulatorExecutor()
"""
if type not in ["SIM", "YC"]:
raise ValueError("type is invalid, {}".format(type))
um, pred_date, trade_date = self.init(self.client, path, date)
for user_id, user in um.users.items():
dates, trade_exchange = prepare(um, trade_date, user_id)
if type == "SIM":
executor = SimulatorExecutor(trade_exchange=trade_exchange)
else:
raise ValueError("not found executor")
# dates[0] is the last_trading_date
if str(dates[0].date()) > str(pred_date.date()):
raise ValueError(
"The account data is not newest! last trading date {}, today {}".format(
dates[0].date(), trade_date.date()
)
)
# load trade info and update account
trade_info = executor.load_trade_info_from_executed_file(
user_path=(pathlib.Path(path) / user_id), trade_date=trade_date
)
score_series = load_score_series((pathlib.Path(path) / user_id), trade_date)
update_account(user.account, trade_info, trade_exchange, trade_date)
report = user.account.report.generate_report_dataframe()
self.logger.info(report)
um.save_user_data(user_id)
self.logger.info("Update account state {} for {}".format(trade_date, user_id))
def simulate(self, id, config, exchange_config, start, end, path, bench="SH000905"):
"""Run the ( generate_order_list -> execute_order_list -> update_account) process everyday
from start date to end date.
Parameters
----------
id : str
user id, need to be unique
config : str
The file path (yaml) of user config
exchange_config: str
The file path (yaml) of exchange config
start : str "YYYY-MM-DD"
The start date to run the online simulate
end : str "YYYY-MM-DD"
The end date to run the online simulate
path : str
Path to save user account.
bench : str
The benchmark that our result compared with.
'SH000905' for csi500, 'SH000300' for csi300
"""
# Clear the current user if exists, then add a new user.
create_user_folder(path)
um = self.init(self.client, path, None)[0]
start_date, end_date = pd.Timestamp(start), pd.Timestamp(end)
try:
um.remove_user(user_id=id)
except BaseException:
pass
um.add_user(user_id=id, config_file=config, add_date=pd.Timestamp(start_date))
# Do the online simulate
um.load_users()
user = um.users[id]
dates, trade_exchange = prepare(um, end_date, id, exchange_config)
executor = SimulatorExecutor(trade_exchange=trade_exchange)
for pred_date, trade_date in zip(dates[:-2], dates[1:-1]):
user_path = pathlib.Path(path) / id
# 1. load and save score_series
input_data = user.model.get_data_with_date(pred_date)
score_series = user.model.predict(input_data)
save_score_series(score_series, (pathlib.Path(path) / id), trade_date)
# 2. update strategy (and model)
user.strategy.update(score_series, pred_date, trade_date)
# 3. generate and save order list
order_list = user.strategy.generate_order_list(
score_series=score_series,
current=user.account.current,
trade_exchange=trade_exchange,
trade_date=trade_date,
)
save_order_list(order_list=order_list, user_path=user_path, trade_date=trade_date)
# 4. auto execute order list
order_list = load_order_list(user_path=user_path, trade_date=trade_date)
trade_info = executor.execute(trade_account=user.account, order_list=order_list, trade_date=trade_date)
executor.save_executed_file_from_trade_info(
trade_info=trade_info, user_path=user_path, trade_date=trade_date
)
# 5. update account state
trade_info = executor.load_trade_info_from_executed_file(user_path=user_path, trade_date=trade_date)
update_account(user.account, trade_info, trade_exchange, trade_date)
report = user.account.report.generate_report_dataframe()
self.logger.info(report)
um.save_user_data(id)
self.show(id, path, bench)
def show(self, id, path, bench="SH000905"):
"""show the newly report (mean, std, sharpe, annual)
Parameters
----------
id : str
user id, need to be unique
path : str
Path to save user account.
bench : str
The benchmark that our result compared with.
'SH000905' for csi500, 'SH000300' for csi300
"""
um = self.init(self.client, path, None)[0]
if id not in um.users:
raise ValueError("Cannot find user ".format(id))
bench = D.features([bench], ["$change"]).loc[bench, "$change"]
report = um.users[id].account.report.generate_report_dataframe()
report["bench"] = bench
analysis_result = {}
r = (report["return"] - report["bench"]).dropna()
analysis_result["sub_bench"] = risk_analysis(r)
r = (report["return"] - report["bench"] - report["cost"]).dropna()
analysis_result["sub_cost"] = risk_analysis(r)
print("Result:")
print("sub_bench:")
print(analysis_result["sub_bench"])
print("sub_cost:")
print(analysis_result["sub_cost"])
def run():
fire.Fire(Operator)
if __name__ == "__main__":
run()

View File

@@ -0,0 +1,74 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from ...log import get_module_logger
from ..evaluate import risk_analysis
from ...data import D
class User:
def __init__(self, account, strategy, model, verbose=False):
"""
A user in online system, which contains account, strategy and model three module.
Parameter
account : Account()
strategy :
a strategy instance
model :
a model instance
report_save_path : string
the path to save report. Will not save report if None
verbose : bool
Whether to print the info during the process
"""
self.logger = get_module_logger("User", level=logging.INFO)
self.account = account
self.strategy = strategy
self.model = model
self.verbose = verbose
def init_state(self, date):
"""
init state when each trading date begin
Parameter
date : pd.Timestamp
"""
self.account.init_state(today=date)
self.strategy.init_state(trade_date=date, model=self.model, account=self.account)
return
def get_latest_trading_date(self):
"""
return the latest trading date for user {user_id}
Parameter
user_id : string
:return
date : string (e.g '2018-10-08')
"""
if not self.account.last_trade_date:
return None
return str(self.account.last_trade_date.date())
def showReport(self, benchmark="SH000905"):
"""
show the newly report (mean, std, sharpe, annual)
Parameter
benchmark : string
bench that to be compared, 'SH000905' for csi500
"""
bench = D.features([benchmark], ["$change"], disk_cache=True).loc[benchmark, "$change"]
report = self.account.report.generate_report_dataframe()
report["bench"] = bench
analysis_result = {"pred": {}, "sub_bench": {}, "sub_cost": {}}
r = (report["return"] - report["bench"]).dropna()
analysis_result["sub_bench"][0] = risk_analysis(r)
r = (report["return"] - report["bench"] - report["cost"]).dropna()
analysis_result["sub_cost"][0] = risk_analysis(r)
self.logger.info("Result of porfolio:")
self.logger.info("sub_bench:")
self.logger.info(analysis_result["sub_bench"][0])
self.logger.info("sub_cost:")
self.logger.info(analysis_result["sub_cost"][0])
return report

View File

@@ -0,0 +1,110 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pathlib
import pickle
import yaml
import pandas as pd
from ...data import D
from ...log import get_module_logger
from ...utils import get_module_by_module_path
from ...utils import get_next_trading_date
from ..backtest.exchange import Exchange
log = get_module_logger("utils")
def load_instance(file_path):
"""
load a pickle file
Parameter
file_path : string / pathlib.Path()
path of file to be loaded
:return
An instance loaded from file
"""
file_path = pathlib.Path(file_path)
if not file_path.exists():
raise ValueError("Cannot find file {}".format(file_path))
with file_path.open("rb") as fr:
instance = pickle.load(fr)
return instance
def save_instance(instance, file_path):
"""
save(dump) an instance to a pickle file
Parameter
instance :
data to te dumped
file_path : string / pathlib.Path()
path of file to be dumped
"""
file_path = pathlib.Path(file_path)
with file_path.open("wb") as fr:
pickle.dump(instance, fr)
def init_instance_by_config(config):
"""
generate an instance with settings in config
Parameter
config : dict
python dict indicate a init parameters to create an item
:return
An instance
"""
module = get_module_by_module_path(config["module_path"])
instance_class = getattr(module, config["class"])
instance = instance_class(**config["args"])
return instance
def create_user_folder(path):
path = pathlib.Path(path)
if path.exists():
return
path.mkdir(parents=True)
head = pd.DataFrame(columns=("user_id", "add_date"))
head.to_csv(path / "users.csv", index=None)
def prepare(um, today, user_id, exchange_config=None):
"""
1. Get the dates that need to do trading till today for user {user_id}
dates[0] indicate the latest trading date of User{user_id},
if User{user_id} haven't do trading before, than dates[0] presents the init date of User{user_id}.
2. Set the exchange with exchange_config file
Parameter
um : UserManager()
today : pd.Timestamp()
user_id : str
:return
dates : list of pd.Timestamp
trade_exchange : Exchange()
"""
# get latest trading date for {user_id}
# if is None, indicate it haven't traded, then last trading date is init date of {user_id}
latest_trading_date = um.users[user_id].get_latest_trading_date()
if not latest_trading_date:
latest_trading_date = um.user_record.loc[user_id][0]
if str(today.date()) < latest_trading_date:
log.warning("user_id:{}, last trading date {} after today {}".format(user_id, latest_trading_date, today))
return [pd.Timestamp(latest_trading_date)], None
dates = D.calendar(
start_time=pd.Timestamp(latest_trading_date),
end_time=pd.Timestamp(today),
future=True,
)
dates = list(dates)
dates.append(get_next_trading_date(dates[-1], future=True))
if exchange_config:
with pathlib.Path(exchange_config).open("r") as fp:
exchange_paras = yaml.load(fp)
else:
exchange_paras = {}
trade_exchange = Exchange(trade_dates=dates, **exchange_paras)
return dates, trade_exchange

View File

@@ -0,0 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
GRAPH_NAME_LISt = [
"analysis_position.report_graph",
"analysis_position.score_ic_graph",
"analysis_position.cumulative_return_graph",
"analysis_position.risk_analysis_graph",
"analysis_position.rank_label_graph",
"analysis_model.model_performance_graph",
]

View File

@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .analysis_model_performance import model_performance_graph

View File

@@ -0,0 +1,304 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
import plotly.tools as tls
import plotly.graph_objs as go
import statsmodels.api as sm
import matplotlib.pyplot as plt
from scipy import stats
from ..graph import ScatterGraph, SubplotsGraph, BarGraph, HeatmapGraph
def _group_return(
pred_label: pd.DataFrame = None, reverse: bool = False, N: int = 5, **kwargs
) -> tuple:
"""
:param pred_label:
:param reverse:
:param N:
:return:
"""
if reverse:
pred_label["score"] *= -1
pred_label = pred_label.sort_values("score", ascending=False)
# Group1 ~ Group5 only consider the dropna values
pred_label_drop = pred_label.dropna(subset=["score"])
# Group
t_df = pd.DataFrame(
{
"Group-%d"
% (i + 1): pred_label_drop.groupby(level="datetime")["label"].apply(
lambda x: x[len(x) // N * i : len(x) // N * (i + 1)].mean()
)
for i in range(N)
}
)
t_df.index = pd.to_datetime(t_df.index)
# Long-Short
t_df["long-short"] = t_df["Group-1"] - t_df["Group-%d" % N]
# Long-Average
t_df["long-average"] = (
t_df["Group-1"] - pred_label.groupby(level="datetime")["label"].mean()
)
t_df = t_df.dropna(how="all") # for days which does not contain label
# FIXME: support HIGH-FREQ
t_df.index = t_df.index.strftime("%Y-%m-%d")
# Cumulative Return By Group
group_scatter_figure = ScatterGraph(
t_df.cumsum(),
layout=dict(
title="Cumulative Return", xaxis=dict(type="category", tickangle=45)
),
).figure
t_df = t_df.loc[:, ["long-short", "long-average"]]
_bin_size = ((t_df.max() - t_df.min()) / 20).min()
group_hist_figure = SubplotsGraph(
t_df,
kind_map=dict(kind="DistplotGraph", kwargs=dict(bin_size=_bin_size)),
subplots_kwargs=dict(
rows=1,
cols=2,
print_grid=False,
subplot_titles=["long-short", "long-average"],
),
).figure
return group_scatter_figure, group_hist_figure
def _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure:
"""
:param data:
:param dist:
:return:
"""
fig, ax = plt.subplots(figsize=(8, 5))
_mpl_fig = sm.qqplot(data.dropna(), dist, fit=True, line="45", ax=ax)
return tls.mpl_to_plotly(_mpl_fig)
def _pred_ic(pred_label: pd.DataFrame = None, rank: bool = False, **kwargs) -> tuple:
"""
:param pred_label:
:param rank:
:return:
"""
if rank:
ic = pred_label.groupby(level="datetime").apply(
lambda x: x["label"].rank(pct=True).corr(x["score"].rank(pct=True))
)
else:
ic = pred_label.groupby(level="datetime").apply(
lambda x: x["label"].corr(x["score"])
)
_index = (
ic.index.get_level_values(0).astype("str").str.replace("-", "").str.slice(0, 6)
)
_monthly_ic = ic.groupby(_index).mean()
_monthly_ic.index = pd.MultiIndex.from_arrays(
[_monthly_ic.index.str.slice(0, 4), _monthly_ic.index.str.slice(4, 6)],
names=["year", "month"],
)
# fill month
_month_list = pd.date_range(
start=pd.Timestamp(f"{_index.min()[:4]}0101"),
end=pd.Timestamp(f"{_index.max()[:4]}1231"),
freq="1M",
)
_years = []
_month = []
for _date in _month_list:
_date = _date.strftime("%Y%m%d")
_years.append(_date[:4])
_month.append(_date[4:6])
fill_index = pd.MultiIndex.from_arrays([_years, _month], names=["year", "month"])
_monthly_ic = _monthly_ic.reindex(fill_index)
_ic_df = ic.to_frame("ic")
ic_bar_figure = ic_figure(_ic_df, kwargs.get("show_nature_day", True))
ic_heatmap_figure = HeatmapGraph(
_monthly_ic.unstack(),
layout=dict(title="Monthly IC", yaxis=dict(tickformat=",d")),
graph_kwargs=dict(xtype="array", ytype="array"),
).figure
dist = stats.norm
_qqplot_fig = _plot_qq(ic, dist)
if isinstance(dist, stats.norm.__class__):
dist_name = "Normal"
else:
dist_name = "Unknown"
_bin_size = ((_ic_df.max() - _ic_df.min()) / 20).min()
_sub_graph_data = [
(
"ic",
dict(
row=1,
col=1,
name="",
kind="DistplotGraph",
graph_kwargs=dict(bin_size=_bin_size),
),
),
(_qqplot_fig, dict(row=1, col=2)),
]
ic_hist_figure = SubplotsGraph(
_ic_df.dropna(),
kind_map=dict(kind="HistogramGraph", kwargs=dict()),
subplots_kwargs=dict(
rows=1,
cols=2,
print_grid=False,
subplot_titles=["IC", "IC %s Dist. Q-Q" % dist_name],
),
sub_graph_data=_sub_graph_data,
layout=dict(
yaxis2=dict(title="Observed Quantile"),
xaxis2=dict(title=f"{dist_name} Distribution Quantile"),
),
).figure
return ic_bar_figure, ic_heatmap_figure, ic_hist_figure
def _pred_autocorr(pred_label: pd.DataFrame, lag=1, **kwargs) -> tuple:
pred = pred_label.copy()
pred["score_last"] = pred.groupby(level="instrument")["score"].shift(lag)
ac = pred.groupby(level="datetime").apply(
lambda x: x["score"].rank(pct=True).corr(x["score_last"].rank(pct=True))
)
# FIXME: support HIGH-FREQ
_df = ac.to_frame("value")
_df.index = _df.index.strftime("%Y-%m-%d")
ac_figure = ScatterGraph(
_df,
layout=dict(
title="Auto Correlation", xaxis=dict(type="category", tickangle=45)
),
).figure
return (ac_figure,)
def _pred_turnover(pred_label: pd.DataFrame, N=5, lag=1, **kwargs) -> tuple:
pred = pred_label.copy()
pred["score_last"] = pred.groupby(level="instrument")["score"].shift(lag)
top = pred.groupby(level="datetime").apply(
lambda x: 1
- x.nlargest(len(x) // N, columns="score")
.index.isin(x.nlargest(len(x) // N, columns="score_last").index)
.sum()
/ (len(x) // N)
)
bottom = pred.groupby(level="datetime").apply(
lambda x: 1
- x.nsmallest(len(x) // N, columns="score")
.index.isin(x.nsmallest(len(x) // N, columns="score_last").index)
.sum()
/ (len(x) // N)
)
r_df = pd.DataFrame({"Top": top, "Bottom": bottom,})
# FIXME: support HIGH-FREQ
r_df.index = r_df.index.strftime("%Y-%m-%d")
turnover_figure = ScatterGraph(
r_df,
layout=dict(
title="Top-Bottom Turnover", xaxis=dict(type="category", tickangle=45)
),
).figure
return (turnover_figure,)
def ic_figure(ic_df: pd.DataFrame, show_nature_day=True, **kwargs) -> go.Figure:
"""IC figure
:param ic_df: ic DataFrame
:param show_nature_day: whether to display the abscissa of non-trading day
:return: plotly.graph_objs.Figure
"""
if show_nature_day:
date_index = pd.date_range(ic_df.index.min(), ic_df.index.max())
ic_df = ic_df.reindex(date_index)
# FIXME: support HIGH-FREQ
ic_df.index = ic_df.index.strftime("%Y-%m-%d")
ic_bar_figure = BarGraph(
ic_df,
layout=dict(
title="Information Coefficient (IC)",
xaxis=dict(type="category", tickangle=45),
),
).figure
return ic_bar_figure
def model_performance_graph(
pred_label: pd.DataFrame,
lag: int = 1,
N: int = 5,
reverse=False,
rank=False,
graph_names: list = ["group_return", "pred_ic", "pred_autocorr"],
show_notebook: bool = True,
show_nature_day=True,
) -> [list, tuple]:
"""Model performance
:param pred_label: index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[score, label]**
.. code-block:: python
instrument datetime score label
SH600004 2017-12-11 -0.013502 -0.013502
2017-12-12 -0.072367 -0.072367
2017-12-13 -0.068605 -0.068605
2017-12-14 0.012440 0.012440
2017-12-15 -0.102778 -0.102778
:param lag: `pred.groupby(level='instrument')['score'].shift(lag)`. It will be only used in the auto-correlation computing.
:param N: group number, default 5
:param reverse: if `True`, `pred['score'] *= -1`
:param rank: if **True**, calculate rank ic
:param graph_names: graph names; default ['cumulative_return', 'pred_ic', 'pred_autocorr', 'pred_turnover']
:param show_notebook: whether to display graphics in notebook, the default is `True`
:param show_nature_day: whether to display the abscissa of non-trading day
:return: if show_notebook is True, display in notebook; else return `plotly.graph_objs.Figure` list
"""
figure_list = []
for graph_name in graph_names:
fun_res = eval(f"_{graph_name}")(
pred_label=pred_label,
lag=lag,
N=N,
reverse=reverse,
rank=rank,
show_nature_day=show_nature_day,
)
figure_list += fun_res
if show_notebook:
BarGraph.show_graph_in_notebook(figure_list)
else:
return figure_list

View File

@@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .cumulative_return import cumulative_return_graph
from .score_ic import score_ic_graph
from .report import report_graph
from .rank_label import rank_label_graph
from .risk_analysis import risk_analysis_graph

View File

@@ -0,0 +1,281 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import copy
from typing import Iterable
import pandas as pd
import plotly.graph_objs as go
from ..graph import BaseGraph, SubplotsGraph
from ..analysis_position.parse_position import get_position_data
def _get_cum_return_data_with_position(
position: dict,
report_normal: pd.DataFrame,
label_data: pd.DataFrame,
start_date=None,
end_date=None,
):
"""
:param position:
:param report_normal:
:param label_data:
:param start_date:
:param end_date:
:return:
"""
_cumulative_return_df = get_position_data(
position=position,
report_normal=report_normal,
label_data=label_data,
start_date=start_date,
end_date=end_date,
).copy()
_cumulative_return_df["label"] = (
_cumulative_return_df["label"] - _cumulative_return_df["bench"]
)
_cumulative_return_df = _cumulative_return_df.dropna()
df_gp = _cumulative_return_df.groupby(level="datetime")
result_list = []
for gp in df_gp:
date = gp[0]
day_df = gp[1]
_hold_df = day_df[day_df["status"] == 0]
_buy_df = day_df[day_df["status"] == 1]
_sell_df = day_df[day_df["status"] == -1]
hold_value = (_hold_df["label"] * _hold_df["weight"]).sum()
hold_weight = _hold_df["weight"].sum()
hold_mean = (hold_value / hold_weight) if hold_weight else 0
sell_value = (_sell_df["label"] * _sell_df["weight"]).sum()
sell_weight = _sell_df["weight"].sum()
sell_mean = (sell_value / sell_weight) if sell_weight else 0
buy_value = (_buy_df["label"] * _buy_df["weight"]).sum()
buy_weight = _buy_df["weight"].sum()
buy_mean = (buy_value / buy_weight) if buy_weight else 0
result_list.append(
dict(
hold_value=hold_value,
hold_mean=hold_mean,
hold_weight=hold_weight,
buy_value=buy_value,
buy_mean=buy_mean,
buy_weight=buy_weight,
sell_value=sell_value,
sell_mean=sell_mean,
sell_weight=sell_weight,
buy_minus_sell_value=buy_value - sell_value,
buy_minus_sell_mean=buy_mean - sell_mean,
buy_plus_sell_weight=buy_weight + sell_weight,
date=date,
)
)
r_df = pd.DataFrame(data=result_list)
r_df["cum_hold"] = r_df["hold_mean"].cumsum()
r_df["cum_buy"] = r_df["buy_mean"].cumsum()
r_df["cum_sell"] = r_df["sell_mean"].cumsum()
r_df["cum_buy_minus_sell"] = r_df["buy_minus_sell_mean"].cumsum()
return r_df
def _get_figure_with_position(
position: dict,
report_normal: pd.DataFrame,
label_data: pd.DataFrame,
start_date=None,
end_date=None,
) -> Iterable[go.Figure]:
"""Get average analysis figures
:param position: position
:param report_normal:
:param label_data:
:param start_date:
:param end_date:
:return:
"""
cum_return_df = _get_cum_return_data_with_position(
position, report_normal, label_data, start_date, end_date
)
cum_return_df = cum_return_df.set_index("date")
# FIXME: support HIGH-FREQ
cum_return_df.index = cum_return_df.index.strftime('%Y-%m-%d')
# Create figures
for _t_name in ["buy", "sell", "buy_minus_sell", "hold"]:
sub_graph_data = [
(
"cum_{}".format(_t_name),
dict(
row=1, col=1, graph_kwargs={"mode": "lines+markers", "xaxis": "x3"}
),
),
(
"{}_weight".format(
_t_name.replace("minus", "plus") if "minus" in _t_name else _t_name
),
dict(row=2, col=1),
),
(
"{}_value".format(_t_name),
dict(row=1, col=2, kind="HistogramGraph", graph_kwargs={}),
),
]
_default_xaxis = dict(showline=False, zeroline=True, tickangle=45)
_default_yaxis = dict(zeroline=True, showline=True, showticklabels=True)
sub_graph_layout = dict(
xaxis1=dict(**_default_xaxis, type="category", showticklabels=False),
xaxis3=dict(**_default_xaxis, type="category"),
xaxis2=_default_xaxis,
yaxis1=dict(**_default_yaxis, title=_t_name),
yaxis2=_default_yaxis,
yaxis3=_default_yaxis,
)
mean_value = cum_return_df["{}_value".format(_t_name)].mean()
layout = dict(
height=500,
title=f"{_t_name}(the red line in the histogram on the right represents the average)",
shapes=[
{
"type": "line",
"xref": "x2",
"yref": "paper",
"x0": mean_value,
"y0": 0,
"x1": mean_value,
"y1": 1,
# NOTE: 'fillcolor': '#d3d3d3', 'opacity': 0.3,
"line": {"color": "red", "width": 1},
},
],
)
kind_map = dict(kind="ScatterGraph", kwargs=dict(mode="lines+markers"))
specs = [
[{"rowspan": 1}, {"rowspan": 2}],
[{"rowspan": 1}, None],
]
subplots_kwargs = dict(
vertical_spacing=0.01,
rows=2,
cols=2,
row_width=[1, 2],
column_width=[3, 1],
print_grid=False,
specs=specs,
)
yield SubplotsGraph(
cum_return_df,
layout=layout,
kind_map=kind_map,
sub_graph_layout=sub_graph_layout,
sub_graph_data=sub_graph_data,
subplots_kwargs=subplots_kwargs,
).figure
def cumulative_return_graph(
position: dict,
report_normal: pd.DataFrame,
label_data: pd.DataFrame,
show_notebook=True,
start_date=None,
end_date=None,
) -> Iterable[go.Figure]:
"""Backtest buy, sell, and holding cumulative return graph
Example:
.. code-block:: python
from qlib.data import D
from qlib.contrib.evaluate import risk_analysis, backtest, long_short_backtest
from qlib.contrib.strategy import TopkDropoutStrategy
# backtest parameters
bparas = {}
bparas['limit_threshold'] = 0.095
bparas['account'] = 1000000000
sparas = {}
sparas['topk'] = 50
sparas['n_drop'] = 5
strategy = TopkDropoutStrategy(**sparas)
report_normal_df, positions = backtest(pred_df, strategy, **bparas)
pred_df_dates = pred_df.index.get_level_values(level='datetime')
features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max())
features_df.columns = ['label']
qcr.cumulative_return_graph(positions, report_normal_df, features_df)
Graph desc:
- Axis X: Trading day
- Axis Y:
- Above axis Y: (((Ref($close, -1)/$close - 1) * weight).sum() / weight.sum()).cumsum()
- Below axis Y: Daily weight sum
- In the sell graph, y < 0 stands for profit; in other cases, y > 0 stands for profit.
- In the buy_minus_sell graph, the y value of the weight graph at the bottom is buy_weight + sell_weight.
- In each graph, the red line in the histogram on the right represents the average.
:param position: position data
:param report_normal:
.. code-block:: python
return cost bench turnover
date
2017-01-04 0.003421 0.000864 0.011693 0.576325
2017-01-05 0.000508 0.000447 0.000721 0.227882
2017-01-06 -0.003321 0.000212 -0.004322 0.102765
2017-01-09 0.006753 0.000212 0.006874 0.105864
2017-01-10 -0.000416 0.000440 -0.003350 0.208396
:param label_data: `D.features` result; index is `pd.MultiIndex`, index name is [`instrument`, `datetime`]; columns names is [`label`].
**The ``label`` T is the change from T to T+1**, it is recommended to use ``close``, example: D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'])
.. code-block:: python
label
instrument datetime
SH600004 2017-12-11 -0.013502
2017-12-12 -0.072367
2017-12-13 -0.068605
2017-12-14 0.012440
2017-12-15 -0.102778
:param show_notebook: True or False. If True, show graph in notebook, else return figures
:param start_date: start date
:param end_date: end date
:return:
"""
position = copy.deepcopy(position)
report_normal = report_normal.copy()
label_data.columns = ["label"]
_figures = _get_figure_with_position(
position, report_normal, label_data, start_date, end_date
)
if show_notebook:
BaseGraph.show_graph_in_notebook(_figures)
else:
return _figures

View File

@@ -0,0 +1,187 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
from ...backtest.profit_attribution import get_stock_weight_df
def parse_position(position: dict = None) -> pd.DataFrame:
"""Parse position dict to position DataFrame
:param position: position data
:return: position DataFrame;
.. code-block:: python
position_df = parse_position(positions)
print(position_df.head())
# status: 0-hold, -1-sell, 1-buy
amount cash count price status weight
instrument datetime
SZ000547 2017-01-04 44.154290 211405.285654 1 205.189575 1 0.031255
SZ300202 2017-01-04 60.638845 211405.285654 1 154.356506 1 0.032290
SH600158 2017-01-04 46.531681 211405.285654 1 153.895142 1 0.024704
SH600545 2017-01-04 197.173093 211405.285654 1 48.607037 1 0.033063
SZ000930 2017-01-04 103.938300 211405.285654 1 80.759453 1 0.028958
"""
position_weight_df = get_stock_weight_df(position)
# If the day does not exist, use the last weight
position_weight_df.fillna(method="ffill", inplace=True)
previous_data = {"date": None, "code_list": []}
result_df = pd.DataFrame()
for _trading_date, _value in position.items():
# pd_date type: pd.Timestamp
_cash = _value.pop("cash")
for _item in ["today_account_value"]:
if _item in _value:
_value.pop(_item)
_trading_day_df = pd.DataFrame.from_dict(_value, orient="index")
_trading_day_df["weight"] = position_weight_df.loc[_trading_date]
_trading_day_df["cash"] = _cash
_trading_day_df["date"] = _trading_date
# status: 0-hold, -1-sell, 1-buy
_trading_day_df["status"] = 0
# T not exist, T-1 exist, T sell
_cur_day_sell = set(previous_data["code_list"]) - set(_trading_day_df.index)
# T exist, T-1 not exist, T buy
_cur_day_buy = set(_trading_day_df.index) - set(previous_data["code_list"])
# Trading day buy
_trading_day_df.loc[_trading_day_df.index.isin(_cur_day_buy), "status"] = 1
# Trading day sell
if not result_df.empty:
_trading_day_sell_df = result_df.loc[
(result_df["date"] == previous_data["date"])
& (result_df.index.isin(_cur_day_sell))
].copy()
if not _trading_day_sell_df.empty:
_trading_day_sell_df["status"] = -1
_trading_day_sell_df["date"] = _trading_date
_trading_day_df = _trading_day_df.append(
_trading_day_sell_df, sort=False
)
result_df = result_df.append(_trading_day_df, sort=True)
previous_data = dict(
date=_trading_date,
code_list=_trading_day_df[_trading_day_df["status"] != -1].index,
)
result_df.reset_index(inplace=True)
result_df.rename(columns={"date": "datetime", "index": "instrument"}, inplace=True)
return result_df.set_index(["instrument", "datetime"])
def _add_label_to_position(
position_df: pd.DataFrame, label_data: pd.DataFrame
) -> pd.DataFrame:
"""Concat position with custom label
:param position_df: position DataFrame
:param label_data:
:return: concat result
"""
_start_time = position_df.index.get_level_values(level="datetime").min()
_end_time = position_df.index.get_level_values(level="datetime").max()
label_data = label_data.loc(axis=0)[:, pd.to_datetime(_start_time) :]
_result_df = pd.concat([position_df, label_data], axis=1, sort=True).reindex(
label_data.index
)
_result_df = _result_df.loc[_result_df.index.get_level_values(1) <= _end_time]
return _result_df
def _add_bench_to_position(
position_df: pd.DataFrame = None, bench: pd.Series = None
) -> pd.DataFrame:
"""Concat position with bench
:param position_df: position DataFrame
:param bench: report normal data
:return: concat result
"""
_temp_df = position_df.reset_index(level="instrument")
# FIXME: After the stock is bought and sold, the rise and fall of the next trading day are calculated.
_temp_df["bench"] = bench.shift(-1)
res_df = _temp_df.set_index(["instrument", _temp_df.index])
return res_df
def _calculate_label_rank(df: pd.DataFrame) -> pd.DataFrame:
"""calculate label rank
:param df:
:return:
"""
_label_name = "label"
def _calculate_day_value(g_df: pd.DataFrame):
g_df = g_df.copy()
g_df["rank_ratio"] = g_df[_label_name].rank(ascending=False) / len(g_df) * 100
# Sell: -1, Hold: 0, Buy: 1
for i in [-1, 0, 1]:
g_df.loc[g_df["status"] == i, "rank_label_mean"] = g_df[
g_df["status"] == i
]["rank_ratio"].mean()
g_df["excess_return"] = g_df[_label_name] - g_df[_label_name].mean()
return g_df
return df.groupby(level="datetime").apply(_calculate_day_value)
def get_position_data(
position: dict,
label_data: pd.DataFrame,
report_normal: pd.DataFrame = None,
calculate_label_rank=False,
start_date=None,
end_date=None,
) -> pd.DataFrame:
"""Concat position data with pred/report_normal
:param position: position data
:param report_normal: report normal, must be container 'bench' column
:param label_data:
:param calculate_label_rank:
:param start_date: start date
:param end_date: end date
:return: concat result,
columns: ['amount', 'cash', 'count', 'price', 'status', 'weight', 'label',
'rank_ratio', 'rank_label_mean', 'excess_return', 'score', 'bench']
index: ['instrument', 'date']
"""
_position_df = parse_position(position)
# Add custom_label, rank_ratio, rank_mean, and excess_return field
_position_df = _add_label_to_position(_position_df, label_data)
if calculate_label_rank:
_position_df = _calculate_label_rank(_position_df)
if report_normal is not None:
# Add bench field
_position_df = _add_bench_to_position(_position_df, report_normal["bench"])
_date_list = _position_df.index.get_level_values(level="datetime")
start_date = _date_list.min() if start_date is None else start_date
end_date = _date_list.max() if end_date is None else end_date
_position_df = _position_df.loc[
(start_date <= _date_list) & (_date_list <= end_date)
]
return _position_df

View File

@@ -0,0 +1,127 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import copy
from typing import Iterable
import pandas as pd
import plotly.graph_objs as go
from ..graph import ScatterGraph
from ..analysis_position.parse_position import get_position_data
def _get_figure_with_position(
position: dict, label_data: pd.DataFrame, start_date=None, end_date=None
) -> Iterable[go.Figure]:
"""Get average analysis figures
:param position: position
:param label_data:
:param start_date:
:param end_date:
:return:
"""
_position_df = get_position_data(
position,
label_data,
calculate_label_rank=True,
start_date=start_date,
end_date=end_date,
)
res_dict = dict()
_pos_gp = _position_df.groupby(level=1)
for _item in _pos_gp:
_date = _item[0]
_day_df = _item[1]
_day_value = res_dict.setdefault(_date, {})
for _i, _name in {0: "Hold", 1: "Buy", -1: "Sell"}.items():
_temp_df = _day_df[_day_df["status"] == _i]
if _temp_df.empty:
_day_value[_name] = 0
else:
_day_value[_name] = _temp_df["rank_label_mean"].values[0]
_res_df = pd.DataFrame.from_dict(res_dict, orient="index")
# FIXME: support HIGH-FREQ
_res_df.index = _res_df.index.strftime('%Y-%m-%d')
for _col in _res_df.columns:
yield ScatterGraph(
_res_df.loc[:, [_col]],
layout=dict(
title=_col,
xaxis=dict(type="category", tickangle=45),
yaxis=dict(title="lable-rank-ratio: %"),
),
graph_kwargs=dict(mode="lines+markers"),
).figure
def rank_label_graph(
position: dict,
label_data: pd.DataFrame,
start_date=None,
end_date=None,
show_notebook=True,
) -> Iterable[go.Figure]:
"""Ranking percentage of stocks buy, sell, and holding on the trading day.
Average rank-ratio(similar to **sell_df['label'].rank(ascending=False) / len(sell_df)**) of daily trading
Example:
.. code-block:: python
from qlib.data import D
from qlib.contrib.evaluate import backtest
from qlib.contrib.strategy import TopkDropoutStrategy
# backtest parameters
bparas = {}
bparas['limit_threshold'] = 0.095
bparas['account'] = 1000000000
sparas = {}
sparas['topk'] = 50
sparas['n_drop'] = 230
strategy = TopkDropoutStrategy(**sparas)
_, positions = backtest(pred_df, strategy, **bparas)
pred_df_dates = pred_df.index.get_level_values(level='datetime')
features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'], pred_df_dates.min(), pred_df_dates.max())
features_df.columns = ['label']
qcr.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
:param position: position data; **qlib.contrib.backtest.backtest.backtest** result
:param label_data: **D.features** result; index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[label]**.
**The ``label`` T is the change from T to T+1**, it is recommended to use ``close``, example: D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'])
.. code-block:: python
label
instrument datetime
SH600004 2017-12-11 -0.013502
2017-12-12 -0.072367
2017-12-13 -0.068605
2017-12-14 0.012440
2017-12-15 -0.102778
:param start_date: start date
:param end_date: end_date
:param show_notebook: **True** or **False**. If True, show graph in notebook, else return figures
:return:
"""
position = copy.deepcopy(position)
label_data.columns = ["label"]
_figures = _get_figure_with_position(position, label_data, start_date, end_date)
if show_notebook:
ScatterGraph.show_graph_in_notebook(_figures)
else:
return _figures

View File

@@ -0,0 +1,220 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
from ..graph import SubplotsGraph, BaseGraph
def _calculate_maximum(df: pd.DataFrame, is_ex: bool = False):
"""
:param df:
:param is_ex:
:return:
"""
if is_ex:
end_date = df["cum_ex_return_wo_cost_mdd"].idxmin()
start_date = df.loc[df.index <= end_date]["cum_ex_return_wo_cost"].idxmax()
else:
end_date = df["return_wo_mdd"].idxmin()
start_date = df.loc[df.index <= end_date]["cum_return_wo_cost"].idxmax()
return start_date, end_date
def _calculate_mdd(series):
"""
Calculate mdd
:param series:
:return:
"""
return series - series.cummax()
def _calculate_report_data(df: pd.DataFrame) -> pd.DataFrame:
"""
:param df:
:return:
"""
df.index = df.index.strftime("%Y-%m-%d")
report_df = pd.DataFrame()
report_df["cum_bench"] = df["bench"].cumsum()
report_df["cum_return_wo_cost"] = df["return"].cumsum()
report_df["cum_return_w_cost"] = (df["return"] - df["cost"]).cumsum()
# report_df['cum_return'] - report_df['cum_return'].cummax()
report_df["return_wo_mdd"] = _calculate_mdd(report_df["cum_return_wo_cost"])
report_df["return_w_cost_mdd"] = _calculate_mdd(
(df["return"] - df["cost"]).cumsum()
)
report_df["cum_ex_return_wo_cost"] = (df["return"] - df["bench"]).cumsum()
report_df["cum_ex_return_w_cost"] = (
df["return"] - df["bench"] - df["cost"]
).cumsum()
report_df["cum_ex_return_wo_cost_mdd"] = _calculate_mdd(
(df["return"] - df["bench"]).cumsum()
)
report_df["cum_ex_return_w_cost_mdd"] = _calculate_mdd(
(df["return"] - df["cost"] - df["bench"]).cumsum()
)
# return_wo_mdd , return_w_cost_mdd, cum_ex_return_wo_cost_mdd, cum_ex_return_w
report_df["turnover"] = df["turnover"]
report_df.sort_index(ascending=True, inplace=True)
return report_df
def _report_figure(df: pd.DataFrame) -> [list, tuple]:
"""
:param df:
:return:
"""
# Get data
report_df = _calculate_report_data(df)
# Maximum Drawdown
max_start_date, max_end_date = _calculate_maximum(report_df)
ex_max_start_date, ex_max_end_date = _calculate_maximum(report_df, True)
_temp_df = report_df.reset_index()
_temp_df.loc[-1] = 0
_temp_df = _temp_df.shift(1)
_temp_df.loc[0, "index"] = "T0"
_temp_df.set_index("index", inplace=True)
_temp_df.iloc[0] = 0
report_df = _temp_df
# Create figure
_default_kind_map = dict(kind="ScatterGraph", kwargs={"mode": "lines+markers"})
_temp_fill_args = {"fill": "tozeroy", "mode": "lines+markers"}
_column_row_col_dict = [
("cum_bench", dict(row=1, col=1)),
("cum_return_wo_cost", dict(row=1, col=1)),
("cum_return_w_cost", dict(row=1, col=1)),
("return_wo_mdd", dict(row=2, col=1, graph_kwargs=_temp_fill_args)),
("return_w_cost_mdd", dict(row=3, col=1, graph_kwargs=_temp_fill_args)),
("cum_ex_return_wo_cost", dict(row=4, col=1)),
("cum_ex_return_w_cost", dict(row=4, col=1)),
("turnover", dict(row=5, col=1)),
("cum_ex_return_w_cost_mdd", dict(row=6, col=1, graph_kwargs=_temp_fill_args)),
("cum_ex_return_wo_cost_mdd", dict(row=7, col=1, graph_kwargs=_temp_fill_args)),
]
_subplot_layout = dict(
xaxis=dict(showline=True, type="category", tickangle=45),
yaxis=dict(zeroline=True, showline=True, showticklabels=True),
)
for i in range(2, 8):
# yaxis
_subplot_layout.update(
{
"yaxis{}".format(i): dict(
zeroline=True, showline=True, showticklabels=True
)
}
)
_layout_style = dict(
height=1200,
title=" ",
shapes=[
{
"type": "rect",
"xref": "x",
"yref": "paper",
"x0": max_start_date,
"y0": 0.55,
"x1": max_end_date,
"y1": 1,
"fillcolor": "#d3d3d3",
"opacity": 0.3,
"line": {"width": 0,},
},
{
"type": "rect",
"xref": "x",
"yref": "paper",
"x0": ex_max_start_date,
"y0": 0,
"x1": ex_max_end_date,
"y1": 0.55,
"fillcolor": "#d3d3d3",
"opacity": 0.3,
"line": {"width": 0,},
},
],
)
_subplot_kwargs = dict(
shared_xaxes=True,
vertical_spacing=0.01,
rows=7,
cols=1,
row_width=[1, 1, 1, 3, 1, 1, 3],
print_grid=False,
)
figure = SubplotsGraph(
df=report_df,
layout=_layout_style,
sub_graph_data=_column_row_col_dict,
subplots_kwargs=_subplot_kwargs,
kind_map=_default_kind_map,
sub_graph_layout=_subplot_layout,
).figure
return (figure,)
def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list, tuple]:
"""display backtest report
Example:
.. code-block:: python
from qlib.contrib.evaluate import backtest
from qlib.contrib.strategy import TopkDropoutStrategy
# backtest parameters
bparas = {}
bparas['limit_threshold'] = 0.095
bparas['account'] = 1000000000
sparas = {}
sparas['topk'] = 50
sparas['n_drop'] = 230
strategy = TopkDropoutStrategy(**sparas)
report_normal_df, _ = backtest(pred_df, strategy, **bparas)
qcr.report_graph(report_normal_df)
:param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench**
.. code-block:: python
return cost bench turnover
date
2017-01-04 0.003421 0.000864 0.011693 0.576325
2017-01-05 0.000508 0.000447 0.000721 0.227882
2017-01-06 -0.003321 0.000212 -0.004322 0.102765
2017-01-09 0.006753 0.000212 0.006874 0.105864
2017-01-10 -0.000416 0.000440 -0.003350 0.208396
:param show_notebook: whether to display graphics in notebook, the default is **True**
:return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list
"""
report_df = report_df.copy()
fig_list = _report_figure(report_df)
if show_notebook:
BaseGraph.show_graph_in_notebook(fig_list)
else:
return fig_list

View File

@@ -0,0 +1,271 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Iterable
import pandas as pd
import plotly.graph_objs as py
from ...evaluate import risk_analysis
from ..graph import SubplotsGraph, ScatterGraph
def _get_risk_analysis_data_with_report(
report_normal_df: pd.DataFrame,
# report_long_short_df: pd.DataFrame,
date: pd.Timestamp,
) -> pd.DataFrame:
"""Get risk analysis data with report
:param report_normal_df: report data
:param report_long_short_df: report data
:param date: date string
:return:
"""
analysis = dict()
# if not report_long_short_df.empty:
# analysis["pred_long"] = risk_analysis(report_long_short_df["long"])
# analysis["pred_short"] = risk_analysis(report_long_short_df["short"])
# analysis["pred_long_short"] = risk_analysis(report_long_short_df["long_short"])
if not report_normal_df.empty:
analysis["sub_bench"] = risk_analysis(
report_normal_df["return"] - report_normal_df["bench"]
)
analysis["sub_cost"] = risk_analysis(
report_normal_df["return"]
- report_normal_df["bench"]
- report_normal_df["cost"]
)
analysis_df = pd.concat(analysis) # type: pd.DataFrame
analysis_df["date"] = date
return analysis_df
def _get_all_risk_analysis(risk_df: pd.DataFrame) -> pd.DataFrame:
"""risk_df to standard
:param risk_df: risk data
:return:
"""
if risk_df is None:
return pd.DataFrame()
risk_df = risk_df.unstack()
risk_df.columns = risk_df.columns.droplevel(0)
return risk_df.drop("mean", axis=1)
def _get_monthly_risk_analysis_with_report(report_normal_df: pd.DataFrame) -> pd.DataFrame:
"""Get monthly analysis data
:param report_normal_df:
# :param report_long_short_df:
:return:
"""
# Group by month
report_normal_gp = report_normal_df.groupby(
[report_normal_df.index.year, report_normal_df.index.month]
)
# report_long_short_gp = report_long_short_df.groupby(
# [report_long_short_df.index.year, report_long_short_df.index.month]
# )
gp_month = sorted(set(report_normal_gp.size().index))
_monthly_df = pd.DataFrame()
for gp_m in gp_month:
_m_report_normal = report_normal_gp.get_group(gp_m)
# _m_report_long_short = report_long_short_gp.get_group(gp_m)
if len(_m_report_normal) < 3:
# The month's data is less than 3, not displayed
# FIXME: If the trading day of a month is less than 3 days, a breakpoint will appear in the graph
continue
month_days = pd.Timestamp(year=gp_m[0], month=gp_m[1], day=1).days_in_month
_temp_df = _get_risk_analysis_data_with_report(
_m_report_normal,
# _m_report_long_short,
pd.Timestamp(year=gp_m[0], month=gp_m[1], day=month_days),
)
_monthly_df = _monthly_df.append(_temp_df, sort=False)
return _monthly_df
def _get_monthly_analysis_with_feature(
monthly_df: pd.DataFrame, feature: str = "annual"
) -> pd.DataFrame:
"""
:param monthly_df:
:param feature:
:return:
"""
_monthly_df_gp = monthly_df.reset_index().groupby(["level_1"])
_name_df = _monthly_df_gp.get_group(feature).set_index(["level_0", "level_1"])
_temp_df = _name_df.pivot_table(
index="date", values=["risk"], columns=_name_df.index
)
_temp_df.columns = map(lambda x: "_".join(x[-1]), _temp_df.columns)
_temp_df.index = _temp_df.index.strftime("%Y-%m")
return _temp_df
def _get_risk_analysis_figure(analysis_df: pd.DataFrame) -> Iterable[py.Figure]:
"""Get analysis graph figure
:param analysis_df:
:return:
"""
if analysis_df is None:
return []
_figure = SubplotsGraph(
_get_all_risk_analysis(analysis_df), kind_map=dict(kind="BarGraph", kwargs={})
).figure
return (_figure,)
def _get_monthly_risk_analysis_figure(report_normal_df: pd.DataFrame) -> Iterable[py.Figure]:
"""Get analysis monthly graph figure
:param report_normal_df:
:param report_long_short_df:
:return:
"""
# if report_normal_df is None and report_long_short_df is None:
# return []
if report_normal_df is None:
return []
# if report_normal_df is None:
# report_normal_df = pd.DataFrame(index=report_long_short_df.index)
# if report_long_short_df is None:
# report_long_short_df = pd.DataFrame(index=report_normal_df.index)
_monthly_df = _get_monthly_risk_analysis_with_report(
report_normal_df=report_normal_df,
# report_long_short_df=report_long_short_df,
)
for _feature in ["annual", "mdd", "sharpe", "std"]:
_temp_df = _get_monthly_analysis_with_feature(_monthly_df, _feature)
yield ScatterGraph(
_temp_df,
layout=dict(title=_feature, xaxis=dict(type="category", tickangle=45)),
graph_kwargs={"mode": "lines+markers"},
).figure
def risk_analysis_graph(
analysis_df: pd.DataFrame = None,
report_normal_df: pd.DataFrame = None,
report_long_short_df: pd.DataFrame = None,
show_notebook: bool = True,
) -> Iterable[py.Figure]:
"""Generate analysis graph and monthly analysis
Example:
.. code-block:: python
from qlib.contrib.evaluate import risk_analysis, backtest, long_short_backtest
from qlib.contrib.strategy import TopkDropoutStrategy
from qlib.contrib.report import analysis_position
# backtest parameters
bparas = {}
bparas['limit_threshold'] = 0.095
bparas['account'] = 1000000000
sparas = {}
sparas['topk'] = 50
sparas['n_drop'] = 230
strategy = TopkDropoutStrategy(**sparas)
report_normal_df, positions = backtest(pred_df, strategy, **bparas)
# long_short_map = long_short_backtest(pred_df)
# report_long_short_df = pd.DataFrame(long_short_map)
analysis = dict()
# analysis['pred_long'] = risk_analysis(report_long_short_df['long'])
# analysis['pred_short'] = risk_analysis(report_long_short_df['short'])
# analysis['pred_long_short'] = risk_analysis(report_long_short_df['long_short'])
analysis['sub_bench'] = risk_analysis(report_normal_df['return'] - report_normal_df['bench'])
analysis['sub_cost'] = risk_analysis(report_normal_df['return'] - report_normal_df['bench'] - report_normal_df['cost'])
analysis_df = pd.concat(analysis)
analysis_position.risk_analysis_graph(analysis_df, report_normal_df)
:param analysis_df: analysis data, index is **pd.MultiIndex**; columns names is **[risk]**.
.. code-block:: python
risk
sub_bench mean 0.000662
std 0.004487
annual 0.166720
sharpe 2.340526
mdd -0.080516
sub_cost mean 0.000577
std 0.004482
annual 0.145392
sharpe 2.043494
mdd -0.083584
:param report_normal_df: **df.index.name** must be **date**, df.columns must contain **return**, **turnover**, **cost**, **bench**
.. code-block:: python
return cost bench turnover
date
2017-01-04 0.003421 0.000864 0.011693 0.576325
2017-01-05 0.000508 0.000447 0.000721 0.227882
2017-01-06 -0.003321 0.000212 -0.004322 0.102765
2017-01-09 0.006753 0.000212 0.006874 0.105864
2017-01-10 -0.000416 0.000440 -0.003350 0.208396
:param report_long_short_df: **df.index.name** must be **date**, df.columns contain **long**, **short**, **long_short**
.. code-block:: python
long short long_short
date
2017-01-04 -0.001360 0.001394 0.000034
2017-01-05 0.002456 0.000058 0.002514
2017-01-06 0.000120 0.002739 0.002859
2017-01-09 0.001436 0.001838 0.003273
2017-01-10 0.000824 -0.001944 -0.001120
:param show_notebook: Whether to display graphics in a notebook, default **True**
If True, show graph in notebook
If False, return graph figure
:return:
"""
_figure_list = list(_get_risk_analysis_figure(analysis_df)) + list(
_get_monthly_risk_analysis_figure(
report_normal_df,
# report_long_short_df,
)
)
if show_notebook:
ScatterGraph.show_graph_in_notebook(_figure_list)
else:
return _figure_list

View File

@@ -0,0 +1,72 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
from ..graph import ScatterGraph
def _get_score_ic(pred_label: pd.DataFrame):
"""
:param pred_label:
:return:
"""
concat_data = pred_label.copy()
concat_data.dropna(axis=0, how="any", inplace=True)
_ic = concat_data.groupby(level="datetime").apply(
lambda x: x["label"].corr(x["score"])
)
_rank_ic = concat_data.groupby(level="datetime").apply(
lambda x: x["label"].corr(x["score"], method="spearman")
)
return pd.DataFrame({"ic": _ic, "rank_ic": _rank_ic})
def score_ic_graph(
pred_label: pd.DataFrame, show_notebook: bool = True
) -> [list, tuple]:
"""score IC
Example:
.. code-block:: python
from qlib.data import D
from qlib.contrib.report import analysis_position
pred_df_dates = pred_df.index.get_level_values(level='datetime')
features_df = D.features(D.instruments('csi500'), ['Ref($close, -2)/Ref($close, -1)-1'], pred_df_dates.min(), pred_df_dates.max())
features_df.columns = ['label']
pred_label = pd.concat([features_df, pred], axis=1, sort=True).reindex(features_df.index)
analysis_position.score_ic_graph(pred_label)
:param pred_label: index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[score, label]**
.. code-block:: python
instrument datetime score label
SH600004 2017-12-11 -0.013502 -0.013502
2017-12-12 -0.072367 -0.072367
2017-12-13 -0.068605 -0.068605
2017-12-14 0.012440 0.012440
2017-12-15 -0.102778 -0.102778
:param show_notebook: whether to display graphics in notebook, the default is **True**
:return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list
"""
_ic_df = _get_score_ic(pred_label)
# FIXME: support HIGH-FREQ
_ic_df.index = _ic_df.index.strftime("%Y-%m-%d")
_figure = ScatterGraph(
_ic_df,
layout=dict(title="Score IC", xaxis=dict(type="category", tickangle=45)),
graph_kwargs={"mode": "lines+markers"},
).figure
if show_notebook:
ScatterGraph.show_graph_in_notebook([_figure])
else:
return (_figure,)

View File

@@ -0,0 +1,370 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import math
import importlib
from pathlib import Path
from typing import Iterable
import pandas as pd
import plotly.offline as py
import plotly.graph_objs as go
from plotly.tools import make_subplots
from plotly.figure_factory import create_distplot
from ...utils import get_module_by_module_path
class BaseGraph(object):
""""""
_name = None
def __init__(
self, df: pd.DataFrame = None, layout: dict = None, graph_kwargs: dict = None, name_dict: dict = None, **kwargs
):
"""
:param df:
:param layout:
:param graph_kwargs:
:param name_dict:
:param kwargs:
layout: dict
go.Layout parameters
graph_kwargs: dict
Graph parameters, eg: go.Bar(**graph_kwargs)
"""
self._df = df
self._layout = dict() if layout is None else layout
self._graph_kwargs = dict() if graph_kwargs is None else graph_kwargs
self._name_dict = name_dict
self.data = None
self._init_parameters(**kwargs)
self._init_data()
def _init_data(self):
"""
:return:
"""
if self._df.empty:
raise ValueError("df is empty.")
self.data = self._get_data()
def _init_parameters(self, **kwargs):
"""
:param kwargs
"""
# Instantiate graphics parameters
self._graph_type = self._name.lower().capitalize()
# Displayed column name
if self._name_dict is None:
self._name_dict = {_item: _item for _item in self._df.columns}
@staticmethod
def get_instance_with_graph_parameters(graph_type: str = None, **kwargs):
"""
:param graph_type:
:param kwargs:
:return:
"""
try:
_graph_module = importlib.import_module("plotly.graph_objs")
_graph_class = getattr(_graph_module, graph_type)
except AttributeError:
_graph_module = importlib.import_module("qlib.contrib.report.graph")
_graph_class = getattr(_graph_module, graph_type)
return _graph_class(**kwargs)
@staticmethod
def show_graph_in_notebook(figure_list: Iterable[go.Figure] = None):
"""
:param figure_list:
:return:
"""
py.init_notebook_mode()
for _fig in figure_list:
py.iplot(_fig)
def _get_layout(self) -> go.Layout:
"""
:return:
"""
return go.Layout(**self._layout)
def _get_data(self) -> list:
"""
:return:
"""
_data = [
self.get_instance_with_graph_parameters(
graph_type=self._graph_type, x=self._df.index, y=self._df[_col], name=_name, **self._graph_kwargs
)
for _col, _name in self._name_dict.items()
]
return _data
@property
def figure(self) -> go.Figure:
"""
:return:
"""
return go.Figure(data=self.data, layout=self._get_layout())
class ScatterGraph(BaseGraph):
_name = "scatter"
class BarGraph(BaseGraph):
_name = "bar"
class DistplotGraph(BaseGraph):
_name = "distplot"
def _get_data(self):
"""
:return:
"""
_t_df = self._df.dropna()
_data_list = [_t_df[_col] for _col in self._name_dict]
_label_list = [_name for _name in self._name_dict.values()]
_fig = create_distplot(_data_list, _label_list, show_rug=False, **self._graph_kwargs)
return _fig["data"]
class HeatmapGraph(BaseGraph):
_name = "heatmap"
def _get_data(self):
"""
:return:
"""
_data = [
self.get_instance_with_graph_parameters(
graph_type=self._graph_type,
x=self._df.columns,
y=self._df.index,
z=self._df.values.tolist(),
**self._graph_kwargs
)
]
return _data
class HistogramGraph(BaseGraph):
_name = "histogram"
def _get_data(self):
"""
:return:
"""
_data = [
self.get_instance_with_graph_parameters(
graph_type=self._graph_type, x=self._df[_col], name=_name, **self._graph_kwargs
)
for _col, _name in self._name_dict.items()
]
return _data
class SubplotsGraph(object):
"""Create subplots same as df.plot(subplots=True)
Simple package for `plotly.tools.subplots`
"""
def __init__(
self,
df: pd.DataFrame = None,
kind_map: dict = None,
layout: dict = None,
sub_graph_layout: dict = None,
sub_graph_data: list = None,
subplots_kwargs: dict = None,
**kwargs
):
"""
:param df: pd.DataFrame
:param kind_map: dict, subplots graph kind and kwargs
eg: dict(kind='ScatterGraph', kwargs=dict())
:param layout: `go.Layout` parameters
:param sub_graph_layout: Layout of each graphic, similar to 'layout'
:param sub_graph_data: Instantiation parameters for each sub-graphic
eg: [(column_name, instance_parameters), ]
column_name: str or go.Figure
Instance_parameters:
- row: int, the row where the graph is located
- col: int, the col where the graph is located
- name: str, show name, default column_name in 'df'
- kind: str, graph kind, default `kind` param, eg: bar, scatter, ...
- graph_kwargs: dict, graph kwargs, default {}, used in `go.Bar(**graph_kwargs)`
:param subplots_kwargs: `plotly.tools.make_subplots` original parameters
- shared_xaxes: bool, default False
- shared_yaxes: bool, default False
- vertical_spacing: float, default 0.3 / rows
- subplot_titles: list, default []
If `sub_graph_data` is None, will generate 'subplot_titles' according to `df.columns`,
this field will be discarded
- specs: list, see `make_subplots` docs
- rows: int, Number of rows in the subplot grid, default 1
If `sub_graph_data` is None, will generate 'rows' according to `df`, this field will be discarded
- cols: int, Number of cols in the subplot grid, default 1
If `sub_graph_data` is None, will generate 'cols' according to `df`, this field will be discarded
:param kwargs:
"""
self._df = df
self._layout = layout
self._sub_graph_layout = sub_graph_layout
self._kind_map = kind_map
if self._kind_map is None:
self._kind_map = dict(kind="ScatterGraph", kwargs=dict())
self._subplots_kwargs = subplots_kwargs
if self._subplots_kwargs is None:
self._init_subplots_kwargs()
self.__cols = self._subplots_kwargs.get("cols", 2)
self.__rows = self._subplots_kwargs.get("rows", math.ceil(len(self._df.columns) / self.__cols))
self._sub_graph_data = sub_graph_data
if self._sub_graph_data is None:
self._init_sub_graph_data()
self._init_figure()
def _init_sub_graph_data(self):
"""
:return:
"""
self._sub_graph_data = list()
self._subplot_titles = list()
for i, column_name in enumerate(self._df.columns):
row = math.ceil((i + 1) / self.__cols)
_temp = (i + 1) % self.__cols
col = _temp if _temp else self.__cols
res_name = column_name.replace("_", " ")
_temp_row_data = (
column_name,
dict(
row=row,
col=col,
name=res_name,
kind=self._kind_map["kind"],
graph_kwargs=self._kind_map["kwargs"],
),
)
self._sub_graph_data.append(_temp_row_data)
self._subplot_titles.append(res_name)
def _init_subplots_kwargs(self):
"""
:return:
"""
# Default cols, rows
_cols = 2
_rows = math.ceil(len(self._df.columns) / 2)
self._subplots_kwargs = dict()
self._subplots_kwargs["rows"] = _rows
self._subplots_kwargs["cols"] = _cols
self._subplots_kwargs["shared_xaxes"] = False
self._subplots_kwargs["shared_yaxes"] = False
self._subplots_kwargs["vertical_spacing"] = 0.3 / _rows
self._subplots_kwargs["print_grid"] = False
self._subplots_kwargs["subplot_titles"] = self._df.columns.tolist()
def _init_figure(self):
"""
:return:
"""
self._figure = make_subplots(**self._subplots_kwargs)
for column_name, column_map in self._sub_graph_data:
if isinstance(column_name, go.Figure):
_graph_obj = column_name
elif isinstance(column_name, str):
temp_name = column_map.get("name", column_name.replace("_", " "))
kind = column_map.get("kind", self._kind_map.get("kind", "ScatterGraph"))
_graph_kwargs = column_map.get("graph_kwargs", self._kind_map.get("kwargs", {}))
_graph_obj = BaseGraph.get_instance_with_graph_parameters(
kind,
**dict(
df=self._df.loc[:, [column_name]],
name_dict={column_name: temp_name},
graph_kwargs=_graph_kwargs,
)
)
else:
raise TypeError()
row = column_map["row"]
col = column_map["col"]
_graph_data = getattr(_graph_obj, "data")
# for _item in _graph_data:
# _item.pop('xaxis', None)
# _item.pop('yaxis', None)
for _g_obj in _graph_data:
self._figure.append_trace(_g_obj, row=row, col=col)
if self._sub_graph_layout is not None:
for k, v in self._sub_graph_layout.items():
self._figure["layout"][k].update(v)
self._figure["layout"].update(self._layout)
@property
def figure(self):
return self._figure

View File

@@ -0,0 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .strategy import (
TopkDropoutStrategy,
BaseStrategy,
WeightStrategyBase,
)

View File

@@ -0,0 +1,73 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .strategy import StrategyWrapper, WeightStrategyBase
import copy
class SoftTopkStrategy(WeightStrategyBase):
def __init__(self, topk, max_sold_weight=1.0, risk_degree=0.95, buy_method="first_fill"):
"""Parameter
topk : int
top-N stocks to buy
risk_degree : float
position percentage of total value
buy_method :
rank_fill: assign the weight stocks that rank high first(1/topk max)
average_fill: assign the weight to the stocks rank high averagely.
"""
super().__init__()
self.topk = topk
self.max_sold_weight = max_sold_weight
self.risk_degree = risk_degree
self.buy_method = buy_method
def get_risk_degree(self, date):
"""get_risk_degree
Return the proportion of your total value you will used in investment.
Dynamically risk_degree will result in Market timing
"""
# It will use 95% amoutn of your total value by default
return self.risk_degree
def generate_target_weight_position(self, score, current, trade_date):
"""Parameter:
score : pred score for this trade date, pd.Series, index is stock_id, contain 'score' column
current : current position, use Position() class
trade_date : trade date
generate target position from score for this date and the current position
The cache is not considered in the position
"""
# TODO:
# If the current stock list is more than topk(eg. The weights are modified
# by risk control), the weight will not be handled correctly.
buy_signal_stocks = set(score.sort_values(ascending=False).iloc[: self.topk].index)
cur_stock_weight = current.get_stock_weight_dict(only_stock=True)
if len(cur_stock_weight) == 0:
final_stock_weight = {code: 1 / self.topk for code in buy_signal_stocks}
else:
final_stock_weight = copy.deepcopy(cur_stock_weight)
sold_stock_weight = 0.0
for stock_id in final_stock_weight:
if stock_id not in buy_signal_stocks:
sw = min(self.max_sold_weight, final_stock_weight[stock_id])
sold_stock_weight += sw
final_stock_weight[stock_id] -= sw
if self.buy_method == "first_fill":
for stock_id in buy_signal_stocks:
add_weight = min(
max(1 / self.topk - final_stock_weight.get(stock_id, 0), 0.0),
sold_stock_weight,
)
final_stock_weight[stock_id] = final_stock_weight.get(stock_id, 0.0) + add_weight
sold_stock_weight -= add_weight
elif self.buy_method == "average_fill":
for stock_id in buy_signal_stocks:
final_stock_weight[stock_id] = final_stock_weight.get(stock_id, 0.0) + sold_stock_weight / len(
buy_signal_stocks
)
else:
raise ValueError("Buy method not found")
return final_stock_weight

View File

@@ -0,0 +1,171 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This order generator is for strategies based on WeightStrategyBase
"""
from ..backtest.position import Position
from ..backtest.exchange import Exchange
import pandas as pd
import copy
class OrderGenerator:
def generate_order_list_from_target_weight_position(
self,
current: Position,
trade_exchange: Exchange,
target_weight_position: dict,
risk_degree: float,
pred_date: pd.Timestamp,
trade_date: pd.Timestamp,
) -> list:
"""generate_order_list_from_target_weight_position
:param current: The current position
:type current: Position
:param trade_exchange:
:type trade_exchange: Exchange
:param target_weight_position: {stock_id : weight}
:type target_weight_position: dict
:param risk_degree:
:type risk_degree: float
:param pred_date: the date the score is predicted
:type pred_date: pd.Timestamp
:param trade_date: the date the stock is traded
:type trade_date: pd.Timestamp
:rtype: list
"""
raise NotImplementedError()
class OrderGenWInteract(OrderGenerator):
"""Order Generator With Interact"""
def generate_order_list_from_target_weight_position(
self,
current: Position,
trade_exchange: Exchange,
target_weight_position: dict,
risk_degree: float,
pred_date: pd.Timestamp,
trade_date: pd.Timestamp,
) -> list:
"""generate_order_list_from_target_weight_position
No adjustment for for the nontradable share.
All the tadable value is assigned to the tadable stock according to the weight.
if interact == True, will use the price at trade date to generate order list
else, will only use the price before the trade date to generate order list
:param current:
:type current: Position
:param trade_exchange:
:type trade_exchange: Exchange
:param target_weight_position:
:type target_weight_position: dict
:param risk_degree:
:type risk_degree: float
:param pred_date:
:type pred_date: pd.Timestamp
:param trade_date:
:type trade_date: pd.Timestamp
:rtype: list
"""
# calculate current_tradable_value
current_amount_dict = current.get_stock_amount_dict()
current_total_value = trade_exchange.calculate_amount_position_value(
amount_dict=current_amount_dict, trade_date=trade_date, only_tradable=False
)
current_tradable_value = trade_exchange.calculate_amount_position_value(
amount_dict=current_amount_dict, trade_date=trade_date, only_tradable=True
)
# add cash
current_tradable_value += current.get_cash()
reserved_cash = (1.0 - risk_degree) * (current_total_value + current.get_cash())
current_tradable_value -= reserved_cash
if current_tradable_value < 0:
# if you sell all the tradable stock can not meet the reserved
# value. Then just sell all the stocks
target_amount_dict = copy.deepcopy(current_amount_dict.copy())
for stock_id in list(target_amount_dict.keys()):
if trade_exchange.is_stock_tradable(stock_id, trade_date):
del target_amount_dict[stock_id]
else:
# consider cost rate
current_tradable_value /= 1 + max(trade_exchange.close_cost, trade_exchange.open_cost)
# strategy 1 : generate amount_position by weight_position
# Use API in Exchange()
target_amount_dict = trade_exchange.generate_amount_position_from_weight_position(
weight_position=target_weight_position,
cash=current_tradable_value,
trade_date=trade_date,
)
order_list = trade_exchange.generate_order_for_target_amount_position(
target_position=target_amount_dict,
current_position=current_amount_dict,
trade_date=trade_date,
)
return order_list
class OrderGenWOInteract(OrderGenerator):
"""Order Generator Without Interact"""
def generate_order_list_from_target_weight_position(
self,
current: Position,
trade_exchange: Exchange,
target_weight_position: dict,
risk_degree: float,
pred_date: pd.Timestamp,
trade_date: pd.Timestamp,
) -> list:
"""generate_order_list_from_target_weight_position
generate order list directly not using the information (e.g. whether can be traded, the accurate trade price) at trade date.
In target weight position, generating order list need to know the price of objective stock in trade date, but we cannot get that
value when do not interact with exchange, so we check the %close price at pred_date or price recorded in current position.
:param current:
:type current: Position
:param trade_exchange:
:type trade_exchange: Exchange
:param target_weight_position:
:type target_weight_position: dict
:param risk_degree:
:type risk_degree: float
:param pred_date:
:type pred_date: pd.Timestamp
:param trade_date:
:type trade_date: pd.Timestamp
:rtype: list
"""
risk_total_value = risk_degree * current.calculate_value()
current_stock = current.get_stock_list()
amount_dict = {}
for stock_id in target_weight_position:
# Current rule will ignore the stock that not hold and cannot be traded at predict date
if trade_exchange.is_stock_tradable(stock_id=stock_id, trade_date=pred_date):
amount_dict[stock_id] = (
risk_total_value * target_weight_position[stock_id] / trade_exchange.get_close(stock_id, pred_date)
)
elif stock_id in current_stock:
amount_dict[stock_id] = (
risk_total_value * target_weight_position[stock_id] / current.get_stock_price(stock_id)
)
else:
continue
order_list = trade_exchange.generate_order_for_target_amount_position(
target_position=amount_dict,
current_position=current.get_stock_amount_dict(),
trade_date=trade_date,
)
return order_list

View File

@@ -0,0 +1,318 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import copy
import numpy as np
import pandas as pd
from ..backtest.order import Order
from ...utils import get_pre_trading_date
from .order_generator import OrderGenWInteract
class BaseStrategy:
def __init__(self):
pass
def get_risk_degree(self, date):
"""get_risk_degree
Return the proportion of your total value you will used in investment.
Dynamically risk_degree will result in Market timing
"""
# It will use 95% amount of your total value by default
return 0.95
def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):
"""Parameter
score_series : pd.Seires
stock_id , score
current : Position()
current state of position
DO NOT directly change the state of current
trade_exchange : Exchange()
trade exchange
pred_date : pd.Timestamp
predict date
trade_date : pd.Timestamp
trade date
DO NOT directly change the state of current
"""
pass
def update(self, score_series, pred_date, trade_date):
"""User can use this method to update strategy state each trade date.
Parameter
---------
score_series : pd.Series
stock_id , score
pred_date : pd.Timestamp
oredict date
trade_date : pd.Timestamp
trade date
"""
pass
def init(self, **kwargs):
"""Some strategy need to be initial after been implemented,
User can use this method to init his strategy with parameters needed.
"""
pass
def get_init_args_from_model(self, model, init_date):
"""
This method only be used in 'online' module, it will generate the *args to initial the strategy.
:param
mode : model used in 'online' module
"""
return {}
class StrategyWrapper:
"""
StrategyWrapper is a wrapper of another strategy.
By overriding some methods to make some changes on the basic strategy
Cost control and risk control will base on this class.
"""
def __init__(self, inner_strategy):
"""__init__
:param inner_strategy: set the inner strategy
"""
self.inner_strategy = inner_strategy
def __getattr__(self, name):
"""__getattr__
:param name: If no implementation in this method. Call the method in the innter_strategy by default.
"""
return getattr(self.inner_strategy, name)
class AdjustTimer:
"""AdjustTimer
Responsible for timing of position adjusting
This is designed as multiple inheritance mechanism due to
- the is_adjust may need access to the internel state of a strategyw
- it can be reguard as a enhancement to the existing strategy
"""
# adjust position in each trade date
def is_adjust(self, trade_date):
"""is_adjust
Return if the strategy can adjust positions on `trade_date`
Will normally be used in strategy do trading with trade frequency
"""
return True
class ListAdjustTimer(AdjustTimer):
def __init__(self, adjust_dates=None):
"""__init__
:param adjust_dates: an iterable object, it will return a timelist for trading dates
"""
if adjust_dates is None:
# None indicates that all dates is OK for adjusting
self.adjust_dates = None
else:
self.adjust_dates = {pd.Timestamp(dt) for dt in adjust_dates}
def is_adjust(self, trade_date):
if self.adjust_dates is None:
return True
return pd.Timestamp(trade_date) in self.adjust_dates
class WeightStrategyBase(BaseStrategy, AdjustTimer):
def __init__(self, order_generator_cls_or_obj=OrderGenWInteract, *args, **kwargs):
super().__init__(*args, **kwargs)
if isinstance(order_generator_cls_or_obj, type):
self.order_generator = order_generator_cls_or_obj()
else:
self.order_generator = order_generator_cls_or_obj
def generate_target_weight_position(self, score, current, trade_date):
"""Parameter:
score : pred score for this trade date, pd.Series, index is stock_id, contain 'score' column
current : current position, use Position() class
trade_exchange : Exchange()
trade_date : trade date
generate target position from score for this date and the current position
The cash is not considered in the position
"""
raise NotImplementedError()
def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):
"""Parameter
score_series : pd.Seires
stock_id , score
current : Position()
current of account
trade_exchange : Exchange()
exchange
trade_date : pd.Timestamp
date
"""
# judge if to adjust
if not self.is_adjust(trade_date):
return []
# generate_order_list
# generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list
current_temp = copy.deepcopy(current)
target_weight_position = self.generate_target_weight_position(
score=score_series, current=current_temp, trade_date=trade_date
)
order_list = self.order_generator.generate_order_list_from_target_weight_position(
current=current_temp,
trade_exchange=trade_exchange,
risk_degree=self.get_risk_degree(trade_date),
target_weight_position=target_weight_position,
pred_date=pred_date,
trade_date=trade_date,
)
return order_list
class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
def __init__(self, topk, n_drop, method="bottom", risk_degree=0.95, thresh=1, hold_thresh=1, **kwargs):
"""Parameter
topk : int
The number of stocks in the portfolio
n_drop : int
number of stocks to be replaced in each trading date
method : str
dropout method, random/bottom
risk_degree : float
position percentage of total value
thresh : int
minimun holding days since last buy singal of the stock
hold_thresh : int
minimum holding days
before sell stock , will check current.get_stock_count(order.stock_id) >= self.thresh
"""
super(TopkDropoutStrategy, self).__init__()
ListAdjustTimer.__init__(self, kwargs.get("adjust_dates", None))
self.topk = topk
self.n_drop = n_drop
self.method = method
self.risk_degree = risk_degree
self.thresh = thresh
# self.stock_count['code'] will be the days the stock has been hold
# since last buy signal. This is designed for thresh
self.stock_count = {}
self.hold_thresh = hold_thresh
def get_risk_degree(self, date):
"""get_risk_degree
Return the proportion of your total value you will used in investment.
Dynamically risk_degree will result in Market timing
"""
# It will use 95% amoutn of your total value by default
return self.risk_degree
def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):
"""Gnererate order list according to score_series at trade_date.
will not change current.
Parameter
score_series : pd.Seires
stock_id , score
current : Position()
current of account
trade_exchange : Exchange()
exchange
pred_date : pd.Timestamp
predict date
trade_date : pd.Timestamp
trade date
"""
if not self.is_adjust(trade_date):
return []
current_temp = copy.deepcopy(current)
# generate order list for this adjust date
sell_order_list = []
buy_order_list = []
# load score
cash = current_temp.get_cash()
current_stock_list = current_temp.get_stock_list()
last = score_series.reindex(current_stock_list).sort_values(ascending=False).index
today = (
score_series[~score_series.index.isin(last)]
.sort_values(ascending=False)
.index[: self.n_drop + self.topk - len(last)]
)
comb = score_series.reindex(last.union(today)).sort_values(ascending=False).index
if self.method == "bottom":
sell = last[last.isin(comb[-self.n_drop :])]
elif self.method == "random":
sell = pd.Index(np.random.choice(last, self.n_drop) if len(last) else [])
buy = today[: len(sell) + self.topk - len(last)]
for code in current_stock_list:
if not trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date):
continue
if code in sell:
# check hold limit
if self.stock_count[code] < self.thresh or current_temp.get_stock_count(code) < self.hold_thresh:
# can not sell this code
# no buy signal, but the stock is kept
self.stock_count[code] += 1
continue
# sell order
sell_amount = current_temp.get_stock_amount(code=code)
sell_order = Order(
stock_id=code,
amount=sell_amount,
trade_date=trade_date,
direction=Order.SELL, # 0 for sell, 1 for buy
factor=trade_exchange.get_factor(code, trade_date),
)
# is order executable
if trade_exchange.check_order(sell_order):
sell_order_list.append(sell_order)
trade_val, trade_cost, trade_price = trade_exchange.deal_order(sell_order, position=current_temp)
# update cash
cash += trade_val - trade_cost
# sold
del self.stock_count[code]
else:
# no buy signal, but the stock is kept
self.stock_count[code] += 1
elif code in buy:
# NOTE: This is different from the original version
# get new buy signal
# Only the stock fall in to topk will produce buy signal
self.stock_count[code] = 1
else:
self.stock_count[code] += 1
# buy new stock
# note the current has been changed
current_stock_list = current_temp.get_stock_list()
value = cash * self.risk_degree / len(buy) if len(buy) > 0 else 0
# open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not consider it
# as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line
# value = value / (1+trade_exchange.open_cost) # set open_cost limit
for code in buy:
# check is stock supended
if not trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date):
continue
# buy order
buy_price = trade_exchange.get_deal_price(stock_id=code, trade_date=trade_date)
buy_amount = value / buy_price
factor = trade_exchange.quote[(code, trade_date)]["$factor"]
buy_amount = trade_exchange.round_amount_by_trade_unit(buy_amount, factor)
buy_order = Order(
stock_id=code,
amount=buy_amount,
trade_date=trade_date,
direction=Order.BUY, # 1 for buy
factor=factor,
)
buy_order_list.append(buy_order)
self.stock_count[code] = 1
return sell_order_list + buy_order_list

View File

Some files were not shown because too many files have changed in this diff Show More