1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-29 00:51:19 +08:00

Compare commits

...

46 Commits

Author SHA1 Message Date
Dong Zhou
5ac9dd7221 temporarily fix create exp conflicts for remote mlflow 2021-11-12 05:16:17 +00:00
you-n-g
7efec6bbc4 Fix private import 2021-11-08 09:52:55 +08:00
Young
3fa48d7017 simplify record tmp 2021-11-05 12:57:14 +00:00
Young
4f2d6b0d84 fix pytorch memory amount error 2021-11-02 20:41:39 +08:00
Young
3943b7001f fix CI bug for AyncCaller 2021-11-02 14:32:09 +08:00
Young
2593185721 Simplify TSDataset and async recorder 2021-11-02 11:07:40 +08:00
Young
7a884fa9f2 remove redundant file only when remote artifact 2021-11-01 18:55:44 +08:00
Dong Zhou
d929d4bb21 rm recorder temp file 2021-11-01 09:29:44 +00:00
Young
e54b019ee2 solve init kwargs conflictions 2021-11-01 06:22:25 +00:00
Young
426b98a3bc make the logic of online manager cleaner 2021-11-01 02:40:54 +00:00
Young
82f8ff9066 Update seperate dataframe 2021-11-01 00:51:21 +08:00
Young
31e9d529de Add multi horizon task generator 2021-10-28 00:01:19 +08:00
Young
5fa56703ae add handler pickle attr, enhance init_instance_by_config 2021-10-26 23:32:33 +08:00
Dong Zhou
c6bb11fe56 avoid trade without enough cash 2021-10-25 05:46:19 +00:00
Dong Zhou
3d7ebd1fe0 add back trade_val 2021-10-22 10:13:15 +00:00
Dong Zhou
7313b4dad0 fix impact cost 2021-10-22 08:58:37 +00:00
Dong Zhou
b70caff522 add doc 2021-10-22 08:49:20 +00:00
Dong Zhou
96b422a906 support market impact cost 2021-10-22 08:44:47 +00:00
Young
64130d9407 Fix the aggregation function of IndexData 2021-10-22 15:20:45 +08:00
Young
a58bc03a8e add sepdf(make mini project only rely on qlib) 2021-10-21 13:15:02 +00:00
Young
f537222ce3 make handler seperable 2021-10-21 12:38:24 +00:00
Dong Zhou
c427c64845 fix calendar 2021-10-19 06:17:53 +00:00
Young
22ff8fdc44 simple change log 2021-10-16 17:14:37 +00:00
Young
4efb0a75c1 Being compatible with previous Qlib version 2021-10-16 16:43:38 +00:00
Young
052aad7982 simplify signal parameter 2021-10-15 14:48:31 +00:00
Young
12f05c7182 Merge branch 'backtest_improve' of github.com:microsoft/qlib into backtest_improve 2021-10-15 11:27:33 +00:00
Young
ac08468330 Make static prediction easier 2021-10-15 11:21:03 +00:00
Dong Zhou
df9745f134 support empty order 2021-10-15 09:07:03 +00:00
Dong Zhou
2e49a5f7c0 fix order generator 2021-10-15 07:04:47 +00:00
you-n-g
3ab5721448 Fix OrderGenerator's return value 2021-10-15 14:28:08 +08:00
you-n-g
6a94b45503 Update order_generator.py 2021-10-15 13:52:55 +08:00
you-n-g
7c31012b50 Auto injecting model and dataset for Recorder (#645)
* Auto injecting model and dataset for Recorder

* Support using Feature in expression
2021-10-15 13:50:24 +08:00
you-n-g
334b92ace7 Checking dataset empty (#647)
* Checking dataset empty

* add dataset checker
2021-10-14 23:35:12 +08:00
you-n-g
9a175d7507 improve the doc of auto init (#541)
* improve the doc of auto init

* Update setup.py

* Update setup.py

* change cvxpy version

Co-authored-by: Wangwuyi123 <51237097+Wangwuyi123@users.noreply.github.com>
2021-10-12 11:58:27 +08:00
Lewen Wang
17ea44e0cf Update TCTS. (#643)
* Update TCTS.

* Update TCTS README.

* Update TCTS README.

* Update TCTS.

Co-authored-by: lewwang <lwwang@microsoft.com>
2021-10-12 10:08:48 +08:00
you-n-g
c0ce712be9 more detailed docs for workflow (#639)
* more detailed docs for workflow

* add more detailed docs for workflow
2021-10-11 15:38:18 +08:00
demon143
8e81a017c1 Update manage.py (#628)
* Update manage.py

* Update manage.py

* Update manage.py

* Create manage.py

* Update manage.py

* Update qlib/workflow/task/manage.py

Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>

Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
2021-10-11 15:37:50 +08:00
you-n-g
706727988c Update README.md 2021-10-09 23:37:07 +08:00
you-n-g
e99224e5c2 Update benchmark based on new backtest (#634)
* free random seed

* update model baselines

* more robust for parameters
2021-10-07 22:57:19 +08:00
Pengrong Zhu
8c8d1336de fix workflow_config_lightgbm_multi_freq.yaml (#635) 2021-10-06 17:18:27 +08:00
Pengrong Zhu
d01de411a8 add support for macos-11 (#630)
Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
2021-10-03 16:49:17 +08:00
Young
28fe4d4bb4 update file strategy test 2021-10-03 14:58:37 +08:00
Young
873129aa9b update fix CI tests bugs 2021-10-03 14:58:37 +08:00
Young
3a152f9b8b fix CI 2021-10-03 14:58:37 +08:00
Young
2b75b41a08 remove 3.6 2021-10-03 14:58:37 +08:00
you-n-g
00d17f0a52 Update python-publish.yml 2021-10-01 03:03:26 +08:00
109 changed files with 1562 additions and 674 deletions

View File

@@ -12,8 +12,9 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [windows-latest, macos-latest]
python-version: [3.6, 3.7, 3.8, 3.9]
os: [windows-latest, macos-latest, macos-11]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
steps:
- uses: actions/checkout@v2
@@ -44,7 +45,8 @@ jobs:
- name: Build wheel on Linux
uses: RalfG/python-wheels-manylinux-build@v0.3.1-manylinux2010_x86_64
with:
python-versions: 'cp36-cp36m cp37-cp37m cp38-cp38'
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-versions: 'cp37-cp37m cp38-cp38'
build-requirements: 'numpy cython'
- name: Set up Python
uses: actions/setup-python@v2

View File

@@ -13,7 +13,8 @@ jobs:
strategy:
matrix:
os: [windows-latest, ubuntu-18.04, ubuntu-20.04]
python-version: [3.6, 3.7, 3.8]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
steps:
- uses: actions/checkout@v2
@@ -49,15 +50,6 @@ jobs:
pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
pip install -e .
- name: Test data downloads
run: |
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
else
$CONDA/bin/python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
fi
shell: bash
- name: Install test dependencies
run: |
pip install --upgrade pip

View File

@@ -10,10 +10,12 @@ on:
jobs:
build:
runs-on: macos-latest
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]
os: [macos-11, macos-latest]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
steps:
- uses: actions/checkout@v2

View File

@@ -159,6 +159,21 @@ Version 0.5.0
- Add baselines
- public data crawler
Version greater than Version 0.5.0
Version 0.8.0
--------------------
- The backtest is greatly refactored.
- Nested decision execution framework is supported
- There are lots of changes for daily trading, it is hard to list all of them. But a few important changes could be noticed
- The trading limitation is more accurate;
- In `previous version <https://github.com/microsoft/qlib/blob/v0.7.2/qlib/contrib/backtest/exchange.py#L160>`_, longing and shorting actions share the same action.
- In `current verison <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/backtest/exchange.py#L304>`_, the trading limitation is different between loging and shorting action.
- The constant is different when calculating annualized metrics.
- `Current version <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/contrib/evaluate.py#L42>`_ uses more accurate constant than `previous version <https://github.com/microsoft/qlib/blob/v0.7.2/qlib/contrib/evaluate.py#L22>`_
- `A new version <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/tests/data.py#L17>`_ of data is released. Due to the unstability of Yahoo data source, the data may be different after downloading data again.
- Users could chec kout the backtesting results between `Current version <https://github.com/microsoft/qlib/tree/7c31012b507a3823117bddcc693fc64899460b2a/examples/benchmarks>`_ and `previous version <https://github.com/microsoft/qlib/tree/v0.7.2/examples/benchmarks>`_
Other Versions
----------------------------------
Please refer to `Github release Notes <https://github.com/microsoft/qlib/releases>`_

View File

@@ -100,7 +100,6 @@ Here is a quick **[demo](https://terminalizer.com/view/3f24561a4470)** shows how
This table demonstrates the supported Python version of `Qlib`:
| | install with pip | install from source | plot |
| ------------- |:---------------------:|:--------------------:|:----:|
| Python 3.6 | :heavy_check_mark: | :heavy_check_mark: (only with `Anaconda`) | :heavy_check_mark: |
| Python 3.7 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Python 3.8 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
@@ -307,7 +306,7 @@ All the models listed above are runnable with ``Qlib``. Users can find the confi
- Users can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
- Users can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py run --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
- **NOTE**: Each baseline has different environment dependencies, please make sure that your python version aligns with the requirements(e.g. TFT only supports Python 3.6~3.7 due to the limitation of `tensorflow==1.15.0`)
## Run multiple models
@@ -317,7 +316,7 @@ The script will create a unique virtual environment for each model, and delete t
Here is an example of running all the models for 10 iterations:
```python
python run_all_model.py 10
python run_all_model.py run 10
```
It also provides the API to run specific models at once. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).

View File

@@ -53,6 +53,9 @@ Below is a typical config file of ``qrun``.
kwargs:
topk: 50
n_drop: 5
signal:
- <MODEL>
- <DATASET>
backtest:
limit_threshold: 0.095
account: 100000000
@@ -240,6 +243,9 @@ The following script is the configuration of `backtest` and the `strategy` used
kwargs:
topk: 50
n_drop: 5
signal:
- <MODEL>
- <DATASET>
backtest:
limit_threshold: 0.095
account: 100000000

View File

@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -86,4 +87,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -100,4 +101,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -35,8 +35,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -94,4 +95,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -85,4 +86,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -85,4 +86,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -14,7 +14,7 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
model: <MODEL>
dataset: <DATASET>
topk: 50
n_drop: 5

View File

@@ -33,6 +33,9 @@ port_analysis_config: &port_analysis_config
kwargs:
topk: 50
n_drop: 5
signal:
- <MODEL>
- <DATASET>
backtest:
verbose: False
limit_threshold: 0.095
@@ -80,4 +83,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -76,4 +77,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -29,8 +29,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -31,18 +31,22 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
limit_threshold: 0.095
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LGBModel

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -41,8 +41,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -98,4 +99,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -29,8 +29,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -85,4 +86,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -3,49 +3,58 @@
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs with different random seeds.
The numbers shown below demonstrate the performance of the entire `workflow` of each model. We will update the `workflow` as well as models in the near future for better results.
<!--
> If you need to reproduce the results below, please use the **v1** dataset: `python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1d --region cn --version v1`
>
> In the new version of qlib, the default dataset is **v2**. Since the data is collected from the YahooFinance API (which is not very stable), the results of *v2* and *v1* may differ
> In the new version of qlib, the default dataset is **v2**. Since the data is collected from the YahooFinance API (which is not very stable), the results of *v2* and *v1* may differ -->
> NOTE:
> The backtest start from 0.8.0 is quite different from previous version. Please check out the changelog for the difference.
## Alpha360 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|---|---|---|---|---|---|---|---|---|
| Linear | Alpha360 | 0.0150±0.00 | 0.1049±0.00| 0.0284±0.00 | 0.1970±0.00 | -0.0659±0.00 | -0.7072±0.00| -0.2955±0.00 |
| CatBoost (Liudmila Prokhorenkova, et al.) | Alpha360 | 0.0397±0.00 | 0.2878±0.00| 0.0470±0.00 | 0.3703±0.00 | 0.0342±0.00 | 0.4092±0.00| -0.1057±0.00 |
| XGBoost (Tianqi Chen, et al.) | Alpha360 | 0.0400±0.00 | 0.3031±0.00| 0.0461±0.00 | 0.3862±0.00 | 0.0528±0.00 | 0.6307±0.00| -0.1113±0.00 |
| LightGBM (Guolin Ke, et al.) | Alpha360 | 0.0399±0.00 | 0.3075±0.00| 0.0492±0.00 | 0.4019±0.00 | 0.0323±0.00 | 0.4370±0.00| -0.0917±0.00 |
| MLP | Alpha360 | 0.0285±0.00 | 0.1981±0.02| 0.0402±0.00 | 0.2993±0.02 | 0.0073±0.02 | 0.0880±0.22| -0.1446±0.03 |
| GRU (Kyunghyun Cho, et al.) | Alpha360 | 0.0490±0.01 | 0.3787±0.05| 0.0581±0.00 | 0.4664±0.04 | 0.0726±0.02 | 0.9817±0.34| -0.0902±0.03 |
| LSTM (Sepp Hochreiter, et al.) | Alpha360 | 0.0443±0.01 | 0.3401±0.05| 0.0536±0.01 | 0.4248±0.05 | 0.0627±0.03 | 0.8441±0.48| -0.0882±0.03 |
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 |
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 |
| TCTS (Xueqing Wu, et al.)| Alpha360 | 0.0485±0.00 | 0.3689±0.04| 0.0586±0.00 | 0.4669±0.02 | 0.0816±0.02 | 1.1572±0.30| -0.0689±0.02 |
| Transformer (Ashish Vaswani, et al.)| Alpha360 | 0.0141±0.00 | 0.0917±0.02| 0.0331±0.00 | 0.2357±0.03 | -0.0259±0.03 | -0.3323±0.43| -0.1763±0.07 |
| Localformer (Juyong Jiang, et al.)| Alpha360 | 0.0408±0.00 | 0.2988±0.03| 0.0538±0.00 | 0.4105±0.02 | 0.0275±0.03 | 0.3464±0.37| -0.1182±0.03 |
| TRA (Hengxu Lin, et al.)| Alpha360 | 0.0491±0.01 | 0.3868±0.06 | 0.0589±0.00 | 0.4802±0.04 | 0.0898±0.02 | 1.2490±0.32 | -0.0778±0.02 |
## Alpha158 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|---|---|---|---|---|---|---|---|---|
| Linear | Alpha158 | 0.0393±0.00 | 0.2980±0.00| 0.0475±0.00 | 0.3546±0.00 | 0.0795±0.00 | 1.0712±0.00| -0.1449±0.00 |
| CatBoost (Liudmila Prokhorenkova, et al.) | Alpha158 | 0.0503±0.00 | 0.3586±0.00| 0.0483±0.00 | 0.3667±0.00 | 0.1080±0.00 | 1.1561±0.00| -0.0787±0.00 |
| XGBoost (Tianqi Chen, et al.) | Alpha158 | 0.0481±0.00 | 0.3659±0.00| 0.0495±0.00 | 0.4033±0.00 | 0.1111±0.00 | 1.2915±0.00| -0.0893±0.00 |
| LightGBM (Guolin Ke, et al.) | Alpha158 | 0.0475±0.00 | 0.3979±0.00| 0.0485±0.00 | 0.4123±0.00 | 0.1143±0.00 | 1.2744±0.00| -0.0800±0.00 |
| MLP | Alpha158 | 0.0358±0.00 | 0.2738±0.03| 0.0425±0.00 | 0.3221±0.01 | 0.0836±0.02 | 1.0323±0.25| -0.1127±0.02 |
| TFT (Bryan Lim, et al.) | Alpha158 (with selected 20 features) | 0.0343±0.00 | 0.2071±0.02| 0.0107±0.00 | 0.0660±0.02 | 0.0623±0.02 | 0.5818±0.20| -0.1762±0.01 |
| GRU (Kyunghyun Cho, et al.) | Alpha158 (with selected 20 features) | 0.0311±0.00 | 0.2418±0.04| 0.0425±0.00 | 0.3434±0.02 | 0.0330±0.02 | 0.4805±0.30| -0.1021±0.02 |
| LSTM (Sepp Hochreiter, et al.) | Alpha158 (with selected 20 features) | 0.0312±0.00 | 0.2394±0.04| 0.0418±0.00 | 0.3324±0.03 | 0.0298±0.02 | 0.4198±0.33| -0.1348±0.03 |
| ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 |
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 |
| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 |
| Transformer (Ashish Vaswani, et al.)| Alpha158 | 0.0274±0.00 | 0.2166±0.04| 0.0409±0.00 | 0.3342±0.04 | 0.0204±0.03 | 0.2888±0.40| -0.1216±0.04 |
| Localformer (Juyong Jiang, et al.)| Alpha158 | 0.0355±0.00 | 0.2747±0.04| 0.0466±0.00 | 0.3762±0.03 | 0.0506±0.02 | 0.7447±0.34| -0.0875±0.02 |
| TRA (Hengxu Lin, et al.)| Alpha158 (with selected 20 features)| 0.0409±0.00 | 0.3253±0.04 | 0.0488±0.00 | 0.4045±0.02 | 0.0673±0.02 | 1.0389±0.39 | -0.0830±0.02 |
| TRA (Hengxu Lin, et al.)| Alpha158 | 0.0442±0.00 | 0.3426±0.03 | 0.0555±0.00 | 0.4395±0.03 | 0.0833±0.03 | 1.2064±0.36 | -0.0849±0.02 |
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|------------------------------------------|-------------------------------------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
| TabNet(Sercan O. Arik, et al.) | Alpha158 | 0.0204±0.01 | 0.1554±0.07 | 0.0333±0.00 | 0.2552±0.05 | 0.0227±0.04 | 0.3676±0.54 | -0.1089±0.08 |
| Transformer(Ashish Vaswani, et al.) | Alpha158 | 0.0264±0.00 | 0.2053±0.02 | 0.0407±0.00 | 0.3273±0.02 | 0.0273±0.02 | 0.3970±0.26 | -0.1101±0.02 |
| GRU(Kyunghyun Cho, et al.) | Alpha158(with selected 20 features) | 0.0315±0.00 | 0.2450±0.04 | 0.0428±0.00 | 0.3440±0.03 | 0.0344±0.02 | 0.5160±0.25 | -0.1017±0.02 |
| LSTM(Sepp Hochreiter, et al.) | Alpha158(with selected 20 features) | 0.0318±0.00 | 0.2367±0.04 | 0.0435±0.00 | 0.3389±0.03 | 0.0381±0.03 | 0.5561±0.46 | -0.1207±0.04 |
| Localformer(Juyong Jiang, et al.) | Alpha158 | 0.0356±0.00 | 0.2756±0.03 | 0.0468±0.00 | 0.3784±0.03 | 0.0438±0.02 | 0.6600±0.33 | -0.0952±0.02 |
| SFM(Liheng Zhang, et al.) | Alpha158 | 0.0379±0.00 | 0.2959±0.04 | 0.0464±0.00 | 0.3825±0.04 | 0.0465±0.02 | 0.5672±0.29 | -0.1282±0.03 |
| ALSTM (Yao Qin, et al.) | Alpha158(with selected 20 features) | 0.0362±0.01 | 0.2789±0.06 | 0.0463±0.01 | 0.3661±0.05 | 0.0470±0.03 | 0.6992±0.47 | -0.1072±0.03 |
| GATs (Petar Velickovic, et al.) | Alpha158(with selected 20 features) | 0.0349±0.00 | 0.2511±0.01 | 0.0462±0.00 | 0.3564±0.01 | 0.0497±0.01 | 0.7338±0.19 | -0.0777±0.02 |
| TRA(Hengxu Lin, et al.) | Alpha158(with selected 20 features) | 0.0404±0.00 | 0.3197±0.05 | 0.0490±0.00 | 0.4047±0.04 | 0.0649±0.02 | 1.0091±0.30 | -0.0860±0.02 |
| Linear | Alpha158 | 0.0397±0.00 | 0.3000±0.00 | 0.0472±0.00 | 0.3531±0.00 | 0.0692±0.00 | 0.9209±0.00 | -0.1509±0.00 |
| TRA(Hengxu Lin, et al.) | Alpha158 | 0.0440±0.00 | 0.3535±0.05 | 0.0540±0.00 | 0.4451±0.03 | 0.0718±0.02 | 1.0835±0.35 | -0.0760±0.02 |
| CatBoost(Liudmila Prokhorenkova, et al.) | Alpha158 | 0.0481±0.00 | 0.3366±0.00 | 0.0454±0.00 | 0.3311±0.00 | 0.0765±0.00 | 0.8032±0.01 | -0.1092±0.00 |
| XGBoost(Tianqi Chen, et al.) | Alpha158 | 0.0498±0.00 | 0.3779±0.00 | 0.0505±0.00 | 0.4131±0.00 | 0.0780±0.00 | 0.9070±0.00 | -0.1168±0.00 |
| TFT (Bryan Lim, et al.) | Alpha158(with selected 20 features) | 0.0358±0.00 | 0.2160±0.03 | 0.0116±0.01 | 0.0720±0.03 | 0.0847±0.02 | 0.8131±0.19 | -0.1824±0.03 |
| MLP | Alpha158 | 0.0376±0.00 | 0.2846±0.02 | 0.0429±0.00 | 0.3220±0.01 | 0.0895±0.02 | 1.1408±0.23 | -0.1103±0.02 |
| LightGBM(Guolin Ke, et al.) | Alpha158 | 0.0448±0.00 | 0.3660±0.00 | 0.0469±0.00 | 0.3877±0.00 | 0.0901±0.00 | 1.0164±0.00 | -0.1038±0.00 |
| DoubleEnsemble(Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4340±0.00 | 0.0523±0.00 | 0.4284±0.01 | 0.1168±0.01 | 1.3384±0.12 | -0.1036±0.01 |
## Alpha360 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|-------------------------------------------|----------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
| Transformer(Ashish Vaswani, et al.) | Alpha360 | 0.0114±0.00 | 0.0716±0.03 | 0.0327±0.00 | 0.2248±0.02 | -0.0270±0.03 | -0.3378±0.37 | -0.1653±0.05 |
| TabNet(Sercan O. Arik, et al.) | Alpha360 | 0.0099±0.00 | 0.0593±0.00 | 0.0290±0.00 | 0.1887±0.00 | -0.0369±0.00 | -0.3892±0.00 | -0.2145±0.00 |
| MLP | Alpha360 | 0.0273±0.00 | 0.1870±0.02 | 0.0396±0.00 | 0.2910±0.02 | 0.0029±0.02 | 0.0274±0.23 | -0.1385±0.03 |
| Localformer(Juyong Jiang, et al.) | Alpha360 | 0.0404±0.00 | 0.2932±0.04 | 0.0542±0.00 | 0.4110±0.03 | 0.0246±0.02 | 0.3211±0.21 | -0.1095±0.02 |
| CatBoost((Liudmila Prokhorenkova, et al.) | Alpha360 | 0.0378±0.00 | 0.2714±0.00 | 0.0467±0.00 | 0.3659±0.00 | 0.0292±0.00 | 0.3781±0.00 | -0.0862±0.00 |
| XGBoost(Tianqi Chen, et al.) | Alpha360 | 0.0394±0.00 | 0.2909±0.00 | 0.0448±0.00 | 0.3679±0.00 | 0.0344±0.00 | 0.4527±0.02 | -0.1004±0.00 |
| DoubleEnsemble(Chuheng Zhang, et al.) | Alpha360 | 0.0404±0.00 | 0.3023±0.00 | 0.0495±0.00 | 0.3898±0.00 | 0.0468±0.01 | 0.6302±0.20 | -0.0860±0.01 |
| LightGBM(Guolin Ke, et al.) | Alpha360 | 0.0400±0.00 | 0.3037±0.00 | 0.0499±0.00 | 0.4042±0.00 | 0.0558±0.00 | 0.7632±0.00 | -0.0659±0.00 |
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0497±0.00 | 0.3829±0.04 | 0.0599±0.00 | 0.4736±0.03 | 0.0626±0.02 | 0.8651±0.31 | -0.0994±0.03 |
| LSTM(Sepp Hochreiter, et al.) | Alpha360 | 0.0448±0.00 | 0.3474±0.04 | 0.0549±0.00 | 0.4366±0.03 | 0.0647±0.03 | 0.8963±0.39 | -0.0875±0.02 |
| GRU(Kyunghyun Cho, et al.) | Alpha360 | 0.0493±0.00 | 0.3772±0.04 | 0.0584±0.00 | 0.4638±0.03 | 0.0720±0.02 | 0.9730±0.33 | -0.0821±0.02 |
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0476±0.00 | 0.3508±0.02 | 0.0598±0.00 | 0.4604±0.01 | 0.0824±0.02 | 1.1079±0.26 | -0.0894±0.03 |
| TCTS(Xueqing Wu, et al.) | Alpha360 | 0.0508±0.00 | 0.3931±0.04 | 0.0599±0.00 | 0.4756±0.03 | 0.0893±0.03 | 1.2256±0.36 | -0.0857±0.02 |
| TRA(Hengxu Lin, et al.) | Alpha360 | 0.0485±0.00 | 0.3787±0.03 | 0.0587±0.00 | 0.4756±0.03 | 0.0920±0.03 | 1.2789±0.42 | -0.0834±0.02 |
- The selected 20 features are based on the feature importance of a lightgbm-based model.
- The base model of DoubleEnsemble is LGBM.
- The base model of TCTS is GRU.

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -1,52 +1,38 @@
# Temporally Correlated Task Scheduling for Sequence Learning
We provide the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) for reproducing the stock trend forecasting experiments.
### Background
Sequence learning has attracted much research attention from the machine learning community in recent years. In many applications, a sequence learning task is usually associated with multiple temporally correlated auxiliary tasks, which are different in terms of how much input information to use or which future step to predict. In stock trend forecasting, as demonstrated in Figure1, one can predict the price of a stock in different future days (e.g., tomorrow, the day after tomorrow). In this paper, we propose a framework to make use of those temporally correlated tasks to help each other.
<p align="center">
<img src="task_description.png" width="600" height="200"/>
</p>
### Method
Given that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. In this work, we introduce a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in current minibatch), and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2.
Given that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. This work introduces a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in the current minibatch) and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2.
<p align="center">
<img src="workflow.png"/>
</p>
At step <img src="https://render.githubusercontent.com/render/math?math=s">, with training data <img src="https://render.githubusercontent.com/render/math?math=x_s,y_s">, the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> chooses a suitable task <img src="https://render.githubusercontent.com/render/math?math=T_{i_s}"> (green solid lines) to update the model <img src="https://render.githubusercontent.com/render/math?math=f"> (blue solid lines). After <img src="https://render.githubusercontent.com/render/math?math=S"> steps, we evaluate the model <img src="https://render.githubusercontent.com/render/math?math=f"> on the validation set and update the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> (green dashed lines).
### DataSet
* We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020.
* We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time.
At step <img src="https://latex.codecogs.com/png.latex?s" title="s" />, with training data <img src="https://latex.codecogs.com/png.latex?x_s,y_s" title="x_s,y_s" />, the scheduler <img src="https://latex.codecogs.com/png.latex?\varphi" title="\varphi" /> chooses a suitable task <img src="https://latex.codecogs.com/png.latex?T_{i_s}" title="T_{i_s}" /> (green solid lines) to update the model <img src="https://latex.codecogs.com/png.latex?f" title="f" /> (blue solid lines). After <img src="https://latex.codecogs.com/png.latex?S" title="S" /> steps, we evaluate the model <img src="https://latex.codecogs.com/png.latex?f" title="f" /> on the validation set and update the scheduler <img src="https://latex.codecogs.com/png.latex?\varphi" title="\varphi" /> (green dashed lines).
### Experiments
#### Task Description
* The main tasks <img src="https://render.githubusercontent.com/render/math?math=T_k"> (<img src="https://render.githubusercontent.com/render/math?math=task_k"> in Figure1) refers to forecasting return of stock <img src="https://render.githubusercontent.com/render/math?math=i"> as following,
Due to different data versions and different Qlib versions, the original data and data preprocessing methods of the experimental settings in the paper are different from those experimental settings in the existing Qlib version. Therefore, we provide two versions of the code according to the two kinds of settings, 1) the [code](https://github.com/lwwang1995/tcts) that can be used to reproduce the experimental results and 2) the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) in the current Qlib baseline.
#### Setting1
* Dataset: We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020. We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time.
* The main tasks <img src="https://latex.codecogs.com/png.latex?T_k" title="T_k" /> refers to forecasting return of stock <img src="https://latex.codecogs.com/png.latex?i" title="i" /> as following,
<div align=center>
<img src="https://render.githubusercontent.com/render/math?math=r_{i}^k = \frac{\price_i^{t+k}}{\price_i^{t+k-1}} - 1">
<img src="https://latex.codecogs.com/png.image?\dpi{110}&space;r_{i}^{t,k}&space;=&space;\frac{price_i^{t&plus;k}}{price_i^{t&plus;k-1}}-1" title="r_{i}^{t,k} = \frac{price_i^{t+k}}{price_i^{t+k-1}}-1" />
</div>
* Temporally correlated task sets <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_k = \{T_1, T_2, ... , T_k\}">, in this paper, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5"> and <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_10"> are used.
#### Baselines
* GRU/MLP/LightGBM (LGB)/Graph Attention Networks (GAT)
* Multi-task learning (MTL): In multi-task learning, multiple tasks are jointly trained and mutually boosted. Each task is treated equally, while in our setting, we focus on the main task.
* Curriculum transfer learning (CL): Transfer learning also leverages auxiliary tasks to boost the main task. [Curriculum transfer learning](https://arxiv.org/pdf/1804.00810.pdf) is one kind of transfer learning which schedules auxiliary tasks according to certain rules. Our problem can also be regarded as a special kind of transfer learning, where the auxiliary tasks are temporally correlated with the main task. Our learning process is dynamically controlled by a scheduler rather than some pre-defined rules. In the CL baseline, we start from the task <img src="https://render.githubusercontent.com/render/math?math=T_1" >, then <img src="https://render.githubusercontent.com/render/math?math=T_2" >, and gradually move to the last one.
#### Result
| Methods | <img src="https://render.githubusercontent.com/render/math?math=T_1" > | <img src="https://render.githubusercontent.com/render/math?math=T_2"> | <img src="https://render.githubusercontent.com/render/math?math=T_3"> |
| :----: | :----: | :----: | :----: |
| GRU | 0.049 / 1.903 | 0.018 / 1.972 | 0.014 / 1.989 |
| MLP | 0.023 / 1.961 | 0.022 / 1.962 | 0.015 / 1.978 |
| LGB | 0.038 / 1.883 | 0.023 / 1.952 | 0.007 / 1.987 |
| GAT | 0.052 / 1.898 | 0.024 / 1.954 | 0.015 / 1.973 |
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.061 / 1.862 | 0.023 / 1.942 | 0.012 / 1.956 |
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.051 / 1.880 | 0.028 / 1.941 | 0.016 / 1.962 |
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.071 / 1.851 | 0.030 / 1.939 | 0.017 / 1.963 |
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.057 / 1.875 | 0.021 / 1.939 | 0.017 / 1.959 |
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.056 / 1.877 | 0.028 / 1.942 | 0.015 / 1.962 |
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.075 / 1.849 | 0.032 /1.939 | 0.021 / 1.955 |
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.052 / 1.882 | 0.020 / 1.947 | 0.019 / 1.952 |
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.051 / 1.882 | 0.028 / 1.950 | 0.016 / 1.961 |
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.067 / 1.867 | 0.030 / 1.960 | 0.022 / 1.942|
* Temporally correlated task sets <img src="https://latex.codecogs.com/png.latex?\mathcal{T}_k&space;=&space;\{T_1,&space;T_2,&space;...&space;,&space;T_k\}" title="\mathcal{T}_k = \{T_1, T_2, ... , T_k\}" />, in this paper, <img src="https://latex.codecogs.com/png.latex?\mathcal{T}_3" title="\mathcal{T}_3" />, <img src="https://latex.codecogs.com/png.latex?\mathcal{T}_5" title="\mathcal{T}_5" /> and <img src="https://latex.codecogs.com/png.latex?\mathcal{T}_{10}" title="\mathcal{T}_{10}" /> are used in <img src="https://latex.codecogs.com/png.latex?T_1" title="T_1" />, <img src="https://latex.codecogs.com/png.latex?T_2" title="T_2" />, and <img src="https://latex.codecogs.com/png.latex?T_3" title="T_3" />.
#### Setting2
* Dataset: We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020. We split the data into training (01/01/2008-12/31/2014), validation (01/01/2015-12/31/2016), and test sets (01/01/2017-08/01/2020) based on the transaction time.
* The main tasks <img src="https://latex.codecogs.com/png.latex?T_k" title="T_k" /> refers to forecasting return of stock <img src="https://latex.codecogs.com/png.latex?i" title="i" /> as following,
<div align=center>
<img src="https://latex.codecogs.com/png.image?\dpi{110}&space;r_{i}^{t,k}&space;=&space;\frac{price_i^{t&plus;1&plus;k}}{price_i^{t&plus;1}}-1" title="r_{i}^{t,k} = \frac{price_i^{t+1+k}}{price_i^{t+1}}-1" />
</div>
* In Qlib baseline, <img src="https://latex.codecogs.com/png.latex?\mathcal{T}_3" title="\mathcal{T}_3" />, is used in <img src="https://latex.codecogs.com/png.latex?T_1" title="T_1" />.
### Experimental Result
You can find the experimental result of setting1 in the [paper](http://proceedings.mlr.press/v139/wu21e/wu21e.pdf) and the experimental result of setting2 in this [page](https://github.com/microsoft/qlib/tree/main/examples/benchmarks).

Binary file not shown.

Before

Width:  |  Height:  |  Size: 25 KiB

View File

@@ -22,16 +22,17 @@ data_handler_config: &data_handler_config
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -1) / $close - 1",
"Ref($close, -2) / Ref($close, -1) - 1",
"Ref($close, -3) / Ref($close, -2) - 1"]
label: ["Ref($close, -2) / Ref($close, -1) - 1",
"Ref($close, -3) / Ref($close, -1) - 1",
"Ref($close, -4) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -53,9 +54,8 @@ task:
d_feat: 6
hidden_size: 64
num_layers: 2
dropout: 0.0
dropout: 0.3
n_epochs: 200
lr: 1e-3
early_stop: 20
batch_size: 800
metric: loss
@@ -64,10 +64,10 @@ task:
fore_optimizer: adam
weight_optimizer: adam
output_dim: 3
fore_lr: 5e-4
weight_lr: 5e-4
fore_lr: 2e-3
weight_lr: 2e-3
steps: 3
target_label: 1
target_label: 0
lowest_valid_performance: 0.993
dataset:
class: DatasetH
@@ -92,7 +92,6 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
label_col: 1
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:

View File

@@ -195,7 +195,8 @@ class Alpha158Formatter(GenericDataFormatter):
for col in column_names:
if col not in {"forecast_time", "identifier"}:
output[col] = self._target_scaler.inverse_transform(predictions[col])
# Using [col] is for aligning with the format when fitting
output[col] = self._target_scaler.inverse_transform(predictions[[col]])
return output

View File

@@ -311,5 +311,11 @@ class TFTModel(ModelFT):
# self.model.save(path)
# save qlib model wrapper
self.model = None
drop_attrs = ["model", "tf_graph", "sess", "data_formatter"]
orig_attr = {}
for attr in drop_attrs:
orig_attr[attr] = getattr(self, attr)
setattr(self, attr, None)
super(TFTModel, self).to_pickle(path)
for attr in drop_attrs:
setattr(self, attr, orig_attr[attr])

View File

@@ -16,8 +16,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -38,7 +38,7 @@ class TRAModel(Model):
model_init_state=None,
lamb=0.0,
rho=0.99,
seed=0,
seed=None,
logdir=None,
eval_train=True,
eval_test=False,

View File

@@ -57,8 +57,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -51,8 +51,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -51,8 +51,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -50,6 +51,7 @@ task:
kwargs:
d_feat: 158
pretrain: True
seed: 993
dataset:
class: DatasetH
module_path: qlib.data.dataset

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -50,6 +51,7 @@ task:
kwargs:
d_feat: 360
pretrain: True
seed: 993
dataset:
class: DatasetH
module_path: qlib.data.dataset

View File

@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -151,10 +151,9 @@ class NestedDecisionExecutionWorkflow:
self._train_model(model, dataset)
strategy_config = {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"module_path": "qlib.contrib.strategy.signal_strategy",
"kwargs": {
"model": model,
"dataset": dataset,
"signal": (model, dataset),
"topk": 50,
"n_drop": 5,
},
@@ -189,10 +188,9 @@ class NestedDecisionExecutionWorkflow:
backtest_config["benchmark"] = self.benchmark
strategy_config = {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"module_path": "qlib.contrib.strategy.signal_strategy",
"kwargs": {
"model": model,
"dataset": dataset,
"signal": (model, dataset),
"topk": 50,
"n_drop": 5,
},

View File

@@ -151,6 +151,9 @@ def get_all_results(folders) -> dict:
if recorders[recorder_id].status == "FINISHED":
recorder = R.get_recorder(recorder_id=recorder_id, experiment_name=fn)
metrics = recorder.list_metrics()
if "1day.excess_return_with_cost.annualized_return" not in metrics:
print(f"{recorder_id} is skipped due to incomplete result")
continue
result["annualized_return_with_cost"].append(metrics["1day.excess_return_with_cost.annualized_return"])
result["information_ratio_with_cost"].append(metrics["1day.excess_return_with_cost.information_ratio"])
result["max_drawdown_with_cost"].append(metrics["1day.excess_return_with_cost.max_drawdown"])
@@ -200,174 +203,183 @@ def gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir):
return temp_path
# function to run the all the models
@only_allow_defined_args
def run(
times=1,
models=None,
dataset="Alpha360",
exclude=False,
qlib_uri: str = "git+https://github.com/microsoft/qlib#egg=pyqlib",
exp_folder_name: str = "run_all_model_records",
wait_before_rm_env: bool = False,
wait_when_err: bool = False,
):
"""
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parallel running the same model
for multiple times, and this will be fixed in the future development.
class ModelRunner:
def _init_qlib(self, exp_folder_name):
# init qlib
GetData().qlib_data(exists_skip=True)
qlib.init(
exp_manager={
"class": "MLflowExpManager",
"module_path": "qlib.workflow.expm",
"kwargs": {
"uri": "file:" + str(Path(os.getcwd()).resolve() / exp_folder_name),
"default_exp_name": "Experiment",
},
}
)
Parameters:
-----------
times : int
determines how many times the model should be running.
models : str or list
determines the specific model or list of models to run or exclude.
exclude : boolean
determines whether the model being used is excluded or included.
dataset : str
determines the dataset to be used for each model.
qlib_uri : str
the uri to install qlib with pip
it could be url on the we or local path
exp_folder_name: str
the name of the experiment folder
wait_before_rm_env : bool
wait before remove environment.
wait_when_err : bool
wait when errors raised when executing commands
# function to run the all the models
@only_allow_defined_args
def run(
self,
times=1,
models=None,
dataset="Alpha360",
exclude=False,
qlib_uri: str = "git+https://github.com/microsoft/qlib#egg=pyqlib",
exp_folder_name: str = "run_all_model_records",
wait_before_rm_env: bool = False,
wait_when_err: bool = False,
):
"""
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parallel running the same model
for multiple times, and this will be fixed in the future development.
Usage:
-------
Here are some use cases of the function in the bash:
Parameters:
-----------
times : int
determines how many times the model should be running.
models : str or list
determines the specific model or list of models to run or exclude.
exclude : boolean
determines whether the model being used is excluded or included.
dataset : str
determines the dataset to be used for each model.
qlib_uri : str
the uri to install qlib with pip
it could be url on the we or local path
exp_folder_name: str
the name of the experiment folder
wait_before_rm_env : bool
wait before remove environment.
wait_when_err : bool
wait when errors raised when executing commands
.. code-block:: bash
Usage:
-------
Here are some use cases of the function in the bash:
# Case 1 - run all models multiple times
python run_all_model.py 3
.. code-block:: bash
# Case 2 - run specific models multiple times
python run_all_model.py 3 mlp
# Case 1 - run all models multiple times
python run_all_model.py run 3
# Case 3 - run specific models multiple times with specific dataset
python run_all_model.py 3 mlp Alpha158
# Case 2 - run specific models multiple times
python run_all_model.py run 3 mlp
# Case 4 - run other models except those are given as arguments for multiple times
python run_all_model.py 3 [mlp,tft,lstm] --exclude=True
# Case 3 - run specific models multiple times with specific dataset
python run_all_model.py run 3 mlp Alpha158
# Case 5 - run specific models for one time
python run_all_model.py --models=[mlp,lightgbm]
# Case 4 - run other models except those are given as arguments for multiple times
python run_all_model.py run 3 [mlp,tft,lstm] --exclude=True
# Case 6 - run other models except those are given as arguments for one time
python run_all_model.py --models=[mlp,tft,sfm] --exclude=True
# Case 5 - run specific models for one time
python run_all_model.py run --models=[mlp,lightgbm]
"""
# init qlib
GetData().qlib_data(exists_skip=True)
qlib.init(
exp_manager={
"class": "MLflowExpManager",
"module_path": "qlib.workflow.expm",
"kwargs": {
"uri": "file:" + str(Path(os.getcwd()).resolve() / exp_folder_name),
"default_exp_name": "Experiment",
},
}
)
# Case 6 - run other models except those are given as arguments for one time
python run_all_model.py run --models=[mlp,tft,sfm] --exclude=True
# get all folders
folders = get_all_folders(models, exclude)
# init error messages:
errors = dict()
# run all the model for iterations
for fn in folders:
# get all files
sys.stderr.write("Retrieving files...\n")
yaml_path, req_path = get_all_files(folders[fn], dataset)
if yaml_path is None:
sys.stderr.write(f"There is no {dataset}.yaml file in {folders[fn]}")
continue
sys.stderr.write("\n")
# create env by anaconda
temp_dir, env_path, python_path, conda_activate = create_env()
"""
self._init_qlib(exp_folder_name)
# install requirements.txt
sys.stderr.write("Installing requirements.txt...\n")
with open(req_path) as f:
content = f.read()
if "torch" in content:
# automatically install pytorch according to nvidia's version
execute(
f"{python_path} -m pip install light-the-torch", wait_when_err=wait_when_err
) # for automatically installing torch according to the nvidia driver
execute(
f"{env_path / 'bin' / 'ltt'} install --install-cmd '{python_path} -m pip install {{packages}}' -- -r {req_path}",
wait_when_err=wait_when_err,
)
else:
execute(f"{python_path} -m pip install -r {req_path}", wait_when_err=wait_when_err)
sys.stderr.write("\n")
# read yaml, remove seed kwargs of model, and then save file in the temp_dir
yaml_path = gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir)
# setup gpu for tft
if fn == "TFT":
execute(
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn",
wait_when_err=wait_when_err,
)
# get all folders
folders = get_all_folders(models, exclude)
# init error messages:
errors = dict()
# run all the model for iterations
for fn in folders:
# get all files
sys.stderr.write("Retrieving files...\n")
yaml_path, req_path = get_all_files(folders[fn], dataset)
if yaml_path is None:
sys.stderr.write(f"There is no {dataset}.yaml file in {folders[fn]}")
continue
sys.stderr.write("\n")
# install qlib
sys.stderr.write("Installing qlib...\n")
execute(f"{python_path} -m pip install --upgrade pip", wait_when_err=wait_when_err) # TODO: FIX ME!
execute(f"{python_path} -m pip install --upgrade cython", wait_when_err=wait_when_err) # TODO: FIX ME!
if fn == "TFT":
execute(
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e {qlib_uri}",
wait_when_err=wait_when_err,
) # TODO: FIX ME!
else:
execute(
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e {qlib_uri}",
wait_when_err=wait_when_err,
) # TODO: FIX ME!
sys.stderr.write("\n")
# run workflow_by_config for multiple times
for i in range(times):
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
errs = execute(
f"{python_path} {env_path / 'bin' / 'qrun'} {yaml_path} {fn} {exp_folder_name}",
wait_when_err=wait_when_err,
)
if errs is not None:
_errs = errors.get(fn, {})
_errs.update({i: errs})
errors[fn] = _errs
# create env by anaconda
temp_dir, env_path, python_path, conda_activate = create_env()
# install requirements.txt
sys.stderr.write("Installing requirements.txt...\n")
with open(req_path) as f:
content = f.read()
if "torch" in content:
# automatically install pytorch according to nvidia's version
execute(
f"{python_path} -m pip install light-the-torch", wait_when_err=wait_when_err
) # for automatically installing torch according to the nvidia driver
execute(
f"{env_path / 'bin' / 'ltt'} install --install-cmd '{python_path} -m pip install {{packages}}' -- -r {req_path}",
wait_when_err=wait_when_err,
)
else:
execute(f"{python_path} -m pip install -r {req_path}", wait_when_err=wait_when_err)
sys.stderr.write("\n")
# read yaml, remove seed kwargs of model, and then save file in the temp_dir
yaml_path = gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir)
# setup gpu for tft
if fn == "TFT":
execute(
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn",
wait_when_err=wait_when_err,
)
sys.stderr.write("\n")
# install qlib
sys.stderr.write("Installing qlib...\n")
execute(f"{python_path} -m pip install --upgrade pip", wait_when_err=wait_when_err) # TODO: FIX ME!
execute(f"{python_path} -m pip install --upgrade cython", wait_when_err=wait_when_err) # TODO: FIX ME!
if fn == "TFT":
execute(
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e {qlib_uri}",
wait_when_err=wait_when_err,
) # TODO: FIX ME!
else:
execute(
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e {qlib_uri}",
wait_when_err=wait_when_err,
) # TODO: FIX ME!
sys.stderr.write("\n")
# run workflow_by_config for multiple times
for i in range(times):
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
errs = execute(
f"{python_path} {env_path / 'bin' / 'qrun'} {yaml_path} {fn} {exp_folder_name}",
wait_when_err=wait_when_err,
)
if errs is not None:
_errs = errors.get(fn, {})
_errs.update({i: errs})
errors[fn] = _errs
sys.stderr.write("\n")
# remove env
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
if wait_before_rm_env:
input("Press Enter to Continue")
shutil.rmtree(env_path)
# print errors
sys.stderr.write(f"Here are some of the errors of the models...\n")
pprint(errors)
self._collect_results(exp_folder_name, dataset)
def _collect_results(self, exp_folder_name, dataset):
folders = get_all_folders(exp_folder_name, dataset)
# getting all results
sys.stderr.write(f"Retrieving results...\n")
results = get_all_results(folders)
if len(results) > 0:
# calculating the mean and std
sys.stderr.write(f"Calculating the mean and std of results...\n")
results = cal_mean_std(results)
# generating md table
sys.stderr.write(f"Generating markdown table...\n")
gen_and_save_md_table(results, dataset)
sys.stderr.write("\n")
# remove env
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
if wait_before_rm_env:
input("Press Enter to Continue")
shutil.rmtree(env_path)
# getting all results
sys.stderr.write(f"Retrieving results...\n")
results = get_all_results(folders)
if len(results) > 0:
# calculating the mean and std
sys.stderr.write(f"Calculating the mean and std of results...\n")
results = cal_mean_std(results)
# generating md table
sys.stderr.write(f"Generating markdown table...\n")
gen_and_save_md_table(results, dataset)
sys.stderr.write("\n")
# print errors
sys.stderr.write(f"Here are some of the errors of the models...\n")
pprint(errors)
sys.stderr.write("\n")
# move results folder
shutil.move(exp_folder_name, exp_folder_name + f"_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}")
shutil.move("table.md", f"table_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}.md")
# move results folder
shutil.move(exp_folder_name, exp_folder_name + f"_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}")
shutil.move("table.md", f"table_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}.md")
if __name__ == "__main__":
fire.Fire(run) # run all the model
fire.Fire(ModelRunner) # run all the model

View File

@@ -31,10 +31,9 @@ if __name__ == "__main__":
},
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"module_path": "qlib.contrib.strategy.signal_strategy",
"kwargs": {
"model": model,
"dataset": dataset,
"signal": (model, dataset),
"topk": 50,
"n_drop": 5,
},

View File

@@ -6,6 +6,7 @@ _version_path = Path(__file__).absolute().parent / "VERSION.txt" # This file is
__version__ = _version_path.read_text(encoding="utf-8").strip()
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
import os
from typing import Union
import yaml
import logging
import platform
@@ -151,14 +152,17 @@ def init_from_yaml_conf(conf_path, **kwargs):
:param conf_path: A path to the qlib config in yml format
"""
with open(conf_path) as f:
config = yaml.safe_load(f)
if conf_path is None:
config = {}
else:
with open(conf_path) as f:
config = yaml.safe_load(f)
config.update(kwargs)
default_conf = config.pop("default_conf", "client")
init(default_conf, **config)
def get_project_path(config_name="config.yaml", cur_path=None) -> Path:
def get_project_path(config_name="config.yaml", cur_path: Union[Path, str, None] = None) -> Path:
"""
If users are building a project follow the following pattern.
- Qlib is a sub folder in project path
@@ -187,6 +191,7 @@ def get_project_path(config_name="config.yaml", cur_path=None) -> Path:
"""
if cur_path is None:
cur_path = Path(__file__).absolute().resolve()
cur_path = Path(cur_path)
while True:
if (cur_path / config_name).exists():
return cur_path
@@ -202,6 +207,40 @@ def auto_init(**kwargs):
- The parsing process will be affected by the `conf_type` of the configuration file
- Init qlib with default config
- Skip initialization if already initialized
:**kwargs: it may contain following parameters
cur_path: the start path to find the project path
Here are two examples of the configuration
Example 1)
If you want create a new project-specific config based on a shared configure, you can use `conf_type: ref`
.. code-block:: yaml
conf_type: ref
qlib_cfg: '<shared_yaml_config_path>' # this could be null reference no config from other files
# following configs in `qlib_cfg_update` is project=specific
qlib_cfg_update:
exp_manager:
class: "MLflowExpManager"
module_path: "qlib.workflow.expm"
kwargs:
uri: "file://<your mlflow experiment path>"
default_exp_name: "Experiment"
Example 2)
If you wan to create simple a stand alone config, you can use following config(a.k.a `conf_type: origin`)
.. code-block:: python
exp_manager:
class: "MLflowExpManager"
module_path: "qlib.workflow.expm"
kwargs:
uri: "file://<your mlflow experiment path>"
default_exp_name: "Experiment"
"""
kwargs["skip_if_reg"] = kwargs.get("skip_if_reg", True)
@@ -210,6 +249,7 @@ def auto_init(**kwargs):
except FileNotFoundError:
init(**kwargs)
else:
logger = get_module_logger("Initialization")
conf_pp = pp / "config.yaml"
with conf_pp.open() as f:
conf = yaml.safe_load(f)
@@ -223,8 +263,14 @@ def auto_init(**kwargs):
# - There is a shared configure file and you don't want to edit it inplace.
# - The shared configure may be updated later and you don't want to copy it.
# - You have some customized config.
qlib_conf_path = conf["qlib_cfg"]
qlib_conf_update = conf.get("qlib_cfg_update")
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs)
logger = get_module_logger("Initialization")
qlib_conf_path = conf.get("qlib_cfg", None)
# merge the arguments
qlib_conf_update = conf.get("qlib_cfg_update", {})
for k, v in kwargs.items():
if k in qlib_conf_update:
logger.warning(f"`qlib_conf_update` from conf_pp is override by `kwargs` on key '{k}'")
qlib_conf_update.update(kwargs)
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update)
logger.info(f"Auto load project config: {conf_pp}")

View File

@@ -34,6 +34,7 @@ class Exchange:
open_cost=0.0015,
close_cost=0.0025,
min_cost=5,
impact_cost=0.0,
extra_quote=None,
quote_cls=NumpyQuote,
**kwargs,
@@ -95,6 +96,7 @@ class Exchange:
**NOTE**: `trade_unit` is included in the `kwargs`. It is necessary because we must
distinguish `not set` and `disable trade_unit`
:param min_cost: min cost, default 5
:param impact_cost: market impact cost rate (a.k.a. slippage). A recommended value is 0.1.
:param extra_quote: pandas, dataframe consists of
columns: like ['$vwap', '$close', '$volume', '$factor', 'limit_sell', 'limit_buy'].
The limit indicates that the etf is tradable on a specific day.
@@ -164,9 +166,12 @@ class Exchange:
all_fields = list(all_fields | set(subscribe_fields))
self.all_fields = all_fields
self.open_cost = open_cost
self.close_cost = close_cost
self.min_cost = min_cost
self.impact_cost = impact_cost
self.limit_threshold: Union[Tuple[str, str], float, None] = limit_threshold
self.volume_threshold = volume_threshold
self.extra_quote = extra_quote
@@ -685,12 +690,14 @@ class Exchange:
f"Order clipped due to volume limitation: {order}, {[(vol, rule) for vol, rule in zip(vol_limit_num, vol_limit)]}"
)
def _get_buy_amount_by_cash_limit(self, trade_price, cash):
def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio):
"""return the real order amount after cash limit for buying.
Parameters
----------
trade_price : float
position : cash
cost_ratio : float
Return
----------
float
@@ -699,10 +706,10 @@ class Exchange:
max_trade_amount = 0
if cash >= self.min_cost:
# critical_price means the stock transaction price when the service fee is equal to min_cost.
critical_price = self.min_cost / self.open_cost + self.min_cost
critical_price = self.min_cost / cost_ratio + self.min_cost
if cash >= critical_price:
# the service fee is equal to open_cost * trade_amount
max_trade_amount = cash / (1 + self.open_cost) / trade_price
# the service fee is equal to cost_ratio * trade_amount
max_trade_amount = cash / (1 + cost_ratio) / trade_price
else:
# the service fee is equal to min_cost
max_trade_amount = (cash - self.min_cost) / trade_price
@@ -718,6 +725,7 @@ class Exchange:
:return: trade_price, trade_val, trade_cost
"""
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction)
total_trade_val = self.get_volume(order.stock_id, order.start_time, order.end_time) * trade_price
order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time)
order.deal_amount = order.amount # set to full amount and clip it step by step
# Clipping amount first
@@ -726,8 +734,12 @@ class Exchange:
# - It simulates that the large order is submitted, but partial is dealt regardless of rounding by trading unit.
self._clip_amount_by_volume(order, dealt_order_amount)
# TODO: the adjusted cost ratio can be overestimated as deal_amount will be clipped in the next steps
trade_val = order.deal_amount * trade_price
adj_cost_ratio = self.impact_cost * (trade_val / total_trade_val) ** 2
if order.direction == Order.SELL:
cost_ratio = self.close_cost
cost_ratio = self.close_cost + adj_cost_ratio
# sell
# if we don't know current position, we choose to sell all
# Otherwise, we clip the amount based on current position
@@ -750,14 +762,18 @@ class Exchange:
self.logger.debug(f"Order clipped due to cash limitation: {order}")
elif order.direction == Order.BUY:
cost_ratio = self.open_cost
cost_ratio = self.open_cost + adj_cost_ratio
# buy
if position is not None:
cash = position.get_cash()
trade_val = order.deal_amount * trade_price
if cash < trade_val + max(trade_val * cost_ratio, self.min_cost):
if cash < max(trade_val * cost_ratio, self.min_cost):
# cash cannot cover cost
order.deal_amount = 0
self.logger.debug(f"Order clipped due to cost higher than cash: {order}")
elif cash < trade_val + max(trade_val * cost_ratio, self.min_cost):
# The money is not enough
max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash)
max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash, cost_ratio)
order.deal_amount = self.round_amount_by_trade_unit(
min(max_buy_amount, order.deal_amount), order.factor
)

