1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 14:01:28 +08:00

Compare commits

...

188 Commits

Author SHA1 Message Date
Young
3fc2f8c93c updategrade version number 2021-09-16 02:15:16 +00:00
Anurag Kumar
66ff3e5bf6 Update python-publish.yml
added python 3.9
2021-09-16 10:09:39 +08:00
Anurag Kumar
8ff68a182e Update setup.py
change to matplotlib==3.3
2021-09-16 10:09:39 +08:00
Anurag Kumar
a105ef1d76 Update setup.py
updated classifiers
2021-09-16 10:09:39 +08:00
zhupr
d02965ea70 Fix SimpleDatasetCache 2021-09-16 10:08:56 +08:00
Christian Clauss
b8d1e08010 Fix undefined names in Python code (#599)
* Update pytorch_tabnet.py

$ `flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics`
```
./qlib/qlib/contrib/model/pytorch_tabnet.py:567:38: F821 undefined name 'inp'
            self.independ.append(GLU(inp, out_dim, vbs=vbs))
                                     ^
./qlib/examples/model_rolling/task_manager_rolling.py:75:18: F821 undefined name 'task_train'
        run_task(task_train, self.task_pool, experiment_name=self.experiment_name)
                 ^
2     F821 undefined name 'task_train'
2
```

* Fix undefined names in Python code

* from qlib.model.trainer import task_train
2021-09-14 12:13:27 +08:00
you-n-g
51709c20d8 Supporting shared processor (#596)
* Supporting shared processor

* fix readonly reverse bug

* remove pytests dependency

* with fit bug

* fix parameter error
2021-09-13 17:11:08 +08:00
Christian Clauss
28c99c77be test.yml: Remove redundant code (#595) 2021-09-13 14:31:32 +08:00
you-n-g
bb5cdfe050 Update Release Note 2021-09-12 17:06:00 +08:00
SaintMalik
fb21c591bb fix typos (#592) 2021-09-12 16:39:22 +08:00
Dong Zhou
5279e71423 Merge pull request #591 from evanzd/fix_tra
Fix TRA
2021-09-11 18:48:13 +08:00
Dong Zhou
f35254c288 update README 2021-09-10 07:38:22 +00:00
Pengrong Zhu
5e82c18cb2 Modify the Feature to be case sensitive (#589) 2021-09-10 11:47:23 +08:00
demon143
2759e8c28d Update the docs of TaskManager (#586)
* Update manage.py
2021-09-09 20:13:45 +08:00
you-n-g
2461575d30 Update README.md
Fix wrong link
2021-09-09 08:28:48 +08:00
Pengrong Zhu
867667531d Update FAQ.rst 2021-09-08 18:06:51 +08:00
zhupr
0fc52333b7 Add wheel package to github CI 2021-09-07 20:41:10 +08:00
zhupr
ab9b6dc47a Modify client-server mode and dataset-cache to disable inst_processor 2021-09-07 20:41:10 +08:00
zhupr
4c5a4d5cd7 Modify the default value in the multi_freq example 2021-09-07 20:41:10 +08:00
zhupr
e84cc23589 Add DataPathManager to QlibConfig && modify inst_processors to supports list only 2021-09-07 20:41:10 +08:00
zhupr
707399a245 Fix duplicate mlflow directories in tests 2021-09-07 20:41:10 +08:00
zhupr
6e88ccca88 Fix the index type of the multi-freq example 2021-09-07 20:41:10 +08:00
zhupr
ee5f3de800 Fix typo 2021-09-07 20:41:10 +08:00
zhupr
3605cd7b96 Add inst_processors to D.features 2021-09-07 20:41:10 +08:00
zhupr
d1cbf4c3d9 support multi-freq uri 2021-09-07 20:41:10 +08:00
zhupr
6011a21308 get_cls_kwargs renamed get_callable_kwargs 2021-09-07 20:41:10 +08:00
zhupr
76a05f37a9 add multi-freq example 2021-09-07 20:41:10 +08:00
zhupr
c99494eb76 Add sample_config to QlibDataLoader, support multi-freq 2021-09-07 20:41:10 +08:00
zhupr
e8126b0c39 Add backend_freq_config parameter, support multi-freq uri 2021-09-07 20:41:10 +08:00
Dong Zhou
8f4d320832 bug fix & use oracle transport pretrain 2021-08-30 07:32:04 +00:00
cslwqxx
e2739ac72c Update README.md 2021-08-29 12:29:11 +08:00
you-n-g
19d15ddc38 Merge pull request #513 from 2796gaurav/main
MVP for Indian Stocks in qlib using yahooquery
2021-08-26 20:59:26 +08:00
you-n-g
12af8f304b Delete .DS_Store 2021-08-26 15:36:35 +08:00
Mark Zhao
25b771ddf1 check lexsort in the 'lazy_sort_index' function (#566)
* check lexsort

* check lexsort

* lexsort comment

* lexsort comment
2021-08-25 18:07:30 +08:00
Pengrong Zhu
1158472489 Fix multi-process loop calls (#574) 2021-08-25 18:05:35 +08:00
you-n-g
84d2cb3226 Update gen.py (#576) 2021-08-25 18:05:10 +08:00
Wangwuyi123
509bfcb02e Fix CI Bug (#575)
Co-authored-by: yuxwang <anduinnn@foxmail.com>
2021-08-25 08:51:39 +08:00
demon143
6608a40965 Update ensemble.py (#560) 2021-08-14 18:07:49 +08:00
you-n-g
3e75cead93 code standard docs 2021-08-12 09:19:57 +00:00
you-n-g
6697f209d4 Conda Suggestion 2021-08-12 16:30:46 +08:00
you-n-g
e3b57b1901 Update README.md 2021-08-06 09:59:30 +08:00
you-n-g
82a5223166 Update README.md 2021-08-06 09:59:30 +08:00
ZhangTP1996
398131cff7 Update strategy.py 2021-08-05 17:21:10 +08:00
Dong Zhou
e71e2f941c fix tra when logdir is None 2021-08-02 19:02:37 +08:00
Dong Zhou
0483406c12 fix tra when logdir is None 2021-08-02 03:57:14 -07:00
Dong Zhou
da1f4db968 update README 2021-07-30 16:05:07 +08:00
Dong Zhou
a7c41b6969 improve pretrain 2021-07-30 16:05:07 +08:00
Dong Zhou
5b7b48e376 clean up 2021-07-30 16:05:07 +08:00
Dong Zhou
4f9f978909 fix TRA when use single head 2021-07-30 16:05:07 +08:00
Dong Zhou
319a2f38cc fix horizon 2021-07-30 16:05:07 +08:00
Dong Zhou
a2c38c979e format by black 2021-07-30 16:05:07 +08:00
Dong Zhou
07655f2d5b refactor TRA 2021-07-30 16:05:07 +08:00
Young
9303415666 refactor online serving rolling api 2021-07-29 18:13:12 +08:00
you-n-g
05d28469ad sort index after loader (#538)
make sure the fetch method is based on a index-sorted pd.DataFrame
2021-07-29 12:06:59 +08:00
you-n-g
dc6859bdd9 Fix docs of QlibRecorder 2021-07-26 19:00:47 +08:00
you-n-g
a6f9dde006 Update README.md 2021-07-26 18:36:09 +08:00
Young
1d22ee56d3 recorder support upload both raw file and directory 2021-07-25 16:35:16 +00:00
panshuaiyin
3810a4cd33 Update data.rst
use own alpha-factor
2021-07-22 20:07:04 +08:00
you-n-g
48af7126b6 Update news about models 2021-07-22 11:07:09 +08:00
Ying-Tao Luo
025b1dcff9 Add two new models in model zoo 2021-07-22 11:05:39 +08:00
Ying-Tao Luo
29e66b2dea Add two new model in zoo
Add transformer and localformer (SLGT) models for time series prediction in finance in the Quant Model Zoo.
2021-07-22 11:05:39 +08:00
Ying-Tao Luo
698e59ac72 Add performance of two new models
Add the performance of transformer and localformer.
2021-07-22 11:05:39 +08:00
Ying-Tao Luo
e006ef40ad Update pytorch_localformer_ts.py 2021-07-22 11:05:39 +08:00
Young
59d4bc9394 update run_all_model and black format 2021-07-22 11:05:39 +08:00
Ying-Tao Luo
b07e0bffb1 Add files via upload 2021-07-22 11:05:39 +08:00
Ying-Tao Luo
161343018f Add files via upload 2021-07-22 11:05:39 +08:00
Ying-Tao Luo
bee031af68 Add files via upload 2021-07-22 11:05:39 +08:00
Ying-Tao Luo
35840606a8 Update pytorch_localformer.py 2021-07-22 11:05:39 +08:00
Ying-Tao Luo
2df9b6e076 Add files via upload 2021-07-22 11:05:39 +08:00
Ying-Tao Luo
0c3eaf3f16 Add files via upload 2021-07-22 11:05:39 +08:00
Ying-Tao Luo
2eee064eb8 Add files via upload 2021-07-22 11:05:39 +08:00
Ying-Tao Luo
096ef5a62b Update pytorch_transformer.py
Have passed black
2021-07-22 11:05:39 +08:00
Ying-Tao Luo
dd0eebed53 Update pytorch_localformer.py
Have passed black.
2021-07-22 11:05:39 +08:00
Ying-Tao Luo
7b20abeda1 Add files via upload
Add naive transformer model and a improved transformer model.
2021-07-22 11:05:39 +08:00
you-n-g
5519420efd Update test_macos.yml
Give more comments about the MacOS test yaml
2021-07-21 18:30:25 +08:00
zhupr
eb3c5b3088 macos-test-ci split out separately 2021-07-21 18:25:31 +08:00
zhupr
f03df874bf fix macos-test-ci 2021-07-21 18:25:31 +08:00
2796gaurav
8fa22bd2e1 added 1min for IN and also updated readme 2021-07-21 14:16:22 +05:30
Gaurav
d1c8d885aa cleaned the code 2021-07-21 17:59:50 +05:30
zhupr
bf7732e284 fix df_features.index conține np.nan 2021-07-21 14:28:20 +08:00
wuzhe1234
3f5334ab39 Update qrun to automaticly save the config to the artifacts uri 2021-07-19 13:32:14 +08:00
zhupr
c97a96363d Add a check if change is mutated to YahooNormalize1d 2021-07-18 20:28:46 +08:00
slowy07
2023f714c9 [fixed] lgtm issue : unused imported module of 'signal' and change to PEP8 style code imported module 2021-07-18 15:25:18 +08:00
slowy07
f8a2b0533b lgtm issue: fixing unused import of 'time' 2021-07-18 15:25:18 +08:00
chaosyu
3183a232df update doc str 2021-07-18 15:24:23 +08:00
chaosyu
8b715268bd use list_kwargs instead filter_string 2021-07-18 15:24:23 +08:00
chaosyu
28cb827a23 fix lint issue 2021-07-18 15:24:23 +08:00
chaosyu
b723f14619 apply filter string to recorder collector 2021-07-18 15:24:23 +08:00
chaosyu
47535ba530 add mlflow filter string support to limit too much run number 2021-07-18 15:24:23 +08:00
Gaurav
d70e5a4f88 add YahooNormalizeIN and YahooNormalizeIN1d 2021-07-17 10:40:16 +05:30
you-n-g
3b8087677c Update online.rst 2021-07-16 12:24:33 +08:00
zhupr
4ec41ea0e7 Add a check if change is mutated to YahooNormalize1d 2021-07-15 19:13:25 +08:00
Gaurav
cfcd9fb1f8 cleaned with black 2021-07-15 11:24:41 +05:30
Gaurav
457dcaa466 cleaned with black 2021-07-14 20:12:00 +05:30
Gaurav
3c740fc2de MVP for Indian Stocks in qlib using yahooquery 2021-07-14 19:54:55 +05:30
you-n-g
6d91f28474 Update README.md 2021-07-14 10:07:02 +08:00
you-n-g
be8653c505 Update contributing section 2021-07-14 09:56:12 +08:00
chaosyu
a8974ce535 bug fix: ClientProvider cannot set connection to calendar and instrument providers 2021-07-13 10:49:21 +08:00
chaosyu
79026e5390 fix bug that duplicate rows will cause reindex failed when dumping with csv files 2021-07-13 10:49:21 +08:00
Gaurav Chauhan
4610e16ac2 updated readme of yahoo collector where region parameter was incorrect (#504)
* updated readme of yahoo collector where region parameter was incorrect

* changes

update readme of yahoo collector where region parameter was incorrect

* update readme of yahoo collector

update readme of yahoo collector where region parameter was incorrect

* updated changes

* updated readme of cn1d data

Co-authored-by: Gaurav Chauhan01/HO/Analytics/General <Gaurav.Chauhan01@bajajallianz.in>
2021-07-13 09:46:13 +08:00
wangwenxi.handsome
b504cc6ac8 update readme and rst 2021-07-12 21:51:08 +08:00
Young
d5059e609f change to dev version 2021-07-12 02:49:25 +00:00
Young
215f7e0d22 update version for release 0.7.0 2021-07-11 14:34:44 +00:00
xiaowuhu
dafef0ac08 Update workflow.rst
should be China instead of china
2021-07-06 09:22:11 +08:00
xiaowuhu
1cb43ea69b Update workflow.rst
remove 空格 before module_path, kwargs, etc, otherwise, yaml parser will report error: ruamel.yaml.scanner.ScannerError: mapping values are not allowed here
2021-07-06 09:21:14 +08:00
you-n-g
7ca9cf79f7 Update README.md 2021-07-05 19:47:49 +08:00
you-n-g
35f090a6e4 Update what's new 2021-07-04 16:47:33 +08:00
Lewen Wang
ace7484304 Update TCTS. (#495)
* Update TCTS Model.

Co-authored-by: lewwang <lwwang@microsoft.com>
2021-07-04 16:45:05 +08:00
bxdd
2d4f0e80f9 black format 2021-07-02 08:47:52 +08:00
bxdd
946c9392a1 support check_transform_proc module_path 2021-07-02 08:47:52 +08:00
lzh222333
b523b27d5a add docstring 2021-06-30 10:59:34 +08:00
lzh222333
0b83fb3564 more general exception 2021-06-30 10:59:34 +08:00
lzh222333
d96f7a67c6 bug & docs fixed 2021-06-30 10:59:34 +08:00
lzh222333
a7862387a2 fixed update bugs 2021-06-30 10:59:34 +08:00
lzh222333
c4c438249c modify OnlineToolR 2021-06-30 10:59:34 +08:00
you-n-g
8709dde65b Merge pull request #481 from ai4stocks/working_workflow_fix_ipynb
examples/workflow_by_code.ipynd: fix an error in R.get_recorder() par…
2021-06-26 19:57:24 +08:00
Guodong Xu
d66733c358 examples/workflow_by_code.ipynd: fix an error in R.get_recorder() parameters
get_recorder() needs specify 'recorder_id='. However workflow_by_code.ipynd
didn't. This patch fixes it.

Without this fix, here is the error message jupyter-notebook reports:

"---------------------------------------------------------------------------
TypeError Traceback (most recent call last)

<ipython-input-7-e6a7b5f4da00> in <module>
26 # backtest and analysis
27 with R.start(experiment_name="backtest_analysis"):
---> 28 recorder = R.get_recorder(rid, experiment_name="train_model")
29 model = recorder.load_object("trained_model")
30

TypeError: get_recorder() takes 1 positional argument but 2 positional arguments (and 1 keyword-only argument) were given"

Signed-off-by: Guodong Xu <guodong.xu@linaro.org>
2021-06-26 18:25:47 +08:00
Dong Zhou
9cf574b697 Merge pull request #479 from linhx25/main
Add TRA Model
2021-06-25 18:08:23 +08:00
linhx25
107e40f3ee Add TRA Model 2021-06-25 16:12:50 +08:00
you-n-g
4837ba8db3 Merge pull request #476 from bxdd/qlib_ops_config
Support using config to register custom operators
2021-06-24 20:41:48 +08:00
Qian Chen
2ab4a9adb3 Set self.fitted = True instead of self._fitted. 2021-06-24 20:40:59 +08:00
bxdd
8d0b673341 add custom_ops docstring 2021-06-24 15:00:45 +08:00
you-n-g
8ebdb1e873 Merge pull request #463 from zhupr/support_extend_data
Support extend data
2021-06-24 13:53:30 +08:00
zhupr
39340fbf06 fix: typo 2021-06-24 11:07:40 +08:00
zhupr
0e277723a3 Merge remote-tracking branch 'qlib/main' into qlib_main
# Conflicts:
#	scripts/data_collector/yahoo/README.md
2021-06-24 00:09:54 +08:00
zhupr
1418417034 fix automatic update of daily frequency data 2021-06-23 23:59:59 +08:00
you-n-g
b261f7b501 Update README.md 2021-06-23 20:51:21 +08:00
zhupr
bab50e8837 fix YahooNormalize1min && update docs 2021-06-23 16:13:26 +08:00
bxdd
0eee4a0f2e support config custom_ops 2021-06-23 15:56:36 +08:00
Young
21eb71d4a9 update framework for online serving 2021-06-23 02:05:38 +00:00
zhupr
46714adf4c modify the YahooNormalize1min factor calculation 2021-06-22 11:15:09 +08:00
zhupr
99fb49650a add end_date parameter to collector.normalize_data 2021-06-21 17:20:37 +08:00
zhupr
985fd0816c Fix cn_index.collector network error 2021-06-21 17:18:04 +08:00
Young
d0f54343c7 support subclass of TSDatasetH 2021-06-21 00:24:31 +08:00
Young
a3679e6758 simplify the code and prevent float when shifting 2021-06-21 00:24:31 +08:00
zhupr
b6c31540e8 add function to automatically update daily frequency data 2021-06-17 23:07:56 +08:00
zhupr
a4f6e04199 modify dump_update starts with the last end date of each symbol 2021-06-17 22:33:31 +08:00
you-n-g
0aee46ee79 Merge pull request #466 from you-n-g/online_hotfix
Online bug fix, enhancement &  docs for dataset, workflow, trainer ...
2021-06-17 11:38:44 +08:00
Young
9c8d423a86 fix ModelUpdater 2021-06-16 14:10:51 +00:00
zhupr
b4efbd53b2 Fix 'report' compatibility with matplotlib versions 2021-06-16 22:00:43 +08:00
you-n-g
5a50d7c952 Merge pull request #471 from Derek-Wds/main
Update Recorder Wrapper to prevent reinitialization
2021-06-16 17:46:31 +08:00
Jactus
0fe8b281ba Update R wrapper logic 2021-06-16 12:28:20 +08:00
lewwang
5331ab93f8 Update TCTS README. 2021-06-16 12:23:22 +08:00
Jactus
64582e9d46 Add QlibException 2021-06-15 15:02:11 +08:00
Jactus
9e0e2ff736 Update QlibRecorder wrapper 2021-06-15 14:46:31 +08:00
Young
973c4137e4 fix mlflow & task bug 2021-06-12 13:54:26 +00:00
Young
730f6258d6 add warning and * 2021-06-11 10:40:56 +00:00
Young
5850490b24 simplify the code and add docs 2021-06-11 08:29:10 +00:00
Young
d4b36bdab4 Online fix
- Skip duplicated qlib.auto_init()
- Fix TSDatasetH flt_col bug!
- Resolve qlib log attribute confliction
- Trainer API enhancement
- More docs and user-friendly warning
2021-06-11 02:06:07 +00:00
you-n-g
40416d8c30 Merge pull request #464 from lwwang1995/main
Add TCTS baseline.
2021-06-10 10:18:20 +08:00
lewwang
567e42840c asdf 2021-06-09 18:37:25 +08:00
lewwang
65ddca133f asdf 2021-06-09 18:36:12 +08:00
lewwang
d199256d34 asdf 2021-06-09 18:35:14 +08:00
lewwang
073fe4668e asdf 2021-06-09 18:34:31 +08:00
lewwang
89d53853e5 asdf 2021-06-09 18:30:42 +08:00
lewwang
bb6c1572ca asdf 2021-06-09 18:29:55 +08:00
lewwang
4c4e77b11f asdf 2021-06-09 18:28:31 +08:00
lewwang
38c7b7303a dsaf 2021-06-09 18:26:50 +08:00
lewwang
02d0eedd68 update 2021-06-09 18:21:16 +08:00
lewwang
5a3dde93a8 update 2021-06-09 18:15:06 +08:00
lewwang
177f6a59d2 asdf 2021-06-09 17:47:24 +08:00
lewwang
492a62a569 tcts demo page 2021-06-09 17:32:24 +08:00
zhupr
9a44fbf9c1 fix PEP8: qlib/scripts/data_collector/fund/collector.py 2021-06-08 22:52:31 +08:00
zhupr
03eb0882de fix YahooNormalizeCN1minOffline bugs 2021-06-08 22:23:05 +08:00
zhupr
a845a2271b add normalize 1min to use local data && change the default parameters for collecting 1min 2021-06-08 14:45:20 +08:00
you-n-g
ba021f6007 Merge pull request #462 from arisliang/patch-1
Remove non-existing parameter description
2021-06-08 13:03:43 +08:00
al
7d9544fb91 Remove non-existing parameter from doc
Remove non-existing TradeExchange parameter from generate_target_weight_position doc
2021-06-08 09:35:36 +08:00
you-n-g
12b7be333d Merge pull request #461 from Derek-Wds/main
Fix exception hook bug
2021-06-07 21:07:33 +08:00
Jactus
ed54f1213c Fix exception hook bug 2021-06-07 17:13:36 +08:00
zhupr
554b9c7826 fix YahooCollector getting 1min data occasionally missing 2021-06-05 23:43:48 +08:00
zhupr
6f150f3fd6 Add YahooCollector support for extend data 2021-06-04 22:28:42 +08:00
you-n-g
2a0d991d9b Merge pull request #459 from you-n-g/online_srv
fix DelayTrainerRM
2021-06-03 15:55:11 +08:00
lzh222333
1320e53f81 fix DelayTrainerRM 2021-06-03 03:23:48 +00:00
Young
8222795ac4 fix format with black 2021-06-02 09:16:46 +00:00
you-n-g
616a742db7 Merge pull request #435 from you-n-g/online_srv
Multiprocessing support for Online Serving
2021-06-02 17:12:19 +08:00
lzh222333
811d2c975e update & fix 2021-06-02 08:56:15 +00:00
lzh222333
6272ce108f Merge remote-tracking branch 'microsoft/main' into online_srv 2021-06-02 08:32:12 +00:00
you-n-g
64896745d0 Merge pull request #457 from zhupr/fix_XGBoost_predict_error
fix XGBoost predict error
2021-06-02 16:14:18 +08:00
zhupr
b2fe2385d5 fix XGBoost predict error 2021-06-01 21:02:32 +08:00
lzh222333
8d05cd2daf modify tests.config.py 2021-06-01 09:40:53 +00:00
lzh222333
231bdf8608 Merge remote-tracking branch 'microsoft/main' into online_srv 2021-06-01 08:29:02 +00:00
lzh222333
ab6b88ce14 delete useless import 2021-06-01 07:48:14 +00:00
lzh222333
94ab4bbf3f add docs 2021-06-01 07:45:39 +00:00
lzh222333
ca0363ded8 update trainer and manage 2021-05-27 06:04:46 +00:00
lzh222333
a467e10974 Merge remote-tracking branch 'microsoft/main' into online_srv 2021-05-24 05:10:15 +00:00
lzh222333
6dfbf00a23 Merge branch 'microsoft_main' into online_srv 2021-05-24 05:07:53 +00:00
lzh222333
b24af7fff6 multiprocessing support 2021-05-24 05:07:38 +00:00
lwwang1995
45f73361e3 add tcts baseline 2021-03-18 11:17:42 +08:00
121 changed files with 8843 additions and 1041 deletions

View File

@@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
os: [windows-latest, macos-latest]
python-version: [3.6, 3.7, 3.8]
python-version: [3.6, 3.7, 3.8, 3.9]
steps:
- uses: actions/checkout@v2

View File

@@ -1,4 +1,4 @@
name: Test
name: Test
on:
push:
@@ -12,8 +12,8 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [windows-latest, ubuntu-16.04, ubuntu-18.04, ubuntu-20.04, macos-latest]
python-version: [3.6, 3.7, 3.8, 3.9]
os: [windows-latest, ubuntu-16.04, ubuntu-18.04, ubuntu-20.04]
python-version: [3.6, 3.7, 3.8]
steps:
- uses: actions/checkout@v2
@@ -25,96 +25,41 @@ jobs:
- name: Lint with Black
run: |
cd ..
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe -m pip install black
$CONDA\\python.exe -m black qlib -l 120 --check --diff
else
sudo $CONDA/bin/python -m pip install black
$CONDA/bin/python -m black qlib -l 120 --check --diff
fi
shell: bash
pip install --upgrade pip
pip install black wheel
black qlib -l 120 --check --diff
# Test Qlib installed with pip
- name: Install Qlib with pip
run: |
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe -m pip install numpy==1.19.5
$CONDA\\python.exe -m pip install pyqlib --ignore-installed ruamel.yaml numpy --user
else
sudo $CONDA/bin/python -m pip install numpy==1.19.5
sudo $CONDA/bin/python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
fi
shell: bash
- name: Install Lightgbm for MacOS
if: runner.os == 'macOS'
run: |
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
pip install numpy==1.19.5 ruamel.yaml
pip install pyqlib --ignore-installed
- name: Test data downloads
run: |
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
else
$CONDA/bin/python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
fi
shell: bash
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
- name: Test workflow by config (install from pip)
run: |
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe qlib\\workflow\\cli.py examples\\benchmarks\\LightGBM\\workflow_config_lightgbm_Alpha158.yaml
$CONDA\\python.exe -m pip uninstall -y pyqlib
else
$CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
sudo $CONDA/bin/python -m pip uninstall -y pyqlib
fi
shell: bash
# Test Qlib installed from source
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
python -m pip uninstall -y pyqlib
# Test Qlib installed from source
- name: Install Qlib from source
run: |
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe -m pip install --upgrade cython
$CONDA\\python.exe -m pip install numpy jupyter jupyter_contrib_nbextensions
$CONDA\\python.exe -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
$CONDA\\python.exe setup.py install
else
sudo $CONDA/bin/python -m pip install --upgrade cython
sudo $CONDA/bin/python -m pip install numpy jupyter jupyter_contrib_nbextensions
sudo $CONDA/bin/python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
sudo $CONDA/bin/python setup.py install
fi
shell: bash
pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
pip install -e .
- name: Install test dependencies
run: |
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe -m pip install --upgrade pip
$CONDA\\python.exe -m pip install black pytest
else
sudo $CONDA/bin/python -m pip install --upgrade pip
sudo $CONDA/bin/python -m pip install black pytest
fi
shell: bash
pip install --upgrade pip
pip install black pytest
- name: Unit tests with Pytest
run: |
cd tests
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe -m pytest . --durations=0
else
$CONDA/bin/python -m pytest . --durations=0
fi
shell: bash
python -m pytest . --durations=10
- name: Test workflow by config (install from source)
run: |
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe qlib\\workflow\\cli.py examples\\benchmarks\\LightGBM\\workflow_config_lightgbm_Alpha158.yaml
else
$CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
fi
shell: bash
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml

67
.github/workflows/test_macos.yml vendored Normal file
View File

@@ -0,0 +1,67 @@
# There are some issues (in the downloading data phase) on MacOS when running with other tests. So we split it into an individual config.
name: Test MacOS
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build:
runs-on: macos-latest
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Lint with Black
run: |
cd ..
python -m pip install pip --upgrade
python -m pip install wheel --upgrade
python -m pip install black
python -m black qlib -l 120 --check --diff
# Test Qlib installed with pip
- name: Install Qlib with pip
run: |
python -m pip install numpy==1.19.5
python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
- name: Install Lightgbm for MacOS
run: |
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
- name: Test data downloads
run: |
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
- name: Test workflow by config (install from pip)
run: |
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
python -m pip uninstall -y pyqlib
# Test Qlib installed from source
- name: Install Qlib from source
run: |
python -m pip install --upgrade cython
python -m pip install numpy jupyter jupyter_contrib_nbextensions
python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
python setup.py install
- name: Install test dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -U pyopenssl idna
python -m pip install black pytest
- name: Unit tests with Pytest
run: |
cd tests
python -m pytest . --durations=0
- name: Test workflow by config (install from source)
run: |
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml

View File

@@ -11,6 +11,10 @@
Recent released features
| Feature | Status |
| -- | ------ |
|Temporal Routing Adaptor (TRA) | [Released](https://github.com/microsoft/qlib/pull/531) on July 30, 2021 |
| Transformer & Localformer | [Released](https://github.com/microsoft/qlib/pull/508) on July 22, 2021 |
| Release Qlib v0.7.0 | [Released](https://github.com/microsoft/qlib/releases/tag/v0.7.0) on July 12, 2021 |
| TCTS Model | [Released](https://github.com/microsoft/qlib/pull/491) on July 1, 2021 |
| Online serving and automatic model rolling | :star: [Released](https://github.com/microsoft/qlib/pull/290) on May 17, 2021 |
| DoubleEnsemble Model | [Released](https://github.com/microsoft/qlib/pull/286) on Mar 2, 2021 |
| High-frequency data processing example | [Released](https://github.com/microsoft/qlib/pull/257) on Feb 5, 2021 |
@@ -42,7 +46,7 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
- [Data Preparation](#data-preparation)
- [Auto Quant Research Workflow](#auto-quant-research-workflow)
- [Building Customized Quant Research Workflow by Code](#building-customized-quant-research-workflow-by-code)
- [**Quant Model Zoo**](#quant-model-zoo)
- [**Quant Model(Paper) Zoo**](#quant-model-paper-zoo)
- [Run a single model](#run-a-single-model)
- [Run multiple models](#run-multiple-models)
- [**Quant Dataset Zoo**](#quant-dataset-zoo)
@@ -68,7 +72,7 @@ Your feedbacks about the features are very important.
# Framework of Qlib
<div style="align: center">
<img src="http://fintech.msra.cn/images_v060/framework.png?v=0.1" />
<img src="http://fintech.msra.cn/images_v060/framework.png?v=0.2" />
</div>
@@ -104,8 +108,9 @@ This table demonstrates the supported Python version of `Qlib`:
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
**Note**:
1. **Conda** is suggested for managing your Python environment.
1. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.7 or use `conda`'s Python to install ``Qlib`` from source.
2. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
1. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
### Install with pip
Users can easily install ``Qlib`` by pip according to the following command.
@@ -159,6 +164,28 @@ Users could create the same dataset with it.
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup), and the data might not be perfect.
We recommend users to prepare their own data if they have a high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
### Automatic update of daily frequency data (from yahoo finance)
> It is recommended that users update the data manually once (--trading_date 2021-05-25) and then set it to update automatically.
> For more information refer to: [yahoo collector](https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
* Automatic update of data to the "qlib" directory each trading day(Linux)
* use *crontab*: `crontab -e`
* set up timed tasks:
```
* * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>
```
* **script path**: *scripts/data_collector/yahoo/collector.py*
* Manual update of data
```
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
```
* *trading_date*: start of trading day
* *end_date*: end of trading day(not included)
<!--
- Run the initialization code and get stock data:
@@ -251,21 +278,24 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
The automatic workflow may not suit the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
# [Quant Model Zoo](examples/benchmarks)
# [Quant Model (Paper) Zoo](examples/benchmarks)
Here is a list of models built on `Qlib`.
- [GBDT based on XGBoost (Tianqi Chen, et al. 2016)](qlib/contrib/model/xgboost.py)
- [GBDT based on LightGBM (Guolin Ke, et al. 2017)](qlib/contrib/model/gbdt.py)
- [GBDT based on Catboost (Liudmila Prokhorenkova, et al. 2017)](qlib/contrib/model/catboost_model.py)
- [GBDT based on XGBoost (Tianqi Chen, et al. KDD 2016)](qlib/contrib/model/xgboost.py)
- [GBDT based on LightGBM (Guolin Ke, et al. NIPS 2017)](qlib/contrib/model/gbdt.py)
- [GBDT based on Catboost (Liudmila Prokhorenkova, et al. NIPS 2018)](qlib/contrib/model/catboost_model.py)
- [MLP based on pytorch](qlib/contrib/model/pytorch_nn.py)
- [LSTM based on pytorch (Sepp Hochreiter, et al. 1997)](qlib/contrib/model/pytorch_lstm.py)
- [LSTM based on pytorch (Sepp Hochreiter, et al. Neural omputation 1997)](qlib/contrib/model/pytorch_lstm.py)
- [GRU based on pytorch (Kyunghyun Cho, et al. 2014)](qlib/contrib/model/pytorch_gru.py)
- [ALSTM based on pytorch (Yao Qin, et al. 2017)](qlib/contrib/model/pytorch_alstm.py)
- [ALSTM based on pytorch (Yao Qin, et al. IJCAI 2017)](qlib/contrib/model/pytorch_alstm.py)
- [GATs based on pytorch (Petar Velickovic, et al. 2017)](qlib/contrib/model/pytorch_gats.py)
- [SFM based on pytorch (Liheng Zhang, et al. 2017)](qlib/contrib/model/pytorch_sfm.py)
- [TFT based on tensorflow (Bryan Lim, et al. 2019)](examples/benchmarks/TFT/tft.py)
- [TabNet based on pytorch (Sercan O. Arik, et al. 2019)](qlib/contrib/model/pytorch_tabnet.py)
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. 2020)](qlib/contrib/model/double_ensemble.py)
- [SFM based on pytorch (Liheng Zhang, et al. KDD 2017)](qlib/contrib/model/pytorch_sfm.py)
- [TFT based on tensorflow (Bryan Lim, et al. International Journal of Forecasting 2019)](examples/benchmarks/TFT/tft.py)
- [TabNet based on pytorch (Sercan O. Arik, et al. AAAI 2019)](qlib/contrib/model/pytorch_tabnet.py)
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. ICDM 2020)](qlib/contrib/model/double_ensemble.py)
- [TCTS based on pytorch (Xueqing Wu, et al. ICML 2021)](qlib/contrib/model/pytorch_tcts.py)
- [Transformer based on pytorch (Ashish Vaswani, et al. NeurIPS 2017)](qlib/contrib/model/pytorch_transformer.py)
- [Localformer based on pytorch (Juyong Jiang, et al.)](qlib/contrib/model/pytorch_localformer.py)
Your PR of new Quant models is highly welcomed.
@@ -281,7 +311,7 @@ All the models listed above are runnable with ``Qlib``. Users can find the confi
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
## Run multiple models
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parrallel running the same model for multiple times as well, and this will be fixed in the future development too.)
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parallel running the same model for multiple times as well, and this will be fixed in the future development too.)
The script will create a unique virtual environment for each model, and delete the environments after training. Thus, only experiment results such as `IC` and `backtest` results will be generated and stored.
@@ -346,9 +376,7 @@ Such overheads greatly slow down the data loading process.
Qlib data are stored in a compact format, which is efficient to be combined into arrays for scientific computation.
# Related Reports
- [【华泰金工林晓明团队】图神经网络选股与Qlib实践——华泰人工智能系列之四十二](https://mp.weixin.qq.com/s/w5fDB6oAv9dO6vlhf1kmhA)
- [Guide To Qlib: Microsofts AI Investment Platform](https://analyticsindiamag.com/qlib/)
- [【华泰金工林晓明团队】微软AI量化投资平台Qlib体验——华泰人工智能系列之四十](https://mp.weixin.qq.com/s/Brcd7im4NibJOJzZfMn6tQ)
- [微软也搞AI量化平台还是开源的](https://mp.weixin.qq.com/s/47bP5YwxfTp2uTHjUBzJQQ)
- [微矿Qlib业内首个AI量化投资开源平台](https://mp.weixin.qq.com/s/vsJv7lsgjEi-ALYUz4CvtQ)
@@ -365,7 +393,12 @@ Join IM discussion groups:
# Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a
This project welcomes contributions and suggestions.
**Here are some
[code standards](docs/developer/code_standard.rst) when you submit a pull request.**
Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the right to use your contribution. For details, visit https://cla.opensource.microsoft.com.

View File

@@ -97,4 +97,57 @@ Also, feel free to post a new issue in our GitHub repository. We always check ea
python setup.py build_ext --inplace
- If the error occurs when importing ``qlib`` package with command ``python`` , users need to change the running directory to ensure that the script does not run in the project directory.
- If the error occurs when importing ``qlib`` package with command ``python`` , users need to change the running directory to ensure that the script does not run in the project directory.
4. BadNamespaceError: / is not a connected namespace
------------------------------------------------------------------------------------------------------------------------------------
.. code-block:: python
File "qlib_online.py", line 35, in <module>
cal = D.calendar()
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 973, in calendar
return Cal.calendar(start_time, end_time, freq, future=future)
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 798, in calendar
self.conn.send_request(
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\client.py", line 101, in send_request
self.sio.emit(request_type + "_request", request_content)
File "G:\apps\miniconda\envs\qlib\lib\site-packages\python_socketio-5.3.0-py3.8.egg\socketio\client.py", line 369, in emit
raise exceptions.BadNamespaceError(
BadNamespaceError: / is not a connected namespace.
- The version of ``python-socketio`` in qlib needs to be the same as the version of ``python-socketio`` in qlib-server:
.. code-block:: bash
pip install -U python-socketio==<qlib-server python-socketio version>
5. TypeError: send() got an unexpected keyword argument 'binary'
------------------------------------------------------------------------------------------------------------------------------------
.. code-block:: python
File "qlib_online.py", line 35, in <module>
cal = D.calendar()
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 973, in calendar
return Cal.calendar(start_time, end_time, freq, future=future)
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 798, in calendar
self.conn.send_request(
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\client.py", line 101, in send_request
self.sio.emit(request_type + "_request", request_content)
File "G:\apps\miniconda\envs\qlib\lib\site-packages\socketio\client.py", line 263, in emit
self._send_packet(packet.Packet(packet.EVENT, namespace=namespace,
File "G:\apps\miniconda\envs\qlib\lib\site-packages\socketio\client.py", line 339, in _send_packet
self.eio.send(ep, binary=binary)
TypeError: send() got an unexpected keyword argument 'binary'
- The ``python-engineio`` version needs to be compatible with the ``python-socketio`` version, reference: https://github.com/miguelgrinberg/python-socketio#version-compatibility
.. code-block:: bash
pip install -U python-engineio==<compatible python-socketio version>
# or
pip install -U python-socketio==3.1.2 python-engineio==3.13.2

Binary file not shown.

Before

Width:  |  Height:  |  Size: 271 KiB

After

Width:  |  Height:  |  Size: 208 KiB

View File

@@ -67,6 +67,34 @@ After running the above command, users can find china-stock and us-stock data in
When ``Qlib`` is initialized with this dataset, users could build and evaluate their own models with it. Please refer to `Initialization <../start/initialization.html>`_ for more details.
Automatic update of daily frequency data
----------------------------------------
**It is recommended that users update the data manually once (\-\-trading_date 2021-05-25) and then set it to update automatically.**
For more information refer to: `yahoo collector <https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#Automatic-update-of-daily-frequency-data>`_
- Automatic update of data to the "qlib" directory each trading day(Linux)
- use *crontab*: `crontab -e`
- set up timed tasks:
.. code-block:: bash
* * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>
- **script path**: *scripts/data_collector/yahoo/collector.py*
- Manual update of data
.. code-block:: bash
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
- *trading_date*: start of trading day
- *end_date*: end of trading day(not included)
Converting CSV Format into Qlib Format
-------------------------------------------
@@ -151,6 +179,7 @@ After conversion, users can find their Qlib format data in the directory `~/.qli
The Restoration factor. Normally, ``factor = adjusted_price / original_price``, `adjusted price` reference: `split adjusted <https://www.investopedia.com/terms/s/splitadjusted.asp>`_
In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended.
If you want to use your own alpha-factor which can't be calculate by OCHLV, like PE, EPS and so on, you could add it to the CSV files with OHCLV together and then dump it to the Qlib format data.
Stock Pool (Market)
--------------------------------

View File

@@ -21,6 +21,8 @@ which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online S
If you have many models or `task` needs to be managed, please consider `Task Management <../advanced/task_management.html>`_.
The `examples <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are based on some components in `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``.
**NOTE**: User should keep his data source updated to support online serving. For example, Qlib provides `a batch of scripts <https://github.com/microsoft/qlib/blob/main/scripts/data_collector/yahoo/README.md#automatic-update-of-daily-frequency-datafrom-yahoo-finance>`_ to help users update Yahoo daily data.
Online Manager
=============
@@ -43,4 +45,4 @@ Updater
=============
.. automodule:: qlib.workflow.online.update
:members:
:members:

View File

@@ -90,12 +90,12 @@ Below is a typical config file of ``qrun``.
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
module_path: qlib.workflow.record_temp
kwargs: {}
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
After saving the config into `configuration.yaml`, users could start the workflow and test their ideas with a single command below.
@@ -142,7 +142,7 @@ The meaning of each field is as follows:
- `region`
- If `region` == "us", ``Qlib`` will be initialized in US-stock mode.
- If `region` == "cn", ``Qlib`` will be initialized in china-stock mode.
- If `region` == "cn", ``Qlib`` will be initialized in China-stock mode.
.. note::

View File

@@ -0,0 +1,20 @@
.. _code_standard:
=================================
Code Standard
=================================
Docstring
=================================
Please use the `Numpydoc Style <https://stackoverflow.com/a/24385103>`_.
Continuous Integration
=================================
Continuous Integration (CI) tools help you stick to the quality standards by running tests every time you push a new commit and reporting the results to a pull request.
A common error is the mixed use of space and tab. You can fix the bug by inputing the following code in the command line.
.. code-block:: python
pip install black
python -m black . -l 120

View File

@@ -61,7 +61,6 @@ task:
metric: loss
loss: mse
base_model: LSTM
with_pretrain: True
model_path: "benchmarks/LSTM/csi300_lstm_ts.pkl"
GPU: 0
dataset:

View File

@@ -54,7 +54,6 @@ task:
metric: loss
loss: mse
base_model: LSTM
with_pretrain: True
model_path: "benchmarks/LSTM/model_lstm_csi300.pkl"
GPU: 0
dataset:
@@ -81,4 +80,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -0,0 +1,16 @@
import datetime
import pandas as pd
from qlib.data.inst_processor import InstProcessor
class Resample1minProcessor(InstProcessor):
def __init__(self, hour: int, minute: int, **kwargs):
self.hour = hour
self.minute = minute
def __call__(self, df: pd.DataFrame, *args, **kwargs):
df.index = pd.to_datetime(df.index)
df = df.loc[df.index.time == datetime.time(self.hour, self.minute)]
df.index = df.index.normalize()
return df

View File

@@ -0,0 +1,83 @@
qlib_init:
provider_uri:
day: "~/.qlib/qlib_data/cn_data"
1min: "~/.qlib/qlib_data/cn_data_1min"
region: cn
dataset_cache: null
maxtasksperchild: 1
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
# 1min closing time is 15:00:00
end_time: "2020-08-01 15:00:00"
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
freq:
label: day
feature: 1min
# with label as reference
inst_processor:
feature:
- class: Resample1minProcessor
module_path: features_sample.py
kwargs:
hour: 14
minute: 56
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LGBModel
module_path: qlib.contrib.model.gbdt
kwargs:
loss: mse
colsample_bytree: 0.8879
learning_rate: 0.2
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -78,4 +78,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -0,0 +1,3 @@
numpy==1.17.4
pandas==1.1.2
torch==1.2.0

View File

@@ -0,0 +1,82 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: FilterCol
kwargs:
fields_group: feature
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
]
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LocalformerModel
module_path: qlib.contrib.model.pytorch_localformer_ts
kwargs:
seed: 0
n_jobs: 20
dataset:
class: TSDatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
step_len: 20
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,73 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LocalformerModel
module_path: qlib.contrib.model.pytorch_localformer
kwargs:
d_feat: 6
seed: 0
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -1,9 +1,13 @@
# Benchmarks Performance
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs.
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs with different random seeds.
The numbers shown below demonstrate the performance of the entire `workflow` of each model. We will update the `workflow` as well as models in the near future for better results.
> If you need to reproduce the results below, please use the **v1** dataset: `python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1d --region cn --version v1`
>
> In the new version of qlib, the default dataset is **v2**. Since the data is collected from the YahooFinance API (which is not very stable), the results of *v2* and *v1* may differ
## Alpha360 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|---|---|---|---|---|---|---|---|---|
@@ -18,6 +22,10 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 |
| TCTS (Xueqing Wu, et al.)| Alpha360 | 0.0485±0.00 | 0.3689±0.04| 0.0586±0.00 | 0.4669±0.02 | 0.0816±0.02 | 1.1572±0.30| -0.0689±0.02 |
| Transformer (Ashish Vaswani, et al.)| Alpha360 | 0.0141±0.00 | 0.0917±0.02| 0.0331±0.00 | 0.2357±0.03 | -0.0259±0.03 | -0.3323±0.43| -0.1763±0.07 |
| Localformer (Juyong Jiang, et al.)| Alpha360 | 0.0408±0.00 | 0.2988±0.03| 0.0538±0.00 | 0.4105±0.02 | 0.0275±0.03 | 0.3464±0.37| -0.1182±0.03 |
| TRA (Hengxu Lin, et al.)| Alpha360 | 0.0491±0.01 | 0.3868±0.06 | 0.0589±0.00 | 0.4802±0.04 | 0.0898±0.02 | 1.2490±0.32 | -0.0778±0.02 |
## Alpha158 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
@@ -34,6 +42,10 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 |
| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 |
| Transformer (Ashish Vaswani, et al.)| Alpha158 | 0.0274±0.00 | 0.2166±0.04| 0.0409±0.00 | 0.3342±0.04 | 0.0204±0.03 | 0.2888±0.40| -0.1216±0.04 |
| Localformer (Juyong Jiang, et al.)| Alpha158 | 0.0355±0.00 | 0.2747±0.04| 0.0466±0.00 | 0.3762±0.03 | 0.0506±0.02 | 0.7447±0.34| -0.0875±0.02 |
| TRA (Hengxu Lin, et al.)| Alpha158 (with selected 20 features)| 0.0409±0.00 | 0.3253±0.04 | 0.0488±0.00 | 0.4045±0.02 | 0.0673±0.02 | 1.0389±0.39 | -0.0830±0.02 |
| TRA (Hengxu Lin, et al.)| Alpha158 | 0.0442±0.00 | 0.3426±0.03 | 0.0555±0.00 | 0.4395±0.03 | 0.0833±0.03 | 1.2064±0.36 | -0.0849±0.02 |
- The selected 20 features are based on the feature importance of a lightgbm-based model.
- The base model of DoubleEnsemble is LGBM.

View File

@@ -0,0 +1,52 @@
# Temporally Correlated Task Scheduling for Sequence Learning
We provide the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) for reproducing the stock trend forecasting experiments.
### Background
Sequence learning has attracted much research attention from the machine learning community in recent years. In many applications, a sequence learning task is usually associated with multiple temporally correlated auxiliary tasks, which are different in terms of how much input information to use or which future step to predict. In stock trend forecasting, as demonstrated in Figure1, one can predict the price of a stock in different future days (e.g., tomorrow, the day after tomorrow). In this paper, we propose a framework to make use of those temporally correlated tasks to help each other.
<p align="center">
<img src="task_description.png" width="600" height="200"/>
</p>
### Method
Given that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. In this work, we introduce a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in current minibatch), and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2.
<p align="center">
<img src="workflow.png"/>
</p>
At step <img src="https://render.githubusercontent.com/render/math?math=s">, with training data <img src="https://render.githubusercontent.com/render/math?math=x_s,y_s">, the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> chooses a suitable task <img src="https://render.githubusercontent.com/render/math?math=T_{i_s}"> (green solid lines) to update the model <img src="https://render.githubusercontent.com/render/math?math=f"> (blue solid lines). After <img src="https://render.githubusercontent.com/render/math?math=S"> steps, we evaluate the model <img src="https://render.githubusercontent.com/render/math?math=f"> on the validation set and update the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> (green dashed lines).
### DataSet
* We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020.
* We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time.
### Experiments
#### Task Description
* The main tasks <img src="https://render.githubusercontent.com/render/math?math=T_k"> (<img src="https://render.githubusercontent.com/render/math?math=task_k"> in Figure1) refers to forecasting return of stock <img src="https://render.githubusercontent.com/render/math?math=i"> as following,
<div align=center>
<img src="https://render.githubusercontent.com/render/math?math=r_{i}^k = \frac{\price_i^{t+k}}{\price_i^{t+k-1}} - 1">
</div>
* Temporally correlated task sets <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_k = \{T_1, T_2, ... , T_k\}">, in this paper, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5"> and <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_10"> are used.
#### Baselines
* GRU/MLP/LightGBM (LGB)/Graph Attention Networks (GAT)
* Multi-task learning (MTL): In multi-task learning, multiple tasks are jointly trained and mutually boosted. Each task is treated equally, while in our setting, we focus on the main task.
* Curriculum transfer learning (CL): Transfer learning also leverages auxiliary tasks to boost the main task. [Curriculum transfer learning](https://arxiv.org/pdf/1804.00810.pdf) is one kind of transfer learning which schedules auxiliary tasks according to certain rules. Our problem can also be regarded as a special kind of transfer learning, where the auxiliary tasks are temporally correlated with the main task. Our learning process is dynamically controlled by a scheduler rather than some pre-defined rules. In the CL baseline, we start from the task <img src="https://render.githubusercontent.com/render/math?math=T_1" >, then <img src="https://render.githubusercontent.com/render/math?math=T_2" >, and gradually move to the last one.
#### Result
| Methods | <img src="https://render.githubusercontent.com/render/math?math=T_1" > | <img src="https://render.githubusercontent.com/render/math?math=T_2"> | <img src="https://render.githubusercontent.com/render/math?math=T_3"> |
| :----: | :----: | :----: | :----: |
| GRU | 0.049 / 1.903 | 0.018 / 1.972 | 0.014 / 1.989 |
| MLP | 0.023 / 1.961 | 0.022 / 1.962 | 0.015 / 1.978 |
| LGB | 0.038 / 1.883 | 0.023 / 1.952 | 0.007 / 1.987 |
| GAT | 0.052 / 1.898 | 0.024 / 1.954 | 0.015 / 1.973 |
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.061 / 1.862 | 0.023 / 1.942 | 0.012 / 1.956 |
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.051 / 1.880 | 0.028 / 1.941 | 0.016 / 1.962 |
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.071 / 1.851 | 0.030 / 1.939 | 0.017 / 1.963 |
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.057 / 1.875 | 0.021 / 1.939 | 0.017 / 1.959 |
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.056 / 1.877 | 0.028 / 1.942 | 0.015 / 1.962 |
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.075 / 1.849 | 0.032 /1.939 | 0.021 / 1.955 |
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.052 / 1.882 | 0.020 / 1.947 | 0.019 / 1.952 |
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.051 / 1.882 | 0.028 / 1.950 | 0.016 / 1.961 |
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.067 / 1.867 | 0.030 / 1.960 | 0.022 / 1.942|

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

View File

@@ -0,0 +1,93 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -1) / $close - 1",
"Ref($close, -2) / Ref($close, -1) - 1",
"Ref($close, -3) / Ref($close, -2) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TCTS
module_path: qlib.contrib.model.pytorch_tcts
kwargs:
d_feat: 6
hidden_size: 64
num_layers: 2
dropout: 0.0
n_epochs: 200
lr: 1e-3
early_stop: 20
batch_size: 800
metric: loss
loss: mse
GPU: 0
fore_optimizer: adam
weight_optimizer: adam
output_dim: 3
fore_lr: 5e-4
weight_lr: 5e-4
steps: 3
target_label: 1
lowest_valid_performance: 0.993
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
label_col: 1
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,92 @@
# Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport
Temporal Routing Adaptor (TRA) is designed to capture multiple trading patterns in the stock market data. Please refer to [our paper](http://arxiv.org/abs/2106.12950) for more details.
If you find our work useful in your research, please cite:
```
@inproceedings{HengxuKDD2021,
author = {Hengxu Lin and Dong Zhou and Weiqing Liu and Jiang Bian},
title = {Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport},
booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery \& Data Mining},
series = {KDD '21},
year = {2021},
publisher = {ACM},
}
@article{yang2020qlib,
title={Qlib: An AI-oriented Quantitative Investment Platform},
author={Yang, Xiao and Liu, Weiqing and Zhou, Dong and Bian, Jiang and Liu, Tie-Yan},
journal={arXiv preprint arXiv:2009.11189},
year={2020}
}
```
## Usage (Recommended)
**Update**: `TRA` has been moved to `qlib.contrib.model.pytorch_tra` to support other `Qlib` components like `qlib.workflow` and `Alpha158/Alpha360` dataset.
Please follow the official [doc](https://qlib.readthedocs.io/en/latest/component/workflow.html) to use `TRA` with `workflow`. Here we also provide several example config files:
- `workflow_config_tra_Alpha360.yaml`: running `TRA` with `Alpha360` dataset
- `workflow_config_tra_Alpha158.yaml`: running `TRA` with `Alpha158` dataset (with feature subsampling)
- `workflow_config_tra_Alpha158_full.yaml`: running `TRA` with `Alpha158` dataset (without feature subsampling)
The performances of `TRA` are reported in [Benchmarks](https://github.com/microsoft/qlib/tree/main/examples/benchmarks).
## Usage (Not Maintained)
This section is used to reproduce the results in the paper.
### Running
We attach our running scripts for the paper in `run.sh`.
And here are two ways to run the model:
* Running from scripts with default parameters
You can directly run from Qlib command `qrun`:
```
qrun configs/config_alstm.yaml
```
* Running from code with self-defined parameters
Setting different parameters is also allowed. See codes in `example.py`:
```
python example.py --config_file configs/config_alstm.yaml
```
Here we trained TRA on a pretrained backbone model. Therefore we run `*_init.yaml` before TRA's scipts.
### Results
After running the scripts, you can find result files in path `./output`:
* `info.json` - config settings and result metrics.
* `log.csv` - running logs.
* `model.bin` - the model parameter dictionary.
* `pred.pkl` - the prediction scores and output for inference.
Evaluation metrics reported in the paper:
| Methods | MSE| MAE| IC | ICIR | AR | AV | SR | MDD |
|-------|-------|------|-----|-----|-----|-----|-----|-----|
|Linear|0.163|0.327|0.020|0.132|-3.2%|16.8%|-0.191|32.1%|
|LightGBM|0.160(0.000)|0.323(0.000)|0.041|0.292|7.8%|15.5%|0.503|25.7%|
|MLP|0.160(0.002)|0.323(0.003)|0.037|0.273|3.7%|15.3%|0.264|26.2%|
|SFM|0.159(0.001) |0.321(0.001) |0.047 |0.381 |7.1% |14.3% |0.497 |22.9%|
|ALSTM|0.158(0.001) |0.320(0.001) |0.053 |0.419 |12.3% |13.7% |0.897 |20.2%|
|Trans.|0.158(0.001) |0.322(0.001) |0.051 |0.400 |14.5% |14.2% |1.028 |22.5%|
|ALSTM+TS|0.160(0.002) |0.321(0.002) |0.039 |0.291 |6.7% |14.6% |0.480|22.3%|
|Trans.+TS|0.160(0.004) |0.324(0.005) |0.037 |0.278 |10.4% |14.7% |0.722 |23.7%|
|ALSTM+TRA(Ours)|0.157(0.000) |0.318(0.000) |0.059 |0.460 |12.4% |14.0% |0.885 |20.4%|
|Trans.+TRA(Ours)|0.157(0.000) |0.320(0.000) |0.056 |0.442 |16.1% |14.2% |1.133 |23.1%|
A more detailed demo for our experiment results in the paper can be found in `Report.ipynb`.
## Common Issues
For help or issues using TRA, please submit a GitHub issue.
Sometimes we might encounter situation where the loss is `NaN`, please check the `epsilon` parameter in the sinkhorn algorithm, adjusting the `epsilon` according to input's scale is important.

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,63 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
data_loader_config: &data_loader_config
class: StaticDataLoader
module_path: qlib.data.dataset.loader
kwargs:
config:
feature: data/feature.pkl
label: data/label.pkl
model_config: &model_config
input_size: 16
hidden_size: 256
num_layers: 2
num_heads: 2
use_attn: True
dropout: 0.1
num_states: &num_states 1
tra_config: &tra_config
num_states: *num_states
hidden_size: 16
tau: 1.0
src_info: LR_TPE
task:
model:
class: TRAModel
module_path: src/model.py
kwargs:
lr: 0.0002
n_epochs: 500
max_steps_per_epoch: 100
early_stop: 20
seed: 1000
logdir: output/test/alstm
model_type: LSTM
model_config: *model_config
tra_config: *tra_config
lamb: 1.0
rho: 0.99
freeze_model: False
model_init_state:
dataset:
class: MTSDatasetH
module_path: src/dataset.py
kwargs:
handler:
class: DataHandler
module_path: qlib.data.dataset.handler
kwargs:
data_loader: *data_loader_config
segments:
train: [2007-10-30, 2016-05-27]
valid: [2016-09-26, 2018-05-29]
test: [2018-09-21, 2020-06-30]
seq_len: 60
horizon: 21
num_states: *num_states
batch_size: 1024

View File

@@ -0,0 +1,63 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
data_loader_config: &data_loader_config
class: StaticDataLoader
module_path: qlib.data.dataset.loader
kwargs:
config:
feature: data/feature.pkl
label: data/label.pkl
model_config: &model_config
input_size: 16
hidden_size: 256
num_layers: 2
num_heads: 2
use_attn: True
dropout: 0.1
num_states: &num_states 10
tra_config: &tra_config
num_states: *num_states
hidden_size: 16
tau: 1.0
src_info: LR_TPE
task:
model:
class: TRAModel
module_path: src/model.py
kwargs:
lr: 0.0001
n_epochs: 500
max_steps_per_epoch: 100
early_stop: 20
seed: 1000
logdir: output/test/alstm_tra
model_type: LSTM
model_config: *model_config
tra_config: *tra_config
lamb: 2.0
rho: 0.99
freeze_model: True
model_init_state: output/test/alstm_tra_init/model.bin
dataset:
class: MTSDatasetH
module_path: src/dataset.py
kwargs:
handler:
class: DataHandler
module_path: qlib.data.dataset.handler
kwargs:
data_loader: *data_loader_config
segments:
train: [2007-10-30, 2016-05-27]
valid: [2016-09-26, 2018-05-29]
test: [2018-09-21, 2020-06-30]
seq_len: 60
horizon: 21
num_states: *num_states
batch_size: 1024

View File

@@ -0,0 +1,63 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
data_loader_config: &data_loader_config
class: StaticDataLoader
module_path: qlib.data.dataset.loader
kwargs:
config:
feature: data/feature.pkl
label: data/label.pkl
model_config: &model_config
input_size: 16
hidden_size: 256
num_layers: 2
num_heads: 2
use_attn: True
dropout: 0.1
num_states: &num_states 3
tra_config: &tra_config
num_states: *num_states
hidden_size: 16
tau: 1.0
src_info: LR_TPE
task:
model:
class: TRAModel
module_path: src/model.py
kwargs:
lr: 0.0002
n_epochs: 500
max_steps_per_epoch: 100
early_stop: 20
seed: 1000
logdir: output/test/alstm_tra_init
model_type: LSTM
model_config: *model_config
tra_config: *tra_config
lamb: 1.0
rho: 0.99
freeze_model: False
model_init_state:
dataset:
class: MTSDatasetH
module_path: src/dataset.py
kwargs:
handler:
class: DataHandler
module_path: qlib.data.dataset.handler
kwargs:
data_loader: *data_loader_config
segments:
train: [2007-10-30, 2016-05-27]
valid: [2016-09-26, 2018-05-29]
test: [2018-09-21, 2020-06-30]
seq_len: 60
horizon: 21
num_states: *num_states
batch_size: 512

View File

@@ -0,0 +1,63 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
data_loader_config: &data_loader_config
class: StaticDataLoader
module_path: qlib.data.dataset.loader
kwargs:
config:
feature: data/feature.pkl
label: data/label.pkl
model_config: &model_config
input_size: 16
hidden_size: 64
num_layers: 2
num_heads: 4
use_attn: False
dropout: 0.1
num_states: &num_states 1
tra_config: &tra_config
num_states: *num_states
hidden_size: 16
tau: 1.0
src_info: LR_TPE
task:
model:
class: TRAModel
module_path: src/model.py
kwargs:
lr: 0.0002
n_epochs: 500
max_steps_per_epoch: 100
early_stop: 20
seed: 1000
logdir: output/test/transformer
model_type: Transformer
model_config: *model_config
tra_config: *tra_config
lamb: 1.0
rho: 0.99
freeze_model: False
model_init_state:
dataset:
class: MTSDatasetH
module_path: src/dataset.py
kwargs:
handler:
class: DataHandler
module_path: qlib.data.dataset.handler
kwargs:
data_loader: *data_loader_config
segments:
train: [2007-10-30, 2016-05-27]
valid: [2016-09-26, 2018-05-29]
test: [2018-09-21, 2020-06-30]
seq_len: 60
horizon: 21
num_states: *num_states
batch_size: 1024

View File

@@ -0,0 +1,63 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
data_loader_config: &data_loader_config
class: StaticDataLoader
module_path: qlib.data.dataset.loader
kwargs:
config:
feature: data/feature.pkl
label: data/label.pkl
model_config: &model_config
input_size: 16
hidden_size: 64
num_layers: 2
num_heads: 4
use_attn: False
dropout: 0.1
num_states: &num_states 3
tra_config: &tra_config
num_states: *num_states
hidden_size: 16
tau: 1.0
src_info: LR_TPE
task:
model:
class: TRAModel
module_path: src/model.py
kwargs:
lr: 0.0005
n_epochs: 500
max_steps_per_epoch: 100
early_stop: 20
seed: 1000
logdir: output/test/transformer_tra
model_type: Transformer
model_config: *model_config
tra_config: *tra_config
lamb: 1.0
rho: 0.99
freeze_model: True
model_init_state: output/test/transformer_tra_init/model.bin
dataset:
class: MTSDatasetH
module_path: src/dataset.py
kwargs:
handler:
class: DataHandler
module_path: qlib.data.dataset.handler
kwargs:
data_loader: *data_loader_config
segments:
train: [2007-10-30, 2016-05-27]
valid: [2016-09-26, 2018-05-29]
test: [2018-09-21, 2020-06-30]
seq_len: 60
horizon: 21
num_states: *num_states
batch_size: 512

View File

@@ -0,0 +1,63 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
data_loader_config: &data_loader_config
class: StaticDataLoader
module_path: qlib.data.dataset.loader
kwargs:
config:
feature: data/feature.pkl
label: data/label.pkl
model_config: &model_config
input_size: 16
hidden_size: 64
num_layers: 2
num_heads: 4
use_attn: False
dropout: 0.1
num_states: &num_states 3
tra_config: &tra_config
num_states: *num_states
hidden_size: 16
tau: 1.0
src_info: LR_TPE
task:
model:
class: TRAModel
module_path: src/model.py
kwargs:
lr: 0.0002
n_epochs: 500
max_steps_per_epoch: 100
early_stop: 20
seed: 1000
logdir: output/test/transformer_tra_init
model_type: Transformer
model_config: *model_config
tra_config: *tra_config
lamb: 1.0
rho: 0.99
freeze_model: False
model_init_state:
dataset:
class: MTSDatasetH
module_path: src/dataset.py
kwargs:
handler:
class: DataHandler
module_path: qlib.data.dataset.handler
kwargs:
data_loader: *data_loader_config
segments:
train: [2007-10-30, 2016-05-27]
valid: [2016-09-26, 2018-05-29]
test: [2018-09-21, 2020-06-30]
seq_len: 60
horizon: 21
num_states: *num_states
batch_size: 512

View File

@@ -0,0 +1 @@
Data Link: https://drive.google.com/drive/folders/1fMqZYSeLyrHiWmVzygeI4sw3vp5Gt8cY?usp=sharing

View File

@@ -0,0 +1,39 @@
import argparse
import qlib
import ruamel.yaml as yaml
from qlib.utils import init_instance_by_config
def main(seed, config_file="configs/config_alstm.yaml"):
# set random seed
with open(config_file) as f:
config = yaml.safe_load(f)
# seed_suffix = "/seed1000" if "init" in config_file else f"/seed{seed}"
seed_suffix = ""
config["task"]["model"]["kwargs"].update(
{"seed": seed, "logdir": config["task"]["model"]["kwargs"]["logdir"] + seed_suffix}
)
# initialize workflow
qlib.init(
provider_uri=config["qlib_init"]["provider_uri"],
region=config["qlib_init"]["region"],
)
dataset = init_instance_by_config(config["task"]["dataset"])
model = init_instance_by_config(config["task"]["model"])
# train model
model.fit(dataset)
if __name__ == "__main__":
# set params from cmd
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--seed", type=int, default=1000, help="random seed")
parser.add_argument("--config_file", type=str, default="configs/config_alstm.yaml", help="config file")
args = parser.parse_args()
main(**vars(args))

View File

@@ -0,0 +1,29 @@
#!/bin/bash
# we used random seed(1 1000 2000 3000 4000 5000) in our experiments
# Directly run from Qlib command `qrun`
qrun configs/config_alstm.yaml
qrun configs/config_transformer.yaml
qrun configs/config_transformer_tra_init.yaml
qrun configs/config_transformer_tra.yaml
qrun configs/config_alstm_tra_init.yaml
qrun configs/config_alstm_tra.yaml
# Or setting different parameters with example.py
python example.py --config_file configs/config_alstm.yaml
python example.py --config_file configs/config_transformer.yaml
python example.py --config_file configs/config_transformer_tra_init.yaml
python example.py --config_file configs/config_transformer_tra.yaml
python example.py --config_file configs/config_alstm_tra_init.yaml
python example.py --config_file configs/config_alstm_tra.yaml

View File

@@ -0,0 +1,253 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import copy
import torch
import numpy as np
import pandas as pd
from qlib.utils import init_instance_by_config
from qlib.data.dataset import DatasetH, DataHandler
device = "cuda" if torch.cuda.is_available() else "cpu"
def _to_tensor(x):
if not isinstance(x, torch.Tensor):
return torch.tensor(x, dtype=torch.float, device=device)
return x
def _create_ts_slices(index, seq_len):
"""
create time series slices from pandas index
Args:
index (pd.MultiIndex): pandas multiindex with <instrument, datetime> order
seq_len (int): sequence length
"""
assert index.is_lexsorted(), "index should be sorted"
# number of dates for each code
sample_count_by_codes = pd.Series(0, index=index).groupby(level=0).size().values
# start_index for each code
start_index_of_codes = np.roll(np.cumsum(sample_count_by_codes), 1)
start_index_of_codes[0] = 0
# all the [start, stop) indices of features
# features btw [start, stop) are used to predict the `stop - 1` label
slices = []
for cur_loc, cur_cnt in zip(start_index_of_codes, sample_count_by_codes):
for stop in range(1, cur_cnt + 1):
end = cur_loc + stop
start = max(end - seq_len, 0)
slices.append(slice(start, end))
slices = np.array(slices)
return slices
def _get_date_parse_fn(target):
"""get date parse function
This method is used to parse date arguments as target type.
Example:
get_date_parse_fn('20120101')('2017-01-01') => '20170101'
get_date_parse_fn(20120101)('2017-01-01') => 20170101
"""
if isinstance(target, pd.Timestamp):
_fn = lambda x: pd.Timestamp(x) # Timestamp('2020-01-01')
elif isinstance(target, str) and len(target) == 8:
_fn = lambda x: str(x).replace("-", "")[:8] # '20200201'
elif isinstance(target, int):
_fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201
else:
_fn = lambda x: x
return _fn
class MTSDatasetH(DatasetH):
"""Memory Augmented Time Series Dataset
Args:
handler (DataHandler): data handler
segments (dict): data split segments
seq_len (int): time series sequence length
horizon (int): label horizon (to mask historical loss for TRA)
num_states (int): how many memory states to be added (for TRA)
batch_size (int): batch size (<0 means daily batch)
shuffle (bool): whether shuffle data
pin_memory (bool): whether pin data to gpu memory
drop_last (bool): whether drop last batch < batch_size
"""
def __init__(
self,
handler,
segments,
seq_len=60,
horizon=0,
num_states=1,
batch_size=-1,
shuffle=True,
pin_memory=False,
drop_last=False,
**kwargs
):
assert horizon > 0, "please specify `horizon` to avoid data leakage"
self.seq_len = seq_len
self.horizon = horizon
self.num_states = num_states
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.pin_memory = pin_memory
self.params = (batch_size, drop_last, shuffle) # for train/eval switch
super().__init__(handler, segments, **kwargs)
def setup_data(self, handler_kwargs: dict = None, **kwargs):
super().setup_data()
# change index to <code, date>
# NOTE: we will use inplace sort to reduce memory use
df = self.handler._data
df.index = df.index.swaplevel()
df.sort_index(inplace=True)
self._data = df["feature"].values.astype("float32")
self._label = df["label"].squeeze().astype("float32")
self._index = df.index
# add memory to feature
self._data = np.c_[self._data, np.zeros((len(self._data), self.num_states), dtype=np.float32)]
# padding tensor
self.zeros = np.zeros((self.seq_len, self._data.shape[1]), dtype=np.float32)
# pin memory
if self.pin_memory:
self._data = _to_tensor(self._data)
self._label = _to_tensor(self._label)
self.zeros = _to_tensor(self.zeros)
# create batch slices
self.batch_slices = _create_ts_slices(self._index, self.seq_len)
# create daily slices
index = [slc.stop - 1 for slc in self.batch_slices]
act_index = self.restore_index(index)
daily_slices = {date: [] for date in sorted(act_index.unique(level=1))}
for i, (code, date) in enumerate(act_index):
daily_slices[date].append(self.batch_slices[i])
self.daily_slices = list(daily_slices.values())
def _prepare_seg(self, slc, **kwargs):
fn = _get_date_parse_fn(self._index[0][1])
start_date = fn(slc.start)
end_date = fn(slc.stop)
obj = copy.copy(self) # shallow copy
# NOTE: Seriable will disable copy `self._data` so we manually assign them here
obj._data = self._data
obj._label = self._label
obj._index = self._index
new_batch_slices = []
for batch_slc in self.batch_slices:
date = self._index[batch_slc.stop - 1][1]
if start_date <= date <= end_date:
new_batch_slices.append(batch_slc)
obj.batch_slices = np.array(new_batch_slices)
new_daily_slices = []
for daily_slc in self.daily_slices:
date = self._index[daily_slc[0].stop - 1][1]
if start_date <= date <= end_date:
new_daily_slices.append(daily_slc)
obj.daily_slices = new_daily_slices
return obj
def restore_index(self, index):
if isinstance(index, torch.Tensor):
index = index.cpu().numpy()
return self._index[index]
def assign_data(self, index, vals):
if isinstance(self._data, torch.Tensor):
vals = _to_tensor(vals)
elif isinstance(vals, torch.Tensor):
vals = vals.detach().cpu().numpy()
index = index.detach().cpu().numpy()
self._data[index, -self.num_states :] = vals
def clear_memory(self):
self._data[:, -self.num_states :] = 0
# TODO: better train/eval mode design
def train(self):
"""enable traning mode"""
self.batch_size, self.drop_last, self.shuffle = self.params
def eval(self):
"""enable evaluation mode"""
self.batch_size = -1
self.drop_last = False
self.shuffle = False
def _get_slices(self):
if self.batch_size < 0:
slices = self.daily_slices.copy()
batch_size = -1 * self.batch_size
else:
slices = self.batch_slices.copy()
batch_size = self.batch_size
return slices, batch_size
def __len__(self):
slices, batch_size = self._get_slices()
if self.drop_last:
return len(slices) // batch_size
return (len(slices) + batch_size - 1) // batch_size
def __iter__(self):
slices, batch_size = self._get_slices()
if self.shuffle:
np.random.shuffle(slices)
for i in range(len(slices))[::batch_size]:
if self.drop_last and i + batch_size > len(slices):
break
# get slices for this batch
slices_subset = slices[i : i + batch_size]
if self.batch_size < 0:
slices_subset = np.concatenate(slices_subset)
# collect data
data = []
label = []
index = []
for slc in slices_subset:
_data = self._data[slc].clone() if self.pin_memory else self._data[slc].copy()
if len(_data) != self.seq_len:
if self.pin_memory:
_data = torch.cat([self.zeros[: self.seq_len - len(_data)], _data], axis=0)
else:
_data = np.concatenate([self.zeros[: self.seq_len - len(_data)], _data], axis=0)
if self.num_states > 0:
_data[-self.horizon :, -self.num_states :] = 0
data.append(_data)
label.append(self._label[slc.stop - 1])
index.append(slc.stop - 1)
# concate
index = torch.tensor(index, device=device)
if isinstance(data[0], torch.Tensor):
data = torch.stack(data)
label = torch.stack(label)
else:
data = _to_tensor(np.stack(data))
label = _to_tensor(np.stack(label))
# yield -> generator
yield {"data": data, "label": label, "index": index}

View File

@@ -0,0 +1,603 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import copy
import math
import json
import collections
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
from qlib.utils import get_or_create_path
from qlib.log import get_module_logger
from qlib.model.base import Model
device = "cuda" if torch.cuda.is_available() else "cpu"
class TRAModel(Model):
def __init__(
self,
model_config,
tra_config,
model_type="LSTM",
lr=1e-3,
n_epochs=500,
early_stop=50,
smooth_steps=5,
max_steps_per_epoch=None,
freeze_model=False,
model_init_state=None,
lamb=0.0,
rho=0.99,
seed=0,
logdir=None,
eval_train=True,
eval_test=False,
avg_params=True,
**kwargs,
):
np.random.seed(seed)
torch.manual_seed(seed)
self.logger = get_module_logger("TRA")
self.logger.info("TRA Model...")
self.model = eval(model_type)(**model_config).to(device)
if model_init_state:
self.model.load_state_dict(torch.load(model_init_state, map_location="cpu")["model"])
if freeze_model:
for param in self.model.parameters():
param.requires_grad_(False)
else:
self.logger.info("# model params: %d" % sum([p.numel() for p in self.model.parameters()]))
self.tra = TRA(self.model.output_size, **tra_config).to(device)
self.logger.info("# tra params: %d" % sum([p.numel() for p in self.tra.parameters()]))
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=lr)
self.model_config = model_config
self.tra_config = tra_config
self.lr = lr
self.n_epochs = n_epochs
self.early_stop = early_stop
self.smooth_steps = smooth_steps
self.max_steps_per_epoch = max_steps_per_epoch
self.lamb = lamb
self.rho = rho
self.seed = seed
self.logdir = logdir
self.eval_train = eval_train
self.eval_test = eval_test
self.avg_params = avg_params
if self.tra.num_states > 1 and not self.eval_train:
self.logger.warn("`eval_train` will be ignored when using TRA")
if self.logdir is not None:
if os.path.exists(self.logdir):
self.logger.warn(f"logdir {self.logdir} is not empty")
os.makedirs(self.logdir, exist_ok=True)
self.fitted = False
self.global_step = -1
def train_epoch(self, data_set):
self.model.train()
self.tra.train()
data_set.train()
max_steps = self.n_epochs
if self.max_steps_per_epoch is not None:
max_steps = min(self.max_steps_per_epoch, self.n_epochs)
count = 0
total_loss = 0
total_count = 0
for batch in tqdm(data_set, total=max_steps):
count += 1
if count > max_steps:
break
self.global_step += 1
data, label, index = batch["data"], batch["label"], batch["index"]
feature = data[:, :, : -self.tra.num_states]
hist_loss = data[:, : -data_set.horizon, -self.tra.num_states :]
hidden = self.model(feature)
pred, all_preds, prob = self.tra(hidden, hist_loss)
loss = (pred - label).pow(2).mean()
L = (all_preds.detach() - label[:, None]).pow(2)
L -= L.min(dim=-1, keepdim=True).values # normalize & ensure postive input
data_set.assign_data(index, L) # save loss to memory
if prob is not None:
P = sinkhorn(-L, epsilon=0.01) # sample assignment matrix
lamb = self.lamb * (self.rho ** self.global_step)
reg = prob.log().mul(P).sum(dim=-1).mean()
loss = loss - lamb * reg
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
total_loss += loss.item()
total_count += len(pred)
total_loss /= total_count
return total_loss
def test_epoch(self, data_set, return_pred=False):
self.model.eval()
self.tra.eval()
data_set.eval()
preds = []
metrics = []
for batch in tqdm(data_set):
data, label, index = batch["data"], batch["label"], batch["index"]
feature = data[:, :, : -self.tra.num_states]
hist_loss = data[:, : -data_set.horizon, -self.tra.num_states :]
with torch.no_grad():
hidden = self.model(feature)
pred, all_preds, prob = self.tra(hidden, hist_loss)
L = (all_preds - label[:, None]).pow(2)
L -= L.min(dim=-1, keepdim=True).values # normalize & ensure postive input
data_set.assign_data(index, L) # save loss to memory
X = np.c_[
pred.cpu().numpy(),
label.cpu().numpy(),
]
columns = ["score", "label"]
if prob is not None:
X = np.c_[X, all_preds.cpu().numpy(), prob.cpu().numpy()]
columns += ["score_%d" % d for d in range(all_preds.shape[1])] + [
"prob_%d" % d for d in range(all_preds.shape[1])
]
pred = pd.DataFrame(X, index=index.cpu().numpy(), columns=columns)
metrics.append(evaluate(pred))
if return_pred:
preds.append(pred)
metrics = pd.DataFrame(metrics)
metrics = {
"MSE": metrics.MSE.mean(),
"MAE": metrics.MAE.mean(),
"IC": metrics.IC.mean(),
"ICIR": metrics.IC.mean() / metrics.IC.std(),
}
if return_pred:
preds = pd.concat(preds, axis=0)
preds.index = data_set.restore_index(preds.index)
preds.index = preds.index.swaplevel()
preds.sort_index(inplace=True)
return metrics, preds
def fit(self, dataset, evals_result=dict()):
train_set, valid_set, test_set = dataset.prepare(["train", "valid", "test"])
best_score = -1
best_epoch = 0
stop_rounds = 0
best_params = {
"model": copy.deepcopy(self.model.state_dict()),
"tra": copy.deepcopy(self.tra.state_dict()),
}
params_list = {
"model": collections.deque(maxlen=self.smooth_steps),
"tra": collections.deque(maxlen=self.smooth_steps),
}
evals_result["train"] = []
evals_result["valid"] = []
evals_result["test"] = []
# train
self.fitted = True
self.global_step = -1
if self.tra.num_states > 1:
self.logger.info("init memory...")
self.test_epoch(train_set)
for epoch in range(self.n_epochs):
self.logger.info("Epoch %d:", epoch)
self.logger.info("training...")
self.train_epoch(train_set)
self.logger.info("evaluating...")
# average params for inference
params_list["model"].append(copy.deepcopy(self.model.state_dict()))
params_list["tra"].append(copy.deepcopy(self.tra.state_dict()))
self.model.load_state_dict(average_params(params_list["model"]))
self.tra.load_state_dict(average_params(params_list["tra"]))
# NOTE: during evaluating, the whole memory will be refreshed
if self.tra.num_states > 1 or self.eval_train:
train_set.clear_memory() # NOTE: clear the shared memory
train_metrics = self.test_epoch(train_set)[0]
evals_result["train"].append(train_metrics)
self.logger.info("\ttrain metrics: %s" % train_metrics)
valid_metrics = self.test_epoch(valid_set)[0]
evals_result["valid"].append(valid_metrics)
self.logger.info("\tvalid metrics: %s" % valid_metrics)
if self.eval_test:
test_metrics = self.test_epoch(test_set)[0]
evals_result["test"].append(test_metrics)
self.logger.info("\ttest metrics: %s" % test_metrics)
if valid_metrics["IC"] > best_score:
best_score = valid_metrics["IC"]
stop_rounds = 0
best_epoch = epoch
best_params = {
"model": copy.deepcopy(self.model.state_dict()),
"tra": copy.deepcopy(self.tra.state_dict()),
}
else:
stop_rounds += 1
if stop_rounds >= self.early_stop:
self.logger.info("early stop @ %s" % epoch)
break
# restore parameters
self.model.load_state_dict(params_list["model"][-1])
self.tra.load_state_dict(params_list["tra"][-1])
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.model.load_state_dict(best_params["model"])
self.tra.load_state_dict(best_params["tra"])
metrics, preds = self.test_epoch(test_set, return_pred=True)
self.logger.info("test metrics: %s" % metrics)
if self.logdir:
self.logger.info("save model & pred to local directory")
pd.concat({name: pd.DataFrame(evals_result[name]) for name in evals_result}, axis=1).to_csv(
self.logdir + "/logs.csv", index=False
)
torch.save(best_params, self.logdir + "/model.bin")
preds.to_pickle(self.logdir + "/pred.pkl")
info = {
"config": {
"model_config": self.model_config,
"tra_config": self.tra_config,
"lr": self.lr,
"n_epochs": self.n_epochs,
"early_stop": self.early_stop,
"smooth_steps": self.smooth_steps,
"max_steps_per_epoch": self.max_steps_per_epoch,
"lamb": self.lamb,
"rho": self.rho,
"seed": self.seed,
"logdir": self.logdir,
},
"best_eval_metric": -best_score, # NOTE: minux -1 for minimize
"metric": metrics,
}
with open(self.logdir + "/info.json", "w") as f:
json.dump(info, f)
def predict(self, dataset, segment="test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
test_set = dataset.prepare(segment)
metrics, preds = self.test_epoch(test_set, return_pred=True)
self.logger.info("test metrics: %s" % metrics)
return preds
class LSTM(nn.Module):
"""LSTM Model
Args:
input_size (int): input size (# features)
hidden_size (int): hidden size
num_layers (int): number of hidden layers
use_attn (bool): whether use attention layer.
we use concat attention as https://github.com/fulifeng/Adv-ALSTM/
dropout (float): dropout rate
input_drop (float): input dropout for data augmentation
noise_level (float): add gaussian noise to input for data augmentation
"""
def __init__(
self,
input_size=16,
hidden_size=64,
num_layers=2,
use_attn=True,
dropout=0.0,
input_drop=0.0,
noise_level=0.0,
*args,
**kwargs,
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.use_attn = use_attn
self.noise_level = noise_level
self.input_drop = nn.Dropout(input_drop)
self.rnn = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout,
)
if self.use_attn:
self.W = nn.Linear(hidden_size, hidden_size)
self.u = nn.Linear(hidden_size, 1, bias=False)
self.output_size = hidden_size * 2
else:
self.output_size = hidden_size
def forward(self, x):
x = self.input_drop(x)
if self.training and self.noise_level > 0:
noise = torch.randn_like(x).to(x)
x = x + noise * self.noise_level
rnn_out, _ = self.rnn(x)
last_out = rnn_out[:, -1]
if self.use_attn:
laten = self.W(rnn_out).tanh()
scores = self.u(laten).softmax(dim=1)
att_out = (rnn_out * scores).sum(dim=1).squeeze()
last_out = torch.cat([last_out, att_out], dim=1)
return last_out
class PositionalEncoding(nn.Module):
# reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[: x.size(0), :]
return self.dropout(x)
class Transformer(nn.Module):
"""Transformer Model
Args:
input_size (int): input size (# features)
hidden_size (int): hidden size
num_layers (int): number of transformer layers
num_heads (int): number of heads in transformer
dropout (float): dropout rate
input_drop (float): input dropout for data augmentation
noise_level (float): add gaussian noise to input for data augmentation
"""
def __init__(
self,
input_size=16,
hidden_size=64,
num_layers=2,
num_heads=2,
dropout=0.0,
input_drop=0.0,
noise_level=0.0,
**kwargs,
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.noise_level = noise_level
self.input_drop = nn.Dropout(input_drop)
self.input_proj = nn.Linear(input_size, hidden_size)
self.pe = PositionalEncoding(input_size, dropout)
layer = nn.TransformerEncoderLayer(
nhead=num_heads, dropout=dropout, d_model=hidden_size, dim_feedforward=hidden_size * 4
)
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
self.output_size = hidden_size
def forward(self, x):
x = self.input_drop(x)
if self.training and self.noise_level > 0:
noise = torch.randn_like(x).to(x)
x = x + noise * self.noise_level
x = x.permute(1, 0, 2).contiguous() # the first dim need to be sequence
x = self.pe(x)
x = self.input_proj(x)
out = self.encoder(x)
return out[-1]
class TRA(nn.Module):
"""Temporal Routing Adaptor (TRA)
TRA takes historical prediction erros & latent representation as inputs,
then routes the input sample to a specific predictor for training & inference.
Args:
input_size (int): input size (RNN/Transformer's hidden size)
num_states (int): number of latent states (i.e., trading patterns)
If `num_states=1`, then TRA falls back to traditional methods
hidden_size (int): hidden size of the router
tau (float): gumbel softmax temperature
"""
def __init__(self, input_size, num_states=1, hidden_size=8, tau=1.0, src_info="LR_TPE"):
super().__init__()
self.num_states = num_states
self.tau = tau
self.src_info = src_info
if num_states > 1:
self.router = nn.LSTM(
input_size=num_states,
hidden_size=hidden_size,
num_layers=1,
batch_first=True,
)
self.fc = nn.Linear(hidden_size + input_size, num_states)
self.predictors = nn.Linear(input_size, num_states)
def forward(self, hidden, hist_loss):
preds = self.predictors(hidden)
if self.num_states == 1:
return preds.squeeze(-1), preds, None
# information type
router_out, _ = self.router(hist_loss)
if "LR" in self.src_info:
latent_representation = hidden
else:
latent_representation = torch.randn(hidden.shape).to(hidden)
if "TPE" in self.src_info:
temporal_pred_error = router_out[:, -1]
else:
temporal_pred_error = torch.randn(router_out[:, -1].shape).to(hidden)
out = self.fc(torch.cat([temporal_pred_error, latent_representation], dim=-1))
prob = F.gumbel_softmax(out, dim=-1, tau=self.tau, hard=False)
if self.training:
final_pred = (preds * prob).sum(dim=-1)
else:
final_pred = preds[range(len(preds)), prob.argmax(dim=-1)]
return final_pred, preds, prob
def evaluate(pred):
pred = pred.rank(pct=True) # transform into percentiles
score = pred.score
label = pred.label
diff = score - label
MSE = (diff ** 2).mean()
MAE = (diff.abs()).mean()
IC = score.corr(label)
return {"MSE": MSE, "MAE": MAE, "IC": IC}
def average_params(params_list):
assert isinstance(params_list, (tuple, list, collections.deque))
n = len(params_list)
if n == 1:
return params_list[0]
new_params = collections.OrderedDict()
keys = None
for i, params in enumerate(params_list):
if keys is None:
keys = params.keys()
for k, v in params.items():
if k not in keys:
raise ValueError("the %d-th model has different params" % i)
if k not in new_params:
new_params[k] = v / n
else:
new_params[k] += v / n
return new_params
def shoot_infs(inp_tensor):
"""Replaces inf by maximum of tensor"""
mask_inf = torch.isinf(inp_tensor)
ind_inf = torch.nonzero(mask_inf, as_tuple=False)
if len(ind_inf) > 0:
for ind in ind_inf:
if len(ind) == 2:
inp_tensor[ind[0], ind[1]] = 0
elif len(ind) == 1:
inp_tensor[ind[0]] = 0
m = torch.max(inp_tensor)
for ind in ind_inf:
if len(ind) == 2:
inp_tensor[ind[0], ind[1]] = m
elif len(ind) == 1:
inp_tensor[ind[0]] = m
return inp_tensor
def sinkhorn(Q, n_iters=3, epsilon=0.01):
# epsilon should be adjusted according to logits value's scale
with torch.no_grad():
Q = shoot_infs(Q)
Q = torch.exp(Q / epsilon)
for i in range(n_iters):
Q /= Q.sum(dim=0, keepdim=True)
Q /= Q.sum(dim=1, keepdim=True)
return Q

View File

@@ -0,0 +1,129 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: FilterCol
kwargs:
fields_group: feature
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"]
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
num_states: &num_states 3
memory_mode: &memory_mode sample
tra_config: &tra_config
num_states: *num_states
rnn_arch: LSTM
hidden_size: 32
num_layers: 1
dropout: 0.0
tau: 1.0
src_info: LR_TPE
model_config: &model_config
input_size: 20
hidden_size: 64
num_layers: 2
rnn_arch: LSTM
use_attn: True
dropout: 0.0
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TRAModel
module_path: qlib.contrib.model.pytorch_tra
kwargs:
tra_config: *tra_config
model_config: *model_config
model_type: RNN
lr: 1e-3
n_epochs: 100
max_steps_per_epoch:
early_stop: 20
logdir: output/Alpha158
seed: 0
lamb: 1.0
rho: 0.99
alpha: 0.5
transport_method: router
memory_mode: *memory_mode
eval_train: False
eval_test: True
pretrain: True
init_state:
freeze_model: False
freeze_predictors: False
dataset:
class: MTSDatasetH
module_path: qlib.contrib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
seq_len: 60
horizon: 2
input_size:
num_states: *num_states
batch_size: 1024
n_samples:
memory_mode: *memory_mode
drop_last: True
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,123 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
num_states: &num_states 3
memory_mode: &memory_mode sample
tra_config: &tra_config
num_states: *num_states
rnn_arch: LSTM
hidden_size: 32
num_layers: 1
dropout: 0.0
tau: 1.0
src_info: LR_TPE
model_config: &model_config
input_size: 158
hidden_size: 256
num_layers: 2
rnn_arch: LSTM
use_attn: True
dropout: 0.2
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TRAModel
module_path: qlib.contrib.model.pytorch_tra
kwargs:
tra_config: *tra_config
model_config: *model_config
model_type: RNN
lr: 1e-3
n_epochs: 100
max_steps_per_epoch:
early_stop: 20
logdir: output/Alpha158_full
seed: 0
lamb: 1.0
rho: 0.99
alpha: 0.5
transport_method: router
memory_mode: *memory_mode
eval_train: False
eval_test: True
pretrain: True
init_state:
freeze_model: False
freeze_predictors: False
dataset:
class: MTSDatasetH
module_path: qlib.contrib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
seq_len: 60
horizon: 2
input_size:
num_states: *num_states
batch_size: 1024
n_samples:
memory_mode: *memory_mode
drop_last: True
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,123 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
num_states: &num_states 3
memory_mode: &memory_mode sample
tra_config: &tra_config
num_states: *num_states
rnn_arch: LSTM
hidden_size: 32
num_layers: 1
dropout: 0.0
tau: 1.0
src_info: LR_TPE
model_config: &model_config
input_size: 6
hidden_size: 64
num_layers: 2
rnn_arch: LSTM
use_attn: True
dropout: 0.0
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TRAModel
module_path: qlib.contrib.model.pytorch_tra
kwargs:
tra_config: *tra_config
model_config: *model_config
model_type: RNN
lr: 1e-3
n_epochs: 100
max_steps_per_epoch:
early_stop: 20
logdir: output/Alpha360
seed: 0
lamb: 1.0
rho: 0.99
alpha: 0.5
transport_method: router
memory_mode: *memory_mode
eval_train: False
eval_test: True
pretrain: True
init_state:
freeze_model: False
freeze_predictors: False
dataset:
class: MTSDatasetH
module_path: qlib.contrib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
seq_len: 60
horizon: 2
input_size: 6
num_states: *num_states
batch_size: 1024
n_samples:
memory_mode: *memory_mode
drop_last: True
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,3 @@
numpy==1.17.4
pandas==1.1.2
torch==1.2.0

View File

@@ -0,0 +1,82 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: FilterCol
kwargs:
fields_group: feature
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
]
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TransformerModel
module_path: qlib.contrib.model.pytorch_transformer_ts
kwargs:
seed: 0
n_jobs: 20
dataset:
class: TSDatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
step_len: 20
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,73 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TransformerModel
module_path: qlib.contrib.model.pytorch_transformer
kwargs:
d_feat: 6
seed: 0
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -1,7 +1,5 @@
from qlib.data.dataset.handler import DataHandler, DataHandlerLP
from qlib.data.dataset.processor import Processor
from qlib.utils import get_cls_kwargs
from qlib.log import TimeInspector
from qlib.contrib.data.handler import check_transform_proc
class HighFreqHandler(DataHandlerLP):
@@ -16,20 +14,9 @@ class HighFreqHandler(DataHandlerLP):
fit_end_time=None,
drop_raw=True,
):
def check_transform_proc(proc_l):
new_l = []
for p in proc_l:
p["kwargs"].update(
{
"fit_start_time": fit_start_time,
"fit_end_time": fit_end_time,
}
)
new_l.append(p)
return new_l
infer_processors = check_transform_proc(infer_processors)
learn_processors = check_transform_proc(learn_processors)
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
data_loader = {
"class": "QlibDataLoader",
@@ -112,8 +99,6 @@ class HighFreqHandler(DataHandlerLP):
]
names += ["$volume_1"]
fields += ["Cut({0}, 240, None)".format(template_paused.format("Date($close)"))]
names += ["date"]
return fields, names

View File

@@ -26,7 +26,7 @@ def get_calendar_day(freq="day", future=False):
if flag in H["c"]:
_calendar = H["c"][flag]
else:
_calendar = np.array(list(map(lambda x: x.date(), Cal.load_calendar(freq, future))))
_calendar = np.array(list(map(lambda x: pd.Timestamp(x.date()), Cal.load_calendar(freq, future))))
H["c"][flag] = _calendar
return _calendar

View File

@@ -33,6 +33,9 @@ class HighFreqNorm(Processor):
self.feature_vmin[name] = np.nanmin(part_values)
def __call__(self, df_features):
df_features["date"] = pd.to_datetime(
df_features.index.get_level_values(level="datetime").to_series().dt.date.values
)
df_features.set_index("date", append=True, drop=True, inplace=True)
df_values = df_features.values
names = {

View File

@@ -33,7 +33,7 @@ class HighfreqWorkflow:
"fit_start_time": start_time,
"fit_end_time": train_end_time,
"instruments": MARKET,
"infer_processors": [{"class": "HighFreqNorm", "module_path": "highfreq_processor", "kwargs": {}}],
"infer_processors": [{"class": "HighFreqNorm", "module_path": "highfreq_processor"}],
}
DATA_HANDLER_CONFIG1 = {
"start_time": start_time,

View File

@@ -0,0 +1 @@
xgboost

View File

@@ -4,6 +4,7 @@
"""
This example shows how a TrainerRM works based on TaskManager with rolling tasks.
After training, how to collect the rolling results will be shown in task_collecting.
Based on the ability of TaskManager, `worker` method offer a simple way for multiprocessing.
"""
from pprint import pprint
@@ -13,10 +14,10 @@ import qlib
from qlib.config import REG_CN
from qlib.workflow import R
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.task.manage import TaskManager, run_task
from qlib.workflow.task.collect import RecorderCollector
from qlib.model.ens.group import RollingGroup
from qlib.model.trainer import TrainerRM
from qlib.model.trainer import TrainerRM, task_train
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
@@ -68,6 +69,11 @@ class RollingTaskExample:
trainer = TrainerRM(self.experiment_name, self.task_pool)
trainer.train(tasks)
def worker(self):
# train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker.
print("========== worker ==========")
run_task(task_train, self.task_pool, experiment_name=self.experiment_name)
def task_collecting(self):
print("========== task_collecting ==========")

View File

@@ -5,6 +5,7 @@
This example is about how can simulate the OnlineManager based on rolling tasks.
"""
from pprint import pprint
import fire
import qlib
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
@@ -13,7 +14,7 @@ from qlib.workflow.online.manager import OnlineManager
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.task.manage import TaskManager
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE
class OnlineSimulationExample:
@@ -22,8 +23,8 @@ class OnlineSimulationExample:
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
exp_name="rolling_exp",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
task_pool="rolling_task",
rolling_step=80,
start_time="2018-09-10",
@@ -46,7 +47,7 @@ class OnlineSimulationExample:
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
"""
if tasks is None:
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE]
self.exp_name = exp_name
self.task_pool = task_pool
self.start_time = start_time
@@ -59,7 +60,7 @@ class OnlineSimulationExample:
self.rolling_gen = RollingGen(
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
self.trainer = TrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
self.rolling_online_manager = OnlineManager(
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
trainer=self.trainer,
@@ -85,6 +86,15 @@ class OnlineSimulationExample:
print("========== signals ==========")
print(self.rolling_online_manager.get_signals())
def worker(self):
# train tasks by other progress or machines for multiprocessing
# FIXME: only can call after finishing simulation when using DelayTrainerRM, or there will be some exception.
print("========== worker ==========")
if isinstance(self.trainer, TrainerRM):
self.trainer.worker()
else:
print(f"{type(self.trainer)} is not supported for worker.")
if __name__ == "__main__":
## to run all workflow automatically with your own parameters, use the command below

View File

@@ -13,11 +13,13 @@ Finally, the OnlineManager will finish second routine and update all strategies.
import os
import fire
import qlib
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM, end_task_train, task_train
from qlib.workflow import R
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.online.manager import OnlineManager
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING, CSI100_RECORD_LGB_TASK_CONFIG_ROLLING
from qlib.workflow.task.manage import TaskManager
class RollingOnlineExample:
@@ -25,16 +27,17 @@ class RollingOnlineExample:
self,
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
trainer=DelayTrainerRM(), # you can choose from TrainerR, TrainerRM, DelayTrainerR, DelayTrainerRM
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
rolling_step=550,
tasks=None,
add_tasks=None,
):
if add_tasks is None:
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG]
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG_ROLLING]
if tasks is None:
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG]
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING]
mongo_conf = {
"task_url": task_url, # your MongoDB url
"task_db_name": task_db_name, # database name
@@ -53,17 +56,28 @@ class RollingOnlineExample:
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
)
)
self.rolling_online_manager = OnlineManager(strategies)
self.trainer = trainer
self.rolling_online_manager = OnlineManager(strategies, trainer=self.trainer)
_ROLLING_MANAGER_PATH = (
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
)
def worker(self):
# train tasks by other progress or machines for multiprocessing
print("========== worker ==========")
if isinstance(self.trainer, TrainerRM):
for task in self.tasks + self.add_tasks:
name_id = task["model"]["class"]
self.trainer.worker(experiment_name=name_id)
else:
print(f"{type(self.trainer)} is not supported for worker.")
# Reset all things to the first status, be careful to save important data
def reset(self):
for task in self.tasks + self.add_tasks:
name_id = task["model"]["class"]
TaskManager(task_pool=name_id).remove()
exp = R.get_exp(experiment_name=name_id)
for rid in exp.list_recorders():
exp.delete_recorder(rid)

View File

@@ -23,7 +23,6 @@ from qlib.config import REG_CN
from qlib.workflow import R
from qlib.tests.data import GetData
# init qlib
provider_uri = "~/.qlib/qlib_data/cn_data"
exp_folder_name = "run_all_model_records"
@@ -40,6 +39,7 @@ exp_manager = {
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
# decorator to check the arguments
def only_allow_defined_args(function_to_decorate):
@functools.wraps(function_to_decorate)
@@ -92,7 +92,8 @@ def create_env():
# function to execute the cmd
def execute(cmd):
def execute(cmd, wait_when_err=False):
print("Running CMD:", cmd)
with subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True, shell=True) as p:
for line in p.stdout:
sys.stdout.write(line.split("\b")[0])
@@ -102,6 +103,8 @@ def execute(cmd):
sys.stdout.write("\b" * 10 + "\b".join(line.split("\b")[1:-1]))
if p.returncode != 0:
if wait_when_err:
input("Press Enter to Continue")
return p.stderr
else:
return None
@@ -184,7 +187,15 @@ def gen_and_save_md_table(metrics, dataset):
# function to run the all the models
@only_allow_defined_args
def run(times=1, models=None, dataset="Alpha360", exclude=False):
def run(
times=1,
models=None,
dataset="Alpha360",
exclude=False,
qlib_uri: str = "git+https://github.com/microsoft/qlib#egg=pyqlib",
wait_before_rm_env: bool = False,
wait_when_err: bool = False,
):
"""
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parrallel running the same model
@@ -200,6 +211,13 @@ def run(times=1, models=None, dataset="Alpha360", exclude=False):
determines whether the model being used is excluded or included.
dataset : str
determines the dataset to be used for each model.
qlib_uri : str
the uri to install qlib with pip
it could be url on the we or local path
wait_before_rm_env : bool
wait before remove environment.
wait_when_err : bool
wait when errors raised when executing commands
Usage:
-------
@@ -240,32 +258,36 @@ def run(times=1, models=None, dataset="Alpha360", exclude=False):
sys.stderr.write("\n")
# install requirements.txt
sys.stderr.write("Installing requirements.txt...\n")
execute(f"{python_path} -m pip install -r {req_path}")
execute(f"{python_path} -m pip install -r {req_path}", wait_when_err=wait_when_err)
sys.stderr.write("\n")
# setup gpu for tft
if fn == "TFT":
execute(
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn"
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn",
wait_when_err=wait_when_err,
)
sys.stderr.write("\n")
# install qlib
sys.stderr.write("Installing qlib...\n")
execute(f"{python_path} -m pip install --upgrade pip") # TODO: FIX ME!
execute(f"{python_path} -m pip install --upgrade cython") # TODO: FIX ME!
execute(f"{python_path} -m pip install --upgrade pip", wait_when_err=wait_when_err) # TODO: FIX ME!
execute(f"{python_path} -m pip install --upgrade cython", wait_when_err=wait_when_err) # TODO: FIX ME!
if fn == "TFT":
execute(
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e git+https://github.com/microsoft/qlib#egg=pyqlib"
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e {qlib_uri}",
wait_when_err=wait_when_err,
) # TODO: FIX ME!
else:
execute(
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e git+https://github.com/microsoft/qlib#egg=pyqlib"
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e {qlib_uri}",
wait_when_err=wait_when_err,
) # TODO: FIX ME!
sys.stderr.write("\n")
# run workflow_by_config for multiple times
for i in range(times):
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
errs = execute(
f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn} {exp_folder_name}"
f"{python_path} {env_path / 'bin' / 'qrun'} {yaml_path} {fn} {exp_folder_name}",
wait_when_err=wait_when_err,
)
if errs is not None:
_errs = errors.get(fn, {})
@@ -274,6 +296,8 @@ def run(times=1, models=None, dataset="Alpha360", exclude=False):
sys.stderr.write("\n")
# remove env
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
if wait_before_rm_env:
input("Press Enter to Continue")
shutil.rmtree(env_path)
# getting all results
sys.stderr.write(f"Retrieving results...\n")

View File

@@ -220,7 +220,7 @@
"\n",
"# backtest and analysis\n",
"with R.start(experiment_name=\"backtest_analysis\"):\n",
" recorder = R.get_recorder(rid, experiment_name=\"train_model\")\n",
" recorder = R.get_recorder(recorder_id=rid, experiment_name=\"train_model\")\n",
" model = recorder.load_object(\"trained_model\")\n",
"\n",
" # prediction\n",
@@ -249,7 +249,7 @@
"source": [
"from qlib.contrib.report import analysis_model, analysis_position\n",
"from qlib.data import D\n",
"recorder = R.get_recorder(ba_rid, experiment_name=\"backtest_analysis\")\n",
"recorder = R.get_recorder(recorder_id=ba_rid, experiment_name=\"backtest_analysis\")\n",
"pred_df = recorder.load_object(\"pred.pkl\")\n",
"pred_df_dates = pred_df.index.get_level_values(level='datetime')\n",
"report_normal_df = recorder.load_object(\"portfolio_analysis/report_normal.pkl\")\n",

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
__version__ = "0.6.3.99"
__version__ = "0.7.1"
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
@@ -20,76 +20,83 @@ def init(default_conf="client", **kwargs):
from .config import C
from .data.cache import H
H.clear()
# FIXME: this logger ignored the level in config
logger = get_module_logger("Initialization", level=logging.INFO)
skip_if_reg = kwargs.pop("skip_if_reg", False)
if skip_if_reg and C.registered:
# if we reinitialize Qlib during running an experiment `R.start`.
# it will result in loss of the recorder
logger.warning("Skip initialization because `skip_if_reg is True`")
return
H.clear()
C.set(default_conf, **kwargs)
# check path if server/local
if C.get_uri_type() == C.LOCAL_URI:
if not os.path.exists(C["provider_uri"]):
if C["auto_mount"]:
logger.error(
f"Invalid provider uri: {C['provider_uri']}, please check if a valid provider uri has been set. This path does not exist."
)
else:
logger.warning(f"auto_path is False, please make sure {C['mount_path']} is mounted")
elif C.get_uri_type() == C.NFS_URI:
_mount_nfs_uri(C)
else:
raise NotImplementedError(f"This type of URI is not supported")
# mount nfs
for _freq, provider_uri in C.provider_uri.items():
mount_path = C["mount_path"][_freq]
# check path if server/local
uri_type = C.dpm.get_uri_type(provider_uri)
if uri_type == C.LOCAL_URI:
if not Path(provider_uri).exists():
if C["auto_mount"]:
logger.error(
f"Invalid provider uri: {provider_uri}, please check if a valid provider uri has been set. This path does not exist."
)
else:
logger.warning(f"auto_path is False, please make sure {mount_path} is mounted")
elif uri_type == C.NFS_URI:
_mount_nfs_uri(provider_uri, mount_path, C["auto_mount"])
else:
raise NotImplementedError(f"This type of URI is not supported")
C.register()
if "flask_server" in C:
logger.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
logger.info("qlib successfully initialized based on %s settings." % default_conf)
logger.info(f"data_path={C.get_data_path()}")
data_path = {_freq: C.dpm.get_data_path(_freq) for _freq in C.dpm.provider_uri.keys()}
logger.info(f"data_path={data_path}")
def _mount_nfs_uri(C):
def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
LOG = get_module_logger("mount nfs", level=logging.INFO)
# FIXME: the C["provider_uri"] is modified in this function
# If it is not modified, we can pass only provider_uri or mount_path instead of C
mount_command = "sudo mount.nfs %s %s" % (C["provider_uri"], C["mount_path"])
mount_command = "sudo mount.nfs %s %s" % (provider_uri, mount_path)
# If the provider uri looks like this 172.23.233.89//data/csdesign'
# It will be a nfs path. The client provider will be used
if not C["auto_mount"]:
if not os.path.exists(C["mount_path"]):
if not auto_mount:
if not Path(mount_path).exists():
raise FileNotFoundError(
f"Invalid mount path: {C['mount_path']}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
f"Invalid mount path: {mount_path}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
)
else:
# Judging system type
sys_type = platform.system()
if "win" in sys_type.lower():
# system: window
exec_result = os.popen("mount -o anon %s %s" % (C["provider_uri"], C["mount_path"] + ":"))
exec_result = os.popen("mount -o anon %s %s" % (provider_uri, mount_path + ":"))
result = exec_result.read()
if "85" in result:
LOG.warning("already mounted or window mount path already exists")
LOG.warning(f"{provider_uri} on Windows:{mount_path} is already mounted")
elif "53" in result:
raise OSError("not find network path")
elif "error" in result or "错误" in result:
raise OSError("Invalid mount path")
elif C["provider_uri"] in result:
elif provider_uri in result:
LOG.info("window success mount..")
else:
raise OSError(f"unknown error: {result}")
# config mount path
C["mount_path"] = C["mount_path"] + ":\\"
else:
# system: linux/Unix/Mac
# check mount
_remote_uri = C["provider_uri"]
_remote_uri = _remote_uri[:-1] if _remote_uri.endswith("/") else _remote_uri
_mount_path = C["mount_path"]
_mount_path = _mount_path[:-1] if _mount_path.endswith("/") else _mount_path
_remote_uri = provider_uri[:-1] if provider_uri.endswith("/") else provider_uri
_mount_path = mount_path[:-1] if mount_path.endswith("/") else mount_path
_check_level_num = 2
_is_mount = False
while _check_level_num:
@@ -115,11 +122,9 @@ def _mount_nfs_uri(C):
if not _is_mount:
try:
os.makedirs(C["mount_path"], exist_ok=True)
Path(mount_path).mkdir(parents=True, exist_ok=True)
except Exception:
raise OSError(
f"Failed to create directory {C['mount_path']}, please create {C['mount_path']} manually!"
)
raise OSError(f"Failed to create directory {mount_path}, please create {mount_path} manually!")
# check nfs-common
command_res = os.popen("dpkg -l | grep nfs-common")
@@ -130,11 +135,11 @@ def _mount_nfs_uri(C):
command_status = os.system(mount_command)
if command_status == 256:
raise OSError(
f"mount {C['provider_uri']} on {C['mount_path']} error! Needs SUDO! Please mount manually: {mount_command}"
f"mount {provider_uri} on {mount_path} error! Needs SUDO! Please mount manually: {mount_command}"
)
elif command_status == 32512:
# LOG.error("Command error")
raise OSError(f"mount {C['provider_uri']} on {C['mount_path']} error! Command error")
raise OSError(f"mount {provider_uri} on {mount_path} error! Command error")
elif command_status == 0:
LOG.info("Mount finished")
else:
@@ -197,14 +202,15 @@ def auto_init(**kwargs):
- Find the project configuration and init qlib
- The parsing process will be affected by the `conf_type` of the configuration file
- Init qlib with default config
- Skip initialization if already initialized
"""
kwargs["skip_if_reg"] = kwargs.get("skip_if_reg", True)
try:
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
except FileNotFoundError:
init(**kwargs)
else:
conf_pp = pp / "config.yaml"
with conf_pp.open() as f:
conf = yaml.safe_load(f)

View File

@@ -15,8 +15,10 @@ import os
import re
import copy
import logging
import platform
import multiprocessing
from pathlib import Path
from typing import Union
class Config:
@@ -73,6 +75,12 @@ REG_US = "us"
NUM_USABLE_CPU = max(multiprocessing.cpu_count() - 2, 1)
DISK_DATASET_CACHE = "DiskDatasetCache"
SIMPLE_DATASET_CACHE = "SimpleDatasetCache"
DISK_EXPRESSION_CACHE = "DiskExpressionCache"
DEPENDENCY_REDIS_CACHE = (DISK_DATASET_CACHE, DISK_EXPRESSION_CACHE)
_default_config = {
# data provider config
"calendar_provider": "LocalCalendarProvider",
@@ -82,6 +90,15 @@ _default_config = {
"dataset_provider": "LocalDatasetProvider",
"provider": "LocalProvider",
# config it in qlib.init()
# "provider_uri" str or dict:
# # str
# "~/.qlib/stock_data/cn_data"
# # dict
# {"day": "~/.qlib/stock_data/cn_data", "1min": "~/.qlib/stock_data/cn_data_1min"}
# NOTE: provider_uri priority
# 1. backend_config: backend_obj["kwargs"]["provider_uri"]
# 2. backend_config: backend_obj["kwargs"]["provider_uri_map"]
# 3. qlib.init: provider_uri
"provider_uri": "",
# cache
"expression_cache": None,
@@ -167,8 +184,9 @@ MODE_CONF = {
"redis_task_db": 1,
"kernels": NUM_USABLE_CPU,
# cache
"expression_cache": "DiskExpressionCache",
"dataset_cache": "DiskDatasetCache",
"expression_cache": DISK_EXPRESSION_CACHE,
"dataset_cache": DISK_DATASET_CACHE,
"local_cache_path": Path("~/.cache/qlib_simple_cache").expanduser().resolve(),
"mount_path": None,
},
"client": {
@@ -183,8 +201,10 @@ MODE_CONF = {
"provider_uri": "~/.qlib/qlib_data/cn_data",
# cache
# Using parameter 'remote' to announce the client is using server_cache, and the writing access will be disabled.
"expression_cache": "DiskExpressionCache",
"dataset_cache": "DiskDatasetCache",
"expression_cache": DISK_EXPRESSION_CACHE,
"dataset_cache": DISK_DATASET_CACHE,
# SimpleDatasetCache directory
"local_cache_path": Path("~/.cache/qlib_simple_cache").expanduser().resolve(),
"calendar_cache": None,
# client config
"kernels": NUM_USABLE_CPU,
@@ -195,7 +215,10 @@ MODE_CONF = {
"timeout": 100,
"logging_level": logging.INFO,
"region": REG_CN,
## Custom Operator
# custom operator
# each element of custom_ops should be Type[ExpressionOps] or dict
# if element of custom_ops is Type[ExpressionOps], it represents the custom operator class
# if element of custom_ops is dict, it represents the config of custom operator and should include `class` and `module_path` keys.
"custom_ops": [],
},
}
@@ -225,11 +248,43 @@ class QlibConfig(Config):
# URI_TYPE
LOCAL_URI = "local"
NFS_URI = "nfs"
DEFAULT_FREQ = "__DEFAULT_FREQ"
def __init__(self, default_conf):
super().__init__(default_conf)
self._registered = False
class DataPathManager:
def __init__(self, provider_uri: Union[str, Path, dict], mount_path: Union[str, Path, dict]):
self.provider_uri = provider_uri
self.mount_path = mount_path
@staticmethod
def get_uri_type(uri: Union[str, Path]):
uri = uri if isinstance(uri, str) else str(uri.expanduser().resolve())
is_win = re.match("^[a-zA-Z]:.*", uri) is not None # such as 'C:\\data', 'D:'
# such as 'host:/data/' (User may define short hostname by themselves or use localhost)
is_nfs_or_win = re.match("^[^/]+:.+", uri) is not None
if is_nfs_or_win and not is_win:
return QlibConfig.NFS_URI
else:
return QlibConfig.LOCAL_URI
def get_data_path(self, freq: str = None) -> Path:
if freq is None or freq not in self.provider_uri:
freq = QlibConfig.DEFAULT_FREQ
_provider_uri = self.provider_uri[freq]
if self.get_uri_type(_provider_uri) == QlibConfig.LOCAL_URI:
return Path(_provider_uri)
elif self.get_uri_type(_provider_uri) == QlibConfig.NFS_URI:
if "win" in platform.system().lower():
# windows, mount_path is the drive
return Path(f"{self.mount_path[freq]}:\\")
return Path(self.mount_path[freq])
else:
raise NotImplementedError(f"This type of uri is not supported")
def set_mode(self, mode):
# raise KeyError
self.update(MODE_CONF[mode])
@@ -239,32 +294,43 @@ class QlibConfig(Config):
# raise KeyError
self.update(_default_region_config[region])
@staticmethod
def is_depend_redis(cache_name: str):
return cache_name in DEPENDENCY_REDIS_CACHE
@property
def dpm(self):
return self.DataPathManager(self["provider_uri"], self["mount_path"])
def resolve_path(self):
# resolve path
if self["mount_path"] is not None:
self["mount_path"] = str(Path(self["mount_path"]).expanduser().resolve())
_mount_path = self["mount_path"]
_provider_uri = self["provider_uri"]
if _provider_uri is None:
raise ValueError("provider_uri cannot be None")
if not isinstance(_provider_uri, dict):
_provider_uri = {self.DEFAULT_FREQ: _provider_uri}
if not isinstance(_mount_path, dict):
_mount_path = {_freq: _mount_path for _freq in _provider_uri.keys()}
if self.get_uri_type() == QlibConfig.LOCAL_URI:
self["provider_uri"] = str(Path(self["provider_uri"]).expanduser().resolve())
# check provider_uri and mount_path
_miss_freq = set(_provider_uri.keys()) - set(_mount_path.keys())
assert len(_miss_freq) == 0, f"mount_path is missing freq: {_miss_freq}"
def get_uri_type(self):
is_win = re.match("^[a-zA-Z]:.*", self["provider_uri"]) is not None # such as 'C:\\data', 'D:'
is_nfs_or_win = (
re.match("^[^/]+:.+", self["provider_uri"]) is not None
) # such as 'host:/data/' (User may define short hostname by themselves or use localhost)
# resolve
for _freq, _uri in _provider_uri.items():
# provider_uri
if self.DataPathManager.get_uri_type(_uri) == QlibConfig.LOCAL_URI:
_provider_uri[_freq] = str(Path(_uri).expanduser().resolve())
# mount_path
_mount_path[_freq] = (
_mount_path[_freq]
if _mount_path[_freq] is None
else str(Path(_mount_path[_freq]).expanduser().resolve())
)
if is_nfs_or_win and not is_win:
return QlibConfig.NFS_URI
else:
return QlibConfig.LOCAL_URI
def get_data_path(self):
if self.get_uri_type() == QlibConfig.LOCAL_URI:
return self["provider_uri"]
elif self.get_uri_type() == QlibConfig.NFS_URI:
return self["mount_path"]
else:
raise NotImplementedError(f"This type of uri is not supported")
self["provider_uri"] = _provider_uri
self["mount_path"] = _mount_path
def set(self, default_conf="client", **kwargs):
from .utils import set_log_with_config, get_module_logger, can_use_cache
@@ -296,11 +362,20 @@ class QlibConfig(Config):
if not (self["expression_cache"] is None and self["dataset_cache"] is None):
# check redis
if not can_use_cache():
logger.warning(
f"redis connection failed(host={self['redis_host']} port={self['redis_port']}), cache will not be used!"
)
self["expression_cache"] = None
self["dataset_cache"] = None
log_str = ""
# check expression cache
if self.is_depend_redis(self["expression_cache"]):
log_str += self["expression_cache"]
self["expression_cache"] = None
# check dataset cache
if self.is_depend_redis(self["dataset_cache"]):
log_str += f" and {self['dataset_cache']}" if log_str else self["dataset_cache"]
self["dataset_cache"] = None
if log_str:
logger.warning(
f"redis connection failed(host={self['redis_host']} port={self['redis_port']}), "
f"{log_str} will not be used!"
)
def register(self):
from .utils import init_instance_by_config

View File

@@ -16,6 +16,7 @@ def get_benchmark_weight(
start_date=None,
end_date=None,
path=None,
freq="day",
):
"""get_benchmark_weight
@@ -25,6 +26,7 @@ def get_benchmark_weight(
:param start_date:
:param end_date:
:param path:
:param freq:
:return: The weight distribution of the the benchmark described by a pandas dataframe
Every row corresponds to a trading day.
@@ -33,7 +35,7 @@ def get_benchmark_weight(
"""
if not path:
path = Path(C.get_data_path()).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
path = Path(C.dpm.get_data_path(freq)).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
# TODO: the storage of weights should be implemented in a more elegent way
# TODO: The benchmark is not consistant with the filename in instruments.
bench_weight_df = pd.read_csv(path, usecols=["code", "date", "index", "weight"])
@@ -222,6 +224,7 @@ def brinson_pa(
group_method="category",
group_n=None,
deal_price="vwap",
freq="day",
):
"""brinson profit attribution
@@ -243,7 +246,7 @@ def brinson_pa(
start_date, end_date = min(dates), max(dates)
bench_stock_weight = get_benchmark_weight(bench, start_date, end_date)
bench_stock_weight = get_benchmark_weight(bench, start_date, end_date, freq)
# The attributes for allocation will not
if not group_field.startswith("$"):
@@ -259,13 +262,14 @@ def brinson_pa(
start_time=shift_start_date,
end_time=end_date,
as_list=True,
freq=freq,
)
stock_df = D.features(
instruments,
[group_field, deal_price],
start_time=shift_start_date,
end_time=end_date,
freq="day",
freq=freq,
)
stock_df.columns = [group_field, "deal_price"]

View File

@@ -0,0 +1,346 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import copy
import torch
import warnings
import numpy as np
import pandas as pd
from qlib.utils import init_instance_by_config
from qlib.data.dataset import DatasetH, DataHandler
device = "cuda" if torch.cuda.is_available() else "cpu"
def _to_tensor(x):
if not isinstance(x, torch.Tensor):
return torch.tensor(x, dtype=torch.float, device=device)
return x
def _create_ts_slices(index, seq_len):
"""
create time series slices from pandas index
Args:
index (pd.MultiIndex): pandas multiindex with <instrument, datetime> order
seq_len (int): sequence length
"""
assert isinstance(index, pd.MultiIndex), "unsupported index type"
assert seq_len > 0, "sequence length should be larger than 0"
assert index.is_monotonic_increasing, "index should be sorted"
# number of dates for each instrument
sample_count_by_insts = index.to_series().groupby(level=0).size().values
# start index for each instrument
start_index_of_insts = np.roll(np.cumsum(sample_count_by_insts), 1)
start_index_of_insts[0] = 0
# all the [start, stop) indices of features
# features between [start, stop) will be used to predict label at `stop - 1`
slices = []
for cur_loc, cur_cnt in zip(start_index_of_insts, sample_count_by_insts):
for stop in range(1, cur_cnt + 1):
end = cur_loc + stop
start = max(end - seq_len, 0)
slices.append(slice(start, end))
slices = np.array(slices, dtype="object")
assert len(slices) == len(index) # the i-th slice = index[i]
return slices
def _get_date_parse_fn(target):
"""get date parse function
This method is used to parse date arguments as target type.
Example:
get_date_parse_fn('20120101')('2017-01-01') => '20170101'
get_date_parse_fn(20120101)('2017-01-01') => 20170101
"""
if isinstance(target, pd.Timestamp):
_fn = lambda x: pd.Timestamp(x) # Timestamp('2020-01-01')
elif isinstance(target, int):
_fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201
elif isinstance(target, str) and len(target) == 8:
_fn = lambda x: str(x).replace("-", "")[:8] # '20200201'
else:
_fn = lambda x: x # '2021-01-01'
return _fn
def _maybe_padding(x, seq_len, zeros=None):
"""padding 2d <time * feature> data with zeros
Args:
x (np.ndarray): 2d data with shape <time * feature>
seq_len (int): target sequence length
zeros (np.ndarray): zeros with shape <seq_len * feature>
"""
assert seq_len > 0, "sequence length should be larger than 0"
if zeros is None:
zeros = np.zeros((seq_len, x.shape[1]), dtype=np.float32)
else:
assert len(zeros) >= seq_len, "zeros matrix is not large enough for padding"
if len(x) != seq_len: # padding zeros
x = np.concatenate([zeros[: seq_len - len(x), : x.shape[1]], x], axis=0)
return x
class MTSDatasetH(DatasetH):
"""Memory Augmented Time Series Dataset
Args:
handler (DataHandler): data handler
segments (dict): data split segments
seq_len (int): time series sequence length
horizon (int): label horizon
num_states (int): how many memory states to be added
memory_mode (str): memory mode (daily or sample)
batch_size (int): batch size (<0 will use daily sampling)
n_samples (int): number of samples in the same day
shuffle (bool): whether shuffle data
drop_last (bool): whether drop last batch < batch_size
input_size (int): reshape flatten rows as this input_size (backward compatibility)
"""
def __init__(
self,
handler,
segments,
seq_len=60,
horizon=0,
num_states=0,
memory_mode="sample",
batch_size=-1,
n_samples=None,
shuffle=True,
drop_last=False,
input_size=None,
**kwargs,
):
assert num_states == 0 or horizon > 0, "please specify `horizon` to avoid data leakage"
assert memory_mode in ["sample", "daily"], "unsupported memory mode"
assert memory_mode == "sample" or batch_size < 0, "daily memory requires daily sampling (`batch_size < 0`)"
assert batch_size != 0, "invalid batch size"
if batch_size > 0 and n_samples is not None:
warnings.warn("`n_samples` can only be used for daily sampling (`batch_size < 0`)")
self.seq_len = seq_len
self.horizon = horizon
self.num_states = num_states
self.memory_mode = memory_mode
self.batch_size = batch_size
self.n_samples = n_samples
self.shuffle = shuffle
self.drop_last = drop_last
self.input_size = input_size
self.params = (batch_size, n_samples, drop_last, shuffle) # for train/eval switch
super().__init__(handler, segments, **kwargs)
def setup_data(self, handler_kwargs: dict = None, **kwargs):
super().setup_data(**kwargs)
if handler_kwargs is not None:
self.handler.setup_data(**handler_kwargs)
# pre-fetch data and change index to <code, date>
# NOTE: we will use inplace sort to reduce memory use
try:
df = self.handler._learn.copy() # use copy otherwise recorder will fail
# FIXME: currently we cannot support switching from `_learn` to `_infer` for inference
except:
warnings.warn("cannot access `_learn`, will load raw data")
df = self.handler._data.copy()
df.index = df.index.swaplevel()
df.sort_index(inplace=True)
# convert to numpy
self._data = df["feature"].values.astype("float32")
np.nan_to_num(self._data, copy=False) # NOTE: fillna in case users forget using the fillna processor
self._label = df["label"].squeeze().values.astype("float32")
self._index = df.index
if self.input_size is not None and self.input_size != self._data.shape[1]:
warnings.warn("the data has different shape from input_size and the data will be reshaped")
assert self._data.shape[1] % self.input_size == 0, "data mismatch, please check `input_size`"
# create batch slices
self._batch_slices = _create_ts_slices(self._index, self.seq_len)
# create daily slices
daily_slices = {date: [] for date in sorted(self._index.unique(level=1))} # sorted by date
for i, (code, date) in enumerate(self._index):
daily_slices[date].append(self._batch_slices[i])
self._daily_slices = np.array(list(daily_slices.values()), dtype="object")
self._daily_index = pd.Series(list(daily_slices.keys())) # index is the original date index
# add memory (sample wise and daily)
if self.memory_mode == "sample":
self._memory = np.zeros((len(self._data), self.num_states), dtype=np.float32)
elif self.memory_mode == "daily":
self._memory = np.zeros((len(self._daily_index), self.num_states), dtype=np.float32)
else:
raise ValueError(f"invalid memory_mode `{self.memory_mode}`")
# padding tensor
self._zeros = np.zeros((self.seq_len, max(self.num_states, self._data.shape[1])), dtype=np.float32)
def _prepare_seg(self, slc, **kwargs):
fn = _get_date_parse_fn(self._index[0][1])
start_date = fn(slc.start)
end_date = fn(slc.stop)
obj = copy.copy(self) # shallow copy
# NOTE: Seriable will disable copy `self._data` so we manually assign them here
obj._data = self._data # reference (no copy)
obj._label = self._label
obj._index = self._index
obj._memory = self._memory
obj._zeros = self._zeros
# update index for this batch
date_index = self._index.get_level_values(1)
obj._batch_slices = self._batch_slices[(date_index >= start_date) & (date_index <= end_date)]
mask = (self._daily_index.values >= start_date) & (self._daily_index.values <= end_date)
obj._daily_slices = self._daily_slices[mask]
obj._daily_index = self._daily_index[mask]
return obj
def restore_index(self, index):
return self._index[index]
def restore_daily_index(self, daily_index):
return pd.Index(self._daily_index.loc[daily_index])
def assign_data(self, index, vals):
if self.num_states == 0:
raise ValueError("cannot assign data as `num_states==0`")
if isinstance(vals, torch.Tensor):
vals = vals.detach().cpu().numpy()
self._memory[index] = vals
def clear_memory(self):
if self.num_states == 0:
raise ValueError("cannot clear memory as `num_states==0`")
self._memory[:] = 0
def train(self):
"""enable traning mode"""
self.batch_size, self.n_samples, self.drop_last, self.shuffle = self.params
def eval(self):
"""enable evaluation mode"""
self.batch_size = -1
self.n_samples = None
self.drop_last = False
self.shuffle = False
def _get_slices(self):
if self.batch_size < 0: # daily sampling
slices = self._daily_slices.copy()
batch_size = -1 * self.batch_size
else: # normal sampling
slices = self._batch_slices.copy()
batch_size = self.batch_size
return slices, batch_size
def __len__(self):
slices, batch_size = self._get_slices()
if self.drop_last:
return len(slices) // batch_size
return (len(slices) + batch_size - 1) // batch_size
def __iter__(self):
slices, batch_size = self._get_slices()
indices = np.arange(len(slices))
if self.shuffle:
np.random.shuffle(indices)
for i in range(len(indices))[::batch_size]:
if self.drop_last and i + batch_size > len(indices):
break
data = [] # store features
label = [] # store labels
index = [] # store index
state = [] # store memory states
daily_index = [] # store daily index
daily_count = [] # store number of samples for each day
for j in indices[i : i + batch_size]:
# normal sampling: self.batch_size > 0 => slices is a list => slices_subset is a slice
# daily sampling: self.batch_size < 0 => slices is a nested list => slices_subset is a list
slices_subset = slices[j]
# daily sampling
# each slices_subset contains a list of slices for multiple stocks
# NOTE: daily sampling is used in 1) eval mode, 2) train mode with self.batch_size < 0
if self.batch_size < 0:
# store daily index
idx = self._daily_index.index[j] # daily_index.index is the index of the original data
daily_index.append(idx)
# store daily memory if specified
# NOTE: daily memory always requires daily sampling (self.batch_size < 0)
if self.memory_mode == "daily":
slc = slice(max(idx - self.seq_len - self.horizon, 0), max(idx - self.horizon, 0))
state.append(_maybe_padding(self._memory[slc], self.seq_len, self._zeros))
# down-sample stocks and store count
if self.n_samples and 0 < self.n_samples < len(slices_subset): # intraday subsample
slices_subset = np.random.choice(slices_subset, self.n_samples, replace=False)
daily_count.append(len(slices_subset))
# normal sampling
# each slices_subset is a single slice
# NOTE: normal sampling is used in train mode with self.batch_size > 0
else:
slices_subset = [slices_subset]
for slc in slices_subset:
# legacy support for Alpha360 data by `input_size`
if self.input_size:
data.append(self._data[slc.stop - 1].reshape(self.input_size, -1).T)
else:
data.append(_maybe_padding(self._data[slc], self.seq_len, self._zeros))
if self.memory_mode == "sample":
state.append(_maybe_padding(self._memory[slc], self.seq_len, self._zeros)[: -self.horizon])
label.append(self._label[slc.stop - 1])
index.append(slc.stop - 1)
# end slices loop
# end indices batch loop
# concate
data = _to_tensor(np.stack(data))
state = _to_tensor(np.stack(state))
label = _to_tensor(np.stack(label))
index = np.array(index)
daily_index = np.array(daily_index)
daily_count = np.array(daily_count)
# yield -> generator
yield {
"data": data,
"label": label,
"state": state,
"index": index,
"daily_index": daily_index,
"daily_count": daily_count,
}
# end indice loop

View File

@@ -3,7 +3,7 @@
from ...data.dataset.handler import DataHandlerLP
from ...data.dataset.processor import Processor
from ...utils import get_cls_kwargs
from ...utils import get_callable_kwargs
from ...data.dataset import processor as processor_module
from ...log import TimeInspector
from inspect import getfullargspec
@@ -14,7 +14,7 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time):
new_l = []
for p in proc_l:
if not isinstance(p, Processor):
klass, pkwargs = get_cls_kwargs(p, processor_module)
klass, pkwargs = get_callable_kwargs(p, processor_module)
args = getfullargspec(klass).args
if "fit_start_time" in args and "fit_end_time" in args:
assert (
@@ -26,8 +26,10 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time):
"fit_end_time": fit_end_time,
}
)
# FIXME: the `module_path` parameter is missed.
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
proc_config = {"class": klass.__name__, "kwargs": pkwargs}
if isinstance(p, dict) and "module_path" in p:
proc_config["module_path"] = p["module_path"]
new_l.append(proc_config)
else:
new_l.append(p)
return new_l
@@ -56,6 +58,7 @@ class Alpha360(DataHandlerLP):
fit_start_time=None,
fit_end_time=None,
filter_pipe=None,
inst_processor=None,
**kwargs,
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
@@ -70,6 +73,7 @@ class Alpha360(DataHandlerLP):
},
"filter_pipe": filter_pipe,
"freq": freq,
"inst_processor": inst_processor,
},
}
@@ -142,6 +146,7 @@ class Alpha158(DataHandlerLP):
fit_end_time=None,
process_type=DataHandlerLP.PTYPE_A,
filter_pipe=None,
inst_processor=None,
**kwargs,
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
@@ -156,6 +161,7 @@ class Alpha158(DataHandlerLP):
},
"filter_pipe": filter_pipe,
"freq": freq,
"inst_processor": inst_processor,
},
}
super().__init__(

View File

@@ -53,7 +53,6 @@ class GATs(Model):
early_stop=20,
loss="mse",
base_model="GRU",
with_pretrain=True,
model_path=None,
optimizer="adam",
GPU=0,
@@ -76,7 +75,6 @@ class GATs(Model):
self.optimizer = optimizer.lower()
self.loss = loss
self.base_model = base_model
self.with_pretrain = with_pretrain
self.model_path = model_path
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
@@ -94,7 +92,6 @@ class GATs(Model):
"\noptimizer : {}"
"\nloss_type : {}"
"\nbase_model : {}"
"\nwith_pretrain : {}"
"\nmodel_path : {}"
"\ndevice : {}"
"\nuse_GPU : {}"
@@ -110,7 +107,6 @@ class GATs(Model):
optimizer.lower(),
loss,
base_model,
with_pretrain,
model_path,
self.device,
self.use_gpu,
@@ -253,24 +249,22 @@ class GATs(Model):
evals_result["valid"] = []
# load pretrained base_model
if self.with_pretrain:
if self.model_path == None:
raise ValueError("the path of the pretrained model should be given first!")
self.logger.info("Loading pretrained model...")
if self.base_model == "LSTM":
pretrained_model = LSTMModel()
pretrained_model.load_state_dict(torch.load(self.model_path))
elif self.base_model == "GRU":
pretrained_model = GRUModel()
pretrained_model.load_state_dict(torch.load(self.model_path))
else:
raise ValueError("unknown base model name `%s`" % self.base_model)
if self.base_model == "LSTM":
pretrained_model = LSTMModel()
elif self.base_model == "GRU":
pretrained_model = GRUModel()
else:
raise ValueError("unknown base model name `%s`" % self.base_model)
model_dict = self.GAT_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
model_dict.update(pretrained_dict)
self.GAT_model.load_state_dict(model_dict)
self.logger.info("Loading pretrained model Done...")
if self.model_path is not None:
self.logger.info("Loading pretrained model...")
pretrained_model.load_state_dict(torch.load(self.model_path))
model_dict = self.GAT_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
model_dict.update(pretrained_dict)
self.GAT_model.load_state_dict(model_dict)
self.logger.info("Loading pretrained model Done...")
# train
self.logger.info("training...")

View File

@@ -29,8 +29,8 @@ class DailyBatchSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
self.data = self.data_source.data.loc[self.data_source.get_index()]
self.daily_count = self.data.groupby(level=0).size().values # calculate number of samples in each batch
# calculate number of samples in each batch
self.daily_count = pd.Series(index=self.data_source.get_index()).groupby("datetime").size().values
self.daily_index = np.roll(np.cumsum(self.daily_count), 1) # calculate begin index of each batch
self.daily_index[0] = 0
@@ -72,7 +72,6 @@ class GATs(Model):
early_stop=20,
loss="mse",
base_model="GRU",
with_pretrain=True,
model_path=None,
optimizer="adam",
GPU="0",
@@ -96,7 +95,6 @@ class GATs(Model):
self.optimizer = optimizer.lower()
self.loss = loss
self.base_model = base_model
self.with_pretrain = with_pretrain
self.model_path = model_path
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.n_jobs = n_jobs
@@ -115,7 +113,6 @@ class GATs(Model):
"\noptimizer : {}"
"\nloss_type : {}"
"\nbase_model : {}"
"\nwith_pretrain : {}"
"\nmodel_path : {}"
"\nvisible_GPU : {}"
"\nuse_GPU : {}"
@@ -131,7 +128,6 @@ class GATs(Model):
optimizer.lower(),
loss,
base_model,
with_pretrain,
model_path,
GPU,
self.use_gpu,
@@ -270,28 +266,22 @@ class GATs(Model):
evals_result["valid"] = []
# load pretrained base_model
if self.with_pretrain:
if self.model_path == None:
raise ValueError("the path of the pretrained model should be given first!")
self.logger.info("Loading pretrained model...")
if self.base_model == "LSTM":
pretrained_model = LSTMModel(
d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers
)
pretrained_model.load_state_dict(torch.load(self.model_path))
elif self.base_model == "GRU":
pretrained_model = GRUModel(
d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers
)
pretrained_model.load_state_dict(torch.load(self.model_path))
else:
raise ValueError("unknown base model name `%s`" % self.base_model)
if self.base_model == "LSTM":
pretrained_model = LSTMModel(d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers)
elif self.base_model == "GRU":
pretrained_model = GRUModel(d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers)
else:
raise ValueError("unknown base model name `%s`" % self.base_model)
model_dict = self.GAT_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
model_dict.update(pretrained_dict)
self.GAT_model.load_state_dict(model_dict)
self.logger.info("Loading pretrained model Done...")
if self.model_path is not None:
self.logger.info("Loading pretrained model...")
pretrained_model.load_state_dict(torch.load(self.model_path))
model_dict = self.GAT_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
model_dict.update(pretrained_dict)
self.GAT_model.load_state_dict(model_dict)
self.logger.info("Loading pretrained model Done...")
# train
self.logger.info("training...")

View File

@@ -0,0 +1,331 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
import math
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH, TSDatasetH
from ...data.dataset.handler import DataHandlerLP
from torch.nn.modules.container import ModuleList
# qrun examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml ”
class LocalformerModel(Model):
def __init__(
self,
d_feat: int = 20,
d_model: int = 64,
batch_size: int = 2048,
nhead: int = 2,
num_layers: int = 2,
dropout: float = 0,
n_epochs=100,
lr=0.0001,
metric="",
early_stop=5,
loss="mse",
optimizer="adam",
reg=1e-3,
n_jobs=10,
GPU=0,
seed=None,
**kwargs
):
# set hyper-parameters.
self.d_model = d_model
self.dropout = dropout
self.n_epochs = n_epochs
self.lr = lr
self.reg = reg
self.metric = metric
self.batch_size = batch_size
self.early_stop = early_stop
self.optimizer = optimizer.lower()
self.loss = loss
self.n_jobs = n_jobs
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
self.logger = get_module_logger("TransformerModel")
self.logger.info("Naive Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device))
if self.seed is not None:
np.random.seed(self.seed)
torch.manual_seed(self.seed)
self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device)
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
elif optimizer.lower() == "gd":
self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self.fitted = False
self.model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred.float() - label.float()) ** 2
return torch.mean(loss)
def loss_fn(self, pred, label):
mask = ~torch.isnan(label)
if self.loss == "mse":
return self.mse(pred[mask], label[mask])
raise ValueError("unknown loss `%s`" % self.loss)
def metric_fn(self, pred, label):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
def train_epoch(self, x_train, y_train):
x_train_values = x_train.values
y_train_values = np.squeeze(y_train.values)
self.model.train()
indices = np.arange(len(x_train_values))
np.random.shuffle(indices)
for i in range(len(indices))[:: self.batch_size]:
if len(indices) - i < self.batch_size:
break
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
pred = self.model(feature)
loss = self.loss_fn(pred, label)
self.train_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
self.train_optimizer.step()
def test_epoch(self, data_x, data_y):
# prepare training data
x_values = data_x.values
y_values = np.squeeze(data_y.values)
self.model.eval()
scores = []
losses = []
indices = np.arange(len(x_values))
for i in range(len(indices))[:: self.batch_size]:
if len(indices) - i < self.batch_size:
break
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
with torch.no_grad():
pred = self.model(feature)
loss = self.loss_fn(pred, label)
losses.append(loss.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
return np.mean(losses), np.mean(scores)
def fit(
self,
dataset: DatasetH,
evals_result=dict(),
save_path=None,
):
df_train, df_valid, df_test = dataset.prepare(
["train", "valid", "test"],
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
best_epoch = 0
evals_result["train"] = []
evals_result["valid"] = []
# train
self.logger.info("training...")
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
self.logger.info("training...")
self.train_epoch(x_train, y_train)
self.logger.info("evaluating...")
train_loss, train_score = self.test_epoch(x_train, y_train)
val_loss, val_score = self.test_epoch(x_valid, y_valid)
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
evals_result["train"].append(train_score)
evals_result["valid"].append(val_score)
if val_score > best_score:
best_score = val_score
stop_steps = 0
best_epoch = step
best_param = copy.deepcopy(self.model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.model.load_state_dict(best_param)
torch.save(best_param, save_path)
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
index = x_test.index
self.model.eval()
x_values = x_test.values
sample_num = x_values.shape[0]
preds = []
for begin in range(sample_num)[:: self.batch_size]:
if sample_num - begin < self.batch_size:
end = sample_num
else:
end = begin + self.batch_size
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
with torch.no_grad():
pred = self.model(x_batch).detach().cpu().numpy()
preds.append(pred)
return pd.Series(np.concatenate(preds), index=index)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=1000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
# [T, N, F]
return x + self.pe[: x.size(0), :]
def _get_clones(module, N):
return ModuleList([copy.deepcopy(module) for i in range(N)])
class LocalformerEncoder(nn.Module):
__constants__ = ["norm"]
def __init__(self, encoder_layer, num_layers, d_model):
super(LocalformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.conv = _get_clones(nn.Conv1d(d_model, d_model, 3, 1, 1), num_layers)
self.num_layers = num_layers
def forward(self, src, mask):
output = src
out = src
for i, mod in enumerate(self.layers):
# [T, N, F] --> [N, T, F] --> [N, F, T]
out = output.transpose(1, 0).transpose(2, 1)
out = self.conv[i](out).transpose(2, 1).transpose(1, 0)
output = mod(output + out, src_mask=mask)
return output + out
class Transformer(nn.Module):
def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None):
super(Transformer, self).__init__()
self.rnn = nn.GRU(
input_size=d_model,
hidden_size=d_model,
num_layers=num_layers,
batch_first=False,
dropout=dropout,
)
self.feature_layer = nn.Linear(d_feat, d_model)
self.pos_encoder = PositionalEncoding(d_model)
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)
self.transformer_encoder = LocalformerEncoder(self.encoder_layer, num_layers=num_layers, d_model=d_model)
self.decoder_layer = nn.Linear(d_model, 1)
self.device = device
self.d_feat = d_feat
def forward(self, src):
# src [N, F*T] --> [N, T, F]
src = src.reshape(len(src), self.d_feat, -1).permute(0, 2, 1)
src = self.feature_layer(src)
# src [N, T, F] --> [T, N, F], [60, 512, 8]
src = src.transpose(1, 0) # not batch first
mask = None
src = self.pos_encoder(src)
output = self.transformer_encoder(src, mask) # [60, 512, 8]
output, _ = self.rnn(output)
# [T, N, F] --> [N, T*F]
output = self.decoder_layer(output.transpose(1, 0)[:, -1, :]) # [512, 1]
return output.squeeze()

View File

@@ -0,0 +1,308 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
import copy
import math
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH, TSDatasetH
from ...data.dataset.handler import DataHandlerLP
from torch.nn.modules.container import ModuleList
class LocalformerModel(Model):
def __init__(
self,
d_feat: int = 20,
d_model: int = 64,
batch_size: int = 8192,
nhead: int = 2,
num_layers: int = 2,
dropout: float = 0,
n_epochs=100,
lr=0.0001,
metric="",
early_stop=5,
loss="mse",
optimizer="adam",
reg=1e-3,
n_jobs=10,
GPU=0,
seed=None,
**kwargs
):
# set hyper-parameters.
self.d_model = d_model
self.dropout = dropout
self.n_epochs = n_epochs
self.lr = lr
self.reg = reg
self.metric = metric
self.batch_size = batch_size
self.early_stop = early_stop
self.optimizer = optimizer.lower()
self.loss = loss
self.n_jobs = n_jobs
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
self.logger = get_module_logger("TransformerModel")
self.logger.info(
"Improved Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device)
)
if self.seed is not None:
np.random.seed(self.seed)
torch.manual_seed(self.seed)
self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device)
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
elif optimizer.lower() == "gd":
self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self.fitted = False
self.model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred.float() - label.float()) ** 2
return torch.mean(loss)
def loss_fn(self, pred, label):
mask = ~torch.isnan(label)
if self.loss == "mse":
return self.mse(pred[mask], label[mask])
raise ValueError("unknown loss `%s`" % self.loss)
def metric_fn(self, pred, label):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
def train_epoch(self, data_loader):
self.model.train()
for data in data_loader:
feature = data[:, :, 0:-1].to(self.device)
label = data[:, -1, -1].to(self.device)
pred = self.model(feature.float()) # .float()
loss = self.loss_fn(pred, label)
self.train_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
self.train_optimizer.step()
def test_epoch(self, data_loader):
self.model.eval()
scores = []
losses = []
for data in data_loader:
feature = data[:, :, 0:-1].to(self.device)
label = data[:, -1, -1].to(self.device)
with torch.no_grad():
pred = self.model(feature.float()) # .float()
loss = self.loss_fn(pred, label)
losses.append(loss.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
return np.mean(losses), np.mean(scores)
def fit(
self,
dataset: DatasetH,
evals_result=dict(),
save_path=None,
):
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
train_loader = DataLoader(
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
)
valid_loader = DataLoader(
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
)
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
best_epoch = 0
evals_result["train"] = []
evals_result["valid"] = []
# train
self.logger.info("training...")
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
self.logger.info("training...")
self.train_epoch(train_loader)
self.logger.info("evaluating...")
train_loss, train_score = self.test_epoch(train_loader)
val_loss, val_score = self.test_epoch(valid_loader)
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
evals_result["train"].append(train_score)
evals_result["valid"].append(val_score)
if val_score > best_score:
best_score = val_score
stop_steps = 0
best_epoch = step
best_param = copy.deepcopy(self.model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.model.load_state_dict(best_param)
torch.save(best_param, save_path)
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset):
if not self.fitted:
raise ValueError("model is not fitted yet!")
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
dl_test.config(fillna_type="ffill+bfill")
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
self.model.eval()
preds = []
for data in test_loader:
feature = data[:, :, 0:-1].to(self.device)
with torch.no_grad():
pred = self.model(feature.float()).detach().cpu().numpy()
preds.append(pred)
return pd.Series(np.concatenate(preds), index=dl_test.get_index())
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=1000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
# [T, N, F]
return x + self.pe[: x.size(0), :]
def _get_clones(module, N):
return ModuleList([copy.deepcopy(module) for i in range(N)])
class LocalformerEncoder(nn.Module):
__constants__ = ["norm"]
def __init__(self, encoder_layer, num_layers, d_model):
super(LocalformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.conv = _get_clones(nn.Conv1d(d_model, d_model, 3, 1, 1), num_layers)
self.num_layers = num_layers
def forward(self, src, mask):
output = src
out = src
for i, mod in enumerate(self.layers):
# [T, N, F] --> [N, T, F] --> [N, F, T]
out = output.transpose(1, 0).transpose(2, 1)
out = self.conv[i](out).transpose(2, 1).transpose(1, 0)
output = mod(output + out, src_mask=mask)
return output + out
class Transformer(nn.Module):
def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None):
super(Transformer, self).__init__()
self.rnn = nn.GRU(
input_size=d_model,
hidden_size=d_model,
num_layers=num_layers,
batch_first=False,
dropout=dropout,
)
self.feature_layer = nn.Linear(d_feat, d_model)
self.pos_encoder = PositionalEncoding(d_model)
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)
self.transformer_encoder = LocalformerEncoder(self.encoder_layer, num_layers=num_layers, d_model=d_model)
self.decoder_layer = nn.Linear(d_model, 1)
self.device = device
self.d_feat = d_feat
def forward(self, src):
# src [N, T, F], [512, 60, 6]
src = self.feature_layer(src) # [512, 60, 8]
# src [N, T, F] --> [T, N, F], [60, 512, 8]
src = src.transpose(1, 0) # not batch first
mask = None
src = self.pos_encoder(src)
output = self.transformer_encoder(src, mask) # [60, 512, 8]
output, _ = self.rnn(output)
# [T, N, F] --> [N, T*F]
output = self.decoder_layer(output.transpose(1, 0)[:, -1, :]) # [512, 1]
return output.squeeze()

View File

@@ -297,7 +297,7 @@ class DNNModelPytorch(Model):
_model_path = os.path.join(model_dir, _model_name)
# Load model
self.dnn_model.load_state_dict(torch.load(_model_path))
self._fitted = True
self.fitted = True
class AverageMeter:

View File

@@ -564,7 +564,7 @@ class FeatureTransformer(nn.Module):
self.shared = None
self.independ = nn.ModuleList()
if first:
self.independ.append(GLU(inp, out_dim, vbs=vbs))
self.independ.append(GLU(inp_dim, out_dim, vbs=vbs))
for x in range(first, n_ind):
self.independ.append(GLU(out_dim, out_dim, vbs=vbs))
self.scale = float(np.sqrt(0.5))

View File

@@ -0,0 +1,420 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
import copy
import random
from sklearn.metrics import roc_auc_score, mean_squared_error
import logging
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
get_or_create_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
import torch
import torch.nn as nn
import torch.optim as optim
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
class TCTS(Model):
"""TCTS Model
Parameters
----------
d_feat : int
input dimension for each time step
metric: str
the evaluate metric used in early stop
optimizer : str
optimizer name
GPU : str
the GPU ID(s) used for training
"""
def __init__(
self,
d_feat=6,
hidden_size=64,
num_layers=2,
dropout=0.0,
n_epochs=200,
batch_size=2000,
early_stop=20,
loss="mse",
fore_optimizer="adam",
weight_optimizer="adam",
output_dim=5,
fore_lr=5e-7,
weight_lr=5e-7,
steps=3,
GPU=0,
seed=0,
target_label=0,
lowest_valid_performance=0.993,
**kwargs
):
# Set logger.
self.logger = get_module_logger("TCTS")
self.logger.info("TCTS pytorch version...")
# set hyper-parameters.
self.d_feat = d_feat
self.hidden_size = hidden_size
self.num_layers = num_layers
self.dropout = dropout
self.n_epochs = n_epochs
self.batch_size = batch_size
self.early_stop = early_stop
self.loss = loss
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
self.use_gpu = torch.cuda.is_available()
self.seed = seed
self.output_dim = output_dim
self.fore_lr = fore_lr
self.weight_lr = weight_lr
self.steps = steps
self.target_label = target_label
self.lowest_valid_performance = lowest_valid_performance
self._fore_optimizer = fore_optimizer
self._weight_optimizer = weight_optimizer
self.logger.info(
"TCTS parameters setting:"
"\nd_feat : {}"
"\nhidden_size : {}"
"\nnum_layers : {}"
"\ndropout : {}"
"\nn_epochs : {}"
"\nbatch_size : {}"
"\nearly_stop : {}"
"\nloss_type : {}"
"\nvisible_GPU : {}"
"\nuse_GPU : {}"
"\nseed : {}".format(
d_feat,
hidden_size,
num_layers,
dropout,
n_epochs,
batch_size,
early_stop,
loss,
GPU,
self.use_gpu,
seed,
)
)
def loss_fn(self, pred, label, weight):
loc = torch.argmax(weight, 1)
loss = (pred - label[np.arange(weight.shape[0]), loc]) ** 2
return torch.mean(loss)
def train_epoch(self, x_train, y_train, x_valid, y_valid):
x_train_values = x_train.values
y_train_values = np.squeeze(y_train.values)
indices = np.arange(len(x_train_values))
np.random.shuffle(indices)
init_fore_model = copy.deepcopy(self.fore_model)
for p in init_fore_model.parameters():
p.init_fore_model = False
self.fore_model.train()
self.weight_model.train()
for p in self.weight_model.parameters():
p.requires_grad = False
for p in self.fore_model.parameters():
p.requires_grad = True
for i in range(self.steps):
for i in range(len(indices))[:: self.batch_size]:
if len(indices) - i < self.batch_size:
break
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
init_pred = init_fore_model(feature)
pred = self.fore_model(feature)
dis = init_pred - label.transpose(0, 1)
weight_feature = torch.cat((feature, dis.transpose(0, 1), label, init_pred.view(-1, 1)), 1)
weight = self.weight_model(weight_feature)
loss = self.loss_fn(pred, label, weight) # hard
self.fore_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(self.fore_model.parameters(), 3.0)
self.fore_optimizer.step()
x_valid_values = x_valid.values
y_valid_values = np.squeeze(y_valid.values)
indices = np.arange(len(x_valid_values))
np.random.shuffle(indices)
for p in self.weight_model.parameters():
p.requires_grad = True
for p in self.fore_model.parameters():
p.requires_grad = False
# fix forecasting model and valid weight model
for i in range(len(indices))[:: self.batch_size]:
if len(indices) - i < self.batch_size:
break
feature = torch.from_numpy(x_valid_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_valid_values[indices[i : i + self.batch_size]]).float().to(self.device)
pred = self.fore_model(feature)
dis = pred - label.transpose(0, 1)
weight_feature = torch.cat((feature, dis.transpose(0, 1), label, pred.view(-1, 1)), 1)
weight = self.weight_model(weight_feature)
loc = torch.argmax(weight, 1)
valid_loss = torch.mean((pred - label[:, 0]) ** 2)
loss = torch.mean(-valid_loss * torch.log(weight[np.arange(weight.shape[0]), loc]))
self.weight_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(self.weight_model.parameters(), 3.0)
self.weight_optimizer.step()
def test_epoch(self, data_x, data_y):
# prepare training data
x_values = data_x.values
y_values = np.squeeze(data_y.values)
self.fore_model.eval()
scores = []
losses = []
indices = np.arange(len(x_values))
for i in range(len(indices))[:: self.batch_size]:
if len(indices) - i < self.batch_size:
break
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
pred = self.fore_model(feature)
loss = torch.mean((pred - label[:, abs(self.target_label)]) ** 2)
losses.append(loss.item())
return np.mean(losses)
def fit(
self,
dataset: DatasetH,
verbose=True,
save_path=None,
):
df_train, df_valid, df_test = dataset.prepare(
["train", "valid", "test"],
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
x_test, y_test = df_test["feature"], df_test["label"]
if save_path == None:
save_path = get_or_create_path(save_path)
best_loss = np.inf
while best_loss > self.lowest_valid_performance:
if best_loss < np.inf:
print("Failed! Start retraining.")
self.seed = random.randint(0, 1000) # reset random seed
if self.seed is not None:
np.random.seed(self.seed)
torch.manual_seed(self.seed)
best_loss = self.training(
x_train, y_train, x_valid, y_valid, x_test, y_test, verbose=verbose, save_path=save_path
)
def training(
self,
x_train,
y_train,
x_valid,
y_valid,
x_test,
y_test,
verbose=True,
save_path=None,
):
self.fore_model = GRUModel(
d_feat=self.d_feat,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
dropout=self.dropout,
)
self.weight_model = MLPModel(
d_feat=360 + 2 * self.output_dim + 1,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
dropout=self.dropout,
output_dim=self.output_dim,
)
if self._fore_optimizer.lower() == "adam":
self.fore_optimizer = optim.Adam(self.fore_model.parameters(), lr=self.fore_lr)
elif self._fore_optimizer.lower() == "gd":
self.fore_optimizer = optim.SGD(self.fore_model.parameters(), lr=self.fore_lr)
else:
raise NotImplementedError("optimizer {} is not supported!".format(self._fore_optimizer))
if self._weight_optimizer.lower() == "adam":
self.weight_optimizer = optim.Adam(self.weight_model.parameters(), lr=self.weight_lr)
elif self._weight_optimizer.lower() == "gd":
self.weight_optimizer = optim.SGD(self.weight_model.parameters(), lr=self.weight_lr)
else:
raise NotImplementedError("optimizer {} is not supported!".format(self._weight_optimizer))
self.fitted = False
self.fore_model.to(self.device)
self.weight_model.to(self.device)
best_loss = np.inf
best_epoch = 0
stop_round = 0
fore_best_param = copy.deepcopy(self.fore_optimizer.state_dict())
weight_best_param = copy.deepcopy(self.weight_optimizer.state_dict())
for epoch in range(self.n_epochs):
print("Epoch:", epoch)
print("training...")
self.train_epoch(x_train, y_train, x_valid, y_valid)
print("evaluating...")
val_loss = self.test_epoch(x_valid, y_valid)
test_loss = self.test_epoch(x_test, y_test)
if verbose:
print("valid %.6f, test %.6f" % (val_loss, test_loss))
if val_loss < best_loss:
best_loss = val_loss
stop_round = 0
best_epoch = epoch
torch.save(copy.deepcopy(self.fore_model.state_dict()), save_path + "_fore_model.bin")
torch.save(copy.deepcopy(self.weight_model.state_dict()), save_path + "_weight_model.bin")
else:
stop_round += 1
if stop_round >= self.early_stop:
print("early stop")
break
print("best loss:", best_loss, "@", best_epoch)
best_param = torch.load(save_path + "_fore_model.bin")
self.fore_model.load_state_dict(best_param)
best_param = torch.load(save_path + "_weight_model.bin")
self.weight_model.load_state_dict(best_param)
self.fitted = True
if self.use_gpu:
torch.cuda.empty_cache()
return best_loss
def predict(self, dataset):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
index = x_test.index
self.fore_model.eval()
x_values = x_test.values
sample_num = x_values.shape[0]
preds = []
for begin in range(sample_num)[:: self.batch_size]:
if sample_num - begin < self.batch_size:
end = sample_num
else:
end = begin + self.batch_size
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
with torch.no_grad():
if self.use_gpu:
pred = self.fore_model(x_batch).detach().cpu().numpy()
else:
pred = self.fore_model(x_batch).detach().numpy()
preds.append(pred)
return pd.Series(np.concatenate(preds), index=index)
class MLPModel(nn.Module):
def __init__(self, d_feat, hidden_size=256, num_layers=3, dropout=0.0, output_dim=1):
super().__init__()
self.mlp = nn.Sequential()
self.softmax = nn.Softmax(dim=1)
for i in range(num_layers):
if i > 0:
self.mlp.add_module("drop_%d" % i, nn.Dropout(dropout))
self.mlp.add_module("fc_%d" % i, nn.Linear(d_feat if i == 0 else hidden_size, hidden_size))
self.mlp.add_module("relu_%d" % i, nn.ReLU())
self.mlp.add_module("fc_out", nn.Linear(hidden_size, output_dim))
def forward(self, x):
# feature
# [N, F]
out = self.mlp(x).squeeze()
out = self.softmax(out)
return out
class GRUModel(nn.Module):
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0):
super().__init__()
self.rnn = nn.GRU(
input_size=d_feat,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout,
)
self.fc_out = nn.Linear(hidden_size, 1)
self.d_feat = d_feat
def forward(self, x):
# x: [N, F*T]
x = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
x = x.permute(0, 2, 1) # [N, T, F]
out, _ = self.rnn(x)
return self.fc_out(out[:, -1, :]).squeeze()

View File

@@ -0,0 +1,944 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import io
import os
import copy
import math
import json
import collections
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
try:
from torch.utils.tensorboard import SummaryWriter
except:
SummaryWriter = None
from tqdm import tqdm
from qlib.utils import get_or_create_path
from qlib.log import get_module_logger
from qlib.model.base import Model
from qlib.contrib.data.dataset import MTSDatasetH
device = "cuda" if torch.cuda.is_available() else "cpu"
class TRAModel(Model):
"""
TRA Model
Args:
model_config (dict): model config (will be used by RNN or Transformer)
tra_config (dict): TRA config (will be used by TRA)
model_type (str): which backbone model to use (RNN/Transformer)
lr (float): learning rate
n_epochs (int): number of total epochs
early_stop (int): early stop when performance not improved at this step
update_freq (int): gradient update frequency
max_steps_per_epoch (int): maximum number of steps in one epoch
lamb (float): regularization parameter
rho (float): exponential decay rate for `lamb`
alpha (float): fusion parameter for calculating transport loss matrix
seed (int): random seed
logdir (str): local log directory
eval_train (bool): whether evaluate train set between epochs
eval_test (bool): whether evaluate test set between epochs
pretrain (bool): whether pretrain the backbone model before training TRA.
Note that only TRA will be optimized after pretraining
init_state (str): model init state path
freeze_model (bool): whether freeze backbone model parameters
freeze_predictors (bool): whether freeze predictors parameters
transport_method (str): transport method, can be none/router/oracle
memory_mode (str): memory mode, the same argument for MTSDatasetH
"""
def __init__(
self,
model_config,
tra_config,
model_type="RNN",
lr=1e-3,
n_epochs=500,
early_stop=50,
update_freq=1,
max_steps_per_epoch=None,
lamb=0.0,
rho=0.99,
alpha=1.0,
seed=0,
logdir=None,
eval_train=False,
eval_test=False,
pretrain=False,
init_state=None,
reset_router=False,
freeze_model=False,
freeze_predictors=False,
transport_method="none",
memory_mode="sample",
):
self.logger = get_module_logger("TRA")
assert memory_mode in ["sample", "daily"], "invalid memory mode"
assert transport_method in ["none", "router", "oracle"], f"invalid transport method {transport_method}"
assert transport_method == "none" or tra_config["num_states"] > 1, "optimal transport requires `num_states` > 1"
assert (
memory_mode != "daily" or tra_config["src_info"] == "TPE"
), "daily transport can only support TPE as `src_info`"
if transport_method == "router" and not eval_train:
self.logger.warning("`eval_train` will be ignored when using TRA.router")
np.random.seed(seed)
torch.manual_seed(seed)
self.model_config = model_config
self.tra_config = tra_config
self.model_type = model_type
self.lr = lr
self.n_epochs = n_epochs
self.early_stop = early_stop
self.update_freq = update_freq
self.max_steps_per_epoch = max_steps_per_epoch
self.lamb = lamb
self.rho = rho
self.alpha = alpha
self.seed = seed
self.logdir = logdir
self.eval_train = eval_train
self.eval_test = eval_test
self.pretrain = pretrain
self.init_state = init_state
self.reset_router = reset_router
self.freeze_model = freeze_model
self.freeze_predictors = freeze_predictors
self.transport_method = transport_method
self.use_daily_transport = memory_mode == "daily"
self.transport_fn = transport_daily if self.use_daily_transport else transport_sample
self._writer = None
if self.logdir is not None:
if os.path.exists(self.logdir):
self.logger.warning(f"logdir {self.logdir} is not empty")
os.makedirs(self.logdir, exist_ok=True)
if SummaryWriter is not None:
self._writer = SummaryWriter(log_dir=self.logdir)
self._init_model()
def _init_model(self):
self.logger.info("init TRAModel...")
self.model = eval(self.model_type)(**self.model_config).to(device)
print(self.model)
self.tra = TRA(self.model.output_size, **self.tra_config).to(device)
print(self.tra)
if self.init_state:
self.logger.warning(f"load state dict from `init_state`")
state_dict = torch.load(self.init_state, map_location="cpu")
self.model.load_state_dict(state_dict["model"])
res = load_state_dict_unsafe(self.tra, state_dict["tra"])
self.logger.warning(str(res))
if self.reset_router:
self.logger.warning(f"reset TRA.router parameters")
self.tra.fc.reset_parameters()
self.tra.router.reset_parameters()
if self.freeze_model:
self.logger.warning(f"freeze model parameters")
for param in self.model.parameters():
param.requires_grad_(False)
if self.freeze_predictors:
self.logger.warning(f"freeze TRA.predictors parameters")
for param in self.tra.predictors.parameters():
param.requires_grad_(False)
self.logger.info("# model params: %d" % sum([p.numel() for p in self.model.parameters() if p.requires_grad]))
self.logger.info("# tra params: %d" % sum([p.numel() for p in self.tra.parameters() if p.requires_grad]))
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr)
self.fitted = False
self.global_step = -1
def train_epoch(self, epoch, data_set, is_pretrain=False):
self.model.train()
self.tra.train()
data_set.train()
self.optimizer.zero_grad()
P_all = []
prob_all = []
choice_all = []
max_steps = len(data_set)
if self.max_steps_per_epoch is not None:
if epoch == 0 and self.max_steps_per_epoch < max_steps:
self.logger.info(f"max steps updated from {max_steps} to {self.max_steps_per_epoch}")
max_steps = min(self.max_steps_per_epoch, max_steps)
cur_step = 0
total_loss = 0
total_count = 0
for batch in tqdm(data_set, total=max_steps):
cur_step += 1
if cur_step > max_steps:
break
if not is_pretrain:
self.global_step += 1
data, state, label, count = batch["data"], batch["state"], batch["label"], batch["daily_count"]
index = batch["daily_index"] if self.use_daily_transport else batch["index"]
with torch.set_grad_enabled(not self.freeze_model):
hidden = self.model(data)
all_preds, choice, prob = self.tra(hidden, state)
if is_pretrain or self.transport_method != "none":
# NOTE: use oracle transport for pre-training
loss, pred, L, P = self.transport_fn(
all_preds,
label,
choice,
prob,
state.mean(dim=1),
count,
self.transport_method if not is_pretrain else "oracle",
self.alpha,
training=True,
)
data_set.assign_data(index, L) # save loss to memory
if self.use_daily_transport: # only save for daily transport
P_all.append(pd.DataFrame(P.detach().cpu().numpy(), index=index))
prob_all.append(pd.DataFrame(prob.detach().cpu().numpy(), index=index))
choice_all.append(pd.DataFrame(choice.detach().cpu().numpy(), index=index))
decay = self.rho ** (self.global_step // 100) # decay every 100 steps
lamb = 0 if is_pretrain else self.lamb * decay
reg = prob.log().mul(P).sum(dim=1).mean() # train router to predict OT assignment
if self._writer is not None and not is_pretrain:
self._writer.add_scalar("training/router_loss", -reg.item(), self.global_step)
self._writer.add_scalar("training/reg_loss", loss.item(), self.global_step)
self._writer.add_scalar("training/lamb", lamb, self.global_step)
if not self.use_daily_transport:
P_mean = P.mean(axis=0).detach()
self._writer.add_scalar("training/P", P_mean.max() / P_mean.min(), self.global_step)
loss = loss - lamb * reg
else:
pred = all_preds.mean(dim=1)
loss = loss_fn(pred, label)
(loss / self.update_freq).backward()
if cur_step % self.update_freq == 0:
self.optimizer.step()
self.optimizer.zero_grad()
if self._writer is not None and not is_pretrain:
self._writer.add_scalar("training/total_loss", loss.item(), self.global_step)
total_loss += loss.item()
total_count += 1
if self.use_daily_transport and len(P_all):
P_all = pd.concat(P_all, axis=0)
prob_all = pd.concat(prob_all, axis=0)
choice_all = pd.concat(choice_all, axis=0)
P_all.index = data_set.restore_daily_index(P_all.index)
prob_all.index = P_all.index
choice_all.index = P_all.index
if not is_pretrain:
self._writer.add_image("P", plot(P_all), epoch, dataformats="HWC")
self._writer.add_image("prob", plot(prob_all), epoch, dataformats="HWC")
self._writer.add_image("choice", plot(choice_all), epoch, dataformats="HWC")
total_loss /= total_count
if self._writer is not None and not is_pretrain:
self._writer.add_scalar("training/loss", total_loss, epoch)
return total_loss
def test_epoch(self, epoch, data_set, return_pred=False, prefix="test", is_pretrain=False):
self.model.eval()
self.tra.eval()
data_set.eval()
preds = []
probs = []
P_all = []
metrics = []
for batch in tqdm(data_set):
data, state, label, count = batch["data"], batch["state"], batch["label"], batch["daily_count"]
index = batch["daily_index"] if self.use_daily_transport else batch["index"]
with torch.no_grad():
hidden = self.model(data)
all_preds, choice, prob = self.tra(hidden, state)
if is_pretrain or self.transport_method != "none":
loss, pred, L, P = self.transport_fn(
all_preds,
label,
choice,
prob,
state.mean(dim=1),
count,
self.transport_method if not is_pretrain else "oracle",
self.alpha,
training=False,
)
data_set.assign_data(index, L) # save loss to memory
if P is not None and return_pred:
P_all.append(pd.DataFrame(P.cpu().numpy(), index=index))
else:
pred = all_preds.mean(dim=1)
X = np.c_[pred.cpu().numpy(), label.cpu().numpy(), all_preds.cpu().numpy()]
columns = ["score", "label"] + ["score_%d" % d for d in range(all_preds.shape[1])]
pred = pd.DataFrame(X, index=batch["index"], columns=columns)
metrics.append(evaluate(pred))
if return_pred:
preds.append(pred)
if prob is not None:
columns = ["prob_%d" % d for d in range(all_preds.shape[1])]
probs.append(pd.DataFrame(prob.cpu().numpy(), index=index, columns=columns))
metrics = pd.DataFrame(metrics)
metrics = {
"MSE": metrics.MSE.mean(),
"MAE": metrics.MAE.mean(),
"IC": metrics.IC.mean(),
"ICIR": metrics.IC.mean() / metrics.IC.std(),
}
if self._writer is not None and epoch >= 0 and not is_pretrain:
for key, value in metrics.items():
self._writer.add_scalar(prefix + "/" + key, value, epoch)
if return_pred:
preds = pd.concat(preds, axis=0)
preds.index = data_set.restore_index(preds.index)
preds.index = preds.index.swaplevel()
preds.sort_index(inplace=True)
if probs:
probs = pd.concat(probs, axis=0)
if self.use_daily_transport:
probs.index = data_set.restore_daily_index(probs.index)
else:
probs.index = data_set.restore_index(probs.index)
probs.index = probs.index.swaplevel()
probs.sort_index(inplace=True)
if len(P_all):
P_all = pd.concat(P_all, axis=0)
if self.use_daily_transport:
P_all.index = data_set.restore_daily_index(P_all.index)
else:
P_all.index = data_set.restore_index(P_all.index)
P_all.index = P_all.index.swaplevel()
P_all.sort_index(inplace=True)
return metrics, preds, probs, P_all
def _fit(self, train_set, valid_set, test_set, evals_result, is_pretrain=True):
best_score = -1
best_epoch = 0
stop_rounds = 0
best_params = {
"model": copy.deepcopy(self.model.state_dict()),
"tra": copy.deepcopy(self.tra.state_dict()),
}
# train
if not is_pretrain and self.transport_method != "none":
self.logger.info("init memory...")
self.test_epoch(-1, train_set)
for epoch in range(self.n_epochs):
self.logger.info("Epoch %d:", epoch)
self.logger.info("training...")
self.train_epoch(epoch, train_set, is_pretrain=is_pretrain)
self.logger.info("evaluating...")
# NOTE: during evaluating, the whole memory will be refreshed
if not is_pretrain and (self.transport_method == "router" or self.eval_train):
train_set.clear_memory() # NOTE: clear the shared memory
train_metrics = self.test_epoch(epoch, train_set, is_pretrain=is_pretrain, prefix="train")[0]
evals_result["train"].append(train_metrics)
self.logger.info("train metrics: %s" % train_metrics)
valid_metrics = self.test_epoch(epoch, valid_set, is_pretrain=is_pretrain, prefix="valid")[0]
evals_result["valid"].append(valid_metrics)
self.logger.info("valid metrics: %s" % valid_metrics)
if self.eval_test:
test_metrics = self.test_epoch(epoch, test_set, is_pretrain=is_pretrain, prefix="test")[0]
evals_result["test"].append(test_metrics)
self.logger.info("test metrics: %s" % test_metrics)
if valid_metrics["IC"] > best_score:
best_score = valid_metrics["IC"]
stop_rounds = 0
best_epoch = epoch
best_params = {
"model": copy.deepcopy(self.model.state_dict()),
"tra": copy.deepcopy(self.tra.state_dict()),
}
if self.logdir is not None:
torch.save(best_params, self.logdir + "/model.bin")
else:
stop_rounds += 1
if stop_rounds >= self.early_stop:
self.logger.info("early stop @ %s" % epoch)
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.model.load_state_dict(best_params["model"])
self.tra.load_state_dict(best_params["tra"])
return best_score
def fit(self, dataset, evals_result=dict()):
assert isinstance(dataset, MTSDatasetH), "TRAModel only supports `qlib.contrib.data.dataset.MTSDatasetH`"
train_set, valid_set, test_set = dataset.prepare(["train", "valid", "test"])
self.fitted = True
self.global_step = -1
evals_result["train"] = []
evals_result["valid"] = []
evals_result["test"] = []
if self.pretrain:
self.logger.info("pretraining...")
self.optimizer = optim.Adam(
list(self.model.parameters()) + list(self.tra.predictors.parameters()), lr=self.lr
)
self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=True)
# reset optimizer
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr)
self.logger.info("training...")
best_score = self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=False)
self.logger.info("inference")
train_metrics, train_preds, train_probs, train_P = self.test_epoch(-1, train_set, return_pred=True)
self.logger.info("train metrics: %s" % train_metrics)
valid_metrics, valid_preds, valid_probs, valid_P = self.test_epoch(-1, valid_set, return_pred=True)
self.logger.info("valid metrics: %s" % valid_metrics)
test_metrics, test_preds, test_probs, test_P = self.test_epoch(-1, test_set, return_pred=True)
self.logger.info("test metrics: %s" % test_metrics)
if self.logdir:
self.logger.info("save model & pred to local directory")
pd.concat({name: pd.DataFrame(evals_result[name]) for name in evals_result}, axis=1).to_csv(
self.logdir + "/logs.csv", index=False
)
torch.save({"model": self.model.state_dict(), "tra": self.tra.state_dict()}, self.logdir + "/model.bin")
train_preds.to_pickle(self.logdir + "/train_pred.pkl")
valid_preds.to_pickle(self.logdir + "/valid_pred.pkl")
test_preds.to_pickle(self.logdir + "/test_pred.pkl")
if len(train_probs):
train_probs.to_pickle(self.logdir + "/train_prob.pkl")
valid_probs.to_pickle(self.logdir + "/valid_prob.pkl")
test_probs.to_pickle(self.logdir + "/test_prob.pkl")
if len(train_P):
train_P.to_pickle(self.logdir + "/train_P.pkl")
valid_P.to_pickle(self.logdir + "/valid_P.pkl")
test_P.to_pickle(self.logdir + "/test_P.pkl")
info = {
"config": {
"model_config": self.model_config,
"tra_config": self.tra_config,
"model_type": self.model_type,
"lr": self.lr,
"n_epochs": self.n_epochs,
"early_stop": self.early_stop,
"max_steps_per_epoch": self.max_steps_per_epoch,
"lamb": self.lamb,
"rho": self.rho,
"alpha": self.alpha,
"seed": self.seed,
"logdir": self.logdir,
"pretrain": self.pretrain,
"init_state": self.init_state,
"transport_method": self.transport_method,
"use_daily_transport": self.use_daily_transport,
},
"best_eval_metric": -best_score, # NOTE: -1 for minimize
"metrics": {"train": train_metrics, "valid": valid_metrics, "test": test_metrics},
}
with open(self.logdir + "/info.json", "w") as f:
json.dump(info, f)
def predict(self, dataset, segment="test"):
assert isinstance(dataset, MTSDatasetH), "TRAModel only supports `qlib.contrib.data.dataset.MTSDatasetH`"
if not self.fitted:
raise ValueError("model is not fitted yet!")
test_set = dataset.prepare(segment)
metrics, preds, _, _ = self.test_epoch(-1, test_set, return_pred=True)
self.logger.info("test metrics: %s" % metrics)
return preds
class RNN(nn.Module):
"""RNN Model
Args:
input_size (int): input size (# features)
hidden_size (int): hidden size
num_layers (int): number of hidden layers
rnn_arch (str): rnn architecture
use_attn (bool): whether use attention layer.
we use concat attention as https://github.com/fulifeng/Adv-ALSTM/
dropout (float): dropout rate
"""
def __init__(
self,
input_size=16,
hidden_size=64,
num_layers=2,
rnn_arch="GRU",
use_attn=True,
dropout=0.0,
**kwargs,
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.rnn_arch = rnn_arch
self.use_attn = use_attn
if hidden_size < input_size:
# compression
self.input_proj = nn.Linear(input_size, hidden_size)
else:
self.input_proj = None
self.rnn = getattr(nn, rnn_arch)(
input_size=min(input_size, hidden_size),
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout,
)
if self.use_attn:
self.W = nn.Linear(hidden_size, hidden_size)
self.u = nn.Linear(hidden_size, 1, bias=False)
self.output_size = hidden_size * 2
else:
self.output_size = hidden_size
def forward(self, x):
if self.input_proj is not None:
x = self.input_proj(x)
rnn_out, last_out = self.rnn(x)
if self.rnn_arch == "LSTM":
last_out = last_out[0]
last_out = last_out.mean(dim=0)
if self.use_attn:
laten = self.W(rnn_out).tanh()
scores = self.u(laten).softmax(dim=1)
att_out = (rnn_out * scores).sum(dim=1)
last_out = torch.cat([last_out, att_out], dim=1)
return last_out
class PositionalEncoding(nn.Module):
# reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[: x.size(0), :]
return self.dropout(x)
class Transformer(nn.Module):
"""Transformer Model
Args:
input_size (int): input size (# features)
hidden_size (int): hidden size
num_layers (int): number of transformer layers
num_heads (int): number of heads in transformer
dropout (float): dropout rate
"""
def __init__(
self,
input_size=16,
hidden_size=64,
num_layers=2,
num_heads=2,
dropout=0.0,
**kwargs,
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.input_proj = nn.Linear(input_size, hidden_size)
self.pe = PositionalEncoding(input_size, dropout)
layer = nn.TransformerEncoderLayer(
nhead=num_heads, dropout=dropout, d_model=hidden_size, dim_feedforward=hidden_size * 4
)
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
self.output_size = hidden_size
def forward(self, x):
x = x.permute(1, 0, 2).contiguous() # the first dim need to be time
x = self.pe(x)
x = self.input_proj(x)
out = self.encoder(x)
return out[-1]
class TRA(nn.Module):
"""Temporal Routing Adaptor (TRA)
TRA takes historical prediction erros & latent representation as inputs,
then routes the input sample to a specific predictor for training & inference.
Args:
input_size (int): input size (RNN/Transformer's hidden size)
num_states (int): number of latent states (i.e., trading patterns)
If `num_states=1`, then TRA falls back to traditional methods
hidden_size (int): hidden size of the router
tau (float): gumbel softmax temperature
src_info (str): information for the router
"""
def __init__(
self,
input_size,
num_states=1,
hidden_size=8,
rnn_arch="GRU",
num_layers=1,
dropout=0.0,
tau=1.0,
src_info="LR_TPE",
):
super().__init__()
assert src_info in ["LR", "TPE", "LR_TPE"], "invalid `src_info`"
self.num_states = num_states
self.tau = tau
self.rnn_arch = rnn_arch
self.src_info = src_info
self.predictors = nn.Linear(input_size, num_states)
if self.num_states > 1:
if "TPE" in src_info:
self.router = getattr(nn, rnn_arch)(
input_size=num_states,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout,
)
self.fc = nn.Linear(hidden_size + input_size if "LR" in src_info else hidden_size, num_states)
else:
self.fc = nn.Linear(input_size, num_states)
def reset_parameters(self):
for child in self.children():
child.reset_parameters()
def forward(self, hidden, hist_loss):
preds = self.predictors(hidden)
if self.num_states == 1: # no need for router when having only one prediction
return preds, None, None
if "TPE" in self.src_info:
out = self.router(hist_loss)[1] # TPE
if self.rnn_arch == "LSTM":
out = out[0]
out = out.mean(dim=0)
if "LR" in self.src_info:
out = torch.cat([hidden, out], dim=-1) # LR_TPE
else:
out = hidden # LR
out = self.fc(out)
choice = F.gumbel_softmax(out, dim=-1, tau=self.tau, hard=True)
prob = torch.softmax(out / self.tau, dim=-1)
return preds, choice, prob
def evaluate(pred):
pred = pred.rank(pct=True) # transform into percentiles
score = pred.score
label = pred.label
diff = score - label
MSE = (diff ** 2).mean()
MAE = (diff.abs()).mean()
IC = score.corr(label, method="spearman")
return {"MSE": MSE, "MAE": MAE, "IC": IC}
def shoot_infs(inp_tensor):
"""Replaces inf by maximum of tensor"""
mask_inf = torch.isinf(inp_tensor)
ind_inf = torch.nonzero(mask_inf, as_tuple=False)
if len(ind_inf) > 0:
for ind in ind_inf:
if len(ind) == 2:
inp_tensor[ind[0], ind[1]] = 0
elif len(ind) == 1:
inp_tensor[ind[0]] = 0
m = torch.max(inp_tensor)
for ind in ind_inf:
if len(ind) == 2:
inp_tensor[ind[0], ind[1]] = m
elif len(ind) == 1:
inp_tensor[ind[0]] = m
return inp_tensor
def sinkhorn(Q, n_iters=3, epsilon=0.1):
# epsilon should be adjusted according to logits value's scale
with torch.no_grad():
Q = torch.exp(Q / epsilon)
Q = shoot_infs(Q)
for i in range(n_iters):
Q /= Q.sum(dim=0, keepdim=True)
Q /= Q.sum(dim=1, keepdim=True)
return Q
def loss_fn(pred, label):
mask = ~torch.isnan(label)
if len(pred.shape) == 2:
label = label[:, None]
return (pred[mask] - label[mask]).pow(2).mean(dim=0)
def minmax_norm(x):
xmin = x.min(dim=-1, keepdim=True).values
xmax = x.max(dim=-1, keepdim=True).values
mask = (xmin == xmax).squeeze()
x = (x - xmin) / (xmax - xmin + 1e-12)
x[mask] = 1
return x
def transport_sample(all_preds, label, choice, prob, hist_loss, count, transport_method, alpha, training=False):
"""
sample-wise transport
Args:
all_preds (torch.Tensor): predictions from all predictors, [sample x states]
label (torch.Tensor): label, [sample]
choice (torch.Tensor): gumbel softmax choice, [sample x states]
prob (torch.Tensor): router predicted probility, [sample x states]
hist_loss (torch.Tensor): history loss matrix, [sample x states]
count (list): sample counts for each day, empty list for sample-wise transport
transport_method (str): transportation method
alpha (float): fusion parameter for calculating transport loss matrix
training (bool): indicate training or inference
"""
assert all_preds.shape == choice.shape
assert len(all_preds) == len(label)
assert transport_method in ["oracle", "router"]
all_loss = torch.zeros_like(all_preds)
mask = ~torch.isnan(label)
all_loss[mask] = (all_preds[mask] - label[mask, None]).pow(2) # [sample x states]
L = minmax_norm(all_loss.detach())
Lh = L * alpha + minmax_norm(hist_loss) * (1 - alpha) # add hist loss for transport
Lh = minmax_norm(Lh)
P = sinkhorn(-Lh)
del Lh
if transport_method == "router":
if training:
pred = (all_preds * choice).sum(dim=1) # gumbel softmax
else:
pred = all_preds[range(len(all_preds)), prob.argmax(dim=-1)] # argmax
else:
pred = (all_preds * P).sum(dim=1)
if transport_method == "router":
loss = loss_fn(pred, label)
else:
loss = (all_loss * P).sum(dim=1).mean()
return loss, pred, L, P
def transport_daily(all_preds, label, choice, prob, hist_loss, count, transport_method, alpha, training=False):
"""
daily transport
Args:
all_preds (torch.Tensor): predictions from all predictors, [sample x states]
label (torch.Tensor): label, [sample]
choice (torch.Tensor): gumbel softmax choice, [days x states]
prob (torch.Tensor): router predicted probility, [days x states]
hist_loss (torch.Tensor): history loss matrix, [days x states]
count (list): sample counts for each day, [days]
transport_method (str): transportation method
alpha (float): fusion parameter for calculating transport loss matrix
training (bool): indicate training or inference
"""
assert len(prob) == len(count)
assert len(all_preds) == sum(count)
assert transport_method in ["oracle", "router"]
all_loss = [] # loss of all predictions
start = 0
for i, cnt in enumerate(count):
slc = slice(start, start + cnt) # samples from the i-th day
start += cnt
tloss = loss_fn(all_preds[slc], label[slc]) # loss of the i-th day
all_loss.append(tloss)
all_loss = torch.stack(all_loss, dim=0) # [days x states]
L = minmax_norm(all_loss.detach())
Lh = L * alpha + minmax_norm(hist_loss) * (1 - alpha) # add hist loss for transport
Lh = minmax_norm(Lh)
P = sinkhorn(-Lh)
del Lh
pred = []
start = 0
for i, cnt in enumerate(count):
slc = slice(start, start + cnt) # samples from the i-th day
start += cnt
if transport_method == "router":
if training:
tpred = all_preds[slc] @ choice[i] # gumbel softmax
else:
tpred = all_preds[slc][:, prob[i].argmax(dim=-1)] # argmax
else:
tpred = all_preds[slc] @ P[i]
pred.append(tpred)
pred = torch.cat(pred, dim=0) # [samples]
if transport_method == "router":
loss = loss_fn(pred, label)
else:
loss = (all_loss * P).sum(dim=1).mean()
return loss, pred, L, P
def load_state_dict_unsafe(model, state_dict):
"""
Load state dict to provided model while ignore exceptions.
"""
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
load(model)
load = None # break load->load reference cycle
return {"unexpected_keys": unexpected_keys, "missing_keys": missing_keys, "error_msgs": error_msgs}
def plot(P):
assert isinstance(P, pd.DataFrame)
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
P.plot.area(ax=axes[0], xlabel="")
P.idxmax(axis=1).value_counts().sort_index().plot.bar(ax=axes[1], xlabel="")
plt.tight_layout()
with io.BytesIO() as buf:
plt.savefig(buf, format="png")
buf.seek(0)
img = plt.imread(buf)
plt.close()
return np.uint8(img * 255)

View File

@@ -0,0 +1,294 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
import math
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH, TSDatasetH
from ...data.dataset.handler import DataHandlerLP
# qrun examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml ”
class TransformerModel(Model):
def __init__(
self,
d_feat: int = 20,
d_model: int = 64,
batch_size: int = 2048,
nhead: int = 2,
num_layers: int = 2,
dropout: float = 0,
n_epochs=100,
lr=0.0001,
metric="",
early_stop=5,
loss="mse",
optimizer="adam",
reg=1e-3,
n_jobs=10,
GPU=0,
seed=None,
**kwargs
):
# set hyper-parameters.
self.d_model = d_model
self.dropout = dropout
self.n_epochs = n_epochs
self.lr = lr
self.reg = reg
self.metric = metric
self.batch_size = batch_size
self.early_stop = early_stop
self.optimizer = optimizer.lower()
self.loss = loss
self.n_jobs = n_jobs
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
self.logger = get_module_logger("TransformerModel")
self.logger.info("Naive Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device))
if self.seed is not None:
np.random.seed(self.seed)
torch.manual_seed(self.seed)
self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device)
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
elif optimizer.lower() == "gd":
self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self.fitted = False
self.model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred.float() - label.float()) ** 2
return torch.mean(loss)
def loss_fn(self, pred, label):
mask = ~torch.isnan(label)
if self.loss == "mse":
return self.mse(pred[mask], label[mask])
raise ValueError("unknown loss `%s`" % self.loss)
def metric_fn(self, pred, label):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
def train_epoch(self, x_train, y_train):
x_train_values = x_train.values
y_train_values = np.squeeze(y_train.values)
self.model.train()
indices = np.arange(len(x_train_values))
np.random.shuffle(indices)
for i in range(len(indices))[:: self.batch_size]:
if len(indices) - i < self.batch_size:
break
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
pred = self.model(feature)
loss = self.loss_fn(pred, label)
self.train_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
self.train_optimizer.step()
def test_epoch(self, data_x, data_y):
# prepare training data
x_values = data_x.values
y_values = np.squeeze(data_y.values)
self.model.eval()
scores = []
losses = []
indices = np.arange(len(x_values))
for i in range(len(indices))[:: self.batch_size]:
if len(indices) - i < self.batch_size:
break
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
with torch.no_grad():
pred = self.model(feature)
loss = self.loss_fn(pred, label)
losses.append(loss.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
return np.mean(losses), np.mean(scores)
def fit(
self,
dataset: DatasetH,
evals_result=dict(),
save_path=None,
):
df_train, df_valid, df_test = dataset.prepare(
["train", "valid", "test"],
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
best_epoch = 0
evals_result["train"] = []
evals_result["valid"] = []
# train
self.logger.info("training...")
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
self.logger.info("training...")
self.train_epoch(x_train, y_train)
self.logger.info("evaluating...")
train_loss, train_score = self.test_epoch(x_train, y_train)
val_loss, val_score = self.test_epoch(x_valid, y_valid)
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
evals_result["train"].append(train_score)
evals_result["valid"].append(val_score)
if val_score > best_score:
best_score = val_score
stop_steps = 0
best_epoch = step
best_param = copy.deepcopy(self.model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.model.load_state_dict(best_param)
torch.save(best_param, save_path)
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
index = x_test.index
self.model.eval()
x_values = x_test.values
sample_num = x_values.shape[0]
preds = []
for begin in range(sample_num)[:: self.batch_size]:
if sample_num - begin < self.batch_size:
end = sample_num
else:
end = begin + self.batch_size
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
with torch.no_grad():
pred = self.model(x_batch).detach().cpu().numpy()
preds.append(pred)
return pd.Series(np.concatenate(preds), index=index)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=1000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
# [T, N, F]
return x + self.pe[: x.size(0), :]
class Transformer(nn.Module):
def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None):
super(Transformer, self).__init__()
self.feature_layer = nn.Linear(d_feat, d_model)
self.pos_encoder = PositionalEncoding(d_model)
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
self.decoder_layer = nn.Linear(d_model, 1)
self.device = device
self.d_feat = d_feat
def forward(self, src):
# src [N, F*T] --> [N, T, F]
src = src.reshape(len(src), self.d_feat, -1).permute(0, 2, 1)
src = self.feature_layer(src)
# src [N, T, F] --> [T, N, F], [60, 512, 8]
src = src.transpose(1, 0) # not batch first
mask = None
src = self.pos_encoder(src)
output = self.transformer_encoder(src, mask) # [60, 512, 8]
# [T, N, F] --> [N, T*F]
output = self.decoder_layer(output.transpose(1, 0)[:, -1, :]) # [512, 1]
return output.squeeze()

View File

@@ -0,0 +1,269 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
import copy
import math
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH, TSDatasetH
from ...data.dataset.handler import DataHandlerLP
class TransformerModel(Model):
def __init__(
self,
d_feat: int = 20,
d_model: int = 64,
batch_size: int = 8192,
nhead: int = 2,
num_layers: int = 2,
dropout: float = 0,
n_epochs=100,
lr=0.0001,
metric="",
early_stop=5,
loss="mse",
optimizer="adam",
reg=1e-3,
n_jobs=10,
GPU=0,
seed=None,
**kwargs
):
# set hyper-parameters.
self.d_model = d_model
self.dropout = dropout
self.n_epochs = n_epochs
self.lr = lr
self.reg = reg
self.metric = metric
self.batch_size = batch_size
self.early_stop = early_stop
self.optimizer = optimizer.lower()
self.loss = loss
self.n_jobs = n_jobs
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
self.logger = get_module_logger("TransformerModel")
self.logger.info("Naive Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device))
if self.seed is not None:
np.random.seed(self.seed)
torch.manual_seed(self.seed)
self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device)
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
elif optimizer.lower() == "gd":
self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self.fitted = False
self.model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred.float() - label.float()) ** 2
return torch.mean(loss)
def loss_fn(self, pred, label):
mask = ~torch.isnan(label)
if self.loss == "mse":
return self.mse(pred[mask], label[mask])
raise ValueError("unknown loss `%s`" % self.loss)
def metric_fn(self, pred, label):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
def train_epoch(self, data_loader):
self.model.train()
for data in data_loader:
feature = data[:, :, 0:-1].to(self.device)
label = data[:, -1, -1].to(self.device)
pred = self.model(feature.float()) # .float()
loss = self.loss_fn(pred, label)
self.train_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
self.train_optimizer.step()
def test_epoch(self, data_loader):
self.model.eval()
scores = []
losses = []
for data in data_loader:
feature = data[:, :, 0:-1].to(self.device)
label = data[:, -1, -1].to(self.device)
with torch.no_grad():
pred = self.model(feature.float()) # .float()
loss = self.loss_fn(pred, label)
losses.append(loss.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
return np.mean(losses), np.mean(scores)
def fit(
self,
dataset: DatasetH,
evals_result=dict(),
save_path=None,
):
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
train_loader = DataLoader(
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
)
valid_loader = DataLoader(
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
)
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
best_epoch = 0
evals_result["train"] = []
evals_result["valid"] = []
# train
self.logger.info("training...")
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
self.logger.info("training...")
self.train_epoch(train_loader)
self.logger.info("evaluating...")
train_loss, train_score = self.test_epoch(train_loader)
val_loss, val_score = self.test_epoch(valid_loader)
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
evals_result["train"].append(train_score)
evals_result["valid"].append(val_score)
if val_score > best_score:
best_score = val_score
stop_steps = 0
best_epoch = step
best_param = copy.deepcopy(self.model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.model.load_state_dict(best_param)
torch.save(best_param, save_path)
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset):
if not self.fitted:
raise ValueError("model is not fitted yet!")
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
dl_test.config(fillna_type="ffill+bfill")
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
self.model.eval()
preds = []
for data in test_loader:
feature = data[:, :, 0:-1].to(self.device)
with torch.no_grad():
pred = self.model(feature.float()).detach().cpu().numpy()
preds.append(pred)
return pd.Series(np.concatenate(preds), index=dl_test.get_index())
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=1000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
# [T, N, F]
return x + self.pe[: x.size(0), :]
class Transformer(nn.Module):
def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None):
super(Transformer, self).__init__()
self.feature_layer = nn.Linear(d_feat, d_model)
self.pos_encoder = PositionalEncoding(d_model)
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
self.decoder_layer = nn.Linear(d_model, 1)
self.device = device
self.d_feat = d_feat
def forward(self, src):
# src [N, T, F], [512, 60, 6]
src = self.feature_layer(src) # [512, 60, 8]
# src [N, T, F] --> [T, N, F], [60, 512, 8]
src = src.transpose(1, 0) # not batch first
mask = None
src = self.pos_encoder(src)
output = self.transformer_encoder(src, mask) # [60, 512, 8]
# [T, N, F] --> [N, T*F]
output = self.decoder_layer(output.transpose(1, 0)[:, -1, :]) # [512, 1]
return output.squeeze()

View File

@@ -62,7 +62,7 @@ class XGBModel(Model, FeatureInt):
if self.model is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
return pd.Series(self.model.predict(xgb.DMatrix(x_test.values)), index=x_test.index)
return pd.Series(self.model.predict(xgb.DMatrix(x_test)), index=x_test.index)
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
"""get feature importance

View File

@@ -3,7 +3,6 @@
import pandas as pd
import plotly.tools as tls
import plotly.graph_objs as go
import statsmodels.api as sm
@@ -80,9 +79,35 @@ def _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure:
:param dist:
:return:
"""
fig, ax = plt.subplots(figsize=(8, 5))
_mpl_fig = sm.qqplot(data.dropna(), dist, fit=True, line="45", ax=ax)
return tls.mpl_to_plotly(_mpl_fig)
# NOTE: plotly.tools.mpl_to_plotly not actively maintained, resulting in errors in the new version of matplotlib,
# ref: https://github.com/plotly/plotly.py/issues/2913#issuecomment-730071567
# removing plotly.tools.mpl_to_plotly for greater compatibility with matplotlib versions
_plt_fig = sm.qqplot(data.dropna(), dist=dist, fit=True, line="45")
plt.close(_plt_fig)
qqplot_data = _plt_fig.gca().lines
fig = go.Figure()
fig.add_trace(
{
"type": "scatter",
"x": qqplot_data[0].get_xdata(),
"y": qqplot_data[0].get_ydata(),
"mode": "markers",
"marker": {"color": "#19d3f3"},
}
)
fig.add_trace(
{
"type": "scatter",
"x": qqplot_data[1].get_xdata(),
"y": qqplot_data[1].get_ydata(),
"mode": "lines",
"line": {"color": "#636efa"},
}
)
del qqplot_data
return fig
def _pred_ic(pred_label: pd.DataFrame = None, rank: bool = False, **kwargs) -> tuple:

View File

@@ -148,7 +148,6 @@ class WeightStrategyBase(BaseStrategy, AdjustTimer):
pred score for this trade date, index is stock_id, contain 'score' column.
current : Position()
current position.
trade_exchange : Exchange()
trade_date : pd.Timestamp
trade date.
"""
@@ -222,9 +221,9 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
only_tradable : bool
will the strategy only consider the tradable stock when buying and selling.
if only_tradable:
strategy will make buy sell decision without checking the tradable state of the stock.
the strategy will peek at the information in the short future to avoid untradable stocks (untradable stocks include stocks that meet suspension, or hit limit up or limit down).
else:
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
the strategy will generate orders without peeking any information in the future, so the order generated by the strategies may fail.
"""
super(TopkDropoutStrategy, self).__init__()
ListAdjustTimer.__init__(self, kwargs.get("adjust_dates", None))

View File

@@ -196,9 +196,9 @@ class Feature(Expression):
def __init__(self, name=None):
if name:
self._name = name.lower()
self._name = name
else:
self._name = type(self).__name__.lower()
self._name = type(self).__name__
def __str__(self):
return "$" + self._name

View File

@@ -17,6 +17,7 @@ import abc
from pathlib import Path
import numpy as np
import pandas as pd
from typing import Union, Iterable
from collections import OrderedDict
from ..config import C
@@ -216,12 +217,14 @@ class CacheUtils:
redis_lock.reset_all(r)
@staticmethod
def visit(cache_path):
def visit(cache_path: Union[str, Path]):
# FIXME: Because read_lock was canceled when reading the cache, multiple processes may have read and write exceptions here
try:
with open(cache_path + ".meta", "rb") as f:
cache_path = Path(cache_path)
meta_path = cache_path.with_suffix(".meta")
with meta_path.open("rb") as f:
d = pickle.load(f)
with open(cache_path + ".meta", "wb") as f:
with meta_path.open("wb") as f:
try:
d["meta"]["last_visit"] = str(time.time())
d["meta"]["visits"] = d["meta"]["visits"] + 1
@@ -237,7 +240,7 @@ class CacheUtils:
lock.acquire()
except redis_lock.AlreadyAcquired:
raise QlibCacheException(
f"""It sees the key(lock:{repr(lock_name)[1:-1]}-wlock) of the redis lock has existed in your redis db now.
f"""It sees the key(lock:{repr(lock_name)[1:-1]}-wlock) of the redis lock has existed in your redis db now.
You can use the following command to clear your redis keys and rerun your commands:
$ redis-cli
> select {C.redis_task_db}
@@ -249,17 +252,17 @@ class CacheUtils:
@staticmethod
@contextlib.contextmanager
def reader_lock(redis_t, lock_name):
lock_name = f"{C.provider_uri}:{lock_name}"
current_cache_rlock = redis_lock.Lock(redis_t, "%s-rlock" % lock_name)
current_cache_wlock = redis_lock.Lock(redis_t, "%s-wlock" % lock_name)
def reader_lock(redis_t, lock_name: str):
current_cache_rlock = redis_lock.Lock(redis_t, f"{lock_name}-rlock")
current_cache_wlock = redis_lock.Lock(redis_t, f"{lock_name}-wlock")
lock_reader = f"{lock_name}-reader"
# make sure only one reader is entering
current_cache_rlock.acquire(timeout=60)
try:
current_cache_readers = redis_t.get("%s-reader" % lock_name)
current_cache_readers = redis_t.get(lock_reader)
if current_cache_readers is None or int(current_cache_readers) == 0:
CacheUtils.acquire(current_cache_wlock, lock_name)
redis_t.incr("%s-reader" % lock_name)
redis_t.incr(lock_reader)
finally:
current_cache_rlock.release()
try:
@@ -268,9 +271,9 @@ class CacheUtils:
# make sure only one reader is leaving
current_cache_rlock.acquire(timeout=60)
try:
redis_t.decr("%s-reader" % lock_name)
if int(redis_t.get("%s-reader" % lock_name)) == 0:
redis_t.delete("%s-reader" % lock_name)
redis_t.decr(lock_reader)
if int(redis_t.get(lock_reader)) == 0:
redis_t.delete(lock_reader)
current_cache_wlock.reset()
finally:
current_cache_rlock.release()
@@ -278,8 +281,7 @@ class CacheUtils:
@staticmethod
@contextlib.contextmanager
def writer_lock(redis_t, lock_name):
lock_name = f"{C.provider_uri}:{lock_name}"
current_cache_wlock = redis_lock.Lock(redis_t, "%s-wlock" % lock_name, id=CacheUtils.LOCK_ID)
current_cache_wlock = redis_lock.Lock(redis_t, f"{lock_name}-wlock", id=CacheUtils.LOCK_ID)
CacheUtils.acquire(current_cache_wlock, lock_name)
try:
yield
@@ -297,6 +299,30 @@ class BaseProviderCache:
def __getattr__(self, attr):
return getattr(self.provider, attr)
@staticmethod
def check_cache_exists(cache_path: Union[str, Path], suffix_list: Iterable = (".index", ".meta")) -> bool:
cache_path = Path(cache_path)
for p in [cache_path] + [cache_path.with_suffix(_s) for _s in suffix_list]:
if not p.exists():
return False
return True
@staticmethod
def clear_cache(cache_path: Union[str, Path]):
for p in [
cache_path,
cache_path.with_suffix(".meta"),
cache_path.with_suffix(".index"),
]:
if p.exists():
p.unlink()
@staticmethod
def get_cache_dir(dir_name: str, freq: str = None) -> Path:
cache_dir = Path(C.dpm.get_data_path(freq)).joinpath(dir_name)
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir
class ExpressionCache(BaseProviderCache):
"""Expression cache mechanism base class.
@@ -330,15 +356,16 @@ class ExpressionCache(BaseProviderCache):
"""
raise NotImplementedError("Implement this method if you want to use expression cache")
def update(self, cache_uri):
def update(self, cache_uri: Union[str, Path], freq: str = "day"):
"""Update expression cache to latest calendar.
Overide this method to define how to update expression cache corresponding to users' own cache mechanism.
Parameters
----------
cache_uri : str
cache_uri : str or Path
the complete uri of expression cache file (include dir path).
freq : str
Returns
-------
@@ -358,7 +385,9 @@ class DatasetCache(BaseProviderCache):
HDF_KEY = "df"
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
def dataset(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
):
"""Get feature dataset.
.. note:: Same interface as `dataset` method in dataset provider
@@ -369,13 +398,19 @@ class DatasetCache(BaseProviderCache):
"""
if disk_cache == 0:
# skip cache
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
return self.provider.dataset(
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
)
else:
# use and replace cache
try:
return self._dataset(instruments, fields, start_time, end_time, freq, disk_cache)
return self._dataset(
instruments, fields, start_time, end_time, freq, disk_cache, inst_processors=inst_processors
)
except NotImplementedError:
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
return self.provider.dataset(
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
)
def _uri(self, instruments, fields, start_time, end_time, freq, **kwargs):
"""Get dataset cache file uri.
@@ -384,14 +419,18 @@ class DatasetCache(BaseProviderCache):
"""
raise NotImplementedError("Implement this function to match your own cache mechanism")
def _dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
def _dataset(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
):
"""Get feature dataset using cache.
Override this method to define how to get feature dataset corresponding to users' own cache mechanism.
"""
raise NotImplementedError("Implement this method if you want to use dataset feature cache")
def _dataset_uri(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
def _dataset_uri(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
):
"""Get a uri of feature dataset using cache.
specially:
disk_cache=1 means using data set cache and return the uri of cache file.
@@ -403,15 +442,16 @@ class DatasetCache(BaseProviderCache):
"Implement this method if you want to use dataset feature cache as a cache file for client"
)
def update(self, cache_uri):
def update(self, cache_uri: Union[str, Path], freq: str = "day"):
"""Update dataset cache to latest calendar.
Overide this method to define how to update dataset cache corresponding to users' own cache mechanism.
Parameters
----------
cache_uri : str
cache_uri : str or Path
the complete uri of dataset cache file (include dir path).
freq : str
Returns
-------
@@ -452,25 +492,19 @@ class DiskExpressionCache(ExpressionCache):
self.r = get_redis_connection()
# remote==True means client is using this module, writing behaviour will not be allowed.
self.remote = kwargs.get("remote", False)
self.expr_cache_path = os.path.join(C.get_data_path(), C.features_cache_dir_name)
os.makedirs(self.expr_cache_path, exist_ok=True)
def get_cache_dir(self, freq: str = None) -> Path:
return super(DiskExpressionCache, self).get_cache_dir(C.features_cache_dir_name, freq)
def _uri(self, instrument, field, start_time, end_time, freq):
field = remove_fields_space(field)
instrument = str(instrument).lower()
return hash_args(instrument, field, freq)
@staticmethod
def check_cache_exists(cache_path):
for p in [cache_path, cache_path + ".meta"]:
if not Path(p).exists():
return False
return True
def _expression(self, instrument, field, start_time=None, end_time=None, freq="day"):
_cache_uri = self._uri(instrument=instrument, field=field, start_time=None, end_time=None, freq=freq)
_instrument_dir = os.path.join(self.expr_cache_path, instrument.lower())
cache_path = os.path.join(_instrument_dir, _cache_uri)
_instrument_dir = self.get_cache_dir(freq).joinpath(instrument.lower())
cache_path = _instrument_dir.joinpath(_cache_uri)
# get calendar
from .data import Cal
@@ -478,7 +512,7 @@ class DiskExpressionCache(ExpressionCache):
_, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq, future=False)
if self.check_cache_exists(cache_path):
if self.check_cache_exists(cache_path, suffix_list=[".meta"]):
"""
In most cases, we do not need reader_lock.
Because updating data is a small probability event compare to reading data.
@@ -502,8 +536,7 @@ class DiskExpressionCache(ExpressionCache):
# normalize field
field = remove_fields_space(field)
# cache unavailable, generate the cache
if not os.path.exists(_instrument_dir):
os.makedirs(_instrument_dir, exist_ok=True)
_instrument_dir.mkdir(parents=True, exist_ok=True)
if not isinstance(eval(parse_field(field)), Feature):
# When the expression is not a raw feature
# generate expression cache if the feature is not a Feature
@@ -511,7 +544,7 @@ class DiskExpressionCache(ExpressionCache):
series = self.provider.expression(instrument, field, _calendar[0], _calendar[-1], freq)
if not series.empty:
# This expresion is empty, we don't generate any cache for it.
with CacheUtils.writer_lock(self.r, "expression-%s" % _cache_uri):
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:expression-{_cache_uri}"):
self.gen_expression_cache(
expression_data=series,
cache_path=cache_path,
@@ -527,14 +560,6 @@ class DiskExpressionCache(ExpressionCache):
# If the expression is a raw feature(such as $close, $open)
return self.provider.expression(instrument, field, start_time, end_time, freq)
@staticmethod
def clear_cache(cache_path):
meta_path = cache_path + ".meta"
for p in [cache_path, meta_path]:
p = Path(p)
if p.exists():
p.unlink()
def gen_expression_cache(self, expression_data, cache_path, instrument, field, freq, last_update):
"""use bin file to save like feature-data."""
# Make sure the cache runs right when the directory is deleted
@@ -544,27 +569,28 @@ class DiskExpressionCache(ExpressionCache):
"meta": {"last_visit": time.time(), "visits": 1},
}
self.logger.debug(f"generating expression cache: {meta}")
os.makedirs(self.expr_cache_path, exist_ok=True)
self.clear_cache(cache_path)
meta_path = cache_path + ".meta"
meta_path = cache_path.with_suffix(".meta")
with open(meta_path, "wb") as f:
with meta_path.open("wb") as f:
pickle.dump(meta, f)
os.chmod(meta_path, stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
meta_path.chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
df = expression_data.to_frame()
r = np.hstack([df.index[0], expression_data]).astype("<f")
r.tofile(str(cache_path))
def update(self, sid, cache_uri):
cp_cache_uri = os.path.join(self.expr_cache_path, sid, cache_uri)
if not self.check_cache_exists(cp_cache_uri):
def update(self, sid, cache_uri, freq: str = "day"):
cp_cache_uri = self.get_cache_dir(freq).joinpath(sid).joinpath(cache_uri)
meta_path = cp_cache_uri.with_suffix(".meta")
if not self.check_cache_exists(cp_cache_uri, suffix_list=[".meta"]):
self.logger.info(f"The cache {cp_cache_uri} has corrupted. It will be removed")
self.clear_cache(cp_cache_uri)
return 2
with CacheUtils.writer_lock(self.r, "expression-%s" % cache_uri):
with open(cp_cache_uri + ".meta", "rb") as f:
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path())}:expression-{cache_uri}"):
with meta_path.open("rb") as f:
d = pickle.load(f)
instrument = d["info"]["instrument"]
field = d["info"]["field"]
@@ -611,7 +637,7 @@ class DiskExpressionCache(ExpressionCache):
f.write(data)
# update meta file
d["info"]["last_update"] = str(new_calendar[-1])
with open(cp_cache_uri + ".meta", "wb") as f:
with meta_path.open("wb") as f:
pickle.dump(d, f)
return 0
@@ -623,22 +649,16 @@ class DiskDatasetCache(DatasetCache):
super(DiskDatasetCache, self).__init__(provider)
self.r = get_redis_connection()
self.remote = kwargs.get("remote", False)
self.dtst_cache_path = os.path.join(C.get_data_path(), C.dataset_cache_dir_name)
os.makedirs(self.dtst_cache_path, exist_ok=True)
@staticmethod
def _uri(instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs):
return hash_args(*DatasetCache.normalize_uri_args(instruments, fields, freq), disk_cache)
def _uri(instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):
return hash_args(*DatasetCache.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors)
@staticmethod
def check_cache_exists(cache_path):
for p in [cache_path, cache_path + ".index", cache_path + ".meta"]:
if not Path(p).exists():
return False
return True
def get_cache_dir(self, freq: str = None) -> Path:
return super(DiskDatasetCache, self).get_cache_dir(C.dataset_cache_dir_name, freq)
@classmethod
def read_data_from_cache(cls, cache_path, start_time, end_time, fields):
def read_data_from_cache(cls, cache_path: Union[str, Path], start_time, end_time, fields):
"""read_cache_from
This function can read data from the disk cache dataset
@@ -671,17 +691,32 @@ class DiskDatasetCache(DatasetCache):
df = pd.DataFrame(columns=fields)
return df
def _dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0):
def _dataset(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[]
):
if disk_cache == 0:
# In this case, data_set cache is configured but will not be used.
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
return self.provider.dataset(
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
)
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
if inst_processors:
raise ValueError(
f"{self.__class__.__name__} does not support inst_processor. "
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
)
_cache_uri = self._uri(
instruments=instruments, fields=fields, start_time=None, end_time=None, freq=freq, disk_cache=disk_cache
instruments=instruments,
fields=fields,
start_time=None,
end_time=None,
freq=freq,
disk_cache=disk_cache,
inst_processors=inst_processors,
)
cache_path = os.path.join(self.dtst_cache_path, _cache_uri)
cache_path = self.get_cache_dir(freq).joinpath(_cache_uri)
features = pd.DataFrame()
gen_flag = False
@@ -689,7 +724,7 @@ class DiskDatasetCache(DatasetCache):
if self.check_cache_exists(cache_path):
if disk_cache == 1:
# use cache
with CacheUtils.reader_lock(self.r, "dataset-%s" % _cache_uri):
with CacheUtils.reader_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
CacheUtils.visit(cache_path)
features = self.read_data_from_cache(cache_path, start_time, end_time, fields)
elif disk_cache == 2:
@@ -699,15 +734,21 @@ class DiskDatasetCache(DatasetCache):
if gen_flag:
# cache unavailable, generate the cache
with CacheUtils.writer_lock(self.r, "dataset-%s" % _cache_uri):
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
features = self.gen_dataset_cache(
cache_path=cache_path, instruments=instruments, fields=fields, freq=freq
cache_path=cache_path,
instruments=instruments,
fields=fields,
freq=freq,
inst_processors=inst_processors,
)
if not features.empty:
features = features.sort_index().loc(axis=0)[:, start_time:end_time]
return features
def _dataset_uri(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0):
def _dataset_uri(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[]
):
if disk_cache == 0:
# In this case, server only checks the expression cache.
# The client will load the cache data by itself.
@@ -715,21 +756,38 @@ class DiskDatasetCache(DatasetCache):
LocalDatasetProvider.multi_cache_walker(instruments, fields, start_time, end_time, freq)
return ""
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
if inst_processors:
raise ValueError(
f"{self.__class__.__name__} does not support inst_processor. "
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
)
_cache_uri = self._uri(
instruments=instruments, fields=fields, start_time=None, end_time=None, freq=freq, disk_cache=disk_cache
instruments=instruments,
fields=fields,
start_time=None,
end_time=None,
freq=freq,
disk_cache=disk_cache,
inst_processors=inst_processors,
)
cache_path = os.path.join(self.dtst_cache_path, _cache_uri)
cache_path = self.get_cache_dir(freq).joinpath(_cache_uri)
if self.check_cache_exists(cache_path):
self.logger.debug(f"The cache dataset has already existed {cache_path}. Return the uri directly")
with CacheUtils.reader_lock(self.r, "dataset-%s" % _cache_uri):
with CacheUtils.reader_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
CacheUtils.visit(cache_path)
return _cache_uri
else:
# cache unavailable, generate the cache
with CacheUtils.writer_lock(self.r, "dataset-%s" % _cache_uri):
self.gen_dataset_cache(cache_path=cache_path, instruments=instruments, fields=fields, freq=freq)
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
self.gen_dataset_cache(
cache_path=cache_path,
instruments=instruments,
fields=fields,
freq=freq,
inst_processors=inst_processors,
)
return _cache_uri
class IndexManager:
@@ -740,8 +798,9 @@ class DiskDatasetCache(DatasetCache):
KEY = "df"
def __init__(self, cache_path):
self.index_path = cache_path + ".index"
def __init__(self, cache_path: Union[str, Path]):
self.index_path = cache_path.with_suffix(".index")
self._data = None
self.logger = get_module_logger(self.__class__.__name__)
@@ -757,7 +816,7 @@ class DiskDatasetCache(DatasetCache):
self._data.sort_index(inplace=True)
self._data.to_hdf(self.index_path, key=self.KEY, mode="w", format="table")
# The index should be readable for all users
os.chmod(self.index_path, stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
self.index_path.chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
def sync_from_disk(self):
# The file will not be closed directly if we read_hdf from the disk directly
@@ -784,10 +843,10 @@ class DiskDatasetCache(DatasetCache):
def build_index_from_data(data, start_index=0):
if data.empty:
return pd.DataFrame()
line_data = data.iloc[:, 0].fillna(0).groupby("datetime").count()
line_data = data.groupby("datetime").size()
line_data.sort_index(inplace=True)
index_end = line_data.cumsum()
index_start = index_end.shift(1).fillna(0)
index_start = index_end.shift(1, fill_value=0)
index_data = pd.DataFrame()
index_data["start"] = index_start
@@ -795,15 +854,7 @@ class DiskDatasetCache(DatasetCache):
index_data += start_index
return index_data
@staticmethod
def clear_cache(cache_path):
meta_path = cache_path + ".meta"
for p in [cache_path, meta_path, cache_path + ".index", cache_path + ".data"]:
p = Path(p)
if p.exists():
p.unlink()
def gen_dataset_cache(self, cache_path, instruments, fields, freq):
def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, freq, inst_processors=[]):
"""gen_dataset_cache
.. note:: This function does not consider the cache read write lock. Please
@@ -838,20 +889,23 @@ class DiskDatasetCache(DatasetCache):
:param instruments: The instruments to store the cache.
:param fields: The fields to store the cache.
:param freq: The freq to store the cache.
:param inst_processors: Instrument processors.
:return type pd.DataFrame; The fields of the returned DataFrame are consistent with the parameters of the function.
"""
# get calendar
from .data import Cal
cache_path = Path(cache_path)
_calendar = Cal.calendar(freq=freq)
self.logger.debug(f"Generating dataset cache {cache_path}")
# Make sure the cache runs right when the directory is deleted
# while running
os.makedirs(self.dtst_cache_path, exist_ok=True)
self.clear_cache(cache_path)
features = self.provider.dataset(instruments, fields, _calendar[0], _calendar[-1], freq)
features = self.provider.dataset(
instruments, fields, _calendar[0], _calendar[-1], freq, inst_processors=inst_processors
)
if features.empty:
return features
@@ -860,7 +914,7 @@ class DiskDatasetCache(DatasetCache):
features = features.swaplevel("instrument", "datetime").sort_index()
# write cache data
with pd.HDFStore(cache_path + ".data") as store:
with pd.HDFStore(str(cache_path.with_suffix(".data"))) as store:
cache_to_orig_map = dict(zip(remove_fields_space(features.columns), features.columns))
orig_to_cache_map = dict(zip(features.columns, remove_fields_space(features.columns)))
cache_features = features[list(cache_to_orig_map.values())].rename(columns=orig_to_cache_map)
@@ -876,12 +930,13 @@ class DiskDatasetCache(DatasetCache):
"fields": cache_columns,
"freq": freq,
"last_update": str(_calendar[-1]), # The last_update to store the cache
"inst_processors": inst_processors, # The last_update to store the cache
},
"meta": {"last_visit": time.time(), "visits": 1},
}
with open(cache_path + ".meta", "wb") as f:
with cache_path.with_suffix(".meta").open("wb") as f:
pickle.dump(meta, f)
os.chmod(cache_path + ".meta", stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
cache_path.with_suffix(".meta").chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
# write index file
im = DiskDatasetCache.IndexManager(cache_path)
index_data = im.build_index_from_data(features)
@@ -890,26 +945,27 @@ class DiskDatasetCache(DatasetCache):
# rename the file after the cache has been generated
# this doesn't work well on windows, but our server won't use windows
# temporarily
os.replace(cache_path + ".data", cache_path)
cache_path.with_suffix(".data").rename(cache_path)
# the fields of the cached features are converted to the original fields
return features.swaplevel("datetime", "instrument")
def update(self, cache_uri):
cp_cache_uri = os.path.join(self.dtst_cache_path, cache_uri)
def update(self, cache_uri, freq: str = "day"):
cp_cache_uri = self.get_cache_dir(freq).joinpath(cache_uri)
meta_path = cp_cache_uri.with_suffix(".meta")
if not self.check_cache_exists(cp_cache_uri):
self.logger.info(f"The cache {cp_cache_uri} has corrupted. It will be removed")
self.clear_cache(cp_cache_uri)
return 2
im = DiskDatasetCache.IndexManager(cp_cache_uri)
with CacheUtils.writer_lock(self.r, "dataset-%s" % cache_uri):
with open(cp_cache_uri + ".meta", "rb") as f:
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path())}:dataset-{cache_uri}"):
with meta_path.open("rb") as f:
d = pickle.load(f)
instruments = d["info"]["instruments"]
fields = d["info"]["fields"]
freq = d["info"]["freq"]
last_update_time = d["info"]["last_update"]
inst_processors = d["info"]["inst_processors"]
index_data = im.get_index()
self.logger.debug("Updating dataset: {}".format(d))
@@ -960,7 +1016,12 @@ class DiskDatasetCache(DatasetCache):
)
data = self.provider.dataset(
instruments, fields, whole_calendar[current_index - rm_n_period], new_calendar[-1], freq
instruments,
fields,
whole_calendar[current_index - rm_n_period],
new_calendar[-1],
freq,
inst_processors=inst_processors,
)
if not data.empty:
@@ -995,7 +1056,7 @@ class DiskDatasetCache(DatasetCache):
# update meta file
d["info"]["last_update"] = str(new_calendar[-1])
with open(cp_cache_uri + ".meta", "wb") as f:
with meta_path.open("wb") as f:
pickle.dump(d, f)
return 0
@@ -1006,26 +1067,36 @@ class SimpleDatasetCache(DatasetCache):
def __init__(self, provider):
super(SimpleDatasetCache, self).__init__(provider)
try:
self.local_cache_path = C["local_cache_path"]
except KeyError as e:
self.local_cache_path: Path = Path(C["local_cache_path"]).expanduser().resolve()
except (KeyError, TypeError) as e:
self.logger.error("Assign a local_cache_path in config if you want to use this cache mechanism")
raise
self.logger.info(
f"DatasetCache directory: {self.local_cache_path}, "
f"modify the cache directory via the local_cache_path in the config"
)
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs):
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):
instruments, fields, freq = self.normalize_uri_args(instruments, fields, freq)
local_cache_path = str(Path(self.local_cache_path).expanduser().resolve())
return hash_args(instruments, fields, start_time, end_time, freq, disk_cache, local_cache_path)
return hash_args(
instruments, fields, start_time, end_time, freq, disk_cache, str(self.local_cache_path), inst_processors
)
def _dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
def _dataset(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
):
if disk_cache == 0:
# In this case, data_set cache is configured but will not be used.
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
os.makedirs(os.path.expanduser(self.local_cache_path), exist_ok=True)
cache_file = os.path.join(
self.local_cache_path, self._uri(instruments, fields, start_time, end_time, freq, disk_cache=disk_cache)
self.local_cache_path.mkdir(exist_ok=True, parents=True)
cache_file = self.local_cache_path.joinpath(
self._uri(
instruments, fields, start_time, end_time, freq, disk_cache=disk_cache, inst_processors=inst_processors
)
)
gen_flag = False
if os.path.exists(cache_file):
if cache_file.exists():
if disk_cache == 1:
# use cache
df = pd.read_pickle(cache_file)
@@ -1037,7 +1108,9 @@ class SimpleDatasetCache(DatasetCache):
gen_flag = True
if gen_flag:
data = self.provider.dataset(instruments, normalize_cache_fields(fields), start_time, end_time, freq)
data = self.provider.dataset(
instruments, normalize_cache_fields(fields), start_time, end_time, freq, inst_processors=inst_processors
)
data.to_pickle(cache_file)
return self.cache_to_origin_data(data, fields)
@@ -1045,26 +1118,53 @@ class SimpleDatasetCache(DatasetCache):
class DatasetURICache(DatasetCache):
"""Prepared cache mechanism for server."""
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs):
return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache)
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):
return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors)
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0):
def dataset(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[]
):
if "local" in C.dataset_provider.lower():
# use LocalDatasetProvider
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
return self.provider.dataset(
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
)
if disk_cache == 0:
# do not use data_set cache, load data from remote expression cache directly
return self.provider.dataset(instruments, fields, start_time, end_time, freq, disk_cache, return_uri=False)
return self.provider.dataset(
instruments,
fields,
start_time,
end_time,
freq,
disk_cache,
return_uri=False,
inst_processors=inst_processors,
)
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
if inst_processors:
raise ValueError(
f"{self.__class__.__name__} does not support inst_processor. "
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
)
# use ClientDatasetProvider
feature_uri = self._uri(instruments, fields, None, None, freq, disk_cache=disk_cache)
feature_uri = self._uri(
instruments, fields, None, None, freq, disk_cache=disk_cache, inst_processors=inst_processors
)
value, expire = MemCacheExpire.get_cache(H["f"], feature_uri)
mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri)
if value is None or expire or not os.path.exists(mnt_feature_uri):
mnt_feature_uri = C.dpm.get_data_path(freq).joinpath(C.dataset_cache_dir_name).joinpath(feature_uri)
if value is None or expire or not mnt_feature_uri.exists():
df, uri = self.provider.dataset(
instruments, fields, start_time, end_time, freq, disk_cache, return_uri=True
instruments,
fields,
start_time,
end_time,
freq,
disk_cache,
return_uri=True,
inst_processors=inst_processors,
)
# cache uri
MemCacheExpire.set_cache(H["f"], uri, uri)
@@ -1072,7 +1172,6 @@ class DatasetURICache(DatasetCache):
# HZ['f'][uri] = df.copy()
get_module_logger("cache").debug(f"get feature from {C.dataset_provider}")
else:
mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri)
df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields)
get_module_logger("cache").debug("get feature from uri cache")

View File

@@ -5,28 +5,34 @@
from __future__ import division
from __future__ import print_function
import os
import re
import abc
import copy
import time
import queue
import bisect
import logging
import importlib
import traceback
import numpy as np
import pandas as pd
from multiprocessing import Pool
from typing import Iterable, Union
from .cache import H
from ..config import C
from .ops import Operators
from ..log import get_module_logger
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, code_to_fname
from .base import Feature
from .ops import Operators
from .inst_processor import InstProcessor
from ..log import get_module_logger
from .cache import DiskDatasetCache, DiskExpressionCache
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
from ..utils import (
Wrapper,
init_instance_by_config,
register_wrapper,
get_module_by_module_path,
parse_field,
hash_args,
normalize_cache_fields,
code_to_fname,
)
class ProviderBackendMixin:
@@ -48,8 +54,14 @@ class ProviderBackendMixin:
# default provider_uri map
if "provider_uri" not in backend_kwargs:
# if the user has no uri configured, use: uri = uri_map[freq]
# NOTE: provider_uri priority
# 1. backend_config: backend_obj["kwargs"]["provider_uri"]
# 2. backend_config: backend_obj["kwargs"]["provider_uri_map"]
# 3. qlib.init: provider_uri
provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {})
freq = kwargs.get("freq", "day")
provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {freq: C.get_data_path()})
if freq not in provider_uri_map:
provider_uri_map[freq] = C.dpm.get_data_path(freq)
backend_kwargs["provider_uri"] = provider_uri_map[freq]
backend.setdefault("kwargs", {}).update(**kwargs)
return init_instance_by_config(backend)
@@ -341,7 +353,7 @@ class DatasetProvider(abc.ABC):
"""
@abc.abstractmethod
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day"):
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", inst_processors=[]):
"""Get dataset data.
Parameters
@@ -356,6 +368,8 @@ class DatasetProvider(abc.ABC):
end of the time range.
freq : str
time frequency.
inst_processors: Iterable[Union[dict, InstProcessor]]
the operations performed on each instrument
Returns
----------
@@ -372,6 +386,7 @@ class DatasetProvider(abc.ABC):
end_time=None,
freq="day",
disk_cache=1,
inst_processors=[],
**kwargs,
):
"""Get task uri, used when generating rabbitmq task in qlib_server
@@ -392,7 +407,8 @@ class DatasetProvider(abc.ABC):
whether to skip(0)/use(1)/replace(2) disk_cache.
"""
return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache)
# TODO: qlib-server support inst_processors
return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache, inst_processors)
@staticmethod
def get_instruments_d(instruments, freq):
@@ -433,7 +449,7 @@ class DatasetProvider(abc.ABC):
return [ExpressionD.get_expression_instance(f) for f in fields]
@staticmethod
def dataset_processor(instruments_d, column_names, start_time, end_time, freq):
def dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors=[]):
"""
Load and process the data, return the data set.
- default using multi-kernel method.
@@ -459,6 +475,7 @@ class DatasetProvider(abc.ABC):
normalize_column_names,
spans,
C,
inst_processors,
),
)
else:
@@ -473,6 +490,7 @@ class DatasetProvider(abc.ABC):
normalize_column_names,
None,
C,
inst_processors,
),
)
@@ -494,7 +512,9 @@ class DatasetProvider(abc.ABC):
return data
@staticmethod
def expression_calculator(inst, start_time, end_time, freq, column_names, spans=None, g_config=None):
def expression_calculator(
inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=[]
):
"""
Calculate the expressions for one instrument, return a df result.
If the expression has been calculated before, load from cache.
@@ -518,13 +538,17 @@ class DatasetProvider(abc.ABC):
data.index = _calendar[data.index.values.astype(int)]
data.index.names = ["datetime"]
if spans is None:
return data
else:
if spans is not None:
mask = np.zeros(len(data), dtype=bool)
for begin, end in spans:
mask |= (data.index >= begin) & (data.index <= end)
return data[mask]
data = data[mask]
for _processor in inst_processors:
if _processor:
_processor_obj = init_instance_by_config(_processor, accept_types=InstProcessor)
data = _processor_obj(data)
return data
class LocalCalendarProvider(CalendarProvider):
@@ -537,11 +561,6 @@ class LocalCalendarProvider(CalendarProvider):
super(LocalCalendarProvider, self).__init__(**kwargs)
self.remote = kwargs.get("remote", False)
@property
def _uri_cal(self):
"""Calendar file uri."""
return os.path.join(C.get_data_path(), "calendars", "{}.txt")
def load_calendar(self, freq, future):
"""Load original calendar timestamp from file.
@@ -601,11 +620,6 @@ class LocalInstrumentProvider(InstrumentProvider):
Provide instrument data from local data source.
"""
@property
def _uri_inst(self):
"""Instrument file uri."""
return os.path.join(C.get_data_path(), "instruments", "{}.txt")
def _load_instruments(self, market, freq):
return self.backend_obj(market=market, freq=freq).data
@@ -654,14 +668,9 @@ class LocalFeatureProvider(FeatureProvider):
super(LocalFeatureProvider, self).__init__(**kwargs)
self.remote = kwargs.get("remote", False)
@property
def _uri_data(self):
"""Static feature file uri."""
return os.path.join(C.get_data_path(), "features", "{}", "{}.{}.bin")
def feature(self, instrument, field, start_index, end_index, freq):
# validate
field = str(field).lower()[1:]
field = str(field)[1:]
instrument = code_to_fname(instrument)
return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]
@@ -703,7 +712,15 @@ class LocalDatasetProvider(DatasetProvider):
def __init__(self):
pass
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day"):
def dataset(
self,
instruments,
fields,
start_time=None,
end_time=None,
freq="day",
inst_processors=[],
):
instruments_d = self.get_instruments_d(instruments, freq)
column_names = self.get_column_names(fields)
cal = Cal.calendar(start_time, end_time, freq)
@@ -712,7 +729,9 @@ class LocalDatasetProvider(DatasetProvider):
start_time = cal[0]
end_time = cal[-1]
data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq)
data = self.dataset_processor(
instruments_d, column_names, start_time, end_time, freq, inst_processors=inst_processors
)
return data
@@ -855,6 +874,7 @@ class ClientDatasetProvider(DatasetProvider):
freq="day",
disk_cache=0,
return_uri=False,
inst_processors=[],
):
if Inst.get_inst_type(instruments) == Inst.DICT:
get_module_logger("data").warning(
@@ -894,7 +914,7 @@ class ClientDatasetProvider(DatasetProvider):
start_time = cal[0]
end_time = cal[-1]
data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq)
data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors)
if return_uri:
return data, feature_uri
else:
@@ -907,6 +927,13 @@ class ClientDatasetProvider(DatasetProvider):
- using single-process implementation.
"""
# TODO: support inst_processors, need to change the code of qlib-server at the same time
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
if inst_processors:
raise ValueError(
f"{self.__class__.__name__} does not support inst_processor. "
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
)
self.conn.send_request(
request_type="feature",
request_content={
@@ -926,7 +953,7 @@ class ClientDatasetProvider(DatasetProvider):
get_module_logger("data").debug("get result")
try:
# pre-mound nfs, used for demo
mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri)
mnt_feature_uri = C.dpm.get_data_path(freq).joinpath(C.dataset_cache_dir_name, feature_uri)
df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields)
get_module_logger("data").debug("finish slicing data")
if return_uri:
@@ -964,6 +991,7 @@ class BaseProvider:
end_time=None,
freq="day",
disk_cache=None,
inst_processors=[],
):
"""
Parameters:
@@ -978,9 +1006,11 @@ class BaseProvider:
disk_cache = C.default_disk_cache if disk_cache is None else disk_cache
fields = list(fields) # In case of tuple.
try:
return DatasetD.dataset(instruments, fields, start_time, end_time, freq, disk_cache)
return DatasetD.dataset(
instruments, fields, start_time, end_time, freq, disk_cache, inst_processors=inst_processors
)
except TypeError:
return DatasetD.dataset(instruments, fields, start_time, end_time, freq)
return DatasetD.dataset(instruments, fields, start_time, end_time, freq, inst_processors=inst_processors)
class LocalProvider(BaseProvider):
@@ -1028,13 +1058,21 @@ class ClientProvider(BaseProvider):
"""
def __init__(self):
def is_instance_of_provider(instance: object, cls: type):
if isinstance(instance, Wrapper):
p = getattr(instance, "_provider", None)
return False if p is None else isinstance(p, cls)
return isinstance(instance, cls)
from .client import Client
self.client = Client(C.flask_server, C.flask_port)
self.logger = get_module_logger(self.__class__.__name__)
if isinstance(Cal, ClientCalendarProvider):
if is_instance_of_provider(Cal, ClientCalendarProvider):
Cal.set_conn(self.client)
if isinstance(Inst, ClientInstrumentProvider):
if is_instance_of_provider(Inst, ClientInstrumentProvider):
Inst.set_conn(self.client)
if hasattr(DatasetD, "provider"):
DatasetD.provider.set_conn(self.client)

View File

@@ -1,6 +1,6 @@
from ...utils.serial import Serializable
from typing import Union, List, Tuple, Dict, Text, Optional
from ...utils import init_instance_by_config, np_ffill
from ...utils import init_instance_by_config, np_ffill, time_to_slc_point
from ...log import get_module_logger
from .handler import DataHandler, DataHandlerLP
from copy import deepcopy
@@ -243,6 +243,8 @@ class TSDataSampler:
It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series
dataset based on tabular data.
- On time step dimension, the smaller index indicates the historical data and the larger index indicates the future
data.
If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
more powerful subclasses.
@@ -309,11 +311,19 @@ class TSDataSampler:
self.data_index = deepcopy(self.data.index)
if flt_data is not None:
self.flt_data = np.array(flt_data.reindex(self.data_index)).reshape(-1)
if isinstance(flt_data, pd.DataFrame):
assert len(flt_data.columns) == 1
flt_data = flt_data.iloc[:, 0]
# NOTE: bool(np.nan) is True !!!!!!!!
# make sure reindex comes first. Otherwise extra NaN may appear.
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
self.flt_data = flt_data.values
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
self.start_idx, self.end_idx = self.data_index.slice_locs(
start=time_to_slc_point(start), end=time_to_slc_point(end)
)
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
del self.data # save memory
@@ -341,7 +351,7 @@ class TSDataSampler:
setattr(self, k, v)
@staticmethod
def build_index(data: pd.DataFrame) -> dict:
def build_index(data: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
"""
The relation of the data
@@ -352,9 +362,15 @@ class TSDataSampler:
Returns
-------
dict:
{<index>: <prev_index or None>}
# get the previous index of a line given index
Tuple[pd.DataFrame, dict]:
1) the first element: reshape the original index into a <datetime(row), instrument(column)> 2D dataframe
instrument SH600000 SH600004 SH600006 SH600007 SH600008 SH600009 ...
datetime
2021-01-11 0 1 2 3 4 5 ...
2021-01-12 4146 4147 4148 4149 4150 4151 ...
2021-01-13 8293 8294 8295 8296 8297 8298 ...
2021-01-14 12441 12442 12443 12444 12445 12446 ...
2) the second element: {<original index>: <row, col>}
"""
# object incase of pandas converting int to flaot
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)
@@ -491,7 +507,9 @@ class TSDatasetH(DatasetH):
- The dimension of a batch of data <batch_idx, feature, timestep>
"""
def __init__(self, step_len=30, **kwargs):
DEFAULT_STEP_LEN = 30
def __init__(self, step_len=DEFAULT_STEP_LEN, **kwargs):
self.step_len = step_len
super().__init__(**kwargs)

View File

@@ -18,6 +18,7 @@ from ...config import C
from ...utils import parse_config, transform_end_date, init_instance_by_config
from ...utils.serial import Serializable
from .utils import fetch_df_by_index
from ...utils import lazy_sort_index
from pathlib import Path
from .loader import DataLoader
@@ -146,7 +147,8 @@ class DataHandler(Serializable):
# Setup data.
# _data may be with multiple column index level. The outer level indicates the feature set name
with TimeInspector.logt("Loading data"):
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time)
# make sure the fetch method is based on a index-sorted pd.DataFrame
self._data = lazy_sort_index(self.data_loader.load(self.instruments, self.start_time, self.end_time))
# TODO: cache
CS_ALL = "__all" # return all columns with single-level index column
@@ -293,11 +295,14 @@ class DataHandlerLP(DataHandler):
# process type
PTYPE_I = "independent"
# - self._infer will be processed by infer_processors
# - self._learn will be processed by learn_processors
# - self._infer will be processed by shared_processors + infer_processors
# - self._learn will be processed by shared_processors + learn_processors
# NOTE:
PTYPE_A = "append"
# - self._infer will be processed by infer_processors
# - self._learn will be processed by infer_processors + learn_processors
# - self._infer will be processed by shared_processors + infer_processors
# - self._learn will be processed by shared_processors + infer_processors + learn_processors
# - (e.g. self._infer processed by learn_processors )
def __init__(
@@ -306,8 +311,9 @@ class DataHandlerLP(DataHandler):
start_time=None,
end_time=None,
data_loader: Union[dict, str, DataLoader] = None,
infer_processors=[],
learn_processors=[],
infer_processors: List = [],
learn_processors: List = [],
shared_processors: List = [],
process_type=PTYPE_A,
drop_raw=False,
**kwargs,
@@ -358,7 +364,8 @@ class DataHandlerLP(DataHandler):
# Setup preprocessor
self.infer_processors = [] # for lint
self.learn_processors = [] # for lint
for pname in "infer_processors", "learn_processors":
self.shared_processors = [] # for lint
for pname in "infer_processors", "learn_processors", "shared_processors":
for proc in locals()[pname]:
getattr(self, pname).append(
init_instance_by_config(
@@ -373,9 +380,12 @@ class DataHandlerLP(DataHandler):
super().__init__(instruments, start_time, end_time, data_loader, **kwargs)
def get_all_processors(self):
return self.infer_processors + self.learn_processors
return self.shared_processors + self.infer_processors + self.learn_processors
def fit(self):
"""
fit data without processing the data
"""
for proc in self.get_all_processors():
with TimeInspector.logt(f"{proc.__class__.__name__}"):
proc.fit(self._data)
@@ -388,30 +398,68 @@ class DataHandlerLP(DataHandler):
"""
self.process_data(with_fit=True)
@staticmethod
def _run_proc_l(
df: pd.DataFrame, proc_l: List[processor_module.Processor], with_fit: bool, check_for_infer: bool
) -> pd.DataFrame:
for proc in proc_l:
if check_for_infer and not proc.is_for_infer():
raise TypeError("Only processors usable for inference can be used in `infer_processors` ")
with TimeInspector.logt(f"{proc.__class__.__name__}"):
if with_fit:
proc.fit(df)
df = proc(df)
return df
@staticmethod
def _is_proc_readonly(proc_l: List[processor_module.Processor]):
"""
NOTE: it will return True if `len(proc_l) == 0`
"""
for p in proc_l:
if not p.readonly():
return False
return True
def process_data(self, with_fit: bool = False):
"""
process_data data. Fun `processor.fit` if necessary
Notation: (data) [processor]
# data processing flow of self.process_type == DataHandlerLP.PTYPE_I
(self._data)-[shared_processors]-(_shared_df)-[learn_processors]-(_learn_df)
\
-[infer_processors]-(_infer_df)
# data processing flow of self.process_type == DataHandlerLP.PTYPE_A
(self._data)-[shared_processors]-(_shared_df)-[infer_processors]-(_infer_df)-[learn_processors]-(_learn_df)
Parameters
----------
with_fit : bool
The input of the `fit` will be the output of the previous processor
"""
# data for inference
_infer_df = self._data
if len(self.infer_processors) > 0 and not self.drop_raw: # avoid modifying the original data
_infer_df = _infer_df.copy()
# shared data processors
# 1) assign
_shared_df = self._data
if not self._is_proc_readonly(self.shared_processors): # avoid modifying the original data
_shared_df = _shared_df.copy()
# 2) process
_shared_df = self._run_proc_l(_shared_df, self.shared_processors, with_fit=with_fit, check_for_infer=True)
# data for inference
# 1) assign
_infer_df = _shared_df
if not self._is_proc_readonly(self.infer_processors): # avoid modifying the original data
_infer_df = _infer_df.copy()
# 2) process
_infer_df = self._run_proc_l(_infer_df, self.infer_processors, with_fit=with_fit, check_for_infer=True)
for proc in self.infer_processors:
if not proc.is_for_infer():
raise TypeError("Only processors usable for inference can be used in `infer_processors` ")
with TimeInspector.logt(f"{proc.__class__.__name__}"):
if with_fit:
proc.fit(_infer_df)
_infer_df = proc(_infer_df)
self._infer = _infer_df
# data for learning
# 1) assign
if self.process_type == DataHandlerLP.PTYPE_I:
_learn_df = self._data
elif self.process_type == DataHandlerLP.PTYPE_A:
@@ -419,14 +467,11 @@ class DataHandlerLP(DataHandler):
_learn_df = _infer_df
else:
raise NotImplementedError(f"This type of input is not supported")
if len(self.learn_processors) > 0: # avoid modifying the original data
if not self._is_proc_readonly(self.learn_processors): # avoid modifying the original data
_learn_df = _learn_df.copy()
for proc in self.learn_processors:
with TimeInspector.logt(f"{proc.__class__.__name__}"):
if with_fit:
proc.fit(_learn_df)
_learn_df = proc(_learn_df)
# 2) process
_learn_df = self._run_proc_l(_learn_df, self.learn_processors, with_fit=with_fit, check_for_infer=False)
self._learn = _learn_df
if self.drop_raw:

View File

@@ -7,12 +7,12 @@ import warnings
import numpy as np
import pandas as pd
from typing import Tuple, Union
from typing import Tuple, Union, List, Type
from qlib.data import D
from qlib.data import filter as filter_module
from qlib.data.filter import BaseDFilter
from qlib.utils import load_dataset, init_instance_by_config
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point, get_callable_kwargs
from qlib.log import get_module_logger
@@ -62,11 +62,11 @@ class DLWParser(DataLoader):
Extracting this class so that QlibDataLoader and other dataloaders(such as QdbDataLoader) can share the fields.
"""
def __init__(self, config: Tuple[list, tuple, dict]):
def __init__(self, config: Union[list, tuple, dict]):
"""
Parameters
----------
config : Tuple[list, tuple, dict]
config : Union[list, tuple, dict]
Config will be used to describe the fields and column names
.. code-block::
@@ -88,7 +88,7 @@ class DLWParser(DataLoader):
else:
self.fields = self._parse_fields_info(config)
def _parse_fields_info(self, fields_info: Tuple[list, tuple]) -> Tuple[list, list]:
def _parse_fields_info(self, fields_info: Union[list, tuple]) -> Tuple[list, list]:
if len(fields_info) == 0:
raise ValueError("The size of fields must be greater than 0")
@@ -104,7 +104,15 @@ class DLWParser(DataLoader):
return exprs, names
@abc.abstractmethod
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
def load_group_df(
self,
instruments,
exprs: list,
names: list,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
gp_name: str = None,
) -> pd.DataFrame:
"""
load the dataframe for specific group
@@ -128,7 +136,7 @@ class DLWParser(DataLoader):
if self.is_group:
df = pd.concat(
{
grp: self.load_group_df(instruments, exprs, names, start_time, end_time)
grp: self.load_group_df(instruments, exprs, names, start_time, end_time, grp)
for grp, (exprs, names) in self.fields.items()
},
axis=1,
@@ -142,7 +150,14 @@ class DLWParser(DataLoader):
class QlibDataLoader(DLWParser):
"""Same as QlibDataLoader. The fields can be define by config"""
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_level=True, freq="day"):
def __init__(
self,
config: Tuple[list, tuple, dict],
filter_pipe: List = None,
swap_level: bool = True,
freq: Union[str, dict] = "day",
inst_processor: dict = None,
):
"""
Parameters
----------
@@ -152,6 +167,11 @@ class QlibDataLoader(DLWParser):
Filter pipe for the instruments
swap_level :
Whether to swap level of MultiIndex
freq: dict or str
If type(config) == dict and type(freq) == str, load config data using freq.
If type(config) == dict and type(freq) == dict, load config[<group_name>] data using freq[<group_name>]
inst_processor: dict
If inst_processor is not None and type(config) == dict; load config[<group_name>] data using inst_processor[<group_name>]
"""
if filter_pipe is not None:
assert isinstance(filter_pipe, list), "The type of `filter_pipe` must be list."
@@ -163,9 +183,32 @@ class QlibDataLoader(DLWParser):
self.filter_pipe = filter_pipe
self.swap_level = swap_level
self.freq = freq
# sample
self.inst_processor = inst_processor if inst_processor is not None else {}
assert isinstance(self.inst_processor, dict), f"inst_processor(={self.inst_processor}) must be dict"
super().__init__(config)
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
if self.is_group:
# check sample config
if isinstance(freq, dict):
for _gp in config.keys():
if _gp not in freq:
raise ValueError(f"freq(={freq}) missing group(={_gp})")
assert (
self.inst_processor
), f"freq(={self.freq}), inst_processor(={self.inst_processor}) cannot be None/empty"
def load_group_df(
self,
instruments,
exprs: list,
names: list,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
gp_name: str = None,
) -> pd.DataFrame:
if instruments is None:
warnings.warn("`instruments` is not set, will load all stocks")
instruments = "all"
@@ -174,7 +217,10 @@ class QlibDataLoader(DLWParser):
elif self.filter_pipe is not None:
warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")
df = D.features(instruments, exprs, start_time, end_time, self.freq)
freq = self.freq[gp_name] if isinstance(self.freq, dict) else self.freq
df = D.features(
instruments, exprs, start_time, end_time, freq=freq, inst_processors=self.inst_processor.get(gp_name, [])
)
df.columns = names
if self.swap_level:
df = df.swaplevel().sort_index() # NOTE: if swaplevel, return <datetime, instrument>
@@ -199,6 +245,10 @@ class StaticDataLoader(DataLoader):
self.join = join
self._data = None
def __getstate__(self) -> dict:
# avoid pickling `self._data`
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
self._maybe_load_raw_data()
if instruments is None:
@@ -207,7 +257,10 @@ class StaticDataLoader(DataLoader):
df = self._data.loc(axis=0)[:, instruments]
if start_time is None and end_time is None:
return df # NOTE: avoid copy by loc
return df.loc[pd.Timestamp(start_time) : pd.Timestamp(end_time)]
# pd.Timestamp(None) == NaT, use NaT as index can not fetch correct thing, so do not change None.
start_time = time_to_slc_point(start_time)
end_time = time_to_slc_point(end_time)
return df.loc[start_time:end_time]
def _maybe_load_raw_data(self):
if self._data is not None:

View File

@@ -73,6 +73,14 @@ class Processor(Serializable):
"""
return True
def readonly(self) -> bool:
"""
Does the processor treat the input data readonly (i.e. does not write the input data) when processsing
Knowning the readonly information is helpful to the Handler to avoid uncessary copy
"""
return False
def config(self, **kwargs):
attr_list = {"fit_start_time", "fit_end_time"}
for k, v in kwargs.items():
@@ -92,6 +100,9 @@ class DropnaProcessor(Processor):
def __call__(self, df):
return df.dropna(subset=get_group_columns(df, self.fields_group))
def readonly(self):
return True
class DropnaLabel(DropnaProcessor):
def __init__(self, fields_group="label"):
@@ -113,6 +124,9 @@ class DropCol(Processor):
mask = df.columns.isin(self.col_list)
return df.loc[:, ~mask]
def readonly(self):
return True
class FilterCol(Processor):
def __init__(self, fields_group="feature", col_list=[]):
@@ -128,6 +142,9 @@ class FilterCol(Processor):
mask = df.columns.get_level_values(-1).isin(self.col_list)
return df.loc[:, mask]
def readonly(self):
return True
class TanhProcess(Processor):
"""Use tanh to process noise data"""

View File

@@ -0,0 +1,23 @@
import abc
import json
import pandas as pd
class InstProcessor:
@abc.abstractmethod
def __call__(self, df: pd.DataFrame, *args, **kwargs):
"""
process the data
NOTE: **The processor could change the content of `df` inplace !!!!! **
User should keep a copy of data outside
Parameters
----------
df : pd.DataFrame
The raw_df of handler or result from previous processor.
"""
pass
def __str__(self):
return f"{self.__class__.__name__}:{json.dumps(self.__dict__, sort_keys=True, default=str)}"

View File

@@ -10,10 +10,12 @@ import abc
import numpy as np
import pandas as pd
from typing import Union, List, Type
from scipy.stats import percentileofscore
from .base import Expression, ExpressionOps
from ..log import get_module_logger
from ..utils import get_callable_kwargs
try:
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
@@ -1495,16 +1497,34 @@ class OpsWrapper:
def reset(self):
self._ops = {}
def register(self, ops_list):
for operator in ops_list:
if not issubclass(operator, ExpressionOps):
raise TypeError("operator must be subclass of ExpressionOps, not {}".format(operator))
def register(self, ops_list: List[Union[Type[ExpressionOps], dict]]):
"""register operator
if operator.__name__ in self._ops:
Parameters
----------
ops_list : List[Union[Type[ExpressionOps], dict]]
- if type(ops_list) is List[Type[ExpressionOps]], each element of ops_list represents the operator class, which should be the subclass of `ExpressionOps`.
- if type(ops_list) is List[dict], each element of ops_list represents the config of operator, which has the following format:
{
"class": class_name,
"module_path": path,
}
Note: `class` should be the class name of operator, `module_path` should be a python module or path of file.
"""
for _operator in ops_list:
if isinstance(_operator, dict):
_ops_class, _ = get_callable_kwargs(_operator)
else:
_ops_class = _operator
if not issubclass(_ops_class, ExpressionOps):
raise TypeError("operator must be subclass of ExpressionOps, not {}".format(_ops_class))
if _ops_class.__name__ in self._ops:
get_module_logger(self.__class__.__name__).warning(
"The custom operator [{}] will override the qlib default definition".format(operator.__name__)
"The custom operator [{}] will override the qlib default definition".format(_ops_class.__name__)
)
self._ops[operator.__name__] = operator
self._ops[_ops_class.__name__] = _ops_class
def __getattr__(self, key):
if key not in self._ops:

View File

@@ -28,16 +28,18 @@ class QlibLogger(metaclass=MetaLogger):
def __init__(self, module_name):
self.module_name = module_name
self.level = 0
# this feature name conflicts with the attribute with Logger
# rename it to avoid some corner cases that result in comparing `str` and `int`
self.__level = 0
@property
def logger(self):
logger = logging.getLogger(self.module_name)
logger.setLevel(self.level)
logger.setLevel(self.__level)
return logger
def setLevel(self, level):
self.level = level
self.__level = level
def __getattr__(self, name):
# During unpickling, python will call __getattr__. Use this line to avoid maximum recursion error.
@@ -68,7 +70,7 @@ def get_module_logger(module_name, level: Optional[int] = None) -> logging.Logge
class TimeInspector:
timer_logger = get_module_logger("timer", level=logging.WARNING)
timer_logger = get_module_logger("timer", level=logging.INFO)
time_marks = []

View File

@@ -97,7 +97,7 @@ class ModelFT(Model):
# Finetune model based on previous trained model
with R.start(experiment_name="finetune model"):
recorder = R.get_recorder(rid, experiment_name="init models")
recorder = R.get_recorder(recorder_id=rid, experiment_name="init models")
model = recorder.load_object("init_model")
model.finetune(dataset, num_boost_round=10)

View File

@@ -105,6 +105,20 @@ class AverageEnsemble(Ensemble):
"""
def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
"""using sample:
from qlib.model.ens.ensemble import AverageEnsemble
pred_res['new_key_name'] = AverageEnsemble()(predict_dict)
Parameters
----------
ensemble_dict : dict
Dictionary you want to ensemble
Returns
-------
pd.DataFrame
The dictionary including ensenbling result
"""
# need to flatten the nested dict
ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE)
values = list(ensemble_dict.values())

View File

@@ -8,15 +8,16 @@ There are two steps in each Trainer including ``train``(make model recorder) and
This is a concept called ``DelayTrainer``, which can be used in online simulating for parallel training.
In ``DelayTrainer``, the first step is only to save some necessary info to model recorders, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
"""
import socket
from typing import Callable, List
from qlib.data.dataset import Dataset
from qlib.log import get_module_logger
from qlib.model.base import Model
from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config
from qlib.utils import flatten_dict, get_callable_kwargs, init_instance_by_config
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord
from qlib.workflow.recorder import Recorder
@@ -70,7 +71,7 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
if isinstance(records, dict): # prevent only one dict
records = [records]
for record in records:
cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp")
cls, kwargs = get_callable_kwargs(record, default_module="qlib.workflow.record_temp")
if cls is SignalRecord:
rconf = {"model": model, "dataset": dataset, "recorder": rec}
else:
@@ -151,6 +152,9 @@ class Trainer:
"""
return self.delay
def __call__(self, *args, **kwargs) -> list:
return self.end_train(self.train(*args, **kwargs))
class TrainerR(Trainer):
"""
@@ -190,6 +194,8 @@ class TrainerR(Trainer):
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(tasks, dict):
tasks = [tasks]
if len(tasks) == 0:
return []
if train_func is None:
@@ -213,6 +219,8 @@ class TrainerR(Trainer):
Returns:
List[Recorder]: the same list as the param.
"""
if isinstance(recs, Recorder):
recs = [recs]
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
@@ -250,6 +258,8 @@ class DelayTrainerR(TrainerR):
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(recs, Recorder):
recs = [recs]
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
@@ -275,7 +285,12 @@ class TrainerRM(Trainer):
STATUS_BEGIN = "begin_task_train"
STATUS_END = "end_task_train"
def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
# This tag is the _id in TaskManager to distinguish tasks.
TM_ID = "_id in TaskManager"
def __init__(
self, experiment_name: str = None, task_pool: str = None, train_func=task_train, skip_run_task: bool = False
):
"""
Init TrainerR.
@@ -283,11 +298,16 @@ class TrainerRM(Trainer):
experiment_name (str): the default name of experiment.
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
train_func (Callable, optional): default training method. Defaults to `task_train`.
skip_run_task (bool):
If skip_run_task == True:
Only run_task in the worker. Otherwise skip run_task.
"""
super().__init__()
self.experiment_name = experiment_name
self.task_pool = task_pool
self.train_func = train_func
self.skip_run_task = skip_run_task
def train(
self,
@@ -315,6 +335,8 @@ class TrainerRM(Trainer):
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(tasks, dict):
tasks = [tasks]
if len(tasks) == 0:
return []
if train_func is None:
@@ -326,19 +348,26 @@ class TrainerRM(Trainer):
task_pool = experiment_name
tm = TaskManager(task_pool=task_pool)
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
run_task(
train_func,
task_pool,
experiment_name=experiment_name,
before_status=before_status,
after_status=after_status,
**kwargs,
)
query = {"_id": {"$in": _id_list}}
if not self.skip_run_task:
run_task(
train_func,
task_pool,
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=before_status,
after_status=after_status,
**kwargs,
)
if not self.is_delay():
tm.wait(query=query)
recs = []
for _id in _id_list:
rec = tm.re_query(_id)["res"]
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
rec.set_tags(**{self.TM_ID: _id})
recs.append(rec)
return recs
@@ -352,10 +381,33 @@ class TrainerRM(Trainer):
Returns:
List[Recorder]: the same list as the param.
"""
if isinstance(recs, Recorder):
recs = [recs]
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
def worker(
self,
train_func: Callable = None,
experiment_name: str = None,
):
"""
The multiprocessing method for `train`. It can share a same task_pool with `train` and can run in other progress or other machines.
Args:
train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method.
experiment_name (str): the experiment name, None for use default name.
"""
if train_func is None:
train_func = self.train_func
if experiment_name is None:
experiment_name = self.experiment_name
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
run_task(train_func, task_pool=task_pool, experiment_name=experiment_name)
class DelayTrainerRM(TrainerRM):
"""
@@ -369,6 +421,7 @@ class DelayTrainerRM(TrainerRM):
task_pool: str = None,
train_func=begin_task_train,
end_train_func=end_task_train,
skip_run_task: bool = False,
):
"""
Init DelayTrainerRM.
@@ -378,10 +431,15 @@ class DelayTrainerRM(TrainerRM):
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
skip_run_task (bool):
If skip_run_task == True:
Only run_task in the worker. Otherwise skip run_task.
E.g. Starting trainer on a CPU VM and then waiting tasks to be finished on GPU VMs.
"""
super().__init__(experiment_name, task_pool, train_func)
self.end_train_func = end_train_func
self.delay = True
self.skip_run_task = skip_run_task
def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
@@ -395,6 +453,8 @@ class DelayTrainerRM(TrainerRM):
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(tasks, dict):
tasks = [tasks]
if len(tasks) == 0:
return []
return super().train(
@@ -410,8 +470,6 @@ class DelayTrainerRM(TrainerRM):
Given a list of Recorder and return a list of trained Recorder.
This class will finish real data loading and model fitting.
NOTE: This method will train all STATUS_PART_DONE tasks in the task pool, not only the ``recs``.
Args:
recs (list): a list of Recorder, the tasks have been saved to them.
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
@@ -421,7 +479,8 @@ class DelayTrainerRM(TrainerRM):
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(recs, Recorder):
recs = [recs]
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
@@ -429,18 +488,45 @@ class DelayTrainerRM(TrainerRM):
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
tasks = []
_id_list = []
for rec in recs:
tasks.append(rec.load_object("task"))
_id_list.append(rec.list_tags()[self.TM_ID])
query = {"_id": {"$in": _id_list}}
if not self.skip_run_task:
run_task(
end_train_func,
task_pool,
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
**kwargs,
)
TaskManager(task_pool=task_pool).wait(query=query)
run_task(
end_train_func,
task_pool,
query={"filter": {"$in": tasks}}, # only train these tasks
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
**kwargs,
)
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
def worker(self, end_train_func=None, experiment_name: str = None):
"""
The multiprocessing method for `end_train`. It can share a same task_pool with `end_train` and can run in other progress or other machines.
Args:
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
experiment_name (str): the experiment name, None for use default name.
"""
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
experiment_name = self.experiment_name
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
run_task(
end_train_func,
task_pool=task_pool,
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
)

View File

@@ -43,17 +43,29 @@ RECORD_CONFIG = [
]
def get_data_handler_config(market=CSI300_MARKET):
def get_data_handler_config(
start_time="2008-01-01",
end_time="2020-08-01",
fit_start_time="2008-01-01",
fit_end_time="2014-12-31",
instruments=CSI300_MARKET,
):
return {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
"start_time": start_time,
"end_time": end_time,
"fit_start_time": fit_start_time,
"fit_end_time": fit_end_time,
"instruments": instruments,
}
def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLASS):
def get_dataset_config(
dataset_class=DATASET_ALPHA158_CLASS,
train=("2008-01-01", "2014-12-31"),
valid=("2015-01-01", "2016-12-31"),
test=("2017-01-01", "2020-08-01"),
handler_kwargs={"instruments": CSI300_MARKET},
):
return {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
@@ -61,48 +73,88 @@ def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLAS
"handler": {
"class": dataset_class,
"module_path": "qlib.contrib.data.handler",
"kwargs": get_data_handler_config(market),
"kwargs": get_data_handler_config(**handler_kwargs),
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
"train": train,
"valid": valid,
"test": test,
},
},
}
def get_gbdt_task(market=CSI300_MARKET):
def get_gbdt_task(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
return {
"model": GBDT_MODEL,
"dataset": get_dataset_config(market),
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
}
def get_record_lgb_config(market=CSI300_MARKET):
def get_record_lgb_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
return {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": get_dataset_config(market),
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
"record": RECORD_CONFIG,
}
def get_record_xgboost_config(market=CSI300_MARKET):
def get_record_xgboost_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
return {
"model": {
"class": "XGBModel",
"module_path": "qlib.contrib.model.xgboost",
},
"dataset": get_dataset_config(market),
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
"record": RECORD_CONFIG,
}
CSI300_DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET)
CSI300_GBDT_TASK = get_gbdt_task(market=CSI300_MARKET)
CSI300_DATASET_CONFIG = get_dataset_config(handler_kwargs={"instruments": CSI300_MARKET})
CSI300_GBDT_TASK = get_gbdt_task(handler_kwargs={"instruments": CSI300_MARKET})
CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(market=CSI100_MARKET)
CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(market=CSI100_MARKET)
CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(handler_kwargs={"instruments": CSI100_MARKET})
CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(handler_kwargs={"instruments": CSI100_MARKET})
# use for rolling_online_managment.py
ROLLING_HANDLER_CONFIG = {
"start_time": "2013-01-01",
"end_time": "2020-09-25",
"fit_start_time": "2013-01-01",
"fit_end_time": "2014-12-31",
"instruments": CSI100_MARKET,
}
ROLLING_DATASET_CONFIG = {
"train": ("2013-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2015-12-31"),
"test": ("2016-01-01", "2020-07-10"),
}
CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING = get_record_xgboost_config(
dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG
)
CSI100_RECORD_LGB_TASK_CONFIG_ROLLING = get_record_lgb_config(
dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG
)
# use for online_management_simulate.py
ONLINE_HANDLER_CONFIG = {
"start_time": "2018-01-01",
"end_time": "2018-10-31",
"fit_start_time": "2018-01-01",
"fit_end_time": "2018-03-31",
"instruments": CSI100_MARKET,
}
ONLINE_DATASET_CONFIG = {
"train": ("2018-01-01", "2018-03-31"),
"valid": ("2018-04-01", "2018-05-31"),
"test": ("2018-06-01", "2018-09-10"),
}
CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE = get_record_xgboost_config(
dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG
)
CSI100_RECORD_LGB_TASK_CONFIG_ONLINE = get_record_lgb_config(
dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG
)

View File

@@ -43,8 +43,9 @@ def get_redis_connection():
#################### Data ####################
def read_bin(file_path, start_index, end_index):
with open(file_path, "rb") as f:
def read_bin(file_path: Union[str, Path], start_index, end_index):
file_path = Path(file_path.expanduser().resolve())
with file_path.open("rb") as f:
# read start_index
ref_start_index = int(np.frombuffer(f.read(4), dtype="<f")[0])
si = max(ref_start_index, start_index)
@@ -189,9 +190,9 @@ def get_module_by_module_path(module_path: Union[str, ModuleType]):
return module
def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict):
def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict):
"""
extract class and kwargs from config info
extract class/func and kwargs from config info
Parameters
----------
@@ -206,22 +207,22 @@ def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleTy
Returns
-------
(type, dict):
the class object and it's arguments.
the class/func object and it's arguments.
"""
if isinstance(config, dict):
module = get_module_by_module_path(config.get("module_path", default_module))
# raise AttributeError
klass = getattr(module, config["class"])
_callable = getattr(module, config["class" if "class" in config else "func"])
kwargs = config.get("kwargs", {})
elif isinstance(config, str):
module = get_module_by_module_path(default_module)
klass = getattr(module, config)
_callable = getattr(module, config)
kwargs = {}
else:
raise NotImplementedError(f"This type of input is not supported")
return klass, kwargs
return _callable, kwargs
def init_instance_by_config(
@@ -272,7 +273,7 @@ def init_instance_by_config(
with open(os.path.join(pr.netloc, pr.path), "rb") as f:
return pickle.load(f)
klass, cls_kwargs = get_cls_kwargs(config, default_module=default_module)
klass, cls_kwargs = get_callable_kwargs(config, default_module=default_module)
return klass(**cls_kwargs, **kwargs)
@@ -570,9 +571,11 @@ def get_pre_trading_date(trading_date, future=False):
def transform_end_date(end_date=None, freq="day"):
"""get previous trading date
"""handle the end date with various format
If end_date is -1, None, or end_date is greater than the maximum trading day, the last trading date is returned.
Otherwise, returns the end_date
----------
end_date: str
end trading date
@@ -642,6 +645,28 @@ def split_pred(pred, number=None, split_date=None):
return pred_left, pred_right
def time_to_slc_point(t: Union[None, str, pd.Timestamp]) -> Union[None, pd.Timestamp]:
"""
Time slicing in Qlib or Pandas is a frequently-used action.
However, user often input all kinds of data format to represent time.
This function will help user to convert these inputs into a uniform format which is friendly to time slicing.
Parameters
----------
t : Union[None, str, pd.Timestamp]
original time
Returns
-------
Union[None, pd.Timestamp]:
"""
if t is None:
# None represents unbounded in Qlib or Pandas(e.g. df.loc[slice(None, "20210303")]).
return t
else:
return pd.Timestamp(t)
def can_use_cache():
res = True
r = get_redis_connection()
@@ -716,7 +741,8 @@ def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
sorted dataframe
"""
idx = df.index if axis == 0 else df.columns
if idx.is_monotonic_increasing:
# NOTE: MultiIndex.is_lexsorted() is a deprecated method in Pandas 1.3.0 and is suggested to be replaced by MultiIndex.is_monotonic_increasing (see discussion here: https://github.com/pandas-dev/pandas/issues/32259). However, in case older versions of Pandas is implemented, MultiIndex.is_lexsorted() is necessary to prevent certain fatal errors.
if idx.is_monotonic_increasing and not (isinstance(idx, pd.MultiIndex) and not idx.is_lexsorted()):
return df
else:
return df.sort_index(axis=axis)
@@ -770,7 +796,7 @@ class Wrapper:
return "{name}(provider={provider})".format(name=self.__class__.__name__, provider=self._provider)
def __getattr__(self, key):
if self._provider is None:
if self.__dict__.get("_provider", None) is None:
raise AttributeError("Please run qlib.init() first using qlib")
return getattr(self._provider, key)

17
qlib/utils/exceptions.py Normal file
View File

@@ -0,0 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Base exception class
class QlibException(Exception):
def __init__(self, message):
super(QlibException, self).__init__(message)
# Error type for reinitialization when starting an experiment
class RecorderInitializationError(QlibException):
pass
# Error type for Recorder when can not load object
class LoadObjectError(QlibException):
pass

View File

@@ -1,10 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
import pickle
import typing
import dill
from pathlib import Path
from typing import Union
@@ -18,6 +17,7 @@ class Serializable:
pickle_backend = "pickle" # another optional value is "dill" which can pickle more things of python.
default_dump_all = False # if dump all things
FLAG_KEY = "_qlib_serial_flag"
def __init__(self):
self._dump_all = self.default_dump_all
@@ -45,8 +45,6 @@ class Serializable:
"""
return getattr(self, "_exclude", [])
FLAG_KEY = "_qlib_serial_flag"
def config(self, dump_all: bool = None, exclude: list = None, recursive=False):
"""
configure the serializable object
@@ -92,16 +90,16 @@ class Serializable:
@classmethod
def load(cls, filepath):
"""
Load the collector from a filepath.
Load the serializable class from a filepath.
Args:
filepath (str): the path of file
Raises:
TypeError: the pickled file must be `Collector`
TypeError: the pickled file must be `type(cls)`
Returns:
Collector: the instance of Collector
`type(cls)`: the instance of `type(cls)`
"""
with open(filepath, "rb") as f:
object = cls.get_backend().load(f)

View File

@@ -7,6 +7,7 @@ from .expm import MLflowExpManager
from .exp import Experiment
from .recorder import Recorder
from ..utils import Wrapper
from ..utils.exceptions import RecorderInitializationError
class QlibRecorder:
@@ -37,13 +38,13 @@ class QlibRecorder:
.. code-block:: Python
# start new experiment and recorder
with R.start('test', 'recorder_1'):
with R.start(experiment_name='test', recorder_name='recorder_1'):
model.fit(dataset)
R.log...
... # further operations
# resume previous experiment and recorder
with R.start('test', 'recorder_1', resume=True): # if users want to resume recorder, they have to specify the exact same name for experiment and recorder.
with R.start(experiment_name='test', recorder_name='recorder_1', resume=True): # if users want to resume recorder, they have to specify the exact same name for experiment and recorder.
... # further operations
Parameters
@@ -215,9 +216,9 @@ class QlibRecorder:
-------
A dictionary (id -> recorder) of recorder information that being stored.
"""
return self.get_exp(experiment_id, experiment_name).list_recorders()
return self.get_exp(experiment_id=experiment_id, experiment_name=experiment_name).list_recorders()
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
"""
Method for retrieving an experiment with given id or name. Once the `create` argument is set to
True, if no valid experiment is found, this method will create one for you. Otherwise, it will
@@ -262,7 +263,7 @@ class QlibRecorder:
# Case 2
with R.start('test'):
exp = R.get_exp('test1')
exp = R.get_exp(experiment_name='test1')
# Case 3
exp = R.get_exp() -> a default experiment.
@@ -287,7 +288,9 @@ class QlibRecorder:
-------
An experiment instance with given id or name.
"""
return self.exp_manager.get_exp(experiment_id, experiment_name, create, start=False)
return self.exp_manager.get_exp(
experiment_id=experiment_id, experiment_name=experiment_name, create=create, start=False
)
def delete_exp(self, experiment_id=None, experiment_name=None):
"""
@@ -331,7 +334,9 @@ class QlibRecorder:
"""
self.exp_manager.set_uri(uri)
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None) -> Recorder:
def get_recorder(
self, *, recorder_id=None, recorder_name=None, experiment_id=None, experiment_name=None
) -> Recorder:
"""
Method for retrieving a recorder.
@@ -384,7 +389,7 @@ class QlibRecorder:
-------
A recorder instance.
"""
return self.get_exp(experiment_name=experiment_name, create=False).get_recorder(
return self.get_exp(experiment_name=experiment_name, experiment_id=experiment_id, create=False).get_recorder(
recorder_id, recorder_name, create=False, start=False
)
@@ -525,14 +530,29 @@ class QlibRecorder:
self.get_exp().get_recorder().set_tags(**kwargs)
class RecorderWrapper(Wrapper):
"""
Wrapper class for QlibRecorder, which detects whether users reinitialize qlib when already starting an experiment.
"""
def register(self, provider):
if self._provider is not None:
expm = getattr(self._provider, "exp_manager")
if expm.active_experiment is not None:
raise RecorderInitializationError(
"Please don't reinitialize Qlib if QlibRecorder is already acivated. Otherwise, the experiment stored location will be modified."
)
self._provider = provider
import sys
if sys.version_info >= (3, 9):
from typing import Annotated
QlibRecorderWrapper = Annotated[QlibRecorder, Wrapper]
QlibRecorderWrapper = Annotated[QlibRecorder, RecorderWrapper]
else:
QlibRecorderWrapper = QlibRecorder
# global record
R: QlibRecorderWrapper = Wrapper()
R: QlibRecorderWrapper = RecorderWrapper()

View File

@@ -53,7 +53,8 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
exp_manager["kwargs"]["uri"] = "file:" + str(Path(os.getcwd()).resolve() / uri_folder)
qlib.init(**config.get("qlib_init"), exp_manager=exp_manager)
task_train(config.get("task"), experiment_name=experiment_name)
recorder = task_train(config.get("task"), experiment_name=experiment_name)
recorder.save_objects(config=config)
# function to run worklflow by config

View File

@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Union
import mlflow, logging
from mlflow.entities import ViewType
from mlflow.exceptions import MlflowException
@@ -213,11 +214,15 @@ class Experiment:
"""
raise NotImplementedError(f"Please implement the `_get_recorder` method")
def list_recorders(self):
def list_recorders(self, **flt_kwargs):
"""
List all the existing recorders of this experiment. Please first get the experiment instance before calling this method.
If user want to use the method `R.list_recorders()`, please refer to the related API document in `QlibRecorder`.
flt_kwargs : dict
filter recorders by conditions
e.g. list_recorders(status=Recorder.STATUS_FI)
Returns
-------
A dictionary (id -> recorder) of recorder information that being stored.
@@ -320,11 +325,25 @@ class MLflowExperiment(Experiment):
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
def list_recorders(self, max_results=UNLIMITED):
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)
def list_recorders(self, max_results: int = UNLIMITED, status: Union[str, None] = None, filter_string: str = ""):
"""
Parameters
----------
max_results : int
the number limitation of the results
status : str
the criteria based on status to filter results.
`None` indicates no filtering.
filter_string : str
mlflow supported filter string like 'params."my_param"="a" and tags."my_tag"="b"', use this will help to reduce too much run number.
"""
runs = self._client.search_runs(
self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results, filter_string=filter_string
)
recorders = dict()
for i in range(len(runs)):
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
recorders[runs[i].info.run_id] = recorder
if status is None or recorder.status == status:
recorders[runs[i].info.run_id] = recorder
return recorders

View File

@@ -109,7 +109,7 @@ class ExpManager:
"""
raise NotImplementedError(f"Please implement the `search_records` method.")
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False):
def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False):
"""
Retrieve an experiment. This method includes getting an active experiment, and get_or_create a specific experiment.
@@ -190,7 +190,7 @@ class ExpManager:
except ValueError:
if experiment_name is None:
experiment_name = self._default_exp_name
logger.info(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
logger.warning(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
return self.create_exp(experiment_name), True
def _get_exp(self, experiment_id=None, experiment_name=None) -> Experiment:
@@ -352,6 +352,8 @@ class MLflowExpManager(ExpManager):
), "Please input at least one of experiment/recorder id or name before retrieving experiment/recorder."
if experiment_id is not None:
try:
# NOTE: the mlflow's experiment_id must be str type...
# https://www.mlflow.org/docs/latest/python_api/mlflow.tracking.html#mlflow.tracking.MlflowClient.get_experiment
exp = self.client.get_experiment(experiment_id)
if exp.lifecycle_stage.upper() == "DELETED":
raise MlflowException("No valid experiment has been found.")

View File

@@ -6,7 +6,7 @@ OnlineManager can manage a set of `Online Strategy <#Online Strategy>`_ and run
With the change of time, the decisive models will be also changed. In this module, we call those contributing models `online` models.
In every routine(such as every day or every minute), the `online` models may be changed and the prediction of them needs to be updated.
So this module provides a series of methods to control this process.
So this module provides a series of methods to control this process.
This module also provides a method to simulate `Online Strategy <#Online Strategy>`_ in history.
Which means you can verify your strategy or find a better one.
@@ -18,10 +18,12 @@ There are 4 total situations for using different trainers in different situation
========================= ===================================================================================
Situations Description
========================= ===================================================================================
Online + Trainer When you REAL want to do a routine, the Trainer will help you train the models.
Online + Trainer When you want to do a REAL routine, the Trainer will help you train the models. It
will train models task by task and strategy by strategy.
Online + DelayTrainer In normal online routine, whether Trainer or DelayTrainer will REAL train models
in this routine. So it is not necessary to use DelayTrainer when do a REAL routine.
Online + DelayTrainer When your models don't have any temporal dependence, the DelayTrainer will train
nothing until all tasks have been prepared. It makes user can train all tasks in
the end of `routine` or `first_train`.
Simulation + Trainer When your models have some temporal dependence on the previous models, then you
need to consider using Trainer. This means it will REAL train your models in
@@ -29,7 +31,7 @@ Simulation + Trainer When your models have some temporal dependence on the
Simulation + DelayTrainer When your models don't have any temporal dependence, you can use DelayTrainer
for the ability to multitasking. It means all tasks in all routines
can be REAL trained at the end of simulating. The signals will be prepared well at
can be REAL trained at the end of simulating. The signals will be prepared well at
different time segments (based on whether or not any new model is online).
========================= ===================================================================================
"""
@@ -103,17 +105,23 @@ class OnlineManager(Serializable):
"""
if strategies is None:
strategies = self.strategies
for strategy in strategies:
models_list = []
for strategy in strategies:
self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
tasks = strategy.first_tasks()
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
models_list.append(models)
self.logger.info(f"Finished training {len(models)} models.")
# FIXME: Traing multiple online models at `first_train` will result in getting too much online models at the
# start.
online_models = strategy.prepare_online_models(models, **model_kwargs)
self.history.setdefault(self.cur_time, {})[strategy] = online_models
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
for strategy, models in zip(strategies, models_list):
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
def routine(
self,
cur_time: Union[str, pd.Timestamp] = None,
@@ -139,33 +147,41 @@ class OnlineManager(Serializable):
cur_time = D.calendar(freq=self.freq).max()
self.cur_time = pd.Timestamp(cur_time) # None for latest date
models_list = []
for strategy in self.strategies:
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
if self.status == self.STATUS_NORMAL:
strategy.tool.update_online_pred()
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
models = self.trainer.train(tasks)
if self.status == self.STATUS_NORMAL or not self.trainer.is_delay():
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
models_list.append(models)
self.logger.info(f"Finished training {len(models)} models.")
online_models = strategy.prepare_online_models(models, **model_kwargs)
self.history.setdefault(self.cur_time, {})[strategy] = online_models
if not self.trainer.is_delay():
# The online model may changes in the above processes
# So updating the predictions of online models should be the last step
if self.status == self.STATUS_NORMAL:
strategy.tool.update_online_pred()
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
for strategy, models in zip(self.strategies, models_list):
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
self.prepare_signals(**signal_kwargs)
def get_collector(self) -> MergeCollector:
def get_collector(self, **kwargs) -> MergeCollector:
"""
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy.
This collector can be a basis as the signals preparation.
Args:
**kwargs: the params for get_collector.
Returns:
MergeCollector: the collector to merge other collectors.
"""
collector_dict = {}
for strategy in self.strategies:
collector_dict[strategy.name_id] = strategy.get_collector()
collector_dict[strategy.name_id] = strategy.get_collector(**kwargs)
return MergeCollector(collector_dict, process_list=[])
def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]):
@@ -225,7 +241,7 @@ class OnlineManager(Serializable):
SIM_LOG_NAME = "SIMULATE_INFO"
def simulate(
self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}
self, end_time=None, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}
) -> Union[pd.Series, pd.DataFrame]:
"""
Starting from the current time, this method will simulate every routine in OnlineManager until the end time.
@@ -297,6 +313,7 @@ class OnlineManager(Serializable):
# NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way.
self.prepare_signals(**signal_kwargs)
if signals_time > cur_time:
# FIXME: if use DelayTrainer and worker (and worker is faster than main progress), there are some possibilities of showing this warning.
self.logger.warn(
f"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models."
)

View File

@@ -10,6 +10,7 @@ from typing import List, Tuple, Union
from qlib.data.data import D
from qlib.log import get_module_logger
from qlib.model.ens.group import RollingGroup
from qlib.utils import transform_end_date
from qlib.workflow.online.utils import OnlineTool, OnlineToolR
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.collect import Collector, RecorderCollector
@@ -52,6 +53,12 @@ class OnlineStrategy:
NOTE: Reset all online models to trained models. If there are no trained models, then do nothing.
**NOTE**:
Current implementation is very naive. Here is a more complex situation which is more closer to the
practical scenarios.
1. Train new models at the day before `test_start` (at time stamp `T`)
2. Switch models at the `test_start` (at time timestamp `T + 1` typically)
Args:
models (list): a list of models.
cur_time (pd.Dataframe): current time from OnlineManger. None for the latest.
@@ -112,6 +119,7 @@ class RollingStrategy(OnlineStrategy):
task_template = [task_template]
self.task_template = task_template
self.rg = rolling_gen
assert issubclass(self.rg.__class__, RollingGen), "The rolling strategy relies on the feature if RollingGen"
self.tool = OnlineToolR(self.exp_name)
self.ta = TimeAdjuster()
@@ -168,28 +176,20 @@ class RollingStrategy(OnlineStrategy):
Returns:
List[dict]: a list of new tasks.
"""
# TODO: filter recorders by latest test segments is not a necessary
latest_records, max_test = self._list_latest(self.tool.online_models())
if max_test is None:
self.logger.warn(f"No latest online recorders, no new tasks.")
return []
calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time
calendar_latest = transform_end_date(cur_time)
self.logger.info(
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
)
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
old_tasks = []
tasks_tmp = []
for rec in latest_records:
task = rec.load_object("task")
old_tasks.append(deepcopy(task))
test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
# modify the test segment to generate new tasks
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
tasks_tmp.append(task)
new_tasks_tmp = task_generator(tasks_tmp, self.rg)
new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
return new_tasks
return []
res = []
for rec in latest_records:
task = rec.load_object("task")
res.extend(self.rg.gen_following_tasks(task, calendar_latest))
return res
def _list_latest(self, rec_list: List[Recorder]):
"""

View File

@@ -105,6 +105,8 @@ class PredUpdater(RecordUpdater):
if to_date == None:
to_date = D.calendar(freq=freq)[-1]
self.to_date = pd.Timestamp(to_date)
# FIXME: it will raise error when running routine with delay trainer
# should we use another predicition updater for delay trainer?
self.old_pred = record.load_object("pred.pkl")
self.last_end = self.old_pred.index.get_level_values("datetime").max()
@@ -135,10 +137,9 @@ class PredUpdater(RecordUpdater):
# RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
# https://github.com/pytorch/pytorch/issues/16797
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
if start_time >= self.to_date:
if self.last_end >= self.to_date:
self.logger.info(
f"The prediction in {self.record.info['id']} are latest ({start_time}). No need to update to {self.to_date}."
f"The prediction in {self.record.info['id']} are latest ({self.last_end}). No need to update to {self.to_date}."
)
return

View File

@@ -8,8 +8,11 @@ This allows us to use efficient submodels as the market-style changing.
"""
from typing import List, Union
from qlib.data.dataset import TSDatasetH
from qlib.log import get_module_logger
from qlib.utils import get_callable_kwargs
from qlib.utils.exceptions import LoadObjectError
from qlib.workflow.online.update import PredUpdater
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.utils import list_recorders
@@ -88,15 +91,15 @@ class OnlineToolR(OnlineTool):
The implementation of OnlineTool based on (R)ecorder.
"""
def __init__(self, experiment_name: str):
def __init__(self, default_exp_name: str = None):
"""
Init OnlineToolR.
Args:
experiment_name (str): the experiment name.
default_exp_name (str): the default experiment name.
"""
super().__init__()
self.exp_name = experiment_name
self.default_exp_name = default_exp_name
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
"""
@@ -125,44 +128,68 @@ class OnlineToolR(OnlineTool):
tags = recorder.list_tags()
return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG)
def reset_online_tag(self, recorder: Union[Recorder, List]):
def reset_online_tag(self, recorder: Union[Recorder, List], exp_name: str = None):
"""
Offline all models and set the recorders to 'online'.
Args:
recorder (Union[Recorder, List]):
the recorder you want to reset to 'online'.
exp_name (str): the experiment name. If None, then use default_exp_name.
"""
exp_name = self._get_exp_name(exp_name)
if isinstance(recorder, Recorder):
recorder = [recorder]
recs = list_recorders(self.exp_name)
recs = list_recorders(exp_name)
self.set_online_tag(self.OFFLINE_TAG, list(recs.values()))
self.set_online_tag(self.ONLINE_TAG, recorder)
def online_models(self) -> list:
def online_models(self, exp_name: str = None) -> list:
"""
Get current `online` models
Args:
exp_name (str): the experiment name. If None, then use default_exp_name.
Returns:
list: a list of `online` models.
"""
return list(list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
exp_name = self._get_exp_name(exp_name)
return list(list_recorders(exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
def update_online_pred(self, to_date=None):
def update_online_pred(self, to_date=None, exp_name: str = None):
"""
Update the predictions of online models to to_date.
Args:
to_date (pd.Timestamp): the pred before this date will be updated. None for updating to latest time in Calendar.
exp_name (str): the experiment name. If None, then use default_exp_name.
"""
online_models = self.online_models()
exp_name = self._get_exp_name(exp_name)
online_models = self.online_models(exp_name=exp_name)
for rec in online_models:
hist_ref = 0
task = rec.load_object("task")
# Special treatment of historical dependencies
if task["dataset"]["class"] == "TSDatasetH":
hist_ref = task["dataset"]["kwargs"]["step_len"]
PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update()
cls, kwargs = get_callable_kwargs(task["dataset"], default_module="qlib.data.dataset")
if issubclass(cls, TSDatasetH):
hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN)
try:
updater = PredUpdater(rec, to_date=to_date, hist_ref=hist_ref)
except LoadObjectError as e:
# skip the recorder without pred
self.logger.warn(f"An exception `{str(e)}` happened when load `pred.pkl`, skip it.")
continue
updater.update()
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {exp_name}.")
def _get_exp_name(self, exp_name):
if exp_name is None:
if self.default_exp_name is None:
raise ValueError(
"Both default_exp_name and exp_name are None. OnlineToolR needs a specific experiment."
)
exp_name = self.default_exp_name
return exp_name

View File

@@ -227,10 +227,11 @@ class SigAnaRecord(SignalRecord):
artifact_path = "sig_analysis"
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, **kwargs):
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0, **kwargs):
super().__init__(recorder=recorder, **kwargs)
self.ana_long_short = ana_long_short
self.ann_scaler = ann_scaler
self.label_col = label_col
def generate(self, **kwargs):
try:
@@ -243,7 +244,7 @@ class SigAnaRecord(SignalRecord):
if label is None or not isinstance(label, pd.DataFrame) or label.empty:
logger.warn(f"Empty label.")
return
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, self.label_col])
metrics = {
"IC": ic.mean(),
"ICIR": ic.mean() / ic.std(),
@@ -252,7 +253,7 @@ class SigAnaRecord(SignalRecord):
}
objects = {"ic.pkl": ic, "ric.pkl": ric}
if self.ana_long_short:
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0])
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, self.label_col])
metrics.update(
{
"Long-Short Ann Return": long_short_r.mean() * self.ann_scaler,

View File

@@ -5,6 +5,8 @@ import mlflow, logging
import shutil, os, pickle, tempfile, codecs, pickle
from pathlib import Path
from datetime import datetime
from qlib.utils.exceptions import LoadObjectError
from ..utils.objm import FileManager
from ..log import get_module_logger
@@ -297,7 +299,11 @@ class MLflowRecorder(Recorder):
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
if local_path is not None:
self.client.log_artifacts(self.id, local_path, artifact_path)
path = Path(local_path)
if path.is_dir():
self.client.log_artifacts(self.id, local_path, artifact_path)
else:
self.client.log_artifact(self.id, local_path, artifact_path)
else:
temp_dir = Path(tempfile.mkdtemp()).resolve()
for name, data in kwargs.items():
@@ -307,10 +313,26 @@ class MLflowRecorder(Recorder):
shutil.rmtree(temp_dir)
def load_object(self, name):
"""
Load object such as prediction file or model checkpoint in mlflow.
Args:
name (str): the object name
Raises:
LoadObjectError: if raise some exceptions when load the object
Returns:
object: the saved object in mlflow.
"""
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
path = self.client.download_artifacts(self.id, name)
with Path(path).open("rb") as f:
return pickle.load(f)
try:
path = self.client.download_artifacts(self.id, name)
with Path(path).open("rb") as f:
return pickle.load(f)
except Exception as e:
raise LoadObjectError(message=str(e))
def log_params(self, **kwargs):
for name, data in kwargs.items():

View File

@@ -6,6 +6,7 @@ Collector module can collect objects from everywhere and process them such as me
"""
from typing import Callable, Dict, List
from qlib.log import get_module_logger
from qlib.utils.serial import Serializable
from qlib.workflow import R
@@ -138,6 +139,7 @@ class RecorderCollector(Collector):
rec_filter_func=None,
artifacts_path={"pred": "pred.pkl"},
artifacts_key=None,
list_kwargs={},
):
"""
Init RecorderCollector.
@@ -149,6 +151,7 @@ class RecorderCollector(Collector):
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}.
artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
list_kwargs (str): arguments for list_recorders function.
"""
super().__init__(process_list=process_list)
if isinstance(experiment, str):
@@ -162,6 +165,7 @@ class RecorderCollector(Collector):
self.rec_key_func = rec_key_func
self.artifacts_key = artifacts_key
self.rec_filter_func = rec_filter_func
self.list_kwargs = list_kwargs
def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> dict:
"""
@@ -186,12 +190,13 @@ class RecorderCollector(Collector):
collect_dict = {}
# filter records
recs = self.experiment.list_recorders()
recs = self.experiment.list_recorders(**self.list_kwargs)
recs_flt = {}
for rid, rec in recs.items():
if rec_filter_func is None or rec_filter_func(rec):
recs_flt[rid] = rec
logger = get_module_logger("RecorderCollector")
for _, rec in recs_flt.items():
rec_key = self.rec_key_func(rec)
for key in artifacts_key:
@@ -205,7 +210,13 @@ class RecorderCollector(Collector):
# only collect existing artifact
continue
raise e
collect_dict.setdefault(key, {})[rec_key] = artifact
# give user some warning if the values are overridden
cdd = collect_dict.setdefault(key, {})
if rec_key in cdd:
logger.warning(
f"key '{rec_key}' is duplicated. Previous value will be overrides. Please check you `rec_key_func`"
)
cdd[rec_key] = artifact
return collect_dict

View File

@@ -5,7 +5,10 @@ TaskGenerator module can generate many tasks based on TaskGen and some task temp
"""
import abc
import copy
import pandas as pd
from typing import List, Union, Callable
from qlib.utils import transform_end_date
from .utils import TimeAdjuster
@@ -137,6 +140,53 @@ class RollingGen(TaskGen):
self.test_key = "test"
self.train_key = "train"
def _update_task_segs(self, task, segs):
# update segments of this task
task["dataset"]["kwargs"]["segments"] = copy.deepcopy(segs)
if self.ds_extra_mod_func is not None:
self.ds_extra_mod_func(task, self)
def gen_following_tasks(self, task: dict, test_end: pd.Timestamp) -> List[dict]:
"""
generating following rolling tasks for `task` until test_end
Parameters
----------
task : dict
Qlib task format
test_end : pd.Timestamp
the latest rolling task includes `test_end`
Returns
-------
List[dict]:
the following tasks of `task`(`task` itself is excluded)
"""
prev_seg = task["dataset"]["kwargs"]["segments"]
while True:
segments = {}
try:
for k, seg in prev_seg.items():
# decide how to shift
# expanding only for train data, the segments size of test data and valid data won't change
if k == self.train_key and self.rtype == self.ROLL_EX:
rtype = self.ta.SHIFT_EX
else:
rtype = self.ta.SHIFT_SD
# shift the segments data
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
if segments[self.test_key][0] > test_end:
break
except KeyError:
# We reach the end of tasks
# No more rolling
break
prev_seg = segments
t = copy.deepcopy(task) # deepcopy is necessary to avoid modify task inplace
self._update_task_segs(t, segments)
yield t
def generate(self, task: dict) -> List[dict]:
"""
Converting the task into a rolling task.
@@ -189,43 +239,23 @@ class RollingGen(TaskGen):
"""
res = []
prev_seg = None
test_end = None
while True:
t = copy.deepcopy(task)
t = copy.deepcopy(task)
# calculate segments
if prev_seg is None:
# First rolling
# 1) prepare the end point
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1]
# 2) and init test segments
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
else:
segments = {}
try:
for k, seg in prev_seg.items():
# decide how to shift
# expanding only for train data, the segments size of test data and valid data won't change
if k == self.train_key and self.rtype == self.ROLL_EX:
rtype = self.ta.SHIFT_EX
else:
rtype = self.ta.SHIFT_SD
# shift the segments data
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
if segments[self.test_key][0] > test_end:
break
except KeyError:
# We reach the end of tasks
# No more rolling
break
# calculate segments
# update segments of this task
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
prev_seg = segments
if self.ds_extra_mod_func is not None:
self.ds_extra_mod_func(t, self)
res.append(t)
# First rolling
# 1) prepare the end point
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
test_end = transform_end_date(segments[self.test_key][1])
# 2) and init test segments
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
# update segments of this task
self._update_task_segs(t, segments)
res.append(t)
# Update the following rolling
res.extend(self.gen_following_tasks(t, test_end))
return res

View File

@@ -47,6 +47,14 @@ class TaskManager:
The tasks manager assumes that you will only update the tasks you fetched.
The mongo fetch one and update will make it date updating secure.
This class can be used as a tool from commandline. Here are serveral examples
.. code-block:: shell
python -m qlib.workflow.task.manage -t <pool_name> wait
python -m qlib.workflow.task.manage -t <pool_name> task_stat
.. note::
Assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded
@@ -69,28 +77,29 @@ class TaskManager:
ENCODE_FIELDS_PREFIX = ["def", "res"]
def __init__(self, task_pool: str = None):
def __init__(self, task_pool: str):
"""
Init Task Manager, remember to make the statement of MongoDB url and database name firstly.
A TaskManager instance serves a specific task pool.
The static method of this module serves the whole MongoDB.
Parameters
----------
task_pool: str
the name of Collection in MongoDB
"""
self.mdb = get_mongodb()
if task_pool is not None:
self.task_pool = getattr(self.mdb, task_pool)
self.task_pool: pymongo.collection.Collection = getattr(get_mongodb(), task_pool)
self.logger = get_module_logger(self.__class__.__name__)
def list(self) -> list:
@staticmethod
def list() -> list:
"""
List the all collection(task_pool) of the db
List the all collection(task_pool) of the db.
Returns:
list
"""
return self.mdb.list_collection_names()
return get_mongodb().list_collection_names()
def _encode_task(self, task):
for prefix in self.ENCODE_FIELDS_PREFIX:
@@ -100,6 +109,20 @@ class TaskManager:
return task
def _decode_task(self, task):
"""
_decode_task is Serialization tool.
Mongodb needs JSON, so it needs to convert Python objects into JSON objects through pickle
Parameters
----------
task : dict
task information
Returns
-------
dict
JSON required by mongodb
"""
for prefix in self.ENCODE_FIELDS_PREFIX:
for k in list(task.keys()):
if k.startswith(prefix):
@@ -109,6 +132,25 @@ class TaskManager:
def _dict_to_str(self, flt):
return {k: str(v) for k, v in flt.items()}
def _decode_query(self, query):
"""
If the query includes any `_id`, then it needs `ObjectId` to decode.
For example, when using TrainerRM, it needs query `{"_id": {"$in": _id_list}}`. Then we need to `ObjectId` every `_id` in `_id_list`.
Args:
query (dict): query dict. Defaults to {}.
Returns:
dict: the query after decoding.
"""
if "_id" in query:
if isinstance(query["_id"], dict):
for key in query["_id"]:
query["_id"][key] = [ObjectId(i) for i in query["_id"][key]]
else:
query["_id"] = ObjectId(query["_id"])
return query
def replace_task(self, task, new_task):
"""
Use a new task to replace a old one
@@ -191,6 +233,7 @@ class TaskManager:
r = self.task_pool.find_one({"filter": t})
except InvalidDocument:
r = self.task_pool.find_one({"filter": self._dict_to_str(t)})
# When r is none, it indicates that r s a new task
if r is None:
new_tasks.append(t)
if not dry_run:
@@ -224,8 +267,7 @@ class TaskManager:
dict: a task(document in collection) after decoding
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query = self._decode_query(query)
query.update({"status": status})
task = self.task_pool.find_one_and_update(
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
@@ -253,10 +295,10 @@ class TaskManager:
task = self.fetch_task(query=query, status=status)
try:
yield task
except Exception:
except (Exception, KeyboardInterrupt): # KeyboardInterrupt is not a subclass of Exception
if task is not None:
self.logger.info("Returning task before raising error")
self.return_task(task)
self.return_task(task, status=status) # return task as the original status
self.logger.info("Task returned")
raise
@@ -283,12 +325,11 @@ class TaskManager:
dict: a task(document in collection) after decoding
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query = self._decode_query(query)
for t in self.task_pool.find(query):
yield self._decode_task(t)
def re_query(self, _id):
def re_query(self, _id) -> dict:
"""
Use _id to query task.
@@ -339,8 +380,7 @@ class TaskManager:
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query = self._decode_query(query)
self.task_pool.delete_many(query)
def task_stat(self, query={}) -> dict:
@@ -354,8 +394,7 @@ class TaskManager:
dict
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query = self._decode_query(query)
tasks = self.query(query=query, decode=False)
status_stat = {}
for t in tasks:
@@ -377,8 +416,7 @@ class TaskManager:
def reset_status(self, query, status):
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query = self._decode_query(query)
print(self.task_pool.update_many(query, {"$set": {"status": status}}))
def prioritize(self, task, priority: int):
@@ -396,15 +434,29 @@ class TaskManager:
self.task_pool.update_one({"_id": task["_id"]}, update_dict)
def _get_undone_n(self, task_stat):
return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0)
return (
task_stat.get(self.STATUS_WAITING, 0)
+ task_stat.get(self.STATUS_RUNNING, 0)
+ task_stat.get(self.STATUS_PART_DONE, 0)
)
def _get_total(self, task_stat):
return sum(task_stat.values())
def wait(self, query={}):
"""
When multiprocessing, the main progress may fetch nothing from TaskManager because there are still some running tasks.
So main progress should wait until all tasks are trained well by other progress or machines.
Args:
query (dict, optional): the query dict. Defaults to {}.
"""
task_stat = self.task_stat(query)
total = self._get_total(task_stat)
last_undone_n = self._get_undone_n(task_stat)
if last_undone_n == 0:
return
self.logger.warning(f"Waiting for {last_undone_n} undone tasks. Please make sure they are running.")
with tqdm(total=total, initial=total - last_undone_n) as pbar:
while True:
time.sleep(10)
@@ -432,11 +484,11 @@ def run_task(
After running this method, here are 4 situations (before_status -> after_status):
STATUS_WAITING -> STATUS_DONE: use task["def"] as `task_func` param
STATUS_WAITING -> STATUS_DONE: use task["def"] as `task_func` param, it means that the task has not been started
STATUS_WAITING -> STATUS_PART_DONE: use task["def"] as `task_func` param
STATUS_PART_DONE -> STATUS_PART_DONE: use task["res"] as `task_func` param
STATUS_PART_DONE -> STATUS_PART_DONE: use task["res"] as `task_func` param, it means that the task has been started but not completed
STATUS_PART_DONE -> STATUS_DONE: use task["res"] as `task_func` param

Some files were not shown because too many files have changed in this diff Show More