1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-12 00:40:58 +08:00

Compare commits

...

660 Commits

Author SHA1 Message Date
Young
77ba7b4e91 Add analyser example and finetune example 2021-06-06 07:51:52 +00:00
Young
7a639eeea7 add IC and rank IC 2021-06-06 07:43:26 +00:00
Young
cddaf90ef5 monitor initial version 2021-06-06 07:43:26 +00:00
you-n-g
4ff0c4fb0f Update strategy.rst 2021-05-31 08:52:41 +08:00
you-n-g
b3eece155f Merge pull request #452 from arisliang/patch-1
Add import stock pool in doc
2021-05-30 09:44:47 +08:00
al
02e34eb9e9 Add import stock pool (csi300) in documentation 2021-05-30 08:27:21 +08:00
you-n-g
3033fdf4b7 Merge pull request #444 from zhupr/feature_importance
add get_feature_importance to model interpret
2021-05-28 16:23:53 +08:00
zhupr
ef11a9d95c modify the default value of exists_skip in the GetData.qlib_data parameter to False 2021-05-28 14:57:06 +08:00
zhupr
98eacf8f88 add test/config.py 2021-05-28 13:24:47 +08:00
you-n-g
43cad1ec27 Merge pull request #442 from arisliang/patch-3
Update 1min demo data in CSV format
2021-05-28 11:22:10 +08:00
you-n-g
e7aa7ffcdd Merge pull request #447 from arisliang/patch-2
Remove repeated package from requirements
2021-05-28 10:46:58 +08:00
you-n-g
ed3c9d9212 Merge pull request #448 from arisliang/patch-4
Update integration.rst
2021-05-28 10:46:25 +08:00
you-n-g
2f3fbae73b Merge pull request #450 from arisliang/patch-5
Update report.rst
2021-05-28 10:45:52 +08:00
al
e409bee9b9 Update report.rst
typo
2021-05-28 07:54:45 +08:00
al
7ceec37848 Update integration.rst
Fix typo
2021-05-27 22:35:43 +08:00
al
c12c861b7a Remove repeated package from requirements 2021-05-27 19:37:57 +08:00
zhupr
0a4e241608 add get_feature_importance to model interpret 2021-05-27 14:19:10 +08:00
al
5a382d7e99 Update data.rst
Update csv format according to feedback
2021-05-27 12:40:55 +08:00
al
9b431bc503 Update 1min demo data in CSV format 2021-05-26 22:01:15 +08:00
you-n-g
cbbf6cd822 Merge pull request #441 from zhupr/fix_yahoo_collector
Fix YahooCollector can't download 1min data
2021-05-26 21:41:14 +08:00
you-n-g
928bae08f4 Merge pull request #440 from arisliang/patch-2
Update collector.py
2021-05-26 21:40:23 +08:00
you-n-g
c65fc226bd Merge pull request #439 from arisliang/patch-1
Update README.md
2021-05-26 21:39:53 +08:00
zhupr
114162693f Fix YahooCollector can't download 1min data 2021-05-26 18:29:41 +08:00
al
b884c8c571 Update collector.py
fix typo
2021-05-26 18:00:23 +08:00
al
6222940b9c Update README.md
fix typo
2021-05-26 17:50:49 +08:00
you-n-g
bb0c555803 Merge pull request #372 from zhupr/data_storage
add data storage
2021-05-26 14:30:46 +08:00
zhupr
5da33562dd remove uri parameter from storage && modify file_storage 2021-05-26 12:33:57 +08:00
you-n-g
db3aa8b887 Merge pull request #433 from arisliang/patch-1
Update README.md
2021-05-23 23:53:49 +08:00
al
ae32d79549 Update README.md
fix typo
2021-05-23 16:54:14 +08:00
you-n-g
a80369b80a Update README.md 2021-05-22 19:00:43 +08:00
you-n-g
369177acf9 Update README.md 2021-05-22 19:00:14 +08:00
you-n-g
2ac5ceb4de Update README.md 2021-05-22 17:58:21 +08:00
zhupr
602f78b568 add documentation on which storage methods are used in qlib 2021-05-22 08:30:12 +08:00
zhupr
669f6bd6f5 modify exception message hint for storage.py && fix FileFeatureStorage[:] bug 2021-05-22 02:06:15 +08:00
you-n-g
3d71fd1966 Merge pull request #431 from evanzd/fix_get_module
Fix `get_module_by_module_path` to support pickle module loaded from arbitrage source file
2021-05-21 12:53:14 +08:00
zhupr
b887d2ec32 code for formatting storage.py using black(v21.5) 2021-05-21 10:29:39 +08:00
zhupr
8e6c744a1b Merge remote-tracking branch 'microsoft/main' into data_storage 2021-05-21 09:49:29 +08:00
zhupr
9e296a8a4e replace the type of numpy deprecated 2021-05-21 08:56:44 +08:00
zhupr
4ba4512619 add write method to FeatureStorage && remove extend 2021-05-21 08:45:11 +08:00
Dong Zhou
2fa7ef32fb fix picker error on importlib loaded module 2021-05-18 22:37:31 +08:00
Dong Zhou
c72ee9091e fix picker error on importlib loaded module 2021-05-18 22:12:41 +08:00
you-n-g
19eda8f4f0 Update README.md 2021-05-17 17:45:31 +08:00
you-n-g
d08146c30f Merge pull request #290 from you-n-g/online_srv
init version of online serving and rolling
2021-05-17 17:35:29 +08:00
lzh222333
8c3a08b18d Finally! 2021-05-17 07:27:55 +00:00
you-n-g
142a9dca3c Merge pull request #423 from ChengYen-Tang/LightGBM-Hyperparameter
LightGBM hyperparameter
2021-05-16 19:43:59 +08:00
Kenneth Tang
41ab130807 Fix CI lint with black 2021-05-18 00:01:45 +08:00
Kenneth Tang
8f67010b58 Fix CI lint with black 2021-05-17 23:09:42 +08:00
lzh222333
a986379deb bug fixed 2021-05-14 11:31:50 +00:00
lzh222333
aef3f186c1 format code 2021-05-14 06:58:02 +00:00
lzh222333
ebd01e0de5 Online Serving V11 2021-05-14 06:44:16 +00:00
Kenneth Tang
f51e04a1cc LightGBM hyperparameter 2021-05-13 23:12:29 +08:00
lzh222333
d71a666904 Online serving V11 2021-05-13 09:43:42 +00:00
you-n-g
8b15ffc027 Merge pull request #419 from Derek-Wds/main
Fix bug and update doc
2021-05-13 16:38:22 +08:00
you-n-g
f15ca39df8 Merge pull request #418 from zhupr/set_global_logger_level
Modify set_global_logger_level use of contextmanager
2021-05-13 16:36:40 +08:00
Jactus
bd37f5d953 Fix bug and update doc 2021-05-13 14:21:54 +08:00
zhupr
76c5c5d1b6 Add docstrings to set_global_logger_level 2021-05-12 22:38:50 +08:00
zhupr
b8e64dc526 Modify set_global_logger_level use of contextmanager 2021-05-12 17:58:39 +08:00
you-n-g
df7c882fe3 Merge pull request #416 from zhupr/configurable_dataset
Add configurable dataset to examples
2021-05-12 09:03:56 +08:00
zhupr
9bd77bd89f Add configurable dataset to examples 2021-05-11 17:36:55 +08:00
you-n-g
c43a0b208d Merge pull request #413 from Derek-Wds/stale
Stale bot update
2021-05-10 21:15:36 +08:00
Jactus
8ba5e93d04 Skip enhancement in stale bot 2021-05-10 12:33:58 +08:00
lzh222333
370b6aad74 logger & doc 2021-05-09 11:58:06 +00:00
lzh222333
f5ded06a15 Merge remote-tracking branch 'microsoft/main' into online_srv 2021-05-09 10:53:41 +00:00
lzh222333
4c232610f1 Merge branch 'online_srv' of https://github.com/you-n-g/qlib into online_srv 2021-05-09 10:52:07 +00:00
Dingsu Wang
81bd2ca8fb Merge branch 'microsoft:main' into stale 2021-05-09 17:59:46 +08:00
Jactus
143c257fa2 Update stale bot 2021-05-09 17:56:37 +08:00
Jactus
724f9ba8d2 Update stale bot 2021-05-09 17:52:18 +08:00
you-n-g
aa1f9b464b Merge pull request #412 from zhupr/set_global_logger_level
Add set_global_logger_level
2021-05-09 11:25:56 +08:00
zhupr
d2daba99d3 Add set_global_logger_level 2021-05-09 09:05:57 +08:00
you-n-g
1c605e505a Merge pull request #14 from you-n-g/online_srv_blin
Online srv blin
2021-05-07 21:07:47 +08:00
you-n-g
060a32e0f6 Merge branch 'online_srv' into online_srv_blin 2021-05-07 21:07:27 +08:00
binlins
08edb92461 add flt_data doc 2021-05-07 12:56:58 +00:00
binlins
bec65ddf94 add document and reindex 2021-05-07 11:47:47 +00:00
lzh222333
9dfd001f6f online serving v10 2021-05-07 09:59:15 +00:00
you-n-g
95a4a98de8 Update README.md 2021-05-07 10:56:57 +08:00
you-n-g
d4639b7df9 Update README.md
Update high-frequency trading link
2021-05-07 10:46:45 +08:00
blin
846c64f6c6 fix param 2021-05-06 12:00:41 +00:00
lzh222333
84c56f13bd docs and bug fixed 2021-05-06 04:18:55 +00:00
lzh222333
1c99fb35da Merge remote-tracking branch 'microsoft/main' into online_srv 2021-05-06 02:24:16 +00:00
you-n-g
5bc2b96346 Update data.rst 2021-05-03 12:34:08 +08:00
you-n-g
ee269b0914 Merge pull request #407 from Derek-Wds/log
Fix logger pickling error
2021-04-30 17:33:42 +08:00
you-n-g
2a2d2cf709 Merge pull request #404 from Derek-Wds/main
Support start exp with given exp & recorder id
2021-04-30 16:46:46 +08:00
Jactus
5eb9dfff16 Remove redundant 2021-04-30 15:28:37 +08:00
Jactus
694ae34027 Update api 2021-04-30 13:27:19 +08:00
Jactus
51b649ec39 Update QlibLogger 2021-04-30 13:13:05 +08:00
Jactus
ca92cb980c Update meta logger 2021-04-29 22:40:52 +08:00
Jactus
f58c61a2e0 Fix logger pickling error 2021-04-29 16:54:51 +08:00
lzh222333
2b7ffa100f Merge remote-tracking branch 'microsoft/main' into online_srv 2021-04-29 05:23:42 +00:00
lzh222333
67c5740c83 OnlineServing V9 2021-04-29 04:30:09 +00:00
lzh222333
6f669348a8 Merge branch 'online_srv' of https://github.com/you-n-g/qlib into online_srv 2021-04-28 09:23:13 +00:00
lzh222333
40cf83e557 online serving V9 middle status 2021-04-28 09:23:07 +00:00
blin
fa4511cb0a filter 2021-04-28 07:30:22 +00:00
blin
45c6dfc5da filter 2021-04-28 07:25:19 +00:00
blin
36ab078fbd filter 2021-04-28 07:15:59 +00:00
you-n-g
5a7f9ef720 Merge pull request #405 from zhupr/future_trading_date_collector
Add future trading date collector
2021-04-28 10:07:08 +08:00
zhupr
8b8d21107c Add future trading date collector 2021-04-27 21:20:47 +08:00
Jactus
eab19de080 Support start exp with given exp & recorder id 2021-04-27 16:56:07 +08:00
lzh222333
42f510024c update collector 2021-04-27 04:12:08 +00:00
Young
5a7eecabee black formating (black is upgraded in github) 2021-04-27 04:04:43 +00:00
lzh222333
0058f7d0dc Online Serving V8 2021-04-26 09:31:47 +00:00
you-n-g
9a74fe34f6 Merge pull request #401 from zhupr/online_fix
Fix online mode bugs
2021-04-26 12:29:55 +08:00
zhupr
e15ea06122 Fix ClientProvider not supporting LocalInstrumentProvider && online using the latest python-socketio 2021-04-25 23:50:29 +08:00
lzh222333
319396c815 online serving V8 2021-04-25 06:26:45 +00:00
you-n-g
50be7a9171 Merge pull request #393 from Derek-Wds/main
Update qlib logger
2021-04-23 12:41:05 +08:00
Jactus
e410caaa8f Simplify meta class 2021-04-23 10:08:12 +08:00
Jactus
fbff4c271a Remove redundant methods in meta 2021-04-23 00:38:45 +08:00
you-n-g
ee91503973 Merge pull request #399 from Derek-Wds/doc
Update doc
2021-04-22 20:37:40 +08:00
lzh222333
de0a0c083d bug fixed 2021-04-22 08:09:15 +00:00
Jactus
8adfafa6aa Black format 2021-04-22 14:17:25 +08:00
Jactus
aafaff45d2 Update doc 2021-04-22 14:13:36 +08:00
Jactus
6a05d4e255 Enable IDEs docstrings 2021-04-19 11:36:00 +08:00
Jactus
cbf1fa721e Update 2021-04-17 15:47:49 +08:00
Jactus
4ebf684794 Update workflow logging 2021-04-16 15:35:11 +08:00
Jactus
f4bfe8e619 First trial of adding docstring 2021-04-16 14:35:05 +08:00
lzh222333
cec318fbfe online serving V7 2021-04-16 05:37:13 +00:00
Jactus
78bb8882cd Format 2021-04-16 12:00:18 +08:00
Jactus
848d953226 Update qlib logger 2021-04-16 09:58:55 +08:00
you-n-g
a3a2b5ae0b Merge pull request #358 from javaThonc/high_freq_demp
update high freq demo
2021-04-14 20:05:07 +08:00
Alex Wang
941c980d06 update tabnet 2021-04-14 17:35:19 +08:00
Alex Wang
fe190dec4b update readme 2021-04-14 14:40:28 +08:00
lzh222333
5095b2a470 simulator & examples 2021-04-13 09:45:16 +00:00
zhupr
317357b50d Modify data.storage 2021-04-13 10:47:01 +08:00
Young
b15e5e33fd Fix the multi-processing bug 2021-04-12 06:33:31 +00:00
Young
cca43cf102 Refactor update & modification when running NN 2021-04-11 14:39:19 +00:00
Young
a366c11d67 Update features for hyb nn 2021-04-09 13:48:01 +00:00
Young
18bf4b5477 parameter adjustment 2021-04-08 03:52:58 +00:00
lzh222333
c20eb5c8a6 format code 2021-04-08 03:30:24 +00:00
Young
71605794a2 Merge branch 'online_srv_wd' into online_srv 2021-04-07 05:17:07 +00:00
Young
1dbb561744 Fix some API(for lb nn) 2021-04-07 03:53:56 +00:00
lzh222333
cb42e99bee bug fixed & examples fire 2021-04-07 03:33:27 +00:00
lzh222333
431a9c92c1 online serving v5 2021-04-02 07:09:29 +00:00
lzh222333
bd7a1c11b9 trainer & group & collect & ensemble 2021-04-02 04:27:14 +00:00
zhupr
70fc58104b Modify FileStorage 2021-04-01 12:58:34 +08:00
lzh222333
edcd7b1ff9 bug fixed & code format 2021-03-31 03:08:48 +00:00
lzh222333
3724273d73 Merge remote-tracking branch 'microsoft/qlib/main' into online_srv 2021-03-31 02:54:05 +00:00
lzh222333
544365f3a9 ensemble & get_exp & dataset_pickle 2021-03-31 02:39:14 +00:00
Alex Wang
bed1175e24 update dataset 2021-03-30 19:29:17 +08:00
you-n-g
70c84cbc77 Merge pull request #381 from D-X-Y/main
Fix print issue
2021-03-30 17:25:26 +08:00
you-n-g
da59b35c0a Merge pull request #380 from Derek-Wds/main
Modify get_exp & get_recorder api
2021-03-30 16:54:01 +08:00
lzh222333
eae94d1ee8 Merge remote-tracking branch 'microsoft/qlib/main' into online_srv 2021-03-30 07:16:56 +00:00
lzh222333
1f2d2c9b69 online debug 2021-03-30 06:56:04 +00:00
Jactus
b6df11b6b4 Modify get_exp & get_recorder api 2021-03-30 14:41:56 +08:00
you-n-g
ae57110f64 Merge pull request #374 from bxdd/qlib_loaderhandler
Add DataLoader Based on DataHandler & Add Rolling Process Example & Restructure the Config & Setup_data
2021-03-30 13:37:55 +08:00
bxdd
7a2203f116 update comments 2021-03-30 11:03:07 +08:00
bxdd
023603479c fix readme 2021-03-30 01:00:12 +08:00
bxdd
f8da79b802 fix readme 2021-03-30 00:54:00 +08:00
bxdd
136830bc2b update comments 2021-03-30 00:38:15 +08:00
you-n-g
45f78676ea Merge pull request #379 from zhupr/fix_usinex_collector
Fix us_index collector
2021-03-29 23:54:05 +08:00
bxdd
1074284666 fix docstring 2021-03-29 20:38:09 +08:00
bxdd
d18c367497 update README 2021-03-29 20:34:36 +08:00
bxdd
8743576f72 black format 2021-03-29 20:16:00 +08:00
bxdd
fb7f84f31e fix ubg 2021-03-29 20:15:42 +08:00
bxdd
31bc85bf86 restructure data layer config & setup 2021-03-29 19:49:30 +08:00
D-X-Y
968930e85f Fix print issue 2021-03-29 04:46:38 +00:00
zhupr
4b66304978 Fix us_index collector 2021-03-29 11:25:24 +08:00
you-n-g
253378a44e Merge pull request #378 from D-X-Y/main
Add MultiSegRecord and add segment kwargs in model.pred
2021-03-29 01:06:41 +08:00
D-X-Y
f809f0a063 Remove un-used imports 2021-03-28 10:50:25 +00:00
D-X-Y
0386df7b16 Collect all contrib models in __init__ and add unit tests for init 2021-03-28 10:39:28 +00:00
D-X-Y
8a2e7b62af Add segment args for pred and refine MultiSegRecord 2021-03-28 08:30:16 +00:00
D-X-Y
9d04ae4676 Add MultiSegRecord in contrib.workflow and decouple its tests from test_all_pipeline 2021-03-28 00:33:59 -07:00
zhupr
9b8acd9a82 Fix FileStorage 2021-03-27 01:15:33 +08:00
lzh222333
ee45a7833e Merge branch 'main' into online_srv 2021-03-26 08:20:21 +00:00
zhupr
d395c904f2 Add FileStorage 2021-03-26 16:14:45 +08:00
lzh222333
9bf819e653 Merge branch 'online_srv' of https://github.com/you-n-g/qlib into online_srv 2021-03-26 04:32:20 +00:00
lzh222333
46cd57688e Online Serving V4 2021-03-26 04:20:25 +00:00
you-n-g
0387eaf7ab Merge pull request #373 from lewwang1995/main
debug
2021-03-26 11:59:20 +08:00
bxdd
4ee0240c24 black format 2021-03-25 22:08:39 +08:00
bxdd
5f60d18dfe fix config_data bug 2021-03-25 22:08:23 +08:00
bxdd
194217fb07 fix bug 2021-03-25 21:47:17 +08:00
bxdd
d6ff764bb2 black format 2021-03-25 20:36:45 +08:00
bxdd
9cc3b18e4e fix but 2021-03-25 20:36:07 +08:00
LewenWang
56eaacd931 debug 2021-03-25 20:34:45 +08:00
bxdd
e119c8576c black format 2021-03-25 19:59:22 +08:00
bxdd
68246b3b6d update workflow 2021-03-25 19:58:55 +08:00
bxdd
a04c6bd6c9 balck format 2021-03-25 19:56:22 +08:00
bxdd
efe134e9f4 update workflow 2021-03-25 19:56:04 +08:00
bxdd
4ec300787e update rolling workflow 2021-03-25 19:54:52 +08:00
you-n-g
3886022669 Merge pull request #371 from Derek-Wds/main
Update notebook
2021-03-25 18:13:54 +08:00
zhupr
8264033a72 add data-storage 2021-03-25 17:22:05 +08:00
Jactus
4861552d28 Update notebook 2021-03-25 17:13:52 +08:00
you-n-g
834f9bd9b8 Update README.md 2021-03-25 16:58:35 +08:00
bxdd
f6dc25b229 update rolling process 2021-03-25 16:14:22 +08:00
bxdd
1fcfe8e4ba add rolling process data 2021-03-25 01:37:17 +08:00
bxdd
b1a28358ad black format 2021-03-25 01:30:31 +08:00
bxdd
1ca3c6a61c add DataHandlerDL 2021-03-25 01:29:59 +08:00
Alex Wang
e3739bb980 fix naming and code style 2021-03-24 15:47:26 +08:00
you-n-g
419629e4d2 Merge pull request #365 from Flouse/main
fix docs
2021-03-24 13:12:15 +08:00
Flouse
e490e83a16 fix docs 2021-03-24 11:37:09 +08:00
you-n-g
fda144e66f Merge pull request #362 from D-X-Y/main
Add load_object function for R
2021-03-23 22:47:49 +08:00
you-n-g
4dc10d27e0 Merge pull request #359 from bxdd/doc0
Fix data doc
2021-03-23 21:40:27 +08:00
D-X-Y
0a0c6a3185 Add load_object function for R 2021-03-23 10:10:17 +00:00
lzh222333
d66d4ec93d Merge branch 'main' into online_srv 2021-03-23 11:20:02 +08:00
bxdd
4b56a4e907 fix doc 2021-03-22 18:45:27 +08:00
bxdd
7370d5af9e add label doc 2021-03-22 18:37:44 +08:00
bxdd
c6b67cb8fe fix doc 2021-03-22 18:37:13 +08:00
Alex Wang
3bf6c7f95f update format 2021-03-22 15:37:54 +08:00
Alex Wang
1ad237f89f update high freq demo 2021-03-22 14:20:44 +08:00
you-n-g
2b74b4dfa4 Merge pull request #357 from zhupr/fix_yahoo_collector
Fix yahoo_collector
2021-03-22 12:41:17 +08:00
zhupr
598ee875a0 Fix yahoo_collector 2021-03-22 10:29:07 +08:00
Young
84d5318bda Merge branch 'online_srv_wd' into online_srv 2021-03-19 07:49:16 +00:00
you-n-g
ba56e4071e Merge pull request #292 from wangershi/addFund
Add fund data as an example
2021-03-19 11:05:29 +08:00
wangershi
d3160e9439 remove some useless code 2021-03-18 21:15:45 +08:00
you-n-g
06c90d654d Merge pull request #350 from zhupr/fix_dump_bin
Fix dump_bin.py && check_dump_bin.py
2021-03-18 18:39:40 +08:00
you-n-g
f72771cc81 Merge pull request #351 from wendili-cs/patch-1
Update __init__.py
2021-03-18 18:28:18 +08:00
lzh222333
8abdd63869 online_serving V3 2021-03-18 09:30:01 +00:00
Wendi Li
38f35658e7 Update __init__.py 2021-03-18 13:19:27 +08:00
zhupr
d245242f2f Fix dump_bin.py && check_dump_bin.py 2021-03-18 11:21:25 +08:00
you-n-g
6ef204f190 Merge pull request #348 from microsoft/you-n-g-patch-1
Update Contact Us Information
2021-03-17 16:29:28 +08:00
Young
dad18074ac update image 2021-03-17 08:27:59 +00:00
you-n-g
3cf84f8859 Update Contact Us Information 2021-03-17 16:16:00 +08:00
you-n-g
0403237232 Update Contact US 2021-03-17 15:41:16 +08:00
you-n-g
689774c6be Merge pull request #340 from Derek-Wds/main
Support resuming recorder
2021-03-17 14:48:59 +08:00
Jactus
d78e42e2fe Update exp base method 2021-03-17 13:37:17 +08:00
you-n-g
4de628c736 Merge pull request #347 from 2young-2simple-sometimes-naive/patch-1
fix bug of consider TURE as boolean instead of stock code
2021-03-17 12:08:57 +08:00
you-n-g
023c1fedfe Merge pull request #280 from yongzhengqi/main
Implement Enhanced Indexing as a Portfolio Optimizer
2021-03-17 12:07:39 +08:00
you-n-g
9be6866972 Update README.md 2021-03-17 12:03:53 +08:00
you-n-g
be55e0e3fe Fix Typos 2021-03-17 12:03:00 +08:00
you-n-g
619a3bb25d Update plan and news on Index Page 2021-03-17 12:02:24 +08:00
you-n-g
4bd2cd4611 Add feature news on index page 2021-03-17 11:36:12 +08:00
you-n-g
aa552fdb20 Merge pull request #345 from D-X-Y/main
Fix errors when SignalRecord is not called before SigAna/PortAna
2021-03-17 10:47:00 +08:00
安阁锐
5520463395 fix bug of consider TURE as boolean instead of stock code 2021-03-16 22:18:11 -04:00
D-X-Y
872ddc6f95 Fix black error 2021-03-16 22:57:26 +08:00
D-X-Y
88b0871c12 Add RMSE for contrib.workflow.record_temp and unit tests 2021-03-16 22:55:28 +08:00
D-X-Y
d4aa681652 Add MSERecord in contrib.workflow 2021-03-16 12:54:12 +00:00
Jactus
34f0be2836 Fix arg error 2021-03-16 17:18:48 +08:00
Jactus
447fed8e54 Update structure for resuming 2021-03-16 17:16:00 +08:00
D-X-Y
4cb74d77d1 add error type for record_temp 2021-03-16 09:01:10 +00:00
D-X-Y
b0fd0d2395 Add tests for SigAnaRecord 2021-03-16 08:30:46 +00:00
D-X-Y
6559d44c7d Add tests for SigAnaRecord 2021-03-16 08:17:13 +00:00
D-X-Y
9f57681032 Fix errors when SignalRecord is not called before SigAna/PortAna 2021-03-16 08:11:05 +00:00
lzh222333
d33041dc24 format example 2021-03-16 02:52:20 +00:00
lzh222333
5953365af3 finished update_online_pred demo 2021-03-16 02:43:12 +00:00
lzh222333
e3730b32d7 more clearly structure 2021-03-16 02:23:28 +00:00
Jactus
08b44ed727 Update docs 2021-03-15 14:12:35 +08:00
Jactus
83fb482f1e Fix Bug 2021-03-15 13:55:57 +08:00
Jactus
734bb9ee3c Support resuming recorder 2021-03-15 13:50:10 +08:00
you-n-g
d47e35d64e Merge pull request #337 from D-X-Y/main
Fix bugs in Ghost BN in TabNet and typos in README
2021-03-15 12:42:48 +08:00
lzh222333
0bc49dab60 add task management to index.rst 2021-03-15 03:58:05 +00:00
lzh222333
646d899f8d update docstring and document 2021-03-15 03:50:43 +00:00
D-X-Y
07434da8b0 Remove unused imports 2021-03-15 03:35:34 +00:00
D-X-Y
53a6b72ce5 Fix black errors 2021-03-15 03:09:52 +00:00
D-X-Y
a51dafcb4c Remove unnecessary codes 2021-03-15 03:07:25 +00:00
Young
8362780e22 fix import bug 2021-03-14 15:16:38 +00:00
D-X-Y
358de88602 Fix typos in README and add TabNet config for Alpha360 2021-03-14 07:42:24 +00:00
D-X-Y
32a7be9964 Fix typos in README 2021-03-14 07:31:18 +00:00
D-X-Y
d5f9395e51 Fix Ghost BN bugs in TabNet and simplify its implementation 2021-03-14 07:25:09 +00:00
wangershi
4e7a147759 use base.py 2021-03-14 14:24:14 +08:00
wangershi
1344c40598 Merge remote-tracking branch 'remoteGit/main' into addFund 2021-03-14 11:19:01 +08:00
you-n-g
1d2b2f4f01 Merge pull request #333 from Rekind1e/main
another typo of docs
2021-03-13 20:05:25 +08:00
Hy
373f6e0900 another typo of docs 2021-03-13 15:47:26 +08:00
you-n-g
ba64758c24 Merge pull request #332 from Rekind1e/main
Fix typo of docs
2021-03-13 14:48:02 +08:00
Hy
abddcfccdf fix typo of docs 2021-03-13 14:32:01 +08:00
you-n-g
6d5381f9b1 Merge pull request #311 from withshubh/main
Fixed code quality issues
2021-03-12 18:31:02 +08:00
Young
e4e8a4abcd fix task name & add cur_path 2021-03-12 10:17:16 +00:00
Shubhendra Singh Chauhan
e41373b8ad revert fix 2021-03-12 14:10:52 +05:30
lzh222333
9d84d389ab format code and add example 2021-03-12 08:24:21 +00:00
lzh222333
6d8aa215d6 the second version of online serving 2021-03-12 08:04:08 +00:00
shubhendra
0969c3e7e0 formatted with black 2021-03-12 13:29:20 +05:30
Young
5de7870f9b Merge branch 'online_srv' of github.com:you-n-g/qlib into online_srv 2021-03-12 07:52:31 +00:00
Young
44a7dc004d update docs and fix duplicated pred bug 2021-03-12 07:50:17 +00:00
Shubhendra Singh Chauhan
5f8d0e0436 Merge branch 'main' into main 2021-03-12 13:08:45 +05:30
shubhendra
4fbb5a03c1 revert fixes that failed unit test 2021-03-12 13:05:48 +05:30
you-n-g
0cffb87cbc Merge pull request #329 from D-X-Y/main
Fix Various Bugs for contrib.pytorch_ models
2021-03-12 12:30:08 +08:00
you-n-g
df56e3bdf9 Merge pull request #324 from zhupr/add_base_collector
Add BaseCollector
2021-03-12 12:20:24 +08:00
D-X-Y
1d435248e2 Add return for use_gpu.. 2021-03-11 19:28:00 -08:00
D-X-Y
593553f573 Fix bug in MLP 2021-03-11 19:15:18 -08:00
D-X-Y
d38b8d6001 Fix bugs in use_gpu 2021-03-11 19:10:32 -08:00
D-X-Y
db59713d36 Add torch.no_grad for evaluation 2021-03-12 02:46:04 +00:00
D-X-Y
67fbdafe76 Fix many bugs in TabNet and use_gpu 2021-03-12 02:42:25 +00:00
zhupr
42be8ac312 Add BaseCollector 2021-03-12 10:30:38 +08:00
lzh222333
0df88c07f6 bug fixed and update collect.py 2021-03-11 16:25:46 +00:00
you-n-g
f6b019dcec Merge pull request #328 from D-X-Y/fshare
Move get_path to get_or_create_path, use the best model of SFM / TabNet
2021-03-11 22:07:20 +08:00
D-X-Y
e626264d5a Merge branch 'main' of github.com:microsoft/qlib into fshare 2021-03-11 12:54:04 +00:00
D-X-Y
b99de068f8 Move save_path to get_or_create_path, and fix bugs in sfm / tabnet 2021-03-11 12:52:26 +00:00
you-n-g
e8beaa5257 Merge pull request #314 from D-X-Y/fshare
(1) Fix /0 bug in double_ensemble, (2) remove _default_uri for R/expm, (3) support model size in pytorch models
2021-03-11 20:51:16 +08:00
D-X-Y
0ef7c8e0e6 Fix bugs for get_local_dir 2021-03-11 03:05:31 +00:00
lzh222333
48f0fc147f first version of online serving 2021-03-11 03:00:30 +00:00
D-X-Y
cda96be8c3 Refine default uri in expm 2021-03-11 02:49:03 +00:00
D-X-Y
f6ed175070 Remove set_log_basic_config, refine count_parameters, rename root_uri as get_local_dir 2021-03-11 02:33:00 +00:00
lzh222333
2ca2071d95 format code 2021-03-10 17:06:08 +00:00
you-n-g
0054a4db2a Merge pull request #322 from Derek-Wds/bug
Fix pytorch ts model loader bug
2021-03-10 19:47:56 +08:00
lzh222333
e2f58274ba update task manager 2021-03-10 10:58:49 +00:00
Jactus
119fe90570 Fix pytorch ts model loader bug 2021-03-10 16:43:32 +08:00
you-n-g
e2817ab87c Merge pull request #319 from Derek-Wds/main
Update Filter doc
2021-03-10 15:38:39 +08:00
you-n-g
2e37033e35 Merge pull request #318 from microsoft/bxdd-patch-2
Fix code in ops
2021-03-10 14:45:35 +08:00
Jactus
105fe1d3ed Remove deprecated warning for numpy>=1.20.0 2021-03-10 10:38:43 +08:00
Jactus
78bc2c8748 Update Filter doc 2021-03-09 17:31:27 +08:00
lzh222333
83dbdfb45e finished document and example 2021-03-09 17:22:36 +08:00
bxdd
81987bb143 Update ops.py 2021-03-09 15:38:04 +08:00
Charles Young
53cf89d7c2 Reformat with black. 2021-03-08 19:43:03 +08:00
Charles Young
8b9065c166 Reformat with black. 2021-03-08 19:32:13 +08:00
Charles Young
6a305c73ae Resolve https://github.com/microsoft/qlib/pull/280\#discussion_r589166529 2021-03-08 19:08:55 +08:00
Charles Young
7022675d00 Resolve https://github.com/microsoft/qlib/pull/280\#discussion_r589169489 2021-03-08 19:07:28 +08:00
Charles Young
2f9af1af8f Resolve https://github.com/microsoft/qlib/pull/280\#discussion_r589169769 2021-03-08 19:02:40 +08:00
Charles Young
fc89fec46d Resolve https://github.com/microsoft/qlib/pull/280\#discussion_r589168764 2021-03-08 18:56:54 +08:00
Charles Young
c6675be792 Resolve https://github.com/microsoft/qlib/pull/280\#discussion_r589166143 2021-03-08 17:51:36 +08:00
Charles Young
351d598c9f Resolve https://github.com/microsoft/qlib/pull/280\#discussion_r589165409 2021-03-08 17:49:59 +08:00
Charles Young
81b86f8022 Update test to cover changes in structured_cov_estimator 2021-03-08 17:18:07 +08:00
Charles Young
4d5a30b30b Resolve https://github.com/microsoft/qlib/pull/280\#discussion_r589167776 2021-03-08 17:14:29 +08:00
Charles Young
79c1142d3e Pass nan_option to structured covariance estimator. 2021-03-08 17:09:33 +08:00
D-X-Y
e061443560 Fix lint error with Black 2021-03-08 08:27:58 +00:00
D-X-Y
03ef918dd8 Fix bugs in count_parameters 2021-03-08 08:24:23 +00:00
D-X-Y
ca48345b29 Simplify count_parameters 2021-03-08 08:16:17 +00:00
lzh222333
def132e140 modified format and added TaskCollector 2021-03-08 16:10:16 +08:00
D-X-Y
7bed3b4c2e Fix python format by black 2021-03-08 06:39:03 +00:00
D-X-Y
4266492a34 Merge branch 'main' of github.com:microsoft/qlib into fshare 2021-03-08 06:12:49 +00:00
you-n-g
91eef93386 Merge pull request #302 from D-X-Y/main
Update repr for dataset/workflow classes and add uri kwarg for QlibRecorder
2021-03-08 14:01:53 +08:00
lzh222333
a244f87f95 modified the comments 2021-03-08 13:25:11 +08:00
wangershi
9df0361262 black 2021-03-07 19:35:50 +08:00
wangershi
6bcd88973b resolve one bug 2021-03-07 19:32:37 +08:00
D-X-Y
d13c9ae018 Avoid dividing zero in model.double_ensemble 2021-03-07 11:25:53 +00:00
wangershi
11412727ef add normalizer 2021-03-07 18:51:38 +08:00
D-X-Y
73b7107ee8 Remove useless verbose kwarg 2021-03-07 00:52:27 -08:00
D-X-Y
91fd53ab4d Add reset_default_uri func for R and expm 2021-03-06 05:33:08 -08:00
shubhendra
aab5c5b311 Refactor the comparison involving not
Signed-off-by: shubhendra <withshubh@gmail.com>
2021-03-06 13:01:35 +05:30
shubhendra
dc86a6abc5 Refactor unnecessary else / elif when if block has a continue statement
Signed-off-by: shubhendra <withshubh@gmail.com>
2021-03-06 13:01:34 +05:30
shubhendra
a62d1a1b36 Use literal syntax to create data structure
Signed-off-by: shubhendra <withshubh@gmail.com>
2021-03-06 13:01:34 +05:30
shubhendra
5015d218ff Remove methods with unnecessary super delegation.
Signed-off-by: shubhendra <withshubh@gmail.com>
2021-03-06 13:01:34 +05:30
shubhendra
6f034ccb5d Remove unnecessary use of comprehension
Signed-off-by: shubhendra <withshubh@gmail.com>
2021-03-06 13:01:33 +05:30
shubhendra
07eef18337 Remove length check in favour of truthiness of the object
Signed-off-by: shubhendra <withshubh@gmail.com>
2021-03-06 13:01:33 +05:30
shubhendra
f277a66582 Add .deepsource.toml
Signed-off-by: shubhendra <withshubh@gmail.com>
2021-03-06 13:01:32 +05:30
D-X-Y
49697b1f15 Show model size for pytorch models 2021-03-05 12:46:41 +00:00
D-X-Y
131f0e2e67 Add count_parameters for pytorch models in contrib 2021-03-05 12:07:43 +00:00
D-X-Y
b14a559a52 Merge branch 'main' of github.com:D-X-Y/qlib into fshare 2021-03-05 07:25:18 +00:00
D-X-Y
b115ca5353 Merge branch 'main' of github.com:microsoft/qlib into main 2021-03-04 23:19:33 -08:00
D-X-Y
b40bfb1ea5 Merge branch 'main' of github.com:microsoft/qlib into fshare 2021-03-05 07:17:09 +00:00
D-X-Y
4f980a0266 Merge branch 'fshare' of github.com:D-X-Y/qlib into fshare 2021-03-05 07:15:46 +00:00
D-X-Y
19d93744f3 Add set_log_basic_config function support re-directing log stream 2021-03-05 07:15:25 +00:00
D-X-Y
e327f404e3 Fix pylint issues 2021-03-04 22:37:58 -08:00
D-X-Y
452fb8f013 Make mlflow client consistant with uri 2021-03-04 22:33:35 -08:00
D-X-Y
c4d6e00470 Fix logic of uri in ExpM and add test 2021-03-04 21:04:01 -08:00
Charles Young
0f3e3d206b Update __init__.py. 2021-03-04 22:47:42 +08:00
Charles Young
83c6e74783 Reindex files. 2021-03-04 22:30:38 +08:00
Charles Young
2bff6eb781 Split classes in riskmodel.py & optimizer.py into seperate files. 2021-03-04 22:08:11 +08:00
D-X-Y
ee7eb79277 Fix unexpected mlruns folder error 2021-03-04 06:15:24 +00:00
D-X-Y
592db903b3 Update repr for Experiment & Recorder 2021-03-04 05:02:56 +00:00
wangershi
34b7da1dd8 add calendar list by threshold 2021-03-03 22:49:48 +08:00
lzh222333
2882929c5d Add an example about workflow using RollingGen. 2021-03-03 16:58:05 +08:00
lzh222333
fd2c1ba1ed Update some hint 2021-03-03 16:36:15 +08:00
lzh222333
05cf0e1edc add task_generator method and update some hint 2021-03-03 15:42:39 +08:00
D-X-Y
229a39d0d3 Fix typos in DataHandler's doc 2021-03-03 07:30:55 +00:00
D-X-Y
a9a70dfddf Update repr for DatasetH and ExpManager 2021-03-03 06:47:52 +00:00
you-n-g
b8cf229b05 Merge pull request #272 from ChengYen-Tang/Fix_collector_doc
Fix collector doc
2021-03-03 14:14:07 +08:00
you-n-g
7258340e0c Merge pull request #300 from bxdd/qlib_2021_03_02
Update FAQ doc & Update ops docstring
2021-03-03 14:11:11 +08:00
lzh222333
b84156fde8 Consider more situations about task_config.
Save the "param" file which is collect.py need.
2021-03-03 11:25:37 +08:00
you-n-g
d1d70616a3 Merge pull request #298 from D-X-Y/main
Update the Wrapper class to have an informative str representation
2021-03-03 11:17:00 +08:00
bxdd
5378d261b4 update ops.py docstring 2021-03-03 00:47:50 +09:00
bxdd
a1fb10f7cf update FAQ doc 2021-03-03 00:44:52 +09:00
D-X-Y
dbc8ca7379 Fix pylint by black -l 120 2021-03-02 22:15:30 +08:00
D-X-Y
c48b4c9971 Make Wrapper with a informative str repr. 2021-03-02 21:06:32 +08:00
you-n-g
b592669d1f Update Index Report 2021-03-02 20:23:27 +08:00
you-n-g
0bcaab3a5a Merge pull request #286 from meng-ustc/main
Add a new method to benchmarks: DoubleEnsemble
2021-03-02 18:34:17 +08:00
meng-ustc
1de4def444 Update parameter names: 'k' and 'base' 2021-03-02 16:14:56 +09:00
meng-ustc
ee4692a355 Merge branch 'main' of https://github.com/meng-ustc/qlib 2021-03-02 12:20:45 +09:00
meng-ustc
6e2ce6f1dc Add the results of DoubleEnsemble 2021-03-02 12:17:05 +09:00
lzh222333
c4733f601f Merge pull request #1 from you-n-g/online_srv
qlib auto init basedon project & black format
2021-03-02 10:06:45 +08:00
wangershi
82353b20e1 black format 2021-03-01 21:10:46 +08:00
wangershi
51baf57b40 Merge remote-tracking branch 'remoteGit/main' into addFund 2021-02-28 19:29:18 +08:00
wangershi
3082f6ac1b ready for dump_bin 2021-02-28 19:06:40 +08:00
wangershi
db80b620d8 ready for collector 2021-02-28 17:03:14 +08:00
wangershi
6e56396217 add crawler 2021-02-28 12:24:26 +08:00
Young
24024d51c7 qlib auto init basedon project & black format 2021-02-27 09:44:01 +00:00
Wendi Li
a96f0c2e5f Update README.md
Typos fixed.
2021-02-27 11:37:45 +08:00
Young
1e5cf1c174 init version of online serving and rolling 2021-02-26 09:14:40 +00:00
wangershi
719074d306 touch file 2021-02-25 19:20:14 +08:00
Meng Dong
70575e8a1c Delete workflow_by_code_lgb_risk_demo.py 2021-02-24 16:10:38 +08:00
meng-ustc
ce60097722 Add README and Formatted 2021-02-24 16:59:31 +09:00
meng-ustc
1a990fdd25 Add Risk Prediction Demo 2021-02-23 19:08:11 +09:00
Charles Young
527718a440 Allow enhanced indexing to generate portfolio without industry related restriction. 2021-02-22 19:04:31 +08:00
Charles Young
d3caea60ee Add unittest for TestStructuredCovEstimator. 2021-02-22 17:32:03 +08:00
Charles Young
f947a2fdef Correct two mistakes in annotation. 2021-02-22 15:15:51 +08:00
Jactus
dc4aa67503 Black format 2021-02-22 11:42:36 +08:00
Charles Young
37871389b9 Format code with the latest version of black. 2021-02-22 11:25:42 +08:00
Charles Young
2f9d45e03a Reformat code with black. 2021-02-22 10:29:29 +08:00
Charles Young
b8647c13c7 Reformat code to follow PEP 8. 2021-02-22 10:20:51 +08:00
Charles Young
164687d54b Add scikit-learn to dependencies. 2021-02-22 10:13:08 +08:00
Charles Young
58f74cfd84 Reformat code to follow PEP 8. 2021-02-22 10:07:03 +08:00
Charles Young
f7d3e56561 Merge optimization related portfolio construction back to portfolio/optimizer. 2021-02-22 09:57:41 +08:00
Charles Young
42f882504e Reformat code to follow PEP 8. 2021-02-22 09:25:48 +08:00
Charles Young
9448a6e2c7 Add a abstract class as the base class for all optimization related portfolio constructions. 2021-02-22 09:23:48 +08:00
Charles Young
2cc057e438 Fix minor mismatches of type hints. 2021-02-22 09:09:03 +08:00
Charles Young
b2e2142594 Applied slight modification to follow PEP 8. 2021-02-22 09:00:12 +08:00
Charles Young
4000518698 Separate specific implementation of Portfolio Optimizer to folder. 2021-02-22 08:41:35 +08:00
you-n-g
fa8f1cba06 Update filter.py 2021-02-19 22:23:22 +08:00
you-n-g
a72911e4f8 Update filter.py 2021-02-19 22:23:22 +08:00
meng-ustc
cd5b721bc6 Update 2021-02-19 11:56:50 +09:00
meng-ustc
42590972e4 Modify run_all_model.py 2021-02-18 19:15:02 +09:00
meng-ustc
d27dc8bab8 Add A New Baseline: DoubleEnsemble 2021-02-18 19:02:33 +09:00
Young
50d5fcf61e yml afe load 2021-02-18 08:43:12 +08:00
Young
77830a546e safe yaml loader 2021-02-18 08:43:12 +08:00
Young
83237ba4ed yml afe load 2021-02-17 05:17:18 +00:00
Young
04b916c8ae safe yaml loader 2021-02-16 15:07:14 +00:00
Kenneth Tang
b90bd66ac6 Merge branch 'main' into Fix_collector_doc 2021-02-16 14:16:53 +08:00
Kenneth Tang
63d05e4a1a Fix typo 2021-02-16 14:08:58 +08:00
lbaiao
0b11dc5167 Update workflow.rst
Corrected an identation problem on the configuration.yaml file.
2021-02-11 10:54:32 +08:00
Charles Young
9c2653f125 Add an implementation of Enhanced Indexing to optimizer.py 2021-02-09 20:31:00 +08:00
Charles Young
7b01c5cae7 Add an implementation of Enhanced Indexing to optimizer.py 2021-02-09 20:30:26 +08:00
Charles Young
988b42e159 Add Structured Covariance Estimator to riskmodel.py 2021-02-09 20:28:42 +08:00
Miae Kim
12c8bfa545 Fix typo in Data Layer
Commit changes
2021-02-08 10:29:37 +08:00
Jactus
c948385e76 Add model saving for qrun and workflow example 2021-02-07 21:32:35 +08:00
bxdd
07b905c153 update 2021-02-05 16:09:06 +08:00
bxdd
0192f28bf4 add docstring & fix code 2021-02-05 16:09:06 +08:00
bxdd
cbf97f56a4 fix market 2021-02-05 16:09:06 +08:00
bxdd
d702c8bcb1 add Cut ops 2021-02-05 16:09:06 +08:00
Jactus
b84686b215 Update models to enable save/load 2021-02-05 13:14:12 +08:00
bxdd
6a670828a5 Update serial.rst 2021-02-05 12:33:47 +08:00
bxdd
ca6c2ffc27 Update serial.rst 2021-02-05 12:33:47 +08:00
bxdd
914637b3ef update index 2021-02-05 12:33:47 +08:00
bxdd
d8da94de10 update docs 2021-02-05 12:33:47 +08:00
bxdd
477a548fe9 update data.rst docs 2021-02-05 12:33:47 +08:00
bxdd
35af9ad954 update docs 2021-02-05 12:33:47 +08:00
bxdd
8a91e7d34d black format 2021-02-05 12:33:47 +08:00
bxdd
4ed8b8e233 add docs & fix reinit of datatset 2021-02-05 12:33:47 +08:00
Jactus
c71b645777 Update docs about record-temp 2021-02-05 12:32:47 +08:00
HUAN-PING SU
f2ffb80a0b Update workflow_by_code.py 2021-02-04 17:01:34 +08:00
HUAN-PING SU
cda1d4be40 Update workflow_by_code.py
Fix typo in workflow_by_code.py
2021-02-04 17:01:34 +08:00
zhupr
fc1431cd4e update 1min docs 2021-02-04 16:49:58 +08:00
Jactus
06158fb621 Update readme and setup.py 2021-02-04 16:09:36 +08:00
Jactus
1e2e02368c Update readme 2021-02-03 14:42:58 +08:00
Jactus
d87d29aca9 Update Windows CI 2021-02-03 14:42:58 +08:00
Jactus
005da6306c Update CI 2021-02-03 14:42:58 +08:00
bxdd
090b68e44e Update workflow.py 2021-02-03 13:12:19 +08:00
Young
bf748ba4b7 update version number to dev 2021-02-03 05:11:39 +00:00
Young
0da9b909f6 update version number 2021-02-02 09:48:41 +00:00
Jactus
5f87fc32ad Fix CI 2021-02-02 17:35:07 +08:00
Young
97d354fa73 update version for releasing 2021-02-02 09:24:33 +00:00
Young
a87fb5a68c fix contrib data freq 2021-02-02 16:52:50 +08:00
Young
835b47a7e7 simplify parameters 2021-02-02 16:52:50 +08:00
Young
802dac81c9 move freq params to dataloader 2021-02-02 16:52:50 +08:00
Wendi Li
bdc70c192a Update pytorch_nn.py 2021-02-02 14:48:12 +08:00
Wendi Li
213f809148 Update pytorch_alstm_ts.py 2021-02-02 14:47:41 +08:00
Wendi Li
f3fd5e0773 Update pytorch_gats.py 2021-02-02 14:47:31 +08:00
Wendi Li
decf74cbdf Update pytorch_gru.py 2021-02-02 14:47:20 +08:00
Wendi Li
b4a92d55f8 Update pytorch_gru_ts.py 2021-02-02 14:47:00 +08:00
Wendi Li
ebc31b9bdb Update pytorch_lstm.py 2021-02-02 14:46:49 +08:00
Wendi Li
56ebe9bf36 Update pytorch_lstm_ts.py 2021-02-02 14:46:21 +08:00
Wendi Li
ddd68fc761 Update pytorch_alstm.py 2021-02-02 14:34:57 +08:00
Meng Dong
fd5c68a7d1 Update workflow_config_doubleensemble_Alpha158.yaml 2021-02-02 12:39:07 +08:00
meng-ustc
8c3ec164ff Add A New Baseline: DoubleEnsemble 2021-02-02 11:46:37 +09:00
meng-ustc
acdc469e39 Add A New Baseline: DoubleEnsemble 2021-02-01 21:05:34 +09:00
bxdd
f50463aca9 Fix bug in alpha360 2021-02-01 18:33:51 +08:00
Jactus
c0e7cbc983 Add filter_pipe API 2021-01-29 12:47:04 +08:00
you-n-g
828993b397 Merge pull request #222 from bxdd/rl-highfreq-include-examples
Qlib Highfreq Support & Highfreq DataHanlder/Operator/Processor Examples
2021-01-29 00:08:10 +08:00
bxdd
8ef89b4fa8 update 2021-01-28 15:01:07 +00:00
bxdd
76cf9dad99 update 2021-01-28 14:30:20 +00:00
bxdd
f3eb02a0bd update docstring 2021-01-28 14:26:30 +00:00
bxdd
ffa68fd010 update 2021-01-28 14:25:55 +00:00
bxdd
f6dd006c35 update 2021-01-28 11:31:15 +00:00
you-n-g
8c29105bca Update cache.py 2021-01-27 19:52:33 +08:00
bxdd
948b829ff4 add get_data in highfreq 2021-01-27 10:34:31 +00:00
Jactus
304a0c3d7a Add paper year 2021-01-27 18:15:52 +08:00
bxdd
02dea2aeb6 update paused 2021-01-27 07:42:00 +00:00
bxdd
6fc4f2b249 fix a bug 2021-01-27 07:02:59 +00:00
bxdd
2a5f06ee9e update dataset test 2021-01-27 06:25:40 +00:00
zhupr
7f9216dc90 Fix the number of minutes on the first and last trading day of high frequency 2021-01-27 10:59:46 +08:00
zhupr
263ccdfe6f US stock code supports Windows 2021-01-27 10:59:46 +08:00
zhupr
1a8f1bfc57 support collecting yahoo 1min data 2021-01-27 10:59:46 +08:00
bxdd
9dc11a9e3c Merge github.com:microsoft/qlib into qlib_register_ops 2021-01-26 17:12:33 +00:00
bxdd
3bdd54308b update some little code 2021-01-26 17:02:30 +00:00
bxdd
1b569d371d simpson vwap 2021-01-26 14:32:08 +00:00
you-n-g
36e5c601de Merge pull request #78 from zhupr/main
Fix the error when the stock code is a number
2021-01-26 21:50:21 +08:00
zhupr
ae45711e2b Merge remote-tracking branch 'qlib/main' into save_inst 2021-01-26 19:42:59 +08:00
you-n-g
bcc47aa4cb Merge pull request #92 from bxdd/qlib_register_ops
Support Register of Custom Feature Operators Easily
2021-01-26 18:53:43 +08:00
bxdd
ee94634b23 black 2021-01-26 08:47:53 +00:00
bxdd
2016ebbbb2 update tests 2021-01-26 08:47:07 +00:00
zhupr
1eaf09cce1 version removed .dev 2021-01-26 16:29:26 +08:00
zhupr
7579f4b4c0 Merge remote-tracking branch 'qlib/main' into save_inst 2021-01-26 16:14:11 +08:00
zhupr
1a1c45981c US stock code supports Windows 2021-01-26 16:06:38 +08:00
bxdd
e4ecea55e4 fix 2021-01-26 07:41:22 +00:00
bxdd
58616fced9 black format 2021-01-26 07:33:50 +00:00
bxdd
8e9ca22b07 del some print 2021-01-26 07:33:26 +00:00
bxdd
6a145df87c fix bug 2021-01-26 07:32:06 +00:00
bxdd
06dbd02b99 black format 2021-01-25 17:59:48 +00:00
bxdd
ffedb6382f add highfreq example 2021-01-25 17:58:45 +00:00
zhupr
3f9f295a87 add register in config 2021-01-24 11:22:02 +08:00
Wendi Li
84d77f4585 Update pytorch_nn.py 2021-01-24 10:40:47 +08:00
you-n-g
afdf58b4fa Update serial.py 2021-01-24 10:36:56 +08:00
Alex Wang
2b6d16feb1 fix naming 2021-01-22 19:16:57 +08:00
Alex Wang
0a86a6f392 update format 2021-01-22 19:16:57 +08:00
Alex Wang
5da5ad4b9f tabnet 2021-01-22 19:16:57 +08:00
you-n-g
dd07810b66 Update README.md 2021-01-22 12:53:05 +08:00
bxdd
a762248d98 update test&docs 2021-01-22 01:06:32 +09:00
bxdd
80c9a47e51 Merge github.com:microsoft/qlib into qlib_register_ops 2021-01-22 00:52:30 +09:00
王雪
784e73bceb black formatting 2021-01-21 00:07:03 +08:00
王雪
5ad1b4cc33 for IDE auto-complete with global Wrapper
R, D, Cal, Inst, FeatureD, ExpressionD, DatasetD, D
2021-01-21 00:07:03 +08:00
王雪
e85646762c Update .gitignore 2021-01-20 22:12:35 +08:00
Young
fc81a39317 Add dataset standalone usage example 2021-01-20 21:14:27 +08:00
you-n-g
d44c5bb2b2 Update README.md 2021-01-20 21:14:03 +08:00
bxdd
c622d3f6f8 Update data.rst 2021-01-20 18:55:30 +08:00
bxdd
6daaa79519 add register ops config 2021-01-20 18:44:53 +09:00
zhupr
3dda2cb379 Merge remote-tracking branch 'qlib/main' into qlib_register_ops 2021-01-20 15:16:06 +08:00
zhupr
4fcfde7cfb Initialization is split into: set_config and config_based_on_C 2021-01-20 15:06:18 +08:00
bxdd
3403c00b6b Update requirements.txt
fix readthedocs cant find cmake error
2021-01-19 20:35:11 +08:00
bxdd
ecdfe49fd1 del custom ops test for check the CI status 2021-01-19 20:39:15 +09:00
bxdd
cc214a3462 black format 2021-01-19 09:14:17 +08:00
bxdd
65d8af41e7 restructure backtest 2021-01-19 09:14:17 +08:00
bxdd
0e0970f06e update backtest 2021-01-19 09:14:17 +08:00
bxdd
917261dbf6 update backtest 2021-01-19 09:14:17 +08:00
bxdd
6a9105e065 add highfreq_backtest 2021-01-19 09:14:17 +08:00
王雪
570bb272eb fix setup error
why required pymongo
2021-01-18 19:37:24 +08:00
Wendi Li
0524a47cf4 Update pytorch_lstm_ts.py 2021-01-18 12:20:40 +08:00
Wendi Li
9abc0b0d4f Update pytorch_gru_ts.py 2021-01-18 12:20:31 +08:00
Wendi Li
fe60e40927 Update pytorch_gats_ts.py 2021-01-18 12:20:20 +08:00
Wendi Li
740c297618 Update pytorch_alstm_ts.py 2021-01-18 12:20:00 +08:00
Anon-Artist
b4a088efe8 Update cli.py 2021-01-14 18:42:33 +08:00
Jactus
b34890772f Make note more clear 2021-01-13 19:19:48 +08:00
Jactus
054ffa29f6 Update readme 2021-01-13 19:19:48 +08:00
Jactus
74e08c9e37 Add deepcopy to config 2021-01-13 19:19:48 +08:00
Jactus
ea96c9e22d Update docs and support Python 3.9 2021-01-13 19:19:48 +08:00
王雪
86e7c44c6b Update initialization.rst
need line changing
2021-01-13 15:28:05 +08:00
you-n-g
64cf2e2df8 Update data.rst 2021-01-12 18:43:05 +08:00
Jactus
4361a4049a Fix create_recorder bug 2021-01-07 18:30:18 +08:00
Zhichong Fang
231f37376b Fix unrecognized config bug 2021-01-07 18:28:17 +08:00
you-n-g
328cdeda4a Update README.md 2021-01-07 11:12:49 +08:00
Zhichong Fang
4dbc8e52ec Update data.py
Fix some typo
2021-01-06 16:36:23 +08:00
Young
ba447d3448 update valute 2021-01-06 14:43:14 +08:00
zhupr
df556532d0 Fix the error when the stock code is a number 2021-01-06 11:21:33 +08:00
Wendi Li
18e040f506 Update workflow_config_gru_Alpha158.yaml
Delete a redundant parameter.
2021-01-04 17:05:21 +08:00
Wendi Li
aefc98b1d7 Update workflow_config_lstm_Alpha158.yaml
Delete a redundant parameter.
2021-01-04 17:05:13 +08:00
Jactus
46c8d791ac Fix doc bugs 2020-12-30 23:51:05 +08:00
Young
afcd91a2d0 black format 2020-12-28 12:04:03 +00:00
Young
4a30d9d1ec update github issue template 2020-12-28 12:02:01 +00:00
you-n-g
2da2e9bd9e Update README.md 2020-12-26 20:21:30 +08:00
you-n-g
3e6877ff0f Update README.md 2020-12-25 22:01:18 +08:00
zhupr
a0f32036a6 Fix the first trading day of the calendar extra in report_df 2020-12-24 11:22:48 +08:00
bxdd
d8f36df7f4 debug on macos 2020-12-23 18:28:05 +00:00
bxdd
cb3b6c5bde black format 2020-12-23 16:41:32 +00:00
bxdd
b11712fa54 fix cant find ops error on Windows 2020-12-23 16:39:17 +00:00
Jactus
660edeb94f Remove fm in recorder 2020-12-23 21:14:53 +08:00
Jactus
95de4088df Fix recorder temp dir bug 2020-12-23 21:14:53 +08:00
hadrianl
e8d7a22651 fix _adjust_size 2020-12-23 17:39:04 +08:00
hadrianl
4a62b929ad add _get_value_size and remove _limit_flag 2020-12-23 17:39:04 +08:00
hadrianl
5efe82fb56 make code cleaner 2020-12-23 17:39:04 +08:00
hadrianl
40bbafcaab black format 2020-12-23 17:39:04 +08:00
hadrianl
4c4f0f3c5e black format 2020-12-23 17:39:04 +08:00
hadrianl
ae0e0eca3d better MemCacheUnit implement 2020-12-23 17:39:04 +08:00
bxdd
7e37fa710a update alpha.rst 2020-12-21 23:31:31 +08:00
bxdd
e0c460c33c Update alpha.rst 2020-12-21 23:31:31 +08:00
bxdd
53f501ac19 del import 2020-12-21 12:44:27 +00:00
bxdd
132df027a5 update format 2020-12-21 12:09:25 +00:00
bxdd
7d97fd39ce update ops register 2020-12-21 12:06:42 +00:00
Young
995fa98fc6 add more doc to PortAnaRecord 2020-12-20 16:11:07 +08:00
Maciej Domagała
824de921d1 fixing typos #4 2020-12-19 11:59:23 +08:00
Maciej Domagała
66d9bd1a68 fixing typos #3
I just randomly find these by the way. Good work on the framework!
2020-12-18 20:16:54 +08:00
you-n-g
1c0bb2f827 Merge pull request #97 from Derek-Wds/main
Update benchmark performance
2020-12-17 17:12:40 +08:00
Maciej Domagała
ea018ed4dc fixing typos #2 2020-12-17 17:12:18 +08:00
hadrianl
f3f1867b14 fix wrong attribute 2020-12-17 15:04:07 +08:00
hadrianl
8bbfd8810c formatting 2020-12-17 15:04:07 +08:00
hadrianl
3f84c3768a Make __getattr__ to raise AttributeError instead of return it.Avoid using try except. 2020-12-17 15:04:07 +08:00
Dingsu Wang
7372a3a598 Merge branch 'main' into main 2020-12-17 14:43:21 +08:00
Jactus
4b4cd38ca6 Update benchmark results 2020-12-17 14:41:12 +08:00
you-n-g
7d40ba753a Update README.md 2020-12-17 00:35:35 +08:00
Young
9b60214e0c make info more friendly 2020-12-16 02:16:06 +00:00
Young
f7e775f941 make message more friendly 2020-12-16 02:14:38 +00:00
Young
aefbf3b5f1 update collect info 2020-12-15 13:24:29 +00:00
G_will
3f85af05e5 Refactor to Python3 style 2020-12-15 20:37:43 +08:00
Jactus
192c2dc5ef Add demo 2020-12-15 20:33:32 +08:00
Jactus
911edd7839 Add stale bot 2020-12-15 20:31:38 +08:00
Maciej Domagała
3d47dd78c8 Typo fix 2020-12-15 20:29:30 +08:00
Jactus
8f6ab0af54 Format 2020-12-14 19:23:43 +08:00
Jactus
cb0b6fcdaa Update CI and script 2020-12-14 19:23:43 +08:00
Yifan Deng (FA Talent)
6b8824dd29 Update Sign in ops.py 2020-12-14 16:55:23 +08:00
Yifan Deng
c217e7c479 Update ops.py
Fix the bug when Sign followed by True/False
2020-12-14 16:55:23 +08:00
you-n-g
ea4fe1577b Update README.md 2020-12-14 13:05:12 +08:00
you-n-g
1bab07e419 Update README.md 2020-12-13 22:45:07 +08:00
bxdd
422d1d8c93 Update README.md 2020-12-12 19:41:16 +08:00
bxdd
c8f9b1162d Update README.md 2020-12-12 19:01:00 +08:00
Young
e2bdef7ffe update version number to dev 2020-12-12 10:09:18 +00:00
Young
e49b590322 Release qlib 0.6.1 2020-12-12 09:51:52 +00:00
bxdd
9d19294f15 update Note 2020-12-12 17:42:23 +08:00
bxdd
b0e7a85601 update readme 2020-12-12 17:42:23 +08:00
you-n-g
8ea45802df Update README.md 2020-12-12 14:04:21 +08:00
Jactus
bba94d72dc Add author names 2020-12-11 17:19:12 +08:00
Jactus
d6dd423dc2 Update benchmark performance 2020-12-11 17:19:12 +08:00
Jactus
c10955d026 Update tft 2020-12-11 14:33:16 +08:00
Jactus
d642c7b6ea Update benchmark performance 2020-12-11 09:55:37 +08:00
G_will
9307bcc8d1 fix typo
fix typo
2020-12-10 20:11:57 +08:00
Jactus
99f3820e42 Update readme 2020-12-10 19:37:18 +08:00
Jactus
b04d2c39c8 Update CI 2020-12-10 19:37:18 +08:00
bxdd
0cdc5e125a update docs 2020-12-10 10:08:29 +00:00
bxdd
2de812f262 update ops docs 2020-12-10 10:04:09 +00:00
bxdd
16450c2876 fix import 2020-12-10 09:54:05 +00:00
bxdd
729b57e4a7 add example script 2020-12-10 09:11:12 +00:00
bxdd
87cc52cd05 black format 2020-12-10 09:02:43 +00:00
bxdd
0be57d51be support register custom feature ops easily 2020-12-10 09:00:00 +00:00
bxdd
9c482ebbe2 black format 2020-12-10 15:50:14 +08:00
bxdd
eb67f1037a update setup 2020-12-10 15:50:14 +08:00
bxdd
59282c8965 fix req 2020-12-10 15:50:14 +08:00
bxdd
03ab67ad5c fix req 2020-12-10 15:50:14 +08:00
bxdd
e2d862bfb2 fix system package 2020-12-10 15:50:14 +08:00
bxdd
936d5abb1f fix docs req 2020-12-10 15:50:14 +08:00
bxdd
7296780149 fix setup 2020-12-10 15:50:14 +08:00
bxdd
97c053ba73 update setup 2020-12-10 15:50:14 +08:00
bxdd
2c5864204e del qlib.readthedocs.yml 2020-12-09 13:58:17 +00:00
bxdd
6562c9aaa4 Merge https://github.com/microsoft/qlib into main 2020-12-09 13:55:56 +00:00
bxdd
85a217c121 update readthedocs 2020-12-09 13:54:22 +00:00
bxdd
f156280a51 Merge pull request #84 from bxdd/qlib_readthedocs
Fix readthedocs
2020-12-09 21:43:51 +08:00
bxdd
e8eb034a97 fix the config of readthedocs 2020-12-09 13:41:18 +00:00
bxdd
7763cf5a5c add the config of readthedocs 2020-12-09 21:27:39 +08:00
bxdd
053736c0ea add the config of readthedocs 2020-12-09 13:23:59 +00:00
Young
74ac230edb rename for pytest 2020-12-09 20:24:56 +08:00
Young
303021cd47 CI updating 2020-12-09 20:24:56 +08:00
Young
c0f1696adb add downcast to save data 2020-12-09 20:24:56 +08:00
Young
361d168890 remove dataloader 2020-12-09 20:24:56 +08:00
lwwang1995
73669de392 Fix code. 2020-12-09 20:24:56 +08:00
Wendi Li
89ec87e45b Add files via upload 2020-12-09 17:20:36 +08:00
Wendi Li
15cdfeb121 Add files via upload 2020-12-09 17:20:36 +08:00
Wendi Li
1bbd026195 Add files via upload 2020-12-09 17:20:36 +08:00
Jactus
a5c098de92 Update tft results 2020-12-09 17:20:36 +08:00
bxdd
a63ba3e819 black format 2020-12-09 17:20:36 +08:00
bxdd
56e579e20f add arg weight_decay 2020-12-09 17:20:36 +08:00
bxdd
2873813562 update mlp model 2020-12-09 17:20:36 +08:00
Jactus
a8ac56a82f Format 2020-12-09 17:20:36 +08:00
Jactus
6ef339b1ec Update mlp results and add doc 2020-12-09 17:20:36 +08:00
Jactus
579caa757c Update readme 2020-12-09 17:20:36 +08:00
Jactus
a1e579ff39 Update readme 2020-12-09 17:20:36 +08:00
Jactus
217019a640 Update benchmark readme 2020-12-09 17:20:36 +08:00
Jactus
c14404afe1 Add benchmark results 2020-12-09 17:20:36 +08:00
lwwang1995
4596a7e000 Delete the setting of SFM on the Alpha158. 2020-12-09 17:20:36 +08:00
lwwang1995
ec40845513 Update settings. 2020-12-09 17:20:36 +08:00
lwwang1995
dcfa8110e8 Fix bugs for Gats model. 2020-12-09 17:20:36 +08:00
lwwang1995
666e1ffcbd Update settings. 2020-12-09 17:20:36 +08:00
lwwang1995
70fb760830 Fix bugs for models. 2020-12-09 17:20:36 +08:00
lwwang1995
4a748525bc Fix bugs of model. 2020-12-09 17:20:36 +08:00
Young
fb4a2e65cc support multi indexing of TSDatasetSample 2020-12-09 17:20:36 +08:00
lwwang1995
71ad651514 Add sample for Gats. 2020-12-09 17:20:36 +08:00
lwwang1995
65a9a72a88 Update models. 2020-12-09 17:20:36 +08:00
lwwang1995
ec0d7838ac Update models. 2020-12-09 17:20:36 +08:00
lwwang1995
752f17e51e Fix SFM bug. 2020-12-09 17:20:36 +08:00
lwwang1995
8d42092a7e Fix models. 2020-12-09 17:20:36 +08:00
lwwang1995
412c9eee2e Update models. 2020-12-09 17:20:36 +08:00
Young
abb90ca2f6 fix sampler performance bug 2020-12-09 17:20:36 +08:00
lwwang1995
a7c6aea386 Update record to support time series dataset. 2020-12-09 17:20:36 +08:00
lwwang1995
a88697151a Test CSRankNorm. 2020-12-09 17:20:36 +08:00
Young
d2107c9957 dataset performance optm 2020-12-09 17:20:36 +08:00
lwwang1995
65902e424c Add filter columns. 2020-12-09 17:20:36 +08:00
lwwang1995
bf8de72605 Update test_dataset 2020-12-09 17:20:36 +08:00
lwwang1995
60f62482b7 Update test_dataset 2020-12-09 17:20:36 +08:00
Young
d2d865fb7a fix bug of workflow 2020-12-09 17:20:36 +08:00
Young
5d5f8c8868 update TimeSeriesDataset 2020-12-09 17:20:36 +08:00
you-n-g
d093afd684 Merge pull request #74 from Derek-Wds/main
Update scripts
2020-12-04 14:30:42 +08:00
Dingsu Wang
46396c229a Update README.md 2020-12-04 14:28:13 +08:00
Jactus
eef90c7901 Update readme 2020-12-04 14:25:26 +08:00
Jactus
895b1e7944 Update CI 2020-12-04 13:10:45 +08:00
Jactus
2fb7774927 Update config names 2020-12-04 13:02:09 +08:00
Jactus
86b0b63771 Update table generator 2020-12-04 12:55:22 +08:00
Jactus
99adc514a5 Update CI 2020-12-04 11:51:45 +08:00
Jactus
07fb9031c6 Update setup 2020-12-04 10:47:20 +08:00
Jactus
f237a344c3 Update 2020-12-04 10:31:50 +08:00
Dingsu Wang
2cb888c8b9 Merge branch 'main' into main 2020-12-04 09:45:00 +08:00
Jactus
ab762b3cd7 Update lightgbm lr 2020-12-02 20:04:00 +08:00
Jactus
703ae5d4aa Update tft and readme 2020-12-02 18:00:26 +08:00
Jactus
91c3dfddf5 Update run_all_model 2020-12-02 17:55:11 +08:00
Jactus
745b93138d Update devices 2020-12-02 13:28:32 +08:00
Jactus
7f385345bb Fix GPU 2020-12-02 13:25:29 +08:00
Jactus
d109d3d44e Fix GPU 2020-12-02 12:23:42 +08:00
Jactus
a2603fe27a Update config 2020-12-02 09:41:24 +08:00
Young
e5590de2a4 udpate header and test docs 2020-12-01 10:11:29 +00:00
Jactus
77884db3a5 Fix and update run_all_model 2020-12-01 11:54:26 +08:00
Jactus
bb5f3cb33d Fix and move to version 0.6.0 2020-12-01 09:59:46 +08:00
219 changed files with 16673 additions and 2861 deletions

