1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 14:01:28 +08:00

Compare commits

...

426 Commits

Author SHA1 Message Date
Young
77ba7b4e91 Add analyser example and finetune example 2021-06-06 07:51:52 +00:00
Young
7a639eeea7 add IC and rank IC 2021-06-06 07:43:26 +00:00
Young
cddaf90ef5 monitor initial version 2021-06-06 07:43:26 +00:00
you-n-g
4ff0c4fb0f Update strategy.rst 2021-05-31 08:52:41 +08:00
you-n-g
b3eece155f Merge pull request #452 from arisliang/patch-1
Add import stock pool in doc
2021-05-30 09:44:47 +08:00
al
02e34eb9e9 Add import stock pool (csi300) in documentation 2021-05-30 08:27:21 +08:00
you-n-g
3033fdf4b7 Merge pull request #444 from zhupr/feature_importance
add get_feature_importance to model interpret
2021-05-28 16:23:53 +08:00
zhupr
ef11a9d95c modify the default value of exists_skip in the GetData.qlib_data parameter to False 2021-05-28 14:57:06 +08:00
zhupr
98eacf8f88 add test/config.py 2021-05-28 13:24:47 +08:00
you-n-g
43cad1ec27 Merge pull request #442 from arisliang/patch-3
Update 1min demo data in CSV format
2021-05-28 11:22:10 +08:00
you-n-g
e7aa7ffcdd Merge pull request #447 from arisliang/patch-2
Remove repeated package from requirements
2021-05-28 10:46:58 +08:00
you-n-g
ed3c9d9212 Merge pull request #448 from arisliang/patch-4
Update integration.rst
2021-05-28 10:46:25 +08:00
you-n-g
2f3fbae73b Merge pull request #450 from arisliang/patch-5
Update report.rst
2021-05-28 10:45:52 +08:00
al
e409bee9b9 Update report.rst
typo
2021-05-28 07:54:45 +08:00
al
7ceec37848 Update integration.rst
Fix typo
2021-05-27 22:35:43 +08:00
al
c12c861b7a Remove repeated package from requirements 2021-05-27 19:37:57 +08:00
zhupr
0a4e241608 add get_feature_importance to model interpret 2021-05-27 14:19:10 +08:00
al
5a382d7e99 Update data.rst
Update csv format according to feedback
2021-05-27 12:40:55 +08:00
al
9b431bc503 Update 1min demo data in CSV format 2021-05-26 22:01:15 +08:00
you-n-g
cbbf6cd822 Merge pull request #441 from zhupr/fix_yahoo_collector
Fix YahooCollector can't download 1min data
2021-05-26 21:41:14 +08:00
you-n-g
928bae08f4 Merge pull request #440 from arisliang/patch-2
Update collector.py
2021-05-26 21:40:23 +08:00
you-n-g
c65fc226bd Merge pull request #439 from arisliang/patch-1
Update README.md
2021-05-26 21:39:53 +08:00
zhupr
114162693f Fix YahooCollector can't download 1min data 2021-05-26 18:29:41 +08:00
al
b884c8c571 Update collector.py
fix typo
2021-05-26 18:00:23 +08:00
al
6222940b9c Update README.md
fix typo
2021-05-26 17:50:49 +08:00
you-n-g
bb0c555803 Merge pull request #372 from zhupr/data_storage
add data storage
2021-05-26 14:30:46 +08:00
zhupr
5da33562dd remove uri parameter from storage && modify file_storage 2021-05-26 12:33:57 +08:00
you-n-g
db3aa8b887 Merge pull request #433 from arisliang/patch-1
Update README.md
2021-05-23 23:53:49 +08:00
al
ae32d79549 Update README.md
fix typo
2021-05-23 16:54:14 +08:00
you-n-g
a80369b80a Update README.md 2021-05-22 19:00:43 +08:00
you-n-g
369177acf9 Update README.md 2021-05-22 19:00:14 +08:00
you-n-g
2ac5ceb4de Update README.md 2021-05-22 17:58:21 +08:00
zhupr
602f78b568 add documentation on which storage methods are used in qlib 2021-05-22 08:30:12 +08:00
zhupr
669f6bd6f5 modify exception message hint for storage.py && fix FileFeatureStorage[:] bug 2021-05-22 02:06:15 +08:00
you-n-g
3d71fd1966 Merge pull request #431 from evanzd/fix_get_module
Fix `get_module_by_module_path` to support pickle module loaded from arbitrage source file
2021-05-21 12:53:14 +08:00
zhupr
b887d2ec32 code for formatting storage.py using black(v21.5) 2021-05-21 10:29:39 +08:00
zhupr
8e6c744a1b Merge remote-tracking branch 'microsoft/main' into data_storage 2021-05-21 09:49:29 +08:00
zhupr
9e296a8a4e replace the type of numpy deprecated 2021-05-21 08:56:44 +08:00
zhupr
4ba4512619 add write method to FeatureStorage && remove extend 2021-05-21 08:45:11 +08:00
Dong Zhou
2fa7ef32fb fix picker error on importlib loaded module 2021-05-18 22:37:31 +08:00
Dong Zhou
c72ee9091e fix picker error on importlib loaded module 2021-05-18 22:12:41 +08:00
you-n-g
19eda8f4f0 Update README.md 2021-05-17 17:45:31 +08:00
you-n-g
d08146c30f Merge pull request #290 from you-n-g/online_srv
init version of online serving and rolling
2021-05-17 17:35:29 +08:00
lzh222333
8c3a08b18d Finally! 2021-05-17 07:27:55 +00:00
you-n-g
142a9dca3c Merge pull request #423 from ChengYen-Tang/LightGBM-Hyperparameter
LightGBM hyperparameter
2021-05-16 19:43:59 +08:00
Kenneth Tang
41ab130807 Fix CI lint with black 2021-05-18 00:01:45 +08:00
Kenneth Tang
8f67010b58 Fix CI lint with black 2021-05-17 23:09:42 +08:00
lzh222333
a986379deb bug fixed 2021-05-14 11:31:50 +00:00
lzh222333
aef3f186c1 format code 2021-05-14 06:58:02 +00:00
lzh222333
ebd01e0de5 Online Serving V11 2021-05-14 06:44:16 +00:00
Kenneth Tang
f51e04a1cc LightGBM hyperparameter 2021-05-13 23:12:29 +08:00
lzh222333
d71a666904 Online serving V11 2021-05-13 09:43:42 +00:00
you-n-g
8b15ffc027 Merge pull request #419 from Derek-Wds/main
Fix bug and update doc
2021-05-13 16:38:22 +08:00
you-n-g
f15ca39df8 Merge pull request #418 from zhupr/set_global_logger_level
Modify set_global_logger_level use of contextmanager
2021-05-13 16:36:40 +08:00
Jactus
bd37f5d953 Fix bug and update doc 2021-05-13 14:21:54 +08:00
zhupr
76c5c5d1b6 Add docstrings to set_global_logger_level 2021-05-12 22:38:50 +08:00
zhupr
b8e64dc526 Modify set_global_logger_level use of contextmanager 2021-05-12 17:58:39 +08:00
you-n-g
df7c882fe3 Merge pull request #416 from zhupr/configurable_dataset
Add configurable dataset to examples
2021-05-12 09:03:56 +08:00
zhupr
9bd77bd89f Add configurable dataset to examples 2021-05-11 17:36:55 +08:00
you-n-g
c43a0b208d Merge pull request #413 from Derek-Wds/stale
Stale bot update
2021-05-10 21:15:36 +08:00
Jactus
8ba5e93d04 Skip enhancement in stale bot 2021-05-10 12:33:58 +08:00
lzh222333
370b6aad74 logger & doc 2021-05-09 11:58:06 +00:00
lzh222333
f5ded06a15 Merge remote-tracking branch 'microsoft/main' into online_srv 2021-05-09 10:53:41 +00:00
lzh222333
4c232610f1 Merge branch 'online_srv' of https://github.com/you-n-g/qlib into online_srv 2021-05-09 10:52:07 +00:00
Dingsu Wang
81bd2ca8fb Merge branch 'microsoft:main' into stale 2021-05-09 17:59:46 +08:00
Jactus
143c257fa2 Update stale bot 2021-05-09 17:56:37 +08:00
Jactus
724f9ba8d2 Update stale bot 2021-05-09 17:52:18 +08:00
you-n-g
aa1f9b464b Merge pull request #412 from zhupr/set_global_logger_level
Add set_global_logger_level
2021-05-09 11:25:56 +08:00
zhupr
d2daba99d3 Add set_global_logger_level 2021-05-09 09:05:57 +08:00
you-n-g
1c605e505a Merge pull request #14 from you-n-g/online_srv_blin
Online srv blin
2021-05-07 21:07:47 +08:00
you-n-g
060a32e0f6 Merge branch 'online_srv' into online_srv_blin 2021-05-07 21:07:27 +08:00
binlins
08edb92461 add flt_data doc 2021-05-07 12:56:58 +00:00
binlins
bec65ddf94 add document and reindex 2021-05-07 11:47:47 +00:00
lzh222333
9dfd001f6f online serving v10 2021-05-07 09:59:15 +00:00
you-n-g
95a4a98de8 Update README.md 2021-05-07 10:56:57 +08:00
you-n-g
d4639b7df9 Update README.md
Update high-frequency trading link
2021-05-07 10:46:45 +08:00
blin
846c64f6c6 fix param 2021-05-06 12:00:41 +00:00
lzh222333
84c56f13bd docs and bug fixed 2021-05-06 04:18:55 +00:00
lzh222333
1c99fb35da Merge remote-tracking branch 'microsoft/main' into online_srv 2021-05-06 02:24:16 +00:00
you-n-g
5bc2b96346 Update data.rst 2021-05-03 12:34:08 +08:00
you-n-g
ee269b0914 Merge pull request #407 from Derek-Wds/log
Fix logger pickling error
2021-04-30 17:33:42 +08:00
you-n-g
2a2d2cf709 Merge pull request #404 from Derek-Wds/main
Support start exp with given exp & recorder id
2021-04-30 16:46:46 +08:00
Jactus
5eb9dfff16 Remove redundant 2021-04-30 15:28:37 +08:00
Jactus
694ae34027 Update api 2021-04-30 13:27:19 +08:00
Jactus
51b649ec39 Update QlibLogger 2021-04-30 13:13:05 +08:00
Jactus
ca92cb980c Update meta logger 2021-04-29 22:40:52 +08:00
Jactus
f58c61a2e0 Fix logger pickling error 2021-04-29 16:54:51 +08:00
lzh222333
2b7ffa100f Merge remote-tracking branch 'microsoft/main' into online_srv 2021-04-29 05:23:42 +00:00
lzh222333
67c5740c83 OnlineServing V9 2021-04-29 04:30:09 +00:00
lzh222333
6f669348a8 Merge branch 'online_srv' of https://github.com/you-n-g/qlib into online_srv 2021-04-28 09:23:13 +00:00
lzh222333
40cf83e557 online serving V9 middle status 2021-04-28 09:23:07 +00:00
blin
fa4511cb0a filter 2021-04-28 07:30:22 +00:00
blin
45c6dfc5da filter 2021-04-28 07:25:19 +00:00
blin
36ab078fbd filter 2021-04-28 07:15:59 +00:00
you-n-g
5a7f9ef720 Merge pull request #405 from zhupr/future_trading_date_collector
Add future trading date collector
2021-04-28 10:07:08 +08:00
zhupr
8b8d21107c Add future trading date collector 2021-04-27 21:20:47 +08:00
Jactus
eab19de080 Support start exp with given exp & recorder id 2021-04-27 16:56:07 +08:00
lzh222333
42f510024c update collector 2021-04-27 04:12:08 +00:00
Young
5a7eecabee black formating (black is upgraded in github) 2021-04-27 04:04:43 +00:00
lzh222333
0058f7d0dc Online Serving V8 2021-04-26 09:31:47 +00:00
you-n-g
9a74fe34f6 Merge pull request #401 from zhupr/online_fix
Fix online mode bugs
2021-04-26 12:29:55 +08:00
zhupr
e15ea06122 Fix ClientProvider not supporting LocalInstrumentProvider && online using the latest python-socketio 2021-04-25 23:50:29 +08:00
lzh222333
319396c815 online serving V8 2021-04-25 06:26:45 +00:00
you-n-g
50be7a9171 Merge pull request #393 from Derek-Wds/main
Update qlib logger
2021-04-23 12:41:05 +08:00
Jactus
e410caaa8f Simplify meta class 2021-04-23 10:08:12 +08:00
Jactus
fbff4c271a Remove redundant methods in meta 2021-04-23 00:38:45 +08:00
you-n-g
ee91503973 Merge pull request #399 from Derek-Wds/doc
Update doc
2021-04-22 20:37:40 +08:00
lzh222333
de0a0c083d bug fixed 2021-04-22 08:09:15 +00:00
Jactus
8adfafa6aa Black format 2021-04-22 14:17:25 +08:00
Jactus
aafaff45d2 Update doc 2021-04-22 14:13:36 +08:00
Jactus
6a05d4e255 Enable IDEs docstrings 2021-04-19 11:36:00 +08:00
Jactus
cbf1fa721e Update 2021-04-17 15:47:49 +08:00
Jactus
4ebf684794 Update workflow logging 2021-04-16 15:35:11 +08:00
Jactus
f4bfe8e619 First trial of adding docstring 2021-04-16 14:35:05 +08:00
lzh222333
cec318fbfe online serving V7 2021-04-16 05:37:13 +00:00
Jactus
78bb8882cd Format 2021-04-16 12:00:18 +08:00
Jactus
848d953226 Update qlib logger 2021-04-16 09:58:55 +08:00
you-n-g
a3a2b5ae0b Merge pull request #358 from javaThonc/high_freq_demp
update high freq demo
2021-04-14 20:05:07 +08:00
Alex Wang
941c980d06 update tabnet 2021-04-14 17:35:19 +08:00
Alex Wang
fe190dec4b update readme 2021-04-14 14:40:28 +08:00
lzh222333
5095b2a470 simulator & examples 2021-04-13 09:45:16 +00:00
zhupr
317357b50d Modify data.storage 2021-04-13 10:47:01 +08:00
Young
b15e5e33fd Fix the multi-processing bug 2021-04-12 06:33:31 +00:00
Young
cca43cf102 Refactor update & modification when running NN 2021-04-11 14:39:19 +00:00
Young
a366c11d67 Update features for hyb nn 2021-04-09 13:48:01 +00:00
Young
18bf4b5477 parameter adjustment 2021-04-08 03:52:58 +00:00
lzh222333
c20eb5c8a6 format code 2021-04-08 03:30:24 +00:00
Young
71605794a2 Merge branch 'online_srv_wd' into online_srv 2021-04-07 05:17:07 +00:00
Young
1dbb561744 Fix some API(for lb nn) 2021-04-07 03:53:56 +00:00
lzh222333
cb42e99bee bug fixed & examples fire 2021-04-07 03:33:27 +00:00
lzh222333
431a9c92c1 online serving v5 2021-04-02 07:09:29 +00:00
lzh222333
bd7a1c11b9 trainer & group & collect & ensemble 2021-04-02 04:27:14 +00:00
zhupr
70fc58104b Modify FileStorage 2021-04-01 12:58:34 +08:00
lzh222333
edcd7b1ff9 bug fixed & code format 2021-03-31 03:08:48 +00:00
lzh222333
3724273d73 Merge remote-tracking branch 'microsoft/qlib/main' into online_srv 2021-03-31 02:54:05 +00:00
lzh222333
544365f3a9 ensemble & get_exp & dataset_pickle 2021-03-31 02:39:14 +00:00
Alex Wang
bed1175e24 update dataset 2021-03-30 19:29:17 +08:00
you-n-g
70c84cbc77 Merge pull request #381 from D-X-Y/main
Fix print issue
2021-03-30 17:25:26 +08:00
you-n-g
da59b35c0a Merge pull request #380 from Derek-Wds/main
Modify get_exp & get_recorder api
2021-03-30 16:54:01 +08:00
lzh222333
eae94d1ee8 Merge remote-tracking branch 'microsoft/qlib/main' into online_srv 2021-03-30 07:16:56 +00:00
lzh222333
1f2d2c9b69 online debug 2021-03-30 06:56:04 +00:00
Jactus
b6df11b6b4 Modify get_exp & get_recorder api 2021-03-30 14:41:56 +08:00
you-n-g
ae57110f64 Merge pull request #374 from bxdd/qlib_loaderhandler
Add DataLoader Based on DataHandler & Add Rolling Process Example & Restructure the Config & Setup_data
2021-03-30 13:37:55 +08:00
bxdd
7a2203f116 update comments 2021-03-30 11:03:07 +08:00
bxdd
023603479c fix readme 2021-03-30 01:00:12 +08:00
bxdd
f8da79b802 fix readme 2021-03-30 00:54:00 +08:00
bxdd
136830bc2b update comments 2021-03-30 00:38:15 +08:00
you-n-g
45f78676ea Merge pull request #379 from zhupr/fix_usinex_collector
Fix us_index collector
2021-03-29 23:54:05 +08:00
bxdd
1074284666 fix docstring 2021-03-29 20:38:09 +08:00
bxdd
d18c367497 update README 2021-03-29 20:34:36 +08:00
bxdd
8743576f72 black format 2021-03-29 20:16:00 +08:00
bxdd
fb7f84f31e fix ubg 2021-03-29 20:15:42 +08:00
bxdd
31bc85bf86 restructure data layer config & setup 2021-03-29 19:49:30 +08:00
D-X-Y
968930e85f Fix print issue 2021-03-29 04:46:38 +00:00
zhupr
4b66304978 Fix us_index collector 2021-03-29 11:25:24 +08:00
you-n-g
253378a44e Merge pull request #378 from D-X-Y/main
Add MultiSegRecord and add segment kwargs in model.pred
2021-03-29 01:06:41 +08:00
D-X-Y
f809f0a063 Remove un-used imports 2021-03-28 10:50:25 +00:00
D-X-Y
0386df7b16 Collect all contrib models in __init__ and add unit tests for init 2021-03-28 10:39:28 +00:00
D-X-Y
8a2e7b62af Add segment args for pred and refine MultiSegRecord 2021-03-28 08:30:16 +00:00
D-X-Y
9d04ae4676 Add MultiSegRecord in contrib.workflow and decouple its tests from test_all_pipeline 2021-03-28 00:33:59 -07:00
zhupr
9b8acd9a82 Fix FileStorage 2021-03-27 01:15:33 +08:00
lzh222333
ee45a7833e Merge branch 'main' into online_srv 2021-03-26 08:20:21 +00:00
zhupr
d395c904f2 Add FileStorage 2021-03-26 16:14:45 +08:00
lzh222333
9bf819e653 Merge branch 'online_srv' of https://github.com/you-n-g/qlib into online_srv 2021-03-26 04:32:20 +00:00
lzh222333
46cd57688e Online Serving V4 2021-03-26 04:20:25 +00:00
you-n-g
0387eaf7ab Merge pull request #373 from lewwang1995/main
debug
2021-03-26 11:59:20 +08:00
bxdd
4ee0240c24 black format 2021-03-25 22:08:39 +08:00
bxdd
5f60d18dfe fix config_data bug 2021-03-25 22:08:23 +08:00
bxdd
194217fb07 fix bug 2021-03-25 21:47:17 +08:00
bxdd
d6ff764bb2 black format 2021-03-25 20:36:45 +08:00
bxdd
9cc3b18e4e fix but 2021-03-25 20:36:07 +08:00
LewenWang
56eaacd931 debug 2021-03-25 20:34:45 +08:00
bxdd
e119c8576c black format 2021-03-25 19:59:22 +08:00
bxdd
68246b3b6d update workflow 2021-03-25 19:58:55 +08:00
bxdd
a04c6bd6c9 balck format 2021-03-25 19:56:22 +08:00
bxdd
efe134e9f4 update workflow 2021-03-25 19:56:04 +08:00
bxdd
4ec300787e update rolling workflow 2021-03-25 19:54:52 +08:00
you-n-g
3886022669 Merge pull request #371 from Derek-Wds/main
Update notebook
2021-03-25 18:13:54 +08:00
zhupr
8264033a72 add data-storage 2021-03-25 17:22:05 +08:00
Jactus
4861552d28 Update notebook 2021-03-25 17:13:52 +08:00
you-n-g
834f9bd9b8 Update README.md 2021-03-25 16:58:35 +08:00
bxdd
f6dc25b229 update rolling process 2021-03-25 16:14:22 +08:00
bxdd
1fcfe8e4ba add rolling process data 2021-03-25 01:37:17 +08:00
bxdd
b1a28358ad black format 2021-03-25 01:30:31 +08:00
bxdd
1ca3c6a61c add DataHandlerDL 2021-03-25 01:29:59 +08:00
Alex Wang
e3739bb980 fix naming and code style 2021-03-24 15:47:26 +08:00
you-n-g
419629e4d2 Merge pull request #365 from Flouse/main
fix docs
2021-03-24 13:12:15 +08:00
Flouse
e490e83a16 fix docs 2021-03-24 11:37:09 +08:00
you-n-g
fda144e66f Merge pull request #362 from D-X-Y/main
Add load_object function for R
2021-03-23 22:47:49 +08:00
you-n-g
4dc10d27e0 Merge pull request #359 from bxdd/doc0
Fix data doc
2021-03-23 21:40:27 +08:00
D-X-Y
0a0c6a3185 Add load_object function for R 2021-03-23 10:10:17 +00:00
lzh222333
d66d4ec93d Merge branch 'main' into online_srv 2021-03-23 11:20:02 +08:00
bxdd
4b56a4e907 fix doc 2021-03-22 18:45:27 +08:00
bxdd
7370d5af9e add label doc 2021-03-22 18:37:44 +08:00
bxdd
c6b67cb8fe fix doc 2021-03-22 18:37:13 +08:00
Alex Wang
3bf6c7f95f update format 2021-03-22 15:37:54 +08:00
Alex Wang
1ad237f89f update high freq demo 2021-03-22 14:20:44 +08:00
you-n-g
2b74b4dfa4 Merge pull request #357 from zhupr/fix_yahoo_collector
Fix yahoo_collector
2021-03-22 12:41:17 +08:00
zhupr
598ee875a0 Fix yahoo_collector 2021-03-22 10:29:07 +08:00
Young
84d5318bda Merge branch 'online_srv_wd' into online_srv 2021-03-19 07:49:16 +00:00
you-n-g
ba56e4071e Merge pull request #292 from wangershi/addFund
Add fund data as an example
2021-03-19 11:05:29 +08:00
wangershi
d3160e9439 remove some useless code 2021-03-18 21:15:45 +08:00
you-n-g
06c90d654d Merge pull request #350 from zhupr/fix_dump_bin
Fix dump_bin.py && check_dump_bin.py
2021-03-18 18:39:40 +08:00
you-n-g
f72771cc81 Merge pull request #351 from wendili-cs/patch-1
Update __init__.py
2021-03-18 18:28:18 +08:00
lzh222333
8abdd63869 online_serving V3 2021-03-18 09:30:01 +00:00
Wendi Li
38f35658e7 Update __init__.py 2021-03-18 13:19:27 +08:00
zhupr
d245242f2f Fix dump_bin.py && check_dump_bin.py 2021-03-18 11:21:25 +08:00
you-n-g
6ef204f190 Merge pull request #348 from microsoft/you-n-g-patch-1
Update Contact Us Information
2021-03-17 16:29:28 +08:00
Young
dad18074ac update image 2021-03-17 08:27:59 +00:00
you-n-g
3cf84f8859 Update Contact Us Information 2021-03-17 16:16:00 +08:00
you-n-g
0403237232 Update Contact US 2021-03-17 15:41:16 +08:00
you-n-g
689774c6be Merge pull request #340 from Derek-Wds/main
Support resuming recorder
2021-03-17 14:48:59 +08:00
Jactus
d78e42e2fe Update exp base method 2021-03-17 13:37:17 +08:00
you-n-g
4de628c736 Merge pull request #347 from 2young-2simple-sometimes-naive/patch-1
fix bug of consider TURE as boolean instead of stock code
2021-03-17 12:08:57 +08:00
you-n-g
023c1fedfe Merge pull request #280 from yongzhengqi/main
Implement Enhanced Indexing as a Portfolio Optimizer
2021-03-17 12:07:39 +08:00
you-n-g
9be6866972 Update README.md 2021-03-17 12:03:53 +08:00
you-n-g
be55e0e3fe Fix Typos 2021-03-17 12:03:00 +08:00
you-n-g
619a3bb25d Update plan and news on Index Page 2021-03-17 12:02:24 +08:00
you-n-g
4bd2cd4611 Add feature news on index page 2021-03-17 11:36:12 +08:00
you-n-g
aa552fdb20 Merge pull request #345 from D-X-Y/main
Fix errors when SignalRecord is not called before SigAna/PortAna
2021-03-17 10:47:00 +08:00
安阁锐
5520463395 fix bug of consider TURE as boolean instead of stock code 2021-03-16 22:18:11 -04:00
D-X-Y
872ddc6f95 Fix black error 2021-03-16 22:57:26 +08:00
D-X-Y
88b0871c12 Add RMSE for contrib.workflow.record_temp and unit tests 2021-03-16 22:55:28 +08:00
D-X-Y
d4aa681652 Add MSERecord in contrib.workflow 2021-03-16 12:54:12 +00:00
Jactus
34f0be2836 Fix arg error 2021-03-16 17:18:48 +08:00
Jactus
447fed8e54 Update structure for resuming 2021-03-16 17:16:00 +08:00
D-X-Y
4cb74d77d1 add error type for record_temp 2021-03-16 09:01:10 +00:00
D-X-Y
b0fd0d2395 Add tests for SigAnaRecord 2021-03-16 08:30:46 +00:00
D-X-Y
6559d44c7d Add tests for SigAnaRecord 2021-03-16 08:17:13 +00:00
D-X-Y
9f57681032 Fix errors when SignalRecord is not called before SigAna/PortAna 2021-03-16 08:11:05 +00:00
lzh222333
d33041dc24 format example 2021-03-16 02:52:20 +00:00
lzh222333
5953365af3 finished update_online_pred demo 2021-03-16 02:43:12 +00:00
lzh222333
e3730b32d7 more clearly structure 2021-03-16 02:23:28 +00:00
Jactus
08b44ed727 Update docs 2021-03-15 14:12:35 +08:00
Jactus
83fb482f1e Fix Bug 2021-03-15 13:55:57 +08:00
Jactus
734bb9ee3c Support resuming recorder 2021-03-15 13:50:10 +08:00
you-n-g
d47e35d64e Merge pull request #337 from D-X-Y/main
Fix bugs in Ghost BN in TabNet and typos in README
2021-03-15 12:42:48 +08:00
lzh222333
0bc49dab60 add task management to index.rst 2021-03-15 03:58:05 +00:00
lzh222333
646d899f8d update docstring and document 2021-03-15 03:50:43 +00:00
D-X-Y
07434da8b0 Remove unused imports 2021-03-15 03:35:34 +00:00
D-X-Y
53a6b72ce5 Fix black errors 2021-03-15 03:09:52 +00:00
D-X-Y
a51dafcb4c Remove unnecessary codes 2021-03-15 03:07:25 +00:00
Young
8362780e22 fix import bug 2021-03-14 15:16:38 +00:00
D-X-Y
358de88602 Fix typos in README and add TabNet config for Alpha360 2021-03-14 07:42:24 +00:00
D-X-Y
32a7be9964 Fix typos in README 2021-03-14 07:31:18 +00:00
D-X-Y
d5f9395e51 Fix Ghost BN bugs in TabNet and simplify its implementation 2021-03-14 07:25:09 +00:00
wangershi
4e7a147759 use base.py 2021-03-14 14:24:14 +08:00
wangershi
1344c40598 Merge remote-tracking branch 'remoteGit/main' into addFund 2021-03-14 11:19:01 +08:00
you-n-g
1d2b2f4f01 Merge pull request #333 from Rekind1e/main
another typo of docs
2021-03-13 20:05:25 +08:00
Hy
373f6e0900 another typo of docs 2021-03-13 15:47:26 +08:00
you-n-g
ba64758c24 Merge pull request #332 from Rekind1e/main
Fix typo of docs
2021-03-13 14:48:02 +08:00
Hy
abddcfccdf fix typo of docs 2021-03-13 14:32:01 +08:00
you-n-g
6d5381f9b1 Merge pull request #311 from withshubh/main
Fixed code quality issues
2021-03-12 18:31:02 +08:00
Young
e4e8a4abcd fix task name & add cur_path 2021-03-12 10:17:16 +00:00
Shubhendra Singh Chauhan
e41373b8ad revert fix 2021-03-12 14:10:52 +05:30
lzh222333
9d84d389ab format code and add example 2021-03-12 08:24:21 +00:00
lzh222333
6d8aa215d6 the second version of online serving 2021-03-12 08:04:08 +00:00
shubhendra
0969c3e7e0 formatted with black 2021-03-12 13:29:20 +05:30
Young
5de7870f9b Merge branch 'online_srv' of github.com:you-n-g/qlib into online_srv 2021-03-12 07:52:31 +00:00
Young
44a7dc004d update docs and fix duplicated pred bug 2021-03-12 07:50:17 +00:00
Shubhendra Singh Chauhan
5f8d0e0436 Merge branch 'main' into main 2021-03-12 13:08:45 +05:30
shubhendra
4fbb5a03c1 revert fixes that failed unit test 2021-03-12 13:05:48 +05:30
you-n-g
0cffb87cbc Merge pull request #329 from D-X-Y/main
Fix Various Bugs for contrib.pytorch_ models
2021-03-12 12:30:08 +08:00
you-n-g
df56e3bdf9 Merge pull request #324 from zhupr/add_base_collector
Add BaseCollector
2021-03-12 12:20:24 +08:00
D-X-Y
1d435248e2 Add return for use_gpu.. 2021-03-11 19:28:00 -08:00
D-X-Y
593553f573 Fix bug in MLP 2021-03-11 19:15:18 -08:00
D-X-Y
d38b8d6001 Fix bugs in use_gpu 2021-03-11 19:10:32 -08:00
D-X-Y
db59713d36 Add torch.no_grad for evaluation 2021-03-12 02:46:04 +00:00
D-X-Y
67fbdafe76 Fix many bugs in TabNet and use_gpu 2021-03-12 02:42:25 +00:00
zhupr
42be8ac312 Add BaseCollector 2021-03-12 10:30:38 +08:00
lzh222333
0df88c07f6 bug fixed and update collect.py 2021-03-11 16:25:46 +00:00
you-n-g
f6b019dcec Merge pull request #328 from D-X-Y/fshare
Move get_path to get_or_create_path, use the best model of SFM / TabNet
2021-03-11 22:07:20 +08:00
D-X-Y
e626264d5a Merge branch 'main' of github.com:microsoft/qlib into fshare 2021-03-11 12:54:04 +00:00
D-X-Y
b99de068f8 Move save_path to get_or_create_path, and fix bugs in sfm / tabnet 2021-03-11 12:52:26 +00:00
you-n-g
e8beaa5257 Merge pull request #314 from D-X-Y/fshare
(1) Fix /0 bug in double_ensemble, (2) remove _default_uri for R/expm, (3) support model size in pytorch models
2021-03-11 20:51:16 +08:00
D-X-Y
0ef7c8e0e6 Fix bugs for get_local_dir 2021-03-11 03:05:31 +00:00
lzh222333
48f0fc147f first version of online serving 2021-03-11 03:00:30 +00:00
D-X-Y
cda96be8c3 Refine default uri in expm 2021-03-11 02:49:03 +00:00
D-X-Y
f6ed175070 Remove set_log_basic_config, refine count_parameters, rename root_uri as get_local_dir 2021-03-11 02:33:00 +00:00
lzh222333
2ca2071d95 format code 2021-03-10 17:06:08 +00:00
you-n-g
0054a4db2a Merge pull request #322 from Derek-Wds/bug
Fix pytorch ts model loader bug
2021-03-10 19:47:56 +08:00
lzh222333
e2f58274ba update task manager 2021-03-10 10:58:49 +00:00
Jactus
119fe90570 Fix pytorch ts model loader bug 2021-03-10 16:43:32 +08:00
you-n-g
e2817ab87c Merge pull request #319 from Derek-Wds/main
Update Filter doc
2021-03-10 15:38:39 +08:00
you-n-g
2e37033e35 Merge pull request #318 from microsoft/bxdd-patch-2
Fix code in ops
2021-03-10 14:45:35 +08:00
Jactus
105fe1d3ed Remove deprecated warning for numpy>=1.20.0 2021-03-10 10:38:43 +08:00
Jactus
78bc2c8748 Update Filter doc 2021-03-09 17:31:27 +08:00
lzh222333
83dbdfb45e finished document and example 2021-03-09 17:22:36 +08:00
bxdd
81987bb143 Update ops.py 2021-03-09 15:38:04 +08:00
Charles Young
53cf89d7c2 Reformat with black. 2021-03-08 19:43:03 +08:00
Charles Young
8b9065c166 Reformat with black. 2021-03-08 19:32:13 +08:00
Charles Young
6a305c73ae Resolve https://github.com/microsoft/qlib/pull/280\#discussion_r589166529 2021-03-08 19:08:55 +08:00
Charles Young
7022675d00 Resolve https://github.com/microsoft/qlib/pull/280\#discussion_r589169489 2021-03-08 19:07:28 +08:00
Charles Young
2f9af1af8f Resolve https://github.com/microsoft/qlib/pull/280\#discussion_r589169769 2021-03-08 19:02:40 +08:00
Charles Young
fc89fec46d Resolve https://github.com/microsoft/qlib/pull/280\#discussion_r589168764 2021-03-08 18:56:54 +08:00
Charles Young
c6675be792 Resolve https://github.com/microsoft/qlib/pull/280\#discussion_r589166143 2021-03-08 17:51:36 +08:00
Charles Young
351d598c9f Resolve https://github.com/microsoft/qlib/pull/280\#discussion_r589165409 2021-03-08 17:49:59 +08:00
Charles Young
81b86f8022 Update test to cover changes in structured_cov_estimator 2021-03-08 17:18:07 +08:00
Charles Young
4d5a30b30b Resolve https://github.com/microsoft/qlib/pull/280\#discussion_r589167776 2021-03-08 17:14:29 +08:00
Charles Young
79c1142d3e Pass nan_option to structured covariance estimator. 2021-03-08 17:09:33 +08:00
D-X-Y
e061443560 Fix lint error with Black 2021-03-08 08:27:58 +00:00
D-X-Y
03ef918dd8 Fix bugs in count_parameters 2021-03-08 08:24:23 +00:00
D-X-Y
ca48345b29 Simplify count_parameters 2021-03-08 08:16:17 +00:00
lzh222333
def132e140 modified format and added TaskCollector 2021-03-08 16:10:16 +08:00
D-X-Y
7bed3b4c2e Fix python format by black 2021-03-08 06:39:03 +00:00
D-X-Y
4266492a34 Merge branch 'main' of github.com:microsoft/qlib into fshare 2021-03-08 06:12:49 +00:00
you-n-g
91eef93386 Merge pull request #302 from D-X-Y/main
Update repr for dataset/workflow classes and add uri kwarg for QlibRecorder
2021-03-08 14:01:53 +08:00
lzh222333
a244f87f95 modified the comments 2021-03-08 13:25:11 +08:00
wangershi
9df0361262 black 2021-03-07 19:35:50 +08:00
wangershi
6bcd88973b resolve one bug 2021-03-07 19:32:37 +08:00
D-X-Y
d13c9ae018 Avoid dividing zero in model.double_ensemble 2021-03-07 11:25:53 +00:00
wangershi
11412727ef add normalizer 2021-03-07 18:51:38 +08:00
D-X-Y
73b7107ee8 Remove useless verbose kwarg 2021-03-07 00:52:27 -08:00
D-X-Y
91fd53ab4d Add reset_default_uri func for R and expm 2021-03-06 05:33:08 -08:00
shubhendra
aab5c5b311 Refactor the comparison involving not
Signed-off-by: shubhendra <withshubh@gmail.com>
2021-03-06 13:01:35 +05:30
shubhendra
dc86a6abc5 Refactor unnecessary else / elif when if block has a continue statement
Signed-off-by: shubhendra <withshubh@gmail.com>
2021-03-06 13:01:34 +05:30
shubhendra
a62d1a1b36 Use literal syntax to create data structure
Signed-off-by: shubhendra <withshubh@gmail.com>
2021-03-06 13:01:34 +05:30
shubhendra
5015d218ff Remove methods with unnecessary super delegation.
Signed-off-by: shubhendra <withshubh@gmail.com>
2021-03-06 13:01:34 +05:30
shubhendra
6f034ccb5d Remove unnecessary use of comprehension
Signed-off-by: shubhendra <withshubh@gmail.com>
2021-03-06 13:01:33 +05:30
shubhendra
07eef18337 Remove length check in favour of truthiness of the object
Signed-off-by: shubhendra <withshubh@gmail.com>
2021-03-06 13:01:33 +05:30
shubhendra
f277a66582 Add .deepsource.toml
Signed-off-by: shubhendra <withshubh@gmail.com>
2021-03-06 13:01:32 +05:30
D-X-Y
49697b1f15 Show model size for pytorch models 2021-03-05 12:46:41 +00:00
D-X-Y
131f0e2e67 Add count_parameters for pytorch models in contrib 2021-03-05 12:07:43 +00:00
D-X-Y
b14a559a52 Merge branch 'main' of github.com:D-X-Y/qlib into fshare 2021-03-05 07:25:18 +00:00
D-X-Y
b115ca5353 Merge branch 'main' of github.com:microsoft/qlib into main 2021-03-04 23:19:33 -08:00
D-X-Y
b40bfb1ea5 Merge branch 'main' of github.com:microsoft/qlib into fshare 2021-03-05 07:17:09 +00:00
D-X-Y
4f980a0266 Merge branch 'fshare' of github.com:D-X-Y/qlib into fshare 2021-03-05 07:15:46 +00:00
D-X-Y
19d93744f3 Add set_log_basic_config function support re-directing log stream 2021-03-05 07:15:25 +00:00
D-X-Y
e327f404e3 Fix pylint issues 2021-03-04 22:37:58 -08:00
D-X-Y
452fb8f013 Make mlflow client consistant with uri 2021-03-04 22:33:35 -08:00
D-X-Y
c4d6e00470 Fix logic of uri in ExpM and add test 2021-03-04 21:04:01 -08:00
Charles Young
0f3e3d206b Update __init__.py. 2021-03-04 22:47:42 +08:00
Charles Young
83c6e74783 Reindex files. 2021-03-04 22:30:38 +08:00
Charles Young
2bff6eb781 Split classes in riskmodel.py & optimizer.py into seperate files. 2021-03-04 22:08:11 +08:00
D-X-Y
ee7eb79277 Fix unexpected mlruns folder error 2021-03-04 06:15:24 +00:00
D-X-Y
592db903b3 Update repr for Experiment & Recorder 2021-03-04 05:02:56 +00:00
wangershi
34b7da1dd8 add calendar list by threshold 2021-03-03 22:49:48 +08:00
lzh222333
2882929c5d Add an example about workflow using RollingGen. 2021-03-03 16:58:05 +08:00
lzh222333
fd2c1ba1ed Update some hint 2021-03-03 16:36:15 +08:00
lzh222333
05cf0e1edc add task_generator method and update some hint 2021-03-03 15:42:39 +08:00
D-X-Y
229a39d0d3 Fix typos in DataHandler's doc 2021-03-03 07:30:55 +00:00
D-X-Y
a9a70dfddf Update repr for DatasetH and ExpManager 2021-03-03 06:47:52 +00:00
you-n-g
b8cf229b05 Merge pull request #272 from ChengYen-Tang/Fix_collector_doc
Fix collector doc
2021-03-03 14:14:07 +08:00
you-n-g
7258340e0c Merge pull request #300 from bxdd/qlib_2021_03_02
Update FAQ doc & Update ops docstring
2021-03-03 14:11:11 +08:00
lzh222333
b84156fde8 Consider more situations about task_config.
Save the "param" file which is collect.py need.
2021-03-03 11:25:37 +08:00
you-n-g
d1d70616a3 Merge pull request #298 from D-X-Y/main
Update the Wrapper class to have an informative str representation
2021-03-03 11:17:00 +08:00
bxdd
5378d261b4 update ops.py docstring 2021-03-03 00:47:50 +09:00
bxdd
a1fb10f7cf update FAQ doc 2021-03-03 00:44:52 +09:00
D-X-Y
dbc8ca7379 Fix pylint by black -l 120 2021-03-02 22:15:30 +08:00
D-X-Y
c48b4c9971 Make Wrapper with a informative str repr. 2021-03-02 21:06:32 +08:00
you-n-g
b592669d1f Update Index Report 2021-03-02 20:23:27 +08:00
you-n-g
0bcaab3a5a Merge pull request #286 from meng-ustc/main
Add a new method to benchmarks: DoubleEnsemble
2021-03-02 18:34:17 +08:00
meng-ustc
1de4def444 Update parameter names: 'k' and 'base' 2021-03-02 16:14:56 +09:00
meng-ustc
ee4692a355 Merge branch 'main' of https://github.com/meng-ustc/qlib 2021-03-02 12:20:45 +09:00
meng-ustc
6e2ce6f1dc Add the results of DoubleEnsemble 2021-03-02 12:17:05 +09:00
lzh222333
c4733f601f Merge pull request #1 from you-n-g/online_srv
qlib auto init basedon project & black format
2021-03-02 10:06:45 +08:00
wangershi
82353b20e1 black format 2021-03-01 21:10:46 +08:00
wangershi
51baf57b40 Merge remote-tracking branch 'remoteGit/main' into addFund 2021-02-28 19:29:18 +08:00
wangershi
3082f6ac1b ready for dump_bin 2021-02-28 19:06:40 +08:00
wangershi
db80b620d8 ready for collector 2021-02-28 17:03:14 +08:00
wangershi
6e56396217 add crawler 2021-02-28 12:24:26 +08:00
Young
24024d51c7 qlib auto init basedon project & black format 2021-02-27 09:44:01 +00:00
Wendi Li
a96f0c2e5f Update README.md
Typos fixed.
2021-02-27 11:37:45 +08:00
Young
1e5cf1c174 init version of online serving and rolling 2021-02-26 09:14:40 +00:00
wangershi
719074d306 touch file 2021-02-25 19:20:14 +08:00
Meng Dong
70575e8a1c Delete workflow_by_code_lgb_risk_demo.py 2021-02-24 16:10:38 +08:00
meng-ustc
ce60097722 Add README and Formatted 2021-02-24 16:59:31 +09:00
meng-ustc
1a990fdd25 Add Risk Prediction Demo 2021-02-23 19:08:11 +09:00
Charles Young
527718a440 Allow enhanced indexing to generate portfolio without industry related restriction. 2021-02-22 19:04:31 +08:00
Charles Young
d3caea60ee Add unittest for TestStructuredCovEstimator. 2021-02-22 17:32:03 +08:00
Charles Young
f947a2fdef Correct two mistakes in annotation. 2021-02-22 15:15:51 +08:00
Jactus
dc4aa67503 Black format 2021-02-22 11:42:36 +08:00
Charles Young
37871389b9 Format code with the latest version of black. 2021-02-22 11:25:42 +08:00
Charles Young
2f9d45e03a Reformat code with black. 2021-02-22 10:29:29 +08:00
Charles Young
b8647c13c7 Reformat code to follow PEP 8. 2021-02-22 10:20:51 +08:00
Charles Young
164687d54b Add scikit-learn to dependencies. 2021-02-22 10:13:08 +08:00
Charles Young
58f74cfd84 Reformat code to follow PEP 8. 2021-02-22 10:07:03 +08:00
Charles Young
f7d3e56561 Merge optimization related portfolio construction back to portfolio/optimizer. 2021-02-22 09:57:41 +08:00
Charles Young
42f882504e Reformat code to follow PEP 8. 2021-02-22 09:25:48 +08:00
Charles Young
9448a6e2c7 Add a abstract class as the base class for all optimization related portfolio constructions. 2021-02-22 09:23:48 +08:00
Charles Young
2cc057e438 Fix minor mismatches of type hints. 2021-02-22 09:09:03 +08:00
Charles Young
b2e2142594 Applied slight modification to follow PEP 8. 2021-02-22 09:00:12 +08:00
Charles Young
4000518698 Separate specific implementation of Portfolio Optimizer to folder. 2021-02-22 08:41:35 +08:00
you-n-g
fa8f1cba06 Update filter.py 2021-02-19 22:23:22 +08:00
you-n-g
a72911e4f8 Update filter.py 2021-02-19 22:23:22 +08:00
meng-ustc
cd5b721bc6 Update 2021-02-19 11:56:50 +09:00
meng-ustc
42590972e4 Modify run_all_model.py 2021-02-18 19:15:02 +09:00
meng-ustc
d27dc8bab8 Add A New Baseline: DoubleEnsemble 2021-02-18 19:02:33 +09:00
Young
50d5fcf61e yml afe load 2021-02-18 08:43:12 +08:00
Young
77830a546e safe yaml loader 2021-02-18 08:43:12 +08:00
Young
83237ba4ed yml afe load 2021-02-17 05:17:18 +00:00
Young
04b916c8ae safe yaml loader 2021-02-16 15:07:14 +00:00
Kenneth Tang
b90bd66ac6 Merge branch 'main' into Fix_collector_doc 2021-02-16 14:16:53 +08:00
Kenneth Tang
63d05e4a1a Fix typo 2021-02-16 14:08:58 +08:00
lbaiao
0b11dc5167 Update workflow.rst
Corrected an identation problem on the configuration.yaml file.
2021-02-11 10:54:32 +08:00
Charles Young
9c2653f125 Add an implementation of Enhanced Indexing to optimizer.py 2021-02-09 20:31:00 +08:00
Charles Young
7b01c5cae7 Add an implementation of Enhanced Indexing to optimizer.py 2021-02-09 20:30:26 +08:00
Charles Young
988b42e159 Add Structured Covariance Estimator to riskmodel.py 2021-02-09 20:28:42 +08:00
Miae Kim
12c8bfa545 Fix typo in Data Layer
Commit changes
2021-02-08 10:29:37 +08:00
Jactus
c948385e76 Add model saving for qrun and workflow example 2021-02-07 21:32:35 +08:00
bxdd
07b905c153 update 2021-02-05 16:09:06 +08:00
bxdd
0192f28bf4 add docstring & fix code 2021-02-05 16:09:06 +08:00
bxdd
cbf97f56a4 fix market 2021-02-05 16:09:06 +08:00
bxdd
d702c8bcb1 add Cut ops 2021-02-05 16:09:06 +08:00
Jactus
b84686b215 Update models to enable save/load 2021-02-05 13:14:12 +08:00
bxdd
6a670828a5 Update serial.rst 2021-02-05 12:33:47 +08:00
bxdd
ca6c2ffc27 Update serial.rst 2021-02-05 12:33:47 +08:00
bxdd
914637b3ef update index 2021-02-05 12:33:47 +08:00
bxdd
d8da94de10 update docs 2021-02-05 12:33:47 +08:00
bxdd
477a548fe9 update data.rst docs 2021-02-05 12:33:47 +08:00
bxdd
35af9ad954 update docs 2021-02-05 12:33:47 +08:00
bxdd
8a91e7d34d black format 2021-02-05 12:33:47 +08:00
bxdd
4ed8b8e233 add docs & fix reinit of datatset 2021-02-05 12:33:47 +08:00
Jactus
c71b645777 Update docs about record-temp 2021-02-05 12:32:47 +08:00
HUAN-PING SU
f2ffb80a0b Update workflow_by_code.py 2021-02-04 17:01:34 +08:00
HUAN-PING SU
cda1d4be40 Update workflow_by_code.py
Fix typo in workflow_by_code.py
2021-02-04 17:01:34 +08:00
zhupr
fc1431cd4e update 1min docs 2021-02-04 16:49:58 +08:00
Jactus
06158fb621 Update readme and setup.py 2021-02-04 16:09:36 +08:00
Jactus
1e2e02368c Update readme 2021-02-03 14:42:58 +08:00
Jactus
d87d29aca9 Update Windows CI 2021-02-03 14:42:58 +08:00
Jactus
005da6306c Update CI 2021-02-03 14:42:58 +08:00
bxdd
090b68e44e Update workflow.py 2021-02-03 13:12:19 +08:00
Young
bf748ba4b7 update version number to dev 2021-02-03 05:11:39 +00:00
Meng Dong
fd5c68a7d1 Update workflow_config_doubleensemble_Alpha158.yaml 2021-02-02 12:39:07 +08:00
meng-ustc
8c3ec164ff Add A New Baseline: DoubleEnsemble 2021-02-02 11:46:37 +09:00
meng-ustc
acdc469e39 Add A New Baseline: DoubleEnsemble 2021-02-01 21:05:34 +09:00
174 changed files with 10700 additions and 1767 deletions