View File

@@ -160,6 +160,11 @@ class NumpyQuote(BaseQuote):
if is_single_value(start_time, end_time, self.freq, self.region):
# this is a very special case.
# skip aggregating function to speed-up the query calculation
# FIXME:
# it will go to the else logic when it comes to the
# 1) the day before holiday when daily trading
# 2) the last minute of the day when intraday trading
try:
return self.data[stock_id].loc[start_time, field]
except KeyError:

View File

@@ -345,15 +345,19 @@ class Position(BasePosition):
if stock_id not in self.position:
raise KeyError("{} not in current position".format(stock_id))
else:
# decrease the amount of stock
self.position[stock_id]["amount"] -= trade_amount
# check if to delete
if self.position[stock_id]["amount"] < -1e-5:
raise ValueError(
"only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, trade_amount)
)
elif abs(self.position[stock_id]["amount"]) <= 1e-5:
if np.isclose(self.position[stock_id]["amount"], trade_amount):
# Selling all the stocks
# we use np.isclose instead of abs(<the final amount>) <= 1e-5 because `np.isclose` consider both ralative amount and absolute amount
# Using abs(<the final amount>) <= 1e-5 will result in error when the amount is large
self._del_stock(stock_id)
else:
# decrease the amount of stock
self.position[stock_id]["amount"] -= trade_amount
# check if to delete
if self.position[stock_id]["amount"] < -1e-5:
raise ValueError(
"only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, trade_amount)
)
new_cash = trade_val - cost
if self._settle_type == self.ST_CASH:

102
qlib/backtest/signal.py Normal file
View File

@@ -0,0 +1,102 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from qlib.utils import init_instance_by_config
from typing import Dict, List, Text, Tuple, Union
from ..model.base import BaseModel
from ..data.dataset import Dataset
from ..data.dataset.utils import convert_index_format
from ..utils.resam import resam_ts_data
import pandas as pd
import abc
class Signal(metaclass=abc.ABCMeta):
"""
Some trading strategy make decisions based on other prediction signals
The signals may comes from different sources(e.g. prepared data, online prediction from model and dataset)
This interface is tries to provide unified interface for those different sources
"""
@abc.abstractmethod
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame, None]:
"""
get the signal at the end of the decision step(from `start_time` to `end_time`)
Returns
-------
Union[pd.Series, pd.DataFrame, None]:
returns None if no signal in the specific day
"""
...
class SignalWCache(Signal):
"""
Signal With pandas with based Cache
SignalWCache will store the prepared signal as a attribute and give the according signal based on input query
"""
def __init__(self, signal: Union[pd.Series, pd.DataFrame]):
"""
Parameters
----------
signal : Union[pd.Series, pd.DataFrame]
The expected format of the signal is like the data below (the order of index is not important and can be automatically adjusted)
instrument datetime
SH600000 2008-01-02 0.079704
2008-01-03 0.120125
2008-01-04 0.878860
2008-01-07 0.505539
2008-01-08 0.395004
"""
self.signal_cache = convert_index_format(signal, level="datetime")
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame]:
# the frequency of the signal may not algin with the decision frequency of strategy
# so resampling from the data is necessary
# the latest signal leverage more recent data and therefore is used in trading.
signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last")
return signal
class ModelSignal(SignalWCache):
def __init__(self, model: BaseModel, dataset: Dataset):
self.model = model
self.dataset = dataset
pred_scores = self.model.predict(dataset)
if isinstance(pred_scores, pd.DataFrame):
pred_scores = pred_scores.iloc[:, 0]
super().__init__(pred_scores)
def _update_model(self):
"""
When using online data, update model in each bar as the following steps:
- update dataset with online data, the dataset should support online update
- make the latest prediction scores of the new bar
- update the pred score into the latest prediction
"""
# TODO: this method is not included in the framework and could be refactor later
raise NotImplementedError("_update_model is not implemented!")
def create_signal_from(
obj: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame]
) -> Signal:
"""
create signal from diverse information
This method will choose the right method to create a signal based on `obj`
Please refer to the code below.
"""
if isinstance(obj, Signal):
return obj
elif isinstance(obj, (tuple, list)):
return ModelSignal(*obj)
elif isinstance(obj, (dict, str)):
return init_instance_by_config(obj)
elif isinstance(obj, (pd.DataFrame, pd.Series)):
return SignalWCache(signal=obj)
else:
raise NotImplementedError(f"This type of signal is not supported")