12
.deepsource.toml Normal file
View File

@@ -0,0 +1,12 @@
version = 1
test_patterns = ["tests/test_*.py"]
exclude_patterns = ["examples/**"]
[[analyzers]]
name = "python"
enabled = true
[analyzers.meta]
runtime_version = "3.x.x"

View File

@@ -28,7 +28,8 @@ Steps to reproduce the behavior:
## Environment
**Note**: One could run `python scripts/collect_info.py` under the `qlib` directory to get the following information.
**Note**: User could run `cd scripts && python collect_info.py all` under project directory to get system information
and paste them here directly.
- Qlib version:
- Python version:
@@ -37,4 +38,4 @@ Steps to reproduce the behavior:
## Additional Notes
<!-- Add any other information about the problem here. -->
<!-- Add any other information about the problem here. -->

24
.github/workflows/stale.yml vendored Normal file
View File

@@ -0,0 +1,24 @@
name: Mark stale issues and pull requests
on:
schedule:
- cron: "0 0/3 * * *"
jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v3
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: 'This issue is stale because it has been open for three months with no activity. Remove the stale label or comment on the issue otherwise this will be closed in 5 days'
stale-pr-message: 'This PR is stale because it has been open for a year with no activity. Remove the stale label or comment on the PR otherwise this will be closed in 5 days'
stale-issue-label: 'stale'
stale-pr-label: 'stale'
days-before-stale: 90
days-before-close: 5
operations-per-run: 100
exempt-issue-labels: 'bug,enhancement'
remove-stale-when-updated: true