12
.deepsource.toml Normal file
View File

@@ -0,0 +1,12 @@
version = 1
test_patterns = ["tests/test_*.py"]
exclude_patterns = ["examples/**"]
[[analyzers]]
name = "python"
enabled = true
[analyzers.meta]
runtime_version = "3.x.x"

62
.github/stale.yml vendored
View File

@@ -1,62 +0,0 @@
# Configuration for probot-stale - https://github.com/probot/stale
# Number of days of inactivity before an Issue or Pull Request becomes stale
daysUntilStale: 60
# Number of days of inactivity before an Issue or Pull Request with the stale label is closed.
# Set to false to disable. If disabled, issues still need to be closed manually, but will remain marked as stale.
daysUntilClose: 7
# Only issues or pull requests with all of these labels are check if stale. Defaults to `[]` (disabled)
onlyLabels: []
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
exemptLabels:
- bug
- pinned
- security
- "[Status] Maybe Later"
# Set to true to ignore issues in a project (defaults to false)
exemptProjects: false
# Set to true to ignore issues in a milestone (defaults to false)
exemptMilestones: false
# Set to true to ignore issues with an assignee (defaults to false)
exemptAssignees: false
# Label to use when marking as stale
staleLabel: wontfix
# Comment to post when marking as stale. Set to `false` to disable
markComment: >
This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Thank you
for your contributions.
# Comment to post when removing the stale label.
# unmarkComment: >
# Your comment here.
# Comment to post when closing a stale Issue or Pull Request.
# closeComment: >
# Your comment here.
# Limit the number of actions per hour, from 1-30. Default is 30
limitPerRun: 30
# Limit to only `issues` or `pulls`
# only: issues
# Optionally, specify configuration settings that are specific to just 'issues' or 'pulls':
# pulls:
# daysUntilStale: 30
# markComment: >
# This pull request has been automatically marked as stale because it has not had
# recent activity. It will be closed if no further activity occurs. Thank you
# for your contributions.
# issues:
# exemptLabels:
# - confirmed

24
.github/workflows/stale.yml vendored Normal file
View File

@@ -0,0 +1,24 @@
name: Mark stale issues and pull requests
on:
schedule:
- cron: "0 0/3 * * *"
jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v3
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: 'This issue is stale because it has been open for three months with no activity. Remove the stale label or comment on the issue otherwise this will be closed in 5 days'
stale-pr-message: 'This PR is stale because it has been open for a year with no activity. Remove the stale label or comment on the PR otherwise this will be closed in 5 days'
stale-issue-label: 'stale'
stale-pr-label: 'stale'
days-before-stale: 90
days-before-close: 5
operations-per-run: 100
exempt-issue-labels: 'bug,enhancement'
remove-stale-when-updated: true

View File

@@ -39,9 +39,11 @@ jobs:
- name: Install Qlib with pip
run: |
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe -m pip install pyqlib --ignore-installed ruamel.yaml --user
$CONDA\\python.exe -m pip install numpy==1.19.5
$CONDA\\python.exe -m pip install pyqlib --ignore-installed ruamel.yaml numpy --user
else
sudo $CONDA/bin/python -m pip install pyqlib --ignore-installed ruamel.yaml
sudo $CONDA/bin/python -m pip install numpy==1.19.5
sudo $CONDA/bin/python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
fi
shell: bash

4
.gitignore vendored
View File

@@ -34,3 +34,7 @@ tags
.pytest_cache/
.vscode/
*.swp
./pretrain

View File

