mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 14:01:28 +08:00
Compare commits
426 Commits
v0.6.3
...
qlib_monit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
77ba7b4e91 | ||
|
|
7a639eeea7 | ||
|
|
cddaf90ef5 | ||
|
|
4ff0c4fb0f | ||
|
|
b3eece155f | ||
|
|
02e34eb9e9 | ||
|
|
3033fdf4b7 | ||
|
|
ef11a9d95c | ||
|
|
98eacf8f88 | ||
|
|
43cad1ec27 | ||
|
|
e7aa7ffcdd | ||
|
|
ed3c9d9212 | ||
|
|
2f3fbae73b | ||
|
|
e409bee9b9 | ||
|
|
7ceec37848 | ||
|
|
c12c861b7a | ||
|
|
0a4e241608 | ||
|
|
5a382d7e99 | ||
|
|
9b431bc503 | ||
|
|
cbbf6cd822 | ||
|
|
928bae08f4 | ||
|
|
c65fc226bd | ||
|
|
114162693f | ||
|
|
b884c8c571 | ||
|
|
6222940b9c | ||
|
|
bb0c555803 | ||
|
|
5da33562dd | ||
|
|
db3aa8b887 | ||
|
|
ae32d79549 | ||
|
|
a80369b80a | ||
|
|
369177acf9 | ||
|
|
2ac5ceb4de | ||
|
|
602f78b568 | ||
|
|
669f6bd6f5 | ||
|
|
3d71fd1966 | ||
|
|
b887d2ec32 | ||
|
|
8e6c744a1b | ||
|
|
9e296a8a4e | ||
|
|
4ba4512619 | ||
|
|
2fa7ef32fb | ||
|
|
c72ee9091e | ||
|
|
19eda8f4f0 | ||
|
|
d08146c30f | ||
|
|
8c3a08b18d | ||
|
|
142a9dca3c | ||
|
|
41ab130807 | ||
|
|
8f67010b58 | ||
|
|
a986379deb | ||
|
|
aef3f186c1 | ||
|
|
ebd01e0de5 | ||
|
|
f51e04a1cc | ||
|
|
d71a666904 | ||
|
|
8b15ffc027 | ||
|
|
f15ca39df8 | ||
|
|
bd37f5d953 | ||
|
|
76c5c5d1b6 | ||
|
|
b8e64dc526 | ||
|
|
df7c882fe3 | ||
|
|
9bd77bd89f | ||
|
|
c43a0b208d | ||
|
|
8ba5e93d04 | ||
|
|
370b6aad74 | ||
|
|
f5ded06a15 | ||
|
|
4c232610f1 | ||
|
|
81bd2ca8fb | ||
|
|
143c257fa2 | ||
|
|
724f9ba8d2 | ||
|
|
aa1f9b464b | ||
|
|
d2daba99d3 | ||
|
|
1c605e505a | ||
|
|
060a32e0f6 | ||
|
|
08edb92461 | ||
|
|
bec65ddf94 | ||
|
|
9dfd001f6f | ||
|
|
95a4a98de8 | ||
|
|
d4639b7df9 | ||
|
|
846c64f6c6 | ||
|
|
84c56f13bd | ||
|
|
1c99fb35da | ||
|
|
5bc2b96346 | ||
|
|
ee269b0914 | ||
|
|
2a2d2cf709 | ||
|
|
5eb9dfff16 | ||
|
|
694ae34027 | ||
|
|
51b649ec39 | ||
|
|
ca92cb980c | ||
|
|
f58c61a2e0 | ||
|
|
2b7ffa100f | ||
|
|
67c5740c83 | ||
|
|
6f669348a8 | ||
|
|
40cf83e557 | ||
|
|
fa4511cb0a | ||
|
|
45c6dfc5da | ||
|
|
36ab078fbd | ||
|
|
5a7f9ef720 | ||
|
|
8b8d21107c | ||
|
|
eab19de080 | ||
|
|
42f510024c | ||
|
|
5a7eecabee | ||
|
|
0058f7d0dc | ||
|
|
9a74fe34f6 | ||
|
|
e15ea06122 | ||
|
|
319396c815 | ||
|
|
50be7a9171 | ||
|
|
e410caaa8f | ||
|
|
fbff4c271a | ||
|
|
ee91503973 | ||
|
|
de0a0c083d | ||
|
|
8adfafa6aa | ||
|
|
aafaff45d2 | ||
|
|
6a05d4e255 | ||
|
|
cbf1fa721e | ||
|
|
4ebf684794 | ||
|
|
f4bfe8e619 | ||
|
|
cec318fbfe | ||
|
|
78bb8882cd | ||
|
|
848d953226 | ||
|
|
a3a2b5ae0b | ||
|
|
941c980d06 | ||
|
|
fe190dec4b | ||
|
|
5095b2a470 | ||
|
|
317357b50d | ||
|
|
b15e5e33fd | ||
|
|
cca43cf102 | ||
|
|
a366c11d67 | ||
|
|
18bf4b5477 | ||
|
|
c20eb5c8a6 | ||
|
|
71605794a2 | ||
|
|
1dbb561744 | ||
|
|
cb42e99bee | ||
|
|
431a9c92c1 | ||
|
|
bd7a1c11b9 | ||
|
|
70fc58104b | ||
|
|
edcd7b1ff9 | ||
|
|
3724273d73 | ||
|
|
544365f3a9 | ||
|
|
bed1175e24 | ||
|
|
70c84cbc77 | ||
|
|
da59b35c0a | ||
|
|
eae94d1ee8 | ||
|
|
1f2d2c9b69 | ||
|
|
b6df11b6b4 | ||
|
|
ae57110f64 | ||
|
|
7a2203f116 | ||
|
|
023603479c | ||
|
|
f8da79b802 | ||
|
|
136830bc2b | ||
|
|
45f78676ea | ||
|
|
1074284666 | ||
|
|
d18c367497 | ||
|
|
8743576f72 | ||
|
|
fb7f84f31e | ||
|
|
31bc85bf86 | ||
|
|
968930e85f | ||
|
|
4b66304978 | ||
|
|
253378a44e | ||
|
|
f809f0a063 | ||
|
|
0386df7b16 | ||
|
|
8a2e7b62af | ||
|
|
9d04ae4676 | ||
|
|
9b8acd9a82 | ||
|
|
ee45a7833e | ||
|
|
d395c904f2 | ||
|
|
9bf819e653 | ||
|
|
46cd57688e | ||
|
|
0387eaf7ab | ||
|
|
4ee0240c24 | ||
|
|
5f60d18dfe | ||
|
|
194217fb07 | ||
|
|
d6ff764bb2 | ||
|
|
9cc3b18e4e | ||
|
|
56eaacd931 | ||
|
|
e119c8576c | ||
|
|
68246b3b6d | ||
|
|
a04c6bd6c9 | ||
|
|
efe134e9f4 | ||
|
|
4ec300787e | ||
|
|
3886022669 | ||
|
|
8264033a72 | ||
|
|
4861552d28 | ||
|
|
834f9bd9b8 | ||
|
|
f6dc25b229 | ||
|
|
1fcfe8e4ba | ||
|
|
b1a28358ad | ||
|
|
1ca3c6a61c | ||
|
|
e3739bb980 | ||
|
|
419629e4d2 | ||
|
|
e490e83a16 | ||
|
|
fda144e66f | ||
|
|
4dc10d27e0 | ||
|
|
0a0c6a3185 | ||
|
|
d66d4ec93d | ||
|
|
4b56a4e907 | ||
|
|
7370d5af9e | ||
|
|
c6b67cb8fe | ||
|
|
3bf6c7f95f | ||
|
|
1ad237f89f | ||
|
|
2b74b4dfa4 | ||
|
|
598ee875a0 | ||
|
|
84d5318bda | ||
|
|
ba56e4071e | ||
|
|
d3160e9439 | ||
|
|
06c90d654d | ||
|
|
f72771cc81 | ||
|
|
8abdd63869 | ||
|
|
38f35658e7 | ||
|
|
d245242f2f | ||
|
|
6ef204f190 | ||
|
|
dad18074ac | ||
|
|
3cf84f8859 | ||
|
|
0403237232 | ||
|
|
689774c6be | ||
|
|
d78e42e2fe | ||
|
|
4de628c736 | ||
|
|
023c1fedfe | ||
|
|
9be6866972 | ||
|
|
be55e0e3fe | ||
|
|
619a3bb25d | ||
|
|
4bd2cd4611 | ||
|
|
aa552fdb20 | ||
|
|
5520463395 | ||
|
|
872ddc6f95 | ||
|
|
88b0871c12 | ||
|
|
d4aa681652 | ||
|
|
34f0be2836 | ||
|
|
447fed8e54 | ||
|
|
4cb74d77d1 | ||
|
|
b0fd0d2395 | ||
|
|
6559d44c7d | ||
|
|
9f57681032 | ||
|
|
d33041dc24 | ||
|
|
5953365af3 | ||
|
|
e3730b32d7 | ||
|
|
08b44ed727 | ||
|
|
83fb482f1e | ||
|
|
734bb9ee3c | ||
|
|
d47e35d64e | ||
|
|
0bc49dab60 | ||
|
|
646d899f8d | ||
|
|
07434da8b0 | ||
|
|
53a6b72ce5 | ||
|
|
a51dafcb4c | ||
|
|
8362780e22 | ||
|
|
358de88602 | ||
|
|
32a7be9964 | ||
|
|
d5f9395e51 | ||
|
|
4e7a147759 | ||
|
|
1344c40598 | ||
|
|
1d2b2f4f01 | ||
|
|
373f6e0900 | ||
|
|
ba64758c24 | ||
|
|
abddcfccdf | ||
|
|
6d5381f9b1 | ||
|
|
e4e8a4abcd | ||
|
|
e41373b8ad | ||
|
|
9d84d389ab | ||
|
|
6d8aa215d6 | ||
|
|
0969c3e7e0 | ||
|
|
5de7870f9b | ||
|
|
44a7dc004d | ||
|
|
5f8d0e0436 | ||
|
|
4fbb5a03c1 | ||
|
|
0cffb87cbc | ||
|
|
df56e3bdf9 | ||
|
|
1d435248e2 | ||
|
|
593553f573 | ||
|
|
d38b8d6001 | ||
|
|
db59713d36 | ||
|
|
67fbdafe76 | ||
|
|
42be8ac312 | ||
|
|
0df88c07f6 | ||
|
|
f6b019dcec | ||
|
|
e626264d5a | ||
|
|
b99de068f8 | ||
|
|
e8beaa5257 | ||
|
|
0ef7c8e0e6 | ||
|
|
48f0fc147f | ||
|
|
cda96be8c3 | ||
|
|
f6ed175070 | ||
|
|
2ca2071d95 | ||
|
|
0054a4db2a | ||
|
|
e2f58274ba | ||
|
|
119fe90570 | ||
|
|
e2817ab87c | ||
|
|
2e37033e35 | ||
|
|
105fe1d3ed | ||
|
|
78bc2c8748 | ||
|
|
83dbdfb45e | ||
|
|
81987bb143 | ||
|
|
53cf89d7c2 | ||
|
|
8b9065c166 | ||
|
|
6a305c73ae | ||
|
|
7022675d00 | ||
|
|
2f9af1af8f | ||
|
|
fc89fec46d | ||
|
|
c6675be792 | ||
|
|
351d598c9f | ||
|
|
81b86f8022 | ||
|
|
4d5a30b30b | ||
|
|
79c1142d3e | ||
|
|
e061443560 | ||
|
|
03ef918dd8 | ||
|
|
ca48345b29 | ||
|
|
def132e140 | ||
|
|
7bed3b4c2e | ||
|
|
4266492a34 | ||
|
|
91eef93386 | ||
|
|
a244f87f95 | ||
|
|
9df0361262 | ||
|
|
6bcd88973b | ||
|
|
d13c9ae018 | ||
|
|
11412727ef | ||
|
|
73b7107ee8 | ||
|
|
91fd53ab4d | ||
|
|
aab5c5b311 | ||
|
|
dc86a6abc5 | ||
|
|
a62d1a1b36 | ||
|
|
5015d218ff | ||
|
|
6f034ccb5d | ||
|
|
07eef18337 | ||
|
|
f277a66582 | ||
|
|
49697b1f15 | ||
|
|
131f0e2e67 | ||
|
|
b14a559a52 | ||
|
|
b115ca5353 | ||
|
|
b40bfb1ea5 | ||
|
|
4f980a0266 | ||
|
|
19d93744f3 | ||
|
|
e327f404e3 | ||
|
|
452fb8f013 | ||
|
|
c4d6e00470 | ||
|
|
0f3e3d206b | ||
|
|
83c6e74783 | ||
|
|
2bff6eb781 | ||
|
|
ee7eb79277 | ||
|
|
592db903b3 | ||
|
|
34b7da1dd8 | ||
|
|
2882929c5d | ||
|
|
fd2c1ba1ed | ||
|
|
05cf0e1edc | ||
|
|
229a39d0d3 | ||
|
|
a9a70dfddf | ||
|
|
b8cf229b05 | ||
|
|
7258340e0c | ||
|
|
b84156fde8 | ||
|
|
d1d70616a3 | ||
|
|
5378d261b4 | ||
|
|
a1fb10f7cf | ||
|
|
dbc8ca7379 | ||
|
|
c48b4c9971 | ||
|
|
b592669d1f | ||
|
|
0bcaab3a5a | ||
|
|
1de4def444 | ||
|
|
ee4692a355 | ||
|
|
6e2ce6f1dc | ||
|
|
c4733f601f | ||
|
|
82353b20e1 | ||
|
|
51baf57b40 | ||
|
|
3082f6ac1b | ||
|
|
db80b620d8 | ||
|
|
6e56396217 | ||
|
|
24024d51c7 | ||
|
|
a96f0c2e5f | ||
|
|
1e5cf1c174 | ||
|
|
719074d306 | ||
|
|
70575e8a1c | ||
|
|
ce60097722 | ||
|
|
1a990fdd25 | ||
|
|
527718a440 | ||
|
|
d3caea60ee | ||
|
|
f947a2fdef | ||
|
|
dc4aa67503 | ||
|
|
37871389b9 | ||
|
|
2f9d45e03a | ||
|
|
b8647c13c7 | ||
|
|
164687d54b | ||
|
|
58f74cfd84 | ||
|
|
f7d3e56561 | ||
|
|
42f882504e | ||
|
|
9448a6e2c7 | ||
|
|
2cc057e438 | ||
|
|
b2e2142594 | ||
|
|
4000518698 | ||
|
|
fa8f1cba06 | ||
|
|
a72911e4f8 | ||
|
|
cd5b721bc6 | ||
|
|
42590972e4 | ||
|
|
d27dc8bab8 | ||
|
|
50d5fcf61e | ||
|
|
77830a546e | ||
|
|
83237ba4ed | ||
|
|
04b916c8ae | ||
|
|
b90bd66ac6 | ||
|
|
63d05e4a1a | ||
|
|
0b11dc5167 | ||
|
|
9c2653f125 | ||
|
|
7b01c5cae7 | ||
|
|
988b42e159 | ||
|
|
12c8bfa545 | ||
|
|
c948385e76 | ||
|
|
07b905c153 | ||
|
|
0192f28bf4 | ||
|
|
cbf97f56a4 | ||
|
|
d702c8bcb1 | ||
|
|
b84686b215 | ||
|
|
6a670828a5 | ||
|
|
ca6c2ffc27 | ||
|
|
914637b3ef | ||
|
|
d8da94de10 | ||
|
|
477a548fe9 | ||
|
|
35af9ad954 | ||
|
|
8a91e7d34d | ||
|
|
4ed8b8e233 | ||
|
|
c71b645777 | ||
|
|
f2ffb80a0b | ||
|
|
cda1d4be40 | ||
|
|
fc1431cd4e | ||
|
|
06158fb621 | ||
|
|
1e2e02368c | ||
|
|
d87d29aca9 | ||
|
|
005da6306c | ||
|
|
090b68e44e | ||
|
|
bf748ba4b7 | ||
|
|
fd5c68a7d1 | ||
|
|
8c3ec164ff | ||
|
|
acdc469e39 |
12
.deepsource.toml
Normal file
12
.deepsource.toml
Normal file
@@ -0,0 +1,12 @@
|
||||
version = 1
|
||||
|
||||
test_patterns = ["tests/test_*.py"]
|
||||
|
||||
exclude_patterns = ["examples/**"]
|
||||
|
||||
[[analyzers]]
|
||||
name = "python"
|
||||
enabled = true
|
||||
|
||||
[analyzers.meta]
|
||||
runtime_version = "3.x.x"
|
||||
62
.github/stale.yml
vendored
62
.github/stale.yml
vendored
@@ -1,62 +0,0 @@
|
||||
# Configuration for probot-stale - https://github.com/probot/stale
|
||||
|
||||
# Number of days of inactivity before an Issue or Pull Request becomes stale
|
||||
daysUntilStale: 60
|
||||
|
||||
# Number of days of inactivity before an Issue or Pull Request with the stale label is closed.
|
||||
# Set to false to disable. If disabled, issues still need to be closed manually, but will remain marked as stale.
|
||||
daysUntilClose: 7
|
||||
|
||||
# Only issues or pull requests with all of these labels are check if stale. Defaults to `[]` (disabled)
|
||||
onlyLabels: []
|
||||
|
||||
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
|
||||
exemptLabels:
|
||||
- bug
|
||||
- pinned
|
||||
- security
|
||||
- "[Status] Maybe Later"
|
||||
|
||||
# Set to true to ignore issues in a project (defaults to false)
|
||||
exemptProjects: false
|
||||
|
||||
# Set to true to ignore issues in a milestone (defaults to false)
|
||||
exemptMilestones: false
|
||||
|
||||
# Set to true to ignore issues with an assignee (defaults to false)
|
||||
exemptAssignees: false
|
||||
|
||||
# Label to use when marking as stale
|
||||
staleLabel: wontfix
|
||||
|
||||
# Comment to post when marking as stale. Set to `false` to disable
|
||||
markComment: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity. It will be closed if no further activity occurs. Thank you
|
||||
for your contributions.
|
||||
|
||||
# Comment to post when removing the stale label.
|
||||
# unmarkComment: >
|
||||
# Your comment here.
|
||||
|
||||
# Comment to post when closing a stale Issue or Pull Request.
|
||||
# closeComment: >
|
||||
# Your comment here.
|
||||
|
||||
# Limit the number of actions per hour, from 1-30. Default is 30
|
||||
limitPerRun: 30
|
||||
|
||||
# Limit to only `issues` or `pulls`
|
||||
# only: issues
|
||||
|
||||
# Optionally, specify configuration settings that are specific to just 'issues' or 'pulls':
|
||||
# pulls:
|
||||
# daysUntilStale: 30
|
||||
# markComment: >
|
||||
# This pull request has been automatically marked as stale because it has not had
|
||||
# recent activity. It will be closed if no further activity occurs. Thank you
|
||||
# for your contributions.
|
||||
|
||||
# issues:
|
||||
# exemptLabels:
|
||||
# - confirmed
|
||||
24
.github/workflows/stale.yml
vendored
Normal file
24
.github/workflows/stale.yml
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
name: Mark stale issues and pull requests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 0/3 * * *"
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/stale@v3
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: 'This issue is stale because it has been open for three months with no activity. Remove the stale label or comment on the issue otherwise this will be closed in 5 days'
|
||||
stale-pr-message: 'This PR is stale because it has been open for a year with no activity. Remove the stale label or comment on the PR otherwise this will be closed in 5 days'
|
||||
stale-issue-label: 'stale'
|
||||
stale-pr-label: 'stale'
|
||||
days-before-stale: 90
|
||||
days-before-close: 5
|
||||
operations-per-run: 100
|
||||
exempt-issue-labels: 'bug,enhancement'
|
||||
remove-stale-when-updated: true
|
||||
6
.github/workflows/test.yml
vendored
6
.github/workflows/test.yml
vendored
@@ -39,9 +39,11 @@ jobs:
|
||||
- name: Install Qlib with pip
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install pyqlib --ignore-installed ruamel.yaml --user
|
||||
$CONDA\\python.exe -m pip install numpy==1.19.5
|
||||
$CONDA\\python.exe -m pip install pyqlib --ignore-installed ruamel.yaml numpy --user
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install pyqlib --ignore-installed ruamel.yaml
|
||||
sudo $CONDA/bin/python -m pip install numpy==1.19.5
|
||||
sudo $CONDA/bin/python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -34,3 +34,7 @@ tags
|
||||
|
||||
.pytest_cache/
|
||||
.vscode/
|
||||
|
||||
*.swp
|
||||
|
||||
./pretrain
|
||||
|
||||
64
README.md
64
README.md
@@ -7,6 +7,20 @@
|
||||
[](LICENSE)
|
||||
[](https://gitter.im/Microsoft/qlib?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||
|
||||
## :newspaper: **What's NEW!** :sparkling_heart:
|
||||
Recent released features
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| Online serving and automatic model rolling | :star: [Released](https://github.com/microsoft/qlib/pull/290) on May 17, 2021 |
|
||||
| DoubleEnsemble Model | [Released](https://github.com/microsoft/qlib/pull/286) on Mar 2, 2021 |
|
||||
| High-frequency data processing example | [Released](https://github.com/microsoft/qlib/pull/257) on Feb 5, 2021 |
|
||||
| High-frequency trading example | [Part of code released](https://github.com/microsoft/qlib/pull/227) on Jan 28, 2021 |
|
||||
| High-frequency data(1min) | [Released](https://github.com/microsoft/qlib/pull/221) on Jan 27, 2021 |
|
||||
| Tabnet Model | [Released](https://github.com/microsoft/qlib/pull/205) on Jan 22, 2021 |
|
||||
|
||||
Features released before 2021 are not listed here.
|
||||
|
||||
|
||||
|
||||
<p align="center">
|
||||
<img src="http://fintech.msra.cn/images_v060/logo/1.png" />
|
||||
@@ -17,10 +31,11 @@ Qlib is an AI-oriented quantitative investment platform, which aims to realize t
|
||||
|
||||
It contains the full ML pipeline of data processing, model training, back-testing; and covers the entire chain of quantitative investment: alpha seeking, risk modeling, portfolio optimization, and order execution.
|
||||
|
||||
With Qlib, user can easily try ideas to create better Quant investment strategies.
|
||||
With Qlib, users can easily try ideas to create better Quant investment strategies.
|
||||
|
||||
For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative Investment Platform"](https://arxiv.org/abs/2009.11189).
|
||||
|
||||
- [**Plans**](#plans)
|
||||
- [Framework of Qlib](#framework-of-qlib)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Installation](#installation)
|
||||
@@ -35,9 +50,20 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
|
||||
- [Offline Mode and Online Mode](#offline-mode-and-online-mode)
|
||||
- [Performance of Qlib Data Server](#performance-of-qlib-data-server)
|
||||
- [Related Reports](#related-reports)
|
||||
- [Contact Us](#contact-us)
|
||||
- [Contributing](#contributing)
|
||||
|
||||
|
||||
# Plans
|
||||
New features under development(order by estimated release time).
|
||||
Your feedbacks about the features are very important.
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| Planning-based portfolio optimization | Under review: https://github.com/microsoft/qlib/pull/280 |
|
||||
| Fund data supporting and analysis | Under review: https://github.com/microsoft/qlib/pull/292 |
|
||||
| Point-in-Time database | Under review: https://github.com/microsoft/qlib/pull/343 |
|
||||
| High-frequency trading | Under review: https://github.com/microsoft/qlib/pull/408 |
|
||||
| Meta-Learning-based data selection | Initial opensource version under development |
|
||||
|
||||
# Framework of Qlib
|
||||
|
||||
@@ -46,11 +72,11 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
|
||||
</div>
|
||||
|
||||
|
||||
At the module level, Qlib is a platform that consists of the above components. The components are designed as loose-coupled modules and each component could be used stand-alone.
|
||||
At the module level, Qlib is a platform that consists of the above components. The components are designed as loose-coupled modules, and each component could be used stand-alone.
|
||||
|
||||
| Name | Description |
|
||||
| ------ | ----- |
|
||||
| `Infrastructure` layer | `Infrastructure` layer provides underlying support for Quant research. `DataServer` provides high-performance infrastructure for users to manage and retrieve raw data. `Trainer` provides flexible interface to control the training process of models which enable algorithms controlling the training process. |
|
||||
| `Infrastructure` layer | `Infrastructure` layer provides underlying support for Quant research. `DataServer` provides a high-performance infrastructure for users to manage and retrieve raw data. `Trainer` provides a flexible interface to control the training process of models, which enable algorithms to control the training process. |
|
||||
| `Workflow` layer | `Workflow` layer covers the whole workflow of quantitative investment. `Information Extractor` extracts data for models. `Forecast Model` focuses on producing all kinds of forecast signals (e.g. _alpha_, risk) for other modules. With these signals `Portfolio Generator` will generate the target portfolio and produce orders to be executed by `Order Executor`. |
|
||||
| `Interface` layer | `Interface` layer tries to present a user-friendly interface for the underlying system. `Analyser` module will provide users detailed analysis reports of forecasting signals, portfolios and execution results |
|
||||
|
||||
@@ -118,14 +144,20 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
|
||||
## Data Preparation
|
||||
Load and prepare data by running the following code:
|
||||
```bash
|
||||
# get 1d data
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
|
||||
# get 1min data
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min
|
||||
|
||||
```
|
||||
|
||||
This dataset is created by public data collected by [crawler scripts](scripts/data_collector/), which have been released in
|
||||
the same repository.
|
||||
Users could create the same dataset with it.
|
||||
|
||||
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
|
||||
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup), and the data might not be perfect.
|
||||
We recommend users to prepare their own data if they have a high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
|
||||
|
||||
<!--
|
||||
- Run the initialization code and get stock data:
|
||||
@@ -213,9 +245,10 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
- Rank Label
|
||||

|
||||
-->
|
||||
- [Explanation](https://qlib.readthedocs.io/en/latest/component/report.html) of above results
|
||||
|
||||
## Building Customized Quant Research Workflow by Code
|
||||
The automatic workflow may not suite the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
|
||||
The automatic workflow may not suit the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
|
||||
|
||||
|
||||
# [Quant Model Zoo](examples/benchmarks)
|
||||
@@ -232,6 +265,7 @@ Here is a list of models built on `Qlib`.
|
||||
- [SFM based on pytorch (Liheng Zhang, et al. 2017)](qlib/contrib/model/pytorch_sfm.py)
|
||||
- [TFT based on tensorflow (Bryan Lim, et al. 2019)](examples/benchmarks/TFT/tft.py)
|
||||
- [TabNet based on pytorch (Sercan O. Arik, et al. 2019)](qlib/contrib/model/pytorch_tabnet.py)
|
||||
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. 2020)](qlib/contrib/model/double_ensemble.py)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
@@ -241,10 +275,10 @@ The performance of each model on the `Alpha158` and `Alpha360` dataset can be fo
|
||||
All the models listed above are runnable with ``Qlib``. Users can find the config files we provide and some details about the model through the [benchmarks](examples/benchmarks) folder. More information can be retrieved at the model files listed above.
|
||||
|
||||
`Qlib` provides three different ways to run a single model, users can pick the one that fits their cases best:
|
||||
- User can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
|
||||
- User can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
|
||||
- Users can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
|
||||
- Users can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
|
||||
|
||||
- User can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
|
||||
## Run multiple models
|
||||
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parrallel running the same model for multiple times as well, and this will be fixed in the future development too.)
|
||||
@@ -307,17 +341,27 @@ which creates a dataset (14 features/factors) from the basic OHLCV daily data of
|
||||
* `+(-)E` indicates with (out) `ExpressionCache`
|
||||
* `+(-)D` indicates with (out) `DatasetCache`
|
||||
|
||||
Most general-purpose databases take too much time on loading data. After looking into the underlying implementation, we find that data go through too many layers of interfaces and unnecessary format transformations in general-purpose database solutions.
|
||||
Most general-purpose databases take too much time to load data. After looking into the underlying implementation, we find that data go through too many layers of interfaces and unnecessary format transformations in general-purpose database solutions.
|
||||
Such overheads greatly slow down the data loading process.
|
||||
Qlib data are stored in a compact format, which is efficient to be combined into arrays for scientific computation.
|
||||
|
||||
|
||||
# Related Reports
|
||||
- [【华泰金工林晓明团队】图神经网络选股与Qlib实践——华泰人工智能系列之四十二](https://mp.weixin.qq.com/s/w5fDB6oAv9dO6vlhf1kmhA)
|
||||
- [Guide To Qlib: Microsoft’s AI Investment Platform](https://analyticsindiamag.com/qlib/)
|
||||
- [【华泰金工林晓明团队】微软AI量化投资平台Qlib体验——华泰人工智能系列之四十](https://mp.weixin.qq.com/s/Brcd7im4NibJOJzZfMn6tQ)
|
||||
- [微软也搞AI量化平台?还是开源的!](https://mp.weixin.qq.com/s/47bP5YwxfTp2uTHjUBzJQQ)
|
||||
- [微矿Qlib:业内首个AI量化投资开源平台](https://mp.weixin.qq.com/s/vsJv7lsgjEi-ALYUz4CvtQ)
|
||||
|
||||
# Contact Us
|
||||
- If you have any issues, please create issue [here](https://github.com/microsoft/qlib/issues/new/choose) or send messages in [gitter](https://gitter.im/Microsoft/qlib).
|
||||
- If you want to make contributions to `Qlib`, please [create pull requests](https://github.com/microsoft/qlib/compare).
|
||||
- For other reasons, you are welcome to contact us by email([qlib@microsoft.com](mailto:qlib@microsoft.com)).
|
||||
- We are recruiting new members(both FTEs and interns), your resumes are welcome!
|
||||
|
||||
Join IM discussion groups:
|
||||
|[Gitter](https://gitter.im/Microsoft/qlib)|
|
||||
|----|
|
||||
||
|
||||
|
||||
# Contributing
|
||||
|
||||
|
||||
@@ -70,3 +70,31 @@ If the issue is not resolved, use ``keys *`` to find if multiple keys exist. If
|
||||
|
||||
|
||||
Also, feel free to post a new issue in our GitHub repository. We always check each issue carefully and try our best to solve them.
|
||||
|
||||
3. ModuleNotFoundError: No module named 'qlib.data._libs.rolling'
|
||||
------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
#### Do not import qlib package in the repository directory in case of importing qlib from . without compiling #####
|
||||
Traceback (most recent call last):
|
||||
File "<stdin>", line 1, in <module>
|
||||
File "qlib/qlib/__init__.py", line 19, in init
|
||||
from .data.cache import H
|
||||
File "qlib/qlib/data/__init__.py", line 8, in <module>
|
||||
from .data import (
|
||||
File "qlib/qlib/data/data.py", line 20, in <module>
|
||||
from .cache import H
|
||||
File "qlib/qlib/data/cache.py", line 36, in <module>
|
||||
from .ops import Operators
|
||||
File "qlib/qlib/data/ops.py", line 19, in <module>
|
||||
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
|
||||
ModuleNotFoundError: No module named 'qlib.data._libs.rolling'
|
||||
|
||||
- If the error occurs when importing ``qlib`` package with ``PyCharm`` IDE, users can execute the following command in the project root folder to compile Cython files and generate executable files:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python setup.py build_ext --inplace
|
||||
|
||||
- If the error occurs when importing ``qlib`` package with command ``python`` , users need to change the running directory to ensure that the script does not run in the project directory.
|
||||
BIN
docs/_static/img/online_serving.png
vendored
Normal file
BIN
docs/_static/img/online_serving.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 440 KiB |
BIN
docs/_static/img/qrcode/gitter_qr.png
vendored
Normal file
BIN
docs/_static/img/qrcode/gitter_qr.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 7.2 KiB |
45
docs/advanced/serial.rst
Normal file
45
docs/advanced/serial.rst
Normal file
@@ -0,0 +1,45 @@
|
||||
.. _serial:
|
||||
|
||||
=================================
|
||||
Serialization
|
||||
=================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
Introduction
|
||||
===================
|
||||
``Qlib`` supports dumping the state of ``DataHandler``, ``DataSet``, ``Processor`` and ``Model``, etc. into a disk and reloading them.
|
||||
|
||||
Serializable Class
|
||||
========================
|
||||
|
||||
``Qlib`` provides a base class ``qlib.utils.serial.Serializable``, whose state can be dumped into or loaded from disk in `pickle` format.
|
||||
When users dump the state of a ``Serializable`` instance, the attributes of the instance whose name **does not** start with `_` will be saved on the disk.
|
||||
However, users can use ``config`` method or override ``default_dump_all`` attribute to prevent this feature.
|
||||
|
||||
Users can also override ``pickle_backend`` attribute to choose a pickle backend. The supported value is "pickle" (default and common) and "dill" (dump more things such as function, more information in `here <https://pypi.org/project/dill/>`_).
|
||||
|
||||
Example
|
||||
==========================
|
||||
``Qlib``'s serializable class includes ``DataHandler``, ``DataSet``, ``Processor`` and ``Model``, etc., which are subclass of ``qlib.utils.serial.Serializable``.
|
||||
Specifically, ``qlib.data.dataset.DatasetH`` is one of them. Users can serialize ``DatasetH`` as follows.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
##=============dump dataset=============
|
||||
dataset.to_pickle(path="dataset.pkl") # dataset is an instance of qlib.data.dataset.DatasetH
|
||||
|
||||
##=============reload dataset=============
|
||||
with open("dataset.pkl", "rb") as file_dataset:
|
||||
dataset = pickle.load(file_dataset)
|
||||
|
||||
.. note::
|
||||
Only state of ``DatasetH`` should be saved on the disk, such as some `mean` and `variance` used for data normalization, etc.
|
||||
|
||||
After reloading the ``DatasetH``, users need to reinitialize it. It means that users can reset some states of ``DatasetH`` or ``QlibDataHandler`` such as `instruments`, `start_time`, `end_time` and `segments`, etc., and generate new data according to the states (data is not state and should not be saved on the disk).
|
||||
|
||||
A more detailed example is in this `link <https://github.com/microsoft/qlib/tree/main/examples/highfreq>`_.
|
||||
|
||||
|
||||
API
|
||||
===================
|
||||
Please refer to `Serializable API <../reference/api.html#module-qlib.utils.serial.Serializable>`_.
|
||||
89
docs/advanced/task_management.rst
Normal file
89
docs/advanced/task_management.rst
Normal file
@@ -0,0 +1,89 @@
|
||||
.. _task_management:
|
||||
|
||||
=================================
|
||||
Task Management
|
||||
=================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
|
||||
Introduction
|
||||
=============
|
||||
|
||||
The `Workflow <../component/introduction.html>`_ part introduces how to run research workflow in a loosely-coupled way. But it can only execute one ``task`` when you use ``qrun``.
|
||||
To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Training`_ and `Task Collecting`_.
|
||||
With this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models.
|
||||
|
||||
This whole process can be used in `Online Serving <../component/online.html>`_.
|
||||
|
||||
An example of the entire process is shown `here <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
|
||||
|
||||
Task Generating
|
||||
===============
|
||||
A ``task`` consists of `Model`, `Dataset`, `Record`, or anything added by users.
|
||||
The specific task template can be viewed in
|
||||
`Task Section <../component/workflow.html#task-section>`_.
|
||||
Even though the task template is fixed, users can customize their ``TaskGen`` to generate different ``task`` by task template.
|
||||
|
||||
Here is the base class of ``TaskGen``:
|
||||
|
||||
.. autoclass:: qlib.workflow.task.gen.TaskGen
|
||||
:members:
|
||||
|
||||
``Qlib`` provides a class `RollingGen <https://github.com/microsoft/qlib/tree/main/qlib/workflow/task/gen.py>`_ to generate a list of ``task`` of the dataset in different date segments.
|
||||
This class allows users to verify the effect of data from different periods on the model in one experiment. More information is `here <../reference/api.html#TaskGen>`_.
|
||||
|
||||
Task Storing
|
||||
===============
|
||||
To achieve higher efficiency and the possibility of cluster operation, ``Task Manager`` will store all tasks in `MongoDB <https://www.mongodb.com/>`_.
|
||||
``TaskManager`` can fetch undone tasks automatically and manage the lifecycle of a set of tasks with error handling.
|
||||
Users **MUST** finish the configuration of `MongoDB <https://www.mongodb.com/>`_ when using this module.
|
||||
|
||||
Users need to provide the MongoDB URL and database name for using ``TaskManager`` in `initialization <../start/initialization.html#Parameters>`_ or make a statement like this.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from qlib.config import C
|
||||
C["mongo"] = {
|
||||
"task_url" : "mongodb://localhost:27017/", # your MongoDB url
|
||||
"task_db_name" : "rolling_db" # database name
|
||||
}
|
||||
|
||||
.. autoclass:: qlib.workflow.task.manage.TaskManager
|
||||
:members:
|
||||
|
||||
More information of ``Task Manager`` can be found in `here <../reference/api.html#TaskManager>`_.
|
||||
|
||||
Task Training
|
||||
===============
|
||||
After generating and storing those ``task``, it's time to run the ``task`` which is in the *WAITING* status.
|
||||
``Qlib`` provides a method called ``run_task`` to run those ``task`` in task pool, however, users can also customize how tasks are executed.
|
||||
An easy way to get the ``task_func`` is using ``qlib.model.trainer.task_train`` directly.
|
||||
It will run the whole workflow defined by ``task``, which includes *Model*, *Dataset*, *Record*.
|
||||
|
||||
.. autofunction:: qlib.workflow.task.manage.run_task
|
||||
|
||||
Meanwhile, ``Qlib`` provides a module called ``Trainer``.
|
||||
|
||||
.. autoclass:: qlib.model.trainer.Trainer
|
||||
:members:
|
||||
|
||||
``Trainer`` will train a list of tasks and return a list of model recorders.
|
||||
``Qlib`` offer two kinds of Trainer, TrainerR is the simplest way and TrainerRM is based on TaskManager to help manager tasks lifecycle automatically.
|
||||
If you do not want to use ``Task Manager`` to manage tasks, then use TrainerR to train a list of tasks generated by ``TaskGen`` is enough.
|
||||
`Here <../reference/api.html#Trainer>`_ are the details about different ``Trainer``.
|
||||
|
||||
Task Collecting
|
||||
===============
|
||||
To collect the results of ``task`` after training, ``Qlib`` provides `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_ to collect the results in a readable, expandable and loosely-coupled way.
|
||||
|
||||
`Collector <../reference/api.html#Collector>`_ can collect objects from everywhere and process them such as merging, grouping, averaging and so on. It has 2 step action including ``collect`` (collect anything in a dict) and ``process_collect`` (process collected dict).
|
||||
|
||||
`Group <../reference/api.html#Group>`_ also has 2 steps including ``group`` (can group a set of object based on `group_func` and change them to a dict) and ``reduce`` (can make a dict become an ensemble based on some rule).
|
||||
For example: {(A,B,C1): object, (A,B,C2): object} ---``group``---> {(A,B): {C1: object, C2: object}} ---``reduce``---> {(A,B): object}
|
||||
|
||||
`Ensemble <../reference/api.html#Ensemble>`_ can merge the objects in an ensemble.
|
||||
For example: {C1: object, C2: object} ---``Ensemble``---> object
|
||||
|
||||
So the hierarchy is ``Collector``'s second step corresponds to ``Group``. And ``Group``'s second step correspond to ``Ensemble``.
|
||||
|
||||
For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
|
||||
@@ -31,7 +31,7 @@ Qlib Format Data
|
||||
We've specially designed a data structure to manage financial data, please refer to the `File storage design section in Qlib paper <https://arxiv.org/abs/2009.11189>`_ for detailed information.
|
||||
Such data will be stored with filename suffix `.bin` (We'll call them `.bin` file, `.bin` format, or qlib format). `.bin` file is designed for scientific computing on finance data.
|
||||
|
||||
``Qlib`` provides two different off-the-shelf dataset, which can be accessed through this `link <https://github.com/microsoft/qlib/blob/main/qlib/contrib/data/handler.py>`_:
|
||||
``Qlib`` provides two different off-the-shelf datasets, which can be accessed through this `link <https://github.com/microsoft/qlib/blob/main/qlib/contrib/data/handler.py>`_:
|
||||
|
||||
======================== ================= ================
|
||||
Dataset US Market China Market
|
||||
@@ -41,6 +41,7 @@ Alpha360 √ √
|
||||
Alpha158 √ √
|
||||
======================== ================= ================
|
||||
|
||||
Also, ``Qlib`` provides a high-frequency dataset. Users can run a high-frequency dataset example through this `link <https://github.com/microsoft/qlib/tree/main/examples/highfreq>`_.
|
||||
|
||||
Qlib Format Dataset
|
||||
--------------------
|
||||
@@ -48,15 +49,19 @@ Qlib Format Dataset
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# download 1d
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
|
||||
# download 1min
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --region cn --interval 1min
|
||||
|
||||
In addition to China-Stock data, ``Qlib`` also includes a US-Stock dataset, which can be downloaded with the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us
|
||||
|
||||
After running the above command, users can find china-stock and us-stock data in ``Qlib`` format in the ``~/.qlib/csv_data/cn_data`` directory and ``~/.qlib/csv_data/us_data`` directory respectively.
|
||||
After running the above command, users can find china-stock and us-stock data in ``Qlib`` format in the ``~/.qlib/qlib_data/cn_data`` directory and ``~/.qlib/qlib_data/us_data`` directory respectively.
|
||||
|
||||
``Qlib`` also provides the scripts in ``scripts/data_collector`` to help users crawl the latest data on the Internet and convert it to qlib format.
|
||||
|
||||
@@ -67,12 +72,19 @@ Converting CSV Format into Qlib Format
|
||||
|
||||
``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert **any** data in CSV format into `.bin` files (``Qlib`` format) as long as they are in the correct format.
|
||||
|
||||
Users can download the demo china-stock data in CSV format as follows for reference to the CSV format.
|
||||
Besides downloading the prepared demo data, users could download demo data directly from the Collector as follows for reference to the CSV format.
|
||||
Here are some example:
|
||||
|
||||
.. code-block:: bash
|
||||
for daily data:
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
|
||||
|
||||
for 1min data:
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/data_collector/yahoo/collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1min --region CN --start 2021-05-20 --end 2021-05-23 --delay 0.1 --interval 1min --limit_nums 10
|
||||
|
||||
Users can also provide their own data in CSV format. However, the CSV data **must satisfies** following criterions:
|
||||
|
||||
- CSV file is named after a specific stock *or* the CSV file includes a column of the stock name
|
||||
@@ -140,6 +152,16 @@ After conversion, users can find their Qlib format data in the directory `~/.qli
|
||||
|
||||
In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended.
|
||||
|
||||
Stock Pool (Market)
|
||||
--------------------------------
|
||||
|
||||
``Qlib`` defines `stock pool <https://github.com/microsoft/qlib/blob/main/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml#L4>`_ as stock list and their date ranges. Predefined stock pools (e.g. csi300) may be imported as follows.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python collector.py --index_name CSI300 --qlib_dir <user qlib data dir> --method parse_instruments
|
||||
|
||||
|
||||
Multiple Stock Modes
|
||||
--------------------------------
|
||||
|
||||
@@ -158,7 +180,7 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
|
||||
- If users use ``Qlib`` in china-stock mode, china-stock data is required. Users can use ``Qlib`` in china-stock mode according to the following steps:
|
||||
- Download china-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
|
||||
- Initialize ``Qlib`` in china-stock mode
|
||||
Supposed that users download their Qlib format data in the directory ``~/.qlib/csv_data/cn_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
Supposed that users download their Qlib format data in the directory ``~/.qlib/qlib_data/cn_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -167,9 +189,9 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
|
||||
|
||||
|
||||
- If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` also provides a script to download US-stock data. Users can use ``Qlib`` in US-stock mode according to the following steps:
|
||||
- Download china-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
|
||||
- Download us-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
|
||||
- Initialize ``Qlib`` in US-stock mode
|
||||
Supposed that users prepare their Qlib format data in the directory ``~/.qlib/csv_data/us_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
Supposed that users prepare their Qlib format data in the directory ``~/.qlib/qlib_data/us_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -177,6 +199,11 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
|
||||
qlib.init(provider_uri='~/.qlib/qlib_data/us_data', region=REG_US)
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
PRs for new data source are highly welcome! Users could commit the code to crawl data as a PR like `the examples here <https://github.com/microsoft/qlib/tree/main/scripts>`_. And then we will use the code to create data cache on our server which other users could use directly.
|
||||
|
||||
|
||||
Data API
|
||||
========================
|
||||
|
||||
@@ -213,6 +240,25 @@ Filter
|
||||
- `cross-sectional features filter` \: rule_expression = '$rank($close)<10'
|
||||
- `time-sequence features filter`: rule_expression = '$Ref($close, 3)>100'
|
||||
|
||||
Here is a simple example showing how to use filter in a basic ``Qlib`` workflow configuration file:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
filter: &filter
|
||||
filter_type: ExpressionDFilter
|
||||
rule_expression: "Ref($close, -2) / Ref($close, -1) > 1"
|
||||
filter_start_time: 2010-01-01
|
||||
filter_end_time: 2010-01-07
|
||||
keep: False
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2010-01-01
|
||||
end_time: 2021-01-22
|
||||
fit_start_time: 2010-01-01
|
||||
fit_end_time: 2015-12-31
|
||||
instruments: *market
|
||||
filter_pipe: [*filter]
|
||||
|
||||
To know more about ``Filter``, please refer to `Filter API <../reference/api.html#module-qlib.data.filter>`_.
|
||||
|
||||
Reference
|
||||
@@ -274,9 +320,10 @@ Here are some important interfaces that ``DataHandlerLP`` provides:
|
||||
.. autoclass:: qlib.data.dataset.handler.DataHandlerLP
|
||||
:members: __init__, fetch, get_cols
|
||||
|
||||
If users want to load features and labels by config, users can inherit ``qlib.data.dataset.handler.ConfigDataHandler``, ``Qlib`` also provides some preprocess method in this subclass.
|
||||
|
||||
If users want to use qlib data, `QLibDataHandler` is recommended. Users can inherit their custom class from `QLibDataHandler`, which is also a subclass of `ConfigDataHandler`.
|
||||
If users want to load features and labels by config, users can define a new handler and call the static method `parse_config_to_fields` of ``qlib.contrib.data.handler.Alpha158``.
|
||||
|
||||
Also, users can pass ``qlib.contrib.data.processor.ConfigSectionProcessor`` that provides some preprocess methods for features defined by config into the new handler.
|
||||
|
||||
|
||||
Processor
|
||||
@@ -313,7 +360,6 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
|
||||
|
||||
.. note:: Users need to initialize ``Qlib`` with `qlib.init` first, please refer to `initialization <../start/initialization.html>`_.
|
||||
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
import qlib
|
||||
@@ -340,6 +386,9 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
|
||||
# fetch all the features
|
||||
print(h.fetch(col_set="feature"))
|
||||
|
||||
|
||||
.. note:: In the ``Alpha158``, ``Qlib`` uses the label `Ref($close, -2)/Ref($close, -1) - 1` that means the change from T+1 to T+2, rather than `Ref($close, -1)/$close - 1`, of which the reason is that when getting the T day close price of a china stock, the stock can be bought on T+1 day and sold on T+2 day.
|
||||
|
||||
API
|
||||
---------
|
||||
|
||||
@@ -364,8 +413,7 @@ The ``DatasetH`` class is the `dataset` with `Data Handler`. Here is the most im
|
||||
API
|
||||
---------
|
||||
|
||||
To know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#module-qlib.data.dataset.__init__>`_.
|
||||
|
||||
To know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#dataset>`_.
|
||||
|
||||
|
||||
Cache
|
||||
|
||||
46
docs/component/online.rst
Normal file
46
docs/component/online.rst
Normal file
@@ -0,0 +1,46 @@
|
||||
.. _online:
|
||||
|
||||
=================================
|
||||
Online Serving
|
||||
=================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
|
||||
Introduction
|
||||
=============
|
||||
|
||||
.. image:: ../_static/img/online_serving.png
|
||||
:align: center
|
||||
|
||||
|
||||
In addition to backtesting, one way to test a model is effective is to make predictions in real market conditions or even do real trading based on those predictions.
|
||||
``Online Serving`` is a set of modules for online models using the latest data,
|
||||
which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online Strategy>`_, `Online Tool <#Online Tool>`_, `Updater <#Updater>`_.
|
||||
|
||||
`Here <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are several examples for reference, which demonstrate different features of ``Online Serving``.
|
||||
If you have many models or `task` needs to be managed, please consider `Task Management <../advanced/task_management.html>`_.
|
||||
The `examples <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are based on some components in `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``.
|
||||
|
||||
Online Manager
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.manager
|
||||
:members:
|
||||
|
||||
Online Strategy
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.strategy
|
||||
:members:
|
||||
|
||||
Online Tool
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.utils
|
||||
:members:
|
||||
|
||||
Updater
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.update
|
||||
:members:
|
||||
@@ -34,6 +34,7 @@ Here is a general view of the structure of the system:
|
||||
- Recorder 2
|
||||
- ...
|
||||
- ...
|
||||
|
||||
This experiment management system defines a set of interface and provided a concrete implementation ``MLflowExpManager``, which is based on the machine learning platform: ``MLFlow`` (`link <https://mlflow.org/>`_).
|
||||
|
||||
If users set the implementation of ``ExpManager`` to be ``MLflowExpManager``, they can use the command `mlflow ui` to visualize and check the experiment results. For more information, pleaes refer to the related documents `here <https://www.mlflow.org/docs/latest/cli.html#mlflow-ui>`_.
|
||||
@@ -94,6 +95,52 @@ The ``RecordTemp`` class is a class that enables generate experiment results suc
|
||||
|
||||
- ``SignalRecord``: This class generates the `prediction` results of the model.
|
||||
- ``SigAnaRecord``: This class generates the `IC`, `ICIR`, `Rank IC` and `Rank ICIR` of the model.
|
||||
|
||||
Here is a simple example of what is done in ``SigAnaRecord``, which users can refer to if they want to calculate IC, Rank IC, Long-Short Return with their own prediction and label.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
from qlib.contrib.eva.alpha import calc_ic, calc_long_short_return
|
||||
|
||||
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
|
||||
- ``PortAnaRecord``: This class generates the results of `backtest`. The detailed information about `backtest` as well as the available `strategy`, users can refer to `Strategy <../component/strategy.html>`_ and `Backtest <../component/backtest.html>`_.
|
||||
|
||||
Here is a simple exampke of what is done in ``PortAnaRecord``, which users can refer to if they want to do backtest based on their own prediction and label.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
|
||||
# backtest
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
|
||||
|
||||
# analysis
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"])
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
print(analysis_df)
|
||||
|
||||
For more information about the APIs, please refer to `Record Template API <../reference/api.html#module-qlib.workflow.record_temp>`_.
|
||||
|
||||
@@ -101,7 +101,7 @@ Graphical Result
|
||||
- Axis Y:
|
||||
- `ic`
|
||||
The `Pearson correlation coefficient` series between `label` and `prediction score`.
|
||||
In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Featrue <data.html#feature>`_ for more details.
|
||||
In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Feature <data.html#feature>`_ for more details.
|
||||
|
||||
- `rank_ic`
|
||||
The `Spearman's rank correlation coefficient` series between `label` and `prediction score`.
|
||||
|
||||
@@ -111,8 +111,6 @@ Usage & Example
|
||||
pred_score, strategy=strategy, **BACKTEST_CONFIG
|
||||
)
|
||||
|
||||
Also, the above example has been given in ``examples/train_backtest_analyze.ipynb``.
|
||||
|
||||
To know more about the `prediction score` `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
|
||||
|
||||
To know more about ``Intraday Trading``, please refer to `Intraday Trading: Model&Strategy Testing <backtest.html>`_.
|
||||
|
||||
@@ -90,12 +90,12 @@ Below is a typical config file of ``qrun``.
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
|
||||
After saving the config into `configuration.yaml`, users could start the workflow and test their ideas with a single command below.
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ Document Structure
|
||||
Intraday Trading: Model&Strategy Testing <component/backtest.rst>
|
||||
Qlib Recorder: Experiment Management <component/recorder.rst>
|
||||
Analysis: Evaluation & Results Analysis <component/report.rst>
|
||||
Online Serving: Online Management & Strategy & Tool <component/online.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
@@ -49,6 +50,8 @@ Document Structure
|
||||
|
||||
Building Formulaic Alphas <advanced/alpha.rst>
|
||||
Online & Offline mode <advanced/server.rst>
|
||||
Serialization <advanced/serial.rst>
|
||||
Task Management <advanced/task_management.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
|
||||
@@ -53,6 +53,34 @@ Cache
|
||||
.. autoclass:: qlib.data.cache.DiskDatasetCache
|
||||
:members:
|
||||
|
||||
|
||||
Storage
|
||||
-------------
|
||||
.. autoclass:: qlib.data.storage.storage.BaseStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.storage.CalendarStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.storage.InstrumentStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.storage.FeatureStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.file_storage.FileStorageMixin
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.file_storage.FileCalendarStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.file_storage.FileInstrumentStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.file_storage.FileFeatureStorage
|
||||
:members:
|
||||
|
||||
|
||||
Dataset
|
||||
---------------
|
||||
|
||||
@@ -152,4 +180,81 @@ Recorder
|
||||
Record Template
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.record_temp
|
||||
:members:
|
||||
:members:
|
||||
|
||||
Task Management
|
||||
====================
|
||||
|
||||
|
||||
TaskGen
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.task.gen
|
||||
:members:
|
||||
|
||||
TaskManager
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.task.manage
|
||||
:members:
|
||||
|
||||
Trainer
|
||||
--------------------
|
||||
.. automodule:: qlib.model.trainer
|
||||
:members:
|
||||
|
||||
Collector
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.task.collect
|
||||
:members:
|
||||
|
||||
Group
|
||||
--------------------
|
||||
.. automodule:: qlib.model.ens.group
|
||||
:members:
|
||||
|
||||
Ensemble
|
||||
--------------------
|
||||
.. automodule:: qlib.model.ens.ensemble
|
||||
:members:
|
||||
|
||||
Utils
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.task.utils
|
||||
:members:
|
||||
|
||||
|
||||
Online Serving
|
||||
====================
|
||||
|
||||
|
||||
Online Manager
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.manager
|
||||
:members:
|
||||
|
||||
Online Strategy
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.strategy
|
||||
:members:
|
||||
|
||||
Online Tool
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.utils
|
||||
:members:
|
||||
|
||||
RecordUpdater
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.update
|
||||
:members:
|
||||
|
||||
|
||||
Utils
|
||||
====================
|
||||
|
||||
Serializable
|
||||
--------------------
|
||||
|
||||
.. automodule:: qlib.utils.serial.Serializable
|
||||
:members:
|
||||
|
||||
|
||||
|
||||
@@ -75,3 +75,14 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
"default_exp_name": "Experiment",
|
||||
}
|
||||
})
|
||||
- `mongo`
|
||||
Type: dict, optional parameter, the setting of `MongoDB <https://www.mongodb.com/>`_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.
|
||||
Users need finished `installation <https://www.mongodb.com/try/download/community>`_ firstly, and run it in a fixed URL.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
# For example, you can initialize qlib below
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo={
|
||||
"task_url": "mongodb://localhost:27017/", # your mongo url
|
||||
"task_db_name": "rolling_db", # the database name of Task Management
|
||||
})
|
||||
|
||||
@@ -82,7 +82,7 @@ The Custom models need to inherit `qlib.model.base.Model <../reference/api.html#
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
- Override the `finetune` method (Optional)
|
||||
- This method is optional to the users, and when users one to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`.
|
||||
- This method is optional to the users. When users want to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`.
|
||||
- The parameters must include the parameter `dataset`.
|
||||
- Code Example: In the following example, users will use `LightGBM` as the model and finetune it.
|
||||
.. code-block:: Python
|
||||
|
||||
4
examples/benchmarks/DoubleEnsemble/README.md
Normal file
4
examples/benchmarks/DoubleEnsemble/README.md
Normal file
@@ -0,0 +1,4 @@
|
||||
# DoubleEnsemble
|
||||
* DoubleEnsemble is an ensemble framework leveraging learning trajectory based sample reweighting and shuffling based feature selection, to solve both the low signal-to-noise ratio and increasing number of features problems. They identify the key samples based on the training dynamics on each sample and elicit key features based on the ablation impact of each feature via shuffling. The model is applicable to a wide range of base models, capable of extracting complex patterns, while mitigating the overfitting and instability issues for financial market prediction.
|
||||
* This code used in Qlib is implemented by ourselves.
|
||||
* Paper: DoubleEnsemble: A New Ensemble Method Based on Sample Reweighting and Feature Selection for Financial Data Analysis [https://arxiv.org/pdf/2010.01265.pdf](https://arxiv.org/pdf/2010.01265.pdf).
|
||||
3
examples/benchmarks/DoubleEnsemble/requirements.txt
Normal file
3
examples/benchmarks/DoubleEnsemble/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
lightgbm==3.1.0
|
||||
@@ -0,0 +1,90 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: DEnsembleModel
|
||||
module_path: qlib.contrib.model.double_ensemble
|
||||
kwargs:
|
||||
base_model: "gbm"
|
||||
loss: mse
|
||||
num_models: 6
|
||||
enable_sr: True
|
||||
enable_fs: True
|
||||
alpha1: 1
|
||||
alpha2: 1
|
||||
bins_sr: 10
|
||||
bins_fs: 5
|
||||
decay: 0.5
|
||||
sample_ratios:
|
||||
- 0.8
|
||||
- 0.7
|
||||
- 0.6
|
||||
- 0.5
|
||||
- 0.4
|
||||
sub_weights:
|
||||
- 1
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
epochs: 28
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
verbosity: -1
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -0,0 +1,97 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors: []
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: DEnsembleModel
|
||||
module_path: qlib.contrib.model.double_ensemble
|
||||
kwargs:
|
||||
base_model: "gbm"
|
||||
loss: mse
|
||||
num_models: 6
|
||||
enable_sr: True
|
||||
enable_fs: True
|
||||
alpha1: 1
|
||||
alpha2: 1
|
||||
bins_sr: 10
|
||||
bins_fs: 5
|
||||
decay: 0.5
|
||||
sample_ratios:
|
||||
- 0.8
|
||||
- 0.7
|
||||
- 0.6
|
||||
- 0.5
|
||||
- 0.4
|
||||
sub_weights:
|
||||
- 1
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
epochs: 136
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
verbosity: -1
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -29,7 +29,7 @@ data_handler_config: &data_handler_config
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
instruments: *market
|
||||
data_loader:
|
||||
class: QlibDataLoader
|
||||
kwargs:
|
||||
config:
|
||||
feature:
|
||||
- ["Resi($close, 15)/$close", "Std(Abs($close/Ref($close, 1)-1)*$volume, 5)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, 5)+1e-12)", "Rsquare($close, 5)", "($high-$low)/$open", "Rsquare($close, 10)", "Corr($close, Log($volume+1), 5)", "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 5)", "Corr($close, Log($volume+1), 10)", "Rsquare($close, 20)", "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 60)", "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 10)", "Corr($close, Log($volume+1), 20)", "(Less($open, $close)-$low)/$open"]
|
||||
- ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", "RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"]
|
||||
label:
|
||||
- ["Ref($close, -2)/Ref($close, -1) - 1"]
|
||||
- ["LABEL0"]
|
||||
freq: day
|
||||
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSZScoreNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandlerLP
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -16,6 +16,8 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| LSTM (Sepp Hochreiter, et al.) | Alpha360 | 0.0443±0.01 | 0.3401±0.05| 0.0536±0.01 | 0.4248±0.05 | 0.0627±0.03 | 0.8441±0.48| -0.0882±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 |
|
||||
|
||||
## Alpha158 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
@@ -25,11 +27,13 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| XGBoost (Tianqi Chen, et al.) | Alpha158 | 0.0481±0.00 | 0.3659±0.00| 0.0495±0.00 | 0.4033±0.00 | 0.1111±0.00 | 1.2915±0.00| -0.0893±0.00 |
|
||||
| LightGBM (Guolin Ke, et al.) | Alpha158 | 0.0475±0.00 | 0.3979±0.00| 0.0485±0.00 | 0.4123±0.00 | 0.1143±0.00 | 1.2744±0.00| -0.0800±0.00 |
|
||||
| MLP | Alpha158 | 0.0358±0.00 | 0.2738±0.03| 0.0425±0.00 | 0.3221±0.01 | 0.0836±0.02 | 1.0323±0.25| -0.1127±0.02 |
|
||||
| TabNet with pretrain (Sercan O. Arikm et al) | Alpha158 | 0.0344±0.00|0.205±0.11|0.0398±0.00 |0.3479±0.01|0.0827±0.02|1.1141±0.32 |-0.0925±0.02 |
|
||||
| TFT (Bryan Lim, et al.) | Alpha158 (with selected 20 features) | 0.0343±0.00 | 0.2071±0.02| 0.0107±0.00 | 0.0660±0.02 | 0.0623±0.02 | 0.5818±0.20| -0.1762±0.01 |
|
||||
| GRU (Kyunghyun Cho, et al.) | Alpha158 (with selected 20 features) | 0.0311±0.00 | 0.2418±0.04| 0.0425±0.00 | 0.3434±0.02 | 0.0330±0.02 | 0.4805±0.30| -0.1021±0.02 |
|
||||
| LSTM (Sepp Hochreiter, et al.) | Alpha158 (with selected 20 features) | 0.0312±0.00 | 0.2394±0.04| 0.0418±0.00 | 0.3324±0.03 | 0.0298±0.02 | 0.4198±0.33| -0.1348±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 |
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
- The base model of DoubleEnsemble is LGBM.
|
||||
|
||||
@@ -132,7 +132,7 @@ class GenericDataFormatter(abc.ABC):
|
||||
return -1, -1
|
||||
|
||||
def get_column_definition(self):
|
||||
""""Returns formatted column definition in order expected by the TFT."""
|
||||
"""Returns formatted column definition in order expected by the TFT."""
|
||||
|
||||
column_definition = self._column_definition
|
||||
|
||||
|
||||
Binary file not shown.
@@ -44,6 +44,7 @@ task:
|
||||
class: TabnetModel
|
||||
module_path: qlib.contrib.model.pytorch_tabnet
|
||||
kwargs:
|
||||
d_feat: 158
|
||||
pretrain: True
|
||||
dataset:
|
||||
class: DatasetH
|
||||
@@ -55,7 +56,7 @@ task:
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
pretrain: [2008-01-01, 2014-12-31]
|
||||
pretrain_validation: [2015-01-01, 2020-08-01]
|
||||
pretrain_validation: [2015-01-01, 2016-12-31]
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TabnetModel
|
||||
module_path: qlib.contrib.model.pytorch_tabnet
|
||||
kwargs:
|
||||
d_feat: 360
|
||||
pretrain: True
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
pretrain: [2008-01-01, 2014-12-31]
|
||||
pretrain_validation: [2015-01-01, 2016-12-31]
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
208
examples/data/monitor.py
Normal file
208
examples/data/monitor.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""
|
||||
This script is the demonstrating the implementation of Metric Extractor and Detector
|
||||
|
||||
NOTE: A lot of details is not considered in this script
|
||||
- Corner case that will raise error( std == 0)
|
||||
|
||||
|
||||
|
||||
The following functions are used to demonstrate the following examples
|
||||
|
||||
|
||||
· Metric Extractor:
|
||||
case 1) Basic statistics on different slices of the DataFrame df:
|
||||
1) The statistics include:
|
||||
· STD, Mean, Skewnes, Kurtosis
|
||||
2) The above statistics can be calculated on the following data slices:
|
||||
· df.groupby(['datetime'])
|
||||
· df.groupby(['datetime', 'industry' ])
|
||||
3) The statistics could be calculated on the time dimension for each instruments and factor(the factor can be represented by experssion)
|
||||
· <df implemented by expresion>.groupby(['instrument', 'factor'])
|
||||
case 2) Advanced statistics on different slices of the DataFrame df:
|
||||
1) Auto-correlation:
|
||||
· Calculate corr(df.loc[t, :, :], df.loc[t-w, :, :]), w=1, 2, ….
|
||||
2) Correlation between factors:
|
||||
· For any pair of factors (i, j): calculate corr(df.loc[t, :, i], df.loc[t, :, j]). The result is a correlation matrix with each element corresponds to a correlation value between a pair of factors.
|
||||
|
||||
· Detector: detect the abnormality of the extracted metric;
|
||||
a) Algorithms:
|
||||
§ Basic checks: NaN.
|
||||
§ Point anomaly detection.
|
||||
§ Segment anomaly detection.
|
||||
b) Scenarios:
|
||||
§ Online anomaly detection: monitoring streaming data.
|
||||
The usage of the detectors are demonstrated in the `case_1_*`and `case_2_*`
|
||||
|
||||
|
||||
case 3): Examples to use MetricExt to monitor IC and rank IC
|
||||
1) IC(Information Coefficient) #case_3_1
|
||||
2) RankIC #case_3_2
|
||||
"""
|
||||
|
||||
# AUTO download data
|
||||
from typing import List, Union
|
||||
from qlib.utils import exists_qlib_data
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.config import REG_CN
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.data.dataset.loader import QlibDataLoader
|
||||
from qlib.data.monitor.metric import format_conv
|
||||
from qlib.data.monitor.metric import MeanM, SkewM, KurtM, StdM, AutoCM, CorrM
|
||||
from qlib.data.monitor.detector import NDDetector, SWNDD, ThresholdD
|
||||
from qlib.data import D
|
||||
import fire
|
||||
|
||||
UNIVERSE = "csi300"
|
||||
START_TIME = "20200101"
|
||||
|
||||
# ------------------ a helper function to get data to demonstrate the functionality --------------------
|
||||
|
||||
|
||||
def get_data_df(col_idx: Union[int, List[int]] = 0, verbose: bool = True):
|
||||
"""
|
||||
a helper function to get data to demonstrate the functionality.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
col_idx : Union[int, List[int]]
|
||||
column index of the metrics
|
||||
"""
|
||||
dh = Alpha158(instruments=UNIVERSE, infer_processors=[], learn_processors=[], start_time=START_TIME)
|
||||
df = dh.fetch()
|
||||
|
||||
if verbose:
|
||||
print(df.head())
|
||||
|
||||
# We don't have industries in dataframe, we generate the with fake data
|
||||
industry = pd.Series(df.index.get_level_values("instrument").str.slice(stop=2).to_list(), index=df.index)
|
||||
|
||||
# select a factor
|
||||
factor_df = format_conv(df.iloc[:, col_idx], industry=industry)
|
||||
if verbose:
|
||||
print(f"Selected metric: {df.columns[col_idx]}")
|
||||
print(factor_df)
|
||||
return factor_df
|
||||
|
||||
|
||||
def get_target(horizon=5):
|
||||
target = f"Ref($close, -{horizon + 1})/Ref($close, -1) - 1" # There are lots of targets: return is one of them
|
||||
qdl = QlibDataLoader(config=([target], ["target"]))
|
||||
df = qdl.load(instruments=UNIVERSE, start_time=START_TIME) # Aligning with factor will improve performance
|
||||
df = format_conv(df["target"])
|
||||
return df
|
||||
|
||||
|
||||
# ----------------- Cases to demonstrate the usage of detector and examples ----------------------
|
||||
|
||||
|
||||
def case_1_1():
|
||||
factor_df = get_data_df()
|
||||
# 1) Extract metrics
|
||||
|
||||
# 1.1) df.groupby(["datetime"])
|
||||
mtrc = MeanM()
|
||||
m_mean = mtrc.extract(factor_df)
|
||||
print(m_mean)
|
||||
|
||||
ndd = NDDetector()
|
||||
ndd.fit(m_mean) # use historical data to fit detector
|
||||
check_res = ndd.check(m_mean)
|
||||
print(check_res) # detecting on new data or historical data
|
||||
print(check_res.value_counts())
|
||||
|
||||
|
||||
def case_1_2():
|
||||
factor_df = get_data_df()
|
||||
# 1.2) df.groupby("datetime", "industry")
|
||||
mtrc = MeanM(group=["industry"])
|
||||
m_multi = mtrc.extract(factor_df)
|
||||
print(m_multi)
|
||||
|
||||
for col_name, s in m_multi.iteritems():
|
||||
print(col_name)
|
||||
ndd = NDDetector()
|
||||
ndd.fit(s) # use historical data to fit detector
|
||||
check_res = ndd.check(s)
|
||||
print(check_res) # detecting on new data or historical data
|
||||
print(check_res.value_counts())
|
||||
|
||||
|
||||
def case_1_3():
|
||||
# case 1.3
|
||||
# factor_df = get_data_df()
|
||||
qdl = QlibDataLoader(config=(["$close/Ref($close, 1) - 1"], ["return"]))
|
||||
df = qdl.load(instruments=["SH600519"], start_time=START_TIME)
|
||||
df = format_conv(df)
|
||||
s = df.iloc[:, 0]
|
||||
print(s)
|
||||
dtc = SWNDD(window=20)
|
||||
dtc.fit(s) # fit use historical data (TODO: updating will be supported in the future)
|
||||
check_res = dtc.check(s) #
|
||||
print(check_res)
|
||||
print(check_res.value_counts())
|
||||
print(check_res[check_res])
|
||||
|
||||
|
||||
def case_2_1():
|
||||
# · Calculate corr(df.loc[t, :, :], df.loc[t-w, :, :]), w=1, 2, ….
|
||||
factor_df = get_data_df()
|
||||
acm = AutoCM()
|
||||
mtrc = acm.extract(factor_df)
|
||||
print(mtrc)
|
||||
|
||||
thd = ThresholdD(0.0, reverse=True)
|
||||
check_res = thd.check(mtrc)
|
||||
|
||||
print(check_res)
|
||||
print(check_res.value_counts())
|
||||
|
||||
|
||||
def case_2_2():
|
||||
factor_df1, factor_df2 = get_data_df(0), get_data_df(1)
|
||||
|
||||
cm = CorrM()
|
||||
mtrc = cm.extract(factor_df1, factor_df2)
|
||||
print(mtrc)
|
||||
|
||||
thd = ThresholdD(0.0, reverse=True)
|
||||
check_res = thd.check(mtrc)
|
||||
|
||||
print(check_res)
|
||||
print(check_res.value_counts())
|
||||
|
||||
|
||||
def case_3_1_3_2():
|
||||
target, factor = get_target(), get_data_df(0)
|
||||
ic_m, rank_ic_m = CorrM(), CorrM(mode="spearman")
|
||||
ic, rank_ic = ic_m.extract(factor, target), rank_ic_m.extract(factor, target)
|
||||
print(pd.DataFrame({"ic": ic, "rank_ic": rank_ic}))
|
||||
|
||||
|
||||
def run(test_list=["case_1_1", "case_1_2", "case_1_3", "case_2_1", "case_2_2", "case_3_1_3_2"]):
|
||||
"""
|
||||
run the specific tests
|
||||
|
||||
python monitor.py case_3_1_3_2
|
||||
|
||||
Parameters
|
||||
----------
|
||||
test_list : str[]
|
||||
The tests to run
|
||||
"""
|
||||
if isinstance(test_list, str):
|
||||
test_list = [test_list]
|
||||
for fn in test_list:
|
||||
globals()[fn]()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
qlib.init()
|
||||
fire.Fire(run)
|
||||
130
examples/data/monitor_analyser_demo.ipynb
Normal file
130
examples/data/monitor_analyser_demo.ipynb
Normal file
@@ -0,0 +1,130 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0e62a81e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from tqdm.auto import tqdm\n",
|
||||
"%matplotlib inline\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c503217b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.data.monitor.analyser import Analyser\n",
|
||||
"import qlib\n",
|
||||
"qlib.init()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9c276470",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class SimpleDFA(Analyser):\n",
|
||||
" \"\"\"Simple (D)ata(F)rame (A)nalyser\"\"\"\n",
|
||||
" def analyse(self, data: pd.DataFrame, *args, **kwargs):\n",
|
||||
" data.plot(*args, **kwargs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "110262e4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from monitor import get_data_df, AutoCM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0ea38c62",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# get data\n",
|
||||
"factor_df = get_data_df([1], verbose=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dbded6fe",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# metric extractor\n",
|
||||
"acm = AutoCM()\n",
|
||||
"mtrc = acm.extract(factor_df)\n",
|
||||
"print(mtrc)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "65517c81",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Analyser\n",
|
||||
"sa = SimpleDFA()\n",
|
||||
"sa.analyse(mtrc, title='Auto Correlation')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dab6fb2e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3"
|
||||
},
|
||||
"toc": {
|
||||
"base_numbering": 1,
|
||||
"nav_menu": {},
|
||||
"number_sections": true,
|
||||
"sideBar": true,
|
||||
"skip_h1_title": false,
|
||||
"title_cell": "Table of Contents",
|
||||
"title_sidebar": "Contents",
|
||||
"toc_cell": false,
|
||||
"toc_position": {},
|
||||
"toc_section_display": true,
|
||||
"toc_window_display": false
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
35
examples/highfreq/README.md
Normal file
35
examples/highfreq/README.md
Normal file
@@ -0,0 +1,35 @@
|
||||
# High-Frequency Dataset
|
||||
|
||||
This dataset is an example for RL high frequency trading.
|
||||
|
||||
## Get High-Frequency Data
|
||||
|
||||
Get high-frequency data by running the following command:
|
||||
```bash
|
||||
python workflow.py get_data
|
||||
```
|
||||
|
||||
## Dump & Reload & Reinitialize the Dataset
|
||||
|
||||
|
||||
The High-Frequency Dataset is implemented as `qlib.data.dataset.DatasetH` in the `workflow.py`. `DatatsetH` is the subclass of [`qlib.utils.serial.Serializable`](https://qlib.readthedocs.io/en/latest/advanced/serial.html), whose state can be dumped in or loaded from disk in `pickle` format.
|
||||
|
||||
### About Reinitialization
|
||||
|
||||
After reloading `Dataset` from disk, `Qlib` also support reinitializing the dataset. It means that users can reset some states of `Dataset` or `DataHandler` such as `instruments`, `start_time`, `end_time` and `segments`, etc., and generate new data according to the states.
|
||||
|
||||
The example is given in `workflow.py`, users can run the code as follows.
|
||||
|
||||
### Run the Code
|
||||
|
||||
Run the example by running the following command:
|
||||
```bash
|
||||
python workflow.py dump_and_load_dataset
|
||||
```
|
||||
|
||||
## Benchmarks Performance
|
||||
### Signal Test
|
||||
Here are the results of signal test for benchmark models. We will keep updating benchmark models in future.
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Long precision| Short Precision | Long-Short Average Return | Long-Short Average Sharpe |
|
||||
|---|---|---|---|---|---|---|---|---|---|
|
||||
| LightGBM | Alpha158 | 0.3042±0.00 | 1.5372±0.00| 0.3117±0.00 | 1.6258±0.00 | 0.6720±0.00 | 0.6870±0.00 | 0.000769±0.00 | 1.0190±0.00 |
|
||||
@@ -62,9 +62,9 @@ class HighFreqHandler(DataHandlerLP):
|
||||
def get_normalized_price_feature(price_field, shift=0):
|
||||
"""Get normalized price feature ops"""
|
||||
if shift == 0:
|
||||
template_norm = "{0}/Ref(DayLast({1}), 240)"
|
||||
template_norm = "Cut({0}/Ref(DayLast({1}), 240), 240, None)"
|
||||
else:
|
||||
template_norm = "Ref({0}, " + str(shift) + ")/Ref(DayLast({1}), 240)"
|
||||
template_norm = "Cut(Ref({0}, " + str(shift) + ")/Ref(DayLast({1}), 240), 240, None)"
|
||||
|
||||
feature_ops = template_norm.format(
|
||||
template_if.format(
|
||||
@@ -90,7 +90,7 @@ class HighFreqHandler(DataHandlerLP):
|
||||
names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"]
|
||||
|
||||
fields += [
|
||||
"{0}/Ref(DayLast(Mean({0}, 7200)), 240)".format(
|
||||
"Cut({0}/Ref(DayLast(Mean({0}, 7200)), 240), 240, None)".format(
|
||||
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
|
||||
template_paused.format("$volume"),
|
||||
template_paused.format(simpson_vwap),
|
||||
@@ -101,7 +101,7 @@ class HighFreqHandler(DataHandlerLP):
|
||||
]
|
||||
names += ["$volume"]
|
||||
fields += [
|
||||
"Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240)".format(
|
||||
"Cut(Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240), 240, None)".format(
|
||||
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
|
||||
template_paused.format("$volume"),
|
||||
template_paused.format(simpson_vwap),
|
||||
@@ -112,7 +112,7 @@ class HighFreqHandler(DataHandlerLP):
|
||||
]
|
||||
names += ["$volume_1"]
|
||||
|
||||
fields += [template_paused.format("Date($close)")]
|
||||
fields += ["Cut({0}, 240, None)".format(template_paused.format("Date($close)"))]
|
||||
names += ["date"]
|
||||
return fields, names
|
||||
|
||||
@@ -149,18 +149,20 @@ class HighFreqBacktestHandler(DataHandler):
|
||||
# Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap
|
||||
simpson_vwap = "($open + 2*$high + 2*$low + $close)/6"
|
||||
fields += [
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
"Cut({0}, 240, None)".format(template_fillnan.format(template_paused.format("$close"))),
|
||||
]
|
||||
names += ["$close0"]
|
||||
fields += [
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format(simpson_vwap),
|
||||
"Cut({0}, 240, None)".format(
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format(simpson_vwap),
|
||||
)
|
||||
)
|
||||
]
|
||||
names += ["$vwap0"]
|
||||
fields += [
|
||||
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
|
||||
"Cut(If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0})), 240, None)".format(
|
||||
template_paused.format("$volume"),
|
||||
template_paused.format(simpson_vwap),
|
||||
template_paused.format("$low"),
|
||||
|
||||
@@ -8,6 +8,20 @@ from qlib.data.data import Cal
|
||||
|
||||
|
||||
def get_calendar_day(freq="day", future=False):
|
||||
"""Load High-Freq Calendar Date Using Memcache.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
freq : str
|
||||
frequency of read calendar file.
|
||||
future : bool
|
||||
whether including future trading day.
|
||||
|
||||
Returns
|
||||
-------
|
||||
_calendar:
|
||||
array of date.
|
||||
"""
|
||||
flag = f"{freq}_future_{future}_day"
|
||||
if flag in H["c"]:
|
||||
_calendar = H["c"][flag]
|
||||
@@ -18,6 +32,19 @@ def get_calendar_day(freq="day", future=False):
|
||||
|
||||
|
||||
class DayLast(ElemOperator):
|
||||
"""DayLast Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a series of that each value equals the last value of its day
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = get_calendar_day(freq=freq)
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
@@ -25,18 +52,57 @@ class DayLast(ElemOperator):
|
||||
|
||||
|
||||
class FFillNan(ElemOperator):
|
||||
"""FFillNan Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a forward fill nan feature
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.fillna(method="ffill")
|
||||
|
||||
|
||||
class BFillNan(ElemOperator):
|
||||
"""BFillNan Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a backfoward fill nan feature
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.fillna(method="bfill")
|
||||
|
||||
|
||||
class Date(ElemOperator):
|
||||
"""Date Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a series of that each value is the date corresponding to feature.index
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = get_calendar_day(freq=freq)
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
@@ -44,6 +110,22 @@ class Date(ElemOperator):
|
||||
|
||||
|
||||
class Select(PairOperator):
|
||||
"""Select Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature_left : Expression
|
||||
feature instance, select condition
|
||||
feature_right : Expression
|
||||
feature instance, select value
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
value(feature_right) that meets the condition(feature_left)
|
||||
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series_condition = self.feature_left.load(instrument, start_index, end_index, freq)
|
||||
series_feature = self.feature_right.load(instrument, start_index, end_index, freq)
|
||||
@@ -51,6 +133,58 @@ class Select(PairOperator):
|
||||
|
||||
|
||||
class IsNull(ElemOperator):
|
||||
"""IsNull Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
A series indicating whether the feature is nan
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.isnull()
|
||||
|
||||
|
||||
class Cut(ElemOperator):
|
||||
"""Cut Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
l : int
|
||||
l > 0, delete the first l elements of feature (default is None, which means 0)
|
||||
r : int
|
||||
r < 0, delete the last -r elements of feature (default is None, which means 0)
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
A series with the first l and last -r elements deleted from the feature.
|
||||
Note: It is deleted from the raw data, not the sliced data
|
||||
"""
|
||||
|
||||
def __init__(self, feature, l=None, r=None):
|
||||
self.l = l
|
||||
self.r = r
|
||||
if (self.l is not None and self.l <= 0) or (self.r is not None and self.r >= 0):
|
||||
raise ValueError("Cut operator l shoud > 0 and r should < 0")
|
||||
|
||||
super(Cut, self).__init__(feature)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.iloc[self.l : self.r]
|
||||
|
||||
def get_extended_window_size(self):
|
||||
ll = 0 if self.l is None else self.l
|
||||
rr = 0 if self.r is None else abs(self.r)
|
||||
lft_etd, rght_etd = self.feature.get_extended_window_size()
|
||||
lft_etd = lft_etd + ll
|
||||
rght_etd = rght_etd + rr
|
||||
return lft_etd, rght_etd
|
||||
|
||||
@@ -1,40 +1,28 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import fire
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.config import HIGH_FREQ_CONFIG
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
|
||||
|
||||
from qlib.utils import init_instance_by_config, exists_qlib_data
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.ops import Operators
|
||||
from qlib.data.data import Cal
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull
|
||||
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut
|
||||
|
||||
|
||||
class HighfreqWorkflow(object):
|
||||
class HighfreqWorkflow:
|
||||
|
||||
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull], "expression_cache": None}
|
||||
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None}
|
||||
|
||||
MARKET = "all"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
start_time = "2020-09-14 00:00:00"
|
||||
start_time = "2020-09-15 00:00:00"
|
||||
end_time = "2021-01-18 16:00:00"
|
||||
train_end_time = "2020-11-30 16:00:00"
|
||||
test_start_time = "2020-12-01 00:00:00"
|
||||
@@ -42,7 +30,6 @@ class HighfreqWorkflow(object):
|
||||
DATA_HANDLER_CONFIG0 = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"freq": "1min",
|
||||
"fit_start_time": start_time,
|
||||
"fit_end_time": train_end_time,
|
||||
"instruments": MARKET,
|
||||
@@ -51,7 +38,6 @@ class HighfreqWorkflow(object):
|
||||
DATA_HANDLER_CONFIG1 = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"freq": "1min",
|
||||
"instruments": MARKET,
|
||||
}
|
||||
|
||||
@@ -99,9 +85,7 @@ class HighfreqWorkflow(object):
|
||||
# use yahoo_cn_1min data
|
||||
QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF}
|
||||
provider_uri = QLIB_INIT_CONFIG.get("provider_uri")
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN)
|
||||
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN, exists_skip=True)
|
||||
qlib.init(**QLIB_INIT_CONFIG)
|
||||
|
||||
def _prepare_calender_cache(self):
|
||||
@@ -125,8 +109,7 @@ class HighfreqWorkflow(object):
|
||||
backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"])
|
||||
print(backtest_train, backtest_test)
|
||||
|
||||
del xtrain, xtest
|
||||
del backtest_train, backtest_test
|
||||
return
|
||||
|
||||
def dump_and_load_dataset(self):
|
||||
"""dump and load dataset state on disk"""
|
||||
@@ -148,18 +131,44 @@ class HighfreqWorkflow(object):
|
||||
dataset_backtest = pickle.load(file_dataset_backtest)
|
||||
|
||||
self._prepare_calender_cache()
|
||||
##=============reload_dataset=============
|
||||
dataset.init(init_type=DataHandlerLP.IT_LS)
|
||||
dataset_backtest.init()
|
||||
##=============reinit dataset=============
|
||||
dataset.config(
|
||||
handler_kwargs={
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segments={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset.setup_data(
|
||||
handler_kwargs={
|
||||
"init_type": DataHandlerLP.IT_LS,
|
||||
},
|
||||
)
|
||||
dataset_backtest.config(
|
||||
handler_kwargs={
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segments={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset_backtest.setup_data(handler_kwargs={})
|
||||
|
||||
##=============get data=============
|
||||
xtrain, xtest = dataset.prepare(["train", "test"])
|
||||
backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"])
|
||||
xtest = dataset.prepare("test")
|
||||
backtest_test = dataset_backtest.prepare("test")
|
||||
|
||||
print(xtrain, xtest)
|
||||
print(backtest_train, backtest_test)
|
||||
del xtrain, xtest
|
||||
del backtest_train, backtest_test
|
||||
print(xtest, backtest_test)
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data_1min"
|
||||
region: cn
|
||||
market: &market 'csi300'
|
||||
start_time: &start_time "2020-09-15 00:00:00"
|
||||
end_time: &end_time "2021-01-18 16:00:00"
|
||||
train_end_time: &train_end_time "2020-11-15 16:00:00"
|
||||
valid_start_time: &valid_start_time "2020-11-16 00:00:00"
|
||||
valid_end_time: &valid_end_time "2020-11-30 16:00:00"
|
||||
test_start_time: &test_start_time "2020-12-01 00:00:00"
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: *start_time
|
||||
end_time: *end_time
|
||||
fit_start_time: *start_time
|
||||
fit_end_time: *train_end_time
|
||||
instruments: *market
|
||||
freq: '1min'
|
||||
infer_processors:
|
||||
- class: 'RobustZScoreNorm'
|
||||
kwargs:
|
||||
fields_group: 'feature'
|
||||
clip_outlier: false
|
||||
- class: "Fillna"
|
||||
kwargs:
|
||||
fields_group: 'feature'
|
||||
learn_processors:
|
||||
- class: 'DropnaLabel'
|
||||
- class: 'CSRankNorm'
|
||||
kwargs:
|
||||
fields_group: 'label'
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
task:
|
||||
model:
|
||||
class: "HFLGBModel"
|
||||
module_path: "qlib.contrib.model.highfreq_gdbt_model"
|
||||
kwargs:
|
||||
objective: 'binary'
|
||||
metric: ['binary_logloss','auc']
|
||||
verbosity: -1
|
||||
learning_rate: 0.01
|
||||
max_depth: 8
|
||||
num_leaves: 150
|
||||
lambda_l1: 1.5
|
||||
lambda_l2: 1
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: "DatasetH"
|
||||
module_path: "qlib.data.dataset"
|
||||
kwargs:
|
||||
handler:
|
||||
class: "Alpha158"
|
||||
module_path: "qlib.contrib.data.handler"
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [*start_time, *train_end_time]
|
||||
valid: [*train_end_time, *valid_end_time]
|
||||
test: [*test_start_time, *end_time]
|
||||
record:
|
||||
- class: "SignalRecord"
|
||||
module_path: "qlib.workflow.record_temp"
|
||||
kwargs: {}
|
||||
- class: "HFSignalRecord"
|
||||
module_path: "qlib.workflow.record_temp"
|
||||
kwargs: {}
|
||||
23
examples/hyperparameter/LightGBM/Readme.md
Normal file
23
examples/hyperparameter/LightGBM/Readme.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# LightGBM hyperparameter
|
||||
|
||||
## Alpha158
|
||||
First terminal
|
||||
```
|
||||
optuna create-study --study LGBM_158 --storage sqlite:///db.sqlite3
|
||||
optuna-dashboard --port 5000 --host 0.0.0.0 sqlite:///db.sqlite3
|
||||
```
|
||||
Second terminal
|
||||
```
|
||||
python hyperparameter_158.py
|
||||
```
|
||||
|
||||
## Alpha360
|
||||
First terminal
|
||||
```
|
||||
optuna create-study --study LGBM_360 --storage sqlite:///db.sqlite3
|
||||
optuna-dashboard --port 5000 --host 0.0.0.0 sqlite:///db.sqlite3
|
||||
```
|
||||
Second terminal
|
||||
```
|
||||
python hyperparameter_360.py
|
||||
```
|
||||
46
examples/hyperparameter/LightGBM/hyperparameter_158.py
Normal file
46
examples/hyperparameter/LightGBM/hyperparameter_158.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import qlib
|
||||
import optuna
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.config import CSI300_DATASET_CONFIG
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
def objective(trial):
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": trial.suggest_uniform("colsample_bytree", 0.5, 1),
|
||||
"learning_rate": trial.suggest_uniform("learning_rate", 0, 1),
|
||||
"subsample": trial.suggest_uniform("subsample", 0, 1),
|
||||
"lambda_l1": trial.suggest_loguniform("lambda_l1", 1e-8, 1e4),
|
||||
"lambda_l2": trial.suggest_loguniform("lambda_l2", 1e-8, 1e4),
|
||||
"max_depth": 10,
|
||||
"num_leaves": trial.suggest_int("num_leaves", 1, 1024),
|
||||
"feature_fraction": trial.suggest_uniform("feature_fraction", 0.4, 1.0),
|
||||
"bagging_fraction": trial.suggest_uniform("bagging_fraction", 0.4, 1.0),
|
||||
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
|
||||
"min_data_in_leaf": trial.suggest_int("min_data_in_leaf", 1, 50),
|
||||
"min_child_samples": trial.suggest_int("min_child_samples", 5, 100),
|
||||
},
|
||||
},
|
||||
}
|
||||
evals_result = dict()
|
||||
model = init_instance_by_config(task["model"])
|
||||
model.fit(dataset, evals_result=evals_result)
|
||||
return min(evals_result["valid"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
qlib.init(provider_uri=provider_uri, region="cn")
|
||||
|
||||
dataset = init_instance_by_config(CSI300_DATASET_CONFIG)
|
||||
|
||||
study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3")
|
||||
study.optimize(objective, n_jobs=6)
|
||||
49
examples/hyperparameter/LightGBM/hyperparameter_360.py
Normal file
49
examples/hyperparameter/LightGBM/hyperparameter_360.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import qlib
|
||||
import optuna
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import get_dataset_config, CSI300_MARKET, DATASET_ALPHA360_CLASS
|
||||
|
||||
DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA360_CLASS)
|
||||
|
||||
|
||||
def objective(trial):
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": trial.suggest_uniform("colsample_bytree", 0.5, 1),
|
||||
"learning_rate": trial.suggest_uniform("learning_rate", 0, 1),
|
||||
"subsample": trial.suggest_uniform("subsample", 0, 1),
|
||||
"lambda_l1": trial.suggest_loguniform("lambda_l1", 1e-8, 1e4),
|
||||
"lambda_l2": trial.suggest_loguniform("lambda_l2", 1e-8, 1e4),
|
||||
"max_depth": 10,
|
||||
"num_leaves": trial.suggest_int("num_leaves", 1, 1024),
|
||||
"feature_fraction": trial.suggest_uniform("feature_fraction", 0.4, 1.0),
|
||||
"bagging_fraction": trial.suggest_uniform("bagging_fraction", 0.4, 1.0),
|
||||
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
|
||||
"min_data_in_leaf": trial.suggest_int("min_data_in_leaf", 1, 50),
|
||||
"min_child_samples": trial.suggest_int("min_child_samples", 5, 100),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
evals_result = dict()
|
||||
model = init_instance_by_config(task["model"])
|
||||
model.fit(dataset, evals_result=evals_result)
|
||||
return min(evals_result["valid"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
dataset = init_instance_by_config(DATASET_CONFIG)
|
||||
|
||||
study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3")
|
||||
study.optimize(objective, n_jobs=6)
|
||||
5
examples/hyperparameter/LightGBM/requirements.txt
Normal file
5
examples/hyperparameter/LightGBM/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
lightgbm==3.1.0
|
||||
optuna==2.7.0
|
||||
optuna-dashboard==0.4.1
|
||||
32
examples/model_interpreter/feature.py
Normal file
32
examples/model_interpreter/feature.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import CSI300_GBDT_TASK
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
# model initialization
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
# get model feature importance
|
||||
feature_importance = model.get_feature_importance()
|
||||
print("feature importance:")
|
||||
print(feature_importance)
|
||||
105
examples/model_rolling/task_manager_rolling.py
Normal file
105
examples/model_rolling/task_manager_rolling.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example shows how a TrainerRM works based on TaskManager with rolling tasks.
|
||||
After training, how to collect the rolling results will be shown in task_collecting.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
|
||||
|
||||
|
||||
class RollingTaskExample:
|
||||
def __init__(
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region=REG_CN,
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
experiment_name="rolling_exp",
|
||||
task_pool="rolling_task",
|
||||
task_config=None,
|
||||
rolling_step=550,
|
||||
rolling_type=RollingGen.ROLL_SD,
|
||||
):
|
||||
# TaskManager config
|
||||
if task_config is None:
|
||||
task_config = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
mongo_conf = {
|
||||
"task_url": task_url,
|
||||
"task_db_name": task_db_name,
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.experiment_name = experiment_name
|
||||
self.task_pool = task_pool
|
||||
self.task_config = task_config
|
||||
self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
print("========== reset ==========")
|
||||
TaskManager(task_pool=self.task_pool).remove()
|
||||
exp = R.get_exp(experiment_name=self.experiment_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
def task_generating(self):
|
||||
print("========== task_generating ==========")
|
||||
tasks = task_generator(
|
||||
tasks=self.task_config,
|
||||
generators=self.rolling_gen, # generate different date segments
|
||||
)
|
||||
pprint(tasks)
|
||||
return tasks
|
||||
|
||||
def task_training(self, tasks):
|
||||
print("========== task_training ==========")
|
||||
trainer = TrainerRM(self.experiment_name, self.task_pool)
|
||||
trainer.train(tasks)
|
||||
|
||||
def task_collecting(self):
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
def rec_key(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, rolling_key
|
||||
|
||||
def my_filter(recorder):
|
||||
# only choose the results of "LGBModel"
|
||||
model_key, rolling_key = rec_key(recorder)
|
||||
if model_key == "LGBModel":
|
||||
return True
|
||||
return False
|
||||
|
||||
collector = RecorderCollector(
|
||||
experiment=self.experiment_name,
|
||||
process_list=RollingGroup(),
|
||||
rec_key_func=rec_key,
|
||||
rec_filter_func=my_filter,
|
||||
)
|
||||
print(collector())
|
||||
|
||||
def main(self):
|
||||
self.reset()
|
||||
tasks = self.task_generating()
|
||||
self.task_training(tasks)
|
||||
self.task_collecting()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to see the whole process with your own parameters, use the command below
|
||||
# python task_manager_rolling.py main --experiment_name="your_exp_name"
|
||||
fire.Fire(RollingTaskExample)
|
||||
92
examples/online_srv/online_management_simulate.py
Normal file
92
examples/online_srv/online_management_simulate.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example is about how can simulate the OnlineManager based on rolling tasks.
|
||||
"""
|
||||
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
|
||||
|
||||
|
||||
class OnlineSimulationExample:
|
||||
def __init__(
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
exp_name="rolling_exp",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
task_pool="rolling_task",
|
||||
rolling_step=80,
|
||||
start_time="2018-09-10",
|
||||
end_time="2018-10-31",
|
||||
tasks=None,
|
||||
):
|
||||
"""
|
||||
Init OnlineManagerExample.
|
||||
|
||||
Args:
|
||||
provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data".
|
||||
region (str, optional): the stock region. Defaults to "cn".
|
||||
exp_name (str, optional): the experiment name. Defaults to "rolling_exp".
|
||||
task_url (str, optional): your MongoDB url. Defaults to "mongodb://10.0.0.4:27017/".
|
||||
task_db_name (str, optional): database name. Defaults to "rolling_db".
|
||||
task_pool (str, optional): the task pool name (a task pool is a collection in MongoDB). Defaults to "rolling_task".
|
||||
rolling_step (int, optional): the step for rolling. Defaults to 80.
|
||||
start_time (str, optional): the start time of simulating. Defaults to "2018-09-10".
|
||||
end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
|
||||
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
|
||||
"""
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
mongo_conf = {
|
||||
"task_url": task_url,
|
||||
"task_db_name": task_db_name,
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.rolling_gen = RollingGen(
|
||||
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
|
||||
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
|
||||
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
self.rolling_online_manager = OnlineManager(
|
||||
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
|
||||
trainer=self.trainer,
|
||||
begin_time=self.start_time,
|
||||
)
|
||||
self.tasks = tasks
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
TaskManager(self.task_pool).remove()
|
||||
exp = R.get_exp(experiment_name=self.exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
# Run this to run all workflow automatically
|
||||
def main(self):
|
||||
print("========== reset ==========")
|
||||
self.reset()
|
||||
print("========== simulate ==========")
|
||||
self.rolling_online_manager.simulate(end_time=self.end_time)
|
||||
print("========== collect results ==========")
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to run all workflow automatically with your own parameters, use the command below
|
||||
# python online_management_simulate.py main --experiment_name="your_exp_name" --rolling_step=60
|
||||
fire.Fire(OnlineSimulationExample)
|
||||
130
examples/online_srv/rolling_online_management.py
Normal file
130
examples/online_srv/rolling_online_management.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example shows how OnlineManager works with rolling tasks.
|
||||
There are four parts including first train, routine 1, add strategy and routine 2.
|
||||
Firstly, the OnlineManager will finish the first training and set trained models to `online` models.
|
||||
Next, the OnlineManager will finish a routine process, including update online prediction -> prepare tasks -> prepare new models -> prepare signals
|
||||
Then, we will add some new strategies to the OnlineManager. This will finish first training of new strategies.
|
||||
Finally, the OnlineManager will finish second routine and update all strategies.
|
||||
"""
|
||||
|
||||
import os
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG
|
||||
|
||||
|
||||
class RollingOnlineExample:
|
||||
def __init__(
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
rolling_step=550,
|
||||
tasks=None,
|
||||
add_tasks=None,
|
||||
):
|
||||
if add_tasks is None:
|
||||
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG]
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
"task_db_name": task_db_name, # database name
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.tasks = tasks
|
||||
self.add_tasks = add_tasks
|
||||
self.rolling_step = rolling_step
|
||||
strategies = []
|
||||
for task in tasks:
|
||||
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
|
||||
strategies.append(
|
||||
RollingStrategy(
|
||||
name_id,
|
||||
task,
|
||||
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
)
|
||||
)
|
||||
|
||||
self.rolling_online_manager = OnlineManager(strategies)
|
||||
|
||||
_ROLLING_MANAGER_PATH = (
|
||||
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
|
||||
)
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
for task in self.tasks + self.add_tasks:
|
||||
name_id = task["model"]["class"]
|
||||
exp = R.get_exp(experiment_name=name_id)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
if os.path.exists(self._ROLLING_MANAGER_PATH):
|
||||
os.remove(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def first_run(self):
|
||||
print("========== reset ==========")
|
||||
self.reset()
|
||||
print("========== first_run ==========")
|
||||
self.rolling_online_manager.first_train()
|
||||
print("========== collect results ==========")
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print("========== dump ==========")
|
||||
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def routine(self):
|
||||
print("========== load ==========")
|
||||
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
|
||||
print("========== routine ==========")
|
||||
self.rolling_online_manager.routine()
|
||||
print("========== collect results ==========")
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
print("========== dump ==========")
|
||||
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def add_strategy(self):
|
||||
print("========== load ==========")
|
||||
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
|
||||
print("========== add strategy ==========")
|
||||
strategies = []
|
||||
for task in self.add_tasks:
|
||||
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
|
||||
strategies.append(
|
||||
RollingStrategy(
|
||||
name_id,
|
||||
task,
|
||||
RollingGen(step=self.rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
)
|
||||
)
|
||||
self.rolling_online_manager.add_strategy(strategies=strategies)
|
||||
print("========== dump ==========")
|
||||
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def main(self):
|
||||
self.first_run()
|
||||
self.routine()
|
||||
self.add_strategy()
|
||||
self.routine()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### to train the first version's models, use the command below
|
||||
# python rolling_online_management.py first_run
|
||||
|
||||
####### to update the models and predictions after the trading time, use the command below
|
||||
# python rolling_online_management.py routine
|
||||
|
||||
####### to define your own parameters, use `--`
|
||||
# python rolling_online_management.py first_run --exp_name='your_exp_name' --rolling_step=40
|
||||
fire.Fire(RollingOnlineExample)
|
||||
54
examples/online_srv/update_online_pred.py
Normal file
54
examples/online_srv/update_online_pred.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example shows how OnlineTool works when we need update prediction.
|
||||
There are two parts including first_train and update_online_pred.
|
||||
Firstly, we will finish the training and set the trained models to the `online` models.
|
||||
Next, we will finish updating online predictions.
|
||||
"""
|
||||
import copy
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.online.utils import OnlineToolR
|
||||
from qlib.tests.config import CSI300_GBDT_TASK
|
||||
|
||||
task = copy.deepcopy(CSI300_GBDT_TASK)
|
||||
|
||||
task["record"] = {
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
}
|
||||
|
||||
|
||||
class UpdatePredExample:
|
||||
def __init__(
|
||||
self, provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv", task_config=task
|
||||
):
|
||||
qlib.init(provider_uri=provider_uri, region=region)
|
||||
self.experiment_name = experiment_name
|
||||
self.online_tool = OnlineToolR(self.experiment_name)
|
||||
self.task_config = task_config
|
||||
|
||||
def first_train(self):
|
||||
rec = task_train(self.task_config, experiment_name=self.experiment_name)
|
||||
self.online_tool.reset_online_tag(rec) # set to online model
|
||||
|
||||
def update_online_pred(self):
|
||||
self.online_tool.update_online_pred()
|
||||
|
||||
def main(self):
|
||||
self.first_train()
|
||||
self.update_online_pred()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to train a model and set it to online model, use the command below
|
||||
# python update_online_pred.py first_train
|
||||
## to update online predictions once a day, use the command below
|
||||
# python update_online_pred.py update_online_pred
|
||||
## to see the whole process with your own parameters, use the command below
|
||||
# python update_online_pred.py main --experiment_name="your_exp_name"
|
||||
fire.Fire(UpdatePredExample)
|
||||
17
examples/rolling_process_data/README.md
Normal file
17
examples/rolling_process_data/README.md
Normal file
@@ -0,0 +1,17 @@
|
||||
# Rolling Process Data
|
||||
|
||||
This workflow is an example for `Rolling Process Data`.
|
||||
|
||||
## Background
|
||||
|
||||
When rolling train the models, data also needs to be generated in the different rolling windows. When the rolling window moves, the training data will change, and the processor's learnable state (such as standard deviation, mean, etc.) will also change.
|
||||
|
||||
In order to avoid regenerating data, this example uses the `DataHandler-based DataLoader` to load the raw features that are not related to the rolling window, and then used Processors to generate processed-features related to the rolling window.
|
||||
|
||||
|
||||
## Run the Code
|
||||
|
||||
Run the example by running the following command:
|
||||
```bash
|
||||
python workflow.py rolling_process
|
||||
```
|
||||
32
examples/rolling_process_data/rolling_handler.py
Normal file
32
examples/rolling_process_data/rolling_handler.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.dataset.loader import DataLoaderDH
|
||||
from qlib.contrib.data.handler import check_transform_proc
|
||||
|
||||
|
||||
class RollingDataHandler(DataHandlerLP):
|
||||
def __init__(
|
||||
self,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
data_loader_kwargs={},
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
"class": "DataLoaderDH",
|
||||
"kwargs": {**data_loader_kwargs},
|
||||
}
|
||||
|
||||
super().__init__(
|
||||
instruments=None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
)
|
||||
137
examples/rolling_process_data/workflow.py
Normal file
137
examples/rolling_process_data/workflow.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import qlib
|
||||
import fire
|
||||
import pickle
|
||||
|
||||
from datetime import datetime
|
||||
from qlib.config import REG_CN
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
class RollingDataWorkflow:
|
||||
|
||||
MARKET = "csi300"
|
||||
start_time = "2010-01-01"
|
||||
end_time = "2019-12-31"
|
||||
rolling_cnt = 5
|
||||
|
||||
def _init_qlib(self):
|
||||
"""initialize qlib"""
|
||||
# use yahoo_cn_1min data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
def _dump_pre_handler(self, path):
|
||||
handler_config = {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": {
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"instruments": self.MARKET,
|
||||
"infer_processors": [],
|
||||
"learn_processors": [],
|
||||
},
|
||||
}
|
||||
pre_handler = init_instance_by_config(handler_config)
|
||||
pre_handler.config(dump_all=True)
|
||||
pre_handler.to_pickle(path)
|
||||
|
||||
def _load_pre_handler(self, path):
|
||||
with open(path, "rb") as file_dataset:
|
||||
pre_handler = pickle.load(file_dataset)
|
||||
return pre_handler
|
||||
|
||||
def rolling_process(self):
|
||||
self._init_qlib()
|
||||
self._dump_pre_handler("pre_handler.pkl")
|
||||
pre_handler = self._load_pre_handler("pre_handler.pkl")
|
||||
|
||||
train_start_time = (2010, 1, 1)
|
||||
train_end_time = (2012, 12, 31)
|
||||
valid_start_time = (2013, 1, 1)
|
||||
valid_end_time = (2013, 12, 31)
|
||||
test_start_time = (2014, 1, 1)
|
||||
test_end_time = (2014, 12, 31)
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "RollingDataHandler",
|
||||
"module_path": "rolling_handler",
|
||||
"kwargs": {
|
||||
"start_time": datetime(*train_start_time),
|
||||
"end_time": datetime(*test_end_time),
|
||||
"fit_start_time": datetime(*train_start_time),
|
||||
"fit_end_time": datetime(*train_end_time),
|
||||
"infer_processors": [
|
||||
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature"}},
|
||||
],
|
||||
"learn_processors": [
|
||||
{"class": "DropnaLabel"},
|
||||
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
|
||||
],
|
||||
"data_loader_kwargs": {
|
||||
"handler_config": pre_handler,
|
||||
},
|
||||
},
|
||||
},
|
||||
"segments": {
|
||||
"train": (datetime(*train_start_time), datetime(*train_end_time)),
|
||||
"valid": (datetime(*valid_start_time), datetime(*valid_end_time)),
|
||||
"test": (datetime(*test_start_time), datetime(*test_end_time)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dataset = init_instance_by_config(dataset_config)
|
||||
|
||||
for rolling_offset in range(self.rolling_cnt):
|
||||
|
||||
print(f"===========rolling{rolling_offset} start===========")
|
||||
if rolling_offset:
|
||||
dataset.config(
|
||||
handler_kwargs={
|
||||
"start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
"end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
|
||||
"processor_kwargs": {
|
||||
"fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
"fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
|
||||
},
|
||||
},
|
||||
segments={
|
||||
"train": (
|
||||
datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
|
||||
),
|
||||
"valid": (
|
||||
datetime(valid_start_time[0] + rolling_offset, *valid_start_time[1:]),
|
||||
datetime(valid_end_time[0] + rolling_offset, *valid_end_time[1:]),
|
||||
),
|
||||
"test": (
|
||||
datetime(test_start_time[0] + rolling_offset, *test_start_time[1:]),
|
||||
datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset.setup_data(
|
||||
handler_kwargs={
|
||||
"init_type": DataHandlerLP.IT_FIT_SEQ,
|
||||
}
|
||||
)
|
||||
|
||||
dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"])
|
||||
print(dtrain, dvalid, dtest)
|
||||
## print or dump data
|
||||
print(f"===========rolling{rolling_offset} end===========")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(RollingDataWorkflow)
|
||||
@@ -5,13 +5,11 @@ import os
|
||||
import sys
|
||||
import fire
|
||||
import time
|
||||
import venv
|
||||
import glob
|
||||
import shutil
|
||||
import signal
|
||||
import inspect
|
||||
import tempfile
|
||||
import traceback
|
||||
import functools
|
||||
import statistics
|
||||
import subprocess
|
||||
@@ -23,8 +21,7 @@ from pprint import pprint
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.cli import workflow
|
||||
from qlib.utils import exists_qlib_data
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
# init qlib
|
||||
@@ -39,12 +36,8 @@ exp_manager = {
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
}
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
|
||||
|
||||
# decorator to check the arguments
|
||||
|
||||
@@ -28,11 +28,17 @@
|
||||
"import sys, site\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"################################# NOTE #################################\n",
|
||||
"# Please be aware that if colab installs the latest numpy and pyqlib #\n",
|
||||
"# in this cell, users should RESTART the runtime in order to run the #\n",
|
||||
"# following cells successfully. #\n",
|
||||
"########################################################################\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" import qlib\n",
|
||||
"except ImportError:\n",
|
||||
" # install qlib\n",
|
||||
" ! pip install --upgrade numpy\n",
|
||||
" ! pip install pyqlib\n",
|
||||
" # reload\n",
|
||||
" site.main()\n",
|
||||
@@ -238,9 +244,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.contrib.report import analysis_model, analysis_position\n",
|
||||
@@ -359,7 +363,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.9"
|
||||
"version": "3.8.3"
|
||||
},
|
||||
"toc": {
|
||||
"base_numbering": 1,
|
||||
@@ -377,4 +381,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
@@ -1,82 +1,22 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import CSI300_BENCH, CSI300_GBDT_TASK
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
port_analysis_config = {
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
@@ -90,7 +30,7 @@ if __name__ == "__main__":
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": benchmark,
|
||||
"benchmark": CSI300_BENCH,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
@@ -99,9 +39,9 @@ if __name__ == "__main__":
|
||||
},
|
||||
}
|
||||
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
# model initialization
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
|
||||
# NOTE: This line is optional
|
||||
# It demonstrates that the dataset can be used standalone.
|
||||
@@ -110,14 +50,16 @@ if __name__ == "__main__":
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
|
||||
model.fit(dataset)
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
|
||||
# backtest
|
||||
# backtest. If users want to use backtest based on their own prediction,
|
||||
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
|
||||
par = PortAnaRecord(recorder, port_analysis_config)
|
||||
par.generate()
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
__version__ = "0.6.3"
|
||||
__version__ = "0.6.3.99"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
|
||||
|
||||
import os
|
||||
@@ -10,12 +11,13 @@ import yaml
|
||||
import logging
|
||||
import platform
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from .log import get_module_logger
|
||||
|
||||
|
||||
# init qlib
|
||||
def init(default_conf="client", **kwargs):
|
||||
from .config import C
|
||||
from .log import get_module_logger
|
||||
from .data.cache import H
|
||||
|
||||
H.clear()
|
||||
@@ -48,7 +50,6 @@ def init(default_conf="client", **kwargs):
|
||||
|
||||
|
||||
def _mount_nfs_uri(C):
|
||||
from .log import get_module_logger
|
||||
|
||||
LOG = get_module_logger("mount nfs", level=logging.INFO)
|
||||
|
||||
@@ -147,7 +148,78 @@ def init_from_yaml_conf(conf_path, **kwargs):
|
||||
"""
|
||||
|
||||
with open(conf_path) as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
config = yaml.safe_load(f)
|
||||
config.update(kwargs)
|
||||
default_conf = config.pop("default_conf", "client")
|
||||
init(default_conf, **config)
|
||||
|
||||
|
||||
def get_project_path(config_name="config.yaml", cur_path=None) -> Path:
|
||||
"""
|
||||
If users are building a project follow the following pattern.
|
||||
- Qlib is a sub folder in project path
|
||||
- There is a file named `config.yaml` in qlib.
|
||||
|
||||
For example:
|
||||
If your project file system stucuture follows such a pattern
|
||||
|
||||
<project_path>/
|
||||
- config.yaml
|
||||
- ...some folders...
|
||||
- qlib/
|
||||
|
||||
This folder will return <project_path>
|
||||
|
||||
NOTE: link is not supported here.
|
||||
|
||||
|
||||
This method is often used when
|
||||
- user want to use a relative config path instead of hard-coding qlib config path in code
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError:
|
||||
If project path is not found
|
||||
"""
|
||||
if cur_path is None:
|
||||
cur_path = Path(__file__).absolute().resolve()
|
||||
while True:
|
||||
if (cur_path / config_name).exists():
|
||||
return cur_path
|
||||
if cur_path == cur_path.parent:
|
||||
raise FileNotFoundError("We can't find the project path")
|
||||
cur_path = cur_path.parent
|
||||
|
||||
|
||||
def auto_init(**kwargs):
|
||||
"""
|
||||
This function will init qlib automatically with following priority
|
||||
- Find the project configuration and init qlib
|
||||
- The parsing process will be affected by the `conf_type` of the configuration file
|
||||
- Init qlib with default config
|
||||
"""
|
||||
|
||||
try:
|
||||
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
|
||||
except FileNotFoundError:
|
||||
init(**kwargs)
|
||||
else:
|
||||
|
||||
conf_pp = pp / "config.yaml"
|
||||
with conf_pp.open() as f:
|
||||
conf = yaml.safe_load(f)
|
||||
|
||||
conf_type = conf.get("conf_type", "origin")
|
||||
if conf_type == "origin":
|
||||
# The type of config is just like original qlib config
|
||||
init_from_yaml_conf(conf_pp, **kwargs)
|
||||
elif conf_type == "ref":
|
||||
# This config type will be more convenient in following scenario
|
||||
# - There is a shared configure file and you don't want to edit it inplace.
|
||||
# - The shared configure may be updated later and you don't want to copy it.
|
||||
# - You have some customized config.
|
||||
qlib_conf_path = conf["qlib_cfg"]
|
||||
qlib_conf_update = conf.get("qlib_cfg_update")
|
||||
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs)
|
||||
logger = get_module_logger("Initialization")
|
||||
logger.info(f"Auto load project config: {conf_pp}")
|
||||
|
||||
@@ -33,6 +33,9 @@ class Config:
|
||||
|
||||
raise AttributeError(f"No such {attr} in self._config")
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self.__dict__["_config"].get(key, default)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.__dict__["_config"][key] = value
|
||||
|
||||
@@ -105,7 +108,7 @@ _default_config = {
|
||||
"redis_port": 6379,
|
||||
"redis_task_db": 1,
|
||||
# This value can be reset via qlib.init
|
||||
"logging_level": "INFO",
|
||||
"logging_level": logging.INFO,
|
||||
# Global configuration of qlib log
|
||||
# logging_level can control the logging level more finely
|
||||
"logging_config": {
|
||||
@@ -124,14 +127,14 @@ _default_config = {
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": "DEBUG",
|
||||
"level": logging.DEBUG,
|
||||
"formatter": "logger_format",
|
||||
"filters": ["field_not_found"],
|
||||
}
|
||||
},
|
||||
"loggers": {"qlib": {"level": "DEBUG", "handlers": ["console"]}},
|
||||
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}},
|
||||
},
|
||||
# Defatult config for experiment manager
|
||||
# Default config for experiment manager
|
||||
"exp_manager": {
|
||||
"class": "MLflowExpManager",
|
||||
"module_path": "qlib.workflow.expm",
|
||||
@@ -140,6 +143,11 @@ _default_config = {
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
},
|
||||
# Default config for MongoDB
|
||||
"mongo": {
|
||||
"task_url": "mongodb://localhost:27017/",
|
||||
"task_db_name": "default_task_db",
|
||||
},
|
||||
}
|
||||
|
||||
MODE_CONF = {
|
||||
@@ -185,7 +193,7 @@ MODE_CONF = {
|
||||
# The nfs should be auto-mounted by qlib on other
|
||||
# serversS(such as PAI) [auto_mount:True]
|
||||
"timeout": 100,
|
||||
"logging_level": "INFO",
|
||||
"logging_level": logging.INFO,
|
||||
"region": REG_CN,
|
||||
## Custom Operator
|
||||
"custom_ops": [],
|
||||
@@ -310,8 +318,22 @@ class QlibConfig(Config):
|
||||
# clean up experiment when python program ends
|
||||
experiment_exit_handler()
|
||||
|
||||
# Supporting user reset qlib version (useful when user want to connect to qlib server with old version)
|
||||
self.reset_qlib_version()
|
||||
|
||||
self._registered = True
|
||||
|
||||
def reset_qlib_version(self):
|
||||
import qlib
|
||||
|
||||
reset_version = self.get("qlib_reset_version", None)
|
||||
if reset_version is not None:
|
||||
qlib.__version__ = reset_version
|
||||
else:
|
||||
qlib.__version__ = getattr(qlib, "__version__bak")
|
||||
# Due to a bug? that converting __version__ to _QlibConfig__version__bak
|
||||
# Using __version__bak instead of __version__
|
||||
|
||||
@property
|
||||
def registered(self):
|
||||
return self._registered
|
||||
|
||||
@@ -104,10 +104,9 @@ class Account:
|
||||
# if suspend, no new price to be updated, profit is 0
|
||||
if trader.check_stock_suspended(code, today):
|
||||
continue
|
||||
else:
|
||||
today_close = trader.get_close(code, today)
|
||||
profit += (today_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
|
||||
self.current.update_stock_price(stock_id=code, price=today_close)
|
||||
today_close = trader.get_close(code, today)
|
||||
profit += (today_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
|
||||
self.current.update_stock_price(stock_id=code, price=today_close)
|
||||
self.rtn += profit
|
||||
# update holding day count
|
||||
self.current.add_count_all()
|
||||
|
||||
@@ -15,7 +15,8 @@ LOG = get_module_logger("backtest")
|
||||
|
||||
|
||||
def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, benchmark, return_order):
|
||||
"""Parameters
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
pred : pandas.DataFrame
|
||||
predict should has <datetime, instrument> index and one `score` column
|
||||
@@ -124,7 +125,9 @@ def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account,
|
||||
|
||||
|
||||
def update_account(trade_account, trade_info, trade_exchange, trade_date):
|
||||
"""Update the account and strategy
|
||||
"""
|
||||
Update the account and strategy
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_account : Account()
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import pandas as pd
|
||||
import copy
|
||||
import pathlib
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from .order import Order
|
||||
|
||||
"""
|
||||
@@ -128,7 +128,7 @@ class Position:
|
||||
return self.position["cash"]
|
||||
|
||||
def get_stock_amount_dict(self):
|
||||
"""generate stock amount dict {stock_id : amount of stock} """
|
||||
"""generate stock amount dict {stock_id : amount of stock}"""
|
||||
d = {}
|
||||
stock_list = self.get_stock_list()
|
||||
for stock_code in stock_list:
|
||||
@@ -166,7 +166,7 @@ class Position:
|
||||
def save_position(self, path, last_trade_date):
|
||||
path = pathlib.Path(path)
|
||||
p = copy.deepcopy(self.position)
|
||||
cash = pd.Series(dtype=np.float)
|
||||
cash = pd.Series(dtype=float)
|
||||
cash["init_cash"] = self.init_cash
|
||||
cash["cash"] = p["cash"]
|
||||
cash["today_account_value"] = p["today_account_value"]
|
||||
|
||||
@@ -26,6 +26,7 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time):
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
# FIXME: the `module_path` parameter is missed.
|
||||
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
|
||||
else:
|
||||
new_l.append(p)
|
||||
|
||||
@@ -8,6 +8,59 @@ import pandas as pd
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def calc_long_short_prec(
|
||||
pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False
|
||||
) -> Tuple[pd.Series, pd.Series]:
|
||||
"""
|
||||
calculate the precision for long and short operation
|
||||
|
||||
|
||||
:param pred/label: index is **pd.MultiIndex**, index name is **[datetime, instruments]**; columns names is **[score]**.
|
||||
|
||||
.. code-block:: python
|
||||
score
|
||||
datetime instrument
|
||||
2020-12-01 09:30:00 SH600068 0.553634
|
||||
SH600195 0.550017
|
||||
SH600276 0.540321
|
||||
SH600584 0.517297
|
||||
SH600715 0.544674
|
||||
label :
|
||||
label
|
||||
date_col :
|
||||
date_col
|
||||
|
||||
Returns
|
||||
-------
|
||||
(pd.Series, pd.Series)
|
||||
long precision and short precision in time level
|
||||
"""
|
||||
if is_alpha:
|
||||
label = label - label.mean(level=date_col)
|
||||
if int(1 / quantile) >= len(label.index.get_level_values(1).unique()):
|
||||
raise ValueError("Need more instruments to calculate precision")
|
||||
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
if dropna:
|
||||
df.dropna(inplace=True)
|
||||
|
||||
group = df.groupby(level=date_col)
|
||||
|
||||
N = lambda x: int(len(x) * quantile)
|
||||
# find the top/low quantile of prediction and treat them as long and short target
|
||||
long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True)
|
||||
short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True)
|
||||
|
||||
groupll = long.groupby(date_col)
|
||||
l_dom = groupll.apply(lambda x: x > 0)
|
||||
l_c = groupll.count()
|
||||
|
||||
groups = short.groupby(date_col)
|
||||
s_dom = groups.apply(lambda x: x < 0)
|
||||
s_c = groups.count()
|
||||
return (l_dom.groupby(date_col).sum() / l_c), (s_dom.groupby(date_col).sum() / s_c)
|
||||
|
||||
|
||||
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]:
|
||||
"""calc_ic.
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ def get_position_value(evaluate_date, position):
|
||||
# load close price for position
|
||||
# position should also consider cash
|
||||
instruments = list(position.keys())
|
||||
instruments = list(set(instruments) - set(["cash"])) # filter 'cash'
|
||||
instruments = list(set(instruments) - {"cash"}) # filter 'cash'
|
||||
fields = ["$close"]
|
||||
close_data_df = D.features(
|
||||
instruments,
|
||||
@@ -80,7 +80,7 @@ def get_position_list_value(positions):
|
||||
instruments = set()
|
||||
for day, position in positions.items():
|
||||
instruments.update(position.keys())
|
||||
instruments = list(set(instruments) - set(["cash"])) # filter 'cash'
|
||||
instruments = list(set(instruments) - {"cash"}) # filter 'cash'
|
||||
instruments.sort()
|
||||
day_list = list(positions.keys())
|
||||
day_list.sort()
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
try:
|
||||
from .catboost_model import CatBoostModel
|
||||
except ModuleNotFoundError:
|
||||
CatBoostModel = None
|
||||
print("Please install necessary libs for CatBoostModel.")
|
||||
try:
|
||||
from .double_ensemble import DEnsembleModel
|
||||
from .gbdt import LGBModel
|
||||
except ModuleNotFoundError:
|
||||
DEnsembleModel, LGBModel = None, None
|
||||
print("Please install necessary libs for DEnsembleModel and LGBModel, such as lightgbm.")
|
||||
try:
|
||||
from .xgboost import XGBModel
|
||||
except ModuleNotFoundError:
|
||||
XGBModel = None
|
||||
print("Please install necessary libs for XGBModel, such as xgboost.")
|
||||
try:
|
||||
from .linear import LinearModel
|
||||
except ModuleNotFoundError:
|
||||
LinearModel = None
|
||||
print("Please install necessary libs for LinearModel, such as scipy and sklearn.")
|
||||
# import pytorch models
|
||||
try:
|
||||
from .pytorch_alstm import ALSTM
|
||||
from .pytorch_gats import GATs
|
||||
from .pytorch_gru import GRU
|
||||
from .pytorch_lstm import LSTM
|
||||
from .pytorch_nn import DNNModelPytorch
|
||||
from .pytorch_tabnet import TabnetModel
|
||||
from .pytorch_sfm import SFM_Model
|
||||
|
||||
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model)
|
||||
except ModuleNotFoundError:
|
||||
pytorch_classes = ()
|
||||
print("Please install necessary libs for PyTorch models.")
|
||||
|
||||
all_model_classes = (CatBoostModel, DEnsembleModel, LGBModel, XGBModel, LinearModel) + pytorch_classes
|
||||
|
||||
@@ -3,15 +3,17 @@
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
from catboost import Pool, CatBoost
|
||||
from catboost.utils import get_gpu_device_count
|
||||
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...model.interpret.base import FeatureInt
|
||||
|
||||
|
||||
class CatBoostModel(Model):
|
||||
class CatBoostModel(Model, FeatureInt):
|
||||
"""CatBoost Model"""
|
||||
|
||||
def __init__(self, loss="RMSE", **kwargs):
|
||||
@@ -62,12 +64,24 @@ class CatBoostModel(Model):
|
||||
evals_result["train"] = list(evals_result["learn"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["validation"].values())[0]
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
|
||||
"""get feature importance
|
||||
|
||||
Notes
|
||||
-----
|
||||
parameters references:
|
||||
https://catboost.ai/docs/concepts/python-reference_catboost_get_feature_importance.html#python-reference_catboost_get_feature_importance
|
||||
"""
|
||||
return pd.Series(
|
||||
data=self.model.get_feature_importance(*args, **kwargs), index=self.model.feature_names_
|
||||
).sort_values(ascending=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cat = CatBoostModel()
|
||||
|
||||
265
qlib/contrib/model/double_ensemble.py
Normal file
265
qlib/contrib/model/double_ensemble.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import lightgbm as lgb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...model.interpret.base import FeatureInt
|
||||
from ...log import get_module_logger
|
||||
|
||||
|
||||
class DEnsembleModel(Model, FeatureInt):
|
||||
"""Double Ensemble Model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_model="gbm",
|
||||
loss="mse",
|
||||
num_models=6,
|
||||
enable_sr=True,
|
||||
enable_fs=True,
|
||||
alpha1=1.0,
|
||||
alpha2=1.0,
|
||||
bins_sr=10,
|
||||
bins_fs=5,
|
||||
decay=None,
|
||||
sample_ratios=None,
|
||||
sub_weights=None,
|
||||
epochs=100,
|
||||
**kwargs
|
||||
):
|
||||
self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm"
|
||||
self.num_models = num_models # the number of sub-models
|
||||
self.enable_sr = enable_sr
|
||||
self.enable_fs = enable_fs
|
||||
self.alpha1 = alpha1
|
||||
self.alpha2 = alpha2
|
||||
self.bins_sr = bins_sr
|
||||
self.bins_fs = bins_fs
|
||||
self.decay = decay
|
||||
if sample_ratios is None: # the default values for sample_ratios
|
||||
sample_ratios = [0.8, 0.7, 0.6, 0.5, 0.4]
|
||||
if sub_weights is None: # the default values for sub_weights
|
||||
sub_weights = [1.0, 0.2, 0.2, 0.2, 0.2, 0.2]
|
||||
if not len(sample_ratios) == bins_fs:
|
||||
raise ValueError("The length of sample_ratios should be equal to bins_fs.")
|
||||
self.sample_ratios = sample_ratios
|
||||
if not len(sub_weights) == num_models:
|
||||
raise ValueError("The length of sub_weights should be equal to num_models.")
|
||||
self.sub_weights = sub_weights
|
||||
self.epochs = epochs
|
||||
self.logger = get_module_logger("DEnsembleModel")
|
||||
self.logger.info("Double Ensemble Model...")
|
||||
self.ensemble = [] # the current ensemble model, a list contains all the sub-models
|
||||
self.sub_features = [] # the features for each sub model in the form of pandas.Index
|
||||
self.params = {"objective": loss}
|
||||
self.params.update(kwargs)
|
||||
self.loss = loss
|
||||
|
||||
def fit(self, dataset: DatasetH):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
# initialize the sample weights
|
||||
N, F = x_train.shape
|
||||
weights = pd.Series(np.ones(N, dtype=float))
|
||||
# initialize the features
|
||||
features = x_train.columns
|
||||
pred_sub = pd.DataFrame(np.zeros((N, self.num_models), dtype=float), index=x_train.index)
|
||||
# train sub-models
|
||||
for k in range(self.num_models):
|
||||
self.sub_features.append(features)
|
||||
self.logger.info("Training sub-model: ({}/{})".format(k + 1, self.num_models))
|
||||
model_k = self.train_submodel(df_train, df_valid, weights, features)
|
||||
self.ensemble.append(model_k)
|
||||
# no further sample re-weight and feature selection needed for the last sub-model
|
||||
if k + 1 == self.num_models:
|
||||
break
|
||||
|
||||
self.logger.info("Retrieving loss curve and loss values...")
|
||||
loss_curve = self.retrieve_loss_curve(model_k, df_train, features)
|
||||
pred_k = self.predict_sub(model_k, df_train, features)
|
||||
pred_sub.iloc[:, k] = pred_k
|
||||
pred_ensemble = pred_sub.iloc[:, : k + 1].mean(axis=1)
|
||||
loss_values = pd.Series(self.get_loss(y_train.values.squeeze(), pred_ensemble.values))
|
||||
|
||||
if self.enable_sr:
|
||||
self.logger.info("Sample re-weighting...")
|
||||
weights = self.sample_reweight(loss_curve, loss_values, k + 1)
|
||||
|
||||
if self.enable_fs:
|
||||
self.logger.info("Feature selection...")
|
||||
features = self.feature_selection(df_train, loss_values)
|
||||
|
||||
def train_submodel(self, df_train, df_valid, weights, features):
|
||||
dtrain, dvalid = self._prepare_data_gbm(df_train, df_valid, weights, features)
|
||||
evals_result = dict()
|
||||
model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=self.epochs,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
verbose_eval=20,
|
||||
evals_result=evals_result,
|
||||
)
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
return model
|
||||
|
||||
def _prepare_data_gbm(self, df_train, df_valid, weights, features):
|
||||
x_train, y_train = df_train["feature"].loc[:, features], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"].loc[:, features], df_valid["label"]
|
||||
|
||||
# Lightgbm need 1D array as its label
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
dtrain = lgb.Dataset(x_train, label=y_train, weight=weights)
|
||||
dvalid = lgb.Dataset(x_valid, label=y_valid)
|
||||
return dtrain, dvalid
|
||||
|
||||
def sample_reweight(self, loss_curve, loss_values, k_th):
|
||||
"""
|
||||
the SR module of Double Ensemble
|
||||
:param loss_curve: the shape is NxT
|
||||
the loss curve for the previous sub-model, where the element (i, t) if the error on the i-th sample
|
||||
after the t-th iteration in the training of the previous sub-model.
|
||||
:param loss_values: the shape is N
|
||||
the loss of the current ensemble on the i-th sample.
|
||||
:param k_th: the index of the current sub-model, starting from 1
|
||||
:return: weights
|
||||
the weights for all the samples.
|
||||
"""
|
||||
# normalize loss_curve and loss_values with ranking
|
||||
loss_curve_norm = loss_curve.rank(axis=0, pct=True)
|
||||
loss_values_norm = (-loss_values).rank(pct=True)
|
||||
|
||||
# calculate l_start and l_end from loss_curve
|
||||
N, T = loss_curve.shape
|
||||
part = np.maximum(int(T * 0.1), 1)
|
||||
l_start = loss_curve_norm.iloc[:, :part].mean(axis=1)
|
||||
l_end = loss_curve_norm.iloc[:, -part:].mean(axis=1)
|
||||
|
||||
# calculate h-value for each sample
|
||||
h1 = loss_values_norm
|
||||
h2 = (l_end / l_start).rank(pct=True)
|
||||
h = pd.DataFrame({"h_value": self.alpha1 * h1 + self.alpha2 * h2})
|
||||
|
||||
# calculate weights
|
||||
h["bins"] = pd.cut(h["h_value"], self.bins_sr)
|
||||
h_avg = h.groupby("bins")["h_value"].mean()
|
||||
weights = pd.Series(np.zeros(N, dtype=float))
|
||||
for i_b, b in enumerate(h_avg.index):
|
||||
weights[h["bins"] == b] = 1.0 / (self.decay ** k_th * h_avg[i_b] + 0.1)
|
||||
return weights
|
||||
|
||||
def feature_selection(self, df_train, loss_values):
|
||||
"""
|
||||
the FS module of Double Ensemble
|
||||
:param df_train: the shape is NxF
|
||||
:param loss_values: the shape is N
|
||||
the loss of the current ensemble on the i-th sample.
|
||||
:return: res_feat: in the form of pandas.Index
|
||||
|
||||
"""
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
features = x_train.columns
|
||||
N, F = x_train.shape
|
||||
g = pd.DataFrame({"g_value": np.zeros(F, dtype=float)})
|
||||
M = len(self.ensemble)
|
||||
|
||||
# shuffle specific columns and calculate g-value for each feature
|
||||
x_train_tmp = x_train.copy()
|
||||
for i_f, feat in enumerate(features):
|
||||
x_train_tmp.loc[:, feat] = np.random.permutation(x_train_tmp.loc[:, feat].values)
|
||||
pred = pd.Series(np.zeros(N), index=x_train_tmp.index)
|
||||
for i_s, submodel in enumerate(self.ensemble):
|
||||
pred += (
|
||||
pd.Series(
|
||||
submodel.predict(x_train_tmp.loc[:, self.sub_features[i_s]].values), index=x_train_tmp.index
|
||||
)
|
||||
/ M
|
||||
)
|
||||
loss_feat = self.get_loss(y_train.values.squeeze(), pred.values)
|
||||
g.loc[i_f, "g_value"] = np.mean(loss_feat - loss_values) / (np.std(loss_feat - loss_values) + 1e-7)
|
||||
x_train_tmp.loc[:, feat] = x_train.loc[:, feat].copy()
|
||||
|
||||
# one column in train features is all-nan # if g['g_value'].isna().any()
|
||||
g["g_value"].replace(np.nan, 0, inplace=True)
|
||||
|
||||
# divide features into bins_fs bins
|
||||
g["bins"] = pd.cut(g["g_value"], self.bins_fs)
|
||||
|
||||
# randomly sample features from bins to construct the new features
|
||||
res_feat = []
|
||||
sorted_bins = sorted(g["bins"].unique(), reverse=True)
|
||||
for i_b, b in enumerate(sorted_bins):
|
||||
b_feat = features[g["bins"] == b]
|
||||
num_feat = int(np.ceil(self.sample_ratios[i_b] * len(b_feat)))
|
||||
res_feat = res_feat + np.random.choice(b_feat, size=num_feat, replace=False).tolist()
|
||||
return pd.Index(set(res_feat))
|
||||
|
||||
def get_loss(self, label, pred):
|
||||
if self.loss == "mse":
|
||||
return (label - pred) ** 2
|
||||
else:
|
||||
raise ValueError("not implemented yet")
|
||||
|
||||
def retrieve_loss_curve(self, model, df_train, features):
|
||||
if self.base_model == "gbm":
|
||||
num_trees = model.num_trees()
|
||||
x_train, y_train = df_train["feature"].loc[:, features], df_train["label"]
|
||||
# Lightgbm need 1D array as its label
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
y_train = np.squeeze(y_train.values)
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
N = x_train.shape[0]
|
||||
loss_curve = pd.DataFrame(np.zeros((N, num_trees)))
|
||||
pred_tree = np.zeros(N, dtype=float)
|
||||
for i_tree in range(num_trees):
|
||||
pred_tree += model.predict(x_train.values, start_iteration=i_tree, num_iteration=1)
|
||||
loss_curve.iloc[:, i_tree] = self.get_loss(y_train, pred_tree)
|
||||
else:
|
||||
raise ValueError("not implemented yet")
|
||||
return loss_curve
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.ensemble is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
pred = pd.Series(np.zeros(x_test.shape[0]), index=x_test.index)
|
||||
for i_sub, submodel in enumerate(self.ensemble):
|
||||
feat_sub = self.sub_features[i_sub]
|
||||
pred += (
|
||||
pd.Series(submodel.predict(x_test.loc[:, feat_sub].values), index=x_test.index)
|
||||
* self.sub_weights[i_sub]
|
||||
)
|
||||
return pred
|
||||
|
||||
def predict_sub(self, submodel, df_data, features):
|
||||
x_data, y_data = df_data["feature"].loc[:, features], df_data["label"]
|
||||
pred_sub = pd.Series(submodel.predict(x_data.values), index=x_data.index)
|
||||
return pred_sub
|
||||
|
||||
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
|
||||
"""get feature importance
|
||||
|
||||
Notes
|
||||
-----
|
||||
parameters reference:
|
||||
https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.Booster.html?highlight=feature_importance#lightgbm.Booster.feature_importance
|
||||
"""
|
||||
res = []
|
||||
for _model, _weight in zip(self.ensemble, self.sub_weights):
|
||||
res.append(pd.Series(_model.feature_importance(*args, **kwargs), index=_model.feature_name()) * _weight)
|
||||
return pd.concat(res, axis=1, sort=False).sum(axis=1).sort_values(ascending=False)
|
||||
@@ -4,13 +4,14 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import lightgbm as lgb
|
||||
|
||||
from typing import Text, Union
|
||||
from ...model.base import ModelFT
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...model.interpret.base import LightGBMFInt
|
||||
|
||||
|
||||
class LGBModel(ModelFT):
|
||||
class LGBModel(ModelFT, LightGBMFInt):
|
||||
"""LightGBM Model"""
|
||||
|
||||
def __init__(self, loss="mse", **kwargs):
|
||||
@@ -33,8 +34,8 @@ class LGBModel(ModelFT):
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
dtrain = lgb.Dataset(x_train.values, label=y_train)
|
||||
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
|
||||
dtrain = lgb.Dataset(x_train, label=y_train)
|
||||
dvalid = lgb.Dataset(x_valid, label=y_valid)
|
||||
return dtrain, dvalid
|
||||
|
||||
def fit(
|
||||
@@ -61,10 +62,10 @@ class LGBModel(ModelFT):
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
|
||||
|
||||
158
qlib/contrib/model/highfreq_gdbt_model.py
Normal file
158
qlib/contrib/model/highfreq_gdbt_model.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import lightgbm as lgb
|
||||
|
||||
from ...model.base import ModelFT
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...model.interpret.base import LightGBMFInt
|
||||
|
||||
|
||||
class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
"""LightGBM Model for high frequency prediction"""
|
||||
|
||||
def __init__(self, loss="mse", **kwargs):
|
||||
if loss not in {"mse", "binary"}:
|
||||
raise NotImplementedError
|
||||
self.params = {"objective": loss, "verbosity": -1}
|
||||
self.params.update(kwargs)
|
||||
self.model = None
|
||||
|
||||
def _cal_signal_metrics(self, y_test, l_cut, r_cut):
|
||||
"""
|
||||
Calcaute the signal metrics by daily level
|
||||
"""
|
||||
up_pre, down_pre = [], []
|
||||
up_alpha_ll, down_alpha_ll = [], []
|
||||
for date in y_test.index.get_level_values(0).unique():
|
||||
df_res = y_test.loc[date].sort_values("pred")
|
||||
if int(l_cut * len(df_res)) < 10:
|
||||
warnings.warn("Warning: threhold is too low or instruments number is not enough")
|
||||
continue
|
||||
top = df_res.iloc[: int(l_cut * len(df_res))]
|
||||
bottom = df_res.iloc[int(r_cut * len(df_res)) :]
|
||||
|
||||
down_precision = len(top[top[top.columns[0]] < 0]) / (len(top))
|
||||
up_precision = len(bottom[bottom[top.columns[0]] > 0]) / (len(bottom))
|
||||
|
||||
down_alpha = top[top.columns[0]].mean()
|
||||
up_alpha = bottom[bottom.columns[0]].mean()
|
||||
|
||||
up_pre.append(up_precision)
|
||||
down_pre.append(down_precision)
|
||||
up_alpha_ll.append(up_alpha)
|
||||
down_alpha_ll.append(down_alpha)
|
||||
|
||||
return (
|
||||
np.array(up_pre).mean(),
|
||||
np.array(down_pre).mean(),
|
||||
np.array(up_alpha_ll).mean(),
|
||||
np.array(down_alpha_ll).mean(),
|
||||
)
|
||||
|
||||
def hf_signal_test(self, dataset: DatasetH, threhold=0.2):
|
||||
"""
|
||||
Test the sigal in high frequency test set
|
||||
"""
|
||||
if self.model == None:
|
||||
raise ValueError("Model hasn't been trained yet")
|
||||
df_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
df_test.dropna(inplace=True)
|
||||
x_test, y_test = df_test["feature"], df_test["label"]
|
||||
# Convert label into alpha
|
||||
y_test[y_test.columns[0]] = y_test[y_test.columns[0]] - y_test[y_test.columns[0]].mean(level=0)
|
||||
|
||||
res = pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
y_test["pred"] = res
|
||||
|
||||
up_p, down_p, up_a, down_a = self._cal_signal_metrics(y_test, threhold, 1 - threhold)
|
||||
print("===============================")
|
||||
print("High frequency signal test")
|
||||
print("===============================")
|
||||
print("Test set precision: ")
|
||||
print("Positive precision: {}, Negative precision: {}".format(up_p, down_p))
|
||||
print("Test Alpha Average in test set: ")
|
||||
print("Positive average alpha: {}, Negative average alpha: {}".format(up_a, down_a))
|
||||
|
||||
def _prepare_data(self, dataset: DatasetH):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_train["feature"], df_valid["label"]
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
l_name = df_train["label"].columns[0]
|
||||
# Convert label into alpha
|
||||
df_train["label"][l_name] = df_train["label"][l_name] - df_train["label"][l_name].mean(level=0)
|
||||
df_valid["label"][l_name] = df_valid["label"][l_name] - df_valid["label"][l_name].mean(level=0)
|
||||
mapping_fn = lambda x: 0 if x < 0 else 1
|
||||
df_train["label_c"] = df_train["label"][l_name].apply(mapping_fn)
|
||||
df_valid["label_c"] = df_valid["label"][l_name].apply(mapping_fn)
|
||||
x_train, y_train = df_train["feature"], df_train["label_c"].values
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label_c"].values
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
dtrain = lgb.Dataset(x_train, label=y_train)
|
||||
dvalid = lgb.Dataset(x_valid, label=y_valid)
|
||||
return dtrain, dvalid
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=50,
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
**kwargs
|
||||
):
|
||||
dtrain, dvalid = self._prepare_data(dataset)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
)
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
def predict(self, dataset):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
|
||||
"""
|
||||
finetune model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset : DatasetH
|
||||
dataset for finetuning
|
||||
num_boost_round : int
|
||||
number of round to finetune model
|
||||
verbose_eval : int
|
||||
verbose level
|
||||
"""
|
||||
# Based on existing model and finetune by train more rounds
|
||||
dtrain, _ = self._prepare_data(dataset)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
init_model=self.model,
|
||||
valid_sets=[dtrain],
|
||||
valid_names=["train"],
|
||||
verbose_eval=verbose_eval,
|
||||
)
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from typing import Text, Union
|
||||
from scipy.optimize import nnls
|
||||
from sklearn.linear_model import LinearRegression, Ridge, Lasso
|
||||
|
||||
@@ -84,8 +84,8 @@ class LinearModel(Model):
|
||||
self.coef_ = coef
|
||||
self.intercept_ = 0.0
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.coef_ is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(x_test.values @ self.coef_ + self.intercept_, index=x_test.index)
|
||||
|
||||
@@ -8,21 +8,16 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -39,8 +34,8 @@ class ALSTM(Model):
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -76,8 +71,7 @@ class ALSTM(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -93,7 +87,7 @@ class ALSTM(Model):
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
@@ -107,7 +101,7 @@ class ALSTM(Model):
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
@@ -123,6 +117,9 @@ class ALSTM(Model):
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.ALSTM_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.ALSTM_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.ALSTM_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -130,9 +127,13 @@ class ALSTM(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self._fitted = False
|
||||
self.fitted = False
|
||||
self.ALSTM_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -201,12 +202,13 @@ class ALSTM(Model):
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.ALSTM_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
with torch.no_grad():
|
||||
pred = self.ALSTM_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
@@ -214,7 +216,6 @@ class ALSTM(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -227,8 +228,7 @@ class ALSTM(Model):
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
@@ -238,7 +238,7 @@ class ALSTM(Model):
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
@@ -269,11 +269,11 @@ class ALSTM(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.ALSTM_model.eval()
|
||||
x_values = x_test.values
|
||||
@@ -290,10 +290,7 @@ class ALSTM(Model):
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.ALSTM_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.ALSTM_model(x_batch).detach().numpy()
|
||||
pred = self.ALSTM_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -8,22 +8,17 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -40,8 +35,8 @@ class ALSTM(Model):
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -78,9 +73,8 @@ class ALSTM(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -96,7 +90,7 @@ class ALSTM(Model):
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nn_jobs : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
@@ -111,7 +105,7 @@ class ALSTM(Model):
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.device,
|
||||
n_jobs,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
@@ -127,7 +121,10 @@ class ALSTM(Model):
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
).to(self.device)
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.ALSTM_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.ALSTM_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.ALSTM_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -135,9 +132,13 @@ class ALSTM(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self._fitted = False
|
||||
self.fitted = False
|
||||
self.ALSTM_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -188,12 +189,13 @@ class ALSTM(Model):
|
||||
# feature[torch.isnan(feature)] = 0
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
pred = self.ALSTM_model(feature.float())
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
with torch.no_grad():
|
||||
pred = self.ALSTM_model(feature.float())
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
@@ -201,7 +203,6 @@ class ALSTM(Model):
|
||||
self,
|
||||
dataset,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
@@ -210,11 +211,14 @@ class ALSTM(Model):
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
|
||||
train_loader = DataLoader(
|
||||
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
@@ -225,7 +229,7 @@ class ALSTM(Model):
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
@@ -256,11 +260,11 @@ class ALSTM(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test = dataset.prepare(segment, col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test.config(fillna_type="ffill+bfill")
|
||||
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
|
||||
self.ALSTM_model.eval()
|
||||
@@ -271,10 +275,7 @@ class ALSTM(Model):
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.ALSTM_model(feature.float()).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.ALSTM_model(feature.float()).detach().numpy()
|
||||
pred = self.ALSTM_model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -8,20 +8,15 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -42,8 +37,8 @@ class GATs(Model):
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -83,8 +78,7 @@ class GATs(Model):
|
||||
self.base_model = base_model
|
||||
self.with_pretrain = with_pretrain
|
||||
self.model_path = model_path
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -102,7 +96,7 @@ class GATs(Model):
|
||||
"\nbase_model : {}"
|
||||
"\nwith_pretrain : {}"
|
||||
"\nmodel_path : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
@@ -118,7 +112,7 @@ class GATs(Model):
|
||||
base_model,
|
||||
with_pretrain,
|
||||
model_path,
|
||||
GPU,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
@@ -135,6 +129,9 @@ class GATs(Model):
|
||||
dropout=self.dropout,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.GAT_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.GAT_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.GAT_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -142,9 +139,13 @@ class GATs(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self._fitted = False
|
||||
self.fitted = False
|
||||
self.GAT_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -232,7 +233,6 @@ class GATs(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -245,8 +245,7 @@ class GATs(Model):
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
@@ -275,7 +274,7 @@ class GATs(Model):
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
@@ -306,11 +305,11 @@ class GATs(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature")
|
||||
index = x_test.index
|
||||
self.GAT_model.eval()
|
||||
x_values = x_test.values
|
||||
@@ -324,10 +323,7 @@ class GATs(Model):
|
||||
x_batch = torch.from_numpy(x_values[batch]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.GAT_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.GAT_model(x_batch).detach().numpy()
|
||||
pred = self.GAT_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -9,21 +9,15 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -62,8 +56,8 @@ class GATs(Model):
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -104,9 +98,8 @@ class GATs(Model):
|
||||
self.base_model = base_model
|
||||
self.with_pretrain = with_pretrain
|
||||
self.model_path = model_path
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -157,6 +150,9 @@ class GATs(Model):
|
||||
dropout=self.dropout,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.GAT_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.GAT_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.GAT_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -164,9 +160,13 @@ class GATs(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self._fitted = False
|
||||
self.fitted = False
|
||||
self.GAT_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -245,7 +245,6 @@ class GATs(Model):
|
||||
self,
|
||||
dataset,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -258,11 +257,10 @@ class GATs(Model):
|
||||
sampler_train = DailyBatchSampler(dl_train)
|
||||
sampler_valid = DailyBatchSampler(dl_valid)
|
||||
|
||||
train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs)
|
||||
train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs, drop_last=True)
|
||||
valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs, drop_last=True)
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
@@ -297,7 +295,7 @@ class GATs(Model):
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
@@ -329,7 +327,7 @@ class GATs(Model):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
@@ -345,10 +343,7 @@ class GATs(Model):
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.GAT_model(feature.float()).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.GAT_model(feature.float()).detach().numpy()
|
||||
pred = self.GAT_model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -8,21 +8,16 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -76,8 +71,7 @@ class GRU(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -123,6 +117,9 @@ class GRU(Model):
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.gru_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.gru_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.gru_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -130,9 +127,13 @@ class GRU(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self._fitted = False
|
||||
self.fitted = False
|
||||
self.gru_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -201,12 +202,13 @@ class GRU(Model):
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.gru_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
with torch.no_grad():
|
||||
pred = self.gru_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
@@ -214,7 +216,6 @@ class GRU(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -227,8 +228,7 @@ class GRU(Model):
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
@@ -238,7 +238,7 @@ class GRU(Model):
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
@@ -269,11 +269,11 @@ class GRU(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.gru_model.eval()
|
||||
x_values = x_test.values
|
||||
@@ -290,10 +290,7 @@ class GRU(Model):
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.gru_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.gru_model(x_batch).detach().numpy()
|
||||
pred = self.gru_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -9,21 +9,15 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -78,9 +72,8 @@ class GRU(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -96,7 +89,7 @@ class GRU(Model):
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nn_jobs : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
@@ -111,7 +104,7 @@ class GRU(Model):
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.device,
|
||||
n_jobs,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
@@ -127,7 +120,10 @@ class GRU(Model):
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
).to(self.device)
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.GRU_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.GRU_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.GRU_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -135,9 +131,13 @@ class GRU(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self._fitted = False
|
||||
self.fitted = False
|
||||
self.GRU_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -188,12 +188,13 @@ class GRU(Model):
|
||||
# feature[torch.isnan(feature)] = 0
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
pred = self.GRU_model(feature.float())
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
with torch.no_grad():
|
||||
pred = self.GRU_model(feature.float())
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
@@ -201,7 +202,6 @@ class GRU(Model):
|
||||
self,
|
||||
dataset,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
@@ -210,11 +210,14 @@ class GRU(Model):
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
|
||||
train_loader = DataLoader(
|
||||
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
@@ -225,7 +228,7 @@ class GRU(Model):
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
@@ -257,7 +260,7 @@ class GRU(Model):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
@@ -271,10 +274,7 @@ class GRU(Model):
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.GRU_model(feature.float()).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.GRU_model(feature.float()).detach().numpy()
|
||||
pred = self.GRU_model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -8,16 +8,10 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -76,8 +70,7 @@ class LSTM(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -130,9 +123,13 @@ class LSTM(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self._fitted = False
|
||||
self.fitted = False
|
||||
self.lstm_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -214,7 +211,6 @@ class LSTM(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -227,8 +223,7 @@ class LSTM(Model):
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
@@ -238,7 +233,7 @@ class LSTM(Model):
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
@@ -269,11 +264,11 @@ class LSTM(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.lstm_model.eval()
|
||||
x_values = x_test.values
|
||||
@@ -281,20 +276,13 @@ class LSTM(Model):
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.lstm_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.lstm_model(x_batch).detach().numpy()
|
||||
|
||||
pred = self.lstm_model(x_batch).detach().cpu().numpy()
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
@@ -9,15 +9,8 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -78,9 +71,8 @@ class LSTM(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -96,7 +88,7 @@ class LSTM(Model):
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nn_jobs : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
@@ -111,7 +103,7 @@ class LSTM(Model):
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.device,
|
||||
n_jobs,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
@@ -135,9 +127,13 @@ class LSTM(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self._fitted = False
|
||||
self.fitted = False
|
||||
self.LSTM_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -201,7 +197,6 @@ class LSTM(Model):
|
||||
self,
|
||||
dataset,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
@@ -210,11 +205,14 @@ class LSTM(Model):
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
|
||||
train_loader = DataLoader(
|
||||
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
@@ -225,7 +223,7 @@ class LSTM(Model):
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
@@ -257,7 +255,7 @@ class LSTM(Model):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
@@ -271,10 +269,7 @@ class LSTM(Model):
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.LSTM_model(feature.float()).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.LSTM_model(feature.float()).detach().numpy()
|
||||
pred = self.LSTM_model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -6,20 +6,21 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
from ...workflow import R
|
||||
|
||||
|
||||
@@ -42,14 +43,14 @@ class DNNModelPytorch(Model):
|
||||
learning rate decay steps
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
output_dim,
|
||||
input_dim=360,
|
||||
output_dim=1,
|
||||
layers=(256,),
|
||||
lr=0.001,
|
||||
max_steps=300,
|
||||
@@ -80,8 +81,7 @@ class DNNModelPytorch(Model):
|
||||
self.lr_decay_steps = lr_decay_steps
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss_type = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_GPU = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
@@ -99,7 +99,7 @@ class DNNModelPytorch(Model):
|
||||
"\nloss_type : {}"
|
||||
"\neval_steps : {}"
|
||||
"\nseed : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nweight_decay : {}".format(
|
||||
layers,
|
||||
@@ -114,8 +114,8 @@ class DNNModelPytorch(Model):
|
||||
loss,
|
||||
eval_steps,
|
||||
seed,
|
||||
GPU,
|
||||
self.use_GPU,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
weight_decay,
|
||||
)
|
||||
)
|
||||
@@ -129,6 +129,9 @@ class DNNModelPytorch(Model):
|
||||
self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
|
||||
|
||||
self.dnn_model = Net(input_dim, output_dim, layers, loss=self.loss_type)
|
||||
self.logger.info("model:\n{:}".format(self.dnn_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.dnn_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -150,9 +153,13 @@ class DNNModelPytorch(Model):
|
||||
eps=1e-08,
|
||||
)
|
||||
|
||||
self._fitted = False
|
||||
self.fitted = False
|
||||
self.dnn_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
@@ -172,7 +179,7 @@ class DNNModelPytorch(Model):
|
||||
w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)
|
||||
w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)
|
||||
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_loss = np.inf
|
||||
@@ -180,7 +187,7 @@ class DNNModelPytorch(Model):
|
||||
evals_result["valid"] = []
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
# return
|
||||
# prepare training data
|
||||
x_train_values = torch.from_numpy(x_train.values).float()
|
||||
@@ -215,7 +222,8 @@ class DNNModelPytorch(Model):
|
||||
|
||||
# validation
|
||||
train_loss += loss.val
|
||||
if step and step % self.eval_steps == 0:
|
||||
# for evert `eval_steps` steps or at the last steps, we will evaluate the model.
|
||||
if step % self.eval_steps == 0 or step + 1 == self.max_steps:
|
||||
stop_steps += 1
|
||||
train_loss /= self.eval_steps
|
||||
|
||||
@@ -248,9 +256,9 @@ class DNNModelPytorch(Model):
|
||||
# update learning rate
|
||||
self.scheduler.step(cur_loss_val)
|
||||
|
||||
# restore the optimal parameters after training ??
|
||||
# restore the optimal parameters after training
|
||||
self.dnn_model.load_state_dict(torch.load(save_path))
|
||||
if self.use_GPU:
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_loss(self, pred, w, target, loss_type):
|
||||
@@ -264,18 +272,14 @@ class DNNModelPytorch(Model):
|
||||
else:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss_type))
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test_pd = dataset.prepare("test", col_set="feature")
|
||||
x_test_pd = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = torch.from_numpy(x_test_pd.values).float().to(self.device)
|
||||
self.dnn_model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_GPU:
|
||||
preds = self.dnn_model(x_test).detach().cpu().numpy()
|
||||
else:
|
||||
preds = self.dnn_model(x_test).detach().numpy()
|
||||
preds = self.dnn_model(x_test).detach().cpu().numpy()
|
||||
return pd.Series(np.squeeze(preds), index=x_test_pd.index)
|
||||
|
||||
def save(self, filename, **kwargs):
|
||||
|
||||
@@ -7,22 +7,17 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -196,8 +191,8 @@ class SFM(Model):
|
||||
learning rate
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -216,7 +211,7 @@ class SFM(Model):
|
||||
eval_steps=5,
|
||||
loss="mse",
|
||||
optimizer="gd",
|
||||
GPU="0",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
@@ -239,8 +234,7 @@ class SFM(Model):
|
||||
self.eval_steps = eval_steps
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -259,7 +253,7 @@ class SFM(Model):
|
||||
"\neval_steps : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
@@ -276,7 +270,7 @@ class SFM(Model):
|
||||
eval_steps,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
@@ -295,6 +289,9 @@ class SFM(Model):
|
||||
dropout_U=self.dropout_U,
|
||||
device=self.device,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.sfm_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.sfm_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.sfm_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -302,9 +299,13 @@ class SFM(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self._fitted = False
|
||||
self.fitted = False
|
||||
self.sfm_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
|
||||
# prepare training data
|
||||
@@ -365,7 +366,6 @@ class SFM(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -377,6 +377,7 @@ class SFM(Model):
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
@@ -386,7 +387,7 @@ class SFM(Model):
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
@@ -409,7 +410,10 @@ class SFM(Model):
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.sfm_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
if self.device != "cpu":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -434,11 +438,11 @@ class SFM(Model):
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.sfm_model.eval()
|
||||
x_values = x_test.values
|
||||
@@ -451,10 +455,7 @@ class SFM(Model):
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float()
|
||||
|
||||
if self.device != "cpu":
|
||||
x_batch = x_batch.to(self.device)
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.sfm_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
@@ -6,16 +6,10 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -23,6 +17,7 @@ import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Function
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -49,12 +44,12 @@ class TabnetModel(Model):
|
||||
loss="mse",
|
||||
metric="",
|
||||
early_stop=20,
|
||||
GPU="1",
|
||||
GPU=0,
|
||||
pretrain_loss="custom",
|
||||
ps=0.3,
|
||||
lr=0.01,
|
||||
pretrain=True,
|
||||
pretrain_file="./pretrain/best.model",
|
||||
pretrain_file=None,
|
||||
):
|
||||
"""
|
||||
TabNet model for Qlib
|
||||
@@ -75,28 +70,27 @@ class TabnetModel(Model):
|
||||
self.n_epochs = n_epochs
|
||||
self.logger = get_module_logger("TabNet")
|
||||
self.pretrain_n_epochs = pretrain_n_epochs
|
||||
self.device = "cuda:%s" % (GPU) if torch.cuda.is_available() else "cpu"
|
||||
self.device = "cuda:%s" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu"
|
||||
self.loss = loss
|
||||
self.metric = metric
|
||||
self.early_stop = early_stop
|
||||
self.pretrain = pretrain
|
||||
self.pretrain_file = pretrain_file
|
||||
self.pretrain_file = get_or_create_path(pretrain_file)
|
||||
self.logger.info(
|
||||
"TabNet:"
|
||||
"\nbatch_size : {}"
|
||||
"\nvirtual bs : {}"
|
||||
"\nGPU : {}"
|
||||
"\npretrain: {}".format(self.batch_size, vbs, GPU, pretrain)
|
||||
"\ndevice : {}"
|
||||
"\npretrain: {}".format(self.batch_size, vbs, self.device, self.pretrain)
|
||||
)
|
||||
self.fitted = False
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.tabnet_model = TabNet(
|
||||
inp_dim=self.d_feat, out_dim=self.out_dim, vbs=vbs, relax=relax, device=self.device
|
||||
).to(self.device)
|
||||
self.tabnet_decoder = TabNet_Decoder(self.out_dim, self.d_feat, n_shared, n_ind, vbs, n_steps, self.device).to(
|
||||
self.device
|
||||
)
|
||||
self.tabnet_model = TabNet(inp_dim=self.d_feat, out_dim=self.out_dim, vbs=vbs, relax=relax).to(self.device)
|
||||
self.tabnet_decoder = TabNet_Decoder(self.out_dim, self.d_feat, n_shared, n_ind, vbs, n_steps).to(self.device)
|
||||
self.logger.info("model:\n{:}\n{:}".format(self.tabnet_model, self.tabnet_decoder))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters([self.tabnet_model, self.tabnet_decoder])))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.pretrain_optimizer = optim.Adam(
|
||||
@@ -112,11 +106,12 @@ class TabnetModel(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def pretrain_fn(self, dataset=DatasetH, pretrain_file="./pretrain/best.model"):
|
||||
# make a directory if pretrian director does not exist
|
||||
if pretrain_file.startswith("./pretrain") and not os.path.exists("pretrain"):
|
||||
self.logger.info("make folder to store model...")
|
||||
os.makedirs("pretrain")
|
||||
get_or_create_path(pretrain_file)
|
||||
|
||||
[df_train, df_valid] = dataset.prepare(
|
||||
["pretrain", "pretrain_validation"],
|
||||
@@ -158,7 +153,6 @@ class TabnetModel(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
if self.pretrain:
|
||||
@@ -178,16 +172,17 @@ class TabnetModel(Model):
|
||||
df_train.fillna(df_train.mean(), inplace=True)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = np.inf
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
for epoch_idx in range(self.n_epochs):
|
||||
self.logger.info("epoch: %s" % (epoch_idx))
|
||||
@@ -200,22 +195,29 @@ class TabnetModel(Model):
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score < best_score:
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = epoch_idx
|
||||
best_param = copy.deepcopy(self.tabnet_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.tabnet_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.tabnet_model.eval()
|
||||
x_values = torch.from_numpy(x_test.values)
|
||||
@@ -259,12 +261,13 @@ class TabnetModel(Model):
|
||||
feature = x_values[indices[i : i + self.batch_size]].float().to(self.device)
|
||||
label = y_values[indices[i : i + self.batch_size]].float().to(self.device)
|
||||
priors = torch.ones(self.batch_size, self.d_feat).to(self.device)
|
||||
pred = self.tabnet_model(feature, priors)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
with torch.no_grad():
|
||||
pred = self.tabnet_model(feature, priors)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
@@ -347,10 +350,11 @@ class TabnetModel(Model):
|
||||
label = y_train_values.float().to(self.device)
|
||||
S_mask = S_mask.to(self.device)
|
||||
priors = 1 - S_mask
|
||||
(vec, sparse_loss) = self.tabnet_model(feature, priors)
|
||||
f = self.tabnet_decoder(vec)
|
||||
with torch.no_grad():
|
||||
(vec, sparse_loss) = self.tabnet_model(feature, priors)
|
||||
f = self.tabnet_decoder(vec)
|
||||
|
||||
loss = self.pretrain_loss_fn(label, f, S_mask)
|
||||
loss = self.pretrain_loss_fn(label, f, S_mask)
|
||||
losses.append(loss.item())
|
||||
|
||||
return np.mean(losses)
|
||||
@@ -396,9 +400,9 @@ class FinetuneModel(nn.Module):
|
||||
|
||||
|
||||
class DecoderStep(nn.Module):
|
||||
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs, device):
|
||||
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs):
|
||||
super().__init__()
|
||||
self.fea_tran = FeatureTransformer(inp_dim, out_dim, shared, n_ind, vbs, device)
|
||||
self.fea_tran = FeatureTransformer(inp_dim, out_dim, shared, n_ind, vbs)
|
||||
self.fc = nn.Linear(out_dim, out_dim)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -407,13 +411,12 @@ class DecoderStep(nn.Module):
|
||||
|
||||
|
||||
class TabNet_Decoder(nn.Module):
|
||||
def __init__(self, inp_dim, out_dim, n_shared, n_ind, vbs, n_steps, device):
|
||||
def __init__(self, inp_dim, out_dim, n_shared, n_ind, vbs, n_steps):
|
||||
"""
|
||||
TabNet decoder that is used in pre-training
|
||||
"""
|
||||
self.out_dim = out_dim
|
||||
|
||||
super().__init__()
|
||||
self.out_dim = out_dim
|
||||
if n_shared > 0:
|
||||
self.shared = nn.ModuleList()
|
||||
self.shared.append(nn.Linear(inp_dim, 2 * out_dim))
|
||||
@@ -424,7 +427,7 @@ class TabNet_Decoder(nn.Module):
|
||||
self.n_steps = n_steps
|
||||
self.steps = nn.ModuleList()
|
||||
for x in range(n_steps):
|
||||
self.steps.append(DecoderStep(inp_dim, out_dim, self.shared, n_ind, vbs, device))
|
||||
self.steps.append(DecoderStep(inp_dim, out_dim, self.shared, n_ind, vbs))
|
||||
|
||||
def forward(self, x):
|
||||
out = torch.zeros(x.size(0), self.out_dim).to(x.device)
|
||||
@@ -434,9 +437,7 @@ class TabNet_Decoder(nn.Module):
|
||||
|
||||
|
||||
class TabNet(nn.Module):
|
||||
def __init__(
|
||||
self, inp_dim=6, out_dim=6, n_d=64, n_a=64, n_shared=2, n_ind=2, n_steps=5, relax=1.2, vbs=1024, device="cpu"
|
||||
):
|
||||
def __init__(self, inp_dim=6, out_dim=6, n_d=64, n_a=64, n_shared=2, n_ind=2, n_steps=5, relax=1.2, vbs=1024):
|
||||
"""
|
||||
TabNet AKA the original encoder
|
||||
|
||||
@@ -460,10 +461,10 @@ class TabNet(nn.Module):
|
||||
else:
|
||||
self.shared = None
|
||||
|
||||
self.first_step = FeatureTransformer(inp_dim, n_d + n_a, self.shared, n_ind, vbs, device)
|
||||
self.first_step = FeatureTransformer(inp_dim, n_d + n_a, self.shared, n_ind, vbs)
|
||||
self.steps = nn.ModuleList()
|
||||
for x in range(n_steps - 1):
|
||||
self.steps.append(DecisionStep(inp_dim, n_d, n_a, self.shared, n_ind, relax, vbs, device))
|
||||
self.steps.append(DecisionStep(inp_dim, n_d, n_a, self.shared, n_ind, relax, vbs))
|
||||
self.fc = nn.Linear(n_d, out_dim)
|
||||
self.bn = nn.BatchNorm1d(inp_dim, momentum=0.01)
|
||||
self.n_d = n_d
|
||||
@@ -472,14 +473,14 @@ class TabNet(nn.Module):
|
||||
assert not torch.isnan(x).any()
|
||||
x = self.bn(x)
|
||||
x_a = self.first_step(x)[:, self.n_d :]
|
||||
sparse_loss = torch.zeros(1).to(x.device)
|
||||
sparse_loss = []
|
||||
out = torch.zeros(x.size(0), self.n_d).to(x.device)
|
||||
for step in self.steps:
|
||||
x_te, l = step(x, x_a, priors)
|
||||
out += F.relu(x_te[:, : self.n_d]) # split the feautre from feat_transformer
|
||||
x_a = x_te[:, self.n_d :]
|
||||
sparse_loss += l
|
||||
return self.fc(out), sparse_loss
|
||||
sparse_loss.append(l)
|
||||
return self.fc(out), sum(sparse_loss)
|
||||
|
||||
|
||||
class GBN(nn.Module):
|
||||
@@ -497,9 +498,12 @@ class GBN(nn.Module):
|
||||
self.vbs = vbs
|
||||
|
||||
def forward(self, x):
|
||||
chunk = torch.chunk(x, x.size(0) // self.vbs, 0)
|
||||
res = [self.bn(y) for y in chunk]
|
||||
return torch.cat(res, 0)
|
||||
if x.size(0) <= self.vbs: # can not be chunked
|
||||
return self.bn(x)
|
||||
else:
|
||||
chunk = torch.chunk(x, x.size(0) // self.vbs, 0)
|
||||
res = [self.bn(y) for y in chunk]
|
||||
return torch.cat(res, 0)
|
||||
|
||||
|
||||
class GLU(nn.Module):
|
||||
@@ -547,7 +551,7 @@ class AttentionTransformer(nn.Module):
|
||||
|
||||
|
||||
class FeatureTransformer(nn.Module):
|
||||
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs, device):
|
||||
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs):
|
||||
super().__init__()
|
||||
first = True
|
||||
self.shared = nn.ModuleList()
|
||||
@@ -563,7 +567,7 @@ class FeatureTransformer(nn.Module):
|
||||
self.independ.append(GLU(inp, out_dim, vbs=vbs))
|
||||
for x in range(first, n_ind):
|
||||
self.independ.append(GLU(out_dim, out_dim, vbs=vbs))
|
||||
self.scale = torch.sqrt(torch.tensor([0.5], device=device))
|
||||
self.scale = float(np.sqrt(0.5))
|
||||
|
||||
def forward(self, x):
|
||||
if self.shared:
|
||||
@@ -582,10 +586,10 @@ class DecisionStep(nn.Module):
|
||||
One step for the TabNet
|
||||
"""
|
||||
|
||||
def __init__(self, inp_dim, n_d, n_a, shared, n_ind, relax, vbs, device):
|
||||
def __init__(self, inp_dim, n_d, n_a, shared, n_ind, relax, vbs):
|
||||
super().__init__()
|
||||
self.atten_tran = AttentionTransformer(n_a, inp_dim, relax, vbs)
|
||||
self.fea_tran = FeatureTransformer(inp_dim, n_d + n_a, shared, n_ind, vbs, device)
|
||||
self.fea_tran = FeatureTransformer(inp_dim, n_d + n_a, shared, n_ind, vbs)
|
||||
|
||||
def forward(self, x, a, priors):
|
||||
mask = self.atten_tran(a, priors)
|
||||
|
||||
37
qlib/contrib/model/pytorch_utils.py
Normal file
37
qlib/contrib/model/pytorch_utils.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def count_parameters(models_or_parameters, unit="m"):
|
||||
"""
|
||||
This function is to obtain the storage size unit of a (or multiple) models.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
models_or_parameters : PyTorch model(s) or a list of parameters.
|
||||
unit : the storage size unit.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The number of parameters of the given model(s) or parameters.
|
||||
"""
|
||||
if isinstance(models_or_parameters, nn.Module):
|
||||
counts = sum(v.numel() for v in models_or_parameters.parameters())
|
||||
elif isinstance(models_or_parameters, nn.Parameter):
|
||||
counts = models_or_parameters.numel()
|
||||
elif isinstance(models_or_parameters, (list, tuple)):
|
||||
return sum(count_parameters(x, unit) for x in models_or_parameters)
|
||||
else:
|
||||
counts = sum(v.numel() for v in models_or_parameters)
|
||||
unit = unit.lower()
|
||||
if unit == "kb" or unit == "k":
|
||||
counts /= 2 ** 10
|
||||
elif unit == "mb" or unit == "m":
|
||||
counts /= 2 ** 20
|
||||
elif unit == "gb" or unit == "g":
|
||||
counts /= 2 ** 30
|
||||
elif unit is not None:
|
||||
raise ValueError("Unknow unit: {:}".format(unit))
|
||||
return counts
|
||||
@@ -4,13 +4,14 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import xgboost as xgb
|
||||
|
||||
from typing import Text, Union
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...model.interpret.base import FeatureInt
|
||||
|
||||
|
||||
class XGBModel(Model):
|
||||
class XGBModel(Model, FeatureInt):
|
||||
"""XGBModel Model"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -42,8 +43,8 @@ class XGBModel(Model):
|
||||
else:
|
||||
raise ValueError("XGBoost doesn't support multi-label training")
|
||||
|
||||
dtrain = xgb.DMatrix(x_train.values, label=y_train_1d)
|
||||
dvalid = xgb.DMatrix(x_valid.values, label=y_valid_1d)
|
||||
dtrain = xgb.DMatrix(x_train, label=y_train_1d)
|
||||
dvalid = xgb.DMatrix(x_valid, label=y_valid_1d)
|
||||
self.model = xgb.train(
|
||||
self._params,
|
||||
dtrain=dtrain,
|
||||
@@ -57,8 +58,18 @@ class XGBModel(Model):
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(xgb.DMatrix(x_test.values)), index=x_test.index)
|
||||
|
||||
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
|
||||
"""get feature importance
|
||||
|
||||
Notes
|
||||
-------
|
||||
parameters reference:
|
||||
https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.Booster.get_score
|
||||
"""
|
||||
return pd.Series(self.model.get_score(*args, **kwargs)).sort_values(ascending=False)
|
||||
|
||||
@@ -63,7 +63,7 @@ class UserManager:
|
||||
account_path = self.data_path / user_id
|
||||
strategy_file = self.data_path / user_id / "strategy_{}.pickle".format(user_id)
|
||||
model_file = self.data_path / user_id / "model_{}.pickle".format(user_id)
|
||||
cur_user_list = [user_id for user_id in self.users]
|
||||
cur_user_list = list(self.users)
|
||||
if user_id in cur_user_list:
|
||||
raise ValueError("User {} has been loaded".format(user_id))
|
||||
else:
|
||||
@@ -110,7 +110,7 @@ class UserManager:
|
||||
raise ValueError("User data for {} already exists".format(user_id))
|
||||
|
||||
with config_file.open("r") as fp:
|
||||
config = yaml.load(fp)
|
||||
config = yaml.safe_load(fp)
|
||||
# load model
|
||||
model = init_instance_by_config(config["model"])
|
||||
|
||||
|
||||
@@ -148,7 +148,7 @@ class Operator:
|
||||
for user_id, user in um.users.items():
|
||||
dates, trade_exchange = prepare(um, trade_date, user_id, exchange_config)
|
||||
executor = SimulatorExecutor(trade_exchange=trade_exchange)
|
||||
if not str(dates[0].date()) == str(pred_date.date()):
|
||||
if str(dates[0].date()) != str(pred_date.date()):
|
||||
raise ValueError(
|
||||
"The account data is not newest! last trading date {}, today {}".format(
|
||||
dates[0].date(), trade_date.date()
|
||||
|
||||
@@ -88,7 +88,7 @@ def prepare(um, today, user_id, exchange_config=None):
|
||||
dates.append(get_next_trading_date(dates[-1], future=True))
|
||||
if exchange_config:
|
||||
with pathlib.Path(exchange_config).open("r") as fp:
|
||||
exchange_paras = yaml.load(fp)
|
||||
exchange_paras = yaml.safe_load(fp)
|
||||
else:
|
||||
exchange_paras = {}
|
||||
trade_exchange = Exchange(trade_dates=dates, **exchange_paras)
|
||||
|
||||
@@ -214,7 +214,7 @@ def cumulative_return_graph(
|
||||
features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max())
|
||||
features_df.columns = ['label']
|
||||
|
||||
qcr.cumulative_return_graph(positions, report_normal_df, features_df)
|
||||
qcr.analysis_position.cumulative_return_graph(positions, report_normal_df, features_df)
|
||||
|
||||
|
||||
Graph desc:
|
||||
|
||||
@@ -94,7 +94,7 @@ def rank_label_graph(
|
||||
features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'], pred_df_dates.min(), pred_df_dates.max())
|
||||
features_df.columns = ['label']
|
||||
|
||||
qcr.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
|
||||
qcr.analysis_position.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
|
||||
|
||||
|
||||
:param position: position data; **qlib.contrib.backtest.backtest.backtest** result.
|
||||
|
||||
@@ -186,7 +186,7 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list,
|
||||
|
||||
report_normal_df, _ = backtest(pred_df, strategy, **bparas)
|
||||
|
||||
qcr.report_graph(report_normal_df)
|
||||
qcr.analysis_position.report_graph(report_normal_df)
|
||||
|
||||
:param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench**.
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from ...utils import get_module_by_module_path
|
||||
|
||||
|
||||
class BaseGraph:
|
||||
""""""
|
||||
""" """
|
||||
|
||||
_name = None
|
||||
|
||||
@@ -161,7 +161,7 @@ class DistplotGraph(BaseGraph):
|
||||
"""
|
||||
_t_df = self._df.dropna()
|
||||
_data_list = [_t_df[_col] for _col in self._name_dict]
|
||||
_label_list = [_name for _name in self._name_dict.values()]
|
||||
_label_list = list(self._name_dict.values())
|
||||
_fig = create_distplot(_data_list, _label_list, show_rug=False, **self._graph_kwargs)
|
||||
|
||||
return _fig["data"]
|
||||
|
||||
@@ -7,7 +7,6 @@ import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ..backtest.order import Order
|
||||
from ...utils import get_pre_trading_date
|
||||
from .order_generator import OrderGenWInteract
|
||||
|
||||
|
||||
@@ -252,7 +251,7 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
|
||||
|
||||
def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):
|
||||
"""
|
||||
Gnererate order list according to score_series at trade_date, will not change current.
|
||||
Generate order list according to score_series at trade_date, will not change current.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
@@ -390,11 +389,11 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
|
||||
current_stock_list = current_temp.get_stock_list()
|
||||
value = cash * self.risk_degree / len(buy) if len(buy) > 0 else 0
|
||||
|
||||
# open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not consider it
|
||||
# as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line
|
||||
# open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not
|
||||
# consider it as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line
|
||||
# value = value / (1+trade_exchange.open_cost) # set open_cost limit
|
||||
for code in buy:
|
||||
# check is stock supended
|
||||
# check is stock suspended
|
||||
if not trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date):
|
||||
continue
|
||||
# buy order
|
||||
|
||||
@@ -14,7 +14,7 @@ class TunerConfigManager:
|
||||
self.config_path = config_path
|
||||
|
||||
with open(config_path) as fp:
|
||||
config = yaml.load(fp)
|
||||
config = yaml.safe_load(fp)
|
||||
self.config = copy.deepcopy(config)
|
||||
|
||||
self.pipeline_ex_config = PipelineExperimentConfig(config.get("experiment", dict()), self)
|
||||
|
||||
4
qlib/contrib/workflow/__init__.py
Normal file
4
qlib/contrib/workflow/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from .record_temp import MultiSegRecord
|
||||
from .record_temp import SignalMseRecord
|
||||
89
qlib/contrib/workflow/record_temp.py
Normal file
89
qlib/contrib/workflow/record_temp.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from typing import Dict, Text, Any
|
||||
|
||||
from ...contrib.eva.alpha import calc_ic
|
||||
from ...workflow.record_temp import RecordTemp
|
||||
from ...workflow.record_temp import SignalRecord
|
||||
from ...data import dataset as qlib_dataset
|
||||
from ...log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
class MultiSegRecord(RecordTemp):
|
||||
"""
|
||||
This is the multiple segments signal record class that generates the signal prediction.
|
||||
This class inherits the ``RecordTemp`` class.
|
||||
"""
|
||||
|
||||
def __init__(self, model, dataset, recorder=None):
|
||||
super().__init__(recorder=recorder)
|
||||
if not isinstance(dataset, qlib_dataset.DatasetH):
|
||||
raise ValueError("The type of dataset is not DatasetH instead of {:}".format(type(dataset)))
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
|
||||
def generate(self, segments: Dict[Text, Any], save: bool = False):
|
||||
for key, segment in segments.items():
|
||||
predics = self.model.predict(self.dataset, segment)
|
||||
if isinstance(predics, pd.Series):
|
||||
predics = predics.to_frame("score")
|
||||
labels = self.dataset.prepare(
|
||||
segments=segment, col_set="label", data_key=qlib_dataset.handler.DataHandlerLP.DK_R
|
||||
)
|
||||
# Compute the IC and Rank IC
|
||||
ic, ric = calc_ic(predics.iloc[:, 0], labels.iloc[:, 0])
|
||||
results = {"all-IC": ic, "mean-IC": ic.mean(), "all-Rank-IC": ric, "mean-Rank-IC": ric.mean()}
|
||||
logger.info("--- Results for {:} ({:}) ---".format(key, segment))
|
||||
ic_x100, ric_x100 = ic * 100, ric * 100
|
||||
logger.info("IC: {:.4f}%".format(ic_x100.mean()))
|
||||
logger.info("ICIR: {:.4f}%".format(ic_x100.mean() / ic_x100.std()))
|
||||
logger.info("Rank IC: {:.4f}%".format(ric_x100.mean()))
|
||||
logger.info("Rank ICIR: {:.4f}%".format(ric_x100.mean() / ric_x100.std()))
|
||||
|
||||
if save:
|
||||
save_name = "results-{:}.pkl".format(key)
|
||||
self.recorder.save_objects(**{save_name: results})
|
||||
logger.info(
|
||||
"The record '{:}' has been saved as the artifact of the Experiment {:}".format(
|
||||
save_name, self.recorder.experiment_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SignalMseRecord(SignalRecord):
|
||||
"""
|
||||
This is the Signal MSE Record class that computes the mean squared error (MSE).
|
||||
This class inherits the ``SignalMseRecord`` class.
|
||||
"""
|
||||
|
||||
artifact_path = "sig_analysis"
|
||||
|
||||
def __init__(self, recorder, **kwargs):
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
|
||||
def generate(self, **kwargs):
|
||||
try:
|
||||
self.check(parent=True)
|
||||
except FileExistsError:
|
||||
super().generate()
|
||||
|
||||
pred = self.load("pred.pkl")
|
||||
label = self.load("label.pkl")
|
||||
masks = ~np.isnan(label.values)
|
||||
mse = mean_squared_error(pred.values[masks], label[masks])
|
||||
metrics = {"MSE": mse, "RMSE": np.sqrt(mse)}
|
||||
objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)}
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
logger.info("The evaluation results in SignalMseRecord is {:}".format(metrics))
|
||||
|
||||
def list(self):
|
||||
paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")]
|
||||
return paths
|
||||
@@ -1045,9 +1045,6 @@ class SimpleDatasetCache(DatasetCache):
|
||||
class DatasetURICache(DatasetCache):
|
||||
"""Prepared cache mechanism for server."""
|
||||
|
||||
def __init__(self, provider):
|
||||
super(DatasetURICache, self).__init__(provider)
|
||||
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs):
|
||||
return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache)
|
||||
|
||||
|
||||
@@ -6,7 +6,9 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import re
|
||||
import abc
|
||||
import copy
|
||||
import time
|
||||
import queue
|
||||
import bisect
|
||||
@@ -27,12 +29,41 @@ from .cache import DiskDatasetCache, DiskExpressionCache
|
||||
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
|
||||
|
||||
|
||||
class CalendarProvider(abc.ABC):
|
||||
class ProviderBackendMixin:
|
||||
def get_default_backend(self):
|
||||
backend = {}
|
||||
provider_name: str = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2]
|
||||
# set default storage class
|
||||
backend.setdefault("class", f"File{provider_name}Storage")
|
||||
# set default storage module
|
||||
backend.setdefault("module_path", "qlib.data.storage.file_storage")
|
||||
return backend
|
||||
|
||||
def backend_obj(self, **kwargs):
|
||||
backend = self.backend if self.backend else self.get_default_backend()
|
||||
backend = copy.deepcopy(backend)
|
||||
|
||||
# set default storage kwargs
|
||||
backend_kwargs = backend.setdefault("kwargs", {})
|
||||
# default provider_uri map
|
||||
if "provider_uri" not in backend_kwargs:
|
||||
# if the user has no uri configured, use: uri = uri_map[freq]
|
||||
freq = kwargs.get("freq", "day")
|
||||
provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {freq: C.get_data_path()})
|
||||
backend_kwargs["provider_uri"] = provider_uri_map[freq]
|
||||
backend.setdefault("kwargs", {}).update(**kwargs)
|
||||
return init_instance_by_config(backend)
|
||||
|
||||
|
||||
class CalendarProvider(abc.ABC, ProviderBackendMixin):
|
||||
"""Calendar provider base class
|
||||
|
||||
Provide calendar data.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.backend = kwargs.get("backend", {})
|
||||
|
||||
@abc.abstractmethod
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
|
||||
"""Get calendar of certain market in given time range.
|
||||
@@ -127,12 +158,15 @@ class CalendarProvider(abc.ABC):
|
||||
return hash_args(start_time, end_time, freq, future)
|
||||
|
||||
|
||||
class InstrumentProvider(abc.ABC):
|
||||
class InstrumentProvider(abc.ABC, ProviderBackendMixin):
|
||||
"""Instrument provider base class
|
||||
|
||||
Provide instrument data.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.backend = kwargs.get("backend", {})
|
||||
|
||||
@staticmethod
|
||||
def instruments(market="all", filter_pipe=None):
|
||||
"""Get the general config dictionary for a base market adding several dynamic filters.
|
||||
@@ -215,12 +249,15 @@ class InstrumentProvider(abc.ABC):
|
||||
raise ValueError(f"Unknown instrument type {inst}")
|
||||
|
||||
|
||||
class FeatureProvider(abc.ABC):
|
||||
class FeatureProvider(abc.ABC, ProviderBackendMixin):
|
||||
"""Feature provider class
|
||||
|
||||
Provide feature data.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.backend = kwargs.get("backend", {})
|
||||
|
||||
@abc.abstractmethod
|
||||
def feature(self, instrument, field, start_time, end_time, freq):
|
||||
"""Get feature data.
|
||||
@@ -478,13 +515,13 @@ class DatasetProvider(abc.ABC):
|
||||
|
||||
data = pd.DataFrame(obj)
|
||||
_calendar = Cal.calendar(freq=freq)
|
||||
data.index = _calendar[data.index.values.astype(np.int)]
|
||||
data.index = _calendar[data.index.values.astype(int)]
|
||||
data.index.names = ["datetime"]
|
||||
|
||||
if spans is None:
|
||||
return data
|
||||
else:
|
||||
mask = np.zeros(len(data), dtype=np.bool)
|
||||
mask = np.zeros(len(data), dtype=bool)
|
||||
for begin, end in spans:
|
||||
mask |= (data.index >= begin) & (data.index <= end)
|
||||
return data[mask]
|
||||
@@ -497,6 +534,7 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(LocalCalendarProvider, self).__init__(**kwargs)
|
||||
self.remote = kwargs.get("remote", False)
|
||||
|
||||
@property
|
||||
@@ -517,18 +555,22 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
list
|
||||
list of timestamps
|
||||
"""
|
||||
if future:
|
||||
fname = self._uri_cal.format(freq + "_future")
|
||||
# if future calendar not exists, return current calendar
|
||||
if not os.path.exists(fname):
|
||||
get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!")
|
||||
fname = self._uri_cal.format(freq)
|
||||
else:
|
||||
fname = self._uri_cal.format(freq)
|
||||
if not os.path.exists(fname):
|
||||
raise ValueError("calendar not exists for freq " + freq)
|
||||
with open(fname) as f:
|
||||
return [pd.Timestamp(x.strip()) for x in f]
|
||||
|
||||
try:
|
||||
backend_obj = self.backend_obj(freq=freq, future=future).data
|
||||
except ValueError:
|
||||
if future:
|
||||
get_module_logger("data").warning(
|
||||
f"load calendar error: freq={freq}, future={future}; return current calendar!"
|
||||
)
|
||||
get_module_logger("data").warning(
|
||||
"You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md"
|
||||
)
|
||||
backend_obj = self.backend_obj(freq=freq, future=False).data
|
||||
else:
|
||||
raise
|
||||
|
||||
return [pd.Timestamp(x) for x in backend_obj]
|
||||
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
|
||||
_calendar, _calendar_index = self._get_calendar(freq, future)
|
||||
@@ -559,38 +601,20 @@ class LocalInstrumentProvider(InstrumentProvider):
|
||||
Provide instrument data from local data source.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def _uri_inst(self):
|
||||
"""Instrument file uri."""
|
||||
return os.path.join(C.get_data_path(), "instruments", "{}.txt")
|
||||
|
||||
def _load_instruments(self, market):
|
||||
fname = self._uri_inst.format(market)
|
||||
if not os.path.exists(fname):
|
||||
raise ValueError("instruments not exists for market " + market)
|
||||
|
||||
_instruments = dict()
|
||||
df = pd.read_csv(
|
||||
fname,
|
||||
sep="\t",
|
||||
usecols=[0, 1, 2],
|
||||
names=["inst", "start_datetime", "end_datetime"],
|
||||
dtype={"inst": str},
|
||||
parse_dates=["start_datetime", "end_datetime"],
|
||||
)
|
||||
for row in df.itertuples(index=False):
|
||||
_instruments.setdefault(row[0], []).append((row[1], row[2]))
|
||||
return _instruments
|
||||
def _load_instruments(self, market, freq):
|
||||
return self.backend_obj(market=market, freq=freq).data
|
||||
|
||||
def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
|
||||
market = instruments["market"]
|
||||
if market in H["i"]:
|
||||
_instruments = H["i"][market]
|
||||
else:
|
||||
_instruments = self._load_instruments(market)
|
||||
_instruments = self._load_instruments(market, freq=freq)
|
||||
H["i"][market] = _instruments
|
||||
# strip
|
||||
# use calendar boundary
|
||||
@@ -601,7 +625,7 @@ class LocalInstrumentProvider(InstrumentProvider):
|
||||
inst: list(
|
||||
filter(
|
||||
lambda x: x[0] <= x[1],
|
||||
[(max(start_time, x[0]), min(end_time, x[1])) for x in spans],
|
||||
[(max(start_time, pd.Timestamp(x[0])), min(end_time, pd.Timestamp(x[1]))) for x in spans],
|
||||
)
|
||||
)
|
||||
for inst, spans in _instruments.items()
|
||||
@@ -627,6 +651,7 @@ class LocalFeatureProvider(FeatureProvider):
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(LocalFeatureProvider, self).__init__(**kwargs)
|
||||
self.remote = kwargs.get("remote", False)
|
||||
|
||||
@property
|
||||
@@ -638,14 +663,7 @@ class LocalFeatureProvider(FeatureProvider):
|
||||
# validate
|
||||
field = str(field).lower()[1:]
|
||||
instrument = code_to_fname(instrument)
|
||||
uri_data = self._uri_data.format(instrument.lower(), field, freq)
|
||||
if not os.path.exists(uri_data):
|
||||
get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field))
|
||||
return pd.Series(dtype=np.float32)
|
||||
# raise ValueError('uri_data not found: ' + uri_data)
|
||||
# load
|
||||
series = read_bin(uri_data, start_index, end_index)
|
||||
return series
|
||||
return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]
|
||||
|
||||
|
||||
class LocalExpressionProvider(ExpressionProvider):
|
||||
@@ -654,9 +672,6 @@ class LocalExpressionProvider(ExpressionProvider):
|
||||
Provide expression data from local data source.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def expression(self, instrument, field, start_time=None, end_time=None, freq="day"):
|
||||
expression = self.get_expression_instance(field)
|
||||
start_time = pd.Timestamp(start_time)
|
||||
@@ -1019,7 +1034,8 @@ class ClientProvider(BaseProvider):
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
if isinstance(Cal, ClientCalendarProvider):
|
||||
Cal.set_conn(self.client)
|
||||
Inst.set_conn(self.client)
|
||||
if isinstance(Inst, ClientInstrumentProvider):
|
||||
Inst.set_conn(self.client)
|
||||
if hasattr(DatasetD, "provider"):
|
||||
DatasetD.provider.set_conn(self.client)
|
||||
else:
|
||||
@@ -1064,7 +1080,8 @@ def register_all_wrappers(C):
|
||||
register_wrapper(Cal, _calendar_provider, "qlib.data")
|
||||
logger.debug(f"registering Cal {C.calendar_provider}-{C.calendar_cache}")
|
||||
|
||||
register_wrapper(Inst, C.instrument_provider, "qlib.data")
|
||||
_instrument_provider = init_instance_by_config(C.instrument_provider, module)
|
||||
register_wrapper(Inst, _instrument_provider, "qlib.data")
|
||||
logger.debug(f"registering Inst {C.instrument_provider}")
|
||||
|
||||
if getattr(C, "feature_provider", None) is not None:
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from ...utils.serial import Serializable
|
||||
from typing import Union, List, Tuple
|
||||
from typing import Union, List, Tuple, Dict, Text, Optional
|
||||
from ...utils import init_instance_by_config, np_ffill
|
||||
from ...log import get_module_logger
|
||||
from .handler import DataHandler, DataHandlerLP
|
||||
from copy import deepcopy
|
||||
from inspect import getfullargspec
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
@@ -16,22 +17,28 @@ class Dataset(Serializable):
|
||||
Preparing data for model training and inferencing.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
init is designed to finish following steps:
|
||||
|
||||
- init the sub instance and the state of the dataset(info to prepare the data)
|
||||
- The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing.
|
||||
|
||||
- setup data
|
||||
- The data related attributes' names should start with '_' so that it will not be saved on disk when serializing.
|
||||
|
||||
- initialize the state of the dataset(info to prepare the data)
|
||||
- The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing.
|
||||
|
||||
The data could specify the info to caculate the essential data for preparation
|
||||
The data could specify the info to calculate the essential data for preparation
|
||||
"""
|
||||
self.setup_data(*args, **kwargs)
|
||||
self.setup_data(**kwargs)
|
||||
super().__init__()
|
||||
|
||||
def setup_data(self, *args, **kwargs):
|
||||
def config(self, **kwargs):
|
||||
"""
|
||||
config is designed to configure and parameters that cannot be learned from the data
|
||||
"""
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, **kwargs):
|
||||
"""
|
||||
Setup the data.
|
||||
|
||||
@@ -39,7 +46,7 @@ class Dataset(Serializable):
|
||||
|
||||
- User have a Dataset object with learned status on disk.
|
||||
|
||||
- User load the Dataset object from the disk(Note the init function is skiped).
|
||||
- User load the Dataset object from the disk.
|
||||
|
||||
- User call `setup_data` to load new data.
|
||||
|
||||
@@ -47,7 +54,7 @@ class Dataset(Serializable):
|
||||
"""
|
||||
pass
|
||||
|
||||
def prepare(self, *args, **kwargs) -> object:
|
||||
def prepare(self, **kwargs) -> object:
|
||||
"""
|
||||
The type of dataset depends on the model. (It could be pd.DataFrame, pytorch.DataLoader, etc.)
|
||||
The parameters should specify the scope for the prepared data
|
||||
@@ -76,22 +83,7 @@ class DatasetH(Dataset):
|
||||
- The processing is related to data split.
|
||||
"""
|
||||
|
||||
def __init__(self, handler: Union[dict, DataHandler], segments: dict):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
handler : Union[dict, DataHandler]
|
||||
handler will be passed into setup_data.
|
||||
segments : dict
|
||||
handler will be passed into setup_data.
|
||||
"""
|
||||
super().__init__(handler, segments)
|
||||
|
||||
def init(self, **kwargs):
|
||||
"""Initialize the DatasetH, Only parameters belonging to handler.init will be passed in"""
|
||||
self.handler.init(**kwargs)
|
||||
|
||||
def setup_data(self, handler: Union[dict, DataHandler], segments: dict):
|
||||
def __init__(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], **kwargs):
|
||||
"""
|
||||
Setup the underlying data.
|
||||
|
||||
@@ -100,7 +92,7 @@ class DatasetH(Dataset):
|
||||
handler : Union[dict, DataHandler]
|
||||
handler could be:
|
||||
|
||||
- insntance of `DataHandler`
|
||||
- instance of `DataHandler`
|
||||
|
||||
- config of `DataHandler`. Please refer to `DataHandler`
|
||||
|
||||
@@ -120,8 +112,57 @@ class DatasetH(Dataset):
|
||||
'outsample': ("2017-01-01", "2020-08-01",),
|
||||
}
|
||||
"""
|
||||
self.handler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self.segments = segments.copy()
|
||||
self.fetch_kwargs = {}
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def config(self, handler_kwargs: dict = None, **kwargs):
|
||||
"""
|
||||
Initialize the DatasetH
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handler_kwargs : dict
|
||||
Config of DataHandler, which could include the following arguments:
|
||||
|
||||
- arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'.
|
||||
|
||||
kwargs : dict
|
||||
Config of DatasetH, such as
|
||||
|
||||
- segments : dict
|
||||
Config of segments which is same as 'segments' in self.__init__
|
||||
|
||||
"""
|
||||
if handler_kwargs is not None:
|
||||
self.handler.config(**handler_kwargs)
|
||||
if "segments" in kwargs:
|
||||
self.segments = deepcopy(kwargs.pop("segments"))
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, handler_kwargs: dict = None, **kwargs):
|
||||
"""
|
||||
Setup the Data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handler_kwargs : dict
|
||||
init arguments of DataHandler, which could include the following arguments:
|
||||
|
||||
- init_type : Init Type of Handler
|
||||
|
||||
- enable_cache : whether to enable cache
|
||||
|
||||
"""
|
||||
super().setup_data(**kwargs)
|
||||
if handler_kwargs is not None:
|
||||
self.handler.setup_data(**handler_kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(handler={handler}, segments={segments})".format(
|
||||
name=self.__class__.__name__, handler=self.handler, segments=self.segments
|
||||
)
|
||||
|
||||
def _prepare_seg(self, slc: slice, **kwargs):
|
||||
"""
|
||||
@@ -131,11 +172,14 @@ class DatasetH(Dataset):
|
||||
----------
|
||||
slc : slice
|
||||
"""
|
||||
return self.handler.fetch(slc, **kwargs)
|
||||
if hasattr(self, "fetch_kwargs"):
|
||||
return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs)
|
||||
else:
|
||||
return self.handler.fetch(slc, **kwargs)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
segments: Union[List[str], Tuple[str], str, slice],
|
||||
segments: Union[List[Text], Tuple[Text], Text, slice],
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key=DataHandlerLP.DK_I,
|
||||
**kwargs,
|
||||
@@ -145,7 +189,7 @@ class DatasetH(Dataset):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
segments : Union[List[str], Tuple[str], str, slice]
|
||||
segments : Union[List[Text], Tuple[Text], Text, slice]
|
||||
Describe the scope of the data to be prepared
|
||||
Here are some examples:
|
||||
|
||||
@@ -159,6 +203,12 @@ class DatasetH(Dataset):
|
||||
The data to fetch: DK_*
|
||||
Default is DK_I, which indicate fetching data for **inference**.
|
||||
|
||||
kwargs :
|
||||
The parameters that kwargs may contain:
|
||||
flt_col : str
|
||||
It only exists in TSDatasetH, can be used to add a column of data(True or False) to filter data.
|
||||
This parameter is only supported when it is an instance of TSDatasetH.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[List[pd.DataFrame], pd.DataFrame]:
|
||||
@@ -191,7 +241,7 @@ class TSDataSampler:
|
||||
(T)ime-(S)eries DataSampler
|
||||
This is the result of TSDatasetH
|
||||
|
||||
It works like `torch.data.utils.Dataset`, it provides a very convient interface for constructing time-series
|
||||
It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series
|
||||
dataset based on tabular data.
|
||||
|
||||
If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
|
||||
@@ -203,7 +253,9 @@ class TSDataSampler:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none"):
|
||||
def __init__(
|
||||
self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None, flt_data=None
|
||||
):
|
||||
"""
|
||||
Build a dataset which looks like torch.data.utils.Dataset.
|
||||
|
||||
@@ -225,6 +277,11 @@ class TSDataSampler:
|
||||
ffill with previous sample
|
||||
ffill+bfill:
|
||||
ffill with previous samples first and fill with later samples second
|
||||
flt_data : pd.Series
|
||||
a column of data(True or False) to filter data.
|
||||
None:
|
||||
kepp all data
|
||||
|
||||
"""
|
||||
self.start = start
|
||||
self.end = end
|
||||
@@ -232,24 +289,51 @@ class TSDataSampler:
|
||||
self.fillna_type = fillna_type
|
||||
assert get_level_index(data, "datetime") == 0
|
||||
self.data = lazy_sort_index(data)
|
||||
self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values! But
|
||||
# NOTE: append last line with full NaN for better performance in `__getitem__`
|
||||
self.data_arr = np.append(self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan), axis=0)
|
||||
|
||||
kwargs = {"object": self.data}
|
||||
if dtype is not None:
|
||||
kwargs["dtype"] = dtype
|
||||
|
||||
self.data_arr = np.array(**kwargs) # Get index from numpy.array will much faster than DataFrame.values!
|
||||
# NOTE:
|
||||
# - append last line with full NaN for better performance in `__getitem__`
|
||||
# - Keep the same dtype will result in a better performance
|
||||
self.data_arr = np.append(
|
||||
self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), axis=0
|
||||
)
|
||||
self.nan_idx = -1 # The last line is all NaN
|
||||
|
||||
# the data type will be changed
|
||||
# The index of usable data is between start_idx and end_idx
|
||||
self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
|
||||
# self.index_link = self.build_link(self.data)
|
||||
self.idx_df, self.idx_map = self.build_index(self.data)
|
||||
self.data_index = deepcopy(self.data.index)
|
||||
|
||||
if flt_data is not None:
|
||||
self.flt_data = np.array(flt_data.reindex(self.data_index)).reshape(-1)
|
||||
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
|
||||
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
|
||||
|
||||
self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
|
||||
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
|
||||
|
||||
del self.data # save memory
|
||||
|
||||
@staticmethod
|
||||
def flt_idx_map(flt_data, idx_map):
|
||||
idx = 0
|
||||
new_idx_map = {}
|
||||
for i, exist in enumerate(flt_data):
|
||||
if exist:
|
||||
new_idx_map[idx] = idx_map[i]
|
||||
idx += 1
|
||||
return new_idx_map
|
||||
|
||||
def get_index(self):
|
||||
"""
|
||||
Get the pandas index of the data, it will be useful in following scenarios
|
||||
- Special sampler will be used (e.g. user want to sample day by day)
|
||||
"""
|
||||
return self.data.index[self.start_idx : self.end_idx]
|
||||
return self.data_index[self.start_idx : self.end_idx]
|
||||
|
||||
def config(self, **kwargs):
|
||||
# Config the attributes
|
||||
@@ -273,7 +357,7 @@ class TSDataSampler:
|
||||
# get the previous index of a line given index
|
||||
"""
|
||||
# object incase of pandas converting int to flaot
|
||||
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=np.object)
|
||||
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)
|
||||
idx_df = lazy_sort_index(idx_df.unstack())
|
||||
# NOTE: the correctness of `__getitem__` depends on columns sorted here
|
||||
idx_df = lazy_sort_index(idx_df, axis=1)
|
||||
@@ -375,7 +459,7 @@ class TSDataSampler:
|
||||
# 1) for better performance, use the last nan line for padding the lost date
|
||||
# 2) In case of precision problems. We use np.float64. # TODO: I'm not sure if whether np.float64 will result in
|
||||
# precision problems. It will not cause any problems in my tests at least
|
||||
indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(np.int)
|
||||
indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(int)
|
||||
|
||||
data = self.data_arr[indices]
|
||||
if isinstance(idx, mtit):
|
||||
@@ -393,7 +477,7 @@ class TSDatasetH(DatasetH):
|
||||
(T)ime-(S)eries Dataset (H)andler
|
||||
|
||||
|
||||
Covnert the tabular data to Time-Series data
|
||||
Convert the tabular data to Time-Series data
|
||||
|
||||
Requirements analysis
|
||||
|
||||
@@ -407,18 +491,22 @@ class TSDatasetH(DatasetH):
|
||||
- The dimension of a batch of data <batch_idx, feature, timestep>
|
||||
"""
|
||||
|
||||
def __init__(self, step_len=30, *args, **kwargs):
|
||||
def __init__(self, step_len=30, **kwargs):
|
||||
self.step_len = step_len
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def setup_data(self, *args, **kwargs):
|
||||
super().setup_data(*args, **kwargs)
|
||||
def config(self, **kwargs):
|
||||
if "step_len" in kwargs:
|
||||
self.step_len = kwargs.pop("step_len")
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, **kwargs):
|
||||
super().setup_data(**kwargs)
|
||||
cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique()
|
||||
cal = sorted(cal)
|
||||
# Get the datatime index for building timestamp
|
||||
self.cal = cal
|
||||
|
||||
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
|
||||
def _prepare_raw_seg(self, slc: slice, **kwargs) -> pd.DataFrame:
|
||||
# Dataset decide how to slice data(Get more data for timeseries).
|
||||
start, end = slc.start, slc.stop
|
||||
start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start))
|
||||
@@ -427,6 +515,25 @@ class TSDatasetH(DatasetH):
|
||||
|
||||
# TSDatasetH will retrieve more data for complete
|
||||
data = super()._prepare_seg(slice(pad_start, end), **kwargs)
|
||||
return data
|
||||
|
||||
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len)
|
||||
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
|
||||
"""
|
||||
split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data
|
||||
"""
|
||||
dtype = kwargs.pop("dtype", None)
|
||||
start, end = slc.start, slc.stop
|
||||
flt_col = kwargs.pop("flt_col", None)
|
||||
# TSDatasetH will retrieve more data for complete
|
||||
data = self._prepare_raw_seg(slc, **kwargs)
|
||||
|
||||
flt_kwargs = deepcopy(kwargs)
|
||||
if flt_col is not None:
|
||||
flt_kwargs["col_set"] = flt_col
|
||||
flt_data = self._prepare_raw_seg(slc, **flt_kwargs)
|
||||
assert len(flt_data.columns) == 1
|
||||
else:
|
||||
flt_data = None
|
||||
|
||||
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype, flt_data=flt_data)
|
||||
return tsds
|
||||
|
||||
@@ -6,7 +6,8 @@ import abc
|
||||
import bisect
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Union, Tuple, List, Iterator, Optional
|
||||
from inspect import getfullargspec
|
||||
from typing import Callable, Union, Tuple, List, Iterator, Optional
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
@@ -16,7 +17,7 @@ from ...data import D
|
||||
from ...config import C
|
||||
from ...utils import parse_config, transform_end_date, init_instance_by_config
|
||||
from ...utils.serial import Serializable
|
||||
from .utils import get_level_index, fetch_df_by_index
|
||||
from .utils import fetch_df_by_index
|
||||
from pathlib import Path
|
||||
from .loader import DataLoader
|
||||
|
||||
@@ -35,7 +36,7 @@ class DataHandler(Serializable):
|
||||
The data handler try to maintain a handler with 2 level.
|
||||
`datetime` & `instruments`.
|
||||
|
||||
Any order of the index level can be suported(The order will implied in the data).
|
||||
Any order of the index level can be supported (The order will be implied in the data).
|
||||
The order <`datetime`, `instruments`> will be used when the dataframe index name is missed.
|
||||
|
||||
Example of the data:
|
||||
@@ -47,9 +48,12 @@ class DataHandler(Serializable):
|
||||
$close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
|
||||
datetime instrument
|
||||
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
|
||||
|
||||
|
||||
Tips for improving the performance of datahandler
|
||||
- Fetching data with `col_set=CS_RAW` will return the raw data and may avoid pandas from copying the data when calling `loc`
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -57,7 +61,7 @@ class DataHandler(Serializable):
|
||||
instruments=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
data_loader: Tuple[dict, str, DataLoader] = None,
|
||||
data_loader: Union[dict, str, DataLoader] = None,
|
||||
init_data=True,
|
||||
fetch_orig=True,
|
||||
):
|
||||
@@ -70,10 +74,10 @@ class DataHandler(Serializable):
|
||||
start_time of the original data.
|
||||
end_time :
|
||||
end_time of the original data.
|
||||
data_loader : Tuple[dict, str, DataLoader]
|
||||
data_loader : Union[dict, str, DataLoader]
|
||||
data loader to load the data.
|
||||
init_data :
|
||||
intialize the original data in the constructor.
|
||||
initialize the original data in the constructor.
|
||||
fetch_orig : bool
|
||||
Return the original data instead of copy if possible.
|
||||
"""
|
||||
@@ -99,10 +103,10 @@ class DataHandler(Serializable):
|
||||
self.fetch_orig = fetch_orig
|
||||
if init_data:
|
||||
with TimeInspector.logt("Init data"):
|
||||
self.init()
|
||||
self.setup_data()
|
||||
super().__init__()
|
||||
|
||||
def conf_data(self, **kwargs):
|
||||
def config(self, **kwargs):
|
||||
"""
|
||||
configuration of data.
|
||||
# what data to be loaded from data source
|
||||
@@ -115,13 +119,16 @@ class DataHandler(Serializable):
|
||||
for k, v in kwargs.items():
|
||||
if k in attr_list:
|
||||
setattr(self, k, v)
|
||||
else:
|
||||
raise KeyError("Such config is not supported.")
|
||||
|
||||
def init(self, enable_cache: bool = False):
|
||||
for attr in attr_list:
|
||||
if attr in kwargs:
|
||||
kwargs.pop(attr)
|
||||
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, enable_cache: bool = False):
|
||||
"""
|
||||
initialize the data.
|
||||
In case of running intialization for multiple time, it will do nothing for the second time.
|
||||
Set Up the data in case of running initialization for multiple time
|
||||
|
||||
It is responsible for maintaining following variable
|
||||
1) self._data
|
||||
@@ -159,6 +166,7 @@ class DataHandler(Serializable):
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = CS_ALL,
|
||||
squeeze: bool = False,
|
||||
proc_func: Callable = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
fetch data from underlying data source
|
||||
@@ -181,6 +189,14 @@ class DataHandler(Serializable):
|
||||
- if isinstance(col_set, List[str]):
|
||||
|
||||
select several sets of meaningful columns, the returned data has multiple levels
|
||||
proc_func: Callable
|
||||
- Give a hook for processing data before fetching
|
||||
- An example to explain the necessity of the hook:
|
||||
- A Dataset learned some processors to process data which is related to data segmentation
|
||||
- It will apply them every time when preparing data.
|
||||
- The learned processor require the dataframe remains the same format when fitting and applying
|
||||
- However the data format will change according to the parameters.
|
||||
- So the processors should be applied to the underlayer data.
|
||||
|
||||
squeeze : bool
|
||||
whether squeeze columns and index
|
||||
@@ -189,8 +205,15 @@ class DataHandler(Serializable):
|
||||
-------
|
||||
pd.DataFrame.
|
||||
"""
|
||||
if proc_func is None:
|
||||
df = self._data
|
||||
else:
|
||||
# FIXME: fetching by time first will be more friendly to `proc_func`
|
||||
# Copy in case of `proc_func` changing the data inplace....
|
||||
df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy())
|
||||
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
df = self._fetch_df_by_col(self._data, col_set)
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
|
||||
if squeeze:
|
||||
# squeeze columns
|
||||
@@ -257,6 +280,10 @@ class DataHandler(Serializable):
|
||||
class DataHandlerLP(DataHandler):
|
||||
"""
|
||||
DataHandler with **(L)earnable (P)rocessor**
|
||||
|
||||
Tips to improving the performance of data handler
|
||||
- To reduce the memory cost
|
||||
- `drop_raw=True`: this will modify the data inplace on raw data;
|
||||
"""
|
||||
|
||||
# data key
|
||||
@@ -278,7 +305,7 @@ class DataHandlerLP(DataHandler):
|
||||
instruments=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
data_loader: Tuple[dict, str, DataLoader] = None,
|
||||
data_loader: Union[dict, str, DataLoader] = None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
process_type=PTYPE_A,
|
||||
@@ -405,14 +432,28 @@ class DataHandlerLP(DataHandler):
|
||||
if self.drop_raw:
|
||||
del self._data
|
||||
|
||||
def config(self, processor_kwargs: dict = None, **kwargs):
|
||||
"""
|
||||
configuration of data.
|
||||
# what data to be loaded from data source
|
||||
|
||||
This method will be used when loading pickled handler from dataset.
|
||||
The data will be initialized with different time range.
|
||||
|
||||
"""
|
||||
super().config(**kwargs)
|
||||
if processor_kwargs is not None:
|
||||
for processor in self.get_all_processors():
|
||||
processor.config(**processor_kwargs)
|
||||
|
||||
# init type
|
||||
IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor
|
||||
IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df
|
||||
IT_LS = "load_state" # The state of the object has been load by pickle
|
||||
|
||||
def init(self, init_type: str = IT_FIT_SEQ, enable_cache: bool = False):
|
||||
def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs):
|
||||
"""
|
||||
Initialize the data of Qlib
|
||||
Set up the data in case of running initialization for multiple time
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -427,7 +468,7 @@ class DataHandlerLP(DataHandler):
|
||||
when we call `init` next time
|
||||
"""
|
||||
# init raw data
|
||||
super().init(enable_cache=enable_cache)
|
||||
super().setup_data(**kwargs)
|
||||
|
||||
with TimeInspector.logt("fit & process data"):
|
||||
if init_type == DataHandlerLP.IT_FIT_IND:
|
||||
@@ -456,6 +497,7 @@ class DataHandlerLP(DataHandler):
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key: str = DK_I,
|
||||
proc_func: Callable = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
fetch data from underlying data source
|
||||
@@ -470,12 +512,18 @@ class DataHandlerLP(DataHandler):
|
||||
select a set of meaningful columns.(e.g. features, columns).
|
||||
data_key : str
|
||||
the data to fetch: DK_*.
|
||||
proc_func: Callable
|
||||
please refer to the doc of DataHandler.fetch
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
"""
|
||||
df = self._get_df_by_key(data_key)
|
||||
if proc_func is not None:
|
||||
# FIXME: fetch by time first will be more friendly to proc_func
|
||||
# Copy incase of `proc_func` changing the data inplace....
|
||||
df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy())
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
|
||||
|
||||
@@ -13,6 +13,7 @@ from qlib.data import D
|
||||
from qlib.data import filter as filter_module
|
||||
from qlib.data.filter import BaseDFilter
|
||||
from qlib.utils import load_dataset, init_instance_by_config
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
|
||||
class DataLoader(abc.ABC):
|
||||
@@ -217,3 +218,68 @@ class StaticDataLoader(DataLoader):
|
||||
join=self.join,
|
||||
)
|
||||
self._data.sort_index(inplace=True)
|
||||
|
||||
|
||||
class DataLoaderDH(DataLoader):
|
||||
"""DataLoaderDH
|
||||
DataLoader based on (D)ata (H)andler
|
||||
It is designed to load multiple data from data handler
|
||||
- If you just want to load data from single datahandler, you can write them in single data handler
|
||||
|
||||
TODO: What make this module not that easy to use.
|
||||
- For online scenario
|
||||
- The underlayer data handler should be configured. But data loader doesn't provide such interface & hook.
|
||||
"""
|
||||
|
||||
def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
handler_config : dict
|
||||
handler_config will be used to describe the handlers
|
||||
|
||||
.. code-block::
|
||||
|
||||
<handler_config> := {
|
||||
"group_name1": <handler>
|
||||
"group_name2": <handler>
|
||||
}
|
||||
or
|
||||
<handler_config> := <handler>
|
||||
<handler> := DataHandler Instance | DataHandler Config
|
||||
|
||||
fetch_kwargs : dict
|
||||
fetch_kwargs will be used to describe the different arguments of fetch method, such as col_set, squeeze, data_key, etc.
|
||||
|
||||
is_group: bool
|
||||
is_group will be used to describe whether the key of handler_config is group
|
||||
|
||||
"""
|
||||
from qlib.data.dataset.handler import DataHandler
|
||||
|
||||
if is_group:
|
||||
self.handlers = {
|
||||
grp: init_instance_by_config(config, accept_types=DataHandler) for grp, config in handler_config.items()
|
||||
}
|
||||
else:
|
||||
self.handlers = init_instance_by_config(handler_config, accept_types=DataHandler)
|
||||
|
||||
self.is_group = is_group
|
||||
self.fetch_kwargs = {"col_set": DataHandler.CS_RAW}
|
||||
self.fetch_kwargs.update(fetch_kwargs)
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
if instruments is not None:
|
||||
get_module_logger(self.__class__.__name__).warning(f"instruments[{instruments}] is ignored")
|
||||
|
||||
if self.is_group:
|
||||
df = pd.concat(
|
||||
{
|
||||
grp: dh.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs)
|
||||
for grp, dh in self.handlers.items()
|
||||
},
|
||||
axis=1,
|
||||
)
|
||||
else:
|
||||
df = self.handlers.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs)
|
||||
return df
|
||||
|
||||
18
qlib/data/dataset/processor.py
Executable file → Normal file
18
qlib/data/dataset/processor.py
Executable file → Normal file
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
from typing import Union, Text
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
@@ -14,7 +15,7 @@ from ...utils.paral import datetime_groupby_apply
|
||||
EPS = 1e-12
|
||||
|
||||
|
||||
def get_group_columns(df: pd.DataFrame, group: str):
|
||||
def get_group_columns(df: pd.DataFrame, group: Union[Text, None]):
|
||||
"""
|
||||
get a group of columns from multi-index columns DataFrame
|
||||
|
||||
@@ -72,6 +73,17 @@ class Processor(Serializable):
|
||||
"""
|
||||
return True
|
||||
|
||||
def config(self, **kwargs):
|
||||
attr_list = {"fit_start_time", "fit_end_time"}
|
||||
for k, v in kwargs.items():
|
||||
if k in attr_list and hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
for attr in attr_list:
|
||||
if attr in kwargs:
|
||||
kwargs.pop(attr)
|
||||
super().config(**kwargs)
|
||||
|
||||
|
||||
class DropnaProcessor(Processor):
|
||||
def __init__(self, fields_group=None):
|
||||
@@ -118,7 +130,7 @@ class FilterCol(Processor):
|
||||
|
||||
|
||||
class TanhProcess(Processor):
|
||||
""" Use tanh to process noise data"""
|
||||
"""Use tanh to process noise data"""
|
||||
|
||||
def __call__(self, df):
|
||||
def tanh_denoise(data):
|
||||
@@ -133,7 +145,7 @@ class TanhProcess(Processor):
|
||||
|
||||
|
||||
class ProcessInf(Processor):
|
||||
"""Process infinity """
|
||||
"""Process infinity"""
|
||||
|
||||
def __call__(self, df):
|
||||
def replace_inf(data):
|
||||
|
||||
@@ -355,6 +355,7 @@ class ExpressionDFilter(SeriesDFilter):
|
||||
all_filter_series = _features[rule_expression_field_name]
|
||||
return all_filter_series
|
||||
|
||||
@staticmethod
|
||||
def from_config(config):
|
||||
return ExpressionDFilter(
|
||||
rule_expression=config["rule_expression"],
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user