View File

@@ -70,7 +70,7 @@ class TradeCalendarManager:
- If self.trade_step >= self.self.trade_len, it means the trading is finished
- If self.trade_step < self.self.trade_len, it means the number of trading step finished is self.trade_step
"""
return self.trade_step >= self.trade_len
return self.trade_step >= self.trade_len - 1
def step(self):
if self.finished():

View File

View File

@@ -0,0 +1,183 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
from typing import Dict, Iterable
def align_index(df_dict, join):
res = {}
for k, df in df_dict.items():
if join is not None and k != join:
df = df.reindex(df_dict[join].index)
res[k] = df
return res
# Mocking the pd.DataFrame class
class SepDataFrame:
"""
(Sep)erate DataFrame
We usually concat multiple dataframe to be processed together(Such as feature, label, weight, filter).
However, they are usally be used seperately at last.
This will result in extra cost for concating and spliting data(reshaping and copying data in the memory is very expensive)
SepDataFrame tries to act like a DataFrame whose column with multiindex
"""
def __init__(self, df_dict: Dict[str, pd.DataFrame], join: str, skip_align=False):
"""
initialize the data based on the dataframe dictionary
Parameters
----------
df_dict : Dict[str, pd.DataFrame]
dataframe dictionary
join : str
how to join the data
It will reindex the dataframe based on the join key.
If join is None, the reindex step will be skipped
skip_align :
for some cases, we can improve performance by skipping aligning index
"""
self.join = join
if skip_align:
self._df_dict = df_dict
else:
self._df_dict = align_index(df_dict, join)
@property
def loc(self):
return SDFLoc(self, join=self.join)
@property
def index(self):
return self._df_dict[self.join].index
def apply_each(self, method: str, skip_align=True, *args, **kwargs):
"""
Assumptions:
- inplace methods will return None
"""
inplace = False
df_dict = {}
for k, df in self._df_dict.items():
df_dict[k] = getattr(df, method)(*args, **kwargs)
if df_dict[k] is None:
inplace = True
if not inplace:
return SepDataFrame(df_dict=df_dict, join=self.join, skip_align=skip_align)
def sort_index(self, *args, **kwargs):
return self.apply_each("sort_index", True, *args, **kwargs)
def copy(self, *args, **kwargs):
return self.apply_each("copy", True, *args, **kwargs)
def _update_join(self):
if self.join not in self:
self.join = next(iter(self._df_dict.keys()))
def __getitem__(self, item):
return self._df_dict[item]
def __setitem__(self, item: str, df: pd.DataFrame):
# TODO: consider the join behavior
self._df_dict[item] = df
def __delitem__(self, item: str):
del self._df_dict[item]
self._update_join()
def __contains__(self, item):
return item in self._df_dict
def __len__(self):
return len(self._df_dict[self.join])
def droplevel(self, *args, **kwargs):
raise NotImplementedError(f"Please implement the `droplevel` method")
@property
def columns(self):
dfs = []
for k, df in self._df_dict.items():
df = df.head(0)
df.columns = pd.MultiIndex.from_product([[k], df.columns])
dfs.append(df)
return pd.concat(dfs, axis=1).columns
# Useless methods
@staticmethod
def merge(df_dict: Dict[str, pd.DataFrame], join: str):
all_df = df_dict[join]
for k, df in df_dict.items():
if k != join:
all_df = all_df.join(df)
return all_df
class SDFLoc:
"""Mock Class"""
def __init__(self, sdf: SepDataFrame, join):
self._sdf = sdf
self.axis = None
self.join = join
def __call__(self, axis):
self.axis = axis
return self
def __getitem__(self, args):
if self.axis == 1:
if isinstance(args, str):
return self._sdf[args]
elif isinstance(args, (tuple, list)):
new_df_dict = {k: self._sdf[k] for k in args}
return SepDataFrame(new_df_dict, join=self.join if self.join in args else args[0], skip_align=True)
else:
raise NotImplementedError(f"This type of input is not supported")
elif self.axis == 0:
return SepDataFrame(
{k: df.loc(axis=0)[args] for k, df in self._sdf._df_dict.items()}, join=self.join, skip_align=True
)
else:
df = self._sdf
if isinstance(args, tuple):
ax0, *ax1 = args
if len(ax1) == 0:
ax1 = None
if ax1 is not None:
df = df.loc(axis=1)[ax1]
if ax0 is not None:
df = df.loc(axis=0)[ax0]
return df
else:
return df.loc(axis=0)[args]
# Patch pandas DataFrame
# Tricking isinstance to accept SepDataFrame as its subclass
import builtins
def _isinstance(instance, cls):
if isinstance_orig(instance, SepDataFrame): # pylint: disable=E0602
if isinstance(cls, Iterable):
for c in cls:
if c is pd.DataFrame:
return True
elif cls is pd.DataFrame:
return True
return isinstance_orig(instance, cls) # pylint: disable=E0602
builtins.isinstance_orig = builtins.isinstance
builtins.isinstance = _isinstance
if __name__ == "__main__":
sdf = SepDataFrame({}, join=None)
print(isinstance(sdf, (pd.DataFrame,)))
print(isinstance(sdf, pd.DataFrame))

View File

@@ -38,6 +38,8 @@ class CatBoostModel(Model, FeatureInt):
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
if df_train.empty or df_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]

View File

@@ -64,6 +64,8 @@ class DEnsembleModel(Model, FeatureInt):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
if df_train.empty or df_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
x_train, y_train = df_train["feature"], df_train["label"]
# initialize the sample weights
N, F = x_train.shape

View File

@@ -25,6 +25,8 @@ class LGBModel(ModelFT, LightGBMFInt):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
if df_train.empty or df_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
@@ -83,6 +85,8 @@ class LGBModel(ModelFT, LightGBMFInt):
"""
# Based on existing model and finetune by train more rounds
dtrain, _ = self._prepare_data(dataset)
if dtrain.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
self.model = lgb.train(
self.params,
dtrain,

View File

@@ -82,6 +82,8 @@ class HFLGBModel(ModelFT, LightGBMFInt):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
if df_train.empty or df_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_train["feature"], df_valid["label"]

View File

@@ -51,6 +51,8 @@ class LinearModel(Model):
def fit(self, dataset: DatasetH):
df_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
if df_train.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
X, y = df_train["feature"].values, np.squeeze(df_train["label"].values)
if self.estimator in [self.OLS, self.RIDGE, self.LASSO]:

View File

@@ -224,6 +224,8 @@ class ALSTM(Model):
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
if df_train.empty or df_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]