@@ -7,6 +7,20 @@
[![License](https://img.shields.io/pypi/l/pyqlib)](LICENSE)
[![Join the chat at https://gitter.im/Microsoft/qlib](https://badges.gitter.im/Microsoft/qlib.svg)](https://gitter.im/Microsoft/qlib?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
## :newspaper: **What's NEW!** &nbsp; :sparkling_heart:
Recent released features
| Feature | Status |
| -- | ------ |
| Online serving and automatic model rolling | :star: [Released](https://github.com/microsoft/qlib/pull/290) on May 17, 2021 |
| DoubleEnsemble Model | [Released](https://github.com/microsoft/qlib/pull/286) on Mar 2, 2021 |
| High-frequency data processing example | [Released](https://github.com/microsoft/qlib/pull/257) on Feb 5, 2021 |
| High-frequency trading example | [Part of code released](https://github.com/microsoft/qlib/pull/227) on Jan 28, 2021 |
| High-frequency data(1min) | [Released](https://github.com/microsoft/qlib/pull/221) on Jan 27, 2021 |
| Tabnet Model | [Released](https://github.com/microsoft/qlib/pull/205) on Jan 22, 2021 |
Features released before 2021 are not listed here.
<p align="center">
<img src="http://fintech.msra.cn/images_v060/logo/1.png" />
@@ -17,10 +31,11 @@ Qlib is an AI-oriented quantitative investment platform, which aims to realize t
It contains the full ML pipeline of data processing, model training, back-testing; and covers the entire chain of quantitative investment: alpha seeking, risk modeling, portfolio optimization, and order execution.
With Qlib, user can easily try ideas to create better Quant investment strategies.
With Qlib, users can easily try ideas to create better Quant investment strategies.
For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative Investment Platform"](https://arxiv.org/abs/2009.11189).
- [**Plans**](#plans)
- [Framework of Qlib](#framework-of-qlib)
- [Quick Start](#quick-start)
- [Installation](#installation)
@@ -35,9 +50,20 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
- [Offline Mode and Online Mode](#offline-mode-and-online-mode)
- [Performance of Qlib Data Server](#performance-of-qlib-data-server)
- [Related Reports](#related-reports)
- [Contact Us](#contact-us)
- [Contributing](#contributing)
# Plans
New features under development(order by estimated release time).
Your feedbacks about the features are very important.
| Feature | Status |
| -- | ------ |
| Planning-based portfolio optimization | Under review: https://github.com/microsoft/qlib/pull/280 |
| Fund data supporting and analysis | Under review: https://github.com/microsoft/qlib/pull/292 |
| Point-in-Time database | Under review: https://github.com/microsoft/qlib/pull/343 |
| High-frequency trading | Under review: https://github.com/microsoft/qlib/pull/408 |
| Meta-Learning-based data selection | Initial opensource version under development |
# Framework of Qlib
@@ -46,11 +72,11 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
</div>
At the module level, Qlib is a platform that consists of the above components. The components are designed as loose-coupled modules and each component could be used stand-alone.
At the module level, Qlib is a platform that consists of the above components. The components are designed as loose-coupled modules, and each component could be used stand-alone.
| Name | Description |
| ------ | ----- |
| `Infrastructure` layer | `Infrastructure` layer provides underlying support for Quant research. `DataServer` provides high-performance infrastructure for users to manage and retrieve raw data. `Trainer` provides flexible interface to control the training process of models which enable algorithms controlling the training process. |
| `Infrastructure` layer | `Infrastructure` layer provides underlying support for Quant research. `DataServer` provides a high-performance infrastructure for users to manage and retrieve raw data. `Trainer` provides a flexible interface to control the training process of models, which enable algorithms to control the training process. |
| `Workflow` layer | `Workflow` layer covers the whole workflow of quantitative investment. `Information Extractor` extracts data for models. `Forecast Model` focuses on producing all kinds of forecast signals (e.g. _alpha_, risk) for other modules. With these signals `Portfolio Generator` will generate the target portfolio and produce orders to be executed by `Order Executor`. |
| `Interface` layer | `Interface` layer tries to present a user-friendly interface for the underlying system. `Analyser` module will provide users detailed analysis reports of forecasting signals, portfolios and execution results |
@@ -118,14 +144,20 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
## Data Preparation
Load and prepare data by running the following code:
```bash
# get 1d data
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
# get 1min data
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min
```
This dataset is created by public data collected by [crawler scripts](scripts/data_collector/), which have been released in
the same repository.
Users could create the same dataset with it.
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup), and the data might not be perfect.
We recommend users to prepare their own data if they have a high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
<!--
- Run the initialization code and get stock data:
@@ -213,9 +245,10 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
- Rank Label
![Rank Label](docs/_static/img/rank_label.png)
-->
- [Explanation](https://qlib.readthedocs.io/en/latest/component/report.html) of above results
## Building Customized Quant Research Workflow by Code
The automatic workflow may not suite the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
The automatic workflow may not suit the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
# [Quant Model Zoo](examples/benchmarks)
@@ -232,6 +265,7 @@ Here is a list of models built on `Qlib`.
- [SFM based on pytorch (Liheng Zhang, et al. 2017)](qlib/contrib/model/pytorch_sfm.py)
- [TFT based on tensorflow (Bryan Lim, et al. 2019)](examples/benchmarks/TFT/tft.py)
- [TabNet based on pytorch (Sercan O. Arik, et al. 2019)](qlib/contrib/model/pytorch_tabnet.py)
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. 2020)](qlib/contrib/model/double_ensemble.py)
Your PR of new Quant models is highly welcomed.
@@ -241,10 +275,10 @@ The performance of each model on the `Alpha158` and `Alpha360` dataset can be fo
All the models listed above are runnable with ``Qlib``. Users can find the config files we provide and some details about the model through the [benchmarks](examples/benchmarks) folder. More information can be retrieved at the model files listed above.
`Qlib` provides three different ways to run a single model, users can pick the one that fits their cases best:
- User can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
- User can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
- Users can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
- Users can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
- User can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
## Run multiple models
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parrallel running the same model for multiple times as well, and this will be fixed in the future development too.)
@@ -307,17 +341,27 @@ which creates a dataset (14 features/factors) from the basic OHLCV daily data of
* `+(-)E` indicates with (out) `ExpressionCache`
* `+(-)D` indicates with (out) `DatasetCache`
Most general-purpose databases take too much time on loading data. After looking into the underlying implementation, we find that data go through too many layers of interfaces and unnecessary format transformations in general-purpose database solutions.
Most general-purpose databases take too much time to load data. After looking into the underlying implementation, we find that data go through too many layers of interfaces and unnecessary format transformations in general-purpose database solutions.
Such overheads greatly slow down the data loading process.
Qlib data are stored in a compact format, which is efficient to be combined into arrays for scientific computation.
# Related Reports
- [【华泰金工林晓明团队】图神经网络选股与Qlib实践——华泰人工智能系列之四十二](https://mp.weixin.qq.com/s/w5fDB6oAv9dO6vlhf1kmhA)
- [Guide To Qlib: Microsofts AI Investment Platform](https://analyticsindiamag.com/qlib/)
- [【华泰金工林晓明团队】微软AI量化投资平台Qlib体验——华泰人工智能系列之四十](https://mp.weixin.qq.com/s/Brcd7im4NibJOJzZfMn6tQ)
- [微软也搞AI量化平台还是开源的](https://mp.weixin.qq.com/s/47bP5YwxfTp2uTHjUBzJQQ)
- [微矿Qlib业内首个AI量化投资开源平台](https://mp.weixin.qq.com/s/vsJv7lsgjEi-ALYUz4CvtQ)
# Contact Us
- If you have any issues, please create issue [here](https://github.com/microsoft/qlib/issues/new/choose) or send messages in [gitter](https://gitter.im/Microsoft/qlib).
- If you want to make contributions to `Qlib`, please [create pull requests](https://github.com/microsoft/qlib/compare).
- For other reasons, you are welcome to contact us by email([qlib@microsoft.com](mailto:qlib@microsoft.com)).
- We are recruiting new members(both FTEs and interns), your resumes are welcome!
Join IM discussion groups:
|[Gitter](https://gitter.im/Microsoft/qlib)|
|----|
|![image](http://fintech.msra.cn/images_v060/qrcode/gitter_qr.png)|
# Contributing

View File

@@ -70,3 +70,31 @@ If the issue is not resolved, use ``keys *`` to find if multiple keys exist. If
Also, feel free to post a new issue in our GitHub repository. We always check each issue carefully and try our best to solve them.
3. ModuleNotFoundError: No module named 'qlib.data._libs.rolling'
------------------------------------------------------------------------------------------------------------------------------------
.. code-block:: python
#### Do not import qlib package in the repository directory in case of importing qlib from . without compiling #####
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "qlib/qlib/__init__.py", line 19, in init
from .data.cache import H
File "qlib/qlib/data/__init__.py", line 8, in <module>
from .data import (
File "qlib/qlib/data/data.py", line 20, in <module>
from .cache import H
File "qlib/qlib/data/cache.py", line 36, in <module>
from .ops import Operators
File "qlib/qlib/data/ops.py", line 19, in <module>
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
ModuleNotFoundError: No module named 'qlib.data._libs.rolling'
- If the error occurs when importing ``qlib`` package with ``PyCharm`` IDE, users can execute the following command in the project root folder to compile Cython files and generate executable files:
.. code-block:: bash
python setup.py build_ext --inplace
- If the error occurs when importing ``qlib`` package with command ``python`` , users need to change the running directory to ensure that the script does not run in the project directory.

BIN
docs/_static/img/online_serving.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 440 KiB

BIN
docs/_static/img/qrcode/gitter_qr.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.2 KiB

45
docs/advanced/serial.rst Normal file
View File

@@ -0,0 +1,45 @@
.. _serial:
=================================
Serialization
=================================
.. currentmodule:: qlib
Introduction
===================
``Qlib`` supports dumping the state of ``DataHandler``, ``DataSet``, ``Processor`` and ``Model``, etc. into a disk and reloading them.
Serializable Class
========================
``Qlib`` provides a base class ``qlib.utils.serial.Serializable``, whose state can be dumped into or loaded from disk in `pickle` format.
When users dump the state of a ``Serializable`` instance, the attributes of the instance whose name **does not** start with `_` will be saved on the disk.
However, users can use ``config`` method or override ``default_dump_all`` attribute to prevent this feature.
Users can also override ``pickle_backend`` attribute to choose a pickle backend. The supported value is "pickle" (default and common) and "dill" (dump more things such as function, more information in `here <https://pypi.org/project/dill/>`_).
Example
==========================
``Qlib``'s serializable class includes ``DataHandler``, ``DataSet``, ``Processor`` and ``Model``, etc., which are subclass of ``qlib.utils.serial.Serializable``.
Specifically, ``qlib.data.dataset.DatasetH`` is one of them. Users can serialize ``DatasetH`` as follows.
.. code-block:: Python
##=============dump dataset=============
dataset.to_pickle(path="dataset.pkl") # dataset is an instance of qlib.data.dataset.DatasetH
##=============reload dataset=============
with open("dataset.pkl", "rb") as file_dataset:
dataset = pickle.load(file_dataset)
.. note::
Only state of ``DatasetH`` should be saved on the disk, such as some `mean` and `variance` used for data normalization, etc.
After reloading the ``DatasetH``, users need to reinitialize it. It means that users can reset some states of ``DatasetH`` or ``QlibDataHandler`` such as `instruments`, `start_time`, `end_time` and `segments`, etc., and generate new data according to the states (data is not state and should not be saved on the disk).
A more detailed example is in this `link <https://github.com/microsoft/qlib/tree/main/examples/highfreq>`_.
API
===================
Please refer to `Serializable API <../reference/api.html#module-qlib.utils.serial.Serializable>`_.

View File

@@ -0,0 +1,89 @@
.. _task_management:
=================================
Task Management
=================================
.. currentmodule:: qlib
Introduction
=============
The `Workflow <../component/introduction.html>`_ part introduces how to run research workflow in a loosely-coupled way. But it can only execute one ``task`` when you use ``qrun``.
To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Training`_ and `Task Collecting`_.
With this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models.
This whole process can be used in `Online Serving <../component/online.html>`_.
An example of the entire process is shown `here <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
Task Generating
===============
A ``task`` consists of `Model`, `Dataset`, `Record`, or anything added by users.
The specific task template can be viewed in
`Task Section <../component/workflow.html#task-section>`_.
Even though the task template is fixed, users can customize their ``TaskGen`` to generate different ``task`` by task template.
Here is the base class of ``TaskGen``:
.. autoclass:: qlib.workflow.task.gen.TaskGen
:members:
``Qlib`` provides a class `RollingGen <https://github.com/microsoft/qlib/tree/main/qlib/workflow/task/gen.py>`_ to generate a list of ``task`` of the dataset in different date segments.
This class allows users to verify the effect of data from different periods on the model in one experiment. More information is `here <../reference/api.html#TaskGen>`_.
Task Storing
===============
To achieve higher efficiency and the possibility of cluster operation, ``Task Manager`` will store all tasks in `MongoDB <https://www.mongodb.com/>`_.
``TaskManager`` can fetch undone tasks automatically and manage the lifecycle of a set of tasks with error handling.
Users **MUST** finish the configuration of `MongoDB <https://www.mongodb.com/>`_ when using this module.
Users need to provide the MongoDB URL and database name for using ``TaskManager`` in `initialization <../start/initialization.html#Parameters>`_ or make a statement like this.
.. code-block:: python
from qlib.config import C
C["mongo"] = {
"task_url" : "mongodb://localhost:27017/", # your MongoDB url
"task_db_name" : "rolling_db" # database name
}
.. autoclass:: qlib.workflow.task.manage.TaskManager
:members:
More information of ``Task Manager`` can be found in `here <../reference/api.html#TaskManager>`_.
Task Training
===============
After generating and storing those ``task``, it's time to run the ``task`` which is in the *WAITING* status.
``Qlib`` provides a method called ``run_task`` to run those ``task`` in task pool, however, users can also customize how tasks are executed.
An easy way to get the ``task_func`` is using ``qlib.model.trainer.task_train`` directly.
It will run the whole workflow defined by ``task``, which includes *Model*, *Dataset*, *Record*.
.. autofunction:: qlib.workflow.task.manage.run_task
Meanwhile, ``Qlib`` provides a module called ``Trainer``.
.. autoclass:: qlib.model.trainer.Trainer
:members:
``Trainer`` will train a list of tasks and return a list of model recorders.
``Qlib`` offer two kinds of Trainer, TrainerR is the simplest way and TrainerRM is based on TaskManager to help manager tasks lifecycle automatically.
If you do not want to use ``Task Manager`` to manage tasks, then use TrainerR to train a list of tasks generated by ``TaskGen`` is enough.
`Here <../reference/api.html#Trainer>`_ are the details about different ``Trainer``.
Task Collecting
===============
To collect the results of ``task`` after training, ``Qlib`` provides `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_ to collect the results in a readable, expandable and loosely-coupled way.
`Collector <../reference/api.html#Collector>`_ can collect objects from everywhere and process them such as merging, grouping, averaging and so on. It has 2 step action including ``collect`` (collect anything in a dict) and ``process_collect`` (process collected dict).
`Group <../reference/api.html#Group>`_ also has 2 steps including ``group`` (can group a set of object based on `group_func` and change them to a dict) and ``reduce`` (can make a dict become an ensemble based on some rule).
For example: {(A,B,C1): object, (A,B,C2): object} ---``group``---> {(A,B): {C1: object, C2: object}} ---``reduce``---> {(A,B): object}
`Ensemble <../reference/api.html#Ensemble>`_ can merge the objects in an ensemble.
For example: {C1: object, C2: object} ---``Ensemble``---> object
So the hierarchy is ``Collector``'s second step corresponds to ``Group``. And ``Group``'s second step correspond to ``Ensemble``.
For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.

View File

@@ -31,7 +31,7 @@ Qlib Format Data
We've specially designed a data structure to manage financial data, please refer to the `File storage design section in Qlib paper <https://arxiv.org/abs/2009.11189>`_ for detailed information.
Such data will be stored with filename suffix `.bin` (We'll call them `.bin` file, `.bin` format, or qlib format). `.bin` file is designed for scientific computing on finance data.
``Qlib`` provides two different off-the-shelf dataset, which can be accessed through this `link <https://github.com/microsoft/qlib/blob/main/qlib/contrib/data/handler.py>`_:
``Qlib`` provides two different off-the-shelf datasets, which can be accessed through this `link <https://github.com/microsoft/qlib/blob/main/qlib/contrib/data/handler.py>`_:
======================== ================= ================
Dataset US Market China Market
@@ -41,6 +41,7 @@ Alpha360 √ √
Alpha158 √ √
======================== ================= ================
Also, ``Qlib`` provides a high-frequency dataset. Users can run a high-frequency dataset example through this `link <https://github.com/microsoft/qlib/tree/main/examples/highfreq>`_.
Qlib Format Dataset
--------------------
@@ -48,15 +49,19 @@ Qlib Format Dataset
.. code-block:: bash
# download 1d
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
# download 1min
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --region cn --interval 1min
In addition to China-Stock data, ``Qlib`` also includes a US-Stock dataset, which can be downloaded with the following command:
.. code-block:: bash
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us
After running the above command, users can find china-stock and us-stock data in ``Qlib`` format in the ``~/.qlib/csv_data/cn_data`` directory and ``~/.qlib/csv_data/us_data`` directory respectively.
After running the above command, users can find china-stock and us-stock data in ``Qlib`` format in the ``~/.qlib/qlib_data/cn_data`` directory and ``~/.qlib/qlib_data/us_data`` directory respectively.
``Qlib`` also provides the scripts in ``scripts/data_collector`` to help users crawl the latest data on the Internet and convert it to qlib format.
@@ -67,12 +72,19 @@ Converting CSV Format into Qlib Format
``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert **any** data in CSV format into `.bin` files (``Qlib`` format) as long as they are in the correct format.
Users can download the demo china-stock data in CSV format as follows for reference to the CSV format.
Besides downloading the prepared demo data, users could download demo data directly from the Collector as follows for reference to the CSV format.
Here are some example:
.. code-block:: bash
for daily data:
.. code-block:: bash
python scripts/get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
for 1min data:
.. code-block:: bash
python scripts/data_collector/yahoo/collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1min --region CN --start 2021-05-20 --end 2021-05-23 --delay 0.1 --interval 1min --limit_nums 10
Users can also provide their own data in CSV format. However, the CSV data **must satisfies** following criterions:
- CSV file is named after a specific stock *or* the CSV file includes a column of the stock name
@@ -140,6 +152,16 @@ After conversion, users can find their Qlib format data in the directory `~/.qli
In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended.
Stock Pool (Market)
--------------------------------
``Qlib`` defines `stock pool <https://github.com/microsoft/qlib/blob/main/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml#L4>`_ as stock list and their date ranges. Predefined stock pools (e.g. csi300) may be imported as follows.
.. code-block:: bash
python collector.py --index_name CSI300 --qlib_dir <user qlib data dir> --method parse_instruments
Multiple Stock Modes
--------------------------------
@@ -158,7 +180,7 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
- If users use ``Qlib`` in china-stock mode, china-stock data is required. Users can use ``Qlib`` in china-stock mode according to the following steps:
- Download china-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
- Initialize ``Qlib`` in china-stock mode
Supposed that users download their Qlib format data in the directory ``~/.qlib/csv_data/cn_data``. Users only need to initialize ``Qlib`` as follows.
Supposed that users download their Qlib format data in the directory ``~/.qlib/qlib_data/cn_data``. Users only need to initialize ``Qlib`` as follows.
.. code-block:: python
@@ -167,9 +189,9 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
- If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` also provides a script to download US-stock data. Users can use ``Qlib`` in US-stock mode according to the following steps:
- Download china-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
- Download us-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
- Initialize ``Qlib`` in US-stock mode
Supposed that users prepare their Qlib format data in the directory ``~/.qlib/csv_data/us_data``. Users only need to initialize ``Qlib`` as follows.
Supposed that users prepare their Qlib format data in the directory ``~/.qlib/qlib_data/us_data``. Users only need to initialize ``Qlib`` as follows.
.. code-block:: python
@@ -177,6 +199,11 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
qlib.init(provider_uri='~/.qlib/qlib_data/us_data', region=REG_US)
.. note::
PRs for new data source are highly welcome! Users could commit the code to crawl data as a PR like `the examples here <https://github.com/microsoft/qlib/tree/main/scripts>`_. And then we will use the code to create data cache on our server which other users could use directly.
Data API
========================
@@ -213,6 +240,25 @@ Filter
- `cross-sectional features filter` \: rule_expression = '$rank($close)<10'
- `time-sequence features filter`: rule_expression = '$Ref($close, 3)>100'
Here is a simple example showing how to use filter in a basic ``Qlib`` workflow configuration file:
.. code-block:: yaml
filter: &filter
filter_type: ExpressionDFilter
rule_expression: "Ref($close, -2) / Ref($close, -1) > 1"
filter_start_time: 2010-01-01
filter_end_time: 2010-01-07
keep: False
data_handler_config: &data_handler_config
start_time: 2010-01-01
end_time: 2021-01-22
fit_start_time: 2010-01-01
fit_end_time: 2015-12-31
instruments: *market
filter_pipe: [*filter]
To know more about ``Filter``, please refer to `Filter API <../reference/api.html#module-qlib.data.filter>`_.
Reference
@@ -274,9 +320,10 @@ Here are some important interfaces that ``DataHandlerLP`` provides:
.. autoclass:: qlib.data.dataset.handler.DataHandlerLP
:members: __init__, fetch, get_cols
If users want to load features and labels by config, users can inherit ``qlib.data.dataset.handler.ConfigDataHandler``, ``Qlib`` also provides some preprocess method in this subclass.
If users want to use qlib data, `QLibDataHandler` is recommended. Users can inherit their custom class from `QLibDataHandler`, which is also a subclass of `ConfigDataHandler`.
If users want to load features and labels by config, users can define a new handler and call the static method `parse_config_to_fields` of ``qlib.contrib.data.handler.Alpha158``.
Also, users can pass ``qlib.contrib.data.processor.ConfigSectionProcessor`` that provides some preprocess methods for features defined by config into the new handler.
Processor
@@ -313,7 +360,6 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
.. note:: Users need to initialize ``Qlib`` with `qlib.init` first, please refer to `initialization <../start/initialization.html>`_.
.. code-block:: Python
import qlib
@@ -340,6 +386,9 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
# fetch all the features
print(h.fetch(col_set="feature"))
.. note:: In the ``Alpha158``, ``Qlib`` uses the label `Ref($close, -2)/Ref($close, -1) - 1` that means the change from T+1 to T+2, rather than `Ref($close, -1)/$close - 1`, of which the reason is that when getting the T day close price of a china stock, the stock can be bought on T+1 day and sold on T+2 day.
API
---------
@@ -364,8 +413,7 @@ The ``DatasetH`` class is the `dataset` with `Data Handler`. Here is the most im
API
---------
To know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#module-qlib.data.dataset.__init__>`_.
To know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#dataset>`_.
Cache

46
docs/component/online.rst Normal file
View File

@@ -0,0 +1,46 @@
.. _online:
=================================
Online Serving
=================================
.. currentmodule:: qlib
Introduction
=============
.. image:: ../_static/img/online_serving.png
:align: center
In addition to backtesting, one way to test a model is effective is to make predictions in real market conditions or even do real trading based on those predictions.
``Online Serving`` is a set of modules for online models using the latest data,
which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online Strategy>`_, `Online Tool <#Online Tool>`_, `Updater <#Updater>`_.
`Here <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are several examples for reference, which demonstrate different features of ``Online Serving``.
If you have many models or `task` needs to be managed, please consider `Task Management <../advanced/task_management.html>`_.
The `examples <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are based on some components in `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``.
Online Manager
=============
.. automodule:: qlib.workflow.online.manager
:members:
Online Strategy
=============
.. automodule:: qlib.workflow.online.strategy
:members:
Online Tool
=============
.. automodule:: qlib.workflow.online.utils
:members:
Updater
=============
.. automodule:: qlib.workflow.online.update
:members:

View File

@@ -34,6 +34,7 @@ Here is a general view of the structure of the system:
- Recorder 2
- ...
- ...
This experiment management system defines a set of interface and provided a concrete implementation ``MLflowExpManager``, which is based on the machine learning platform: ``MLFlow`` (`link <https://mlflow.org/>`_).
If users set the implementation of ``ExpManager`` to be ``MLflowExpManager``, they can use the command `mlflow ui` to visualize and check the experiment results. For more information, pleaes refer to the related documents `here <https://www.mlflow.org/docs/latest/cli.html#mlflow-ui>`_.
@@ -94,6 +95,52 @@ The ``RecordTemp`` class is a class that enables generate experiment results suc
- ``SignalRecord``: This class generates the `prediction` results of the model.
- ``SigAnaRecord``: This class generates the `IC`, `ICIR`, `Rank IC` and `Rank ICIR` of the model.
Here is a simple example of what is done in ``SigAnaRecord``, which users can refer to if they want to calculate IC, Rank IC, Long-Short Return with their own prediction and label.
.. code-block:: Python
from qlib.contrib.eva.alpha import calc_ic, calc_long_short_return
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0])
- ``PortAnaRecord``: This class generates the results of `backtest`. The detailed information about `backtest` as well as the available `strategy`, users can refer to `Strategy <../component/strategy.html>`_ and `Backtest <../component/backtest.html>`_.
Here is a simple exampke of what is done in ``PortAnaRecord``, which users can refer to if they want to do backtest based on their own prediction and label.
.. code-block:: Python
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
# backtest
STRATEGY_CONFIG = {
"topk": 50,
"n_drop": 5,
}
BACKTEST_CONFIG = {
"verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": BENCHMARK,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
}
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
# analysis
analysis = dict()
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
analysis["excess_return_with_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"])
analysis_df = pd.concat(analysis) # type: pd.DataFrame
print(analysis_df)
For more information about the APIs, please refer to `Record Template API <../reference/api.html#module-qlib.workflow.record_temp>`_.

View File

@@ -101,7 +101,7 @@ Graphical Result
- Axis Y:
- `ic`
The `Pearson correlation coefficient` series between `label` and `prediction score`.
In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Featrue <data.html#feature>`_ for more details.
In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Feature <data.html#feature>`_ for more details.
- `rank_ic`
The `Spearman's rank correlation coefficient` series between `label` and `prediction score`.

View File

@@ -111,8 +111,6 @@ Usage & Example
pred_score, strategy=strategy, **BACKTEST_CONFIG
)
Also, the above example has been given in ``examples/train_backtest_analyze.ipynb``.
To know more about the `prediction score` `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
To know more about ``Intraday Trading``, please refer to `Intraday Trading: Model&Strategy Testing <backtest.html>`_.

View File

@@ -90,12 +90,12 @@ Below is a typical config file of ``qrun``.
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
module_path: qlib.workflow.record_temp
kwargs: {}
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
After saving the config into `configuration.yaml`, users could start the workflow and test their ideas with a single command below.

View File

@@ -42,6 +42,7 @@ Document Structure
Intraday Trading: Model&Strategy Testing <component/backtest.rst>
Qlib Recorder: Experiment Management <component/recorder.rst>
Analysis: Evaluation & Results Analysis <component/report.rst>
Online Serving: Online Management & Strategy & Tool <component/online.rst>
.. toctree::
:maxdepth: 3
@@ -49,6 +50,8 @@ Document Structure
Building Formulaic Alphas <advanced/alpha.rst>
Online & Offline mode <advanced/server.rst>
Serialization <advanced/serial.rst>
Task Management <advanced/task_management.rst>
.. toctree::
:maxdepth: 3

View File

@@ -53,6 +53,34 @@ Cache
.. autoclass:: qlib.data.cache.DiskDatasetCache
:members:
Storage
-------------
.. autoclass:: qlib.data.storage.storage.BaseStorage
:members:
.. autoclass:: qlib.data.storage.storage.CalendarStorage
:members:
.. autoclass:: qlib.data.storage.storage.InstrumentStorage
:members:
.. autoclass:: qlib.data.storage.storage.FeatureStorage
:members:
.. autoclass:: qlib.data.storage.file_storage.FileStorageMixin
:members:
.. autoclass:: qlib.data.storage.file_storage.FileCalendarStorage
:members:
.. autoclass:: qlib.data.storage.file_storage.FileInstrumentStorage
:members:
.. autoclass:: qlib.data.storage.file_storage.FileFeatureStorage
:members:
Dataset
---------------
@@ -152,4 +180,81 @@ Recorder
Record Template
--------------------
.. automodule:: qlib.workflow.record_temp
:members:
:members:
Task Management
====================
TaskGen
--------------------
.. automodule:: qlib.workflow.task.gen
:members:
TaskManager
--------------------
.. automodule:: qlib.workflow.task.manage
:members:
Trainer
--------------------
.. automodule:: qlib.model.trainer
:members:
Collector
--------------------
.. automodule:: qlib.workflow.task.collect
:members:
Group
--------------------
.. automodule:: qlib.model.ens.group
:members:
Ensemble
--------------------
.. automodule:: qlib.model.ens.ensemble
:members:
Utils
--------------------
.. automodule:: qlib.workflow.task.utils
:members:
Online Serving
====================
Online Manager
--------------------
.. automodule:: qlib.workflow.online.manager
:members:
Online Strategy
--------------------
.. automodule:: qlib.workflow.online.strategy
:members:
Online Tool
--------------------
.. automodule:: qlib.workflow.online.utils
:members:
RecordUpdater
--------------------
.. automodule:: qlib.workflow.online.update
:members:
Utils
====================
Serializable
--------------------
.. automodule:: qlib.utils.serial.Serializable
:members:

View File

@@ -75,3 +75,14 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
"default_exp_name": "Experiment",
}
})
- `mongo`
Type: dict, optional parameter, the setting of `MongoDB <https://www.mongodb.com/>`_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.
Users need finished `installation <https://www.mongodb.com/try/download/community>`_ firstly, and run it in a fixed URL.
.. code-block:: Python
# For example, you can initialize qlib below
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo={
"task_url": "mongodb://localhost:27017/", # your mongo url
"task_db_name": "rolling_db", # the database name of Task Management
})

View File

@@ -82,7 +82,7 @@ The Custom models need to inherit `qlib.model.base.Model <../reference/api.html#
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
- Override the `finetune` method (Optional)
- This method is optional to the users, and when users one to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`.
- This method is optional to the users. When users want to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`.
- The parameters must include the parameter `dataset`.
- Code Example: In the following example, users will use `LightGBM` as the model and finetune it.
.. code-block:: Python

View File

@@ -0,0 +1,4 @@
# DoubleEnsemble
* DoubleEnsemble is an ensemble framework leveraging learning trajectory based sample reweighting and shuffling based feature selection, to solve both the low signal-to-noise ratio and increasing number of features problems. They identify the key samples based on the training dynamics on each sample and elicit key features based on the ablation impact of each feature via shuffling. The model is applicable to a wide range of base models, capable of extracting complex patterns, while mitigating the overfitting and instability issues for financial market prediction.
* This code used in Qlib is implemented by ourselves.
* Paper: DoubleEnsemble: A New Ensemble Method Based on Sample Reweighting and Feature Selection for Financial Data Analysis [https://arxiv.org/pdf/2010.01265.pdf](https://arxiv.org/pdf/2010.01265.pdf).

View File

@@ -0,0 +1,3 @@
pandas==1.1.2
numpy==1.17.4
lightgbm==3.1.0

View File

@@ -0,0 +1,90 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: DEnsembleModel
module_path: qlib.contrib.model.double_ensemble
kwargs:
base_model: "gbm"
loss: mse
num_models: 6
enable_sr: True
enable_fs: True
alpha1: 1
alpha2: 1
bins_sr: 10
bins_fs: 5
decay: 0.5
sample_ratios:
- 0.8
- 0.7
- 0.6
- 0.5
- 0.4
sub_weights:
- 1
- 0.2
- 0.2
- 0.2
- 0.2
- 0.2
epochs: 28
colsample_bytree: 0.8879
learning_rate: 0.2
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
verbosity: -1
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,97 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors: []
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: DEnsembleModel
module_path: qlib.contrib.model.double_ensemble
kwargs:
base_model: "gbm"
loss: mse
num_models: 6
enable_sr: True
enable_fs: True
alpha1: 1
alpha2: 1
bins_sr: 10
bins_fs: 5
decay: 0.5
sample_ratios:
- 0.8
- 0.7
- 0.6
- 0.5
- 0.4
sub_weights:
- 1
- 0.2
- 0.2
- 0.2
- 0.2
- 0.2
epochs: 136
colsample_bytree: 0.8879
learning_rate: 0.0421
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
verbosity: -1
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -29,7 +29,7 @@ data_handler_config: &data_handler_config
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:

View File

@@ -0,0 +1,81 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
instruments: *market
data_loader:
class: QlibDataLoader
kwargs:
config:
feature:
- ["Resi($close, 15)/$close", "Std(Abs($close/Ref($close, 1)-1)*$volume, 5)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, 5)+1e-12)", "Rsquare($close, 5)", "($high-$low)/$open", "Rsquare($close, 10)", "Corr($close, Log($volume+1), 5)", "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 5)", "Corr($close, Log($volume+1), 10)", "Rsquare($close, 20)", "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 60)", "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 10)", "Corr($close, Log($volume+1), 20)", "(Less($open, $close)-$low)/$open"]
- ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", "RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"]
label:
- ["Ref($close, -2)/Ref($close, -1) - 1"]
- ["LABEL0"]
freq: day
learn_processors:
- class: DropnaLabel
- class: CSZScoreNorm
kwargs:
fields_group: label
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LGBModel
module_path: qlib.contrib.model.gbdt
kwargs:
loss: mse
colsample_bytree: 0.8879
learning_rate: 0.2
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: DataHandlerLP
module_path: qlib.data.dataset.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -16,6 +16,8 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
| LSTM (Sepp Hochreiter, et al.) | Alpha360 | 0.0443±0.01 | 0.3401±0.05| 0.0536±0.01 | 0.4248±0.05 | 0.0627±0.03 | 0.8441±0.48| -0.0882±0.03 |
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 |
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 |
## Alpha158 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
@@ -25,11 +27,13 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
| XGBoost (Tianqi Chen, et al.) | Alpha158 | 0.0481±0.00 | 0.3659±0.00| 0.0495±0.00 | 0.4033±0.00 | 0.1111±0.00 | 1.2915±0.00| -0.0893±0.00 |
| LightGBM (Guolin Ke, et al.) | Alpha158 | 0.0475±0.00 | 0.3979±0.00| 0.0485±0.00 | 0.4123±0.00 | 0.1143±0.00 | 1.2744±0.00| -0.0800±0.00 |
| MLP | Alpha158 | 0.0358±0.00 | 0.2738±0.03| 0.0425±0.00 | 0.3221±0.01 | 0.0836±0.02 | 1.0323±0.25| -0.1127±0.02 |
| TabNet with pretrain (Sercan O. Arikm et al) | Alpha158 | 0.0344±0.00|0.205±0.11|0.0398±0.00 |0.3479±0.01|0.0827±0.02|1.1141±0.32 |-0.0925±0.02 |
| TFT (Bryan Lim, et al.) | Alpha158 (with selected 20 features) | 0.0343±0.00 | 0.2071±0.02| 0.0107±0.00 | 0.0660±0.02 | 0.0623±0.02 | 0.5818±0.20| -0.1762±0.01 |
| GRU (Kyunghyun Cho, et al.) | Alpha158 (with selected 20 features) | 0.0311±0.00 | 0.2418±0.04| 0.0425±0.00 | 0.3434±0.02 | 0.0330±0.02 | 0.4805±0.30| -0.1021±0.02 |
| LSTM (Sepp Hochreiter, et al.) | Alpha158 (with selected 20 features) | 0.0312±0.00 | 0.2394±0.04| 0.0418±0.00 | 0.3324±0.03 | 0.0298±0.02 | 0.4198±0.33| -0.1348±0.03 |
| ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 |
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 |
| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 |
- The selected 20 features are based on the feature importance of a lightgbm-based model.
- The base model of DoubleEnsemble is LGBM.

View File

@@ -132,7 +132,7 @@ class GenericDataFormatter(abc.ABC):
return -1, -1
def get_column_definition(self):
""""Returns formatted column definition in order expected by the TFT."""
"""Returns formatted column definition in order expected by the TFT."""
column_definition = self._column_definition

View File

@@ -44,6 +44,7 @@ task:
class: TabnetModel
module_path: qlib.contrib.model.pytorch_tabnet
kwargs:
d_feat: 158
pretrain: True
dataset:
class: DatasetH
@@ -55,7 +56,7 @@ task:
kwargs: *data_handler_config
segments:
pretrain: [2008-01-01, 2014-12-31]
pretrain_validation: [2015-01-01, 2020-08-01]
pretrain_validation: [2015-01-01, 2016-12-31]
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]

View File

@@ -0,0 +1,75 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TabnetModel
module_path: qlib.contrib.model.pytorch_tabnet
kwargs:
d_feat: 360
pretrain: True
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
pretrain: [2008-01-01, 2014-12-31]
pretrain_validation: [2015-01-01, 2016-12-31]
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

208
examples/data/monitor.py Normal file
View File

@@ -0,0 +1,208 @@
"""
This script is the demonstrating the implementation of Metric Extractor and Detector
NOTE: A lot of details is not considered in this script
- Corner case that will raise error( std == 0)
The following functions are used to demonstrate the following examples
· Metric Extractor:
case 1) Basic statistics on different slices of the DataFrame df:
1) The statistics include:
· STD, Mean, Skewnes, Kurtosis
2) The above statistics can be calculated on the following data slices:
· df.groupby(['datetime'])
· df.groupby(['datetime', 'industry' ])
3) The statistics could be calculated on the time dimension for each instruments and factor(the factor can be represented by experssion)
· <df implemented by expresion>.groupby(['instrument', 'factor'])
case 2) Advanced statistics on different slices of the DataFrame df:
1) Auto-correlation:
· Calculate corr(df.loc[t, :, :], df.loc[t-w, :, :]), w=1, 2, ….
2) Correlation between factors:
· For any pair of factors (i, j): calculate corr(df.loc[t, :, i], df.loc[t, :, j]). The result is a correlation matrix with each element corresponds to a correlation value between a pair of factors.
· Detector: detect the abnormality of the extracted metric;
a) Algorithms:
§ Basic checks: NaN.
§ Point anomaly detection.
§ Segment anomaly detection.
b) Scenarios:
§ Online anomaly detection: monitoring streaming data.
The usage of the detectors are demonstrated in the `case_1_*`and `case_2_*`
case 3): Examples to use MetricExt to monitor IC and rank IC
1) IC(Information Coefficient) #case_3_1
2) RankIC #case_3_2
"""
# AUTO download data
from typing import List, Union
from qlib.utils import exists_qlib_data
from qlib.tests.data import GetData
from qlib.config import REG_CN
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
import qlib
import pandas as pd
from qlib.contrib.data.handler import Alpha158
from qlib.data.dataset.loader import QlibDataLoader
from qlib.data.monitor.metric import format_conv
from qlib.data.monitor.metric import MeanM, SkewM, KurtM, StdM, AutoCM, CorrM
from qlib.data.monitor.detector import NDDetector, SWNDD, ThresholdD
from qlib.data import D
import fire
UNIVERSE = "csi300"
START_TIME = "20200101"
# ------------------ a helper function to get data to demonstrate the functionality --------------------
def get_data_df(col_idx: Union[int, List[int]] = 0, verbose: bool = True):
"""
a helper function to get data to demonstrate the functionality.
Parameters
----------
col_idx : Union[int, List[int]]
column index of the metrics
"""
dh = Alpha158(instruments=UNIVERSE, infer_processors=[], learn_processors=[], start_time=START_TIME)
df = dh.fetch()
if verbose:
print(df.head())
# We don't have industries in dataframe, we generate the with fake data
industry = pd.Series(df.index.get_level_values("instrument").str.slice(stop=2).to_list(), index=df.index)
# select a factor
factor_df = format_conv(df.iloc[:, col_idx], industry=industry)
if verbose:
print(f"Selected metric: {df.columns[col_idx]}")
print(factor_df)
return factor_df
def get_target(horizon=5):
target = f"Ref($close, -{horizon + 1})/Ref($close, -1) - 1" # There are lots of targets: return is one of them
qdl = QlibDataLoader(config=([target], ["target"]))
df = qdl.load(instruments=UNIVERSE, start_time=START_TIME) # Aligning with factor will improve performance
df = format_conv(df["target"])
return df
# ----------------- Cases to demonstrate the usage of detector and examples ----------------------
def case_1_1():
factor_df = get_data_df()
# 1) Extract metrics
# 1.1) df.groupby(["datetime"])
mtrc = MeanM()
m_mean = mtrc.extract(factor_df)
print(m_mean)
ndd = NDDetector()
ndd.fit(m_mean) # use historical data to fit detector
check_res = ndd.check(m_mean)
print(check_res) # detecting on new data or historical data
print(check_res.value_counts())
def case_1_2():
factor_df = get_data_df()
# 1.2) df.groupby("datetime", "industry")
mtrc = MeanM(group=["industry"])
m_multi = mtrc.extract(factor_df)
print(m_multi)
for col_name, s in m_multi.iteritems():
print(col_name)
ndd = NDDetector()
ndd.fit(s) # use historical data to fit detector
check_res = ndd.check(s)
print(check_res) # detecting on new data or historical data
print(check_res.value_counts())
def case_1_3():
# case 1.3
# factor_df = get_data_df()
qdl = QlibDataLoader(config=(["$close/Ref($close, 1) - 1"], ["return"]))
df = qdl.load(instruments=["SH600519"], start_time=START_TIME)
df = format_conv(df)
s = df.iloc[:, 0]
print(s)
dtc = SWNDD(window=20)
dtc.fit(s) # fit use historical data (TODO: updating will be supported in the future)
check_res = dtc.check(s) #
print(check_res)
print(check_res.value_counts())
print(check_res[check_res])
def case_2_1():
# · Calculate corr(df.loc[t, :, :], df.loc[t-w, :, :]), w=1, 2, ….
factor_df = get_data_df()
acm = AutoCM()
mtrc = acm.extract(factor_df)
print(mtrc)
thd = ThresholdD(0.0, reverse=True)
check_res = thd.check(mtrc)
print(check_res)
print(check_res.value_counts())
def case_2_2():
factor_df1, factor_df2 = get_data_df(0), get_data_df(1)
cm = CorrM()
mtrc = cm.extract(factor_df1, factor_df2)
print(mtrc)
thd = ThresholdD(0.0, reverse=True)
check_res = thd.check(mtrc)
print(check_res)
print(check_res.value_counts())
def case_3_1_3_2():
target, factor = get_target(), get_data_df(0)
ic_m, rank_ic_m = CorrM(), CorrM(mode="spearman")
ic, rank_ic = ic_m.extract(factor, target), rank_ic_m.extract(factor, target)
print(pd.DataFrame({"ic": ic, "rank_ic": rank_ic}))
def run(test_list=["case_1_1", "case_1_2", "case_1_3", "case_2_1", "case_2_2", "case_3_1_3_2"]):
"""
run the specific tests
python monitor.py case_3_1_3_2
Parameters
----------
test_list : str[]
The tests to run
"""
if isinstance(test_list, str):
test_list = [test_list]
for fn in test_list:
globals()[fn]()
if __name__ == "__main__":
qlib.init()
fire.Fire(run)

View File

@@ -0,0 +1,130 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "0e62a81e",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from tqdm.auto import tqdm\n",
"%matplotlib inline\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c503217b",
"metadata": {},
"outputs": [],
"source": [
"from qlib.data.monitor.analyser import Analyser\n",
"import qlib\n",
"qlib.init()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9c276470",
"metadata": {},
"outputs": [],
"source": [
"class SimpleDFA(Analyser):\n",
" \"\"\"Simple (D)ata(F)rame (A)nalyser\"\"\"\n",
" def analyse(self, data: pd.DataFrame, *args, **kwargs):\n",
" data.plot(*args, **kwargs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "110262e4",
"metadata": {},
"outputs": [],
"source": [
"from monitor import get_data_df, AutoCM"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ea38c62",
"metadata": {},
"outputs": [],
"source": [
"# get data\n",
"factor_df = get_data_df([1], verbose=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dbded6fe",
"metadata": {},
"outputs": [],
"source": [
"# metric extractor\n",
"acm = AutoCM()\n",
"mtrc = acm.extract(factor_df)\n",
"print(mtrc)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "65517c81",
"metadata": {},
"outputs": [],
"source": [
"# Analyser\n",
"sa = SimpleDFA()\n",
"sa.analyse(mtrc, title='Auto Correlation')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dab6fb2e",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,35 @@
# High-Frequency Dataset
This dataset is an example for RL high frequency trading.
## Get High-Frequency Data
Get high-frequency data by running the following command:
```bash
python workflow.py get_data
```
## Dump & Reload & Reinitialize the Dataset
The High-Frequency Dataset is implemented as `qlib.data.dataset.DatasetH` in the `workflow.py`. `DatatsetH` is the subclass of [`qlib.utils.serial.Serializable`](https://qlib.readthedocs.io/en/latest/advanced/serial.html), whose state can be dumped in or loaded from disk in `pickle` format.
### About Reinitialization
After reloading `Dataset` from disk, `Qlib` also support reinitializing the dataset. It means that users can reset some states of `Dataset` or `DataHandler` such as `instruments`, `start_time`, `end_time` and `segments`, etc., and generate new data according to the states.
The example is given in `workflow.py`, users can run the code as follows.
### Run the Code
Run the example by running the following command:
```bash
python workflow.py dump_and_load_dataset
```
## Benchmarks Performance
### Signal Test
Here are the results of signal test for benchmark models. We will keep updating benchmark models in future.
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Long precision| Short Precision | Long-Short Average Return | Long-Short Average Sharpe |
|---|---|---|---|---|---|---|---|---|---|
| LightGBM | Alpha158 | 0.3042±0.00 | 1.5372±0.00| 0.3117±0.00 | 1.6258±0.00 | 0.6720±0.00 | 0.6870±0.00 | 0.000769±0.00 | 1.0190±0.00 |

View File

@@ -62,9 +62,9 @@ class HighFreqHandler(DataHandlerLP):
def get_normalized_price_feature(price_field, shift=0):
"""Get normalized price feature ops"""
if shift == 0:
template_norm = "{0}/Ref(DayLast({1}), 240)"
template_norm = "Cut({0}/Ref(DayLast({1}), 240), 240, None)"
else:
template_norm = "Ref({0}, " + str(shift) + ")/Ref(DayLast({1}), 240)"
template_norm = "Cut(Ref({0}, " + str(shift) + ")/Ref(DayLast({1}), 240), 240, None)"
feature_ops = template_norm.format(
template_if.format(
@@ -90,7 +90,7 @@ class HighFreqHandler(DataHandlerLP):
names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"]
fields += [
"{0}/Ref(DayLast(Mean({0}, 7200)), 240)".format(
"Cut({0}/Ref(DayLast(Mean({0}, 7200)), 240), 240, None)".format(
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
template_paused.format("$volume"),
template_paused.format(simpson_vwap),
@@ -101,7 +101,7 @@ class HighFreqHandler(DataHandlerLP):
]
names += ["$volume"]
fields += [
"Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240)".format(
"Cut(Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240), 240, None)".format(
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
template_paused.format("$volume"),
template_paused.format(simpson_vwap),
@@ -112,7 +112,7 @@ class HighFreqHandler(DataHandlerLP):
]
names += ["$volume_1"]
fields += [template_paused.format("Date($close)")]
fields += ["Cut({0}, 240, None)".format(template_paused.format("Date($close)"))]
names += ["date"]
return fields, names
@@ -149,18 +149,20 @@ class HighFreqBacktestHandler(DataHandler):
# Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap
simpson_vwap = "($open + 2*$high + 2*$low + $close)/6"
fields += [
template_fillnan.format(template_paused.format("$close")),
"Cut({0}, 240, None)".format(template_fillnan.format(template_paused.format("$close"))),
]
names += ["$close0"]
fields += [
template_if.format(
template_fillnan.format(template_paused.format("$close")),
template_paused.format(simpson_vwap),
"Cut({0}, 240, None)".format(
template_if.format(
template_fillnan.format(template_paused.format("$close")),
template_paused.format(simpson_vwap),
)
)
]
names += ["$vwap0"]
fields += [
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
"Cut(If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0})), 240, None)".format(
template_paused.format("$volume"),
template_paused.format(simpson_vwap),
template_paused.format("$low"),

View File

@@ -8,6 +8,20 @@ from qlib.data.data import Cal
def get_calendar_day(freq="day", future=False):
"""Load High-Freq Calendar Date Using Memcache.
Parameters
----------
freq : str
frequency of read calendar file.
future : bool
whether including future trading day.
Returns
-------
_calendar:
array of date.
"""
flag = f"{freq}_future_{future}_day"
if flag in H["c"]:
_calendar = H["c"][flag]
@@ -18,6 +32,19 @@ def get_calendar_day(freq="day", future=False):
class DayLast(ElemOperator):
"""DayLast Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
feature:
a series of that each value equals the last value of its day
"""
def _load_internal(self, instrument, start_index, end_index, freq):
_calendar = get_calendar_day(freq=freq)
series = self.feature.load(instrument, start_index, end_index, freq)
@@ -25,18 +52,57 @@ class DayLast(ElemOperator):
class FFillNan(ElemOperator):
"""FFillNan Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
feature:
a forward fill nan feature
"""
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.fillna(method="ffill")
class BFillNan(ElemOperator):
"""BFillNan Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
feature:
a backfoward fill nan feature
"""
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.fillna(method="bfill")
class Date(ElemOperator):
"""Date Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
feature:
a series of that each value is the date corresponding to feature.index
"""
def _load_internal(self, instrument, start_index, end_index, freq):
_calendar = get_calendar_day(freq=freq)
series = self.feature.load(instrument, start_index, end_index, freq)
@@ -44,6 +110,22 @@ class Date(ElemOperator):
class Select(PairOperator):
"""Select Operator
Parameters
----------
feature_left : Expression
feature instance, select condition
feature_right : Expression
feature instance, select value
Returns
----------
feature:
value(feature_right) that meets the condition(feature_left)
"""
def _load_internal(self, instrument, start_index, end_index, freq):
series_condition = self.feature_left.load(instrument, start_index, end_index, freq)
series_feature = self.feature_right.load(instrument, start_index, end_index, freq)
@@ -51,6 +133,58 @@ class Select(PairOperator):
class IsNull(ElemOperator):
"""IsNull Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
feature:
A series indicating whether the feature is nan
"""
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.isnull()
class Cut(ElemOperator):
"""Cut Operator
Parameters
----------
feature : Expression
feature instance
l : int
l > 0, delete the first l elements of feature (default is None, which means 0)
r : int
r < 0, delete the last -r elements of feature (default is None, which means 0)
Returns
----------
feature:
A series with the first l and last -r elements deleted from the feature.
Note: It is deleted from the raw data, not the sliced data
"""
def __init__(self, feature, l=None, r=None):
self.l = l
self.r = r
if (self.l is not None and self.l <= 0) or (self.r is not None and self.r >= 0):
raise ValueError("Cut operator l shoud > 0 and r should < 0")
super(Cut, self).__init__(feature)
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.iloc[self.l : self.r]
def get_extended_window_size(self):
ll = 0 if self.l is None else self.l
rr = 0 if self.r is None else abs(self.r)
lft_etd, rght_etd = self.feature.get_extended_window_size()
lft_etd = lft_etd + ll
rght_etd = rght_etd + rr
return lft_etd, rght_etd

View File

@@ -1,40 +1,28 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import fire
from pathlib import Path
import qlib
import pickle
import numpy as np
import pandas as pd
from qlib.config import HIGH_FREQ_CONFIG
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
from qlib.utils import init_instance_by_config, exists_qlib_data
from qlib.utils import init_instance_by_config
from qlib.data.dataset.handler import DataHandlerLP
from qlib.data.ops import Operators
from qlib.data.data import Cal
from qlib.tests.data import GetData
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut
class HighfreqWorkflow(object):
class HighfreqWorkflow:
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull], "expression_cache": None}
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None}
MARKET = "all"
BENCHMARK = "SH000300"
start_time = "2020-09-14 00:00:00"
start_time = "2020-09-15 00:00:00"
end_time = "2021-01-18 16:00:00"
train_end_time = "2020-11-30 16:00:00"
test_start_time = "2020-12-01 00:00:00"
@@ -42,7 +30,6 @@ class HighfreqWorkflow(object):
DATA_HANDLER_CONFIG0 = {
"start_time": start_time,
"end_time": end_time,
"freq": "1min",
"fit_start_time": start_time,
"fit_end_time": train_end_time,
"instruments": MARKET,
@@ -51,7 +38,6 @@ class HighfreqWorkflow(object):
DATA_HANDLER_CONFIG1 = {
"start_time": start_time,
"end_time": end_time,
"freq": "1min",
"instruments": MARKET,
}
@@ -99,9 +85,7 @@ class HighfreqWorkflow(object):
# use yahoo_cn_1min data
QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF}
provider_uri = QLIB_INIT_CONFIG.get("provider_uri")
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN)
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN, exists_skip=True)
qlib.init(**QLIB_INIT_CONFIG)
def _prepare_calender_cache(self):
@@ -125,8 +109,7 @@ class HighfreqWorkflow(object):
backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"])
print(backtest_train, backtest_test)
del xtrain, xtest
del backtest_train, backtest_test
return
def dump_and_load_dataset(self):
"""dump and load dataset state on disk"""
@@ -148,18 +131,44 @@ class HighfreqWorkflow(object):
dataset_backtest = pickle.load(file_dataset_backtest)
self._prepare_calender_cache()
##=============reload_dataset=============
dataset.init(init_type=DataHandlerLP.IT_LS)
dataset_backtest.init()
##=============reinit dataset=============
dataset.config(
handler_kwargs={
"start_time": "2021-01-19 00:00:00",
"end_time": "2021-01-25 16:00:00",
},
segments={
"test": (
"2021-01-19 00:00:00",
"2021-01-25 16:00:00",
),
},
)
dataset.setup_data(
handler_kwargs={
"init_type": DataHandlerLP.IT_LS,
},
)
dataset_backtest.config(
handler_kwargs={
"start_time": "2021-01-19 00:00:00",
"end_time": "2021-01-25 16:00:00",
},
segments={
"test": (
"2021-01-19 00:00:00",
"2021-01-25 16:00:00",
),
},
)
dataset_backtest.setup_data(handler_kwargs={})
##=============get data=============
xtrain, xtest = dataset.prepare(["train", "test"])
backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"])
xtest = dataset.prepare("test")
backtest_test = dataset_backtest.prepare("test")
print(xtrain, xtest)
print(backtest_train, backtest_test)
del xtrain, xtest
del backtest_train, backtest_test
print(xtest, backtest_test)
return
if __name__ == "__main__":

View File

@@ -0,0 +1,65 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data_1min"
region: cn
market: &market 'csi300'
start_time: &start_time "2020-09-15 00:00:00"
end_time: &end_time "2021-01-18 16:00:00"
train_end_time: &train_end_time "2020-11-15 16:00:00"
valid_start_time: &valid_start_time "2020-11-16 00:00:00"
valid_end_time: &valid_end_time "2020-11-30 16:00:00"
test_start_time: &test_start_time "2020-12-01 00:00:00"
data_handler_config: &data_handler_config
start_time: *start_time
end_time: *end_time
fit_start_time: *start_time
fit_end_time: *train_end_time
instruments: *market
freq: '1min'
infer_processors:
- class: 'RobustZScoreNorm'
kwargs:
fields_group: 'feature'
clip_outlier: false
- class: "Fillna"
kwargs:
fields_group: 'feature'
learn_processors:
- class: 'DropnaLabel'
- class: 'CSRankNorm'
kwargs:
fields_group: 'label'
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
task:
model:
class: "HFLGBModel"
module_path: "qlib.contrib.model.highfreq_gdbt_model"
kwargs:
objective: 'binary'
metric: ['binary_logloss','auc']
verbosity: -1
learning_rate: 0.01
max_depth: 8
num_leaves: 150
lambda_l1: 1.5
lambda_l2: 1
num_threads: 20
dataset:
class: "DatasetH"
module_path: "qlib.data.dataset"
kwargs:
handler:
class: "Alpha158"
module_path: "qlib.contrib.data.handler"
kwargs: *data_handler_config
segments:
train: [*start_time, *train_end_time]
valid: [*train_end_time, *valid_end_time]
test: [*test_start_time, *end_time]
record:
- class: "SignalRecord"
module_path: "qlib.workflow.record_temp"
kwargs: {}
- class: "HFSignalRecord"
module_path: "qlib.workflow.record_temp"
kwargs: {}

View File

@@ -0,0 +1,23 @@
# LightGBM hyperparameter
## Alpha158
First terminal
```
optuna create-study --study LGBM_158 --storage sqlite:///db.sqlite3
optuna-dashboard --port 5000 --host 0.0.0.0 sqlite:///db.sqlite3
```
Second terminal
```
python hyperparameter_158.py
```
## Alpha360
First terminal
```
optuna create-study --study LGBM_360 --storage sqlite:///db.sqlite3
optuna-dashboard --port 5000 --host 0.0.0.0 sqlite:///db.sqlite3
```
Second terminal
```
python hyperparameter_360.py
```

View File

@@ -0,0 +1,46 @@
import qlib
import optuna
from qlib.config import REG_CN
from qlib.utils import init_instance_by_config
from qlib.tests.config import CSI300_DATASET_CONFIG
from qlib.tests.data import GetData
def objective(trial):
task = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": trial.suggest_uniform("colsample_bytree", 0.5, 1),
"learning_rate": trial.suggest_uniform("learning_rate", 0, 1),
"subsample": trial.suggest_uniform("subsample", 0, 1),
"lambda_l1": trial.suggest_loguniform("lambda_l1", 1e-8, 1e4),
"lambda_l2": trial.suggest_loguniform("lambda_l2", 1e-8, 1e4),
"max_depth": 10,
"num_leaves": trial.suggest_int("num_leaves", 1, 1024),
"feature_fraction": trial.suggest_uniform("feature_fraction", 0.4, 1.0),
"bagging_fraction": trial.suggest_uniform("bagging_fraction", 0.4, 1.0),
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
"min_data_in_leaf": trial.suggest_int("min_data_in_leaf", 1, 50),
"min_child_samples": trial.suggest_int("min_child_samples", 5, 100),
},
},
}
evals_result = dict()
model = init_instance_by_config(task["model"])
model.fit(dataset, evals_result=evals_result)
return min(evals_result["valid"])
if __name__ == "__main__":
provider_uri = "~/.qlib/qlib_data/cn_data"
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region="cn")
dataset = init_instance_by_config(CSI300_DATASET_CONFIG)
study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3")
study.optimize(objective, n_jobs=6)

View File

@@ -0,0 +1,49 @@
import qlib
import optuna
from qlib.config import REG_CN
from qlib.utils import init_instance_by_config
from qlib.tests.data import GetData
from qlib.tests.config import get_dataset_config, CSI300_MARKET, DATASET_ALPHA360_CLASS
DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA360_CLASS)
def objective(trial):
task = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": trial.suggest_uniform("colsample_bytree", 0.5, 1),
"learning_rate": trial.suggest_uniform("learning_rate", 0, 1),
"subsample": trial.suggest_uniform("subsample", 0, 1),
"lambda_l1": trial.suggest_loguniform("lambda_l1", 1e-8, 1e4),
"lambda_l2": trial.suggest_loguniform("lambda_l2", 1e-8, 1e4),
"max_depth": 10,
"num_leaves": trial.suggest_int("num_leaves", 1, 1024),
"feature_fraction": trial.suggest_uniform("feature_fraction", 0.4, 1.0),
"bagging_fraction": trial.suggest_uniform("bagging_fraction", 0.4, 1.0),
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
"min_data_in_leaf": trial.suggest_int("min_data_in_leaf", 1, 50),
"min_child_samples": trial.suggest_int("min_child_samples", 5, 100),
},
},
}
evals_result = dict()
model = init_instance_by_config(task["model"])
model.fit(dataset, evals_result=evals_result)
return min(evals_result["valid"])
if __name__ == "__main__":
provider_uri = "~/.qlib/qlib_data/cn_data"
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN)
dataset = init_instance_by_config(DATASET_CONFIG)
study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3")
study.optimize(objective, n_jobs=6)

View File

@@ -0,0 +1,5 @@
pandas==1.1.2
numpy==1.17.4
lightgbm==3.1.0
optuna==2.7.0
optuna-dashboard==0.4.1

View File

@@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import qlib
from qlib.config import REG_CN
from qlib.utils import init_instance_by_config
from qlib.tests.data import GetData
from qlib.tests.config import CSI300_GBDT_TASK
if __name__ == "__main__":
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN)
###################################
# train model
###################################
# model initialization
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
model.fit(dataset)
# get model feature importance
feature_importance = model.get_feature_importance()
print("feature importance:")
print(feature_importance)

View File

@@ -0,0 +1,105 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This example shows how a TrainerRM works based on TaskManager with rolling tasks.
After training, how to collect the rolling results will be shown in task_collecting.
"""
from pprint import pprint
import fire
import qlib
from qlib.config import REG_CN
from qlib.workflow import R
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.task.collect import RecorderCollector
from qlib.model.ens.group import RollingGroup
from qlib.model.trainer import TrainerRM
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
class RollingTaskExample:
def __init__(
self,
provider_uri="~/.qlib/qlib_data/cn_data",
region=REG_CN,
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
experiment_name="rolling_exp",
task_pool="rolling_task",
task_config=None,
rolling_step=550,
rolling_type=RollingGen.ROLL_SD,
):
# TaskManager config
if task_config is None:
task_config = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
mongo_conf = {
"task_url": task_url,
"task_db_name": task_db_name,
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.experiment_name = experiment_name
self.task_pool = task_pool
self.task_config = task_config
self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)
# Reset all things to the first status, be careful to save important data
def reset(self):
print("========== reset ==========")
TaskManager(task_pool=self.task_pool).remove()
exp = R.get_exp(experiment_name=self.experiment_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
def task_generating(self):
print("========== task_generating ==========")
tasks = task_generator(
tasks=self.task_config,
generators=self.rolling_gen, # generate different date segments
)
pprint(tasks)
return tasks
def task_training(self, tasks):
print("========== task_training ==========")
trainer = TrainerRM(self.experiment_name, self.task_pool)
trainer.train(tasks)
def task_collecting(self):
print("========== task_collecting ==========")
def rec_key(recorder):
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, rolling_key
def my_filter(recorder):
# only choose the results of "LGBModel"
model_key, rolling_key = rec_key(recorder)
if model_key == "LGBModel":
return True
return False
collector = RecorderCollector(
experiment=self.experiment_name,
process_list=RollingGroup(),
rec_key_func=rec_key,
rec_filter_func=my_filter,
)
print(collector())
def main(self):
self.reset()
tasks = self.task_generating()
self.task_training(tasks)
self.task_collecting()
if __name__ == "__main__":
## to see the whole process with your own parameters, use the command below
# python task_manager_rolling.py main --experiment_name="your_exp_name"
fire.Fire(RollingTaskExample)

View File

@@ -0,0 +1,92 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This example is about how can simulate the OnlineManager based on rolling tasks.
"""
import fire
import qlib
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
from qlib.workflow import R
from qlib.workflow.online.manager import OnlineManager
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.task.manage import TaskManager
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
class OnlineSimulationExample:
def __init__(
self,
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
exp_name="rolling_exp",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
task_pool="rolling_task",
rolling_step=80,
start_time="2018-09-10",
end_time="2018-10-31",
tasks=None,
):
"""
Init OnlineManagerExample.
Args:
provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data".
region (str, optional): the stock region. Defaults to "cn".
exp_name (str, optional): the experiment name. Defaults to "rolling_exp".
task_url (str, optional): your MongoDB url. Defaults to "mongodb://10.0.0.4:27017/".
task_db_name (str, optional): database name. Defaults to "rolling_db".
task_pool (str, optional): the task pool name (a task pool is a collection in MongoDB). Defaults to "rolling_task".
rolling_step (int, optional): the step for rolling. Defaults to 80.
start_time (str, optional): the start time of simulating. Defaults to "2018-09-10".
end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
"""
if tasks is None:
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
self.exp_name = exp_name
self.task_pool = task_pool
self.start_time = start_time
self.end_time = end_time
mongo_conf = {
"task_url": task_url,
"task_db_name": task_db_name,
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.rolling_gen = RollingGen(
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
self.rolling_online_manager = OnlineManager(
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
trainer=self.trainer,
begin_time=self.start_time,
)
self.tasks = tasks
# Reset all things to the first status, be careful to save important data
def reset(self):
TaskManager(self.task_pool).remove()
exp = R.get_exp(experiment_name=self.exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
# Run this to run all workflow automatically
def main(self):
print("========== reset ==========")
self.reset()
print("========== simulate ==========")
self.rolling_online_manager.simulate(end_time=self.end_time)
print("========== collect results ==========")
print(self.rolling_online_manager.get_collector()())
print("========== signals ==========")
print(self.rolling_online_manager.get_signals())
if __name__ == "__main__":
## to run all workflow automatically with your own parameters, use the command below
# python online_management_simulate.py main --experiment_name="your_exp_name" --rolling_step=60
fire.Fire(OnlineSimulationExample)

View File

@@ -0,0 +1,130 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This example shows how OnlineManager works with rolling tasks.
There are four parts including first train, routine 1, add strategy and routine 2.
Firstly, the OnlineManager will finish the first training and set trained models to `online` models.
Next, the OnlineManager will finish a routine process, including update online prediction -> prepare tasks -> prepare new models -> prepare signals
Then, we will add some new strategies to the OnlineManager. This will finish first training of new strategies.
Finally, the OnlineManager will finish second routine and update all strategies.
"""
import os
import fire
import qlib
from qlib.workflow import R
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.online.manager import OnlineManager
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG
class RollingOnlineExample:
def __init__(
self,
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
rolling_step=550,
tasks=None,
add_tasks=None,
):
if add_tasks is None:
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG]
if tasks is None:
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG]
mongo_conf = {
"task_url": task_url, # your MongoDB url
"task_db_name": task_db_name, # database name
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.tasks = tasks
self.add_tasks = add_tasks
self.rolling_step = rolling_step
strategies = []
for task in tasks:
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
strategies.append(
RollingStrategy(
name_id,
task,
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
)
)
self.rolling_online_manager = OnlineManager(strategies)
_ROLLING_MANAGER_PATH = (
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
)
# Reset all things to the first status, be careful to save important data
def reset(self):
for task in self.tasks + self.add_tasks:
name_id = task["model"]["class"]
exp = R.get_exp(experiment_name=name_id)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
if os.path.exists(self._ROLLING_MANAGER_PATH):
os.remove(self._ROLLING_MANAGER_PATH)
def first_run(self):
print("========== reset ==========")
self.reset()
print("========== first_run ==========")
self.rolling_online_manager.first_train()
print("========== collect results ==========")
print(self.rolling_online_manager.get_collector()())
print("========== dump ==========")
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
def routine(self):
print("========== load ==========")
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
print("========== routine ==========")
self.rolling_online_manager.routine()
print("========== collect results ==========")
print(self.rolling_online_manager.get_collector()())
print("========== signals ==========")
print(self.rolling_online_manager.get_signals())
print("========== dump ==========")
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
def add_strategy(self):
print("========== load ==========")
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
print("========== add strategy ==========")
strategies = []
for task in self.add_tasks:
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
strategies.append(
RollingStrategy(
name_id,
task,
RollingGen(step=self.rolling_step, rtype=RollingGen.ROLL_SD),
)
)
self.rolling_online_manager.add_strategy(strategies=strategies)
print("========== dump ==========")
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
def main(self):
self.first_run()
self.routine()
self.add_strategy()
self.routine()
if __name__ == "__main__":
####### to train the first version's models, use the command below
# python rolling_online_management.py first_run
####### to update the models and predictions after the trading time, use the command below
# python rolling_online_management.py routine
####### to define your own parameters, use `--`
# python rolling_online_management.py first_run --exp_name='your_exp_name' --rolling_step=40
fire.Fire(RollingOnlineExample)

View File

@@ -0,0 +1,54 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This example shows how OnlineTool works when we need update prediction.
There are two parts including first_train and update_online_pred.
Firstly, we will finish the training and set the trained models to the `online` models.
Next, we will finish updating online predictions.
"""
import copy
import fire
import qlib
from qlib.config import REG_CN
from qlib.model.trainer import task_train
from qlib.workflow.online.utils import OnlineToolR
from qlib.tests.config import CSI300_GBDT_TASK
task = copy.deepcopy(CSI300_GBDT_TASK)
task["record"] = {
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
}
class UpdatePredExample:
def __init__(
self, provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv", task_config=task
):
qlib.init(provider_uri=provider_uri, region=region)
self.experiment_name = experiment_name
self.online_tool = OnlineToolR(self.experiment_name)
self.task_config = task_config
def first_train(self):
rec = task_train(self.task_config, experiment_name=self.experiment_name)
self.online_tool.reset_online_tag(rec) # set to online model
def update_online_pred(self):
self.online_tool.update_online_pred()
def main(self):
self.first_train()
self.update_online_pred()
if __name__ == "__main__":
## to train a model and set it to online model, use the command below
# python update_online_pred.py first_train
## to update online predictions once a day, use the command below
# python update_online_pred.py update_online_pred
## to see the whole process with your own parameters, use the command below
# python update_online_pred.py main --experiment_name="your_exp_name"
fire.Fire(UpdatePredExample)

View File

@@ -0,0 +1,17 @@
# Rolling Process Data
This workflow is an example for `Rolling Process Data`.
## Background
When rolling train the models, data also needs to be generated in the different rolling windows. When the rolling window moves, the training data will change, and the processor's learnable state (such as standard deviation, mean, etc.) will also change.
In order to avoid regenerating data, this example uses the `DataHandler-based DataLoader` to load the raw features that are not related to the rolling window, and then used Processors to generate processed-features related to the rolling window.
## Run the Code
Run the example by running the following command:
```bash
python workflow.py rolling_process
```

View File

@@ -0,0 +1,32 @@
from qlib.data.dataset.handler import DataHandlerLP
from qlib.data.dataset.loader import DataLoaderDH
from qlib.contrib.data.handler import check_transform_proc
class RollingDataHandler(DataHandlerLP):
def __init__(
self,
start_time=None,
end_time=None,
infer_processors=[],
learn_processors=[],
fit_start_time=None,
fit_end_time=None,
data_loader_kwargs={},
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
data_loader = {
"class": "DataLoaderDH",
"kwargs": {**data_loader_kwargs},
}
super().__init__(
instruments=None,
start_time=start_time,
end_time=end_time,
data_loader=data_loader,
infer_processors=infer_processors,
learn_processors=learn_processors,
)

View File

@@ -0,0 +1,137 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import qlib
import fire
import pickle
from datetime import datetime
from qlib.config import REG_CN
from qlib.data.dataset.handler import DataHandlerLP
from qlib.utils import init_instance_by_config
from qlib.tests.data import GetData
class RollingDataWorkflow:
MARKET = "csi300"
start_time = "2010-01-01"
end_time = "2019-12-31"
rolling_cnt = 5
def _init_qlib(self):
"""initialize qlib"""
# use yahoo_cn_1min data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN)
def _dump_pre_handler(self, path):
handler_config = {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": {
"start_time": self.start_time,
"end_time": self.end_time,
"instruments": self.MARKET,
"infer_processors": [],
"learn_processors": [],
},
}
pre_handler = init_instance_by_config(handler_config)
pre_handler.config(dump_all=True)
pre_handler.to_pickle(path)
def _load_pre_handler(self, path):
with open(path, "rb") as file_dataset:
pre_handler = pickle.load(file_dataset)
return pre_handler
def rolling_process(self):
self._init_qlib()
self._dump_pre_handler("pre_handler.pkl")
pre_handler = self._load_pre_handler("pre_handler.pkl")
train_start_time = (2010, 1, 1)
train_end_time = (2012, 12, 31)
valid_start_time = (2013, 1, 1)
valid_end_time = (2013, 12, 31)
test_start_time = (2014, 1, 1)
test_end_time = (2014, 12, 31)
dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "RollingDataHandler",
"module_path": "rolling_handler",
"kwargs": {
"start_time": datetime(*train_start_time),
"end_time": datetime(*test_end_time),
"fit_start_time": datetime(*train_start_time),
"fit_end_time": datetime(*train_end_time),
"infer_processors": [
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature"}},
],
"learn_processors": [
{"class": "DropnaLabel"},
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
],
"data_loader_kwargs": {
"handler_config": pre_handler,
},
},
},
"segments": {
"train": (datetime(*train_start_time), datetime(*train_end_time)),
"valid": (datetime(*valid_start_time), datetime(*valid_end_time)),
"test": (datetime(*test_start_time), datetime(*test_end_time)),
},
},
}
dataset = init_instance_by_config(dataset_config)
for rolling_offset in range(self.rolling_cnt):
print(f"===========rolling{rolling_offset} start===========")
if rolling_offset:
dataset.config(
handler_kwargs={
"start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
"end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
"processor_kwargs": {
"fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
"fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
},
},
segments={
"train": (
datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
),
"valid": (
datetime(valid_start_time[0] + rolling_offset, *valid_start_time[1:]),
datetime(valid_end_time[0] + rolling_offset, *valid_end_time[1:]),
),
"test": (
datetime(test_start_time[0] + rolling_offset, *test_start_time[1:]),
datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
),
},
)
dataset.setup_data(
handler_kwargs={
"init_type": DataHandlerLP.IT_FIT_SEQ,
}
)
dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"])
print(dtrain, dvalid, dtest)
## print or dump data
print(f"===========rolling{rolling_offset} end===========")
if __name__ == "__main__":
fire.Fire(RollingDataWorkflow)

View File

@@ -5,13 +5,11 @@ import os
import sys
import fire
import time
import venv
import glob
import shutil
import signal
import inspect
import tempfile
import traceback
import functools
import statistics
import subprocess
@@ -23,8 +21,7 @@ from pprint import pprint
import qlib
from qlib.config import REG_CN
from qlib.workflow import R
from qlib.workflow.cli import workflow
from qlib.utils import exists_qlib_data
from qlib.tests.data import GetData
# init qlib
@@ -39,12 +36,8 @@ exp_manager = {
"default_exp_name": "Experiment",
},
}
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
# decorator to check the arguments

View File

@@ -28,11 +28,17 @@
"import sys, site\n",
"from pathlib import Path\n",
"\n",
"################################# NOTE #################################\n",
"# Please be aware that if colab installs the latest numpy and pyqlib #\n",
"# in this cell, users should RESTART the runtime in order to run the #\n",
"# following cells successfully. #\n",
"########################################################################\n",
"\n",
"try:\n",
" import qlib\n",
"except ImportError:\n",
" # install qlib\n",
" ! pip install --upgrade numpy\n",
" ! pip install pyqlib\n",
" # reload\n",
" site.main()\n",
@@ -238,9 +244,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"from qlib.contrib.report import analysis_model, analysis_position\n",
@@ -359,7 +363,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
"version": "3.8.3"
},
"toc": {
"base_numbering": 1,
@@ -377,4 +381,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}

View File

@@ -1,82 +1,22 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
from pathlib import Path
import qlib
import pandas as pd
from qlib.config import REG_CN
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.utils import init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.tests.data import GetData
from qlib.tests.config import CSI300_BENCH, CSI300_GBDT_TASK
if __name__ == "__main__":
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN)
market = "csi300"
benchmark = "SH000300"
###################################
# train model
###################################
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
}
task = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": 0.8879,
"learning_rate": 0.0421,
"subsample": 0.8789,
"lambda_l1": 205.6999,
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
},
},
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
}
port_analysis_config = {
"strategy": {
"class": "TopkDropoutStrategy",
@@ -90,7 +30,7 @@ if __name__ == "__main__":
"verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": benchmark,
"benchmark": CSI300_BENCH,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
@@ -99,9 +39,9 @@ if __name__ == "__main__":
},
}
# model initiaiton
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
# model initialization
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
# NOTE: This line is optional
# It demonstrates that the dataset can be used standalone.
@@ -110,14 +50,16 @@ if __name__ == "__main__":
# start exp
with R.start(experiment_name="workflow"):
R.log_params(**flatten_dict(task))
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
model.fit(dataset)
R.save_objects(**{"params.pkl": model})
# prediction
recorder = R.get_recorder()
sr = SignalRecord(model, dataset, recorder)
sr.generate()
# backtest
# backtest. If users want to use backtest based on their own prediction,
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
par = PortAnaRecord(recorder, port_analysis_config)
par.generate()

View File

@@ -2,7 +2,8 @@
# Licensed under the MIT License.
__version__ = "0.6.3"
__version__ = "0.6.3.99"
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
import os
@@ -10,12 +11,13 @@ import yaml
import logging
import platform
import subprocess
from pathlib import Path
from .log import get_module_logger
# init qlib
def init(default_conf="client", **kwargs):
from .config import C
from .log import get_module_logger
from .data.cache import H
H.clear()
@@ -48,7 +50,6 @@ def init(default_conf="client", **kwargs):
def _mount_nfs_uri(C):
from .log import get_module_logger
LOG = get_module_logger("mount nfs", level=logging.INFO)
@@ -147,7 +148,78 @@ def init_from_yaml_conf(conf_path, **kwargs):
"""
with open(conf_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
config = yaml.safe_load(f)
config.update(kwargs)
default_conf = config.pop("default_conf", "client")
init(default_conf, **config)
def get_project_path(config_name="config.yaml", cur_path=None) -> Path:
"""
If users are building a project follow the following pattern.
- Qlib is a sub folder in project path
- There is a file named `config.yaml` in qlib.
For example:
If your project file system stucuture follows such a pattern
<project_path>/
- config.yaml
- ...some folders...
- qlib/
This folder will return <project_path>
NOTE: link is not supported here.
This method is often used when
- user want to use a relative config path instead of hard-coding qlib config path in code
Raises
------
FileNotFoundError:
If project path is not found
"""
if cur_path is None:
cur_path = Path(__file__).absolute().resolve()
while True:
if (cur_path / config_name).exists():
return cur_path
if cur_path == cur_path.parent:
raise FileNotFoundError("We can't find the project path")
cur_path = cur_path.parent
def auto_init(**kwargs):
"""
This function will init qlib automatically with following priority
- Find the project configuration and init qlib
- The parsing process will be affected by the `conf_type` of the configuration file
- Init qlib with default config
"""
try:
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
except FileNotFoundError:
init(**kwargs)
else:
conf_pp = pp / "config.yaml"
with conf_pp.open() as f:
conf = yaml.safe_load(f)
conf_type = conf.get("conf_type", "origin")
if conf_type == "origin":
# The type of config is just like original qlib config
init_from_yaml_conf(conf_pp, **kwargs)
elif conf_type == "ref":
# This config type will be more convenient in following scenario
# - There is a shared configure file and you don't want to edit it inplace.
# - The shared configure may be updated later and you don't want to copy it.
# - You have some customized config.
qlib_conf_path = conf["qlib_cfg"]
qlib_conf_update = conf.get("qlib_cfg_update")
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs)
logger = get_module_logger("Initialization")
logger.info(f"Auto load project config: {conf_pp}")

View File

@@ -33,6 +33,9 @@ class Config:
raise AttributeError(f"No such {attr} in self._config")
def get(self, key, default=None):
return self.__dict__["_config"].get(key, default)
def __setitem__(self, key, value):
self.__dict__["_config"][key] = value
@@ -105,7 +108,7 @@ _default_config = {
"redis_port": 6379,
"redis_task_db": 1,
# This value can be reset via qlib.init
"logging_level": "INFO",
"logging_level": logging.INFO,
# Global configuration of qlib log
# logging_level can control the logging level more finely
"logging_config": {
@@ -124,14 +127,14 @@ _default_config = {
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": "DEBUG",
"level": logging.DEBUG,
"formatter": "logger_format",
"filters": ["field_not_found"],
}
},
"loggers": {"qlib": {"level": "DEBUG", "handlers": ["console"]}},
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}},
},
# Defatult config for experiment manager
# Default config for experiment manager
"exp_manager": {
"class": "MLflowExpManager",
"module_path": "qlib.workflow.expm",
@@ -140,6 +143,11 @@ _default_config = {
"default_exp_name": "Experiment",
},
},
# Default config for MongoDB
"mongo": {
"task_url": "mongodb://localhost:27017/",
"task_db_name": "default_task_db",
},
}
MODE_CONF = {
@@ -185,7 +193,7 @@ MODE_CONF = {
# The nfs should be auto-mounted by qlib on other
# serversS(such as PAI) [auto_mount:True]
"timeout": 100,
"logging_level": "INFO",
"logging_level": logging.INFO,
"region": REG_CN,
## Custom Operator
"custom_ops": [],
@@ -310,8 +318,22 @@ class QlibConfig(Config):
# clean up experiment when python program ends
experiment_exit_handler()
# Supporting user reset qlib version (useful when user want to connect to qlib server with old version)
self.reset_qlib_version()
self._registered = True
def reset_qlib_version(self):
import qlib
reset_version = self.get("qlib_reset_version", None)
if reset_version is not None:
qlib.__version__ = reset_version
else:
qlib.__version__ = getattr(qlib, "__version__bak")
# Due to a bug? that converting __version__ to _QlibConfig__version__bak
# Using __version__bak instead of __version__
@property
def registered(self):
return self._registered

View File

@@ -104,10 +104,9 @@ class Account:
# if suspend, no new price to be updated, profit is 0
if trader.check_stock_suspended(code, today):
continue
else:
today_close = trader.get_close(code, today)
profit += (today_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
self.current.update_stock_price(stock_id=code, price=today_close)
today_close = trader.get_close(code, today)
profit += (today_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
self.current.update_stock_price(stock_id=code, price=today_close)
self.rtn += profit
# update holding day count
self.current.add_count_all()

View File

@@ -15,7 +15,8 @@ LOG = get_module_logger("backtest")
def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, benchmark, return_order):
"""Parameters
"""
Parameters
----------
pred : pandas.DataFrame
predict should has <datetime, instrument> index and one `score` column
@@ -124,7 +125,9 @@ def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account,
def update_account(trade_account, trade_info, trade_exchange, trade_date):
"""Update the account and strategy
"""
Update the account and strategy
Parameters
----------
trade_account : Account()

View File

@@ -1,10 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
import copy
import pathlib
import pandas as pd
import numpy as np
from .order import Order
"""
@@ -128,7 +128,7 @@ class Position:
return self.position["cash"]
def get_stock_amount_dict(self):
"""generate stock amount dict {stock_id : amount of stock} """
"""generate stock amount dict {stock_id : amount of stock}"""
d = {}
stock_list = self.get_stock_list()
for stock_code in stock_list:
@@ -166,7 +166,7 @@ class Position:
def save_position(self, path, last_trade_date):
path = pathlib.Path(path)
p = copy.deepcopy(self.position)
cash = pd.Series(dtype=np.float)
cash = pd.Series(dtype=float)
cash["init_cash"] = self.init_cash
cash["cash"] = p["cash"]
cash["today_account_value"] = p["today_account_value"]

View File

@@ -26,6 +26,7 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time):
"fit_end_time": fit_end_time,
}
)
# FIXME: the `module_path` parameter is missed.
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
else:
new_l.append(p)

View File

@@ -8,6 +8,59 @@ import pandas as pd
from typing import Tuple
def calc_long_short_prec(
pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False
) -> Tuple[pd.Series, pd.Series]:
"""
calculate the precision for long and short operation
:param pred/label: index is **pd.MultiIndex**, index name is **[datetime, instruments]**; columns names is **[score]**.
.. code-block:: python
score
datetime instrument
2020-12-01 09:30:00 SH600068 0.553634
SH600195 0.550017
SH600276 0.540321
SH600584 0.517297
SH600715 0.544674
label :
label
date_col :
date_col
Returns
-------
(pd.Series, pd.Series)
long precision and short precision in time level
"""
if is_alpha:
label = label - label.mean(level=date_col)
if int(1 / quantile) >= len(label.index.get_level_values(1).unique()):
raise ValueError("Need more instruments to calculate precision")
df = pd.DataFrame({"pred": pred, "label": label})
if dropna:
df.dropna(inplace=True)
group = df.groupby(level=date_col)
N = lambda x: int(len(x) * quantile)
# find the top/low quantile of prediction and treat them as long and short target
long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True)
short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True)
groupll = long.groupby(date_col)
l_dom = groupll.apply(lambda x: x > 0)
l_c = groupll.count()
groups = short.groupby(date_col)
s_dom = groups.apply(lambda x: x < 0)
s_c = groups.count()
return (l_dom.groupby(date_col).sum() / l_c), (s_dom.groupby(date_col).sum() / s_c)
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]:
"""calc_ic.

View File

@@ -61,7 +61,7 @@ def get_position_value(evaluate_date, position):
# load close price for position
# position should also consider cash
instruments = list(position.keys())
instruments = list(set(instruments) - set(["cash"])) # filter 'cash'
instruments = list(set(instruments) - {"cash"}) # filter 'cash'
fields = ["$close"]
close_data_df = D.features(
instruments,
@@ -80,7 +80,7 @@ def get_position_list_value(positions):
instruments = set()
for day, position in positions.items():
instruments.update(position.keys())
instruments = list(set(instruments) - set(["cash"])) # filter 'cash'
instruments = list(set(instruments) - {"cash"}) # filter 'cash'
instruments.sort()
day_list = list(positions.keys())
day_list.sort()

View File

@@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
try:
from .catboost_model import CatBoostModel
except ModuleNotFoundError:
CatBoostModel = None
print("Please install necessary libs for CatBoostModel.")
try:
from .double_ensemble import DEnsembleModel
from .gbdt import LGBModel
except ModuleNotFoundError:
DEnsembleModel, LGBModel = None, None
print("Please install necessary libs for DEnsembleModel and LGBModel, such as lightgbm.")
try:
from .xgboost import XGBModel
except ModuleNotFoundError:
XGBModel = None
print("Please install necessary libs for XGBModel, such as xgboost.")
try:
from .linear import LinearModel
except ModuleNotFoundError:
LinearModel = None
print("Please install necessary libs for LinearModel, such as scipy and sklearn.")
# import pytorch models
try:
from .pytorch_alstm import ALSTM
from .pytorch_gats import GATs
from .pytorch_gru import GRU
from .pytorch_lstm import LSTM
from .pytorch_nn import DNNModelPytorch
from .pytorch_tabnet import TabnetModel
from .pytorch_sfm import SFM_Model
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model)
except ModuleNotFoundError:
pytorch_classes = ()
print("Please install necessary libs for PyTorch models.")
all_model_classes = (CatBoostModel, DEnsembleModel, LGBModel, XGBModel, LinearModel) + pytorch_classes

View File

@@ -3,15 +3,17 @@
import numpy as np
import pandas as pd
from typing import Text, Union
from catboost import Pool, CatBoost
from catboost.utils import get_gpu_device_count
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...model.interpret.base import FeatureInt
class CatBoostModel(Model):
class CatBoostModel(Model, FeatureInt):
"""CatBoost Model"""
def __init__(self, loss="RMSE", **kwargs):
@@ -62,12 +64,24 @@ class CatBoostModel(Model):
evals_result["train"] = list(evals_result["learn"].values())[0]
evals_result["valid"] = list(evals_result["validation"].values())[0]
def predict(self, dataset):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if self.model is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
"""get feature importance
Notes
-----
parameters references:
https://catboost.ai/docs/concepts/python-reference_catboost_get_feature_importance.html#python-reference_catboost_get_feature_importance
"""
return pd.Series(
data=self.model.get_feature_importance(*args, **kwargs), index=self.model.feature_names_
).sort_values(ascending=False)
if __name__ == "__main__":
cat = CatBoostModel()

View File

@@ -0,0 +1,265 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import lightgbm as lgb
import numpy as np
import pandas as pd
from typing import Text, Union
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...model.interpret.base import FeatureInt
from ...log import get_module_logger
class DEnsembleModel(Model, FeatureInt):
"""Double Ensemble Model"""
def __init__(
self,
base_model="gbm",
loss="mse",
num_models=6,
enable_sr=True,
enable_fs=True,
alpha1=1.0,
alpha2=1.0,
bins_sr=10,
bins_fs=5,
decay=None,
sample_ratios=None,
sub_weights=None,
epochs=100,
**kwargs
):
self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm"
self.num_models = num_models # the number of sub-models
self.enable_sr = enable_sr
self.enable_fs = enable_fs
self.alpha1 = alpha1
self.alpha2 = alpha2
self.bins_sr = bins_sr
self.bins_fs = bins_fs
self.decay = decay
if sample_ratios is None: # the default values for sample_ratios
sample_ratios = [0.8, 0.7, 0.6, 0.5, 0.4]
if sub_weights is None: # the default values for sub_weights
sub_weights = [1.0, 0.2, 0.2, 0.2, 0.2, 0.2]
if not len(sample_ratios) == bins_fs:
raise ValueError("The length of sample_ratios should be equal to bins_fs.")
self.sample_ratios = sample_ratios
if not len(sub_weights) == num_models:
raise ValueError("The length of sub_weights should be equal to num_models.")
self.sub_weights = sub_weights
self.epochs = epochs
self.logger = get_module_logger("DEnsembleModel")
self.logger.info("Double Ensemble Model...")
self.ensemble = [] # the current ensemble model, a list contains all the sub-models
self.sub_features = [] # the features for each sub model in the form of pandas.Index
self.params = {"objective": loss}
self.params.update(kwargs)
self.loss = loss
def fit(self, dataset: DatasetH):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
x_train, y_train = df_train["feature"], df_train["label"]
# initialize the sample weights
N, F = x_train.shape
weights = pd.Series(np.ones(N, dtype=float))
# initialize the features
features = x_train.columns
pred_sub = pd.DataFrame(np.zeros((N, self.num_models), dtype=float), index=x_train.index)
# train sub-models
for k in range(self.num_models):
self.sub_features.append(features)
self.logger.info("Training sub-model: ({}/{})".format(k + 1, self.num_models))
model_k = self.train_submodel(df_train, df_valid, weights, features)
self.ensemble.append(model_k)
# no further sample re-weight and feature selection needed for the last sub-model
if k + 1 == self.num_models:
break
self.logger.info("Retrieving loss curve and loss values...")
loss_curve = self.retrieve_loss_curve(model_k, df_train, features)
pred_k = self.predict_sub(model_k, df_train, features)
pred_sub.iloc[:, k] = pred_k
pred_ensemble = pred_sub.iloc[:, : k + 1].mean(axis=1)
loss_values = pd.Series(self.get_loss(y_train.values.squeeze(), pred_ensemble.values))
if self.enable_sr:
self.logger.info("Sample re-weighting...")
weights = self.sample_reweight(loss_curve, loss_values, k + 1)
if self.enable_fs:
self.logger.info("Feature selection...")
features = self.feature_selection(df_train, loss_values)
def train_submodel(self, df_train, df_valid, weights, features):
dtrain, dvalid = self._prepare_data_gbm(df_train, df_valid, weights, features)
evals_result = dict()
model = lgb.train(
self.params,
dtrain,
num_boost_round=self.epochs,
valid_sets=[dtrain, dvalid],
valid_names=["train", "valid"],
verbose_eval=20,
evals_result=evals_result,
)
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]
return model
def _prepare_data_gbm(self, df_train, df_valid, weights, features):
x_train, y_train = df_train["feature"].loc[:, features], df_train["label"]
x_valid, y_valid = df_valid["feature"].loc[:, features], df_valid["label"]
# Lightgbm need 1D array as its label
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)
else:
raise ValueError("LightGBM doesn't support multi-label training")
dtrain = lgb.Dataset(x_train, label=y_train, weight=weights)
dvalid = lgb.Dataset(x_valid, label=y_valid)
return dtrain, dvalid
def sample_reweight(self, loss_curve, loss_values, k_th):
"""
the SR module of Double Ensemble
:param loss_curve: the shape is NxT
the loss curve for the previous sub-model, where the element (i, t) if the error on the i-th sample
after the t-th iteration in the training of the previous sub-model.
:param loss_values: the shape is N
the loss of the current ensemble on the i-th sample.
:param k_th: the index of the current sub-model, starting from 1
:return: weights
the weights for all the samples.
"""
# normalize loss_curve and loss_values with ranking
loss_curve_norm = loss_curve.rank(axis=0, pct=True)
loss_values_norm = (-loss_values).rank(pct=True)
# calculate l_start and l_end from loss_curve
N, T = loss_curve.shape
part = np.maximum(int(T * 0.1), 1)
l_start = loss_curve_norm.iloc[:, :part].mean(axis=1)
l_end = loss_curve_norm.iloc[:, -part:].mean(axis=1)
# calculate h-value for each sample
h1 = loss_values_norm
h2 = (l_end / l_start).rank(pct=True)
h = pd.DataFrame({"h_value": self.alpha1 * h1 + self.alpha2 * h2})
# calculate weights
h["bins"] = pd.cut(h["h_value"], self.bins_sr)
h_avg = h.groupby("bins")["h_value"].mean()
weights = pd.Series(np.zeros(N, dtype=float))
for i_b, b in enumerate(h_avg.index):
weights[h["bins"] == b] = 1.0 / (self.decay ** k_th * h_avg[i_b] + 0.1)
return weights
def feature_selection(self, df_train, loss_values):
"""
the FS module of Double Ensemble
:param df_train: the shape is NxF
:param loss_values: the shape is N
the loss of the current ensemble on the i-th sample.
:return: res_feat: in the form of pandas.Index
"""
x_train, y_train = df_train["feature"], df_train["label"]
features = x_train.columns
N, F = x_train.shape
g = pd.DataFrame({"g_value": np.zeros(F, dtype=float)})
M = len(self.ensemble)
# shuffle specific columns and calculate g-value for each feature
x_train_tmp = x_train.copy()
for i_f, feat in enumerate(features):
x_train_tmp.loc[:, feat] = np.random.permutation(x_train_tmp.loc[:, feat].values)
pred = pd.Series(np.zeros(N), index=x_train_tmp.index)
for i_s, submodel in enumerate(self.ensemble):
pred += (
pd.Series(
submodel.predict(x_train_tmp.loc[:, self.sub_features[i_s]].values), index=x_train_tmp.index
)
/ M
)
loss_feat = self.get_loss(y_train.values.squeeze(), pred.values)
g.loc[i_f, "g_value"] = np.mean(loss_feat - loss_values) / (np.std(loss_feat - loss_values) + 1e-7)
x_train_tmp.loc[:, feat] = x_train.loc[:, feat].copy()
# one column in train features is all-nan # if g['g_value'].isna().any()
g["g_value"].replace(np.nan, 0, inplace=True)
# divide features into bins_fs bins
g["bins"] = pd.cut(g["g_value"], self.bins_fs)
# randomly sample features from bins to construct the new features
res_feat = []
sorted_bins = sorted(g["bins"].unique(), reverse=True)
for i_b, b in enumerate(sorted_bins):
b_feat = features[g["bins"] == b]
num_feat = int(np.ceil(self.sample_ratios[i_b] * len(b_feat)))
res_feat = res_feat + np.random.choice(b_feat, size=num_feat, replace=False).tolist()
return pd.Index(set(res_feat))
def get_loss(self, label, pred):
if self.loss == "mse":
return (label - pred) ** 2
else:
raise ValueError("not implemented yet")
def retrieve_loss_curve(self, model, df_train, features):
if self.base_model == "gbm":
num_trees = model.num_trees()
x_train, y_train = df_train["feature"].loc[:, features], df_train["label"]
# Lightgbm need 1D array as its label
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
y_train = np.squeeze(y_train.values)
else:
raise ValueError("LightGBM doesn't support multi-label training")
N = x_train.shape[0]
loss_curve = pd.DataFrame(np.zeros((N, num_trees)))
pred_tree = np.zeros(N, dtype=float)
for i_tree in range(num_trees):
pred_tree += model.predict(x_train.values, start_iteration=i_tree, num_iteration=1)
loss_curve.iloc[:, i_tree] = self.get_loss(y_train, pred_tree)
else:
raise ValueError("not implemented yet")
return loss_curve
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if self.ensemble is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
pred = pd.Series(np.zeros(x_test.shape[0]), index=x_test.index)
for i_sub, submodel in enumerate(self.ensemble):
feat_sub = self.sub_features[i_sub]
pred += (
pd.Series(submodel.predict(x_test.loc[:, feat_sub].values), index=x_test.index)
* self.sub_weights[i_sub]
)
return pred
def predict_sub(self, submodel, df_data, features):
x_data, y_data = df_data["feature"].loc[:, features], df_data["label"]
pred_sub = pd.Series(submodel.predict(x_data.values), index=x_data.index)
return pred_sub
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
"""get feature importance
Notes
-----
parameters reference:
https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.Booster.html?highlight=feature_importance#lightgbm.Booster.feature_importance
"""
res = []
for _model, _weight in zip(self.ensemble, self.sub_weights):
res.append(pd.Series(_model.feature_importance(*args, **kwargs), index=_model.feature_name()) * _weight)
return pd.concat(res, axis=1, sort=False).sum(axis=1).sort_values(ascending=False)

View File

@@ -4,13 +4,14 @@
import numpy as np
import pandas as pd
import lightgbm as lgb
from typing import Text, Union
from ...model.base import ModelFT
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...model.interpret.base import LightGBMFInt
class LGBModel(ModelFT):
class LGBModel(ModelFT, LightGBMFInt):
"""LightGBM Model"""
def __init__(self, loss="mse", **kwargs):
@@ -33,8 +34,8 @@ class LGBModel(ModelFT):
else:
raise ValueError("LightGBM doesn't support multi-label training")
dtrain = lgb.Dataset(x_train.values, label=y_train)
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
dtrain = lgb.Dataset(x_train, label=y_train)
dvalid = lgb.Dataset(x_valid, label=y_valid)
return dtrain, dvalid
def fit(
@@ -61,10 +62,10 @@ class LGBModel(ModelFT):
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]
def predict(self, dataset):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if self.model is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):

View File

@@ -0,0 +1,158 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import warnings
import numpy as np
import pandas as pd
import lightgbm as lgb
from ...model.base import ModelFT
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...model.interpret.base import LightGBMFInt
class HFLGBModel(ModelFT, LightGBMFInt):
"""LightGBM Model for high frequency prediction"""
def __init__(self, loss="mse", **kwargs):
if loss not in {"mse", "binary"}:
raise NotImplementedError
self.params = {"objective": loss, "verbosity": -1}
self.params.update(kwargs)
self.model = None
def _cal_signal_metrics(self, y_test, l_cut, r_cut):
"""
Calcaute the signal metrics by daily level
"""
up_pre, down_pre = [], []
up_alpha_ll, down_alpha_ll = [], []
for date in y_test.index.get_level_values(0).unique():
df_res = y_test.loc[date].sort_values("pred")
if int(l_cut * len(df_res)) < 10:
warnings.warn("Warning: threhold is too low or instruments number is not enough")
continue
top = df_res.iloc[: int(l_cut * len(df_res))]
bottom = df_res.iloc[int(r_cut * len(df_res)) :]
down_precision = len(top[top[top.columns[0]] < 0]) / (len(top))
up_precision = len(bottom[bottom[top.columns[0]] > 0]) / (len(bottom))
down_alpha = top[top.columns[0]].mean()
up_alpha = bottom[bottom.columns[0]].mean()
up_pre.append(up_precision)
down_pre.append(down_precision)
up_alpha_ll.append(up_alpha)
down_alpha_ll.append(down_alpha)
return (
np.array(up_pre).mean(),
np.array(down_pre).mean(),
np.array(up_alpha_ll).mean(),
np.array(down_alpha_ll).mean(),
)
def hf_signal_test(self, dataset: DatasetH, threhold=0.2):
"""
Test the sigal in high frequency test set
"""
if self.model == None:
raise ValueError("Model hasn't been trained yet")
df_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
df_test.dropna(inplace=True)
x_test, y_test = df_test["feature"], df_test["label"]
# Convert label into alpha
y_test[y_test.columns[0]] = y_test[y_test.columns[0]] - y_test[y_test.columns[0]].mean(level=0)
res = pd.Series(self.model.predict(x_test.values), index=x_test.index)
y_test["pred"] = res
up_p, down_p, up_a, down_a = self._cal_signal_metrics(y_test, threhold, 1 - threhold)
print("===============================")
print("High frequency signal test")
print("===============================")
print("Test set precision: ")
print("Positive precision: {}, Negative precision: {}".format(up_p, down_p))
print("Test Alpha Average in test set: ")
print("Positive average alpha: {}, Negative average alpha: {}".format(up_a, down_a))
def _prepare_data(self, dataset: DatasetH):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_train["feature"], df_valid["label"]
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
l_name = df_train["label"].columns[0]
# Convert label into alpha
df_train["label"][l_name] = df_train["label"][l_name] - df_train["label"][l_name].mean(level=0)
df_valid["label"][l_name] = df_valid["label"][l_name] - df_valid["label"][l_name].mean(level=0)
mapping_fn = lambda x: 0 if x < 0 else 1
df_train["label_c"] = df_train["label"][l_name].apply(mapping_fn)
df_valid["label_c"] = df_valid["label"][l_name].apply(mapping_fn)
x_train, y_train = df_train["feature"], df_train["label_c"].values
x_valid, y_valid = df_valid["feature"], df_valid["label_c"].values
else:
raise ValueError("LightGBM doesn't support multi-label training")
dtrain = lgb.Dataset(x_train, label=y_train)
dvalid = lgb.Dataset(x_valid, label=y_valid)
return dtrain, dvalid
def fit(
self,
dataset: DatasetH,
num_boost_round=1000,
early_stopping_rounds=50,
verbose_eval=20,
evals_result=dict(),
**kwargs
):
dtrain, dvalid = self._prepare_data(dataset)
self.model = lgb.train(
self.params,
dtrain,
num_boost_round=num_boost_round,
valid_sets=[dtrain, dvalid],
valid_names=["train", "valid"],
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
evals_result=evals_result,
**kwargs
)
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]
def predict(self, dataset):
if self.model is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
"""
finetune model
Parameters
----------
dataset : DatasetH
dataset for finetuning
num_boost_round : int
number of round to finetune model
verbose_eval : int
verbose level
"""
# Based on existing model and finetune by train more rounds
dtrain, _ = self._prepare_data(dataset)
self.model = lgb.train(
self.params,
dtrain,
num_boost_round=num_boost_round,
init_model=self.model,
valid_sets=[dtrain],
valid_names=["train"],
verbose_eval=verbose_eval,
)

View File

@@ -3,7 +3,7 @@
import numpy as np
import pandas as pd
from typing import Text, Union
from scipy.optimize import nnls
from sklearn.linear_model import LinearRegression, Ridge, Lasso
@@ -84,8 +84,8 @@ class LinearModel(Model):
self.coef_ = coef
self.intercept_ = 0.0
def predict(self, dataset):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if self.coef_ is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
return pd.Series(x_test.values @ self.coef_ + self.intercept_, index=x_test.index)

View File

@@ -8,21 +8,16 @@ from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
from sklearn.metrics import roc_auc_score, mean_squared_error
import logging
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
create_save_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
@@ -39,8 +34,8 @@ class ALSTM(Model):
the evaluate metric used in early stop
optimizer : str
optimizer name
GPU : str
the GPU ID(s) used for training
GPU : int
the GPU ID used for training
"""
def __init__(
@@ -76,8 +71,7 @@ class ALSTM(Model):
self.early_stop = early_stop
self.optimizer = optimizer.lower()
self.loss = loss
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
self.use_gpu = torch.cuda.is_available()
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
self.logger.info(
@@ -93,7 +87,7 @@ class ALSTM(Model):
"\nearly_stop : {}"
"\noptimizer : {}"
"\nloss_type : {}"
"\nvisible_GPU : {}"
"\ndevice : {}"
"\nuse_GPU : {}"
"\nseed : {}".format(
d_feat,
@@ -107,7 +101,7 @@ class ALSTM(Model):
early_stop,
optimizer.lower(),
loss,
GPU,
self.device,
self.use_gpu,
seed,
)
@@ -123,6 +117,9 @@ class ALSTM(Model):
num_layers=self.num_layers,
dropout=self.dropout,
)
self.logger.info("model:\n{:}".format(self.ALSTM_model))
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.ALSTM_model)))
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.ALSTM_model.parameters(), lr=self.lr)
elif optimizer.lower() == "gd":
@@ -130,9 +127,13 @@ class ALSTM(Model):
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self._fitted = False
self.fitted = False
self.ALSTM_model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred - label) ** 2
return torch.mean(loss)
@@ -201,12 +202,13 @@ class ALSTM(Model):
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
pred = self.ALSTM_model(feature)
loss = self.loss_fn(pred, label)
losses.append(loss.item())
with torch.no_grad():
pred = self.ALSTM_model(feature)
loss = self.loss_fn(pred, label)
losses.append(loss.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
return np.mean(losses), np.mean(scores)
@@ -214,7 +216,6 @@ class ALSTM(Model):
self,
dataset: DatasetH,
evals_result=dict(),
verbose=True,
save_path=None,
):
@@ -227,8 +228,7 @@ class ALSTM(Model):
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
if save_path == None:
save_path = create_save_path(save_path)
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
@@ -238,7 +238,7 @@ class ALSTM(Model):
# train
self.logger.info("training...")
self._fitted = True
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
@@ -269,11 +269,11 @@ class ALSTM(Model):
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset):
if not self._fitted:
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
index = x_test.index
self.ALSTM_model.eval()
x_values = x_test.values
@@ -290,10 +290,7 @@ class ALSTM(Model):
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
with torch.no_grad():
if self.use_gpu:
pred = self.ALSTM_model(x_batch).detach().cpu().numpy()
else:
pred = self.ALSTM_model(x_batch).detach().numpy()
pred = self.ALSTM_model(x_batch).detach().cpu().numpy()
preds.append(pred)

View File

@@ -8,22 +8,17 @@ from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
from sklearn.metrics import roc_auc_score, mean_squared_error
import logging
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
create_save_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH, TSDatasetH
from ...data.dataset.handler import DataHandlerLP
@@ -40,8 +35,8 @@ class ALSTM(Model):
the evaluate metric used in early stop
optimizer : str
optimizer name
GPU : str
the GPU ID(s) used for training
GPU : int
the GPU ID used for training
"""
def __init__(
@@ -78,9 +73,8 @@ class ALSTM(Model):
self.early_stop = early_stop
self.optimizer = optimizer.lower()
self.loss = loss
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.n_jobs = n_jobs
self.use_gpu = torch.cuda.is_available()
self.seed = seed
self.logger.info(
@@ -96,7 +90,7 @@ class ALSTM(Model):
"\nearly_stop : {}"
"\noptimizer : {}"
"\nloss_type : {}"
"\nvisible_GPU : {}"
"\ndevice : {}"
"\nn_jobs : {}"
"\nuse_GPU : {}"
"\nseed : {}".format(
@@ -111,7 +105,7 @@ class ALSTM(Model):
early_stop,
optimizer.lower(),
loss,
GPU,
self.device,
n_jobs,
self.use_gpu,
seed,
@@ -127,7 +121,10 @@ class ALSTM(Model):
hidden_size=self.hidden_size,
num_layers=self.num_layers,
dropout=self.dropout,
).to(self.device)
)
self.logger.info("model:\n{:}".format(self.ALSTM_model))
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.ALSTM_model)))
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.ALSTM_model.parameters(), lr=self.lr)
elif optimizer.lower() == "gd":
@@ -135,9 +132,13 @@ class ALSTM(Model):
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self._fitted = False
self.fitted = False
self.ALSTM_model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred - label) ** 2
return torch.mean(loss)
@@ -188,12 +189,13 @@ class ALSTM(Model):
# feature[torch.isnan(feature)] = 0
label = data[:, -1, -1].to(self.device)
pred = self.ALSTM_model(feature.float())
loss = self.loss_fn(pred, label)
losses.append(loss.item())
with torch.no_grad():
pred = self.ALSTM_model(feature.float())
loss = self.loss_fn(pred, label)
losses.append(loss.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
return np.mean(losses), np.mean(scores)
@@ -201,7 +203,6 @@ class ALSTM(Model):
self,
dataset,
evals_result=dict(),
verbose=True,
save_path=None,
):
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
@@ -210,11 +211,14 @@ class ALSTM(Model):
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
train_loader = DataLoader(
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
)
valid_loader = DataLoader(
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
)
if save_path == None:
save_path = create_save_path(save_path)
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
@@ -225,7 +229,7 @@ class ALSTM(Model):
# train
self.logger.info("training...")
self._fitted = True
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
@@ -256,11 +260,11 @@ class ALSTM(Model):
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset):
if not self._fitted:
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
dl_test = dataset.prepare(segment, col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
dl_test.config(fillna_type="ffill+bfill")
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
self.ALSTM_model.eval()
@@ -271,10 +275,7 @@ class ALSTM(Model):
feature = data[:, :, 0:-1].to(self.device)
with torch.no_grad():
if self.use_gpu:
pred = self.ALSTM_model(feature.float()).detach().cpu().numpy()
else:
pred = self.ALSTM_model(feature.float()).detach().numpy()
pred = self.ALSTM_model(feature.float()).detach().cpu().numpy()
preds.append(pred)

View File

@@ -8,20 +8,15 @@ from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
from sklearn.metrics import roc_auc_score, mean_squared_error
import logging
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
create_save_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
@@ -42,8 +37,8 @@ class GATs(Model):
the evaluate metric used in early stop
optimizer : str
optimizer name
GPU : str
the GPU ID(s) used for training
GPU : int
the GPU ID used for training
"""
def __init__(
@@ -83,8 +78,7 @@ class GATs(Model):
self.base_model = base_model
self.with_pretrain = with_pretrain
self.model_path = model_path
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
self.use_gpu = torch.cuda.is_available()
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
self.logger.info(
@@ -102,7 +96,7 @@ class GATs(Model):
"\nbase_model : {}"
"\nwith_pretrain : {}"
"\nmodel_path : {}"
"\nvisible_GPU : {}"
"\ndevice : {}"
"\nuse_GPU : {}"
"\nseed : {}".format(
d_feat,
@@ -118,7 +112,7 @@ class GATs(Model):
base_model,
with_pretrain,
model_path,
GPU,
self.device,
self.use_gpu,
seed,
)
@@ -135,6 +129,9 @@ class GATs(Model):
dropout=self.dropout,
base_model=self.base_model,
)
self.logger.info("model:\n{:}".format(self.GAT_model))
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.GAT_model)))
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.GAT_model.parameters(), lr=self.lr)
elif optimizer.lower() == "gd":
@@ -142,9 +139,13 @@ class GATs(Model):
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self._fitted = False
self.fitted = False
self.GAT_model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred - label) ** 2
return torch.mean(loss)
@@ -232,7 +233,6 @@ class GATs(Model):
self,
dataset: DatasetH,
evals_result=dict(),
verbose=True,
save_path=None,
):
@@ -245,8 +245,7 @@ class GATs(Model):
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
if save_path == None:
save_path = create_save_path(save_path)
save_path = get_or_create_path(save_path)
stop_steps = 0
best_score = -np.inf
best_epoch = 0
@@ -275,7 +274,7 @@ class GATs(Model):
# train
self.logger.info("training...")
self._fitted = True
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
@@ -306,11 +305,11 @@ class GATs(Model):
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset):
if not self._fitted:
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
x_test = dataset.prepare(segment, col_set="feature")
index = x_test.index
self.GAT_model.eval()
x_values = x_test.values
@@ -324,10 +323,7 @@ class GATs(Model):
x_batch = torch.from_numpy(x_values[batch]).float().to(self.device)
with torch.no_grad():
if self.use_gpu:
pred = self.GAT_model(x_batch).detach().cpu().numpy()
else:
pred = self.GAT_model(x_batch).detach().numpy()
pred = self.GAT_model(x_batch).detach().cpu().numpy()
preds.append(pred)

View File

@@ -9,21 +9,15 @@ import os
import numpy as np
import pandas as pd
import copy
from sklearn.metrics import roc_auc_score, mean_squared_error
import logging
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
create_save_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Sampler
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
@@ -62,8 +56,8 @@ class GATs(Model):
the evaluate metric used in early stop
optimizer : str
optimizer name
GPU : str
the GPU ID(s) used for training
GPU : int
the GPU ID used for training
"""
def __init__(
@@ -104,9 +98,8 @@ class GATs(Model):
self.base_model = base_model
self.with_pretrain = with_pretrain
self.model_path = model_path
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.n_jobs = n_jobs
self.use_gpu = torch.cuda.is_available()
self.seed = seed
self.logger.info(
@@ -157,6 +150,9 @@ class GATs(Model):
dropout=self.dropout,
base_model=self.base_model,
)
self.logger.info("model:\n{:}".format(self.GAT_model))
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.GAT_model)))
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.GAT_model.parameters(), lr=self.lr)
elif optimizer.lower() == "gd":
@@ -164,9 +160,13 @@ class GATs(Model):
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self._fitted = False
self.fitted = False
self.GAT_model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred - label) ** 2
return torch.mean(loss)
@@ -245,7 +245,6 @@ class GATs(Model):
self,
dataset,
evals_result=dict(),
verbose=True,
save_path=None,
):
@@ -258,11 +257,10 @@ class GATs(Model):
sampler_train = DailyBatchSampler(dl_train)
sampler_valid = DailyBatchSampler(dl_valid)
train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs)
valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs)
train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs, drop_last=True)
valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs, drop_last=True)
if save_path == None:
save_path = create_save_path(save_path)
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
@@ -297,7 +295,7 @@ class GATs(Model):
# train
self.logger.info("training...")
self._fitted = True
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
@@ -329,7 +327,7 @@ class GATs(Model):
torch.cuda.empty_cache()
def predict(self, dataset):
if not self._fitted:
if not self.fitted:
raise ValueError("model is not fitted yet!")
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
@@ -345,10 +343,7 @@ class GATs(Model):
feature = data[:, :, 0:-1].to(self.device)
with torch.no_grad():
if self.use_gpu:
pred = self.GAT_model(feature.float()).detach().cpu().numpy()
else:
pred = self.GAT_model(feature.float()).detach().numpy()
pred = self.GAT_model(feature.float()).detach().cpu().numpy()
preds.append(pred)

View File

@@ -8,21 +8,16 @@ from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
from sklearn.metrics import roc_auc_score, mean_squared_error
import logging
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
create_save_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
@@ -76,8 +71,7 @@ class GRU(Model):
self.early_stop = early_stop
self.optimizer = optimizer.lower()
self.loss = loss
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
self.use_gpu = torch.cuda.is_available()
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
self.logger.info(
@@ -123,6 +117,9 @@ class GRU(Model):
num_layers=self.num_layers,
dropout=self.dropout,
)
self.logger.info("model:\n{:}".format(self.gru_model))
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.gru_model)))
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.gru_model.parameters(), lr=self.lr)
elif optimizer.lower() == "gd":
@@ -130,9 +127,13 @@ class GRU(Model):
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self._fitted = False
self.fitted = False
self.gru_model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred - label) ** 2
return torch.mean(loss)
@@ -201,12 +202,13 @@ class GRU(Model):
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
pred = self.gru_model(feature)
loss = self.loss_fn(pred, label)
losses.append(loss.item())
with torch.no_grad():
pred = self.gru_model(feature)
loss = self.loss_fn(pred, label)
losses.append(loss.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
return np.mean(losses), np.mean(scores)
@@ -214,7 +216,6 @@ class GRU(Model):
self,
dataset: DatasetH,
evals_result=dict(),
verbose=True,
save_path=None,
):
@@ -227,8 +228,7 @@ class GRU(Model):
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
if save_path == None:
save_path = create_save_path(save_path)
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
@@ -238,7 +238,7 @@ class GRU(Model):
# train
self.logger.info("training...")
self._fitted = True
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
@@ -269,11 +269,11 @@ class GRU(Model):
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset):
if not self._fitted:
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
index = x_test.index
self.gru_model.eval()
x_values = x_test.values
@@ -290,10 +290,7 @@ class GRU(Model):
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
with torch.no_grad():
if self.use_gpu:
pred = self.gru_model(x_batch).detach().cpu().numpy()
else:
pred = self.gru_model(x_batch).detach().numpy()
pred = self.gru_model(x_batch).detach().cpu().numpy()
preds.append(pred)

View File

@@ -9,21 +9,15 @@ import os
import numpy as np
import pandas as pd
import copy
from sklearn.metrics import roc_auc_score, mean_squared_error
import logging
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
create_save_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH, TSDatasetH
from ...data.dataset.handler import DataHandlerLP
@@ -78,9 +72,8 @@ class GRU(Model):
self.early_stop = early_stop
self.optimizer = optimizer.lower()
self.loss = loss
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.n_jobs = n_jobs
self.use_gpu = torch.cuda.is_available()
self.seed = seed
self.logger.info(
@@ -96,7 +89,7 @@ class GRU(Model):
"\nearly_stop : {}"
"\noptimizer : {}"
"\nloss_type : {}"
"\nvisible_GPU : {}"
"\ndevice : {}"
"\nn_jobs : {}"
"\nuse_GPU : {}"
"\nseed : {}".format(
@@ -111,7 +104,7 @@ class GRU(Model):
early_stop,
optimizer.lower(),
loss,
GPU,
self.device,
n_jobs,
self.use_gpu,
seed,
@@ -127,7 +120,10 @@ class GRU(Model):
hidden_size=self.hidden_size,
num_layers=self.num_layers,
dropout=self.dropout,
).to(self.device)
)
self.logger.info("model:\n{:}".format(self.GRU_model))
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.GRU_model)))
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.GRU_model.parameters(), lr=self.lr)
elif optimizer.lower() == "gd":
@@ -135,9 +131,13 @@ class GRU(Model):
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self._fitted = False
self.fitted = False
self.GRU_model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred - label) ** 2
return torch.mean(loss)
@@ -188,12 +188,13 @@ class GRU(Model):
# feature[torch.isnan(feature)] = 0
label = data[:, -1, -1].to(self.device)
pred = self.GRU_model(feature.float())
loss = self.loss_fn(pred, label)
losses.append(loss.item())
with torch.no_grad():
pred = self.GRU_model(feature.float())
loss = self.loss_fn(pred, label)
losses.append(loss.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
return np.mean(losses), np.mean(scores)
@@ -201,7 +202,6 @@ class GRU(Model):
self,
dataset,
evals_result=dict(),
verbose=True,
save_path=None,
):
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
@@ -210,11 +210,14 @@ class GRU(Model):
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
train_loader = DataLoader(
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
)
valid_loader = DataLoader(
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
)
if save_path == None:
save_path = create_save_path(save_path)
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
@@ -225,7 +228,7 @@ class GRU(Model):
# train
self.logger.info("training...")
self._fitted = True
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
@@ -257,7 +260,7 @@ class GRU(Model):
torch.cuda.empty_cache()
def predict(self, dataset):
if not self._fitted:
if not self.fitted:
raise ValueError("model is not fitted yet!")
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
@@ -271,10 +274,7 @@ class GRU(Model):
feature = data[:, :, 0:-1].to(self.device)
with torch.no_grad():
if self.use_gpu:
pred = self.GRU_model(feature.float()).detach().cpu().numpy()
else:
pred = self.GRU_model(feature.float()).detach().numpy()
pred = self.GRU_model(feature.float()).detach().cpu().numpy()
preds.append(pred)

View File

@@ -8,16 +8,10 @@ from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
from sklearn.metrics import roc_auc_score, mean_squared_error
import logging
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
create_save_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
@@ -76,8 +70,7 @@ class LSTM(Model):
self.early_stop = early_stop
self.optimizer = optimizer.lower()
self.loss = loss
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
self.use_gpu = torch.cuda.is_available()
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
self.logger.info(
@@ -130,9 +123,13 @@ class LSTM(Model):
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self._fitted = False
self.fitted = False
self.lstm_model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred - label) ** 2
return torch.mean(loss)
@@ -214,7 +211,6 @@ class LSTM(Model):
self,
dataset: DatasetH,
evals_result=dict(),
verbose=True,
save_path=None,
):
@@ -227,8 +223,7 @@ class LSTM(Model):
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
if save_path == None:
save_path = create_save_path(save_path)
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
@@ -238,7 +233,7 @@ class LSTM(Model):
# train
self.logger.info("training...")
self._fitted = True
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
@@ -269,11 +264,11 @@ class LSTM(Model):
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset):
if not self._fitted:
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
index = x_test.index
self.lstm_model.eval()
x_values = x_test.values
@@ -281,20 +276,13 @@ class LSTM(Model):
preds = []
for begin in range(sample_num)[:: self.batch_size]:
if sample_num - begin < self.batch_size:
end = sample_num
else:
end = begin + self.batch_size
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
with torch.no_grad():
if self.use_gpu:
pred = self.lstm_model(x_batch).detach().cpu().numpy()
else:
pred = self.lstm_model(x_batch).detach().numpy()
pred = self.lstm_model(x_batch).detach().cpu().numpy()
preds.append(pred)
return pd.Series(np.concatenate(preds), index=index)

View File

@@ -9,15 +9,8 @@ import os
import numpy as np
import pandas as pd
import copy
from sklearn.metrics import roc_auc_score, mean_squared_error
import logging
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
create_save_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
@@ -78,9 +71,8 @@ class LSTM(Model):
self.early_stop = early_stop
self.optimizer = optimizer.lower()
self.loss = loss
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.n_jobs = n_jobs
self.use_gpu = torch.cuda.is_available()
self.seed = seed
self.logger.info(
@@ -96,7 +88,7 @@ class LSTM(Model):
"\nearly_stop : {}"
"\noptimizer : {}"
"\nloss_type : {}"
"\nvisible_GPU : {}"
"\ndevice : {}"
"\nn_jobs : {}"
"\nuse_GPU : {}"
"\nseed : {}".format(
@@ -111,7 +103,7 @@ class LSTM(Model):
early_stop,
optimizer.lower(),
loss,
GPU,
self.device,
n_jobs,
self.use_gpu,
seed,
@@ -135,9 +127,13 @@ class LSTM(Model):
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self._fitted = False
self.fitted = False
self.LSTM_model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred - label) ** 2
return torch.mean(loss)
@@ -201,7 +197,6 @@ class LSTM(Model):
self,
dataset,
evals_result=dict(),
verbose=True,
save_path=None,
):
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
@@ -210,11 +205,14 @@ class LSTM(Model):
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
train_loader = DataLoader(
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
)
valid_loader = DataLoader(
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
)
if save_path == None:
save_path = create_save_path(save_path)
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
@@ -225,7 +223,7 @@ class LSTM(Model):
# train
self.logger.info("training...")
self._fitted = True
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
@@ -257,7 +255,7 @@ class LSTM(Model):
torch.cuda.empty_cache()
def predict(self, dataset):
if not self._fitted:
if not self.fitted:
raise ValueError("model is not fitted yet!")
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
@@ -271,10 +269,7 @@ class LSTM(Model):
feature = data[:, :, 0:-1].to(self.device)
with torch.no_grad():
if self.use_gpu:
pred = self.LSTM_model(feature.float()).detach().cpu().numpy()
else:
pred = self.LSTM_model(feature.float()).detach().numpy()
pred = self.LSTM_model(feature.float()).detach().cpu().numpy()
preds.append(pred)

View File

@@ -6,20 +6,21 @@ from __future__ import division
from __future__ import print_function
import os
import logging
import numpy as np
import pandas as pd
from typing import Text, Union
from sklearn.metrics import roc_auc_score, mean_squared_error
import torch
import torch.nn as nn
import torch.optim as optim
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
from ...log import get_module_logger, TimeInspector
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, get_or_create_path
from ...log import get_module_logger
from ...workflow import R
@@ -42,14 +43,14 @@ class DNNModelPytorch(Model):
learning rate decay steps
optimizer : str
optimizer name
GPU : str
the GPU ID(s) used for training
GPU : int
the GPU ID used for training
"""
def __init__(
self,
input_dim,
output_dim,
input_dim=360,
output_dim=1,
layers=(256,),
lr=0.001,
max_steps=300,
@@ -80,8 +81,7 @@ class DNNModelPytorch(Model):
self.lr_decay_steps = lr_decay_steps
self.optimizer = optimizer.lower()
self.loss_type = loss
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
self.use_GPU = torch.cuda.is_available()
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
self.weight_decay = weight_decay
@@ -99,7 +99,7 @@ class DNNModelPytorch(Model):
"\nloss_type : {}"
"\neval_steps : {}"
"\nseed : {}"
"\nvisible_GPU : {}"
"\ndevice : {}"
"\nuse_GPU : {}"
"\nweight_decay : {}".format(
layers,
@@ -114,8 +114,8 @@ class DNNModelPytorch(Model):
loss,
eval_steps,
seed,
GPU,
self.use_GPU,
self.device,
self.use_gpu,
weight_decay,
)
)
@@ -129,6 +129,9 @@ class DNNModelPytorch(Model):
self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
self.dnn_model = Net(input_dim, output_dim, layers, loss=self.loss_type)
self.logger.info("model:\n{:}".format(self.dnn_model))
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.dnn_model)))
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
elif optimizer.lower() == "gd":
@@ -150,9 +153,13 @@ class DNNModelPytorch(Model):
eps=1e-08,
)
self._fitted = False
self.fitted = False
self.dnn_model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def fit(
self,
dataset: DatasetH,
@@ -172,7 +179,7 @@ class DNNModelPytorch(Model):
w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)
w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)
save_path = create_save_path(save_path)
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_loss = np.inf
@@ -180,7 +187,7 @@ class DNNModelPytorch(Model):
evals_result["valid"] = []
# train
self.logger.info("training...")
self._fitted = True
self.fitted = True
# return
# prepare training data
x_train_values = torch.from_numpy(x_train.values).float()
@@ -215,7 +222,8 @@ class DNNModelPytorch(Model):
# validation
train_loss += loss.val
if step and step % self.eval_steps == 0:
# for evert `eval_steps` steps or at the last steps, we will evaluate the model.
if step % self.eval_steps == 0 or step + 1 == self.max_steps:
stop_steps += 1
train_loss /= self.eval_steps
@@ -248,9 +256,9 @@ class DNNModelPytorch(Model):
# update learning rate
self.scheduler.step(cur_loss_val)
# restore the optimal parameters after training ??
# restore the optimal parameters after training
self.dnn_model.load_state_dict(torch.load(save_path))
if self.use_GPU:
if self.use_gpu:
torch.cuda.empty_cache()
def get_loss(self, pred, w, target, loss_type):
@@ -264,18 +272,14 @@ class DNNModelPytorch(Model):
else:
raise NotImplementedError("loss {} is not supported!".format(loss_type))
def predict(self, dataset):
if not self._fitted:
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test_pd = dataset.prepare("test", col_set="feature")
x_test_pd = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
x_test = torch.from_numpy(x_test_pd.values).float().to(self.device)
self.dnn_model.eval()
with torch.no_grad():
if self.use_GPU:
preds = self.dnn_model(x_test).detach().cpu().numpy()
else:
preds = self.dnn_model(x_test).detach().numpy()
preds = self.dnn_model(x_test).detach().cpu().numpy()
return pd.Series(np.squeeze(preds), index=x_test_pd.index)
def save(self, filename, **kwargs):

View File

@@ -7,22 +7,17 @@ from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
from sklearn.metrics import roc_auc_score, mean_squared_error
import logging
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
create_save_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
@@ -196,8 +191,8 @@ class SFM(Model):
learning rate
optimizer : str
optimizer name
GPU : str
the GPU ID(s) used for training
GPU : int
the GPU ID used for training
"""
def __init__(
@@ -216,7 +211,7 @@ class SFM(Model):
eval_steps=5,
loss="mse",
optimizer="gd",
GPU="0",
GPU=0,
seed=None,
**kwargs
):
@@ -239,8 +234,7 @@ class SFM(Model):
self.eval_steps = eval_steps
self.optimizer = optimizer.lower()
self.loss = loss
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
self.use_gpu = torch.cuda.is_available()
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
self.logger.info(
@@ -259,7 +253,7 @@ class SFM(Model):
"\neval_steps : {}"
"\noptimizer : {}"
"\nloss_type : {}"
"\nvisible_GPU : {}"
"\ndevice : {}"
"\nuse_GPU : {}"
"\nseed : {}".format(
d_feat,
@@ -276,7 +270,7 @@ class SFM(Model):
eval_steps,
optimizer.lower(),
loss,
GPU,
self.device,
self.use_gpu,
seed,
)
@@ -295,6 +289,9 @@ class SFM(Model):
dropout_U=self.dropout_U,
device=self.device,
)
self.logger.info("model:\n{:}".format(self.sfm_model))
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.sfm_model)))
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.sfm_model.parameters(), lr=self.lr)
elif optimizer.lower() == "gd":
@@ -302,9 +299,13 @@ class SFM(Model):
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self._fitted = False
self.fitted = False
self.sfm_model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def test_epoch(self, data_x, data_y):
# prepare training data
@@ -365,7 +366,6 @@ class SFM(Model):
self,
dataset: DatasetH,
evals_result=dict(),
verbose=True,
save_path=None,
):
@@ -377,6 +377,7 @@ class SFM(Model):
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
@@ -386,7 +387,7 @@ class SFM(Model):
# train
self.logger.info("training...")
self._fitted = True
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
@@ -409,7 +410,10 @@ class SFM(Model):
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.sfm_model.load_state_dict(best_param)
torch.save(best_param, save_path)
if self.device != "cpu":
torch.cuda.empty_cache()
@@ -434,11 +438,11 @@ class SFM(Model):
raise ValueError("unknown metric `%s`" % self.metric)
def predict(self, dataset):
if not self._fitted:
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
index = x_test.index
self.sfm_model.eval()
x_values = x_test.values
@@ -451,10 +455,7 @@ class SFM(Model):
else:
end = begin + self.batch_size
x_batch = torch.from_numpy(x_values[begin:end]).float()
if self.device != "cpu":
x_batch = x_batch.to(self.device)
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
with torch.no_grad():
pred = self.sfm_model(x_batch).detach().cpu().numpy()

View File

@@ -6,16 +6,10 @@ from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
from sklearn.metrics import roc_auc_score, mean_squared_error
import logging
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
create_save_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
@@ -23,6 +17,7 @@ import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Function
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
@@ -49,12 +44,12 @@ class TabnetModel(Model):
loss="mse",
metric="",
early_stop=20,
GPU="1",
GPU=0,
pretrain_loss="custom",
ps=0.3,
lr=0.01,
pretrain=True,
pretrain_file="./pretrain/best.model",
pretrain_file=None,
):
"""
TabNet model for Qlib
@@ -75,28 +70,27 @@ class TabnetModel(Model):
self.n_epochs = n_epochs
self.logger = get_module_logger("TabNet")
self.pretrain_n_epochs = pretrain_n_epochs
self.device = "cuda:%s" % (GPU) if torch.cuda.is_available() else "cpu"
self.device = "cuda:%s" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu"
self.loss = loss
self.metric = metric
self.early_stop = early_stop
self.pretrain = pretrain
self.pretrain_file = pretrain_file
self.pretrain_file = get_or_create_path(pretrain_file)
self.logger.info(
"TabNet:"
"\nbatch_size : {}"
"\nvirtual bs : {}"
"\nGPU : {}"
"\npretrain: {}".format(self.batch_size, vbs, GPU, pretrain)
"\ndevice : {}"
"\npretrain: {}".format(self.batch_size, vbs, self.device, self.pretrain)
)
self.fitted = False
np.random.seed(self.seed)
torch.manual_seed(self.seed)
self.tabnet_model = TabNet(
inp_dim=self.d_feat, out_dim=self.out_dim, vbs=vbs, relax=relax, device=self.device
).to(self.device)
self.tabnet_decoder = TabNet_Decoder(self.out_dim, self.d_feat, n_shared, n_ind, vbs, n_steps, self.device).to(
self.device
)
self.tabnet_model = TabNet(inp_dim=self.d_feat, out_dim=self.out_dim, vbs=vbs, relax=relax).to(self.device)
self.tabnet_decoder = TabNet_Decoder(self.out_dim, self.d_feat, n_shared, n_ind, vbs, n_steps).to(self.device)
self.logger.info("model:\n{:}\n{:}".format(self.tabnet_model, self.tabnet_decoder))
self.logger.info("model size: {:.4f} MB".format(count_parameters([self.tabnet_model, self.tabnet_decoder])))
if optimizer.lower() == "adam":
self.pretrain_optimizer = optim.Adam(
@@ -112,11 +106,12 @@ class TabnetModel(Model):
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def pretrain_fn(self, dataset=DatasetH, pretrain_file="./pretrain/best.model"):
# make a directory if pretrian director does not exist
if pretrain_file.startswith("./pretrain") and not os.path.exists("pretrain"):
self.logger.info("make folder to store model...")
os.makedirs("pretrain")
get_or_create_path(pretrain_file)
[df_train, df_valid] = dataset.prepare(
["pretrain", "pretrain_validation"],
@@ -158,7 +153,6 @@ class TabnetModel(Model):
self,
dataset: DatasetH,
evals_result=dict(),
verbose=True,
save_path=None,
):
if self.pretrain:
@@ -178,16 +172,17 @@ class TabnetModel(Model):
df_train.fillna(df_train.mean(), inplace=True)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = np.inf
best_score = -np.inf
best_epoch = 0
evals_result["train"] = []
evals_result["valid"] = []
self.logger.info("training...")
self._fitted = True
self.fitted = True
for epoch_idx in range(self.n_epochs):
self.logger.info("epoch: %s" % (epoch_idx))
@@ -200,22 +195,29 @@ class TabnetModel(Model):
evals_result["train"].append(train_score)
evals_result["valid"].append(val_score)
if val_score < best_score:
if val_score > best_score:
best_score = val_score
stop_steps = 0
best_epoch = epoch_idx
best_param = copy.deepcopy(self.tabnet_model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
def predict(self, dataset):
if not self._fitted:
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.tabnet_model.load_state_dict(best_param)
torch.save(best_param, save_path)
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
index = x_test.index
self.tabnet_model.eval()
x_values = torch.from_numpy(x_test.values)
@@ -259,12 +261,13 @@ class TabnetModel(Model):
feature = x_values[indices[i : i + self.batch_size]].float().to(self.device)
label = y_values[indices[i : i + self.batch_size]].float().to(self.device)
priors = torch.ones(self.batch_size, self.d_feat).to(self.device)
pred = self.tabnet_model(feature, priors)
loss = self.loss_fn(pred, label)
losses.append(loss.item())
with torch.no_grad():
pred = self.tabnet_model(feature, priors)
loss = self.loss_fn(pred, label)
losses.append(loss.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
return np.mean(losses), np.mean(scores)
@@ -347,10 +350,11 @@ class TabnetModel(Model):
label = y_train_values.float().to(self.device)
S_mask = S_mask.to(self.device)
priors = 1 - S_mask
(vec, sparse_loss) = self.tabnet_model(feature, priors)
f = self.tabnet_decoder(vec)
with torch.no_grad():
(vec, sparse_loss) = self.tabnet_model(feature, priors)
f = self.tabnet_decoder(vec)
loss = self.pretrain_loss_fn(label, f, S_mask)
loss = self.pretrain_loss_fn(label, f, S_mask)
losses.append(loss.item())
return np.mean(losses)
@@ -396,9 +400,9 @@ class FinetuneModel(nn.Module):
class DecoderStep(nn.Module):
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs, device):
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs):
super().__init__()
self.fea_tran = FeatureTransformer(inp_dim, out_dim, shared, n_ind, vbs, device)
self.fea_tran = FeatureTransformer(inp_dim, out_dim, shared, n_ind, vbs)
self.fc = nn.Linear(out_dim, out_dim)
def forward(self, x):
@@ -407,13 +411,12 @@ class DecoderStep(nn.Module):
class TabNet_Decoder(nn.Module):
def __init__(self, inp_dim, out_dim, n_shared, n_ind, vbs, n_steps, device):
def __init__(self, inp_dim, out_dim, n_shared, n_ind, vbs, n_steps):
"""
TabNet decoder that is used in pre-training
"""
self.out_dim = out_dim
super().__init__()
self.out_dim = out_dim
if n_shared > 0:
self.shared = nn.ModuleList()
self.shared.append(nn.Linear(inp_dim, 2 * out_dim))
@@ -424,7 +427,7 @@ class TabNet_Decoder(nn.Module):
self.n_steps = n_steps
self.steps = nn.ModuleList()
for x in range(n_steps):
self.steps.append(DecoderStep(inp_dim, out_dim, self.shared, n_ind, vbs, device))
self.steps.append(DecoderStep(inp_dim, out_dim, self.shared, n_ind, vbs))
def forward(self, x):
out = torch.zeros(x.size(0), self.out_dim).to(x.device)
@@ -434,9 +437,7 @@ class TabNet_Decoder(nn.Module):
class TabNet(nn.Module):
def __init__(
self, inp_dim=6, out_dim=6, n_d=64, n_a=64, n_shared=2, n_ind=2, n_steps=5, relax=1.2, vbs=1024, device="cpu"
):
def __init__(self, inp_dim=6, out_dim=6, n_d=64, n_a=64, n_shared=2, n_ind=2, n_steps=5, relax=1.2, vbs=1024):
"""
TabNet AKA the original encoder
@@ -460,10 +461,10 @@ class TabNet(nn.Module):
else:
self.shared = None
self.first_step = FeatureTransformer(inp_dim, n_d + n_a, self.shared, n_ind, vbs, device)
self.first_step = FeatureTransformer(inp_dim, n_d + n_a, self.shared, n_ind, vbs)
self.steps = nn.ModuleList()
for x in range(n_steps - 1):
self.steps.append(DecisionStep(inp_dim, n_d, n_a, self.shared, n_ind, relax, vbs, device))
self.steps.append(DecisionStep(inp_dim, n_d, n_a, self.shared, n_ind, relax, vbs))
self.fc = nn.Linear(n_d, out_dim)
self.bn = nn.BatchNorm1d(inp_dim, momentum=0.01)
self.n_d = n_d
@@ -472,14 +473,14 @@ class TabNet(nn.Module):
assert not torch.isnan(x).any()
x = self.bn(x)
x_a = self.first_step(x)[:, self.n_d :]
sparse_loss = torch.zeros(1).to(x.device)
sparse_loss = []
out = torch.zeros(x.size(0), self.n_d).to(x.device)
for step in self.steps:
x_te, l = step(x, x_a, priors)
out += F.relu(x_te[:, : self.n_d]) # split the feautre from feat_transformer
x_a = x_te[:, self.n_d :]
sparse_loss += l
return self.fc(out), sparse_loss
sparse_loss.append(l)
return self.fc(out), sum(sparse_loss)
class GBN(nn.Module):
@@ -497,9 +498,12 @@ class GBN(nn.Module):
self.vbs = vbs
def forward(self, x):
chunk = torch.chunk(x, x.size(0) // self.vbs, 0)
res = [self.bn(y) for y in chunk]
return torch.cat(res, 0)
if x.size(0) <= self.vbs: # can not be chunked
return self.bn(x)
else:
chunk = torch.chunk(x, x.size(0) // self.vbs, 0)
res = [self.bn(y) for y in chunk]
return torch.cat(res, 0)
class GLU(nn.Module):
@@ -547,7 +551,7 @@ class AttentionTransformer(nn.Module):
class FeatureTransformer(nn.Module):
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs, device):
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs):
super().__init__()
first = True
self.shared = nn.ModuleList()
@@ -563,7 +567,7 @@ class FeatureTransformer(nn.Module):
self.independ.append(GLU(inp, out_dim, vbs=vbs))
for x in range(first, n_ind):
self.independ.append(GLU(out_dim, out_dim, vbs=vbs))
self.scale = torch.sqrt(torch.tensor([0.5], device=device))
self.scale = float(np.sqrt(0.5))
def forward(self, x):
if self.shared:
@@ -582,10 +586,10 @@ class DecisionStep(nn.Module):
One step for the TabNet
"""
def __init__(self, inp_dim, n_d, n_a, shared, n_ind, relax, vbs, device):
def __init__(self, inp_dim, n_d, n_a, shared, n_ind, relax, vbs):
super().__init__()
self.atten_tran = AttentionTransformer(n_a, inp_dim, relax, vbs)
self.fea_tran = FeatureTransformer(inp_dim, n_d + n_a, shared, n_ind, vbs, device)
self.fea_tran = FeatureTransformer(inp_dim, n_d + n_a, shared, n_ind, vbs)
def forward(self, x, a, priors):
mask = self.atten_tran(a, priors)

View File

@@ -0,0 +1,37 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.nn as nn
def count_parameters(models_or_parameters, unit="m"):
"""
This function is to obtain the storage size unit of a (or multiple) models.
Parameters
----------
models_or_parameters : PyTorch model(s) or a list of parameters.
unit : the storage size unit.
Returns
-------
The number of parameters of the given model(s) or parameters.
"""
if isinstance(models_or_parameters, nn.Module):
counts = sum(v.numel() for v in models_or_parameters.parameters())
elif isinstance(models_or_parameters, nn.Parameter):
counts = models_or_parameters.numel()
elif isinstance(models_or_parameters, (list, tuple)):
return sum(count_parameters(x, unit) for x in models_or_parameters)
else:
counts = sum(v.numel() for v in models_or_parameters)
unit = unit.lower()
if unit == "kb" or unit == "k":
counts /= 2 ** 10
elif unit == "mb" or unit == "m":
counts /= 2 ** 20
elif unit == "gb" or unit == "g":
counts /= 2 ** 30
elif unit is not None:
raise ValueError("Unknow unit: {:}".format(unit))
return counts

View File

@@ -4,13 +4,14 @@
import numpy as np
import pandas as pd
import xgboost as xgb
from typing import Text, Union
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...model.interpret.base import FeatureInt
class XGBModel(Model):
class XGBModel(Model, FeatureInt):
"""XGBModel Model"""
def __init__(self, **kwargs):
@@ -42,8 +43,8 @@ class XGBModel(Model):
else:
raise ValueError("XGBoost doesn't support multi-label training")
dtrain = xgb.DMatrix(x_train.values, label=y_train_1d)
dvalid = xgb.DMatrix(x_valid.values, label=y_valid_1d)
dtrain = xgb.DMatrix(x_train, label=y_train_1d)
dvalid = xgb.DMatrix(x_valid, label=y_valid_1d)
self.model = xgb.train(
self._params,
dtrain=dtrain,
@@ -57,8 +58,18 @@ class XGBModel(Model):
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]
def predict(self, dataset):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if self.model is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
return pd.Series(self.model.predict(xgb.DMatrix(x_test.values)), index=x_test.index)
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
"""get feature importance
Notes
-------
parameters reference:
https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.Booster.get_score
"""
return pd.Series(self.model.get_score(*args, **kwargs)).sort_values(ascending=False)

View File

@@ -63,7 +63,7 @@ class UserManager:
account_path = self.data_path / user_id
strategy_file = self.data_path / user_id / "strategy_{}.pickle".format(user_id)
model_file = self.data_path / user_id / "model_{}.pickle".format(user_id)
cur_user_list = [user_id for user_id in self.users]
cur_user_list = list(self.users)
if user_id in cur_user_list:
raise ValueError("User {} has been loaded".format(user_id))
else:
@@ -110,7 +110,7 @@ class UserManager:
raise ValueError("User data for {} already exists".format(user_id))
with config_file.open("r") as fp:
config = yaml.load(fp)
config = yaml.safe_load(fp)
# load model
model = init_instance_by_config(config["model"])

View File

@@ -148,7 +148,7 @@ class Operator:
for user_id, user in um.users.items():
dates, trade_exchange = prepare(um, trade_date, user_id, exchange_config)
executor = SimulatorExecutor(trade_exchange=trade_exchange)
if not str(dates[0].date()) == str(pred_date.date()):
if str(dates[0].date()) != str(pred_date.date()):
raise ValueError(
"The account data is not newest! last trading date {}, today {}".format(
dates[0].date(), trade_date.date()

View File

@@ -88,7 +88,7 @@ def prepare(um, today, user_id, exchange_config=None):
dates.append(get_next_trading_date(dates[-1], future=True))
if exchange_config:
with pathlib.Path(exchange_config).open("r") as fp:
exchange_paras = yaml.load(fp)
exchange_paras = yaml.safe_load(fp)
else:
exchange_paras = {}
trade_exchange = Exchange(trade_dates=dates, **exchange_paras)

View File

@@ -214,7 +214,7 @@ def cumulative_return_graph(
features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max())
features_df.columns = ['label']
qcr.cumulative_return_graph(positions, report_normal_df, features_df)
qcr.analysis_position.cumulative_return_graph(positions, report_normal_df, features_df)
Graph desc:

View File

@@ -94,7 +94,7 @@ def rank_label_graph(
features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'], pred_df_dates.min(), pred_df_dates.max())
features_df.columns = ['label']
qcr.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
qcr.analysis_position.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
:param position: position data; **qlib.contrib.backtest.backtest.backtest** result.

View File

@@ -186,7 +186,7 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list,
report_normal_df, _ = backtest(pred_df, strategy, **bparas)
qcr.report_graph(report_normal_df)
qcr.analysis_position.report_graph(report_normal_df)
:param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench**.

View File

@@ -18,7 +18,7 @@ from ...utils import get_module_by_module_path
class BaseGraph:
""""""
""" """
_name = None
@@ -161,7 +161,7 @@ class DistplotGraph(BaseGraph):
"""
_t_df = self._df.dropna()
_data_list = [_t_df[_col] for _col in self._name_dict]
_label_list = [_name for _name in self._name_dict.values()]
_label_list = list(self._name_dict.values())
_fig = create_distplot(_data_list, _label_list, show_rug=False, **self._graph_kwargs)
return _fig["data"]

View File

@@ -7,7 +7,6 @@ import numpy as np
import pandas as pd
from ..backtest.order import Order
from ...utils import get_pre_trading_date
from .order_generator import OrderGenWInteract
@@ -252,7 +251,7 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):
"""
Gnererate order list according to score_series at trade_date, will not change current.
Generate order list according to score_series at trade_date, will not change current.
Parameters
-----------
@@ -390,11 +389,11 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
current_stock_list = current_temp.get_stock_list()
value = cash * self.risk_degree / len(buy) if len(buy) > 0 else 0
# open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not consider it
# as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line
# open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not
# consider it as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line
# value = value / (1+trade_exchange.open_cost) # set open_cost limit
for code in buy:
# check is stock supended
# check is stock suspended
if not trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date):
continue
# buy order

View File

@@ -14,7 +14,7 @@ class TunerConfigManager:
self.config_path = config_path
with open(config_path) as fp:
config = yaml.load(fp)
config = yaml.safe_load(fp)
self.config = copy.deepcopy(config)
self.pipeline_ex_config = PipelineExperimentConfig(config.get("experiment", dict()), self)

View File

@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .record_temp import MultiSegRecord
from .record_temp import SignalMseRecord

View File

@@ -0,0 +1,89 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import pandas as pd
import numpy as np
from sklearn.metrics import mean_squared_error
from typing import Dict, Text, Any
from ...contrib.eva.alpha import calc_ic
from ...workflow.record_temp import RecordTemp
from ...workflow.record_temp import SignalRecord
from ...data import dataset as qlib_dataset
from ...log import get_module_logger
logger = get_module_logger("workflow", logging.INFO)
class MultiSegRecord(RecordTemp):
"""
This is the multiple segments signal record class that generates the signal prediction.
This class inherits the ``RecordTemp`` class.
"""
def __init__(self, model, dataset, recorder=None):
super().__init__(recorder=recorder)
if not isinstance(dataset, qlib_dataset.DatasetH):
raise ValueError("The type of dataset is not DatasetH instead of {:}".format(type(dataset)))
self.model = model
self.dataset = dataset
def generate(self, segments: Dict[Text, Any], save: bool = False):
for key, segment in segments.items():
predics = self.model.predict(self.dataset, segment)
if isinstance(predics, pd.Series):
predics = predics.to_frame("score")
labels = self.dataset.prepare(
segments=segment, col_set="label", data_key=qlib_dataset.handler.DataHandlerLP.DK_R
)
# Compute the IC and Rank IC
ic, ric = calc_ic(predics.iloc[:, 0], labels.iloc[:, 0])
results = {"all-IC": ic, "mean-IC": ic.mean(), "all-Rank-IC": ric, "mean-Rank-IC": ric.mean()}
logger.info("--- Results for {:} ({:}) ---".format(key, segment))
ic_x100, ric_x100 = ic * 100, ric * 100
logger.info("IC: {:.4f}%".format(ic_x100.mean()))
logger.info("ICIR: {:.4f}%".format(ic_x100.mean() / ic_x100.std()))
logger.info("Rank IC: {:.4f}%".format(ric_x100.mean()))
logger.info("Rank ICIR: {:.4f}%".format(ric_x100.mean() / ric_x100.std()))
if save:
save_name = "results-{:}.pkl".format(key)
self.recorder.save_objects(**{save_name: results})
logger.info(
"The record '{:}' has been saved as the artifact of the Experiment {:}".format(
save_name, self.recorder.experiment_id
)
)
class SignalMseRecord(SignalRecord):
"""
This is the Signal MSE Record class that computes the mean squared error (MSE).
This class inherits the ``SignalMseRecord`` class.
"""
artifact_path = "sig_analysis"
def __init__(self, recorder, **kwargs):
super().__init__(recorder=recorder, **kwargs)
def generate(self, **kwargs):
try:
self.check(parent=True)
except FileExistsError:
super().generate()
pred = self.load("pred.pkl")
label = self.load("label.pkl")
masks = ~np.isnan(label.values)
mse = mean_squared_error(pred.values[masks], label[masks])
metrics = {"MSE": mse, "RMSE": np.sqrt(mse)}
objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)}
self.recorder.log_metrics(**metrics)
self.recorder.save_objects(**objects, artifact_path=self.get_path())
logger.info("The evaluation results in SignalMseRecord is {:}".format(metrics))
def list(self):
paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")]
return paths

View File

@@ -1045,9 +1045,6 @@ class SimpleDatasetCache(DatasetCache):
class DatasetURICache(DatasetCache):
"""Prepared cache mechanism for server."""
def __init__(self, provider):
super(DatasetURICache, self).__init__(provider)
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs):
return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache)

View File

@@ -6,7 +6,9 @@ from __future__ import division
from __future__ import print_function
import os
import re
import abc
import copy
import time
import queue
import bisect
@@ -27,12 +29,41 @@ from .cache import DiskDatasetCache, DiskExpressionCache
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
class CalendarProvider(abc.ABC):
class ProviderBackendMixin:
def get_default_backend(self):
backend = {}
provider_name: str = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2]
# set default storage class
backend.setdefault("class", f"File{provider_name}Storage")
# set default storage module
backend.setdefault("module_path", "qlib.data.storage.file_storage")
return backend
def backend_obj(self, **kwargs):
backend = self.backend if self.backend else self.get_default_backend()
backend = copy.deepcopy(backend)
# set default storage kwargs
backend_kwargs = backend.setdefault("kwargs", {})
# default provider_uri map
if "provider_uri" not in backend_kwargs:
# if the user has no uri configured, use: uri = uri_map[freq]
freq = kwargs.get("freq", "day")
provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {freq: C.get_data_path()})
backend_kwargs["provider_uri"] = provider_uri_map[freq]
backend.setdefault("kwargs", {}).update(**kwargs)
return init_instance_by_config(backend)
class CalendarProvider(abc.ABC, ProviderBackendMixin):
"""Calendar provider base class
Provide calendar data.
"""
def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
@abc.abstractmethod
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
"""Get calendar of certain market in given time range.
@@ -127,12 +158,15 @@ class CalendarProvider(abc.ABC):
return hash_args(start_time, end_time, freq, future)
class InstrumentProvider(abc.ABC):
class InstrumentProvider(abc.ABC, ProviderBackendMixin):
"""Instrument provider base class
Provide instrument data.
"""
def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
@staticmethod
def instruments(market="all", filter_pipe=None):
"""Get the general config dictionary for a base market adding several dynamic filters.
@@ -215,12 +249,15 @@ class InstrumentProvider(abc.ABC):
raise ValueError(f"Unknown instrument type {inst}")
class FeatureProvider(abc.ABC):
class FeatureProvider(abc.ABC, ProviderBackendMixin):
"""Feature provider class
Provide feature data.
"""
def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
@abc.abstractmethod
def feature(self, instrument, field, start_time, end_time, freq):
"""Get feature data.
@@ -478,13 +515,13 @@ class DatasetProvider(abc.ABC):
data = pd.DataFrame(obj)
_calendar = Cal.calendar(freq=freq)
data.index = _calendar[data.index.values.astype(np.int)]
data.index = _calendar[data.index.values.astype(int)]
data.index.names = ["datetime"]
if spans is None:
return data
else:
mask = np.zeros(len(data), dtype=np.bool)
mask = np.zeros(len(data), dtype=bool)
for begin, end in spans:
mask |= (data.index >= begin) & (data.index <= end)
return data[mask]
@@ -497,6 +534,7 @@ class LocalCalendarProvider(CalendarProvider):
"""
def __init__(self, **kwargs):
super(LocalCalendarProvider, self).__init__(**kwargs)
self.remote = kwargs.get("remote", False)
@property
@@ -517,18 +555,22 @@ class LocalCalendarProvider(CalendarProvider):
list
list of timestamps
"""
if future:
fname = self._uri_cal.format(freq + "_future")
# if future calendar not exists, return current calendar
if not os.path.exists(fname):
get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!")
fname = self._uri_cal.format(freq)
else:
fname = self._uri_cal.format(freq)
if not os.path.exists(fname):
raise ValueError("calendar not exists for freq " + freq)
with open(fname) as f:
return [pd.Timestamp(x.strip()) for x in f]
try:
backend_obj = self.backend_obj(freq=freq, future=future).data
except ValueError:
if future:
get_module_logger("data").warning(
f"load calendar error: freq={freq}, future={future}; return current calendar!"
)
get_module_logger("data").warning(
"You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md"
)
backend_obj = self.backend_obj(freq=freq, future=False).data
else:
raise
return [pd.Timestamp(x) for x in backend_obj]
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
_calendar, _calendar_index = self._get_calendar(freq, future)
@@ -559,38 +601,20 @@ class LocalInstrumentProvider(InstrumentProvider):
Provide instrument data from local data source.
"""
def __init__(self):
pass
@property
def _uri_inst(self):
"""Instrument file uri."""
return os.path.join(C.get_data_path(), "instruments", "{}.txt")
def _load_instruments(self, market):
fname = self._uri_inst.format(market)
if not os.path.exists(fname):
raise ValueError("instruments not exists for market " + market)
_instruments = dict()
df = pd.read_csv(
fname,
sep="\t",
usecols=[0, 1, 2],
names=["inst", "start_datetime", "end_datetime"],
dtype={"inst": str},
parse_dates=["start_datetime", "end_datetime"],
)
for row in df.itertuples(index=False):
_instruments.setdefault(row[0], []).append((row[1], row[2]))
return _instruments
def _load_instruments(self, market, freq):
return self.backend_obj(market=market, freq=freq).data
def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
market = instruments["market"]
if market in H["i"]:
_instruments = H["i"][market]
else:
_instruments = self._load_instruments(market)
_instruments = self._load_instruments(market, freq=freq)
H["i"][market] = _instruments
# strip
# use calendar boundary
@@ -601,7 +625,7 @@ class LocalInstrumentProvider(InstrumentProvider):
inst: list(
filter(
lambda x: x[0] <= x[1],
[(max(start_time, x[0]), min(end_time, x[1])) for x in spans],
[(max(start_time, pd.Timestamp(x[0])), min(end_time, pd.Timestamp(x[1]))) for x in spans],
)
)
for inst, spans in _instruments.items()
@@ -627,6 +651,7 @@ class LocalFeatureProvider(FeatureProvider):
"""
def __init__(self, **kwargs):
super(LocalFeatureProvider, self).__init__(**kwargs)
self.remote = kwargs.get("remote", False)
@property
@@ -638,14 +663,7 @@ class LocalFeatureProvider(FeatureProvider):
# validate
field = str(field).lower()[1:]
instrument = code_to_fname(instrument)
uri_data = self._uri_data.format(instrument.lower(), field, freq)
if not os.path.exists(uri_data):
get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field))
return pd.Series(dtype=np.float32)
# raise ValueError('uri_data not found: ' + uri_data)
# load
series = read_bin(uri_data, start_index, end_index)
return series
return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]
class LocalExpressionProvider(ExpressionProvider):
@@ -654,9 +672,6 @@ class LocalExpressionProvider(ExpressionProvider):
Provide expression data from local data source.
"""
def __init__(self):
super().__init__()
def expression(self, instrument, field, start_time=None, end_time=None, freq="day"):
expression = self.get_expression_instance(field)
start_time = pd.Timestamp(start_time)
@@ -1019,7 +1034,8 @@ class ClientProvider(BaseProvider):
self.logger = get_module_logger(self.__class__.__name__)
if isinstance(Cal, ClientCalendarProvider):
Cal.set_conn(self.client)
Inst.set_conn(self.client)
if isinstance(Inst, ClientInstrumentProvider):
Inst.set_conn(self.client)
if hasattr(DatasetD, "provider"):
DatasetD.provider.set_conn(self.client)
else:
@@ -1064,7 +1080,8 @@ def register_all_wrappers(C):
register_wrapper(Cal, _calendar_provider, "qlib.data")
logger.debug(f"registering Cal {C.calendar_provider}-{C.calendar_cache}")
register_wrapper(Inst, C.instrument_provider, "qlib.data")
_instrument_provider = init_instance_by_config(C.instrument_provider, module)
register_wrapper(Inst, _instrument_provider, "qlib.data")
logger.debug(f"registering Inst {C.instrument_provider}")
if getattr(C, "feature_provider", None) is not None:

View File

@@ -1,8 +1,9 @@
from ...utils.serial import Serializable
from typing import Union, List, Tuple
from typing import Union, List, Tuple, Dict, Text, Optional
from ...utils import init_instance_by_config, np_ffill
from ...log import get_module_logger
from .handler import DataHandler, DataHandlerLP
from copy import deepcopy
from inspect import getfullargspec
import pandas as pd
import numpy as np
@@ -16,22 +17,28 @@ class Dataset(Serializable):
Preparing data for model training and inferencing.
"""
def __init__(self, *args, **kwargs):
def __init__(self, **kwargs):
"""
init is designed to finish following steps:
- init the sub instance and the state of the dataset(info to prepare the data)
- The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing.
- setup data
- The data related attributes' names should start with '_' so that it will not be saved on disk when serializing.
- initialize the state of the dataset(info to prepare the data)
- The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing.
The data could specify the info to caculate the essential data for preparation
The data could specify the info to calculate the essential data for preparation
"""
self.setup_data(*args, **kwargs)
self.setup_data(**kwargs)
super().__init__()
def setup_data(self, *args, **kwargs):
def config(self, **kwargs):
"""
config is designed to configure and parameters that cannot be learned from the data
"""
super().config(**kwargs)
def setup_data(self, **kwargs):
"""
Setup the data.
@@ -39,7 +46,7 @@ class Dataset(Serializable):
- User have a Dataset object with learned status on disk.
- User load the Dataset object from the disk(Note the init function is skiped).
- User load the Dataset object from the disk.
- User call `setup_data` to load new data.
@@ -47,7 +54,7 @@ class Dataset(Serializable):
"""
pass
def prepare(self, *args, **kwargs) -> object:
def prepare(self, **kwargs) -> object:
"""
The type of dataset depends on the model. (It could be pd.DataFrame, pytorch.DataLoader, etc.)
The parameters should specify the scope for the prepared data
@@ -76,22 +83,7 @@ class DatasetH(Dataset):
- The processing is related to data split.
"""
def __init__(self, handler: Union[dict, DataHandler], segments: dict):
"""
Parameters
----------
handler : Union[dict, DataHandler]
handler will be passed into setup_data.
segments : dict
handler will be passed into setup_data.
"""
super().__init__(handler, segments)
def init(self, **kwargs):
"""Initialize the DatasetH, Only parameters belonging to handler.init will be passed in"""
self.handler.init(**kwargs)
def setup_data(self, handler: Union[dict, DataHandler], segments: dict):
def __init__(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], **kwargs):
"""
Setup the underlying data.
@@ -100,7 +92,7 @@ class DatasetH(Dataset):
handler : Union[dict, DataHandler]
handler could be:
- insntance of `DataHandler`
- instance of `DataHandler`
- config of `DataHandler`. Please refer to `DataHandler`
@@ -120,8 +112,57 @@ class DatasetH(Dataset):
'outsample': ("2017-01-01", "2020-08-01",),
}
"""
self.handler = init_instance_by_config(handler, accept_types=DataHandler)
self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler)
self.segments = segments.copy()
self.fetch_kwargs = {}
super().__init__(**kwargs)
def config(self, handler_kwargs: dict = None, **kwargs):
"""
Initialize the DatasetH
Parameters
----------
handler_kwargs : dict
Config of DataHandler, which could include the following arguments:
- arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'.
kwargs : dict
Config of DatasetH, such as
- segments : dict
Config of segments which is same as 'segments' in self.__init__
"""
if handler_kwargs is not None:
self.handler.config(**handler_kwargs)
if "segments" in kwargs:
self.segments = deepcopy(kwargs.pop("segments"))
super().config(**kwargs)
def setup_data(self, handler_kwargs: dict = None, **kwargs):
"""
Setup the Data
Parameters
----------
handler_kwargs : dict
init arguments of DataHandler, which could include the following arguments:
- init_type : Init Type of Handler
- enable_cache : whether to enable cache
"""
super().setup_data(**kwargs)
if handler_kwargs is not None:
self.handler.setup_data(**handler_kwargs)
def __repr__(self):
return "{name}(handler={handler}, segments={segments})".format(
name=self.__class__.__name__, handler=self.handler, segments=self.segments
)
def _prepare_seg(self, slc: slice, **kwargs):
"""
@@ -131,11 +172,14 @@ class DatasetH(Dataset):
----------
slc : slice
"""
return self.handler.fetch(slc, **kwargs)
if hasattr(self, "fetch_kwargs"):
return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs)
else:
return self.handler.fetch(slc, **kwargs)
def prepare(
self,
segments: Union[List[str], Tuple[str], str, slice],
segments: Union[List[Text], Tuple[Text], Text, slice],
col_set=DataHandler.CS_ALL,
data_key=DataHandlerLP.DK_I,
**kwargs,
@@ -145,7 +189,7 @@ class DatasetH(Dataset):
Parameters
----------
segments : Union[List[str], Tuple[str], str, slice]
segments : Union[List[Text], Tuple[Text], Text, slice]
Describe the scope of the data to be prepared
Here are some examples:
@@ -159,6 +203,12 @@ class DatasetH(Dataset):
The data to fetch: DK_*
Default is DK_I, which indicate fetching data for **inference**.
kwargs :
The parameters that kwargs may contain:
flt_col : str
It only exists in TSDatasetH, can be used to add a column of data(True or False) to filter data.
This parameter is only supported when it is an instance of TSDatasetH.
Returns
-------
Union[List[pd.DataFrame], pd.DataFrame]:
@@ -191,7 +241,7 @@ class TSDataSampler:
(T)ime-(S)eries DataSampler
This is the result of TSDatasetH
It works like `torch.data.utils.Dataset`, it provides a very convient interface for constructing time-series
It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series
dataset based on tabular data.
If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
@@ -203,7 +253,9 @@ class TSDataSampler:
"""
def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none"):
def __init__(
self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None, flt_data=None
):
"""
Build a dataset which looks like torch.data.utils.Dataset.
@@ -225,6 +277,11 @@ class TSDataSampler:
ffill with previous sample
ffill+bfill:
ffill with previous samples first and fill with later samples second
flt_data : pd.Series
a column of data(True or False) to filter data.
None:
kepp all data
"""
self.start = start
self.end = end
@@ -232,24 +289,51 @@ class TSDataSampler:
self.fillna_type = fillna_type
assert get_level_index(data, "datetime") == 0
self.data = lazy_sort_index(data)
self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values! But
# NOTE: append last line with full NaN for better performance in `__getitem__`
self.data_arr = np.append(self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan), axis=0)
kwargs = {"object": self.data}
if dtype is not None:
kwargs["dtype"] = dtype
self.data_arr = np.array(**kwargs) # Get index from numpy.array will much faster than DataFrame.values!
# NOTE:
# - append last line with full NaN for better performance in `__getitem__`
# - Keep the same dtype will result in a better performance
self.data_arr = np.append(
self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), axis=0
)
self.nan_idx = -1 # The last line is all NaN
# the data type will be changed
# The index of usable data is between start_idx and end_idx
self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
# self.index_link = self.build_link(self.data)
self.idx_df, self.idx_map = self.build_index(self.data)
self.data_index = deepcopy(self.data.index)
if flt_data is not None:
self.flt_data = np.array(flt_data.reindex(self.data_index)).reshape(-1)
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
del self.data # save memory
@staticmethod
def flt_idx_map(flt_data, idx_map):
idx = 0
new_idx_map = {}
for i, exist in enumerate(flt_data):
if exist:
new_idx_map[idx] = idx_map[i]
idx += 1
return new_idx_map
def get_index(self):
"""
Get the pandas index of the data, it will be useful in following scenarios
- Special sampler will be used (e.g. user want to sample day by day)
"""
return self.data.index[self.start_idx : self.end_idx]
return self.data_index[self.start_idx : self.end_idx]
def config(self, **kwargs):
# Config the attributes
@@ -273,7 +357,7 @@ class TSDataSampler:
# get the previous index of a line given index
"""
# object incase of pandas converting int to flaot
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=np.object)
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)
idx_df = lazy_sort_index(idx_df.unstack())
# NOTE: the correctness of `__getitem__` depends on columns sorted here
idx_df = lazy_sort_index(idx_df, axis=1)
@@ -375,7 +459,7 @@ class TSDataSampler:
# 1) for better performance, use the last nan line for padding the lost date
# 2) In case of precision problems. We use np.float64. # TODO: I'm not sure if whether np.float64 will result in
# precision problems. It will not cause any problems in my tests at least
indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(np.int)
indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(int)
data = self.data_arr[indices]
if isinstance(idx, mtit):
@@ -393,7 +477,7 @@ class TSDatasetH(DatasetH):
(T)ime-(S)eries Dataset (H)andler
Covnert the tabular data to Time-Series data
Convert the tabular data to Time-Series data
Requirements analysis
@@ -407,18 +491,22 @@ class TSDatasetH(DatasetH):
- The dimension of a batch of data <batch_idx, feature, timestep>
"""
def __init__(self, step_len=30, *args, **kwargs):
def __init__(self, step_len=30, **kwargs):
self.step_len = step_len
super().__init__(*args, **kwargs)
super().__init__(**kwargs)
def setup_data(self, *args, **kwargs):
super().setup_data(*args, **kwargs)
def config(self, **kwargs):
if "step_len" in kwargs:
self.step_len = kwargs.pop("step_len")
super().config(**kwargs)
def setup_data(self, **kwargs):
super().setup_data(**kwargs)
cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique()
cal = sorted(cal)
# Get the datatime index for building timestamp
self.cal = cal
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
def _prepare_raw_seg(self, slc: slice, **kwargs) -> pd.DataFrame:
# Dataset decide how to slice data(Get more data for timeseries).
start, end = slc.start, slc.stop
start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start))
@@ -427,6 +515,25 @@ class TSDatasetH(DatasetH):
# TSDatasetH will retrieve more data for complete
data = super()._prepare_seg(slice(pad_start, end), **kwargs)
return data
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len)
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
"""
split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data
"""
dtype = kwargs.pop("dtype", None)
start, end = slc.start, slc.stop
flt_col = kwargs.pop("flt_col", None)
# TSDatasetH will retrieve more data for complete
data = self._prepare_raw_seg(slc, **kwargs)
flt_kwargs = deepcopy(kwargs)
if flt_col is not None:
flt_kwargs["col_set"] = flt_col
flt_data = self._prepare_raw_seg(slc, **flt_kwargs)
assert len(flt_data.columns) == 1
else:
flt_data = None
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype, flt_data=flt_data)
return tsds

View File

@@ -6,7 +6,8 @@ import abc
import bisect
import logging
import warnings
from typing import Union, Tuple, List, Iterator, Optional
from inspect import getfullargspec
from typing import Callable, Union, Tuple, List, Iterator, Optional
import pandas as pd
import numpy as np
@@ -16,7 +17,7 @@ from ...data import D
from ...config import C
from ...utils import parse_config, transform_end_date, init_instance_by_config
from ...utils.serial import Serializable
from .utils import get_level_index, fetch_df_by_index
from .utils import fetch_df_by_index
from pathlib import Path
from .loader import DataLoader
@@ -35,7 +36,7 @@ class DataHandler(Serializable):
The data handler try to maintain a handler with 2 level.
`datetime` & `instruments`.
Any order of the index level can be suported(The order will implied in the data).
Any order of the index level can be supported (The order will be implied in the data).
The order <`datetime`, `instruments`> will be used when the dataframe index name is missed.
Example of the data:
@@ -47,9 +48,12 @@ class DataHandler(Serializable):
$close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
datetime instrument
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
Tips for improving the performance of datahandler
- Fetching data with `col_set=CS_RAW` will return the raw data and may avoid pandas from copying the data when calling `loc`
"""
def __init__(
@@ -57,7 +61,7 @@ class DataHandler(Serializable):
instruments=None,
start_time=None,
end_time=None,
data_loader: Tuple[dict, str, DataLoader] = None,
data_loader: Union[dict, str, DataLoader] = None,
init_data=True,
fetch_orig=True,
):
@@ -70,10 +74,10 @@ class DataHandler(Serializable):
start_time of the original data.
end_time :
end_time of the original data.
data_loader : Tuple[dict, str, DataLoader]
data_loader : Union[dict, str, DataLoader]
data loader to load the data.
init_data :
intialize the original data in the constructor.
initialize the original data in the constructor.
fetch_orig : bool
Return the original data instead of copy if possible.
"""
@@ -99,10 +103,10 @@ class DataHandler(Serializable):
self.fetch_orig = fetch_orig
if init_data:
with TimeInspector.logt("Init data"):
self.init()
self.setup_data()
super().__init__()
def conf_data(self, **kwargs):
def config(self, **kwargs):
"""
configuration of data.
# what data to be loaded from data source
@@ -115,13 +119,16 @@ class DataHandler(Serializable):
for k, v in kwargs.items():
if k in attr_list:
setattr(self, k, v)
else:
raise KeyError("Such config is not supported.")
def init(self, enable_cache: bool = False):
for attr in attr_list:
if attr in kwargs:
kwargs.pop(attr)
super().config(**kwargs)
def setup_data(self, enable_cache: bool = False):
"""
initialize the data.
In case of running intialization for multiple time, it will do nothing for the second time.
Set Up the data in case of running initialization for multiple time
It is responsible for maintaining following variable
1) self._data
@@ -159,6 +166,7 @@ class DataHandler(Serializable):
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = CS_ALL,
squeeze: bool = False,
proc_func: Callable = None,
) -> pd.DataFrame:
"""
fetch data from underlying data source
@@ -181,6 +189,14 @@ class DataHandler(Serializable):
- if isinstance(col_set, List[str]):
select several sets of meaningful columns, the returned data has multiple levels
proc_func: Callable
- Give a hook for processing data before fetching
- An example to explain the necessity of the hook:
- A Dataset learned some processors to process data which is related to data segmentation
- It will apply them every time when preparing data.
- The learned processor require the dataframe remains the same format when fitting and applying
- However the data format will change according to the parameters.
- So the processors should be applied to the underlayer data.
squeeze : bool
whether squeeze columns and index
@@ -189,8 +205,15 @@ class DataHandler(Serializable):
-------
pd.DataFrame.
"""
if proc_func is None:
df = self._data
else:
# FIXME: fetching by time first will be more friendly to `proc_func`
# Copy in case of `proc_func` changing the data inplace....
df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy())
# Fetch column first will be more friendly to SepDataFrame
df = self._fetch_df_by_col(self._data, col_set)
df = self._fetch_df_by_col(df, col_set)
df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
if squeeze:
# squeeze columns
@@ -257,6 +280,10 @@ class DataHandler(Serializable):
class DataHandlerLP(DataHandler):
"""
DataHandler with **(L)earnable (P)rocessor**
Tips to improving the performance of data handler
- To reduce the memory cost
- `drop_raw=True`: this will modify the data inplace on raw data;
"""
# data key
@@ -278,7 +305,7 @@ class DataHandlerLP(DataHandler):
instruments=None,
start_time=None,
end_time=None,
data_loader: Tuple[dict, str, DataLoader] = None,
data_loader: Union[dict, str, DataLoader] = None,
infer_processors=[],
learn_processors=[],
process_type=PTYPE_A,
@@ -405,14 +432,28 @@ class DataHandlerLP(DataHandler):
if self.drop_raw:
del self._data
def config(self, processor_kwargs: dict = None, **kwargs):
"""
configuration of data.
# what data to be loaded from data source
This method will be used when loading pickled handler from dataset.
The data will be initialized with different time range.
"""
super().config(**kwargs)
if processor_kwargs is not None:
for processor in self.get_all_processors():
processor.config(**processor_kwargs)
# init type
IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor
IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df
IT_LS = "load_state" # The state of the object has been load by pickle
def init(self, init_type: str = IT_FIT_SEQ, enable_cache: bool = False):
def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs):
"""
Initialize the data of Qlib
Set up the data in case of running initialization for multiple time
Parameters
----------
@@ -427,7 +468,7 @@ class DataHandlerLP(DataHandler):
when we call `init` next time
"""
# init raw data
super().init(enable_cache=enable_cache)
super().setup_data(**kwargs)
with TimeInspector.logt("fit & process data"):
if init_type == DataHandlerLP.IT_FIT_IND:
@@ -456,6 +497,7 @@ class DataHandlerLP(DataHandler):
level: Union[str, int] = "datetime",
col_set=DataHandler.CS_ALL,
data_key: str = DK_I,
proc_func: Callable = None,
) -> pd.DataFrame:
"""
fetch data from underlying data source
@@ -470,12 +512,18 @@ class DataHandlerLP(DataHandler):
select a set of meaningful columns.(e.g. features, columns).
data_key : str
the data to fetch: DK_*.
proc_func: Callable
please refer to the doc of DataHandler.fetch
Returns
-------
pd.DataFrame:
"""
df = self._get_df_by_key(data_key)
if proc_func is not None:
# FIXME: fetch by time first will be more friendly to proc_func
# Copy incase of `proc_func` changing the data inplace....
df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy())
# Fetch column first will be more friendly to SepDataFrame
df = self._fetch_df_by_col(df, col_set)
return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)

View File

@@ -13,6 +13,7 @@ from qlib.data import D
from qlib.data import filter as filter_module
from qlib.data.filter import BaseDFilter
from qlib.utils import load_dataset, init_instance_by_config
from qlib.log import get_module_logger
class DataLoader(abc.ABC):
@@ -217,3 +218,68 @@ class StaticDataLoader(DataLoader):
join=self.join,
)
self._data.sort_index(inplace=True)
class DataLoaderDH(DataLoader):
"""DataLoaderDH
DataLoader based on (D)ata (H)andler
It is designed to load multiple data from data handler
- If you just want to load data from single datahandler, you can write them in single data handler
TODO: What make this module not that easy to use.
- For online scenario
- The underlayer data handler should be configured. But data loader doesn't provide such interface & hook.
"""
def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False):
"""
Parameters
----------
handler_config : dict
handler_config will be used to describe the handlers
.. code-block::
<handler_config> := {
"group_name1": <handler>
"group_name2": <handler>
}
or
<handler_config> := <handler>
<handler> := DataHandler Instance | DataHandler Config
fetch_kwargs : dict
fetch_kwargs will be used to describe the different arguments of fetch method, such as col_set, squeeze, data_key, etc.
is_group: bool
is_group will be used to describe whether the key of handler_config is group
"""
from qlib.data.dataset.handler import DataHandler
if is_group:
self.handlers = {
grp: init_instance_by_config(config, accept_types=DataHandler) for grp, config in handler_config.items()
}
else:
self.handlers = init_instance_by_config(handler_config, accept_types=DataHandler)
self.is_group = is_group
self.fetch_kwargs = {"col_set": DataHandler.CS_RAW}
self.fetch_kwargs.update(fetch_kwargs)
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
if instruments is not None:
get_module_logger(self.__class__.__name__).warning(f"instruments[{instruments}] is ignored")
if self.is_group:
df = pd.concat(
{
grp: dh.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs)
for grp, dh in self.handlers.items()
},
axis=1,
)
else:
df = self.handlers.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs)
return df

18
qlib/data/dataset/processor.py Executable file → Normal file
View File

@@ -2,6 +2,7 @@
# Licensed under the MIT License.
import abc
from typing import Union, Text
import numpy as np
import pandas as pd
import copy
@@ -14,7 +15,7 @@ from ...utils.paral import datetime_groupby_apply
EPS = 1e-12
def get_group_columns(df: pd.DataFrame, group: str):
def get_group_columns(df: pd.DataFrame, group: Union[Text, None]):
"""
get a group of columns from multi-index columns DataFrame
@@ -72,6 +73,17 @@ class Processor(Serializable):
"""
return True
def config(self, **kwargs):
attr_list = {"fit_start_time", "fit_end_time"}
for k, v in kwargs.items():
if k in attr_list and hasattr(self, k):
setattr(self, k, v)
for attr in attr_list:
if attr in kwargs:
kwargs.pop(attr)
super().config(**kwargs)
class DropnaProcessor(Processor):
def __init__(self, fields_group=None):
@@ -118,7 +130,7 @@ class FilterCol(Processor):
class TanhProcess(Processor):
""" Use tanh to process noise data"""
"""Use tanh to process noise data"""
def __call__(self, df):
def tanh_denoise(data):
@@ -133,7 +145,7 @@ class TanhProcess(Processor):
class ProcessInf(Processor):
"""Process infinity """
"""Process infinity"""
def __call__(self, df):
def replace_inf(data):

View File

@@ -355,6 +355,7 @@ class ExpressionDFilter(SeriesDFilter):
all_filter_series = _features[rule_expression_field_name]
return all_filter_series
@staticmethod
def from_config(config):
return ExpressionDFilter(
rule_expression=config["rule_expression"],

Some files were not shown because too many files have changed in this diff Show More