mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 10:31:00 +08:00
Merge remote-tracking branch 'qlib/main' into qlib_register_ops
This commit is contained in:
5
.github/ISSUE_TEMPLATE/bug-report.md
vendored
5
.github/ISSUE_TEMPLATE/bug-report.md
vendored
@@ -28,7 +28,8 @@ Steps to reproduce the behavior:
|
||||
|
||||
## Environment
|
||||
|
||||
**Note**: One could run `python scripts/collect_info.py` under the `qlib` directory to get the following information.
|
||||
**Note**: User could run `cd scripts && python collect_info.py all` under project directory to get system information
|
||||
and paste them here directly.
|
||||
|
||||
- Qlib version:
|
||||
- Python version:
|
||||
@@ -37,4 +38,4 @@ Steps to reproduce the behavior:
|
||||
|
||||
## Additional Notes
|
||||
|
||||
<!-- Add any other information about the problem here. -->
|
||||
<!-- Add any other information about the problem here. -->
|
||||
|
||||
62
.github/stale.yml
vendored
Normal file
62
.github/stale.yml
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
# Configuration for probot-stale - https://github.com/probot/stale
|
||||
|
||||
# Number of days of inactivity before an Issue or Pull Request becomes stale
|
||||
daysUntilStale: 60
|
||||
|
||||
# Number of days of inactivity before an Issue or Pull Request with the stale label is closed.
|
||||
# Set to false to disable. If disabled, issues still need to be closed manually, but will remain marked as stale.
|
||||
daysUntilClose: 7
|
||||
|
||||
# Only issues or pull requests with all of these labels are check if stale. Defaults to `[]` (disabled)
|
||||
onlyLabels: []
|
||||
|
||||
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
|
||||
exemptLabels:
|
||||
- bug
|
||||
- pinned
|
||||
- security
|
||||
- "[Status] Maybe Later"
|
||||
|
||||
# Set to true to ignore issues in a project (defaults to false)
|
||||
exemptProjects: false
|
||||
|
||||
# Set to true to ignore issues in a milestone (defaults to false)
|
||||
exemptMilestones: false
|
||||
|
||||
# Set to true to ignore issues with an assignee (defaults to false)
|
||||
exemptAssignees: false
|
||||
|
||||
# Label to use when marking as stale
|
||||
staleLabel: wontfix
|
||||
|
||||
# Comment to post when marking as stale. Set to `false` to disable
|
||||
markComment: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity. It will be closed if no further activity occurs. Thank you
|
||||
for your contributions.
|
||||
|
||||
# Comment to post when removing the stale label.
|
||||
# unmarkComment: >
|
||||
# Your comment here.
|
||||
|
||||
# Comment to post when closing a stale Issue or Pull Request.
|
||||
# closeComment: >
|
||||
# Your comment here.
|
||||
|
||||
# Limit the number of actions per hour, from 1-30. Default is 30
|
||||
limitPerRun: 30
|
||||
|
||||
# Limit to only `issues` or `pulls`
|
||||
# only: issues
|
||||
|
||||
# Optionally, specify configuration settings that are specific to just 'issues' or 'pulls':
|
||||
# pulls:
|
||||
# daysUntilStale: 30
|
||||
# markComment: >
|
||||
# This pull request has been automatically marked as stale because it has not had
|
||||
# recent activity. It will be closed if no further activity occurs. Thank you
|
||||
# for your contributions.
|
||||
|
||||
# issues:
|
||||
# exemptLabels:
|
||||
# - confirmed
|
||||
100
.github/workflows/test.yml
vendored
100
.github/workflows/test.yml
vendored
@@ -12,8 +12,8 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-16.04, ubuntu-18.04, macos-latest]
|
||||
python-version: [3.7, 3.8]
|
||||
os: [windows-latest, ubuntu-16.04, ubuntu-18.04, ubuntu-20.04, macos-latest]
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
@@ -23,38 +23,96 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
- name: Lint with Black
|
||||
run: |
|
||||
pip install --upgrade cython
|
||||
pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
python setup.py install
|
||||
cd ..
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install black
|
||||
$CONDA\\python.exe -m black qlib -l 120 --check --diff
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install black
|
||||
$CONDA/bin/python -m black qlib -l 120 --check --diff
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
# Test Qlib installed with pip
|
||||
- name: Install Qlib with pip
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install pyqlib --ignore-installed ruamel.yaml --user
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install pyqlib --ignore-installed ruamel.yaml
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
else
|
||||
$CONDA/bin/python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Test workflow by config (install from pip)
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe qlib\\workflow\\cli.py examples\\benchmarks\\LightGBM\\workflow_config_lightgbm_Alpha158.yaml
|
||||
$CONDA\\python.exe -m pip uninstall -y pyqlib
|
||||
else
|
||||
$CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
sudo $CONDA/bin/python -m pip uninstall -y pyqlib
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
# Test Qlib installed from source
|
||||
- name: Install Qlib from source
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install --upgrade cython
|
||||
$CONDA\\python.exe -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
$CONDA\\python.exe -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
$CONDA\\python.exe setup.py install
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install --upgrade cython
|
||||
sudo $CONDA/bin/python -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
sudo $CONDA/bin/python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
sudo $CONDA/bin/python setup.py install
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install black pytest
|
||||
|
||||
- name: Lint with Black
|
||||
run: |
|
||||
cd ..
|
||||
python -m black qlib -l 120 --check --diff
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install --upgrade pip
|
||||
$CONDA\\python.exe -m pip install black pytest
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install --upgrade pip
|
||||
sudo $CONDA/bin/python -m pip install black pytest
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
run: |
|
||||
cd tests
|
||||
pytest . --durations=0
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pytest . --durations=0
|
||||
else
|
||||
$CONDA/bin/python -m pytest . --durations=0
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Test data downloads
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
|
||||
- name: Test workflow by config
|
||||
run: |
|
||||
qrun examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe qlib\\workflow\\cli.py examples\\benchmarks\\LightGBM\\workflow_config_lightgbm_Alpha158.yaml
|
||||
else
|
||||
$CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
fi
|
||||
shell: bash
|
||||
@@ -18,5 +18,4 @@ python:
|
||||
install:
|
||||
- requirements: docs/requirements.txt
|
||||
- method: setuptools
|
||||
path: .
|
||||
system_packages: true
|
||||
path: .
|
||||
78
README.md
78
README.md
@@ -34,6 +34,7 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
|
||||
- [More About Qlib](#more-about-qlib)
|
||||
- [Offline Mode and Online Mode](#offline-mode-and-online-mode)
|
||||
- [Performance of Qlib Data Server](#performance-of-qlib-data-server)
|
||||
- [Related Reports](#related-reports)
|
||||
- [Contributing](#contributing)
|
||||
|
||||
|
||||
@@ -61,17 +62,36 @@ At the module level, Qlib is a platform that consists of the above components. T
|
||||
|
||||
This quick start guide tries to demonstrate
|
||||
1. It's very easy to build a complete Quant research workflow and try your ideas with _Qlib_.
|
||||
1. Though with *public data* and *simple models*, machine learning technologies **work very well** in practical Quant investment.
|
||||
2. Though with *public data* and *simple models*, machine learning technologies **work very well** in practical Quant investment.
|
||||
|
||||
Here is a quick **[demo](https://terminalizer.com/view/3f24561a4470)** shows how to install ``Qlib``, and run LightGBM with ``qrun``. **But**, please make sure you have already prepared the data following the [instruction](#data-preparation).
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
Users can easily install ``Qlib`` by pip according to the following command
|
||||
This table demonstrates the supported Python version of `Qlib`:
|
||||
| | install with pip | install from source | plot |
|
||||
| ------------- |:---------------------:|:--------------------:|:----:|
|
||||
| Python 3.6 | :heavy_check_mark: | :heavy_check_mark: (only with `Anaconda`) | :heavy_check_mark: |
|
||||
| Python 3.7 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.8 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
|
||||
|
||||
**Note**:
|
||||
1. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.7 or use `conda`'s Python to install ``Qlib`` from source.
|
||||
2. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
|
||||
|
||||
### Install with pip
|
||||
Users can easily install ``Qlib`` by pip according to the following command.
|
||||
|
||||
```bash
|
||||
pip install pyqlib
|
||||
```
|
||||
|
||||
Also, users can install ``Qlib`` by the source code according to the following steps:
|
||||
**Note**: pip will install the latest stable qlib. However, the main branch of qlib is in active development. If you want to test the latest scripts or functions in the main branch. Please install qlib with the methods below.
|
||||
|
||||
### Install from source
|
||||
Also, users can install the latest dev version ``Qlib`` by the source code according to the following steps:
|
||||
|
||||
* Before installing ``Qlib`` from source, users need to install some dependencies:
|
||||
|
||||
@@ -80,13 +100,20 @@ Also, users can install ``Qlib`` by the source code according to the following s
|
||||
pip install --upgrade cython
|
||||
```
|
||||
|
||||
* Clone the repository and install ``Qlib``:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/microsoft/qlib.git && cd qlib
|
||||
python setup.py install
|
||||
```
|
||||
* Clone the repository and install ``Qlib`` as follows.
|
||||
* If you haven't installed qlib by the command ``pip install pyqlib`` before:
|
||||
```bash
|
||||
git clone https://github.com/microsoft/qlib.git && cd qlib
|
||||
python setup.py install
|
||||
```
|
||||
* If you have already installed the stable version by the command ``pip install pyqlib``:
|
||||
```bash
|
||||
git clone https://github.com/microsoft/qlib.git && cd qlib
|
||||
pip install .
|
||||
```
|
||||
**Note**: **Only** the command ``pip install .`` **can** overwrite the stable version installed by ``pip install pyqlib``, while the command ``python setup.py install`` **can't**.
|
||||
|
||||
**Tips**: If you fail to install `Qlib` or run the examples in your environment, comparing your steps and the [CI workflow](.github/workflows/test.yml) may help you find the problem.
|
||||
|
||||
## Data Preparation
|
||||
Load and prepare data by running the following code:
|
||||
@@ -130,12 +157,16 @@ Users could create the same dataset with it.
|
||||
## Auto Quant Research Workflow
|
||||
Qlib provides a tool named `qrun` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). You can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
|
||||
|
||||
1. Quant Research Workflow: Run `qrun` with lightgbm workflow config ([workflow_config_lightgbm.yaml](examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml)) as following.
|
||||
1. Quant Research Workflow: Run `qrun` with lightgbm workflow config ([workflow_config_lightgbm_Alpha158.yaml](examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml) as following.
|
||||
```bash
|
||||
cd examples # Avoid running program under the directory contains `qlib`
|
||||
qrun benchmarks/LightGBM/workflow_config_lightgbm.yaml
|
||||
qrun benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
```
|
||||
The result of `qrun` is as follows, please refer to please refer to [Intraday Trading](https://qlib.readthedocs.io/en/latest/component/backtest.html) for more details about the result.
|
||||
If users want to use `qrun` under debug mode, please use the following command:
|
||||
```bash
|
||||
python -m pdb qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
```
|
||||
The result of `qrun` is as follows, please refer to [Intraday Trading](https://qlib.readthedocs.io/en/latest/component/backtest.html) for more details about the result.
|
||||
|
||||
```bash
|
||||
|
||||
@@ -190,16 +221,16 @@ The automatic workflow may not suite the research workflow of all Quant research
|
||||
# [Quant Model Zoo](examples/benchmarks)
|
||||
|
||||
Here is a list of models built on `Qlib`.
|
||||
- [GBDT based on LightGBM](qlib/contrib/model/gbdt.py)
|
||||
- [GBDT based on Catboost](qlib/contrib/model/catboost_model.py)
|
||||
- [GBDT based on XGBoost](qlib/contrib/model/xgboost.py)
|
||||
- [GBDT based on LightGBM (Guolin Ke, et al.)](qlib/contrib/model/gbdt.py)
|
||||
- [GBDT based on Catboost (Liudmila Prokhorenkova, et al.)](qlib/contrib/model/catboost_model.py)
|
||||
- [GBDT based on XGBoost (Tianqi Chen, et al.)](qlib/contrib/model/xgboost.py)
|
||||
- [MLP based on pytorch](qlib/contrib/model/pytorch_nn.py)
|
||||
- [GRU based on pytorch](qlib/contrib/model/pytorch_gru.py)
|
||||
- [LSTM based on pytorcn](qlib/contrib/model/pytorch_lstm.py)
|
||||
- [ALSTM based on pytorcn](qlib/contrib/model/pytorch_alstm.py)
|
||||
- [GATs based on pytorch](qlib/contrib/model/pytorch_gats.py)
|
||||
- [SFM based on pytorch](qlib/contrib/model/pytorch_sfm.py)
|
||||
- [TFT based on tensorflow](examples/benchmarks/TFT/tft.py)
|
||||
- [GRU based on pytorch (Kyunghyun Cho, et al.)](qlib/contrib/model/pytorch_gru.py)
|
||||
- [LSTM based on pytorcn (Sepp Hochreiter, et al.)](qlib/contrib/model/pytorch_lstm.py)
|
||||
- [ALSTM based on pytorcn (Yao Qin, et al.)](qlib/contrib/model/pytorch_alstm.py)
|
||||
- [GATs based on pytorch (Petar Velickovic, et al.)](qlib/contrib/model/pytorch_gats.py)
|
||||
- [SFM based on pytorch (Liheng Zhang, et al.)](qlib/contrib/model/pytorch_sfm.py)
|
||||
- [TFT based on tensorflow (Bryan Lim, et al.)](examples/benchmarks/TFT/tft.py)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
@@ -280,7 +311,10 @@ Such overheads greatly slow down the data loading process.
|
||||
Qlib data are stored in a compact format, which is efficient to be combined into arrays for scientific computation.
|
||||
|
||||
|
||||
|
||||
# Related Reports
|
||||
- [【华泰金工林晓明团队】微软AI量化投资平台Qlib体验——华泰人工智能系列之四十](https://mp.weixin.qq.com/s/Brcd7im4NibJOJzZfMn6tQ)
|
||||
- [微软也搞AI量化平台?还是开源的!](https://mp.weixin.qq.com/s/47bP5YwxfTp2uTHjUBzJQQ)
|
||||
- [微矿Qlib:业内首个AI量化投资开源平台](https://mp.weixin.qq.com/s/vsJv7lsgjEi-ALYUz4CvtQ)
|
||||
|
||||
|
||||
# Contributing
|
||||
|
||||
12
docs/_static/demo.sh
vendored
Normal file
12
docs/_static/demo.sh
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
#!/bin/sh
|
||||
git clone https://github.com/microsoft/qlib.git
|
||||
cd qlib
|
||||
ls
|
||||
pip install pyqlib
|
||||
# or
|
||||
# pip install numpy
|
||||
# pip install --upgrade cython
|
||||
# python setup.py install
|
||||
cd examples
|
||||
ls
|
||||
qrun benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
@@ -50,57 +50,37 @@ Users can use ``Data Handler`` to build formulaic alphas `MACD` in qlib:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
>> from qlib.data.dataset.handler import QLibDataHandler
|
||||
>> from qlib.data.dataset.loader import QlibDataLoader
|
||||
>> MACD_EXP = '(EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9)/$close'
|
||||
>> fields = [MACD_EXP] # MACD
|
||||
>> names = ['MACD']
|
||||
>> labels = ['$close'] # label
|
||||
>> labels = ['Ref($close, -2)/Ref($close, -1) - 1'] # label
|
||||
>> label_names = ['LABEL']
|
||||
>> data_handler = QLibDataHandler(start_date='2010-01-01', end_date='2017-12-31', fields=fields, names=names, labels=labels, label_names=label_names)
|
||||
>> TRAINER_CONFIG = {
|
||||
.. "train_start_date": "2007-01-01",
|
||||
.. "train_end_date": "2014-12-31",
|
||||
.. "validate_start_date": "2015-01-01",
|
||||
.. "validate_end_date": "2016-12-31",
|
||||
.. "test_start_date": "2017-01-01",
|
||||
.. "test_end_date": "2020-08-01",
|
||||
>> data_loader_config = {
|
||||
.. "feature": (fields, names),
|
||||
.. "label": (labels, label_names)
|
||||
.. }
|
||||
>> feature_train, label_train, feature_validate, label_validate, feature_test, label_test = data_handler.get_split_data(**TRAINER_CONFIG)
|
||||
>> print(feature_train, label_train)
|
||||
MACD
|
||||
instrument datetime
|
||||
SH600000 2010-01-04 -0.008625
|
||||
2010-01-05 -0.007234
|
||||
2010-01-06 -0.007693
|
||||
2010-01-07 -0.009633
|
||||
2010-01-08 -0.009891
|
||||
... ...
|
||||
SZ300251 2014-12-25 0.043072
|
||||
2014-12-26 0.041345
|
||||
2014-12-29 0.042733
|
||||
2014-12-30 0.042066
|
||||
2014-12-31 0.036299
|
||||
|
||||
[322025 rows x 1 columns]
|
||||
LABEL
|
||||
instrument datetime
|
||||
SH600000 2010-01-04 4.260015
|
||||
2010-01-05 4.292182
|
||||
2010-01-06 4.207747
|
||||
2010-01-07 4.113258
|
||||
2010-01-08 4.159496
|
||||
... ...
|
||||
SZ300251 2014-12-25 4.343212
|
||||
2014-12-26 4.470587
|
||||
2014-12-29 4.762474
|
||||
2014-12-30 4.369748
|
||||
2014-12-31 4.182222
|
||||
|
||||
[322025 rows x 1 columns]
|
||||
>> data_loader = QlibDataLoader(config=data_loader_config)
|
||||
>> df = data_loader.load(instruments='csi300', start_time='2010-01-01', end_time='2017-12-31')
|
||||
>> print(df)
|
||||
feature label
|
||||
MACD LABEL
|
||||
datetime instrument
|
||||
2010-01-04 SH600000 -0.011547 -0.019672
|
||||
SH600004 0.002745 -0.014721
|
||||
SH600006 0.010133 0.002911
|
||||
SH600008 -0.001113 0.009818
|
||||
SH600009 0.025878 -0.017758
|
||||
... ... ...
|
||||
2017-12-29 SZ300124 0.007306 -0.005074
|
||||
SZ300136 -0.013492 0.056352
|
||||
SZ300144 -0.000966 0.011853
|
||||
SZ300251 0.004383 0.021739
|
||||
SZ300315 -0.030557 0.012455
|
||||
|
||||
Reference
|
||||
===========
|
||||
|
||||
To learn more about ``Data Handler``, please refer to `Data Handler <../component/data.html>`_
|
||||
To learn more about ``Data Loader``, please refer to `Data Loader <../component/data.html#data-loader>`_
|
||||
|
||||
To learn more about ``Data API``, please refer to `Data API <../component/data.html>`_
|
||||
|
||||
@@ -126,17 +126,17 @@ After conversion, users can find their Qlib format data in the directory `~/.qli
|
||||
The arguments of `--include_fields` should correspond with the column names of CSV files. The columns names of dataset provided by ``Qlib`` should include open, close, high, low, volume and factor at least.
|
||||
|
||||
- `open`
|
||||
The opening price
|
||||
The adjusted opening price
|
||||
- `close`
|
||||
The closing price
|
||||
The adjusted closing price
|
||||
- `high`
|
||||
The highest price
|
||||
The adjusted highest price
|
||||
- `low`
|
||||
The lowest price
|
||||
The adjusted lowest price
|
||||
- `volume`
|
||||
The trading volume
|
||||
The adjusted trading volume
|
||||
- `factor`
|
||||
The Restoration factor
|
||||
The Restoration factor. Normally, original_price = adj_price / factor
|
||||
|
||||
In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended.
|
||||
|
||||
@@ -296,6 +296,7 @@ The ``Processor`` module in ``Qlib`` is designed to be learnable and it is respo
|
||||
- ``RobustZScoreNorm``: `processor` that applies robust z-score normalization.
|
||||
- ``CSZScoreNorm``: `processor` that applies cross sectional z-score normalization.
|
||||
- ``CSRankNorm``: `processor` that applies cross sectional rank normalization.
|
||||
- ``CSZFillna``: `processor` that fills N/A values in a cross sectional way by the mean of the column.
|
||||
|
||||
Users can also create their own `processor` by inheriting the base class of ``Processor``. Please refer to the implementation of all the processors for more information (`Processor Link <https://github.com/microsoft/qlib/blob/main/qlib/data/dataset/processor.py>`_).
|
||||
|
||||
|
||||
@@ -34,8 +34,9 @@ Here is a general view of the structure of the system:
|
||||
- Recorder 2
|
||||
- ...
|
||||
- ...
|
||||
This experiment management system defines a set of interface and provided a concrete implementation based on the machine learning platform: ``MLFlow`` (`link <https://mlflow.org/>`_).
|
||||
This experiment management system defines a set of interface and provided a concrete implementation ``MLflowExpManager``, which is based on the machine learning platform: ``MLFlow`` (`link <https://mlflow.org/>`_).
|
||||
|
||||
If users set the implementation of ``ExpManager`` to be ``MLflowExpManager``, they can use the command `mlflow ui` to visualize and check the experiment results. For more information, pleaes refer to the related documents `here <https://www.mlflow.org/docs/latest/cli.html#mlflow-ui>`_.
|
||||
|
||||
Qlib Recorder
|
||||
===================
|
||||
@@ -91,7 +92,7 @@ Record Template
|
||||
|
||||
The ``RecordTemp`` class is a class that enables generate experiment results such as IC and backtest in a certain format. We have provided three different `Record Template` class:
|
||||
|
||||
- ``SignalRecord``: This class generates the `preidction` results of the model.
|
||||
- ``SignalRecord``: This class generates the `prediction` results of the model.
|
||||
- ``SigAnaRecord``: This class generates the `IC`, `ICIR`, `Rank IC` and `Rank ICIR` of the model.
|
||||
- ``PortAnaRecord``: This class generates the results of `backtest`. The detailed information about `backtest` as well as the available `strategy`, users can refer to `Strategy <../component/strategy.html>`_ and `Backtest <../component/backtest.html>`_.
|
||||
|
||||
|
||||
@@ -103,6 +103,12 @@ After saving the config into `configuration.yaml`, users could start the workflo
|
||||
|
||||
qrun configuration.yaml
|
||||
|
||||
If users want to use ``qrun`` under debug mode, please use the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python -m pdb qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
.. note::
|
||||
|
||||
`qrun` will be placed in your $PATH directory when installing ``Qlib``.
|
||||
|
||||
@@ -226,3 +226,8 @@ epub_exclude_files = ["search.html"]
|
||||
|
||||
autodoc_member_order = "bysource"
|
||||
autodoc_default_flags = ["members"]
|
||||
autodoc_default_options = {
|
||||
"members": True,
|
||||
"member-order": "bysource",
|
||||
"special-members": "__init__",
|
||||
}
|
||||
|
||||
@@ -1 +1,5 @@
|
||||
Cython==0.29.21
|
||||
Cython
|
||||
cmake
|
||||
numpy
|
||||
scipy
|
||||
scikit-learn
|
||||
|
||||
@@ -63,6 +63,7 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
If Qlib fails to connect redis via `redis_host` and `redis_port`, cache mechanism will not be used! Please refer to `Cache <../component/data.html#cache>`_ for details.
|
||||
- `exp_manager`
|
||||
Type: dict, optional parameter, the setting of `experiment manager` to be used in qlib. Users can specify an experiment manager class, as well as the tracking URI for all the experiments. However, please be aware that we only support input of a dictionary in the following style for `exp_manager`. For more information about `exp_manager`, users can refer to `Recorder: Experiment Management <../component/recorder.html>`_.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
# For example, if you want to set your tracking_uri to a <specific folder>, you can initialize qlib below
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Requirements
|
||||
|
||||
Here is the minimal hardware requirements to run the example.
|
||||
Here is the minimal hardware requirements to run the `workflow_by_code` example.
|
||||
- Memory: 16G
|
||||
- Free Disk: 5G
|
||||
|
||||
|
||||
@@ -64,7 +64,6 @@ task:
|
||||
loss: mse
|
||||
n_jobs: 20
|
||||
GPU: 0
|
||||
rnn_type: GRU
|
||||
dataset:
|
||||
class: TSDatasetH
|
||||
module_path: qlib.data.dataset
|
||||
|
||||
@@ -64,7 +64,6 @@ task:
|
||||
loss: mse
|
||||
n_jobs: 20
|
||||
GPU: 0
|
||||
rnn_type: GRU
|
||||
dataset:
|
||||
class: TSDatasetH
|
||||
module_path: qlib.data.dataset
|
||||
|
||||
@@ -1,32 +1,34 @@
|
||||
# Benchmarks Performance
|
||||
|
||||
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 10 runs.
|
||||
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs.
|
||||
|
||||
The numbers shown below demonstrate the performance of the entire `workflow` of each model. We will update the `workflow` as well as models in the near future for better results.
|
||||
|
||||
## Alpha360 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|---|---|---|---|---|---|---|---|---|
|
||||
| Linear | Alpha360 | 0.0150±0.00 | 0.1049±0.00| 0.0284±0.00 | 0.1970±0.00 | -0.0655±0.00 | -0.6985±0.00| -0.2961±0.00 |
|
||||
| CatBoost | Alpha360 | 0.0397±0.00 | 0.2878±0.00| 0.0470±0.00 | 0.3703±0.00 | 0.0342±0.00 | 0.4092±0.00| -0.1057±0.00 |
|
||||
| XGBoost | Alpha360 | 0.0400±0.00 | 0.3031±0.00| 0.0461±0.00 | 0.3862±0.00 | 0.0528±0.00 | 0.6307±0.00| -0.1113±0.00 |
|
||||
| LightGBM | Alpha360 | 0.0399±0.00 | 0.3075±0.00| 0.0492±0.00 | 0.4019±0.00 | 0.0323±0.00 | 0.4370±0.00| -0.0917±0.00 |
|
||||
| MLP | Alpha360 | 0.0253±0.01 | 0.1954±0.05| 0.0329±0.00 | 0.2687±0.04 | 0.0161±0.01 | 0.1989±0.19| -0.1275±0.03 |
|
||||
| GRU | Alpha360 | 0.0503±0.01 | 0.3946±0.06| 0.0588±0.00 | 0.4737±0.05 | 0.0799±0.02 | 1.0940±0.26| -0.0810±0.03 |
|
||||
| LSTM | Alpha360 | 0.0466±0.01 | 0.3644±0.06| 0.0555±0.00 | 0.4451±0.04 | 0.0783±0.05 | 1.0539±0.65| -0.0844±0.03 |
|
||||
| ALSTM | Alpha360 | 0.0472±0.00 | 0.3558±0.04| 0.0577±0.00 | 0.4522±0.04 | 0.0522±0.02 | 0.7090±0.32| -0.1059±0.03 |
|
||||
| GATs | Alpha360 | 0.0480±0.00 | 0.3555±0.02| 0.0598±0.00 | 0.4616±0.01 | 0.0857±0.03 | 1.1317±0.42| -0.0917±0.01 |
|
||||
| Linear | Alpha360 | 0.0150±0.00 | 0.1049±0.00| 0.0284±0.00 | 0.1970±0.00 | -0.0659±0.00 | -0.7072±0.00| -0.2955±0.00 |
|
||||
| CatBoost (Liudmila Prokhorenkova, et al.) | Alpha360 | 0.0397±0.00 | 0.2878±0.00| 0.0470±0.00 | 0.3703±0.00 | 0.0342±0.00 | 0.4092±0.00| -0.1057±0.00 |
|
||||
| XGBoost (Tianqi Chen, et al.) | Alpha360 | 0.0400±0.00 | 0.3031±0.00| 0.0461±0.00 | 0.3862±0.00 | 0.0528±0.00 | 0.6307±0.00| -0.1113±0.00 |
|
||||
| LightGBM (Guolin Ke, et al.) | Alpha360 | 0.0399±0.00 | 0.3075±0.00| 0.0492±0.00 | 0.4019±0.00 | 0.0323±0.00 | 0.4370±0.00| -0.0917±0.00 |
|
||||
| MLP | Alpha360 | 0.0285±0.00 | 0.1981±0.02| 0.0402±0.00 | 0.2993±0.02 | 0.0073±0.02 | 0.0880±0.22| -0.1446±0.03 |
|
||||
| GRU (Kyunghyun Cho, et al.) | Alpha360 | 0.0490±0.01 | 0.3787±0.05| 0.0581±0.00 | 0.4664±0.04 | 0.0726±0.02 | 0.9817±0.34| -0.0902±0.03 |
|
||||
| LSTM (Sepp Hochreiter, et al.) | Alpha360 | 0.0443±0.01 | 0.3401±0.05| 0.0536±0.01 | 0.4248±0.05 | 0.0627±0.03 | 0.8441±0.48| -0.0882±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
|
||||
|
||||
## Alpha158 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|---|---|---|---|---|---|---|---|---|
|
||||
| Linear | Alpha158 | 0.0393±0.00 | 0.2980±0.00| 0.0475±0.00 | 0.3546±0.00 | 0.0795±0.00 | 1.0712±0.00| -0.1449±0.00 |
|
||||
| CatBoost | Alpha158 | 0.0503±0.00 | 0.3586±0.00| 0.0483±0.00 | 0.3667±0.00 | 0.1080±0.00 | 1.1567±0.00| -0.0787±0.00 |
|
||||
| XGBoost | Alpha158 | 0.0481±0.00 | 0.3659±0.00| 0.0495±0.00 | 0.4033±0.00 | 0.1111±0.00 | 1.2915±0.00| -0.0893±0.00 |
|
||||
| LightGBM | Alpha158 | 0.0475±0.00 | 0.3979±0.00| 0.0485±0.00 | 0.4123±0.00 | 0.1143±0.00 | 1.2744±0.00| -0.0800±0.00 |
|
||||
| MLP | Alpha158 | 0.0363±0.00 | 0.2770±0.02| 0.0421±0.00 | 0.3167±0.01 | 0.0856±0.01 | 1.0397±0.12| -0.1134±0.01 |
|
||||
| TFT | Alpha158 (with selected 20 features) | 0.0335±0.00 | 0.2009±0.01| 0.0090±0.00 | 0.0553±0.01 | 0.0605±0.01 | 0.5438±0.12| -0.1772±0.03 |
|
||||
| GRU | Alpha158 (with selected 20 features) | 0.0313±0.00 | 0.2427±0.01 | 0.0416±0.00 | 0.3370±0.01 | 0.0335±0.01 | 0.4808±0.22 | -0.1112±0.03 |
|
||||
| LSTM | Alpha158 (with selected 20 features) | 0.0337±0.01 | 0.2562±0.05 | 0.0427±0.01 | 0.3392±0.04 | 0.0269±0.06 | 0.3385±0.74 | -0.1285±0.04 |
|
||||
| ALSTM | Alpha158 (with selected 20 features) | 0.0366±0.00 | 0.2803±0.04 | 0.0478±0.00 | 0.3770±0.02 | 0.0520±0.03 | 0.7115±0.30 | -0.0986±0.01 |
|
||||
| GATs | Alpha158 (with selected 20 features) | 0.0355±0.00 | 0.2576±0.02 | 0.0465±0.00 | 0.3585±0.00 | 0.0509±0.02 | 0.7212±0.22 | -0.0821±0.01 |
|
||||
| CatBoost (Liudmila Prokhorenkova, et al.) | Alpha158 | 0.0503±0.00 | 0.3586±0.00| 0.0483±0.00 | 0.3667±0.00 | 0.1080±0.00 | 1.1561±0.00| -0.0787±0.00 |
|
||||
| XGBoost (Tianqi Chen, et al.) | Alpha158 | 0.0481±0.00 | 0.3659±0.00| 0.0495±0.00 | 0.4033±0.00 | 0.1111±0.00 | 1.2915±0.00| -0.0893±0.00 |
|
||||
| LightGBM (Guolin Ke, et al.) | Alpha158 | 0.0475±0.00 | 0.3979±0.00| 0.0485±0.00 | 0.4123±0.00 | 0.1143±0.00 | 1.2744±0.00| -0.0800±0.00 |
|
||||
| MLP | Alpha158 | 0.0358±0.00 | 0.2738±0.03| 0.0425±0.00 | 0.3221±0.01 | 0.0836±0.02 | 1.0323±0.25| -0.1127±0.02 |
|
||||
| TFT (Bryan Lim, et al.) | Alpha158 (with selected 20 features) | 0.0343±0.00 | 0.2071±0.02| 0.0107±0.00 | 0.0660±0.02 | 0.0623±0.02 | 0.5818±0.20| -0.1762±0.01 |
|
||||
| GRU (Kyunghyun Cho, et al.) | Alpha158 (with selected 20 features) | 0.0311±0.00 | 0.2418±0.04| 0.0425±0.00 | 0.3434±0.02 | 0.0330±0.02 | 0.4805±0.30| -0.1021±0.02 |
|
||||
| LSTM (Sepp Hochreiter, et al.) | Alpha158 (with selected 20 features) | 0.0312±0.00 | 0.2394±0.04| 0.0418±0.00 | 0.3324±0.03 | 0.0298±0.02 | 0.4198±0.33| -0.1348±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# State-Frequency-Memory
|
||||
- State Frequency Memory (SFM) is a novel recurrent network that uses Discrete Fourier Transform to decompose the hidden states of memory cells and capture the multi-frequency trading patterns from past market data to make stock price predictions.
|
||||
- Paper: Stock Price Prediction via Discovering Multi-Frequency Trading Patterns. [https://www.cs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.](https://www.cs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.)
|
||||
- Paper: Stock Price Prediction via Discovering Multi-Frequency Trading Patterns. [http://www.eecs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.](http://www.eecs.ucf.edu/~gqi/publications/kdd2017_stock.pdf)
|
||||
@@ -25,7 +25,7 @@ import os
|
||||
import data_formatters.qlib_Alpha158
|
||||
|
||||
|
||||
class ExperimentConfig(object):
|
||||
class ExperimentConfig:
|
||||
"""Defines experiment configs and paths to outputs.
|
||||
|
||||
Attributes:
|
||||
|
||||
@@ -320,7 +320,7 @@ class InterpretableMultiHeadAttention:
|
||||
return outputs, attn
|
||||
|
||||
|
||||
class TFTDataCache(object):
|
||||
class TFTDataCache:
|
||||
"""Caches data for the TFT."""
|
||||
|
||||
_data_cache = {}
|
||||
@@ -348,7 +348,7 @@ class TFTDataCache(object):
|
||||
|
||||
|
||||
# TFT model definitions.
|
||||
class TemporalFusionTransformer(object):
|
||||
class TemporalFusionTransformer:
|
||||
"""Defines Temporal Fusion Transformer.
|
||||
|
||||
Attributes:
|
||||
@@ -972,7 +972,7 @@ class TemporalFusionTransformer(object):
|
||||
valid_quantiles = self.quantiles
|
||||
output_size = self.output_size
|
||||
|
||||
class QuantileLossCalculator(object):
|
||||
class QuantileLossCalculator:
|
||||
"""Computes the combined quantile loss for prespecified quantiles.
|
||||
|
||||
Attributes:
|
||||
|
||||
@@ -69,9 +69,9 @@ def handler(signum, frame):
|
||||
os.system("kill -9 %d" % os.getpid())
|
||||
|
||||
|
||||
signal.signal(signal.SIGTSTP, handler)
|
||||
signal.signal(signal.SIGINT, handler)
|
||||
|
||||
|
||||
# function to calculate the mean and std of a list in the results dictionary
|
||||
def cal_mean_std(results) -> dict:
|
||||
mean_std = dict()
|
||||
|
||||
@@ -98,6 +98,7 @@ if __name__ == "__main__":
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
"return_order": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
__version__ = "0.6.0.dev"
|
||||
__version__ = "0.6.1.dev"
|
||||
|
||||
|
||||
import os
|
||||
|
||||
@@ -20,17 +20,17 @@ import multiprocessing
|
||||
|
||||
class Config:
|
||||
def __init__(self, default_conf):
|
||||
self.__dict__["_default_config"] = default_conf # avoiding conflictions with __getattr__
|
||||
self.__dict__["_default_config"] = copy.deepcopy(default_conf) # avoiding conflictions with __getattr__
|
||||
self.reset()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.__dict__["_config"][key]
|
||||
|
||||
def __getattr__(self, attr):
|
||||
try:
|
||||
if attr in self.__dict__["_config"]:
|
||||
return self.__dict__["_config"][attr]
|
||||
except KeyError:
|
||||
return AttributeError(f"No such {attr} in self._config")
|
||||
|
||||
raise AttributeError(f"No such {attr} in self._config")
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.__dict__["_config"][key] = value
|
||||
|
||||
@@ -1,9 +1,324 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
from .order import Order
|
||||
from .account import Account
|
||||
from .position import Position
|
||||
from .exchange import Exchange
|
||||
from .report import Report
|
||||
from .backtest import backtest as backtest_func, get_date_range
|
||||
|
||||
import numpy as np
|
||||
import inspect
|
||||
from ...utils import init_instance_by_config
|
||||
from ...log import get_module_logger
|
||||
from ...config import C
|
||||
|
||||
logger = get_module_logger("backtest caller")
|
||||
|
||||
|
||||
def get_strategy(
|
||||
strategy=None,
|
||||
topk=50,
|
||||
margin=0.5,
|
||||
n_drop=5,
|
||||
risk_degree=0.95,
|
||||
str_type="dropout",
|
||||
adjust_dates=None,
|
||||
):
|
||||
"""get_strategy
|
||||
|
||||
There will be 3 ways to return a stratgy. Please follow the code.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
strategy : Strategy()
|
||||
strategy used in backtest.
|
||||
topk : int (Default value: 50)
|
||||
top-N stocks to buy.
|
||||
margin : int or float(Default value: 0.5)
|
||||
- if isinstance(margin, int):
|
||||
|
||||
sell_limit = margin
|
||||
|
||||
- else:
|
||||
|
||||
sell_limit = pred_in_a_day.count() * margin
|
||||
|
||||
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
|
||||
sell_limit should be no less than topk.
|
||||
n_drop : int
|
||||
number of stocks to be replaced in each trading date.
|
||||
risk_degree: float
|
||||
0-1, 0.95 for example, use 95% money to trade.
|
||||
str_type: 'amount', 'weight' or 'dropout'
|
||||
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: Strategy
|
||||
an initialized strategy object
|
||||
"""
|
||||
|
||||
# There will be 3 ways to return a strategy.
|
||||
if strategy is None:
|
||||
# 1) create strategy with param `strategy`
|
||||
str_cls_dict = {
|
||||
"amount": "TopkAmountStrategy",
|
||||
"weight": "TopkWeightStrategy",
|
||||
"dropout": "TopkDropoutStrategy",
|
||||
}
|
||||
logger.info("Create new strategy ")
|
||||
from .. import strategy as strategy_pool
|
||||
|
||||
str_cls = getattr(strategy_pool, str_cls_dict.get(str_type))
|
||||
strategy = str_cls(
|
||||
topk=topk,
|
||||
buffer_margin=margin,
|
||||
n_drop=n_drop,
|
||||
risk_degree=risk_degree,
|
||||
adjust_dates=adjust_dates,
|
||||
)
|
||||
elif isinstance(strategy, (dict, str)):
|
||||
# 2) create strategy with init_instance_by_config
|
||||
logger.info("Create new strategy ")
|
||||
strategy = init_instance_by_config(strategy)
|
||||
|
||||
from ..strategy.strategy import BaseStrategy
|
||||
|
||||
# else: nothing happens. 3) Use the strategy directly
|
||||
if not isinstance(strategy, BaseStrategy):
|
||||
raise TypeError("Strategy not supported")
|
||||
return strategy
|
||||
|
||||
|
||||
def get_exchange(
|
||||
pred,
|
||||
exchange=None,
|
||||
subscribe_fields=[],
|
||||
open_cost=0.0015,
|
||||
close_cost=0.0025,
|
||||
min_cost=5.0,
|
||||
trade_unit=None,
|
||||
limit_threshold=None,
|
||||
deal_price=None,
|
||||
extract_codes=False,
|
||||
shift=1,
|
||||
):
|
||||
"""get_exchange
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
# exchange related arguments
|
||||
exchange: Exchange().
|
||||
subscribe_fields: list
|
||||
subscribe fields.
|
||||
open_cost : float
|
||||
open transaction cost.
|
||||
close_cost : float
|
||||
close transaction cost.
|
||||
min_cost : float
|
||||
min transaction cost.
|
||||
trade_unit : int
|
||||
100 for China A.
|
||||
deal_price: str
|
||||
dealing price type: 'close', 'open', 'vwap'.
|
||||
limit_threshold : float
|
||||
limit move 0.1 (10%) for example, long and short with same limit.
|
||||
extract_codes: bool
|
||||
will we pass the codes extracted from the pred to the exchange.
|
||||
NOTE: This will be faster with offline qlib.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: Exchange
|
||||
an initialized Exchange object
|
||||
"""
|
||||
|
||||
if trade_unit is None:
|
||||
trade_unit = C.trade_unit
|
||||
if limit_threshold is None:
|
||||
limit_threshold = C.limit_threshold
|
||||
if deal_price is None:
|
||||
deal_price = C.deal_price
|
||||
if exchange is None:
|
||||
logger.info("Create new exchange")
|
||||
# handle exception for deal_price
|
||||
if deal_price[0] != "$":
|
||||
deal_price = "$" + deal_price
|
||||
if extract_codes:
|
||||
codes = sorted(pred.index.get_level_values("instrument").unique())
|
||||
else:
|
||||
codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks
|
||||
|
||||
dates = sorted(pred.index.get_level_values("datetime").unique())
|
||||
dates = np.append(dates, get_date_range(dates[-1], left_shift=1, right_shift=shift))
|
||||
|
||||
exchange = Exchange(
|
||||
trade_dates=dates,
|
||||
codes=codes,
|
||||
deal_price=deal_price,
|
||||
subscribe_fields=subscribe_fields,
|
||||
limit_threshold=limit_threshold,
|
||||
open_cost=open_cost,
|
||||
close_cost=close_cost,
|
||||
min_cost=min_cost,
|
||||
trade_unit=trade_unit,
|
||||
)
|
||||
return exchange
|
||||
|
||||
|
||||
def get_executor(
|
||||
executor=None,
|
||||
trade_exchange=None,
|
||||
verbose=True,
|
||||
):
|
||||
"""get_executor
|
||||
|
||||
There will be 3 ways to return a executor. Please follow the code.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
executor : BaseExecutor
|
||||
executor used in backtest.
|
||||
trade_exchange : Exchange
|
||||
exchange used in executor
|
||||
verbose : bool
|
||||
whether to print log.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: BaseExecutor
|
||||
an initialized BaseExecutor object
|
||||
"""
|
||||
|
||||
# There will be 3 ways to return a executor.
|
||||
if executor is None:
|
||||
# 1) create executor with param `executor`
|
||||
logger.info("Create new executor ")
|
||||
from ..online.executor import SimulatorExecutor
|
||||
|
||||
executor = SimulatorExecutor(trade_exchange=trade_exchange, verbose=verbose)
|
||||
elif isinstance(executor, (dict, str)):
|
||||
# 2) create executor with config
|
||||
logger.info("Create new executor ")
|
||||
executor = init_instance_by_config(executor)
|
||||
|
||||
from ..online.executor import BaseExecutor
|
||||
|
||||
# 3) Use the executor directly
|
||||
if not isinstance(executor, BaseExecutor):
|
||||
raise TypeError("Executor not supported")
|
||||
return executor
|
||||
|
||||
|
||||
# This is the API for compatibility for legacy code
|
||||
def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, return_order=False, **kwargs):
|
||||
"""This function will help you set a reasonable Exchange and provide default value for strategy
|
||||
Parameters
|
||||
----------
|
||||
|
||||
- **backtest workflow related or commmon arguments**
|
||||
|
||||
pred : pandas.DataFrame
|
||||
predict should has <datetime, instrument> index and one `score` column.
|
||||
account : float
|
||||
init account value.
|
||||
shift : int
|
||||
whether to shift prediction by one day.
|
||||
benchmark : str
|
||||
benchmark code, default is SH000905 CSI 500.
|
||||
verbose : bool
|
||||
whether to print log.
|
||||
return_order : bool
|
||||
whether to return order list
|
||||
|
||||
- **strategy related arguments**
|
||||
|
||||
strategy : Strategy()
|
||||
strategy used in backtest.
|
||||
topk : int (Default value: 50)
|
||||
top-N stocks to buy.
|
||||
margin : int or float(Default value: 0.5)
|
||||
- if isinstance(margin, int):
|
||||
|
||||
sell_limit = margin
|
||||
|
||||
- else:
|
||||
|
||||
sell_limit = pred_in_a_day.count() * margin
|
||||
|
||||
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
|
||||
sell_limit should be no less than topk.
|
||||
n_drop : int
|
||||
number of stocks to be replaced in each trading date.
|
||||
risk_degree: float
|
||||
0-1, 0.95 for example, use 95% money to trade.
|
||||
str_type: 'amount', 'weight' or 'dropout'
|
||||
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
|
||||
|
||||
- **exchange related arguments**
|
||||
|
||||
exchange: Exchange()
|
||||
pass the exchange for speeding up.
|
||||
subscribe_fields: list
|
||||
subscribe fields.
|
||||
open_cost : float
|
||||
open transaction cost. The default value is 0.002(0.2%).
|
||||
close_cost : float
|
||||
close transaction cost. The default value is 0.002(0.2%).
|
||||
min_cost : float
|
||||
min transaction cost.
|
||||
trade_unit : int
|
||||
100 for China A.
|
||||
deal_price: str
|
||||
dealing price type: 'close', 'open', 'vwap'.
|
||||
limit_threshold : float
|
||||
limit move 0.1 (10%) for example, long and short with same limit.
|
||||
extract_codes: bool
|
||||
will we pass the codes extracted from the pred to the exchange.
|
||||
|
||||
.. note:: This will be faster with offline qlib.
|
||||
|
||||
- **executor related arguments**
|
||||
|
||||
executor : BaseExecutor()
|
||||
executor used in backtest.
|
||||
verbose : bool
|
||||
whether to print log.
|
||||
|
||||
"""
|
||||
# check strategy:
|
||||
spec = inspect.getfullargspec(get_strategy)
|
||||
str_args = {k: v for k, v in kwargs.items() if k in spec.args}
|
||||
strategy = get_strategy(**str_args)
|
||||
|
||||
# init exchange:
|
||||
spec = inspect.getfullargspec(get_exchange)
|
||||
ex_args = {k: v for k, v in kwargs.items() if k in spec.args}
|
||||
trade_exchange = get_exchange(pred, **ex_args)
|
||||
|
||||
# init executor:
|
||||
executor = get_executor(executor=kwargs.get("executor"), trade_exchange=trade_exchange, verbose=verbose)
|
||||
|
||||
# run backtest
|
||||
report_dict = backtest_func(
|
||||
pred=pred,
|
||||
strategy=strategy,
|
||||
executor=executor,
|
||||
trade_exchange=trade_exchange,
|
||||
shift=shift,
|
||||
verbose=verbose,
|
||||
account=account,
|
||||
benchmark=benchmark,
|
||||
return_order=return_order,
|
||||
)
|
||||
# for compatibility of the old API. return the dict positions
|
||||
|
||||
positions = report_dict.get("positions")
|
||||
report_dict.update({"positions": {k: p.position for k, p in positions.items()}})
|
||||
return report_dict
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from ...utils import get_date_by_shift, get_date_range
|
||||
from ..online.executor import SimulatorExecutor
|
||||
from ...data import D
|
||||
from .account import Account
|
||||
from ...config import C
|
||||
@@ -15,7 +14,7 @@ from ...data.dataset.utils import get_level_index
|
||||
LOG = get_module_logger("backtest")
|
||||
|
||||
|
||||
def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark):
|
||||
def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, benchmark, return_order):
|
||||
"""Parameters
|
||||
----------
|
||||
pred : pandas.DataFrame
|
||||
@@ -69,9 +68,9 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
|
||||
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
|
||||
bench = _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean()
|
||||
|
||||
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], shift=shift))
|
||||
executor = SimulatorExecutor(trade_exchange, verbose=verbose)
|
||||
|
||||
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift))
|
||||
if return_order:
|
||||
multi_order_list = []
|
||||
# trading apart
|
||||
for pred_date, trade_date in zip(predict_dates, trade_dates):
|
||||
# for loop predict date and trading date
|
||||
@@ -103,6 +102,8 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
|
||||
)
|
||||
else:
|
||||
order_list = []
|
||||
if return_order:
|
||||
multi_order_list.append((trade_account, order_list, trade_date))
|
||||
# 4. Get result after executing order list
|
||||
# NOTE: The following operation will modify order.amount.
|
||||
# NOTE: If it is buy and the cash is insufficient, the tradable amount will be recalculated
|
||||
@@ -115,7 +116,11 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
|
||||
report_df = trade_account.report.generate_report_dataframe()
|
||||
report_df["bench"] = bench
|
||||
positions = trade_account.get_positions()
|
||||
return report_df, positions
|
||||
|
||||
report_dict = {"report_df": report_df, "positions": positions}
|
||||
if return_order:
|
||||
report_dict.update({"order_list": multi_order_list})
|
||||
return report_dict
|
||||
|
||||
|
||||
def update_account(trade_account, trade_info, trade_exchange, trade_date):
|
||||
|
||||
@@ -6,17 +6,16 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import inspect
|
||||
import warnings
|
||||
from ..log import get_module_logger
|
||||
from . import strategy as strategy_pool
|
||||
from .strategy.strategy import BaseStrategy
|
||||
from .backtest.exchange import Exchange
|
||||
from .backtest.backtest import backtest as backtest_func, get_date_range
|
||||
from .backtest import get_exchange, backtest as backtest_func
|
||||
from .backtest.backtest import get_date_range
|
||||
|
||||
from ..data import D
|
||||
from ..config import C
|
||||
from ..data.dataset.utils import get_level_index
|
||||
|
||||
|
||||
logger = get_module_logger("Evaluate")
|
||||
|
||||
|
||||
@@ -46,144 +45,6 @@ def risk_analysis(r, N=252):
|
||||
return res
|
||||
|
||||
|
||||
def get_strategy(
|
||||
strategy=None,
|
||||
topk=50,
|
||||
margin=0.5,
|
||||
n_drop=5,
|
||||
risk_degree=0.95,
|
||||
str_type="amount",
|
||||
adjust_dates=None,
|
||||
):
|
||||
"""get_strategy
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
strategy : Strategy()
|
||||
strategy used in backtest.
|
||||
topk : int (Default value: 50)
|
||||
top-N stocks to buy.
|
||||
margin : int or float(Default value: 0.5)
|
||||
- if isinstance(margin, int):
|
||||
|
||||
sell_limit = margin
|
||||
|
||||
- else:
|
||||
|
||||
sell_limit = pred_in_a_day.count() * margin
|
||||
|
||||
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
|
||||
sell_limit should be no less than topk.
|
||||
n_drop : int
|
||||
number of stocks to be replaced in each trading date.
|
||||
risk_degree: float
|
||||
0-1, 0.95 for example, use 95% money to trade.
|
||||
str_type: 'amount', 'weight' or 'dropout'
|
||||
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: Strategy
|
||||
an initialized strategy object
|
||||
"""
|
||||
if strategy is None:
|
||||
str_cls_dict = {
|
||||
"amount": "TopkAmountStrategy",
|
||||
"weight": "TopkWeightStrategy",
|
||||
"dropout": "TopkDropoutStrategy",
|
||||
}
|
||||
logger.info("Create new streategy ")
|
||||
str_cls = getattr(strategy_pool, str_cls_dict.get(str_type))
|
||||
strategy = str_cls(
|
||||
topk=topk,
|
||||
buffer_margin=margin,
|
||||
n_drop=n_drop,
|
||||
risk_degree=risk_degree,
|
||||
adjust_dates=adjust_dates,
|
||||
)
|
||||
if not isinstance(strategy, BaseStrategy):
|
||||
raise TypeError("Strategy not supported")
|
||||
return strategy
|
||||
|
||||
|
||||
def get_exchange(
|
||||
pred,
|
||||
exchange=None,
|
||||
subscribe_fields=[],
|
||||
open_cost=0.0015,
|
||||
close_cost=0.0025,
|
||||
min_cost=5.0,
|
||||
trade_unit=None,
|
||||
limit_threshold=None,
|
||||
deal_price=None,
|
||||
extract_codes=False,
|
||||
shift=1,
|
||||
):
|
||||
"""get_exchange
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
# exchange related arguments
|
||||
exchange: Exchange().
|
||||
subscribe_fields: list
|
||||
subscribe fields.
|
||||
open_cost : float
|
||||
open transaction cost.
|
||||
close_cost : float
|
||||
close transaction cost.
|
||||
min_cost : float
|
||||
min transaction cost.
|
||||
trade_unit : int
|
||||
100 for China A.
|
||||
deal_price: str
|
||||
dealing price type: 'close', 'open', 'vwap'.
|
||||
limit_threshold : float
|
||||
limit move 0.1 (10%) for example, long and short with same limit.
|
||||
extract_codes: bool
|
||||
will we pass the codes extracted from the pred to the exchange.
|
||||
NOTE: This will be faster with offline qlib.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: Exchange
|
||||
an initialized Exchange object
|
||||
"""
|
||||
|
||||
if trade_unit is None:
|
||||
trade_unit = C.trade_unit
|
||||
if limit_threshold is None:
|
||||
limit_threshold = C.limit_threshold
|
||||
if deal_price is None:
|
||||
deal_price = C.deal_price
|
||||
if exchange is None:
|
||||
logger.info("Create new exchange")
|
||||
# handle exception for deal_price
|
||||
if deal_price[0] != "$":
|
||||
deal_price = "$" + deal_price
|
||||
if extract_codes:
|
||||
codes = sorted(pred.index.get_level_values("instrument").unique())
|
||||
else:
|
||||
codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks
|
||||
|
||||
dates = sorted(pred.index.get_level_values("datetime").unique())
|
||||
dates = np.append(dates, get_date_range(dates[-1], shift=shift))
|
||||
|
||||
exchange = Exchange(
|
||||
trade_dates=dates,
|
||||
codes=codes,
|
||||
deal_price=deal_price,
|
||||
subscribe_fields=subscribe_fields,
|
||||
limit_threshold=limit_threshold,
|
||||
open_cost=open_cost,
|
||||
close_cost=close_cost,
|
||||
min_cost=min_cost,
|
||||
trade_unit=trade_unit,
|
||||
)
|
||||
return exchange
|
||||
|
||||
|
||||
# This is the API for compatibility for legacy code
|
||||
def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **kwargs):
|
||||
"""This function will help you set a reasonable Exchange and provide default value for strategy
|
||||
@@ -249,30 +110,22 @@ def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **k
|
||||
will we pass the codes extracted from the pred to the exchange.
|
||||
|
||||
.. note:: This will be faster with offline qlib.
|
||||
|
||||
- **executor related arguments**
|
||||
|
||||
executor : BaseExecutor()
|
||||
executor used in backtest.
|
||||
verbose : bool
|
||||
whether to print log.
|
||||
|
||||
"""
|
||||
# check strategy:
|
||||
spec = inspect.getfullargspec(get_strategy)
|
||||
str_args = {k: v for k, v in kwargs.items() if k in spec.args}
|
||||
strategy = get_strategy(**str_args)
|
||||
|
||||
# init exchange:
|
||||
spec = inspect.getfullargspec(get_exchange)
|
||||
ex_args = {k: v for k, v in kwargs.items() if k in spec.args}
|
||||
trade_exchange = get_exchange(pred, **ex_args)
|
||||
|
||||
# run backtest
|
||||
report_df, positions = backtest_func(
|
||||
pred=pred,
|
||||
strategy=strategy,
|
||||
trade_exchange=trade_exchange,
|
||||
shift=shift,
|
||||
verbose=verbose,
|
||||
account=account,
|
||||
benchmark=benchmark,
|
||||
warnings.warn(
|
||||
"this function is deprecated, please use backtest function in qlib.contrib.backtest", DeprecationWarning
|
||||
)
|
||||
# for compatibility of the old API. return the dict positions
|
||||
positions = {k: p.position for k, p in positions.items()}
|
||||
return report_df, positions
|
||||
report_dict = backtest_func(
|
||||
pred=pred, account=account, shift=shift, benchmark=benchmark, verbose=verbose, return_order=False, **kwargs
|
||||
)
|
||||
return report_dict.get("report_df"), report_dict.get("positions")
|
||||
|
||||
|
||||
def long_short_backtest(
|
||||
@@ -340,7 +193,7 @@ def long_short_backtest(
|
||||
|
||||
_pred_dates = pred.index.get_level_values(level="datetime")
|
||||
predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max())
|
||||
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], shift=shift))
|
||||
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift))
|
||||
|
||||
long_returns = {}
|
||||
short_returns = {}
|
||||
|
||||
@@ -204,8 +204,8 @@ class ALSTM(Model):
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
dl_train = dataset.prepare("train", data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", data_key=DataHandlerLP.DK_L)
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
@@ -260,7 +260,7 @@ class ALSTM(Model):
|
||||
if not self._fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", data_key=DataHandlerLP.DK_I)
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test.config(fillna_type="ffill+bfill")
|
||||
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
|
||||
self.ALSTM_model.eval()
|
||||
|
||||
@@ -249,8 +249,8 @@ class GATs(Model):
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
dl_train = dataset.prepare("train", data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", data_key=DataHandlerLP.DK_L)
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
@@ -332,7 +332,7 @@ class GATs(Model):
|
||||
if not self._fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", data_key=DataHandlerLP.DK_I)
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test.config(fillna_type="ffill+bfill")
|
||||
sampler_test = DailyBatchSampler(dl_test)
|
||||
test_loader = DataLoader(dl_test, sampler=sampler_test, num_workers=self.n_jobs)
|
||||
|
||||
@@ -204,8 +204,8 @@ class GRU(Model):
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
dl_train = dataset.prepare("train", data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", data_key=DataHandlerLP.DK_L)
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
@@ -260,7 +260,7 @@ class GRU(Model):
|
||||
if not self._fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", data_key=DataHandlerLP.DK_I)
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test.config(fillna_type="ffill+bfill")
|
||||
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
|
||||
self.GRU_model.eval()
|
||||
|
||||
@@ -204,8 +204,8 @@ class LSTM(Model):
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
dl_train = dataset.prepare("train", data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", data_key=DataHandlerLP.DK_L)
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
@@ -260,7 +260,7 @@ class LSTM(Model):
|
||||
if not self._fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", data_key=DataHandlerLP.DK_I)
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test.config(fillna_type="ffill+bfill")
|
||||
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
|
||||
self.LSTM_model.eval()
|
||||
|
||||
@@ -296,7 +296,7 @@ class DNNModelPytorch(Model):
|
||||
self._fitted = True
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
class AverageMeter:
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -464,7 +464,7 @@ class SFM(Model):
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
class AverageMeter:
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -21,7 +21,7 @@ from .executor import SimulatorExecutor
|
||||
from .executor import save_score_series, load_score_series
|
||||
|
||||
|
||||
class Operator(object):
|
||||
class Operator:
|
||||
def __init__(self, client: str):
|
||||
"""
|
||||
Parameters
|
||||
|
||||
@@ -38,7 +38,7 @@ def _calculate_report_data(df: pd.DataFrame) -> pd.DataFrame:
|
||||
:param df:
|
||||
:return:
|
||||
"""
|
||||
|
||||
index_names = df.index.names
|
||||
df.index = df.index.strftime("%Y-%m-%d")
|
||||
|
||||
report_df = pd.DataFrame()
|
||||
@@ -58,6 +58,8 @@ def _calculate_report_data(df: pd.DataFrame) -> pd.DataFrame:
|
||||
|
||||
report_df["turnover"] = df["turnover"]
|
||||
report_df.sort_index(ascending=True, inplace=True)
|
||||
|
||||
report_df.index.names = index_names
|
||||
return report_df
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from plotly.figure_factory import create_distplot
|
||||
from ...utils import get_module_by_module_path
|
||||
|
||||
|
||||
class BaseGraph(object):
|
||||
class BaseGraph:
|
||||
""""""
|
||||
|
||||
_name = None
|
||||
@@ -204,7 +204,7 @@ class HistogramGraph(BaseGraph):
|
||||
return _data
|
||||
|
||||
|
||||
class SubplotsGraph(object):
|
||||
class SubplotsGraph:
|
||||
"""Create subplots same as df.plot(subplots=True)
|
||||
|
||||
Simple package for `plotly.tools.subplots`
|
||||
|
||||
@@ -30,7 +30,7 @@ class BaseStrategy:
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
score_series : pd.Seires
|
||||
score_series : pd.Series
|
||||
stock_id , score.
|
||||
current : Position()
|
||||
current state of position.
|
||||
|
||||
@@ -6,7 +6,7 @@ import copy
|
||||
import os
|
||||
|
||||
|
||||
class TunerConfigManager(object):
|
||||
class TunerConfigManager:
|
||||
def __init__(self, config_path):
|
||||
|
||||
if not config_path:
|
||||
@@ -27,7 +27,7 @@ class TunerConfigManager(object):
|
||||
self.qlib_client_config = config.get("qlib_client", dict())
|
||||
|
||||
|
||||
class PipelineExperimentConfig(object):
|
||||
class PipelineExperimentConfig:
|
||||
def __init__(self, config, TUNER_CONFIG_MANAGER):
|
||||
"""
|
||||
:param config: The config dict for tuner experiment
|
||||
@@ -53,7 +53,7 @@ class PipelineExperimentConfig(object):
|
||||
yaml.dump(TUNER_CONFIG_MANAGER.config, fp)
|
||||
|
||||
|
||||
class OptimizationConfig(object):
|
||||
class OptimizationConfig:
|
||||
def __init__(self, config, TUNER_CONFIG_MANAGER):
|
||||
|
||||
self.report_type = config.get("report_type", "pred_long")
|
||||
|
||||
@@ -11,7 +11,7 @@ from ...log import get_module_logger, TimeInspector
|
||||
from ...utils import get_module_by_module_path
|
||||
|
||||
|
||||
class Pipeline(object):
|
||||
class Pipeline:
|
||||
|
||||
GLOBAL_BEST_PARAMS_NAME = "global_best_params.json"
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from hyperopt import fmin, tpe
|
||||
from hyperopt import STATUS_OK, STATUS_FAIL
|
||||
|
||||
|
||||
class Tuner(object):
|
||||
class Tuner:
|
||||
def __init__(self, tuner_config, optim_config):
|
||||
|
||||
self.logger = get_module_logger("Tuner", sh_level=logging.INFO)
|
||||
|
||||
@@ -8,7 +8,7 @@ from libc.math cimport sqrt, isnan, NAN
|
||||
from libcpp.vector cimport vector
|
||||
|
||||
|
||||
cdef class Expanding(object):
|
||||
cdef class Expanding:
|
||||
"""1-D array expanding"""
|
||||
cdef vector[double] barv
|
||||
cdef int na_count
|
||||
|
||||
@@ -8,7 +8,7 @@ from libc.math cimport sqrt, isnan, NAN
|
||||
from libcpp.deque cimport deque
|
||||
|
||||
|
||||
cdef class Rolling(object):
|
||||
cdef class Rolling:
|
||||
"""1-D array rolling"""
|
||||
cdef int window
|
||||
cdef deque[double] barv
|
||||
|
||||
@@ -13,6 +13,7 @@ import pickle
|
||||
import traceback
|
||||
import redis_lock
|
||||
import contextlib
|
||||
import abc
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -39,36 +40,100 @@ class QlibCacheException(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class MemCacheUnit(OrderedDict):
|
||||
class MemCacheUnit(abc.ABC):
|
||||
"""Memory Cache Unit."""
|
||||
|
||||
# TODO: use min_heap to replace ordereddict for better performance
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.size_limit = kwargs.pop("size_limit", None)
|
||||
# limit_type: check size_limit type, length(call fun: len) or size(call fun: sys.getsizeof)
|
||||
self.limit_type = kwargs.pop("limit_type", "length")
|
||||
super(MemCacheUnit, self).__init__(*args, **kwargs)
|
||||
self._check_size_limit()
|
||||
self.size_limit = kwargs.pop("size_limit", 0)
|
||||
self._size = 0
|
||||
self.od = OrderedDict()
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
super(MemCacheUnit, self).__setitem__(key, value)
|
||||
self._check_size_limit()
|
||||
# TODO: thread safe?__setitem__ failure might cause inconsistent size?
|
||||
|
||||
def __getitem__(self, key):
|
||||
value = super(MemCacheUnit, self).__getitem__(key)
|
||||
super(MemCacheUnit, self).__delitem__(key)
|
||||
super(MemCacheUnit, self).__setitem__(key, value)
|
||||
return value
|
||||
# precalculate the size after od.__setitem__
|
||||
self._adjust_size(key, value)
|
||||
|
||||
def _check_size_limit(self):
|
||||
if self.size_limit is not None:
|
||||
get_cur_size = lambda x: len(x) if self.limit_type == "length" else sum(map(sys.getsizeof, x.values()))
|
||||
while get_cur_size(self) > self.size_limit:
|
||||
self.od.__setitem__(key, value)
|
||||
|
||||
# move the key to end,make it latest
|
||||
self.od.move_to_end(key)
|
||||
|
||||
if self.limited:
|
||||
# pop the oldest items beyond size limit
|
||||
while self._size > self.size_limit:
|
||||
self.popitem(last=False)
|
||||
|
||||
def __getitem__(self, key):
|
||||
v = self.od.__getitem__(key)
|
||||
self.od.move_to_end(key)
|
||||
return v
|
||||
|
||||
class MemCache(object):
|
||||
def __contains__(self, key):
|
||||
return key in self.od
|
||||
|
||||
def __len__(self):
|
||||
return self.od.__len__()
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}<size_limit:{self.size_limit if self.limited else 'no limit'} total_size:{self._size}>\n{self.od.__repr__()}"
|
||||
|
||||
def set_limit_size(self, limit):
|
||||
self.size_limit = limit
|
||||
|
||||
@property
|
||||
def limited(self):
|
||||
"""whether memory cache is limited"""
|
||||
return self.size_limit > 0
|
||||
|
||||
@property
|
||||
def total_size(self):
|
||||
return self._size
|
||||
|
||||
def clear(self):
|
||||
self._size = 0
|
||||
self.od.clear()
|
||||
|
||||
def popitem(self, last=True):
|
||||
k, v = self.od.popitem(last=last)
|
||||
self._size -= self._get_value_size(v)
|
||||
|
||||
return k, v
|
||||
|
||||
def pop(self, key):
|
||||
v = self.od.pop(key)
|
||||
self._size -= self._get_value_size(v)
|
||||
|
||||
return v
|
||||
|
||||
def _adjust_size(self, key, value):
|
||||
if key in self.od:
|
||||
self._size -= self._get_value_size(self.od[key])
|
||||
|
||||
self._size += self._get_value_size(value)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_value_size(self, value):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MemCacheLengthUnit(MemCacheUnit):
|
||||
def __init__(self, size_limit=0):
|
||||
super().__init__(size_limit=size_limit)
|
||||
|
||||
def _get_value_size(self, value):
|
||||
return 1
|
||||
|
||||
|
||||
class MemCacheSizeofUnit(MemCacheUnit):
|
||||
def __init__(self, size_limit=0):
|
||||
super().__init__(size_limit=size_limit)
|
||||
|
||||
def _get_value_size(self, value):
|
||||
return sys.getsizeof(value)
|
||||
|
||||
|
||||
class MemCache:
|
||||
"""Memory cache."""
|
||||
|
||||
def __init__(self, mem_cache_size_limit=None, limit_type="length"):
|
||||
@@ -79,21 +144,19 @@ class MemCache(object):
|
||||
mem_cache_size_limit: cache max size.
|
||||
limit_type: length or sizeof; length(call fun: len), size(call fun: sys.getsizeof).
|
||||
"""
|
||||
if limit_type not in ["length", "sizeof"]:
|
||||
|
||||
size_limit = C.mem_cache_size_limit if mem_cache_size_limit is None else mem_cache_size_limit
|
||||
|
||||
if limit_type == "length":
|
||||
klass = MemCacheLengthUnit
|
||||
elif limit_type == "sizeof":
|
||||
klass = MemCacheSizeofUnit
|
||||
else:
|
||||
raise ValueError(f"limit_type must be length or sizeof, your limit_type is {limit_type}")
|
||||
|
||||
self.__calendar_mem_cache = MemCacheUnit(
|
||||
size_limit=C.mem_cache_size_limit if mem_cache_size_limit is None else mem_cache_size_limit,
|
||||
limit_type=limit_type,
|
||||
)
|
||||
self.__instrument_mem_cache = MemCacheUnit(
|
||||
size_limit=C.mem_cache_size_limit if mem_cache_size_limit is None else mem_cache_size_limit,
|
||||
limit_type=limit_type,
|
||||
)
|
||||
self.__feature_mem_cache = MemCacheUnit(
|
||||
size_limit=C.mem_cache_size_limit if mem_cache_size_limit is None else mem_cache_size_limit,
|
||||
limit_type=limit_type,
|
||||
)
|
||||
self.__calendar_mem_cache = klass(size_limit)
|
||||
self.__instrument_mem_cache = klass(size_limit)
|
||||
self.__feature_mem_cache = klass(size_limit)
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key == "c":
|
||||
@@ -140,7 +203,7 @@ class MemCacheExpire:
|
||||
return value, expire
|
||||
|
||||
|
||||
class CacheUtils(object):
|
||||
class CacheUtils:
|
||||
LOCK_ID = "QLIB"
|
||||
|
||||
@staticmethod
|
||||
@@ -224,7 +287,7 @@ class CacheUtils(object):
|
||||
current_cache_wlock.release()
|
||||
|
||||
|
||||
class BaseProviderCache(object):
|
||||
class BaseProviderCache:
|
||||
"""Provider cache base class"""
|
||||
|
||||
def __init__(self, provider):
|
||||
|
||||
@@ -12,7 +12,7 @@ from ..log import get_module_logger
|
||||
import pickle
|
||||
|
||||
|
||||
class Client(object):
|
||||
class Client:
|
||||
"""A client class
|
||||
|
||||
Provide the connection tool functions for ClientProvider.
|
||||
|
||||
@@ -1051,7 +1051,7 @@ def register_all_wrappers(C):
|
||||
if getattr(C, "calendar_cache", None) is not None:
|
||||
_calendar_provider = init_instance_by_config(C.calendar_cache, module, provide=_calendar_provider)
|
||||
register_wrapper(Cal, _calendar_provider, "qlib.data")
|
||||
logger.debug(f"registering Cal {C.calendar_provider}-{C.calenar_cache}")
|
||||
logger.debug(f"registering Cal {C.calendar_provider}-{C.calendar_cache}")
|
||||
|
||||
register_wrapper(Inst, C.instrument_provider, "qlib.data")
|
||||
logger.debug(f"registering Inst {C.instrument_provider}")
|
||||
|
||||
@@ -18,7 +18,9 @@ try:
|
||||
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
|
||||
from ._libs.expanding import expanding_slope, expanding_rsquare, expanding_resi
|
||||
except ImportError as err:
|
||||
print("Do not import qlib package in the repository directory!")
|
||||
print(
|
||||
"#### Do not import qlib package in the repository directory in case of importing qlib from . without compiling #####"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@@ -96,6 +98,15 @@ class Sign(ElemOperator):
|
||||
def __init__(self, feature):
|
||||
super(Sign, self).__init__(feature, "sign")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
"""
|
||||
To avoid error raised by bool type input, we transform the data into float32.
|
||||
"""
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
# TODO: More precision types should be configurable
|
||||
series = series.astype(np.float32)
|
||||
return getattr(np, self.func)(series)
|
||||
|
||||
|
||||
class Log(ElemOperator):
|
||||
"""Feature Log
|
||||
|
||||
@@ -36,7 +36,7 @@ def get_module_logger(module_name, level=None):
|
||||
return module_logger
|
||||
|
||||
|
||||
class TimeInspector(object):
|
||||
class TimeInspector:
|
||||
|
||||
timer_logger = get_module_logger("timer", level=logging.WARNING)
|
||||
|
||||
|
||||
@@ -30,11 +30,6 @@ class Model(BaseModel):
|
||||
The attribute names of learned model should `not` start with '_'. So that the model could be
|
||||
dumped to disk.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset : Dataset
|
||||
dataset will generate the processed data from model training.
|
||||
|
||||
The following code example shows how to retrieve `x_train`, `y_train` and `w_train` from the `dataset`:
|
||||
|
||||
.. code-block:: Python
|
||||
@@ -53,6 +48,12 @@ class Model(BaseModel):
|
||||
except KeyError as e:
|
||||
w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)
|
||||
w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset : Dataset
|
||||
dataset will generate the processed data from model training.
|
||||
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import scipy.optimize as so
|
||||
from typing import Optional, Union, Callable, List
|
||||
|
||||
|
||||
class PortfolioOptimizer(object):
|
||||
class PortfolioOptimizer:
|
||||
"""Portfolio Optimizer
|
||||
|
||||
The following optimization algorithms are supported:
|
||||
|
||||
@@ -31,20 +31,20 @@ class GetData:
|
||||
if resp.status_code != 200:
|
||||
raise requests.exceptions.HTTPError()
|
||||
|
||||
chuck_size = 1024
|
||||
chunk_size = 1024
|
||||
logger.warning(
|
||||
f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)"
|
||||
)
|
||||
logger.info(f"{file_name} downloading......")
|
||||
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
|
||||
with target_path.open("wb") as fp:
|
||||
for chuck in resp.iter_content(chunk_size=chuck_size):
|
||||
fp.write(chuck)
|
||||
p_bar.update(chuck_size)
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
fp.write(chunk)
|
||||
p_bar.update(chunk_size)
|
||||
|
||||
self._unzip(target_path, target_dir)
|
||||
if self.delete_zip_file:
|
||||
target_path.unlike()
|
||||
target_path.unlink()
|
||||
|
||||
@staticmethod
|
||||
def _unzip(file_path: Path, target_dir: Path):
|
||||
|
||||
@@ -281,8 +281,10 @@ def compare_dict_value(src_data: dict, dst_data: dict):
|
||||
def create_save_path(save_path=None):
|
||||
"""Create save path
|
||||
|
||||
:param save_path:
|
||||
:return:
|
||||
Parameters
|
||||
----------
|
||||
save_path: str
|
||||
|
||||
"""
|
||||
if save_path:
|
||||
if not os.path.exists(save_path):
|
||||
@@ -473,30 +475,28 @@ def is_tradable_date(cur_date):
|
||||
return str(cur_date.date()) == str(D.calendar(start_time=cur_date, future=True)[0].date())
|
||||
|
||||
|
||||
def get_date_range(trading_date, shift, future=False):
|
||||
def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
|
||||
"""get trading date range by shift
|
||||
|
||||
:param trading_date:
|
||||
:param shift: int
|
||||
:param future: bool
|
||||
:return:
|
||||
Parameters
|
||||
----------
|
||||
trading_date: pd.Timestamp
|
||||
left_shift: int
|
||||
right_shift: int
|
||||
future: bool
|
||||
|
||||
"""
|
||||
|
||||
from ..data import D
|
||||
|
||||
calendar = D.calendar(future=future)
|
||||
if pd.to_datetime(trading_date) not in list(calendar):
|
||||
raise ValueError("{} is not trading day!".format(str(trading_date)))
|
||||
day_index = bisect.bisect_left(calendar, trading_date)
|
||||
if 0 <= (day_index + shift) < len(calendar):
|
||||
if shift > 0:
|
||||
return calendar[day_index + 1 : day_index + 1 + shift]
|
||||
else:
|
||||
return calendar[day_index + shift : day_index]
|
||||
else:
|
||||
return calendar
|
||||
start = get_date_by_shift(trading_date, left_shift, future=future)
|
||||
end = get_date_by_shift(trading_date, right_shift, future=future)
|
||||
|
||||
calendar = D.calendar(start, end, future=future)
|
||||
return calendar
|
||||
|
||||
|
||||
def get_date_by_shift(trading_date, shift, future=False):
|
||||
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True):
|
||||
"""get trading date with shift bias wil cur_date
|
||||
e.g. : shift == 1, return next trading date
|
||||
shift == -1, return previous trading date
|
||||
@@ -504,8 +504,22 @@ def get_date_by_shift(trading_date, shift, future=False):
|
||||
trading_date : pandas.Timestamp
|
||||
current date
|
||||
shift : int
|
||||
clip_shift: bool
|
||||
|
||||
"""
|
||||
return get_date_range(trading_date, shift, future)[0 if shift < 0 else -1] if shift != 0 else trading_date
|
||||
from qlib.data import D
|
||||
|
||||
cal = D.calendar(future=future)
|
||||
if pd.to_datetime(trading_date) not in list(cal):
|
||||
raise ValueError("{} is not trading day!".format(str(trading_date)))
|
||||
_index = bisect.bisect_left(cal, trading_date)
|
||||
shift_index = _index + shift
|
||||
if shift_index < 0 or shift_index >= len(cal):
|
||||
if clip_shift:
|
||||
shift_index = np.clip(shift_index, 0, len(cal) - 1)
|
||||
else:
|
||||
raise IndexError(f"The shift_index({shift_index}) of the trading day ({trading_date}) is out of range")
|
||||
return cal[shift_index]
|
||||
|
||||
|
||||
def get_next_trading_date(trading_date, future=False):
|
||||
@@ -688,7 +702,7 @@ def flatten_dict(d, parent_key="", sep="."):
|
||||
|
||||
|
||||
#################### Wrapper #####################
|
||||
class Wrapper(object):
|
||||
class Wrapper:
|
||||
"""Wrapper class for anything that needs to set up during qlib.init"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -44,7 +44,7 @@ def sys_config(config, config_path):
|
||||
# worflow handler function
|
||||
def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
||||
with open(config_path) as fp:
|
||||
config = yaml.load(fp, Loader=yaml.Loader)
|
||||
config = yaml.load(fp, Loader=yaml.SafeLoader)
|
||||
|
||||
# config the `sys` section
|
||||
sys_config(config, config_path)
|
||||
|
||||
@@ -65,13 +65,13 @@ class Experiment:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `end` method.")
|
||||
|
||||
def create_recorder(self, name=None):
|
||||
def create_recorder(self, recorder_name=None):
|
||||
"""
|
||||
Create a recorder for each experiment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
recorder_name : str
|
||||
the name of the recorder to be created.
|
||||
|
||||
Returns
|
||||
|
||||
@@ -5,10 +5,9 @@ import re
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
from ..contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from ..contrib.evaluate import risk_analysis
|
||||
from ..contrib.backtest import backtest as normal_backtest
|
||||
|
||||
from ..data.dataset import DatasetH
|
||||
from ..data.dataset.handler import DataHandlerLP
|
||||
from ..utils import init_instance_by_config, get_module_by_module_path
|
||||
@@ -213,6 +212,11 @@ class SigAnaRecord(SignalRecord):
|
||||
class PortAnaRecord(SignalRecord):
|
||||
"""
|
||||
This is the Portfolio Analysis Record class that generates the analysis results such as those of backtest. This class inherits the ``RecordTemp`` class.
|
||||
|
||||
The following files will be stored in recorder
|
||||
- report_normal.pkl & positions_normal.pkl:
|
||||
- The return report and detailed positions of the backtest, returned by `qlib/contrib/evaluate.py:backtest`
|
||||
- port_analysis.pkl : The risk analysis of your portfolio, returned by `qlib/contrib/evaluate.py:risk_analysis`
|
||||
"""
|
||||
|
||||
artifact_path = "portfolio_analysis"
|
||||
@@ -236,9 +240,14 @@ class PortAnaRecord(SignalRecord):
|
||||
|
||||
# custom strategy and get backtest
|
||||
pred_score = super().load()
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
|
||||
report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
|
||||
report_normal = report_dict.get("report_df")
|
||||
positions_normal = report_dict.get("positions")
|
||||
self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
order_normal = report_dict.get("order_list")
|
||||
if order_normal:
|
||||
self.recorder.save_objects(**{"order_normal.pkl": order_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
|
||||
# analysis
|
||||
analysis = dict()
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import mlflow
|
||||
import shutil, os, pickle, tempfile, codecs
|
||||
import shutil, os, pickle, tempfile, codecs, pickle
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from ..utils.objm import FileManager
|
||||
@@ -202,9 +202,6 @@ class MLflowRecorder(Recorder):
|
||||
super(MLflowRecorder, self).__init__(experiment_id, name)
|
||||
self._uri = uri
|
||||
self.artifact_uri = None
|
||||
# set up file manager for saving objects
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.fm = FileManager(Path(self.temp_dir).absolute())
|
||||
self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
# construct from mlflow run
|
||||
if mlflow_run is not None:
|
||||
@@ -248,16 +245,18 @@ class MLflowRecorder(Recorder):
|
||||
self.end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
if self.status != Recorder.STATUS_S:
|
||||
self.status = status
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
||||
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
|
||||
assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
if local_path is not None:
|
||||
self.client.log_artifacts(self.id, local_path, artifact_path)
|
||||
else:
|
||||
temp_dir = Path(tempfile.mkdtemp()).resolve()
|
||||
for name, data in kwargs.items():
|
||||
self.fm.save_obj(data, name)
|
||||
self.client.log_artifact(self.id, self.fm.path / name, artifact_path)
|
||||
with (temp_dir / name).open("wb") as f:
|
||||
pickle.dump(data, f)
|
||||
self.client.log_artifact(self.id, temp_dir / name, artifact_path)
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
def load_object(self, name):
|
||||
assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
|
||||
@@ -43,7 +43,7 @@ python get_data.py qlib_data --help
|
||||
|
||||
### US data
|
||||
|
||||
> Need to download data first: [Downlaod US Data](#Downlaod-US-Data)
|
||||
> Need to download data first: [Download US Data](#Download-US-Data)
|
||||
|
||||
```python
|
||||
import qlib
|
||||
|
||||
@@ -1,28 +1,71 @@
|
||||
import sys, platform
|
||||
import sys
|
||||
import platform
|
||||
import qlib
|
||||
import fire
|
||||
import pkg_resources
|
||||
from pathlib import Path
|
||||
|
||||
QLIB_PATH = Path(__file__).absolute().resolve().parent.parent
|
||||
|
||||
|
||||
def linux_distribution():
|
||||
try:
|
||||
return platform.linux_distribution()
|
||||
except:
|
||||
return "N/A"
|
||||
class InfoCollector:
|
||||
"""
|
||||
User could collect system info by following commands
|
||||
`cd scripts && python collect_info.py all`
|
||||
- NOTE: please avoid running this script in the project folder which contains `qlib`
|
||||
"""
|
||||
|
||||
def sys(self):
|
||||
"""collect system related info"""
|
||||
for method in ["system", "machine", "platform", "version"]:
|
||||
print(getattr(platform, method)())
|
||||
|
||||
def py(self):
|
||||
"""collect Python related info"""
|
||||
print("Python version: {}".format(sys.version.replace("\n", " ")))
|
||||
|
||||
def qlib(self):
|
||||
"""collect qlib related info"""
|
||||
print("Qlib version: {}".format(qlib.__version__))
|
||||
REQUIRED = [
|
||||
"numpy",
|
||||
"pandas",
|
||||
"scipy",
|
||||
"requests",
|
||||
"sacred",
|
||||
"python-socketio",
|
||||
"redis",
|
||||
"python-redis-lock",
|
||||
"schedule",
|
||||
"cvxpy",
|
||||
"hyperopt",
|
||||
"fire",
|
||||
"statsmodels",
|
||||
"xlrd",
|
||||
"plotly",
|
||||
"matplotlib",
|
||||
"tables",
|
||||
"pyyaml",
|
||||
"mlflow",
|
||||
"tqdm",
|
||||
"loguru",
|
||||
"lightgbm",
|
||||
"tornado",
|
||||
"joblib",
|
||||
"fire",
|
||||
"ruamel.yaml",
|
||||
]
|
||||
|
||||
for package in REQUIRED:
|
||||
version = pkg_resources.get_distribution(package).version
|
||||
print(f"{package}=={version}")
|
||||
|
||||
def all(self):
|
||||
"""collect all info"""
|
||||
for method in ["sys", "py", "qlib"]:
|
||||
getattr(self, method)()
|
||||
print()
|
||||
|
||||
|
||||
print("Qlib version: {} \n".format(qlib.__version__))
|
||||
print(
|
||||
"""Python version: {} \n
|
||||
linux_distribution: {}
|
||||
system: {}
|
||||
machine: {}
|
||||
platform: {}
|
||||
version: {}
|
||||
""".format(
|
||||
sys.version.split("\n"),
|
||||
linux_distribution(),
|
||||
platform.system(),
|
||||
platform.machine(),
|
||||
platform.platform(),
|
||||
platform.version(),
|
||||
)
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(InfoCollector)
|
||||
|
||||
4
setup.py
4
setup.py
@@ -11,7 +11,7 @@ NAME = "pyqlib"
|
||||
DESCRIPTION = "A Quantitative-research Platform"
|
||||
REQUIRES_PYTHON = ">=3.5.0"
|
||||
|
||||
VERSION = "0.6.0.dev"
|
||||
VERSION = "0.6.1.dev"
|
||||
|
||||
# Detect Cython
|
||||
try:
|
||||
@@ -35,7 +35,6 @@ REQUIRED = [
|
||||
"scipy>=1.0.0",
|
||||
"requests>=2.18.0",
|
||||
"sacred>=0.7.4",
|
||||
"pymongo==3.7.2",
|
||||
"python-socketio==3.1.2",
|
||||
"redis>=3.0.1",
|
||||
"python-redis-lock>=3.3.1",
|
||||
@@ -55,7 +54,6 @@ REQUIRED = [
|
||||
"lightgbm",
|
||||
"tornado",
|
||||
"joblib>=0.17.0",
|
||||
"fire>=0.3.1",
|
||||
"ruamel.yaml>=0.16.12",
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user