diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml
index 08d41d198..8b94a2d3b 100644
--- a/.github/workflows/python-publish.yml
+++ b/.github/workflows/python-publish.yml
@@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
os: [windows-latest, macos-latest]
- python-version: [3.6, 3.7, 3.8]
+ python-version: [3.6, 3.7, 3.8, 3.9]
steps:
- uses: actions/checkout@v2
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 29265b1eb..af386f6ca 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -1,4 +1,4 @@
-name: Test
+name: Test
on:
push:
@@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
os: [windows-latest, ubuntu-16.04, ubuntu-18.04, ubuntu-20.04]
- python-version: [3.6, 3.7, 3.8, 3.9]
+ python-version: [3.6, 3.7, 3.8]
steps:
- uses: actions/checkout@v2
@@ -25,63 +25,29 @@ jobs:
- name: Lint with Black
run: |
- cd ..
- if [ "$RUNNER_OS" == "Windows" ]; then
- $CONDA\\python.exe -m pip install black
- $CONDA\\python.exe -m black qlib -l 120 --check --diff
- else
- sudo $CONDA/bin/python -m pip install black
- $CONDA/bin/python -m black qlib -l 120 --check --diff
- fi
- shell: bash
+ pip install --upgrade pip
+ pip install black wheel
+ black qlib -l 120 --check --diff
- # Test Qlib installed with pip
- # - name: Install Qlib with pip
- # run: |
- # if [ "$RUNNER_OS" == "Windows" ]; then
- # $CONDA\\python.exe -m pip install numpy==1.19.5
- # $CONDA\\python.exe -m pip install pyqlib --ignore-installed ruamel.yaml numpy --user
- # else
- # sudo $CONDA/bin/python -m pip install numpy==1.19.5
- # sudo $CONDA/bin/python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
- # fi
- # shell: bash
+ - name: Install Qlib with pip
+ run: |
+ pip install numpy==1.19.5 ruamel.yaml
+ pip install pyqlib --ignore-installed
- # - name: Test data downloads
- # run: |
- # if [ "$RUNNER_OS" == "Windows" ]; then
- # $CONDA\\python.exe scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
- # else
- # $CONDA/bin/python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
- # fi
- # shell: bash
+ - name: Test data downloads
+ run: |
+ python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
- # - name: Test workflow by config (install from pip)
- # run: |
- # if [ "$RUNNER_OS" == "Windows" ]; then
- # $CONDA\\python.exe qlib\\workflow\\cli.py examples\\benchmarks\\LightGBM\\workflow_config_lightgbm_Alpha158.yaml
- # $CONDA\\python.exe -m pip uninstall -y pyqlib
- # else
- # $CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
- # sudo $CONDA/bin/python -m pip uninstall -y pyqlib
- # fi
- # shell: bash
-
- # Test Qlib installed from source
+ - name: Test workflow by config (install from pip)
+ run: |
+ python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
+ python -m pip uninstall -y pyqlib
+
+ # Test Qlib installed from source
- name: Install Qlib from source
run: |
- if [ "$RUNNER_OS" == "Windows" ]; then
- $CONDA\\python.exe -m pip install --upgrade cython
- $CONDA\\python.exe -m pip install numpy jupyter jupyter_contrib_nbextensions
- $CONDA\\python.exe -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
- $CONDA\\python.exe setup.py install
- else
- sudo $CONDA/bin/python -m pip install --upgrade cython
- sudo $CONDA/bin/python -m pip install numpy jupyter jupyter_contrib_nbextensions
- sudo $CONDA/bin/python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
- sudo $CONDA/bin/python setup.py install
- fi
- shell: bash
+ pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
+ pip install -e .
- name: Test data downloads
run: |
@@ -94,30 +60,15 @@ jobs:
- name: Install test dependencies
run: |
- if [ "$RUNNER_OS" == "Windows" ]; then
- $CONDA\\python.exe -m pip install --upgrade pip
- $CONDA\\python.exe -m pip install black pytest
- else
- sudo $CONDA/bin/python -m pip install --upgrade pip
- sudo $CONDA/bin/python -m pip install black pytest
- fi
- shell: bash
+ pip install --upgrade pip
+ pip install black pytest
- name: Unit tests with Pytest
run: |
cd tests
- if [ "$RUNNER_OS" == "Windows" ]; then
- $CONDA\\python.exe -m pytest . --durations=0
- else
- $CONDA/bin/python -m pytest . --durations=0
- fi
- shell: bash
+ python -m pytest . --durations=10
- name: Test workflow by config (install from source)
run: |
- if [ "$RUNNER_OS" == "Windows" ]; then
- $CONDA\\python.exe qlib\\workflow\\cli.py examples\\benchmarks\\LightGBM\\workflow_config_lightgbm_Alpha158.yaml
- else
- $CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
- fi
- shell: bash
+ python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
+
diff --git a/.github/workflows/test_macos.yml b/.github/workflows/test_macos.yml
index e52c27786..b6003f668 100644
--- a/.github/workflows/test_macos.yml
+++ b/.github/workflows/test_macos.yml
@@ -13,7 +13,7 @@ jobs:
runs-on: macos-latest
strategy:
matrix:
- python-version: [3.6, 3.7, 3.8, 3.9]
+ python-version: [3.6, 3.7, 3.8]
steps:
- uses: actions/checkout@v2
@@ -26,52 +26,46 @@ jobs:
- name: Lint with Black
run: |
cd ..
- sudo $CONDA/bin/python -m pip install black
- $CONDA/bin/python -m black qlib -l 120 --check --diff
-
+ python -m pip install pip --upgrade
+ python -m pip install wheel --upgrade
+ python -m pip install black
+ python -m black qlib -l 120 --check --diff
# Test Qlib installed with pip
- # - name: Install Qlib with pip
- # run: |
- # sudo $CONDA/bin/python -m pip install numpy==1.19.5
- # sudo $CONDA/bin/python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
+
+ - name: Install Qlib with pip
+ run: |
+ python -m pip install numpy==1.19.5
+ python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
- name: Install Lightgbm for MacOS
run: |
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
- # - name: Test data downloads
- # run: |
- # $CONDA/bin/python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
-
- # - name: Test workflow by config (install from pip)
- # run: |
- # $CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
- # sudo $CONDA/bin/python -m pip uninstall -y pyqlib
-
+ - name: Test data downloads
+ run: |
+ python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
+ - name: Test workflow by config (install from pip)
+ run: |
+ python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
+ python -m pip uninstall -y pyqlib
# Test Qlib installed from source
- name: Install Qlib from source
run: |
- sudo $CONDA/bin/python -m pip install --upgrade cython
- sudo $CONDA/bin/python -m pip install numpy jupyter jupyter_contrib_nbextensions
- sudo $CONDA/bin/python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
- sudo $CONDA/bin/python setup.py install
-
- - name: Test data downloads
- run: |
- $CONDA/bin/python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
+ python -m pip install --upgrade cython
+ python -m pip install numpy jupyter jupyter_contrib_nbextensions
+ python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
+ python setup.py install
- name: Install test dependencies
run: |
- sudo $CONDA/bin/python -m pip install --upgrade pip
- sudo $CONDA/bin/python -m pip install -U pyopenssl idna
- sudo $CONDA/bin/python -m pip install black pytest
-
+ python -m pip install --upgrade pip
+ python -m pip install -U pyopenssl idna
+ python -m pip install black pytest
- name: Unit tests with Pytest
run: |
cd tests
- $CONDA/bin/python -m pytest . --durations=0
-
+ python -m pytest . --durations=0
- name: Test workflow by config (install from source)
run: |
- $CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
+ python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index 33a2a2530..a563ed5c7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -20,6 +20,7 @@ dist/
.nvimrc
.vscode
+qlib/VERSION.txt
qlib/data/_libs/expanding.cpp
qlib/data/_libs/rolling.cpp
examples/estimator/estimator_example/
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 000000000..8dd91c79d
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1 @@
+include qlib/VERSION.txt
diff --git a/README.md b/README.md
index 422046c13..6ceb26e66 100644
--- a/README.md
+++ b/README.md
@@ -11,6 +11,7 @@
Recent released features
| Feature | Status |
| -- | ------ |
+|Temporal Routing Adaptor (TRA) | [Released](https://github.com/microsoft/qlib/pull/531) on July 30, 2021 |
| Transformer & Localformer | [Released](https://github.com/microsoft/qlib/pull/508) on July 22, 2021 |
| Release Qlib v0.7.0 | [Released](https://github.com/microsoft/qlib/releases/tag/v0.7.0) on July 12, 2021 |
| TCTS Model | [Released](https://github.com/microsoft/qlib/pull/491) on July 1, 2021 |
@@ -23,10 +24,8 @@ Recent released features
Features released before 2021 are not listed here.
-
-
-
+
@@ -45,7 +44,7 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
- [Data Preparation](#data-preparation)
- [Auto Quant Research Workflow](#auto-quant-research-workflow)
- [Building Customized Quant Research Workflow by Code](#building-customized-quant-research-workflow-by-code)
-- [**Quant Model Zoo**](#quant-model-zoo)
+- [**Quant Model(Paper) Zoo**](#quant-model-paper-zoo)
- [Run a single model](#run-a-single-model)
- [Run multiple models](#run-multiple-models)
- [**Quant Dataset Zoo**](#quant-dataset-zoo)
@@ -71,7 +70,7 @@ Your feedbacks about the features are very important.
# Framework of Qlib
-
+
@@ -107,8 +106,9 @@ This table demonstrates the supported Python version of `Qlib`:
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
**Note**:
+1. **Conda** is suggested for managing your Python environment.
1. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.7 or use `conda`'s Python to install ``Qlib`` from source.
-2. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
+1. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
### Install with pip
Users can easily install ``Qlib`` by pip according to the following command.
@@ -162,7 +162,7 @@ Users could create the same dataset with it.
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup), and the data might not be perfect.
We recommend users to prepare their own data if they have a high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
-### Automatic update of daily frequency data(from yahoo finance)
+### Automatic update of daily frequency data (from yahoo finance)
> It is recommended that users update the data manually once (--trading_date 2021-05-25) and then set it to update automatically.
> For more information refer to: [yahoo collector](https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
@@ -247,19 +247,19 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
2. Graphical Reports Analysis: Run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports
- Forecasting signal (model prediction) analysis
- Cumulative Return of groups
- 
+ 
- Return distribution
- 
+ 
- Information Coefficient (IC)
- 
- 
- 
+ 
+ 
+ 
- Auto Correlation of forecasting signal (model prediction)
- 
+ 
- Portfolio analysis
- Backtest return
- 
+ 
+
+
\ No newline at end of file
diff --git a/docs/component/backtest.rst b/docs/component/backtest.rst
index 88e01e2de..e83e1023a 100644
--- a/docs/component/backtest.rst
+++ b/docs/component/backtest.rst
@@ -30,7 +30,7 @@ The simple example of the default strategy is as follows.
from qlib.contrib.evaluate import backtest
# pred_score is the prediction score
- report, positions = backtest(pred_score, topk=50, n_drop=0.5, verbose=False, limit_threshold=0.0095)
+ report, positions = backtest(pred_score, topk=50, n_drop=0.5, limit_threshold=0.0095)
To know more about backtesting with a specific ``Strategy``, please refer to `Portfolio Strategy `_.
diff --git a/docs/component/highfreq.rst b/docs/component/highfreq.rst
new file mode 100644
index 000000000..13ebb959d
--- /dev/null
+++ b/docs/component/highfreq.rst
@@ -0,0 +1,120 @@
+.. _highfreq:
+
+============================================
+Design of hierarchical order execution framework
+============================================
+.. currentmodule:: qlib
+
+Introduction
+===================
+
+In order to support reinforcement learning algorithms for high-frequency trading, a corresponding framework is required. None of the publicly available high-frequency trading frameworks now consider multi-layer trading mechanisms, and the currently designed algorithms cannot directly use existing frameworks.
+In addition to supporting the basic intraday multi-layer trading, the linkage with the day-ahead strategy is also a factor that affects the performance evaluation of the strategy. Different day strategies generate different order distributions and different patterns on different stocks. To verify that high-frequency trading strategies perform well on real trading orders, it is necessary to support day-frequency and high-frequency multi-level linkage trading. In addition to more accurate backtesting of high-frequency trading algorithms, if the distribution of day-frequency orders is considered when training a high-frequency trading model, the algorithm can also be optimized more for product-specific day-frequency orders.
+Therefore, innovation in the high-frequency trading framework is necessary to solve the various problems mentioned above, for which we designed a hierarchical order execution framework that can link daily-frequency and intra-day trading at different granularities.
+
+.. image:: ../_static/img/framework.svg
+
+The design of the framework is shown in the figure above. At each layer consists of Trading Agent and Execution Env. The Trading Agent has its own data processing module (Information Extractor), forecasting module (Forecast Model) and decision generator (Decision Generator). The trading algorithm generates the corresponding decisions by the Decision Generator based on the forecast signals output by the Forecast Module, and the decisions generated by the trading algorithm are passed to the Execution Env, which returns the execution results. Here the frequency of trading algorithm, decision content and execution environment can be customized by users (e.g. intra-day trading, daily-frequency trading, weekly-frequency trading), and the execution environment can be nested with finer-grained trading algorithm and execution environment inside (i.e. sub-workflow in the figure, e.g. daily-frequency orders can be turned into finer-grained decisions by splitting orders within the day). The hierarchical order execution framework is user-defined in terms of hierarchy division and decision frequency, making it easy for users to explore the effects of combining different levels of trading algorithms and breaking down the barriers between different levels of trading algorithm optimization.
+In addition to the innovation in the framework, the hierarchical order execution framework also takes into account various details of the real backtesting environment, minimizing the differences with the final real environment as much as possible. At the same time, the framework is designed to unify the interface between online and offline (e.g. data pre-processing level supports using the same set of code to process both offline and online data) to reduce the cost of strategy go-live as much as possible.
+
+Prepare Data
+===================
+.. _data:: ../../examples/highfreq/README.md
+
+
+Example
+===========================
+
+Here is an example of highfreq execution.
+
+.. code-block:: python
+
+ import qlib
+ # init qlib
+ provider_uri_day = "~/.qlib/qlib_data/cn_data"
+ provider_uri_1min = "~/.qlib/qlib_data/cn_data_1min"
+ provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day}
+ qlib.init(provider_uri=provider_uri_day, expression_cache=None, dataset_cache=None)
+
+ # data freq and backtest time
+ freq = "1min"
+ inst_list = D.list_instruments(D.instruments("all"), as_list=True)
+ start_time = "2020-01-01"
+ start_time = "2020-01-31"
+
+When initializing qlib, if the default data is used, then both daily and minute frequency data need to be passed in.
+
+.. code-block:: python
+
+ # random order strategy config
+ strategy_config = {
+ "class": "RandomOrderStrategy",
+ "module_path": "qlib.contrib.strategy.rule_strategy",
+ "kwargs": {
+ "trade_range": TradeRangeByTime("9:30", "15:00"),
+ "sample_ratio": 1.0,
+ "volume_ratio": 0.01,
+ "market": market,
+ },
+ }
+
+.. code-block:: python
+ # backtest config
+ backtest_config = {
+ "start_time": start_time,
+ "end_time": end_time,
+ "account": 100000000,
+ "benchmark": None,
+ "exchange_kwargs": {
+ "freq": freq,
+ "limit_threshold": 0.095,
+ "deal_price": "close",
+ "open_cost": 0.0005,
+ "close_cost": 0.0015,
+ "min_cost": 5,
+ "codes": market,
+ },
+ "pos_type": "InfPosition", # Position with infinitive position
+ }
+
+please refer to "../../qlib/backtest".
+
+.. code-block:: python
+ # excutor config
+ executor_config = {
+ "class": "NestedExecutor",
+ "module_path": "qlib.backtest.executor",
+ "kwargs": {
+ "time_per_step": "day",
+ "inner_executor": {
+ "class": "SimulatorExecutor",
+ "module_path": "qlib.backtest.executor",
+ "kwargs": {
+ "time_per_step": freq,
+ "generate_portfolio_metrics": True,
+ "verbose": False,
+ # "verbose": True,
+ "indicator_config": {
+ "show_indicator": False,
+ },
+ },
+ },
+ "inner_strategy": {
+ "class": "TWAPStrategy",
+ "module_path": "qlib.contrib.strategy.rule_strategy",
+ },
+ "track_data": True,
+ "generate_portfolio_metrics": True,
+ "indicator_config": {
+ "show_indicator": True,
+ },
+ },
+ }
+
+NestedExecutor represents not the innermost layer, the initialization parameters should contain inner_executor and inner_strategy. simulatorExecutor represents the current excutor is the innermost layer, the innermost strategy used here is the TWAP strategy, the framework currently also supports the VWAP strategy
+
+.. code-block:: python
+ # backtest
+ portfolio_metrics_dict, indicator_dict = backtest(executor=executor_config, strategy=strategy_config, **backtest_config)
+
+The metrics of backtest are included in the portfolio_metrics_dict and indicator_dict.
diff --git a/docs/component/recorder.rst b/docs/component/recorder.rst
index cc425fa8e..5a7d195d6 100644
--- a/docs/component/recorder.rst
+++ b/docs/component/recorder.rst
@@ -123,7 +123,6 @@ Here is a simple exampke of what is done in ``PortAnaRecord``, which users can r
"n_drop": 5,
}
BACKTEST_CONFIG = {
- "verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": BENCHMARK,
diff --git a/docs/component/strategy.rst b/docs/component/strategy.rst
index e4a5a94d1..c9d002ca1 100644
--- a/docs/component/strategy.rst
+++ b/docs/component/strategy.rst
@@ -93,7 +93,6 @@ Usage & Example
"n_drop": 5,
}
BACKTEST_CONFIG = {
- "verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": BENCHMARK,
diff --git a/docs/component/workflow.rst b/docs/component/workflow.rst
index 2b7ec19ad..84522af99 100644
--- a/docs/component/workflow.rst
+++ b/docs/component/workflow.rst
@@ -54,7 +54,6 @@ Below is a typical config file of ``qrun``.
topk: 50
n_drop: 5
backtest:
- verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
@@ -242,7 +241,6 @@ The following script is the configuration of `backtest` and the `strategy` used
topk: 50
n_drop: 5
backtest:
- verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
diff --git a/docs/developer/code_standard.rst b/docs/developer/code_standard.rst
index 23ea713ba..27da3bfd1 100644
--- a/docs/developer/code_standard.rst
+++ b/docs/developer/code_standard.rst
@@ -6,15 +6,17 @@ Code Standard
Docstring
=================================
-Please use the Numpy Style.
+Please use the `Numpydoc Style `_.
Continuous Integration
=================================
Continuous Integration (CI) tools help you stick to the quality standards by running tests every time you push a new commit and reporting the results to a pull request.
+When you submit a PR request, you can check whether your code passes the CI tests in the "check" section at the bottom of the web page.
+
A common error is the mixed use of space and tab. You can fix the bug by inputing the following code in the command line.
.. code-block:: python
pip install black
- python -m black . -l 120
\ No newline at end of file
+ python -m black . -l 120
diff --git a/docs/hidden/tuner.rst b/docs/hidden/tuner.rst
index 6d62f899f..8abf2ec7c 100644
--- a/docs/hidden/tuner.rst
+++ b/docs/hidden/tuner.rst
@@ -93,7 +93,6 @@ We write a simple configuration example as following,
fend_time: 2018-12-11
backtest:
normal_backtest_args:
- verbose: False
limit_threshold: 0.095
account: 500000
benchmark: SH000905
@@ -306,7 +305,6 @@ About the data and backtest
fend_time: 2018-12-11
backtest:
normal_backtest_args:
- verbose: False
limit_threshold: 0.095
account: 500000
benchmark: SH000905
diff --git a/docs/introduction/introduction.rst b/docs/introduction/introduction.rst
index 06fac46fa..a55edd5ec 100644
--- a/docs/introduction/introduction.rst
+++ b/docs/introduction/introduction.rst
@@ -15,7 +15,7 @@ With ``Qlib``, users can easily try their ideas to create better Quant investmen
Framework
===================
-.. image:: ../_static/img/framework.png
+.. image:: ../_static/img/framework.svg
:align: center
diff --git a/docs/start/initialization.rst b/docs/start/initialization.rst
index 32c17ff83..eaf80f4a5 100644
--- a/docs/start/initialization.rst
+++ b/docs/start/initialization.rst
@@ -77,7 +77,8 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
})
- `mongo`
Type: dict, optional parameter, the setting of `MongoDB `_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.
- Users need finished `installation `_ firstly, and run it in a fixed URL.
+ Users need to follow the steps in `installation `_ to install MongoDB firstly and then access it via a URI.
+ Users can access mongodb with credential by setting "task_url" to a string like `"mongodb://%s:%s@%s" % (user, pwd, host + ":" + port)`.
.. code-block:: Python
diff --git a/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml b/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml
index ea38ae19c..039040d8f 100755
--- a/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml
+++ b/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml
@@ -93,8 +93,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha360.yaml b/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha360.yaml
index 83720b4b2..88c6fcd07 100644
--- a/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha360.yaml
+++ b/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha360.yaml
@@ -83,8 +83,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha158.yaml b/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha158.yaml
index 0ffe19e1b..18e19bd0f 100644
--- a/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha158.yaml
+++ b/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha158.yaml
@@ -65,8 +65,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha360.yaml b/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha360.yaml
index 57c1751a1..a6cdd1882 100644
--- a/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha360.yaml
+++ b/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha360.yaml
@@ -72,8 +72,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha158.yaml b/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha158.yaml
index 71f0d3e64..fb8cce74d 100644
--- a/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha158.yaml
+++ b/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha158.yaml
@@ -90,8 +90,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha360.yaml b/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha360.yaml
index 8a185f05f..d1fbd7807 100644
--- a/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha360.yaml
+++ b/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha360.yaml
@@ -97,8 +97,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/GATs/workflow_config_gats_Alpha158.yaml b/examples/benchmarks/GATs/workflow_config_gats_Alpha158.yaml
index 63aa1b429..5387adc24 100644
--- a/examples/benchmarks/GATs/workflow_config_gats_Alpha158.yaml
+++ b/examples/benchmarks/GATs/workflow_config_gats_Alpha158.yaml
@@ -91,8 +91,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/GATs/workflow_config_gats_Alpha360.yaml b/examples/benchmarks/GATs/workflow_config_gats_Alpha360.yaml
index e06192b2b..1ffd6780e 100644
--- a/examples/benchmarks/GATs/workflow_config_gats_Alpha360.yaml
+++ b/examples/benchmarks/GATs/workflow_config_gats_Alpha360.yaml
@@ -83,8 +83,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/GRU/workflow_config_gru_Alpha158.yaml b/examples/benchmarks/GRU/workflow_config_gru_Alpha158.yaml
index 42286fecd..82c690889 100755
--- a/examples/benchmarks/GRU/workflow_config_gru_Alpha158.yaml
+++ b/examples/benchmarks/GRU/workflow_config_gru_Alpha158.yaml
@@ -92,8 +92,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/GRU/workflow_config_gru_Alpha360.yaml b/examples/benchmarks/GRU/workflow_config_gru_Alpha360.yaml
index bd1a6e1bf..02c81c850 100644
--- a/examples/benchmarks/GRU/workflow_config_gru_Alpha360.yaml
+++ b/examples/benchmarks/GRU/workflow_config_gru_Alpha360.yaml
@@ -82,8 +82,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/LSTM/workflow_config_lstm_Alpha158.yaml b/examples/benchmarks/LSTM/workflow_config_lstm_Alpha158.yaml
index 687404419..f4412c262 100755
--- a/examples/benchmarks/LSTM/workflow_config_lstm_Alpha158.yaml
+++ b/examples/benchmarks/LSTM/workflow_config_lstm_Alpha158.yaml
@@ -92,8 +92,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/LSTM/workflow_config_lstm_Alpha360.yaml b/examples/benchmarks/LSTM/workflow_config_lstm_Alpha360.yaml
index e6c3b5736..10a1dc5df 100644
--- a/examples/benchmarks/LSTM/workflow_config_lstm_Alpha360.yaml
+++ b/examples/benchmarks/LSTM/workflow_config_lstm_Alpha360.yaml
@@ -82,8 +82,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/LightGBM/features_resample_N.py b/examples/benchmarks/LightGBM/features_resample_N.py
new file mode 100644
index 000000000..13061513c
--- /dev/null
+++ b/examples/benchmarks/LightGBM/features_resample_N.py
@@ -0,0 +1,18 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import pandas as pd
+
+from qlib.data.inst_processor import InstProcessor
+from qlib.utils.resam import resam_calendar
+
+
+class ResampleNProcessor(InstProcessor):
+ def __init__(self, target_frq: str, **kwargs):
+ self.target_frq = target_frq
+
+ def __call__(self, df: pd.DataFrame, *args, **kwargs):
+ df.index = pd.to_datetime(df.index)
+ res_index = resam_calendar(df.index, "1min", self.target_frq)
+ df = df.resample(self.target_frq).last().reindex(res_index)
+ return df
diff --git a/examples/benchmarks/LightGBM/multi_freq_handler.py b/examples/benchmarks/LightGBM/multi_freq_handler.py
new file mode 100644
index 000000000..07d7ac27c
--- /dev/null
+++ b/examples/benchmarks/LightGBM/multi_freq_handler.py
@@ -0,0 +1,135 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import pandas as pd
+
+from qlib.data.dataset.loader import QlibDataLoader
+from qlib.contrib.data.handler import DataHandlerLP, _DEFAULT_LEARN_PROCESSORS, check_transform_proc
+
+
+class Avg15minLoader(QlibDataLoader):
+ def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
+ df = super(Avg15minLoader, self).load(instruments, start_time, end_time)
+ if self.is_group:
+ # feature_day(day freq) and feature_15min(1min freq, Average every 15 minutes) renamed feature
+ df.columns = df.columns.map(lambda x: ("feature", x[1]) if x[0].startswith("feature") else x)
+ return df
+
+
+class Avg15minHandler(DataHandlerLP):
+ def __init__(
+ self,
+ instruments="csi500",
+ start_time=None,
+ end_time=None,
+ freq="day",
+ infer_processors=[],
+ learn_processors=_DEFAULT_LEARN_PROCESSORS,
+ fit_start_time=None,
+ fit_end_time=None,
+ process_type=DataHandlerLP.PTYPE_A,
+ filter_pipe=None,
+ inst_processor=None,
+ **kwargs,
+ ):
+ infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
+ learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
+ data_loader = Avg15minLoader(
+ config=self.loader_config(), filter_pipe=filter_pipe, freq=freq, inst_processor=inst_processor
+ )
+ super().__init__(
+ instruments=instruments,
+ start_time=start_time,
+ end_time=end_time,
+ data_loader=data_loader,
+ infer_processors=infer_processors,
+ learn_processors=learn_processors,
+ process_type=process_type,
+ )
+
+ def loader_config(self):
+
+ # Results for dataset: df: pd.DataFrame
+ # len(df.columns) == 6 + 6 * 16, len(df.index.get_level_values(level="datetime").unique()) == T
+ # df.columns: close0, close1, ..., close16, open0, ..., open16, ..., vwap16
+ # freq == day:
+ # close0, open0, low0, high0, volume0, vwap0
+ # freq == 1min:
+ # close1, ..., close16, ..., vwap1, ..., vwap16
+ # df.index.name == ["datetime", "instrument"]: pd.MultiIndex
+ # Example:
+ # feature ... label
+ # close0 open0 low0 ... vwap1 vwap16 LABEL0
+ # datetime instrument ...
+ # 2020-10-09 SH600000 11.794546 11.819587 11.769505 ... NaN NaN -0.005214
+ # 2020-10-15 SH600000 12.044961 11.944795 11.932274 ... NaN NaN -0.007202
+ # ... ... ... ... ... ... ... ...
+ # 2021-05-28 SZ300676 6.369684 6.495406 6.306568 ... NaN NaN -0.001321
+ # 2021-05-31 SZ300676 6.601626 6.465643 6.465130 ... NaN NaN -0.023428
+
+ # features day: len(columns) == 6, freq = day
+ # $close is the closing price of the current trading day:
+ # if the user needs to get the `close` before the last T days, use Ref($close, T-1), for example:
+ # $close Ref($close, 1) Ref($close, 2) Ref($close, 3) Ref($close, 4)
+ # instrument datetime
+ # SH600519 2021-06-01 244.271530
+ # 2021-06-02 242.205917 244.271530
+ # 2021-06-03 242.229889 242.205917 244.271530
+ # 2021-06-04 245.421524 242.229889 242.205917 244.271530
+ # 2021-06-07 247.547089 245.421524 242.229889 242.205917 244.271530
+
+ # WARNING: Ref($close, N), if N == 0, Ref($close, N) ==> $close
+
+ fields = ["$close", "$open", "$low", "$high", "$volume", "$vwap"]
+ # names: close0, open0, ..., vwap0
+ names = list(map(lambda x: x.strip("$") + "0", fields))
+
+ config = {"feature_day": (fields, names)}
+
+ # features 15min: len(columns) == 6 * 16, freq = 1min
+ # $close is the closing price of the current trading day:
+ # if the user gets 'close' for the i-th 15min of the last T days, use `Ref(Mean($close, 15), (T-1) * 240 + i * 15)`, for example:
+ # Ref(Mean($close, 15), 225) Ref(Mean($close, 15), 465) Ref(Mean($close, 15), 705)
+ # instrument datetime
+ # SH600519 2021-05-31 241.769897 243.077942 244.712997
+ # 2021-06-01 244.271530 241.769897 243.077942
+ # 2021-06-02 242.205917 244.271530 241.769897
+
+ # WARNING: Ref(Mean($close, 15), N), if N == 0, Ref(Mean($close, 15), N) ==> Mean($close, 15)
+
+ # Results of the current script:
+ # time: 09:00 --> 09:14, ..., 14:45 --> 14:59
+ # fields: Ref(Mean($close, 15), 225), ..., Mean($close, 15)
+ # name: close1, ..., close16
+ #
+
+ # Expression description: take close as an example
+ # Mean($close, 15) ==> df["$close"].rolling(15, min_periods=1).mean()
+ # Ref(Mean($close, 15), 15) ==> df["$close"].rolling(15, min_periods=1).mean().shift(15)
+
+ # NOTE: The last data of each trading day, which is the average of the i-th 15 minutes
+
+ # Average:
+ # Average of the i-th 15-minute period of each trading day: 1 <= i <= 250 // 16
+ # Avg(15minutes): Ref(Mean($close, 15), 240 - i * 15)
+ #
+ # Average of the first 15 minutes of each trading day; i = 1
+ # Avg(09:00 --> 09:14), df.index.loc["09:14"]: Ref(Mean($close, 15), 240- 1 * 15) ==> Ref(Mean($close, 15), 225)
+ # Average of the last 15 minutes of each trading day; i = 16
+ # Avg(14:45 --> 14:59), df.index.loc["14:59"]: Ref(Mean($close, 15), 240 - 16 * 15) ==> Ref(Mean($close, 15), 0) ==> Mean($close, 15)
+
+ # 15min resample to day
+ # df.resample("1d").last()
+ tmp_fields = []
+ tmp_names = []
+ for i, _f in enumerate(fields):
+ _fields = [f"Ref(Mean({_f}, 15), {j * 15})" for j in range(1, 240 // 15)]
+ _names = [f"{names[i][:-1]}{int(names[i][-1])+j}" for j in range(240 // 15 - 1, 0, -1)]
+ _fields.append(f"Mean({_f}, 15)")
+ _names.append(f"{names[i][:-1]}{int(names[i][-1])+240 // 15}")
+ tmp_fields += _fields
+ tmp_names += _names
+ config["feature_15min"] = (tmp_fields, tmp_names)
+ # label
+ config["label"] = (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
+ return config
diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
index 9d6f45076..8bee2bd38 100644
--- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
+++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
@@ -66,8 +66,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha360.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha360.yaml
index ba96b076c..b8af19ec1 100644
--- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha360.yaml
+++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha360.yaml
@@ -73,8 +73,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml
index 0f71b2a36..a92f342a1 100644
--- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml
+++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml
@@ -81,9 +81,7 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
- config: *port_analysis_config
\ No newline at end of file
+ config: *port_analysis_config
diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_multi_freq.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_multi_freq.yaml
new file mode 100644
index 000000000..829c87115
--- /dev/null
+++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_multi_freq.yaml
@@ -0,0 +1,86 @@
+qlib_init:
+ provider_uri:
+ day: "~/.qlib/qlib_data/cn_data"
+ 1min: "~/.qlib/qlib_data/cn_data_1min"
+ region: cn
+ dataset_cache: null
+ maxtasksperchild: null
+market: &market csi300
+benchmark: &benchmark SH000300
+data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ # 1min closing time is 15:00:00
+ end_time: "2020-08-01 15:00:00"
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+ freq:
+ label: day
+ feature_15min: 1min
+ feature_day: day
+ # with label as reference
+ inst_processor:
+ feature_15min:
+ - class: ResampleNProcessor
+ module_path: features_resample_N.py
+ kwargs:
+ target_frq: 1d
+
+port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy
+ kwargs:
+ model:
+ dataset:
+ topk: 50
+ n_drop: 5
+ backtest:
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+task:
+ model:
+ class: LGBModel
+ module_path: qlib.contrib.model.gbdt
+ kwargs:
+ loss: mse
+ colsample_bytree: 0.8879
+ learning_rate: 0.2
+ subsample: 0.8789
+ lambda_l1: 205.6999
+ lambda_l2: 580.9768
+ max_depth: 8
+ num_leaves: 210
+ num_threads: 20
+ dataset:
+ class: DatasetH
+ module_path: qlib.data.dataset
+ kwargs:
+ handler:
+ class: Avg15minHandler
+ module_path: multi_freq_handler.py
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ model:
+ dataset:
+ - class: SigAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ ana_long_short: False
+ ann_scaler: 252
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
diff --git a/examples/benchmarks/Linear/workflow_config_linear_Alpha158.yaml b/examples/benchmarks/Linear/workflow_config_linear_Alpha158.yaml
index 1cf28024e..9f055a62c 100644
--- a/examples/benchmarks/Linear/workflow_config_linear_Alpha158.yaml
+++ b/examples/benchmarks/Linear/workflow_config_linear_Alpha158.yaml
@@ -72,8 +72,6 @@ task:
kwargs:
ana_long_short: True
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/Localformer/workflow_config_localformer_Alpha158.yaml b/examples/benchmarks/Localformer/workflow_config_localformer_Alpha158.yaml
index d7e967333..cd31ecd1e 100644
--- a/examples/benchmarks/Localformer/workflow_config_localformer_Alpha158.yaml
+++ b/examples/benchmarks/Localformer/workflow_config_localformer_Alpha158.yaml
@@ -34,19 +34,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: LocalformerModel
@@ -70,13 +74,15 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
- ana_long_short: False
- ann_scaler: 252
+ ana_long_short: False
+ ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
- config: *port_analysis_config
+ config: *port_analysis_config
diff --git a/examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml b/examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml
index 1c8489461..f9cc091fd 100644
--- a/examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml
+++ b/examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml
@@ -26,19 +26,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: LocalformerModel
@@ -59,15 +63,17 @@ task:
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- - class: SignalRecord
- module_path: qlib.workflow.record_temp
- kwargs: {}
- - class: SigAnaRecord
- module_path: qlib.workflow.record_temp
- kwargs:
- ana_long_short: False
- ann_scaler: 252
- - class: PortAnaRecord
- module_path: qlib.workflow.record_temp
- kwargs:
- config: *port_analysis_config
\ No newline at end of file
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ model:
+ dataset:
+ - class: SigAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ ana_long_short: False
+ ann_scaler: 252
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
diff --git a/examples/benchmarks/MLP/workflow_config_mlp_Alpha158.yaml b/examples/benchmarks/MLP/workflow_config_mlp_Alpha158.yaml
index bc005b43e..8303f3945 100644
--- a/examples/benchmarks/MLP/workflow_config_mlp_Alpha158.yaml
+++ b/examples/benchmarks/MLP/workflow_config_mlp_Alpha158.yaml
@@ -95,8 +95,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/MLP/workflow_config_mlp_Alpha360.yaml b/examples/benchmarks/MLP/workflow_config_mlp_Alpha360.yaml
index a4ceab8da..f52c5930d 100644
--- a/examples/benchmarks/MLP/workflow_config_mlp_Alpha360.yaml
+++ b/examples/benchmarks/MLP/workflow_config_mlp_Alpha360.yaml
@@ -82,8 +82,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/README.md b/examples/benchmarks/README.md
index ee2c0a833..cd4276781 100644
--- a/examples/benchmarks/README.md
+++ b/examples/benchmarks/README.md
@@ -25,6 +25,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
| TCTS (Xueqing Wu, et al.)| Alpha360 | 0.0485±0.00 | 0.3689±0.04| 0.0586±0.00 | 0.4669±0.02 | 0.0816±0.02 | 1.1572±0.30| -0.0689±0.02 |
| Transformer (Ashish Vaswani, et al.)| Alpha360 | 0.0141±0.00 | 0.0917±0.02| 0.0331±0.00 | 0.2357±0.03 | -0.0259±0.03 | -0.3323±0.43| -0.1763±0.07 |
| Localformer (Juyong Jiang, et al.)| Alpha360 | 0.0408±0.00 | 0.2988±0.03| 0.0538±0.00 | 0.4105±0.02 | 0.0275±0.03 | 0.3464±0.37| -0.1182±0.03 |
+| TRA (Hengxu Lin, et al.)| Alpha360 | 0.0491±0.01 | 0.3868±0.06 | 0.0589±0.00 | 0.4802±0.04 | 0.0898±0.02 | 1.2490±0.32 | -0.0778±0.02 |
## Alpha158 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
@@ -43,6 +44,8 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 |
| Transformer (Ashish Vaswani, et al.)| Alpha158 | 0.0274±0.00 | 0.2166±0.04| 0.0409±0.00 | 0.3342±0.04 | 0.0204±0.03 | 0.2888±0.40| -0.1216±0.04 |
| Localformer (Juyong Jiang, et al.)| Alpha158 | 0.0355±0.00 | 0.2747±0.04| 0.0466±0.00 | 0.3762±0.03 | 0.0506±0.02 | 0.7447±0.34| -0.0875±0.02 |
+| TRA (Hengxu Lin, et al.)| Alpha158 (with selected 20 features)| 0.0409±0.00 | 0.3253±0.04 | 0.0488±0.00 | 0.4045±0.02 | 0.0673±0.02 | 1.0389±0.39 | -0.0830±0.02 |
+| TRA (Hengxu Lin, et al.)| Alpha158 | 0.0442±0.00 | 0.3426±0.03 | 0.0555±0.00 | 0.4395±0.03 | 0.0833±0.03 | 1.2064±0.36 | -0.0849±0.02 |
- The selected 20 features are based on the feature importance of a lightgbm-based model.
- The base model of DoubleEnsemble is LGBM.
diff --git a/examples/benchmarks/SFM/requirements.txt b/examples/benchmarks/SFM/requirements.txt
index 6a3d13097..16de0a438 100644
--- a/examples/benchmarks/SFM/requirements.txt
+++ b/examples/benchmarks/SFM/requirements.txt
@@ -1,4 +1,4 @@
pandas==1.1.2
numpy==1.17.4
scikit_learn==0.23.2
-torch==1.7.0
\ No newline at end of file
+torch==1.7.0
diff --git a/examples/benchmarks/SFM/workflow_config_sfm_Alpha360.yaml b/examples/benchmarks/SFM/workflow_config_sfm_Alpha360.yaml
index e42f75aec..5c66400bb 100644
--- a/examples/benchmarks/SFM/workflow_config_sfm_Alpha360.yaml
+++ b/examples/benchmarks/SFM/workflow_config_sfm_Alpha360.yaml
@@ -85,8 +85,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/TCTS/requirements.txt b/examples/benchmarks/TCTS/requirements.txt
new file mode 100644
index 000000000..6a3d13097
--- /dev/null
+++ b/examples/benchmarks/TCTS/requirements.txt
@@ -0,0 +1,4 @@
+pandas==1.1.2
+numpy==1.17.4
+scikit_learn==0.23.2
+torch==1.7.0
\ No newline at end of file
diff --git a/examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml b/examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml
index c6eac243c..7ca6e937f 100644
--- a/examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml
+++ b/examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml
@@ -90,8 +90,6 @@ task:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
- model:
- dataset:
ana_long_short: False
ann_scaler: 252
label_col: 1
diff --git a/examples/benchmarks/TFT/README.md b/examples/benchmarks/TFT/README.md
index 5a6a9f153..991066b7f 100644
--- a/examples/benchmarks/TFT/README.md
+++ b/examples/benchmarks/TFT/README.md
@@ -8,7 +8,7 @@
Users can follow the ``workflow_by_code_tft.py`` to run the benchmark.
### Notes
-1. Please be **aware** that this script can only support `Python 3.5 - 3.8`.
+1. Please be **aware** that this script can only support `Python 3.6 - 3.7`.
2. If the CUDA version on your machine is not 10.0, please remember to run the following commands `conda install anaconda cudatoolkit=10.0` and `conda install cudnn` on your machine.
3. The model must run in GPU, or an error will be raised.
4. New datasets should be registered in ``data_formatters``, for detail please visit the source.
diff --git a/examples/benchmarks/TFT/requirements.txt b/examples/benchmarks/TFT/requirements.txt
index 04234aaed..f8bd00002 100644
--- a/examples/benchmarks/TFT/requirements.txt
+++ b/examples/benchmarks/TFT/requirements.txt
@@ -1,3 +1,2 @@
tensorflow-gpu==1.15.0
-numpy == 1.19.4
-pandas==1.1.0
\ No newline at end of file
+pandas==1.1.0
diff --git a/examples/benchmarks/TFT/tft.py b/examples/benchmarks/TFT/tft.py
index e1205b0e0..a854c2dd9 100644
--- a/examples/benchmarks/TFT/tft.py
+++ b/examples/benchmarks/TFT/tft.py
@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
+from pathlib import Path
+from typing import Union
import numpy as np
import pandas as pd
import tensorflow.compat.v1 as tf
@@ -243,7 +245,7 @@ class TFTModel(ModelFT):
# extract_numerical_data(targets), extract_numerical_data(p90_forecast),
# 0.9)
tf.keras.backend.set_session(default_keras_session)
- print("Training completed.".format(dte.datetime.now()))
+ print("Training completed at {}.".format(dte.datetime.now()))
# ===========================Training Process===========================
def predict(self, dataset):
@@ -289,3 +291,25 @@ class TFTModel(ModelFT):
dataset for finetuning
"""
pass
+
+ def to_pickle(self, path: Union[Path, str]):
+ """
+ Tensorflow model can't be dumped directly.
+ So the data should be save seperatedly
+
+ **TODO**: Please implement the function to load the files
+
+ Parameters
+ ----------
+ path : Union[Path, str]
+ the target path to be dumped
+ """
+ # FIXME: implementing saving tensorflow models
+ # save tensorflow model
+ # path = Path(path)
+ # path.mkdir(parents=True)
+ # self.model.save(path)
+
+ # save qlib model wrapper
+ self.model = None
+ super(TFTModel, self).to_pickle(path)
diff --git a/examples/benchmarks/TFT/workflow_config_tft_Alpha158.yaml b/examples/benchmarks/TFT/workflow_config_tft_Alpha158.yaml
index a396371dc..0508ce676 100644
--- a/examples/benchmarks/TFT/workflow_config_tft_Alpha158.yaml
+++ b/examples/benchmarks/TFT/workflow_config_tft_Alpha158.yaml
@@ -58,8 +58,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/TRA/README.md b/examples/benchmarks/TRA/README.md
index 070527ddb..5ff5b480e 100644
--- a/examples/benchmarks/TRA/README.md
+++ b/examples/benchmarks/TRA/README.md
@@ -1,53 +1,78 @@
# Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport
-This code provides a PyTorch implementation for TRA (Temporal Routing Adaptor), as described in the paper [Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport](http://arxiv.org/abs/2106.12950).
+Temporal Routing Adaptor (TRA) is designed to capture multiple trading patterns in the stock market data. Please refer to [our paper](http://arxiv.org/abs/2106.12950) for more details.
-* TRA (Temporal Routing Adaptor) is a lightweight module that consists of a set of independent predictors for learning multiple patterns as well as a router to dispatch samples to different predictors.
-* We also design a learning algorithm based on Optimal Transport (OT) to obtain the optimal sample to predictor assignment and effectively optimize the router with such assignment through an auxiliary loss term.
+If you find our work useful in your research, please cite:
+```
+@inproceedings{HengxuKDD2021,
+ author = {Hengxu Lin and Dong Zhou and Weiqing Liu and Jiang Bian},
+ title = {Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport},
+ booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery \& Data Mining},
+ series = {KDD '21},
+ year = {2021},
+ publisher = {ACM},
+}
+@article{yang2020qlib,
+ title={Qlib: An AI-oriented Quantitative Investment Platform},
+ author={Yang, Xiao and Liu, Weiqing and Zhou, Dong and Bian, Jiang and Liu, Tie-Yan},
+ journal={arXiv preprint arXiv:2009.11189},
+ year={2020}
+}
+```
-# Running TRA
+## Usage (Recommended)
-## Requirements
-- Install `Qlib` main branch
+**Update**: `TRA` has been moved to `qlib.contrib.model.pytorch_tra` to support other `Qlib` components like `qlib.workflow` and `Alpha158/Alpha360` dataset.
-## Running
+Please follow the official [doc](https://qlib.readthedocs.io/en/latest/component/workflow.html) to use `TRA` with `workflow`. Here we also provide several example config files:
+
+- `workflow_config_tra_Alpha360.yaml`: running `TRA` with `Alpha360` dataset
+- `workflow_config_tra_Alpha158.yaml`: running `TRA` with `Alpha158` dataset (with feature subsampling)
+- `workflow_config_tra_Alpha158_full.yaml`: running `TRA` with `Alpha158` dataset (without feature subsampling)
+
+The performances of `TRA` are reported in [Benchmarks](https://github.com/microsoft/qlib/tree/main/examples/benchmarks).
+
+## Usage (Not Maintained)
+
+This section is used to reproduce the results in the paper.
+
+### Running
We attach our running scripts for the paper in `run.sh`.
And here are two ways to run the model:
* Running from scripts with default parameters
- You can directly run from Qlib command `qrun`:
- ```
- qrun configs/config_alstm.yaml
- ```
+
+ You can directly run from Qlib command `qrun`:
+ ```
+ qrun configs/config_alstm.yaml
+ ```
* Running from code with self-defined parameters
- Setting different parameters is also allowed. See codes in `example.py`:
- ```
- python example.py --config_file configs/config_alstm.yaml
- ```
+
+ Setting different parameters is also allowed. See codes in `example.py`:
+ ```
+ python example.py --config_file configs/config_alstm.yaml
+ ```
Here we trained TRA on a pretrained backbone model. Therefore we run `*_init.yaml` before TRA's scipts.
-# Results
-
-## Outputs
+### Results
After running the scripts, you can find result files in path `./output`:
-`info.json` - config settings and result metrics.
+* `info.json` - config settings and result metrics.
+* `log.csv` - running logs.
+* `model.bin` - the model parameter dictionary.
+* `pred.pkl` - the prediction scores and output for inference.
-`log.csv` - running logs.
+Evaluation metrics reported in the paper:
+This result is generated by qlib==0.7.1.
-`model.bin` - the model parameter dictionary.
-
-`pred.pkl` - the prediction scores and output for inference.
-
-## Our Results
| Methods | MSE| MAE| IC | ICIR | AR | AV | SR | MDD |
-|-------------------|-------------------|---------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
+|-------|-------|------|-----|-----|-----|-----|-----|-----|
|Linear|0.163|0.327|0.020|0.132|-3.2%|16.8%|-0.191|32.1%|
|LightGBM|0.160(0.000)|0.323(0.000)|0.041|0.292|7.8%|15.5%|0.503|25.7%|
|MLP|0.160(0.002)|0.323(0.003)|0.037|0.273|3.7%|15.3%|0.264|26.2%|
@@ -61,21 +86,8 @@ After running the scripts, you can find result files in path `./output`:
A more detailed demo for our experiment results in the paper can be found in `Report.ipynb`.
-# Common Issues
+## Common Issues
For help or issues using TRA, please submit a GitHub issue.
-Sometimes we might encounter situation where the loss is `NaN`, please check the `epsilon` parameter in the sinkhorn algorithm, adjusting the `epsilon` according to input's scale is important.
-
-# Citation
-If you find this repository useful in your research, please cite:
-```
-@inproceedings{HengxuKDD2021,
- author = {Hengxu Lin and Dong Zhou and Weiqing Liu and Jiang Bian},
- title = {Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport},
- booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery \& Data Mining},
- series = {KDD '21},
- year = {2021},
- publisher = {ACM},
-}
-```
+Sometimes we might encounter situation where the loss is `NaN`, please check the `epsilon` parameter in the sinkhorn algorithm, adjusting the `epsilon` according to input's scale is important.
diff --git a/examples/benchmarks/TRA/requirements.txt b/examples/benchmarks/TRA/requirements.txt
new file mode 100644
index 000000000..ab819ec1c
--- /dev/null
+++ b/examples/benchmarks/TRA/requirements.txt
@@ -0,0 +1,5 @@
+pandas==1.1.2
+numpy==1.17.4
+scikit_learn==0.23.2
+torch==1.7.0
+seaborn
diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml
new file mode 100644
index 000000000..72b900127
--- /dev/null
+++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml
@@ -0,0 +1,132 @@
+qlib_init:
+ provider_uri: "~/.qlib/qlib_data/cn_data"
+ region: cn
+
+market: &market csi300
+benchmark: &benchmark SH000300
+
+data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ end_time: 2020-08-01
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+ infer_processors:
+ - class: FilterCol
+ kwargs:
+ fields_group: feature
+ col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
+ "ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
+ "RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"]
+ - class: RobustZScoreNorm
+ kwargs:
+ fields_group: feature
+ clip_outlier: true
+ - class: Fillna
+ kwargs:
+ fields_group: feature
+ learn_processors:
+ - class: CSRankNorm
+ kwargs:
+ fields_group: label
+ label: ["Ref($close, -2) / Ref($close, -1) - 1"]
+
+num_states: &num_states 3
+
+memory_mode: &memory_mode sample
+
+tra_config: &tra_config
+ num_states: *num_states
+ rnn_arch: LSTM
+ hidden_size: 32
+ num_layers: 1
+ dropout: 0.0
+ tau: 1.0
+ src_info: LR_TPE
+
+model_config: &model_config
+ input_size: 20
+ hidden_size: 64
+ num_layers: 2
+ rnn_arch: LSTM
+ use_attn: True
+ dropout: 0.0
+
+port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy
+ kwargs:
+ model:
+ dataset:
+ topk: 50
+ n_drop: 5
+ backtest:
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+
+task:
+ model:
+ class: TRAModel
+ module_path: qlib.contrib.model.pytorch_tra
+ kwargs:
+ tra_config: *tra_config
+ model_config: *model_config
+ model_type: RNN
+ lr: 1e-3
+ n_epochs: 100
+ max_steps_per_epoch:
+ early_stop: 20
+ logdir: output/Alpha158
+ seed: 0
+ lamb: 1.0
+ rho: 0.99
+ alpha: 0.5
+ transport_method: router
+ memory_mode: *memory_mode
+ eval_train: False
+ eval_test: True
+ pretrain: True
+ init_state:
+ freeze_model: False
+ freeze_predictors: False
+ dataset:
+ class: MTSDatasetH
+ module_path: qlib.contrib.data.dataset
+ kwargs:
+ handler:
+ class: Alpha158
+ module_path: qlib.contrib.data.handler
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+ seq_len: 60
+ horizon: 2
+ input_size:
+ num_states: *num_states
+ batch_size: 1024
+ n_samples:
+ memory_mode: *memory_mode
+ drop_last: True
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ model:
+ dataset:
+ - class: SigAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ ana_long_short: False
+ ann_scaler: 252
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml
new file mode 100644
index 000000000..ab8febc2f
--- /dev/null
+++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml
@@ -0,0 +1,126 @@
+qlib_init:
+ provider_uri: "~/.qlib/qlib_data/cn_data"
+ region: cn
+
+market: &market csi300
+benchmark: &benchmark SH000300
+
+data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ end_time: 2020-08-01
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+ infer_processors:
+ - class: RobustZScoreNorm
+ kwargs:
+ fields_group: feature
+ clip_outlier: true
+ - class: Fillna
+ kwargs:
+ fields_group: feature
+ learn_processors:
+ - class: CSRankNorm
+ kwargs:
+ fields_group: label
+ label: ["Ref($close, -2) / Ref($close, -1) - 1"]
+
+num_states: &num_states 3
+
+memory_mode: &memory_mode sample
+
+tra_config: &tra_config
+ num_states: *num_states
+ rnn_arch: LSTM
+ hidden_size: 32
+ num_layers: 1
+ dropout: 0.0
+ tau: 1.0
+ src_info: LR_TPE
+
+model_config: &model_config
+ input_size: 158
+ hidden_size: 256
+ num_layers: 2
+ rnn_arch: LSTM
+ use_attn: True
+ dropout: 0.2
+
+port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy
+ kwargs:
+ model:
+ dataset:
+ topk: 50
+ n_drop: 5
+ backtest:
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+
+task:
+ model:
+ class: TRAModel
+ module_path: qlib.contrib.model.pytorch_tra
+ kwargs:
+ tra_config: *tra_config
+ model_config: *model_config
+ model_type: RNN
+ lr: 1e-3
+ n_epochs: 100
+ max_steps_per_epoch:
+ early_stop: 20
+ logdir: output/Alpha158_full
+ seed: 0
+ lamb: 1.0
+ rho: 0.99
+ alpha: 0.5
+ transport_method: router
+ memory_mode: *memory_mode
+ eval_train: False
+ eval_test: True
+ pretrain: True
+ init_state:
+ freeze_model: False
+ freeze_predictors: False
+ dataset:
+ class: MTSDatasetH
+ module_path: qlib.contrib.data.dataset
+ kwargs:
+ handler:
+ class: Alpha158
+ module_path: qlib.contrib.data.handler
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+ seq_len: 60
+ horizon: 2
+ input_size:
+ num_states: *num_states
+ batch_size: 1024
+ n_samples:
+ memory_mode: *memory_mode
+ drop_last: True
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ model:
+ dataset:
+ - class: SigAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ ana_long_short: False
+ ann_scaler: 252
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml
new file mode 100644
index 000000000..64e3c91cb
--- /dev/null
+++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml
@@ -0,0 +1,126 @@
+qlib_init:
+ provider_uri: "~/.qlib/qlib_data/cn_data"
+ region: cn
+
+market: &market csi300
+benchmark: &benchmark SH000300
+
+data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ end_time: 2020-08-01
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+ infer_processors:
+ - class: RobustZScoreNorm
+ kwargs:
+ fields_group: feature
+ clip_outlier: true
+ - class: Fillna
+ kwargs:
+ fields_group: feature
+ learn_processors:
+ - class: CSRankNorm
+ kwargs:
+ fields_group: label
+ label: ["Ref($close, -2) / Ref($close, -1) - 1"]
+
+num_states: &num_states 3
+
+memory_mode: &memory_mode sample
+
+tra_config: &tra_config
+ num_states: *num_states
+ rnn_arch: LSTM
+ hidden_size: 32
+ num_layers: 1
+ dropout: 0.0
+ tau: 1.0
+ src_info: LR_TPE
+
+model_config: &model_config
+ input_size: 6
+ hidden_size: 64
+ num_layers: 2
+ rnn_arch: LSTM
+ use_attn: True
+ dropout: 0.0
+
+port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy
+ kwargs:
+ model:
+ dataset:
+ topk: 50
+ n_drop: 5
+ backtest:
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+
+task:
+ model:
+ class: TRAModel
+ module_path: qlib.contrib.model.pytorch_tra
+ kwargs:
+ tra_config: *tra_config
+ model_config: *model_config
+ model_type: RNN
+ lr: 1e-3
+ n_epochs: 100
+ max_steps_per_epoch:
+ early_stop: 20
+ logdir: output/Alpha360
+ seed: 0
+ lamb: 1.0
+ rho: 0.99
+ alpha: 0.5
+ transport_method: router
+ memory_mode: *memory_mode
+ eval_train: False
+ eval_test: True
+ pretrain: True
+ init_state:
+ freeze_model: False
+ freeze_predictors: False
+ dataset:
+ class: MTSDatasetH
+ module_path: qlib.contrib.data.dataset
+ kwargs:
+ handler:
+ class: Alpha360
+ module_path: qlib.contrib.data.handler
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+ seq_len: 60
+ horizon: 2
+ input_size: 6
+ num_states: *num_states
+ batch_size: 1024
+ n_samples:
+ memory_mode: *memory_mode
+ drop_last: True
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ model:
+ dataset:
+ - class: SigAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ ana_long_short: False
+ ann_scaler: 252
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
diff --git a/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha158.yaml b/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha158.yaml
index 71d41be63..0fa1b23d5 100644
--- a/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha158.yaml
+++ b/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha158.yaml
@@ -75,8 +75,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha360.yaml b/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha360.yaml
index f43af104c..0c798ae30 100644
--- a/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha360.yaml
+++ b/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha360.yaml
@@ -75,8 +75,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/Transformer/workflow_config_transformer_Alpha158.yaml b/examples/benchmarks/Transformer/workflow_config_transformer_Alpha158.yaml
index 54707386f..6174abf2e 100644
--- a/examples/benchmarks/Transformer/workflow_config_transformer_Alpha158.yaml
+++ b/examples/benchmarks/Transformer/workflow_config_transformer_Alpha158.yaml
@@ -34,19 +34,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: TransformerModel
@@ -70,7 +74,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml b/examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml
index e568a1b30..883c18cdc 100644
--- a/examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml
+++ b/examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml
@@ -26,19 +26,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: TransformerModel
@@ -61,7 +65,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
@@ -70,4 +76,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
- config: *port_analysis_config
\ No newline at end of file
+ config: *port_analysis_config
diff --git a/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha158.yaml b/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha158.yaml
index dee169f18..502a5e73c 100644
--- a/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha158.yaml
+++ b/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha158.yaml
@@ -64,8 +64,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha360.yaml b/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha360.yaml
index 926224f84..a2e40eefb 100644
--- a/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha360.yaml
+++ b/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha360.yaml
@@ -71,8 +71,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
- model:
- dataset:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml
index 0152cfd63..93d9dde56 100644
--- a/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml
+++ b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml
@@ -60,8 +60,6 @@ task:
- class: "SignalRecord"
module_path: "qlib.workflow.record_temp"
kwargs:
- model:
- dataset:
- class: "HFSignalRecord"
module_path: "qlib.workflow.record_temp"
kwargs: {}
\ No newline at end of file
diff --git a/examples/model_rolling/requirements.txt b/examples/model_rolling/requirements.txt
new file mode 100644
index 000000000..10ddd5b71
--- /dev/null
+++ b/examples/model_rolling/requirements.txt
@@ -0,0 +1 @@
+xgboost
diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py
index 844f18198..091a87862 100644
--- a/examples/model_rolling/task_manager_rolling.py
+++ b/examples/model_rolling/task_manager_rolling.py
@@ -17,7 +17,7 @@ from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager, run_task
from qlib.workflow.task.collect import RecorderCollector
from qlib.model.ens.group import RollingGroup
-from qlib.model.trainer import TrainerRM
+from qlib.model.trainer import TrainerRM, task_train
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
diff --git a/examples/nested_decision_execution/workflow.py b/examples/nested_decision_execution/workflow.py
index b6c1362fd..ef6906018 100644
--- a/examples/nested_decision_execution/workflow.py
+++ b/examples/nested_decision_execution/workflow.py
@@ -19,7 +19,7 @@ class NestedDecisionExecutionWorkflow:
benchmark = "SH000300"
data_handler_config = {
"start_time": "2008-01-01",
- "end_time": "2020-12-31",
+ "end_time": "2021-05-31",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
@@ -53,7 +53,7 @@ class NestedDecisionExecutionWorkflow:
"segments": {
"train": ("2007-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
- "test": ("2020-01-01", "2020-12-31"),
+ "test": ("2020-01-01", "2021-05-31"),
},
},
},
@@ -75,7 +75,7 @@ class NestedDecisionExecutionWorkflow:
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": "5min",
- "generate_report": True,
+ "generate_portfolio_metrics": True,
"verbose": True,
"indicator_config": {
"show_indicator": True,
@@ -86,7 +86,7 @@ class NestedDecisionExecutionWorkflow:
"class": "TWAPStrategy",
"module_path": "qlib.contrib.strategy.rule_strategy",
},
- "generate_report": True,
+ "generate_portfolio_metrics": True,
"indicator_config": {
"show_indicator": True,
},
@@ -101,15 +101,15 @@ class NestedDecisionExecutionWorkflow:
},
},
"track_data": True,
- "generate_report": True,
+ "generate_portfolio_metrics": True,
"indicator_config": {
"show_indicator": True,
},
},
},
"backtest": {
- "start_time": "2020-01-01",
- "end_time": "2020-12-31",
+ "start_time": "2020-09-20",
+ "end_time": "2021-05-20",
"account": 100000000,
"exchange_kwargs": {
"freq": "1min",
@@ -124,8 +124,6 @@ class NestedDecisionExecutionWorkflow:
def _init_qlib(self):
"""initialize qlib"""
- # provider_uri_day = "/data/stock_data/huaxia/qlib"
- # provider_uri_1min = "/data2/stock_data/huaxia_1min_qlib"
provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir
GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN, version="v2", exists_skip=True)
provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri")
@@ -133,31 +131,7 @@ class NestedDecisionExecutionWorkflow:
target_dir=provider_uri_1min, interval="1min", region=REG_CN, version="v2", exists_skip=True
)
provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day}
- client_config = {
- "calendar_provider": {
- "class": "LocalCalendarProvider",
- "module_path": "qlib.data.data",
- "kwargs": {
- "backend": {
- "class": "FileCalendarStorage",
- "module_path": "qlib.data.storage.file_storage",
- "kwargs": {"provider_uri_map": provider_uri_map},
- }
- },
- },
- "feature_provider": {
- "class": "LocalFeatureProvider",
- "module_path": "qlib.data.data",
- "kwargs": {
- "backend": {
- "class": "FileFeatureStorage",
- "module_path": "qlib.data.storage.file_storage",
- "kwargs": {"provider_uri_map": provider_uri_map},
- }
- },
- },
- }
- qlib.init(provider_uri=provider_uri_day, **client_config, redis_port=-1)
+ qlib.init(provider_uri=provider_uri_map, dataset_cache=None, expression_cache=None)
def _train_model(self, model, dataset):
with R.start(experiment_name="train"):
@@ -186,9 +160,8 @@ class NestedDecisionExecutionWorkflow:
},
}
self.port_analysis_config["strategy"] = strategy_config
- self.port_analysis_config["backtest"]["benchmark"] = D.list_instruments(
- instruments=D.instruments(market=self.market), as_list=True
- )
+ self.port_analysis_config["backtest"]["benchmark"] = self.benchmark
+
with R.start(experiment_name="backtest"):
recorder = R.get_recorder()
@@ -201,6 +174,7 @@ class NestedDecisionExecutionWorkflow:
)
par.generate()
+ # user could use following methods to analysis the position
# report_normal_df = recorder.load_object("portfolio_analysis/report_normal_1day.pkl")
# from qlib.contrib.report import analysis_position
# analysis_position.report_graph(report_normal_df)
@@ -212,7 +186,7 @@ class NestedDecisionExecutionWorkflow:
self._train_model(model, dataset)
executor_config = self.port_analysis_config["executor"]
backtest_config = self.port_analysis_config["backtest"]
- backtest_config["benchmark"] = D.list_instruments(instruments=D.instruments(market=self.market), as_list=True)
+ backtest_config["benchmark"] = self.benchmark
strategy_config = {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
diff --git a/examples/run_all_model.py b/examples/run_all_model.py
index 1284d8e99..41aba091e 100644
--- a/examples/run_all_model.py
+++ b/examples/run_all_model.py
@@ -6,6 +6,7 @@ import sys
import fire
import time
import glob
+import yaml
import shutil
import signal
import inspect
@@ -23,22 +24,6 @@ from qlib.config import REG_CN
from qlib.workflow import R
from qlib.tests.data import GetData
-# init qlib
-provider_uri = "~/.qlib/qlib_data/cn_data"
-exp_folder_name = "run_all_model_records"
-exp_path = str(Path(os.getcwd()).resolve() / exp_folder_name)
-exp_manager = {
- "class": "MLflowExpManager",
- "module_path": "qlib.workflow.expm",
- "kwargs": {
- "uri": "file:" + exp_path,
- "default_exp_name": "Experiment",
- },
-}
-
-GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
-qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
-
# decorator to check the arguments
def only_allow_defined_args(function_to_decorate):
@@ -88,11 +73,11 @@ def create_env():
sys.stderr.write("\n")
# get anaconda activate path
conda_activate = Path(os.environ["CONDA_PREFIX"]) / "bin" / "activate" # TODO: FIX ME!
- return env_path, python_path, conda_activate
+ return temp_dir, env_path, python_path, conda_activate
# function to execute the cmd
-def execute(cmd, wait_when_err=False):
+def execute(cmd, wait_when_err=False, raise_err=True):
print("Running CMD:", cmd)
with subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True, shell=True) as p:
for line in p.stdout:
@@ -105,6 +90,8 @@ def execute(cmd, wait_when_err=False):
if p.returncode != 0:
if wait_when_err:
input("Press Enter to Continue")
+ if raise_err:
+ raise RuntimeError(f"Error when executing command: {cmd}")
return p.stderr
else:
return None
@@ -134,14 +121,23 @@ def get_all_folders(models, exclude) -> dict:
def get_all_files(folder_path, dataset) -> (str, str):
yaml_path = str(Path(f"{folder_path}") / f"*{dataset}*.yaml")
req_path = str(Path(f"{folder_path}") / f"*.txt")
- return glob.glob(yaml_path)[0], glob.glob(req_path)[0]
+ yaml_file = glob.glob(yaml_path)
+ req_file = glob.glob(req_path)
+ if len(yaml_file) == 0:
+ return None, None
+ else:
+ return yaml_file[0], req_file[0]
# function to retrieve all the results
def get_all_results(folders) -> dict:
results = dict()
for fn in folders:
- exp = R.get_exp(experiment_name=fn, create=False)
+ try:
+ exp = R.get_exp(experiment_name=fn, create=False)
+ except ValueError:
+ # No experiment results
+ continue
recorders = exp.list_recorders()
result = dict()
result["annualized_return_with_cost"] = list()
@@ -155,9 +151,9 @@ def get_all_results(folders) -> dict:
if recorders[recorder_id].status == "FINISHED":
recorder = R.get_recorder(recorder_id=recorder_id, experiment_name=fn)
metrics = recorder.list_metrics()
- result["annualized_return_with_cost"].append(metrics["excess_return_with_cost.annualized_return"])
- result["information_ratio_with_cost"].append(metrics["excess_return_with_cost.information_ratio"])
- result["max_drawdown_with_cost"].append(metrics["excess_return_with_cost.max_drawdown"])
+ result["annualized_return_with_cost"].append(metrics["1day.excess_return_with_cost.annualized_return"])
+ result["information_ratio_with_cost"].append(metrics["1day.excess_return_with_cost.information_ratio"])
+ result["max_drawdown_with_cost"].append(metrics["1day.excess_return_with_cost.max_drawdown"])
result["ic"].append(metrics["IC"])
result["icir"].append(metrics["ICIR"])
result["rank_ic"].append(metrics["Rank IC"])
@@ -185,6 +181,25 @@ def gen_and_save_md_table(metrics, dataset):
return table
+# read yaml, remove seed kwargs of model, and then save file in the temp_dir
+def gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir):
+ with open(yaml_path, "r") as fp:
+ config = yaml.load(fp)
+ try:
+ del config["task"]["model"]["kwargs"]["seed"]
+ except KeyError:
+ # If the key does not exists, use original yaml
+ # NOTE: it is very important if the model most run in original path(when sys.rel_path is used)
+ return yaml_path
+ else:
+ # otherwise, generating a new yaml without random seed
+ file_name = yaml_path.split("/")[-1]
+ temp_path = os.path.join(temp_dir, file_name)
+ with open(temp_path, "w") as fp:
+ yaml.dump(config, fp)
+ return temp_path
+
+
# function to run the all the models
@only_allow_defined_args
def run(
@@ -193,12 +208,13 @@ def run(
dataset="Alpha360",
exclude=False,
qlib_uri: str = "git+https://github.com/microsoft/qlib#egg=pyqlib",
+ exp_folder_name: str = "run_all_model_records",
wait_before_rm_env: bool = False,
wait_when_err: bool = False,
):
"""
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
- Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parrallel running the same model
+ Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parallel running the same model
for multiple times, and this will be fixed in the future development.
Parameters:
@@ -214,6 +230,8 @@ def run(
qlib_uri : str
the uri to install qlib with pip
it could be url on the we or local path
+ exp_folder_name: str
+ the name of the experiment folder
wait_before_rm_env : bool
wait before remove environment.
wait_when_err : bool
@@ -240,26 +258,58 @@ def run(
# Case 5 - run specific models for one time
python run_all_model.py --models=[mlp,lightgbm]
- # Case 6 - run other models except those are given as aruments for one time
+ # Case 6 - run other models except those are given as arguments for one time
python run_all_model.py --models=[mlp,tft,sfm] --exclude=True
"""
+ # init qlib
+ GetData().qlib_data(exists_skip=True)
+ qlib.init(
+ exp_manager={
+ "class": "MLflowExpManager",
+ "module_path": "qlib.workflow.expm",
+ "kwargs": {
+ "uri": "file:" + str(Path(os.getcwd()).resolve() / exp_folder_name),
+ "default_exp_name": "Experiment",
+ },
+ }
+ )
+
# get all folders
folders = get_all_folders(models, exclude)
# init error messages:
errors = dict()
# run all the model for iterations
for fn in folders:
- # create env by anaconda
- env_path, python_path, conda_activate = create_env()
# get all files
sys.stderr.write("Retrieving files...\n")
yaml_path, req_path = get_all_files(folders[fn], dataset)
+ if yaml_path is None:
+ sys.stderr.write(f"There is no {dataset}.yaml file in {folders[fn]}")
+ continue
sys.stderr.write("\n")
+ # create env by anaconda
+ temp_dir, env_path, python_path, conda_activate = create_env()
+
# install requirements.txt
sys.stderr.write("Installing requirements.txt...\n")
- execute(f"{python_path} -m pip install -r {req_path}", wait_when_err=wait_when_err)
+ with open(req_path) as f:
+ content = f.read()
+ if "torch" in content:
+ # automatically install pytorch according to nvidia's version
+ execute(
+ f"{python_path} -m pip install light-the-torch", wait_when_err=wait_when_err
+ ) # for automatically installing torch according to the nvidia driver
+ execute(
+ f"{env_path / 'bin' / 'ltt'} install --install-cmd '{python_path} -m pip install {{packages}}' -- -r {req_path}",
+ wait_when_err=wait_when_err,
+ )
+ else:
+ execute(f"{python_path} -m pip install -r {req_path}", wait_when_err=wait_when_err)
sys.stderr.write("\n")
+
+ # read yaml, remove seed kwargs of model, and then save file in the temp_dir
+ yaml_path = gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir)
# setup gpu for tft
if fn == "TFT":
execute(
@@ -302,19 +352,20 @@ def run(
# getting all results
sys.stderr.write(f"Retrieving results...\n")
results = get_all_results(folders)
- # calculating the mean and std
- sys.stderr.write(f"Calculating the mean and std of results...\n")
- results = cal_mean_std(results)
- # generating md table
- sys.stderr.write(f"Generating markdown table...\n")
- gen_and_save_md_table(results, dataset)
- sys.stderr.write("\n")
- # print erros
+ if len(results) > 0:
+ # calculating the mean and std
+ sys.stderr.write(f"Calculating the mean and std of results...\n")
+ results = cal_mean_std(results)
+ # generating md table
+ sys.stderr.write(f"Generating markdown table...\n")
+ gen_and_save_md_table(results, dataset)
+ sys.stderr.write("\n")
+ # print errors
sys.stderr.write(f"Here are some of the errors of the models...\n")
pprint(errors)
sys.stderr.write("\n")
# move results folder
- shutil.move(exp_path, exp_path + f"_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}")
+ shutil.move(exp_folder_name, exp_folder_name + f"_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}")
shutil.move("table.md", f"table_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}.md")
diff --git a/examples/workflow_by_code.ipynb b/examples/workflow_by_code.ipynb
index 1658565d6..907245ade 100644
--- a/examples/workflow_by_code.ipynb
+++ b/examples/workflow_by_code.ipynb
@@ -20,9 +20,7 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "scrolled": true
- },
+ "metadata": {},
"outputs": [],
"source": [
"import sys, site\n",
@@ -201,7 +199,7 @@
" \"module_path\": \"qlib.backtest.executor\",\n",
" \"kwargs\": {\n",
" \"time_per_step\": \"day\",\n",
- " \"generate_report\": True,\n",
+ " \"generate_portfolio_metrics\": True,\n",
" },\n",
" },\n",
" \"strategy\": {\n",
@@ -362,7 +360,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -375,8 +373,7 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.3"
+ "pygments_lexer": "ipython3"
},
"toc": {
"base_numbering": 1,
@@ -394,4 +391,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
-}
\ No newline at end of file
+}
diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py
index d7bb544f9..486e694a7 100644
--- a/examples/workflow_by_code.py
+++ b/examples/workflow_by_code.py
@@ -26,7 +26,7 @@ if __name__ == "__main__":
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": "day",
- "generate_report": True,
+ "generate_portfolio_metrics": True,
},
},
"strategy": {
diff --git a/qlib/__init__.py b/qlib/__init__.py
index 6f76bbcaa..efa89b153 100644
--- a/qlib/__init__.py
+++ b/qlib/__init__.py
@@ -1,17 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
+from pathlib import Path
-
-__version__ = "0.7.0.99"
+_version_path = Path(__file__).absolute().parent / "VERSION.txt" # This file is copyed from setup.py
+__version__ = _version_path.read_text(encoding="utf-8").strip()
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
-
-
import os
import yaml
import logging
import platform
import subprocess
-from pathlib import Path
from .log import get_module_logger
@@ -33,69 +31,71 @@ def init(default_conf="client", **kwargs):
H.clear()
C.set(default_conf, **kwargs)
- # check path if server/local
- if C.get_uri_type() == C.LOCAL_URI:
- if not os.path.exists(C["provider_uri"]):
- if C["auto_mount"]:
- logger.error(
- f"Invalid provider uri: {C['provider_uri']}, please check if a valid provider uri has been set. This path does not exist."
- )
- else:
- logger.warning(f"auto_path is False, please make sure {C['mount_path']} is mounted")
- elif C.get_uri_type() == C.NFS_URI:
- _mount_nfs_uri(C)
- else:
- raise NotImplementedError(f"This type of URI is not supported")
+ # mount nfs
+ for _freq, provider_uri in C.provider_uri.items():
+ mount_path = C["mount_path"][_freq]
+ # check path if server/local
+ uri_type = C.dpm.get_uri_type(provider_uri)
+ if uri_type == C.LOCAL_URI:
+ if not Path(provider_uri).exists():
+ if C["auto_mount"]:
+ logger.error(
+ f"Invalid provider uri: {provider_uri}, please check if a valid provider uri has been set. This path does not exist."
+ )
+ else:
+ logger.warning(f"auto_path is False, please make sure {mount_path} is mounted")
+ elif uri_type == C.NFS_URI:
+ _mount_nfs_uri(provider_uri, mount_path, C["auto_mount"])
+ else:
+ raise NotImplementedError(f"This type of URI is not supported")
C.register()
if "flask_server" in C:
logger.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
logger.info("qlib successfully initialized based on %s settings." % default_conf)
- logger.info(f"data_path={C.get_data_path()}")
+ data_path = {_freq: C.dpm.get_data_uri(_freq) for _freq in C.dpm.provider_uri.keys()}
+ logger.info(f"data_path={data_path}")
-def _mount_nfs_uri(C):
+def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
LOG = get_module_logger("mount nfs", level=logging.INFO)
-
+ if mount_path is None:
+ raise ValueError(f"Invalid mount path: {mount_path}!")
# FIXME: the C["provider_uri"] is modified in this function
# If it is not modified, we can pass only provider_uri or mount_path instead of C
- mount_command = "sudo mount.nfs %s %s" % (C["provider_uri"], C["mount_path"])
+ mount_command = "sudo mount.nfs %s %s" % (provider_uri, mount_path)
# If the provider uri looks like this 172.23.233.89//data/csdesign'
# It will be a nfs path. The client provider will be used
- if not C["auto_mount"]:
- if not os.path.exists(C["mount_path"]):
+ if not auto_mount:
+ if not Path(mount_path).exists():
raise FileNotFoundError(
- f"Invalid mount path: {C['mount_path']}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
+ f"Invalid mount path: {mount_path}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
)
else:
# Judging system type
sys_type = platform.system()
if "win" in sys_type.lower():
# system: window
- exec_result = os.popen("mount -o anon %s %s" % (C["provider_uri"], C["mount_path"] + ":"))
+ exec_result = os.popen("mount -o anon %s %s" % (provider_uri, mount_path + ":"))
result = exec_result.read()
if "85" in result:
- LOG.warning("already mounted or window mount path already exists")
+ LOG.warning(f"{provider_uri} on Windows:{mount_path} is already mounted")
elif "53" in result:
raise OSError("not find network path")
elif "error" in result or "错误" in result:
raise OSError("Invalid mount path")
- elif C["provider_uri"] in result:
+ elif provider_uri in result:
LOG.info("window success mount..")
else:
raise OSError(f"unknown error: {result}")
- # config mount path
- C["mount_path"] = C["mount_path"] + ":\\"
else:
# system: linux/Unix/Mac
# check mount
- _remote_uri = C["provider_uri"]
- _remote_uri = _remote_uri[:-1] if _remote_uri.endswith("/") else _remote_uri
- _mount_path = C["mount_path"]
- _mount_path = _mount_path[:-1] if _mount_path.endswith("/") else _mount_path
+ _remote_uri = provider_uri[:-1] if provider_uri.endswith("/") else provider_uri
+ _mount_path = mount_path[:-1] if mount_path.endswith("/") else mount_path
_check_level_num = 2
_is_mount = False
while _check_level_num:
@@ -121,11 +121,9 @@ def _mount_nfs_uri(C):
if not _is_mount:
try:
- os.makedirs(C["mount_path"], exist_ok=True)
+ Path(mount_path).mkdir(parents=True, exist_ok=True)
except Exception:
- raise OSError(
- f"Failed to create directory {C['mount_path']}, please create {C['mount_path']} manually!"
- )
+ raise OSError(f"Failed to create directory {mount_path}, please create {mount_path} manually!")
# check nfs-common
command_res = os.popen("dpkg -l | grep nfs-common")
@@ -136,11 +134,11 @@ def _mount_nfs_uri(C):
command_status = os.system(mount_command)
if command_status == 256:
raise OSError(
- f"mount {C['provider_uri']} on {C['mount_path']} error! Needs SUDO! Please mount manually: {mount_command}"
+ f"mount {provider_uri} on {mount_path} error! Needs SUDO! Please mount manually: {mount_command}"
)
elif command_status == 32512:
# LOG.error("Command error")
- raise OSError(f"mount {C['provider_uri']} on {C['mount_path']} error! Command error")
+ raise OSError(f"mount {provider_uri} on {mount_path} error! Command error")
elif command_status == 0:
LOG.info("Mount finished")
else:
diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py
index d4a19eb25..38541d768 100644
--- a/qlib/backtest/__init__.py
+++ b/qlib/backtest/__init__.py
@@ -9,13 +9,13 @@ from .account import Account
if TYPE_CHECKING:
from ..strategy.base import BaseStrategy
from .executor import BaseExecutor
- from .order import BaseTradeDecision
-from .order import Order
+ from .decision import BaseTradeDecision
from .position import Position
from .exchange import Exchange
from .backtest import backtest_loop
from .backtest import collect_data_loop
-from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
+from .utils import CommonInfrastructure
+from .decision import Order
from ..utils import init_instance_by_config
from ..log import get_module_logger
from ..config import C
@@ -231,10 +231,9 @@ def backtest(
Returns
-------
- report: Report
- it records the trading report information
- It is organized in a dict format
- indicator: Indicator
+ portfolio_metrics_dict: Dict[PortfolioMetrics]
+ it records the trading portfolio_metrics information
+ indicator_dict: Dict[Indicator]
it computes the trading indicator
It is organized in a dict format
@@ -249,9 +248,8 @@ def backtest(
exchange_kwargs,
pos_type=pos_type,
)
- report, indicator = backtest_loop(start_time, end_time, trade_strategy, trade_executor)
-
- return report, indicator
+ portfolio_metrics, indicator = backtest_loop(start_time, end_time, trade_strategy, trade_executor)
+ return portfolio_metrics, indicator
def collect_data(
diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py
index 163ee8c26..aa503ebc2 100644
--- a/qlib/backtest/account.py
+++ b/qlib/backtest/account.py
@@ -4,22 +4,19 @@ from __future__ import annotations
import copy
from typing import Dict, List, Tuple, TYPE_CHECKING
from qlib.utils import init_instance_by_config
-import warnings
import pandas as pd
from .position import BasePosition, InfPosition, Position
-from .report import Report, Indicator
-from .order import BaseTradeDecision, Order
-
-if TYPE_CHECKING:
- from .exchange import Exchange
+from .report import PortfolioMetrics, Indicator
+from .decision import BaseTradeDecision, Order
+from .exchange import Exchange
"""
rtn & earning in the Account
rtn:
from order's view
1.change if any order is executed, sell order or buy order
- 2.change at the end of today, (today_clse - stock_price) * amount
+ 2.change at the end of today, (today_close - stock_price) * amount
earning
from value of current position
earning will be updated at the end of trade date
@@ -32,7 +29,7 @@ rtn & earning in the Account
class AccumulatedInfo:
- """accumulated trading info, including accumulated return\cost\turnover"""
+ """accumulated trading info, including accumulated return/cost/turnover"""
def __init__(self):
self.reset()
@@ -94,9 +91,12 @@ class Account:
self._pos_type = pos_type
self._port_metr_enabled = port_metr_enabled
+ self.benchmark_config = None # avoid no attribute error
+ self.init_vars(init_cash, position_dict, freq, benchmark_config)
+ def init_vars(self, init_cash, position_dict, freq: str, benchmark_config: dict):
self.init_cash = init_cash
- self.current: BasePosition = init_instance_by_config(
+ self.current_position: BasePosition = init_instance_by_config(
{
"class": self._pos_type,
"kwargs": {
@@ -106,37 +106,33 @@ class Account:
"module_path": "qlib.backtest.position",
}
)
- self.report = None
- self.positions = {}
-
- # in of reset ignore None values
- self.benchmark_config = benchmark_config
- self.freq = freq
-
- self.reset(freq=freq, benchmark_config=benchmark_config, init_report=True)
+ self.portfolio_metrics = None
+ self.hist_positions = {}
+ self.reset(freq=freq, benchmark_config=benchmark_config)
def is_port_metr_enabled(self):
"""
Is portfolio-based metrics enabled.
"""
- return self._port_metr_enabled and not self.current.skip_update()
+ return self._port_metr_enabled and not self.current_position.skip_update()
def reset_report(self, freq, benchmark_config):
# portfolio related metrics
if self.is_port_metr_enabled():
self.accum_info = AccumulatedInfo()
- self.report = Report(freq, benchmark_config)
- self.positions = {}
+ self.portfolio_metrics = PortfolioMetrics(freq, benchmark_config)
+ self.hist_positions = {}
+
# fill stock value
# The frequency of account may not align with the trading frequency.
# This may result in obscure bugs when data quality is low.
if isinstance(self.benchmark_config, dict) and self.benchmark_config.get("start_time") is not None:
- self.current.fill_stock_value(self.benchmark_config["start_time"], self.freq)
+ self.current_position.fill_stock_value(self.benchmark_config["start_time"], self.freq)
# trading related metrics(e.g. high-frequency trading)
self.indicator = Indicator()
- def reset(self, freq=None, benchmark_config=None, init_report=False, port_metr_enabled: bool = None):
+ def reset(self, freq=None, benchmark_config=None, port_metr_enabled: bool = None):
"""reset freq and report of account
Parameters
@@ -145,27 +141,23 @@ class Account:
frequency of account & report, by default None
benchmark_config : {}, optional
benchmark config of report, by default None
- init_report : bool, optional
- whether to initialize the report, by default False
"""
if freq is not None:
self.freq = freq
if benchmark_config is not None:
self.benchmark_config = benchmark_config
-
if port_metr_enabled is not None:
self._port_metr_enabled = port_metr_enabled
- if freq is not None or benchmark_config is not None or init_report:
- self.reset_report(self.freq, self.benchmark_config)
+ self.reset_report(self.freq, self.benchmark_config)
- def get_positions(self):
- return self.positions
+ def get_hist_positions(self):
+ return self.hist_positions
def get_cash(self):
- return self.current.get_cash()
+ return self.current_position.get_cash()
- def _update_accum_info_from_order(self, order, trade_val, cost, trade_price):
+ def _update_state_from_order(self, order, trade_val, cost, trade_price):
if self.is_port_metr_enabled():
# update turnover
self.accum_info.add_turnover(trade_val)
@@ -176,17 +168,17 @@ class Account:
trade_amount = trade_val / trade_price
if order.direction == Order.SELL: # 0 for sell
# when sell stock, get profit from price change
- profit = trade_val - self.current.get_stock_price(order.stock_id) * trade_amount
+ profit = trade_val - self.current_position.get_stock_price(order.stock_id) * trade_amount
self.accum_info.add_return_value(profit) # note here do not consider cost
elif order.direction == Order.BUY: # 1 for buy
# when buy stock, we get return for the rtn computing method
# profit in buy order is to make rtn is consistent with earning at the end of bar
- profit = self.current.get_stock_price(order.stock_id) * trade_amount - trade_val
+ profit = self.current_position.get_stock_price(order.stock_id) * trade_amount - trade_val
self.accum_info.add_return_value(profit) # note here do not consider cost
def update_order(self, order, trade_val, cost, trade_price):
- if self.current.skip_update():
+ if self.current_position.skip_update():
# TODO: supporting polymorphism for account
# updating order for infinite position is meaningless
return
@@ -196,65 +188,61 @@ class Account:
# The cost will be substracted from the cash at last. So the trading logic can ignore the cost calculation
if order.direction == Order.SELL:
# sell stock
- self._update_accum_info_from_order(order, trade_val, cost, trade_price)
+ self._update_state_from_order(order, trade_val, cost, trade_price)
# update current position
# for may sell all of stock_id
- self.current.update_order(order, trade_val, cost, trade_price)
+ self.current_position.update_order(order, trade_val, cost, trade_price)
else:
# buy stock
# deal order, then update state
- self.current.update_order(order, trade_val, cost, trade_price)
- self._update_accum_info_from_order(order, trade_val, cost, trade_price)
+ self.current_position.update_order(order, trade_val, cost, trade_price)
+ self._update_state_from_order(order, trade_val, cost, trade_price)
- def update_bar_count(self):
- """at the end of the trading bar, update holding bar, count of stock"""
- # update holding day count
- # NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy
- if not self.current.skip_update():
- self.current.add_count_all(bar=self.freq)
-
- def update_current(self, trade_start_time, trade_end_time, trade_exchange):
- """update current to make rtn consistent with earning at the end of bar"""
+ def update_current_position(self, trade_start_time, trade_end_time, trade_exchange):
+ """update current to make rtn consistent with earning at the end of bar, and update holding bar count of stock"""
# update price for stock in the position and the profit from changed_price
# NOTE: updating position does not only serve portfolio metrics, it also serve the strategy
- if not self.current.skip_update():
- stock_list = self.current.get_stock_list()
+ if not self.current_position.skip_update():
+ stock_list = self.current_position.get_stock_list()
for code in stock_list:
# if suspend, no new price to be updated, profit is 0
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
continue
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
- self.current.update_stock_price(stock_id=code, price=bar_close)
+ self.current_position.update_stock_price(stock_id=code, price=bar_close)
+ # update holding day count
+ # NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy
+ self.current_position.add_count_all(bar=self.freq)
- def update_report(self, trade_start_time, trade_end_time):
- """update position history, report"""
+ def update_portfolio_metrics(self, trade_start_time, trade_end_time):
+ """update portfolio_metrics"""
# calculate earning
# account_value - last_account_value
# for the first trade date, account_value - init_cash
- # self.report.is_empty() to judge is_first_trade_date
+ # self.portfolio_metrics.is_empty() to judge is_first_trade_date
# get last_account_value, last_total_cost, last_total_turnover
- if self.report.is_empty():
+ if self.portfolio_metrics.is_empty():
last_account_value = self.init_cash
last_total_cost = 0
last_total_turnover = 0
else:
- last_account_value = self.report.get_latest_account_value()
- last_total_cost = self.report.get_latest_total_cost()
- last_total_turnover = self.report.get_latest_total_turnover()
+ last_account_value = self.portfolio_metrics.get_latest_account_value()
+ last_total_cost = self.portfolio_metrics.get_latest_total_cost()
+ last_total_turnover = self.portfolio_metrics.get_latest_total_turnover()
# get now_account_value, now_stock_value, now_earning, now_cost, now_turnover
- now_account_value = self.current.calculate_value()
- now_stock_value = self.current.calculate_stock_value()
+ now_account_value = self.current_position.calculate_value()
+ now_stock_value = self.current_position.calculate_stock_value()
now_earning = now_account_value - last_account_value
now_cost = self.accum_info.get_cost - last_total_cost
now_turnover = self.accum_info.get_turnover - last_total_turnover
- # update report for today
+ # update portfolio_metrics for today
# judge whether the the trading is begin.
- # and don't add init account state into report, due to we don't have excess return in those days.
- self.report.update_report_record(
+ # and don't add init account state into portfolio_metrics, due to we don't have excess return in those days.
+ self.portfolio_metrics.update_portfolio_metrics_record(
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
account_value=now_account_value,
- cash=self.current.position["cash"],
+ cash=self.current_position.position["cash"],
return_rate=(now_earning + now_cost) / last_account_value,
# here use earning to calculate return, position's view, earning consider cost, true return
# in order to make same definition with original backtest in evaluate.py
@@ -264,12 +252,51 @@ class Account:
cost_rate=now_cost / last_account_value,
stock_value=now_stock_value,
)
+
+ def update_hist_positions(self, trade_start_time):
+ """update history position"""
+ now_account_value = self.current_position.calculate_value()
# set now_account_value to position
- self.current.position["now_account_value"] = now_account_value
- self.current.update_weight_all()
- # update positions
+ self.current_position.position["now_account_value"] = now_account_value
+ self.current_position.update_weight_all()
+ # update hist_positions
# note use deepcopy
- self.positions[trade_start_time] = copy.deepcopy(self.current)
+ self.hist_positions[trade_start_time] = copy.deepcopy(self.current_position)
+
+ def update_indicator(
+ self,
+ trade_start_time: pd.Timestamp,
+ trade_exchange: Exchange,
+ atomic: bool,
+ outer_trade_decision: BaseTradeDecision,
+ trade_info: list = None,
+ inner_order_indicators: List[Dict[str, pd.Series]] = None,
+ decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
+ indicator_config: dict = {},
+ ):
+ """update trade indicators and order indicators in each bar end"""
+ # TODO: will skip empty decisions make it faster? `outer_trade_decision.empty():`
+
+ # indicator is trading (e.g. high-frequency order execution) related analysis
+ self.indicator.reset()
+
+ # aggregate the information for each order
+ if atomic:
+ self.indicator.update_order_indicators(trade_info)
+ else:
+ self.indicator.agg_order_indicators(
+ inner_order_indicators,
+ decision_list=decision_list,
+ outer_trade_decision=outer_trade_decision,
+ trade_exchange=trade_exchange,
+ indicator_config=indicator_config,
+ )
+
+ # aggregate all the order metrics a single step
+ self.indicator.cal_trade_indicators(trade_start_time, self.freq, indicator_config)
+
+ # record the metrics
+ self.indicator.record(trade_start_time)
def update_bar_end(
self,
@@ -316,44 +343,34 @@ class Account:
elif atomic is False and inner_order_indicators is None:
raise ValueError("inner_order_indicators is necessary in un-atomic executor")
- # TODO: `update_bar_count` and `update_current` should placed in Position and be merged.
- self.update_bar_count()
- self.update_current(trade_start_time, trade_end_time, trade_exchange)
+ # update current position and hold bar count in each bar end
+ self.update_current_position(trade_start_time, trade_end_time, trade_exchange)
+
if self.is_port_metr_enabled():
- # report is portfolio related analysis
- self.update_report(trade_start_time, trade_end_time)
+ # portfolio_metrics is portfolio related analysis
+ self.update_portfolio_metrics(trade_start_time, trade_end_time)
+ self.update_hist_positions(trade_start_time)
- # TODO: will skip empty decisions make it faster? `outer_trade_decision.empty():`
+ # update indicator in each bar end
+ self.update_indicator(
+ trade_start_time=trade_start_time,
+ trade_exchange=trade_exchange,
+ atomic=atomic,
+ outer_trade_decision=outer_trade_decision,
+ trade_info=trade_info,
+ inner_order_indicators=inner_order_indicators,
+ decision_list=decision_list,
+ indicator_config=indicator_config,
+ )
- # indicator is trading (e.g. high-frequency order execution) related analysis
- self.indicator.reset()
-
- # aggregate the information for each order
- if atomic:
- self.indicator.update_order_indicators(trade_info)
- else:
- self.indicator.agg_order_indicators(
- inner_order_indicators,
- decision_list=decision_list,
- outer_trade_decision=outer_trade_decision,
- trade_exchange=trade_exchange,
- indicator_config=indicator_config,
- )
-
- # aggregate all the order metrics a single step
- self.indicator.cal_trade_indicators(trade_start_time, self.freq, indicator_config)
-
- # record the metrics
- self.indicator.record(trade_start_time)
-
- def get_report(self):
- """get the history report and postions instance"""
+ def get_portfolio_metrics(self):
+ """get the history portfolio_metrics and postions instance"""
if self.is_port_metr_enabled():
- _report = self.report.generate_report_dataframe()
- _positions = self.get_positions()
- return _report, _positions
+ _portfolio_metrics = self.portfolio_metrics.generate_portfolio_metrics_dataframe()
+ _positions = self.get_hist_positions()
+ return _portfolio_metrics, _positions
else:
- raise ValueError("generate_report should be True if you want to generate report")
+ raise ValueError("generate_portfolio_metrics should be True if you want to generate portfolio_metrics")
def get_trade_indicator(self) -> Indicator:
"""get the trade indicator instance, which has pa/pos/ffr info."""
diff --git a/qlib/backtest/backtest.py b/qlib/backtest/backtest.py
index c707aa1f0..fa4063bc9 100644
--- a/qlib/backtest/backtest.py
+++ b/qlib/backtest/backtest.py
@@ -2,7 +2,7 @@
# Licensed under the MIT License.
from __future__ import annotations
-from qlib.backtest.order import BaseTradeDecision
+from qlib.backtest.decision import BaseTradeDecision
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@@ -19,15 +19,15 @@ def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_exec
Returns
-------
- report: Report
- it records the trading report information
+ portfolio_metrics: PortfolioMetrics
+ it records the trading portfolio_metrics information
indicator: Indicator
it computes the trading indicator
"""
return_value = {}
for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
pass
- return return_value.get("report"), return_value.get("indicator")
+ return return_value.get("portfolio_metrics"), return_value.get("indicator")
def collect_data_loop(
@@ -68,9 +68,8 @@ def collect_data_loop(
if return_value is not None:
all_executors = trade_executor.get_all_executors()
-
- all_reports = {
- "{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_report()
+ all_portfolio_metrics = {
+ "{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_portfolio_metrics()
for _executor in all_executors
if _executor.trade_account.is_port_metr_enabled()
}
@@ -79,4 +78,4 @@ def collect_data_loop(
key = "{}{}".format(*Freq.parse(_executor.time_per_step))
all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator()
- return_value.update({"report": all_reports, "indicator": all_indicators})
+ return_value.update({"portfolio_metrics": all_portfolio_metrics, "indicator": all_indicators})
diff --git a/qlib/backtest/order.py b/qlib/backtest/decision.py
similarity index 99%
rename from qlib/backtest/order.py
rename to qlib/backtest/decision.py
index a1b21be0a..049e56c00 100644
--- a/qlib/backtest/order.py
+++ b/qlib/backtest/decision.py
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-# TODO: rename it with decision.py
+
from __future__ import annotations
from enum import IntEnum
from qlib.data.data import Cal
diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py
index 8b539de8b..9e40e1877 100644
--- a/qlib/backtest/exchange.py
+++ b/qlib/backtest/exchange.py
@@ -15,10 +15,9 @@ import pandas as pd
from ..data.data import D
from ..config import C, REG_CN
-from ..utils.resam import resam_ts_data, ts_data_last
from ..log import get_module_logger
-from .order import Order, OrderDir, OrderHelper
-from .high_performance_ds import BaseQuote, PandasQuote, CN1minNumpyQuote
+from .decision import Order, OrderDir, OrderHelper
+from .high_performance_ds import BaseQuote, PandasQuote, NumpyQuote
class Exchange:
@@ -36,29 +35,24 @@ class Exchange:
close_cost=0.0025,
min_cost=5,
extra_quote=None,
- quote_cls=CN1minNumpyQuote,
+ quote_cls=NumpyQuote,
**kwargs,
):
"""__init__
-
:param freq: frequency of data
:param start_time: closed start time for backtest
:param end_time: closed end time for backtest
:param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50)
-
:param deal_price: Union[str, Tuple[str, str], List[str]]
The `deal_price` supports following two types of input
- : str
- (, ): Tuple[str] or List[str]
-
, or := := str
- for example '$close', '$open', '$vwap' ("close" is OK. `Exchange` will help to prepend
"$" to the expression)
-
:param subscribe_fields: list, subscribe fields. This expressions will be added to the query and `self.quote`.
It is useful when users want more fields to be queried
-
:param limit_threshold: Union[Tuple[str, str], float, None]
1) `None`: no limitation
2) float, 0.1 for example, default None
@@ -66,7 +60,6 @@ class Exchange:
)
`False` value indicates the stock is tradable
`True` value indicates the stock is limited and not tradable
-
:param volume_threshold: Union[
Dict[
"all": ("cum" or "current", limit_str),
@@ -85,26 +78,22 @@ class Exchange:
- "current" means that this is a real-time value and will not accumulate over time,
so it can be directly used as a capacity limit.
e.g. ("cum", "0.2 * DayCumsum($volume, '9:45', '14:45')"), ("current", "$bidV1")
-
2) "all" means the volume limits are both buying and selling.
"buy" means the volume limits of buying. "sell" means the volume limits of selling.
Different volume limits will be aggregated with min(). If volume_threshold is only
("cum" or "current", limit_str) instead of a dict, the volume limits are for
both by deault. In other words, it is same as {"all": ("cum" or "current", limit_str)}.
-
3) e.g. "volume_threshold": {
"all": ("cum", "0.2 * DayCumsum($volume, '9:45', '14:45')"),
"buy": ("current", "$askV1"),
"sell": ("current", "$bidV1"),
}
-
:param open_cost: cost rate for open, default 0.0015
:param close_cost: cost rate for close, default 0.0025
:param trade_unit: trade unit, 100 for China A market.
None for disable trade unit.
**NOTE**: `trade_unit` is included in the `kwargs`. It is necessary because we must
distinguish `not set` and `disable trade_unit`
-
:param min_cost: min cost, default 5
:param extra_quote: pandas, dataframe consists of
columns: like ['$vwap', '$close', '$volume', '$factor', 'limit_sell', 'limit_buy'].
@@ -185,7 +174,7 @@ class Exchange:
# init quote by quote_df
self.quote_cls = quote_cls
- self.quote: BaseQuote = self.quote_cls(self.quote_df)
+ self.quote: BaseQuote = self.quote_cls(self.quote_df, freq)
def get_quote_from_qlib(self):
# get stock data from qlib
@@ -273,12 +262,10 @@ class Exchange:
preproccess the volume limit.
get the fields need to get from qlib.
get the volume limit list of buying and selling which is composed of all limits.
-
Parameters
----------
volume_threshold :
please refer to the doc of exchange.
-
Returns
-------
fields: set
@@ -287,7 +274,6 @@ class Exchange:
all volume limits of buying.
sell_vol_limit: List[Tuple[str]]
all volume limits of selling.
-
Raises
------
ValueError
@@ -324,7 +310,6 @@ class Exchange:
- if direction is None, check if tradable for buying and selling.
- if direction == Order.BUY, check the if tradable for buying
- if direction == Order.SELL, check the sell limit for selling.
-
"""
if direction is None:
buy_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
@@ -372,9 +357,7 @@ class Exchange:
):
"""
Deal order when the actual transaction
-
the results section in `Order` will be changed.
-
:param order: Deal the order.
:param trade_account: Trade account to be updated after dealing the order.
:param position: position to be updated after dealing the order.
@@ -393,12 +376,12 @@ class Exchange:
# NOTE: order will be changed in this function
trade_price, trade_val, trade_cost = self._calc_trade_info_by_order(
- order, trade_account.current if trade_account else position, dealt_order_amount
+ order, trade_account.current_position if trade_account else position, dealt_order_amount
)
- if order.deal_amount > 1e-5:
- # If the order can only be deal 0 amount. Nothing to be updated
+ if trade_val > 1e-5:
+ # If the order can only be deal 0 value. Nothing to be updated
# Otherwise, it will result in
- # 1) some stock with 0 amount in the position
+ # 1) some stock with 0 value in the position
# 2) `trade_unit` of trade_cost will be lost in user account
if trade_account:
trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)
@@ -407,16 +390,17 @@ class Exchange:
return trade_val, trade_cost, trade_price
- def get_quote_info(self, stock_id, start_time, end_time, method=ts_data_last):
+ def get_quote_info(self, stock_id, start_time, end_time, method="ts_data_last"):
return self.quote.get_data(stock_id, start_time, end_time, method=method)
- def get_close(self, stock_id, start_time, end_time, method=ts_data_last):
+ def get_close(self, stock_id, start_time, end_time, method="ts_data_last"):
return self.quote.get_data(stock_id, start_time, end_time, field="$close", method=method)
- def get_volume(self, stock_id, start_time, end_time, method="sum"):
- return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method)
+ def get_volume(self, stock_id, start_time, end_time):
+ """get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)"""
+ return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method="sum")
- def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir, method=ts_data_last):
+ def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir, method="ts_data_last"):
if direction == OrderDir.SELL:
pstr = self.sell_price
elif direction == OrderDir.BUY:
@@ -441,7 +425,7 @@ class Exchange:
assert start_time is not None and end_time is not None, "the time range must be given"
if stock_id not in self.quote.get_all_stock():
return None
- return self.quote.get_data(stock_id, start_time, end_time, field="$factor", method=ts_data_last)
+ return self.quote.get_data(stock_id, start_time, end_time, field="$factor", method="ts_data_last")
def generate_amount_position_from_weight_position(
self, weight_position, cash, start_time, end_time, direction=OrderDir.BUY
@@ -449,7 +433,6 @@ class Exchange:
"""
The generate the target position according to the weight and the cash.
NOTE: All the cash will assigned to the tadable stock.
-
Parameter:
weight_position : dict {stock_id : weight}; allocate cash by weight_position
among then, weight must be in this range: 0 < weight < 1
@@ -493,7 +476,6 @@ class Exchange:
def get_real_deal_amount(self, current_amount, target_amount, factor):
"""
Calculate the real adjust deal amount when considering the trading unit
-
:param current_amount:
:param target_amount:
:param factor:
@@ -516,7 +498,6 @@ class Exchange:
def generate_order_for_target_amount_position(self, target_position, current_position, start_time, end_time):
"""
Note: some future information is used in this function
-
Parameter:
target_position : dict { stock_id : amount }
current_postion : dict { stock_id : amount}
@@ -590,8 +571,10 @@ class Exchange:
value = 0
for stock_id in amount_dict:
if (
- self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
+ only_tradable is True
+ and self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
and self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
+ or only_tradable is False
):
value += (
self.get_deal_price(
@@ -613,10 +596,8 @@ class Exchange:
def get_amount_of_trade_unit(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None):
"""
get the trade unit of amount based on **factor**
-
the factor can be given directly or calculated in given time range and stock id.
`factor` has higher priority than `stock_id`, `start_time` and `end_time`
-
Parameters
----------
factor : float
@@ -641,7 +622,6 @@ class Exchange:
):
"""Parameter
Please refer to the docs of get_amount_of_trade_unit
-
deal_amount : float, adjusted amount
factor : float, adjusted factor
return : float, real amount
@@ -656,11 +636,9 @@ class Exchange:
def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> int:
"""parse the capacity limit string and return the actual amount of orders that can be executed.
-
NOTE:
this function will change the order.deal_amount **inplace**
- This will make the order info more accurate
-
Parameters
----------
order : Order
@@ -694,7 +672,7 @@ class Exchange:
order.start_time,
order.end_time,
field=limit[1],
- method=ts_data_last,
+ method="ts_data_last",
)
vol_limit_num.append(limit_value - dealt_order_amount[order.stock_id])
else:
@@ -709,12 +687,10 @@ class Exchange:
def _get_buy_amount_by_cash_limit(self, trade_price, cash):
"""return the real order amount after cash limit for buying.
-
Parameters
----------
trade_price : float
position : cash
-
Return
----------
float
@@ -735,9 +711,7 @@ class Exchange:
def _calc_trade_info_by_order(self, order, position: Position, dealt_order_amount):
"""
Calculation of trade info
-
**NOTE**: Order will be changed in this function
-
:param order:
:param position: Position
:param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
@@ -745,18 +719,27 @@ class Exchange:
"""
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction)
order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time)
+ order.deal_amount = order.amount # set to full amount and clip it step by step
+ # Clipping amount first
+ # - It simulates that the order is rejected directly by the exchange due to large order
+ # Another choice is placing it after rounding the order
+ # - It simulates that the large order is submitted, but partial is dealt regardless of rounding by trading unit.
+ self._clip_amount_by_volume(order, dealt_order_amount)
+
if order.direction == Order.SELL:
cost_ratio = self.close_cost
# sell
+ # if we don't know current position, we choose to sell all
+ # Otherwise, we clip the amount based on current position
if position is not None:
current_amount = (
position.get_stock_amount(order.stock_id) if position.check_stock(order.stock_id) else 0
)
- if np.isclose(order.amount, current_amount):
- # when selling last stock. The amount don't need rounding
- order.deal_amount = order.amount
- else:
- order.deal_amount = self.round_amount_by_trade_unit(min(current_amount, order.amount), order.factor)
+ if not np.isclose(order.deal_amount, current_amount):
+ # when not selling last stock. rounding is necessary
+ order.deal_amount = self.round_amount_by_trade_unit(
+ min(current_amount, order.deal_amount), order.factor
+ )
# in case of negative value of cash
if position.get_cash() + order.deal_amount * trade_price < max(
@@ -765,33 +748,30 @@ class Exchange:
):
order.deal_amount = 0
self.logger.debug(f"Order clipped due to cash limitation: {order}")
- else:
- # TODO: We don't know current position.
- # We choose to sell all
- order.deal_amount = order.amount
elif order.direction == Order.BUY:
cost_ratio = self.open_cost
# buy
if position is not None:
cash = position.get_cash()
- trade_val = order.amount * trade_price
+ trade_val = order.deal_amount * trade_price
if cash < trade_val + max(trade_val * cost_ratio, self.min_cost):
# The money is not enough
max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash)
- order.deal_amount = self.round_amount_by_trade_unit(max_buy_amount, order.factor)
+ order.deal_amount = self.round_amount_by_trade_unit(
+ min(max_buy_amount, order.deal_amount), order.factor
+ )
self.logger.debug(f"Order clipped due to cash limitation: {order}")
else:
# The money is enough
- order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
+ order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor)
else:
# Unknown amount of money. Just round the amount
- order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
+ order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor)
else:
raise NotImplementedError("order type {} error".format(order.type))
- self._clip_amount_by_volume(order, dealt_order_amount)
trade_val = order.deal_amount * trade_price
trade_cost = max(trade_val * cost_ratio, self.min_cost)
if trade_val <= 1e-5:
diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py
index e7882714a..44f3e8db0 100644
--- a/qlib/backtest/executor.py
+++ b/qlib/backtest/executor.py
@@ -11,7 +11,7 @@ from collections import defaultdict
from qlib.backtest.report import Indicator
-from .order import EmptyTradeDecision, Order, BaseTradeDecision
+from .decision import EmptyTradeDecision, Order, BaseTradeDecision
from .exchange import Exchange
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, get_start_end_idx
@@ -29,7 +29,7 @@ class BaseExecutor:
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
indicator_config: dict = {},
- generate_report: bool = False,
+ generate_portfolio_metrics: bool = False,
verbose: bool = False,
track_data: bool = False,
trade_exchange: Exchange = None,
@@ -77,8 +77,8 @@ class BaseExecutor:
'weight_method': 'value_weighted',
}
}
- generate_report : bool, optional
- whether to generate report, by default False
+ generate_portfolio_metrics : bool, optional
+ whether to generate portfolio_metrics, by default False
verbose : bool, optional
whether to print trading info, by default False
track_data : bool, optional
@@ -87,8 +87,8 @@ class BaseExecutor:
- Else, `trade_decision` will not be generated
trade_exchange : Exchange
- exchange that provides market info, used to generate report
- - If generate_report is None, trade_exchange will be ignored
+ exchange that provides market info, used to generate portfolio_metrics
+ - If generate_portfolio_metrics is None, trade_exchange will be ignored
- Else If `trade_exchange` is None, self.trade_exchange will be set with common_infra
common_infra : CommonInfrastructure, optional:
@@ -103,7 +103,7 @@ class BaseExecutor:
"""
self.time_per_step = time_per_step
self.indicator_config = indicator_config
- self.generate_report = generate_report
+ self.generate_portfolio_metrics = generate_portfolio_metrics
self.verbose = verbose
self.track_data = track_data
self._trade_exchange = trade_exchange
@@ -132,7 +132,7 @@ class BaseExecutor:
# NOTE: there is a trick in the code.
# copy is used instead of deepcopy. So positions are shared
self.trade_account: Account = copy.copy(common_infra.get("trade_account"))
- self.trade_account.reset(freq=self.time_per_step, init_report=True, port_metr_enabled=self.generate_report)
+ self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics)
@property
def trade_exchange(self) -> Exchange:
@@ -246,7 +246,7 @@ class BaseExecutor:
raise ValueError("atomic executor doesn't support specify `range_limit`")
if self._settle_type != BasePosition.ST_NO:
- self.trade_account.current.settle_start(self._settle_type)
+ self.trade_account.current_position.settle_start(self._settle_type)
obj = self._collect_data(trade_decision=trade_decision, level=level)
@@ -271,7 +271,7 @@ class BaseExecutor:
self.trade_calendar.step()
if self._settle_type != BasePosition.ST_NO:
- self.trade_account.current.settle_commit()
+ self.trade_account.current_position.settle_commit()
if return_value is not None:
return_value.update({"execute_result": res})
@@ -296,7 +296,7 @@ class NestedExecutor(BaseExecutor):
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
indicator_config: dict = {},
- generate_report: bool = False,
+ generate_portfolio_metrics: bool = False,
verbose: bool = False,
track_data: bool = False,
skip_empty_decision: bool = True,
@@ -335,7 +335,7 @@ class NestedExecutor(BaseExecutor):
start_time=start_time,
end_time=end_time,
indicator_config=indicator_config,
- generate_report=generate_report,
+ generate_portfolio_metrics=generate_portfolio_metrics,
verbose=verbose,
track_data=track_data,
common_infra=common_infra,
@@ -444,7 +444,7 @@ class SimulatorExecutor(BaseExecutor):
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
indicator_config: dict = {},
- generate_report: bool = False,
+ generate_portfolio_metrics: bool = False,
verbose: bool = False,
track_data: bool = False,
common_infra: CommonInfrastructure = None,
@@ -462,7 +462,7 @@ class SimulatorExecutor(BaseExecutor):
start_time=start_time,
end_time=end_time,
indicator_config=indicator_config,
- generate_report=generate_report,
+ generate_portfolio_metrics=generate_portfolio_metrics,
verbose=verbose,
track_data=track_data,
common_infra=common_infra,
diff --git a/qlib/backtest/high_performance_ds.py b/qlib/backtest/high_performance_ds.py
index 97310ffb6..235bd054b 100644
--- a/qlib/backtest/high_performance_ds.py
+++ b/qlib/backtest/high_performance_ds.py
@@ -1,7 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-
from functools import lru_cache
import logging
from typing import List, Text, Union, Callable, Iterable, Dict
@@ -14,12 +13,12 @@ import numpy as np
from ..utils.index_data import IndexData, SingleData
from ..utils.resam import resam_ts_data, ts_data_last
from ..log import get_module_logger
-from ..utils.time import is_single_value
+from ..utils.time import is_single_value, Freq
import qlib.utils.index_data as idd
class BaseQuote:
- def __init__(self, quote_df: pd.DataFrame):
+ def __init__(self, quote_df: pd.DataFrame, freq):
self.logger = get_module_logger("online operator", level=logging.INFO)
def get_all_stock(self) -> Iterable:
@@ -39,7 +38,7 @@ class BaseQuote:
start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str],
field: Union[str],
- method: Union[str, Callable, None] = None,
+ method: Union[str, None] = None,
) -> Union[None, int, float, bool, IndexData]:
"""get the specific field of stock data during start time and end_time,
and apply method to the data.
@@ -83,9 +82,9 @@ class BaseQuote:
closed end time for backtest
field : str
the columns of data to fetch
- method : Union[str, Callable, None]
+ method : Union[str, None]
the method apply to data.
- e.g [None, "last", "all", "sum", "mean", qlib/utils/resam.py/ts_data_last]
+ e.g [None, "last", "all", "sum", "mean", "ts_data_last"]
Return
----------
@@ -99,8 +98,8 @@ class BaseQuote:
class PandasQuote(BaseQuote):
- def __init__(self, quote_df: pd.DataFrame):
- super().__init__(quote_df=quote_df)
+ def __init__(self, quote_df: pd.DataFrame, freq):
+ super().__init__(quote_df=quote_df, freq=freq)
quote_dict = {}
for stock_id, stock_val in quote_df.groupby(level="instrument"):
quote_dict[stock_id] = stock_val.droplevel(level="instrument")
@@ -110,6 +109,8 @@ class PandasQuote(BaseQuote):
return self.data.keys()
def get_data(self, stock_id, start_time, end_time, field, method=None):
+ if method == "ts_data_last":
+ method = ts_data_last
stock_data = resam_ts_data(self.data[stock_id][field], start_time, end_time, method=method)
if stock_data is None:
return None
@@ -121,9 +122,9 @@ class PandasQuote(BaseQuote):
raise ValueError(f"stock data from resam_ts_data must be a number, pd.Series or pd.DataFrame")
-class CN1minNumpyQuote(BaseQuote):
- def __init__(self, quote_df: pd.DataFrame):
- """CN1minNumpyQuote
+class NumpyQuote(BaseQuote):
+ def __init__(self, quote_df: pd.DataFrame, freq, region="cn"):
+ """NumpyQuote
Parameters
----------
@@ -131,13 +132,19 @@ class CN1minNumpyQuote(BaseQuote):
the init dataframe from qlib.
self.data : Dict(stock_id, IndexData.DataFrame)
"""
- super().__init__(quote_df=quote_df)
+ super().__init__(quote_df=quote_df, freq=freq)
quote_dict = {}
for stock_id, stock_val in quote_df.groupby(level="instrument"):
quote_dict[stock_id] = idd.MultiData(stock_val.droplevel(level="instrument"))
quote_dict[stock_id].sort_index() # To support more flexible slicing, we must sort data first
self.data = quote_dict
- self.freq = pd.Timedelta(minutes=1)
+
+ n, unit = Freq.parse(freq)
+ if unit in Freq.SUPPORT_CAL_LIST:
+ self.freq = Freq.get_timedelta(1, unit)
+ else:
+ raise ValueError(f"{freq} is not supported in NumpyQuote")
+ self.region = region
def get_all_stock(self):
return self.data.keys()
@@ -150,7 +157,7 @@ class CN1minNumpyQuote(BaseQuote):
# single data
# If it don't consider the classification of single data, it will consume a lot of time.
- if is_single_value(start_time, end_time, self.freq):
+ if is_single_value(start_time, end_time, self.freq, self.region):
# this is a very special case.
# skip aggregating function to speed-up the query calculation
try:
@@ -178,9 +185,7 @@ class CN1minNumpyQuote(BaseQuote):
return data[-1]
elif method == "all":
return data.all()
- elif method == "any":
- return data.any()
- elif method == ts_data_last:
+ elif method == "ts_data_last":
valid_data = data.loc[~data.isna().data.astype(bool)]
if len(valid_data) == 0:
return None
diff --git a/qlib/backtest/position.py b/qlib/backtest/position.py
index 234ec08b9..2bfb20893 100644
--- a/qlib/backtest/position.py
+++ b/qlib/backtest/position.py
@@ -10,7 +10,7 @@ import pandas as pd
from datetime import timedelta
import numpy as np
-from .order import Order
+from .decision import Order
from ..data.data import D
@@ -151,7 +151,8 @@ class BasePosition:
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
"""
generate stock weight dict {stock_id : value weight of stock in the position}
- it is meaningful in the beginning or the end of each trade date
+ it is meaningful in the beginning or the end of each trade step
+ - During execution of each trading step, the weight may be not consistant with the portfolio value
Parameters
----------
@@ -408,7 +409,7 @@ class Position(BasePosition):
return self.position[code]["price"]
def get_stock_amount(self, code):
- return self.position[code]["amount"]
+ return self.position[code]["amount"] if code in self.position else 0
def get_stock_count(self, code, bar):
"""the days the account has been hold, it may be used in some special strategies"""
@@ -531,7 +532,7 @@ class InfPosition(BasePosition):
raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
def add_count_all(self, bar):
- raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
+ raise NotImplementedError(f"InfPosition doesn't support add_count_all")
def update_weight_all(self):
raise NotImplementedError(f"InfPosition doesn't support update_weight_all")
diff --git a/qlib/backtest/profit_attribution.py b/qlib/backtest/profit_attribution.py
index 05ee138cb..895f5c78b 100644
--- a/qlib/backtest/profit_attribution.py
+++ b/qlib/backtest/profit_attribution.py
@@ -18,6 +18,7 @@ def get_benchmark_weight(
start_date=None,
end_date=None,
path=None,
+ freq="day",
):
"""get_benchmark_weight
@@ -27,6 +28,7 @@ def get_benchmark_weight(
:param start_date:
:param end_date:
:param path:
+ :param freq:
:return: The weight distribution of the the benchmark described by a pandas dataframe
Every row corresponds to a trading day.
@@ -35,7 +37,7 @@ def get_benchmark_weight(
"""
if not path:
- path = Path(C.get_data_path()).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
+ path = Path(C.dpm.get_data_uri(freq)).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
# TODO: the storage of weights should be implemented in a more elegent way
# TODO: The benchmark is not consistant with the filename in instruments.
bench_weight_df = pd.read_csv(path, usecols=["code", "date", "index", "weight"])
@@ -224,6 +226,7 @@ def brinson_pa(
group_method="category",
group_n=None,
deal_price="vwap",
+ freq="day",
):
"""brinson profit attribution
@@ -245,7 +248,7 @@ def brinson_pa(
start_date, end_date = min(dates), max(dates)
- bench_stock_weight = get_benchmark_weight(bench, start_date, end_date)
+ bench_stock_weight = get_benchmark_weight(bench, start_date, end_date, freq)
# The attributes for allocation will not
if not group_field.startswith("$"):
@@ -261,13 +264,14 @@ def brinson_pa(
start_time=shift_start_date,
end_time=end_date,
as_list=True,
+ freq=freq,
)
stock_df = D.features(
instruments,
[group_field, deal_price],
start_time=shift_start_date,
end_time=end_date,
- freq="day",
+ freq=freq,
)
stock_df.columns = [group_field, "deal_price"]
diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py
index a364b10db..03fb85344 100644
--- a/qlib/backtest/report.py
+++ b/qlib/backtest/report.py
@@ -10,21 +10,24 @@ import numpy as np
import pandas as pd
from qlib.backtest.exchange import Exchange
-from qlib.backtest.order import BaseTradeDecision, Order, OrderDir
+from .decision import IdxTradeRange
+from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir
+from qlib.backtest.utils import TradeCalendarManager
from .high_performance_ds import BaseOrderIndicator, PandasOrderIndicator, NumpyOrderIndicator, SingleMetric
+from ..data import D
from ..tests.config import CSI300_BENCH
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
-from .order import IdxTradeRange
import qlib.utils.index_data as idd
-class Report:
+class PortfolioMetrics:
"""
Motivation:
- Report is for supporting portfolio related metrics.
+ PortfolioMetrics is for supporting portfolio related metrics.
Implementation:
- daily report of the account
+
+ daily portfolio metrics of the account
contain those followings: return, cost, turnover, account, cash, bench, value
For each step(bar/day/minute), each column represents
- return: the return of the portfolio generated by strategy **without transaction fee**.
@@ -33,7 +36,7 @@ class Report:
- cash: the amount of cash in user's account.
- bench: the return of the benchmark
- value: the total value of securities/stocks/instruments (cash is excluded).
-
+
update report
"""
@@ -79,7 +82,7 @@ class Report:
self.values = OrderedDict() # value for each trade time
self.cashes = OrderedDict()
self.benches = OrderedDict()
- self.latest_report_time = None # pd.TimeStamp
+ self.latest_pm_time = None # pd.TimeStamp
def init_bench(self, freq=None, benchmark_config=None):
if freq is not None:
@@ -123,18 +126,18 @@ class Report:
return len(self.accounts) == 0
def get_latest_date(self):
- return self.latest_report_time
+ return self.latest_pm_time
def get_latest_account_value(self):
- return self.accounts[self.latest_report_time]
+ return self.accounts[self.latest_pm_time]
def get_latest_total_cost(self):
- return self.total_costs[self.latest_report_time]
+ return self.total_costs[self.latest_pm_time]
def get_latest_total_turnover(self):
- return self.total_turnovers[self.latest_report_time]
+ return self.total_turnovers[self.latest_pm_time]
- def update_report_record(
+ def update_portfolio_metrics_record(
self,
trade_start_time=None,
trade_end_time=None,
@@ -169,7 +172,7 @@ class Report:
elif bench_value is None:
bench_value = self._sample_benchmark(self.bench, trade_start_time, trade_end_time)
- # update report data
+ # update pm data
self.accounts[trade_start_time] = account_value
self.returns[trade_start_time] = return_rate
self.total_turnovers[trade_start_time] = total_turnover
@@ -179,30 +182,30 @@ class Report:
self.values[trade_start_time] = stock_value
self.cashes[trade_start_time] = cash
self.benches[trade_start_time] = bench_value
- # update latest_report_date
- self.latest_report_time = trade_start_time
- # finish report update in each step
+ # update pm
+ self.latest_pm_time = trade_start_time
+ # finish pm update in each step
- def generate_report_dataframe(self):
- report = pd.DataFrame()
- report["account"] = pd.Series(self.accounts)
- report["return"] = pd.Series(self.returns)
- report["total_turnover"] = pd.Series(self.total_turnovers)
- report["turnover"] = pd.Series(self.turnovers)
- report["total_cost"] = pd.Series(self.total_costs)
- report["cost"] = pd.Series(self.costs)
- report["value"] = pd.Series(self.values)
- report["cash"] = pd.Series(self.cashes)
- report["bench"] = pd.Series(self.benches)
- report.index.name = "datetime"
- return report
+ def generate_portfolio_metrics_dataframe(self):
+ pm = pd.DataFrame()
+ pm["account"] = pd.Series(self.accounts)
+ pm["return"] = pd.Series(self.returns)
+ pm["total_turnover"] = pd.Series(self.total_turnovers)
+ pm["turnover"] = pd.Series(self.turnovers)
+ pm["total_cost"] = pd.Series(self.total_costs)
+ pm["cost"] = pd.Series(self.costs)
+ pm["value"] = pd.Series(self.values)
+ pm["cash"] = pd.Series(self.cashes)
+ pm["bench"] = pd.Series(self.benches)
+ pm.index.name = "datetime"
+ return pm
- def save_report(self, path):
- r = self.generate_report_dataframe()
+ def save_portfolio_metrics(self, path):
+ r = self.generate_portfolio_metrics_dataframe()
r.to_csv(path)
- def load_report(self, path):
- """load report from a file
+ def load_portfolio_metrics(self, path):
+ """load pm from a file
should have format like
columns = ['account', 'return', 'total_turnover', 'turnover', 'cost', 'total_cost', 'value', 'cash', 'bench']
:param
@@ -215,7 +218,7 @@ class Report:
index = r.index
self.init_vars()
for trade_start_time in index:
- self.update_report_record(
+ self.update_portfolio_metrics_record(
trade_start_time=trade_start_time,
account_value=r.loc[trade_start_time]["account"],
cash=r.loc[trade_start_time]["cash"],
@@ -376,8 +379,6 @@ class Indicator:
price = pa_config.get("price", "deal_price").lower()
if decision.trade_range is not None:
- if isinstance(decision.trade_range, IdxTradeRange):
- raise TypeError(f"IdxTradeRange is not supported")
trade_start_time, trade_end_time = decision.trade_range.clip_time_range(
start_time=trade_start_time, end_time=trade_end_time
)
diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py
index b5ff84c54..51130712d 100644
--- a/qlib/backtest/utils.py
+++ b/qlib/backtest/utils.py
@@ -1,18 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
+
from __future__ import annotations
import bisect
from qlib.utils.time import epsilon_change
-from typing import Union, TYPE_CHECKING, Tuple, Union, List, Set
+from typing import TYPE_CHECKING, Tuple, Union
if TYPE_CHECKING:
- from qlib.backtest.order import BaseTradeDecision
- from qlib.strategy.base import BaseStrategy
+ from qlib.backtest.decision import BaseTradeDecision
import pandas as pd
import warnings
-from ..utils.resam import get_resam_calendar
from ..data.data import Cal
@@ -56,9 +55,9 @@ class TradeCalendarManager:
self.start_time = pd.Timestamp(start_time) if start_time else None
self.end_time = pd.Timestamp(end_time) if end_time else None
- _calendar, freq, freq_sam = get_resam_calendar(freq=freq)
+ _calendar = Cal.calendar(freq=freq)
self._calendar = _calendar
- _, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam)
+ _, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq)
self.start_index = _start_index
self.end_index = _end_index
self.trade_len = _end_index - _start_index + 1
diff --git a/qlib/config.py b/qlib/config.py
index 796cc5ca6..029434a88 100644
--- a/qlib/config.py
+++ b/qlib/config.py
@@ -15,8 +15,10 @@ import os
import re
import copy
import logging
+import platform
import multiprocessing
from pathlib import Path
+from typing import Union
class Config:
@@ -73,6 +75,12 @@ REG_US = "us"
NUM_USABLE_CPU = max(multiprocessing.cpu_count() - 2, 1)
+DISK_DATASET_CACHE = "DiskDatasetCache"
+SIMPLE_DATASET_CACHE = "SimpleDatasetCache"
+DISK_EXPRESSION_CACHE = "DiskExpressionCache"
+
+DEPENDENCY_REDIS_CACHE = (DISK_DATASET_CACHE, DISK_EXPRESSION_CACHE)
+
_default_config = {
# data provider config
"calendar_provider": "LocalCalendarProvider",
@@ -82,6 +90,15 @@ _default_config = {
"dataset_provider": "LocalDatasetProvider",
"provider": "LocalProvider",
# config it in qlib.init()
+ # "provider_uri" str or dict:
+ # # str
+ # "~/.qlib/stock_data/cn_data"
+ # # dict
+ # {"day": "~/.qlib/stock_data/cn_data", "1min": "~/.qlib/stock_data/cn_data_1min"}
+ # NOTE: provider_uri priority:
+ # 1. backend_config: backend_obj["kwargs"]["provider_uri"]
+ # 2. backend_config: backend_obj["kwargs"]["provider_uri_map"]
+ # 3. qlib.init: provider_uri
"provider_uri": "",
# cache
"expression_cache": None,
@@ -173,8 +190,9 @@ MODE_CONF = {
"redis_task_db": 1,
"kernels": NUM_USABLE_CPU,
# cache
- "expression_cache": "DiskExpressionCache",
- "dataset_cache": "DiskDatasetCache",
+ "expression_cache": DISK_EXPRESSION_CACHE,
+ "dataset_cache": DISK_DATASET_CACHE,
+ "local_cache_path": Path("~/.cache/qlib_simple_cache").expanduser().resolve(),
"mount_path": None,
},
"client": {
@@ -189,8 +207,10 @@ MODE_CONF = {
"provider_uri": "~/.qlib/qlib_data/cn_data",
# cache
# Using parameter 'remote' to announce the client is using server_cache, and the writing access will be disabled.
- "expression_cache": "DiskExpressionCache",
- "dataset_cache": "DiskDatasetCache",
+ "expression_cache": DISK_EXPRESSION_CACHE,
+ "dataset_cache": DISK_DATASET_CACHE,
+ # SimpleDatasetCache directory
+ "local_cache_path": Path("~/.cache/qlib_simple_cache").expanduser().resolve(),
"calendar_cache": None,
# client config
"kernels": NUM_USABLE_CPU,
@@ -234,11 +254,43 @@ class QlibConfig(Config):
# URI_TYPE
LOCAL_URI = "local"
NFS_URI = "nfs"
+ DEFAULT_FREQ = "__DEFAULT_FREQ"
def __init__(self, default_conf):
super().__init__(default_conf)
self._registered = False
+ class DataPathManager:
+ def __init__(self, provider_uri: Union[str, Path, dict], mount_path: Union[str, Path, dict]):
+ self.provider_uri = provider_uri
+ self.mount_path = mount_path
+
+ @staticmethod
+ def get_uri_type(uri: Union[str, Path]):
+ uri = uri if isinstance(uri, str) else str(uri.expanduser().resolve())
+ is_win = re.match("^[a-zA-Z]:.*", uri) is not None # such as 'C:\\data', 'D:'
+ # such as 'host:/data/' (User may define short hostname by themselves or use localhost)
+ is_nfs_or_win = re.match("^[^/]+:.+", uri) is not None
+
+ if is_nfs_or_win and not is_win:
+ return QlibConfig.NFS_URI
+ else:
+ return QlibConfig.LOCAL_URI
+
+ def get_data_uri(self, freq: str = None) -> Path:
+ if freq is None or freq not in self.provider_uri:
+ freq = QlibConfig.DEFAULT_FREQ
+ _provider_uri = self.provider_uri[freq]
+ if self.get_uri_type(_provider_uri) == QlibConfig.LOCAL_URI:
+ return Path(_provider_uri)
+ elif self.get_uri_type(_provider_uri) == QlibConfig.NFS_URI:
+ if "win" in platform.system().lower():
+ # windows, mount_path is the drive
+ return Path(f"{self.mount_path[freq]}:\\")
+ return Path(self.mount_path[freq])
+ else:
+ raise NotImplementedError(f"This type of uri is not supported")
+
def set_mode(self, mode):
# raise KeyError
self.update(MODE_CONF[mode])
@@ -248,13 +300,42 @@ class QlibConfig(Config):
# raise KeyError
self.update(_default_region_config[region])
+ @staticmethod
+ def is_depend_redis(cache_name: str):
+ return cache_name in DEPENDENCY_REDIS_CACHE
+
+ @property
+ def dpm(self):
+ return self.DataPathManager(self["provider_uri"], self["mount_path"])
+
def resolve_path(self):
# resolve path
- if self["mount_path"] is not None:
- self["mount_path"] = str(Path(self["mount_path"]).expanduser().resolve())
+ _mount_path = self["mount_path"]
+ _provider_uri = self["provider_uri"]
+ if _provider_uri is None:
+ raise ValueError("provider_uri cannot be None")
+ if not isinstance(_provider_uri, dict):
+ _provider_uri = {self.DEFAULT_FREQ: _provider_uri}
+ if not isinstance(_mount_path, dict):
+ _mount_path = {_freq: _mount_path for _freq in _provider_uri.keys()}
- if self.get_uri_type() == QlibConfig.LOCAL_URI:
- self["provider_uri"] = str(Path(self["provider_uri"]).expanduser().resolve())
+ # check provider_uri and mount_path
+ _miss_freq = set(_provider_uri.keys()) - set(_mount_path.keys())
+ assert len(_miss_freq) == 0, f"mount_path is missing freq: {_miss_freq}"
+
+ # resolve
+ for _freq, _uri in _provider_uri.items():
+ # provider_uri
+ if self.DataPathManager.get_uri_type(_uri) == QlibConfig.LOCAL_URI:
+ _provider_uri[_freq] = str(Path(_uri).expanduser().resolve())
+ # mount_path
+ _mount_path[_freq] = (
+ _mount_path[_freq]
+ if _mount_path[_freq] is None
+ else str(Path(_mount_path[_freq]).expanduser().resolve())
+ )
+ self["provider_uri"] = _provider_uri
+ self["mount_path"] = _mount_path
def get_uri_type(self):
path = self["provider_uri"]
@@ -270,14 +351,6 @@ class QlibConfig(Config):
else:
return QlibConfig.LOCAL_URI
- def get_data_path(self):
- if self.get_uri_type() == QlibConfig.LOCAL_URI:
- return self["provider_uri"]
- elif self.get_uri_type() == QlibConfig.NFS_URI:
- return self["mount_path"]
- else:
- raise NotImplementedError(f"This type of uri is not supported")
-
def set(self, default_conf: str = "client", **kwargs):
"""
configure qlib based on the input parameters
@@ -325,11 +398,20 @@ class QlibConfig(Config):
if not (self["expression_cache"] is None and self["dataset_cache"] is None):
# check redis
if not can_use_cache():
- logger.warning(
- f"redis connection failed(host={self['redis_host']} port={self['redis_port']}), cache will not be used!"
- )
- self["expression_cache"] = None
- self["dataset_cache"] = None
+ log_str = ""
+ # check expression cache
+ if self.is_depend_redis(self["expression_cache"]):
+ log_str += self["expression_cache"]
+ self["expression_cache"] = None
+ # check dataset cache
+ if self.is_depend_redis(self["dataset_cache"]):
+ log_str += f" and {self['dataset_cache']}" if log_str else self["dataset_cache"]
+ self["dataset_cache"] = None
+ if log_str:
+ logger.warning(
+ f"redis connection failed(host={self['redis_host']} port={self['redis_port']}), "
+ f"{log_str} will not be used!"
+ )
def register(self):
from .utils import init_instance_by_config
diff --git a/qlib/contrib/data/dataset.py b/qlib/contrib/data/dataset.py
new file mode 100644
index 000000000..af4893acf
--- /dev/null
+++ b/qlib/contrib/data/dataset.py
@@ -0,0 +1,346 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import copy
+import torch
+import warnings
+import numpy as np
+import pandas as pd
+
+from qlib.utils import init_instance_by_config
+from qlib.data.dataset import DatasetH, DataHandler
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+
+def _to_tensor(x):
+ if not isinstance(x, torch.Tensor):
+ return torch.tensor(x, dtype=torch.float, device=device)
+ return x
+
+
+def _create_ts_slices(index, seq_len):
+ """
+ create time series slices from pandas index
+
+ Args:
+ index (pd.MultiIndex): pandas multiindex with order
+ seq_len (int): sequence length
+ """
+ assert isinstance(index, pd.MultiIndex), "unsupported index type"
+ assert seq_len > 0, "sequence length should be larger than 0"
+ assert index.is_monotonic_increasing, "index should be sorted"
+
+ # number of dates for each instrument
+ sample_count_by_insts = index.to_series().groupby(level=0).size().values
+
+ # start index for each instrument
+ start_index_of_insts = np.roll(np.cumsum(sample_count_by_insts), 1)
+ start_index_of_insts[0] = 0
+
+ # all the [start, stop) indices of features
+ # features between [start, stop) will be used to predict label at `stop - 1`
+ slices = []
+ for cur_loc, cur_cnt in zip(start_index_of_insts, sample_count_by_insts):
+ for stop in range(1, cur_cnt + 1):
+ end = cur_loc + stop
+ start = max(end - seq_len, 0)
+ slices.append(slice(start, end))
+ slices = np.array(slices, dtype="object")
+
+ assert len(slices) == len(index) # the i-th slice = index[i]
+
+ return slices
+
+
+def _get_date_parse_fn(target):
+ """get date parse function
+
+ This method is used to parse date arguments as target type.
+
+ Example:
+ get_date_parse_fn('20120101')('2017-01-01') => '20170101'
+ get_date_parse_fn(20120101)('2017-01-01') => 20170101
+ """
+ if isinstance(target, pd.Timestamp):
+ _fn = lambda x: pd.Timestamp(x) # Timestamp('2020-01-01')
+ elif isinstance(target, int):
+ _fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201
+ elif isinstance(target, str) and len(target) == 8:
+ _fn = lambda x: str(x).replace("-", "")[:8] # '20200201'
+ else:
+ _fn = lambda x: x # '2021-01-01'
+ return _fn
+
+
+def _maybe_padding(x, seq_len, zeros=None):
+ """padding 2d