View File

@@ -12,8 +12,8 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [windows-latest, ubuntu-16.04, ubuntu-18.04, macos-latest]
python-version: [3.6, 3.7, 3.8]
os: [windows-latest, ubuntu-16.04, ubuntu-18.04, ubuntu-20.04, macos-latest]
python-version: [3.6, 3.7, 3.8, 3.9]
steps:
- uses: actions/checkout@v2
@@ -23,37 +23,98 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
- name: Lint with Black
run: |
pip install --upgrade cython
pip install numpy jupyter jupyter_contrib_nbextensions
python setup.py install
cd ..
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe -m pip install black
$CONDA\\python.exe -m black qlib -l 120 --check --diff
else
sudo $CONDA/bin/python -m pip install black
$CONDA/bin/python -m black qlib -l 120 --check --diff
fi
shell: bash
# Test Qlib installed with pip
- name: Install Qlib with pip
run: |
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe -m pip install numpy==1.19.5
$CONDA\\python.exe -m pip install pyqlib --ignore-installed ruamel.yaml numpy --user
else
sudo $CONDA/bin/python -m pip install numpy==1.19.5
sudo $CONDA/bin/python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
fi
shell: bash
- name: Install Lightgbm for MacOS
if: runner.os == 'macOS'
run: |
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
- name: Test data downloads
run: |
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
else
$CONDA/bin/python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
fi
shell: bash
- name: Test workflow by config (install from pip)
run: |
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe qlib\\workflow\\cli.py examples\\benchmarks\\LightGBM\\workflow_config_lightgbm_Alpha158.yaml
$CONDA\\python.exe -m pip uninstall -y pyqlib
else
$CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
sudo $CONDA/bin/python -m pip uninstall -y pyqlib
fi
shell: bash
# Test Qlib installed from source
- name: Install Qlib from source
run: |
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe -m pip install --upgrade cython
$CONDA\\python.exe -m pip install numpy jupyter jupyter_contrib_nbextensions
$CONDA\\python.exe -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
$CONDA\\python.exe setup.py install
else
sudo $CONDA/bin/python -m pip install --upgrade cython
sudo $CONDA/bin/python -m pip install numpy jupyter jupyter_contrib_nbextensions
sudo $CONDA/bin/python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
sudo $CONDA/bin/python setup.py install
fi
shell: bash
- name: Install test dependencies
run: |
python -m pip install --upgrade pip
pip install black pytest
- name: Lint with Black
run: |
cd ..
python -m black qlib -l 120 --check --diff
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe -m pip install --upgrade pip
$CONDA\\python.exe -m pip install black pytest
else
sudo $CONDA/bin/python -m pip install --upgrade pip
sudo $CONDA/bin/python -m pip install black pytest
fi
shell: bash
- name: Unit tests with Pytest
run: |
cd tests
pytest . --durations=0
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe -m pytest . --durations=0
else
$CONDA/bin/python -m pytest . --durations=0
fi
shell: bash
- name: Test data downloads
- name: Test workflow by config (install from source)
run: |
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
- name: Test workflow by config
run: |
qrun examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe qlib\\workflow\\cli.py examples\\benchmarks\\LightGBM\\workflow_config_lightgbm_Alpha158.yaml
else
$CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
fi
shell: bash

5
.gitignore vendored
View File

@@ -2,6 +2,7 @@
__pycache__/
*.pyc
*.pyd
*.so
*.ipynb
.ipynb_checkpoints
@@ -33,3 +34,7 @@ tags
.pytest_cache/
.vscode/
*.swp
./pretrain

21
.readthedocs.yml Normal file
View File

@@ -0,0 +1,21 @@
# .readthedocs.yml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
# Required
version: 2
# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/conf.py
# Build all formats
formats: all
# Optionally set the version of Python and requirements required to build your docs
python:
version: 3.7
install:
- requirements: docs/requirements.txt
- method: setuptools
path: .

View File

@@ -114,7 +114,7 @@ Version 0.4.1
Version 0.4.2
--------------------
- Refactor DataHandler
- Add ``ALPHA360`` DataHandler
- Add ``Alpha360`` DataHandler
Version 0.4.3

155
README.md
View File