View File

@@ -207,6 +207,8 @@ class ALSTM(Model):
):
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
if dl_train.empty or dl_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader

View File

@@ -237,6 +237,8 @@ class GATs(Model):
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
if df_train.empty or df_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]

View File

@@ -245,6 +245,8 @@ class GATs(Model):
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
if dl_train.empty or dl_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader

View File

@@ -224,6 +224,8 @@ class GRU(Model):
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
if df_train.empty or df_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]

View File

@@ -206,6 +206,8 @@ class GRU(Model):
):
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
if dl_train.empty or dl_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader

View File

@@ -176,6 +176,8 @@ class LocalformerModel(Model):
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
if df_train.empty or df_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]

View File

@@ -153,6 +153,8 @@ class LocalformerModel(Model):
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
if dl_train.empty or dl_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader

View File

@@ -219,6 +219,8 @@ class LSTM(Model):
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
if df_train.empty or df_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]

View File

@@ -201,6 +201,8 @@ class LSTM(Model):
):
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
if dl_train.empty or dl_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader

View File

@@ -374,6 +374,8 @@ class SFM(Model):
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
if df_train.empty or df_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]

View File

@@ -169,6 +169,8 @@ class TabnetModel(Model):
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
if df_train.empty or df_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
df_train.fillna(df_train.mean(), inplace=True)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]

View File

@@ -61,8 +61,9 @@ class TCTS(Model):
weight_lr=5e-7,
steps=3,
GPU=0,
seed=0,
target_label=0,
mode="soft",
seed=None,
lowest_valid_performance=0.993,
**kwargs
):
@@ -87,6 +88,7 @@ class TCTS(Model):
self.weight_lr = weight_lr
self.steps = steps
self.target_label = target_label
self.mode = mode
self.lowest_valid_performance = lowest_valid_performance
self._fore_optimizer = fore_optimizer
self._weight_optimizer = weight_optimizer
@@ -100,6 +102,8 @@ class TCTS(Model):
"\nn_epochs : {}"
"\nbatch_size : {}"
"\nearly_stop : {}"
"\ntarget_label : {}"
"\nmode : {}"
"\nloss_type : {}"
"\nvisible_GPU : {}"
"\nuse_GPU : {}"
@@ -111,6 +115,8 @@ class TCTS(Model):
n_epochs,
batch_size,
early_stop,
target_label,
mode,
loss,
GPU,
self.use_gpu,
@@ -120,9 +126,17 @@ class TCTS(Model):
def loss_fn(self, pred, label, weight):
loc = torch.argmax(weight, 1)
loss = (pred - label[np.arange(weight.shape[0]), loc]) ** 2
return torch.mean(loss)
if self.mode == "hard":
loc = torch.argmax(weight, 1)
loss = (pred - label[np.arange(weight.shape[0]), loc]) ** 2
return torch.mean(loss)
elif self.mode == "soft":
loss = (pred - label.transpose(0, 1)) ** 2
return torch.mean(loss * weight.transpose(0, 1))
else:
raise NotImplementedError("mode {} is not supported!".format(self.mode))
def train_epoch(self, x_train, y_train, x_valid, y_valid):
@@ -132,6 +146,10 @@ class TCTS(Model):
indices = np.arange(len(x_train_values))
np.random.shuffle(indices)
task_embedding = torch.zeros([self.batch_size, self.output_dim])
task_embedding[:, self.target_label] = 1
task_embedding = task_embedding.to(self.device)
init_fore_model = copy.deepcopy(self.fore_model)
for p in init_fore_model.parameters():
p.init_fore_model = False
@@ -155,12 +173,13 @@ class TCTS(Model):
init_pred = init_fore_model(feature)
pred = self.fore_model(feature)
dis = init_pred - label.transpose(0, 1)
weight_feature = torch.cat((feature, dis.transpose(0, 1), label, init_pred.view(-1, 1)), 1)
weight_feature = torch.cat(
(feature, dis.transpose(0, 1), label, init_pred.view(-1, 1), task_embedding), 1
)
weight = self.weight_model(weight_feature)
loss = self.loss_fn(pred, label, weight) # hard
loss = self.loss_fn(pred, label, weight)
self.fore_optimizer.zero_grad()
loss.backward()
@@ -188,11 +207,11 @@ class TCTS(Model):
pred = self.fore_model(feature)
dis = pred - label.transpose(0, 1)
weight_feature = torch.cat((feature, dis.transpose(0, 1), label, pred.view(-1, 1)), 1)
weight_feature = torch.cat((feature, dis.transpose(0, 1), label, pred.view(-1, 1), task_embedding), 1)
weight = self.weight_model(weight_feature)
loc = torch.argmax(weight, 1)
valid_loss = torch.mean((pred - label[:, 0]) ** 2)
loss = torch.mean(-valid_loss * torch.log(weight[np.arange(weight.shape[0]), loc]))
valid_loss = torch.mean((pred - label[:, abs(self.target_label)]) ** 2)
loss = torch.mean(valid_loss * torch.log(weight[np.arange(weight.shape[0]), loc]))
self.weight_optimizer.zero_grad()
loss.backward()
@@ -207,7 +226,6 @@ class TCTS(Model):
self.fore_model.eval()
scores = []
losses = []
indices = np.arange(len(x_values))
@@ -237,6 +255,8 @@ class TCTS(Model):
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
if df_train.empty or df_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
@@ -277,7 +297,7 @@ class TCTS(Model):
dropout=self.dropout,
)
self.weight_model = MLPModel(
d_feat=360 + 2 * self.output_dim + 1,
d_feat=360 + 3 * self.output_dim + 1,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
dropout=self.dropout,
@@ -303,8 +323,6 @@ class TCTS(Model):
best_loss = np.inf
best_epoch = 0
stop_round = 0
fore_best_param = copy.deepcopy(self.fore_optimizer.state_dict())
weight_best_param = copy.deepcopy(self.weight_optimizer.state_dict())
for epoch in range(self.n_epochs):
print("Epoch:", epoch)

View File

@@ -74,7 +74,7 @@ class TRAModel(Model):
lamb=0.0,
rho=0.99,
alpha=1.0,
seed=0,
seed=None,
logdir=None,
eval_train=False,
eval_test=False,
@@ -99,8 +99,9 @@ class TRAModel(Model):
if transport_method == "router" and not eval_train:
self.logger.warning("`eval_train` will be ignored when using TRA.router")
np.random.seed(seed)
torch.manual_seed(seed)
if seed is not None:
np.random.seed(seed)
torch.manual_seed(seed)
self.model_config = model_config
self.tra_config = tra_config

View File

@@ -175,6 +175,8 @@ class TransformerModel(Model):
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
if df_train.empty or df_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]

View File

