1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-13 01:11:00 +08:00

Compare commits

..

67 Commits

Author SHA1 Message Date
Fivele-Li
5af99e1d3f optimize log (#1612) 2023-08-01 18:57:48 +08:00
Fivele-Li
70a066baf8 optimize workflow and output format 2023-07-20 12:15:04 +08:00
Xu Yang
f93f331a3b Merge pull request #1609 from microsoft/xuyang1/finetune_prompts
finetune prompts
2023-07-19 20:01:07 +08:00
Xu Yang
561086d9e1 commit 2023-07-19 20:00:09 +08:00
Young
8eb129358b Add prompt logger 2023-07-18 21:47:58 +08:00
Xu Yang
ce8cb517e9 hot fix one small bug in template 2023-07-18 11:52:43 +08:00
Xu Yang
1c5a73aa81 small refinement in finance knowledge 2023-07-17 21:33:40 +08:00
Xu Yang
d909d54362 Merge pull request #1603 from microsoft/xuyang1/add_idea_task
add idea task and round1
2023-07-17 20:38:43 +08:00
Xu Yang
13c63eee0a merge into one commit 2023-07-17 20:33:47 +08:00
you-n-g
b21e044513 Fix find class bug (#1601) 2023-07-17 20:09:13 +08:00
Fivele-Li
8c1905d1d7 Optimize KnowledgeBase to complete workflow (#1598)
* optimize KnowledgeBase to complete workflow;
* Update Knowledge methods of handle data IO;
* Update task to handle multi recorders;
* Integrate Knowledge to workflow;

* optimize KnowledgeBase to complete workflow
* Update TrainTask & AnalyseTask's recorder method;
* Update SummarizeTask;
* Update Workflow & Topic prompt;
2023-07-17 18:17:04 +08:00
you-n-g
1c9841b15e Connect TrainTask & Rolling & DDG-DA (#1599)
* Connect train task to ddg-da & rolling

* Pylint & black formatting

* Formatting
2023-07-17 09:58:58 +08:00
Xu Yang
5e0873ca81 Merge pull request #1592 from Fivele-Li/update_knowledge_module
update knowledge module;
2023-07-16 11:36:31 +08:00
Cadenza-Li
8a56cf69b4 add KnowledgeBase to workflow;
* Update CMDTask prompt example for Windows OS;
* Windows OS decode output of subprocess in gbk by default, specify encoding format explict;
* Add KnowledgeBase's 4 knowledge types to corresponding task;
2023-07-14 22:25:43 +08:00
you-n-g
a19e616bc3 Update test_utils.py 2023-07-14 16:43:43 +08:00
Cadenza-Li
025859acba Merge branch 'finco' into update_knowledge_module 2023-07-14 16:19:57 +08:00
Xu Yang
e5f685ce08 merge all commit (#1593)
Co-authored-by: Xu Yang <xuyang1@microsoft.com>
2023-07-14 16:17:24 +08:00
Cadenza-Li
b9b6938e71 Merge branch 'finco' into update_knowledge_module 2023-07-14 14:20:21 +08:00
Young
51a9403b15 Merge remote-tracking branch 'origin/main' into finco 2023-07-14 12:16:51 +08:00
Cadenza-Li
37d83fd747 update knowledge module;
* Knowledge.storage to storages list;
* optimize Knowledge & Storage save and load method;
* optimize Knowledge query prompt;
2023-07-13 17:20:22 +08:00
Cadenza-Li
d7ab6935dd update knowledge module;
* add storage class;
* new practice,execute,finance,infrastructure knowledge;
* add query method to KnowledgeBase;
2023-07-12 17:23:47 +08:00
Fivele-Li
effed382e9 Optimize prompt for entire learn loop (#1589)
* Adjust prompt and fix cases
* adjust summarizeTask & learn prompts;
* fix typos & drop duplicate task method;

* adjust learn prompts;
2023-07-11 18:13:52 +08:00
Fivele-Li
86ffd1799d Add knowledge module and tune summarizeTask (#1582)
* Add knowledge module
* add KnowledgeExperiment add KnowledgeBase;
* add knowledge associate prompts to template;

* Add Topic class
* add Topic to summarize knowledge;
* add recorder's metric to summarizeTask;

---------

Co-authored-by: Cadenza-Li <362237642@qq.com>
2023-07-06 11:39:36 +08:00
Young
aef11536e3 rename & test 2023-07-04 20:28:08 +08:00
Xu Yang
8b0fdf1623 Merge pull request #1581 from microsoft/xuyang1/fix_singleton_bug
fix singleton bug
2023-07-04 16:51:51 +08:00
Xu Yang
9a36f8da20 fix singleton bug 2023-07-04 16:20:02 +08:00
Xu Yang
b7757d5008 Merge pull request #1580 from microsoft/xuyang1/refine_workflow_to_increase_success_rate
refine workflow to increase success rate
2023-07-03 17:59:54 +08:00
Xu Yang
ee5e5cfdd8 remove useless code 2023-07-03 17:57:13 +08:00
Xu Yang
6cb87ecfd1 refine code to use qrun 2023-07-03 17:56:22 +08:00
Xu Yang
9119bcdd3c Merge pull request #1576 from microsoft/xuyang1/add_config_and_code_dump_task
refine workflow and prompts
2023-06-30 14:43:49 +08:00
Xu Yang
4fccf8112d fix one workflow 2023-06-30 14:33:41 +08:00
Xu Yang
73bd79ca1a merge into one commit 2023-06-30 14:23:40 +08:00
Fivele-Li
7e84f3aae2 Add backtest and backforward task (#1568)
* * add TrainTask & BacktestTask;
* add BackForwardTask;
* adjust prompt_template.yaml which default config failed to backtest;
* run workflow in loop
* add update method to prompt_template.py

* remove debug code

* Adjust Learn Process
* add LearnManager class & use LearnManager to update system prompt;
* use qrun to replace recorder for training and backtesting;

* Adjust analyser
* analyser independent of recorder;
* rename analyser's workspace attribution;
* analyser load variable by recorder.

---------

Co-authored-by: Cadenza-Li <362237642@qq.com>
2023-06-30 10:04:43 +08:00
Fivele-Li
1326ac614d Add docs to context and retrieve (#1566)
* add analyser docstring to context;
* add retrieve method to context manager;

* add notes to retrieve
2023-06-24 21:47:27 +08:00
Fivele-Li
f12184cc0f Add analyser task and optimize interact (#1552)
* * optimize interact
* add AnalyserTask
* optimize logger format and add render feature

* format optimize
2023-06-16 11:42:45 +08:00
Xu Yang
a70386ad52 Merge pull request #1550 from microsoft/xuyang1/refine_task_prompts
add datahandler and design action task according to component
2023-06-14 14:52:42 +08:00
Xu Yang
74619ed8d8 fix using defaut in record strategy and backtest 2023-06-14 14:52:16 +08:00
Fivele-Li
1a523df007 Optimize log and interact of FinCo (#1549)
* use FinCoLog for a better interact experience

* addition file changes

* optimize format

* optimize format
2023-06-14 14:48:17 +08:00
Xu Yang
f9cc8a5aaa remove useless prompt 2023-06-14 10:46:38 +08:00
Xu Yang
7762c5a1fd add datahandler and design action task according to component 2023-06-13 23:28:27 +08:00
Xu Yang
fa7ef29281 Merge pull request #1548 from microsoft/xuyang1/add_dump_to_file_task
add simple readme & move prompt templates to outer yaml file to make the code clean
2023-06-13 15:29:13 +08:00
Xu Yang
429c9a7c66 format 2023-06-13 15:27:59 +08:00
Xu Yang
80fbc00792 move prompt templates to yaml file to make code clean 2023-06-13 15:21:19 +08:00
Xu Yang
01accec24c update code 2023-06-12 16:25:16 +08:00
Fivele-Li
1d88830b0d Add recorder task and visualize (#1542)
* add recorder task

* add batch generate summarize report unittest.

* * add recorder to RecorderTask;
* add matplot figure to analyzer.py

* add image to markdown;

* Add some log

* update figure path.

---------

Co-authored-by: Young <afe.young@gmail.com>
Co-authored-by: Cadenza-Li <362237642@qq.com>
2023-06-12 15:48:00 +08:00
you-n-g
ad7498e287 Edit yaml task (#1538)
* Edit yaml task

* update comments
2023-06-02 00:44:41 +08:00
you-n-g
73d51f05b4 Init workspace and CMDTask (#1537)
* Update setup.py and config

* WIP

* init_workspace and CMDTask

* Delete test_sumarize.py
2023-06-01 23:32:35 +08:00
Fivele-Li
3b56b8e6c0 Optimize summarize task prompt and others (#1533)
* 1.update prompt;
2.update fetch information method.

* 1.update prompt;
2.save result to markdown;

* 1.get context info from context_manager;
2.run the entire process successfully.
2023-06-01 21:22:24 +08:00
you-n-g
40e0c329ba Add configurable dataset (#1535) 2023-06-01 20:05:02 +08:00
Xu Yang
e376648860 Merge pull request #1536 from microsoft/xuyang1/add_debug_mode_to_save_cache
add a debug mode to speed up debug process
2023-06-01 19:44:17 +08:00
Xu Yang
5f37f32184 update code 2023-06-01 19:38:26 +08:00
Xu Yang
d46b4c1ebf Merge pull request #1534 from microsoft/xuyang1/add_code_implementation_task
add code implementation task
2023-06-01 18:13:05 +08:00
Xu Yang
0515524b51 add code implementation code 2023-06-01 18:04:31 +08:00
Xu Yang
cda32d5703 Merge pull request #1532 from microsoft/xuyang1/add-plan-and-config-task-implementation
add the initial version of plan and config task implementation
2023-06-01 11:20:04 +08:00
Xu Yang
e2332a004b imporove some words in prompt 2023-06-01 01:09:14 +08:00
Xu Yang
08d9dbccc9 update v1 code containing SLplan and config action 2023-06-01 00:36:04 +08:00
Fivele-Li
e7cd93a36d add base method for summarization; (#1530) 2023-05-31 15:50:34 +08:00
Xu Yang
3919678028 split task into workflow and task to make the strcture more clear 2023-05-31 11:45:25 +08:00
Xu Yang
421b1403b2 Merge pull request #1528 from microsoft/xuyang1/refine_task_and_implement_workflow_task_as_example
Xuyang1/refine task and implement workflow task as example
2023-05-31 11:36:36 +08:00
Xu Yang
94102fb742 remove tasktype variable 2023-05-31 11:35:54 +08:00
Cadenza-Li
74a5d7c8af add parse method for summarization; 2023-05-31 00:08:21 +08:00
Xu Yang
ce39b4b6f8 add qlib auto init so logger can display info 2023-05-30 21:52:35 +08:00
Xu Yang
2af35d9c89 second commit 2023-05-30 20:20:16 +08:00
Xu Yang
f37643550b first round 2023-05-30 20:19:58 +08:00
Xu Yang
55611aa43e Merge pull request #1527 from microsoft/xuyang1/add_openai_api_support
add openai interface support
2023-05-30 13:44:10 +08:00
Xu Yang
f24253efd2 add openai interface support 2023-05-30 13:42:01 +08:00
Young
7c4f3b8a7d Initial interface for discussion 2023-05-24 12:18:31 +08:00
128 changed files with 18903 additions and 2950 deletions

View File

@@ -19,24 +19,7 @@ jobs:
steps:
- uses: actions/checkout@v2
# This is because on macos systems you can install pyqlib using
# `pip install pyqlib` installs, it does not recognize the
# `pyqlib-<version>-cp38-cp38-macosx_11_0_x86_64.whl` and `pyqlib-<veresion>-cp38-cp37m-macosx_11_0_x86_64.whl`.
# So we limit the version of python, in order to generate a version of qlib that is usable for macos: `pyqlib-<veresion>-cp38-cp37m
# `pyqlib-<version>-cp38-cp38-macosx_10_15_x86_64.whl` and `pyqlib-<veresion>-cp38-cp37m-macosx_10_15_x86_64.whl`.
# Python 3.7.16, 3.8.16 can build macosx_10_15. But Python 3.7.17, 3.8.17 can build macosx_11_0
- name: Set up Python ${{ matrix.python-version }}
if: matrix.os == 'macos-11' && matrix.python-version == '3.7'
uses: actions/setup-python@v2
with:
python-version: "3.7.16"
- name: Set up Python ${{ matrix.python-version }}
if: matrix.os == 'macos-11' && matrix.python-version == '3.8'
uses: actions/setup-python@v2
with:
python-version: "3.8.16"
- name: Set up Python ${{ matrix.python-version }}
if: matrix.os != 'macos-11'
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
@@ -44,15 +27,15 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Build wheel on ${{ matrix.os }}
- name: Build wheel on Windows
run: |
pip install numpy
pip install cython
python setup.py bdist_wheel
- name: Build and publish
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
twine upload dist/*
@@ -72,10 +55,10 @@ jobs:
python-version: 3.7
- name: Install dependencies
run: |
pip install twine
pip install twine
- name: Build and publish
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
twine upload dist/pyqlib-*-manylinux*.whl

View File

@@ -6,14 +6,8 @@ on:
branches:
- main
permissions:
contents: read
jobs:
update_release_draft:
permissions:
contents: write
pull-requests: read
runs-on: ubuntu-latest
steps:
# Drafts your next Release notes as Pull Requests are merged into "master"

View File

@@ -8,15 +8,13 @@ on:
jobs:
build:
if: ${{ false }} # FIXME: temporarily disable... Due to we are rushing a feature
timeout-minutes: 120
runs-on: ${{ matrix.os }}
strategy:
matrix:
# Since macos-latest changed from 12.7.4 to 14.4.1,
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
# so we limit the macos version to macos-12.
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
@@ -45,10 +43,10 @@ jobs:
- name: Qlib installation test
run: |
# 2024-05-30 scs has released a new version: 3.2.4.post2,
# This will cause the CI to fail, so we have limited the version of scs for now.
python -m pip install "scs<=3.2.4"
python -m pip install pyqlib
# Specify the numpy version because the numpy upgrade caused the CI test to fail,
# and this line of code will be removed when the next version of qlib is released.
python -m pip install "numpy<1.23"
- name: Install Lightgbm for MacOS
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
@@ -68,8 +66,5 @@ jobs:
cd qlib
- name: Test workflow by config
# On macos-11 system, it will lead to "Segmentation fault: 11" error,
# which may be caused by the excessive memory overhead of macos-11 system, so we disable macos-11 temporarily here.
if: ${{ matrix.os != 'macos-11' }}
run: |
qrun examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml

View File

@@ -14,10 +14,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
# Since macos-latest changed from 12.7.4 to 14.4.1,
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
# so we limit the macos version to macos-12.
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
@@ -41,8 +38,10 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Update pip to the latest version
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
# The pip version has been temporarily fixed to 23.0
run: |
python -m pip install --upgrade pip
python -m pip install pip==23.0
- name: Installing pytorch for macos
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
@@ -72,10 +71,8 @@ jobs:
black . -l 120 --check --diff
- name: Make html with sphinx
# Since read the docs builds on ubuntu 22.04, we only need to test that the build passes on ubuntu 22.04.
if: ${{ matrix.os == 'ubuntu-22.04' }}
run: |
cd docs
cd docs
sphinx-build -W --keep-going -b html . _build
cd ..
@@ -107,7 +104,6 @@ jobs:
- name: Check Qlib with pylint
run: |
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' scripts --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
# The following flake8 error codes were ignored:
# E501 line too long
@@ -161,16 +157,11 @@ jobs:
# Run after data downloads
- name: Check Qlib ipynb with nbconvert
# Running the nbconvert check on a macos-11 system results in a "Kernel died" error, so we've temporarily disabled macos-11 here.
if: ${{ matrix.os != 'macos-11' }}
run: |
# add more ipynb files in future
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
- name: Test workflow by config (install from source)
# On macos-11 system, it will lead to "Segmentation fault: 11" error,
# which may be caused by the excessive memory overhead of macos-11 system, so we disable macos-11 temporarily here.
if: ${{ matrix.os != 'macos-11' }}
run: |
python -m pip install numba
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml

View File

@@ -14,10 +14,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
# Since macos-latest changed from 12.7.4 to 14.4.1,
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
# so we limit the macos version to macos-12.
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
@@ -41,8 +38,10 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Set up Python tools
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
# The pip version has been temporarily fixed to 23.0
run: |
python -m pip install --upgrade pip
python -m pip install pip==23.0
pip install --upgrade cython numpy
pip install -e .[dev]

6
.gitignore vendored
View File

@@ -22,6 +22,10 @@ dist/
qlib/VERSION.txt
qlib/data/_libs/expanding.cpp
qlib/data/_libs/rolling.cpp
qlib/finco/prompt_cache.json
qlib/finco/finco_workspace/
qlib/finco/knowledge/*/knowledge.pkl
qlib/finco/knowledge/*/storage.yml
examples/estimator/estimator_example/
examples/rl/data/
examples/rl/checkpoints/
@@ -48,4 +52,4 @@ tags
*.swp
./pretrain
.idea/
.idea/

View File

@@ -5,12 +5,6 @@
# Required
version: 2
# Set the version of Python and other tools you might need
build:
os: ubuntu-22.04
tools:
python: "3.7"
# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/conf.py
@@ -20,6 +14,7 @@ formats: all
# Optionally set the version of Python and requirements required to build your docs
python:
version: 3.7
install:
- requirements: docs/requirements.txt
- method: pip

View File

@@ -40,7 +40,7 @@ Recent released features
Features released before 2021 are not listed here.
<p align="center">
<img src="docs/_static/img/logo/1.png" />
<img src="http://fintech.msra.cn/images_v070/logo/1.png" />
</p>
Qlib is an open-source, AI-oriented quantitative investment platform that aims to realize the potential, empower research, and create value using AI technologies in quantitative investment, from exploring ideas to implementing productions. Qlib supports diverse machine learning modeling paradigms, including supervised learning, market dynamics modeling, and reinforcement learning.
@@ -139,7 +139,7 @@ This table demonstrates the supported Python version of `Qlib`:
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
**Note**:
1. **Conda** is suggested for managing your Python environment. In some cases, using Python outside of a `conda` environment may result in missing header files, causing the installation failure of certain packages.
1. **Conda** is suggested for managing your Python environment.
1. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.7 or use `conda`'s Python to install ``Qlib`` from source.
1. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
1. `Qlib`Requires `tables` package, `hdf5` in tables does not support python3.9.
@@ -166,29 +166,13 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
* Clone the repository and install ``Qlib`` as follows.
```bash
git clone https://github.com/microsoft/qlib.git && cd qlib
pip install . # `pip install -e .[dev]` is recommended for development. check details in docs/developer/code_standard_and_dev_guide.rst
pip install .
```
**Note**: You can install Qlib with `python setup.py install` as well. But it is not the recommended approach. It will skip `pip` and cause obscure problems. For example, **only** the command ``pip install .`` **can** overwrite the stable version installed by ``pip install pyqlib``, while the command ``python setup.py install`` **can't**.
**Tips**: If you fail to install `Qlib` or run the examples in your environment, comparing your steps and the [CI workflow](.github/workflows/test_qlib_from_source.yml) may help you find the problem.
**Tips for Mac**: If you are using Mac with M1, you might encounter issues in building the wheel for LightGBM, which is due to missing dependencies from OpenMP. To solve the problem, install openmp first with ``brew install libomp`` and then run ``pip install .`` to build it successfully.
## Data Preparation
❗ Due to more restrict data security policy. The offical dataset is disabled temporarily. You can try [this data source](https://github.com/chenditc/investment_data/releases) contributed by the community.
Here is an example to download the data updated on 20220720.
```bash
wget https://github.com/chenditc/investment_data/releases/download/20220720/qlib_bin.tar.gz
mkdir -p ~/.qlib/qlib_data/cn_data
tar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=2
rm -f qlib_bin.tar.gz
```
The official dataset below will resume in short future.
----
Load and prepare data by running the following code:
### Get with module
@@ -337,7 +321,7 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
The automatic workflow may not suit the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
# Main Challenges & Solutions in Quant Research
Quant investment is a very unique scenario with lots of key challenges to be solved.
Quant investment is an very unique scenario with lots of key challenges to be solved.
Currently, Qlib provides some solutions for several of them.
## Forecasting: Finding Valuable Signals/Patterns
@@ -376,7 +360,7 @@ Here is a list of models built on `Qlib`.
Your PR of new Quant models is highly welcomed.
The performance of each model on the `Alpha158` and `Alpha360` datasets can be found [here](examples/benchmarks/README.md).
The performance of each model on the `Alpha158` and `Alpha360` dataset can be found [here](examples/benchmarks/README.md).
### Run a single model
All the models listed above are runnable with ``Qlib``. Users can find the config files we provide and some details about the model through the [benchmarks](examples/benchmarks) folder. More information can be retrieved at the model files listed above.

View File

@@ -52,7 +52,7 @@ Also, ``Qlib`` provides a high-frequency dataset. Users can run a high-frequency
Qlib Format Dataset
-------------------
``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the China-Stock dataset as follows. User can also use numpy to load `.bin` file to validate data.
The price volume data look different from the actual dealing price because of they are **adjusted** (`adjusted price <https://www.investopedia.com/terms/a/adjusted_closing_price.asp>`_). And then you may find that the adjusted price may be different from different data sources. This is because different data sources may vary in the way of adjusting prices. Qlib normalize the price on first trading day of each stock to 1 when adjusting them.
The price volume data look different from the actual dealling price because of they are **adjusted** (`adjusted price <https://www.investopedia.com/terms/a/adjusted_closing_price.asp>`_). And then you may find that the adjusted price may be different from different data sources. This is because different data sources may vary in the way of adjusting prices. Qlib normalize the price on first trading day of each stock to 1 when adjusting them.
Users can leverage `$factor` to get the original trading price (e.g. `$close / $factor` to get the original close price).
Here are some discussions about the price adjusting of Qlib.
@@ -140,13 +140,12 @@ Users can also provide their own data in CSV format. However, the CSV data **mus
where the data are in the following format:
+-----------+-------+
| symbol | close |
+===========+=======+
| SH600000 | 120 |
+-----------+-------+
.. code-block::
- CSV file **must** include a column for the date, and when dumping the data, user must specify the date column name. Here is an example:
symbol,close
SH600000,120
- CSV file **must** includes a column for the date, and when dumping the data, user must specify the date column name. Here is an example:
.. code-block:: bash
@@ -154,13 +153,11 @@ Users can also provide their own data in CSV format. However, the CSV data **mus
where the data are in the following format:
+---------+------------+-------+------+----------+
| symbol | date | close | open | volume |
+=========+============+=======+======+==========+
| SH600000| 2020-11-01 | 120 | 121 | 12300000 |
+---------+------------+-------+------+----------+
| SH600000| 2020-11-02 | 123 | 120 | 12300000 |
+---------+------------+-------+------+----------+
.. code-block::
symbol,date,close,open,volume
SH600000,2020-11-01,120,121,12300000
SH600000,2020-11-02,123,120,12300000
Supposed that users prepare their CSV format data in the directory ``~/.qlib/csv_data/my_data``, they can run the following command to start the conversion.

View File

@@ -86,7 +86,7 @@ Example
},
}
# model initialization
# model initiaiton
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])

View File

@@ -60,4 +60,4 @@ The `[dev]` option will help you to install some related packages when developin
.. code-block:: bash
pip install -e ".[dev]"
pip install -e .[dev]

View File

@@ -36,7 +36,7 @@ Name Description
the training process of models which enable algorithms controlling the
training process.
`Learning Framework` layer The `Forecast Model` and `Trading Agent` are trainable. They are trained
`Learning Framework` layer The `Forecast Model` and `Trading Agent` are learnable. They are learned
based on the `Learning Framework` layer and then applied to multiple scenarios
in `Workflow` layer. The supported learning paradigms can be categorized into
reinforcement learning and supervised learning. The learning framework
@@ -51,7 +51,7 @@ Name Description
modules. With these signals `Decision Generator` will generate the target
trading decisions(i.e. portfolio, orders)
If RL-based Strategies are adopted, the `Policy` is learned in a end-to-end way,
the trading decisions are generated directly.
the trading deicsions are generated directly.
Decisions will be executed by `Execution Env`
(i.e. the trading market). There may be multiple levels of `Strategy`
and `Executor` (e.g. an *order executor trading strategy and intraday order executor*

View File

@@ -16,7 +16,7 @@ This ``Quick Start`` guide tries to demonstrate
Installation
============
Users can easily install ``Qlib`` according to the following steps:
Users can easily intsall ``Qlib`` according to the following steps:
- Before installing ``Qlib`` from source, users need to install some dependencies:

View File

@@ -5,4 +5,3 @@ scipy
scikit-learn
pandas
tianshou
sphinx_rtd_theme

View File

@@ -1,15 +0,0 @@
# Introduction
What is GeneralPtNN
- Fix previous design that fail to support both Time-series and tabular data
- Now you can just replace the Pytorch model structure to run a NN model.
We provide an example to demonstrate the effectiveness of the current design.
- `workflow_config_gru.yaml` align with previous results [GRU(Kyunghyun Cho, et al.)](../README.md#Alpha158 dataset)
- `workflow_config_mlp.yaml` align with previous results [MLP](../README.md#Alpha158 dataset)
# TODO
We will align existing models to current design.

View File

@@ -1,97 +0,0 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: FilterCol
kwargs:
fields_group: feature
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
]
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
signal: <PRED>
topk: 50
n_drop: 5
backtest:
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: GeneralPTNN
module_path: qlib.contrib.model.pytorch_general_nn
kwargs:
d_feat: 20
hidden_size: 64
num_layers: 2
dropout: 0.0
n_epochs: 200
lr: 2e-4
early_stop: 10
batch_size: 800
metric: loss
loss: mse
n_jobs: 20
GPU: 0
dataset:
class: TSDatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
step_len: 20
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs:
model: <MODEL>
dataset: <DATASET>
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -1,98 +0,0 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors: [
{
"class" : "DropCol",
"kwargs":{"col_list": ["VWAP0"]}
},
{
"class" : "CSZFillna",
"kwargs":{"fields_group": "feature"}
}
]
learn_processors: [
{
"class" : "DropCol",
"kwargs":{"col_list": ["VWAP0"]}
},
{
"class" : "DropnaProcessor",
"kwargs":{"fields_group": "feature"}
},
"DropnaLabel",
{
"class": "CSZScoreNorm",
"kwargs": {"fields_group": "label"}
}
]
process_type: "independent"
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
signal: <PRED>
topk: 50
n_drop: 5
backtest:
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: GeneralPTNN
module_path: qlib.contrib.model.pytorch_general_nn
kwargs:
loss: mse
lr: 0.002
optimizer: adam
max_steps: 8000
batch_size: 8192
GPU: 0
weight_decay: 0.0002
pt_model_kwargs:
input_dim: 157
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs:
model: <MODEL>
dataset: <DATASET>
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -136,7 +136,7 @@ If you want to contribute your new models, you can follow the steps below.
- `README.md`: a brief introduction to your models
- `workflow_config_<model name>_<dataset>.yaml`: a configuration which can read by `qrun`. You are encouraged to run your model in all datasets.
3. You can integrate your model as a module [in this folder](https://github.com/microsoft/qlib/tree/main/qlib/contrib/model).
4. Please update your results in the above **Benchmark Tables**, e.g. [Alpha360](#alpha158-dataset), [Alpha158](#alpha158-dataset)(the values of each metric are the mean and std calculated based on **20 Runs** with different random seeds. You can accomplish the above operations through the automated [script](https://github.com/microsoft/qlib/blob/main/examples/run_all_model.py) provided by Qlib, and get the final result in the .md file. if you don't have enough computational resource, you can ask for help in the PR).
4. Please update your results in the above **Benchmark Tables**, e.g. [Alpha360](#alpha158-dataset), [Alpha158](#alpha158-dataset)(the values of each metric are the mean and std calculated based on **20 Runs** with different random seeds. You can accomplish the above operations through the automated [script](https://github.com/microsoft/qlib/blob/main/examples/run_all_model.py#LL286C22-L286C22) provided by Qlib, and get the final result in the .md file. if you don't have enough computational resource, you can ask for help in the PR).
5. Update the info in the index page in the [news list](https://github.com/microsoft/qlib#newspaper-whats-new----sparkling_heart) and [model list](https://github.com/microsoft/qlib#quant-model-paper-zoo).
Finally, you can send PR for review. ([here is an example](https://github.com/microsoft/qlib/pull/1040))

View File

@@ -324,6 +324,7 @@ class TRAModel(Model):
class LSTM(nn.Module):
"""LSTM Model
Args:
@@ -413,6 +414,7 @@ class PositionalEncoding(nn.Module):
class Transformer(nn.Module):
"""Transformer Model
Args:
@@ -473,6 +475,7 @@ class Transformer(nn.Module):
class TRA(nn.Module):
"""Temporal Routing Adaptor (TRA)
TRA takes historical prediction errors & latent representation as inputs,

View File

@@ -1,6 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from pathlib import Path
from typing import Union
@@ -36,10 +35,6 @@ class DDGDABench(DDGDA):
if __name__ == "__main__":
kwargs = {}
if os.environ.get("PROVIDER_URI", "") == "":
GetData().qlib_data(exists_skip=True)
else:
kwargs["provider_uri"] = os.environ["PROVIDER_URI"]
auto_init(**kwargs)
GetData().qlib_data(exists_skip=True)
auto_init()
fire.Fire(DDGDABench)

View File

@@ -1,6 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from pathlib import Path
from typing import Union
@@ -32,10 +31,6 @@ class RollingBenchmark(Rolling):
if __name__ == "__main__":
kwargs = {}
if os.environ.get("PROVIDER_URI", "") == "":
GetData().qlib_data(exists_skip=True)
else:
kwargs["provider_uri"] = os.environ["PROVIDER_URI"]
auto_init(**kwargs)
GetData().qlib_data(exists_skip=True)
auto_init()
fire.Fire(RollingBenchmark)

View File

@@ -16,7 +16,7 @@ Current version of script with default value tries to connect localhost **via de
Run following command to install necessary libraries
```
pip install pytest coverage gdown
pip install pytest coverage
pip install arctic # NOTE: pip may fail to resolve the right package dependency !!! Please make sure the dependency are satisfied.
```
@@ -27,12 +27,13 @@ pip install arctic # NOTE: pip may fail to resolve the right package dependency
2. Please follow following steps to download example data
```bash
cd examples/orderbook_data/
gdown https://drive.google.com/uc?id=15nZF7tFT_eKVZAcMFL1qPS4jGyJflH7e # Proxies may be necessary here.
python ../../scripts/get_data.py _unzip --file_path highfreq_orderbook_example_data.zip --target_dir .
wget http://fintech.msra.cn/stock_data/downloads/highfreq_orderboook_example_data.tar.bz2
tar xf highfreq_orderboook_example_data.tar.bz2
```
3. Please import the example data to your mongo db
```bash
cd examples/orderbook_data/
python create_dataset.py initialize_library # Initialization Libraries
python create_dataset.py import_data # Initialization Libraries
```
@@ -41,6 +42,7 @@ python create_dataset.py import_data # Initialization Libraries
After importing these data, you run `example.py` to create some high-frequency features.
```bash
cd examples/orderbook_data/
pytest -s --disable-warnings example.py # If you want run all examples
pytest -s --disable-warnings example.py::TestClass::test_exp_10 # If you want to run specific example
```

View File

@@ -20,7 +20,7 @@ We use China stock market data for our example.
1. Prepare CSI300 weight:
```bash
wget https://github.com/SunsetWolf/qlib_dataset/releases/download/v0/csi300_weight.zip
wget http://fintech.msra.cn/stock_data/downloads/csi300_weight.zip
unzip -d ~/.qlib/qlib_data/cn_data csi300_weight.zip
rm -f csi300_weight.zip
```

View File

@@ -161,7 +161,7 @@
" },\n",
"}\n",
"\n",
"# model initialization\n",
"# model initiaiton\n",
"model = init_instance_by_config(task[\"model\"])\n",
"dataset = init_instance_by_config(task[\"dataset\"])\n",
"\n",

View File

@@ -1,2 +0,0 @@
[build-system]
requires = ["setuptools", "numpy", "Cython"]

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
from pathlib import Path
__version__ = "0.9.5.99"
__version__ = "0.9.2.99"
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
import os
from typing import Union

View File

@@ -162,15 +162,13 @@ def create_account_instance(
init_cash=init_cash,
position_dict=position_dict,
pos_type=pos_type,
benchmark_config=(
{}
if benchmark is None
else {
"benchmark": benchmark,
"start_time": start_time,
"end_time": end_time,
}
),
benchmark_config={}
if benchmark is None
else {
"benchmark": benchmark,
"start_time": start_time,
"end_time": end_time,
},
)

View File

@@ -622,11 +622,9 @@ class Indicator:
print(
"[Indicator({}) {}]: FFR: {}, PA: {}, POS: {}".format(
freq,
(
trade_start_time
if isinstance(trade_start_time, str)
else trade_start_time.strftime("%Y-%m-%d %H:%M:%S")
),
trade_start_time
if isinstance(trade_start_time, str)
else trade_start_time.strftime("%Y-%m-%d %H:%M:%S"),
fulfill_rate,
price_advantage,
positive_rate,

View File

@@ -486,5 +486,8 @@ class QlibConfig(Config):
return self._registered
DEFAULT_QLIB_DOT_PATH = Path("~/.qlib/").expanduser()
# global config
C = QlibConfig(_default_config)

111
qlib/contrib/analyzer.py Normal file
View File

@@ -0,0 +1,111 @@
import logging
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
from ..log import get_module_logger
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
logger = get_module_logger("analysis", logging.INFO)
class AnalyzerTemp:
def __init__(self, recorder, output_dir=None, **kwargs):
self.recorder = recorder
self.output_dir = Path(output_dir) if output_dir else "./"
def load(self, name: str):
"""
It behaves the same as self.recorder.load_object.
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
Parameters
----------
name : str
the name for the file to be load.
Return
------
The stored records.
"""
return self.recorder.load_object(name)
def analyse(self, **kwargs):
"""
Analyse data index, distribution .etc
Parameters
----------
Return
------
The handled data.
"""
raise NotImplementedError(f"Please implement the `analysis` method.")
class HFAnalyzer(AnalyzerTemp):
"""
This is the Signal Analysis class that generates the analysis results such as IC and IR.
default output image filename is "HFAnalyzerTable.jpeg"
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def analyse(self):
pred = self.load("pred.pkl")
label = self.load("label.pkl")
long_pre, short_pre = calc_long_short_prec(pred.iloc[:, 0], label.iloc[:, 0], is_alpha=True)
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
metrics = {
"IC": ic.mean(),
"ICIR": ic.mean() / ic.std(),
"Rank IC": ric.mean(),
"Rank ICIR": ric.mean() / ric.std(),
"Long precision": long_pre.mean(),
"Short precision": short_pre.mean(),
}
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0])
metrics.update(
{
"Long-Short Average Return": long_short_r.mean(),
"Long-Short Average Sharpe": long_short_r.mean() / long_short_r.std(),
}
)
table = [[k, v] for (k, v) in metrics.items()]
plt.table(cellText=table, loc="center")
plt.axis("off")
plt.savefig(self.output_dir.joinpath("HFAnalyzerTable.jpeg"))
plt.clf()
plt.scatter(np.arange(0, len(pred)), pred.iloc[:, 0])
plt.scatter(np.arange(0, len(label)), label.iloc[:, 0])
plt.title("HFAnalyzer")
plt.savefig(self.output_dir.joinpath("HFAnalyzer.jpeg"))
return "HFAnalyzer.jpeg"
class SignalAnalyzer(AnalyzerTemp):
"""
This is the Signal Analysis class that generates the analysis results such as IC and IR.
default output image filename is "signalAnalysis.jpeg"
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def analyse(self, dataset=None, **kwargs):
label = self.load("label.pkl")
plt.hist(label)
plt.title("SignalAnalyzer")
plt.savefig(self.output_dir.joinpath("signalAnalysis.jpeg"))
return "signalAnalysis.jpeg"

View File

@@ -1,7 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from qlib.contrib.data.loader import Alpha158DL, Alpha360DL
from typing import Optional
from qlib.utils.data import update_config
from ...data.dataset.handler import DataHandlerLP
from ...data.dataset.processor import Processor
from ...utils import get_callable_kwargs
@@ -58,16 +59,17 @@ class Alpha360(DataHandlerLP):
fit_end_time=None,
filter_pipe=None,
inst_processors=None,
data_loader: Optional[dict] = None,
**kwargs
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
data_loader = {
_data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"config": {
"feature": Alpha360DL.get_feature_config(),
"feature": self.get_feature_config(),
"label": kwargs.pop("label", self.get_label_config()),
},
"filter_pipe": filter_pipe,
@@ -75,12 +77,14 @@ class Alpha360(DataHandlerLP):
"inst_processors": inst_processors,
},
}
if data_loader is not None:
update_config(_data_loader, data_loader)
super().__init__(
instruments=instruments,
start_time=start_time,
end_time=end_time,
data_loader=data_loader,
data_loader=_data_loader,
learn_processors=learn_processors,
infer_processors=infer_processors,
**kwargs
@@ -89,6 +93,51 @@ class Alpha360(DataHandlerLP):
def get_label_config(self):
return ["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]
@staticmethod
def get_feature_config():
# NOTE:
# Alpha360 tries to provide a dataset with original price data
# the original price data includes the prices and volume in the last 60 days.
# To make it easier to learn models from this dataset, all the prices and volume
# are normalized by the latest price and volume data ( dividing by $close, $volume)
# So the latest normalized $close will be 1 (with name CLOSE0), the latest normalized $volume will be 1 (with name VOLUME0)
# If further normalization are executed (e.g. centralization), CLOSE0 and VOLUME0 will be 0.
fields = []
names = []
for i in range(59, 0, -1):
fields += ["Ref($close, %d)/$close" % i]
names += ["CLOSE%d" % i]
fields += ["$close/$close"]
names += ["CLOSE0"]
for i in range(59, 0, -1):
fields += ["Ref($open, %d)/$close" % i]
names += ["OPEN%d" % i]
fields += ["$open/$close"]
names += ["OPEN0"]
for i in range(59, 0, -1):
fields += ["Ref($high, %d)/$close" % i]
names += ["HIGH%d" % i]
fields += ["$high/$close"]
names += ["HIGH0"]
for i in range(59, 0, -1):
fields += ["Ref($low, %d)/$close" % i]
names += ["LOW%d" % i]
fields += ["$low/$close"]
names += ["LOW0"]
for i in range(59, 0, -1):
fields += ["Ref($vwap, %d)/$close" % i]
names += ["VWAP%d" % i]
fields += ["$vwap/$close"]
names += ["VWAP0"]
for i in range(59, 0, -1):
fields += ["Ref($volume, %d)/($volume+1e-12)" % i]
names += ["VOLUME%d" % i]
fields += ["$volume/($volume+1e-12)"]
names += ["VOLUME0"]
return fields, names
class Alpha360vwap(Alpha360):
def get_label_config(self):
@@ -109,12 +158,13 @@ class Alpha158(DataHandlerLP):
process_type=DataHandlerLP.PTYPE_A,
filter_pipe=None,
inst_processors=None,
data_loader: Optional[dict] = None,
**kwargs
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
data_loader = {
_data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"config": {
@@ -126,11 +176,13 @@ class Alpha158(DataHandlerLP):
"inst_processors": inst_processors,
},
}
if data_loader is not None:
update_config(_data_loader, data_loader)
super().__init__(
instruments=instruments,
start_time=start_time,
end_time=end_time,
data_loader=data_loader,
data_loader=_data_loader,
infer_processors=infer_processors,
learn_processors=learn_processors,
process_type=process_type,
@@ -146,11 +198,242 @@ class Alpha158(DataHandlerLP):
},
"rolling": {},
}
return Alpha158DL.get_feature_config(conf)
return self.parse_config_to_fields(conf)
def get_label_config(self):
return ["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]
@staticmethod
def parse_config_to_fields(config):
"""create factors from config
config = {
'kbar': {}, # whether to use some hard-code kbar features
'price': { # whether to use raw price features
'windows': [0, 1, 2, 3, 4], # use price at n days ago
'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
},
'volume': { # whether to use raw volume features
'windows': [0, 1, 2, 3, 4], # use volume at n days ago
},
'rolling': { # whether to use rolling operator based features
'windows': [5, 10, 20, 30, 60], # rolling windows size
'include': ['ROC', 'MA', 'STD'], # rolling operator to use
#if include is None we will use default operators
'exclude': ['RANK'], # rolling operator not to use
}
}
"""
fields = []
names = []
if "kbar" in config:
fields += [
"($close-$open)/$open",
"($high-$low)/$open",
"($close-$open)/($high-$low+1e-12)",
"($high-Greater($open, $close))/$open",
"($high-Greater($open, $close))/($high-$low+1e-12)",
"(Less($open, $close)-$low)/$open",
"(Less($open, $close)-$low)/($high-$low+1e-12)",
"(2*$close-$high-$low)/$open",
"(2*$close-$high-$low)/($high-$low+1e-12)",
]
names += [
"KMID",
"KLEN",
"KMID2",
"KUP",
"KUP2",
"KLOW",
"KLOW2",
"KSFT",
"KSFT2",
]
if "price" in config:
windows = config["price"].get("windows", range(5))
feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
for field in feature:
field = field.lower()
fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
names += [field.upper() + str(d) for d in windows]
if "volume" in config:
windows = config["volume"].get("windows", range(5))
fields += ["Ref($volume, %d)/($volume+1e-12)" % d if d != 0 else "$volume/($volume+1e-12)" for d in windows]
names += ["VOLUME" + str(d) for d in windows]
if "rolling" in config:
windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
include = config["rolling"].get("include", None)
exclude = config["rolling"].get("exclude", [])
# `exclude` in dataset config unnecessary filed
# `include` in dataset config necessary field
def use(x):
return x not in exclude and (include is None or x in include)
# Some factor ref: https://guorn.com/static/upload/file/3/134065454575605.pdf
if use("ROC"):
# https://www.investopedia.com/terms/r/rateofchange.asp
# Rate of change, the price change in the past d days, divided by latest close price to remove unit
fields += ["Ref($close, %d)/$close" % d for d in windows]
names += ["ROC%d" % d for d in windows]
if use("MA"):
# https://www.investopedia.com/ask/answers/071414/whats-difference-between-moving-average-and-weighted-moving-average.asp
# Simple Moving Average, the simple moving average in the past d days, divided by latest close price to remove unit
fields += ["Mean($close, %d)/$close" % d for d in windows]
names += ["MA%d" % d for d in windows]
if use("STD"):
# The standard diviation of close price for the past d days, divided by latest close price to remove unit
fields += ["Std($close, %d)/$close" % d for d in windows]
names += ["STD%d" % d for d in windows]
if use("BETA"):
# The rate of close price change in the past d days, divided by latest close price to remove unit
# For example, price increase 10 dollar per day in the past d days, then Slope will be 10.
fields += ["Slope($close, %d)/$close" % d for d in windows]
names += ["BETA%d" % d for d in windows]
if use("RSQR"):
# The R-sqaure value of linear regression for the past d days, represent the trend linear
fields += ["Rsquare($close, %d)" % d for d in windows]
names += ["RSQR%d" % d for d in windows]
if use("RESI"):
# The redisdual for linear regression for the past d days, represent the trend linearity for past d days.
fields += ["Resi($close, %d)/$close" % d for d in windows]
names += ["RESI%d" % d for d in windows]
if use("MAX"):
# The max price for past d days, divided by latest close price to remove unit
fields += ["Max($high, %d)/$close" % d for d in windows]
names += ["MAX%d" % d for d in windows]
if use("LOW"):
# The low price for past d days, divided by latest close price to remove unit
fields += ["Min($low, %d)/$close" % d for d in windows]
names += ["MIN%d" % d for d in windows]
if use("QTLU"):
# The 80% quantile of past d day's close price, divided by latest close price to remove unit
# Used with MIN and MAX
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
names += ["QTLU%d" % d for d in windows]
if use("QTLD"):
# The 20% quantile of past d day's close price, divided by latest close price to remove unit
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
names += ["QTLD%d" % d for d in windows]
if use("RANK"):
# Get the percentile of current close price in past d day's close price.
# Represent the current price level comparing to past N days, add additional information to moving average.
fields += ["Rank($close, %d)" % d for d in windows]
names += ["RANK%d" % d for d in windows]
if use("RSV"):
# Represent the price position between upper and lower resistent price for past d days.
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
names += ["RSV%d" % d for d in windows]
if use("IMAX"):
# The number of days between current date and previous highest price date.
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
# The indicator measures the time between highs and the time between lows over a time period.
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
names += ["IMAX%d" % d for d in windows]
if use("IMIN"):
# The number of days between current date and previous lowest price date.
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
# The indicator measures the time between highs and the time between lows over a time period.
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
names += ["IMIN%d" % d for d in windows]
if use("IMXD"):
# The time period between previous lowest-price date occur after highest price date.
# Large value suggest downward momemtum.
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
names += ["IMXD%d" % d for d in windows]
if use("CORR"):
# The correlation between absolute close price and log scaled trading volume
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
names += ["CORR%d" % d for d in windows]
if use("CORD"):
# The correlation between price change ratio and volume change ratio
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
names += ["CORD%d" % d for d in windows]
if use("CNTP"):
# The percentage of days in past d days that price go up.
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
names += ["CNTP%d" % d for d in windows]
if use("CNTN"):
# The percentage of days in past d days that price go down.
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
names += ["CNTN%d" % d for d in windows]
if use("CNTD"):
# The diff between past up day and past down day
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
names += ["CNTD%d" % d for d in windows]
if use("SUMP"):
# The total gain / the absolute total price changed
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
fields += [
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
for d in windows
]
names += ["SUMP%d" % d for d in windows]
if use("SUMN"):
# The total lose / the absolute total price changed
# Can be derived from SUMP by SUMN = 1 - SUMP
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
fields += [
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
for d in windows
]
names += ["SUMN%d" % d for d in windows]
if use("SUMD"):
# The diff ratio between total gain and total lose
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
fields += [
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d)
for d in windows
]
names += ["SUMD%d" % d for d in windows]
if use("VMA"):
# Simple Volume Moving average: https://www.barchart.com/education/technical-indicators/volume_moving_average
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
names += ["VMA%d" % d for d in windows]
if use("VSTD"):
# The standard deviation for volume in past d days.
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
names += ["VSTD%d" % d for d in windows]
if use("WVMA"):
# The volume weighted price change volatility
fields += [
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)"
% (d, d)
for d in windows
]
names += ["WVMA%d" % d for d in windows]
if use("VSUMP"):
# The total volume increase / the absolute total volume changed
fields += [
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
% (d, d)
for d in windows
]
names += ["VSUMP%d" % d for d in windows]
if use("VSUMN"):
# The total volume increase / the absolute total volume changed
# Can be derived from VSUMP by VSUMN = 1 - VSUMP
fields += [
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
% (d, d)
for d in windows
]
names += ["VSUMN%d" % d for d in windows]
if use("VSUMD"):
# The diff ratio between total volume increase and total volume decrease
# RSI indicator for volume
fields += [
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d)
for d in windows
]
names += ["VSUMD%d" % d for d in windows]
return fields, names
class Alpha158vwap(Alpha158):
def get_label_config(self):

View File

@@ -1,310 +0,0 @@
from qlib.data.dataset.loader import QlibDataLoader
class Alpha360DL(QlibDataLoader):
"""Dataloader to get Alpha360"""
def __init__(self, config=None, **kwargs):
_config = {
"feature": self.get_feature_config(),
}
if config is not None:
_config.update(config)
super().__init__(config=_config, **kwargs)
@staticmethod
def get_feature_config():
# NOTE:
# Alpha360 tries to provide a dataset with original price data
# the original price data includes the prices and volume in the last 60 days.
# To make it easier to learn models from this dataset, all the prices and volume
# are normalized by the latest price and volume data ( dividing by $close, $volume)
# So the latest normalized $close will be 1 (with name CLOSE0), the latest normalized $volume will be 1 (with name VOLUME0)
# If further normalization are executed (e.g. centralization), CLOSE0 and VOLUME0 will be 0.
fields = []
names = []
for i in range(59, 0, -1):
fields += ["Ref($close, %d)/$close" % i]
names += ["CLOSE%d" % i]
fields += ["$close/$close"]
names += ["CLOSE0"]
for i in range(59, 0, -1):
fields += ["Ref($open, %d)/$close" % i]
names += ["OPEN%d" % i]
fields += ["$open/$close"]
names += ["OPEN0"]
for i in range(59, 0, -1):
fields += ["Ref($high, %d)/$close" % i]
names += ["HIGH%d" % i]
fields += ["$high/$close"]
names += ["HIGH0"]
for i in range(59, 0, -1):
fields += ["Ref($low, %d)/$close" % i]
names += ["LOW%d" % i]
fields += ["$low/$close"]
names += ["LOW0"]
for i in range(59, 0, -1):
fields += ["Ref($vwap, %d)/$close" % i]
names += ["VWAP%d" % i]
fields += ["$vwap/$close"]
names += ["VWAP0"]
for i in range(59, 0, -1):
fields += ["Ref($volume, %d)/($volume+1e-12)" % i]
names += ["VOLUME%d" % i]
fields += ["$volume/($volume+1e-12)"]
names += ["VOLUME0"]
return fields, names
class Alpha158DL(QlibDataLoader):
"""Dataloader to get Alpha158"""
def __init__(self, config=None, **kwargs):
_config = {
"feature": self.get_feature_config(),
}
if config is not None:
_config.update(config)
super().__init__(config=_config, **kwargs)
@staticmethod
def get_feature_config(
config={
"kbar": {},
"price": {
"windows": [0],
"feature": ["OPEN", "HIGH", "LOW", "VWAP"],
},
"rolling": {},
}
):
"""create factors from config
config = {
'kbar': {}, # whether to use some hard-code kbar features
'price': { # whether to use raw price features
'windows': [0, 1, 2, 3, 4], # use price at n days ago
'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
},
'volume': { # whether to use raw volume features
'windows': [0, 1, 2, 3, 4], # use volume at n days ago
},
'rolling': { # whether to use rolling operator based features
'windows': [5, 10, 20, 30, 60], # rolling windows size
'include': ['ROC', 'MA', 'STD'], # rolling operator to use
#if include is None we will use default operators
'exclude': ['RANK'], # rolling operator not to use
}
}
"""
fields = []
names = []
if "kbar" in config:
fields += [
"($close-$open)/$open",
"($high-$low)/$open",
"($close-$open)/($high-$low+1e-12)",
"($high-Greater($open, $close))/$open",
"($high-Greater($open, $close))/($high-$low+1e-12)",
"(Less($open, $close)-$low)/$open",
"(Less($open, $close)-$low)/($high-$low+1e-12)",
"(2*$close-$high-$low)/$open",
"(2*$close-$high-$low)/($high-$low+1e-12)",
]
names += [
"KMID",
"KLEN",
"KMID2",
"KUP",
"KUP2",
"KLOW",
"KLOW2",
"KSFT",
"KSFT2",
]
if "price" in config:
windows = config["price"].get("windows", range(5))
feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
for field in feature:
field = field.lower()
fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
names += [field.upper() + str(d) for d in windows]
if "volume" in config:
windows = config["volume"].get("windows", range(5))
fields += ["Ref($volume, %d)/($volume+1e-12)" % d if d != 0 else "$volume/($volume+1e-12)" for d in windows]
names += ["VOLUME" + str(d) for d in windows]
if "rolling" in config:
windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
include = config["rolling"].get("include", None)
exclude = config["rolling"].get("exclude", [])
# `exclude` in dataset config unnecessary filed
# `include` in dataset config necessary field
def use(x):
return x not in exclude and (include is None or x in include)
# Some factor ref: https://guorn.com/static/upload/file/3/134065454575605.pdf
if use("ROC"):
# https://www.investopedia.com/terms/r/rateofchange.asp
# Rate of change, the price change in the past d days, divided by latest close price to remove unit
fields += ["Ref($close, %d)/$close" % d for d in windows]
names += ["ROC%d" % d for d in windows]
if use("MA"):
# https://www.investopedia.com/ask/answers/071414/whats-difference-between-moving-average-and-weighted-moving-average.asp
# Simple Moving Average, the simple moving average in the past d days, divided by latest close price to remove unit
fields += ["Mean($close, %d)/$close" % d for d in windows]
names += ["MA%d" % d for d in windows]
if use("STD"):
# The standard diviation of close price for the past d days, divided by latest close price to remove unit
fields += ["Std($close, %d)/$close" % d for d in windows]
names += ["STD%d" % d for d in windows]
if use("BETA"):
# The rate of close price change in the past d days, divided by latest close price to remove unit
# For example, price increase 10 dollar per day in the past d days, then Slope will be 10.
fields += ["Slope($close, %d)/$close" % d for d in windows]
names += ["BETA%d" % d for d in windows]
if use("RSQR"):
# The R-sqaure value of linear regression for the past d days, represent the trend linear
fields += ["Rsquare($close, %d)" % d for d in windows]
names += ["RSQR%d" % d for d in windows]
if use("RESI"):
# The redisdual for linear regression for the past d days, represent the trend linearity for past d days.
fields += ["Resi($close, %d)/$close" % d for d in windows]
names += ["RESI%d" % d for d in windows]
if use("MAX"):
# The max price for past d days, divided by latest close price to remove unit
fields += ["Max($high, %d)/$close" % d for d in windows]
names += ["MAX%d" % d for d in windows]
if use("LOW"):
# The low price for past d days, divided by latest close price to remove unit
fields += ["Min($low, %d)/$close" % d for d in windows]
names += ["MIN%d" % d for d in windows]
if use("QTLU"):
# The 80% quantile of past d day's close price, divided by latest close price to remove unit
# Used with MIN and MAX
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
names += ["QTLU%d" % d for d in windows]
if use("QTLD"):
# The 20% quantile of past d day's close price, divided by latest close price to remove unit
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
names += ["QTLD%d" % d for d in windows]
if use("RANK"):
# Get the percentile of current close price in past d day's close price.
# Represent the current price level comparing to past N days, add additional information to moving average.
fields += ["Rank($close, %d)" % d for d in windows]
names += ["RANK%d" % d for d in windows]
if use("RSV"):
# Represent the price position between upper and lower resistent price for past d days.
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
names += ["RSV%d" % d for d in windows]
if use("IMAX"):
# The number of days between current date and previous highest price date.
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
# The indicator measures the time between highs and the time between lows over a time period.
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
names += ["IMAX%d" % d for d in windows]
if use("IMIN"):
# The number of days between current date and previous lowest price date.
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
# The indicator measures the time between highs and the time between lows over a time period.
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
names += ["IMIN%d" % d for d in windows]
if use("IMXD"):
# The time period between previous lowest-price date occur after highest price date.
# Large value suggest downward momemtum.
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
names += ["IMXD%d" % d for d in windows]
if use("CORR"):
# The correlation between absolute close price and log scaled trading volume
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
names += ["CORR%d" % d for d in windows]
if use("CORD"):
# The correlation between price change ratio and volume change ratio
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
names += ["CORD%d" % d for d in windows]
if use("CNTP"):
# The percentage of days in past d days that price go up.
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
names += ["CNTP%d" % d for d in windows]
if use("CNTN"):
# The percentage of days in past d days that price go down.
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
names += ["CNTN%d" % d for d in windows]
if use("CNTD"):
# The diff between past up day and past down day
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
names += ["CNTD%d" % d for d in windows]
if use("SUMP"):
# The total gain / the absolute total price changed
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
fields += [
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
for d in windows
]
names += ["SUMP%d" % d for d in windows]
if use("SUMN"):
# The total lose / the absolute total price changed
# Can be derived from SUMP by SUMN = 1 - SUMP
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
fields += [
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
for d in windows
]
names += ["SUMN%d" % d for d in windows]
if use("SUMD"):
# The diff ratio between total gain and total lose
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
fields += [
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d)
for d in windows
]
names += ["SUMD%d" % d for d in windows]
if use("VMA"):
# Simple Volume Moving average: https://www.barchart.com/education/technical-indicators/volume_moving_average
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
names += ["VMA%d" % d for d in windows]
if use("VSTD"):
# The standard deviation for volume in past d days.
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
names += ["VSTD%d" % d for d in windows]
if use("WVMA"):
# The volume weighted price change volatility
fields += [
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)"
% (d, d)
for d in windows
]
names += ["WVMA%d" % d for d in windows]
if use("VSUMP"):
# The total volume increase / the absolute total volume changed
fields += [
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
% (d, d)
for d in windows
]
names += ["VSUMP%d" % d for d in windows]
if use("VSUMN"):
# The total volume increase / the absolute total volume changed
# Can be derived from VSUMP by VSUMN = 1 - VSUMP
fields += [
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
% (d, d)
for d in windows
]
names += ["VSUMN%d" % d for d in windows]
if use("VSUMD"):
# The diff ratio between total volume increase and total volume decrease
# RSI indicator for volume
fields += [
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d)
for d in windows
]
names += ["VSUMD%d" % d for d in windows]
return fields, names

View File

@@ -3,7 +3,6 @@ Here is a batch of evaluation functions.
The interface should be redesigned carefully in the future.
"""
import pandas as pd
from typing import Tuple
from qlib import get_module_logger

View File

@@ -243,7 +243,7 @@ class MetaDatasetDS(MetaTaskDataset):
trunc_days: int = None,
rolling_ext_days: int = 0,
exp_name: Union[str, InternalData],
segments: Union[Dict[Text, Tuple], float, str],
segments: Union[Dict[Text, Tuple], float],
hist_step_n: int = 10,
task_mode: str = MetaTask.PROC_MODE_FULL,
fill_method: str = "max",
@@ -271,16 +271,12 @@ class MetaDatasetDS(MetaTaskDataset):
- str: the name of the experiment to store the performance of data
- InternalData: a prepared internal data
segments: Union[Dict[Text, Tuple], float]
if the segment is a Dict
the segments to divide data
both left and right are included
the segments to divide data
both left and right
if segments is a float:
the float represents the percentage of data for training
if segments is a string:
it will try its best to put its data in training and ensure that the date `segments` is in the test set
hist_step_n: int
length of historical steps for the meta infomation
Number of steps of the data similarity information
task_mode : str
Please refer to the docs of MetaTask
"""
@@ -387,30 +383,10 @@ class MetaDatasetDS(MetaTaskDataset):
if isinstance(self.segments, float):
train_task_n = int(len(self.meta_task_l) * self.segments)
if segment == "train":
train_tasks = self.meta_task_l[:train_task_n]
get_module_logger("MetaDatasetDS").info(f"The first train meta task: {train_tasks[0]}")
return train_tasks
return self.meta_task_l[:train_task_n]
elif segment == "test":
test_tasks = self.meta_task_l[train_task_n:]
get_module_logger("MetaDatasetDS").info(f"The first test meta task: {test_tasks[0]}")
return test_tasks
return self.meta_task_l[train_task_n:]
else:
raise NotImplementedError(f"This type of input is not supported")
elif isinstance(self.segments, str):
train_tasks = []
test_tasks = []
for t in self.meta_task_l:
test_end = t.task["dataset"]["kwargs"]["segments"]["test"][1]
if test_end is None or pd.Timestamp(test_end) < pd.Timestamp(self.segments):
train_tasks.append(t)
else:
test_tasks.append(t)
get_module_logger("MetaDatasetDS").info(f"The first train meta task: {train_tasks[0]}")
get_module_logger("MetaDatasetDS").info(f"The first test meta task: {test_tasks[0]}")
if segment == "train":
return train_tasks
elif segment == "test":
return test_tasks
raise NotImplementedError(f"This type of input is not supported")
else:
raise NotImplementedError(f"This type of input is not supported")

View File

@@ -53,12 +53,7 @@ class MetaModelDS(MetaTaskModel):
max_epoch=100,
seed=43,
alpha=0.0,
loss_skip_thresh=50,
):
"""
loss_skip_size: int
The number of threshold to skip the loss calculation for each day.
"""
self.step = step
self.hist_step_n = hist_step_n
self.clip_method = clip_method
@@ -68,7 +63,6 @@ class MetaModelDS(MetaTaskModel):
self.max_epoch = max_epoch
self.fitted = False
self.alpha = alpha
self.loss_skip_thresh = loss_skip_thresh
torch.manual_seed(seed)
def run_epoch(self, phase, task_list, epoch, opt, loss_l, ignore_weight=False):
@@ -94,14 +88,12 @@ class MetaModelDS(MetaTaskModel):
criterion = nn.MSELoss()
loss = criterion(pred, meta_input["y_test"])
elif self.criterion == "ic_loss":
criterion = ICLoss(self.loss_skip_thresh)
criterion = ICLoss()
try:
loss = criterion(pred, meta_input["y_test"], meta_input["test_idx"])
loss = criterion(pred, meta_input["y_test"], meta_input["test_idx"], skip_size=50)
except ValueError as e:
get_module_logger("MetaModelDS").warning(f"Exception `{e}` when calculating IC loss")
continue
else:
raise ValueError(f"Unknown criterion: {self.criterion}")
assert not np.isnan(loss.detach().item()), "NaN loss!"

View File

@@ -10,11 +10,7 @@ from qlib.log import get_module_logger
class ICLoss(nn.Module):
def __init__(self, skip_size=50):
super().__init__()
self.skip_size = skip_size
def forward(self, pred, y, idx):
def forward(self, pred, y, idx, skip_size=50):
"""forward.
FIXME:
- Some times it will be a slightly different from the result from `pandas.corr()`
@@ -37,7 +33,7 @@ class ICLoss(nn.Module):
skip_n = 0
for start_i, end_i in zip(diff_point, diff_point[1:]):
pred_focus = pred[start_i:end_i] # TODO: just for fake
if pred_focus.shape[0] < self.skip_size:
if pred_focus.shape[0] < skip_size:
# skip some days which have very small amount of stock.
skip_n += 1
continue
@@ -54,7 +50,6 @@ class ICLoss(nn.Module):
)
ic_all += ic_day
if len(diff_point) - 1 - skip_n <= 0:
__import__("ipdb").set_trace()
raise ValueError("No enough data for calculating IC")
if skip_n > 0:
get_module_logger("ICLoss").info(

View File

@@ -63,7 +63,6 @@ class LinearModel(Model):
df_train = pd.concat([df_train, df_valid])
except KeyError:
get_module_logger("LinearModel").info("include_valid=True, but valid does not exist")
df_train = df_train.dropna()
if df_train.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
if reweighter is not None:

View File

@@ -160,10 +160,6 @@ class ALSTM(Model):
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
elif self.metric == "mse":
mask = ~torch.isnan(label)
weight = torch.ones_like(label)
return -self.mse(pred[mask], label[mask], weight[mask])
raise ValueError("unknown metric `%s`" % self.metric)

View File

@@ -1,663 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
from torch.utils.data import DataLoader, RandomSampler, StackDataset
import os
import numpy as np
import pandas as pd
from typing import Callable, Optional, Text, Union
from sklearn.metrics import roc_auc_score, mean_squared_error
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import StackDataset
from qlib.data.dataset.weight import Reweighter
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH, TSDatasetH
from ...data.dataset.handler import DataHandlerLP
from ...utils import (
auto_filter_kwargs,
init_instance_by_config,
unpack_archive_with_buffer,
save_multiple_parts_file,
get_or_create_path,
)
from ...log import get_module_logger
from ...workflow import R
from qlib.contrib.meta.data_selection.utils import ICLoss
from torch.nn import DataParallel
class GeneralPTNN(Model):
"""General Pytorch Neural Network Model
Parameters
----------
input_dim : int
input dimension
output_dim : int
output dimension
layers : tuple
layer sizes
lr : float
learning rate
optimizer : str
optimizer name
GPU : int
the GPU ID used for training
"""
def __init__(
self,
lr=0.001,
max_steps=300,
batch_size=2000,
early_stop_rounds=50,
eval_steps=20,
optimizer="gd",
loss="mse",
GPU=0,
seed=None,
weight_decay=0.0,
data_parall=False,
scheduler: Optional[Union[Callable]] = "default", # when it is Callable, it accept one argument named optimizer
init_model=None,
eval_train_metric=False,
pt_model_uri="qlib.contrib.model.pytorch_nn.Net",
pt_model_kwargs={
"input_dim": 360,
"layers": (256,),
},
valid_key=DataHandlerLP.DK_L,
# TODO: Infer Key is a more reasonable key. But it requires more detailed processing on label processing
):
# Set logger.
self.logger = get_module_logger("DNNModelPytorch")
self.logger.info("DNN pytorch version...")
# set hyper-parameters.
self.lr = lr
self.max_steps = max_steps
self.batch_size = batch_size
self.early_stop_rounds = early_stop_rounds
self.eval_steps = eval_steps
self.optimizer = optimizer.lower()
self.loss_type = loss
if isinstance(GPU, str):
self.device = torch.device(GPU)
else:
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
self.weight_decay = weight_decay
self.data_parall = data_parall
self.eval_train_metric = eval_train_metric
self.valid_key = valid_key
self.best_step = None
self.logger.info(
"DNN parameters setting:"
f"\nlr : {lr}"
f"\nmax_steps : {max_steps}"
f"\nbatch_size : {batch_size}"
f"\nearly_stop_rounds : {early_stop_rounds}"
f"\neval_steps : {eval_steps}"
f"\noptimizer : {optimizer}"
f"\nloss_type : {loss}"
f"\nseed : {seed}"
f"\ndevice : {self.device}"
f"\nuse_GPU : {self.use_gpu}"
f"\nweight_decay : {weight_decay}"
f"\nenable data parall : {self.data_parall}"
f"\npt_model_uri: {pt_model_uri}"
f"\npt_model_kwargs: {pt_model_kwargs}"
)
if self.seed is not None:
np.random.seed(self.seed)
torch.manual_seed(self.seed)
if loss not in {"mse", "binary"}:
raise NotImplementedError("loss {} is not supported!".format(loss))
self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
if init_model is None:
self.dnn_model = init_instance_by_config({"class": pt_model_uri, "kwargs": pt_model_kwargs})
if self.data_parall:
self.dnn_model = DataParallel(self.dnn_model).to(self.device)
else:
self.dnn_model = init_model
self.logger.info("model:\n{:}".format(self.dnn_model))
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.dnn_model)))
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
elif optimizer.lower() == "gd":
self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
if scheduler == "default":
# Reduce learning rate when loss has stopped decrease
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.train_optimizer,
mode="min",
factor=0.5,
patience=10,
verbose=True,
threshold=0.0001,
threshold_mode="rel",
cooldown=0,
min_lr=0.00001,
eps=1e-08,
)
elif scheduler is None:
self.scheduler = None
else:
self.scheduler = scheduler(optimizer=self.train_optimizer)
self.dnn_model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def _eval_valid_dl(self, valid_loader, val_index):
with torch.no_grad():
self.dnn_model.eval()
val_loss = []
val_pred = []
val_label = []
for x_batch, y_batch in valid_loader:
x_batch = x_batch.to(self.device)
y_batch = y_batch.to(self.device)
cur_loss = self.get_loss(preds, y_batch, self.loss_type)
val_loss.append(cur_loss.detach().cpu().numpy().item())
val_loss = np.mean(val_loss)
val_pred = torch.cat(val_pred, axis=0).detach().cpu().numpy()
val_label = torch.cat(val_label, axis=0).detach().cpu().numpy()
val_metric = self.get_metric(val_pred, val_label, val_index).detach().cpu().numpy().item()
return val_loss, val_metric
def fit(
self,
dataset: Union[DatasetH, TSDatasetH],
verbose=True,
save_path=None,
):
ists = isinstance(dataset, TSDatasetH) # is this time series dataset
# prepare training
train_x = dataset.prepare("train", col_set="feature", data_key=DataHandlerLP.DK_L)
train_y = dataset.prepare("train", col_set="label", data_key=DataHandlerLP.DK_L)
train_ds = StackDataset(train_x, train_y)
train_sampler = RandomSampler(train_ds)
train_loader = DataLoader(train_ds, batch_size=self.batch_size, sampler=train_sampler)
# prepare validation
valid_x = dataset.prepare("train", col_set="feature", data_key=DataHandlerLP.DK_L)
valid_y = dataset.prepare("train", col_set="label", data_key=DataHandlerLP.DK_L)
valid_ds = StackDataset(valid_x, valid_y)
valid_loader = DataLoader(valid_ds, batch_size=self.batch_size, shuffle=False)
if ists:
val_index = valid_x.data_index
else:
val_index = valid_x.index
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_loss = np.inf
# train
self.logger.info("training...")
for step in range(1, self.max_steps + 1):
if stop_steps >= self.early_stop_rounds:
if verbose:
self.logger.info("\tearly stop")
break
loss = AverageMeter()
self.dnn_model.train()
self.train_optimizer.zero_grad()
for x_batch, y_batch in train_loader:
x_batch = x_batch.to(self.device)
y_batch = y_batch.to(self.device)
# forward
preds = self.dnn_model(x_batch)
cur_loss = self.get_loss(preds, y_batch, self.loss_type)
cur_loss.backward()
self.train_optimizer.step()
loss.update(cur_loss.item())
R.log_metrics(train_loss=loss.avg, step=step)
# validation
train_loss += loss.val
# for every `eval_steps` steps or at the last steps, we will evaluate the model.
if step % self.eval_steps == 0 or step == self.max_steps:
stop_steps += 1
train_loss /= self.eval_steps
val_loss, val_metric = self._eval_valid_dl(valid_loader, val_index)
R.log_metrics(val_loss=val_loss, step=step)
R.log_metrics(val_metric=val_metric, step=step)
if val_loss < best_loss:
if verbose:
self.logger.info(
"\tvalid loss update from {:.6f} to {:.6f}, save checkpoint.".format(
best_loss, val_loss
)
)
best_loss = val_loss
self.best_step = step
R.log_metrics(best_step=self.best_step, step=step)
stop_steps = 0
torch.save(self.dnn_model.state_dict(), save_path)
train_loss = 0
# update learning rate
if self.scheduler is not None:
auto_filter_kwargs(self.scheduler.step, warning=False)(metrics=val_loss, epoch=step)
R.log_metrics(lr=self.get_lr(), step=step)
# restore the optimal parameters after training
self.dnn_model.load_state_dict(torch.load(save_path, map_location=self.device))
if self.use_gpu:
torch.cuda.empty_cache()
def get_lr(self):
assert len(self.train_optimizer.param_groups) == 1
return self.train_optimizer.param_groups[0]["lr"]
def get_loss(self, pred, target, loss_type, w=None):
pred, target = pred.reshape(-1), target.reshape(-1)
if w is None:
# make it ones and the same size with pred
w = torch.ones_like(pred).to(pred.device)
if loss_type == "mse":
sqr_loss = torch.mul(pred - target, pred - target)
loss = torch.mul(sqr_loss, w).mean()
return loss
elif loss_type == "binary":
loss = nn.BCEWithLogitsLoss(weight=w)
return loss(pred, target)
else:
raise NotImplementedError("loss {} is not supported!".format(loss_type))
def get_metric(self, pred, target, index):
# NOTE: the order of the index must follow <datetime, instrument> sorted order
return -ICLoss()(pred, target, index) # pylint: disable=E1130
def _nn_predict(self, data, return_cpu=True):
"""Reusing predicting NN.
Scenarios
1) test inference (data may come from CPU and expect the output data is on CPU)
2) evaluation on training (data may come from GPU)
"""
if not isinstance(data, torch.Tensor):
if isinstance(data, pd.DataFrame):
data = data.values
data = torch.Tensor(data)
data = data.to(self.device)
preds = []
self.dnn_model.eval()
with torch.no_grad():
batch_size = 8096
for i in range(0, len(data), batch_size):
x = data[i : i + batch_size]
preds.append(self.dnn_model(x.to(self.device)).detach().reshape(-1))
if return_cpu:
preds = np.concatenate([pr.cpu().numpy() for pr in preds])
else:
preds = torch.cat(preds, axis=0)
return preds
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
x_test_pd = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
preds = self._nn_predict(x_test_pd)
return pd.Series(preds.reshape(-1), index=x_test_pd.index)
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
from ...model.utils import ConcatDataset
class GeneralPTNN(Model):
"""
Motivation:
We want to provide a Qlib General Pytorch Model Adaptor
You can reuse it for all kinds of Pytorch models.
It should include the training and predict process
Parameters
----------
d_feat : int
input dimension for each time step
metric: str
the evaluation metric used in early stop
optimizer : str
optimizer name
GPU : str
the GPU ID(s) used for training
"""
def __init__(
self,
n_epochs=200,
lr=0.001,
metric="",
batch_size=2000,
early_stop=20,
loss="mse",
optimizer="adam",
n_jobs=10,
GPU=0,
seed=None,
pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel",
pt_model_kwargs={
"d_feat":6,
"hidden_size":64,
"num_layers":2,
"dropout":0.,
},
):
# Set logger.
self.logger = get_module_logger("GeneralPTNN")
self.logger.info("GeneralPTNN pytorch version...")
# set hyper-parameters.
self.n_epochs = n_epochs
self.lr = lr
self.metric = metric
self.batch_size = batch_size
self.early_stop = early_stop
self.optimizer = optimizer.lower()
self.loss = loss
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.n_jobs = n_jobs
self.seed = seed
self.pt_model_uri, self.pt_model_kwargs = pt_model_uri, pt_model_kwargs
self.dnn_model = init_instance_by_config({"class": pt_model_uri, "kwargs": pt_model_kwargs})
self.logger.info(
"GeneralPTNN parameters setting:"
"\nn_epochs : {}"
"\nlr : {}"
"\nmetric : {}"
"\nbatch_size : {}"
"\nearly_stop : {}"
"\noptimizer : {}"
"\nloss_type : {}"
"\ndevice : {}"
"\nn_jobs : {}"
"\nuse_GPU : {}"
"\nseed : {}"
"\npt_model_uri: {}"
"\npt_model_kwargs: {}".format(
n_epochs,
lr,
metric,
batch_size,
early_stop,
optimizer.lower(),
loss,
self.device,
n_jobs,
self.use_gpu,
seed,
pt_model_uri,
pt_model_kwargs,
)
)
if self.seed is not None:
np.random.seed(self.seed)
torch.manual_seed(self.seed)
self.logger.info("model:\n{:}".format(self.dnn_model))
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.dnn_model)))
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr)
elif optimizer.lower() == "gd":
self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr)
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self.fitted = False
self.dnn_model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def mse(self, pred, label, weight):
loss = weight * (pred - label) ** 2
return torch.mean(loss)
def loss_fn(self, pred, label, weight=None):
mask = ~torch.isnan(label)
if weight is None:
weight = torch.ones_like(label)
if self.loss == "mse":
return self.mse(pred[mask], label[mask], weight[mask])
raise ValueError("unknown loss `%s`" % self.loss)
def metric_fn(self, pred, label):
mask = torch.isfinite(label)
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
def _get_fl(self, data: torch.Tensor):
"""
get feature and label from data
- Handle the different data shape of time series and tabular data
Parameters
----------
data : torch.Tensor
input data which maybe 3 dimension or 2 dimension
- 3dim: [batch_size, time_step, feature_dim]
- 2dim: [batch_size, feature_dim]
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
"""
if data.dim() == 3:
# it is a time series dataset
feature = data[:, :, 0:-1].to(self.device)
label = data[:, -1, -1].to(self.device)
elif data.dim() == 2:
# it is a tabular dataset
feature = data[:, 0:-1].to(self.device)
label = data[:, -1].to(self.device)
else:
raise ValueError("Unsupported data shape.")
return feature, label
def train_epoch(self, data_loader):
self.dnn_model.train()
for data, weight in data_loader:
feature , label = self._get_fl(data)
pred = self.dnn_model(feature.float())
loss = self.loss_fn(pred, label, weight.to(self.device))
self.train_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(self.dnn_model.parameters(), 3.0)
self.train_optimizer.step()
def test_epoch(self, data_loader):
self.dnn_model.eval()
scores = []
losses = []
for data, weight in data_loader:
feature = data[:, :, 0:-1].to(self.device)
# feature[torch.isnan(feature)] = 0
label = data[:, -1, -1].to(self.device)
with torch.no_grad():
pred = self.dnn_model(feature.float())
loss = self.loss_fn(pred, label, weight.to(self.device))
losses.append(loss.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
return np.mean(losses), np.mean(scores)
def fit(
self,
dataset: Union[DatasetH, TSDatasetH],
evals_result=dict(),
save_path=None,
reweighter=None,
):
ists = isinstance(dataset, TSDatasetH) # is this time series dataset
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
if dl_train.empty or dl_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
if reweighter is None:
wl_train = np.ones(len(dl_train))
wl_valid = np.ones(len(dl_valid))
elif isinstance(reweighter, Reweighter):
wl_train = reweighter.reweight(dl_train)
wl_valid = reweighter.reweight(dl_valid)
else:
raise ValueError("Unsupported reweighter type.")
# Preprocess for data. To align to Dataset Interface for DataLoader
if ists:
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
else:
# If it is a tabular, we convert the dataframe to numpy to be indexable by DataLoader
dl_train = dl_train.values
dl_valid = dl_valid.values
train_loader = DataLoader(
ConcatDataset(dl_train, wl_train),
batch_size=self.batch_size,
shuffle=True,
num_workers=self.n_jobs,
drop_last=True,
)
valid_loader = DataLoader(
ConcatDataset(dl_valid, wl_valid),
batch_size=self.batch_size,
shuffle=False,
num_workers=self.n_jobs,
drop_last=True,
)
del dl_train, dl_valid, wl_train, wl_valid
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
best_epoch = 0
evals_result["train"] = []
evals_result["valid"] = []
# train
self.logger.info("training...")
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
self.logger.info("training...")
self.train_epoch(train_loader)
self.logger.info("evaluating...")
train_loss, train_score = self.test_epoch(train_loader)
val_loss, val_score = self.test_epoch(valid_loader)
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
evals_result["train"].append(train_score)
evals_result["valid"].append(val_score)
if val_score > best_score:
best_score = val_score
stop_steps = 0
best_epoch = step
best_param = copy.deepcopy(self.dnn_model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.dnn_model.load_state_dict(best_param)
torch.save(best_param, save_path)
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset: Union[DatasetH, TSDatasetH]):
if not self.fitted:
raise ValueError("model is not fitted yet!")
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
dl_test.config(fillna_type="ffill+bfill")
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
self.dnn_model.eval()
preds = []
for data in test_loader:
feature = data[:, :, 0:-1].to(self.device)
with torch.no_grad():
pred = self.dnn_model(feature.float()).detach().cpu().numpy()
preds.append(pred)
return pd.Series(np.concatenate(preds), index=dl_test.get_index())

View File

@@ -1,25 +1,25 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import copy
from typing import Text, Union
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from qlib.workflow import R
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...log import get_module_logger
from ...model.base import Model
from ...utils import get_or_create_path
from .pytorch_utils import count_parameters
class GRU(Model):
@@ -212,31 +212,16 @@ class GRU(Model):
evals_result=dict(),
save_path=None,
):
# prepare training and validation data
dfs = {
k: dataset.prepare(
k,
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
for k in ["train", "valid"]
if k in dataset.segments
}
df_train, df_valid = dfs.get("train", pd.DataFrame()), dfs.get("valid", pd.DataFrame())
df_train, df_valid, df_test = dataset.prepare(
["train", "valid", "test"],
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
if df_train.empty or df_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
# check if training data is empty
if df_train.empty:
raise ValueError("Empty training data from dataset, please check your dataset config.")
df_train = df_train.dropna()
x_train, y_train = df_train["feature"], df_train["label"]
# check if validation data is provided
if not df_valid.empty:
df_valid = df_valid.dropna()
x_valid, y_valid = df_valid["feature"], df_valid["label"]
else:
x_valid, y_valid = None, None
x_valid, y_valid = df_valid["feature"], df_valid["label"]
save_path = get_or_create_path(save_path)
stop_steps = 0
@@ -250,42 +235,32 @@ class GRU(Model):
self.logger.info("training...")
self.fitted = True
best_param = copy.deepcopy(self.gru_model.state_dict())
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
self.logger.info("training...")
self.train_epoch(x_train, y_train)
self.logger.info("evaluating...")
train_loss, train_score = self.test_epoch(x_train, y_train)
val_loss, val_score = self.test_epoch(x_valid, y_valid)
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
evals_result["train"].append(train_score)
evals_result["valid"].append(val_score)
# evaluate on validation data if provided
if x_valid is not None and y_valid is not None:
val_loss, val_score = self.test_epoch(x_valid, y_valid)
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
evals_result["valid"].append(val_score)
if val_score > best_score:
best_score = val_score
stop_steps = 0
best_epoch = step
best_param = copy.deepcopy(self.gru_model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
if val_score > best_score:
best_score = val_score
stop_steps = 0
best_epoch = step
best_param = copy.deepcopy(self.gru_model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.gru_model.load_state_dict(best_param)
torch.save(best_param, save_path)
# Logging
rec = R.get_recorder()
for k, v_l in evals_result.items():
for i, v in enumerate(v_l):
rec.log_metrics(step=i, **{k: v})
if self.use_gpu:
torch.cuda.empty_cache()
@@ -317,7 +292,6 @@ class GRU(Model):
class GRUModel(nn.Module):
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0):
super().__init__()

View File

@@ -511,6 +511,7 @@ class TRAModel(Model):
class RNN(nn.Module):
"""RNN Model
Args:
@@ -600,6 +601,7 @@ class PositionalEncoding(nn.Module):
class Transformer(nn.Module):
"""Transformer Model
Args:
@@ -647,6 +649,7 @@ class Transformer(nn.Module):
class TRA(nn.Module):
"""Temporal Routing Adaptor (TRA)
TRA takes historical prediction errors & latent representation as inputs,

View File

@@ -1,17 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Here we have a comprehensive set of analysis classes.
Here is an example.
.. code-block:: python
from qlib.contrib.report.data.ana import FeaMeanStd
fa = FeaMeanStd(ret_df)
fa.plot_all(wspace=0.3, sub_figsize=(12, 3), col_n=5)
"""
import pandas as pd
import numpy as np
from qlib.contrib.report.data.base import FeaAnalyser
@@ -164,7 +152,6 @@ class FeaSkewTurt(NumFeaAnalyser):
self._kurt[col].plot(ax=right_ax, label="kurt", color="green")
right_ax.set_xlabel("")
right_ax.set_ylabel("kurt")
right_ax.grid(None) # set the grid to None to avoid two layer of grid
h1, l1 = ax.get_legend_handles_labels()
h2, l2 = right_ax.get_legend_handles_labels()
@@ -184,15 +171,12 @@ class FeaMeanStd(NumFeaAnalyser):
ax.set_xlabel("")
ax.set_ylabel("mean")
ax.legend()
ax.tick_params(axis="x", rotation=90)
right_ax = ax.twinx()
self._std[col].plot(ax=right_ax, label="std", color="green")
right_ax.set_xlabel("")
right_ax.set_ylabel("std")
right_ax.tick_params(axis="x", rotation=90)
right_ax.grid(None) # set the grid to None to avoid two layer of grid
h1, l1 = ax.get_legend_handles_labels()
h2, l2 = right_ax.get_legend_handles_labels()

View File

@@ -14,24 +14,6 @@ from qlib.contrib.report.utils import sub_fig_generator
class FeaAnalyser:
def __init__(self, dataset: pd.DataFrame):
"""
Parameters
----------
dataset : pd.DataFrame
We often have multiple columns for dataset. Each column corresponds to one sub figure.
There will be a datatime column in the index levels.
Aggretation will be used for more summarized metrics overtime.
Here is an example of data:
.. code-block::
return
datetime instrument
2007-02-06 equity_tpx 0.010087
equity_spx 0.000786
"""
self._dataset = dataset
with TimeInspector.logt("calc_stat_values"):
self.calc_stat_values()

View File

@@ -4,7 +4,7 @@ import matplotlib.pyplot as plt
import pandas as pd
def sub_fig_generator(sub_figsize=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None, sharex=False, sharey=False):
def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None, sharex=False, sharey=False):
"""sub_fig_generator.
it will return a generator, each row contains <col_n> sub graph
@@ -13,7 +13,7 @@ def sub_fig_generator(sub_figsize=(3, 3), col_n=10, row_n=1, wspace=None, hspace
Parameters
----------
sub_figsize :
sub_fs :
the figure size of each subgraph in <col_n> * <row_n> subgraphs
col_n :
the number of subgraph in each row; It will generating a new graph after generating <col_n> of subgraphs.
@@ -33,7 +33,7 @@ def sub_fig_generator(sub_figsize=(3, 3), col_n=10, row_n=1, wspace=None, hspace
while True:
fig, axes = plt.subplots(
row_n, col_n, figsize=(sub_figsize[0] * col_n, sub_figsize[1] * row_n), sharex=sharex, sharey=sharey
row_n, col_n, figsize=(sub_fs[0] * col_n, sub_fs[1] * row_n), sharex=sharex, sharey=sharey
)
plt.subplots_adjust(wspace=wspace, hspace=hspace)
axes = axes.reshape(row_n, col_n)

View File

@@ -73,8 +73,8 @@ class Rolling:
The horizon of the prediction target.
This is used to override the prediction horizon of the file.
h_path : Optional[str]
It is other data source that is dumped as a handler. It will override the data handler section in the config.
If it is not given, it will create a customized cache for the handler when `enable_handler_cache=True`
the dumped data handler;
It may come from other data source. It will override the data handler in the config.
test_end : Optional[str]
the test end for the data. It is typically used together with the handler
You can do the same thing with task_ext_conf in a more complicated way
@@ -119,7 +119,7 @@ class Rolling:
with self.conf_path.open("r") as f:
return yaml.safe_load(f)
def _replace_handler_with_cache(self, task: dict):
def _replace_hanler_with_cache(self, task: dict):
"""
Due to the data processing part in original rolling is slow. So we have to
This class tries to add more feature
@@ -159,20 +159,13 @@ class Rolling:
# - get horizon automatically from the expression!!!!
raise NotImplementedError(f"This type of input is not supported")
else:
if enable_handler_cache and self.h_path is not None:
self.logger.info("Fail to override the horizon due to data handler cache")
else:
self.logger.info("The prediction horizon is overrided")
if isinstance(task["dataset"]["kwargs"]["handler"], dict):
task["dataset"]["kwargs"]["handler"]["kwargs"]["label"] = [
"Ref($close, -{}) / Ref($close, -1) - 1".format(self.horizon + 1)
]
else:
self.logger.warning("Try to automatically configure the lablel but failed.")
self.logger.info("The prediction horizon is overrided")
task["dataset"]["kwargs"]["handler"]["kwargs"]["label"] = [
"Ref($close, -{}) / Ref($close, -1) - 1".format(self.horizon + 1)
]
if self.h_path is not None or enable_handler_cache:
# if we already have provided data source or we want to create one
task = self._replace_handler_with_cache(task)
if enable_handler_cache:
task = self._replace_hanler_with_cache(task)
task = self._update_start_end_time(task)
if self.task_ext_conf is not None:
@@ -180,16 +173,6 @@ class Rolling:
self.logger.info(task)
return task
def run_basic_task(self):
"""
Run the basic task without rolling.
This is for fast testing for model tunning.
"""
task = self.basic_task()
print(task)
trainer = TrainerR(experiment_name=self.exp_name)
trainer([task])
def get_task_list(self) -> List[dict]:
"""return a batch of tasks for rolling."""
task = self.basic_task()

View File

@@ -80,11 +80,6 @@ class DDGDA(Rolling):
sim_task_model: UTIL_MODEL_TYPE = "gbdt",
meta_1st_train_end: Optional[str] = None,
alpha: float = 0.01,
loss_skip_thresh: int = 50,
fea_imp_n: Optional[int] = 30,
meta_data_proc: Optional[str] = "V01",
segments: Union[float, str] = 0.62,
hist_step_n: int = 30,
working_dir: Optional[Union[str, Path]] = None,
**kwargs,
):
@@ -99,15 +94,6 @@ class DDGDA(Rolling):
alpha: float
Setting the L2 regularization for ridge
The `alpha` is only passed to MetaModelDS (it is not passed to sim_task_model currently..)
loss_skip_thresh: int
The thresh to skip the loss calculation for each day. If the number of item is less than it, it will skip the loss on that day.
meta_data_proc : Optional[str]
How we process the meta dataset for learning meta model.
segments : Union[float, str]
if segments is a float:
The ratio of training data in the meta task dataset
if segments is a string:
it will try its best to put its data in training and ensure that the date `segments` is in the test set
"""
# NOTE:
# the horizon must match the meaning in the base task template
@@ -118,22 +104,14 @@ class DDGDA(Rolling):
super().__init__(**kwargs)
self.working_dir = self.conf_path.parent if working_dir is None else Path(working_dir)
self.proxy_hd = self.working_dir / "handler_proxy.pkl"
self.fea_imp_n = fea_imp_n
self.meta_data_proc = meta_data_proc
self.loss_skip_thresh = loss_skip_thresh
self.segments = segments
self.hist_step_n = hist_step_n
def _adjust_task(self, task: dict, astype: UTIL_MODEL_TYPE):
"""
Base on the original task, we need to do some extra things.
some task are use for special purpose.
For example:
- GBDT for calculating feature importance
- Linear or GBDT for calculating similarity
- Datset (well processed) that aligned to Linear that for meta learning
So we may need to change the dataset and model for the special purpose and other settings remains the same.
"""
# NOTE: here is just for aligning with previous implementation
# It is not necessary for the current implementation
@@ -141,16 +119,12 @@ class DDGDA(Rolling):
if astype == "gbdt":
task["model"] = LGBM_MODEL
if isinstance(handler, dict):
# We don't need preprocessing when using GBDT model
for k in ["infer_processors", "learn_processors"]:
if k in handler.setdefault("kwargs", {}):
handler["kwargs"].pop(k)
elif astype == "linear":
task["model"] = LINEAR_MODEL
if isinstance(handler, dict):
handler["kwargs"].update(PROC_ARGS)
else:
self.logger.warning("The handler can't be adjusted.")
handler["kwargs"].update(PROC_ARGS)
else:
raise ValueError(f"astype not supported: {astype}")
return task
@@ -181,15 +155,12 @@ class DDGDA(Rolling):
The meta model will be trained upon the proxy forecasting model.
This dataset is for the proxy forecasting model.
"""
topk = 30
fi = self._get_feature_importance()
col_selected = fi.nlargest(topk)
# NOTE: adjusting to `self.sim_task_model` just for aligning with previous implementation.
# In previous version. The data for proxy model is using sim_task_model's way for processing
task = self._adjust_task(self.basic_task(enable_handler_cache=False), self.sim_task_model)
task = replace_task_handler_with_cache(task, self.working_dir)
# if self.meta_data_proc is not None:
# else:
# # Otherwise, we don't need futher processing
# task = self.basic_task()
dataset = init_instance_by_config(task["dataset"])
prep_ds = dataset.prepare(slice(None), col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
@@ -197,18 +168,12 @@ class DDGDA(Rolling):
feature_df = prep_ds["feature"]
label_df = prep_ds["label"]
if self.fea_imp_n is not None:
fi = self._get_feature_importance()
col_selected = fi.nlargest(self.fea_imp_n)
feature_selected = feature_df.loc[:, col_selected.index]
else:
feature_selected = feature_df
feature_selected = feature_df.loc[:, col_selected.index]
if self.meta_data_proc == "V01":
feature_selected = feature_selected.groupby("datetime", group_keys=False).apply(
lambda df: (df - df.mean()).div(df.std())
)
feature_selected = feature_selected.fillna(0.0)
feature_selected = feature_selected.groupby("datetime", group_keys=False).apply(
lambda df: (df - df.mean()).div(df.std())
)
feature_selected = feature_selected.fillna(0.0)
df_all = {
"label": label_df.reindex(feature_selected.index),
@@ -258,10 +223,7 @@ class DDGDA(Rolling):
# 1) leverage the simplified proxy forecasting model to train meta model.
# - Only the dataset part is important, in current version of meta model will integrate the
# NOTE:
# - The train_start for training meta model does not necessarily align with final rolling
# But please select a right time to make sure the finnal rolling tasks are not leaked in the training data.
# - The test_start is automatically aligned to the next day of test_end. Validation is ignored.
# the train_start for training meta model does not necessarily align with final rolling
train_start = "2008-01-01" if self.train_start is None else self.train_start
train_end = "2010-12-31" if self.meta_1st_train_end is None else self.meta_1st_train_end
test_start = (pd.Timestamp(train_end) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
@@ -287,9 +249,9 @@ class DDGDA(Rolling):
kwargs = dict(
task_tpl=proxy_forecast_model_task,
step=self.step,
segments=self.segments, # keep test period consistent with the dataset yaml
segments=0.62, # keep test period consistent with the dataset yaml
trunc_days=1 + self.horizon,
hist_step_n=self.hist_step_n,
hist_step_n=30,
fill_method=fill_method,
rolling_ext_days=0,
)
@@ -306,13 +268,7 @@ class DDGDA(Rolling):
with R.start(experiment_name=self.meta_exp_name):
R.log_params(**kwargs)
mm = MetaModelDS(
step=self.step,
hist_step_n=kwargs["hist_step_n"],
lr=0.001,
max_epoch=30,
seed=43,
alpha=self.alpha,
loss_skip_thresh=self.loss_skip_thresh,
step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=30, seed=43, alpha=self.alpha
)
mm.fit(md)
R.save_objects(model=mm)

View File

@@ -373,6 +373,7 @@ class WeightStrategyBase(BaseSignalStrategy):
class EnhancedIndexingStrategy(WeightStrategyBase):
"""Enhanced Indexing Strategy
Enhanced indexing combines the arts of active management and passive management,

View File

@@ -35,7 +35,7 @@ class Client:
def connect_server(self):
"""Connect to server."""
try:
self.sio.connect(f"ws://{self.server_host}:{self.server_port}")
self.sio.connect("ws://" + self.server_host + ":" + str(self.server_port))
except socketio.exceptions.ConnectionError:
self.logger.error("Cannot connect to server - check your network or server status")

View File

@@ -536,6 +536,7 @@ class DatasetProvider(abc.ABC):
"""
if len(fields) == 0:
raise ValueError("fields cannot be empty")
fields = fields.copy()
column_names = [str(f) for f in fields]
return column_names
@@ -616,7 +617,7 @@ class DatasetProvider(abc.ABC):
data = pd.DataFrame(obj)
if not data.empty and not np.issubdtype(data.index.dtype, np.dtype("M")):
# If the underlaying provides the data not in datetime format, we'll convert it into datetime format
# If the underlaying provides the data not in datatime formmat, we'll convert it into datetime format
_calendar = Cal.calendar(freq=freq)
data.index = _calendar[data.index.values.astype(int)]
data.index.names = ["datetime"]

View File

@@ -403,7 +403,7 @@ class TSDataSampler:
np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype),
axis=0,
)
self.nan_idx = len(self.data_arr) - 1 # The last line is all NaN; setting it to -1 can cause bug #1716
self.nan_idx = -1 # The last line is all NaN
# the data type will be changed
# The index of usable data is between start_idx and end_idx

View File

@@ -7,7 +7,7 @@ from pathlib import Path
import warnings
import pandas as pd
from typing import Tuple, Union, List, Dict
from typing import Tuple, Union, List
from qlib.data import D
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point
@@ -247,14 +247,10 @@ class StaticDataLoader(DataLoader, Serializable):
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
self._maybe_load_raw_data()
# 1) Filter by instruments
if instruments is None:
df = self._data
else:
df = self._data.loc(axis=0)[:, instruments]
# 2) Filter by Datetime
if start_time is None and end_time is None:
return df # NOTE: avoid copy by loc
# pd.Timestamp(None) == NaT, use NaT as index can not fetch correct thing, so do not change None.
@@ -279,55 +275,6 @@ class StaticDataLoader(DataLoader, Serializable):
self._data = self._config
class NestedDataLoader(DataLoader):
"""
We have multiple DataLoader, we can use this class to combine them.
"""
def __init__(self, dataloader_l: List[Dict], join="left") -> None:
"""
Parameters
----------
dataloader_l : list[dict]
A list of dataloader, for exmaple
.. code-block:: python
nd = NestedDataLoader(
dataloader_l=[
{
"class": "qlib.contrib.data.loader.Alpha158DL",
}, {
"class": "qlib.contrib.data.loader.Alpha360DL",
"kwargs": {
"config": {
"label": ( ["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
}
}
}
]
)
join :
it will pass to pd.concat when merging it.
"""
super().__init__()
self.data_loader_l = [
(dl if isinstance(dl, DataLoader) else init_instance_by_config(dl)) for dl in dataloader_l
]
self.join = join
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
df_full = None
for dl in self.data_loader_l:
df_current = dl.load(instruments, start_time, end_time)
if df_full is None:
df_full = df_current
else:
df_full = pd.merge(df_full, df_current, left_index=True, right_index=True, how=self.join)
return df_full.sort_index(axis=1)
class DataLoaderDH(DataLoader):
"""DataLoaderDH
DataLoader based on (D)ata (H)andler

View File

@@ -318,13 +318,9 @@ class CSZScoreNorm(Processor):
# try not modify original dataframe
if not isinstance(self.fields_group, list):
self.fields_group = [self.fields_group]
# depress warning by references:
# https://stackoverflow.com/questions/20625582/how-to-deal-with-settingwithcopywarning-in-pandas
# https://pandas.pydata.org/pandas-docs/stable/user_guide/options.html#getting-and-setting-options
with pd.option_context("mode.chained_assignment", None):
for g in self.fields_group:
cols = get_group_columns(df, g)
df[cols] = df[cols].groupby("datetime", group_keys=False).apply(self.zscore_func)
for g in self.fields_group:
cols = get_group_columns(df, g)
df[cols] = df[cols].groupby("datetime", group_keys=False).apply(self.zscore_func)
return df

View File

@@ -9,7 +9,7 @@ if TYPE_CHECKING:
from qlib.data.dataset import DataHandler
def get_level_index(df: pd.DataFrame, level: Union[str, int]) -> int:
def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
"""
get the level index of `df` given `level`

20
qlib/finco/.env.example Normal file
View File

@@ -0,0 +1,20 @@
OPENAI_API_KEY=your_api_key
# USE_AZURE=True
# AZURE_API_BASE=your_api_base
# AZURE_API_VERSION=your_api_version
# use gpt-4 means more token but more wait time
# MODEL=gpt-4
# MAX_TOKENS=1600
# MAX_RETRY=1000
MAX_TOKENS=1600
MAX_RETRY=120
CONTINOUS_MODE=True
DEBUG_MODE=True
# TEMPERATURE=

22
qlib/finco/README.md Normal file
View File

@@ -0,0 +1,22 @@
# This is an experimental branch of "`FI`nancial `CO`pilot of `Qlib`"
## Installation
- To run this module, you need to first install Qlib following the instruction in [install-from-source](/README.md#install-from-source) or follow:
```python
python -m pip install git+https://github.com/microsoft/qlib.git@finco
```
- then you need to install other dependencies of finco:
```python
python -m pip install pydantic openai python-dotenv
```
## Quick run
To run this module, you can start the workflow easily with one command:
```sh
cd qlib/finco; python cli.py "your prompt"
```

13
qlib/finco/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
DIRNAME = Path(__file__).absolute().resolve().parent
def get_finco_path() -> Path:
"""
return the template path
Because the template path is located in the folder. We don't know where it is located. So __file__ for this module will be used.
"""
return DIRNAME

15
qlib/finco/cli.py Normal file
View File

@@ -0,0 +1,15 @@
import fire
from qlib.finco.workflow import WorkflowManager
from dotenv import load_dotenv
from qlib import auto_init
def main(prompt=None):
load_dotenv(verbose=True, override=True)
wm = WorkflowManager()
wm.run(prompt)
if __name__ == "__main__":
auto_init()
fire.Fire(main)

15
qlib/finco/cli_learn.py Normal file
View File

@@ -0,0 +1,15 @@
import fire
from qlib.finco.workflow import LearnManager
from dotenv import load_dotenv
from qlib import auto_init
def main(prompt=None):
load_dotenv(verbose=True, override=True)
lm = LearnManager()
lm.run(prompt)
if __name__ == "__main__":
auto_init()
fire.Fire(main)

32
qlib/finco/conf.py Normal file
View File

@@ -0,0 +1,32 @@
# TODO: use pydantic for other modules in Qlib
# from pydantic_settings import BaseSettings
from qlib.finco.utils import SingletonBaseClass
import os
class Config(SingletonBaseClass):
"""
This config is for fast demo purpose.
Please use BaseSettings insetead in the future
"""
def __init__(self):
self.use_azure = os.getenv("USE_AZURE") == "True"
self.temperature = 0 if os.getenv("TEMPERATURE") is None else float(os.getenv("TEMPERATURE"))
self.max_tokens = 800 if os.getenv("MAX_TOKENS") is None else int(os.getenv("MAX_TOKENS"))
self.openai_api_key = os.getenv("OPENAI_API_KEY")
self.use_azure = os.getenv("USE_AZURE") == "True"
self.azure_api_base = os.getenv("AZURE_API_BASE")
self.azure_api_version = os.getenv("AZURE_API_VERSION")
self.model = os.getenv("MODEL") or ("gpt-35-turbo" if self.use_azure else "gpt-3.5-turbo")
self.max_retry = int(os.getenv("MAX_RETRY")) if os.getenv("MAX_RETRY") is not None else None
self.continuous_mode = (
os.getenv("CONTINOUS_MODE") == "True" if os.getenv("CONTINOUS_MODE") is not None else False
)
self.debug_mode = os.getenv("DEBUG_MODE") == "True" if os.getenv("DEBUG_MODE") is not None else False
self.workspace = os.getenv("WORKSPACE") if os.getenv("WORKSPACE") is not None else "./finco_workspace"
self.max_past_message_include = int(os.getenv("MAX_PAST_MESSAGE_INCLUDE") or 6) // 2 * 2

97
qlib/finco/context.py Normal file
View File

@@ -0,0 +1,97 @@
from dataclasses import dataclass, field
import copy
from pathlib import Path
from typing import Optional, List
from qlib.finco.log import FinCoLog
from qlib.typehint import Literal
from qlib.finco.utils import similarity
@dataclass
class Design:
plan: str
classes: str
decision: str
@dataclass
class Exp:
"""Experiment"""
# compoments
dataset: Optional[Design] = None
datahandler: Optional[Design] = None
model: Optional[Design] = None
record: Optional[Design] = None
strategy: Optional[Design] = None
backtest: Optional[Design] = None
# basic
template: Optional[Path] = None
# rolling strategy. None indicates no rolling
rolling: Optional[Literal["base", "ddgda"]] = None
@dataclass
class StructContext:
"""Part of the context have clear meaning and structure, so they will be saved here and can be easily retrieved and understood"""
# TODO: move more content in WorkflowContextManager.context to here
workspace: Path
exp_list: List[Exp] = field(default_factory=list) # the planned experiments
class WorkflowContextManager:
"""Context Manager stores the context of the workflow"""
"""All context are key value pairs which saves the input, output and status of the whole workflow"""
def __init__(self, workspace: Path) -> None:
self.context = {}
self.logger = FinCoLog()
# this context is public
self.struct_context = StructContext(workspace) # TODO: move more content in context to here
self.set_context("workspace", workspace) # TODO: remove me
def set_context(self, key, value):
if key in self.context:
self.logger.warning("The key already exists in the context, the value will be overwritten")
self.context[key] = value
def get_context(self, key):
# NOTE: if the key doesn't exist, return None. In the future, we may raise an error to detect abnormal behavior
if key not in self.context:
self.logger.warning("The key doesn't exist in the context")
return None
return self.context[key]
def update_context(self, key, new_value):
# NOTE: if the key doesn't exist, return None. In the future, we may raise an error to detect abnormal behavior
if key not in self.context:
self.logger.warning("The key doesn't exist in the context")
self.context.update({key: new_value})
def get_all_context(self):
"""return a deep copy of the context"""
"""TODO: do we need to return a deep copy?"""
return copy.deepcopy(self.context)
def retrieve(self, query: str) -> dict:
if query in self.context.keys():
return {query: self.context.get(query)}
# Note: retrieve information from context by string similarity maybe abandon in future
scores = {}
for k, v in self.context.items():
scores.update({k: max(similarity(query, k), similarity(query, v))})
max_score_key = max(scores, key=scores.get)
return {max_score_key: self.context.get(max_score_key)}
def clear(self, reserve: list = None):
if reserve is None:
reserve = []
_context = {k: self.get_context(k) for k in reserve}
self.context = _context

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

307
qlib/finco/demo_failed.yml Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

539
qlib/finco/knowledge.py Normal file
View File

@@ -0,0 +1,539 @@
from pathlib import Path
from jinja2 import Template
from typing import List, Union
import pickle
import yaml
from qlib.workflow import R
from qlib.finco.log import FinCoLog
from qlib.finco.llm import APIBackend
from qlib.finco.utils import similarity, random_string, SingletonBaseClass
logger = FinCoLog()
class Storage:
"""
This class is responsible for storage and loading of Knowledge related data.
"""
def __init__(self, path: Union[str, Path], name: str = None):
self.path = path if isinstance(path, Path) else Path(path)
self.name = name if name else self.path.name
self.source = None
# todo: get document by key
self.documents = []
def add(self, documents: List):
self.documents.extend(documents)
self.save()
def load(self, **kwargs):
raise NotImplementedError(f"Please implement the `load` method.")
def save(self, **kwargs):
raise NotImplementedError(f"Please implement the `save` method.")
class PickleStorage(Storage):
"""
This class is responsible for storage and loading of Knowledge related data in pickle format.
"""
def __init__(self, path: Union[str, Path]):
super().__init__(path)
@classmethod
def load(cls, path: Union[str, Path]):
"""use pickle as the default load method"""
path = path if isinstance(path, Path) else Path(path)
with open(path, "rb") as f:
return pickle.load(f)
def save(self, **kwargs):
"""use pickle as the default save method"""
Path.mkdir(self.path.parent, exist_ok=True)
with open(self.path, "wb") as f:
pickle.dump(self, f)
class YamlStorage(Storage):
"""
This class is responsible for storage and loading of Knowledge related data in yaml format.
"""
DEFAULT_NAME = "storage.yml"
def __init__(self, path: Union[str, Path]):
super().__init__(path)
assert self.path.name, "Yaml storage should specify file name."
self.load()
def load(self):
"""load data from yaml format file"""
try:
self.documents = yaml.safe_load(self.path.open())
except FileNotFoundError:
logger.warning(f"YamlStorage: file {self.path} doesn't exist.")
def save(self, **kwargs):
"""use pickle as the default save method"""
Path.mkdir(self.path.parent, exist_ok=True, parents=True)
with open(self.path, 'w') as f:
yaml.dump(self.documents, f)
class ExperimentStorage(Storage):
"""
This class is responsible for storage and loading of mlflow related data.
"""
def __init__(self, exp_name, path=None):
super().__init__(path=path)
self.exp_name = exp_name
self.exp = None
self.recs = []
self.docs = []
def load(self, exp_name, rec_id=None):
recs = []
self.exp = R.get_exp(experiment_name=exp_name)
for r in self.exp.list_recorders(rtype=self.exp.RT_L):
if rec_id is not None and r.id != rec_id:
continue
recs.append(r)
self.recs.extend(recs)
class Knowledge:
"""
Use to handle knowledge in finCo such as experiment and outside domain information
"""
def __init__(self, storages: Union[List[Storage], Storage], name: str = None):
self.name = name if name else random_string()
self.workdir = Path.cwd().joinpath("knowledge")
self.storages = [storages] if isinstance(storages, Storage) else storages
self.knowledge = []
def get_storage(self, name: str):
"""
return first storage matched given name, else return None
"""
for storage in self.storages:
if storage.name == name:
return storage
return None
def summarize(self, **kwargs):
"""
summarize storage data to knowledge, default knowledge is storage.documents
Parameters
----------
Return
------
"""
knowledge = []
for storage in self.storages:
knowledge.extend(storage.documents)
self.knowledge = knowledge
@classmethod
def load(cls, path: Union[str, Path]):
"""
Load knowledge in memory
use pickle as the default file type
Parameters
----------
Return
------
"""
""""""
path = path if isinstance(path, Path) else Path(path)
with open(path, "rb") as f:
return pickle.load(f)
def brief(self, **kwargs):
"""
Return a brief summary of knowledge
Parameters
----------
Return
------
"""
raise NotImplementedError(f"Please implement the `load` method.")
def save(self, **kwargs):
"""save knowledge persistently"""
# todo: storages save index only
Path.mkdir(self.workdir.joinpath(self.name), exist_ok=True)
with open(self.workdir.joinpath(self.name).joinpath("knowledge.pkl"), "wb") as f:
pickle.dump(self, f)
class ExperimentKnowledge(Knowledge):
"""
Handle knowledge from experiments
"""
def __init__(self, storages: Union[List[ExperimentStorage], ExperimentStorage]):
super().__init__(storages=storages)
self.storage = storages
def brief(self):
docs = []
for recorder in self.storage.recs:
docs.append(
{
"exp_name": self.storage.exp.name,
"record_info": recorder.info,
"config": recorder.load_object("config"),
"context_summary": recorder.load_object("context_summary"),
}
)
return docs
class PracticeKnowledge(Knowledge):
"""
some template sentence for now
"""
def __init__(self, storages: Union[List[YamlStorage], YamlStorage]):
super().__init__(storages=storages, name="practice")
self.summarize()
def add(self, docs: List, storage_name: str = YamlStorage.DEFAULT_NAME):
s = "\n".join(docs)
logger.info(f'Add to Practice Knowledge:\n {s}')
storage = self.get_storage(storage_name)
if storage is None:
storage = YamlStorage(path=self.workdir.joinpath(self.name).joinpath(storage_name))
storage.add(documents=docs)
self.storages.append(storage)
else:
storage.add(documents=docs)
self.summarize()
self.save()
class FinanceKnowledge(Knowledge):
"""
Knowledge from articles
"""
def __init__(self, storages: Union[List[YamlStorage], YamlStorage]):
super().__init__(storages=storages, name="finance")
storage = self.get_storage(YamlStorage.DEFAULT_NAME)
if len(storage.documents) == 0:
docs = self.read_files_in_directory(self.workdir.joinpath(self.name))
self.add(docs)
self.summarize()
def add(self, docs: List, storage_name: str = YamlStorage.DEFAULT_NAME):
storage = self.get_storage(storage_name)
if storage is None:
storage = YamlStorage(path=self.workdir.joinpath(self.name).joinpath(storage_name))
storage.add(documents=docs)
self.storages.append(storage)
else:
storage.add(documents=docs)
self.summarize()
self.save()
@staticmethod
def read_files_in_directory(directory) -> List:
"""
read all .txt files under directory
"""
# todo: split article in trunks
file_contents = []
for file_path in Path(directory).rglob("*.txt"):
if file_path.is_file():
file_content = file_path.read_text(encoding="utf-8")
file_contents.append(file_content)
return file_contents
class ExecuteKnowledge(Knowledge):
"""
Config and associate execution result(pass or error message). We can regard the example in prompt as pass execution
"""
def __init__(self, storages: Union[List[YamlStorage], YamlStorage]):
super().__init__(storages=storages, name="execute")
self.summarize()
storage = self.get_storage(YamlStorage.DEFAULT_NAME)
if len(storage.documents) == 0:
docs = [{"content": "[Success]: XXXX, the results looks reasonable # Keywords: supervised learning, data"},
{"content": "[Fail]: XXXX, it raise memory error due to YYYYY "
"# Keywords: supervised learning, data"}]
self.add(docs)
self.summarize()
def add(self, docs: List, storage_name: str = YamlStorage.DEFAULT_NAME):
storage = self.get_storage(storage_name)
if storage is None:
storage = YamlStorage(path=self.workdir.joinpath(self.name).joinpath(storage_name))
storage.add(documents=docs)
self.storages.append(storage)
else:
storage.add(documents=docs)
self.summarize()
self.save()
class InfrastructureKnowledge(Knowledge):
"""
Knowledge from sentences, docstring, and code
"""
def __init__(self, storages: Union[List[YamlStorage], YamlStorage]):
super().__init__(storages=storages, name="infrastructure")
storage = self.get_storage(YamlStorage.DEFAULT_NAME)
if len(storage.documents) == 0:
docs = self.get_functions_and_docstrings(Path(__file__).parent.parent.parent)
docs.extend([{"docstring": "All the models can be import from `qlib.contrib.models` "
"# Keywords: supervised learning"},
{"docstring": "The API to run rolling models can be found in … #Keywords: control"},
{"docstring": "Here are a list of Qlibs available analyzers. #KEYWORDS: analysis"}])
self.add(docs)
self.summarize()
def add(self, docs: List, storage_name: str = YamlStorage.DEFAULT_NAME):
storage = self.get_storage(storage_name)
if storage is None:
storage = YamlStorage(path=self.workdir.joinpath(self.name).joinpath(storage_name))
storage.add(documents=docs)
self.storages.append(storage)
else:
storage.add(documents=docs)
self.summarize()
self.save()
def get_functions_and_docstrings(self, directory) -> List:
"""
get all method and docstring in .py files under directory
"""
functions = []
for py_file_path in Path(directory).rglob("*.py"):
for _functions in self.get_functions_with_docstrings(py_file_path):
functions.append(_functions)
return functions
@staticmethod
def get_functions_with_docstrings(file_path):
"""
Extract method name and docstring using string matching method
"""
with open(file_path, "r", encoding="utf8") as f:
lines = f.readlines()
functions = []
current_func = None
docstring = None
for line in lines:
if line.strip().startswith("def ") or line.strip().startswith("class "):
func = line.strip().split(" ")[1].split("(")[0]
if func.startswith("__"):
continue
if current_func is not None:
docstring = docstring.replace('"""', "") if docstring else docstring
functions.append({"function": current_func, "docstring": docstring})
current_func = f"{file_path.name.split('.')[0]}.{func}"
docstring = None
elif current_func is not None and docstring is None and line.strip().startswith('"""'):
docstring = line
elif current_func is not None and docstring is not None:
docstring += line.strip()
if line.strip().endswith('"""'):
docstring = docstring.replace('"""', "") if docstring else docstring
functions.append({"function": current_func, "docstring": docstring})
current_func = None
docstring = None
return functions
class Topic:
def __init__(self, name: str, system: Template, user: Template):
self.name = name
self.system_prompt_template = system
self.user_prompt_template = user
self.docs = []
self.knowledge = None
self.logger = FinCoLog()
def summarize(self, practice_knowlege, user_intention, target, diffrence, target_metrics):
system_prompt = self.system_prompt_template.render(topic=self.name)
user_prompt = self.user_prompt_template.render(
experiment_1_info = practice_knowlege[0],
experiment_2_info = practice_knowlege[1],
user_intention=user_intention,
target=target,
diffrence=diffrence,
target_metrics=target_metrics)
response = APIBackend().build_messages_and_create_chat_completion(user_prompt=user_prompt, system_prompt=system_prompt)
self.knowledge = response
self.docs = practice_knowlege
self.logger.info(f"Summary of {self.name}:\n{self.knowledge}")
class KnowledgeBase(SingletonBaseClass):
"""
Load knowledge, offer brief information of knowledge and common handle interfaces
"""
KT_EXECUTE = "execute"
KT_PRACTICE = "practice"
KT_FINANCE = "finance"
KT_INFRASTRUCTURE = "infrastructure"
def __init__(self, workdir=None):
self.logger = FinCoLog()
self.workdir = Path(workdir) if workdir else Path.cwd()
if not self.workdir.exists():
self.logger.warning(f"{self.workdir} not exist, create empty directory.")
Path.mkdir(self.workdir)
self.practice_knowledge = self.load_practice_knowledge(self.workdir)
self.execute_knowledge = self.load_execute_knowledge(self.workdir)
self.finance_knowledge = self.load_finance_knowledge(self.workdir)
self.infrastructure_knowledge = self.load_infrastructure_knowledge(self.workdir)
def load_experiment_knowledge(self, path) -> List:
# similar to practice knowledge, not use for now
if isinstance(path, str):
path = Path(path)
knowledge = []
path = path if path.name == "mlruns" else path.joinpath("mlruns")
# todo: check the influence of set uri
R.set_uri(path.as_uri())
for exp_name in R.list_experiments():
knowledge.append(ExperimentKnowledge(storages=ExperimentStorage(exp_name=exp_name)))
self.logger.plain_info(f"Load knowledge from: {path} finished.")
return knowledge
def load_practice_knowledge(self, path: Path) -> PracticeKnowledge:
self.practice_knowledge = PracticeKnowledge(
YamlStorage(path.joinpath(Path.cwd().joinpath("knowledge")/f"{self.KT_PRACTICE}/{YamlStorage.DEFAULT_NAME}")))
return self.practice_knowledge
def load_execute_knowledge(self, path: Path) -> ExecuteKnowledge:
self.execute_knowledge = ExecuteKnowledge(
YamlStorage(path.joinpath(Path.cwd().joinpath("knowledge")/f"{self.KT_EXECUTE}/{YamlStorage.DEFAULT_NAME}")))
return self.execute_knowledge
def load_finance_knowledge(self, path: Path) -> FinanceKnowledge:
self.finance_knowledge = FinanceKnowledge(
YamlStorage(path.joinpath(Path.cwd().joinpath("knowledge")/f"{self.KT_FINANCE}/{YamlStorage.DEFAULT_NAME}")))
return self.finance_knowledge
def load_infrastructure_knowledge(self, path: Path) -> InfrastructureKnowledge:
self.infrastructure_knowledge = InfrastructureKnowledge(
YamlStorage(path.joinpath(Path.cwd().joinpath("knowledge")/f"{self.KT_INFRASTRUCTURE}/{YamlStorage.DEFAULT_NAME}")))
return self.infrastructure_knowledge
def get_knowledge(self, knowledge_type: str = None):
if knowledge_type == self.KT_EXECUTE:
knowledge = self.execute_knowledge.knowledge
elif knowledge_type == self.KT_PRACTICE:
knowledge = self.practice_knowledge.knowledge
elif knowledge_type == self.KT_FINANCE:
knowledge = self.finance_knowledge.knowledge
elif knowledge_type == self.KT_INFRASTRUCTURE:
knowledge = self.infrastructure_knowledge.knowledge
else:
knowledge = (
self.execute_knowledge.knowledge
+ self.practice_knowledge.knowledge
+ self.finance_knowledge.knowledge
+ self.infrastructure_knowledge.knowledge
)
return knowledge
def query(self, knowledge_type: str = None, content: str = None, n: int = 5):
"""
@param knowledge_type: self.KT_EXECUTE, self.KT_PRACTICE or self.KT_FINANCE
@param content: content to query KnowledgeBase
@param n: top n knowledge to ask ChatGPT
@return:
"""
# todo: replace list with persistent storage strategy such as ES/pinecone to enable
# literal search/semantic search
knowledge = self.get_knowledge(knowledge_type=knowledge_type)
if len(knowledge) == 0 or knowledge_type == "infrastructure":
return ""
if knowledge_type == "practice":
knowledge = [line for line in knowledge if line.startswith("practice_knowledge on")]
scores = []
for k in knowledge:
scores.append(similarity(str(k), content))
sorted_indexes = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
similar_n_indexes = sorted_indexes[:n]
similar_n_docs = "\n".join([knowledge[i] for i in similar_n_indexes])
user_prompt_template = Template(
"""
query: '{{query}}'
paragraph:
{{paragraph}}.
"""
)
user_prompt = user_prompt_template.render(query=content, paragraph=similar_n_docs)
system_prompt = """
You are an assistant who find relevant sentences from a long paragraph to fit user's query sentence. Relevant means the sentence might provide userful information to explain user's query sentence. People after reading the relevant sentences might have a better understanding of the query sentence.
Please response no less than ten sentences, if paragraph is not enough, you can return less than ten. Don't pop out irrelevant sentences. Please list the sentences in a number index instead of a whole paragraph.
Example input:
query: what is the best model for image classification?
paragraph:
Image classification is the process of identifying and categorizing objects within an image into different groups or classes.
Machine learning is a type of artificial intelligence that enables computers to learn and make decisions without being explicitly programmed.
The solar system is a collection of celestial bodies, including the Sun, planets, moons, and other objects, that orbit around the Sun due to its gravitational pull.
A car is a wheeled vehicle, typically powered by an engine or electric motor, used for transportation of people and goods.
ResNet, short for Residual Network, is a type of deep learning architecture designed to improve the accuracy and training speed of neural networks for image recognition tasks.
Example output:
1. ResNet, short for Residual Network, is a type of deep learning architecture designed to improve the accuracy and training speed of neural networks for image recognition tasks.
2. Image classification is the process of identifying and categorizing objects within an image into different groups or classes.
3. Machine learning is a type of artificial intelligence that enables computers to learn and make decisions without being explicitly programmed.
"""
response = APIBackend().build_messages_and_create_chat_completion(
user_prompt=user_prompt, system_prompt=system_prompt
)
return response
# perhaps init KnowledgeBase in other place
KnowledgeBase(workdir=Path.cwd().joinpath('knowledge'))

View File

@@ -0,0 +1,47 @@
Quantitative investment research, often referred to as "quant," is an investment approach that uses mathematical and statistical models to analyze financial data and identify investment opportunities. This method relies heavily on computer algorithms and advanced data analysis techniques to develop trading strategies and make investment decisions.
One of the key aspects of quant investment research is the development of predictive models to forecast asset prices, market movements, and other financial variables. These models are typically built using historical data and refined through rigorous testing and validation processes.
In quant investment research, various metrics are used to evaluate the performance of a model or strategy. Some common metrics include annual return, information coefficient, maximum drawdown, and cumulative sum (cumsum) return.
Annual return is a measure of an investment's performance over the course of a year and is expressed as a percentage. It is an important metric to consider but can be controversial as higher annual returns are often associated with higher risks.
Maximum drawdown is the largest peak-to-trough decline in an investment's value over a specified period. It is a measure of the strategy's risk and can be controversial since increasing annual return often leads to a more dynamic strategy with larger drawdowns.
Information coefficient (IC) is a measure of the relationship between predicted returns and actual returns. A higher IC indicates a stronger relationship and suggests a more effective predictive model.
Cumulative sum return is the total return generated by an investment over a given period. It is useful for evaluating the overall performance of a strategy and is particularly relevant when comparing multiple strategies over the same time frame.
Another important aspect of quant investment research is portfolio optimization, which involves determining the optimal allocation of assets to maximize returns while minimizing risk.
Quantitative researchers often use techniques such as factor analysis to identify underlying drivers of asset returns. This helps them to build more robust models and better understand the relationships between various financial variables.
Machine learning has become increasingly popular in quant investment research, as it offers new ways to identify patterns and relationships in large datasets. Techniques such as neural networks, decision trees, and clustering algorithms are commonly used in this field.
Backtesting is a critical step in the development of a quantitative investment strategy. It involves applying a model or algorithm to historical data to see how it would have performed under various market conditions.
Risk management is a crucial component of quant investment research. Quantitative researchers must carefully consider the potential risks associated with their models and strategies and take steps to mitigate these risks.
In recent years, there has been a growing interest in alternative data sources for quant investment research. These can include social media sentiment, satellite imagery, and other non-traditional data sources that may provide unique insights into market trends and investment opportunities.
Despite its reliance on data and mathematical models, quant investment research is not without its challenges and limitations. Some critics argue that it is inherently backward-looking, as it relies on historical data to make predictions about future market movements.
Furthermore, quant models may struggle to adapt to rapidly changing market conditions and may be susceptible to overfitting, which occurs when a model is tailored too closely to the training data and performs poorly on new, unseen data.
In conclusion, quantitative investment research is a growing field that combines advanced mathematical and statistical techniques with financial expertise to develop sophisticated investment strategies and models. While it offers significant potential for generating returns and managing risk, it also comes with its own set of challenges and limitations that must be carefully considered.
A rolling meta-controller is a valuable tool in quantitative investment research, providing several benefits such as increased adaptability, enhanced risk management, and improved performance. By continuously updating the parameters of underlying trading models based on recent data, this approach helps strategies stay relevant in dynamic market conditions and reduces the risk of overfitting. Additionally, a rolling meta-controller can manage a diverse portfolio of trading models, leading to a more robust and diversified investment strategy.
However, incorporating a rolling meta-controller also comes with potential challenges, such as increased complexity, higher computational requirements, and the risk of parameter instability. It is essential for quantitative researchers to carefully weigh these factors when deciding whether to include a rolling meta-controller in their trading strategies. By balancing the potential advantages and disadvantages, researchers can optimize their strategies to achieve their specific investment objectives and risk tolerances.
Meta controller like rolling or DDGDA is proved to be relevant to the maximum drawdown but effect is still unknown. Because seeing the latest data is always good to avoid some new patterns recently. A static model is not good at doing it.
If rolling has been proved in practice knowledge that it's useful in minimizing max drawdown or maximize dropdown, always try DDGDA in the following experiments to further minimize the max drawdown.
DDG-DA, which stands for Data Distribution Generation for Predictable Concept Drift Adaptation, is a research paper focused on addressing the challenge of concept drift in machine learning and data mining. Concept drift occurs when the statistical properties of a target variable change over time, causing the model's performance to degrade. This is a common issue in various domains, including finance, where market conditions and economic factors can change rapidly.
The DDG-DA paper proposes a framework for generating synthetic datasets that simulate concept drift in a controlled and predictable manner. By creating these datasets, researchers can better understand how concept drift affects the performance of their machine learning models and develop strategies for adapting to these changes.
The main idea behind DDG-DA is to create synthetic data distributions that mimic the underlying data generating process while controlling the extent of concept drift. This is achieved by using a combination of data transformation techniques, such as scaling, rotation, and translation of the original data distribution.
By generating synthetic datasets with controlled concept drift, researchers can evaluate and compare the performance of various adaptation techniques in a more systematic and controlled manner. This can lead to the development of more robust and adaptive machine learning models that can better handle changing data distributions, ultimately improving the performance of these models in real-world applications, such as finance and investment.

139
qlib/finco/llm.py Normal file
View File

@@ -0,0 +1,139 @@
import re
import os
import time
import openai
import json
import yaml
from typing import Optional, Tuple, Union
from qlib.finco.conf import Config
from qlib.finco.utils import SingletonBaseClass
from qlib.finco.log import FinCoLog
from qlib.config import DEFAULT_QLIB_DOT_PATH
from pathlib import Path
class ConvManager:
"""
This is a conversation manager of LLM
It is for convenience of exporting conversation for debugging.
"""
def __init__(self, path: Union[Path, str] = DEFAULT_QLIB_DOT_PATH / "llm_conv", recent_n: int = 10) -> None:
self.path = Path(path)
self.path.mkdir(parents=True, exist_ok=True)
self.recent_n = recent_n
def _rotate_files(self):
pairs = []
for f in self.path.glob("*.json"):
m = re.match(r"(\d+).json", f.name)
if m is not None:
n = int(m.group(1))
pairs.append((n, f))
pass
pairs.sort(key=lambda x: x[0])
for n, f in pairs[: self.recent_n][::-1]:
f.rename(self.path / f"{n+1}.json")
def append(self, conv: Tuple[list, str]):
self._rotate_files()
json.dump(conv, open(self.path / "0.json", "w"))
# TODO: reseve line breaks to make it more convient to edit file directly.
class APIBackend(SingletonBaseClass):
def __init__(self):
self.cfg = Config()
openai.api_key = self.cfg.openai_api_key
if self.cfg.use_azure:
openai.api_type = "azure"
openai.api_base = self.cfg.azure_api_base
openai.api_version = self.cfg.azure_api_version
self.use_azure = self.cfg.use_azure
self.debug_mode = False
if self.cfg.debug_mode:
self.debug_mode = True
cwd = os.getcwd()
self.cache_file_location = os.path.join(cwd, "prompt_cache.json")
self.cache = (
json.load(open(self.cache_file_location, "r")) if os.path.exists(self.cache_file_location) else {}
)
def build_messages_and_create_chat_completion(self, user_prompt, system_prompt=None, former_messages=[], **kwargs):
"""build the messages to avoid implementing several redundant lines of code"""
cfg = Config()
# TODO: system prompt should always be provided. In development stage we can use default value
if system_prompt is None:
try:
system_prompt = cfg.system_prompt
except AttributeError:
FinCoLog().warning("system_prompt is not set, using default value.")
system_prompt = "You are an AI assistant who helps to answer user's questions about finance."
messages = [
{
"role": "system",
"content": system_prompt,
}
]
messages.extend(former_messages[-1 * cfg.max_past_message_include :])
messages.append(
{
"role": "user",
"content": user_prompt,
}
)
fcl = FinCoLog()
response = self.try_create_chat_completion(messages=messages, **kwargs)
fcl.log_message(messages)
fcl.log_response(response)
if self.debug_mode:
ConvManager().append((messages, response))
return response
def try_create_chat_completion(self, max_retry=10, **kwargs):
max_retry = self.cfg.max_retry if self.cfg.max_retry is not None else max_retry
for i in range(max_retry):
try:
response = self.create_chat_completion(**kwargs)
return response
except (openai.error.RateLimitError, openai.error.Timeout, openai.error.APIError) as e:
print(e)
print(f"Retrying {i+1}th time...")
time.sleep(1)
continue
raise Exception(f"Failed to create chat completion after {max_retry} retries.")
def create_chat_completion(
self,
messages,
model=None,
temperature: float = None,
max_tokens: Optional[int] = None,
) -> str:
if self.debug_mode:
key = json.dumps(messages)
if key in self.cache:
return self.cache[key]
if temperature is None:
temperature = self.cfg.temperature
if max_tokens is None:
max_tokens = self.cfg.max_tokens
if self.cfg.use_azure:
response = openai.ChatCompletion.create(
engine=self.cfg.model,
messages=messages,
max_tokens=self.cfg.max_tokens,
)
else:
response = openai.ChatCompletion.create(
model=self.cfg.model,
messages=messages,
)
resp = response.choices[0].message["content"]
if self.debug_mode:
self.cache[key] = resp
json.dump(self.cache, open(self.cache_file_location, "w"))
return resp

139
qlib/finco/log.py Normal file
View File

@@ -0,0 +1,139 @@
"""
This module will base on Qlib's logger module and provides some interactive functions.
"""
import logging
import time
from typing import Dict, List
from qlib.finco.utils import SingletonBaseClass
from contextlib import contextmanager
class LogColors:
"""
ANSI color codes for use in console output.
"""
RED = "\033[91m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
BLUE = "\033[94m"
MAGENTA = "\033[95m"
CYAN = "\033[96m"
WHITE = "\033[97m"
GRAY = "\033[90m"
BLACK = "\033[30m"
BOLD = "\033[1m"
ITALIC = "\033[3m"
END = "\033[0m"
@classmethod
def get_all_colors(cls):
names = dir(cls)
names = [name for name in names if not name.startswith("__") and not callable(getattr(cls, name))]
var_values = [getattr(cls, name) for name in names]
return var_values
def render(self, text: str, color: str = "", style: str = ""):
"""
render text by input color and style. It's not recommend that input text is already rendered.
"""
# This method is called too frequently, which is not good.
colors = self.get_all_colors()
# Perhaps color and font should be distinguished here.
if color:
assert color in colors, f"color should be in: {colors} but now is: {color}"
if style:
assert style in colors, f"style should be in: {colors} but now is: {style}"
text = f"{color}{text}{self.END}"
text = f"{style}{text}{self.END}"
return text
@contextmanager
def formatting_log(logger, title="Info"):
"""
a context manager, print liens before and after a function
"""
length = {"Start": 90, "Round": 90, "Task": 90, "Info": 60, "Interact": 60, "End": 90}.get(title, 60)
color, bold = (
(LogColors.YELLOW, LogColors.BOLD)
if title in ["Start", "Round", "Task", "Info", "Interact", "End"]
else (LogColors.CYAN, "")
)
logger.info("")
logger.info(f"{color}{bold}{'-'} {title} {'-' * (length - len(title))}{LogColors.END}")
yield
if color == LogColors.YELLOW:
time.sleep(2)
logger.info("")
class FinCoLog(SingletonBaseClass):
# TODO:
# - config to file logger and save it into workspace
def __init__(self) -> None:
self.logger = logging.Logger("interactive")
# TODO: merge these with Qlib's default logger.
# We can do the same thing by changing the default log dict of Qlib.
# Reference: https://github.com/microsoft/qlib/blob/main/qlib/config.py#L155
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(message)s"))
self.logger.addHandler(handler)
self.logger.setLevel(logging.INFO)
def log_message(self, messages: List[Dict[str, str]]):
"""
messages is some info like this [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": user_prompt,
},
]
"""
with formatting_log(self.logger, "GPT Messages"):
for m in messages:
self.logger.info(
f"{LogColors.MAGENTA}{LogColors.BOLD}Role:{LogColors.END} "
f"{LogColors.CYAN}{m['role']}{LogColors.END}\n"
+ f"{LogColors.MAGENTA}{LogColors.BOLD}Content:{LogColors.END} "
f"{LogColors.CYAN}{m['content']}{LogColors.END}\n"
)
def log_response(self, response: str):
with formatting_log(self.logger, "GPT Response"):
self.logger.info(f"{LogColors.CYAN}{response}{LogColors.END}\n")
time.sleep(1)
# TODO:
# It looks wierd if we only have logger
def info(self, *args, plain=False, title="Info"):
if plain:
return self.plain_info(*args)
with formatting_log(self.logger, title):
for arg in args:
self.logger.info(f"{LogColors.WHITE}{arg}{LogColors.END}")
def plain_info(self, *args):
for arg in args:
self.logger.info(
f"{LogColors.YELLOW}{LogColors.BOLD}Info:{LogColors.END}{LogColors.WHITE}{arg}{LogColors.END}"
)
def warning(self, *args):
for arg in args:
self.logger.warning(f"{LogColors.BLUE}{LogColors.BOLD}Warning:{LogColors.END}{arg}")
def error(self, *args):
for arg in args:
self.logger.error(f"{LogColors.RED}{LogColors.BOLD}Error:{LogColors.END}{arg}")

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,33 @@
from typing import Union
from pathlib import Path
from jinja2 import Template
import yaml
from qlib.finco.utils import SingletonBaseClass
from qlib.finco import get_finco_path
class PromptTemplate(SingletonBaseClass):
def __init__(self) -> None:
super().__init__()
_template = yaml.load(
open(Path.joinpath(get_finco_path(), "prompt_template.yaml"), "r"), Loader=yaml.FullLoader
)
for k, v in _template.items():
if k == "mods":
continue
self.__setattr__(k, Template(v))
def get(self, key: str):
return self.__dict__.get(key, Template(""))
def update(self, key: str, value):
self.__setattr__(key, value)
def save(self, file_path: Union[str, Path]):
if isinstance(file_path, str):
file_path = Path(file_path)
Path.mkdir(file_path.parent, exist_ok=True)
with open(file_path, "w") as f:
yaml.dump(self.__dict__, f)

File diff suppressed because it is too large Load Diff

2
qlib/finco/record.txt Normal file
View File

@@ -0,0 +1,2 @@
conda activate qlib38
python cli_learn.py "build an A-share stock market daily portfolio in quantitative investment and minimize the maximum drawdown while maintaining return."

1328
qlib/finco/task.py Normal file

File diff suppressed because it is too large Load Diff

12
qlib/finco/tpl/README.md Normal file
View File

@@ -0,0 +1,12 @@
This is a set of templates that should be copied for a new project.
Here are the explanations for the templates folder.
| folder | explanations |
|--------|------------------------------------------------------------------|
| sl | Default configuration for supervised learning |
| sl-cfg | Like configuration in sl. But the dataset is highly configurable |
# TODO
- [ ] [Copier](https://copier.readthedocs.io/en/stable/#quick-start) may be useful if the generation process becomes complicated

View File

@@ -0,0 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
DIRNAME = Path(__file__).absolute().resolve().parent
def get_tpl_path() -> Path:
"""
return the template path
Because the template path is located in the folder. We don't know where it is located. So __file__ for this module will be used.
"""
return DIRNAME

File diff suppressed because one or more lines are too long

View File

@@ -1,6 +1,7 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
experiment_name: finCo
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
@@ -9,6 +10,7 @@ data_handler_config: &data_handler_config
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
label: ["Ref($close, -21) / Ref($close, -1) - 1"]
infer_processors:
- class: RobustZScoreNorm
kwargs:
@@ -27,9 +29,7 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
signal:
- <MODEL>
- <DATASET>
signal: <PRED>
topk: 50
n_drop: 5
backtest:
@@ -48,7 +48,8 @@ task:
class: LinearModel
module_path: qlib.contrib.model.linear
kwargs:
estimator: ols
estimator: ridge
alpha: 0.05
dataset:
class: DatasetH
module_path: qlib.data.dataset
@@ -72,7 +73,7 @@ task:
kwargs:
ana_long_short: True
ann_scaler: 252
- class: MultiPassPortAnaRecord
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

71
qlib/finco/utils.py Normal file
View File

@@ -0,0 +1,71 @@
import json
import string
import random
from typing import List
from pathlib import Path
from fuzzywuzzy import fuzz
class SingletonMeta(type):
_instance = None
def __call__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(SingletonMeta, cls).__call__(*args, **kwargs)
return cls._instance
class SingletonBaseClass(metaclass=SingletonMeta):
"""
Because we try to support defining Singleton with `class A(SingletonBaseClass)` instead of `A(metaclass=SingletonMeta)`
This class becomes necessary
"""
# TODO: Add move this class to Qlib's general utils.
def parse_json(response):
try:
return json.loads(response)
except json.decoder.JSONDecodeError:
pass
raise Exception(f"Failed to parse response: {response}, please report it or help us to fix it.")
def similarity(text1, text2):
text1 = text1 if isinstance(text1, str) else ""
text2 = text2 if isinstance(text2, str) else ""
# Maybe we can use other similarity algorithm such as tfidf
return fuzz.ratio(text1, text2)
def random_string(length=10):
letters = string.ascii_letters + string.digits
return "".join(random.choice(letters) for i in range(length))
def directory_tree(root_dif, max_depth=None):
def _directory_tree(root_dir, padding="", deep=1, max_d=None) -> List:
_output = []
if max_d and deep > max_d:
return _output
files = sorted(root_dir.iterdir())
for i, file in enumerate(files):
if i == len(files) - 1:
_output.append(padding + '└── ' + file.name)
if file.is_dir():
_output.extend(_directory_tree(file, padding + " ", deep=deep + 1, max_d=max_d))
else:
_output.append(padding + '├── ' + file.name)
if file.is_dir():
_output.extend(_directory_tree(file, padding + "", deep=deep + 1, max_d=max_d))
return _output
output = _directory_tree(root_dif, max_d=max_depth)
return '\n'.join(output)

212
qlib/finco/workflow.py Normal file
View File

@@ -0,0 +1,212 @@
import sys
import time
import shutil
from typing import List
from pathlib import Path
from qlib.finco.task import IdeaTask, SummarizeTask
from qlib.finco.prompt_template import PromptTemplate, Template
from qlib.finco.log import FinCoLog, LogColors
from qlib.finco.llm import APIBackend
from qlib.finco.conf import Config
from qlib.finco.knowledge import KnowledgeBase, Topic
from qlib.finco.context import WorkflowContextManager
# TODO: it is not necessary in current phase
# class TaskDAG:
# """
# This is a Task manager. it maintains a graph and a stack stucture to manager the task
# The reason why the DGA relationship is maintained outside instead of inside the task is that
# - To make the creating of task simpler(user don't have to care about the relation-ship)
# - To manage the relation ship when poping and executing the tasks is relatively easier instead of scattering them everywhere
# """
# def __init__(self) -> None:
# self._finished = []
# self._stack = []
# self._dag = defaultdict(list) # from id(object) -> list of id(object)
#
# def pop(self):
# return self._stack.pop(0)
#
# def push(self, task: Union[Task, List[Task]], parent: Optional[Task] = None):
# if isinstance(task, Task):
# task = [task]
# if parent is not None:
# self._dag
#
# def done(self) -> bool:
# return len(self._stack) == 0
class WorkflowManager:
"""This manage the whole task automation workflow including tasks and actions"""
def __init__(self, workspace=None) -> None:
self.logger = FinCoLog()
if workspace is None:
self._workspace = Path.cwd() / "finco_workspace"
else:
self._workspace = Path(workspace)
self.conf = Config()
self._confirm_and_rm()
self.prompt_template = PromptTemplate()
self.context = WorkflowContextManager(workspace=self._workspace)
self.context.set_context("workspace", self._workspace)
self.default_user_prompt = "build an A-share stock market daily portfolio in quantitative investment and minimize the maximum drawdown while maintaining return."
def _confirm_and_rm(self):
# if workspace exists, please confirm and remove it. Otherwise exit.
if self._workspace.exists() and not self.conf.continuous_mode:
self.logger.info(title="Interact")
flag = input(
LogColors().render(
f"Will be deleted: \n\t{self._workspace}\n"
f"If you do not need to delete {self._workspace},"
f" please change the workspace dir or rename existing files\n"
f"Are you sure you want to delete, yes(Y/y), no (N/n):",
color=LogColors.WHITE,
)
)
if str(flag) not in ["Y", "y"]:
sys.exit()
else:
# remove self._workspace
shutil.rmtree(self._workspace)
elif self._workspace.exists() and self.conf.continuous_mode:
shutil.rmtree(self._workspace)
def set_context(self, key, value):
"""Direct call set_context method of the context manager"""
self.context.set_context(key, value)
def get_context(self) -> WorkflowContextManager:
return self.context
def run(self, prompt: str) -> Path:
"""
The workflow manager is supposed to generate a codebase based on the prompt
Parameters
----------
prompt: str
the prompt user gives
Returns
-------
Path
The workflow manager is expected to produce output that includes a codebase containing generated code, results, and reports in a designated location.
The path is returned
The output path should follow a specific format:
- TODO: design
There is a summarized report where user can start from.
"""
# NOTE: The following items are not designed to make the workflow very flexible.
# - The generated tasks can't be changed after geting new information from the execution retuls.
# - But it is required in some cases, if we want to build a external dataset, it maybe have to plan like autogpt...
# NOTE: default user prompt might be changed in the future and exposed to the user
if prompt is None:
self.set_context("user_intention", self.default_user_prompt)
else:
self.set_context("user_intention", prompt)
self.logger.info(f"user_intention: {self.get_context().get_context('user_intention')}", title="Start")
# NOTE: list may not be enough for general task list
task_list = [IdeaTask(), SummarizeTask()]
task_finished = []
while len(task_list):
task_list_info = [str(task) for task in task_list]
# task list is not long, so sort it is not a big problem
# TODO: sort the task list based on the priority of the task
# task_list = sorted(task_list, key=lambda x: x.task_type)
t = task_list.pop(0)
self.logger.info(
f"Task finished: {[str(task) for task in task_finished]}",
f"Task in queue: {task_list_info}",
f"Executing task: {str(t)}",
title="Task",
)
t.assign_context_manager(self.context)
res = t.execute()
t.summarize()
task_finished.append(t)
self.context.set_context("task_finished", task_finished)
self.logger.plain_info(f"{str(t)} finished.\n\n\n")
task_list = res + task_list
return self._workspace
class LearnManager:
__DEFAULT_TOPICS = ["RollingModel"]
def __init__(self):
self.epoch = 0
self.wm = WorkflowManager()
self.topics = [
Topic(name=topic, system=self.wm.prompt_template.get(f"Topic_system"), user=self.wm.prompt_template.get(f"Topic_user")) for topic in self.__DEFAULT_TOPICS
]
self.knowledge_base = KnowledgeBase()
def run(self, prompt):
# todo: add early stop condition
for i in range(10):
self.wm.logger.info(f"Round: {self.epoch+1}", title="Round")
self.wm.run(prompt)
self.learn()
self.epoch += 1
def learn(self):
workspace = self.wm.context.get_context("workspace")
def _drop_duplicate_task(_task: List):
unique_task = {}
for obj in _task:
task_name = obj.__class__.__name__
if task_name not in unique_task:
unique_task[task_name] = obj
return list(unique_task.values())
# one task maybe run several times in workflow
task_finished = _drop_duplicate_task(self.wm.context.get_context("task_finished"))
user_intention = self.wm.context.get_context("user_intention")
summary = self.wm.context.get_context("summary")
target = self.wm.context.get_context(f"target")
diffrence = self.wm.context.get_context(f"experiments_difference")
target_metrics = self.wm.context.get_context(f"high_level_metrics")
[topic.summarize(self.knowledge_base.practice_knowledge.knowledge[-2:], user_intention, target, diffrence, target_metrics) for topic in self.topics]
[self.knowledge_base.practice_knowledge.add([f"practice_knowledge on {topic.name}:\,{topic.knowledge}"]) for topic in self.topics]
# knowledge_of_topics = [{topic.name: topic.knowledge} for topic in self.topics]
# for task in task_finished:
# prompt_workflow_selection = self.wm.prompt_template.get(f"{self.__class__.__name__}_user").render(
# summary=summary,
# brief=knowledge_of_topics,
# task_finished=[str(t) for t in task_finished],
# task=task.__class__.__name__, system=task.system.render(), user_intention=user_intention
# )
# response = APIBackend().build_messages_and_create_chat_completion(
# user_prompt=prompt_workflow_selection,
# system_prompt=self.wm.prompt_template.get(f"{self.__class__.__name__}_system").render(),
# )
# # todo: response assertion
# task.prompt_template.update(key=f"{task.__class__.__name__}_system", value=Template(response))
self.wm.prompt_template.save(Path.joinpath(workspace, f"prompts/checkpoint_{self.epoch}.yml"))
self.wm.context.clear(reserve=["workspace"])

View File

@@ -30,6 +30,7 @@ class Ensemble:
class SingleKeyEnsemble(Ensemble):
"""
Extract the object if there is only one key and value in the dict. Make the result more readable.
{Only key: Only value} -> Only value
@@ -63,6 +64,7 @@ class SingleKeyEnsemble(Ensemble):
class RollingEnsemble(Ensemble):
"""Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime".

View File

@@ -51,6 +51,3 @@ class MetaTask:
Return the **processed** meta_info
"""
return self.meta_info
def __repr__(self):
return f"MetaTask(task={self.task}, meta_info={self.meta_info})"

View File

@@ -247,7 +247,9 @@ class ShrinkCovEstimator(RiskModel):
v1 = y.T.dot(z) / t - cov_mkt[:, None] * S
roff1 = np.sum(v1 * cov_mkt[:, None].T) / var_mkt - np.sum(np.diag(v1) * cov_mkt) / var_mkt
v3 = z.T.dot(z) / t - var_mkt * S
roff3 = np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2
roff3 = (
np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2
)
roff = 2 * roff1 - roff3
rho = rdiag + roff

View File

@@ -41,7 +41,7 @@ def _log_task_info(task_config: dict):
def _exe_task(task_config: dict):
rec = R.get_recorder()
# model & dataset initialization
# model & dataset initiation
model: Model = init_instance_by_config(task_config["model"], accept_types=Model)
dataset: Dataset = init_instance_by_config(task_config["dataset"], accept_types=Dataset)
reweighter: Reweighter = task_config.get("reweighter", None)

View File

@@ -168,9 +168,7 @@ class TrainingVessel(TrainingVesselBase):
self.policy.train()
with vector_env.collector_guard():
collector = Collector(
self.policy, vector_env, VectorReplayBuffer(self.buffer_size, len(vector_env)), exploration_noise=True
)
collector = Collector(self.policy, vector_env, VectorReplayBuffer(self.buffer_size, len(vector_env)))
# Number of episodes collected in each training iteration can be overridden by fast dev run.
if self.trainer.fast_dev_run is not None:

View File

@@ -12,11 +12,15 @@ import datetime
from tqdm import tqdm
from pathlib import Path
from loguru import logger
from cryptography.fernet import Fernet
from qlib.utils import exists_qlib_data
class GetData:
REMOTE_URL = "https://github.com/SunsetWolf/qlib_dataset/releases/download"
REMOTE_URL = "https://qlibpublic.blob.core.windows.net/data/default/stock_data"
# "?" is not included in the token.
TOKEN = b"gAAAAABkmDhojHc0VSCDdNK1MqmRzNLeDFXe5hy8obHpa6SDQh4de6nW5gtzuD-fa6O_WZb0yyqYOL7ndOfJX_751W3xN5YB4-n-P22jK-t6ucoZqhT70KPD0Lf0_P328QPJVZ1gDnjIdjhi2YLOcP4BFTHLNYO0mvzszR8TKm9iT5AKRvuysWnpi8bbYwGU9zAcJK3x9EPL43hOGtxliFHcPNGMBoJW4g_ercdhi0-Qgv5_JLsV-29_MV-_AhuaYvJuN2dEywBy"
KEY = "EYcA8cgorA8X9OhyMwVfuFxn_1W3jGk6jCbs3L2oPoA="
def __init__(self, delete_zip_file=False):
"""
@@ -29,45 +33,9 @@ class GetData:
self.delete_zip_file = delete_zip_file
def merge_remote_url(self, file_name: str):
"""
Generate download links.
Parameters
----------
file_name: str
The name of the file to be downloaded.
The file name can be accompanied by a version number, (e.g.: v2/qlib_data_simple_cn_1d_latest.zip),
if no version number is attached, it will be downloaded from v0 by default.
"""
return f"{self.REMOTE_URL}/{file_name}" if "/" in file_name else f"{self.REMOTE_URL}/v0/{file_name}"
def download(self, url: str, target_path: [Path, str]):
"""
Download a file from the specified url.
Parameters
----------
url: str
The url of the data.
target_path: str
The location where the data is saved, including the file name.
"""
file_name = str(target_path).rsplit("/", maxsplit=1)[-1]
resp = requests.get(url, stream=True, timeout=60)
resp.raise_for_status()
if resp.status_code != 200:
raise requests.exceptions.HTTPError()
chunk_size = 1024
logger.warning(
f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)"
)
logger.info(f"{os.path.basename(file_name)} downloading......")
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
with target_path.open("wb") as fp:
for chunk in resp.iter_content(chunk_size=chunk_size):
fp.write(chunk)
p_bar.update(chunk_size)
fernet = Fernet(self.KEY)
token = fernet.decrypt(self.TOKEN).decode()
return f"{self.REMOTE_URL}/{file_name}?{token}"
def download_data(self, file_name: str, target_dir: [Path, str], delete_old: bool = True):
"""
@@ -102,7 +70,21 @@ class GetData:
target_path = target_dir.joinpath(_target_file_name)
url = self.merge_remote_url(file_name)
self.download(url=url, target_path=target_path)
resp = requests.get(url, stream=True, timeout=60)
resp.raise_for_status()
if resp.status_code != 200:
raise requests.exceptions.HTTPError()
chunk_size = 1024
logger.warning(
f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)"
)
logger.info(f"{os.path.basename(file_name)} downloading......")
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
with target_path.open("wb") as fp:
for chunk in resp.iter_content(chunk_size=chunk_size):
fp.write(chunk)
p_bar.update(chunk_size)
self._unzip(target_path, target_dir, delete_old)
if self.delete_zip_file:
@@ -117,9 +99,7 @@ class GetData:
return status
@staticmethod
def _unzip(file_path: [Path, str], target_dir: [Path, str], delete_old: bool = True):
file_path = Path(file_path)
target_dir = Path(target_dir)
def _unzip(file_path: Path, target_dir: Path, delete_old: bool = True):
if delete_old:
logger.warning(
f"will delete the old qlib data directory(features, instruments, calendars, features_cache, dataset_cache): {target_dir}"

View File

@@ -25,12 +25,7 @@ import pandas as pd
from pathlib import Path
from typing import List, Union, Optional, Callable
from packaging import version
from .file import (
get_or_create_path,
save_multiple_parts_file,
unpack_archive_with_buffer,
get_tmp_file_with_buffer,
)
from .file import get_or_create_path, save_multiple_parts_file, unpack_archive_with_buffer, get_tmp_file_with_buffer
from ..config import C
from ..log import get_module_logger, set_log_with_config
@@ -42,12 +37,7 @@ is_deprecated_lexsorted_pandas = version.parse(pd.__version__) > version.parse("
#################### Server ####################
def get_redis_connection():
"""get redis connection instance."""
return redis.StrictRedis(
host=C.redis_host,
port=C.redis_port,
db=C.redis_task_db,
password=C.redis_password,
)
return redis.StrictRedis(host=C.redis_host, port=C.redis_port, db=C.redis_task_db, password=C.redis_password)
#################### Data ####################
@@ -106,14 +96,7 @@ def get_period_offset(first_year, period, quarterly):
return offset
def read_period_data(
index_path,
data_path,
period,
cur_date_int: int,
quarterly,
last_period_index: int = None,
):
def read_period_data(index_path, data_path, period, cur_date_int: int, quarterly, last_period_index: int = None):
"""
At `cur_date`(e.g. 20190102), read the information at `period`(e.g. 201803).
Only the updating info before cur_date or at cur_date will be used.
@@ -290,10 +273,7 @@ def parse_field(field):
# \uff09 -> )
chinese_punctuation_regex = r"\u3001\uff1a\uff08\uff09"
for pattern, new in [
(
rf"\$\$([\w{chinese_punctuation_regex}]+)",
r'PFeature("\1")',
), # $$ must be before $
(rf"\$\$([\w{chinese_punctuation_regex}]+)", r'PFeature("\1")'), # $$ must be before $
(rf"\$([\w{chinese_punctuation_regex}]+)", r'Feature("\1")'),
(r"(\w+\s*)\(", r"Operators.\1("),
]: # Features # Operators
@@ -403,14 +383,7 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
return calendar
def get_date_by_shift(
trading_date,
shift,
future=False,
clip_shift=True,
freq="day",
align: Optional[str] = None,
):
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day", align: Optional[str] = None):
"""get trading date with shift bias will cur_date
e.g. : shift == 1, return next trading date
shift == -1, return previous trading date
@@ -596,38 +569,7 @@ def exists_qlib_data(qlib_dir):
# check instruments
code_names = set(map(lambda x: fname_to_code(x.name.lower()), features_dir.iterdir()))
_instrument = instruments_dir.joinpath("all.txt")
# Removed two possible ticker names "NA" and "NULL" from the default na_values list for column 0
miss_code = set(
pd.read_csv(
_instrument,
sep="\t",
header=None,
keep_default_na=False,
na_values={
0: [
" ",
"#N/A",
"#N/A N/A",
"#NA",
"-1.#IND",
"-1.#QNAN",
"-NaN",
"-nan",
"1.#IND",
"1.#QNAN",
"<NA>",
"N/A",
"NaN",
"None",
"n/a",
"nan",
"null ",
]
},
)
.loc[:, 0]
.apply(str.lower)
) - set(code_names)
miss_code = set(pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)) - set(code_names)
if miss_code and any(map(lambda x: "sht" not in x, miss_code)):
return False

View File

@@ -108,12 +108,6 @@ class Index:
self.index_map = self.idx_list = np.arange(idx_list)
self._is_sorted = True
else:
# Check if all elements in idx_list are of the same type
if not all(isinstance(x, type(idx_list[0])) for x in idx_list):
raise TypeError("All elements in idx_list must be of the same type")
# Check if all elements in idx_list are of the same datetime64 precision
if isinstance(idx_list[0], np.datetime64) and not all(x.dtype == idx_list[0].dtype for x in idx_list):
raise TypeError("All elements in idx_list must be of the same datetime64 precision")
self.idx_list = np.array(idx_list)
# NOTE: only the first appearance is indexed
self.index_map = dict(zip(self.idx_list, range(len(self))))
@@ -137,12 +131,7 @@ class Index:
if self.idx_list.dtype.type is np.datetime64:
if isinstance(item, pd.Timestamp):
# This happens often when creating index based on pandas.DatetimeIndex and query with pd.Timestamp
return item.to_numpy().astype(self.idx_list.dtype)
elif isinstance(item, np.datetime64):
# This happens often when creating index based on np.datetime64 and query with another precision
return item.astype(self.idx_list.dtype)
# NOTE: It is hard to consider every case at first.
# We just try to cover part of cases to make it more user-friendly
return item.to_numpy()
return item
def index(self, item) -> int:

View File

@@ -161,13 +161,7 @@ def init_instance_by_config(
# path like 'file:///<path to pickle file>/obj.pkl'
pr = urlparse(config)
if pr.scheme == "file":
# To enable relative path like file://data/a/b/c.pkl. pr.netloc will be data
path = pr.path
if pr.netloc != "":
path = path.lstrip("/")
pr_path = os.path.join(pr.netloc, path) if bool(pr.path) else pr.netloc
pr_path = os.path.join(pr.netloc, pr.path) if bool(pr.path) else pr.netloc
with open(os.path.normpath(pr_path), "rb") as f:
return pickle.load(f)
else:
@@ -212,6 +206,9 @@ def find_all_classes(module_path: Union[str, ModuleType], cls: type) -> List[typ
>>> from qlib.data.dataset.handler import DataHandler
>>> find_all_classes("qlib.contrib.data.handler", DataHandler)
[<class 'qlib.contrib.data.handler.Alpha158'>, <class 'qlib.contrib.data.handler.Alpha158vwap'>, <class 'qlib.contrib.data.handler.Alpha360'>, <class 'qlib.contrib.data.handler.Alpha360vwap'>, <class 'qlib.data.dataset.handler.DataHandlerLP'>]
>>> from qlib.contrib.rolling.base import Rolling
>>> find_all_classes("qlib.contrib.rolling", Rolling)
[<class 'qlib.contrib.rolling.base.Rolling'>, <class 'qlib.contrib.rolling.ddgda.DDGDA'>]
TODO:
- skip import error
@@ -226,7 +223,7 @@ def find_all_classes(module_path: Union[str, ModuleType], cls: type) -> List[typ
def _append_cls(obj):
# Leverage the closure trick to reuse code
if isinstance(obj, type) and issubclass(obj, cls) and cls not in cls_list:
if isinstance(obj, type) and issubclass(obj, cls) and obj not in cls_list:
cls_list.append(obj)
for attr in dir(mod):

View File

@@ -1,20 +1,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import sys
import os
from pathlib import Path
import sys
import fire
from jinja2 import Template, meta
import ruamel.yaml as yaml
import qlib
import fire
import ruamel.yaml as yaml
from qlib.config import C
from qlib.log import get_module_logger
from qlib.model.trainer import task_train
from qlib.utils import set_log_with_config
from qlib.utils.data import update_config
from qlib.log import get_module_logger
from qlib.utils import set_log_with_config
set_log_with_config(C.logging_config)
logger = get_module_logger("qrun", logging.INFO)
@@ -49,39 +47,6 @@ def sys_config(config, config_path):
sys.path.append(str(Path(config_path).parent.resolve().absolute() / p))
def render_template(config_path: str) -> str:
"""
render the template based on the environment
Parameters
----------
config_path : str
configuration path
Returns
-------
str
the rendered content
"""
with open(config_path, "r") as f:
config = f.read()
# Set up the Jinja2 environment
template = Template(config)
# Parse the template to find undeclared variables
env = template.environment
parsed_content = env.parse(config)
variables = meta.find_undeclared_variables(parsed_content)
# Get context from os.environ according to the variables
context = {var: os.getenv(var, "") for var in variables if var in os.environ}
logger.info(f"Render the template with the context: {context}")
# Render the template with the context
rendered_content = template.render(context)
return rendered_content
# workflow handler function
def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
"""
@@ -102,9 +67,8 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
market: csi300
"""
# Render the template
rendered_yaml = render_template(config_path)
config = yaml.safe_load(rendered_yaml)
with open(config_path) as fp:
config = yaml.safe_load(fp)
base_config_path = config.get("BASE_CONFIG_PATH", None)
if base_config_path:

View File

@@ -90,6 +90,7 @@ class OnlineStrategy:
class RollingStrategy(OnlineStrategy):
"""
This example strategy always uses the latest rolling model sas online models.
"""

View File

@@ -4,10 +4,8 @@
import logging
import warnings
import pandas as pd
import numpy as np
from tqdm import trange
from pprint import pprint
from typing import Union, List, Optional, Dict
from typing import Union, List, Optional
from qlib.utils.exceptions import LoadObjectError
from ..contrib.evaluate import risk_analysis, indicator_analysis
@@ -19,9 +17,8 @@ from ..log import get_module_logger
from ..utils import fill_placeholder, flatten_dict, class_casting, get_date_by_shift
from ..utils.time import Freq
from ..utils.data import deepcopy_basic_type
from ..utils.exceptions import QlibException
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
from qlib.contrib.analyzer import HFAnalyzer, SignalAnalyzer
logger = get_module_logger("workflow", logging.INFO)
@@ -158,6 +155,9 @@ class RecordTemp:
with class_casting(self, self.depend_cls):
self.check(include_self=True)
def analyse(self):
raise NotImplementedError(f"Please implement the `analysis` method.")
class SignalRecord(RecordTemp):
"""
@@ -233,16 +233,9 @@ class ACRecordTemp(RecordTemp):
except FileNotFoundError:
logger.warning("The dependent data does not exists. Generation skipped.")
return
artifact_dict = self._generate(*args, **kwargs)
if isinstance(artifact_dict, dict):
self.save(**artifact_dict)
return artifact_dict
return self._generate(*args, **kwargs)
def _generate(self, *args, **kwargs) -> Dict[str, object]:
"""
Run the concrete generating task, return the dictionary of the generated results.
The caller method will save the results to the recorder.
"""
def _generate(self, *args, **kwargs):
raise NotImplementedError(f"Please implement the `_generate` method")
@@ -346,8 +339,8 @@ class SigAnaRecord(ACRecordTemp):
}
)
self.recorder.log_metrics(**metrics)
self.save(**objects)
pprint(metrics)
return objects
def list(self):
paths = ["ic.pkl", "ric.pkl"]
@@ -478,18 +471,17 @@ class PortAnaRecord(ACRecordTemp):
if self.backtest_config["end_time"] is None:
self.backtest_config["end_time"] = get_date_by_shift(dt_values.max(), 1)
artifact_objects = {}
# custom strategy and get backtest
portfolio_metric_dict, indicator_dict = normal_backtest(
executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config
)
for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items():
artifact_objects.update({f"report_normal_{_freq}.pkl": report_normal})
artifact_objects.update({f"positions_normal_{_freq}.pkl": positions_normal})
self.save(**{f"report_normal_{_freq}.pkl": report_normal})
self.save(**{f"positions_normal_{_freq}.pkl": positions_normal})
for _freq, indicators_normal in indicator_dict.items():
artifact_objects.update({f"indicators_normal_{_freq}.pkl": indicators_normal[0]})
artifact_objects.update({f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]})
self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal[0]})
self.save(**{f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]})
for _analysis_freq in self.risk_analysis_freq:
if _analysis_freq not in portfolio_metric_dict:
@@ -511,7 +503,7 @@ class PortAnaRecord(ACRecordTemp):
analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict())
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
# save results
artifact_objects.update({f"port_analysis_{_analysis_freq}.pkl": analysis_df})
self.save(**{f"port_analysis_{_analysis_freq}.pkl": analysis_df})
logger.info(
f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
)
@@ -536,13 +528,12 @@ class PortAnaRecord(ACRecordTemp):
analysis_dict = analysis_df["value"].to_dict()
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
# save results
artifact_objects.update({f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
self.save(**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
logger.info(
f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
)
pprint(f"The following are analysis results of indicators({_analysis_freq}).")
pprint(analysis_df)
return artifact_objects
def list(self):
list_path = []
@@ -565,124 +556,3 @@ class PortAnaRecord(ACRecordTemp):
else:
warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found")
return list_path
class MultiPassPortAnaRecord(PortAnaRecord):
"""
This is the Multiple Pass Portfolio Analysis Record class that run backtest multiple times and generates the analysis results such as those of backtest. This class inherits the ``PortAnaRecord`` class.
If shuffle_init_score enabled, the prediction score of the first backtest date will be shuffled, so that initial position will be random.
The shuffle_init_score will only works when the signal is used as <PRED> placeholder. The placeholder will be replaced by pred.pkl saved in recorder.
Parameters
----------
recorder : Recorder
The recorder used to save the backtest results.
pass_num : int
The number of backtest passes.
shuffle_init_score : bool
Whether to shuffle the prediction score of the first backtest date.
"""
depend_cls = SignalRecord
def __init__(self, recorder, pass_num=10, shuffle_init_score=True, **kwargs):
"""
Parameters
----------
recorder : Recorder
The recorder used to save the backtest results.
pass_num : int
The number of backtest passes.
shuffle_init_score : bool
Whether to shuffle the prediction score of the first backtest date.
"""
self.pass_num = pass_num
self.shuffle_init_score = shuffle_init_score
super().__init__(recorder, **kwargs)
# Save original strategy so that pred df can be replaced in next generate
self.original_strategy = deepcopy_basic_type(self.strategy_config)
if not isinstance(self.original_strategy, dict):
raise QlibException("MultiPassPortAnaRecord require the passed in strategy to be a dict")
if "signal" not in self.original_strategy.get("kwargs", {}):
raise QlibException("MultiPassPortAnaRecord require the passed in strategy to have signal as a parameter")
def random_init(self):
pred_df = self.load("pred.pkl")
all_pred_dates = pred_df.index.get_level_values("datetime")
bt_start_date = pd.to_datetime(self.backtest_config.get("start_time"))
if bt_start_date is None:
first_bt_pred_date = all_pred_dates.min()
else:
first_bt_pred_date = all_pred_dates[all_pred_dates >= bt_start_date].min()
# Shuffle the first backtest date's pred score
first_date_score = pred_df.loc[first_bt_pred_date]["score"]
np.random.shuffle(first_date_score.values)
# Use shuffled signal as the strategy signal
self.strategy_config = deepcopy_basic_type(self.original_strategy)
self.strategy_config["kwargs"]["signal"] = pred_df
def _generate(self, **kwargs):
risk_analysis_df_map = {}
# Collect each frequency's analysis df as df list
for i in trange(self.pass_num):
if self.shuffle_init_score:
self.random_init()
# Not check for cache file list
single_run_artifacts = super()._generate(**kwargs)
for _analysis_freq in self.risk_analysis_freq:
risk_analysis_df_list = risk_analysis_df_map.get(_analysis_freq, [])
risk_analysis_df_map[_analysis_freq] = risk_analysis_df_list
analysis_df = single_run_artifacts[f"port_analysis_{_analysis_freq}.pkl"]
analysis_df["run_id"] = i
risk_analysis_df_list.append(analysis_df)
result_artifacts = {}
# Concat df list
for _analysis_freq in self.risk_analysis_freq:
combined_df = pd.concat(risk_analysis_df_map[_analysis_freq])
# Calculate return and information ratio's mean, std and mean/std
multi_pass_port_analysis_df = combined_df.groupby(level=[0, 1]).apply(
lambda x: pd.Series(
{"mean": x["risk"].mean(), "std": x["risk"].std(), "mean_std": x["risk"].mean() / x["risk"].std()}
)
)
# Only look at "annualized_return" and "information_ratio"
multi_pass_port_analysis_df = multi_pass_port_analysis_df.loc[
(slice(None), ["annualized_return", "information_ratio"]), :
]
pprint(multi_pass_port_analysis_df)
# Save new df
result_artifacts.update({f"multi_pass_port_analysis_{_analysis_freq}.pkl": multi_pass_port_analysis_df})
# Log metrics
metrics = flatten_dict(
{
"mean": multi_pass_port_analysis_df["mean"].unstack().T.to_dict(),
"std": multi_pass_port_analysis_df["std"].unstack().T.to_dict(),
"mean_std": multi_pass_port_analysis_df["mean_std"].unstack().T.to_dict(),
}
)
self.recorder.log_metrics(**metrics)
return result_artifacts
def list(self):
list_path = []
for _analysis_freq in self.risk_analysis_freq:
if _analysis_freq in self.all_freq:
list_path.append(f"multi_pass_port_analysis_{_analysis_freq}.pkl")
else:
warnings.warn(f"risk_analysis freq {_analysis_freq} is not found")
return list_path

View File

@@ -242,7 +242,7 @@ class TimeAdjuster:
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple:
"""
Shift the datetime of segment
Shift the datatime of segment
If there are None (which indicates unbounded index) in the segment, this method will return None.

View File

@@ -1,81 +0,0 @@
## Collector Data
### Get Qlib data(`bin file`)
- get data: `python scripts/get_data.py qlib_data`
- parameters:
- `target_dir`: save dir, by default *~/.qlib/qlib_data/cn_data_5min*
- `version`: dataset version, value from [`v2`], by default `v2`
- `v2` end date is *2022-12*
- `interval`: `5min`
- `region`: `hs300`
- `delete_old`: delete existing data from `target_dir`(*features, calendars, instruments, dataset_cache, features_cache*), value from [`True`, `False`], by default `True`
- `exists_skip`: traget_dir data already exists, skip `get_data`, value from [`True`, `False`], by default `False`
- examples:
```bash
# hs300 5min
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/hs300_data_5min --region hs300 --interval 5min
```
### Collector *Baostock high frequency* data to qlib
> collector *Baostock high frequency* data and *dump* into `qlib` format.
> If the above ready-made data can't meet users' requirements, users can follow this section to crawl the latest data and convert it to qlib-data.
1. download data to csv: `python scripts/data_collector/baostock_5min/collector.py download_data`
This will download the raw data such as date, symbol, open, high, low, close, volume, amount, adjustflag from baostock to a local directory. One file per symbol.
- parameters:
- `source_dir`: save the directory
- `interval`: `5min`
- `region`: `HS300`
- `start`: start datetime, by default *None*
- `end`: end datetime, by default *None*
- examples:
```bash
# cn 5min data
python collector.py download_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --start 2022-01-01 --end 2022-01-30 --interval 5min --region HS300
```
2. normalize data: `python scripts/data_collector/baostock_5min/collector.py normalize_data`
This will:
1. Normalize high, low, close, open price using adjclose.
2. Normalize the high, low, close, open price so that the first valid trading date's close price is 1.
- parameters:
- `source_dir`: csv directory
- `normalize_dir`: result directory
- `interval`: `5min`
> if **`interval == 5min`**, `qlib_data_1d_dir` cannot be `None`
- `region`: `HS300`
- `date_field_name`: column *name* identifying time in csv files, by default `date`
- `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`
- `end_date`: if not `None`, normalize the last date saved (*including end_date*); if `None`, it will ignore this parameter; by default `None`
- `qlib_data_1d_dir`: qlib directory(1d data)
if interval==5min, qlib_data_1d_dir cannot be None, normalize 5min needs to use 1d data;
```
# qlib_data_1d can be obtained like this:
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn --version v3
```
- examples:
```bash
# normalize 5min cn
python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --normalize_dir ~/.qlib/stock_data/source/hs300_5min_nor --region HS300 --interval 5min
```
3. dump data: `python scripts/dump_bin.py dump_all`
This will convert the normalized csv in `feature` directory as numpy array and store the normalized data one file per column and one symbol per directory.
- parameters:
- `csv_path`: stock data path or directory, **normalize result(normalize_dir)**
- `qlib_dir`: qlib(dump) data director
- `freq`: transaction frequency, by default `day`
> `freq_map = {1d:day, 5mih: 5min}`
- `max_workers`: number of threads, by default *16*
- `include_fields`: dump fields, by default `""`
- `exclude_fields`: fields not dumped, by default `"""
> dump_fields = `include_fields if include_fields else set(symbol_df.columns) - set(exclude_fields) exclude_fields else symbol_df.columns`
- `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`
- `date_field_name`: column *name* identifying time in csv files, by default `date`
- examples:
```bash
# dump 5min cn
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/hs300_5min_nor --qlib_dir ~/.qlib/qlib_data/hs300_5min_bin --freq 5min --exclude_fields date,symbol
```

View File

@@ -1,328 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import copy
import fire
import numpy as np
import pandas as pd
import baostock as bs
from tqdm import tqdm
from pathlib import Path
from loguru import logger
from typing import Iterable, List
import qlib
from qlib.data import D
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
from data_collector.utils import generate_minutes_calendar_from_daily, calc_adjusted_price
class BaostockCollectorHS3005min(BaseCollector):
def __init__(
self,
save_dir: [str, Path],
start=None,
end=None,
interval="5min",
max_workers=4,
max_collector_count=2,
delay=0,
check_data_length: int = None,
limit_nums: int = None,
):
"""
Parameters
----------
save_dir: str
stock save dir
max_workers: int
workers, default 4
max_collector_count: int
default 2
delay: float
time.sleep(delay), default 0
interval: str
freq, value from [5min], default 5min
start: str
start datetime, default None
end: str
end datetime, default None
check_data_length: int
check data length, by default None
limit_nums: int
using for debug, by default None
"""
bs.login()
super(BaostockCollectorHS3005min, self).__init__(
save_dir=save_dir,
start=start,
end=end,
interval=interval,
max_workers=max_workers,
max_collector_count=max_collector_count,
delay=delay,
check_data_length=check_data_length,
limit_nums=limit_nums,
)
def get_trade_calendar(self):
_format = "%Y-%m-%d"
start = self.start_datetime.strftime(_format)
end = self.end_datetime.strftime(_format)
rs = bs.query_trade_dates(start_date=start, end_date=end)
calendar_list = []
while (rs.error_code == "0") & rs.next():
calendar_list.append(rs.get_row_data())
calendar_df = pd.DataFrame(calendar_list, columns=rs.fields)
trade_calendar_df = calendar_df[~calendar_df["is_trading_day"].isin(["0"])]
return trade_calendar_df["calendar_date"].values
@staticmethod
def process_interval(interval: str):
if interval == "1d":
return {"interval": "d", "fields": "date,code,open,high,low,close,volume,amount,adjustflag"}
if interval == "5min":
return {"interval": "5", "fields": "date,time,code,open,high,low,close,volume,amount,adjustflag"}
def get_data(
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
) -> pd.DataFrame:
df = self.get_data_from_remote(
symbol=symbol, interval=interval, start_datetime=start_datetime, end_datetime=end_datetime
)
df.columns = ["date", "time", "symbol", "open", "high", "low", "close", "volume", "amount", "adjustflag"]
df["time"] = pd.to_datetime(df["time"], format="%Y%m%d%H%M%S%f")
df["date"] = df["time"].dt.strftime("%Y-%m-%d %H:%M:%S")
df["date"] = df["date"].map(lambda x: pd.Timestamp(x) - pd.Timedelta(minutes=5))
df.drop(["time"], axis=1, inplace=True)
df["symbol"] = df["symbol"].map(lambda x: str(x).replace(".", "").upper())
return df
@staticmethod
def get_data_from_remote(
symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
) -> pd.DataFrame:
df = pd.DataFrame()
rs = bs.query_history_k_data_plus(
symbol,
BaostockCollectorHS3005min.process_interval(interval=interval)["fields"],
start_date=str(start_datetime.strftime("%Y-%m-%d")),
end_date=str(end_datetime.strftime("%Y-%m-%d")),
frequency=BaostockCollectorHS3005min.process_interval(interval=interval)["interval"],
adjustflag="3",
)
if rs.error_code == "0" and len(rs.data) > 0:
data_list = rs.data
columns = rs.fields
df = pd.DataFrame(data_list, columns=columns)
return df
def get_hs300_symbols(self) -> List[str]:
hs300_stocks = []
trade_calendar = self.get_trade_calendar()
with tqdm(total=len(trade_calendar)) as p_bar:
for date in trade_calendar:
rs = bs.query_hs300_stocks(date=date)
while rs.error_code == "0" and rs.next():
hs300_stocks.append(rs.get_row_data())
p_bar.update()
return sorted({e[1] for e in hs300_stocks})
def get_instrument_list(self):
logger.info("get HS stock symbols......")
symbols = self.get_hs300_symbols()
logger.info(f"get {len(symbols)} symbols.")
return symbols
def normalize_symbol(self, symbol: str):
return str(symbol).replace(".", "").upper()
class BaostockNormalizeHS3005min(BaseNormalize):
COLUMNS = ["open", "close", "high", "low", "volume"]
AM_RANGE = ("09:30:00", "11:29:00")
PM_RANGE = ("13:00:00", "14:59:00")
def __init__(
self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
):
"""
Parameters
----------
qlib_data_1d_dir: str, Path
the qlib data to be updated for yahoo, usually from: Normalised to 5min using local 1d data
date_field_name: str
date field name, default is date
symbol_field_name: str
symbol field name, default is symbol
"""
bs.login()
qlib.init(provider_uri=qlib_data_1d_dir)
self.all_1d_data = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day")
super(BaostockNormalizeHS3005min, self).__init__(date_field_name, symbol_field_name)
@staticmethod
def calc_change(df: pd.DataFrame, last_close: float) -> pd.Series:
df = df.copy()
_tmp_series = df["close"].fillna(method="ffill")
_tmp_shift_series = _tmp_series.shift(1)
if last_close is not None:
_tmp_shift_series.iloc[0] = float(last_close)
change_series = _tmp_series / _tmp_shift_series - 1
return change_series
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
return self.generate_5min_from_daily(self.calendar_list_1d)
@property
def calendar_list_1d(self):
calendar_list_1d = getattr(self, "_calendar_list_1d", None)
if calendar_list_1d is None:
calendar_list_1d = self._get_1d_calendar_list()
setattr(self, "_calendar_list_1d", calendar_list_1d)
return calendar_list_1d
@staticmethod
def normalize_baostock(
df: pd.DataFrame,
calendar_list: list = None,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
last_close: float = None,
):
if df.empty:
return df
symbol = df.loc[df[symbol_field_name].first_valid_index(), symbol_field_name]
columns = copy.deepcopy(BaostockNormalizeHS3005min.COLUMNS)
df = df.copy()
df.set_index(date_field_name, inplace=True)
df.index = pd.to_datetime(df.index)
df = df[~df.index.duplicated(keep="first")]
if calendar_list is not None:
df = df.reindex(
pd.DataFrame(index=calendar_list)
.loc[pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date() + pd.Timedelta(days=1)]
.index
)
df.sort_index(inplace=True)
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), list(set(df.columns) - {symbol_field_name})] = np.nan
df["change"] = BaostockNormalizeHS3005min.calc_change(df, last_close)
columns += ["change"]
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan
df[symbol_field_name] = symbol
df.index.names = [date_field_name]
return df.reset_index()
def generate_5min_from_daily(self, calendars: Iterable) -> pd.Index:
return generate_minutes_calendar_from_daily(
calendars, freq="5min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE
)
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
df = calc_adjusted_price(
df=df,
_date_field_name=self._date_field_name,
_symbol_field_name=self._symbol_field_name,
frequence="5min",
_1d_data_all=self.all_1d_data,
)
return df
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
return list(D.calendar(freq="day"))
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
# normalize
df = self.normalize_baostock(df, self._calendar_list, self._date_field_name, self._symbol_field_name)
# adjusted price
df = self.adjusted_price(df)
return df
class Run(BaseRun):
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="5min", region="HS300"):
"""
Changed the default value of: scripts.data_collector.base.BaseRun.
"""
super().__init__(source_dir, normalize_dir, max_workers, interval)
self.region = region
@property
def collector_class_name(self):
return f"BaostockCollector{self.region.upper()}{self.interval}"
@property
def normalize_class_name(self):
return f"BaostockNormalize{self.region.upper()}{self.interval}"
@property
def default_base_dir(self) -> [Path, str]:
return CUR_DIR
def download_data(
self,
max_collector_count=2,
delay=0.5,
start=None,
end=None,
check_data_length=None,
limit_nums=None,
):
"""download data from Baostock
Notes
-----
check_data_length, example:
hs300 5min, a week: 4 * 60 * 5
Examples
---------
# get hs300 5min data
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --start 2022-01-01 --end 2022-01-30 --interval 5min --region HS300
"""
super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums)
def normalize_data(
self,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
end_date: str = None,
qlib_data_1d_dir: str = None,
):
"""normalize data
Attention
---------
qlib_data_1d_dir cannot be None, normalize 5min needs to use 1d data;
qlib_data_1d can be obtained like this:
$ python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn --version v3
or:
download 1d data, reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#1d-from-yahoo
Examples
---------
$ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --normalize_dir ~/.qlib/stock_data/source/hs300_5min_nor --region HS300 --interval 5min
"""
if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists():
raise ValueError(
"If normalize 5min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir <user qlib 1d data >, Reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance"
)
super(Run, self).normalize_data(
date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir
)
if __name__ == "__main__":
fire.Fire(Run)

View File

@@ -1,13 +0,0 @@
loguru
fire
requests
numpy
pandas
tqdm
lxml
yahooquery
joblib
beautifulsoup4
bs4
soupsieve
baostock

View File

@@ -8,7 +8,7 @@ import datetime
import importlib
from pathlib import Path
from typing import Type, Iterable
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import pandas as pd
from tqdm import tqdm
@@ -290,7 +290,7 @@ class Normalize:
# some symbol_field values such as TRUE, NA are decoded as True(bool), NaN(np.float) by pandas default csv parsing.
# manually defines dtype and na_values of the symbol_field.
default_na = pd._libs.parsers.STR_NA_VALUES # pylint: disable=I1101
default_na = pd._libs.parsers.STR_NA_VALUES
symbol_na = default_na.copy()
symbol_na.remove("NA")
columns = pd.read_csv(file_path, nrows=0).columns
@@ -301,7 +301,6 @@ class Normalize:
na_values={col: symbol_na if col == self._symbol_field_name else default_na for col in columns},
)
# NOTE: It has been reported that there may be some problems here, and the specific issues will be dealt with when they are identified.
df = self._normalize_obj.normalize(df)
if df is not None and not df.empty:
if self._end_date is not None:

View File

@@ -3,6 +3,7 @@
from functools import partial
import sys
from pathlib import Path
import importlib
import datetime
import fire
@@ -97,7 +98,7 @@ class IBOVIndex(IndexBase):
now = datetime.datetime.now()
current_year = now.year
current_month = now.month
for year in [item for item in range(init_year, current_year)]: # pylint: disable=R1721
for year in [item for item in range(init_year, current_year)]:
for el in four_months_period:
self.years_4_month_periods.append(str(year) + "_" + el)
# For current year the logic must be a little different

View File

@@ -1,6 +1,6 @@
async-generator==1.10
attrs==21.4.0
certifi==2022.12.7
certifi==2021.10.8
cffi==1.15.0
charset-normalizer==2.0.12
cryptography==36.0.1
@@ -8,7 +8,7 @@ fire==0.4.0
h11==0.13.0
idna==3.3
loguru==0.6.0
lxml==4.9.1
lxml==4.8.0
multitasking==0.0.10
numpy==1.22.2
outcome==1.1.0

View File

@@ -4,6 +4,7 @@
import re
import abc
import sys
import datetime
from io import BytesIO
from typing import List, Iterable
from pathlib import Path
@@ -38,7 +39,7 @@ def retry_request(url: str, method: str = "get", exclude_status: List = None):
if exclude_status is None:
exclude_status = []
method_func = getattr(requests, method)
_resp = method_func(url, headers=REQ_HEADERS, timeout=None)
_resp = method_func(url, headers=REQ_HEADERS)
_status = _resp.status_code
if _status not in exclude_status and _status != 200:
raise ValueError(f"response status: {_status}, url={url}")
@@ -396,7 +397,14 @@ class CSI500Index(CSIIndex):
today = pd.Timestamp.now()
date_range = pd.DataFrame(pd.date_range(start="2007-01-15", end=today, freq="7D"))[0].dt.date
ret_list = []
col = ["date", "symbol", "code_name"]
for date in tqdm(date_range, desc="Download CSI500"):
rs = bs.query_zz500_stocks(date=str(date))
zz500_stocks = []
while (rs.error_code == "0") & rs.next():
zz500_stocks.append(rs.get_row_data())
result = pd.DataFrame(zz500_stocks, columns=col)
result["symbol"] = result["symbol"].apply(lambda x: x.replace(".", "").upper())
result = self.get_data_from_baostock(date)
ret_list.append(result[["date", "symbol"]])
bs.logout()

Some files were not shown because too many files have changed in this diff Show More