@@ -7,6 +7,20 @@
[![License](https://img.shields.io/pypi/l/pyqlib)](LICENSE)
[![Join the chat at https://gitter.im/Microsoft/qlib](https://badges.gitter.im/Microsoft/qlib.svg)](https://gitter.im/Microsoft/qlib?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
## :newspaper: **What's NEW!** &nbsp; :sparkling_heart:
Recent released features
| Feature | Status |
| -- | ------ |
| Online serving and automatic model rolling | :star: [Released](https://github.com/microsoft/qlib/pull/290) on May 17, 2021 |
| DoubleEnsemble Model | [Released](https://github.com/microsoft/qlib/pull/286) on Mar 2, 2021 |
| High-frequency data processing example | [Released](https://github.com/microsoft/qlib/pull/257) on Feb 5, 2021 |
| High-frequency trading example | [Part of code released](https://github.com/microsoft/qlib/pull/227) on Jan 28, 2021 |
| High-frequency data(1min) | [Released](https://github.com/microsoft/qlib/pull/221) on Jan 27, 2021 |
| Tabnet Model | [Released](https://github.com/microsoft/qlib/pull/205) on Jan 22, 2021 |
Features released before 2021 are not listed here.
<p align="center">
<img src="http://fintech.msra.cn/images_v060/logo/1.png" />
@@ -17,10 +31,11 @@ Qlib is an AI-oriented quantitative investment platform, which aims to realize t
It contains the full ML pipeline of data processing, model training, back-testing; and covers the entire chain of quantitative investment: alpha seeking, risk modeling, portfolio optimization, and order execution.
With Qlib, user can easily try ideas to create better Quant investment strategies.
With Qlib, users can easily try ideas to create better Quant investment strategies.
For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative Investment Platform"](https://arxiv.org/abs/2009.11189).
- [**Plans**](#plans)
- [Framework of Qlib](#framework-of-qlib)
- [Quick Start](#quick-start)
- [Installation](#installation)
@@ -34,9 +49,21 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
- [More About Qlib](#more-about-qlib)
- [Offline Mode and Online Mode](#offline-mode-and-online-mode)
- [Performance of Qlib Data Server](#performance-of-qlib-data-server)
- [Related Reports](#related-reports)
- [Contact Us](#contact-us)
- [Contributing](#contributing)
# Plans
New features under development(order by estimated release time).
Your feedbacks about the features are very important.
| Feature | Status |
| -- | ------ |
| Planning-based portfolio optimization | Under review: https://github.com/microsoft/qlib/pull/280 |
| Fund data supporting and analysis | Under review: https://github.com/microsoft/qlib/pull/292 |
| Point-in-Time database | Under review: https://github.com/microsoft/qlib/pull/343 |
| High-frequency trading | Under review: https://github.com/microsoft/qlib/pull/408 |
| Meta-Learning-based data selection | Initial opensource version under development |
# Framework of Qlib
@@ -45,11 +72,11 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
</div>
At the module level, Qlib is a platform that consists of the above components. The components are designed as loose-coupled modules and each component could be used stand-alone.
At the module level, Qlib is a platform that consists of the above components. The components are designed as loose-coupled modules, and each component could be used stand-alone.
| Name | Description |
| ------ | ----- |
| `Infrastructure` layer | `Infrastructure` layer provides underlying support for Quant research. `DataServer` provides high-performance infrastructure for users to manage and retrieve raw data. `Trainer` provides flexible interface to control the training process of models which enable algorithms controlling the training process. |
| `Infrastructure` layer | `Infrastructure` layer provides underlying support for Quant research. `DataServer` provides a high-performance infrastructure for users to manage and retrieve raw data. `Trainer` provides a flexible interface to control the training process of models, which enable algorithms to control the training process. |
| `Workflow` layer | `Workflow` layer covers the whole workflow of quantitative investment. `Information Extractor` extracts data for models. `Forecast Model` focuses on producing all kinds of forecast signals (e.g. _alpha_, risk) for other modules. With these signals `Portfolio Generator` will generate the target portfolio and produce orders to be executed by `Order Executor`. |
| `Interface` layer | `Interface` layer tries to present a user-friendly interface for the underlying system. `Analyser` module will provide users detailed analysis reports of forecasting signals, portfolios and execution results |
@@ -61,17 +88,36 @@ At the module level, Qlib is a platform that consists of the above components. T
This quick start guide tries to demonstrate
1. It's very easy to build a complete Quant research workflow and try your ideas with _Qlib_.
1. Though with *public data* and *simple models*, machine learning technologies **work very well** in practical Quant investment.
2. Though with *public data* and *simple models*, machine learning technologies **work very well** in practical Quant investment.
Here is a quick **[demo](https://terminalizer.com/view/3f24561a4470)** shows how to install ``Qlib``, and run LightGBM with ``qrun``. **But**, please make sure you have already prepared the data following the [instruction](#data-preparation).
## Installation
Users can easily install ``Qlib`` by pip according to the following command
This table demonstrates the supported Python version of `Qlib`:
| | install with pip | install from source | plot |
| ------------- |:---------------------:|:--------------------:|:----:|
| Python 3.6 | :heavy_check_mark: | :heavy_check_mark: (only with `Anaconda`) | :heavy_check_mark: |
| Python 3.7 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Python 3.8 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
**Note**:
1. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.7 or use `conda`'s Python to install ``Qlib`` from source.
2. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
### Install with pip
Users can easily install ``Qlib`` by pip according to the following command.
```bash
pip install pyqlib
```
Also, users can install ``Qlib`` by the source code according to the following steps:
**Note**: pip will install the latest stable qlib. However, the main branch of qlib is in active development. If you want to test the latest scripts or functions in the main branch. Please install qlib with the methods below.
### Install from source
Also, users can install the latest dev version ``Qlib`` by the source code according to the following steps:
* Before installing ``Qlib`` from source, users need to install some dependencies:
@@ -80,25 +126,38 @@ Also, users can install ``Qlib`` by the source code according to the following s
pip install --upgrade cython
```
* Clone the repository and install ``Qlib``:
```bash
git clone https://github.com/microsoft/qlib.git && cd qlib
python setup.py install
```
* Clone the repository and install ``Qlib`` as follows.
* If you haven't installed qlib by the command ``pip install pyqlib`` before:
```bash
git clone https://github.com/microsoft/qlib.git && cd qlib
python setup.py install
```
* If you have already installed the stable version by the command ``pip install pyqlib``:
```bash
git clone https://github.com/microsoft/qlib.git && cd qlib
pip install .
```
**Note**: **Only** the command ``pip install .`` **can** overwrite the stable version installed by ``pip install pyqlib``, while the command ``python setup.py install`` **can't**.
**Tips**: If you fail to install `Qlib` or run the examples in your environment, comparing your steps and the [CI workflow](.github/workflows/test.yml) may help you find the problem.
## Data Preparation
Load and prepare data by running the following code:
```bash
# get 1d data
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
# get 1min data
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min
```
This dataset is created by public data collected by [crawler scripts](scripts/data_collector/), which have been released in
the same repository.
Users could create the same dataset with it.
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup), and the data might not be perfect.
We recommend users to prepare their own data if they have a high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
<!--
- Run the initialization code and get stock data:
@@ -130,12 +189,16 @@ Users could create the same dataset with it.
## Auto Quant Research Workflow
Qlib provides a tool named `qrun` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). You can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
1. Quant Research Workflow: Run `qrun` with lightgbm workflow config ([workflow_config_lightgbm.yaml](examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml)) as following.
1. Quant Research Workflow: Run `qrun` with lightgbm workflow config ([workflow_config_lightgbm_Alpha158.yaml](examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml) as following.
```bash
cd examples # Avoid running program under the directory contains `qlib`
qrun benchmarks/LightGBM/workflow_config_lightgbm.yaml
qrun benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
```
The result of `qrun` is as follows, please refer to please refer to [Intraday Trading](https://qlib.readthedocs.io/en/latest/component/backtest.html) for more details about the result.
If users want to use `qrun` under debug mode, please use the following command:
```bash
python -m pdb qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
```
The result of `qrun` is as follows, please refer to [Intraday Trading](https://qlib.readthedocs.io/en/latest/component/backtest.html) for more details about the result.
```bash
@@ -153,9 +216,6 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
annualized_return 0.128982
information_ratio 1.444287
max_drawdown -0.091078
```
Here are detailed documents for `qrun` and [workflow](https://qlib.readthedocs.io/en/latest/component/workflow.html).
@@ -185,40 +245,45 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
- Rank Label
![Rank Label](docs/_static/img/rank_label.png)
-->
- [Explanation](https://qlib.readthedocs.io/en/latest/component/report.html) of above results
## Building Customized Quant Research Workflow by Code
The automatic workflow may not suite the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
The automatic workflow may not suit the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
# [Quant Model Zoo](examples/benchmarks)
Here is a list of models built on `Qlib`.
- [GBDT based on LightGBM](qlib/contrib/model/gbdt.py)
- [GBDT based on Catboost](qlib/contrib/model/catboost_model.py)
- [GBDT based on XGBoost](qlib/contrib/model/xgboost.py)
- [GBDT based on XGBoost (Tianqi Chen, et al. 2016)](qlib/contrib/model/xgboost.py)
- [GBDT based on LightGBM (Guolin Ke, et al. 2017)](qlib/contrib/model/gbdt.py)
- [GBDT based on Catboost (Liudmila Prokhorenkova, et al. 2017)](qlib/contrib/model/catboost_model.py)
- [MLP based on pytorch](qlib/contrib/model/pytorch_nn.py)
- [GRU based on pytorch](qlib/contrib/model/pytorch_gru.py)
- [LSTM based on pytorcn](qlib/contrib/model/pytorch_lstm.py)
- [ALSTM based on pytorcn](qlib/contrib/model/pytorch_alstm.py)
- [GATs based on pytorch](qlib/contrib/model/pytorch_gats.py)
- [SFM based on pytorch](qlib/contrib/model/pytorch_sfm.py)
<!-- - [TFT based on tensorflow](examples/benchmarks/TFT/tft.py) -->
- [LSTM based on pytorch (Sepp Hochreiter, et al. 1997)](qlib/contrib/model/pytorch_lstm.py)
- [GRU based on pytorch (Kyunghyun Cho, et al. 2014)](qlib/contrib/model/pytorch_gru.py)
- [ALSTM based on pytorch (Yao Qin, et al. 2017)](qlib/contrib/model/pytorch_alstm.py)
- [GATs based on pytorch (Petar Velickovic, et al. 2017)](qlib/contrib/model/pytorch_gats.py)
- [SFM based on pytorch (Liheng Zhang, et al. 2017)](qlib/contrib/model/pytorch_sfm.py)
- [TFT based on tensorflow (Bryan Lim, et al. 2019)](examples/benchmarks/TFT/tft.py)
- [TabNet based on pytorch (Sercan O. Arik, et al. 2019)](qlib/contrib/model/pytorch_tabnet.py)
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. 2020)](qlib/contrib/model/double_ensemble.py)
Your PR of new Quant models is highly welcomed.
The performance of each model on the `Alpha158` and `Alpha360` dataset can be found [here](examples/benchmarks/README.md).
## Run a single model
All the models listed above are runnable with ``Qlib``. Users can find the config files we provide and some details about the model through the [benchmarks](examples/benchmarks) folder. More information can be retrieved at the model files listed above.
`Qlib` provides three different ways to run a single model, users can pick the one that fits their cases best:
- User can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
- User can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
- Users can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
- Users can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
- User can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
## Run multiple models
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only supprots *Linux* now. Other OS will be supported in the future.)
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parrallel running the same model for multiple times as well, and this will be fixed in the future development too.)
The script will create a unique virtual environment for each model, and delete the environments after training. Thus, only experiment results such as `IC` and `backtest` results will be generated and stored. (**Note**: the script will erase your previous experiment records created by running itself.)
The script will create a unique virtual environment for each model, and delete the environments after training. Thus, only experiment results such as `IC` and `backtest` results will be generated and stored.
Here is an example of running all the models for 10 iterations:
```python
@@ -229,12 +294,12 @@ It also provides the API to run specific models at once. For more use cases, ple
# Quant Dataset Zoo
Dataset plays a very important role in Quant. Here is a list of the datasets built on `Qlib`.
Dataset plays a very important role in Quant. Here is a list of the datasets built on `Qlib`:
| Dataset | US Market | China Market |
| -- | -- | -- |
| [Alpha360](./qlib/contrib/data/handler.py) | √ | √ |
| [Alpha158](./qlib/contrib/data/handler.py) | √ | √ |
| [Alpha158](./qlib/contrib/data/handler.py) | √ | √ |
[Here](https://qlib.readthedocs.io/en/latest/advanced/alpha.html) is a tutorial to build dataset with `Qlib`.
Your PR to build new Quant dataset is highly welcomed.
@@ -276,13 +341,27 @@ which creates a dataset (14 features/factors) from the basic OHLCV daily data of
* `+(-)E` indicates with (out) `ExpressionCache`
* `+(-)D` indicates with (out) `DatasetCache`
Most general-purpose databases take too much time on loading data. After looking into the underlying implementation, we find that data go through too many layers of interfaces and unnecessary format transformations in general-purpose database solutions.
Most general-purpose databases take too much time to load data. After looking into the underlying implementation, we find that data go through too many layers of interfaces and unnecessary format transformations in general-purpose database solutions.
Such overheads greatly slow down the data loading process.
Qlib data are stored in a compact format, which is efficient to be combined into arrays for scientific computation.
# Related Reports
- [【华泰金工林晓明团队】图神经网络选股与Qlib实践——华泰人工智能系列之四十二](https://mp.weixin.qq.com/s/w5fDB6oAv9dO6vlhf1kmhA)
- [Guide To Qlib: Microsofts AI Investment Platform](https://analyticsindiamag.com/qlib/)
- [【华泰金工林晓明团队】微软AI量化投资平台Qlib体验——华泰人工智能系列之四十](https://mp.weixin.qq.com/s/Brcd7im4NibJOJzZfMn6tQ)
- [微软也搞AI量化平台还是开源的](https://mp.weixin.qq.com/s/47bP5YwxfTp2uTHjUBzJQQ)
- [微矿Qlib业内首个AI量化投资开源平台](https://mp.weixin.qq.com/s/vsJv7lsgjEi-ALYUz4CvtQ)
# Contact Us
- If you have any issues, please create issue [here](https://github.com/microsoft/qlib/issues/new/choose) or send messages in [gitter](https://gitter.im/Microsoft/qlib).
- If you want to make contributions to `Qlib`, please [create pull requests](https://github.com/microsoft/qlib/compare).
- For other reasons, you are welcome to contact us by email([qlib@microsoft.com](mailto:qlib@microsoft.com)).
- We are recruiting new members(both FTEs and interns), your resumes are welcome!
Join IM discussion groups:
|[Gitter](https://gitter.im/Microsoft/qlib)|
|----|
|![image](http://fintech.msra.cn/images_v060/qrcode/gitter_qr.png)|
# Contributing

View File

@@ -70,3 +70,31 @@ If the issue is not resolved, use ``keys *`` to find if multiple keys exist. If
Also, feel free to post a new issue in our GitHub repository. We always check each issue carefully and try our best to solve them.
3. ModuleNotFoundError: No module named 'qlib.data._libs.rolling'
------------------------------------------------------------------------------------------------------------------------------------
.. code-block:: python
#### Do not import qlib package in the repository directory in case of importing qlib from . without compiling #####
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "qlib/qlib/__init__.py", line 19, in init
from .data.cache import H
File "qlib/qlib/data/__init__.py", line 8, in <module>
from .data import (
File "qlib/qlib/data/data.py", line 20, in <module>
from .cache import H
File "qlib/qlib/data/cache.py", line 36, in <module>
from .ops import Operators
File "qlib/qlib/data/ops.py", line 19, in <module>
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
ModuleNotFoundError: No module named 'qlib.data._libs.rolling'
- If the error occurs when importing ``qlib`` package with ``PyCharm`` IDE, users can execute the following command in the project root folder to compile Cython files and generate executable files:
.. code-block:: bash
python setup.py build_ext --inplace
- If the error occurs when importing ``qlib`` package with command ``python`` , users need to change the running directory to ensure that the script does not run in the project directory.

12
docs/_static/demo.sh vendored Normal file
View File

@@ -0,0 +1,12 @@
#!/bin/sh
git clone https://github.com/microsoft/qlib.git
cd qlib
ls
pip install pyqlib
# or
# pip install numpy
# pip install --upgrade cython
# python setup.py install
cd examples
ls
qrun benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml

BIN
docs/_static/img/online_serving.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 440 KiB

BIN
docs/_static/img/qrcode/gitter_qr.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.2 KiB

View File

@@ -50,57 +50,37 @@ Users can use ``Data Handler`` to build formulaic alphas `MACD` in qlib:
.. code-block:: python
>> from qlib.data.dataset.handler import QLibDataHandler
>> from qlib.data.dataset.loader import QlibDataLoader
>> MACD_EXP = '(EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9)/$close'
>> fields = [MACD_EXP] # MACD
>> names = ['MACD']
>> labels = ['$close'] # label
>> labels = ['Ref($close, -2)/Ref($close, -1) - 1'] # label
>> label_names = ['LABEL']
>> data_handler = QLibDataHandler(start_date='2010-01-01', end_date='2017-12-31', fields=fields, names=names, labels=labels, label_names=label_names)
>> TRAINER_CONFIG = {
.. "train_start_date": "2007-01-01",
.. "train_end_date": "2014-12-31",
.. "validate_start_date": "2015-01-01",
.. "validate_end_date": "2016-12-31",
.. "test_start_date": "2017-01-01",
.. "test_end_date": "2020-08-01",
>> data_loader_config = {
.. "feature": (fields, names),
.. "label": (labels, label_names)
.. }
>> feature_train, label_train, feature_validate, label_validate, feature_test, label_test = data_handler.get_split_data(**TRAINER_CONFIG)
>> print(feature_train, label_train)
MACD
instrument datetime
SH600000 2010-01-04 -0.008625
2010-01-05 -0.007234
2010-01-06 -0.007693
2010-01-07 -0.009633
2010-01-08 -0.009891
... ...
SZ300251 2014-12-25 0.043072
2014-12-26 0.041345
2014-12-29 0.042733
2014-12-30 0.042066
2014-12-31 0.036299
[322025 rows x 1 columns]
LABEL
instrument datetime
SH600000 2010-01-04 4.260015
2010-01-05 4.292182
2010-01-06 4.207747
2010-01-07 4.113258
2010-01-08 4.159496
... ...
SZ300251 2014-12-25 4.343212
2014-12-26 4.470587
2014-12-29 4.762474
2014-12-30 4.369748
2014-12-31 4.182222
[322025 rows x 1 columns]
>> data_loader = QlibDataLoader(config=data_loader_config)
>> df = data_loader.load(instruments='csi300', start_time='2010-01-01', end_time='2017-12-31')
>> print(df)
feature label
MACD LABEL
datetime instrument
2010-01-04 SH600000 -0.011547 -0.019672
SH600004 0.002745 -0.014721
SH600006 0.010133 0.002911
SH600008 -0.001113 0.009818
SH600009 0.025878 -0.017758
... ... ...
2017-12-29 SZ300124 0.007306 -0.005074
SZ300136 -0.013492 0.056352
SZ300144 -0.000966 0.011853
SZ300251 0.004383 0.021739
SZ300315 -0.030557 0.012455
Reference
===========
To learn more about ``Data Handler``, please refer to `Data Handler <../component/data.html>`_
To learn more about ``Data Loader``, please refer to `Data Loader <../component/data.html#data-loader>`_
To learn more about ``Data API``, please refer to `Data API <../component/data.html>`_

45
docs/advanced/serial.rst Normal file
View File

@@ -0,0 +1,45 @@
.. _serial:
=================================
Serialization
=================================
.. currentmodule:: qlib
Introduction
===================
``Qlib`` supports dumping the state of ``DataHandler``, ``DataSet``, ``Processor`` and ``Model``, etc. into a disk and reloading them.
Serializable Class
========================
``Qlib`` provides a base class ``qlib.utils.serial.Serializable``, whose state can be dumped into or loaded from disk in `pickle` format.
When users dump the state of a ``Serializable`` instance, the attributes of the instance whose name **does not** start with `_` will be saved on the disk.
However, users can use ``config`` method or override ``default_dump_all`` attribute to prevent this feature.
Users can also override ``pickle_backend`` attribute to choose a pickle backend. The supported value is "pickle" (default and common) and "dill" (dump more things such as function, more information in `here <https://pypi.org/project/dill/>`_).
Example
==========================
``Qlib``'s serializable class includes ``DataHandler``, ``DataSet``, ``Processor`` and ``Model``, etc., which are subclass of ``qlib.utils.serial.Serializable``.
Specifically, ``qlib.data.dataset.DatasetH`` is one of them. Users can serialize ``DatasetH`` as follows.
.. code-block:: Python
##=============dump dataset=============
dataset.to_pickle(path="dataset.pkl") # dataset is an instance of qlib.data.dataset.DatasetH
##=============reload dataset=============
with open("dataset.pkl", "rb") as file_dataset:
dataset = pickle.load(file_dataset)
.. note::
Only state of ``DatasetH`` should be saved on the disk, such as some `mean` and `variance` used for data normalization, etc.
After reloading the ``DatasetH``, users need to reinitialize it. It means that users can reset some states of ``DatasetH`` or ``QlibDataHandler`` such as `instruments`, `start_time`, `end_time` and `segments`, etc., and generate new data according to the states (data is not state and should not be saved on the disk).
A more detailed example is in this `link <https://github.com/microsoft/qlib/tree/main/examples/highfreq>`_.
API
===================
Please refer to `Serializable API <../reference/api.html#module-qlib.utils.serial.Serializable>`_.

View File

@@ -0,0 +1,89 @@
.. _task_management:
=================================
Task Management
=================================
.. currentmodule:: qlib
Introduction
=============
The `Workflow <../component/introduction.html>`_ part introduces how to run research workflow in a loosely-coupled way. But it can only execute one ``task`` when you use ``qrun``.
To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Training`_ and `Task Collecting`_.
With this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models.
This whole process can be used in `Online Serving <../component/online.html>`_.
An example of the entire process is shown `here <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
Task Generating
===============
A ``task`` consists of `Model`, `Dataset`, `Record`, or anything added by users.
The specific task template can be viewed in
`Task Section <../component/workflow.html#task-section>`_.
Even though the task template is fixed, users can customize their ``TaskGen`` to generate different ``task`` by task template.
Here is the base class of ``TaskGen``:
.. autoclass:: qlib.workflow.task.gen.TaskGen
:members:
``Qlib`` provides a class `RollingGen <https://github.com/microsoft/qlib/tree/main/qlib/workflow/task/gen.py>`_ to generate a list of ``task`` of the dataset in different date segments.
This class allows users to verify the effect of data from different periods on the model in one experiment. More information is `here <../reference/api.html#TaskGen>`_.
Task Storing
===============
To achieve higher efficiency and the possibility of cluster operation, ``Task Manager`` will store all tasks in `MongoDB <https://www.mongodb.com/>`_.
``TaskManager`` can fetch undone tasks automatically and manage the lifecycle of a set of tasks with error handling.
Users **MUST** finish the configuration of `MongoDB <https://www.mongodb.com/>`_ when using this module.
Users need to provide the MongoDB URL and database name for using ``TaskManager`` in `initialization <../start/initialization.html#Parameters>`_ or make a statement like this.
.. code-block:: python
from qlib.config import C
C["mongo"] = {
"task_url" : "mongodb://localhost:27017/", # your MongoDB url
"task_db_name" : "rolling_db" # database name
}
.. autoclass:: qlib.workflow.task.manage.TaskManager
:members:
More information of ``Task Manager`` can be found in `here <../reference/api.html#TaskManager>`_.
Task Training
===============
After generating and storing those ``task``, it's time to run the ``task`` which is in the *WAITING* status.
``Qlib`` provides a method called ``run_task`` to run those ``task`` in task pool, however, users can also customize how tasks are executed.
An easy way to get the ``task_func`` is using ``qlib.model.trainer.task_train`` directly.
It will run the whole workflow defined by ``task``, which includes *Model*, *Dataset*, *Record*.
.. autofunction:: qlib.workflow.task.manage.run_task
Meanwhile, ``Qlib`` provides a module called ``Trainer``.
.. autoclass:: qlib.model.trainer.Trainer
:members:
``Trainer`` will train a list of tasks and return a list of model recorders.
``Qlib`` offer two kinds of Trainer, TrainerR is the simplest way and TrainerRM is based on TaskManager to help manager tasks lifecycle automatically.
If you do not want to use ``Task Manager`` to manage tasks, then use TrainerR to train a list of tasks generated by ``TaskGen`` is enough.
`Here <../reference/api.html#Trainer>`_ are the details about different ``Trainer``.
Task Collecting
===============
To collect the results of ``task`` after training, ``Qlib`` provides `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_ to collect the results in a readable, expandable and loosely-coupled way.
`Collector <../reference/api.html#Collector>`_ can collect objects from everywhere and process them such as merging, grouping, averaging and so on. It has 2 step action including ``collect`` (collect anything in a dict) and ``process_collect`` (process collected dict).
`Group <../reference/api.html#Group>`_ also has 2 steps including ``group`` (can group a set of object based on `group_func` and change them to a dict) and ``reduce`` (can make a dict become an ensemble based on some rule).
For example: {(A,B,C1): object, (A,B,C2): object} ---``group``---> {(A,B): {C1: object, C2: object}} ---``reduce``---> {(A,B): object}
`Ensemble <../reference/api.html#Ensemble>`_ can merge the objects in an ensemble.
For example: {C1: object, C2: object} ---``Ensemble``---> object
So the hierarchy is ``Collector``'s second step corresponds to ``Group``. And ``Group``'s second step correspond to ``Ensemble``.
For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.

View File

@@ -31,7 +31,7 @@ Qlib Format Data
We've specially designed a data structure to manage financial data, please refer to the `File storage design section in Qlib paper <https://arxiv.org/abs/2009.11189>`_ for detailed information.
Such data will be stored with filename suffix `.bin` (We'll call them `.bin` file, `.bin` format, or qlib format). `.bin` file is designed for scientific computing on finance data.
``Qlib`` provides two different off-the-shelf dataset, which can be accessed through this `link <https://github.com/microsoft/qlib/blob/main/qlib/contrib/data/handler.py>`_:
``Qlib`` provides two different off-the-shelf datasets, which can be accessed through this `link <https://github.com/microsoft/qlib/blob/main/qlib/contrib/data/handler.py>`_:
======================== ================= ================
Dataset US Market China Market
@@ -41,6 +41,7 @@ Alpha360 √ √
Alpha158 √ √
======================== ================= ================
Also, ``Qlib`` provides a high-frequency dataset. Users can run a high-frequency dataset example through this `link <https://github.com/microsoft/qlib/tree/main/examples/highfreq>`_.
Qlib Format Dataset
--------------------
@@ -48,15 +49,19 @@ Qlib Format Dataset
.. code-block:: bash
# download 1d
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
# download 1min
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --region cn --interval 1min
In addition to China-Stock data, ``Qlib`` also includes a US-Stock dataset, which can be downloaded with the following command:
.. code-block:: bash
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us
After running the above command, users can find china-stock and us-stock data in ``Qlib`` format in the ``~/.qlib/csv_data/cn_data`` directory and ``~/.qlib/csv_data/us_data`` directory respectively.
After running the above command, users can find china-stock and us-stock data in ``Qlib`` format in the ``~/.qlib/qlib_data/cn_data`` directory and ``~/.qlib/qlib_data/us_data`` directory respectively.
``Qlib`` also provides the scripts in ``scripts/data_collector`` to help users crawl the latest data on the Internet and convert it to qlib format.
@@ -67,12 +72,19 @@ Converting CSV Format into Qlib Format
``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert **any** data in CSV format into `.bin` files (``Qlib`` format) as long as they are in the correct format.
Users can download the demo china-stock data in CSV format as follows for reference to the CSV format.
Besides downloading the prepared demo data, users could download demo data directly from the Collector as follows for reference to the CSV format.
Here are some example:
.. code-block:: bash
for daily data:
.. code-block:: bash
python scripts/get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
for 1min data:
.. code-block:: bash
python scripts/data_collector/yahoo/collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1min --region CN --start 2021-05-20 --end 2021-05-23 --delay 0.1 --interval 1min --limit_nums 10
Users can also provide their own data in CSV format. However, the CSV data **must satisfies** following criterions:
- CSV file is named after a specific stock *or* the CSV file includes a column of the stock name
@@ -126,20 +138,30 @@ After conversion, users can find their Qlib format data in the directory `~/.qli
The arguments of `--include_fields` should correspond with the column names of CSV files. The columns names of dataset provided by ``Qlib`` should include open, close, high, low, volume and factor at least.
- `open`
The opening price
The adjusted opening price
- `close`
The closing price
The adjusted closing price
- `high`
The highest price
The adjusted highest price
- `low`
The lowest price
The adjusted lowest price
- `volume`
The trading volume
The adjusted trading volume
- `factor`
The Restoration factor
The Restoration factor. Normally, ``factor = adjusted_price / original_price``, `adjusted price` reference: `split adjusted <https://www.investopedia.com/terms/s/splitadjusted.asp>`_
In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended.
Stock Pool (Market)
--------------------------------
``Qlib`` defines `stock pool <https://github.com/microsoft/qlib/blob/main/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml#L4>`_ as stock list and their date ranges. Predefined stock pools (e.g. csi300) may be imported as follows.
.. code-block:: bash
python collector.py --index_name CSI300 --qlib_dir <user qlib data dir> --method parse_instruments
Multiple Stock Modes
--------------------------------
@@ -158,7 +180,7 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
- If users use ``Qlib`` in china-stock mode, china-stock data is required. Users can use ``Qlib`` in china-stock mode according to the following steps:
- Download china-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
- Initialize ``Qlib`` in china-stock mode
Supposed that users download their Qlib format data in the directory ``~/.qlib/csv_data/cn_data``. Users only need to initialize ``Qlib`` as follows.
Supposed that users download their Qlib format data in the directory ``~/.qlib/qlib_data/cn_data``. Users only need to initialize ``Qlib`` as follows.
.. code-block:: python
@@ -167,9 +189,9 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
- If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` also provides a script to download US-stock data. Users can use ``Qlib`` in US-stock mode according to the following steps:
- Download china-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
- Download us-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
- Initialize ``Qlib`` in US-stock mode
Supposed that users prepare their Qlib format data in the directory ``~/.qlib/csv_data/us_data``. Users only need to initialize ``Qlib`` as follows.
Supposed that users prepare their Qlib format data in the directory ``~/.qlib/qlib_data/us_data``. Users only need to initialize ``Qlib`` as follows.
.. code-block:: python
@@ -177,6 +199,11 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
qlib.init(provider_uri='~/.qlib/qlib_data/us_data', region=REG_US)
.. note::
PRs for new data source are highly welcome! Users could commit the code to crawl data as a PR like `the examples here <https://github.com/microsoft/qlib/tree/main/scripts>`_. And then we will use the code to create data cache on our server which other users could use directly.
Data API
========================
@@ -195,6 +222,7 @@ Feature
- `ExpressionOps`
`ExpressionOps` will use operator for feature construction.
To know more about ``Operator``, please refer to `Operator API <../reference/api.html#module-qlib.data.ops>`_.
Also, ``Qlib`` supports users to define their own custom ``Operator``, an example has been given in ``tests/test_register_ops.py``.
To know more about ``Feature``, please refer to `Feature API <../reference/api.html#module-qlib.data.base>`_.
@@ -212,6 +240,25 @@ Filter
- `cross-sectional features filter` \: rule_expression = '$rank($close)<10'
- `time-sequence features filter`: rule_expression = '$Ref($close, 3)>100'
Here is a simple example showing how to use filter in a basic ``Qlib`` workflow configuration file:
.. code-block:: yaml
filter: &filter
filter_type: ExpressionDFilter
rule_expression: "Ref($close, -2) / Ref($close, -1) > 1"
filter_start_time: 2010-01-01
filter_end_time: 2010-01-07
keep: False
data_handler_config: &data_handler_config
start_time: 2010-01-01
end_time: 2021-01-22
fit_start_time: 2010-01-01
fit_end_time: 2015-12-31
instruments: *market
filter_pipe: [*filter]
To know more about ``Filter``, please refer to `Filter API <../reference/api.html#module-qlib.data.filter>`_.
Reference
@@ -273,9 +320,10 @@ Here are some important interfaces that ``DataHandlerLP`` provides:
.. autoclass:: qlib.data.dataset.handler.DataHandlerLP
:members: __init__, fetch, get_cols
If users want to load features and labels by config, users can inherit ``qlib.data.dataset.handler.ConfigDataHandler``, ``Qlib`` also provides some preprocess method in this subclass.
If users want to use qlib data, `QLibDataHandler` is recommended. Users can inherit their custom class from `QLibDataHandler`, which is also a subclass of `ConfigDataHandler`.
If users want to load features and labels by config, users can define a new handler and call the static method `parse_config_to_fields` of ``qlib.contrib.data.handler.Alpha158``.
Also, users can pass ``qlib.contrib.data.processor.ConfigSectionProcessor`` that provides some preprocess methods for features defined by config into the new handler.
Processor
@@ -295,6 +343,7 @@ The ``Processor`` module in ``Qlib`` is designed to be learnable and it is respo
- ``RobustZScoreNorm``: `processor` that applies robust z-score normalization.
- ``CSZScoreNorm``: `processor` that applies cross sectional z-score normalization.
- ``CSRankNorm``: `processor` that applies cross sectional rank normalization.
- ``CSZFillna``: `processor` that fills N/A values in a cross sectional way by the mean of the column.
Users can also create their own `processor` by inheriting the base class of ``Processor``. Please refer to the implementation of all the processors for more information (`Processor Link <https://github.com/microsoft/qlib/blob/main/qlib/data/dataset/processor.py>`_).
@@ -311,7 +360,6 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
.. note:: Users need to initialize ``Qlib`` with `qlib.init` first, please refer to `initialization <../start/initialization.html>`_.
.. code-block:: Python
import qlib
@@ -338,6 +386,9 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
# fetch all the features
print(h.fetch(col_set="feature"))
.. note:: In the ``Alpha158``, ``Qlib`` uses the label `Ref($close, -2)/Ref($close, -1) - 1` that means the change from T+1 to T+2, rather than `Ref($close, -1)/$close - 1`, of which the reason is that when getting the T day close price of a china stock, the stock can be bought on T+1 day and sold on T+2 day.
API
---------
@@ -362,8 +413,7 @@ The ``DatasetH`` class is the `dataset` with `Data Handler`. Here is the most im
API
---------
To know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#module-qlib.data.dataset.__init__>`_.
To know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#dataset>`_.
Cache

46
docs/component/online.rst Normal file
View File

@@ -0,0 +1,46 @@
.. _online:
=================================
Online Serving
=================================
.. currentmodule:: qlib
Introduction
=============
.. image:: ../_static/img/online_serving.png
:align: center
In addition to backtesting, one way to test a model is effective is to make predictions in real market conditions or even do real trading based on those predictions.
``Online Serving`` is a set of modules for online models using the latest data,
which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online Strategy>`_, `Online Tool <#Online Tool>`_, `Updater <#Updater>`_.
`Here <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are several examples for reference, which demonstrate different features of ``Online Serving``.
If you have many models or `task` needs to be managed, please consider `Task Management <../advanced/task_management.html>`_.
The `examples <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are based on some components in `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``.
Online Manager
=============
.. automodule:: qlib.workflow.online.manager
:members:
Online Strategy
=============
.. automodule:: qlib.workflow.online.strategy
:members:
Online Tool
=============
.. automodule:: qlib.workflow.online.utils
:members:
Updater
=============
.. automodule:: qlib.workflow.online.update
:members:

View File

@@ -34,8 +34,10 @@ Here is a general view of the structure of the system:
- Recorder 2
- ...
- ...
This experiment management system defines a set of interface and provided a concrete implementation based on the machine learning platform: ``MLFlow`` (`link <https://mlflow.org/>`_).
This experiment management system defines a set of interface and provided a concrete implementation ``MLflowExpManager``, which is based on the machine learning platform: ``MLFlow`` (`link <https://mlflow.org/>`_).
If users set the implementation of ``ExpManager`` to be ``MLflowExpManager``, they can use the command `mlflow ui` to visualize and check the experiment results. For more information, pleaes refer to the related documents `here <https://www.mlflow.org/docs/latest/cli.html#mlflow-ui>`_.
Qlib Recorder
===================
@@ -91,8 +93,54 @@ Record Template
The ``RecordTemp`` class is a class that enables generate experiment results such as IC and backtest in a certain format. We have provided three different `Record Template` class:
- ``SignalRecord``: This class generates the `preidction` results of the model.
- ``SignalRecord``: This class generates the `prediction` results of the model.
- ``SigAnaRecord``: This class generates the `IC`, `ICIR`, `Rank IC` and `Rank ICIR` of the model.
Here is a simple example of what is done in ``SigAnaRecord``, which users can refer to if they want to calculate IC, Rank IC, Long-Short Return with their own prediction and label.
.. code-block:: Python
from qlib.contrib.eva.alpha import calc_ic, calc_long_short_return
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0])
- ``PortAnaRecord``: This class generates the results of `backtest`. The detailed information about `backtest` as well as the available `strategy`, users can refer to `Strategy <../component/strategy.html>`_ and `Backtest <../component/backtest.html>`_.
Here is a simple exampke of what is done in ``PortAnaRecord``, which users can refer to if they want to do backtest based on their own prediction and label.
.. code-block:: Python
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
# backtest
STRATEGY_CONFIG = {
"topk": 50,
"n_drop": 5,
}
BACKTEST_CONFIG = {
"verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": BENCHMARK,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
}
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
# analysis
analysis = dict()
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
analysis["excess_return_with_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"])
analysis_df = pd.concat(analysis) # type: pd.DataFrame
print(analysis_df)
For more information about the APIs, please refer to `Record Template API <../reference/api.html#module-qlib.workflow.record_temp>`_.

View File

@@ -101,7 +101,7 @@ Graphical Result
- Axis Y:
- `ic`
The `Pearson correlation coefficient` series between `label` and `prediction score`.
In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Featrue <data.html#feature>`_ for more details.
In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Feature <data.html#feature>`_ for more details.
- `rank_ic`
The `Spearman's rank correlation coefficient` series between `label` and `prediction score`.

View File

@@ -111,8 +111,6 @@ Usage & Example
pred_score, strategy=strategy, **BACKTEST_CONFIG
)
Also, the above example has been given in ``examples/train_backtest_analyze.ipynb``.
To know more about the `prediction score` `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
To know more about ``Intraday Trading``, please refer to `Intraday Trading: Model&Strategy Testing <backtest.html>`_.

View File

@@ -90,12 +90,12 @@ Below is a typical config file of ``qrun``.
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
module_path: qlib.workflow.record_temp
kwargs: {}
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
After saving the config into `configuration.yaml`, users could start the workflow and test their ideas with a single command below.
@@ -103,6 +103,12 @@ After saving the config into `configuration.yaml`, users could start the workflo
qrun configuration.yaml
If users want to use ``qrun`` under debug mode, please use the following command:
.. code-block:: bash
python -m pdb qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
.. note::
`qrun` will be placed in your $PATH directory when installing ``Qlib``.

View File

@@ -226,3 +226,8 @@ epub_exclude_files = ["search.html"]
autodoc_member_order = "bysource"
autodoc_default_flags = ["members"]
autodoc_default_options = {
"members": True,
"member-order": "bysource",
"special-members": "__init__",
}

View File

@@ -42,6 +42,7 @@ Document Structure
Intraday Trading: Model&Strategy Testing <component/backtest.rst>
Qlib Recorder: Experiment Management <component/recorder.rst>
Analysis: Evaluation & Results Analysis <component/report.rst>
Online Serving: Online Management & Strategy & Tool <component/online.rst>
.. toctree::
:maxdepth: 3
@@ -49,6 +50,8 @@ Document Structure
Building Formulaic Alphas <advanced/alpha.rst>
Online & Offline mode <advanced/server.rst>
Serialization <advanced/serial.rst>
Task Management <advanced/task_management.rst>
.. toctree::
:maxdepth: 3

View File

@@ -53,6 +53,34 @@ Cache
.. autoclass:: qlib.data.cache.DiskDatasetCache
:members:
Storage
-------------
.. autoclass:: qlib.data.storage.storage.BaseStorage
:members:
.. autoclass:: qlib.data.storage.storage.CalendarStorage
:members:
.. autoclass:: qlib.data.storage.storage.InstrumentStorage
:members:
.. autoclass:: qlib.data.storage.storage.FeatureStorage
:members:
.. autoclass:: qlib.data.storage.file_storage.FileStorageMixin
:members:
.. autoclass:: qlib.data.storage.file_storage.FileCalendarStorage
:members:
.. autoclass:: qlib.data.storage.file_storage.FileInstrumentStorage
:members:
.. autoclass:: qlib.data.storage.file_storage.FileFeatureStorage
:members:
Dataset
---------------
@@ -152,4 +180,81 @@ Recorder
Record Template
--------------------
.. automodule:: qlib.workflow.record_temp
:members:
:members:
Task Management
====================
TaskGen
--------------------
.. automodule:: qlib.workflow.task.gen
:members:
TaskManager
--------------------
.. automodule:: qlib.workflow.task.manage
:members:
Trainer
--------------------
.. automodule:: qlib.model.trainer
:members:
Collector
--------------------
.. automodule:: qlib.workflow.task.collect
:members:
Group
--------------------
.. automodule:: qlib.model.ens.group
:members:
Ensemble
--------------------
.. automodule:: qlib.model.ens.ensemble
:members:
Utils
--------------------
.. automodule:: qlib.workflow.task.utils
:members:
Online Serving
====================
Online Manager
--------------------
.. automodule:: qlib.workflow.online.manager
:members:
Online Strategy
--------------------
.. automodule:: qlib.workflow.online.strategy
:members:
Online Tool
--------------------
.. automodule:: qlib.workflow.online.utils
:members:
RecordUpdater
--------------------
.. automodule:: qlib.workflow.online.update
:members:
Utils
====================
Serializable
--------------------
.. automodule:: qlib.utils.serial.Serializable
:members:

View File

@@ -1 +1,5 @@
Cython==0.29.21
Cython
cmake
numpy
scipy
scikit-learn

View File

@@ -63,6 +63,7 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
If Qlib fails to connect redis via `redis_host` and `redis_port`, cache mechanism will not be used! Please refer to `Cache <../component/data.html#cache>`_ for details.
- `exp_manager`
Type: dict, optional parameter, the setting of `experiment manager` to be used in qlib. Users can specify an experiment manager class, as well as the tracking URI for all the experiments. However, please be aware that we only support input of a dictionary in the following style for `exp_manager`. For more information about `exp_manager`, users can refer to `Recorder: Experiment Management <../component/recorder.html>`_.
.. code-block:: Python
# For example, if you want to set your tracking_uri to a <specific folder>, you can initialize qlib below
@@ -74,3 +75,14 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
"default_exp_name": "Experiment",
}
})
- `mongo`
Type: dict, optional parameter, the setting of `MongoDB <https://www.mongodb.com/>`_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.
Users need finished `installation <https://www.mongodb.com/try/download/community>`_ firstly, and run it in a fixed URL.
.. code-block:: Python
# For example, you can initialize qlib below
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo={
"task_url": "mongodb://localhost:27017/", # your mongo url
"task_db_name": "rolling_db", # the database name of Task Management
})

View File

@@ -82,7 +82,7 @@ The Custom models need to inherit `qlib.model.base.Model <../reference/api.html#
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
- Override the `finetune` method (Optional)
- This method is optional to the users, and when users one to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`.
- This method is optional to the users. When users want to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`.
- The parameters must include the parameter `dataset`.
- Code Example: In the following example, users will use `LightGBM` as the model and finetune it.
.. code-block:: Python

View File

@@ -1,6 +1,6 @@
# Requirements
Here is the minimal hardware requirements to run the example.
Here is the minimal hardware requirements to run the `workflow_by_code` example.
- Memory: 16G
- Free Disk: 5G

View File

@@ -0,0 +1,93 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: FilterCol
kwargs:
fields_group: feature
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
]
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: ALSTM
module_path: qlib.contrib.model.pytorch_alstm_ts
kwargs:
d_feat: 20
hidden_size: 64
num_layers: 2
dropout: 0.0
n_epochs: 200
lr: 1e-3
early_stop: 10
batch_size: 800
metric: loss
loss: mse
n_jobs: 20
GPU: 0
rnn_type: GRU
dataset:
class: TSDatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
step_len: 20
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -54,7 +54,6 @@ task:
batch_size: 800
metric: loss
loss: mse
seed: 0
GPU: 0
rnn_type: GRU
dataset:
@@ -62,7 +61,7 @@ task:
module_path: qlib.data.dataset
kwargs:
handler:
class: ALPHA360
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:

View File

@@ -0,0 +1,72 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors: []
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: CatBoostModel
module_path: qlib.contrib.model.catboost_model
kwargs:
loss: RMSE
learning_rate: 0.0421
subsample: 0.8789
max_depth: 6
num_leaves: 100
thread_count: 20
grow_policy: Lossguide
bootstrap_type: Poisson
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,4 @@
# DoubleEnsemble
* DoubleEnsemble is an ensemble framework leveraging learning trajectory based sample reweighting and shuffling based feature selection, to solve both the low signal-to-noise ratio and increasing number of features problems. They identify the key samples based on the training dynamics on each sample and elicit key features based on the ablation impact of each feature via shuffling. The model is applicable to a wide range of base models, capable of extracting complex patterns, while mitigating the overfitting and instability issues for financial market prediction.
* This code used in Qlib is implemented by ourselves.
* Paper: DoubleEnsemble: A New Ensemble Method Based on Sample Reweighting and Feature Selection for Financial Data Analysis [https://arxiv.org/pdf/2010.01265.pdf](https://arxiv.org/pdf/2010.01265.pdf).

View File

@@ -0,0 +1,3 @@
pandas==1.1.2
numpy==1.17.4
lightgbm==3.1.0

View File

@@ -0,0 +1,90 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: DEnsembleModel
module_path: qlib.contrib.model.double_ensemble
kwargs:
base_model: "gbm"
loss: mse
num_models: 6
enable_sr: True
enable_fs: True
alpha1: 1
alpha2: 1
bins_sr: 10
bins_fs: 5
decay: 0.5
sample_ratios:
- 0.8
- 0.7
- 0.6
- 0.5
- 0.4
sub_weights:
- 1
- 0.2
- 0.2
- 0.2
- 0.2
- 0.2
epochs: 28
colsample_bytree: 0.8879
learning_rate: 0.2
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
verbosity: -1
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,97 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors: []
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: DEnsembleModel
module_path: qlib.contrib.model.double_ensemble
kwargs:
base_model: "gbm"
loss: mse
num_models: 6
enable_sr: True
enable_fs: True
alpha1: 1
alpha2: 1
bins_sr: 10
bins_fs: 5
decay: 0.5
sample_ratios:
- 0.8
- 0.7
- 0.6
- 0.5
- 0.4
sub_weights:
- 1
- 0.2
- 0.2
- 0.2
- 0.2
- 0.2
epochs: 136
colsample_bytree: 0.8879
learning_rate: 0.0421
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
verbosity: -1
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,92 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: FilterCol
kwargs:
fields_group: feature
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
]
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: GATs
module_path: qlib.contrib.model.pytorch_gats_ts
kwargs:
d_feat: 20
hidden_size: 64
num_layers: 2
dropout: 0.7
n_epochs: 200
lr: 1e-4
early_stop: 10
metric: loss
loss: mse
base_model: LSTM
with_pretrain: True
model_path: "benchmarks/LSTM/csi300_lstm_ts.pkl"
GPU: 0
dataset:
class: TSDatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
step_len: 20
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -56,14 +56,13 @@ task:
base_model: LSTM
with_pretrain: True
model_path: "benchmarks/LSTM/model_lstm_csi300.pkl"
seed: 0
GPU: 0
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: ALPHA360
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
@@ -74,6 +73,11 @@ task:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:

Binary file not shown.

View File

@@ -0,0 +1,92 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: FilterCol
kwargs:
fields_group: feature
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
]
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: GRU
module_path: qlib.contrib.model.pytorch_gru_ts
kwargs:
d_feat: 20
hidden_size: 64
num_layers: 2
dropout: 0.0
n_epochs: 200
lr: 2e-4
early_stop: 10
batch_size: 800
metric: loss
loss: mse
n_jobs: 20
GPU: 0
dataset:
class: TSDatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
step_len: 20
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -54,14 +54,13 @@ task:
batch_size: 800
metric: loss
loss: mse
seed: 0
GPU: 0
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: ALPHA360
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:

Binary file not shown.

View File

@@ -0,0 +1,92 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: FilterCol
kwargs:
fields_group: feature
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
]
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LSTM
module_path: qlib.contrib.model.pytorch_lstm_ts
kwargs:
d_feat: 20
hidden_size: 64
num_layers: 2
dropout: 0.0
n_epochs: 200
lr: 1e-3
early_stop: 10
batch_size: 800
metric: loss
loss: mse
n_jobs: 20
GPU: 0
dataset:
class: TSDatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
step_len: 20
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -54,14 +54,13 @@ task:
batch_size: 800
metric: loss
loss: mse
seed: 0
GPU: 0
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: ALPHA360
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:

View File

@@ -32,7 +32,7 @@ task:
kwargs:
loss: mse
colsample_bytree: 0.8879
learning_rate: 0.0421
learning_rate: 0.2
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768

View File

@@ -0,0 +1,73 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors: []
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LGBModel
module_path: qlib.contrib.model.gbdt
kwargs:
loss: mse
colsample_bytree: 0.8879
learning_rate: 0.0421
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,81 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
instruments: *market
data_loader:
class: QlibDataLoader
kwargs:
config:
feature:
- ["Resi($close, 15)/$close", "Std(Abs($close/Ref($close, 1)-1)*$volume, 5)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, 5)+1e-12)", "Rsquare($close, 5)", "($high-$low)/$open", "Rsquare($close, 10)", "Corr($close, Log($volume+1), 5)", "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 5)", "Corr($close, Log($volume+1), 10)", "Rsquare($close, 20)", "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 60)", "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 10)", "Corr($close, Log($volume+1), 20)", "(Less($open, $close)-$low)/$open"]
- ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", "RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"]
label:
- ["Ref($close, -2)/Ref($close, -1) - 1"]
- ["LABEL0"]
freq: day
learn_processors:
- class: DropnaLabel
- class: CSZScoreNorm
kwargs:
fields_group: label
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LGBModel
module_path: qlib.contrib.model.gbdt
kwargs:
loss: mse
colsample_bytree: 0.8879
learning_rate: 0.2
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: DataHandlerLP
module_path: qlib.data.dataset.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -65,8 +65,9 @@ task:
lr_decay_steps: 100
optimizer: adam
max_steps: 8000
batch_size: 4096
batch_size: 8192
GPU: 0
weight_decay: 0.0002
dataset:
class: DatasetH
module_path: qlib.data.dataset

View File

@@ -0,0 +1,82 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: DNNModelPytorch
module_path: qlib.contrib.model.pytorch_nn
kwargs:
loss: mse
input_dim: 360
output_dim: 1
lr: 0.002
lr_decay: 0.96
lr_decay_steps: 100
optimizer: adam
max_steps: 8000
batch_size: 4096
GPU: 0
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,39 @@
# Benchmarks Performance
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs.
The numbers shown below demonstrate the performance of the entire `workflow` of each model. We will update the `workflow` as well as models in the near future for better results.
## Alpha360 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|---|---|---|---|---|---|---|---|---|
| Linear | Alpha360 | 0.0150±0.00 | 0.1049±0.00| 0.0284±0.00 | 0.1970±0.00 | -0.0659±0.00 | -0.7072±0.00| -0.2955±0.00 |
| CatBoost (Liudmila Prokhorenkova, et al.) | Alpha360 | 0.0397±0.00 | 0.2878±0.00| 0.0470±0.00 | 0.3703±0.00 | 0.0342±0.00 | 0.4092±0.00| -0.1057±0.00 |
| XGBoost (Tianqi Chen, et al.) | Alpha360 | 0.0400±0.00 | 0.3031±0.00| 0.0461±0.00 | 0.3862±0.00 | 0.0528±0.00 | 0.6307±0.00| -0.1113±0.00 |
| LightGBM (Guolin Ke, et al.) | Alpha360 | 0.0399±0.00 | 0.3075±0.00| 0.0492±0.00 | 0.4019±0.00 | 0.0323±0.00 | 0.4370±0.00| -0.0917±0.00 |
| MLP | Alpha360 | 0.0285±0.00 | 0.1981±0.02| 0.0402±0.00 | 0.2993±0.02 | 0.0073±0.02 | 0.0880±0.22| -0.1446±0.03 |
| GRU (Kyunghyun Cho, et al.) | Alpha360 | 0.0490±0.01 | 0.3787±0.05| 0.0581±0.00 | 0.4664±0.04 | 0.0726±0.02 | 0.9817±0.34| -0.0902±0.03 |
| LSTM (Sepp Hochreiter, et al.) | Alpha360 | 0.0443±0.01 | 0.3401±0.05| 0.0536±0.01 | 0.4248±0.05 | 0.0627±0.03 | 0.8441±0.48| -0.0882±0.03 |
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 |
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 |
## Alpha158 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|---|---|---|---|---|---|---|---|---|
| Linear | Alpha158 | 0.0393±0.00 | 0.2980±0.00| 0.0475±0.00 | 0.3546±0.00 | 0.0795±0.00 | 1.0712±0.00| -0.1449±0.00 |
| CatBoost (Liudmila Prokhorenkova, et al.) | Alpha158 | 0.0503±0.00 | 0.3586±0.00| 0.0483±0.00 | 0.3667±0.00 | 0.1080±0.00 | 1.1561±0.00| -0.0787±0.00 |
| XGBoost (Tianqi Chen, et al.) | Alpha158 | 0.0481±0.00 | 0.3659±0.00| 0.0495±0.00 | 0.4033±0.00 | 0.1111±0.00 | 1.2915±0.00| -0.0893±0.00 |
| LightGBM (Guolin Ke, et al.) | Alpha158 | 0.0475±0.00 | 0.3979±0.00| 0.0485±0.00 | 0.4123±0.00 | 0.1143±0.00 | 1.2744±0.00| -0.0800±0.00 |
| MLP | Alpha158 | 0.0358±0.00 | 0.2738±0.03| 0.0425±0.00 | 0.3221±0.01 | 0.0836±0.02 | 1.0323±0.25| -0.1127±0.02 |
| TFT (Bryan Lim, et al.) | Alpha158 (with selected 20 features) | 0.0343±0.00 | 0.2071±0.02| 0.0107±0.00 | 0.0660±0.02 | 0.0623±0.02 | 0.5818±0.20| -0.1762±0.01 |
| GRU (Kyunghyun Cho, et al.) | Alpha158 (with selected 20 features) | 0.0311±0.00 | 0.2418±0.04| 0.0425±0.00 | 0.3434±0.02 | 0.0330±0.02 | 0.4805±0.30| -0.1021±0.02 |
| LSTM (Sepp Hochreiter, et al.) | Alpha158 (with selected 20 features) | 0.0312±0.00 | 0.2394±0.04| 0.0418±0.00 | 0.3324±0.03 | 0.0298±0.02 | 0.4198±0.33| -0.1348±0.03 |
| ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 |
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 |
| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 |
- The selected 20 features are based on the feature importance of a lightgbm-based model.
- The base model of DoubleEnsemble is LGBM.

View File

@@ -1,3 +1,3 @@
# State-Frequency-Memory
- State Frequency Memory (SFM) is a novel recurrent network that uses Discrete Fourier Transform to decompose the hidden states of memory cells and capture the multi-frequency trading patterns from past market data to make stock price predictions.
- Paper: Stock Price Prediction via Discovering Multi-Frequency Trading Patterns. [https://www.cs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.](https://www.cs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.)
- Paper: Stock Price Prediction via Discovering Multi-Frequency Trading Patterns. [http://www.eecs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.](http://www.eecs.ucf.edu/~gqi/publications/kdd2017_stock.pdf)

View File

@@ -57,14 +57,13 @@ task:
eval_steps: 5
loss: mse
optimizer: adam
GPU: 1
seed: 710
GPU: 0
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: ALPHA360
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:

View File

@@ -132,7 +132,7 @@ class GenericDataFormatter(abc.ABC):
return -1, -1
def get_column_definition(self):
""""Returns formatted column definition in order expected by the TFT."""
"""Returns formatted column definition in order expected by the TFT."""
column_definition = self._column_definition

View File

@@ -1,219 +1,229 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Custom formatting functions for Alpha158 dataset.
Defines dataset specific column definitions and data transformations.
"""
import data_formatters.base
import libs.utils as utils
import sklearn.preprocessing
GenericDataFormatter = data_formatters.base.GenericDataFormatter
DataTypes = data_formatters.base.DataTypes
InputTypes = data_formatters.base.InputTypes
class Alpha158Formatter(GenericDataFormatter):
"""Defines and formats data for the Alpha158 dataset.
Attributes:
column_definition: Defines input and data type of column used in the
experiment.
identifiers: Entity identifiers used in experiments.
"""
_column_definition = [
("instrument", DataTypes.CATEGORICAL, InputTypes.ID),
("LABEL0", DataTypes.REAL_VALUED, InputTypes.TARGET),
("date", DataTypes.DATE, InputTypes.TIME),
("month", DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
("day_of_week", DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
# Selected 10 features
("RESI5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("WVMA5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("RSQR5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("KLEN", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("RSQR10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("CORR5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("CORD5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("CORR10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("ROC60", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("RESI10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("const", DataTypes.CATEGORICAL, InputTypes.STATIC_INPUT),
]
def __init__(self):
"""Initialises formatter."""
self.identifiers = None
self._real_scalers = None
self._cat_scalers = None
self._target_scaler = None
self._num_classes_per_cat_input = None
def split_data(self, df, valid_boundary=2016, test_boundary=2018):
"""Splits data frame into training-validation-test data frames.
This also calibrates scaling object, and transforms data for each split.
Args:
df: Source data frame to split.
valid_boundary: Starting year for validation data
test_boundary: Starting year for test data
Returns:
Tuple of transformed (train, valid, test) data.
"""
print("Formatting train-valid-test splits.")
index = df["year"]
train = df.loc[index < valid_boundary]
valid = df.loc[(index >= valid_boundary) & (index < test_boundary)]
test = df.loc[index >= test_boundary]
self.set_scalers(train)
return (self.transform_inputs(data) for data in [train, valid, test])
def set_scalers(self, df):
"""Calibrates scalers using the data supplied.
Args:
df: Data to use to calibrate scalers.
"""
print("Setting scalers with training data...")
column_definitions = self.get_column_definition()
id_column = utils.get_single_col_by_input_type(InputTypes.ID, column_definitions)
target_column = utils.get_single_col_by_input_type(InputTypes.TARGET, column_definitions)
# Extract identifiers in case required
self.identifiers = list(df[id_column].unique())
# Format real scalers
real_inputs = utils.extract_cols_from_data_type(
DataTypes.REAL_VALUED, column_definitions, {InputTypes.ID, InputTypes.TIME}
)
data = df[real_inputs].values
self._real_scalers = sklearn.preprocessing.StandardScaler().fit(data)
self._target_scaler = sklearn.preprocessing.StandardScaler().fit(
df[[target_column]].values
) # used for predictions
# Format categorical scalers
categorical_inputs = utils.extract_cols_from_data_type(
DataTypes.CATEGORICAL, column_definitions, {InputTypes.ID, InputTypes.TIME}
)
categorical_scalers = {}
num_classes = []
for col in categorical_inputs:
# Set all to str so that we don't have mixed integer/string columns
srs = df[col].apply(str)
categorical_scalers[col] = sklearn.preprocessing.LabelEncoder().fit(srs.values)
num_classes.append(srs.nunique())
# Set categorical scaler outputs
self._cat_scalers = categorical_scalers
self._num_classes_per_cat_input = num_classes
def transform_inputs(self, df):
"""Performs feature transformations.
This includes both feature engineering, preprocessing and normalisation.
Args:
df: Data frame to transform.
Returns:
Transformed data frame.
"""
output = df.copy()
if self._real_scalers is None and self._cat_scalers is None:
raise ValueError("Scalers have not been set!")
column_definitions = self.get_column_definition()
real_inputs = utils.extract_cols_from_data_type(
DataTypes.REAL_VALUED, column_definitions, {InputTypes.ID, InputTypes.TIME}
)
categorical_inputs = utils.extract_cols_from_data_type(
DataTypes.CATEGORICAL, column_definitions, {InputTypes.ID, InputTypes.TIME}
)
# Format real inputs
output[real_inputs] = self._real_scalers.transform(df[real_inputs].values)
# Format categorical inputs
for col in categorical_inputs:
string_df = df[col].apply(str)
output[col] = self._cat_scalers[col].transform(string_df)
return output
def format_predictions(self, predictions):
"""Reverts any normalisation to give predictions in original scale.
Args:
predictions: Dataframe of model predictions.
Returns:
Data frame of unnormalised predictions.
"""
output = predictions.copy()
column_names = predictions.columns
for col in column_names:
if col not in {"forecast_time", "identifier"}:
output[col] = self._target_scaler.inverse_transform(predictions[col])
return output
# Default params
def get_fixed_params(self):
"""Returns fixed model parameters for experiments."""
fixed_params = {
"total_time_steps": 6 + 6,
"num_encoder_steps": 6,
"num_epochs": 100,
"early_stopping_patience": 10,
"multiprocessing_workers": 5,
}
return fixed_params
def get_default_model_params(self):
"""Returns default optimised model parameters."""
model_params = {
"dropout_rate": 0.4,
"hidden_layer_size": 16,
"learning_rate": 0.0001,
"minibatch_size": 128,
"max_gradient_norm": 0.0135,
"num_heads": 1,
"stack_size": 1,
}
return model_params
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Custom formatting functions for Alpha158 dataset.
Defines dataset specific column definitions and data transformations.
"""
import data_formatters.base
import libs.utils as utils
import sklearn.preprocessing
GenericDataFormatter = data_formatters.base.GenericDataFormatter
DataTypes = data_formatters.base.DataTypes
InputTypes = data_formatters.base.InputTypes
class Alpha158Formatter(GenericDataFormatter):
"""Defines and formats data for the Alpha158 dataset.
Attributes:
column_definition: Defines input and data type of column used in the
experiment.
identifiers: Entity identifiers used in experiments.
"""
_column_definition = [
("instrument", DataTypes.CATEGORICAL, InputTypes.ID),
("LABEL0", DataTypes.REAL_VALUED, InputTypes.TARGET),
("date", DataTypes.DATE, InputTypes.TIME),
("month", DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
("day_of_week", DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
# Selected features
("RESI5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("WVMA5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("RSQR5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("KLEN", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("RSQR10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("CORR5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("CORD5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("CORR10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("ROC60", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("RESI10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("VSTD5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("RSQR60", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("CORR60", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("WVMA60", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("STD5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("RSQR20", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("CORD60", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("CORD10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("CORR20", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("KLOW", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("const", DataTypes.CATEGORICAL, InputTypes.STATIC_INPUT),
]
def __init__(self):
"""Initialises formatter."""
self.identifiers = None
self._real_scalers = None
self._cat_scalers = None
self._target_scaler = None
self._num_classes_per_cat_input = None
def split_data(self, df, valid_boundary=2016, test_boundary=2018):
"""Splits data frame into training-validation-test data frames.
This also calibrates scaling object, and transforms data for each split.
Args:
df: Source data frame to split.
valid_boundary: Starting year for validation data
test_boundary: Starting year for test data
Returns:
Tuple of transformed (train, valid, test) data.
"""
print("Formatting train-valid-test splits.")
index = df["year"]
train = df.loc[index < valid_boundary]
valid = df.loc[(index >= valid_boundary) & (index < test_boundary)]
test = df.loc[index >= test_boundary]
self.set_scalers(train)
return (self.transform_inputs(data) for data in [train, valid, test])
def set_scalers(self, df):
"""Calibrates scalers using the data supplied.
Args:
df: Data to use to calibrate scalers.
"""
print("Setting scalers with training data...")
column_definitions = self.get_column_definition()
id_column = utils.get_single_col_by_input_type(InputTypes.ID, column_definitions)
target_column = utils.get_single_col_by_input_type(InputTypes.TARGET, column_definitions)
# Extract identifiers in case required
self.identifiers = list(df[id_column].unique())
# Format real scalers
real_inputs = utils.extract_cols_from_data_type(
DataTypes.REAL_VALUED, column_definitions, {InputTypes.ID, InputTypes.TIME}
)
data = df[real_inputs].values
self._real_scalers = sklearn.preprocessing.StandardScaler().fit(data)
self._target_scaler = sklearn.preprocessing.StandardScaler().fit(
df[[target_column]].values
) # used for predictions
# Format categorical scalers
categorical_inputs = utils.extract_cols_from_data_type(
DataTypes.CATEGORICAL, column_definitions, {InputTypes.ID, InputTypes.TIME}
)
categorical_scalers = {}
num_classes = []
for col in categorical_inputs:
# Set all to str so that we don't have mixed integer/string columns
srs = df[col].apply(str)
categorical_scalers[col] = sklearn.preprocessing.LabelEncoder().fit(srs.values)
num_classes.append(srs.nunique())
# Set categorical scaler outputs
self._cat_scalers = categorical_scalers
self._num_classes_per_cat_input = num_classes
def transform_inputs(self, df):
"""Performs feature transformations.
This includes both feature engineering, preprocessing and normalisation.
Args:
df: Data frame to transform.
Returns:
Transformed data frame.
"""
output = df.copy()
if self._real_scalers is None and self._cat_scalers is None:
raise ValueError("Scalers have not been set!")
column_definitions = self.get_column_definition()
real_inputs = utils.extract_cols_from_data_type(
DataTypes.REAL_VALUED, column_definitions, {InputTypes.ID, InputTypes.TIME}
)
categorical_inputs = utils.extract_cols_from_data_type(
DataTypes.CATEGORICAL, column_definitions, {InputTypes.ID, InputTypes.TIME}
)
# Format real inputs
output[real_inputs] = self._real_scalers.transform(df[real_inputs].values)
# Format categorical inputs
for col in categorical_inputs:
string_df = df[col].apply(str)
output[col] = self._cat_scalers[col].transform(string_df)
return output
def format_predictions(self, predictions):
"""Reverts any normalisation to give predictions in original scale.
Args:
predictions: Dataframe of model predictions.
Returns:
Data frame of unnormalised predictions.
"""
output = predictions.copy()
column_names = predictions.columns
for col in column_names:
if col not in {"forecast_time", "identifier"}:
output[col] = self._target_scaler.inverse_transform(predictions[col])
return output
# Default params
def get_fixed_params(self):
"""Returns fixed model parameters for experiments."""
fixed_params = {
"total_time_steps": 6 + 6,
"num_encoder_steps": 6,
"num_epochs": 100,
"early_stopping_patience": 10,
"multiprocessing_workers": 5,
}
return fixed_params
def get_default_model_params(self):
"""Returns default optimised model parameters."""
model_params = {
"dropout_rate": 0.4,
"hidden_layer_size": 160,
"learning_rate": 0.0001,
"minibatch_size": 128,
"max_gradient_norm": 0.0135,
"num_heads": 1,
"stack_size": 1,
}
return model_params

View File

@@ -25,7 +25,7 @@ import os
import data_formatters.qlib_Alpha158
class ExperimentConfig(object):
class ExperimentConfig:
"""Defines experiment configs and paths to outputs.
Attributes:

View File

@@ -320,7 +320,7 @@ class InterpretableMultiHeadAttention:
return outputs, attn
class TFTDataCache(object):
class TFTDataCache:
"""Caches data for the TFT."""
_data_cache = {}
@@ -348,7 +348,7 @@ class TFTDataCache(object):
# TFT model definitions.
class TemporalFusionTransformer(object):
class TemporalFusionTransformer:
"""Defines Temporal Fusion Transformer.
Attributes:
@@ -972,7 +972,7 @@ class TemporalFusionTransformer(object):
valid_quantiles = self.quantiles
output_size = self.output_size
class QuantileLossCalculator(object):
class QuantileLossCalculator:
"""Computes the combined quantile loss for prespecified quantiles.
Attributes:

View File

@@ -1,249 +1,291 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import pandas as pd
import tensorflow.compat.v1 as tf
import data_formatters.base
import expt_settings.configs
import libs.hyperparam_opt
import libs.tft_model
import libs.utils as utils
import os
import datetime as dte
from qlib.model.base import ModelFT
from qlib.data.dataset import DatasetH
from qlib.data.dataset.handler import DataHandlerLP
# To register new datasets, please add them here.
ALLOW_DATASET = ["Alpha158"]
DATASET_SETTING = {
"Alpha158": {
"feature_col": ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", "ROC60", "RESI10"],
"label_col": ["LABEL0"],
},
}
# To register new datasets, please add their configurations here.
def get_shifted_label(data_df, shifts=5, col_shift="LABEL0"):
return data_df[[col_shift]].groupby("instrument").apply(lambda df: df.shift(shifts))
def fill_test_na(test_df):
test_df_res = test_df.copy()
feature_cols = ~test_df_res.columns.str.contains("label", case=False)
test_feature_fna = test_df_res.loc[:, feature_cols].groupby("datetime").apply(lambda df: df.fillna(df.mean()))
test_df_res.loc[:, feature_cols] = test_feature_fna
return test_df_res
def process_qlib_data(df, dataset, fillna=False):
"""Prepare data to fit the TFT model.
Args:
df: Original DataFrame.
fillna: Whether to fill the data with the mean values.
Returns:
Transformed DataFrame.
"""
# Several features selected manually
feature_col = DATASET_SETTING[dataset]["feature_col"]
label_col = DATASET_SETTING[dataset]["label_col"]
temp_df = df.loc[:, feature_col + label_col]
if fillna:
temp_df = fill_test_na(temp_df)
temp_df = temp_df.swaplevel()
temp_df = temp_df.sort_index()
temp_df = temp_df.reset_index(level=0)
dates = pd.to_datetime(temp_df.index)
temp_df["date"] = dates
temp_df["day_of_week"] = dates.dayofweek
temp_df["month"] = dates.month
temp_df["year"] = dates.year
temp_df["const"] = 1.0
return temp_df
def process_predicted(df, col_name):
"""Transform the TFT predicted data into Qlib format.
Args:
df: Original DataFrame.
fillna: New column name.
Returns:
Transformed DataFrame.
"""
df_res = df.copy()
df_res = df_res.rename(columns={"forecast_time": "datetime", "identifier": "instrument", "t+4": col_name})
df_res = df_res.set_index(["datetime", "instrument"]).sort_index()
df_res = df_res[[col_name]]
return df_res
def format_score(forecast_df, col_name="pred", label_shift=5):
pred = process_predicted(forecast_df, col_name=col_name)
pred = get_shifted_label(pred, shifts=-label_shift, col_shift=col_name)
pred = pred.dropna()[col_name]
return pred
def transform_df(df, col_name="LABEL0"):
df_res = df["feature"]
df_res[col_name] = df["label"]
return df_res
class TFTModel(ModelFT):
"""TFT Model"""
def __init__(self, **kwargs):
self.model = None
def _prepare_data(self, dataset: DatasetH):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
return transform_df(df_train), transform_df(df_valid)
def fit(
self,
dataset: DatasetH,
DATASET="Alpha158",
MODEL_FOLDER="qlib_alpha158_model",
LABEL_COL="LABEL0",
LABEL_SHIFT=5,
USE_GPU_ID=0,
**kwargs
):
if DATASET not in ALLOW_DATASET:
raise AssertionError("The dataset is not supported, please make a new formatter to fit this dataset")
dtrain, dvalid = self._prepare_data(dataset)
dtrain.loc[:, LABEL_COL] = get_shifted_label(dtrain, shifts=LABEL_SHIFT, col_shift=LABEL_COL)
dvalid.loc[:, LABEL_COL] = get_shifted_label(dvalid, shifts=LABEL_SHIFT, col_shift=LABEL_COL)
train = process_qlib_data(dtrain, DATASET, fillna=True).dropna()
valid = process_qlib_data(dvalid, DATASET, fillna=True).dropna()
ExperimentConfig = expt_settings.configs.ExperimentConfig
config = ExperimentConfig(DATASET)
self.data_formatter = config.make_data_formatter()
self.model_folder = MODEL_FOLDER
self.gpu_id = USE_GPU_ID
self.label_shift = LABEL_SHIFT
self.expt_name = DATASET
self.label_col = LABEL_COL
use_gpu = (True, self.gpu_id)
# ===========================Training Process===========================
ModelClass = libs.tft_model.TemporalFusionTransformer
if not isinstance(self.data_formatter, data_formatters.base.GenericDataFormatter):
raise ValueError(
"Data formatters should inherit from"
+ "AbstractDataFormatter! Type={}".format(type(self.data_formatter))
)
default_keras_session = tf.keras.backend.get_session()
if use_gpu[0]:
self.tf_config = utils.get_default_tensorflow_config(tf_device="gpu", gpu_id=use_gpu[1])
else:
self.tf_config = utils.get_default_tensorflow_config(tf_device="cpu")
self.data_formatter.set_scalers(train)
# Sets up default params
fixed_params = self.data_formatter.get_experiment_params()
params = self.data_formatter.get_default_model_params()
# Wendi: 合并调优的参数和非调优的参数
params = {**params, **fixed_params}
if not os.path.exists(self.model_folder):
os.makedirs(self.model_folder)
params["model_folder"] = self.model_folder
print("*** Begin training ***")
best_loss = np.Inf
tf.reset_default_graph()
self.tf_graph = tf.Graph()
with self.tf_graph.as_default():
self.sess = tf.Session(config=self.tf_config)
tf.keras.backend.set_session(self.sess)
self.model = ModelClass(params, use_cudnn=use_gpu[0])
self.sess.run(tf.global_variables_initializer())
self.model.fit(train_df=train, valid_df=valid)
print("*** Finished training ***")
saved_model_dir = self.model_folder + "/" + "saved_model"
if not os.path.exists(saved_model_dir):
os.makedirs(saved_model_dir)
self.model.save(saved_model_dir)
def extract_numerical_data(data):
"""Strips out forecast time and identifier columns."""
return data[[col for col in data.columns if col not in {"forecast_time", "identifier"}]]
# p50_loss = utils.numpy_normalised_quantile_loss(
# extract_numerical_data(targets), extract_numerical_data(p50_forecast),
# 0.5)
# p90_loss = utils.numpy_normalised_quantile_loss(
# extract_numerical_data(targets), extract_numerical_data(p90_forecast),
# 0.9)
tf.keras.backend.set_session(default_keras_session)
print("Training completed.".format(dte.datetime.now()))
# ===========================Training Process===========================
def predict(self, dataset):
if self.model is None:
raise ValueError("model is not fitted yet!")
d_test = dataset.prepare("test", col_set=["feature", "label"])
d_test = transform_df(d_test)
d_test.loc[:, self.label_col] = get_shifted_label(d_test, shifts=self.label_shift, col_shift=self.label_col)
test = process_qlib_data(d_test, self.expt_name, fillna=True).dropna()
use_gpu = (True, self.gpu_id)
# ===========================Predicting Process===========================
default_keras_session = tf.keras.backend.get_session()
# Sets up default params
fixed_params = self.data_formatter.get_experiment_params()
params = self.data_formatter.get_default_model_params()
params = {**params, **fixed_params}
print("*** Begin predicting ***")
tf.reset_default_graph()
with self.tf_graph.as_default():
tf.keras.backend.set_session(self.sess)
output_map = self.model.predict(test, return_targets=True)
targets = self.data_formatter.format_predictions(output_map["targets"])
p50_forecast = self.data_formatter.format_predictions(output_map["p50"])
p90_forecast = self.data_formatter.format_predictions(output_map["p90"])
tf.keras.backend.set_session(default_keras_session)
predict50 = format_score(p50_forecast, "pred", 1)
predict90 = format_score(p90_forecast, "pred", 1)
predict = (predict50 + predict90) / 2 # self.label_shift
# ===========================Predicting Process===========================
return predict
def finetune(self, dataset: DatasetH):
"""
finetune model
Parameters
----------
dataset : DatasetH
dataset for finetuning
"""
pass
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import pandas as pd
import tensorflow.compat.v1 as tf
import data_formatters.base
import expt_settings.configs
import libs.hyperparam_opt
import libs.tft_model
import libs.utils as utils
import os
import datetime as dte
from qlib.model.base import ModelFT
from qlib.data.dataset import DatasetH
from qlib.data.dataset.handler import DataHandlerLP
# To register new datasets, please add them here.
ALLOW_DATASET = ["Alpha158", "Alpha360"]
# To register new datasets, please add their configurations here.
DATASET_SETTING = {
"Alpha158": {
"feature_col": [
"RESI5",
"WVMA5",
"RSQR5",
"KLEN",
"RSQR10",
"CORR5",
"CORD5",
"CORR10",
"ROC60",
"RESI10",
"VSTD5",
"RSQR60",
"CORR60",
"WVMA60",
"STD5",
"RSQR20",
"CORD60",
"CORD10",
"CORR20",
"KLOW",
],
"label_col": "LABEL0",
},
"Alpha360": {
"feature_col": [
"HIGH0",
"LOW0",
"OPEN0",
"CLOSE1",
"HIGH1",
"VOLUME1",
"LOW1",
"VOLUME3",
"OPEN1",
"VOLUME4",
"CLOSE2",
"CLOSE4",
"VOLUME5",
"LOW2",
"CLOSE3",
"VOLUME2",
"HIGH2",
"LOW4",
"VOLUME8",
"VOLUME11",
],
"label_col": "LABEL0",
},
}
def get_shifted_label(data_df, shifts=5, col_shift="LABEL0"):
return data_df[[col_shift]].groupby("instrument").apply(lambda df: df.shift(shifts))
def fill_test_na(test_df):
test_df_res = test_df.copy()
feature_cols = ~test_df_res.columns.str.contains("label", case=False)
test_feature_fna = test_df_res.loc[:, feature_cols].groupby("datetime").apply(lambda df: df.fillna(df.mean()))
test_df_res.loc[:, feature_cols] = test_feature_fna
return test_df_res
def process_qlib_data(df, dataset, fillna=False):
"""Prepare data to fit the TFT model.
Args:
df: Original DataFrame.
fillna: Whether to fill the data with the mean values.
Returns:
Transformed DataFrame.
"""
# Several features selected manually
feature_col = DATASET_SETTING[dataset]["feature_col"]
label_col = [DATASET_SETTING[dataset]["label_col"]]
temp_df = df.loc[:, feature_col + label_col]
if fillna:
temp_df = fill_test_na(temp_df)
temp_df = temp_df.swaplevel()
temp_df = temp_df.sort_index()
temp_df = temp_df.reset_index(level=0)
dates = pd.to_datetime(temp_df.index)
temp_df["date"] = dates
temp_df["day_of_week"] = dates.dayofweek
temp_df["month"] = dates.month
temp_df["year"] = dates.year
temp_df["const"] = 1.0
return temp_df
def process_predicted(df, col_name):
"""Transform the TFT predicted data into Qlib format.
Args:
df: Original DataFrame.
fillna: New column name.
Returns:
Transformed DataFrame.
"""
df_res = df.copy()
df_res = df_res.rename(columns={"forecast_time": "datetime", "identifier": "instrument", "t+4": col_name})
df_res = df_res.set_index(["datetime", "instrument"]).sort_index()
df_res = df_res[[col_name]]
return df_res
def format_score(forecast_df, col_name="pred", label_shift=5):
pred = process_predicted(forecast_df, col_name=col_name)
pred = get_shifted_label(pred, shifts=-label_shift, col_shift=col_name)
pred = pred.dropna()[col_name]
return pred
def transform_df(df, col_name="LABEL0"):
df_res = df["feature"]
df_res[col_name] = df["label"]
return df_res
class TFTModel(ModelFT):
"""TFT Model"""
def __init__(self, **kwargs):
self.model = None
self.params = {"DATASET": "Alpha158", "label_shift": 5}
self.params.update(kwargs)
def _prepare_data(self, dataset: DatasetH):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
return transform_df(df_train), transform_df(df_valid)
def fit(self, dataset: DatasetH, MODEL_FOLDER="qlib_tft_model", USE_GPU_ID=0, **kwargs):
DATASET = self.params["DATASET"]
LABEL_SHIFT = self.params["label_shift"]
LABEL_COL = DATASET_SETTING[DATASET]["label_col"]
if DATASET not in ALLOW_DATASET:
raise AssertionError("The dataset is not supported, please make a new formatter to fit this dataset")
dtrain, dvalid = self._prepare_data(dataset)
dtrain.loc[:, LABEL_COL] = get_shifted_label(dtrain, shifts=LABEL_SHIFT, col_shift=LABEL_COL)
dvalid.loc[:, LABEL_COL] = get_shifted_label(dvalid, shifts=LABEL_SHIFT, col_shift=LABEL_COL)
train = process_qlib_data(dtrain, DATASET, fillna=True).dropna()
valid = process_qlib_data(dvalid, DATASET, fillna=True).dropna()
ExperimentConfig = expt_settings.configs.ExperimentConfig
config = ExperimentConfig(DATASET)
self.data_formatter = config.make_data_formatter()
self.model_folder = MODEL_FOLDER
self.gpu_id = USE_GPU_ID
self.label_shift = LABEL_SHIFT
self.expt_name = DATASET
self.label_col = LABEL_COL
use_gpu = (True, self.gpu_id)
# ===========================Training Process===========================
ModelClass = libs.tft_model.TemporalFusionTransformer
if not isinstance(self.data_formatter, data_formatters.base.GenericDataFormatter):
raise ValueError(
"Data formatters should inherit from"
+ "AbstractDataFormatter! Type={}".format(type(self.data_formatter))
)
default_keras_session = tf.keras.backend.get_session()
if use_gpu[0]:
self.tf_config = utils.get_default_tensorflow_config(tf_device="gpu", gpu_id=use_gpu[1])
else:
self.tf_config = utils.get_default_tensorflow_config(tf_device="cpu")
self.data_formatter.set_scalers(train)
# Sets up default params
fixed_params = self.data_formatter.get_experiment_params()
params = self.data_formatter.get_default_model_params()
# Wendi: 合并调优的参数和非调优的参数
params = {**params, **fixed_params}
if not os.path.exists(self.model_folder):
os.makedirs(self.model_folder)
params["model_folder"] = self.model_folder
print("*** Begin training ***")
best_loss = np.Inf
tf.reset_default_graph()
self.tf_graph = tf.Graph()
with self.tf_graph.as_default():
self.sess = tf.Session(config=self.tf_config)
tf.keras.backend.set_session(self.sess)
self.model = ModelClass(params, use_cudnn=use_gpu[0])
self.sess.run(tf.global_variables_initializer())
self.model.fit(train_df=train, valid_df=valid)
print("*** Finished training ***")
saved_model_dir = self.model_folder + "/" + "saved_model"
if not os.path.exists(saved_model_dir):
os.makedirs(saved_model_dir)
self.model.save(saved_model_dir)
def extract_numerical_data(data):
"""Strips out forecast time and identifier columns."""
return data[[col for col in data.columns if col not in {"forecast_time", "identifier"}]]
# p50_loss = utils.numpy_normalised_quantile_loss(
# extract_numerical_data(targets), extract_numerical_data(p50_forecast),
# 0.5)
# p90_loss = utils.numpy_normalised_quantile_loss(
# extract_numerical_data(targets), extract_numerical_data(p90_forecast),
# 0.9)
tf.keras.backend.set_session(default_keras_session)
print("Training completed.".format(dte.datetime.now()))
# ===========================Training Process===========================
def predict(self, dataset):
if self.model is None:
raise ValueError("model is not fitted yet!")
d_test = dataset.prepare("test", col_set=["feature", "label"])
d_test = transform_df(d_test)
d_test.loc[:, self.label_col] = get_shifted_label(d_test, shifts=self.label_shift, col_shift=self.label_col)
test = process_qlib_data(d_test, self.expt_name, fillna=True).dropna()
use_gpu = (True, self.gpu_id)
# ===========================Predicting Process===========================
default_keras_session = tf.keras.backend.get_session()
# Sets up default params
fixed_params = self.data_formatter.get_experiment_params()
params = self.data_formatter.get_default_model_params()
params = {**params, **fixed_params}
print("*** Begin predicting ***")
tf.reset_default_graph()
with self.tf_graph.as_default():
tf.keras.backend.set_session(self.sess)
output_map = self.model.predict(test, return_targets=True)
targets = self.data_formatter.format_predictions(output_map["targets"])
p50_forecast = self.data_formatter.format_predictions(output_map["p50"])
p90_forecast = self.data_formatter.format_predictions(output_map["p90"])
tf.keras.backend.set_session(default_keras_session)
predict50 = format_score(p50_forecast, "pred", 1)
predict90 = format_score(p90_forecast, "pred", 1)
predict = (predict50 + predict90) / 2 # self.label_shift
# ===========================Predicting Process===========================
return predict
def finetune(self, dataset: DatasetH):
"""
finetune model
Parameters
----------
dataset : DatasetH
dataset for finetuning
"""
pass

View File

@@ -0,0 +1,4 @@
pandas==1.1.2
numpy==1.17.4
scikit_learn==0.23.2
torch==1.7.0

View File

@@ -0,0 +1,75 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TabnetModel
module_path: qlib.contrib.model.pytorch_tabnet
kwargs:
d_feat: 158
pretrain: True
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
pretrain: [2008-01-01, 2014-12-31]
pretrain_validation: [2015-01-01, 2016-12-31]
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,75 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TabnetModel
module_path: qlib.contrib.model.pytorch_tabnet
kwargs:
d_feat: 360
pretrain: True
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
pretrain: [2008-01-01, 2014-12-31]
pretrain_validation: [2015-01-01, 2016-12-31]
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,71 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors: []
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: XGBModel
module_path: qlib.contrib.model.xgboost
kwargs:
eval_metric: rmse
colsample_bytree: 0.8879
eta: 0.0421
max_depth: 8
n_estimators: 647
subsample: 0.8789
nthread: 20
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

208
examples/data/monitor.py Normal file
View File

@@ -0,0 +1,208 @@
"""
This script is the demonstrating the implementation of Metric Extractor and Detector
NOTE: A lot of details is not considered in this script
- Corner case that will raise error( std == 0)
The following functions are used to demonstrate the following examples
· Metric Extractor:
case 1) Basic statistics on different slices of the DataFrame df:
1) The statistics include:
· STD, Mean, Skewnes, Kurtosis
2) The above statistics can be calculated on the following data slices:
· df.groupby(['datetime'])
· df.groupby(['datetime', 'industry' ])
3) The statistics could be calculated on the time dimension for each instruments and factor(the factor can be represented by experssion)
· <df implemented by expresion>.groupby(['instrument', 'factor'])
case 2) Advanced statistics on different slices of the DataFrame df:
1) Auto-correlation:
· Calculate corr(df.loc[t, :, :], df.loc[t-w, :, :]), w=1, 2, ….
2) Correlation between factors:
· For any pair of factors (i, j): calculate corr(df.loc[t, :, i], df.loc[t, :, j]). The result is a correlation matrix with each element corresponds to a correlation value between a pair of factors.
· Detector: detect the abnormality of the extracted metric;
a) Algorithms:
§ Basic checks: NaN.
§ Point anomaly detection.
§ Segment anomaly detection.
b) Scenarios:
§ Online anomaly detection: monitoring streaming data.
The usage of the detectors are demonstrated in the `case_1_*`and `case_2_*`
case 3): Examples to use MetricExt to monitor IC and rank IC
1) IC(Information Coefficient) #case_3_1
2) RankIC #case_3_2
"""
# AUTO download data
from typing import List, Union
from qlib.utils import exists_qlib_data
from qlib.tests.data import GetData
from qlib.config import REG_CN
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
import qlib
import pandas as pd
from qlib.contrib.data.handler import Alpha158
from qlib.data.dataset.loader import QlibDataLoader
from qlib.data.monitor.metric import format_conv
from qlib.data.monitor.metric import MeanM, SkewM, KurtM, StdM, AutoCM, CorrM
from qlib.data.monitor.detector import NDDetector, SWNDD, ThresholdD
from qlib.data import D
import fire
UNIVERSE = "csi300"
START_TIME = "20200101"
# ------------------ a helper function to get data to demonstrate the functionality --------------------
def get_data_df(col_idx: Union[int, List[int]] = 0, verbose: bool = True):
"""
a helper function to get data to demonstrate the functionality.
Parameters
----------
col_idx : Union[int, List[int]]
column index of the metrics
"""
dh = Alpha158(instruments=UNIVERSE, infer_processors=[], learn_processors=[], start_time=START_TIME)
df = dh.fetch()
if verbose:
print(df.head())
# We don't have industries in dataframe, we generate the with fake data
industry = pd.Series(df.index.get_level_values("instrument").str.slice(stop=2).to_list(), index=df.index)
# select a factor
factor_df = format_conv(df.iloc[:, col_idx], industry=industry)
if verbose:
print(f"Selected metric: {df.columns[col_idx]}")
print(factor_df)
return factor_df
def get_target(horizon=5):
target = f"Ref($close, -{horizon + 1})/Ref($close, -1) - 1" # There are lots of targets: return is one of them
qdl = QlibDataLoader(config=([target], ["target"]))
df = qdl.load(instruments=UNIVERSE, start_time=START_TIME) # Aligning with factor will improve performance
df = format_conv(df["target"])
return df
# ----------------- Cases to demonstrate the usage of detector and examples ----------------------
def case_1_1():
factor_df = get_data_df()
# 1) Extract metrics
# 1.1) df.groupby(["datetime"])
mtrc = MeanM()
m_mean = mtrc.extract(factor_df)
print(m_mean)
ndd = NDDetector()
ndd.fit(m_mean) # use historical data to fit detector
check_res = ndd.check(m_mean)
print(check_res) # detecting on new data or historical data
print(check_res.value_counts())
def case_1_2():
factor_df = get_data_df()
# 1.2) df.groupby("datetime", "industry")
mtrc = MeanM(group=["industry"])
m_multi = mtrc.extract(factor_df)
print(m_multi)
for col_name, s in m_multi.iteritems():
print(col_name)
ndd = NDDetector()
ndd.fit(s) # use historical data to fit detector
check_res = ndd.check(s)
print(check_res) # detecting on new data or historical data
print(check_res.value_counts())
def case_1_3():
# case 1.3
# factor_df = get_data_df()
qdl = QlibDataLoader(config=(["$close/Ref($close, 1) - 1"], ["return"]))
df = qdl.load(instruments=["SH600519"], start_time=START_TIME)
df = format_conv(df)
s = df.iloc[:, 0]
print(s)
dtc = SWNDD(window=20)
dtc.fit(s) # fit use historical data (TODO: updating will be supported in the future)
check_res = dtc.check(s) #
print(check_res)
print(check_res.value_counts())
print(check_res[check_res])
def case_2_1():
# · Calculate corr(df.loc[t, :, :], df.loc[t-w, :, :]), w=1, 2, ….
factor_df = get_data_df()
acm = AutoCM()
mtrc = acm.extract(factor_df)
print(mtrc)
thd = ThresholdD(0.0, reverse=True)
check_res = thd.check(mtrc)
print(check_res)
print(check_res.value_counts())
def case_2_2():
factor_df1, factor_df2 = get_data_df(0), get_data_df(1)
cm = CorrM()
mtrc = cm.extract(factor_df1, factor_df2)
print(mtrc)
thd = ThresholdD(0.0, reverse=True)
check_res = thd.check(mtrc)
print(check_res)
print(check_res.value_counts())
def case_3_1_3_2():
target, factor = get_target(), get_data_df(0)
ic_m, rank_ic_m = CorrM(), CorrM(mode="spearman")
ic, rank_ic = ic_m.extract(factor, target), rank_ic_m.extract(factor, target)
print(pd.DataFrame({"ic": ic, "rank_ic": rank_ic}))
def run(test_list=["case_1_1", "case_1_2", "case_1_3", "case_2_1", "case_2_2", "case_3_1_3_2"]):
"""
run the specific tests
python monitor.py case_3_1_3_2
Parameters
----------
test_list : str[]
The tests to run
"""
if isinstance(test_list, str):
test_list = [test_list]
for fn in test_list:
globals()[fn]()
if __name__ == "__main__":
qlib.init()
fire.Fire(run)

View File

@@ -0,0 +1,130 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "0e62a81e",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from tqdm.auto import tqdm\n",
"%matplotlib inline\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c503217b",
"metadata": {},
"outputs": [],
"source": [
"from qlib.data.monitor.analyser import Analyser\n",
"import qlib\n",
"qlib.init()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9c276470",
"metadata": {},
"outputs": [],
"source": [
"class SimpleDFA(Analyser):\n",
" \"\"\"Simple (D)ata(F)rame (A)nalyser\"\"\"\n",
" def analyse(self, data: pd.DataFrame, *args, **kwargs):\n",
" data.plot(*args, **kwargs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "110262e4",
"metadata": {},
"outputs": [],
"source": [
"from monitor import get_data_df, AutoCM"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ea38c62",
"metadata": {},
"outputs": [],
"source": [
"# get data\n",
"factor_df = get_data_df([1], verbose=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dbded6fe",
"metadata": {},
"outputs": [],
"source": [
"# metric extractor\n",
"acm = AutoCM()\n",
"mtrc = acm.extract(factor_df)\n",
"print(mtrc)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "65517c81",
"metadata": {},
"outputs": [],
"source": [
"# Analyser\n",
"sa = SimpleDFA()\n",
"sa.analyse(mtrc, title='Auto Correlation')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dab6fb2e",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,35 @@
# High-Frequency Dataset
This dataset is an example for RL high frequency trading.
## Get High-Frequency Data
Get high-frequency data by running the following command:
```bash
python workflow.py get_data
```
## Dump & Reload & Reinitialize the Dataset
The High-Frequency Dataset is implemented as `qlib.data.dataset.DatasetH` in the `workflow.py`. `DatatsetH` is the subclass of [`qlib.utils.serial.Serializable`](https://qlib.readthedocs.io/en/latest/advanced/serial.html), whose state can be dumped in or loaded from disk in `pickle` format.
### About Reinitialization
After reloading `Dataset` from disk, `Qlib` also support reinitializing the dataset. It means that users can reset some states of `Dataset` or `DataHandler` such as `instruments`, `start_time`, `end_time` and `segments`, etc., and generate new data according to the states.
The example is given in `workflow.py`, users can run the code as follows.
### Run the Code
Run the example by running the following command:
```bash
python workflow.py dump_and_load_dataset
```
## Benchmarks Performance
### Signal Test
Here are the results of signal test for benchmark models. We will keep updating benchmark models in future.
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Long precision| Short Precision | Long-Short Average Return | Long-Short Average Sharpe |
|---|---|---|---|---|---|---|---|---|---|
| LightGBM | Alpha158 | 0.3042±0.00 | 1.5372±0.00| 0.3117±0.00 | 1.6258±0.00 | 0.6720±0.00 | 0.6870±0.00 | 0.000769±0.00 | 1.0190±0.00 |

View File

@@ -0,0 +1,174 @@
from qlib.data.dataset.handler import DataHandler, DataHandlerLP
from qlib.data.dataset.processor import Processor
from qlib.utils import get_cls_kwargs
from qlib.log import TimeInspector
class HighFreqHandler(DataHandlerLP):
def __init__(
self,
instruments="csi300",
start_time=None,
end_time=None,
infer_processors=[],
learn_processors=[],
fit_start_time=None,
fit_end_time=None,
drop_raw=True,
):
def check_transform_proc(proc_l):
new_l = []
for p in proc_l:
p["kwargs"].update(
{
"fit_start_time": fit_start_time,
"fit_end_time": fit_end_time,
}
)
new_l.append(p)
return new_l
infer_processors = check_transform_proc(infer_processors)
learn_processors = check_transform_proc(learn_processors)
data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"config": self.get_feature_config(),
"swap_level": False,
"freq": "1min",
},
}
super().__init__(
instruments=instruments,
start_time=start_time,
end_time=end_time,
data_loader=data_loader,
infer_processors=infer_processors,
learn_processors=learn_processors,
drop_raw=drop_raw,
)
def get_feature_config(self):
fields = []
names = []
template_if = "If(IsNull({1}), {0}, {1})"
template_paused = "Select(Or(IsNull($paused), Eq($paused, 0.0)), {0})"
template_fillnan = "BFillNan(FFillNan({0}))"
# Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap
simpson_vwap = "($open + 2*$high + 2*$low + $close)/6"
def get_normalized_price_feature(price_field, shift=0):
"""Get normalized price feature ops"""
if shift == 0:
template_norm = "Cut({0}/Ref(DayLast({1}), 240), 240, None)"
else:
template_norm = "Cut(Ref({0}, " + str(shift) + ")/Ref(DayLast({1}), 240), 240, None)"
feature_ops = template_norm.format(
template_if.format(
template_fillnan.format(template_paused.format("$close")),
template_paused.format(price_field),
),
template_fillnan.format(template_paused.format("$close")),
)
return feature_ops
fields += [get_normalized_price_feature("$open", 0)]
fields += [get_normalized_price_feature("$high", 0)]
fields += [get_normalized_price_feature("$low", 0)]
fields += [get_normalized_price_feature("$close", 0)]
fields += [get_normalized_price_feature(simpson_vwap, 0)]
names += ["$open", "$high", "$low", "$close", "$vwap"]
fields += [get_normalized_price_feature("$open", 240)]
fields += [get_normalized_price_feature("$high", 240)]
fields += [get_normalized_price_feature("$low", 240)]
fields += [get_normalized_price_feature("$close", 240)]
fields += [get_normalized_price_feature(simpson_vwap, 240)]
names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"]
fields += [
"Cut({0}/Ref(DayLast(Mean({0}, 7200)), 240), 240, None)".format(
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
template_paused.format("$volume"),
template_paused.format(simpson_vwap),
template_paused.format("$low"),
template_paused.format("$high"),
)
)
]
names += ["$volume"]
fields += [
"Cut(Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240), 240, None)".format(
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
template_paused.format("$volume"),
template_paused.format(simpson_vwap),
template_paused.format("$low"),
template_paused.format("$high"),
)
)
]
names += ["$volume_1"]
fields += ["Cut({0}, 240, None)".format(template_paused.format("Date($close)"))]
names += ["date"]
return fields, names
class HighFreqBacktestHandler(DataHandler):
def __init__(
self,
instruments="csi300",
start_time=None,
end_time=None,
):
data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"config": self.get_feature_config(),
"swap_level": False,
"freq": "1min",
},
}
super().__init__(
instruments=instruments,
start_time=start_time,
end_time=end_time,
data_loader=data_loader,
)
def get_feature_config(self):
fields = []
names = []
template_if = "If(IsNull({1}), {0}, {1})"
template_paused = "Select(Or(IsNull($paused), Eq($paused, 0.0)), {0})"
template_fillnan = "BFillNan(FFillNan({0}))"
# Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap
simpson_vwap = "($open + 2*$high + 2*$low + $close)/6"
fields += [
"Cut({0}, 240, None)".format(template_fillnan.format(template_paused.format("$close"))),
]
names += ["$close0"]
fields += [
"Cut({0}, 240, None)".format(
template_if.format(
template_fillnan.format(template_paused.format("$close")),
template_paused.format(simpson_vwap),
)
)
]
names += ["$vwap0"]
fields += [
"Cut(If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0})), 240, None)".format(
template_paused.format("$volume"),
template_paused.format(simpson_vwap),
template_paused.format("$low"),
template_paused.format("$high"),
)
]
names += ["$volume0"]
return fields, names

View File

@@ -0,0 +1,190 @@
import numpy as np
import pandas as pd
import importlib
from qlib.data.ops import ElemOperator, PairOperator
from qlib.config import C
from qlib.data.cache import H
from qlib.data.data import Cal
def get_calendar_day(freq="day", future=False):
"""Load High-Freq Calendar Date Using Memcache.
Parameters
----------
freq : str
frequency of read calendar file.
future : bool
whether including future trading day.
Returns
-------
_calendar:
array of date.
"""
flag = f"{freq}_future_{future}_day"
if flag in H["c"]:
_calendar = H["c"][flag]
else:
_calendar = np.array(list(map(lambda x: x.date(), Cal.load_calendar(freq, future))))
H["c"][flag] = _calendar
return _calendar
class DayLast(ElemOperator):
"""DayLast Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
feature:
a series of that each value equals the last value of its day
"""
def _load_internal(self, instrument, start_index, end_index, freq):
_calendar = get_calendar_day(freq=freq)
series = self.feature.load(instrument, start_index, end_index, freq)
return series.groupby(_calendar[series.index]).transform("last")
class FFillNan(ElemOperator):
"""FFillNan Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
feature:
a forward fill nan feature
"""
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.fillna(method="ffill")
class BFillNan(ElemOperator):
"""BFillNan Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
feature:
a backfoward fill nan feature
"""
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.fillna(method="bfill")
class Date(ElemOperator):
"""Date Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
feature:
a series of that each value is the date corresponding to feature.index
"""
def _load_internal(self, instrument, start_index, end_index, freq):
_calendar = get_calendar_day(freq=freq)
series = self.feature.load(instrument, start_index, end_index, freq)
return pd.Series(_calendar[series.index], index=series.index)
class Select(PairOperator):
"""Select Operator
Parameters
----------
feature_left : Expression
feature instance, select condition
feature_right : Expression
feature instance, select value
Returns
----------
feature:
value(feature_right) that meets the condition(feature_left)
"""
def _load_internal(self, instrument, start_index, end_index, freq):
series_condition = self.feature_left.load(instrument, start_index, end_index, freq)
series_feature = self.feature_right.load(instrument, start_index, end_index, freq)
return series_feature.loc[series_condition]
class IsNull(ElemOperator):
"""IsNull Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
feature:
A series indicating whether the feature is nan
"""
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.isnull()
class Cut(ElemOperator):
"""Cut Operator
Parameters
----------
feature : Expression
feature instance
l : int
l > 0, delete the first l elements of feature (default is None, which means 0)
r : int
r < 0, delete the last -r elements of feature (default is None, which means 0)
Returns
----------
feature:
A series with the first l and last -r elements deleted from the feature.
Note: It is deleted from the raw data, not the sliced data
"""
def __init__(self, feature, l=None, r=None):
self.l = l
self.r = r
if (self.l is not None and self.l <= 0) or (self.r is not None and self.r >= 0):
raise ValueError("Cut operator l shoud > 0 and r should < 0")
super(Cut, self).__init__(feature)
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.iloc[self.l : self.r]
def get_extended_window_size(self):
ll = 0 if self.l is None else self.l
rr = 0 if self.r is None else abs(self.r)
lft_etd, rght_etd = self.feature.get_extended_window_size()
lft_etd = lft_etd + ll
rght_etd = rght_etd + rr
return lft_etd, rght_etd

View File

@@ -0,0 +1,72 @@
import numpy as np
import pandas as pd
from qlib.data.dataset.processor import Processor
from qlib.data.dataset.utils import fetch_df_by_index
class HighFreqNorm(Processor):
def __init__(self, fit_start_time, fit_end_time):
self.fit_start_time = fit_start_time
self.fit_end_time = fit_end_time
def fit(self, df_features):
fetch_df = fetch_df_by_index(df_features, slice(self.fit_start_time, self.fit_end_time), level="datetime")
del df_features
df_values = fetch_df.values
names = {
"price": slice(0, 10),
"volume": slice(10, 12),
}
self.feature_med = {}
self.feature_std = {}
self.feature_vmax = {}
self.feature_vmin = {}
for name, name_val in names.items():
part_values = df_values[:, name_val].astype(np.float32)
if name == "volume":
part_values = np.log1p(part_values)
self.feature_med[name] = np.nanmedian(part_values)
part_values = part_values - self.feature_med[name]
self.feature_std[name] = np.nanmedian(np.absolute(part_values)) * 1.4826 + 1e-12
part_values = part_values / self.feature_std[name]
self.feature_vmax[name] = np.nanmax(part_values)
self.feature_vmin[name] = np.nanmin(part_values)
def __call__(self, df_features):
df_features.set_index("date", append=True, drop=True, inplace=True)
df_values = df_features.values
names = {
"price": slice(0, 10),
"volume": slice(10, 12),
}
for name, name_val in names.items():
if name == "volume":
df_values[:, name_val] = np.log1p(df_values[:, name_val])
df_values[:, name_val] -= self.feature_med[name]
df_values[:, name_val] /= self.feature_std[name]
slice0 = df_values[:, name_val] > 3.0
slice1 = df_values[:, name_val] > 3.5
slice2 = df_values[:, name_val] < -3.0
slice3 = df_values[:, name_val] < -3.5
df_values[:, name_val][slice0] = (
3.0 + (df_values[:, name_val][slice0] - 3.0) / (self.feature_vmax[name] - 3) * 0.5
)
df_values[:, name_val][slice1] = 3.5
df_values[:, name_val][slice2] = (
-3.0 - (df_values[:, name_val][slice2] + 3.0) / (self.feature_vmin[name] + 3) * 0.5
)
df_values[:, name_val][slice3] = -3.5
idx = df_features.index.droplevel("datetime").drop_duplicates()
idx.set_names(["instrument", "datetime"], inplace=True)
# Reshape is specifically for adapting to RL high-freq executor
feat = df_values[:, [0, 1, 2, 3, 4, 10]].reshape(-1, 6 * 240)
feat_1 = df_values[:, [5, 6, 7, 8, 9, 11]].reshape(-1, 6 * 240)
df_new_features = pd.DataFrame(
data=np.concatenate((feat, feat_1), axis=1),
index=idx,
columns=["FEATURE_%d" % i for i in range(12 * 240)],
).sort_index()
return df_new_features

View File

@@ -0,0 +1,175 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import fire
import qlib
import pickle
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
from qlib.utils import init_instance_by_config
from qlib.data.dataset.handler import DataHandlerLP
from qlib.data.ops import Operators
from qlib.data.data import Cal
from qlib.tests.data import GetData
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut
class HighfreqWorkflow:
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None}
MARKET = "all"
start_time = "2020-09-15 00:00:00"
end_time = "2021-01-18 16:00:00"
train_end_time = "2020-11-30 16:00:00"
test_start_time = "2020-12-01 00:00:00"
DATA_HANDLER_CONFIG0 = {
"start_time": start_time,
"end_time": end_time,
"fit_start_time": start_time,
"fit_end_time": train_end_time,
"instruments": MARKET,
"infer_processors": [{"class": "HighFreqNorm", "module_path": "highfreq_processor", "kwargs": {}}],
}
DATA_HANDLER_CONFIG1 = {
"start_time": start_time,
"end_time": end_time,
"instruments": MARKET,
}
task = {
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "HighFreqHandler",
"module_path": "highfreq_handler",
"kwargs": DATA_HANDLER_CONFIG0,
},
"segments": {
"train": (start_time, train_end_time),
"test": (
test_start_time,
end_time,
),
},
},
},
"dataset_backtest": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "HighFreqBacktestHandler",
"module_path": "highfreq_handler",
"kwargs": DATA_HANDLER_CONFIG1,
},
"segments": {
"train": (start_time, train_end_time),
"test": (
test_start_time,
end_time,
),
},
},
},
}
def _init_qlib(self):
"""initialize qlib"""
# use yahoo_cn_1min data
QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF}
provider_uri = QLIB_INIT_CONFIG.get("provider_uri")
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN, exists_skip=True)
qlib.init(**QLIB_INIT_CONFIG)
def _prepare_calender_cache(self):
"""preload the calendar for cache"""
# This code used the copy-on-write feature of Linux to avoid calculating the calendar multiple times in the subprocess
# This code may accelerate, but may be not useful on Windows and Mac Os
Cal.calendar(freq="1min")
get_calendar_day(freq="1min")
def get_data(self):
"""use dataset to get highreq data"""
self._init_qlib()
self._prepare_calender_cache()
dataset = init_instance_by_config(self.task["dataset"])
xtrain, xtest = dataset.prepare(["train", "test"])
print(xtrain, xtest)
dataset_backtest = init_instance_by_config(self.task["dataset_backtest"])
backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"])
print(backtest_train, backtest_test)
return
def dump_and_load_dataset(self):
"""dump and load dataset state on disk"""
self._init_qlib()
self._prepare_calender_cache()
dataset = init_instance_by_config(self.task["dataset"])
dataset_backtest = init_instance_by_config(self.task["dataset_backtest"])
##=============dump dataset=============
dataset.to_pickle(path="dataset.pkl")
dataset_backtest.to_pickle(path="dataset_backtest.pkl")
del dataset, dataset_backtest
##=============reload dataset=============
with open("dataset.pkl", "rb") as file_dataset:
dataset = pickle.load(file_dataset)
with open("dataset_backtest.pkl", "rb") as file_dataset_backtest:
dataset_backtest = pickle.load(file_dataset_backtest)
self._prepare_calender_cache()
##=============reinit dataset=============
dataset.config(
handler_kwargs={
"start_time": "2021-01-19 00:00:00",
"end_time": "2021-01-25 16:00:00",
},
segments={
"test": (
"2021-01-19 00:00:00",
"2021-01-25 16:00:00",
),
},
)
dataset.setup_data(
handler_kwargs={
"init_type": DataHandlerLP.IT_LS,
},
)
dataset_backtest.config(
handler_kwargs={
"start_time": "2021-01-19 00:00:00",
"end_time": "2021-01-25 16:00:00",
},
segments={
"test": (
"2021-01-19 00:00:00",
"2021-01-25 16:00:00",
),
},
)
dataset_backtest.setup_data(handler_kwargs={})
##=============get data=============
xtest = dataset.prepare("test")
backtest_test = dataset_backtest.prepare("test")
print(xtest, backtest_test)
return
if __name__ == "__main__":
fire.Fire(HighfreqWorkflow)

View File

@@ -0,0 +1,65 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data_1min"
region: cn
market: &market 'csi300'
start_time: &start_time "2020-09-15 00:00:00"
end_time: &end_time "2021-01-18 16:00:00"
train_end_time: &train_end_time "2020-11-15 16:00:00"
valid_start_time: &valid_start_time "2020-11-16 00:00:00"
valid_end_time: &valid_end_time "2020-11-30 16:00:00"
test_start_time: &test_start_time "2020-12-01 00:00:00"
data_handler_config: &data_handler_config
start_time: *start_time
end_time: *end_time
fit_start_time: *start_time
fit_end_time: *train_end_time
instruments: *market
freq: '1min'
infer_processors:
- class: 'RobustZScoreNorm'
kwargs:
fields_group: 'feature'
clip_outlier: false
- class: "Fillna"
kwargs:
fields_group: 'feature'
learn_processors:
- class: 'DropnaLabel'
- class: 'CSRankNorm'
kwargs:
fields_group: 'label'
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
task:
model:
class: "HFLGBModel"
module_path: "qlib.contrib.model.highfreq_gdbt_model"
kwargs:
objective: 'binary'
metric: ['binary_logloss','auc']
verbosity: -1
learning_rate: 0.01
max_depth: 8
num_leaves: 150
lambda_l1: 1.5
lambda_l2: 1
num_threads: 20
dataset:
class: "DatasetH"
module_path: "qlib.data.dataset"
kwargs:
handler:
class: "Alpha158"
module_path: "qlib.contrib.data.handler"
kwargs: *data_handler_config
segments:
train: [*start_time, *train_end_time]
valid: [*train_end_time, *valid_end_time]
test: [*test_start_time, *end_time]
record:
- class: "SignalRecord"
module_path: "qlib.workflow.record_temp"
kwargs: {}
- class: "HFSignalRecord"
module_path: "qlib.workflow.record_temp"
kwargs: {}

View File

@@ -0,0 +1,23 @@
# LightGBM hyperparameter
## Alpha158
First terminal
```
optuna create-study --study LGBM_158 --storage sqlite:///db.sqlite3
optuna-dashboard --port 5000 --host 0.0.0.0 sqlite:///db.sqlite3
```
Second terminal
```
python hyperparameter_158.py
```
## Alpha360
First terminal
```
optuna create-study --study LGBM_360 --storage sqlite:///db.sqlite3
optuna-dashboard --port 5000 --host 0.0.0.0 sqlite:///db.sqlite3
```
Second terminal
```
python hyperparameter_360.py
```

View File

@@ -0,0 +1,46 @@
import qlib
import optuna
from qlib.config import REG_CN
from qlib.utils import init_instance_by_config
from qlib.tests.config import CSI300_DATASET_CONFIG
from qlib.tests.data import GetData
def objective(trial):
task = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": trial.suggest_uniform("colsample_bytree", 0.5, 1),
"learning_rate": trial.suggest_uniform("learning_rate", 0, 1),
"subsample": trial.suggest_uniform("subsample", 0, 1),
"lambda_l1": trial.suggest_loguniform("lambda_l1", 1e-8, 1e4),
"lambda_l2": trial.suggest_loguniform("lambda_l2", 1e-8, 1e4),
"max_depth": 10,
"num_leaves": trial.suggest_int("num_leaves", 1, 1024),
"feature_fraction": trial.suggest_uniform("feature_fraction", 0.4, 1.0),
"bagging_fraction": trial.suggest_uniform("bagging_fraction", 0.4, 1.0),
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
"min_data_in_leaf": trial.suggest_int("min_data_in_leaf", 1, 50),
"min_child_samples": trial.suggest_int("min_child_samples", 5, 100),
},
},
}
evals_result = dict()
model = init_instance_by_config(task["model"])
model.fit(dataset, evals_result=evals_result)
return min(evals_result["valid"])
if __name__ == "__main__":
provider_uri = "~/.qlib/qlib_data/cn_data"
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region="cn")
dataset = init_instance_by_config(CSI300_DATASET_CONFIG)
study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3")
study.optimize(objective, n_jobs=6)

View File

@@ -0,0 +1,49 @@
import qlib
import optuna
from qlib.config import REG_CN
from qlib.utils import init_instance_by_config
from qlib.tests.data import GetData
from qlib.tests.config import get_dataset_config, CSI300_MARKET, DATASET_ALPHA360_CLASS
DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA360_CLASS)
def objective(trial):
task = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": trial.suggest_uniform("colsample_bytree", 0.5, 1),
"learning_rate": trial.suggest_uniform("learning_rate", 0, 1),
"subsample": trial.suggest_uniform("subsample", 0, 1),
"lambda_l1": trial.suggest_loguniform("lambda_l1", 1e-8, 1e4),
"lambda_l2": trial.suggest_loguniform("lambda_l2", 1e-8, 1e4),
"max_depth": 10,
"num_leaves": trial.suggest_int("num_leaves", 1, 1024),
"feature_fraction": trial.suggest_uniform("feature_fraction", 0.4, 1.0),
"bagging_fraction": trial.suggest_uniform("bagging_fraction", 0.4, 1.0),
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
"min_data_in_leaf": trial.suggest_int("min_data_in_leaf", 1, 50),
"min_child_samples": trial.suggest_int("min_child_samples", 5, 100),
},
},
}
evals_result = dict()
model = init_instance_by_config(task["model"])
model.fit(dataset, evals_result=evals_result)
return min(evals_result["valid"])
if __name__ == "__main__":
provider_uri = "~/.qlib/qlib_data/cn_data"
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN)
dataset = init_instance_by_config(DATASET_CONFIG)
study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3")
study.optimize(objective, n_jobs=6)

View File

@@ -0,0 +1,5 @@
pandas==1.1.2
numpy==1.17.4
lightgbm==3.1.0
optuna==2.7.0
optuna-dashboard==0.4.1

View File

@@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import qlib
from qlib.config import REG_CN
from qlib.utils import init_instance_by_config
from qlib.tests.data import GetData
from qlib.tests.config import CSI300_GBDT_TASK
if __name__ == "__main__":
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN)
###################################
# train model
###################################
# model initialization
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
model.fit(dataset)
# get model feature importance
feature_importance = model.get_feature_importance()
print("feature importance:")
print(feature_importance)

View File

@@ -0,0 +1,105 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This example shows how a TrainerRM works based on TaskManager with rolling tasks.
After training, how to collect the rolling results will be shown in task_collecting.
"""
from pprint import pprint
import fire
import qlib
from qlib.config import REG_CN
from qlib.workflow import R
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.task.collect import RecorderCollector
from qlib.model.ens.group import RollingGroup
from qlib.model.trainer import TrainerRM
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
class RollingTaskExample:
def __init__(
self,
provider_uri="~/.qlib/qlib_data/cn_data",
region=REG_CN,
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
experiment_name="rolling_exp",
task_pool="rolling_task",
task_config=None,
rolling_step=550,
rolling_type=RollingGen.ROLL_SD,
):
# TaskManager config
if task_config is None:
task_config = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
mongo_conf = {
"task_url": task_url,
"task_db_name": task_db_name,
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.experiment_name = experiment_name
self.task_pool = task_pool
self.task_config = task_config
self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)
# Reset all things to the first status, be careful to save important data
def reset(self):
print("========== reset ==========")
TaskManager(task_pool=self.task_pool).remove()
exp = R.get_exp(experiment_name=self.experiment_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
def task_generating(self):
print("========== task_generating ==========")
tasks = task_generator(
tasks=self.task_config,
generators=self.rolling_gen, # generate different date segments
)
pprint(tasks)
return tasks
def task_training(self, tasks):
print("========== task_training ==========")
trainer = TrainerRM(self.experiment_name, self.task_pool)
trainer.train(tasks)
def task_collecting(self):
print("========== task_collecting ==========")
def rec_key(recorder):
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, rolling_key
def my_filter(recorder):
# only choose the results of "LGBModel"
model_key, rolling_key = rec_key(recorder)
if model_key == "LGBModel":
return True
return False
collector = RecorderCollector(
experiment=self.experiment_name,
process_list=RollingGroup(),
rec_key_func=rec_key,
rec_filter_func=my_filter,
)
print(collector())
def main(self):
self.reset()
tasks = self.task_generating()
self.task_training(tasks)
self.task_collecting()
if __name__ == "__main__":
## to see the whole process with your own parameters, use the command below
# python task_manager_rolling.py main --experiment_name="your_exp_name"
fire.Fire(RollingTaskExample)

View File

@@ -0,0 +1,92 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This example is about how can simulate the OnlineManager based on rolling tasks.
"""
import fire
import qlib
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
from qlib.workflow import R
from qlib.workflow.online.manager import OnlineManager
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.task.manage import TaskManager
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
class OnlineSimulationExample:
def __init__(
self,
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
exp_name="rolling_exp",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
task_pool="rolling_task",
rolling_step=80,
start_time="2018-09-10",
end_time="2018-10-31",
tasks=None,
):
"""
Init OnlineManagerExample.
Args:
provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data".
region (str, optional): the stock region. Defaults to "cn".
exp_name (str, optional): the experiment name. Defaults to "rolling_exp".
task_url (str, optional): your MongoDB url. Defaults to "mongodb://10.0.0.4:27017/".
task_db_name (str, optional): database name. Defaults to "rolling_db".
task_pool (str, optional): the task pool name (a task pool is a collection in MongoDB). Defaults to "rolling_task".
rolling_step (int, optional): the step for rolling. Defaults to 80.
start_time (str, optional): the start time of simulating. Defaults to "2018-09-10".
end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
"""
if tasks is None:
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
self.exp_name = exp_name
self.task_pool = task_pool
self.start_time = start_time
self.end_time = end_time
mongo_conf = {
"task_url": task_url,
"task_db_name": task_db_name,
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.rolling_gen = RollingGen(
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
self.rolling_online_manager = OnlineManager(
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
trainer=self.trainer,
begin_time=self.start_time,
)
self.tasks = tasks
# Reset all things to the first status, be careful to save important data
def reset(self):
TaskManager(self.task_pool).remove()
exp = R.get_exp(experiment_name=self.exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
# Run this to run all workflow automatically
def main(self):
print("========== reset ==========")
self.reset()
print("========== simulate ==========")
self.rolling_online_manager.simulate(end_time=self.end_time)
print("========== collect results ==========")
print(self.rolling_online_manager.get_collector()())
print("========== signals ==========")
print(self.rolling_online_manager.get_signals())
if __name__ == "__main__":
## to run all workflow automatically with your own parameters, use the command below
# python online_management_simulate.py main --experiment_name="your_exp_name" --rolling_step=60
fire.Fire(OnlineSimulationExample)

View File

@@ -0,0 +1,130 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This example shows how OnlineManager works with rolling tasks.
There are four parts including first train, routine 1, add strategy and routine 2.
Firstly, the OnlineManager will finish the first training and set trained models to `online` models.
Next, the OnlineManager will finish a routine process, including update online prediction -> prepare tasks -> prepare new models -> prepare signals
Then, we will add some new strategies to the OnlineManager. This will finish first training of new strategies.
Finally, the OnlineManager will finish second routine and update all strategies.
"""
import os
import fire
import qlib
from qlib.workflow import R
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.online.manager import OnlineManager
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG
class RollingOnlineExample:
def __init__(
self,
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
rolling_step=550,
tasks=None,
add_tasks=None,
):
if add_tasks is None:
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG]
if tasks is None:
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG]
mongo_conf = {
"task_url": task_url, # your MongoDB url
"task_db_name": task_db_name, # database name
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.tasks = tasks
self.add_tasks = add_tasks
self.rolling_step = rolling_step
strategies = []
for task in tasks:
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
strategies.append(
RollingStrategy(
name_id,
task,
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
)
)
self.rolling_online_manager = OnlineManager(strategies)
_ROLLING_MANAGER_PATH = (
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
)
# Reset all things to the first status, be careful to save important data
def reset(self):
for task in self.tasks + self.add_tasks:
name_id = task["model"]["class"]
exp = R.get_exp(experiment_name=name_id)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
if os.path.exists(self._ROLLING_MANAGER_PATH):
os.remove(self._ROLLING_MANAGER_PATH)
def first_run(self):
print("========== reset ==========")
self.reset()
print("========== first_run ==========")
self.rolling_online_manager.first_train()
print("========== collect results ==========")
print(self.rolling_online_manager.get_collector()())
print("========== dump ==========")
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
def routine(self):
print("========== load ==========")
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
print("========== routine ==========")
self.rolling_online_manager.routine()
print("========== collect results ==========")
print(self.rolling_online_manager.get_collector()())
print("========== signals ==========")
print(self.rolling_online_manager.get_signals())
print("========== dump ==========")
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
def add_strategy(self):
print("========== load ==========")
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
print("========== add strategy ==========")
strategies = []
for task in self.add_tasks:
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
strategies.append(
RollingStrategy(
name_id,
task,
RollingGen(step=self.rolling_step, rtype=RollingGen.ROLL_SD),
)
)
self.rolling_online_manager.add_strategy(strategies=strategies)
print("========== dump ==========")
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
def main(self):
self.first_run()
self.routine()
self.add_strategy()
self.routine()
if __name__ == "__main__":
####### to train the first version's models, use the command below
# python rolling_online_management.py first_run
####### to update the models and predictions after the trading time, use the command below
# python rolling_online_management.py routine
####### to define your own parameters, use `--`
# python rolling_online_management.py first_run --exp_name='your_exp_name' --rolling_step=40
fire.Fire(RollingOnlineExample)

View File

@@ -0,0 +1,54 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This example shows how OnlineTool works when we need update prediction.
There are two parts including first_train and update_online_pred.
Firstly, we will finish the training and set the trained models to the `online` models.
Next, we will finish updating online predictions.
"""
import copy
import fire
import qlib
from qlib.config import REG_CN
from qlib.model.trainer import task_train
from qlib.workflow.online.utils import OnlineToolR
from qlib.tests.config import CSI300_GBDT_TASK
task = copy.deepcopy(CSI300_GBDT_TASK)
task["record"] = {
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
}
class UpdatePredExample:
def __init__(
self, provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv", task_config=task
):
qlib.init(provider_uri=provider_uri, region=region)
self.experiment_name = experiment_name
self.online_tool = OnlineToolR(self.experiment_name)
self.task_config = task_config
def first_train(self):
rec = task_train(self.task_config, experiment_name=self.experiment_name)
self.online_tool.reset_online_tag(rec) # set to online model
def update_online_pred(self):
self.online_tool.update_online_pred()
def main(self):
self.first_train()
self.update_online_pred()
if __name__ == "__main__":
## to train a model and set it to online model, use the command below
# python update_online_pred.py first_train
## to update online predictions once a day, use the command below
# python update_online_pred.py update_online_pred
## to see the whole process with your own parameters, use the command below
# python update_online_pred.py main --experiment_name="your_exp_name"
fire.Fire(UpdatePredExample)

View File

@@ -0,0 +1,17 @@
# Rolling Process Data
This workflow is an example for `Rolling Process Data`.
## Background
When rolling train the models, data also needs to be generated in the different rolling windows. When the rolling window moves, the training data will change, and the processor's learnable state (such as standard deviation, mean, etc.) will also change.
In order to avoid regenerating data, this example uses the `DataHandler-based DataLoader` to load the raw features that are not related to the rolling window, and then used Processors to generate processed-features related to the rolling window.
## Run the Code
Run the example by running the following command:
```bash
python workflow.py rolling_process
```

View File

@@ -0,0 +1,32 @@
from qlib.data.dataset.handler import DataHandlerLP
from qlib.data.dataset.loader import DataLoaderDH
from qlib.contrib.data.handler import check_transform_proc
class RollingDataHandler(DataHandlerLP):
def __init__(
self,
start_time=None,
end_time=None,
infer_processors=[],
learn_processors=[],
fit_start_time=None,
fit_end_time=None,
data_loader_kwargs={},
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
data_loader = {
"class": "DataLoaderDH",
"kwargs": {**data_loader_kwargs},
}
super().__init__(
instruments=None,
start_time=start_time,
end_time=end_time,
data_loader=data_loader,
infer_processors=infer_processors,
learn_processors=learn_processors,
)

View File

@@ -0,0 +1,137 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import qlib
import fire
import pickle
from datetime import datetime
from qlib.config import REG_CN
from qlib.data.dataset.handler import DataHandlerLP
from qlib.utils import init_instance_by_config
from qlib.tests.data import GetData
class RollingDataWorkflow:
MARKET = "csi300"
start_time = "2010-01-01"
end_time = "2019-12-31"
rolling_cnt = 5
def _init_qlib(self):
"""initialize qlib"""
# use yahoo_cn_1min data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN)
def _dump_pre_handler(self, path):
handler_config = {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": {
"start_time": self.start_time,
"end_time": self.end_time,
"instruments": self.MARKET,
"infer_processors": [],
"learn_processors": [],
},
}
pre_handler = init_instance_by_config(handler_config)
pre_handler.config(dump_all=True)
pre_handler.to_pickle(path)
def _load_pre_handler(self, path):
with open(path, "rb") as file_dataset:
pre_handler = pickle.load(file_dataset)
return pre_handler
def rolling_process(self):
self._init_qlib()
self._dump_pre_handler("pre_handler.pkl")
pre_handler = self._load_pre_handler("pre_handler.pkl")
train_start_time = (2010, 1, 1)
train_end_time = (2012, 12, 31)
valid_start_time = (2013, 1, 1)
valid_end_time = (2013, 12, 31)
test_start_time = (2014, 1, 1)
test_end_time = (2014, 12, 31)
dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "RollingDataHandler",
"module_path": "rolling_handler",
"kwargs": {
"start_time": datetime(*train_start_time),
"end_time": datetime(*test_end_time),
"fit_start_time": datetime(*train_start_time),
"fit_end_time": datetime(*train_end_time),
"infer_processors": [
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature"}},
],
"learn_processors": [
{"class": "DropnaLabel"},
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
],
"data_loader_kwargs": {
"handler_config": pre_handler,
},
},
},
"segments": {
"train": (datetime(*train_start_time), datetime(*train_end_time)),
"valid": (datetime(*valid_start_time), datetime(*valid_end_time)),
"test": (datetime(*test_start_time), datetime(*test_end_time)),
},
},
}
dataset = init_instance_by_config(dataset_config)
for rolling_offset in range(self.rolling_cnt):
print(f"===========rolling{rolling_offset} start===========")
if rolling_offset:
dataset.config(
handler_kwargs={
"start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
"end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
"processor_kwargs": {
"fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
"fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
},
},
segments={
"train": (
datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
),
"valid": (
datetime(valid_start_time[0] + rolling_offset, *valid_start_time[1:]),
datetime(valid_end_time[0] + rolling_offset, *valid_end_time[1:]),
),
"test": (
datetime(test_start_time[0] + rolling_offset, *test_start_time[1:]),
datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
),
},
)
dataset.setup_data(
handler_kwargs={
"init_type": DataHandlerLP.IT_FIT_SEQ,
}
)
dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"])
print(dtrain, dvalid, dtest)
## print or dump data
print(f"===========rolling{rolling_offset} end===========")
if __name__ == "__main__":
fire.Fire(RollingDataWorkflow)

View File

@@ -5,16 +5,15 @@ import os
import sys
import fire
import time
import venv
import glob
import shutil
import signal
import inspect
import tempfile
import traceback
import functools
import statistics
import subprocess
from datetime import datetime
from pathlib import Path
from operator import xor
from pprint import pprint
@@ -22,8 +21,7 @@ from pprint import pprint
import qlib
from qlib.config import REG_CN
from qlib.workflow import R
from qlib.workflow.cli import workflow
from qlib.utils import exists_qlib_data
from qlib.tests.data import GetData
# init qlib
@@ -38,15 +36,9 @@ exp_manager = {
"default_exp_name": "Experiment",
},
}
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
if os.path.isdir(exp_path):
shutil.rmtree(exp_path)
# decorator to check the arguments
def only_allow_defined_args(function_to_decorate):
@@ -70,9 +62,9 @@ def handler(signum, frame):
os.system("kill -9 %d" % os.getpid())
signal.signal(signal.SIGTSTP, handler)
signal.signal(signal.SIGINT, handler)
# function to calculate the mean and std of a list in the results dictionary
def cal_mean_std(results) -> dict:
mean_std = dict()
@@ -136,9 +128,9 @@ def get_all_folders(models, exclude) -> dict:
# function to get all the files under the model folder
def get_all_files(folder_path) -> (str, str):
yaml_path = str(Path(f"{folder_path}") / "*.yaml")
req_path = str(Path(f"{folder_path}") / "*.txt")
def get_all_files(folder_path, dataset) -> (str, str):
yaml_path = str(Path(f"{folder_path}") / f"*{dataset}*.yaml")
req_path = str(Path(f"{folder_path}") / f"*.txt")
return glob.glob(yaml_path)[0], glob.glob(req_path)[0]
@@ -152,6 +144,10 @@ def get_all_results(folders) -> dict:
result["annualized_return_with_cost"] = list()
result["information_ratio_with_cost"] = list()
result["max_drawdown_with_cost"] = list()
result["ic"] = list()
result["icir"] = list()
result["rank_ic"] = list()
result["rank_icir"] = list()
for recorder_id in recorders:
if recorders[recorder_id].status == "FINISHED":
recorder = R.get_recorder(recorder_id=recorder_id, experiment_name=fn)
@@ -159,19 +155,27 @@ def get_all_results(folders) -> dict:
result["annualized_return_with_cost"].append(metrics["excess_return_with_cost.annualized_return"])
result["information_ratio_with_cost"].append(metrics["excess_return_with_cost.information_ratio"])
result["max_drawdown_with_cost"].append(metrics["excess_return_with_cost.max_drawdown"])
result["ic"].append(metrics["IC"])
result["icir"].append(metrics["ICIR"])
result["rank_ic"].append(metrics["Rank IC"])
result["rank_icir"].append(metrics["Rank ICIR"])
results[fn] = result
return results
# function to generate and save markdown table
def gen_and_save_md_table(metrics):
table = "| Model Name | Annualized Return | Information Ratio | Max Drawdown |\n"
table += "|---|---|---|---|\n"
def gen_and_save_md_table(metrics, dataset):
table = "| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |\n"
table += "|---|---|---|---|---|---|---|---|---|\n"
for fn in metrics:
ic = metrics[fn]["ic"]
icir = metrics[fn]["icir"]
ric = metrics[fn]["rank_ic"]
ricir = metrics[fn]["rank_icir"]
ar = metrics[fn]["annualized_return_with_cost"]
ir = metrics[fn]["information_ratio_with_cost"]
md = metrics[fn]["max_drawdown_with_cost"]
table += f"| {fn} | {ar[0]:9.4f}±{ar[1]:9.2f} | {ir[0]:9.4f}±{ir[1]:9.2f}| {md[0]:9.4f}±{md[1]:9.2f} |\n"
table += f"| {fn} | {dataset} | {ic[0]:5.4f}±{ic[1]:2.2f} | {icir[0]:5.4f}±{icir[1]:2.2f}| {ric[0]:5.4f}±{ric[1]:2.2f} | {ricir[0]:5.4f}±{ricir[1]:2.2f} | {ar[0]:5.4f}±{ar[1]:2.2f} | {ir[0]:5.4f}±{ir[1]:2.2f}| {md[0]:5.4f}±{md[1]:2.2f} |\n"
pprint(table)
with open("table.md", "w") as f:
f.write(table)
@@ -180,10 +184,11 @@ def gen_and_save_md_table(metrics):
# function to run the all the models
@only_allow_defined_args
def run(times=1, models=None, exclude=False):
def run(times=1, models=None, dataset="Alpha360", exclude=False):
"""
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
Any PR to enhance this method is highly welcomed.
Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parrallel running the same model
for multiple times, and this will be fixed in the future development.
Parameters:
-----------
@@ -193,6 +198,8 @@ def run(times=1, models=None, exclude=False):
determines the specific model or list of models to run or exclude.
exclude : boolean
determines whether the model being used is excluded or included.
dataset : str
determines the dataset to be used for each model.
Usage:
-------
@@ -206,13 +213,16 @@ def run(times=1, models=None, exclude=False):
# Case 2 - run specific models multiple times
python run_all_model.py 3 mlp
# Case 3 - run other models except those are given as arguments for multiple times
python run_all_model.py 3 [mlp,tft,lstm] True
# Case 3 - run specific models multiple times with specific dataset
python run_all_model.py 3 mlp Alpha158
# Case 4 - run specific models for one time
# Case 4 - run other models except those are given as arguments for multiple times
python run_all_model.py 3 [mlp,tft,lstm] --exclude=True
# Case 5 - run specific models for one time
python run_all_model.py --models=[mlp,lightgbm]
# Case 5 - run other models except those are given as aruments for one time
# Case 6 - run other models except those are given as aruments for one time
python run_all_model.py --models=[mlp,tft,sfm] --exclude=True
"""
@@ -226,7 +236,7 @@ def run(times=1, models=None, exclude=False):
env_path, python_path, conda_activate = create_env()
# get all files
sys.stderr.write("Retrieving files...\n")
yaml_path, req_path = get_all_files(folders[fn])
yaml_path, req_path = get_all_files(folders[fn], dataset)
sys.stderr.write("\n")
# install requirements.txt
sys.stderr.write("Installing requirements.txt...\n")
@@ -240,6 +250,7 @@ def run(times=1, models=None, exclude=False):
sys.stderr.write("\n")
# install qlib
sys.stderr.write("Installing qlib...\n")
execute(f"{python_path} -m pip install --upgrade pip") # TODO: FIX ME!
execute(f"{python_path} -m pip install --upgrade cython") # TODO: FIX ME!
if fn == "TFT":
execute(
@@ -272,12 +283,15 @@ def run(times=1, models=None, exclude=False):
results = cal_mean_std(results)
# generating md table
sys.stderr.write(f"Generating markdown table...\n")
gen_and_save_md_table(results)
gen_and_save_md_table(results, dataset)
sys.stderr.write("\n")
# print erros
sys.stderr.write(f"Here are some of the errors of the models...\n")
pprint(errors)
sys.stderr.write("\n")
# move results folder
shutil.move(exp_path, exp_path + f"_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}")
shutil.move("table.md", f"table_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}.md")
if __name__ == "__main__":

View File

@@ -28,11 +28,17 @@
"import sys, site\n",
"from pathlib import Path\n",
"\n",
"################################# NOTE #################################\n",
"# Please be aware that if colab installs the latest numpy and pyqlib #\n",
"# in this cell, users should RESTART the runtime in order to run the #\n",
"# following cells successfully. #\n",
"########################################################################\n",
"\n",
"try:\n",
" import qlib\n",
"except ImportError:\n",
" # install qlib\n",
" ! pip install --upgrade numpy\n",
" ! pip install pyqlib\n",
" # reload\n",
" site.main()\n",
@@ -238,9 +244,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"from qlib.contrib.report import analysis_model, analysis_position\n",
@@ -359,7 +363,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
"version": "3.8.3"
},
"toc": {
"base_numbering": 1,
@@ -377,4 +381,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}

View File

@@ -1,85 +1,22 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
from pathlib import Path
import qlib
import pandas as pd
from qlib.config import REG_CN
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.utils import init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.tests.data import GetData
from qlib.tests.config import CSI300_BENCH, CSI300_GBDT_TASK
if __name__ == "__main__":
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN)
market = "csi300"
benchmark = "SH000300"
###################################
# train model
###################################
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
}
task = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": 0.8879,
"learning_rate": 0.0421,
"subsample": 0.8789,
"lambda_l1": 205.6999,
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
},
},
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
}
port_analysis_config = {
"strategy": {
"class": "TopkDropoutStrategy",
@@ -93,28 +30,36 @@ if __name__ == "__main__":
"verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": benchmark,
"benchmark": CSI300_BENCH,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
"return_order": True,
},
}
# model initiaiton
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
# model initialization
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
# NOTE: This line is optional
# It demonstrates that the dataset can be used standalone.
example_df = dataset.prepare("train")
print(example_df.head())
# start exp
with R.start(experiment_name="workflow"):
R.log_params(**flatten_dict(task))
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
model.fit(dataset)
R.save_objects(**{"params.pkl": model})
# prediction
recorder = R.get_recorder()
sr = SignalRecord(model, dataset, recorder)
sr.generate()
# backtest
# backtest. If users want to use backtest based on their own prediction,
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
par = PortAnaRecord(recorder, port_analysis_config)
par.generate()

View File

@@ -2,95 +2,54 @@
# Licensed under the MIT License.
__version__ = "0.6.0"
__version__ = "0.6.3.99"
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
import os
import re
import sys
import copy
import yaml
import logging
import platform
import subprocess
from pathlib import Path
from .log import get_module_logger
from .utils import can_use_cache, init_instance_by_config, get_module_by_module_path
from .workflow.utils import experiment_exit_handler
# init qlib
def init(default_conf="client", **kwargs):
from .config import C, REG_CN, REG_US, QlibConfig
from .data.data import register_all_wrappers
from .log import get_module_logger, set_log_with_config
from .config import C
from .data.cache import H
from .workflow import R, QlibRecorder
C.reset()
H.clear()
_logging_config = C.logging_config
if "logging_config" in kwargs:
_logging_config = kwargs["logging_config"]
# set global config
if _logging_config:
set_log_with_config(_logging_config)
# FIXME: this logger ignored the level in config
LOG = get_module_logger("Initialization", level=logging.INFO)
LOG.info(f"default_conf: {default_conf}.")
logger = get_module_logger("Initialization", level=logging.INFO)
C.set_mode(default_conf)
C.set_region(kwargs.get("region", C["region"] if "region" in C else REG_CN))
for k, v in kwargs.items():
C[k] = v
if k not in C:
LOG.warning("Unrecognized config %s" % k)
C.resolve_path()
if not (C["expression_cache"] is None and C["dataset_cache"] is None):
# check redis
if not can_use_cache():
LOG.warning(
f"redis connection failed(host={C['redis_host']} port={C['redis_port']}), cache will not be used!"
)
C["expression_cache"] = None
C["dataset_cache"] = None
C.set(default_conf, **kwargs)
# check path if server/local
if C.get_uri_type() == QlibConfig.LOCAL_URI:
if C.get_uri_type() == C.LOCAL_URI:
if not os.path.exists(C["provider_uri"]):
if C["auto_mount"]:
LOG.error(
logger.error(
f"Invalid provider uri: {C['provider_uri']}, please check if a valid provider uri has been set. This path does not exist."
)
else:
LOG.warning(f"auto_path is False, please make sure {C['mount_path']} is mounted")
elif C.get_uri_type() == QlibConfig.NFS_URI:
logger.warning(f"auto_path is False, please make sure {C['mount_path']} is mounted")
elif C.get_uri_type() == C.NFS_URI:
_mount_nfs_uri(C)
else:
raise NotImplementedError(f"This type of URI is not supported")
LOG.info("qlib successfully initialized based on %s settings." % default_conf)
register_all_wrappers()
LOG.info(f"data_path={C.get_data_path()}")
C.register()
if "flask_server" in C:
LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
# set up QlibRecorder
exp_manager = init_instance_by_config(C["exp_manager"])
qr = QlibRecorder(exp_manager)
R.register(qr)
# clean up experiment when python program ends
experiment_exit_handler()
logger.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
logger.info("qlib successfully initialized based on %s settings." % default_conf)
logger.info(f"data_path={C.get_data_path()}")
def _mount_nfs_uri(C):
from .log import get_module_logger
LOG = get_module_logger("mount nfs", level=logging.INFO)
@@ -189,7 +148,78 @@ def init_from_yaml_conf(conf_path, **kwargs):
"""
with open(conf_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
config = yaml.safe_load(f)
config.update(kwargs)
default_conf = config.pop("default_conf", "client")
init(default_conf, **config)
def get_project_path(config_name="config.yaml", cur_path=None) -> Path:
"""
If users are building a project follow the following pattern.
- Qlib is a sub folder in project path
- There is a file named `config.yaml` in qlib.
For example:
If your project file system stucuture follows such a pattern
<project_path>/
- config.yaml
- ...some folders...
- qlib/
This folder will return <project_path>
NOTE: link is not supported here.
This method is often used when
- user want to use a relative config path instead of hard-coding qlib config path in code
Raises
------
FileNotFoundError:
If project path is not found
"""
if cur_path is None:
cur_path = Path(__file__).absolute().resolve()
while True:
if (cur_path / config_name).exists():
return cur_path
if cur_path == cur_path.parent:
raise FileNotFoundError("We can't find the project path")
cur_path = cur_path.parent
def auto_init(**kwargs):
"""
This function will init qlib automatically with following priority
- Find the project configuration and init qlib
- The parsing process will be affected by the `conf_type` of the configuration file
- Init qlib with default config
"""
try:
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
except FileNotFoundError:
init(**kwargs)
else:
conf_pp = pp / "config.yaml"
with conf_pp.open() as f:
conf = yaml.safe_load(f)
conf_type = conf.get("conf_type", "origin")
if conf_type == "origin":
# The type of config is just like original qlib config
init_from_yaml_conf(conf_pp, **kwargs)
elif conf_type == "ref":
# This config type will be more convenient in following scenario
# - There is a shared configure file and you don't want to edit it inplace.
# - The shared configure may be updated later and you don't want to copy it.
# - You have some customized config.
qlib_conf_path = conf["qlib_cfg"]
qlib_conf_update = conf.get("qlib_cfg_update")
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs)
logger = get_module_logger("Initialization")
logger.info(f"Auto load project config: {conf_pp}")

View File

@@ -11,26 +11,30 @@ Two modes are supported
"""
import copy
from pathlib import Path
import re
import os
import re
import copy
import logging
import multiprocessing
from pathlib import Path
class Config:
def __init__(self, default_conf):
self.__dict__["_default_config"] = default_conf # avoiding conflictions with __getattr__
self.__dict__["_default_config"] = copy.deepcopy(default_conf) # avoiding conflictions with __getattr__
self.reset()
def __getitem__(self, key):
return self.__dict__["_config"][key]
def __getattr__(self, attr):
try:
if attr in self.__dict__["_config"]:
return self.__dict__["_config"][attr]
except KeyError:
return AttributeError(f"No such {attr} in self._config")
raise AttributeError(f"No such {attr} in self._config")
def get(self, key, default=None):
return self.__dict__["_config"].get(key, default)
def __setitem__(self, key, value):
self.__dict__["_config"][key] = value
@@ -59,6 +63,9 @@ class Config:
def update(self, *args, **kwargs):
self.__dict__["_config"].update(*args, **kwargs)
def set_conf_from_C(self, config_c):
self.update(**config_c.__dict__["_config"])
# REGION CONST
REG_CN = "cn"
@@ -86,7 +93,6 @@ _default_config = {
# How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data.
"maxtasksperchild": None,
"default_disk_cache": 1, # 0:skip/1:use
"disable_disk_cache": False, # disable disk cache; if High-frequency data generally disable_disk_cache=True
"mem_cache_size_limit": 500,
# memory cache expire second, only in used 'DatasetURICache' and 'client D.calendar'
# default 1 hour
@@ -102,7 +108,7 @@ _default_config = {
"redis_port": 6379,
"redis_task_db": 1,
# This value can be reset via qlib.init
"logging_level": "INFO",
"logging_level": logging.INFO,
# Global configuration of qlib log
# logging_level can control the logging level more finely
"logging_config": {
@@ -121,14 +127,14 @@ _default_config = {
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": "DEBUG",
"level": logging.DEBUG,
"formatter": "logger_format",
"filters": ["field_not_found"],
}
},
"loggers": {"qlib": {"level": "DEBUG", "handlers": ["console"]}},
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}},
},
# Defatult config for experiment manager
# Default config for experiment manager
"exp_manager": {
"class": "MLflowExpManager",
"module_path": "qlib.workflow.expm",
@@ -137,6 +143,11 @@ _default_config = {
"default_exp_name": "Experiment",
},
},
# Default config for MongoDB
"mongo": {
"task_url": "mongodb://localhost:27017/",
"task_db_name": "default_task_db",
},
}
MODE_CONF = {
@@ -182,11 +193,19 @@ MODE_CONF = {
# The nfs should be auto-mounted by qlib on other
# serversS(such as PAI) [auto_mount:True]
"timeout": 100,
"logging_level": "INFO",
"logging_level": logging.INFO,
"region": REG_CN,
## Custom Operator
"custom_ops": [],
},
}
HIGH_FREQ_CONFIG = {
"provider_uri": "~/.qlib/qlib_data/yahoo_cn_1min",
"dataset_cache": None,
"expression_cache": "DiskExpressionCache",
"region": REG_CN,
}
_default_region_config = {
REG_CN: {
@@ -207,6 +226,10 @@ class QlibConfig(Config):
LOCAL_URI = "local"
NFS_URI = "nfs"
def __init__(self, default_conf):
super().__init__(default_conf)
self._registered = False
def set_mode(self, mode):
# raise KeyError
self.update(MODE_CONF[mode])
@@ -243,6 +266,78 @@ class QlibConfig(Config):
else:
raise NotImplementedError(f"This type of uri is not supported")
def set(self, default_conf="client", **kwargs):
from .utils import set_log_with_config, get_module_logger, can_use_cache
self.reset()
_logging_config = self.logging_config
if "logging_config" in kwargs:
_logging_config = kwargs["logging_config"]
# set global config
if _logging_config:
set_log_with_config(_logging_config)
# FIXME: this logger ignored the level in config
logger = get_module_logger("Initialization", level=logging.INFO)
logger.info(f"default_conf: {default_conf}.")
self.set_mode(default_conf)
self.set_region(kwargs.get("region", self["region"] if "region" in self else REG_CN))
for k, v in kwargs.items():
if k not in self:
logger.warning("Unrecognized config %s" % k)
self[k] = v
self.resolve_path()
if not (self["expression_cache"] is None and self["dataset_cache"] is None):
# check redis
if not can_use_cache():
logger.warning(
f"redis connection failed(host={self['redis_host']} port={self['redis_port']}), cache will not be used!"
)
self["expression_cache"] = None
self["dataset_cache"] = None
def register(self):
from .utils import init_instance_by_config
from .data.ops import register_all_ops
from .data.data import register_all_wrappers
from .workflow import R, QlibRecorder
from .workflow.utils import experiment_exit_handler
register_all_ops(self)
register_all_wrappers(self)
# set up QlibRecorder
exp_manager = init_instance_by_config(self["exp_manager"])
qr = QlibRecorder(exp_manager)
R.register(qr)
# clean up experiment when python program ends
experiment_exit_handler()
# Supporting user reset qlib version (useful when user want to connect to qlib server with old version)
self.reset_qlib_version()
self._registered = True
def reset_qlib_version(self):
import qlib
reset_version = self.get("qlib_reset_version", None)
if reset_version is not None:
qlib.__version__ = reset_version
else:
qlib.__version__ = getattr(qlib, "__version__bak")
# Due to a bug? that converting __version__ to _QlibConfig__version__bak
# Using __version__bak instead of __version__
@property
def registered(self):
return self._registered
# global config
C = QlibConfig(_default_config)

View File

@@ -1,9 +1,324 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# -*- coding: utf-8 -*-
from .order import Order
from .account import Account
from .position import Position
from .exchange import Exchange
from .report import Report
from .backtest import backtest as backtest_func, get_date_range
import numpy as np
import inspect
from ...utils import init_instance_by_config
from ...log import get_module_logger
from ...config import C
logger = get_module_logger("backtest caller")
def get_strategy(
strategy=None,
topk=50,
margin=0.5,
n_drop=5,
risk_degree=0.95,
str_type="dropout",
adjust_dates=None,
):
"""get_strategy
There will be 3 ways to return a stratgy. Please follow the code.
Parameters
----------
strategy : Strategy()
strategy used in backtest.
topk : int (Default value: 50)
top-N stocks to buy.
margin : int or float(Default value: 0.5)
- if isinstance(margin, int):
sell_limit = margin
- else:
sell_limit = pred_in_a_day.count() * margin
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
sell_limit should be no less than topk.
n_drop : int
number of stocks to be replaced in each trading date.
risk_degree: float
0-1, 0.95 for example, use 95% money to trade.
str_type: 'amount', 'weight' or 'dropout'
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
Returns
-------
:class: Strategy
an initialized strategy object
"""
# There will be 3 ways to return a strategy.
if strategy is None:
# 1) create strategy with param `strategy`
str_cls_dict = {
"amount": "TopkAmountStrategy",
"weight": "TopkWeightStrategy",
"dropout": "TopkDropoutStrategy",
}
logger.info("Create new strategy ")
from .. import strategy as strategy_pool
str_cls = getattr(strategy_pool, str_cls_dict.get(str_type))
strategy = str_cls(
topk=topk,
buffer_margin=margin,
n_drop=n_drop,
risk_degree=risk_degree,
adjust_dates=adjust_dates,
)
elif isinstance(strategy, (dict, str)):
# 2) create strategy with init_instance_by_config
logger.info("Create new strategy ")
strategy = init_instance_by_config(strategy)
from ..strategy.strategy import BaseStrategy
# else: nothing happens. 3) Use the strategy directly
if not isinstance(strategy, BaseStrategy):
raise TypeError("Strategy not supported")
return strategy
def get_exchange(
pred,
exchange=None,
subscribe_fields=[],
open_cost=0.0015,
close_cost=0.0025,
min_cost=5.0,
trade_unit=None,
limit_threshold=None,
deal_price=None,
extract_codes=False,
shift=1,
):
"""get_exchange
Parameters
----------
# exchange related arguments
exchange: Exchange().
subscribe_fields: list
subscribe fields.
open_cost : float
open transaction cost.
close_cost : float
close transaction cost.
min_cost : float
min transaction cost.
trade_unit : int
100 for China A.
deal_price: str
dealing price type: 'close', 'open', 'vwap'.
limit_threshold : float
limit move 0.1 (10%) for example, long and short with same limit.
extract_codes: bool
will we pass the codes extracted from the pred to the exchange.
NOTE: This will be faster with offline qlib.
Returns
-------
:class: Exchange
an initialized Exchange object
"""
if trade_unit is None:
trade_unit = C.trade_unit
if limit_threshold is None:
limit_threshold = C.limit_threshold
if deal_price is None:
deal_price = C.deal_price
if exchange is None:
logger.info("Create new exchange")
# handle exception for deal_price
if deal_price[0] != "$":
deal_price = "$" + deal_price
if extract_codes:
codes = sorted(pred.index.get_level_values("instrument").unique())
else:
codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks
dates = sorted(pred.index.get_level_values("datetime").unique())
dates = np.append(dates, get_date_range(dates[-1], left_shift=1, right_shift=shift))
exchange = Exchange(
trade_dates=dates,
codes=codes,
deal_price=deal_price,
subscribe_fields=subscribe_fields,
limit_threshold=limit_threshold,
open_cost=open_cost,
close_cost=close_cost,
min_cost=min_cost,
trade_unit=trade_unit,
)
return exchange
def get_executor(
executor=None,
trade_exchange=None,
verbose=True,
):
"""get_executor
There will be 3 ways to return a executor. Please follow the code.
Parameters
----------
executor : BaseExecutor
executor used in backtest.
trade_exchange : Exchange
exchange used in executor
verbose : bool
whether to print log.
Returns
-------
:class: BaseExecutor
an initialized BaseExecutor object
"""
# There will be 3 ways to return a executor.
if executor is None:
# 1) create executor with param `executor`
logger.info("Create new executor ")
from ..online.executor import SimulatorExecutor
executor = SimulatorExecutor(trade_exchange=trade_exchange, verbose=verbose)
elif isinstance(executor, (dict, str)):
# 2) create executor with config
logger.info("Create new executor ")
executor = init_instance_by_config(executor)
from ..online.executor import BaseExecutor
# 3) Use the executor directly
if not isinstance(executor, BaseExecutor):
raise TypeError("Executor not supported")
return executor
# This is the API for compatibility for legacy code
def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, return_order=False, **kwargs):
"""This function will help you set a reasonable Exchange and provide default value for strategy
Parameters
----------
- **backtest workflow related or commmon arguments**
pred : pandas.DataFrame
predict should has <datetime, instrument> index and one `score` column.
account : float
init account value.
shift : int
whether to shift prediction by one day.
benchmark : str
benchmark code, default is SH000905 CSI 500.
verbose : bool
whether to print log.
return_order : bool
whether to return order list
- **strategy related arguments**
strategy : Strategy()
strategy used in backtest.
topk : int (Default value: 50)
top-N stocks to buy.
margin : int or float(Default value: 0.5)
- if isinstance(margin, int):
sell_limit = margin
- else:
sell_limit = pred_in_a_day.count() * margin
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
sell_limit should be no less than topk.
n_drop : int
number of stocks to be replaced in each trading date.
risk_degree: float
0-1, 0.95 for example, use 95% money to trade.
str_type: 'amount', 'weight' or 'dropout'
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
- **exchange related arguments**
exchange: Exchange()
pass the exchange for speeding up.
subscribe_fields: list
subscribe fields.
open_cost : float
open transaction cost. The default value is 0.002(0.2%).
close_cost : float
close transaction cost. The default value is 0.002(0.2%).
min_cost : float
min transaction cost.
trade_unit : int
100 for China A.
deal_price: str
dealing price type: 'close', 'open', 'vwap'.
limit_threshold : float
limit move 0.1 (10%) for example, long and short with same limit.
extract_codes: bool
will we pass the codes extracted from the pred to the exchange.
.. note:: This will be faster with offline qlib.
- **executor related arguments**
executor : BaseExecutor()
executor used in backtest.
verbose : bool
whether to print log.
"""
# check strategy:
spec = inspect.getfullargspec(get_strategy)
str_args = {k: v for k, v in kwargs.items() if k in spec.args}
strategy = get_strategy(**str_args)
# init exchange:
spec = inspect.getfullargspec(get_exchange)
ex_args = {k: v for k, v in kwargs.items() if k in spec.args}
trade_exchange = get_exchange(pred, **ex_args)
# init executor:
executor = get_executor(executor=kwargs.get("executor"), trade_exchange=trade_exchange, verbose=verbose)
# run backtest
report_dict = backtest_func(
pred=pred,
strategy=strategy,
executor=executor,
trade_exchange=trade_exchange,
shift=shift,
verbose=verbose,
account=account,
benchmark=benchmark,
return_order=return_order,
)
# for compatibility of the old API. return the dict positions
positions = report_dict.get("positions")
report_dict.update({"positions": {k: p.position for k, p in positions.items()}})
return report_dict

View File

@@ -104,10 +104,9 @@ class Account:
# if suspend, no new price to be updated, profit is 0
if trader.check_stock_suspended(code, today):
continue
else:
today_close = trader.get_close(code, today)
profit += (today_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
self.current.update_stock_price(stock_id=code, price=today_close)
today_close = trader.get_close(code, today)
profit += (today_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
self.current.update_stock_price(stock_id=code, price=today_close)
self.rtn += profit
# update holding day count
self.current.add_count_all()

View File

@@ -5,7 +5,6 @@
import numpy as np
import pandas as pd
from ...utils import get_date_by_shift, get_date_range
from ..online.executor import SimulatorExecutor
from ...data import D
from .account import Account
from ...config import C
@@ -15,8 +14,9 @@ from ...data.dataset.utils import get_level_index
LOG = get_module_logger("backtest")
def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark):
"""Parameters
def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, benchmark, return_order):
"""
Parameters
----------
pred : pandas.DataFrame
predict should has <datetime, instrument> index and one `score` column
@@ -69,9 +69,9 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
bench = _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean()
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], shift=shift))
executor = SimulatorExecutor(trade_exchange, verbose=verbose)
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift))
if return_order:
multi_order_list = []
# trading apart
for pred_date, trade_date in zip(predict_dates, trade_dates):
# for loop predict date and trading date
@@ -103,6 +103,8 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
)
else:
order_list = []
if return_order:
multi_order_list.append((trade_account, order_list, trade_date))
# 4. Get result after executing order list
# NOTE: The following operation will modify order.amount.
# NOTE: If it is buy and the cash is insufficient, the tradable amount will be recalculated
@@ -115,11 +117,17 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
report_df = trade_account.report.generate_report_dataframe()
report_df["bench"] = bench
positions = trade_account.get_positions()
return report_df, positions
report_dict = {"report_df": report_df, "positions": positions}
if return_order:
report_dict.update({"order_list": multi_order_list})
return report_dict
def update_account(trade_account, trade_info, trade_exchange, trade_date):
"""Update the account and strategy
"""
Update the account and strategy
Parameters
----------
trade_account : Account()

View File

@@ -1,10 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
import copy
import pathlib
import pandas as pd
import numpy as np
from .order import Order
"""
@@ -128,7 +128,7 @@ class Position:
return self.position["cash"]
def get_stock_amount_dict(self):
"""generate stock amount dict {stock_id : amount of stock} """
"""generate stock amount dict {stock_id : amount of stock}"""
d = {}
stock_list = self.get_stock_list()
for stock_code in stock_list:
@@ -166,7 +166,7 @@ class Position:
def save_position(self, path, last_trade_date):
path = pathlib.Path(path)
p = copy.deepcopy(self.position)
cash = pd.Series(dtype=np.float)
cash = pd.Series(dtype=float)
cash["init_cash"] = self.init_cash
cash["cash"] = p["cash"]
cash["today_account_value"] = p["today_account_value"]

View File

@@ -26,6 +26,7 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time):
"fit_end_time": fit_end_time,
}
)
# FIXME: the `module_path` parameter is missed.
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
else:
new_l.append(p)
@@ -43,16 +44,18 @@ _DEFAULT_INFER_PROCESSORS = [
]
class ALPHA360(DataHandlerLP):
class Alpha360(DataHandlerLP):
def __init__(
self,
instruments="csi500",
start_time=None,
end_time=None,
freq="day",
infer_processors=_DEFAULT_INFER_PROCESSORS,
learn_processors=_DEFAULT_LEARN_PROCESSORS,
fit_start_time=None,
fit_end_time=None,
filter_pipe=None,
**kwargs,
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
@@ -65,13 +68,15 @@ class ALPHA360(DataHandlerLP):
"feature": self.get_feature_config(),
"label": kwargs.get("label", self.get_label_config()),
},
"filter_pipe": filter_pipe,
"freq": freq,
},
}
super().__init__(
instruments,
start_time,
end_time,
instruments=instruments,
start_time=start_time,
end_time=end_time,
data_loader=data_loader,
learn_processors=learn_processors,
infer_processors=infer_processors,
@@ -119,7 +124,7 @@ class ALPHA360(DataHandlerLP):
return fields, names
class ALPHA360vwap(ALPHA360):
class Alpha360vwap(Alpha360):
def get_label_config(self):
return (["Ref($vwap, -2)/Ref($vwap, -1) - 1"], ["LABEL0"])
@@ -130,11 +135,13 @@ class Alpha158(DataHandlerLP):
instruments="csi500",
start_time=None,
end_time=None,
freq="day",
infer_processors=[],
learn_processors=_DEFAULT_LEARN_PROCESSORS,
fit_start_time=None,
fit_end_time=None,
process_type=DataHandlerLP.PTYPE_A,
filter_pipe=None,
**kwargs,
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
@@ -143,13 +150,18 @@ class Alpha158(DataHandlerLP):
data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"config": {"feature": self.get_feature_config(), "label": kwargs.get("label", self.get_label_config())},
"config": {
"feature": self.get_feature_config(),
"label": kwargs.get("label", self.get_label_config()),
},
"filter_pipe": filter_pipe,
"freq": freq,
},
}
super().__init__(
instruments,
start_time,
end_time,
instruments=instruments,
start_time=start_time,
end_time=end_time,
data_loader=data_loader,
infer_processors=infer_processors,
learn_processors=learn_processors,

View File

@@ -8,6 +8,59 @@ import pandas as pd
from typing import Tuple
def calc_long_short_prec(
pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False
) -> Tuple[pd.Series, pd.Series]:
"""
calculate the precision for long and short operation
:param pred/label: index is **pd.MultiIndex**, index name is **[datetime, instruments]**; columns names is **[score]**.
.. code-block:: python
score
datetime instrument
2020-12-01 09:30:00 SH600068 0.553634
SH600195 0.550017
SH600276 0.540321
SH600584 0.517297
SH600715 0.544674
label :
label
date_col :
date_col
Returns
-------
(pd.Series, pd.Series)
long precision and short precision in time level
"""
if is_alpha:
label = label - label.mean(level=date_col)
if int(1 / quantile) >= len(label.index.get_level_values(1).unique()):
raise ValueError("Need more instruments to calculate precision")
df = pd.DataFrame({"pred": pred, "label": label})
if dropna:
df.dropna(inplace=True)
group = df.groupby(level=date_col)
N = lambda x: int(len(x) * quantile)
# find the top/low quantile of prediction and treat them as long and short target
long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True)
short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True)
groupll = long.groupby(date_col)
l_dom = groupll.apply(lambda x: x > 0)
l_c = groupll.count()
groups = short.groupby(date_col)
s_dom = groups.apply(lambda x: x < 0)
s_c = groups.count()
return (l_dom.groupby(date_col).sum() / l_c), (s_dom.groupby(date_col).sum() / s_c)
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]:
"""calc_ic.

View File

@@ -6,17 +6,16 @@ from __future__ import print_function
import numpy as np
import pandas as pd
import inspect
import warnings
from ..log import get_module_logger
from . import strategy as strategy_pool
from .strategy.strategy import BaseStrategy
from .backtest.exchange import Exchange
from .backtest.backtest import backtest as backtest_func, get_date_range
from .backtest import get_exchange, backtest as backtest_func
from .backtest.backtest import get_date_range
from ..data import D
from ..config import C
from ..data.dataset.utils import get_level_index
logger = get_module_logger("Evaluate")
@@ -46,144 +45,6 @@ def risk_analysis(r, N=252):
return res
def get_strategy(
strategy=None,
topk=50,
margin=0.5,
n_drop=5,
risk_degree=0.95,
str_type="amount",
adjust_dates=None,
):
"""get_strategy
Parameters
----------
strategy : Strategy()
strategy used in backtest.
topk : int (Default value: 50)
top-N stocks to buy.
margin : int or float(Default value: 0.5)
- if isinstance(margin, int):
sell_limit = margin
- else:
sell_limit = pred_in_a_day.count() * margin
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
sell_limit should be no less than topk.
n_drop : int
number of stocks to be replaced in each trading date.
risk_degree: float
0-1, 0.95 for example, use 95% money to trade.
str_type: 'amount', 'weight' or 'dropout'
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
Returns
-------
:class: Strategy
an initialized strategy object
"""
if strategy is None:
str_cls_dict = {
"amount": "TopkAmountStrategy",
"weight": "TopkWeightStrategy",
"dropout": "TopkDropoutStrategy",
}
logger.info("Create new streategy ")
str_cls = getattr(strategy_pool, str_cls_dict.get(str_type))
strategy = str_cls(
topk=topk,
buffer_margin=margin,
n_drop=n_drop,
risk_degree=risk_degree,
adjust_dates=adjust_dates,
)
if not isinstance(strategy, BaseStrategy):
raise TypeError("Strategy not supported")
return strategy
def get_exchange(
pred,
exchange=None,
subscribe_fields=[],
open_cost=0.0015,
close_cost=0.0025,
min_cost=5.0,
trade_unit=None,
limit_threshold=None,
deal_price=None,
extract_codes=False,
shift=1,
):
"""get_exchange
Parameters
----------
# exchange related arguments
exchange: Exchange().
subscribe_fields: list
subscribe fields.
open_cost : float
open transaction cost.
close_cost : float
close transaction cost.
min_cost : float
min transaction cost.
trade_unit : int
100 for China A.
deal_price: str
dealing price type: 'close', 'open', 'vwap'.
limit_threshold : float
limit move 0.1 (10%) for example, long and short with same limit.
extract_codes: bool
will we pass the codes extracted from the pred to the exchange.
NOTE: This will be faster with offline qlib.
Returns
-------
:class: Exchange
an initialized Exchange object
"""
if trade_unit is None:
trade_unit = C.trade_unit
if limit_threshold is None:
limit_threshold = C.limit_threshold
if deal_price is None:
deal_price = C.deal_price
if exchange is None:
logger.info("Create new exchange")
# handle exception for deal_price
if deal_price[0] != "$":
deal_price = "$" + deal_price
if extract_codes:
codes = sorted(pred.index.get_level_values("instrument").unique())
else:
codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks
dates = sorted(pred.index.get_level_values("datetime").unique())
dates = np.append(dates, get_date_range(dates[-1], shift=shift))
exchange = Exchange(
trade_dates=dates,
codes=codes,
deal_price=deal_price,
subscribe_fields=subscribe_fields,
limit_threshold=limit_threshold,
open_cost=open_cost,
close_cost=close_cost,
min_cost=min_cost,
trade_unit=trade_unit,
)
return exchange
# This is the API for compatibility for legacy code
def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **kwargs):
"""This function will help you set a reasonable Exchange and provide default value for strategy
@@ -249,30 +110,22 @@ def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **k
will we pass the codes extracted from the pred to the exchange.
.. note:: This will be faster with offline qlib.
- **executor related arguments**
executor : BaseExecutor()
executor used in backtest.
verbose : bool
whether to print log.
"""
# check strategy:
spec = inspect.getfullargspec(get_strategy)
str_args = {k: v for k, v in kwargs.items() if k in spec.args}
strategy = get_strategy(**str_args)
# init exchange:
spec = inspect.getfullargspec(get_exchange)
ex_args = {k: v for k, v in kwargs.items() if k in spec.args}
trade_exchange = get_exchange(pred, **ex_args)
# run backtest
report_df, positions = backtest_func(
pred=pred,
strategy=strategy,
trade_exchange=trade_exchange,
shift=shift,
verbose=verbose,
account=account,
benchmark=benchmark,
warnings.warn(
"this function is deprecated, please use backtest function in qlib.contrib.backtest", DeprecationWarning
)
# for compatibility of the old API. return the dict positions
positions = {k: p.position for k, p in positions.items()}
return report_df, positions
report_dict = backtest_func(
pred=pred, account=account, shift=shift, benchmark=benchmark, verbose=verbose, return_order=False, **kwargs
)
return report_dict.get("report_df"), report_dict.get("positions")
def long_short_backtest(
@@ -340,7 +193,7 @@ def long_short_backtest(
_pred_dates = pred.index.get_level_values(level="datetime")
predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max())
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], shift=shift))
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift))
long_returns = {}
short_returns = {}

View File

@@ -61,7 +61,7 @@ def get_position_value(evaluate_date, position):
# load close price for position
# position should also consider cash
instruments = list(position.keys())
instruments = list(set(instruments) - set(["cash"])) # filter 'cash'
instruments = list(set(instruments) - {"cash"}) # filter 'cash'
fields = ["$close"]
close_data_df = D.features(
instruments,
@@ -80,7 +80,7 @@ def get_position_list_value(positions):
instruments = set()
for day, position in positions.items():
instruments.update(position.keys())
instruments = list(set(instruments) - set(["cash"])) # filter 'cash'
instruments = list(set(instruments) - {"cash"}) # filter 'cash'
instruments.sort()
day_list = list(positions.keys())
day_list.sort()

View File

@@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
try:
from .catboost_model import CatBoostModel
except ModuleNotFoundError:
CatBoostModel = None
print("Please install necessary libs for CatBoostModel.")
try:
from .double_ensemble import DEnsembleModel
from .gbdt import LGBModel
except ModuleNotFoundError:
DEnsembleModel, LGBModel = None, None
print("Please install necessary libs for DEnsembleModel and LGBModel, such as lightgbm.")
try:
from .xgboost import XGBModel
except ModuleNotFoundError:
XGBModel = None
print("Please install necessary libs for XGBModel, such as xgboost.")
try:
from .linear import LinearModel
except ModuleNotFoundError:
LinearModel = None
print("Please install necessary libs for LinearModel, such as scipy and sklearn.")
# import pytorch models
try:
from .pytorch_alstm import ALSTM
from .pytorch_gats import GATs
from .pytorch_gru import GRU
from .pytorch_lstm import LSTM
from .pytorch_nn import DNNModelPytorch
from .pytorch_tabnet import TabnetModel
from .pytorch_sfm import SFM_Model
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model)
except ModuleNotFoundError:
pytorch_classes = ()
print("Please install necessary libs for PyTorch models.")
all_model_classes = (CatBoostModel, DEnsembleModel, LGBModel, XGBModel, LinearModel) + pytorch_classes

View File

@@ -3,15 +3,17 @@
import numpy as np
import pandas as pd
from typing import Text, Union
from catboost import Pool, CatBoost
from catboost.utils import get_gpu_device_count
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...model.interpret.base import FeatureInt
class CatBoostModel(Model):
class CatBoostModel(Model, FeatureInt):
"""CatBoost Model"""
def __init__(self, loss="RMSE", **kwargs):
@@ -62,12 +64,24 @@ class CatBoostModel(Model):
evals_result["train"] = list(evals_result["learn"].values())[0]
evals_result["valid"] = list(evals_result["validation"].values())[0]
def predict(self, dataset):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if self.model is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
"""get feature importance
Notes
-----
parameters references:
https://catboost.ai/docs/concepts/python-reference_catboost_get_feature_importance.html#python-reference_catboost_get_feature_importance
"""
return pd.Series(
data=self.model.get_feature_importance(*args, **kwargs), index=self.model.feature_names_
).sort_values(ascending=False)
if __name__ == "__main__":
cat = CatBoostModel()

View File

@@ -0,0 +1,265 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import lightgbm as lgb
import numpy as np
import pandas as pd
from typing import Text, Union
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...model.interpret.base import FeatureInt
from ...log import get_module_logger
class DEnsembleModel(Model, FeatureInt):
"""Double Ensemble Model"""
def __init__(
self,
base_model="gbm",
loss="mse",
num_models=6,
enable_sr=True,
enable_fs=True,
alpha1=1.0,
alpha2=1.0,
bins_sr=10,
bins_fs=5,
decay=None,
sample_ratios=None,
sub_weights=None,
epochs=100,
**kwargs
):
self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm"
self.num_models = num_models # the number of sub-models
self.enable_sr = enable_sr
self.enable_fs = enable_fs
self.alpha1 = alpha1
self.alpha2 = alpha2
self.bins_sr = bins_sr
self.bins_fs = bins_fs
self.decay = decay
if sample_ratios is None: # the default values for sample_ratios
sample_ratios = [0.8, 0.7, 0.6, 0.5, 0.4]
if sub_weights is None: # the default values for sub_weights
sub_weights = [1.0, 0.2, 0.2, 0.2, 0.2, 0.2]
if not len(sample_ratios) == bins_fs:
raise ValueError("The length of sample_ratios should be equal to bins_fs.")
self.sample_ratios = sample_ratios
if not len(sub_weights) == num_models:
raise ValueError("The length of sub_weights should be equal to num_models.")
self.sub_weights = sub_weights
self.epochs = epochs
self.logger = get_module_logger("DEnsembleModel")
self.logger.info("Double Ensemble Model...")
self.ensemble = [] # the current ensemble model, a list contains all the sub-models
self.sub_features = [] # the features for each sub model in the form of pandas.Index
self.params = {"objective": loss}
self.params.update(kwargs)
self.loss = loss
def fit(self, dataset: DatasetH):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
x_train, y_train = df_train["feature"], df_train["label"]
# initialize the sample weights
N, F = x_train.shape
weights = pd.Series(np.ones(N, dtype=float))
# initialize the features
features = x_train.columns
pred_sub = pd.DataFrame(np.zeros((N, self.num_models), dtype=float), index=x_train.index)
# train sub-models
for k in range(self.num_models):
self.sub_features.append(features)
self.logger.info("Training sub-model: ({}/{})".format(k + 1, self.num_models))
model_k = self.train_submodel(df_train, df_valid, weights, features)
self.ensemble.append(model_k)
# no further sample re-weight and feature selection needed for the last sub-model
if k + 1 == self.num_models:
break
self.logger.info("Retrieving loss curve and loss values...")
loss_curve = self.retrieve_loss_curve(model_k, df_train, features)
pred_k = self.predict_sub(model_k, df_train, features)
pred_sub.iloc[:, k] = pred_k
pred_ensemble = pred_sub.iloc[:, : k + 1].mean(axis=1)
loss_values = pd.Series(self.get_loss(y_train.values.squeeze(), pred_ensemble.values))
if self.enable_sr:
self.logger.info("Sample re-weighting...")
weights = self.sample_reweight(loss_curve, loss_values, k + 1)
if self.enable_fs:
self.logger.info("Feature selection...")
features = self.feature_selection(df_train, loss_values)
def train_submodel(self, df_train, df_valid, weights, features):
dtrain, dvalid = self._prepare_data_gbm(df_train, df_valid, weights, features)
evals_result = dict()
model = lgb.train(
self.params,
dtrain,
num_boost_round=self.epochs,
valid_sets=[dtrain, dvalid],
valid_names=["train", "valid"],
verbose_eval=20,
evals_result=evals_result,
)
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]
return model
def _prepare_data_gbm(self, df_train, df_valid, weights, features):
x_train, y_train = df_train["feature"].loc[:, features], df_train["label"]
x_valid, y_valid = df_valid["feature"].loc[:, features], df_valid["label"]
# Lightgbm need 1D array as its label
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)
else:
raise ValueError("LightGBM doesn't support multi-label training")
dtrain = lgb.Dataset(x_train, label=y_train, weight=weights)
dvalid = lgb.Dataset(x_valid, label=y_valid)
return dtrain, dvalid
def sample_reweight(self, loss_curve, loss_values, k_th):
"""
the SR module of Double Ensemble
:param loss_curve: the shape is NxT
the loss curve for the previous sub-model, where the element (i, t) if the error on the i-th sample
after the t-th iteration in the training of the previous sub-model.
:param loss_values: the shape is N
the loss of the current ensemble on the i-th sample.
:param k_th: the index of the current sub-model, starting from 1
:return: weights
the weights for all the samples.
"""
# normalize loss_curve and loss_values with ranking
loss_curve_norm = loss_curve.rank(axis=0, pct=True)
loss_values_norm = (-loss_values).rank(pct=True)
# calculate l_start and l_end from loss_curve
N, T = loss_curve.shape
part = np.maximum(int(T * 0.1), 1)
l_start = loss_curve_norm.iloc[:, :part].mean(axis=1)
l_end = loss_curve_norm.iloc[:, -part:].mean(axis=1)
# calculate h-value for each sample
h1 = loss_values_norm
h2 = (l_end / l_start).rank(pct=True)
h = pd.DataFrame({"h_value": self.alpha1 * h1 + self.alpha2 * h2})
# calculate weights
h["bins"] = pd.cut(h["h_value"], self.bins_sr)
h_avg = h.groupby("bins")["h_value"].mean()
weights = pd.Series(np.zeros(N, dtype=float))
for i_b, b in enumerate(h_avg.index):
weights[h["bins"] == b] = 1.0 / (self.decay ** k_th * h_avg[i_b] + 0.1)
return weights
def feature_selection(self, df_train, loss_values):
"""
the FS module of Double Ensemble
:param df_train: the shape is NxF
:param loss_values: the shape is N
the loss of the current ensemble on the i-th sample.
:return: res_feat: in the form of pandas.Index
"""
x_train, y_train = df_train["feature"], df_train["label"]
features = x_train.columns
N, F = x_train.shape
g = pd.DataFrame({"g_value": np.zeros(F, dtype=float)})
M = len(self.ensemble)
# shuffle specific columns and calculate g-value for each feature
x_train_tmp = x_train.copy()
for i_f, feat in enumerate(features):
x_train_tmp.loc[:, feat] = np.random.permutation(x_train_tmp.loc[:, feat].values)
pred = pd.Series(np.zeros(N), index=x_train_tmp.index)
for i_s, submodel in enumerate(self.ensemble):
pred += (
pd.Series(
submodel.predict(x_train_tmp.loc[:, self.sub_features[i_s]].values), index=x_train_tmp.index
)
/ M
)
loss_feat = self.get_loss(y_train.values.squeeze(), pred.values)
g.loc[i_f, "g_value"] = np.mean(loss_feat - loss_values) / (np.std(loss_feat - loss_values) + 1e-7)
x_train_tmp.loc[:, feat] = x_train.loc[:, feat].copy()
# one column in train features is all-nan # if g['g_value'].isna().any()
g["g_value"].replace(np.nan, 0, inplace=True)
# divide features into bins_fs bins
g["bins"] = pd.cut(g["g_value"], self.bins_fs)
# randomly sample features from bins to construct the new features
res_feat = []
sorted_bins = sorted(g["bins"].unique(), reverse=True)
for i_b, b in enumerate(sorted_bins):
b_feat = features[g["bins"] == b]
num_feat = int(np.ceil(self.sample_ratios[i_b] * len(b_feat)))
res_feat = res_feat + np.random.choice(b_feat, size=num_feat, replace=False).tolist()
return pd.Index(set(res_feat))
def get_loss(self, label, pred):
if self.loss == "mse":
return (label - pred) ** 2
else:
raise ValueError("not implemented yet")
def retrieve_loss_curve(self, model, df_train, features):
if self.base_model == "gbm":
num_trees = model.num_trees()
x_train, y_train = df_train["feature"].loc[:, features], df_train["label"]
# Lightgbm need 1D array as its label
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
y_train = np.squeeze(y_train.values)
else:
raise ValueError("LightGBM doesn't support multi-label training")
N = x_train.shape[0]
loss_curve = pd.DataFrame(np.zeros((N, num_trees)))
pred_tree = np.zeros(N, dtype=float)
for i_tree in range(num_trees):
pred_tree += model.predict(x_train.values, start_iteration=i_tree, num_iteration=1)
loss_curve.iloc[:, i_tree] = self.get_loss(y_train, pred_tree)
else:
raise ValueError("not implemented yet")
return loss_curve
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if self.ensemble is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
pred = pd.Series(np.zeros(x_test.shape[0]), index=x_test.index)
for i_sub, submodel in enumerate(self.ensemble):
feat_sub = self.sub_features[i_sub]
pred += (
pd.Series(submodel.predict(x_test.loc[:, feat_sub].values), index=x_test.index)
* self.sub_weights[i_sub]
)
return pred
def predict_sub(self, submodel, df_data, features):
x_data, y_data = df_data["feature"].loc[:, features], df_data["label"]
pred_sub = pd.Series(submodel.predict(x_data.values), index=x_data.index)
return pred_sub
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
"""get feature importance
Notes
-----
parameters reference:
https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.Booster.html?highlight=feature_importance#lightgbm.Booster.feature_importance
"""
res = []
for _model, _weight in zip(self.ensemble, self.sub_weights):
res.append(pd.Series(_model.feature_importance(*args, **kwargs), index=_model.feature_name()) * _weight)
return pd.concat(res, axis=1, sort=False).sum(axis=1).sort_values(ascending=False)

Some files were not shown because too many files have changed in this diff Show More