@@ -151,6 +151,9 @@ class TransformerModel(Model):
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
if dl_train.empty or dl_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
from .model_strategy import (
from .signal_strategy import (
TopkDropoutStrategy,
WeightStrategyBase,
)

View File

@@ -6,7 +6,7 @@ This strategy is not well maintained
from .order_generator import OrderGenWInteract
from .model_strategy import WeightStrategyBase
from .signal_strategy import WeightStrategyBase
import copy

View File

@@ -80,18 +80,22 @@ class OrderGenWInteract(OrderGenerator):
:rtype: list
"""
if target_weight_position is None:
return []
# calculate current_tradable_value
current_amount_dict = current.get_stock_amount_dict()
current_total_value = trade_exchange.calculate_amount_position_value(
amount_dict=current_amount_dict,
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
start_time=trade_start_time,
end_time=trade_end_time,
only_tradable=False,
)
current_tradable_value = trade_exchange.calculate_amount_position_value(
amount_dict=current_amount_dict,
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
start_time=trade_start_time,
end_time=trade_end_time,
only_tradable=True,
)
# add cash
@@ -105,9 +109,7 @@ class OrderGenWInteract(OrderGenerator):
# value. Then just sell all the stocks
target_amount_dict = copy.deepcopy(current_amount_dict.copy())
for stock_id in list(target_amount_dict.keys()):
if trade_exchange.is_stock_tradable(
stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time
):
if trade_exchange.is_stock_tradable(stock_id, start_time=trade_start_time, end_time=trade_end_time):
del target_amount_dict[stock_id]
else:
# consider cost rate
@@ -118,16 +120,16 @@ class OrderGenWInteract(OrderGenerator):
target_amount_dict = trade_exchange.generate_amount_position_from_weight_position(
weight_position=target_weight_position,
cash=current_tradable_value,
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
start_time=trade_start_time,
end_time=trade_end_time,
)
order_list = trade_exchange.generate_order_for_target_amount_position(
target_position=target_amount_dict,
current_position=current_amount_dict,
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
start_time=trade_start_time,
end_time=trade_end_time,
)
return TradeDecisionWO(order_list, self)
return order_list
class OrderGenWOInteract(OrderGenerator):
@@ -163,8 +165,11 @@ class OrderGenWOInteract(OrderGenerator):
:param trade_date:
:type trade_date: pd.Timestamp
:rtype: list
:rtype: list of generated orders
"""
if target_weight_position is None:
return []
risk_total_value = risk_degree * current.calculate_value()
current_stock = current.get_stock_list()
@@ -172,13 +177,17 @@ class OrderGenWOInteract(OrderGenerator):
for stock_id in target_weight_position:
# Current rule will ignore the stock that not hold and cannot be traded at predict date
if trade_exchange.is_stock_tradable(
stock_id=stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time
stock_id=stock_id, start_time=trade_start_time, end_time=trade_end_time
) and trade_exchange.is_stock_tradable(
stock_id=stock_id, start_time=pred_start_time, end_time=pred_end_time
):
amount_dict[stock_id] = (
risk_total_value
* target_weight_position[stock_id]
/ trade_exchange.get_close(stock_id, trade_start_time=pred_start_time, trade_end_time=pred_end_time)
/ trade_exchange.get_close(stock_id, start_time=pred_start_time, end_time=pred_end_time)
)
# TODO: Qlib use None to represent trading suspension. So last close price can't be the estimated trading price.
# Maybe a close price with forward fill will be a better solution.
elif stock_id in current_stock:
amount_dict[stock_id] = (
risk_total_value * target_weight_position[stock_id] / current.get_stock_price(stock_id)
@@ -188,7 +197,7 @@ class OrderGenWOInteract(OrderGenerator):
order_list = trade_exchange.generate_order_for_target_amount_position(
target_position=amount_dict,
current_position=current.get_stock_amount_dict(),
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
start_time=trade_start_time,
end_time=trade_end_time,
)
return TradeDecisionWO(order_list, self)
return order_list

View File

@@ -1,3 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
import warnings
import numpy as np

View File

@@ -1,27 +1,33 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import copy
from qlib.backtest.signal import Signal, create_signal_from
from typing import Dict, List, Text, Tuple, Union
from qlib.data.dataset import Dataset
from qlib.model.base import BaseModel
from qlib.backtest.position import Position
import warnings
import numpy as np
import pandas as pd
from ...utils.resam import resam_ts_data
from ...strategy.base import ModelStrategy
from ...strategy.base import BaseStrategy
from ...backtest.decision import Order, BaseTradeDecision, OrderDir, TradeDecisionWO
from .order_generator import OrderGenWInteract
class TopkDropoutStrategy(ModelStrategy):
class TopkDropoutStrategy(BaseStrategy):
# TODO:
# 1. Supporting leverage the get_range_limit result from the decision
# 2. Supporting alter_outer_trade_decision
# 3. Supporting checking the availability of trade decision
def __init__(
self,
model,
dataset,
*,
topk,
n_drop,
signal: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame] = None,
method_sell="bottom",
method_buy="top",
risk_degree=0.95,
@@ -30,6 +36,8 @@ class TopkDropoutStrategy(ModelStrategy):
trade_exchange=None,
level_infra=None,
common_infra=None,
model=None,
dataset=None,
**kwargs,
):
"""
@@ -39,6 +47,9 @@ class TopkDropoutStrategy(ModelStrategy):
the number of stocks in the portfolio.
n_drop : int
number of stocks to be replaced in each trading date.
signal :
the information to describe a signal. Please refer to the docs of `qlib.backtest.signal.create_signal_from`
the decision of the strategy will base on the given signal
method_sell : str
dropout method_sell, random/bottom.
method_buy : str
@@ -64,7 +75,7 @@ class TopkDropoutStrategy(ModelStrategy):
"""
super(TopkDropoutStrategy, self).__init__(
model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
)
self.topk = topk
self.n_drop = n_drop
@@ -74,6 +85,13 @@ class TopkDropoutStrategy(ModelStrategy):
self.hold_thresh = hold_thresh
self.only_tradable = only_tradable
# This is trying to be compatible with previous version of qlib task config
if model is not None and dataset is not None:
warnings.warn("`model` `dataset` is deprecated; use `signal`.", DeprecationWarning)
signal = model, dataset
self.signal: Signal = create_signal_from(signal)
def get_risk_degree(self, trade_step=None):
"""get_risk_degree
Return the proportion of your total value you will used in investment.
@@ -87,7 +105,7 @@ class TopkDropoutStrategy(ModelStrategy):
trade_step = self.trade_calendar.get_trade_step()
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time)
if pred_score is None:
return TradeDecisionWO([], self)
if self.only_tradable:
@@ -235,15 +253,15 @@ class TopkDropoutStrategy(ModelStrategy):
return TradeDecisionWO(sell_order_list + buy_order_list, self)
class WeightStrategyBase(ModelStrategy):
class WeightStrategyBase(BaseStrategy):
# TODO:
# 1. Supporting leverage the get_range_limit result from the decision
# 2. Supporting alter_outer_trade_decision
# 3. Supporting checking the availability of trade decision
def __init__(
self,
model,
dataset,
*,
signal: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame],
order_generator_cls_or_obj=OrderGenWInteract,
trade_exchange=None,
level_infra=None,
@@ -251,6 +269,9 @@ class WeightStrategyBase(ModelStrategy):
**kwargs,
):
"""
signal :
the information to describe a signal. Please refer to the docs of `qlib.backtest.signal.create_signal_from`
the decision of the strategy will base on the given signal
trade_exchange : Exchange
exchange that provides market info, used to deal order and generate report
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
@@ -260,13 +281,15 @@ class WeightStrategyBase(ModelStrategy):
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
"""
super(WeightStrategyBase, self).__init__(
model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
)
if isinstance(order_generator_cls_or_obj, type):
self.order_generator = order_generator_cls_or_obj()
else:
self.order_generator = order_generator_cls_or_obj
self.signal: Signal = create_signal_from(signal)
def get_risk_degree(self, trade_step=None):
"""get_risk_degree
Return the proportion of your total value you will used in investment.
@@ -298,7 +321,7 @@ class WeightStrategyBase(ModelStrategy):
trade_step = self.trade_calendar.get_trade_step()
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time)
if pred_score is None:
return TradeDecisionWO([], self)
current_temp = copy.deepcopy(self.trade_position)

View File

@@ -49,7 +49,7 @@ class MultiSegRecord(RecordTemp):
if save:
save_name = "results-{:}.pkl".format(key)
self.recorder.save_objects(**{save_name: results})
self.save(**{save_name: results})
logger.info(
"The record '{:}' has been saved as the artifact of the Experiment {:}".format(
save_name, self.recorder.experiment_id
@@ -57,22 +57,20 @@ class MultiSegRecord(RecordTemp):
)
class SignalMseRecord(SignalRecord):
class SignalMseRecord(RecordTemp):
"""
This is the Signal MSE Record class that computes the mean squared error (MSE).
This class inherits the ``SignalMseRecord`` class.
"""
artifact_path = "sig_analysis"
depend_cls = SignalRecord
def __init__(self, recorder, **kwargs):
super().__init__(recorder=recorder, **kwargs)
def generate(self, **kwargs):
try:
self.check(parent=True)
except FileExistsError:
super().generate()
def generate(self):
self.check()
pred = self.load("pred.pkl")
label = self.load("label.pkl")
@@ -81,9 +79,8 @@ class SignalMseRecord(SignalRecord):
metrics = {"MSE": mse, "RMSE": np.sqrt(mse)}
objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)}
self.recorder.log_metrics(**metrics)
self.recorder.save_objects(**objects, artifact_path=self.get_path())
self.save(**objects)
logger.info("The evaluation results in SignalMseRecord is {:}".format(metrics))
def list(self):
paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")]
return paths
return ["mse.pkl", "rmse.pkl"]

View File

@@ -320,6 +320,7 @@ class TSDataSampler:
self.flt_data = flt_data.values
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
self.idx_map = self.idx_map2arr(self.idx_map)
self.start_idx, self.end_idx = self.data_index.slice_locs(
start=time_to_slc_point(start), end=time_to_slc_point(end)
@@ -328,6 +329,25 @@ class TSDataSampler:
del self.data # save memory
@staticmethod
def idx_map2arr(idx_map):
# pytorch data sampler will have better memory control without large dict or list
# - https://github.com/pytorch/pytorch/issues/13243
# - https://github.com/airctic/icevision/issues/613
# So we convert the dict into int array.
# The arr_map is expected to behave the same as idx_map
dtype = np.int32
# set a index out of bound to indicate the none existing
no_existing_idx = (np.iinfo(dtype).max, np.iinfo(dtype).max)
max_idx = max(idx_map.keys())
arr_map = []
for i in range(max_idx + 1):
arr_map.append(idx_map.get(i, no_existing_idx))
arr_map = np.array(arr_map, dtype=dtype)
return arr_map
@staticmethod
def flt_idx_map(flt_data, idx_map):
idx = 0
@@ -385,6 +405,10 @@ class TSDataSampler:
idx_map[real_idx] = (i, j)
return idx_df, idx_map
@property
def empty(self):
return self.__len__() == 0
def _get_indices(self, row: int, col: int) -> np.array:
"""
get series indices of self.data_arr from the row, col indices of self.idx_df
@@ -520,20 +544,18 @@ class TSDatasetH(DatasetH):
def setup_data(self, **kwargs):
super().setup_data(**kwargs)
# make sure the calendar is updated to latest when loading data from new config
cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique()
cal = sorted(cal)
self.cal = cal
self.cal = sorted(cal)
def _prepare_raw_seg(self, slc: slice, **kwargs) -> pd.DataFrame:
@staticmethod
def _extend_slice(slc: slice, cal: list, step_len: int) -> slice:
# Dataset decide how to slice data(Get more data for timeseries).
start, end = slc.start, slc.stop
start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start))
pad_start_idx = max(0, start_idx - self.step_len)
pad_start = self.cal[pad_start_idx]
# TSDatasetH will retrieve more data for complete
data = super()._prepare_seg(slice(pad_start, end), **kwargs)
return data
start_idx = bisect.bisect_left(cal, pd.Timestamp(start))
pad_start_idx = max(0, start_idx - step_len)
pad_start = cal[pad_start_idx]
return slice(pad_start, end)
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
"""
@@ -542,13 +564,15 @@ class TSDatasetH(DatasetH):
dtype = kwargs.pop("dtype", None)
start, end = slc.start, slc.stop
flt_col = kwargs.pop("flt_col", None)
# TSDatasetH will retrieve more data for complete
data = self._prepare_raw_seg(slc, **kwargs)
# TSDatasetH will retrieve more data for complete time-series
ext_slice = self._extend_slice(slc, self.cal, self.step_len)
data = super()._prepare_seg(ext_slice, **kwargs)
flt_kwargs = deepcopy(kwargs)
if flt_col is not None:
flt_kwargs["col_set"] = flt_col
flt_data = self._prepare_raw_seg(slc, **flt_kwargs)
flt_data = self._prepare_seg(ext_slice, **flt_kwargs)
assert len(flt_data.columns) == 1
else:
flt_data = None

View File

@@ -82,8 +82,6 @@ class DataHandler(Serializable):
fetch_orig : bool
Return the original data instead of copy if possible.
"""
# Set logger
self.logger = get_module_logger("DataHandler")
# Setup data loader
assert data_loader is not None # to make start_time end_time could have None default value
@@ -302,6 +300,7 @@ class DataHandlerLP(DataHandler):
DK_R = "raw"
DK_I = "infer"
DK_L = "learn"
ATTR_MAP = {DK_R: "_data", DK_I: "_infer", DK_L: "_learn"}
# process type
PTYPE_I = "independent"
@@ -543,7 +542,7 @@ class DataHandlerLP(DataHandler):
raise AttributeError(
"DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data"
)
df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key])
df = getattr(self, self.ATTR_MAP[data_key])
return df
def fetch(
@@ -624,3 +623,33 @@ class DataHandlerLP(DataHandler):
df = self._get_df_by_key(data_key).head()
df = fetch_df_by_col(df, col_set)
return df.columns.to_list()
@classmethod
def cast(cls, handler: "DataHandlerLP") -> "DataHandlerLP":
"""
Motivation
- A user create a datahandler in his customized package. Then he want to share the processed handler to other users without introduce the package dependency and complicated data processing logic.
- This class make it possible by casting the class to DataHandlerLP and only keep the processed data
Parameters
----------
handler : DataHandlerLP
A subclass of DataHandlerLP
Returns
-------
DataHandlerLP:
the converted processed data
"""
new_hd: DataHandlerLP = object.__new__(DataHandlerLP)
new_hd.from_cast = True # add a mark for the casted instance
for key in list(DataHandlerLP.ATTR_MAP.values()) + [
"instruments",
"start_time",
"end_time",
"fetch_orig",
"drop_raw",
]:
setattr(new_hd, key, getattr(handler, key, None))
return new_hd

View File

@@ -13,7 +13,7 @@ import pandas as pd
from typing import Union, List, Type
from scipy.stats import percentileofscore
from .base import Expression, ExpressionOps
from .base import Expression, ExpressionOps, Feature
from ..log import get_module_logger
from ..utils import get_callable_kwargs
@@ -1485,6 +1485,7 @@ OpsList = [
IdxMax,
IdxMin,
If,
Feature,
]
@@ -1517,7 +1518,7 @@ class OpsWrapper:
else:
_ops_class = _operator
if not issubclass(_ops_class, ExpressionOps):
if not issubclass(_ops_class, Expression):
raise TypeError("operator must be subclass of ExpressionOps, not {}".format(_ops_class))
if _ops_class.__name__ in self._ops:

View File

@@ -70,9 +70,9 @@ def fill_placeholder(config: dict, config_extend: dict):
# bfs
top = 0
tail = 1
item_quene = [config]
item_queue = [config]
while top < tail:
now_item = item_quene[top]
now_item = item_queue[top]
top += 1
if isinstance(now_item, list):
item_keys = range(len(now_item))
@@ -80,9 +80,9 @@ def fill_placeholder(config: dict, config_extend: dict):
item_keys = now_item.keys()
for key in item_keys:
if isinstance(now_item[key], list) or isinstance(now_item[key], dict):
item_quene.append(now_item[key])
item_queue.append(now_item[key])
tail += 1
elif now_item[key] in config_extend.keys():
elif isinstance(now_item[key], str) and now_item[key] in config_extend.keys():
now_item[key] = config_extend[now_item[key]]
return config
@@ -114,10 +114,19 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
task_config = fill_placeholder(task_config, placehorder_value)
# generate records: prediction, backtest, and analysis
records = task_config.get("record", [])
if isinstance(records, dict): # prevent only one dict
if isinstance(records, dict): # uniform the data format to list
records = [records]
for record in records:
r = init_instance_by_config(record, recorder=rec)
# Some recorder require the parameter `model` and `dataset`.
# try to automatically pass in them to the initialization function
# to make defining the tasking easier
r = init_instance_by_config(
record,
recorder=rec,
default_module="qlib.workflow.record_temp",
try_kwargs={"model": model, "dataset": dataset},
)
r.generate()
return rec

View File

@@ -7,6 +7,7 @@ if TYPE_CHECKING:
from qlib.backtest.exchange import Exchange
from qlib.backtest.position import BasePosition
from typing import List, Tuple, Union
import pandas as pd
from ..model.base import BaseModel
from ..data.dataset import DatasetH
@@ -16,7 +17,7 @@ from ..utils import init_instance_by_config
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
from ..backtest.decision import BaseTradeDecision
__all__ = ["BaseStrategy", "ModelStrategy", "RLStrategy", "RLIntStrategy"]
__all__ = ["BaseStrategy", "RLStrategy", "RLIntStrategy"]
class BaseStrategy:
@@ -193,43 +194,6 @@ class BaseStrategy:
return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1])
class ModelStrategy(BaseStrategy):
"""Model-based trading strategy, use model to make predictions for trading"""
def __init__(
self,
model: BaseModel,
dataset: DatasetH,
outer_trade_decision: BaseTradeDecision = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
**kwargs,
):
"""
Parameters
----------
model : BaseModel
the model used in when making predictions
dataset : DatasetH
provide test data for model
kwargs : dict
arguments that will be passed into `reset` method
"""
super(ModelStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
self.model = model
self.dataset = dataset
self.pred_scores = convert_index_format(self.model.predict(dataset), level="datetime")
def _update_model(self):
"""
When using online data, pdate model in each bar as the following steps:
- update dataset with online data, the dataset should support online update
- make the latest prediction scores of the new bar
- update the pred score into the latest prediction
"""
raise NotImplementedError("_update_model is not implemented!")
class RLStrategy(BaseStrategy):
"""RL-based strategy"""

View File

@@ -44,36 +44,10 @@ class TestAutoData(unittest.TestCase):
)
provider_uri_map = {"1min": cls.provider_uri_1min, "day": provider_uri_day}
client_config = {
"calendar_provider": {
"class": "LocalCalendarProvider",
"module_path": "qlib.data.data",
"kwargs": {
"backend": {
"class": "FileCalendarStorage",
"module_path": "qlib.data.storage.file_storage",
"kwargs": {"provider_uri_map": provider_uri_map},
}
},
},
"feature_provider": {
"class": "LocalFeatureProvider",
"module_path": "qlib.data.data",
"kwargs": {
"backend": {
"class": "FileFeatureStorage",
"module_path": "qlib.data.storage.file_storage",
"kwargs": {"provider_uri_map": provider_uri_map},
}
},
},
}
init(
provider_uri=cls.provider_uri,
provider_uri=provider_uri_map,
region=REG_CN,
expression_cache=None,
dataset_cache=None,
**client_config,
**cls._setup_kwargs,
)

View File

@@ -35,6 +35,10 @@ RECORD_CONFIG = [
{
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
"kwargs": {
"dataset": "<DATASET>",
"model": "<MODEL>",
},
},
{
"class": "SigAnaRecord",

View File

@@ -27,7 +27,7 @@ import collections
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Union, Tuple, Any, Text, Optional
from typing import Dict, Union, Tuple, Any, Text, Optional
from types import ModuleType
from urllib.parse import urlparse
@@ -199,6 +199,7 @@ def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, Mod
----------
config : [dict, str]
similar to config
please refer to the doc of init_instance_by_config
default_module : Python module or str
It should be a python module to load the class type
@@ -219,9 +220,12 @@ def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, Mod
_callable = config["class"] # the class type itself is passed in
kwargs = config.get("kwargs", {})
elif isinstance(config, str):
module = get_module_by_module_path(default_module)
# a.b.c.ClassName
*m_path, cls = config.split(".")
m_path = ".".join(m_path)
module = get_module_by_module_path(default_module if m_path == "" else m_path)
_callable = getattr(module, config)
_callable = getattr(module, cls)
kwargs = {}
else:
raise NotImplementedError(f"This type of input is not supported")
@@ -232,7 +236,11 @@ get_cls_kwargs = get_callable_kwargs # NOTE: this is for compatibility for the
def init_instance_by_config(
config: Union[str, dict, object], default_module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs
config: Union[str, dict, object],
default_module=None,
accept_types: Union[type, Tuple[type]] = (),
try_kwargs: Dict = {},
**kwargs,
) -> Any:
"""
get initialized instance with config
@@ -256,7 +264,9 @@ def init_instance_by_config(
1) specify a pickle object
- path like 'file:///<path to pickle file>/obj.pkl'
2) specify a class name
- "ClassName": getattr(module, config)() will be used.
- "ClassName": getattr(module, "ClassName")() will be used.
3) specify module path with class name
- "a.b.c.ClassName" getattr(<a.b.c.module>, "ClassName")() will be used.
object example:
instance of accept_types
default_module : Python module
@@ -270,6 +280,10 @@ def init_instance_by_config(
Optional. If the config is a instance of specific type, return the config directly.
This will be passed into the second parameter of isinstance.
try_kwargs: Dict
Try to pass in kwargs in `try_kwargs` when initialized the instance
If error occurred, it will fail back to initialization without try_kwargs.
Returns
-------
object:
@@ -286,7 +300,33 @@ def init_instance_by_config(
return pickle.load(f)
klass, cls_kwargs = get_callable_kwargs(config, default_module=default_module)
return klass(**cls_kwargs, **kwargs)
try:
return klass(**cls_kwargs, **try_kwargs, **kwargs)
except (TypeError,):
# TypeError for handling errors like
# 1: `XXX() got multiple values for keyword argument 'YYY'`
# 2: `XXX() got an unexpected keyword argument 'YYY'
return klass(**cls_kwargs, **kwargs)
@contextlib.contextmanager
def class_casting(obj: object, cls: type):
"""
Python doesn't provide the downcasting mechanism.
We use the trick here to downcast the class
Parameters
----------
obj : object
the object to be cast
cls : type
the target class type
"""
orig_cls = obj.__class__
obj.__class__ = cls
yield
obj.__class__ = orig_cls
def compare_dict_value(src_data: dict, dst_data: dict):

View File

@@ -18,3 +18,9 @@ class LoadObjectError(QlibException):
"""Error type for Recorder when can not load object"""
pass
class ExpAlreadyExistError(Exception):
"""Experiment already exists"""
pass

View File

@@ -401,6 +401,10 @@ class IndexData(metaclass=index_data_ops_creator):
def columns(self):
return self.indices[1]
def __getitem__(self, args):
# NOTE: this tries to behave like a numpy array to be compatible with numpy aggregating function like nansum and nanmean
return self.iloc[args]
def _align_indices(self, other: "IndexData") -> "IndexData":
"""
Align all indices of `other` to `self` before performing the arithmetic operations.
@@ -409,7 +413,7 @@ class IndexData(metaclass=index_data_ops_creator):
Parameters
----------
other : "IndexData"
the index in `other` is to be chagned
the index in `other` is to be changed
Returns
-------
@@ -455,7 +459,8 @@ class IndexData(metaclass=index_data_ops_creator):
"""
return len(self.data)
def sum(self, axis=None):
def sum(self, axis=None, dtype=None, out=None):
assert out is None and dtype is None, "`out` is just for compatible with numpy's aggregating function"
# FIXME: weird logic and not general
if axis is None:
return np.nansum(self.data)
@@ -468,7 +473,8 @@ class IndexData(metaclass=index_data_ops_creator):
else:
raise ValueError(f"axis must be None, 0 or 1")
def mean(self, axis=None):
def mean(self, axis=None, dtype=None, out=None):
assert out is None and dtype is None, "`out` is just for compatible with numpy's aggregating function"
# FIXME: weird logic and not general
if axis is None:
return np.nanmean(self.data)

View File

@@ -1,9 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
from functools import partial
from threading import Thread
from typing import Callable
from joblib import Parallel, delayed
from joblib._parallel_backends import MultiprocessingBackend
import pandas as pd
from queue import Queue
class ParallelExt(Parallel):
@@ -46,3 +52,54 @@ def datetime_groupby_apply(df, apply_func, axis=0, level="datetime", resample_ru
return pd.concat(dfs, axis=axis).sort_index()
else:
return _naive_group_apply(df)
class AsyncCaller:
"""
This AsyncCaller tries to make it easier to async call
Currently, it is used in MLflowRecorder to make functions like `log_params` async
NOTE:
- This caller didn't consider the return value
"""
STOP_MARK = "__STOP"
def __init__(self) -> None:
self._q = Queue()
self._stop = False
self._t = Thread(target=self.run)
self._t.start()
def close(self):
self._q.put(self.STOP_MARK)
def run(self):
while True:
data = self._q.get()
if data == self.STOP_MARK:
break
else:
data()
def __call__(self, func, *args, **kwargs):
self._q.put(partial(func, *args, **kwargs))
def wait(self, close=True):
if close:
self.close()
self._t.join()
@staticmethod
def async_dec(ac_attr):
def decorator_func(func):
def wrapper(self, *args, **kwargs):
if isinstance(getattr(self, ac_attr, None), Callable):
return getattr(self, ac_attr)(func, self, *args, **kwargs)
else:
return func(self, *args, **kwargs)
return wrapper
return decorator_func

View File

@@ -2,8 +2,8 @@
# Licensed under the MIT License.
from contextlib import contextmanager
from typing import Text, Optional
from .expm import MLflowExpManager
from typing import Any, Dict, Text, Optional
from .expm import ExpManager
from .exp import Experiment
from .recorder import Recorder
from ..utils import Wrapper
@@ -16,7 +16,7 @@ class QlibRecorder:
"""
def __init__(self, exp_manager):
self.exp_manager = exp_manager
self.exp_manager: ExpManager = exp_manager
def __repr__(self):
return "{name}(manager={manager})".format(name=self.__class__.__name__, manager=self.exp_manager)
@@ -334,6 +334,26 @@ class QlibRecorder:
"""
self.exp_manager.set_uri(uri)
@contextmanager
def uri_context(self, uri: Text):
"""
Temporarily set the exp_manager's uri to uri
NOTE:
- Please refer to the NOTE in the `set_uri`
Parameters
----------
uri : Text
the temporal uri
"""
prev_uri = self.exp_manager._current_uri
self.exp_manager.set_uri(uri)
try:
yield
finally:
self.exp_manager.set_uri(prev_uri)
def get_recorder(
self, *, recorder_id=None, recorder_name=None, experiment_id=None, experiment_name=None
) -> Recorder:
@@ -360,11 +380,11 @@ class QlibRecorder:
.. code-block:: Python
# Case 1
with R.start('test'):
with R.start(experiment_name='test'):
recorder = R.get_recorder()
# Case 2
with R.start('test'):
with R.start(experiment_name='test'):
recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d')
# Case 3
@@ -413,12 +433,18 @@ class QlibRecorder:
"""
self.get_exp().delete_recorder(recorder_id, recorder_name)
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
def save_objects(self, local_path=None, artifact_path=None, **kwargs: Dict[Text, Any]):
"""
Method for saving objects as artifacts in the experiment to the uri. It supports either saving
from a local file/directory, or directly saving objects. User can use valid python's keywords arguments
to specify the object to be saved as well as its name (name: value).
In summary, this API is designs for saving **objects** to **the experiments management backend path**,
1. Qlib provide two methods to specify **objects**
- Passing in the object directly by passing with `**kwargs` (e.g. R.save_objects(trained_model=model))
- Passing in the local path to the object, i.e. `local_path` parameter.
2. `artifact_path` represents the **the experiments management backend path**
- If `active recorder` exists: it will save the objects through the active recorder.
- If `active recorder` not exists: the system will create a default experiment, and a new recorder and save objects under it.
@@ -431,13 +457,20 @@ class QlibRecorder:
.. code-block:: Python
# Case 1
with R.start('test'):
with R.start(experiment_name='test'):
pred = model.predict(dataset)
R.save_objects(**{"pred.pkl": pred}, artifact_path='prediction')
rid = R.get_recorder().id
...
R.get_recorder(recorder_id=rid).load_object("prediction/pred.pkl") # after saving objects, you can load the previous object with this api
# Case 2
with R.start('test'):
R.save_objects(local_path='results/pred.pkl')
with R.start(experiment_name='test'):
R.save_objects(local_path='results/pred.pkl', artifact_path="prediction")
rid = R.get_recorder().id
...
R.get_recorder(recorder_id=rid).load_object("prediction/pred.pkl") # after saving objects, you can load the previous object with this api
Parameters
----------
@@ -445,7 +478,14 @@ class QlibRecorder:
if provided, them save the file or directory to the artifact URI.
artifact_path : str
the relative path for the artifact to be stored in the URI.
**kwargs: Dict[Text, Any]
the object to be saved.
For example, `{"pred.pkl": pred}`
"""
if local_path is not None and len(kwargs) > 0:
raise ValueError(
"You can choose only one of `local_path`(save the files in a path) or `kwargs`(pass in the objects directly)"
)
self.get_exp().get_recorder().save_objects(local_path, artifact_path, **kwargs)
def load_object(self, name: Text):

View File

@@ -4,7 +4,7 @@
from urllib.parse import urlparse
import mlflow
from filelock import FileLock
from mlflow.exceptions import MlflowException
from mlflow.exceptions import MlflowException, RESOURCE_ALREADY_EXISTS, ErrorCode
from mlflow.entities import ViewType
import os, logging
from pathlib import Path
@@ -15,6 +15,7 @@ from .exp import MLflowExperiment, Experiment
from ..config import C
from .recorder import Recorder
from ..log import get_module_logger
from ..utils.exceptions import ExpAlreadyExistError
logger = get_module_logger("workflow", logging.INFO)
@@ -94,6 +95,10 @@ class ExpManager:
Returns
-------
An experiment object.
Raise
-----
ExpAlreadyExistError
"""
raise NotImplementedError(f"Please implement the `create_exp` method.")
@@ -200,7 +205,14 @@ class ExpManager:
if pr.scheme == "file":
with FileLock(os.path.join(pr.netloc, pr.path, "filelock")) as f:
return self.create_exp(experiment_name), True
return self.create_exp(experiment_name), True
# NOTE: for other schemes like http, we double check to avoid create exp conflicts
try:
return self.create_exp(experiment_name), True
except ExpAlreadyExistError:
return (
self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),
False,
)
def _get_exp(self, experiment_id=None, experiment_name=None) -> Experiment:
"""
@@ -345,10 +357,15 @@ class MLflowExpManager(ExpManager):
def create_exp(self, experiment_name: Optional[Text] = None):
assert experiment_name is not None
# init experiment
experiment_id = self.client.create_experiment(experiment_name)
try:
experiment_id = self.client.create_experiment(experiment_name)
except MlflowException as e:
if e.error_code == ErrorCode.Name(RESOURCE_ALREADY_EXISTS):
raise ExpAlreadyExistError()
raise e
experiment = MLflowExperiment(experiment_id, experiment_name, self.uri)
experiment._default_name = self._default_exp_name
return experiment
def _get_exp(self, experiment_id=None, experiment_name=None):

View File

@@ -21,19 +21,65 @@ Situations Description
Online + Trainer When you want to do a REAL routine, the Trainer will help you train the models. It
will train models task by task and strategy by strategy.
Online + DelayTrainer When your models don't have any temporal dependence, the DelayTrainer will train
nothing until all tasks have been prepared. It makes user can train all tasks in
the end of `routine` or `first_train`.
Online + DelayTrainer DelayTrainer will skip concrete training until all tasks have been prepared by
different strategies. It makes users can parallelly train all tasks at the end of
`routine` or `first_train`. Otherwise, these functions will get stuck when each
strategy prepare tasks.
Simulation + Trainer When your models have some temporal dependence on the previous models, then you
need to consider using Trainer. This means it will REAL train your models in
every routine and prepare signals for every routine.
Simulation + Trainer It will behave in the same way as `Online + Trainer`. The only difference is that it
is for simulation/backtesting instead of online trading
Simulation + DelayTrainer When your models don't have any temporal dependence, you can use DelayTrainer
for the ability to multitasking. It means all tasks in all routines
can be REAL trained at the end of simulating. The signals will be prepared well at
different time segments (based on whether or not any new model is online).
========================= ===================================================================================
Here is some pseudo code the demonstrate the workflow of each situation
For simplicity
- Only one strategy is used in the strategy
- `update_online_pred` is only called in the online mode and is ignored
1) `Online + Trainer`
.. code-block:: python
tasks = first_train()
models = trainer.train(tasks)
trainer.end_train(models)
for day in online_trading_days:
# OnlineManager.routine
models = trainer.train(strategy.prepare_tasks()) # for each strategy
strategy.prepare_online_models(models) # for each strategy
trainer.end_train(models)
prepare_signals() # prepare trading signals daily
`Online + DelayTrainer`: the workflow is the same as `Online + Trainer`.
2) `Simulation + DelayTrainer`
.. code-block:: python
# simulate
tasks = first_train()
models = trainer.train(tasks)
for day in historical_calendars:
# OnlineManager.routine
models = trainer.train(strategy.prepare_tasks()) # for each strategy
strategy.prepare_online_models(models) # for each strategy
# delay_prepare()
# FIXME: Currently the delay_prepare is not implemented in a proper way.
trainer.end_train(<for all previous models>)
prepare_signals()
# Can we simplify current workflow?
- Can reduce the number of state of tasks?
- For each task, we have three phases (i.e. task, partly trained task, final trained task)
"""
import logging
@@ -58,7 +104,7 @@ class OnlineManager(Serializable):
"""
STATUS_SIMULATING = "simulating" # when calling `simulate`
STATUS_NORMAL = "normal" # the normal status
STATUS_ONLINE = "online" # the normal status. It is used when online trading
def __init__(
self,
@@ -87,12 +133,24 @@ class OnlineManager(Serializable):
self.begin_time = pd.Timestamp(begin_time)
self.cur_time = self.begin_time
# OnlineManager will recorder the history of online models, which is a dict like {pd.Timestamp, {strategy, [online_models]}}.
# It records the online servnig models of each strategy for each day.
self.history = {}
if trainer is None:
trainer = TrainerR()
self.trainer = trainer
self.signals = None
self.status = self.STATUS_NORMAL
self.status = self.STATUS_ONLINE
def _postpone_action(self):
"""
Should the workflow to postpone the following actions to the end (in delay_prepare)
- trainer.end_train
- prepare_signals
Postpone these actions is to support simulating/backtest online strategies without time dependencies.
All the actions can be done parallelly at the end.
"""
return self.status == self.STATUS_SIMULATING and self.trainer.is_delay()
def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = {}):
"""
@@ -113,12 +171,12 @@ class OnlineManager(Serializable):
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
models_list.append(models)
self.logger.info(f"Finished training {len(models)} models.")
# FIXME: Traing multiple online models at `first_train` will result in getting too much online models at the
# FIXME: Train multiple online models at `first_train` will result in getting too much online models at the
# start.
online_models = strategy.prepare_online_models(models, **model_kwargs)
self.history.setdefault(self.cur_time, {})[strategy] = online_models
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
if not self._postpone_action():
for strategy, models in zip(strategies, models_list):
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
@@ -160,10 +218,10 @@ class OnlineManager(Serializable):
# The online model may changes in the above processes
# So updating the predictions of online models should be the last step
if self.status == self.STATUS_NORMAL:
if self.status == self.STATUS_ONLINE:
strategy.tool.update_online_pred()
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
if not self._postpone_action():
for strategy, models in zip(self.strategies, models_list):
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
self.prepare_signals(**signal_kwargs)
@@ -278,13 +336,13 @@ class OnlineManager(Serializable):
signal_kwargs=signal_kwargs,
)
# delay prepare the models and signals
if self.trainer.is_delay():
if self._postpone_action():
self.delay_prepare(model_kwargs=model_kwargs, signal_kwargs=signal_kwargs)
# FIXME: get logging level firstly and restore it here
set_global_logger_level(logging.DEBUG)
self.logger.info(f"Finished preparing signals")
self.status = self.STATUS_NORMAL
self.status = self.STATUS_ONLINE
return self.get_signals()
def delay_prepare(self, model_kwargs={}, signal_kwargs={}):
@@ -295,6 +353,8 @@ class OnlineManager(Serializable):
model_kwargs: the params for `end_train`
signal_kwargs: the params for `prepare_signals`
"""
# FIXME:
# This method is not implemented in the proper way!!!
last_models = {}
signals_time = D.calendar()[0]
need_prepare = False

View File

@@ -9,6 +9,9 @@ import pandas as pd
from pathlib import Path
from pprint import pprint
from typing import Union, List
from collections import defaultdict
from qlib.utils.exceptions import LoadObjectError
from ..contrib.evaluate import indicator_analysis, risk_analysis, indicator_analysis
from ..data.dataset import DatasetH
@@ -16,7 +19,7 @@ from ..data.dataset.handler import DataHandlerLP
from ..backtest import backtest as normal_backtest
from ..utils import init_instance_by_config, get_module_by_module_path
from ..log import get_module_logger
from ..utils import flatten_dict
from ..utils import flatten_dict, class_casting
from ..utils.time import Freq
from ..strategy.base import BaseStrategy
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
@@ -32,6 +35,7 @@ class RecordTemp:
"""
artifact_path = None
depend_cls = None # the depend class of the record; the record will depend on the results generated by `depend_cls`
@classmethod
def get_path(cls, path=None):
@@ -44,6 +48,16 @@ class RecordTemp:
return "/".join(names)
def save(self, **kwargs):
"""
It behaves the same as self.recorder.save_objects.
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
"""
art_path = self.get_path()
if art_path == "":
art_path = None
self.recorder.save_objects(artifact_path=art_path, **kwargs)
def __init__(self, recorder):
self._recorder = recorder
@@ -66,31 +80,37 @@ class RecordTemp:
"""
raise NotImplementedError(f"Please implement the `generate` method.")
def load(self, name):
def load(self, name: str, parents: bool = True):
"""
Load the stored records. Due to the fact that some problems occured when we tried to balancing a clean API
with the Python's inheritance. This method has to be used in a rather ugly way, and we will try to fix them
in the future::
sar = SigAnaRecord(recorder)
ic = sar.load(sar.get_path("ic.pkl"))
It behaves the same as self.recorder.load_object.
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
Parameters
----------
name : str
the name for the file to be load.
parents : bool
Each recorder has different `artifact_path`.
So parents recursively find the path in parents
Sub classes has higher priority
Return
------
The stored records.
"""
# try to load the saved object
obj = self.recorder.load_object(name)
return obj
try:
return self.recorder.load_object(self.get_path(name))
except LoadObjectError:
if parents:
if self.depend_cls is not None:
with class_casting(self, self.depend_cls):
return self.load(name, parents=True)
def list(self):
"""
List the supported artifacts.
Users don't have to consider self.get_path
Return
------
@@ -98,21 +118,45 @@ class RecordTemp:
"""
return []
def check(self, cls="self"):
def check(self, include_self: bool = False, parents: bool = True):
"""
Check if the records is properly generated and saved.
It is useful in following examples
- checking if the depended files complete before genrating new things.
- checking if the final files is completed
Parameters
----------
include_self : bool
is the file generated by self included
parents : bool
will we check parents
Raise
------
FileExistsError: whether the records are stored properly.
FileNotFoundError
: whether the records are stored properly.
"""
artifacts = set(self.recorder.list_artifacts())
if cls == "self":
cls = self
flist = cls.list()
for item in flist:
if item not in artifacts:
raise FileExistsError(item)
if include_self:
# Some mlflow backend will not list the directly recursively.
# So we force to the directly
artifacts = {}
def _get_arts(dirn):
if dirn not in artifacts:
artifacts[dirn] = self.recorder.list_artifacts(dirn)
return artifacts[dirn]
for item in self.list():
ps = self.get_path(item).split("/")
dirn, fn = "/".join(ps[:-1]), ps[-1]
if self.get_path(item) not in _get_arts(dirn):
raise FileNotFoundError
if parents:
if self.depend_cls is not None:
with class_casting(self, self.depend_cls):
self.check(include_self=True)
class SignalRecord(RecordTemp):
@@ -127,26 +171,20 @@ class SignalRecord(RecordTemp):
@staticmethod
def generate_label(dataset):
# NOTE:
# Python doesn't provide the downcasting mechanism.
# We use the trick here to downcast the class
orig_cls = dataset.__class__
dataset.__class__ = DatasetH
params = dict(segments="test", col_set="label", data_key=DataHandlerLP.DK_R)
try:
# Assume the backend handler is DataHandlerLP
raw_label = dataset.prepare(**params)
except TypeError:
# The argument number is not right
del params["data_key"]
# The backend handler should be DataHandler
raw_label = dataset.prepare(**params)
except AttributeError:
# The data handler is initialize with `drop_raw=True`...
# So raw_label is not available
raw_label = None
dataset.__class__ = orig_cls
with class_casting(dataset, DatasetH):
params = dict(segments="test", col_set="label", data_key=DataHandlerLP.DK_R)
try:
# Assume the backend handler is DataHandlerLP
raw_label = dataset.prepare(**params)
except TypeError:
# The argument number is not right
del params["data_key"]
# The backend handler should be DataHandler
raw_label = dataset.prepare(**params)
except AttributeError:
# The data handler is initialize with `drop_raw=True`...
# So raw_label is not available
raw_label = None
return raw_label
def generate(self, **kwargs):
@@ -154,7 +192,7 @@ class SignalRecord(RecordTemp):
pred = self.model.predict(self.dataset)
if isinstance(pred, pd.Series):
pred = pred.to_frame("score")
self.recorder.save_objects(**{"pred.pkl": pred})
self.save(**{"pred.pkl": pred})
logger.info(
f"Signal record 'pred.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
@@ -165,15 +203,11 @@ class SignalRecord(RecordTemp):
if isinstance(self.dataset, DatasetH):
raw_label = self.generate_label(self.dataset)
self.recorder.save_objects(**{"label.pkl": raw_label})
self.save(**{"label.pkl": raw_label})
@staticmethod
def list():
def list(self):
return ["pred.pkl", "label.pkl"]
def load(self, name="pred.pkl"):
return super().load(name)
class HFSignalRecord(SignalRecord):
"""
@@ -214,19 +248,11 @@ class HFSignalRecord(SignalRecord):
}
)
self.recorder.log_metrics(**metrics)
self.recorder.save_objects(**objects, artifact_path=self.get_path())
self.save(**objects)
pprint(metrics)
def list(self):
paths = [
self.get_path("ic.pkl"),
self.get_path("ric.pkl"),
self.get_path("long_pre.pkl"),
self.get_path("short_pre.pkl"),
self.get_path("long_short_r.pkl"),
self.get_path("long_avg_r.pkl"),
]
return paths
return ["ic.pkl", "ric.pkl", "long_pre.pkl", "short_pre.pkl", "long_short_r.pkl", "long_avg_r.pkl"]
class SigAnaRecord(RecordTemp):
@@ -235,16 +261,26 @@ class SigAnaRecord(RecordTemp):
"""
artifact_path = "sig_analysis"
pre_class = SignalRecord
depend_cls = SignalRecord
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0):
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0, skip_existing=False):
super().__init__(recorder=recorder)
self.ana_long_short = ana_long_short
self.ann_scaler = ann_scaler
self.label_col = label_col
self.skip_existing = skip_existing
def generate(self, **kwargs):
self.check(self.pre_class)
if self.skip_existing:
try:
self.check(include_self=True, parents=False)
except FileNotFoundError:
pass # continue to generating metrics
else:
logger.info("The results has previously generated, generation skipped.")
return
self.check()
pred = self.load("pred.pkl")
label = self.load("label.pkl")
@@ -276,13 +312,13 @@ class SigAnaRecord(RecordTemp):
}
)
self.recorder.log_metrics(**metrics)
self.recorder.save_objects(**objects, artifact_path=self.get_path())
self.save(**objects)
pprint(metrics)
def list(self):
paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl")]
paths = ["ic.pkl", "ric.pkl"]
if self.ana_long_short:
paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")])
paths.extend(["long_short_r.pkl", "long_avg_r.pkl"])
return paths
@@ -369,17 +405,11 @@ class PortAnaRecord(RecordTemp):
executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config
)
for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items():
self.recorder.save_objects(
**{f"report_normal_{_freq}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()
)
self.recorder.save_objects(
**{f"positions_normal_{_freq}.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path()
)
self.save(**{f"report_normal_{_freq}.pkl": report_normal})
self.save(**{f"positions_normal_{_freq}.pkl": positions_normal})
for _freq, indicators_normal in indicator_dict.items():
self.recorder.save_objects(
**{f"indicators_normal_{_freq}.pkl": indicators_normal}, artifact_path=PortAnaRecord.get_path()
)
self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal})
for _analysis_freq in self.risk_analysis_freq:
if _analysis_freq not in portfolio_metric_dict:
@@ -401,9 +431,7 @@ class PortAnaRecord(RecordTemp):
analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict())
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
# save results
self.recorder.save_objects(
**{f"port_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
)
self.save(**{f"port_analysis_{_analysis_freq}.pkl": analysis_df})
logger.info(
f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
)
@@ -428,9 +456,7 @@ class PortAnaRecord(RecordTemp):
analysis_dict = analysis_df["value"].to_dict()
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
# save results
self.recorder.save_objects(
**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
)
self.save(**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
logger.info(
f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
)
@@ -442,20 +468,19 @@ class PortAnaRecord(RecordTemp):
for _freq in self.all_freq:
list_path.extend(
[
PortAnaRecord.get_path(f"report_normal_{_freq}.pkl"),
PortAnaRecord.get_path(f"positions_normal_{_freq}.pkl"),
f"report_normal_{_freq}.pkl",
f"positions_normal_{_freq}.pkl",
]
)
for _analysis_freq in self.risk_analysis_freq:
if _analysis_freq in self.all_freq:
list_path.append(PortAnaRecord.get_path(f"port_analysis_{_analysis_freq}.pkl"))
list_path.append(f"port_analysis_{_analysis_freq}.pkl")
else:
warnings.warn(f"risk_analysis freq {_analysis_freq} is not found")
for _analysis_freq in self.indicator_analysis_freq:
if _analysis_freq in self.all_freq:
list_path.append(PortAnaRecord.get_path(f"indicator_analysis_{_analysis_freq}.pkl"))
list_path.append(f"indicator_analysis_{_analysis_freq}.pkl")
else:
warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found")
return list_path

View File

@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from qlib.utils.serial import Serializable
import mlflow, logging
import shutil, os, pickle, tempfile, codecs, pickle
@@ -8,8 +9,10 @@ from pathlib import Path
from datetime import datetime
from qlib.utils.exceptions import LoadObjectError
from qlib.utils.paral import AsyncCaller
from ..utils.objm import FileManager
from ..log import get_module_logger
from ..log import TimeInspector, get_module_logger
from mlflow.store.artifact.azure_blob_artifact_repo import AzureBlobArtifactRepository
logger = get_module_logger("workflow", logging.INFO)
@@ -65,6 +68,8 @@ class Recorder:
Save objects such as prediction file or model checkpoints to the artifact URI. User
can save object through keywords arguments (name:value).
Please refer to the docs of qlib.workflow:R.save_objects
Parameters
----------
local_path : str
@@ -225,6 +230,7 @@ class MLflowRecorder(Recorder):
if mlflow_run.info.end_time is not None
else None
)
self.async_log = None
def __repr__(self):
name = self.__class__.__name__
@@ -283,6 +289,10 @@ class MLflowRecorder(Recorder):
self.status = Recorder.STATUS_R
logger.info(f"Recorder {self.id} starts running under Experiment {self.experiment_id} ...")
# NOTE: making logging async.
# - This may cause delay when uploading results
# - The logging time may not be accurate
self.async_log = AsyncCaller()
return run
def end_run(self, status: str = Recorder.STATUS_S):
@@ -296,6 +306,9 @@ class MLflowRecorder(Recorder):
self.end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
if self.status != Recorder.STATUS_S:
self.status = status
with TimeInspector.logt("waiting `async_log`"):
self.async_log.wait()
self.async_log = None
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
@@ -331,18 +344,27 @@ class MLflowRecorder(Recorder):
try:
path = self.client.download_artifacts(self.id, name)
with Path(path).open("rb") as f:
return pickle.load(f)
data = pickle.load(f)
ar = self.client._tracking_client._get_artifact_repo(self.id)
if isinstance(ar, AzureBlobArtifactRepository):
# for saving disk space
# For safety, only remove redundant file for specific ArtifactRepository
shutil.rmtree(Path(path).absolute().parent)
return data
except Exception as e:
raise LoadObjectError(message=str(e))
@AsyncCaller.async_dec(ac_attr="async_log")
def log_params(self, **kwargs):
for name, data in kwargs.items():
self.client.log_param(self.id, name, data)
@AsyncCaller.async_dec(ac_attr="async_log")
def log_metrics(self, step=None, **kwargs):
for name, data in kwargs.items():
self.client.log_metric(self.id, name, data, step=step)
@AsyncCaller.async_dec(ac_attr="async_log")
def set_tags(self, **kwargs):
for name, data in kwargs.items():
self.client.set_tag(self.id, name, data)

Some files were not shown because too many files have changed in this diff Show More