1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 14:01:28 +08:00

Compare commits

...

72 Commits

Author SHA1 Message Date
Xu Yang
2df211c320 merge all commit 2023-07-13 16:29:44 +08:00
Fivele-Li
effed382e9 Optimize prompt for entire learn loop (#1589)
* Adjust prompt and fix cases
* adjust summarizeTask & learn prompts;
* fix typos & drop duplicate task method;

* adjust learn prompts;
2023-07-11 18:13:52 +08:00
Fivele-Li
86ffd1799d Add knowledge module and tune summarizeTask (#1582)
* Add knowledge module
* add KnowledgeExperiment add KnowledgeBase;
* add knowledge associate prompts to template;

* Add Topic class
* add Topic to summarize knowledge;
* add recorder's metric to summarizeTask;

---------

Co-authored-by: Cadenza-Li <362237642@qq.com>
2023-07-06 11:39:36 +08:00
Young
aef11536e3 rename & test 2023-07-04 20:28:08 +08:00
Xu Yang
8b0fdf1623 Merge pull request #1581 from microsoft/xuyang1/fix_singleton_bug
fix singleton bug
2023-07-04 16:51:51 +08:00
Xu Yang
9a36f8da20 fix singleton bug 2023-07-04 16:20:02 +08:00
Xu Yang
b7757d5008 Merge pull request #1580 from microsoft/xuyang1/refine_workflow_to_increase_success_rate
refine workflow to increase success rate
2023-07-03 17:59:54 +08:00
Xu Yang
ee5e5cfdd8 remove useless code 2023-07-03 17:57:13 +08:00
Xu Yang
6cb87ecfd1 refine code to use qrun 2023-07-03 17:56:22 +08:00
Xu Yang
9119bcdd3c Merge pull request #1576 from microsoft/xuyang1/add_config_and_code_dump_task
refine workflow and prompts
2023-06-30 14:43:49 +08:00
Xu Yang
4fccf8112d fix one workflow 2023-06-30 14:33:41 +08:00
Xu Yang
73bd79ca1a merge into one commit 2023-06-30 14:23:40 +08:00
Fivele-Li
7e84f3aae2 Add backtest and backforward task (#1568)
* * add TrainTask & BacktestTask;
* add BackForwardTask;
* adjust prompt_template.yaml which default config failed to backtest;
* run workflow in loop
* add update method to prompt_template.py

* remove debug code

* Adjust Learn Process
* add LearnManager class & use LearnManager to update system prompt;
* use qrun to replace recorder for training and backtesting;

* Adjust analyser
* analyser independent of recorder;
* rename analyser's workspace attribution;
* analyser load variable by recorder.

---------

Co-authored-by: Cadenza-Li <362237642@qq.com>
2023-06-30 10:04:43 +08:00
Fivele-Li
1326ac614d Add docs to context and retrieve (#1566)
* add analyser docstring to context;
* add retrieve method to context manager;

* add notes to retrieve
2023-06-24 21:47:27 +08:00
Fivele-Li
f12184cc0f Add analyser task and optimize interact (#1552)
* * optimize interact
* add AnalyserTask
* optimize logger format and add render feature

* format optimize
2023-06-16 11:42:45 +08:00
Xu Yang
a70386ad52 Merge pull request #1550 from microsoft/xuyang1/refine_task_prompts
add datahandler and design action task according to component
2023-06-14 14:52:42 +08:00
Xu Yang
74619ed8d8 fix using defaut in record strategy and backtest 2023-06-14 14:52:16 +08:00
Fivele-Li
1a523df007 Optimize log and interact of FinCo (#1549)
* use FinCoLog for a better interact experience

* addition file changes

* optimize format

* optimize format
2023-06-14 14:48:17 +08:00
Xu Yang
f9cc8a5aaa remove useless prompt 2023-06-14 10:46:38 +08:00
Xu Yang
7762c5a1fd add datahandler and design action task according to component 2023-06-13 23:28:27 +08:00
Xu Yang
fa7ef29281 Merge pull request #1548 from microsoft/xuyang1/add_dump_to_file_task
add simple readme & move prompt templates to outer yaml file to make the code clean
2023-06-13 15:29:13 +08:00
Xu Yang
429c9a7c66 format 2023-06-13 15:27:59 +08:00
Xu Yang
80fbc00792 move prompt templates to yaml file to make code clean 2023-06-13 15:21:19 +08:00
Xu Yang
01accec24c update code 2023-06-12 16:25:16 +08:00
Fivele-Li
1d88830b0d Add recorder task and visualize (#1542)
* add recorder task

* add batch generate summarize report unittest.

* * add recorder to RecorderTask;
* add matplot figure to analyzer.py

* add image to markdown;

* Add some log

* update figure path.

---------

Co-authored-by: Young <afe.young@gmail.com>
Co-authored-by: Cadenza-Li <362237642@qq.com>
2023-06-12 15:48:00 +08:00
you-n-g
ad7498e287 Edit yaml task (#1538)
* Edit yaml task

* update comments
2023-06-02 00:44:41 +08:00
you-n-g
73d51f05b4 Init workspace and CMDTask (#1537)
* Update setup.py and config

* WIP

* init_workspace and CMDTask

* Delete test_sumarize.py
2023-06-01 23:32:35 +08:00
Fivele-Li
3b56b8e6c0 Optimize summarize task prompt and others (#1533)
* 1.update prompt;
2.update fetch information method.

* 1.update prompt;
2.save result to markdown;

* 1.get context info from context_manager;
2.run the entire process successfully.
2023-06-01 21:22:24 +08:00
you-n-g
40e0c329ba Add configurable dataset (#1535) 2023-06-01 20:05:02 +08:00
Xu Yang
e376648860 Merge pull request #1536 from microsoft/xuyang1/add_debug_mode_to_save_cache
add a debug mode to speed up debug process
2023-06-01 19:44:17 +08:00
Xu Yang
5f37f32184 update code 2023-06-01 19:38:26 +08:00
Xu Yang
d46b4c1ebf Merge pull request #1534 from microsoft/xuyang1/add_code_implementation_task
add code implementation task
2023-06-01 18:13:05 +08:00
Xu Yang
0515524b51 add code implementation code 2023-06-01 18:04:31 +08:00
Xu Yang
cda32d5703 Merge pull request #1532 from microsoft/xuyang1/add-plan-and-config-task-implementation
add the initial version of plan and config task implementation
2023-06-01 11:20:04 +08:00
Xu Yang
e2332a004b imporove some words in prompt 2023-06-01 01:09:14 +08:00
Xu Yang
08d9dbccc9 update v1 code containing SLplan and config action 2023-06-01 00:36:04 +08:00
Fivele-Li
e7cd93a36d add base method for summarization; (#1530) 2023-05-31 15:50:34 +08:00
Xu Yang
3919678028 split task into workflow and task to make the strcture more clear 2023-05-31 11:45:25 +08:00
Xu Yang
421b1403b2 Merge pull request #1528 from microsoft/xuyang1/refine_task_and_implement_workflow_task_as_example
Xuyang1/refine task and implement workflow task as example
2023-05-31 11:36:36 +08:00
Xu Yang
94102fb742 remove tasktype variable 2023-05-31 11:35:54 +08:00
Cadenza-Li
74a5d7c8af add parse method for summarization; 2023-05-31 00:08:21 +08:00
Xu Yang
ce39b4b6f8 add qlib auto init so logger can display info 2023-05-30 21:52:35 +08:00
Xu Yang
2af35d9c89 second commit 2023-05-30 20:20:16 +08:00
Xu Yang
f37643550b first round 2023-05-30 20:19:58 +08:00
Xu Yang
55611aa43e Merge pull request #1527 from microsoft/xuyang1/add_openai_api_support
add openai interface support
2023-05-30 13:44:10 +08:00
Xu Yang
f24253efd2 add openai interface support 2023-05-30 13:42:01 +08:00
Young
7c4f3b8a7d Initial interface for discussion 2023-05-24 12:18:31 +08:00
you-n-g
94268619c4 Update README.md 2023-05-23 09:50:00 +08:00
Huoran Li
8d60a6a02b Resolve RL FIXMES (#1503)
* Solve several small FIXMEs left in RL

* Add TODO in example

* Minor bugfix

* black
2023-05-17 16:57:08 +08:00
Fivele-Li
7234308651 Add base config in yml (#1500)
* path on Windows contains double '/' which may cause open file failed.

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* add baseConfig in yml,user can add new keys or update/drop keys in baseConfig;

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs. The pip version has been temporarily fixed to 23.0.1.

* 1.Search for baseConfig in multiple directories;
2.Add user instructions in qrun;

* fix format with black

* 1.modify baseConfig key to BASE_CONFIG_PATH;
2.only find config file in absolute path and relative path;

* load BASE_CONFIG_PATH on absolute path & relative path;

* fix Lint with black

---------

Co-authored-by: lijinhui <362237642@qq.com>
2023-05-12 17:35:37 +08:00
Chaoying
acf5df27ce Add support for redis password (#1508) 2023-05-08 16:17:15 +08:00
Chaoying
37a59f28d3 Fix deprecated syntax in numpy (#1507)
* Fix deprecated syntax in numpy

* Replace np.bool with bool
2023-05-08 16:17:02 +08:00
YQ Tsui
b084c352f5 provide dtype to empty series to surpress warning; fix type (#1449) 2023-05-05 17:47:44 +08:00
Maksim Zayakin
9e22e5168b Remove unused DNNModelPytorch params (#1470)
* Remove lr_decay and lr_decay_steps params

More flexible way to pass a scheduler (via callable function) is already
supported

* remove lr_decay and lr_decay_steps from mlp workflow configs
2023-04-28 17:48:40 +08:00
Fivele-Li
dceff7b471 Specify the tianshou version to match the dev environment to avoid the error in issue #1477. (#1502) 2023-04-28 13:50:25 +08:00
Huoran Li
7f1e8c5206 Refine Qlib RL data format (#1480)
* wip

* wip

* wip

* Fix naming errors

* Backtest test passed

* Why training stuck?

* Minor

* Refine train configs

* Use dummy in training

* Remove pickle_dataframe

* CI

* CI

* Add more strict condition to filter orders

* Pass test

* Add TODO in example

---------

Co-authored-by: Young <afe.young@gmail.com>
2023-04-26 21:14:30 +08:00
Fivele-Li
46264dfec9 normpath for Windows (#1495)
* path on Windows contains double '/' which may cause open file failed.

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs. The pip version has been temporarily fixed to 23.0.1.

---------

Co-authored-by: lijinhui <362237642@qq.com>
2023-04-26 16:26:12 +08:00
Fivele-Li
754799ab05 update ubuntu CI version; (#1488)
* update ubuntu CI version;
(End of standard support for 18.04 LTS - 31 May 2023)

* update ubuntu CI version;

---------

Co-authored-by: lijinhui <362237642@qq.com>
2023-04-10 17:06:48 +08:00
you-n-g
32c3070b73 Refine DDG-DA (#1472)
* Run ddg-da successfully

* Support include valid; More parameters

* Support L2 reg & visualization

* Blackformat

* Enable fill_method

* Support specify handler & optim dataset

* Fix Pylint
2023-04-07 15:00:21 +08:00
you-n-g
40de67265a Update Docs about some concepts in DataHandler (#1485) 2023-04-07 10:02:16 +08:00
saurabh dave
e6f9a94fc5 fix: removed extra blank link between sections (#1451) 2023-04-03 17:32:01 +08:00
Fivele-Li
73937863f1 Merge pull request #1475 from qianyun210603/bugfix
[BUGFIX] potential file// url parsing error
2023-03-24 11:22:57 +08:00
BookSword
d010219ba6 Merge branch 'main' into bugfix 2023-03-23 16:11:19 +08:00
BookSword
4fc8a5f25f merge 2023-03-23 16:05:09 +08:00
Linlang
0e8bfcb5d3 fix_pylint_w0719 (#1463)
* fix_pylint_w0719

* remove_fixme
2023-03-17 19:25:49 +08:00
you-n-g
e457ca8511 Improve annotation & documentation for handler (#1312)
* Improve annotation & documentation for handler

* Add type
2023-03-15 21:15:40 +08:00
Huoran Li
4dbb8ecb86 Remove (#1464) 2023-03-15 15:26:44 +08:00
Huoran Li
653c082e7a Order execution open source (#1447)
* Waiting for bin data

* Complete readme

* CI

* Add inst filter by time

* Update qlib/data/dataset/processor.py

* typo

* Fix time filter bug

* Add Filter and set Universe

* Complete data pipeline

* Fix Provider Logger Info Args

* Add DQN; a minor bugfix in ppo reward.

* update readme. modify assertion logic in strategy check.

* Fix Doc issues and fix black

* Fix pylint Error

---------

Co-authored-by: Young <afe.young@gmail.com>
Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
2023-03-13 12:06:28 +08:00
you-n-g
f98e04ca9d Fix Field Name Error 2023-03-03 16:28:47 +08:00
Cadenza-Li
76f2fb1a1a Add ipynb format check (#1439)
* Update test_qlib_from_source.yml

* add ipynb format check to workflow

* test ipynb CI

* modify nbqa check path

* add pylint flake8 mypy check to ipynb

* check ipynb with black and pylint

* reformat .ipynb files

* format line length

nbqa black . -l 120

* update nbqa .ipynb format CI

* format old ipynb files

* add nbconvert check to CI

* adjust CI order to avoid repeating download data
2023-02-21 09:23:22 +08:00
Huoran Li
5eb5ac1f1f RL backtest pipeline on 5-min data (#1417)
* Workflow runnable

* CI

* Slight changes to make the workflow runnable. The changes of handler/provider should be reverted before merging.

* Train experiment successful

* Refine handler & provider

* test passed

* Ready to test on server

* Minor

* Test passed

* TWAP training

* Add PPOReward

* Add a FIXME

* Refine PPO reward according to PR comments

* Minor

* Resolve PR comments

* CI issues

* CI issues

* CI issues
2023-02-13 12:43:22 +08:00
Young
6295939346 Update to Dev Version 2023-01-29 18:55:23 +08:00
121 changed files with 5338 additions and 962 deletions

View File

@@ -13,7 +13,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]

View File

@@ -14,7 +14,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
@@ -28,8 +28,10 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Update pip to the latest version
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
# The pip version has been temporarily fixed to 23.0.1
run: |
python -m pip install --upgrade pip
python -m pip install pip==23.0.1
- name: Installing pytorch for macos
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
@@ -37,15 +39,13 @@ jobs:
python -m pip install torch torchvision torchaudio
- name: Installing pytorch for ubuntu
if: ${{ matrix.os == 'ubuntu-18.04' || matrix.os == 'ubuntu-20.04' }}
if: ${{ matrix.os == 'ubuntu-20.04' || matrix.os == 'ubuntu-22.04' }}
run: |
python -m pip install --upgrade pip
python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
- name: Installing pytorch for windows
if: ${{ matrix.os == 'windows-latest' }}
run: |
python -m pip install --upgrade pip
python -m pip install torch torchvision torchaudio
- name: Set up Python tools
@@ -120,6 +120,11 @@ jobs:
run: |
mypy qlib --install-types --non-interactive || true
mypy qlib --verbose
- name: Check Qlib ipynb with nbqa
run: |
nbqa black . -l 120 --check --diff
nbqa pylint . --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719,W0104,W0404,C0412,W0611,C0410 --const-rgx='[a-z_][a-z0-9_]{2,30}$'
- name: Test data downloads
run: |
@@ -138,6 +143,12 @@ jobs:
brew unlink libomp
brew install libomp.rb
# Run after data downloads
- name: Check Qlib ipynb with nbconvert
run: |
# add more ipynb files in future
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
- name: Test workflow by config (install from source)
run: |
python -m pip install numba

View File

@@ -14,7 +14,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
@@ -28,9 +28,10 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Set up Python tools
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
# The pip version has been temporarily fixed to 23.0.1
run: |
python -m pip install --upgrade pip
# python -m pip is necessary to upgrade pip.
python -m pip install pip==23.0.1
pip install --upgrade cython numpy
pip install -e .[dev]

4
.gitignore vendored
View File

@@ -10,7 +10,6 @@ _build
build/
dist/
*.pkl
*.hd5
*.csv
@@ -23,10 +22,13 @@ dist/
qlib/VERSION.txt
qlib/data/_libs/expanding.cpp
qlib/data/_libs/rolling.cpp
qlib/finco/prompt_cache.json
examples/estimator/estimator_example/
examples/rl/data/
examples/rl/checkpoints/
examples/rl/outputs/
examples/rl_order_execution/data/
examples/rl_order_execution/outputs/
*.egg-info/

View File

@@ -42,13 +42,11 @@ Features released before 2021 are not listed here.
<img src="http://fintech.msra.cn/images_v070/logo/1.png" />
</p>
Qlib is an open-source, AI-oriented quantitative investment platform that aims to realize the potential, empower research, and create value using AI technologies in quantitative investment, from exploring ideas to implementing productions. Qlib supports diverse machine learning modeling paradigms, including supervised learning, market dynamics modeling, and reinforcement learning.
Qlib is an AI-oriented quantitative investment platform, which aims to realize the potential, empower the research, and create the value of AI technologies in quantitative investment.
An increasing number of SOTA Quant research works/papers in diverse paradigms are being released in Qlib to collaboratively solve key challenges in quantitative investment. For example, 1) using supervised learning to mine the market's complex non-linear patterns from rich and heterogeneous financial data, 2) modeling the dynamic nature of the financial market using adaptive concept drift technology, and 3) using reinforcement learning to model continuous investment decisions and assist investors in optimizing their trading strategies.
It contains the full ML pipeline of data processing, model training, back-testing; and covers the entire chain of quantitative investment: alpha seeking, risk modeling, portfolio optimization, and order execution.
With Qlib, users can easily try ideas to create better Quant investment strategies.
For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative Investment Platform"](https://arxiv.org/abs/2009.11189).

View File

@@ -29,13 +29,13 @@ class Avg15minHandler(DataHandlerLP):
fit_end_time=None,
process_type=DataHandlerLP.PTYPE_A,
filter_pipe=None,
inst_processor=None,
inst_processors=None,
**kwargs,
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
data_loader = Avg15minLoader(
config=self.loader_config(), filter_pipe=filter_pipe, freq=freq, inst_processor=inst_processor
config=self.loader_config(), filter_pipe=filter_pipe, freq=freq, inst_processors=inst_processors
)
super().__init__(
instruments=instruments,

View File

@@ -18,7 +18,7 @@ data_handler_config: &data_handler_config
label: day
feature: 1min
# with label as reference
inst_processor:
inst_processors:
feature:
- class: Resample1minProcessor
module_path: features_sample.py

View File

@@ -19,7 +19,7 @@ data_handler_config: &data_handler_config
feature_15min: 1min
feature_day: day
# with label as reference
inst_processor:
inst_processors:
feature_15min:
- class: ResampleNProcessor
module_path: features_resample_N.py

View File

@@ -64,8 +64,6 @@ task:
kwargs:
loss: mse
lr: 0.002
lr_decay: 0.96
lr_decay_steps: 100
optimizer: adam
max_steps: 8000
batch_size: 8192

View File

@@ -64,8 +64,6 @@ task:
kwargs:
loss: mse
lr: 0.002
lr_decay: 0.96
lr_decay_steps: 100
optimizer: adam
max_steps: 8000
batch_size: 8192

View File

@@ -52,8 +52,6 @@ task:
kwargs:
loss: mse
lr: 0.002
lr_decay: 0.96
lr_decay_steps: 100
optimizer: adam
max_steps: 8000
batch_size: 4096

View File

@@ -52,8 +52,6 @@ task:
kwargs:
loss: mse
lr: 0.002
lr_decay: 0.96
lr_decay_steps: 100
optimizer: adam
max_steps: 8000
batch_size: 4096

View File

@@ -25,59 +25,65 @@
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib\n",
"sns.set(style='white')\n",
"matplotlib.rcParams['pdf.fonttype'] = 42\n",
"matplotlib.rcParams['ps.fonttype'] = 42\n",
"\n",
"sns.set(style=\"white\")\n",
"matplotlib.rcParams[\"pdf.fonttype\"] = 42\n",
"matplotlib.rcParams[\"ps.fonttype\"] = 42\n",
"\n",
"from tqdm.auto import tqdm\n",
"from joblib import Parallel, delayed\n",
"\n",
"\n",
"def func(x, N=80):\n",
" ret = x.ret.copy()\n",
" x = x.rank(pct=True)\n",
" x['ret'] = ret\n",
" x[\"ret\"] = ret\n",
" diff = x.score.sub(x.label)\n",
" r = x.nlargest(N, columns='score').ret.mean()\n",
" r -= x.nsmallest(N, columns='score').ret.mean()\n",
" return pd.Series({\n",
" 'MSE': diff.pow(2).mean(), \n",
" 'MAE': diff.abs().mean(), \n",
" 'IC': x.score.corr(x.label),\n",
" 'R': r\n",
" })\n",
" \n",
" r = x.nlargest(N, columns=\"score\").ret.mean()\n",
" r -= x.nsmallest(N, columns=\"score\").ret.mean()\n",
" return pd.Series(\n",
" {\n",
" \"MSE\": diff.pow(2).mean(),\n",
" \"MAE\": diff.abs().mean(),\n",
" \"IC\": x.score.corr(x.label),\n",
" \"R\": r,\n",
" }\n",
" )\n",
"\n",
"\n",
"ret = pd.read_pickle(\"data/ret.pkl\").clip(-0.1, 0.1)\n",
"\n",
"\n",
"def backtest(fname, **kwargs):\n",
" pred = pd.read_pickle(fname).loc['2018-09-21':'2020-06-30'] # test period\n",
" pred['ret'] = ret\n",
" pred = pd.read_pickle(fname).loc[\"2018-09-21\":\"2020-06-30\"] # test period\n",
" pred[\"ret\"] = ret\n",
" dates = pred.index.unique(level=0)\n",
" res = Parallel(n_jobs=-1)(delayed(func)(pred.loc[d], **kwargs) for d in dates)\n",
" res = {\n",
" dates[i]: res[i]\n",
" for i in range(len(dates))\n",
" }\n",
" res = {dates[i]: res[i] for i in range(len(dates))}\n",
" res = pd.DataFrame(res).T\n",
" r = res['R'].copy()\n",
" r = res[\"R\"].copy()\n",
" r.index = pd.to_datetime(r.index)\n",
" r = r.reindex(pd.date_range(r.index[0], r.index[-1])).fillna(0) # paper use 365 days\n",
" return {\n",
" 'MSE': res['MSE'].mean(),\n",
" 'MAE': res['MAE'].mean(),\n",
" 'IC': res['IC'].mean(),\n",
" 'ICIR': res['IC'].mean()/res['IC'].std(),\n",
" 'AR': r.mean()*365,\n",
" 'AV': r.std()*365**0.5,\n",
" 'SR': r.mean()/r.std()*365**0.5,\n",
" 'MDD': (r.cumsum().cummax() - r.cumsum()).max()\n",
" \"MSE\": res[\"MSE\"].mean(),\n",
" \"MAE\": res[\"MAE\"].mean(),\n",
" \"IC\": res[\"IC\"].mean(),\n",
" \"ICIR\": res[\"IC\"].mean() / res[\"IC\"].std(),\n",
" \"AR\": r.mean() * 365,\n",
" \"AV\": r.std() * 365**0.5,\n",
" \"SR\": r.mean() / r.std() * 365**0.5,\n",
" \"MDD\": (r.cumsum().cummax() - r.cumsum()).max(),\n",
" }, r\n",
"\n",
"\n",
"def fmt(x, p=3, scale=1, std=False):\n",
" _fmt = '{:.%df}'%p\n",
" _fmt = \"{:.%df}\" % p\n",
" string = _fmt.format((x.mean() if not isinstance(x, (float, np.floating)) else x) * scale)\n",
" if std and len(x) > 1:\n",
" string += ' ('+_fmt.format(x.std()*scale)+')'\n",
" string += \" (\" + _fmt.format(x.std() * scale) + \")\"\n",
" return string\n",
"\n",
"\n",
"def backtest_multi(files, **kwargs):\n",
" res = []\n",
" pnl = []\n",
@@ -88,14 +94,14 @@
" res = pd.DataFrame(res)\n",
" pnl = pd.concat(pnl, axis=1)\n",
" return {\n",
" 'MSE': fmt(res['MSE'], std=True),\n",
" 'MAE': fmt(res['MAE'], std=True),\n",
" 'IC': fmt(res['IC']),\n",
" 'ICIR': fmt(res['ICIR']),\n",
" 'AR': fmt(res['AR'], scale=100, p=1)+'%',\n",
" 'VR': fmt(res['AV'], scale=100, p=1)+'%',\n",
" 'SR': fmt(res['SR']),\n",
" 'MDD': fmt(res['MDD'], scale=100, p=1)+'%'\n",
" \"MSE\": fmt(res[\"MSE\"], std=True),\n",
" \"MAE\": fmt(res[\"MAE\"], std=True),\n",
" \"IC\": fmt(res[\"IC\"]),\n",
" \"ICIR\": fmt(res[\"ICIR\"]),\n",
" \"AR\": fmt(res[\"AR\"], scale=100, p=1) + \"%\",\n",
" \"VR\": fmt(res[\"AV\"], scale=100, p=1) + \"%\",\n",
" \"SR\": fmt(res[\"SR\"]),\n",
" \"MDD\": fmt(res[\"MDD\"], scale=100, p=1) + \"%\",\n",
" }, pnl"
]
},
@@ -124,16 +130,20 @@
"outputs": [],
"source": [
"exps = {\n",
" 'Linear': ['output/Linear/pred.pkl'],\n",
" 'LightGBM': ['output/GBDT/lr0.05_leaves128/pred.pkl'],\n",
" 'MLP': glob.glob('output/search/MLP/hs128_bs512_do0.3_lr0.001_seed*/pred.pkl'),\n",
" 'SFM': glob.glob('output/search/SFM/hs32_bs512_do0.5_lr0.001_seed*/pred.pkl'),\n",
" 'ALSTM': glob.glob('output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n",
" 'Trans.': glob.glob('output/search/Transformer/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n",
" 'ALSTM+TS':glob.glob('output/LSTM_Attn_TS/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n",
" 'Trans.+TS':glob.glob('output/Transformer_TS/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n",
" 'ALSTM+TRA(Ours)': glob.glob('output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n",
" 'Trans.+TRA(Ours)': glob.glob('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb1.0_head4_hs64_bs512_do0.1_lr0.0005_seed*/pred.pkl')\n",
" \"Linear\": [\"output/Linear/pred.pkl\"],\n",
" \"LightGBM\": [\"output/GBDT/lr0.05_leaves128/pred.pkl\"],\n",
" \"MLP\": glob.glob(\"output/search/MLP/hs128_bs512_do0.3_lr0.001_seed*/pred.pkl\"),\n",
" \"SFM\": glob.glob(\"output/search/SFM/hs32_bs512_do0.5_lr0.001_seed*/pred.pkl\"),\n",
" \"ALSTM\": glob.glob(\"output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl\"),\n",
" \"Trans.\": glob.glob(\"output/search/Transformer/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl\"),\n",
" \"ALSTM+TS\": glob.glob(\"output/LSTM_Attn_TS/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl\"),\n",
" \"Trans.+TS\": glob.glob(\"output/Transformer_TS/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl\"),\n",
" \"ALSTM+TRA(Ours)\": glob.glob(\n",
" \"output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n",
" ),\n",
" \"Trans.+TRA(Ours)\": glob.glob(\n",
" \"output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb1.0_head4_hs64_bs512_do0.1_lr0.0005_seed*/pred.pkl\"\n",
" ),\n",
"}"
]
},
@@ -160,14 +170,8 @@
}
],
"source": [
"res = {\n",
" name: backtest_multi(exps[name])\n",
" for name in tqdm(exps)\n",
"}\n",
"report = pd.DataFrame({\n",
" k: v[0]\n",
" for k, v in res.items()\n",
"}).T"
"res = {name: backtest_multi(exps[name]) for name in tqdm(exps)}\n",
"report = pd.DataFrame({k: v[0] for k, v in res.items()}).T"
]
},
{
@@ -385,24 +389,40 @@
}
],
"source": [
"df = pd.read_pickle('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed1000/pred.pkl')\n",
"code = 'SH600157'\n",
"date = '2018-09-28'\n",
"df = pd.read_pickle(\n",
" \"output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed1000/pred.pkl\"\n",
")\n",
"code = \"SH600157\"\n",
"date = \"2018-09-28\"\n",
"lookbackperiod = 50\n",
"\n",
"prob = df.iloc[:, -3:].loc(axis=0)[:, code].reset_index(level=1, drop=True).loc[date:].iloc[:lookbackperiod]\n",
"pred = df.loc[:,[\"score_0\",\"score_1\",\"score_2\",\"label\"]].loc(axis=0)[:, code].reset_index(level=1, drop=True).loc[date:].iloc[:lookbackperiod]\n",
"e_all = pred.iloc[:,:-1].sub(pred.iloc[:,-1], axis=0).pow(2)\n",
"pred = (\n",
" df.loc[:, [\"score_0\", \"score_1\", \"score_2\", \"label\"]]\n",
" .loc(axis=0)[:, code]\n",
" .reset_index(level=1, drop=True)\n",
" .loc[date:]\n",
" .iloc[:lookbackperiod]\n",
")\n",
"e_all = pred.iloc[:, :-1].sub(pred.iloc[:, -1], axis=0).pow(2)\n",
"e_all = e_all.sub(e_all.min(axis=1), axis=0)\n",
"e_all.columns = [r'$\\theta_%d$'%d for d in range(1, 4)]\n",
"e_all.columns = [r\"$\\theta_%d$\" % d for d in range(1, 4)]\n",
"prob = pd.Series(np.argmax(prob.values, axis=1), index=prob.index).rolling(7).mean().round()\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(7, 3))\n",
"e_all.plot(ax=axes[0], xlabel='', rot=30)\n",
"prob.plot(ax=axes[1], xlabel='', rot=30, color='red', linestyle='None', marker='^', markersize=5)\n",
"e_all.plot(ax=axes[0], xlabel=\"\", rot=30)\n",
"prob.plot(\n",
" ax=axes[1],\n",
" xlabel=\"\",\n",
" rot=30,\n",
" color=\"red\",\n",
" linestyle=\"None\",\n",
" marker=\"^\",\n",
" markersize=5,\n",
")\n",
"plt.yticks(np.array([0, 1, 2]), e_all.columns.values)\n",
"axes[0].set_ylabel('Predictor Loss')\n",
"axes[1].set_ylabel('Router Selection')\n",
"axes[0].set_ylabel(\"Predictor Loss\")\n",
"axes[1].set_ylabel(\"Router Selection\")\n",
"plt.tight_layout()\n",
"# plt.savefig('select.pdf', bbox_inches='tight')\n",
"plt.show()"
@@ -428,10 +448,18 @@
"outputs": [],
"source": [
"exps = {\n",
" 'Random': glob.glob('output/search/LSTM_Attn_tra/K10_traHs16_traSrcNONE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n",
" 'LR': glob.glob('output/search/LSTM_Attn_tra/K10_traHs16_traSrcLR_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n",
" 'TPE': glob.glob('output/search/LSTM_Attn_tra/K10_traHs16_traSrcTPE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n",
" 'LR+TPE': glob.glob('output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl')\n",
" \"Random\": glob.glob(\n",
" \"output/search/LSTM_Attn_tra/K10_traHs16_traSrcNONE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n",
" ),\n",
" \"LR\": glob.glob(\n",
" \"output/search/LSTM_Attn_tra/K10_traHs16_traSrcLR_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n",
" ),\n",
" \"TPE\": glob.glob(\n",
" \"output/search/LSTM_Attn_tra/K10_traHs16_traSrcTPE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n",
" ),\n",
" \"LR+TPE\": glob.glob(\n",
" \"output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n",
" ),\n",
"}"
]
},
@@ -456,14 +484,8 @@
}
],
"source": [
"res = {\n",
" name: backtest_multi(exps[name])\n",
" for name in tqdm(exps)\n",
"}\n",
"report = pd.DataFrame({\n",
" k: v[0]\n",
" for k, v in res.items()\n",
"}).T"
"res = {name: backtest_multi(exps[name]) for name in tqdm(exps)}\n",
"report = pd.DataFrame({k: v[0] for k, v in res.items()}).T"
]
},
{
@@ -597,18 +619,22 @@
}
],
"source": [
"a = pd.read_pickle('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl')\n",
"b = pd.read_pickle('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl')\n",
"a = pd.read_pickle(\n",
" \"output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl\"\n",
")\n",
"b = pd.read_pickle(\n",
" \"output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl\"\n",
")\n",
"a = a.iloc[:, -3:]\n",
"b = b.iloc[:, -3:]\n",
"b = np.eye(3)[b.values.argmax(axis=1)]\n",
"a = np.eye(3)[a.values.argmax(axis=1)]\n",
"\n",
"res = pd.DataFrame({\n",
" 'with OT': b.sum(axis=0) / b.sum(),\n",
" 'without OT': a.sum(axis=0)/ a.sum() \n",
"},index=[r'$\\theta_1$',r'$\\theta_2$',r'$\\theta_3$'])\n",
"res.plot.bar(rot=30, figsize=(5, 4), color=['b', 'g'])\n",
"res = pd.DataFrame(\n",
" {\"with OT\": b.sum(axis=0) / b.sum(), \"without OT\": a.sum(axis=0) / a.sum()},\n",
" index=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"],\n",
")\n",
"res.plot.bar(rot=30, figsize=(5, 4), color=[\"b\", \"g\"])\n",
"del a, b"
]
},
@@ -633,11 +659,19 @@
"outputs": [],
"source": [
"exps = {\n",
" 'K=1': glob.glob('output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/info.json'),\n",
" 'K=3': glob.glob('output/search/finetune/LSTM_Attn_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json'),\n",
" 'K=5': glob.glob('output/search/finetune/LSTM_Attn_tra/K5_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json'),\n",
" 'K=10': glob.glob('output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json'),\n",
" 'K=20': glob.glob('output/search/finetune/LSTM_Attn_tra/K20_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json')\n",
" \"K=1\": glob.glob(\"output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/info.json\"),\n",
" \"K=3\": glob.glob(\n",
" \"output/search/finetune/LSTM_Attn_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json\"\n",
" ),\n",
" \"K=5\": glob.glob(\n",
" \"output/search/finetune/LSTM_Attn_tra/K5_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json\"\n",
" ),\n",
" \"K=10\": glob.glob(\n",
" \"output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json\"\n",
" ),\n",
" \"K=20\": glob.glob(\n",
" \"output/search/finetune/LSTM_Attn_tra/K20_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json\"\n",
" ),\n",
"}"
]
},
@@ -649,16 +683,11 @@
"source": [
"report = dict()\n",
"for k, v in exps.items():\n",
" \n",
" tmp = dict()\n",
" for fname in v:\n",
" with open(fname) as f:\n",
" info = json.load(f)\n",
" tmp[fname] = (\n",
" {\n",
" \"IC\":info[\"metric\"][\"IC\"],\n",
" \"MSE\":info[\"metric\"][\"MSE\"]\n",
" })\n",
" tmp[fname] = {\"IC\": info[\"metric\"][\"IC\"], \"MSE\": info[\"metric\"][\"MSE\"]}\n",
" tmp = pd.DataFrame(tmp).T\n",
" report[k] = tmp.mean()\n",
"report = pd.DataFrame(report).T"
@@ -681,13 +710,14 @@
}
],
"source": [
"fig, axes = plt.subplots(1, 2, figsize=(6,3)); axes = axes.flatten()\n",
"report['IC'].plot.bar(rot=30, ax=axes[0])\n",
"fig, axes = plt.subplots(1, 2, figsize=(6, 3))\n",
"axes = axes.flatten()\n",
"report[\"IC\"].plot.bar(rot=30, ax=axes[0])\n",
"axes[0].set_ylim(0.045, 0.062)\n",
"axes[0].set_title('IC performance')\n",
"report['MSE'].astype(float).plot.bar(rot=30, ax=axes[1], color='green')\n",
"axes[0].set_title(\"IC performance\")\n",
"report[\"MSE\"].astype(float).plot.bar(rot=30, ax=axes[1], color=\"green\")\n",
"axes[1].set_ylim(0.155, 0.1585)\n",
"axes[1].set_title('MSE performance')\n",
"axes[1].set_title(\"MSE performance\")\n",
"plt.tight_layout()\n",
"# plt.savefig('sensitivity.pdf')"
]

View File

@@ -0,0 +1,107 @@
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(color_codes=True)
plt.rcParams["font.sans-serif"] = "SimHei"
plt.rcParams["axes.unicode_minus"] = False
from tqdm.auto import tqdm
# tqdm.pandas() # for progress_apply
# %matplotlib inline
# %load_ext autoreload
# # Meta Input
# +
with open("./internal_data_s20.pkl", "rb") as f:
data = pickle.load(f)
data.data_ic_df.columns.names = ["start_date", "end_date"]
data_sim = data.data_ic_df.droplevel(axis=1, level="end_date")
data_sim.index.name = "test datetime"
# -
plt.figure(figsize=(40, 20))
sns.heatmap(data_sim)
plt.figure(figsize=(40, 20))
sns.heatmap(data_sim.rolling(20).mean())
# # Meta Model
from qlib import auto_init
auto_init()
from qlib.workflow import R
exp = R.get_exp(experiment_name="DDG-DA")
meta_rec = exp.list_recorders(rtype="list", max_results=1)[0]
meta_m = meta_rec.load_object("model")
pd.DataFrame(meta_m.tn.twm.linear.weight.detach().numpy()).T[0].plot()
pd.DataFrame(meta_m.tn.twm.linear.weight.detach().numpy()).T[0].rolling(5).mean().plot()
# # Meta Output
# +
with open("./tasks_s20.pkl", "rb") as f:
tasks = pickle.load(f)
task_df = {}
for t in tasks:
test_seg = t["dataset"]["kwargs"]["segments"]["test"]
if None not in test_seg:
# The last rolling is skipped.
task_df[test_seg] = t["reweighter"].time_weight
task_df = pd.concat(task_df)
task_df.index.names = ["OS_start", "OS_end", "IS_start", "IS_end"]
task_df = task_df.droplevel(["OS_end", "IS_end"])
task_df = task_df.unstack("OS_start")
# -
plt.figure(figsize=(40, 20))
sns.heatmap(task_df.T)
plt.figure(figsize=(40, 20))
sns.heatmap(task_df.rolling(10).mean().T)
# # Sub Models
#
# NOTE:
# - this section assumes that the model is Linear model!!
# - Other models does not support this analysis
exp = R.get_exp(experiment_name="rolling_ds")
def show_linear_weight(exp):
coef_df = {}
for r in exp.list_recorders("list"):
t = r.load_object("task")
if None in t["dataset"]["kwargs"]["segments"]["test"]:
continue
m = r.load_object("params.pkl")
coef_df[t["dataset"]["kwargs"]["segments"]["test"]] = pd.Series(m.coef_)
coef_df = pd.concat(coef_df)
coef_df.index.names = ["test_start", "test_end", "coef_idx"]
coef_df = coef_df.droplevel("test_end").unstack("coef_idx").T
plt.figure(figsize=(40, 20))
sns.heatmap(coef_df)
plt.show()
show_linear_weight(R.get_exp(experiment_name="rolling_ds"))
show_linear_weight(R.get_exp(experiment_name="rolling_models"))

View File

@@ -10,8 +10,10 @@ import pandas as pd
import fire
import sys
import pickle
from typing import Optional
from qlib import auto_init
from qlib.model.trainer import TrainerR
from qlib.typehint import Literal
from qlib.utils import init_instance_by_config
from qlib.workflow import R
from qlib.tests.data import GetData
@@ -30,7 +32,33 @@ class DDGDA:
- `rm -r mlruns`
"""
def __init__(self, sim_task_model="linear", forecast_model="linear"):
def __init__(
self,
sim_task_model: Literal["linear", "gbdt"] = "linear",
forecast_model: Literal["linear", "gbdt"] = "linear",
h_path: Optional[str] = None,
test_end: Optional[str] = None,
train_start: Optional[str] = None,
meta_1st_train_end: Optional[str] = None,
task_ext_conf: Optional[dict] = None,
alpha: float = 0.0,
proxy_hd: str = "handler_proxy.pkl",
):
"""
Parameters
----------
train_start: Optional[str]
the start datetime for data. It is used in training start time (for both tasks & meta learing)
test_end: Optional[str]
the end datetime for data. It is used in test end time
meta_1st_train_end: Optional[str]
the datetime of training end of the first meta_task
alpha: float
Setting the L2 regularization for ridge
The `alpha` is only passed to MetaModelDS (it is not passed to sim_task_model currently..)
"""
self.step = 20
# NOTE:
# the horizon must match the meaning in the base task template
@@ -38,10 +66,19 @@ class DDGDA:
self.meta_exp_name = "DDG-DA"
self.sim_task_model = sim_task_model # The model to capture the distribution of data.
self.forecast_model = forecast_model # downstream forecasting models' type
self.rb_kwargs = {
"h_path": h_path,
"test_end": test_end,
"train_start": train_start,
"task_ext_conf": task_ext_conf,
}
self.alpha = alpha
self.meta_1st_train_end = meta_1st_train_end
self.proxy_hd = proxy_hd
def get_feature_importance(self):
# this must be lightGBM, because it needs to get the feature importance
rb = RollingBenchmark(model_type="gbdt")
rb = RollingBenchmark(model_type="gbdt", **self.rb_kwargs)
task = rb.basic_task()
with R.start(experiment_name="feature_importance"):
@@ -69,7 +106,7 @@ class DDGDA:
fi = self.get_feature_importance()
col_selected = fi.nlargest(topk)
rb = RollingBenchmark(model_type=self.sim_task_model)
rb = RollingBenchmark(model_type=self.sim_task_model, **self.rb_kwargs)
task = rb.basic_task()
dataset = init_instance_by_config(task["dataset"])
prep_ds = dataset.prepare(slice(None), col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
@@ -96,7 +133,7 @@ class DDGDA:
"kwargs": {"config": DIRNAME / "fea_label_df.pkl"},
}
)
handler.to_pickle(DIRNAME / "handler_proxy.pkl", dump_all=True)
handler.to_pickle(DIRNAME / self.proxy_hd, dump_all=True)
@property
def _internal_data_path(self):
@@ -108,7 +145,7 @@ class DDGDA:
This function will dump the input data for meta model
"""
# According to the experiments, the choice of the model type is very important for achieving good results
rb = RollingBenchmark(model_type=self.sim_task_model)
rb = RollingBenchmark(model_type=self.sim_task_model, **self.rb_kwargs)
sim_task = rb.basic_task()
if self.sim_task_model == "gbdt":
@@ -122,24 +159,27 @@ class DDGDA:
with self._internal_data_path.open("wb") as f:
pickle.dump(internal_data, f)
def train_meta_model(self):
def train_meta_model(self, fill_method="max"):
"""
training a meta model based on a simplified linear proxy model;
"""
# 1) leverage the simplified proxy forecasting model to train meta model.
# - Only the dataset part is important, in current version of meta model will integrate the
rb = RollingBenchmark(model_type=self.sim_task_model)
rb = RollingBenchmark(model_type=self.sim_task_model, **self.rb_kwargs)
sim_task = rb.basic_task()
train_start = self.rb_kwargs.get("train_start", "2008-01-01")
train_end = "2010-12-31" if self.meta_1st_train_end is None else self.meta_1st_train_end
test_start = (pd.Timestamp(train_end) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
proxy_forecast_model_task = {
# "model": "qlib.contrib.model.linear.LinearModel",
"dataset": {
"class": "qlib.data.dataset.DatasetH",
"kwargs": {
"handler": f"file://{(DIRNAME / 'handler_proxy.pkl').absolute()}",
"handler": f"file://{(DIRNAME / self.proxy_hd).absolute()}",
"segments": {
"train": ("2008-01-01", "2010-12-31"),
"test": ("2011-01-01", sim_task["dataset"]["kwargs"]["segments"]["test"][1]),
"train": (train_start, train_end),
"test": (test_start, sim_task["dataset"]["kwargs"]["segments"]["test"][1]),
},
},
},
@@ -156,7 +196,7 @@ class DDGDA:
segments=0.62, # keep test period consistent with the dataset yaml
trunc_days=1 + self.horizon,
hist_step_n=30,
fill_method="max",
fill_method=fill_method,
rolling_ext_days=0,
)
# NOTE:
@@ -165,12 +205,15 @@ class DDGDA:
# So the misalignment will not affect the effectiveness of the method.
with self._internal_data_path.open("rb") as f:
internal_data = pickle.load(f)
md = MetaDatasetDS(exp_name=internal_data, **kwargs)
# 3) train and logging meta model
with R.start(experiment_name=self.meta_exp_name):
R.log_params(**kwargs)
mm = MetaModelDS(step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=100, seed=43)
mm = MetaModelDS(
step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=100, seed=43, alpha=self.alpha
)
mm.fit(md)
R.save_objects(model=mm)
@@ -203,7 +246,7 @@ class DDGDA:
hist_step_n = int(param["hist_step_n"])
fill_method = param.get("fill_method", "max")
rb = RollingBenchmark(model_type=self.forecast_model)
rb = RollingBenchmark(model_type=self.forecast_model, **self.rb_kwargs)
task_l = rb.create_rolling_tasks()
# 2.2) create meta dataset for final dataset
@@ -233,13 +276,13 @@ class DDGDA:
"""
with self._task_path.open("rb") as f:
tasks = pickle.load(f)
rb = RollingBenchmark(rolling_exp="rolling_ds", model_type=self.forecast_model)
rb = RollingBenchmark(rolling_exp="rolling_ds", model_type=self.forecast_model, **self.rb_kwargs)
rb.train_rolling_tasks(tasks)
rb.ens_rolling()
rb.update_rolling_rec()
def run_all(self):
# 1) file: handler_proxy.pkl
# 1) file: handler_proxy.pkl (self.proxy_hd)
self.dump_data_for_proxy_model()
# 2)
# file: internal_data_s20.pkl

View File

@@ -1,13 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Optional
from qlib.model.ens.ensemble import RollingEnsemble
from qlib.utils import init_instance_by_config
import fire
import yaml
import pandas as pd
from qlib import auto_init
from pathlib import Path
from tqdm.auto import tqdm
from qlib.model.trainer import TrainerR
from qlib.log import get_module_logger
from qlib.utils.data import update_config
from qlib.workflow import R
from qlib.tests.data import GetData
@@ -25,11 +29,40 @@ class RollingBenchmark:
"""
def __init__(self, rolling_exp="rolling_models", model_type="linear") -> None:
def __init__(
self,
rolling_exp: str = "rolling_models",
model_type: str = "linear",
h_path: Optional[str] = None,
train_start: Optional[str] = None,
test_end: Optional[str] = None,
task_ext_conf: Optional[dict] = None,
) -> None:
"""
Parameters
----------
rolling_exp : str
The name for the experiments for rolling
model_type : str
The model to be boosted.
h_path : Optional[str]
the dumped data handler;
test_end : Optional[str]
the test end for the data. It is typically used together with the handler
train_start : Optional[str]
the train start for the data. It is typically used together with the handler.
task_ext_conf : Optional[dict]
some option to update the
"""
self.step = 20
self.horizon = 20
self.rolling_exp = rolling_exp
self.model_type = model_type
self.h_path = h_path
self.train_start = train_start
self.test_end = test_end
self.logger = get_module_logger("RollingBenchmark")
self.task_ext_conf = task_ext_conf
def basic_task(self):
"""For fast training rolling"""
@@ -42,6 +75,10 @@ class RollingBenchmark:
h_path = DIRNAME / "linear_alpha158_handler_horizon{}.pkl".format(self.horizon)
else:
raise AssertionError("Model type is not supported!")
if self.h_path is not None:
h_path = Path(self.h_path)
with conf_path.open("r") as f:
conf = yaml.safe_load(f)
@@ -52,6 +89,9 @@ class RollingBenchmark:
task = conf["task"]
if self.task_ext_conf is not None:
task = update_config(task, self.task_ext_conf)
if not h_path.exists():
h_conf = task["dataset"]["kwargs"]["handler"]
h = init_instance_by_config(h_conf)
@@ -59,6 +99,15 @@ class RollingBenchmark:
task["dataset"]["kwargs"]["handler"] = f"file://{h_path}"
task["record"] = ["qlib.workflow.record_temp.SignalRecord"]
if self.train_start is not None:
seg = task["dataset"]["kwargs"]["segments"]["train"]
task["dataset"]["kwargs"]["segments"]["train"] = pd.Timestamp(self.train_start), seg[1]
if self.test_end is not None:
seg = task["dataset"]["kwargs"]["segments"]["test"]
task["dataset"]["kwargs"]["segments"]["test"] = seg[0], pd.Timestamp(self.test_end)
self.logger.info(task)
return task
def create_rolling_tasks(self):
@@ -93,7 +142,7 @@ class RollingBenchmark:
"""
Evaluate the combined rolling results
"""
for rid, rec in R.list_recorders(experiment_name=self.COMB_EXP).items():
for _, rec in R.list_recorders(experiment_name=self.COMB_EXP).items():
for rt_cls in SigAnaRecord, PortAnaRecord:
rt = rt_cls(recorder=rec, skip_existing=True)
rt.generate()

View File

@@ -1,60 +0,0 @@
This folder contains a simple example of how to run Qlib RL. It contains:
```
.
├── experiment_config
│ ├── backtest # Backtest config
│ └── training # Training config
├── README.md # Readme (the current file)
└── scripts # Scripts for data pre-processing
```
## Data preparation
Use [AzCopy](https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-v10) to download data:
```
azcopy copy https://qlibpublic.blob.core.windows.net/data/rl/qlib_rl_example_data ./ --recursive
mv qlib_rl_example_data data
```
The downloaded data will be placed at `./data`. The original data are in `data/csv`. To create all data needed by the case, run:
```
bash scripts/data_pipeline.sh
```
After the execution finishes, the `data/` directory should be like:
```
data
├── backtest_orders.csv
├── bin
├── csv
├── pickle
├── pickle_dataframe
└── training_order_split
```
## Run training
Run:
```
python -m qlib.rl.contrib.train_onpolicy --config_path ./experiment_config/training/config.yml
```
After training, checkpoints will be stored under `checkpoints/`.
## Run backtest
```
python -m qlib.rl.contrib.backtest --config_path ./experiment_config/backtest/config.yml
```
The backtest workflow will use the trained model in `checkpoints/`. The backtest summary can be found in `outputs/`.
## Others
The RL module is designed in a loosely-coupled way. Currently, RL examples are integrated with concrete business logic.
But the core part of RL is much simpler than what you see.
To demonstrate the simple core of RL, [a dedicated notebook](./simple_example.ipynb) for RL without business loss is created.

View File

@@ -1,57 +0,0 @@
order_file: ./data/backtest_orders.csv
start_time: "9:45"
end_time: "14:44"
qlib:
provider_uri_1min: ./data/bin
feature_root_dir: ./data/pickle
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$volume",
]
feature_columns_yesterday: [
"$open_v1", "$high_v1", "$low_v1", "$close_v1", "$vwap_v1", "$volume_v1",
]
exchange:
limit_threshold: ['$close == 0', '$close == 0']
deal_price: ["If($close == 0, $vwap, $close)", "If($close == 0, $vwap, $close)"]
volume_threshold:
all: ["cum", "0.2 * DayCumsum($volume, '9:45', '14:44')"]
buy: ["current", "$close"]
sell: ["current", "$close"]
strategies:
30min:
class: TWAPStrategy
module_path: qlib.contrib.strategy.rule_strategy
kwargs: {}
1day:
class: SAOEIntStrategy
module_path: qlib.rl.order_execution.strategy
kwargs:
state_interpreter:
class: FullHistoryStateInterpreter
module_path: qlib.rl.order_execution.interpreter
kwargs:
max_step: 8
data_ticks: 240
data_dim: 6
processed_data_provider:
class: PickleProcessedDataProvider
module_path: qlib.rl.data.pickle_styled
kwargs:
data_dir: ./data/pickle_dataframe/feature
action_interpreter:
class: CategoricalActionInterpreter
module_path: qlib.rl.order_execution.interpreter
kwargs:
values: 14
max_step: 8
network:
class: Recurrent
module_path: qlib.rl.order_execution.network
kwargs: {}
policy:
class: PPO
module_path: qlib.rl.order_execution.policy
kwargs:
lr: 1.0e-4
weight_file: ./checkpoints/latest.pth
concurrency: 5

View File

@@ -1,59 +0,0 @@
simulator:
time_per_step: 30
vol_limit: null
env:
concurrency: 1
parallel_mode: dummy
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
values: 14
max_step: 8
module_path: qlib.rl.order_execution.interpreter
state_interpreter:
class: FullHistoryStateInterpreter
kwargs:
data_dim: 6
data_ticks: 240
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
module_path: qlib.rl.data.pickle_styled
kwargs:
data_dir: ./data/pickle_dataframe/feature
module_path: qlib.rl.order_execution.interpreter
reward:
class: PAPenaltyReward
kwargs:
penalty: 100.0
module_path: qlib.rl.order_execution.reward
data:
source:
order_dir: ./data/training_order_split
data_dir: ./data/pickle_dataframe/backtest
total_time: 240
default_start_time: 0
default_end_time: 240
proc_data_dim: 6
num_workers: 0
queue_size: 20
network:
class: Recurrent
module_path: qlib.rl.order_execution.network
policy:
class: PPO
kwargs:
lr: 0.0001
module_path: qlib.rl.order_execution.policy
runtime:
seed: 42
use_cuda: false
trainer:
max_epoch: 2
repeat_per_collect: 5
earlystop_patience: 2
episode_per_collect: 20
batch_size: 16
val_every_n_epoch: 1
checkpoint_path: ./checkpoints
checkpoint_every_n_iters: 1

View File

@@ -1,21 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import pickle
import pandas as pd
from tqdm import tqdm
os.makedirs(os.path.join("data", "pickle_dataframe"), exist_ok=True)
for tag in ("backtest", "feature"):
df = pickle.load(open(os.path.join("data", "pickle", f"{tag}.pkl"), "rb"))
df = pd.concat(list(df.values())).reset_index()
df["date"] = df["datetime"].dt.date.astype("datetime64")
instruments = sorted(set(df["instrument"]))
os.makedirs(os.path.join("data", "pickle_dataframe", tag), exist_ok=True)
for instrument in tqdm(instruments):
cur = df[df["instrument"] == instrument].sort_values(by=["datetime"])
cur = cur.set_index(["instrument", "datetime", "date"])
pickle.dump(cur, open(os.path.join("data", "pickle_dataframe", tag, f"{instrument}.pkl"), "wb"))

View File

@@ -1,14 +0,0 @@
# Generate `bin` format data
set -e
python ../../scripts/dump_bin.py dump_all --csv_path ./data/csv --qlib_dir ./data/bin --include_fields open,close,high,low,vwap,volume --symbol_field_name symbol --date_field_name date --freq 1min
# Generate pickle format data
python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml
if [ -e stat/ ]; then
rm -r stat/
fi
python scripts/collect_pickle_dataframe.py
# Sample orders
python scripts/gen_training_orders.py
python scripts/gen_backtest_orders.py

View File

@@ -1,55 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import os
import pandas as pd
import numpy as np
import pickle
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=20220926)
parser.add_argument("--num_order", type=int, default=10)
args = parser.parse_args()
np.random.seed(args.seed)
path = os.path.join("data", "pickle", "backtesttest.pkl")
df = pickle.load(open(path, "rb")).reset_index()
df["date"] = df["datetime"].dt.date.astype("datetime64")
instruments = sorted(set(df["instrument"]))
# TODO: The example is expected to be able to handle data containing missing values.
# TODO: Currently, we just simply skip dates that contain missing data. We will add
# TODO: this feature in the future.
skip_dates = {}
for instrument in instruments:
csv_df = pd.read_csv(os.path.join("data", "csv", f"{instrument}.csv"))
csv_df = csv_df[csv_df["close"].isna()]
dates = set([str(d).split(" ")[0] for d in csv_df["date"]])
skip_dates[instrument] = dates
df_list = []
for instrument in instruments:
print(instrument)
cur_df = df[df["instrument"] == instrument]
dates = sorted(set([str(d).split(" ")[0] for d in cur_df["date"]]))
dates = [date for date in dates if date not in skip_dates[instrument]]
n = args.num_order
df_list.append(
pd.DataFrame(
{
"date": sorted(np.random.choice(dates, size=n, replace=False)),
"instrument": [instrument] * n,
"amount": np.random.randint(low=3, high=11, size=n) * 100.0,
"order_type": np.random.randint(low=0, high=2, size=n),
}
).set_index(["date", "instrument"]),
)
total_df = pd.concat(df_list)
total_df.to_csv("data/backtest_orders.csv")

View File

@@ -1,39 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import os
import pandas as pd
import numpy as np
import pickle
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=20220926)
parser.add_argument("--stock", type=str, default="AAPL")
parser.add_argument("--train_size", type=int, default=10)
parser.add_argument("--valid_size", type=int, default=2)
parser.add_argument("--test_size", type=int, default=2)
args = parser.parse_args()
np.random.seed(args.seed)
os.makedirs(os.path.join("data", "training_order_split"), exist_ok=True)
for group, n in zip(("train", "valid", "test"), (args.train_size, args.valid_size, args.test_size)):
path = os.path.join("data", "pickle", f"backtest{group}.pkl")
df = pickle.load(open(path, "rb")).reset_index()
df["date"] = df["datetime"].dt.date.astype("datetime64")
dates = sorted(set([str(d).split(" ")[0] for d in df["date"]]))
data_df = pd.DataFrame(
{
"date": sorted(np.random.choice(dates, size=n, replace=False)),
"instrument": [args.stock] * n,
"amount": np.random.randint(low=3, high=11, size=n) * 100.0,
"order_type": [0] * n,
}
).set_index(["date", "instrument"])
os.makedirs(os.path.join("data", "training_order_split", group), exist_ok=True)
pickle.dump(data_df, open(os.path.join("data", "training_order_split", group, f"{args.stock}.pkl"), "wb"))

View File

@@ -41,6 +41,7 @@
"\n",
"State = namedtuple(\"State\", [\"value\", \"last_action\"])\n",
"\n",
"\n",
"class SimpleSimulator(Simulator[float, State, float]):\n",
" def __init__(self, initial: float, nsteps: int, **kwargs: Any) -> None:\n",
" super().__init__(initial)\n",
@@ -92,6 +93,7 @@
"from gym import spaces\n",
"from qlib.rl.interpreter import StateInterpreter\n",
"\n",
"\n",
"class SimpleStateInterpreter(StateInterpreter[Tuple[float, float], np.ndarray]):\n",
" def interpret(self, state: State) -> np.ndarray:\n",
" # Convert state.value to a 1D Numpy array\n",
@@ -101,7 +103,8 @@
" @property\n",
" def observation_space(self) -> spaces.Box:\n",
" return spaces.Box(0, np.inf, shape=(1,), dtype=np.float32)\n",
" \n",
"\n",
"\n",
"state_interpreter = SimpleStateInterpreter()"
]
},
@@ -120,6 +123,7 @@
"source": [
"from qlib.rl.interpreter import ActionInterpreter\n",
"\n",
"\n",
"class SimpleActionInterpreter(ActionInterpreter[State, int, float]):\n",
" def __init__(self, n_value: int) -> None:\n",
" self.n_value = n_value\n",
@@ -132,7 +136,8 @@
" assert 0 <= action <= self.n_value\n",
" # simulator_state.value is used as the denominator\n",
" return simulator_state.value * (action / self.n_value)\n",
" \n",
"\n",
"\n",
"action_interpreter = SimpleActionInterpreter(n_value=10)"
]
},
@@ -151,12 +156,14 @@
"source": [
"from qlib.rl.reward import Reward\n",
"\n",
"\n",
"class SimpleReward(Reward[State]):\n",
" def reward(self, simulator_state: State) -> float:\n",
" # Use last_action to calculate reward. This is why it should be in the state.\n",
" rew = simulator_state.last_action / simulator_state.value\n",
" return rew\n",
" \n",
"\n",
"\n",
"reward = SimpleReward()"
]
},
@@ -180,6 +187,7 @@
"from torch import nn\n",
"from qlib.rl.order_execution import PPO\n",
"\n",
"\n",
"class SimpleFullyConnect(nn.Module):\n",
" def __init__(self, dims: List[int]) -> None:\n",
" super().__init__()\n",
@@ -195,7 +203,8 @@
"\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" return self.fc(x)\n",
" \n",
"\n",
"\n",
"policy = PPO(\n",
" network=SimpleFullyConnect(dims=[16, 8]),\n",
" obs_space=state_interpreter.observation_space,\n",
@@ -221,6 +230,7 @@
"source": [
"from torch.utils.data import Dataset\n",
"\n",
"\n",
"class SimpleDataset(Dataset):\n",
" def __init__(self, positions: List[float]) -> None:\n",
" self.positions = positions\n",
@@ -230,7 +240,8 @@
"\n",
" def __getitem__(self, index: int) -> float:\n",
" return self.positions[index]\n",
" \n",
"\n",
"\n",
"dataset = SimpleDataset(positions=[10.0, 50.0, 100.0])"
]
},
@@ -265,11 +276,13 @@
"trainer_kwargs = {\n",
" \"max_iters\": 10,\n",
" \"finite_env_type\": \"dummy\",\n",
" \"callbacks\": [Checkpoint(\n",
" dirpath=Path(\"./checkpoints\"),\n",
" every_n_iters=1,\n",
" save_latest=\"copy\",\n",
" )],\n",
" \"callbacks\": [\n",
" Checkpoint(\n",
" dirpath=Path(\"./checkpoints\"),\n",
" every_n_iters=1,\n",
" save_latest=\"copy\",\n",
" )\n",
" ],\n",
"}\n",
"vessel_kwargs = {\n",
" \"update_kwargs\": {\"batch_size\": 16, \"repeat\": 5},\n",

View File

@@ -0,0 +1,100 @@
# RL Example for Order Execution
This folder comprises an example of Reinforcement Learning (RL) workflows for order execution scenario, including both training workflows and backtest workflows.
## Data Processing
### Get Data
```
python -m qlib.run.get_data qlib_data qlib_data --target_dir ./data/bin --region hs300 --interval 5min
```
### Generate Pickle-Style Data
To run codes in this example, we need data in pickle format. To achieve this, run following commands (might need a few minutes to finish):
[//]: # (TODO: Instead of dumping dataframe with different format &#40;like `_gen_dataset` and `_gen_day_dataset` in `qlib/contrib/data/highfreq_provider.py`&#41;, we encourage to implement different subclass of `Dataset` and `DataHandler`. This will keep the workflow cleaner and interfaces more consistent, and move all the complexity to the subclass.)
```
python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml
python scripts/gen_training_orders.py
python scripts/merge_orders.py
```
When finished, the structure under `data/` should be:
```
data
├── bin
├── orders
└── pickle
```
## Training
Each training task is specified by a config file. The config file for task `TASKNAME` is `exp_configs/train_TASKNAME.yml`. This example provides two training tasks:
- **PPO**: Method proposed by IJCAL 2020 paper "[An End-to-End Optimal Trade Execution Framework based on Proximal Policy Optimization](https://www.ijcai.org/proceedings/2020/0627.pdf)".
- **OPDS**: Method proposed by AAAI 2021 paper "[Universal Trading for Order Execution with Oracle Policy Distillation](https://arxiv.org/abs/2103.10860)".
The main differece between these two methods is their reward functions. Please see their config files for details.
Take OPDS as an example, to run the training workflow, run:
```
python -m qlib.rl.contrib.train_onpolicy --config_path exp_configs/train_opds.yml --run_backtest
```
Metrics, logs, and checkpoints will be stored under `outputs/opds` (configured by `exp_configs/train_opds.yml`).
## Backtest
Once the training workflow has completed, the trained model can be used for the backtesting workflow. Still taking OPDS as an example, once training is finished, the latest checkpoint of the model can be found at `outputs/opds/checkpoints/latest.pth`. To run backtest workflow:
1. Uncomment the `weight_file` parameter in `exp_configs/train_opds.yml` (it is commented by default). While it is possible to run the backtesting workflow without setting a checkpoint, this will lead to randomly initialized model results, thus making them meaningless.
2. Run `python -m qlib.rl.contrib.backtest --config_path exp_configs/backtest_opds.yml`.
The backtest result is stored in `outputs/checkpoints/backtest_result.csv`.
In addition to OPDS and PPO, we also provide TWAP ([Time-weighted average price](https://en.wikipedia.org/wiki/Time-weighted_average_price)) as a weak baseline. The config file for TWAP is `exp_configs/backtest_twap.yml`.
### Gap between backtest and training pipeline's testing
It is worthy to notice that the results of the backtesting process may differ from the results of the testing process used during training.
This is because different simulators are used to simulate market conditions during training and backtesting.
In training pipeline, the simplified simulator called `SingleAssetOrderExecutionSimple` is used for efficiency reasons.
`SingleAssetOrderExecutionSimple` makes no restriction to trading amounts.
No matter what the amount of the order is, it can be completely executed.
However, during backtesting, a more realistic simulator called `SingleAssetOrderExecution` is used.
It takes into account practical constraints in more real-world scenarios (for example, the trading volume must be a multiple of the smallest trading unit).
As a result, the amount of an order that is actually executed during backtesting may differ from the amount expected to be executed.
If you would like to obtain results that are exactly the same as those obtained during testing in the training pipeline, you could run training pipeline with only backtest phrase.
In order to do this:
- Modify the training config. Add the path of the checkpoint you want to use (see following for an example).
- Run `python -m qlib.rl.contrib.train_onpolicy --config_path PATH/TO/CONFIG --run_backtest --no_training`
```yaml
...
policy:
class: PPO # PPO, DQN
kwargs:
lr: 0.0001
weight_file: PATH/TO/CHECKPOINT
module_path: qlib.rl.order_execution.policy
...
```
## Benchmarks (TBD)
To accurately evaluate the performance of models using Reinforcement Learning algorithms, it's best to run experiments multiple times and compute the average performance across all trials. However, given the time-consuming nature of model training, this is not always feasible. An alternative approach is to run each training task only once, selecting the 10 checkpoints with the highest validation performance to simulate multiple trials. In this example, we use "Price Advantage (PA)" as the metric for selecting these checkpoints. The average performance of these 10 checkpoints on the testing set is as follows:
| **Model** | **PA mean with std.** |
|-----------------------------|-----------------------|
| OPDS (with PPO policy) | 0.4785 ± 0.7815 |
| OPDS (with DQN policy) | -0.0114 ± 0.5780 |
| PPO | -1.0935 ± 0.0922 |
| TWAP | ≈ 0.0 ± 0.0 |
The table above also includes TWAP as a rule-based baseline. The ideal PA of TWAP should be 0.0, however, in this example, the order execution is divided into two steps: first, the order is split equally among each half hour, and then each five minutes within each half hour. Since trading is forbidden during the last five minutes of the day, this approach may slightly differ from traditional TWAP over the course of a full day (as there are 5 minutes missing in the last "half hour"). Therefore, the PA of TWAP can be considered as a number that is close to 0.0. To verify this, you may run a TWAP backtest and check the results.

View File

@@ -0,0 +1,53 @@
order_file: ./data/orders/test_orders.pkl
start_time: "9:30"
end_time: "14:54"
data_granularity: "5min"
qlib:
provider_uri_5min: ./data/bin/
exchange:
limit_threshold: null
deal_price: ["$close", "$close"]
volume_threshold: null
strategies:
1day:
class: SAOEIntStrategy
kwargs:
data_granularity: 5
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
max_step: 8
values: 4
module_path: qlib.rl.order_execution.interpreter
network:
class: Recurrent
kwargs: {}
module_path: qlib.rl.order_execution.network
policy:
class: PPO # PPO, DQN
kwargs:
lr: 0.0001
# Restore `weight_file` once the training workflow finishes. You can change the checkpoint file you want to use.
# weight_file: outputs/opds/checkpoints/latest.pth
module_path: qlib.rl.order_execution.policy
state_interpreter:
class: FullHistoryStateInterpreter
kwargs:
data_dim: 5
data_ticks: 48
max_step: 8
processed_data_provider:
class: HandlerProcessedDataProvider
kwargs:
data_dir: ./data/pickle/
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
module_path: qlib.rl.data.native
module_path: qlib.rl.order_execution.interpreter
module_path: qlib.rl.order_execution.strategy
30min:
class: TWAPStrategy
kwargs: {}
module_path: qlib.contrib.strategy.rule_strategy
concurrency: 16
output_dir: outputs/opds/

View File

@@ -0,0 +1,53 @@
order_file: ./data/orders/test_orders.pkl
start_time: "9:30"
end_time: "14:54"
data_granularity: "5min"
qlib:
provider_uri_5min: ./data/bin/
exchange:
limit_threshold: null
deal_price: ["$close", "$close"]
volume_threshold: null
strategies:
1day:
class: SAOEIntStrategy
kwargs:
data_granularity: 5
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
max_step: 8
values: 4
module_path: qlib.rl.order_execution.interpreter
network:
class: Recurrent
kwargs: {}
module_path: qlib.rl.order_execution.network
policy:
class: PPO # PPO, DQN
kwargs:
lr: 0.0001
# Restore `weight_file` once the training workflow finishes. You can change the checkpoint file you want to use.
# weight_file: outputs/ppo/checkpoints/latest.pth
module_path: qlib.rl.order_execution.policy
state_interpreter:
class: FullHistoryStateInterpreter
kwargs:
data_dim: 5
data_ticks: 48
max_step: 8
processed_data_provider:
class: HandlerProcessedDataProvider
kwargs:
data_dir: ./data/pickle/
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
module_path: qlib.rl.data.native
module_path: qlib.rl.order_execution.interpreter
module_path: qlib.rl.order_execution.strategy
30min:
class: TWAPStrategy
kwargs: {}
module_path: qlib.contrib.strategy.rule_strategy
concurrency: 16
output_dir: outputs/ppo/

View File

@@ -0,0 +1,21 @@
order_file: ./data/orders/test_orders.pkl
start_time: "9:30"
end_time: "14:54"
data_granularity: "5min"
qlib:
provider_uri_5min: ./data/bin/
exchange:
limit_threshold: null
deal_price: ["$close", "$close"]
volume_threshold: null
strategies:
1day:
class: TWAPStrategy
kwargs: {}
module_path: qlib.contrib.strategy.rule_strategy
30min:
class: TWAPStrategy
kwargs: {}
module_path: qlib.contrib.strategy.rule_strategy
concurrency: 16
output_dir: outputs/twap/

View File

@@ -0,0 +1,66 @@
simulator:
data_granularity: 5
time_per_step: 30
vol_limit: null
env:
concurrency: 32
parallel_mode: dummy
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
values: 4
max_step: 8
module_path: qlib.rl.order_execution.interpreter
state_interpreter:
class: FullHistoryStateInterpreter
kwargs:
data_dim: 5
data_ticks: 48 # 48 = 240 min / 5 min
max_step: 8
processed_data_provider:
class: HandlerProcessedDataProvider
kwargs:
data_dir: ./data/pickle/
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
backtest: false
module_path: qlib.rl.data.native
module_path: qlib.rl.order_execution.interpreter
reward:
class: PAPenaltyReward
kwargs:
penalty: 4.0
scale: 0.01
module_path: qlib.rl.order_execution.reward
data:
source:
order_dir: ./data/orders
feature_root_dir: ./data/pickle/
feature_columns_today: ["$close0", "$volume0"]
feature_columns_yesterday: []
total_time: 240
default_start_time_index: 0
default_end_time_index: 235
proc_data_dim: 5
num_workers: 0
queue_size: 20
network:
class: Recurrent
module_path: qlib.rl.order_execution.network
policy:
class: PPO # PPO, DQN
kwargs:
lr: 0.0001
module_path: qlib.rl.order_execution.policy
runtime:
seed: 42
use_cuda: false
trainer:
max_epoch: 500
repeat_per_collect: 25
earlystop_patience: 50
episode_per_collect: 10000
batch_size: 1024
val_every_n_epoch: 4
checkpoint_path: ./outputs/opds
checkpoint_every_n_iters: 1

View File

@@ -0,0 +1,67 @@
simulator:
data_granularity: 5
time_per_step: 30
vol_limit: null
env:
concurrency: 32
parallel_mode: dummy
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
values: 4
max_step: 8
module_path: qlib.rl.order_execution.interpreter
state_interpreter:
class: FullHistoryStateInterpreter
kwargs:
data_dim: 5
data_ticks: 48 # 48 = 240 min / 5 min
max_step: 8
processed_data_provider:
class: HandlerProcessedDataProvider
kwargs:
data_dir: ./data/pickle/
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
backtest: false
module_path: qlib.rl.data.native
module_path: qlib.rl.order_execution.interpreter
reward:
class: PPOReward
kwargs:
max_step: 8
start_time_index: 0
end_time_index: 46 # 46 = (240 - 5) min / 5 min - 1
module_path: qlib.rl.order_execution.reward
data:
source:
order_dir: ./data/orders
feature_root_dir: ./data/pickle/
feature_columns_today: ["$close0", "$volume0"]
feature_columns_yesterday: []
total_time: 240
default_start_time_index: 0
default_end_time_index: 235
proc_data_dim: 5
num_workers: 0
queue_size: 20
network:
class: Recurrent
module_path: qlib.rl.order_execution.network
policy:
class: PPO # PPO, DQN
kwargs:
lr: 0.0001
module_path: qlib.rl.order_execution.policy
runtime:
seed: 42
use_cuda: false
trainer:
max_epoch: 500
repeat_per_collect: 25
earlystop_patience: 50
episode_per_collect: 10000
batch_size: 1024
val_every_n_epoch: 4
checkpoint_path: ./outputs/ppo
checkpoint_every_n_iters: 1

View File

@@ -4,6 +4,7 @@
import yaml
import argparse
import os
import shutil
from copy import deepcopy
from qlib.contrib.data.highfreq_provider import HighFreqProvider
@@ -41,3 +42,5 @@ if __name__ == "__main__":
if args.split == "stock" or args.split == "both":
provider._gen_stock_dataset(deepcopy(provider.feature_conf), "feature")
provider._gen_stock_dataset(deepcopy(provider.backtest_conf), "backtest")
shutil.rmtree("stat/", ignore_errors=True)

View File

@@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import numpy as np
import pandas as pd
from pathlib import Path
DATA_PATH = Path(os.path.join("data", "pickle", "backtest"))
OUTPUT_PATH = Path(os.path.join("data", "orders"))
def generate_order(stock: str, start_idx: int, end_idx: int) -> bool:
dataset = pd.read_pickle(DATA_PATH / f"{stock}.pkl")
df = dataset.handler.fetch(level=None).reset_index()
if len(df) == 0 or df.isnull().values.any() or min(df["$volume0"]) < 1e-5:
return False
df["date"] = df["datetime"].dt.date.astype("datetime64")
df = df.set_index(["instrument", "datetime", "date"])
df = df.groupby("date").take(range(start_idx, end_idx)).droplevel(level=0)
order_all = pd.DataFrame(df.groupby(level=(2, 0)).mean().dropna())
order_all["amount"] = np.random.lognormal(-3.28, 1.14) * order_all["$volume0"]
order_all = order_all[order_all["amount"] > 0.0]
order_all["order_type"] = 0
order_all = order_all.drop(columns=["$volume0"])
order_train = order_all[order_all.index.get_level_values(0) <= pd.Timestamp("2021-06-30")]
order_test = order_all[order_all.index.get_level_values(0) > pd.Timestamp("2021-06-30")]
order_valid = order_test[order_test.index.get_level_values(0) <= pd.Timestamp("2021-09-30")]
order_test = order_test[order_test.index.get_level_values(0) > pd.Timestamp("2021-09-30")]
for order, tag in zip((order_train, order_valid, order_test, order_all), ("train", "valid", "test", "all")):
path = OUTPUT_PATH / tag
os.makedirs(path, exist_ok=True)
if len(order) > 0:
order.to_pickle(path / f"{stock}.pkl.target")
return True
np.random.seed(1234)
file_list = sorted(os.listdir(DATA_PATH))
stocks = [f.replace(".pkl", "") for f in file_list]
np.random.shuffle(stocks)
cnt = 0
for stock in stocks:
if generate_order(stock, 0, 240 // 5 - 1):
cnt += 1
if cnt == 100:
break

View File

@@ -0,0 +1,15 @@
import pickle
import os
import pandas as pd
from tqdm import tqdm
for tag in ["test", "valid"]:
files = os.listdir(os.path.join("data/orders/", tag))
dfs = []
for f in tqdm(files):
df = pickle.load(open(os.path.join("data/orders/", tag, f), "rb"))
df = df.drop(["$close0"], axis=1)
dfs.append(df)
total_df = pd.concat(dfs)
pickle.dump(total_df, open(os.path.join("data", "orders", f"{tag}_orders.pkl"), "wb"))

View File

@@ -1,15 +1,16 @@
# start & end time for training/validation/test datasets
start_time: !!str &start 2020-01-01
end_time: !!str &end 2020-07-31
train_end_time: !!str &tend 2020-03-31
valid_start_time: !!str &vstart 2020-04-01
valid_end_time: !!str &vend 2020-05-31
test_start_time: !!str &tstart 2020-06-01
end_time: !!str &end 2021-12-31
train_end_time: !!str &tend 2021-06-30
valid_start_time: !!str &vstart 2021-07-01
valid_end_time: !!str &vend 2021-09-30
test_start_time: !!str &tstart 2021-10-01
# the instrument set
instruments: &ins all
instruments: &ins csi300s19_22
# qlib related configuration
qlib_conf:
provider_uri: ./data/bin # path to generated qlib bin
provider_uri:
5min: ./data/bin # path to generated qlib bin
redis_port: 233
feature_conf:
path: ./data/pickle/feature.pkl # output path of feature
@@ -26,14 +27,23 @@ feature_conf:
fit_end_time: *tend
instruments: *ins
day_length: 240 # how many minutes in one trading day
freq: 5min
columns: ["$open", "$high", "$low", "$close"]
infer_processors:
- class: HighFreqNorm
module_path: qlib.contrib.data.highfreq_processor
kwargs:
feature_save_dir: ./stat/ # output path of statistics of features (for feature normalization)
norm_groups:
price: 10
price: 8
volume: 2
inst_processors:
- class: TimeRangeFlt
module_path: qlib.data.dataset.processor
kwargs:
start_time: "2020-01-01"
end_time: "2021-12-31"
freq: 5min
segments:
train: !!python/tuple [*start, *tend]
valid: !!python/tuple [*vstart, *vend]
@@ -51,7 +61,17 @@ backtest_conf:
end_time: *end
instruments: *ins
day_length: 240
freq: 5min
columns: ["$close", "$volume"]
inst_processors:
- class: TimeRangeFlt
module_path: qlib.data.dataset.processor
kwargs:
start_time: "2020-01-01"
end_time: "2021-12-31"
freq: 5min
segments:
train: !!python/tuple [*start, *tend]
valid: !!python/tuple [*vstart, *vend]
test: !!python/tuple [*tstart, *end]
freq: 5min

View File

@@ -88,6 +88,7 @@
"outputs": [],
"source": [
"from qlib.tests.data import GetData\n",
"\n",
"GetData().qlib_data(exists_skip=True)"
]
},
@@ -99,6 +100,7 @@
"outputs": [],
"source": [
"import qlib\n",
"\n",
"qlib.init()"
]
},
@@ -134,7 +136,8 @@
"outputs": [],
"source": [
"from qlib.data import D\n",
"D.calendar(start_time='2010-01-01', end_time='2017-12-31', freq='day')[:2] # calendar data"
"\n",
"print(D.calendar(start_time=\"2010-01-01\", end_time=\"2017-12-31\", freq=\"day\")[:2]) # calendar data"
]
},
{
@@ -152,7 +155,12 @@
"metadata": {},
"outputs": [],
"source": [
"df = D.features(['SH601216'], ['$open', '$high', '$low', '$close', '$factor'], start_time='2020-05-01', end_time='2020-05-31') "
"df = D.features(\n",
" [\"SH601216\"],\n",
" [\"$open\", \"$high\", \"$low\", \"$close\", \"$factor\"],\n",
" start_time=\"2020-05-01\",\n",
" end_time=\"2020-05-31\",\n",
")"
]
},
{
@@ -163,11 +171,18 @@
"outputs": [],
"source": [
"import plotly.graph_objects as go\n",
"fig = go.Figure(data=[go.Candlestick(x=df.index.get_level_values(\"datetime\"),\n",
" open=df['$open'],\n",
" high=df['$high'],\n",
" low=df['$low'],\n",
" close=df['$close'])])\n",
"\n",
"fig = go.Figure(\n",
" data=[\n",
" go.Candlestick(\n",
" x=df.index.get_level_values(\"datetime\"),\n",
" open=df[\"$open\"],\n",
" high=df[\"$high\"],\n",
" low=df[\"$low\"],\n",
" close=df[\"$close\"],\n",
" )\n",
" ]\n",
")\n",
"fig.show()"
]
},
@@ -197,11 +212,18 @@
"outputs": [],
"source": [
"import plotly.graph_objects as go\n",
"fig = go.Figure(data=[go.Candlestick(x=df.index.get_level_values(\"datetime\"),\n",
" open=df['$open'] / df['$factor'],\n",
" high=df['$high'] / df['$factor'],\n",
" low=df['$low'] / df['$factor'],\n",
" close=df['$close'] / df['$factor'])])\n",
"\n",
"fig = go.Figure(\n",
" data=[\n",
" go.Candlestick(\n",
" x=df.index.get_level_values(\"datetime\"),\n",
" open=df[\"$open\"] / df[\"$factor\"],\n",
" high=df[\"$high\"] / df[\"$factor\"],\n",
" low=df[\"$low\"] / df[\"$factor\"],\n",
" close=df[\"$close\"] / df[\"$factor\"],\n",
" )\n",
" ]\n",
")\n",
"fig.show()"
]
},
@@ -240,7 +262,7 @@
"outputs": [],
"source": [
"# dynamic universe\n",
"universe = D.list_instruments(D.instruments('csi100'), start_time='2010-01-01', end_time='2020-12-31')\n",
"universe = D.list_instruments(D.instruments(\"csi100\"), start_time=\"2010-01-01\", end_time=\"2020-12-31\")\n",
"pprint(universe)"
]
},
@@ -271,8 +293,8 @@
"metadata": {},
"outputs": [],
"source": [
"df = D.features(D.instruments('csi100'), ['$close'], start_time='2010-01-01', end_time='2020-12-31') \n",
"df.groupby('datetime').size().plot()"
"df = D.features(D.instruments(\"csi100\"), [\"$close\"], start_time=\"2010-01-01\", end_time=\"2020-12-31\")\n",
"df.groupby(\"datetime\").size().plot()"
]
},
{
@@ -313,8 +335,7 @@
" !cd ../../scripts/data_collector/pit/ && pip install -r requirements.txt\n",
" !cd ../../scripts/data_collector/pit/ && python collector.py download_data --source_dir ~/.qlib/stock_data/source/pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_regex \"^(600519|000725).*\"\n",
" !cd ../../scripts/data_collector/pit/ && python collector.py normalize_data --interval quarterly --source_dir ~/.qlib/stock_data/source/pit --normalize_dir ~/.qlib/stock_data/source/pit_normalized\n",
" !cd ../../scripts/ && python dump_pit.py dump --csv_path ~/.qlib/stock_data/source/pit_normalized --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly\n",
" pass"
" !cd ../../scripts/ && python dump_pit.py dump --csv_path ~/.qlib/stock_data/source/pit_normalized --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly"
]
},
{
@@ -338,7 +359,13 @@
"outputs": [],
"source": [
"instruments = [\"sh600519\"]\n",
"data = D.features(instruments, ['P($$roewa_q)'], start_time=\"2019-01-01\", end_time=\"2019-07-19\", freq=\"day\")"
"data = D.features(\n",
" instruments,\n",
" [\"P($$roewa_q)\"],\n",
" start_time=\"2019-01-01\",\n",
" end_time=\"2019-07-19\",\n",
" freq=\"day\",\n",
")"
]
},
{
@@ -366,7 +393,10 @@
"metadata": {},
"outputs": [],
"source": [
"D.features([\"sh600519\"], ['(EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9)/$close'])"
"D.features(\n",
" [\"sh600519\"],\n",
" [\"(EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9)/$close\"],\n",
")"
]
},
{
@@ -418,7 +448,7 @@
"metadata": {},
"outputs": [],
"source": [
"qdl = QlibDataLoader(config=(['$close / Ref($close, 10)'], ['RET10']))"
"qdl = QlibDataLoader(config=([\"$close / Ref($close, 10)\"], [\"RET10\"]))"
]
},
{
@@ -428,7 +458,7 @@
"metadata": {},
"outputs": [],
"source": [
"qdl.load(instruments=['sh600519'], start_time='20190101', end_time='20191231')"
"qdl.load(instruments=[\"sh600519\"], start_time=\"20190101\", end_time=\"20191231\")"
]
},
{
@@ -456,7 +486,7 @@
"metadata": {},
"outputs": [],
"source": [
"df = qdl.load(instruments=['sh600519'], start_time='20190101', end_time='20191231')"
"df = qdl.load(instruments=[\"sh600519\"], start_time=\"20190101\", end_time=\"20191231\")"
]
},
{
@@ -476,7 +506,7 @@
"metadata": {},
"outputs": [],
"source": [
"df.plot(kind='hist')"
"df.plot(kind=\"hist\")"
]
},
{
@@ -508,9 +538,16 @@
"source": [
"# NOTE: normally, the training & validation time range will be `fit_start_time` `fit_end_time`\n",
"# howeverall the components are decomposed, so the training & validation time range is unknown when preprocessing.\n",
"dh = DataHandlerLP(instruments=['sh600519'], start_time='20170101', end_time='20191231',\n",
" infer_processors=[ZScoreNorm(fit_start_time='20170101', fit_end_time='20181231'), Fillna()],\n",
" data_loader=qdl)"
"dh = DataHandlerLP(\n",
" instruments=[\"sh600519\"],\n",
" start_time=\"20170101\",\n",
" end_time=\"20191231\",\n",
" infer_processors=[\n",
" ZScoreNorm(fit_start_time=\"20170101\", fit_end_time=\"20181231\"),\n",
" Fillna(),\n",
" ],\n",
" data_loader=qdl,\n",
")"
]
},
{
@@ -550,7 +587,7 @@
"metadata": {},
"outputs": [],
"source": [
"df.plot(kind='hist')"
"df.plot(kind=\"hist\")"
]
},
{
@@ -586,7 +623,7 @@
"metadata": {},
"outputs": [],
"source": [
"ds = DatasetH(dh, segments={\"train\": ('20180101', '20181231'), \"valid\": ('20190101', '20191231')})"
"ds = DatasetH(dh, segments={\"train\": (\"20180101\", \"20181231\"), \"valid\": (\"20190101\", \"20191231\")})"
]
},
{
@@ -596,7 +633,7 @@
"metadata": {},
"outputs": [],
"source": [
"ds.prepare('train')"
"ds.prepare(\"train\")"
]
},
{
@@ -606,7 +643,7 @@
"metadata": {},
"outputs": [],
"source": [
"ds.prepare('valid')"
"ds.prepare(\"valid\")"
]
},
{
@@ -628,8 +665,12 @@
"metadata": {},
"outputs": [],
"source": [
"ds = TSDatasetH(step_len=10, handler=dh, segments={\"train\": ('20180101', '20181231'), \"valid\": ('20190101', '20191231')})\n",
"train_sampler = ds.prepare('train')"
"ds = TSDatasetH(\n",
" step_len=10,\n",
" handler=dh,\n",
" segments={\"train\": (\"20180101\", \"20181231\"), \"valid\": (\"20190101\", \"20191231\")},\n",
")\n",
"train_sampler = ds.prepare(\"train\")"
]
},
{
@@ -649,7 +690,7 @@
"metadata": {},
"outputs": [],
"source": [
"train_sampler[0] # Retrieving the first example"
"train_sampler[0] # Retrieving the first example"
]
},
{
@@ -659,7 +700,7 @@
"metadata": {},
"outputs": [],
"source": [
"train_sampler['2018-01-08', 'sh600519'] # get the time series by <'timestamp', 'instrument_id'> index"
"train_sampler[\"2018-01-08\", \"sh600519\"] # get the time series by <'timestamp', 'instrument_id'> index"
]
},
{
@@ -682,11 +723,11 @@
"outputs": [],
"source": [
"handler_kwargs = {\n",
" \"start_time\": \"2008-01-01\",\n",
" \"end_time\": \"2020-08-01\",\n",
" \"fit_start_time\": \"2008-01-01\",\n",
" \"fit_end_time\": \"2014-12-31\",\n",
" \"instruments\": MARKET,\n",
" \"start_time\": \"2008-01-01\",\n",
" \"end_time\": \"2020-08-01\",\n",
" \"fit_start_time\": \"2008-01-01\",\n",
" \"fit_end_time\": \"2014-12-31\",\n",
" \"instruments\": MARKET,\n",
"}\n",
"handler_conf = {\n",
" \"class\": \"Alpha158\",\n",
@@ -735,6 +776,7 @@
"outputs": [],
"source": [
"from qlib.contrib.data.handler import Alpha158\n",
"\n",
"hd = Alpha158(**handler_kwargs)"
]
},
@@ -826,7 +868,7 @@
"metadata": {},
"outputs": [],
"source": [
"hd.process_type # appending type"
"hd.process_type # appending type"
]
},
{
@@ -857,16 +899,16 @@
"outputs": [],
"source": [
"dataset_conf = {\n",
" \"class\": \"DatasetH\",\n",
" \"module_path\": \"qlib.data.dataset\",\n",
" \"kwargs\": {\n",
" \"handler\": hd,\n",
" \"segments\": {\n",
" \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
" \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
" \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
" },\n",
" \"class\": \"DatasetH\",\n",
" \"module_path\": \"qlib.data.dataset\",\n",
" \"kwargs\": {\n",
" \"handler\": hd,\n",
" \"segments\": {\n",
" \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
" \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
" \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
" },\n",
" },\n",
"}"
]
},
@@ -908,7 +950,8 @@
"metadata": {},
"outputs": [],
"source": [
"model = init_instance_by_config({\n",
"model = init_instance_by_config(\n",
" {\n",
" \"class\": \"LGBModel\",\n",
" \"module_path\": \"qlib.contrib.model.gbdt\",\n",
" \"kwargs\": {\n",
@@ -922,7 +965,8 @@
" \"num_leaves\": 210,\n",
" \"num_threads\": 20,\n",
" },\n",
"})"
" }\n",
")"
]
},
{
@@ -938,7 +982,7 @@
" R.save_objects(trained_model=model)\n",
"\n",
" rec = R.get_recorder()\n",
" rid = rec.id # save the record id\n",
" rid = rec.id # save the record id\n",
"\n",
" # Inference and saving signal\n",
" sr = SignalRecord(model, dataset, rec)\n",
@@ -1001,12 +1045,11 @@
"\n",
"# backtest and analysis\n",
"with R.start(experiment_name=EXP_NAME, recorder_id=rid, resume=True):\n",
"\n",
" # signal-based analysis\n",
" rec = R.get_recorder()\n",
" sar = SigAnaRecord(rec)\n",
" sar.generate()\n",
" \n",
"\n",
" # portfolio-based analysis: backtest\n",
" par = PortAnaRecord(rec, port_analysis_config, \"day\")\n",
" par.generate()"
@@ -1137,7 +1180,7 @@
"outputs": [],
"source": [
"label_df = dataset.prepare(\"test\", col_set=\"label\")\n",
"label_df.columns = ['label']"
"label_df.columns = [\"label\"]"
]
},
{

View File

@@ -38,7 +38,7 @@
" # install qlib\n",
" ! pip install --upgrade numpy\n",
" ! pip install pyqlib\n",
" if 'google.colab' in sys.modules:\n",
" if \"google.colab\" in sys.modules:\n",
" # The Google colab environment is a little outdated. We have to downgrade the pyyaml to make it compatible with other packages\n",
" ! pip install pyyaml==5.4.1\n",
" # reload\n",
@@ -50,7 +50,8 @@
" scripts_dir = Path(\"~/tmp/qlib_code/scripts\").expanduser().resolve()\n",
" scripts_dir.mkdir(parents=True, exist_ok=True)\n",
" import requests\n",
" with requests.get(\"https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py\") as resp:\n",
"\n",
" with requests.get(\"https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py\", timeout=10) as resp:\n",
" with open(scripts_dir.joinpath(\"get_data.py\"), \"wb\") as fp:\n",
" fp.write(resp.content)"
]
@@ -61,14 +62,13 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"import qlib\n",
"import pandas as pd\n",
"from qlib.constant import REG_CN\n",
"from qlib.utils import exists_qlib_data, init_instance_by_config\n",
"from qlib.workflow import R\n",
"from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n",
"from qlib.utils import flatten_dict\n"
"from qlib.utils import flatten_dict"
]
},
{
@@ -86,6 +86,7 @@
" print(f\"Qlib data is not found in {provider_uri}\")\n",
" sys.path.append(str(scripts_dir))\n",
" from get_data import GetData\n",
"\n",
" GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n",
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
]
@@ -169,7 +170,7 @@
" R.log_params(**flatten_dict(task))\n",
" model.fit(dataset)\n",
" R.save_objects(trained_model=model)\n",
" rid = R.get_recorder().id\n"
" rid = R.get_recorder().id"
]
},
{
@@ -238,7 +239,7 @@
"\n",
" # backtest & analysis\n",
" par = PortAnaRecord(recorder, port_analysis_config, \"day\")\n",
" par.generate()\n"
" par.generate()"
]
},
{
@@ -256,6 +257,7 @@
"source": [
"from qlib.contrib.report import analysis_model, analysis_position\n",
"from qlib.data import D\n",
"\n",
"recorder = R.get_recorder(recorder_id=ba_rid, experiment_name=\"backtest_analysis\")\n",
"print(recorder)\n",
"pred_df = recorder.load_object(\"pred.pkl\")\n",
@@ -317,7 +319,7 @@
"outputs": [],
"source": [
"label_df = dataset.prepare(\"test\", col_set=\"label\")\n",
"label_df.columns = ['label']"
"label_df.columns = [\"label\"]"
]
},
{

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
from pathlib import Path
__version__ = "0.9.1"
__version__ = "0.9.1.99"
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
import os
from typing import Union

View File

@@ -40,8 +40,8 @@ def get_exchange(
open_cost: float = 0.0015,
close_cost: float = 0.0025,
min_cost: float = 5.0,
limit_threshold: Union[Tuple[str, str], float, None] = None,
deal_price: Union[str, Tuple[str, str], List[str]] = None,
limit_threshold: Union[Tuple[str, str], float, None] | None = None,
deal_price: Union[str, Tuple[str, str], List[str]] | None = None,
**kwargs: Any,
) -> Exchange:
"""get_exchange
@@ -284,7 +284,7 @@ def collect_data(
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
return_value: dict = None,
return_value: dict | None = None,
) -> Generator[object, None, None]:
"""initialize the strategy and executor, then collect the trade decision data for rl training

View File

@@ -152,7 +152,9 @@ class Account:
# trading related metrics(e.g. high-frequency trading)
self.indicator = Indicator()
def reset(self, freq: str = None, benchmark_config: dict = None, port_metr_enabled: bool = None) -> None:
def reset(
self, freq: str | None = None, benchmark_config: dict | None = None, port_metr_enabled: bool | None = None
) -> None:
"""reset freq and report of account
Parameters

View File

@@ -55,7 +55,7 @@ def collect_data_loop(
end_time: Union[pd.Timestamp, str],
trade_strategy: BaseStrategy,
trade_executor: BaseExecutor,
return_value: dict = None,
return_value: dict | None = None,
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], None]:
"""Generator for collecting the trade decision data for rl training

View File

@@ -254,7 +254,7 @@ class IdxTradeRange(TradeRange):
self._start_idx = start_idx
self._end_idx = end_idx
def __call__(self, trade_calendar: TradeCalendarManager = None) -> Tuple[int, int]:
def __call__(self, trade_calendar: TradeCalendarManager | None = None) -> Tuple[int, int]:
return self._start_idx, self._end_idx
def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:
@@ -315,7 +315,7 @@ class BaseTradeDecision(Generic[DecisionType]):
2. Same as `case 1.3`
"""
def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None) -> None:
def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange, None] = None) -> None:
"""
Parameters
----------
@@ -554,7 +554,7 @@ class TradeDecisionWO(BaseTradeDecision[Order]):
self,
order_list: List[Order],
strategy: BaseStrategy,
trade_range: Union[Tuple[int, int], TradeRange] = None,
trade_range: Union[Tuple[int, int], TradeRange, None] = None,
) -> None:
super().__init__(strategy, trade_range=trade_range)
self.order_list = cast(List[Order], order_list)

View File

@@ -41,10 +41,10 @@ class Exchange:
start_time: Union[pd.Timestamp, str] = None,
end_time: Union[pd.Timestamp, str] = None,
codes: Union[list, str] = "all",
deal_price: Union[str, Tuple[str, str], List[str]] = None,
deal_price: Union[str, Tuple[str, str], List[str], None] = None,
subscribe_fields: list = [],
limit_threshold: Union[Tuple[str, str], float, None] = None,
volume_threshold: Union[tuple, dict] = None,
volume_threshold: Union[tuple, dict, None] = None,
open_cost: float = 0.0015,
close_cost: float = 0.0025,
min_cost: float = 5.0,
@@ -340,7 +340,7 @@ class Exchange:
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
direction: int = None,
direction: int | None = None,
) -> bool:
"""
Parameters
@@ -406,7 +406,7 @@ class Exchange:
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
direction: int = None,
direction: int | None = None,
) -> bool:
# check if stock can be traded
return not (
@@ -421,8 +421,8 @@ class Exchange:
def deal_order(
self,
order: Order,
trade_account: Account = None,
position: BasePosition = None,
trade_account: Account | None = None,
position: BasePosition | None = None,
dealt_order_amount: Dict[str, float] = defaultdict(float),
) -> Tuple[float, float, float]:
"""
@@ -586,7 +586,7 @@ class Exchange:
)
return amount_dict
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float = None) -> float:
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float | None = None) -> float:
"""
Calculate the real adjust deal amount when considering the trading unit
:param current_amount:
@@ -712,8 +712,8 @@ class Exchange:
def _get_factor_or_raise_error(
self,
factor: float = None,
stock_id: str = None,
factor: float | None = None,
stock_id: str | None = None,
start_time: pd.Timestamp = None,
end_time: pd.Timestamp = None,
) -> float:
@@ -728,8 +728,8 @@ class Exchange:
def get_amount_of_trade_unit(
self,
factor: float = None,
stock_id: str = None,
factor: float | None = None,
stock_id: str | None = None,
start_time: pd.Timestamp = None,
end_time: pd.Timestamp = None,
) -> Optional[float]:
@@ -762,8 +762,8 @@ class Exchange:
def round_amount_by_trade_unit(
self,
deal_amount: float,
factor: float = None,
stock_id: str = None,
factor: float | None = None,
stock_id: str | None = None,
start_time: pd.Timestamp = None,
end_time: pd.Timestamp = None,
) -> float:

View File

@@ -31,8 +31,8 @@ class BaseExecutor:
generate_portfolio_metrics: bool = False,
verbose: bool = False,
track_data: bool = False,
trade_exchange: Exchange = None,
common_infra: CommonInfrastructure = None,
trade_exchange: Exchange | None = None,
common_infra: CommonInfrastructure | None = None,
settle_type: str = BasePosition.ST_NO,
**kwargs: Any,
) -> None:
@@ -161,7 +161,7 @@ class BaseExecutor:
"""
return self.level_infra.get("trade_calendar")
def reset(self, common_infra: CommonInfrastructure = None, **kwargs: Any) -> None:
def reset(self, common_infra: CommonInfrastructure | None = None, **kwargs: Any) -> None:
"""
- reset `start_time` and `end_time`, used in trade calendar
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
@@ -227,7 +227,7 @@ class BaseExecutor:
def collect_data(
self,
trade_decision: BaseTradeDecision,
return_value: dict = None,
return_value: dict | None = None,
level: int = 0,
) -> Generator[Any, Any, List[object]]:
"""Generator for collecting the trade decision data for rl training
@@ -327,7 +327,7 @@ class NestedExecutor(BaseExecutor):
track_data: bool = False,
skip_empty_decision: bool = True,
align_range_limit: bool = True,
common_infra: CommonInfrastructure = None,
common_infra: CommonInfrastructure | None = None,
**kwargs: Any,
) -> None:
"""
@@ -534,7 +534,7 @@ class SimulatorExecutor(BaseExecutor):
generate_portfolio_metrics: bool = False,
verbose: bool = False,
track_data: bool = False,
common_infra: CommonInfrastructure = None,
common_infra: CommonInfrastructure | None = None,
trade_type: str = TT_SERIAL,
**kwargs: Any,
) -> None:

View File

@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from datetime import timedelta
from typing import Any, Dict, List, Union
@@ -320,7 +321,7 @@ class Position(BasePosition):
self.position[stock]["price"] = price_dict[stock]
self.position["now_account_value"] = self.calculate_value()
def _init_stock(self, stock_id: str, amount: float, price: float = None) -> None:
def _init_stock(self, stock_id: str, amount: float, price: float | None = None) -> None:
"""
initialization the stock in current position

View File

@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import pathlib
from collections import OrderedDict
@@ -86,7 +87,7 @@ class PortfolioMetrics:
self.benches: dict = OrderedDict()
self.latest_pm_time: Optional[pd.TimeStamp] = None
def init_bench(self, freq: str = None, benchmark_config: dict = None) -> None:
def init_bench(self, freq: str | None = None, benchmark_config: dict | None = None) -> None:
if freq is not None:
self.freq = freq
self.benchmark_config = benchmark_config
@@ -149,15 +150,15 @@ class PortfolioMetrics:
self,
trade_start_time: Union[str, pd.Timestamp] = None,
trade_end_time: Union[str, pd.Timestamp] = None,
account_value: float = None,
cash: float = None,
return_rate: float = None,
total_turnover: float = None,
turnover_rate: float = None,
total_cost: float = None,
cost_rate: float = None,
stock_value: float = None,
bench_value: float = None,
account_value: float | None = None,
cash: float | None = None,
return_rate: float | None = None,
total_turnover: float | None = None,
turnover_rate: float | None = None,
total_cost: float | None = None,
cost_rate: float | None = None,
stock_value: float | None = None,
bench_value: float | None = None,
) -> None:
# check data
if None in [

View File

@@ -31,7 +31,7 @@ class TradeCalendarManager:
freq: str,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
level_infra: LevelInfrastructure = None,
level_infra: LevelInfrastructure | None = None,
) -> None:
"""
Parameters
@@ -99,7 +99,7 @@ class TradeCalendarManager:
def get_trade_step(self) -> int:
return self.trade_step
def get_step_time(self, trade_step: int = None, shift: int = 0) -> Tuple[pd.Timestamp, pd.Timestamp]:
def get_step_time(self, trade_step: int | None = None, shift: int = 0) -> Tuple[pd.Timestamp, pd.Timestamp]:
"""
Get the left and right endpoints of the trade_step'th trading interval

View File

@@ -147,6 +147,7 @@ _default_config = {
"redis_host": "127.0.0.1",
"redis_port": 6379,
"redis_task_db": 1,
"redis_password": None,
# This value can be reset via qlib.init
"logging_level": logging.INFO,
# Global configuration of qlib log

111
qlib/contrib/analyzer.py Normal file
View File

@@ -0,0 +1,111 @@
import logging
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
from ..log import get_module_logger
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
logger = get_module_logger("analysis", logging.INFO)
class AnalyzerTemp:
def __init__(self, recorder, output_dir=None, **kwargs):
self.recorder = recorder
self.output_dir = Path(output_dir) if output_dir else "./"
def load(self, name: str):
"""
It behaves the same as self.recorder.load_object.
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
Parameters
----------
name : str
the name for the file to be load.
Return
------
The stored records.
"""
return self.recorder.load_object(name)
def analyse(self, **kwargs):
"""
Analyse data index, distribution .etc
Parameters
----------
Return
------
The handled data.
"""
raise NotImplementedError(f"Please implement the `analysis` method.")
class HFAnalyzer(AnalyzerTemp):
"""
This is the Signal Analysis class that generates the analysis results such as IC and IR.
default output image filename is "HFAnalyzerTable.jpeg"
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def analyse(self):
pred = self.load("pred.pkl")
label = self.load("label.pkl")
long_pre, short_pre = calc_long_short_prec(pred.iloc[:, 0], label.iloc[:, 0], is_alpha=True)
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
metrics = {
"IC": ic.mean(),
"ICIR": ic.mean() / ic.std(),
"Rank IC": ric.mean(),
"Rank ICIR": ric.mean() / ric.std(),
"Long precision": long_pre.mean(),
"Short precision": short_pre.mean(),
}
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0])
metrics.update(
{
"Long-Short Average Return": long_short_r.mean(),
"Long-Short Average Sharpe": long_short_r.mean() / long_short_r.std(),
}
)
table = [[k, v] for (k, v) in metrics.items()]
plt.table(cellText=table, loc="center")
plt.axis("off")
plt.savefig(self.output_dir.joinpath("HFAnalyzerTable.jpeg"))
plt.clf()
plt.scatter(np.arange(0, len(pred)), pred.iloc[:, 0])
plt.scatter(np.arange(0, len(label)), label.iloc[:, 0])
plt.title("HFAnalyzer")
plt.savefig(self.output_dir.joinpath("HFAnalyzer.jpeg"))
return "HFAnalyzer.jpeg"
class SignalAnalyzer(AnalyzerTemp):
"""
This is the Signal Analysis class that generates the analysis results such as IC and IR.
default output image filename is "signalAnalysis.jpeg"
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def analyse(self, dataset=None, **kwargs):
label = self.load("label.pkl")
plt.hist(label)
plt.title("SignalAnalyzer")
plt.savefig(self.output_dir.joinpath("signalAnalysis.jpeg"))
return "signalAnalysis.jpeg"

View File

@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Optional
from qlib.utils.data import update_config
from ...data.dataset.handler import DataHandlerLP
from ...data.dataset.processor import Processor
from ...utils import get_callable_kwargs
@@ -56,13 +58,14 @@ class Alpha360(DataHandlerLP):
fit_start_time=None,
fit_end_time=None,
filter_pipe=None,
inst_processor=None,
inst_processors=None,
data_loader: Optional[dict] = None,
**kwargs
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
data_loader = {
_data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"config": {
@@ -71,15 +74,17 @@ class Alpha360(DataHandlerLP):
},
"filter_pipe": filter_pipe,
"freq": freq,
"inst_processor": inst_processor,
"inst_processors": inst_processors,
},
}
if data_loader is not None:
update_config(_data_loader, data_loader)
super().__init__(
instruments=instruments,
start_time=start_time,
end_time=end_time,
data_loader=data_loader,
data_loader=_data_loader,
learn_processors=learn_processors,
infer_processors=infer_processors,
**kwargs
@@ -152,13 +157,14 @@ class Alpha158(DataHandlerLP):
fit_end_time=None,
process_type=DataHandlerLP.PTYPE_A,
filter_pipe=None,
inst_processor=None,
inst_processors=None,
data_loader: Optional[dict] = None,
**kwargs
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
data_loader = {
_data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"config": {
@@ -167,14 +173,16 @@ class Alpha158(DataHandlerLP):
},
"filter_pipe": filter_pipe,
"freq": freq,
"inst_processor": inst_processor,
"inst_processors": inst_processors,
},
}
if data_loader is not None:
update_config(_data_loader, data_loader)
super().__init__(
instruments=instruments,
start_time=start_time,
end_time=end_time,
data_loader=data_loader,
data_loader=_data_loader,
infer_processors=infer_processors,
learn_processors=learn_processors,
process_type=process_type,

View File

@@ -44,7 +44,7 @@ class HighFreqHandler(DataHandlerLP):
names = []
template_if = "If(IsNull({1}), {0}, {1})"
template_paused = "Select(Gt($hx_paused_num, 1.001), {0})"
template_paused = "Select(Gt($paused_num, 1.001), {0})"
def get_normalized_price_feature(price_field, shift=0):
# norm with the close price of 237th minute of yesterday.
@@ -115,6 +115,7 @@ class HighFreqGeneralHandler(DataHandlerLP):
day_length=240,
freq="1min",
columns=["$open", "$high", "$low", "$close", "$vwap"],
inst_processors=None,
):
self.day_length = day_length
self.columns = columns
@@ -128,6 +129,7 @@ class HighFreqGeneralHandler(DataHandlerLP):
"config": self.get_feature_config(),
"swap_level": False,
"freq": freq,
"inst_processors": inst_processors,
},
}
super().__init__(
@@ -257,6 +259,7 @@ class HighFreqGeneralBacktestHandler(DataHandler):
day_length=240,
freq="1min",
columns=["$close", "$vwap", "$volume"],
inst_processors=None,
):
self.day_length = day_length
self.columns = set(columns)
@@ -266,6 +269,7 @@ class HighFreqGeneralBacktestHandler(DataHandler):
"config": self.get_feature_config(),
"swap_level": False,
"freq": freq,
"inst_processors": inst_processors,
},
}
super().__init__(
@@ -311,6 +315,7 @@ class HighFreqOrderHandler(DataHandlerLP):
learn_processors=[],
fit_start_time=None,
fit_end_time=None,
inst_processors=None,
drop_raw=True,
):
@@ -323,6 +328,7 @@ class HighFreqOrderHandler(DataHandlerLP):
"config": self.get_feature_config(),
"swap_level": False,
"freq": "1min",
"inst_processors": inst_processors,
},
}
super().__init__(
@@ -482,7 +488,7 @@ class HighFreqBacktestOrderHandler(DataHandler):
names = []
template_if = "If(IsNull({1}), {0}, {1})"
template_paused = "Select(Gt($hx_paused_num, 1.001), {0})"
template_paused = "Select(Gt($paused_num, 1.001), {0})"
template_fillnan = "FFillNan({0})"
fields += [
template_fillnan.format(template_paused.format("$close")),

View File

@@ -128,7 +128,7 @@ class HighFreqProvider:
raise ValueError("Must specify the path to save the dataset.") from e
if os.path.isfile(path):
start = time.time()
self.logger.info("Dataset exists, load from disk.", __name__)
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
# res = dataset.prepare(['train', 'valid', 'test'])
with open(path, "rb") as f:
@@ -137,11 +137,11 @@ class HighFreqProvider:
res = [data[i] for i in datasets]
else:
res = data.prepare(datasets)
self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}")
else:
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
self.logger.info("Generating dataset", __name__)
self.logger.info(f"[{__name__}]Generating dataset")
start_time = time.time()
self._prepare_calender_cache()
dataset = init_instance_by_config(config)
@@ -160,7 +160,7 @@ class HighFreqProvider:
with open(path[:-4] + "test.pkl", "wb") as f:
pkl.dump(testset, f)
res = [data[i] for i in datasets]
self.logger.info(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
self.logger.info(f"[{__name__}]Data generated, time cost: {(time.time() - start_time):.2f}")
return res
def _gen_data(self, config, datasets=["train", "valid", "test"]):
@@ -170,7 +170,7 @@ class HighFreqProvider:
raise ValueError("Must specify the path to save the dataset.") from e
if os.path.isfile(path):
start = time.time()
self.logger.info("Dataset exists, load from disk.", __name__)
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
# res = dataset.prepare(['train', 'valid', 'test'])
with open(path, "rb") as f:
@@ -179,18 +179,18 @@ class HighFreqProvider:
res = [data[i] for i in datasets]
else:
res = data.prepare(datasets)
self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}")
else:
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
self.logger.info("Generating dataset", __name__)
self.logger.info(f"[{__name__}]Generating dataset")
start_time = time.time()
self._prepare_calender_cache()
dataset = init_instance_by_config(config)
dataset.config(dump_all=True, recursive=True)
dataset.to_pickle(path)
res = dataset.prepare(datasets)
self.logger.info(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
self.logger.info(f"[{__name__}]Data generated, time cost: {(time.time() - start_time):.2f}")
return res
def _gen_dataset(self, config):
@@ -200,21 +200,21 @@ class HighFreqProvider:
raise ValueError("Must specify the path to save the dataset.") from e
if os.path.isfile(path):
start = time.time()
self.logger.info("Dataset exists, load from disk.", __name__)
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
with open(path, "rb") as f:
dataset = pkl.load(f)
self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}")
else:
start = time.time()
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
self.logger.info("Generating dataset", __name__)
self.logger.info(f"[{__name__}]Generating dataset")
self._prepare_calender_cache()
dataset = init_instance_by_config(config)
self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}")
dataset.prepare(["train", "valid", "test"])
self.logger.info(f"Dataset prepared, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"[{__name__}]Dataset prepared, time cost: {time.time() - start:.2f}")
dataset.config(dump_all=True, recursive=True)
dataset.to_pickle(path)
return dataset
@@ -227,15 +227,15 @@ class HighFreqProvider:
if os.path.isfile(path + "tmp_dataset.pkl"):
start = time.time()
self.logger.info("Dataset exists, load from disk.", __name__)
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
else:
start = time.time()
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
self.logger.info("Generating dataset", __name__)
self.logger.info(f"[{__name__}]Generating dataset")
self._prepare_calender_cache()
dataset = init_instance_by_config(config)
self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}")
dataset.config(dump_all=False, recursive=True)
dataset.to_pickle(path + "tmp_dataset.pkl")
@@ -268,15 +268,15 @@ class HighFreqProvider:
if os.path.isfile(path + "tmp_dataset.pkl"):
start = time.time()
self.logger.info("Dataset exists, load from disk.", __name__)
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
else:
start = time.time()
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
self.logger.info("Generating dataset", __name__)
self.logger.info(f"[{__name__}]Generating dataset")
self._prepare_calender_cache()
dataset = init_instance_by_config(config)
self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}")
dataset.config(dump_all=False, recursive=True)
dataset.to_pickle(path + "tmp_dataset.pkl")

View File

@@ -55,8 +55,10 @@ class InternalData:
# The handler is initialized for only once.
if not trainer.has_worker():
self.dh = init_task_handler(perf_task_tpl)
self.dh.config(dump_all=False) # in some cases, the data handler are saved to disk with `dump_all=True`
else:
self.dh = init_instance_by_config(perf_task_tpl["dataset"]["kwargs"]["handler"])
assert self.dh.dump_all is False # otherwise, it will save all the detailed data
seg = perf_task_tpl["dataset"]["kwargs"]["segments"]
@@ -77,7 +79,7 @@ class InternalData:
get_module_logger("Internal Data").info("the data has been initialized")
else:
# train new models
assert 0 == len(recorders), "An empty experiment is required for setup `InternalData``"
assert 0 == len(recorders), "An empty experiment is required for setup `InternalData`"
trainer.train(gen_task)
# 2) extract the similarity matrix
@@ -119,6 +121,7 @@ class MetaTaskDS(MetaTask):
def __init__(self, task: dict, meta_info: pd.DataFrame, mode: str = MetaTask.PROC_MODE_FULL, fill_method="max"):
"""
The description of the processed data
time_perf: A array with shape <hist_step_n * step, data pieces> -> data piece performance
@@ -132,6 +135,10 @@ class MetaTaskDS(MetaTask):
[0., 0., 0., ..., 0., 0., 1.],
[0., 0., 0., ..., 0., 0., 1.]])
Parameters
----------
meta_info: pd.DataFrame
please refer to the docs of _prepare_meta_ipt for detailed explanation.
"""
super().__init__(task, meta_info)
self.fill_method = fill_method
@@ -180,12 +187,41 @@ class MetaTaskDS(MetaTask):
self.processed_meta_input = data_to_tensor(self.processed_meta_input)
def _get_processed_meta_info(self):
meta_info_norm = self.meta_info.sub(self.meta_info.mean(axis=1), axis=0) # .fillna(0.)
if self.fill_method == "max":
meta_info_norm = meta_info_norm.T.fillna(
meta_info_norm.max(axis=1)
).T # fill it with row max to align with previous implementation
meta_info_norm = self.meta_info.sub(self.meta_info.mean(axis=1), axis=0)
if self.fill_method.startswith("max"):
suffix = self.fill_method.lstrip("max")
if suffix == "seg":
fill_value = {}
for col in meta_info_norm.columns:
fill_value[col] = meta_info_norm.loc[meta_info_norm[col].isna(), :].dropna(axis=1).mean().max()
fill_value = pd.Series(fill_value).sort_index()
# The NaN Values are filled segment-wise. Below is an exampleof fill_value
# 2009-01-05 2009-02-06 0.145809
# 2009-02-09 2009-03-06 0.148005
# 2009-03-09 2009-04-03 0.090385
# 2009-04-07 2009-05-05 0.114318
# 2009-05-06 2009-06-04 0.119328
# ...
meta_info_norm = meta_info_norm.fillna(fill_value)
else:
if len(suffix) > 0:
get_module_logger("MetaTaskDS").warning(
f"fill_method={self.fill_method}; the info after can't be correctly parsed. Please check your parameters."
)
fill_value = meta_info_norm.max(axis=1)
# fill it with row max to align with previous implementation
# This will magnify the data similarity when data is in daily freq
# the fill value corresponds to data like this
# It get a performance value for each day.
# The performance value are get from other models on this day
# 2009-01-16 0.276320
# 2009-01-19 0.280603
# ...
# 2011-06-27 0.203773
meta_info_norm = meta_info_norm.T.fillna(fill_value).T
elif self.fill_method == "zero":
# It will fillna(0.0) at the end.
pass
else:
raise NotImplementedError(f"This type of input is not supported")
@@ -286,7 +322,33 @@ class MetaDatasetDS(MetaTaskDataset):
logger.warning(f"ValueError: {e}")
assert len(self.meta_task_l) > 0, "No meta tasks found. Please check the data and setting"
def _prepare_meta_ipt(self, task):
def _prepare_meta_ipt(self, task) -> pd.DataFrame:
"""
Please refer to `self.internal_data.setup` for detailed information about `self.internal_data.data_ic_df`
Indices with format below can be successfully sliced by `ic_df.loc[:end, pd.IndexSlice[:, :end]]`
2021-06-21 2021-06-04 .. 2021-03-22 2021-03-08
2021-07-02 2021-06-18 .. 2021-04-02 None
Returns
-------
a pd.DataFrame with similar content below.
- each column corresponds to a trained model named by the training data range
- each row corresponds to a day of data tested by the models of the columns
- The rows cells that overlaps with the data used by columns are masked
2009-01-05 2009-02-09 ... 2011-04-27 2011-05-26
2009-02-06 2009-03-06 ... 2011-05-25 2011-06-23
datetime ...
2009-01-13 NaN 0.310639 ... -0.169057 0.137792
2009-01-14 NaN 0.261086 ... -0.143567 0.082581
... ... ... ... ... ...
2011-06-30 -0.054907 -0.020219 ... -0.023226 NaN
2011-07-01 -0.075762 -0.026626 ... -0.003167 NaN
"""
ic_df = self.internal_data.data_ic_df
segs = task["dataset"]["kwargs"]["segments"]
@@ -294,15 +356,19 @@ class MetaDatasetDS(MetaTaskDataset):
ic_df_avail = ic_df.loc[:end, pd.IndexSlice[:, :end]]
# meta data set focus on the **information** instead of preprocess
# 1) filter the future info
def mask_future(s):
"""mask future information"""
# from qlib.utils import get_date_by_shift
# 1) filter the overlap info
def mask_overlap(s):
"""
mask overlap information
data after self.name[end] with self.trunc_days that contains future info are also considered as overlap info
Approximately the diagnal + horizon length of data are masked.
"""
start, end = s.name
end = get_date_by_shift(trading_date=end, shift=self.trunc_days - 1, future=True)
return s.mask((s.index >= start) & (s.index <= end))
ic_df_avail = ic_df_avail.apply(mask_future) # apply to each col
ic_df_avail = ic_df_avail.apply(mask_overlap) # apply to each col
# 2) filter the info with too long periods
total_len = self.step * self.hist_step_n

View File

@@ -52,6 +52,7 @@ class MetaModelDS(MetaTaskModel):
lr=0.0001,
max_epoch=100,
seed=43,
alpha=0.0,
):
self.step = step
self.hist_step_n = hist_step_n
@@ -61,6 +62,7 @@ class MetaModelDS(MetaTaskModel):
self.lr = lr
self.max_epoch = max_epoch
self.fitted = False
self.alpha = alpha
torch.manual_seed(seed)
def run_epoch(self, phase, task_list, epoch, opt, loss_l, ignore_weight=False):
@@ -144,7 +146,11 @@ class MetaModelDS(MetaTaskModel):
) # debug: record when the test phase starts
self.tn = PredNet(
step=self.step, hist_step_n=self.hist_step_n, clip_weight=self.clip_weight, clip_method=self.clip_method
step=self.step,
hist_step_n=self.hist_step_n,
clip_weight=self.clip_weight,
clip_method=self.clip_method,
alpha=self.alpha,
)
opt = optim.Adam(self.tn.parameters(), lr=self.lr)

View File

@@ -41,11 +41,18 @@ class TimeWeightMeta(SingleMetaBase):
class PredNet(nn.Module):
def __init__(self, step, hist_step_n, clip_weight=None, clip_method="tanh"):
def __init__(self, step, hist_step_n, clip_weight=None, clip_method="tanh", alpha: float = 0.0):
"""
Parameters
----------
alpha : float
the regularization for sub model (useful when align meta model with linear submodel)
"""
super().__init__()
self.step = step
self.twm = TimeWeightMeta(hist_step_n=hist_step_n, clip_weight=clip_weight, clip_method=clip_method)
self.init_paramters(hist_step_n)
self.alpha = alpha
def get_sample_weights(self, X, time_perf, time_belong, ignore_weight=False):
weights = torch.from_numpy(np.ones(X.shape[0])).float().to(X.device)
@@ -59,7 +66,7 @@ class PredNet(nn.Module):
"""Please refer to the docs of MetaTaskDS for the description of the variables"""
weights = self.get_sample_weights(X, time_perf, time_belong, ignore_weight=ignore_weight)
X_w = X.T * weights.view(1, -1)
theta = torch.inverse(X_w @ X) @ X_w @ y
theta = torch.inverse(X_w @ X + self.alpha * torch.eye(X_w.shape[0])) @ X_w @ y
return X_test @ theta, weights
def init_paramters(self, hist_step_n):

View File

@@ -5,6 +5,9 @@ import numpy as np
import torch
from torch import nn
from qlib.constant import EPS
from qlib.log import get_module_logger
class ICLoss(nn.Module):
def forward(self, pred, y, idx, skip_size=50):
@@ -24,6 +27,7 @@ class ICLoss(nn.Module):
diff_point.append(i)
prev = date
diff_point.append(None)
# The lengths of diff_point will be one more larger then diff_point
ic_all = 0.0
skip_n = 0
@@ -34,13 +38,23 @@ class ICLoss(nn.Module):
skip_n += 1
continue
y_focus = y[start_i:end_i]
if pred_focus.std() < EPS or y_focus.std() < EPS:
# These cases often happend at the end of test data.
# Usually caused by fillna(0.)
skip_n += 1
continue
ic_day = torch.dot(
(pred_focus - pred_focus.mean()) / np.sqrt(pred_focus.shape[0]) / pred_focus.std(),
(y_focus - y_focus.mean()) / np.sqrt(y_focus.shape[0]) / y_focus.std(),
)
ic_all += ic_day
if len(diff_point) - 1 - skip_n <= 0:
raise ValueError("No enough data for calculating iC")
raise ValueError("No enough data for calculating IC")
if skip_n > 0:
get_module_logger("ICLoss").info(
f"{skip_n} days are skipped due to zero std or small scale of valid samples."
)
ic_mean = ic_all / (len(diff_point) - 1 - skip_n)
return -ic_mean # ic loss

View File

@@ -4,6 +4,7 @@
import numpy as np
import pandas as pd
from typing import Text, Union
from qlib.log import get_module_logger
from qlib.data.dataset.weight import Reweighter
from scipy.optimize import nnls
from sklearn.linear_model import LinearRegression, Ridge, Lasso
@@ -29,7 +30,7 @@ class LinearModel(Model):
RIDGE = "ridge"
LASSO = "lasso"
def __init__(self, estimator="ols", alpha=0.0, fit_intercept=False):
def __init__(self, estimator="ols", alpha=0.0, fit_intercept=False, include_valid: bool = False):
"""
Parameters
----------
@@ -39,6 +40,9 @@ class LinearModel(Model):
l1 or l2 regularization parameter
fit_intercept : bool
whether fit intercept
include_valid: bool
Should the validation data be included for training?
The validation data should be included
"""
assert estimator in [self.OLS, self.NNLS, self.RIDGE, self.LASSO], f"unsupported estimator `{estimator}`"
self.estimator = estimator
@@ -49,9 +53,16 @@ class LinearModel(Model):
self.fit_intercept = fit_intercept
self.coef_ = None
self.include_valid = include_valid
def fit(self, dataset: DatasetH, reweighter: Reweighter = None):
df_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
if self.include_valid:
try:
df_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
df_train = pd.concat([df_train, df_valid])
except KeyError:
get_module_logger("LinearModel").info("include_valid=True, but valid does not exist")
if df_train.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
if reweighter is not None:

View File

@@ -47,10 +47,6 @@ class DNNModelPytorch(Model):
layer sizes
lr : float
learning rate
lr_decay : float
learning rate decay
lr_decay_steps : int
learning rate decay steps
optimizer : str
optimizer name
GPU : int
@@ -64,8 +60,6 @@ class DNNModelPytorch(Model):
batch_size=2000,
early_stop_rounds=50,
eval_steps=20,
lr_decay=0.96,
lr_decay_steps=100,
optimizer="gd",
loss="mse",
GPU=0,
@@ -93,8 +87,6 @@ class DNNModelPytorch(Model):
self.batch_size = batch_size
self.early_stop_rounds = early_stop_rounds
self.eval_steps = eval_steps
self.lr_decay = lr_decay
self.lr_decay_steps = lr_decay_steps
self.optimizer = optimizer.lower()
self.loss_type = loss
if isinstance(GPU, str):
@@ -116,8 +108,6 @@ class DNNModelPytorch(Model):
f"\nbatch_size : {batch_size}"
f"\nearly_stop_rounds : {early_stop_rounds}"
f"\neval_steps : {eval_steps}"
f"\nlr_decay : {lr_decay}"
f"\nlr_decay_steps : {lr_decay_steps}"
f"\noptimizer : {optimizer}"
f"\nloss_type : {loss}"
f"\nseed : {seed}"

View File

@@ -70,7 +70,7 @@ class DayCumsum(ElemOperator):
Otherwise, the value is zero.
"""
def __init__(self, feature, start: str = "9:30", end: str = "14:59"):
def __init__(self, feature, start: str = "9:30", end: str = "14:59", data_granularity: int = 1):
self.feature = feature
self.start = datetime.strptime(start, "%H:%M")
self.end = datetime.strptime(end, "%H:%M")
@@ -80,15 +80,17 @@ class DayCumsum(ElemOperator):
self.noon_open = datetime.strptime("13:00", "%H:%M")
self.noon_close = datetime.strptime("15:00", "%H:%M")
self.start_id = time_to_day_index(self.start)
self.end_id = time_to_day_index(self.end)
self.data_granularity = data_granularity
self.start_id = time_to_day_index(self.start) // self.data_granularity
self.end_id = time_to_day_index(self.end) // self.data_granularity
assert 240 % self.data_granularity == 0
def period_cusum(self, df):
df = df.copy()
assert len(df) == 240
assert len(df) == 240 // self.data_granularity
df.iloc[0 : self.start_id] = 0
df = df.cumsum()
df.iloc[self.end_id + 1 : 240] = 0
df.iloc[self.end_id + 1 : 240 // self.data_granularity] = 0
return df
def _load_internal(self, instrument, start_index, end_index, freq):

View File

@@ -635,7 +635,7 @@ class FileOrderStrategy(BaseStrategy):
self.order_df = file
else:
with get_io_object(file) as f:
self.order_df = pd.read_csv(f, dtype={"datetime": np.str})
self.order_df = pd.read_csv(f, dtype={"datetime": str})
self.order_df["datetime"] = self.order_df["datetime"].apply(pd.Timestamp)
self.order_df = self.order_df.set_index(["datetime", "instrument"])

View File

@@ -783,7 +783,7 @@ class LocalPITProvider(PITProvider):
index_path = C.dpm.get_data_uri() / "financial" / instrument.lower() / f"{field}.index"
data_path = C.dpm.get_data_uri() / "financial" / instrument.lower() / f"{field}.data"
if not (index_path.exists() and data_path.exists()):
raise FileNotFoundError("No file is found. Raise exception and ")
raise FileNotFoundError("No file is found.")
# NOTE: The most significant performance loss is here.
# Does the acceleration that makes the program complicated really matters?
# - It makes parameters of the interface complicate
@@ -797,14 +797,14 @@ class LocalPITProvider(PITProvider):
cur_time_int = int(cur_time.year) * 10000 + int(cur_time.month) * 100 + int(cur_time.day)
loc = np.searchsorted(data["date"], cur_time_int, side="right")
if loc <= 0:
return pd.Series()
return pd.Series(dtype=C.pit_record_type["value"])
last_period = data["period"][:loc].max() # return the latest quarter
first_period = data["period"][:loc].min()
period_list = get_period_list(first_period, last_period, quarterly)
if period is not None:
# NOTE: `period` has higher priority than `start_index` & `end_index`
if period not in period_list:
return pd.Series()
return pd.Series(dtype=C.pit_record_type["value"])
else:
period_list = [period]
else:
@@ -868,7 +868,7 @@ class LocalExpressionProvider(ExpressionProvider):
# Ensure that each column type is consistent
# FIXME:
# 1) The stock data is currently float. If there is other types of data, this part needs to be re-implemented.
# 2) The the precision should be configurable
# 2) The precision should be configurable
try:
series = series.astype(np.float32)
except ValueError:

View File

@@ -417,7 +417,7 @@ class TSDataSampler:
# NOTE: bool(np.nan) is True !!!!!!!!
# make sure reindex comes first. Otherwise extra NaN may appear.
flt_data = flt_data.swaplevel()
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(bool)
self.flt_data = flt_data.values
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
self.data_index = self.data_index[np.where(self.flt_data)[0]]

View File

@@ -7,6 +7,7 @@ from typing import Callable, Union, Tuple, List, Iterator, Optional
import pandas as pd
from qlib.typehint import Literal
from ...log import get_module_logger, TimeInspector
from ...utils import init_instance_by_config
from ...utils.serial import Serializable
@@ -49,6 +50,8 @@ class DataHandler(Serializable):
- Fetching data with `col_set=CS_RAW` will return the raw data and may avoid pandas from copying the data when calling `loc`
"""
_data: pd.DataFrame # underlying data.
def __init__(
self,
instruments=None,
@@ -155,6 +158,11 @@ class DataHandler(Serializable):
"""
fetch data from underlying data source
Design motivation:
- providing a unified interface for underlying data.
- Potential to make the interface more friendly.
- User can improve performance when fetching data in this extra layer
Parameters
----------
selector : Union[pd.Timestamp, slice, str]
@@ -328,6 +336,9 @@ class DataHandler(Serializable):
yield cur_date, self.fetch(selector, **kwargs)
DATA_KEY_TYPE = Literal["raw", "infer", "learn"]
class DataHandlerLP(DataHandler):
"""
DataHandler with **(L)earnable (P)rocessor**
@@ -346,17 +357,28 @@ class DataHandlerLP(DataHandler):
- These processors only apply to the learning phase.
Tips to improve the performance of data handler
Tips for data handler
- To reduce the memory cost
- `drop_raw=True`: this will modify the data inplace on raw data;
- Please note processed data like `self._infer` or `self._learn` are concepts different from `segments` in Qlib's `Dataset` like "train" and "test"
- Processed data like `self._infer` or `self._learn` are underlying data processed with different processors
- `segments` in Qlib's `Dataset` like "train" and "test" are simply the time segmentations when querying data("train" are often before "test" in time-series).
- For example, you can query `data._infer` processed by `infer_processors` in the "train" time segmentation.
"""
# based on `self._data`, _infer and _learn are genrated after processors
_infer: pd.DataFrame # data for inference
_learn: pd.DataFrame # data for learning models
# data key
DK_R = "raw"
DK_I = "infer"
DK_L = "learn"
DK_R: DATA_KEY_TYPE = "raw"
DK_I: DATA_KEY_TYPE = "infer"
DK_L: DATA_KEY_TYPE = "learn"
# map data_key to attribute name
ATTR_MAP = {DK_R: "_data", DK_I: "_infer", DK_L: "_learn"}
# process type
@@ -600,7 +622,7 @@ class DataHandlerLP(DataHandler):
# TODO: Be able to cache handler data. Save the memory for data processing
def _get_df_by_key(self, data_key: str = DK_I) -> pd.DataFrame:
def _get_df_by_key(self, data_key: DATA_KEY_TYPE = DK_I) -> pd.DataFrame:
if data_key == self.DK_R and self.drop_raw:
raise AttributeError(
"DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data"
@@ -613,7 +635,7 @@ class DataHandlerLP(DataHandler):
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
level: Union[str, int] = "datetime",
col_set=DataHandler.CS_ALL,
data_key: str = DK_I,
data_key: DATA_KEY_TYPE = DK_I,
squeeze: bool = False,
proc_func: Callable = None,
) -> pd.DataFrame:
@@ -647,7 +669,7 @@ class DataHandlerLP(DataHandler):
proc_func=proc_func,
)
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list:
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: DATA_KEY_TYPE = DK_I) -> list:
"""
get the column names
@@ -655,7 +677,7 @@ class DataHandlerLP(DataHandler):
----------
col_set : str
select a set of meaningful columns.(e.g. features, columns).
data_key : str
data_key : DATA_KEY_TYPE
the data to fetch: DK_*.
Returns
@@ -698,3 +720,26 @@ class DataHandlerLP(DataHandler):
]:
setattr(new_hd, key, getattr(handler, key, None))
return new_hd
@classmethod
def from_df(cls, df: pd.DataFrame) -> "DataHandlerLP":
"""
Motivation:
- When user want to get a quick data handler.
The created data handler will have only one shared Dataframe without processors.
After creating the handler, user may often want to dump the handler for reuse
Here is a typical use case
.. code-block:: python
from qlib.data.dataset import DataHandlerLP
dh = DataHandlerLP.from_df(df)
dh.to_pickle(fname, dump_all=True)
TODO:
- The StaticDataLoader is quite slow. It don't have to copy the data again...
"""
loader = data_loader_module.StaticDataLoader(df)
return cls(data_loader=loader)

View File

@@ -153,7 +153,7 @@ class QlibDataLoader(DLWParser):
filter_pipe: List = None,
swap_level: bool = True,
freq: Union[str, dict] = "day",
inst_processor: dict = None,
inst_processors: Union[dict, list] = None,
):
"""
Parameters
@@ -167,16 +167,19 @@ class QlibDataLoader(DLWParser):
freq: dict or str
If type(config) == dict and type(freq) == str, load config data using freq.
If type(config) == dict and type(freq) == dict, load config[<group_name>] data using freq[<group_name>]
inst_processor: dict
If inst_processor is not None and type(config) == dict; load config[<group_name>] data using inst_processor[<group_name>]
inst_processors: dict | list
If inst_processors is not None and type(config) == dict; load config[<group_name>] data using inst_processors[<group_name>]
If inst_processors is a list, then it will be applied to all groups.
"""
self.filter_pipe = filter_pipe
self.swap_level = swap_level
self.freq = freq
# sample
self.inst_processor = inst_processor if inst_processor is not None else {}
assert isinstance(self.inst_processor, dict), f"inst_processor(={self.inst_processor}) must be dict"
self.inst_processors = inst_processors if inst_processors is not None else {}
assert isinstance(
self.inst_processors, (dict, list)
), f"inst_processors(={self.inst_processors}) must be dict or list"
super().__init__(config)
@@ -187,8 +190,8 @@ class QlibDataLoader(DLWParser):
if _gp not in freq:
raise ValueError(f"freq(={freq}) missing group(={_gp})")
assert (
self.inst_processor
), f"freq(={self.freq}), inst_processor(={self.inst_processor}) cannot be None/empty"
self.inst_processors
), f"freq(={self.freq}), inst_processors(={self.inst_processors}) cannot be None/empty"
def load_group_df(
self,
@@ -208,9 +211,10 @@ class QlibDataLoader(DLWParser):
warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")
freq = self.freq[gp_name] if isinstance(self.freq, dict) else self.freq
df = D.features(
instruments, exprs, start_time, end_time, freq=freq, inst_processors=self.inst_processor.get(gp_name, [])
inst_processors = (
self.inst_processors if isinstance(self.inst_processors, list) else self.inst_processors.get(gp_name, [])
)
df = D.features(instruments, exprs, start_time, end_time, freq=freq, inst_processors=inst_processors)
df.columns = names
if self.swap_level:
df = df.swaplevel().sort_index() # NOTE: if swaplevel, return <datetime, instrument>

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
import abc
from typing import Union, Text
from typing import Union, Text, Optional
import numpy as np
import pandas as pd
@@ -11,6 +11,8 @@ from ...constant import EPS
from .utils import fetch_df_by_index
from ...utils.serial import Serializable
from ...utils.paral import datetime_groupby_apply
from qlib.data.inst_processor import InstProcessor
from qlib.data import D
def get_group_columns(df: pd.DataFrame, group: Union[Text, None]):
@@ -378,3 +380,42 @@ class HashStockFormat(Processor):
from .storage import HashingStockStorage # pylint: disable=C0415
return HashingStockStorage.from_df(df)
class TimeRangeFlt(InstProcessor):
"""
This is a filter to filter stock.
Only keep the data that exist from start_time to end_time (the existence in the middle is not checked.)
WARNING: It may induce leakage!!!
"""
def __init__(
self,
start_time: Optional[Union[pd.Timestamp, str]] = None,
end_time: Optional[Union[pd.Timestamp, str]] = None,
freq: str = "day",
):
"""
Parameters
----------
start_time : Optional[Union[pd.Timestamp, str]]
The data must start earlier (or equal) than `start_time`
None indicates data will not be filtered based on `start_time`
end_time : Optional[Union[pd.Timestamp, str]]
similar to start_time
freq : str
The frequency of the calendar
"""
# Align to calendar before filtering
cal = D.calendar(start_time=start_time, end_time=end_time, freq=freq)
self.start_time = None if start_time is None else cal[0]
self.end_time = None if end_time is None else cal[-1]
def __call__(self, df: pd.DataFrame, instrument, *args, **kwargs):
if (
df.empty
or (self.start_time is None or df.index.min() <= self.start_time)
and (self.end_time is None or df.index.max() >= self.end_time)
):
return df
return df.head(0)

View File

@@ -2,9 +2,8 @@
# Licensed under the MIT License.
from __future__ import annotations
import pandas as pd
from typing import Union, List
from typing import Union, List, TYPE_CHECKING
from qlib.utils import init_instance_by_config
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from qlib.data.dataset import DataHandler
@@ -121,7 +120,7 @@ def convert_index_format(df: Union[pd.DataFrame, pd.Series], level: str = "datet
return df
def init_task_handler(task: dict) -> Union[DataHandler, None]:
def init_task_handler(task: dict) -> DataHandler:
"""
initialize the handler part of the task **inplace**
@@ -142,5 +141,6 @@ def init_task_handler(task: dict) -> Union[DataHandler, None]:
if h_conf is not None:
handler = init_instance_by_config(h_conf, accept_types=DataHandler)
task["dataset"]["kwargs"]["handler"] = handler
return handler
else:
raise ValueError("The task does not contains a handler part.")

18
qlib/finco/.env.example Normal file
View File

@@ -0,0 +1,18 @@
OPENAI_API_KEY=your_api_key
# USE_AZURE=True
# AZURE_API_BASE=your_api_base
# AZURE_API_VERSION=your_api_version
# use gpt-4 means more token but more wait time
# MODEL=gpt-4
# MAX_TOKENS=1600
# MAX_RETRY=1000
MAX_TOKENS=1600
MAX_RETRY=120
CONTINOUS_MODE=True
DEBUG_MODE=True

22
qlib/finco/README.md Normal file
View File

@@ -0,0 +1,22 @@
# This is an experimental branch of "`FI`nancial `CO`pilot of `Qlib`"
## Installation
- To run this module, you need to first install Qlib following the instruction in [install-from-source](/README.md#install-from-source) or follow:
```python
python -m pip install git+https://github.com/microsoft/qlib.git@finco
```
- then you need to install other dependencies of finco:
```python
python -m pip install pydantic openai python-dotenv
```
## Quick run
To run this module, you can start the workflow easily with one command:
```sh
cd qlib/finco; python cli.py "your prompt"
```

13
qlib/finco/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
DIRNAME = Path(__file__).absolute().resolve().parent
def get_finco_path() -> Path:
"""
return the template path
Because the template path is located in the folder. We don't know where it is located. So __file__ for this module will be used.
"""
return DIRNAME

15
qlib/finco/cli.py Normal file
View File

@@ -0,0 +1,15 @@
import fire
from qlib.finco.workflow import WorkflowManager
from dotenv import load_dotenv
from qlib import auto_init
def main(prompt=None):
load_dotenv(verbose=True, override=True)
wm = WorkflowManager()
wm.run(prompt)
if __name__ == "__main__":
auto_init()
fire.Fire(main)

15
qlib/finco/cli_learn.py Normal file
View File

@@ -0,0 +1,15 @@
import fire
from qlib.finco.workflow import LearnManager
from dotenv import load_dotenv
from qlib import auto_init
def main(prompt=None):
load_dotenv(verbose=True, override=True)
lm = LearnManager()
lm.run(prompt)
if __name__ == "__main__":
auto_init()
fire.Fire(main)

32
qlib/finco/conf.py Normal file
View File

@@ -0,0 +1,32 @@
# TODO: use pydantic for other modules in Qlib
from pydantic import BaseSettings
from qlib.finco.utils import SingletonBaseClass
import os
class Config(SingletonBaseClass):
"""
This config is for fast demo purpose.
Please use BaseSettings insetead in the future
"""
def __init__(self):
self.use_azure = os.getenv("USE_AZURE") == "True"
self.temperature = 0.5 if os.getenv("TEMPERATURE") is None else float(os.getenv("TEMPERATURE"))
self.max_tokens = 800 if os.getenv("MAX_TOKENS") is None else int(os.getenv("MAX_TOKENS"))
self.openai_api_key = os.getenv("OPENAI_API_KEY")
self.use_azure = os.getenv("USE_AZURE") == "True"
self.azure_api_base = os.getenv("AZURE_API_BASE")
self.azure_api_version = os.getenv("AZURE_API_VERSION")
self.model = os.getenv("MODEL") or ("gpt-35-turbo" if self.use_azure else "gpt-3.5-turbo")
self.max_retry = int(os.getenv("MAX_RETRY")) if os.getenv("MAX_RETRY") is not None else None
self.continuous_mode = (
os.getenv("CONTINOUS_MODE") == "True" if os.getenv("CONTINOUS_MODE") is not None else False
)
self.debug_mode = os.getenv("DEBUG_MODE") == "True" if os.getenv("DEBUG_MODE") is not None else False
self.workspace = os.getenv("WORKSPACE") if os.getenv("WORKSPACE") is not None else "./finco_workspace"
self.max_past_message_include = int(os.getenv("MAX_PAST_MESSAGE_INCLUDE") or 6) // 2 * 2

156
qlib/finco/knowledge.py Normal file
View File

@@ -0,0 +1,156 @@
from pathlib import Path
from jinja2 import Template
from typing import List
from qlib.workflow import R
from qlib.finco.log import FinCoLog
from qlib.finco.llm import APIBackend
class Knowledge:
"""
Use to handle knowledge in finCo such as experiment and outside domain information
"""
def __init__(self):
self.logger = FinCoLog()
def load(self, **kwargs):
"""
Load knowledge in memory
Parameters
----------
Return
------
"""
raise NotImplementedError(f"Please implement the `load` method.")
def brief(self, **kwargs):
"""
Return a brief summary of knowledge
Parameters
----------
Return
------
"""
raise NotImplementedError(f"Please implement the `load` method.")
class KnowledgeExperiment(Knowledge):
"""
Handle knowledge from experiments
"""
def __init__(self, exp_name, rec_id=None):
super().__init__()
self.exp_name = exp_name
self.exp = None
self.recs = []
self.load(exp_name=exp_name, rec_id=rec_id)
def load(self, exp_name, rec_id=None):
recs = []
self.exp = R.get_exp(experiment_name=exp_name)
for r in self.exp.list_recorders(rtype=self.exp.RT_L):
if rec_id is not None and r.id != rec_id:
continue
recs.append(r)
self.recs.extend(recs)
def brief(self):
docs = []
for recorder in self.recs:
docs.append({"exp_name": self.exp.name, "record_info": recorder.info,
"config": recorder.load_object("config"),
"context_summary": recorder.load_object("context_summary")})
return docs
class Topic:
def __init__(self, name: str, describe: Template):
self.name = name
self.describe = describe
self.docs = []
self.knowledge = None
self.logger = FinCoLog()
def summarize(self, docs: list):
self.logger.info(f"Summarize topic: \nname: {self.name}\ndescribe: {self.describe.module}")
prompt_workflow_selection = self.describe.render(docs=docs)
response = APIBackend().build_messages_and_create_chat_completion(
user_prompt=prompt_workflow_selection
)
self.knowledge = response
self.docs = docs
class KnowledgeBase:
"""
Load knowledge, offer brief information of knowledge and common handle interfaces
"""
def __init__(self, init_path=None, topics: List[Topic] = None):
self.logger = FinCoLog()
init_path = init_path if init_path else Path.cwd()
if not init_path.exists():
self.logger.warning(f"{init_path} not exist, create empty directory.")
Path.mkdir(init_path)
self.knowledge = self.load(path=init_path)
# todo: replace list with persistent storage strategy such as ES/pinecone to enable
# literal search/semantic search
self.docs = self.brief(knowledge=self.knowledge)
self.topics = topics if topics else []
def load(self, path) -> List:
if isinstance(path, str):
path = Path(path)
knowledge = []
path = path if path.name == "mlruns" else path.joinpath("mlruns")
R.set_uri(path.as_uri())
for exp_name in R.list_experiments():
knowledge.append(KnowledgeExperiment(exp_name=exp_name))
self.logger.plain_info(f"Load knowledge from: {path} finished.")
return knowledge
def update(self, path):
# note: only update new knowledge in future
knowledge = self.load(path)
self.knowledge = knowledge
self.docs = self.brief(self.knowledge)
self.logger.plain_info(f"Update knowledge finished.")
def brief(self, knowledge: List[Knowledge]) -> List:
docs = []
for k in knowledge:
docs.extend(k.brief())
self.logger.plain_info(f"Generate brief knowledge summary finished.")
return docs
def query(self, content: str = None):
# todo: query by DSL
return self.docs
def query_topics(self):
knowledge_of_topics = []
for topic in self.topics:
knowledge_of_topics.append({topic.name: topic.knowledge})
return knowledge_of_topics
def summarize_by_topic(self):
for topic in self.topics:
topic.summarize(self.docs)

111
qlib/finco/llm.py Normal file
View File

@@ -0,0 +1,111 @@
import os
import time
import openai
import json
from typing import Optional
from qlib.finco.conf import Config
from qlib.finco.utils import SingletonBaseClass
from qlib.finco.log import FinCoLog
class APIBackend(SingletonBaseClass):
def __init__(self):
self.cfg = Config()
openai.api_key = self.cfg.openai_api_key
if self.cfg.use_azure:
openai.api_type = "azure"
openai.api_base = self.cfg.azure_api_base
openai.api_version = self.cfg.azure_api_version
self.use_azure = self.cfg.use_azure
self.debug_mode = False
if self.cfg.debug_mode:
self.debug_mode = True
cwd = os.getcwd()
self.cache_file_location = os.path.join(cwd, "prompt_cache.json")
self.cache = (
json.load(open(self.cache_file_location, "r")) if os.path.exists(self.cache_file_location) else {}
)
def build_messages_and_create_chat_completion(self, user_prompt, system_prompt=None, former_messages=[], **kwargs):
"""build the messages to avoid implementing several redundant lines of code"""
cfg = Config()
# TODO: system prompt should always be provided. In development stage we can use default value
if system_prompt is None:
try:
system_prompt = cfg.system_prompt
except AttributeError:
FinCoLog().warning("system_prompt is not set, using default value.")
system_prompt = "You are an AI assistant who helps to answer user's questions about finance."
messages = [
{
"role": "system",
"content": system_prompt,
}
]
messages.extend(former_messages[-1*cfg.max_past_message_include:])
messages.append(
{
"role": "user",
"content": user_prompt,
}
)
fcl = FinCoLog()
response = self.try_create_chat_completion(messages=messages, **kwargs)
fcl.log_message(messages)
fcl.log_response(response)
return response
def try_create_chat_completion(self, max_retry=10, **kwargs):
max_retry = self.cfg.max_retry if self.cfg.max_retry is not None else max_retry
for i in range(max_retry):
try:
response = self.create_chat_completion(**kwargs)
return response
except (openai.error.RateLimitError, openai.error.Timeout, openai.error.APIError) as e:
print(e)
print(f"Retrying {i+1}th time...")
time.sleep(1)
continue
except openai.InvalidRequestError as e:
print("Invalid request, will try to reduce the messages length and retry...")
if len(kwargs["messages"]) > 2:
kwargs["messages"] = kwargs["messages"][[0]] + kwargs["messages"][3:]
continue
raise e
raise Exception(f"Failed to create chat completion after {max_retry} retries.")
def create_chat_completion(
self,
messages,
model=None,
temperature: float = None,
max_tokens: Optional[int] = None,
) -> str:
if self.debug_mode:
key = json.dumps(messages)
if key in self.cache:
return self.cache[key]
if temperature is None:
temperature = self.cfg.temperature
if max_tokens is None:
max_tokens = self.cfg.max_tokens
if self.cfg.use_azure:
response = openai.ChatCompletion.create(
engine=self.cfg.model,
messages=messages,
max_tokens=self.cfg.max_tokens,
)
else:
response = openai.ChatCompletion.create(
model=self.cfg.model,
messages=messages,
)
resp = response.choices[0].message["content"]
if self.debug_mode:
self.cache[key] = resp
json.dump(self.cache, open(self.cache_file_location, "w"))
return resp

131
qlib/finco/log.py Normal file
View File

@@ -0,0 +1,131 @@
"""
This module will base on Qlib's logger module and provides some interactive functions.
"""
import logging
from typing import Dict, List
from qlib.finco.utils import SingletonBaseClass
from contextlib import contextmanager
class LogColors:
"""
ANSI color codes for use in console output.
"""
RED = "\033[91m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
BLUE = "\033[94m"
MAGENTA = "\033[95m"
CYAN = "\033[96m"
WHITE = "\033[97m"
GRAY = "\033[90m"
BLACK = "\033[30m"
BOLD = "\033[1m"
ITALIC = "\033[3m"
END = "\033[0m"
@classmethod
def get_all_colors(cls):
names = dir(cls)
names = [name for name in names if not name.startswith("__") and not callable(getattr(cls, name))]
var_values = [getattr(cls, name) for name in names]
return var_values
def render(self, text: str, color: str = "", style: str = ""):
"""
render text by input color and style. It's not recommend that input text is already rendered.
"""
# This method is called too frequently, which is not good.
colors = self.get_all_colors()
# Perhaps color and font should be distinguished here.
if color:
assert color in colors, f"color should be in: {colors} but now is: {color}"
if style:
assert style in colors, f"style should be in: {colors} but now is: {style}"
text = f"{color}{text}{self.END}"
text = f"{style}{text}{self.END}"
return text
@contextmanager
def formatting_log(logger, title="Info"):
"""
a context manager, print liens before and after a function
"""
length = {"Start": 120, "Task": 120, "Info": 60, "Interact": 60, "End": 120}.get(title, 60)
color, bold = (LogColors.YELLOW, LogColors.BOLD) \
if title in ["Start", "Task", "Info", "Interact", "End"] else (LogColors.CYAN, "")
logger.info("")
logger.info(f"{color}{bold}{'-'} {title} {'-' * (length - len(title))}{LogColors.END}")
yield
logger.info("")
class FinCoLog(SingletonBaseClass):
# TODO:
# - config to file logger and save it into workspace
def __init__(self) -> None:
self.logger = logging.Logger("interactive")
# TODO: merge these with Qlib's default logger.
# We can do the same thing by changing the default log dict of Qlib.
# Reference: https://github.com/microsoft/qlib/blob/main/qlib/config.py#L155
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(message)s"))
self.logger.addHandler(handler)
self.logger.setLevel(logging.INFO)
def log_message(self, messages: List[Dict[str, str]]):
"""
messages is some info like this [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": user_prompt,
},
]
"""
with formatting_log(self.logger, "GPT Messages"):
for m in messages:
self.logger.info(
f"{LogColors.MAGENTA}{LogColors.BOLD}Role:{LogColors.END} "
f"{LogColors.CYAN}{m['role']}{LogColors.END}\n"
+ f"{LogColors.MAGENTA}{LogColors.BOLD}Content:{LogColors.END} "
f"{LogColors.CYAN}{m['content']}{LogColors.END}\n")
def log_response(self, response: str):
with formatting_log(self.logger, "GPT Response"):
self.logger.info(
f"{LogColors.CYAN}{response}{LogColors.END}\n")
# TODO:
# It looks wierd if we only have logger
def info(self, *args, plain=False, title="Info"):
if plain:
return self.plain_info(*args)
with formatting_log(self.logger, title):
for arg in args:
self.logger.info(f"{LogColors.WHITE}{arg}{LogColors.END}")
def plain_info(self, *args):
for arg in args:
self.logger.info(
f"{LogColors.YELLOW}{LogColors.BOLD}Info:{LogColors.END}{LogColors.WHITE}{arg}{LogColors.END}")
def warning(self, *args):
for arg in args:
self.logger.warning(
f"{LogColors.BLUE}{LogColors.BOLD}Warning:{LogColors.END}{arg}")
def error(self, *args):
for arg in args:
self.logger.error(
f"{LogColors.RED}{LogColors.BOLD}Error:{LogColors.END}{arg}")

View File

@@ -0,0 +1,32 @@
from typing import Union
from pathlib import Path
from jinja2 import Template
import yaml
from qlib.finco.utils import SingletonBaseClass
from qlib.finco import get_finco_path
class PromptTemplate(SingletonBaseClass):
def __init__(self) -> None:
super().__init__()
_template = yaml.load(open(Path.joinpath(get_finco_path(), "prompt_template.yaml"), "r"),
Loader=yaml.FullLoader)
for k, v in _template.items():
if k == "mods":
continue
self.__setattr__(k, Template(v))
def get(self, key: str):
return self.__dict__.get(key, Template(""))
def update(self, key: str, value):
self.__setattr__(key, value)
def save(self, file_path: Union[str, Path]):
if isinstance(file_path, str):
file_path = Path(file_path)
Path.mkdir(file_path.parent, exist_ok=True)
with open(file_path, 'w') as f:
yaml.dump(self.__dict__, f)

File diff suppressed because it is too large Load Diff

1110
qlib/finco/task.py Normal file

File diff suppressed because it is too large Load Diff

12
qlib/finco/tpl/README.md Normal file
View File

@@ -0,0 +1,12 @@
This is a set of templates that should be copied for a new project.
Here are the explanations for the templates folder.
| folder | explanations |
|--------|------------------------------------------------------------------|
| sl | Default configuration for supervised learning |
| sl-cfg | Like configuration in sl. But the dataset is highly configurable |
# TODO
- [ ] [Copier](https://copier.readthedocs.io/en/stable/#quick-start) may be useful if the generation process becomes complicated

View File

@@ -0,0 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
DIRNAME = Path(__file__).absolute().resolve().parent
def get_tpl_path() -> Path:
"""
return the template path
Because the template path is located in the folder. We don't know where it is located. So __file__ for this module will be used.
"""
return DIRNAME

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,73 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
experiment_name: finCo
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
topk: 50
n_drop: 5
backtest:
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LGBModel
module_path: qlib.contrib.model.gbdt
kwargs:
loss: mse
colsample_bytree: 0.8879
learning_rate: 0.2
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs:
model: <MODEL>
dataset: <DATASET>
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

38
qlib/finco/utils.py Normal file
View File

@@ -0,0 +1,38 @@
import json
from fuzzywuzzy import fuzz
class SingletonMeta(type):
_instance = None
def __call__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(SingletonMeta, cls).__call__(*args, **kwargs)
return cls._instance
class SingletonBaseClass(metaclass=SingletonMeta):
"""
Because we try to support defining Singleton with `class A(SingletonBaseClass)` instead of `A(metaclass=SingletonMeta)`
This class becomes necessary
"""
# TODO: Add move this class to Qlib's general utils.
def parse_json(response):
try:
return json.loads(response)
except json.decoder.JSONDecodeError:
pass
raise Exception(f"Failed to parse response: {response}, please report it or help us to fix it.")
def similarity(text1, text2):
text1 = text1 if isinstance(text1, str) else ""
text2 = text2 if isinstance(text2, str) else ""
# Maybe we can use other similarity algorithm such as tfidf
return fuzz.ratio(text1, text2)

223
qlib/finco/workflow.py Normal file
View File

@@ -0,0 +1,223 @@
import sys
import copy
import shutil
from pathlib import Path
from typing import List
from qlib.finco.task import HighLevelPlanTask, SummarizeTask, TrainTask
from qlib.finco.prompt_template import PromptTemplate, Template
from qlib.finco.log import FinCoLog, LogColors
from qlib.finco.utils import similarity
from qlib.finco.llm import APIBackend
from qlib.finco.conf import Config
from qlib.finco.knowledge import KnowledgeBase, Topic
class WorkflowContextManager:
"""Context Manager stores the context of the workflow"""
"""All context are key value pairs which saves the input, output and status of the whole workflow"""
def __init__(self) -> None:
self.context = {}
self.logger = FinCoLog()
def set_context(self, key, value):
if key in self.context:
self.logger.warning("The key already exists in the context, the value will be overwritten")
self.context[key] = value
def get_context(self, key):
# NOTE: if the key doesn't exist, return None. In the future, we may raise an error to detect abnormal behavior
if key not in self.context:
self.logger.warning("The key doesn't exist in the context")
return None
return self.context[key]
def update_context(self, key, new_value):
# NOTE: if the key doesn't exist, return None. In the future, we may raise an error to detect abnormal behavior
if key not in self.context:
self.logger.warning("The key doesn't exist in the context")
self.context.update({key: new_value})
def get_all_context(self):
"""return a deep copy of the context"""
"""TODO: do we need to return a deep copy?"""
return copy.deepcopy(self.context)
def retrieve(self, query: str) -> dict:
if query in self.context.keys():
return {query: self.context.get(query)}
# Note: retrieve information from context by string similarity maybe abandon in future
scores = {}
for k, v in self.context.items():
scores.update({k: max(similarity(query, k), similarity(query, v))})
max_score_key = max(scores, key=scores.get)
return {max_score_key: self.context.get(max_score_key)}
def clear(self, reserve: list = None):
if reserve is None:
reserve = []
_context = {k: self.get_context(k) for k in reserve}
self.context = _context
class WorkflowManager:
"""This manage the whole task automation workflow including tasks and actions"""
def __init__(self, workspace=None) -> None:
self.logger = FinCoLog()
if workspace is None:
self._workspace = Path.cwd() / "finco_workspace"
else:
self._workspace = Path(workspace)
self.conf = Config()
self._confirm_and_rm()
self.prompt_template = PromptTemplate()
self.context = WorkflowContextManager()
self.context.set_context("workspace", self._workspace)
self.default_user_prompt = "Please help me build a low turnover strategy that focus more on longterm return in China A csi300. Please help to use lightgbm model."
def _confirm_and_rm(self):
# if workspace exists, please confirm and remove it. Otherwise exit.
if self._workspace.exists() and not self.conf.continuous_mode:
self.logger.info(title="Interact")
flag = input(
LogColors().render(
f"Will be deleted: \n\t{self._workspace}\n"
f"If you do not need to delete {self._workspace},"
f" please change the workspace dir or rename existing files\n"
f"Are you sure you want to delete, yes(Y/y), no (N/n):",
color=LogColors.WHITE)
)
if str(flag) not in ["Y", "y"]:
sys.exit()
else:
# remove self._workspace
shutil.rmtree(self._workspace)
elif self._workspace.exists() and self.conf.continuous_mode:
shutil.rmtree(self._workspace)
def set_context(self, key, value):
"""Direct call set_context method of the context manager"""
self.context.set_context(key, value)
def get_context(self) -> WorkflowContextManager:
return self.context
def run(self, prompt: str) -> Path:
"""
The workflow manager is supposed to generate a codebase based on the prompt
Parameters
----------
prompt: str
the prompt user gives
Returns
-------
Path
The workflow manager is expected to produce output that includes a codebase containing generated code, results, and reports in a designated location.
The path is returned
The output path should follow a specific format:
- TODO: design
There is a summarized report where user can start from.
"""
# NOTE: The following items are not designed to make the workflow very flexible.
# - The generated tasks can't be changed after geting new information from the execution retuls.
# - But it is required in some cases, if we want to build a external dataset, it maybe have to plan like autogpt...
# NOTE: default user prompt might be changed in the future and exposed to the user
if prompt is None:
self.set_context("user_prompt", self.default_user_prompt)
else:
self.set_context("user_prompt", prompt)
self.logger.info(f"user_prompt: {self.get_context().get_context('user_prompt')}", title="Start")
# NOTE: list may not be enough for general task list
task_list = [HighLevelPlanTask(), SummarizeTask()]
task_finished = []
while len(task_list):
task_list_info = [str(task) for task in task_list]
# task list is not long, so sort it is not a big problem
# TODO: sort the task list based on the priority of the task
# task_list = sorted(task_list, key=lambda x: x.task_type)
t = task_list.pop(0)
self.logger.info(f"Task finished: {[str(task) for task in task_finished]}",
f"Task in queue: {task_list_info}",
f"Executing task: {str(t)}",
title="Task")
t.assign_context_manager(self.context)
res = t.execute()
t.summarize()
task_finished.append(t)
self.context.set_context("task_finished", task_finished)
self.logger.plain_info(f"{str(t)} finished.\n\n\n")
task_list = res + task_list
return self._workspace
class LearnManager:
__DEFAULT_TOPICS = ["IC", "MaxDropDown"]
def __init__(self):
self.epoch = 0
self.wm = WorkflowManager()
topics = [Topic(name=topic, describe=self.wm.prompt_template.get(f"Topic_{topic}")) for topic in
self.__DEFAULT_TOPICS]
self.knowledge_base = KnowledgeBase(init_path=Path.cwd().joinpath('knowledge'), topics=topics)
def run(self, prompt):
# todo: add early stop condition
for i in range(10):
self.wm.run(prompt)
self.knowledge_base.update(self.wm._workspace)
self.knowledge_base.summarize_by_topic()
self.learn()
self.epoch += 1
def learn(self):
workspace = self.wm.context.get_context("workspace")
def _drop_duplicate_task(_task: List):
unique_task = {}
for obj in _task:
task_name = obj.__class__.__name__
if task_name not in unique_task:
unique_task[task_name] = obj
return list(unique_task.values())
# one task maybe run several times in workflow
task_finished = _drop_duplicate_task(self.wm.context.get_context("task_finished"))
user_prompt = self.wm.context.get_context("user_prompt")
summary = self.wm.context.get_context("summary")
for task in task_finished:
prompt_workflow_selection = self.wm.prompt_template.get(f"{self.__class__.__name__}_user").render(
summary=summary, brief=self.knowledge_base.query_topics(),
task_finished=[str(t) for t in task_finished],
task=task.__class__.__name__, system=task.system.render(), user_prompt=user_prompt
)
response = APIBackend().build_messages_and_create_chat_completion(
user_prompt=prompt_workflow_selection,
system_prompt=self.wm.prompt_template.get(f"{self.__class__.__name__}_system").render()
)
# todo: response assertion
task.prompt_template.update(key=f"{task.__class__.__name__}_system", value=Template(response))
self.wm.prompt_template.save(Path.joinpath(workspace, f"prompts/checkpoint_{self.epoch}.yml"))
self.wm.context.clear(reserve=["workspace"])

View File

@@ -28,14 +28,15 @@ from qlib.typehint import Literal
def _get_multi_level_executor_config(
strategy_config: dict,
cash_limit: float = None,
cash_limit: float | None = None,
generate_report: bool = False,
data_granularity: str = "1min",
) -> dict:
executor_config = {
"class": "SimulatorExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": "1min",
"time_per_step": data_granularity,
"verbose": False,
"trade_type": SimulatorExecutor.TT_PARAL if cash_limit is not None else SimulatorExecutor.TT_SERIAL,
"generate_report": generate_report,
@@ -127,7 +128,7 @@ def single_with_simulator(
backtest_config: dict,
orders: pd.DataFrame,
split: Literal["stock", "day"] = "stock",
cash_limit: float = None,
cash_limit: float | None = None,
generate_report: bool = False,
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
"""Run backtest in a single thread with SingleAssetOrderExecution simulator. The orders will be executed day by day.
@@ -154,12 +155,7 @@ def single_with_simulator(
-------
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
"""
if split == "stock":
stock_id = orders.iloc[0].instrument
init_qlib(backtest_config["qlib"], part=stock_id)
else:
day = orders.iloc[0].datetime
init_qlib(backtest_config["qlib"], part=day)
init_qlib(backtest_config["qlib"])
stocks = orders.instrument.unique().tolist()
@@ -181,13 +177,14 @@ def single_with_simulator(
strategy_config=backtest_config["strategies"],
cash_limit=cash_limit,
generate_report=generate_report,
data_granularity=backtest_config["data_granularity"],
)
exchange_config = copy.deepcopy(backtest_config["exchange"])
exchange_config.update(
{
"codes": stocks,
"freq": "1min",
"freq": backtest_config["data_granularity"],
}
)
@@ -202,7 +199,7 @@ def single_with_simulator(
reports.append(simulator.report_dict)
decisions += simulator.decisions
indicator_1day_objs = [report["indicator"]["1day"][1] for report in reports]
indicator_1day_objs = [report["indicator_dict"]["1day"][1] for report in reports]
indicator_info = {k: v for obj in indicator_1day_objs for k, v in obj.order_indicator_his.items()}
records = _convert_indicator_to_dataframe(indicator_info)
assert records is None or not np.isnan(records["ffr"]).any()
@@ -226,7 +223,7 @@ def single_with_collect_data_loop(
backtest_config: dict,
orders: pd.DataFrame,
split: Literal["stock", "day"] = "stock",
cash_limit: float = None,
cash_limit: float | None = None,
generate_report: bool = False,
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
"""Run backtest in a single thread with collect_data_loop.
@@ -253,12 +250,7 @@ def single_with_collect_data_loop(
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
"""
if split == "stock":
stock_id = orders.iloc[0].instrument
init_qlib(backtest_config["qlib"], part=stock_id)
else:
day = orders.iloc[0].datetime
init_qlib(backtest_config["qlib"], part=day)
init_qlib(backtest_config["qlib"])
trade_start_time = orders["datetime"].min()
trade_end_time = orders["datetime"].max()
@@ -280,13 +272,14 @@ def single_with_collect_data_loop(
strategy_config=backtest_config["strategies"],
cash_limit=cash_limit,
generate_report=generate_report,
data_granularity=backtest_config["data_granularity"],
)
exchange_config = copy.deepcopy(backtest_config["exchange"])
exchange_config.update(
{
"codes": stocks,
"freq": "1min",
"freq": backtest_config["data_granularity"],
}
)
@@ -357,7 +350,10 @@ def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFram
if not output_path.exists():
os.makedirs(output_path)
res.to_csv(output_path / "summary.csv")
if "pa" in res.columns:
res["pa"] = res["pa"] * 10000.0 # align with training metrics
res.to_csv(output_path / "backtest_result.csv")
return res

View File

@@ -98,8 +98,9 @@ def get_backtest_config_fromfile(path: str) -> dict:
"debug_single_day": None,
"concurrency": -1,
"multiplier": 1.0,
"output_dir": "outputs/",
"output_dir": "outputs_backtest/",
"generate_report": False,
"data_granularity": "1min",
}
backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)

View File

@@ -1,20 +1,23 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import argparse
import os
import random
import sys
import warnings
from pathlib import Path
from typing import cast, List, Optional
import numpy as np
import pandas as pd
import qlib
import torch
import yaml
from qlib.backtest import Order
from qlib.backtest.decision import OrderDir
from qlib.constant import ONE_MIN
from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data
from qlib.rl.data.native import load_handler_intraday_processed_data
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
from qlib.rl.reward import Reward
@@ -23,7 +26,6 @@ from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
from qlib.rl.utils.log import CsvWriter
from qlib.utils import init_instance_by_config
from tianshou.policy import BasePolicy
from torch import nn
from torch.utils.data import Dataset
@@ -49,19 +51,17 @@ def _read_orders(order_dir: Path) -> pd.DataFrame:
class LazyLoadDataset(Dataset):
def __init__(
self,
data_dir: str,
order_file_path: Path,
data_dir: Path,
default_start_time_index: int,
default_end_time_index: int,
) -> None:
self._default_start_time_index = default_start_time_index
self._default_end_time_index = default_end_time_index
self._order_file_path = order_file_path
self._order_df = _read_orders(order_file_path).reset_index()
self._data_dir = data_dir
self._ticks_index: Optional[pd.DatetimeIndex] = None
self._data_dir = Path(data_dir)
def __len__(self) -> int:
return len(self._order_df)
@@ -74,12 +74,17 @@ class LazyLoadDataset(Dataset):
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
# TODO: of all dates.
backtest_data = load_simple_intraday_backtest_data(
data = load_handler_intraday_processed_data(
data_dir=self._data_dir,
stock_id=row["instrument"],
date=date,
feature_columns_today=[],
feature_columns_yesterday=[],
backtest=True,
index_only=True,
)
self._ticks_index = [t - date for t in backtest_data.get_time_index()]
self._ticks_index = [t - date for t in data.today.index]
order = Order(
stock_id=row["instrument"],
@@ -101,10 +106,9 @@ def train_and_test(
action_interpreter: ActionInterpreter,
policy: BasePolicy,
reward: Reward,
run_training: bool,
run_backtest: bool,
) -> None:
qlib.init()
order_root_path = Path(data_config["source"]["order_dir"])
data_granularity = simulator_config.get("data_granularity", 1)
@@ -112,72 +116,78 @@ def train_and_test(
def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
return SingleAssetOrderExecutionSimple(
order=order,
data_dir=Path(data_config["source"]["data_dir"]),
ticks_per_step=simulator_config["time_per_step"],
data_dir=data_config["source"]["feature_root_dir"],
feature_columns_today=data_config["source"]["feature_columns_today"],
feature_columns_yesterday=data_config["source"]["feature_columns_yesterday"],
data_granularity=data_granularity,
deal_price_type=data_config["source"].get("deal_price_column", "close"),
ticks_per_step=simulator_config["time_per_step"],
vol_threshold=simulator_config["vol_limit"],
)
assert data_config["source"]["default_start_time_index"] % data_granularity == 0
assert data_config["source"]["default_end_time_index"] % data_granularity == 0
train_dataset, valid_dataset, test_dataset = [
LazyLoadDataset(
order_file_path=order_root_path / tag,
data_dir=Path(data_config["source"]["data_dir"]),
if run_training:
train_dataset, valid_dataset = [
LazyLoadDataset(
data_dir=data_config["source"]["feature_root_dir"],
order_file_path=order_root_path / tag,
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
)
for tag in ("train", "valid")
]
callbacks: List[Callback] = []
if "checkpoint_path" in trainer_config:
callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"])))
callbacks.append(
Checkpoint(
dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints",
every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1),
save_latest="copy",
),
)
if "earlystop_patience" in trainer_config:
callbacks.append(
EarlyStopping(
patience=trainer_config["earlystop_patience"],
monitor="val/pa",
)
)
train(
simulator_fn=_simulator_factory_simple,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
policy=policy,
reward=reward,
initial_states=cast(List[Order], train_dataset),
trainer_kwargs={
"max_iters": trainer_config["max_epoch"],
"finite_env_type": env_config["parallel_mode"],
"concurrency": env_config["concurrency"],
"val_every_n_iters": trainer_config.get("val_every_n_epoch", None),
"callbacks": callbacks,
},
vessel_kwargs={
"episode_per_iter": trainer_config["episode_per_collect"],
"update_kwargs": {
"batch_size": trainer_config["batch_size"],
"repeat": trainer_config["repeat_per_collect"],
},
"val_initial_states": valid_dataset,
},
)
if run_backtest:
test_dataset = LazyLoadDataset(
data_dir=data_config["source"]["feature_root_dir"],
order_file_path=order_root_path / "test",
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
)
for tag in ("train", "valid", "test")
]
if "checkpoint_path" in trainer_config:
callbacks: List[Callback] = []
callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"])))
callbacks.append(
Checkpoint(
dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints",
every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1),
save_latest="copy",
),
)
if "earlystop_patience" in trainer_config:
callbacks.append(
EarlyStopping(
patience=trainer_config["earlystop_patience"],
monitor="val/pa",
)
)
trainer_kwargs = {
"max_iters": trainer_config["max_epoch"],
"finite_env_type": env_config["parallel_mode"],
"concurrency": env_config["concurrency"],
"val_every_n_iters": trainer_config.get("val_every_n_epoch", None),
"callbacks": callbacks,
}
vessel_kwargs = {
"episode_per_iter": trainer_config["episode_per_collect"],
"update_kwargs": {
"batch_size": trainer_config["batch_size"],
"repeat": trainer_config["repeat_per_collect"],
},
"val_initial_states": valid_dataset,
}
train(
simulator_fn=_simulator_factory_simple,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
policy=policy,
reward=reward,
initial_states=cast(List[Order], train_dataset),
trainer_kwargs=trainer_kwargs,
vessel_kwargs=vessel_kwargs,
)
if run_backtest:
backtest(
simulator_fn=_simulator_factory_simple,
state_interpreter=state_interpreter,
@@ -186,35 +196,42 @@ def train_and_test(
policy=policy,
logger=CsvWriter(Path(trainer_config["checkpoint_path"])),
reward=reward,
finite_env_type=trainer_kwargs["finite_env_type"],
concurrency=trainer_kwargs["concurrency"],
finite_env_type=env_config["parallel_mode"],
concurrency=env_config["concurrency"],
)
def main(config: dict, run_backtest: bool) -> None:
def main(config: dict, run_training: bool, run_backtest: bool) -> None:
if not run_training and not run_backtest:
warnings.warn("Skip the entire job since training and backtest are both skipped.")
return
if "seed" in config["runtime"]:
seed_everything(config["runtime"]["seed"])
state_config = config["state_interpreter"]
state_interpreter: StateInterpreter = init_instance_by_config(state_config)
for extra_module_path in config["env"].get("extra_module_paths", []):
sys.path.append(extra_module_path)
state_interpreter: StateInterpreter = init_instance_by_config(config["state_interpreter"])
action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"])
reward: Reward = init_instance_by_config(config["reward"])
additional_policy_kwargs = {
"obs_space": state_interpreter.observation_space,
"action_space": action_interpreter.action_space,
}
# Create torch network
if "kwargs" not in config["network"]:
config["network"]["kwargs"] = {}
config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space})
network: nn.Module = init_instance_by_config(config["network"])
if "network" in config:
if "kwargs" not in config["network"]:
config["network"]["kwargs"] = {}
config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space})
additional_policy_kwargs["network"] = init_instance_by_config(config["network"])
# Create policy
config["policy"]["kwargs"].update(
{
"network": network,
"obs_space": state_interpreter.observation_space,
"action_space": action_interpreter.action_space,
}
)
if "kwargs" not in config["policy"]:
config["policy"]["kwargs"] = {}
config["policy"]["kwargs"].update(additional_policy_kwargs)
policy: BasePolicy = init_instance_by_config(config["policy"])
use_cuda = config["runtime"].get("use_cuda", False)
@@ -230,22 +247,22 @@ def main(config: dict, run_backtest: bool) -> None:
state_interpreter=state_interpreter,
policy=policy,
reward=reward,
run_training=run_training,
run_backtest=run_backtest,
)
if __name__ == "__main__":
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow after training is finished")
parser.add_argument("--no_training", action="store_true", help="Skip training workflow.")
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.")
args = parser.parse_args()
with open(args.config_path, "r") as input_stream:
config = yaml.safe_load(input_stream)
main(config, run_backtest=args.run_backtest)
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)

View File

@@ -8,48 +8,14 @@ TODO: The implementation here is kind of adhoc. It is better to design a more un
from __future__ import annotations
import pickle
from pathlib import Path
from typing import List
import cachetools
import numpy as np
import pandas as pd
import qlib
from qlib.constant import REG_CN
from qlib.contrib.ops.high_freq import BFillNan, Cut, Date, DayCumsum, DayLast, FFillNan, IsInf, IsNull, Select
from qlib.data.dataset import DatasetH
dataset = None
class DataWrapper:
def __init__(
self,
feature_dataset: DatasetH,
backtest_dataset: DatasetH,
columns_today: List[str],
columns_yesterday: List[str],
_internal: bool = False,
):
assert _internal, "Init function of data wrapper is for internal use only."
self.feature_dataset = feature_dataset
self.backtest_dataset = backtest_dataset
self.columns_today = columns_today
self.columns_yesterday = columns_yesterday
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(100),
key=lambda _, stock_id, date, backtest: (stock_id, date.replace(hour=0, minute=0, second=0), backtest),
)
def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False) -> pd.DataFrame:
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
dataset = self.backtest_dataset if backtest else self.feature_dataset
return dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
def init_qlib(qlib_config: dict, part: str = None) -> None:
def init_qlib(qlib_config: dict) -> None:
"""Initialize necessary resource to launch the workflow, including data direction, feature columns, etc..
Parameters
@@ -72,20 +38,15 @@ def init_qlib(qlib_config: dict, part: str = None) -> None:
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1",
],
}
part
Identifying which part (stock / date) to load.
"""
global dataset # pylint: disable=W0603
def _convert_to_path(path: str | Path) -> Path:
return path if isinstance(path, Path) else Path(path)
provider_uri_map = {}
if "provider_uri_day" in qlib_config:
provider_uri_map["day"] = _convert_to_path(qlib_config["provider_uri_day"]).as_posix()
if "provider_uri_1min" in qlib_config:
provider_uri_map["1min"] = _convert_to_path(qlib_config["provider_uri_1min"]).as_posix()
for granularity in ["1min", "5min", "day"]:
if f"provider_uri_{granularity}" in qlib_config:
provider_uri_map[f"{granularity}"] = _convert_to_path(qlib_config[f"provider_uri_{granularity}"]).as_posix()
qlib.init(
region=REG_CN,
@@ -119,47 +80,3 @@ def init_qlib(qlib_config: dict, part: str = None) -> None:
redis_port=-1,
clear_mem_cache=False, # init_qlib will be called for multiple times. Keep the cache for improving performance
)
if part == "skip":
return
# this won't work if it's put outside in case of multiprocessing
from qlib.data import D # noqa pylint: disable=C0415,W0611
if part is None:
feature_path = Path(qlib_config["feature_root_dir"]) / "feature.pkl"
backtest_path = Path(qlib_config["feature_root_dir"]) / "backtest.pkl"
else:
feature_path = Path(qlib_config["feature_root_dir"]) / "feature" / (part + ".pkl")
backtest_path = Path(qlib_config["feature_root_dir"]) / "backtest" / (part + ".pkl")
with feature_path.open("rb") as f:
feature_dataset = pickle.load(f)
with backtest_path.open("rb") as f:
backtest_dataset = pickle.load(f)
dataset = DataWrapper(
feature_dataset,
backtest_dataset,
qlib_config["feature_columns_today"],
qlib_config["feature_columns_yesterday"],
_internal=True,
)
def fetch_features(stock_id: str, date: pd.Timestamp, yesterday: bool = False, backtest: bool = False) -> pd.DataFrame:
assert dataset is not None, "You must call init_qlib() before doing this."
if backtest:
fields = ["$close", "$volume"]
else:
fields = dataset.columns_yesterday if yesterday else dataset.columns_today
data = dataset.get(stock_id, date, backtest)
if data is None or len(data) == 0:
# create a fake index, but RL doesn't care about index
data = pd.DataFrame(0.0, index=np.arange(240), columns=fields, dtype=np.float32) # FIXME: hardcode here
else:
data = data.rename(columns={c: c.rstrip("0") for c in data.columns})
data = data[fields]
return data

View File

@@ -2,17 +2,29 @@
# Licensed under the MIT License.
from __future__ import annotations
from typing import cast
from pathlib import Path
from typing import cast, List
import cachetools
import pandas as pd
import pickle
import os
from qlib.backtest import Exchange, Order
from qlib.backtest.decision import TradeRange, TradeRangeByTime
from qlib.rl.order_execution.utils import get_ticks_slice
from qlib.constant import EPS_T
from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
from .integration import fetch_features
def get_ticks_slice(
ticks_index: pd.DatetimeIndex,
start: pd.Timestamp,
end: pd.Timestamp,
include_end: bool = False,
) -> pd.DatetimeIndex:
if not include_end:
end = end - EPS_T
return ticks_index[ticks_index.slice_indexer(start, end)]
class IntradayBacktestData(BaseIntradayBacktestData):
@@ -71,6 +83,31 @@ class IntradayBacktestData(BaseIntradayBacktestData):
return pd.DatetimeIndex([e[1] for e in list(self._exchange.quote_df.index)])
class DataframeIntradayBacktestData(BaseIntradayBacktestData):
"""Backtest data from dataframe"""
def __init__(self, df: pd.DataFrame, price_column: str = "$close0", volume_column: str = "$volume0") -> None:
self.df = df
self.price_column = price_column
self.volume_column = volume_column
def __repr__(self) -> str:
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
return f"{self.__class__.__name__}({self.df})"
def __len__(self) -> int:
return len(self.df)
def get_deal_price(self) -> pd.Series:
return self.df[self.price_column]
def get_volume(self) -> pd.Series:
return self.df[self.volume_column]
def get_time_index(self) -> pd.DatetimeIndex:
return cast(pd.DatetimeIndex, self.df.index)
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(100),
key=lambda order, _, __: order.key_by_day,
@@ -103,13 +140,18 @@ def load_backtest_data(
return backtest_data
class NTIntradayProcessedData(BaseIntradayProcessedData):
"""Subclass of IntradayProcessedData. Used to handle NT style data."""
class HandlerIntradayProcessedData(BaseIntradayProcessedData):
"""Subclass of IntradayProcessedData. Used to handle handler (bin format) style data."""
def __init__(
self,
data_dir: Path,
stock_id: str,
date: pd.Timestamp,
feature_columns_today: List[str],
feature_columns_yesterday: List[str],
backtest: bool = False,
index_only: bool = False,
) -> None:
def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame:
df = df.reset_index()
@@ -117,8 +159,18 @@ class NTIntradayProcessedData(BaseIntradayProcessedData):
df = df.drop(columns=["instrument"])
return df.set_index(["datetime"])
self.today = _drop_stock_id(fetch_features(stock_id, date))
self.yesterday = _drop_stock_id(fetch_features(stock_id, date, yesterday=True))
path = os.path.join(data_dir, "backtest" if backtest else "feature", f"{stock_id}.pkl")
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
with open(path, "rb") as fstream:
dataset = pickle.load(fstream)
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
if index_only:
self.today = _drop_stock_id(data[[]])
self.yesterday = _drop_stock_id(data[[]])
else:
self.today = _drop_stock_id(data[feature_columns_today])
self.yesterday = _drop_stock_id(data[feature_columns_yesterday])
def __repr__(self) -> str:
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
@@ -127,12 +179,42 @@ class NTIntradayProcessedData(BaseIntradayProcessedData):
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
key=lambda data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only: (
stock_id,
date,
backtest,
index_only,
),
)
def load_nt_intraday_processed_data(stock_id: str, date: pd.Timestamp) -> NTIntradayProcessedData:
return NTIntradayProcessedData(stock_id, date)
def load_handler_intraday_processed_data(
data_dir: Path,
stock_id: str,
date: pd.Timestamp,
feature_columns_today: List[str],
feature_columns_yesterday: List[str],
backtest: bool = False,
index_only: bool = False,
) -> HandlerIntradayProcessedData:
return HandlerIntradayProcessedData(
data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only
)
class NTProcessedDataProvider(ProcessedDataProvider):
class HandlerProcessedDataProvider(ProcessedDataProvider):
def __init__(
self,
data_dir: str,
feature_columns_today: List[str],
feature_columns_yesterday: List[str],
backtest: bool = False,
) -> None:
super().__init__()
self.data_dir = Path(data_dir)
self.feature_columns_today = feature_columns_today
self.feature_columns_yesterday = feature_columns_yesterday
self.backtest = backtest
def get_data(
self,
stock_id: str,
@@ -140,4 +222,12 @@ class NTProcessedDataProvider(ProcessedDataProvider):
feature_dim: int,
time_index: pd.Index,
) -> BaseIntradayProcessedData:
return load_nt_intraday_processed_data(stock_id, date)
return load_handler_intraday_processed_data(
self.data_dir,
stock_id,
date,
self.feature_columns_today,
self.feature_columns_yesterday,
backtest=self.backtest,
index_only=False,
)

View File

@@ -104,7 +104,7 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData):
stock_id: str,
date: pd.Timestamp,
deal_price: DealPriceType = "close",
order_dir: int = None,
order_dir: int | None = None,
) -> None:
super(SimpleIntradayBacktestData, self).__init__()
@@ -158,8 +158,8 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData):
return cast(pd.DatetimeIndex, self.data.index)
class IntradayProcessedData(BaseIntradayProcessedData):
"""Subclass of IntradayProcessedData. Used to handle Dataset Handler style data."""
class PickleIntradayProcessedData(BaseIntradayProcessedData):
"""Subclass of IntradayProcessedData. Used to handle pickle-styled data."""
def __init__(
self,
@@ -208,7 +208,7 @@ def load_simple_intraday_backtest_data(
stock_id: str,
date: pd.Timestamp,
deal_price: DealPriceType = "close",
order_dir: int = None,
order_dir: int | None = None,
) -> SimpleIntradayBacktestData:
return SimpleIntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
@@ -217,14 +217,14 @@ def load_simple_intraday_backtest_data(
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
key=lambda data_dir, stock_id, date, feature_dim, time_index: hashkey(data_dir, stock_id, date),
)
def load_pickled_intraday_processed_data(
def load_pickle_intraday_processed_data(
data_dir: Path,
stock_id: str,
date: pd.Timestamp,
feature_dim: int,
time_index: pd.Index,
) -> BaseIntradayProcessedData:
return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
return PickleIntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
class PickleProcessedDataProvider(ProcessedDataProvider):
@@ -240,7 +240,7 @@ class PickleProcessedDataProvider(ProcessedDataProvider):
feature_dim: int,
time_index: pd.Index,
) -> BaseIntradayProcessedData:
return load_pickled_intraday_processed_data(
return load_pickle_intraday_processed_data(
data_dir=self._data_dir,
stock_id=stock_id,
date=date,

View File

@@ -53,6 +53,18 @@ class FullHistoryObs(TypedDict):
position_history: Any
class DummyStateInterpreter(StateInterpreter[SAOEState, dict]):
"""Dummy interpreter for policies that do not need inputs (for example, AllOne)."""
def interpret(self, state: SAOEState) -> dict:
# TODO: A fake state, used to pass `check_nan_observation`. Find a better way in the future.
return {"DUMMY": _to_int32(1)}
@property
def observation_space(self) -> spaces.Dict:
return spaces.Dict({"DUMMY": spaces.Box(-np.inf, np.inf, shape=(), dtype=np.int32)})
class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
"""The observation of all the history, including today (until this moment), and yesterday.

View File

@@ -12,11 +12,11 @@ import torch
import torch.nn as nn
from gym.spaces import Discrete
from tianshou.data import Batch, ReplayBuffer, to_torch
from tianshou.policy import BasePolicy, PPOPolicy
from tianshou.policy import BasePolicy, PPOPolicy, DQNPolicy
from qlib.rl.trainer.trainer import Trainer
__all__ = ["AllOne", "PPO"]
__all__ = ["AllOne", "PPO", "DQN"]
# baselines #
@@ -32,7 +32,7 @@ class NonLearnablePolicy(BasePolicy):
super().__init__()
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, Any]:
pass
return {}
def process_fn(
self,
@@ -40,7 +40,7 @@ class NonLearnablePolicy(BasePolicy):
buffer: ReplayBuffer,
indices: np.ndarray,
) -> Batch:
pass
return Batch({})
class AllOne(NonLearnablePolicy):
@@ -49,13 +49,18 @@ class AllOne(NonLearnablePolicy):
Useful when implementing some baselines (e.g., TWAP).
"""
def __init__(self, obs_space: gym.Space, action_space: gym.Space, fill_value: float | int = 1.0) -> None:
super().__init__(obs_space, action_space)
self.fill_value = fill_value
def forward(
self,
batch: Batch,
state: dict | Batch | np.ndarray = None,
**kwargs: Any,
) -> Batch:
return Batch(act=np.full(len(batch), 1.0), state=state)
return Batch(act=np.full(len(batch), self.fill_value), state=state)
# ppo #
@@ -153,6 +158,56 @@ class PPO(PPOPolicy):
set_weight(self, Trainer.get_policy_state_dict(weight_file))
DQNModel = PPOActor # Reuse PPOActor.
class DQN(DQNPolicy):
"""A wrapper of tianshou DQNPolicy.
Differences:
- Auto-create model network. Supports discrete action space only.
- Support a ``weight_file`` that supports loading checkpoint.
"""
def __init__(
self,
network: nn.Module,
obs_space: gym.Space,
action_space: gym.Space,
lr: float,
weight_decay: float = 0.0,
discount_factor: float = 0.99,
estimation_step: int = 1,
target_update_freq: int = 0,
reward_normalization: bool = False,
is_double: bool = True,
clip_loss_grad: bool = False,
weight_file: Optional[Path] = None,
) -> None:
assert isinstance(action_space, Discrete)
model = DQNModel(network, action_space.n)
optimizer = torch.optim.Adam(
model.parameters(),
lr=lr,
weight_decay=weight_decay,
)
super().__init__(
model,
optimizer,
discount_factor=discount_factor,
estimation_step=estimation_step,
target_update_freq=target_update_freq,
reward_normalization=reward_normalization,
is_double=is_double,
clip_loss_grad=clip_loss_grad,
)
if weight_file is not None:
set_weight(self, Trainer.get_policy_state_dict(weight_file))
# utilities: these should be put in a separate (common) file. #

View File

@@ -7,6 +7,7 @@ from typing import cast
import numpy as np
from qlib.backtest.decision import OrderDir
from qlib.rl.order_execution.state import SAOEMetrics, SAOEState
from qlib.rl.reward import Reward
@@ -47,3 +48,52 @@ class PAPenaltyReward(Reward[SAOEState]):
self.log("reward/pa", pa)
self.log("reward/penalty", penalty)
return reward * self.scale
class PPOReward(Reward[SAOEState]):
"""Reward proposed by paper "An End-to-End Optimal Trade Execution Framework based on Proximal Policy Optimization".
Parameters
----------
max_step
Maximum number of steps.
start_time_index
First time index that allowed to trade.
end_time_index
Last time index that allowed to trade.
"""
def __init__(self, max_step: int, start_time_index: int = 0, end_time_index: int = 239) -> None:
self.max_step = max_step
self.start_time_index = start_time_index
self.end_time_index = end_time_index
def reward(self, simulator_state: SAOEState) -> float:
if simulator_state.cur_step == self.max_step - 1 or simulator_state.position < 1e-6:
if simulator_state.history_exec["deal_amount"].sum() == 0.0:
vwap_price = cast(
float,
np.average(simulator_state.history_exec["market_price"]),
)
else:
vwap_price = cast(
float,
np.average(
simulator_state.history_exec["market_price"],
weights=simulator_state.history_exec["deal_amount"],
),
)
twap_price = simulator_state.backtest_data.get_deal_price().mean()
if simulator_state.order.direction == OrderDir.SELL:
ratio = vwap_price / twap_price if twap_price != 0 else 1.0
else:
ratio = twap_price / vwap_price if vwap_price != 0 else 1.0
if ratio < 1.0:
return -1.0
elif ratio < 1.1:
return 0.0
else:
return 1.0
else:
return 0.0

View File

@@ -38,8 +38,8 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
order: Order,
executor_config: dict,
exchange_config: dict,
qlib_config: dict = None,
cash_limit: Optional[float] = None,
qlib_config: dict | None = None,
cash_limit: float | None = None,
) -> None:
super().__init__(initial=order)
@@ -63,11 +63,11 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
strategy_config: dict,
executor_config: dict,
exchange_config: dict,
qlib_config: dict = None,
qlib_config: dict | None = None,
cash_limit: Optional[float] = None,
) -> None:
if qlib_config is not None:
init_qlib(qlib_config, part="skip")
init_qlib(qlib_config)
strategy, self._executor = get_strategy_executor(
start_time=order.date,

View File

@@ -3,17 +3,19 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, cast, Optional
from typing import Any, cast, List, Optional
import numpy as np
import pandas as pd
from pathlib import Path
from qlib.backtest.decision import Order, OrderDir
from qlib.constant import EPS, EPS_T, float_or_ndarray
from qlib.rl.data.pickle_styled import DealPriceType, load_simple_intraday_backtest_data
from qlib.rl.data.base import BaseIntradayBacktestData
from qlib.rl.data.native import DataframeIntradayBacktestData, load_handler_intraday_processed_data
from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data
from qlib.rl.simulator import Simulator
from qlib.rl.utils import LogLevel
from .state import SAOEMetrics, SAOEState
__all__ = ["SingleAssetOrderExecutionSimple"]
@@ -36,12 +38,16 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
----------
order
The seed to start an SAOE simulator is an order.
data_dir
Path to load backtest data.
feature_columns_today
Columns of today's feature.
feature_columns_yesterday
Columns of yesterday's feature.
data_granularity
Number of ticks between consecutive data entries.
ticks_per_step
How many ticks per step.
data_dir
Path to load backtest data
vol_threshold
Maximum execution volume (divided by market execution volume).
"""
@@ -73,9 +79,10 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
self,
order: Order,
data_dir: Path,
feature_columns_today: List[str] = [],
feature_columns_yesterday: List[str] = [],
data_granularity: int = 1,
ticks_per_step: int = 30,
deal_price_type: DealPriceType = "close",
vol_threshold: Optional[float] = None,
) -> None:
super().__init__(initial=order)
@@ -83,18 +90,13 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
assert ticks_per_step % data_granularity == 0
self.order = order
self.ticks_per_step: int = ticks_per_step // data_granularity
self.deal_price_type = deal_price_type
self.vol_threshold = vol_threshold
self.data_dir = data_dir
self.backtest_data = load_simple_intraday_backtest_data(
self.data_dir,
order.stock_id,
pd.Timestamp(order.start_time.date()),
self.deal_price_type,
order.direction,
)
self.feature_columns_today = feature_columns_today
self.feature_columns_yesterday = feature_columns_yesterday
self.ticks_per_step: int = ticks_per_step // data_granularity
self.vol_threshold = vol_threshold
self.backtest_data = self.get_backtest_data()
self.ticks_index = self.backtest_data.get_time_index()
# Get time index available for trading
@@ -118,6 +120,30 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
self.market_vol: Optional[np.ndarray] = None
self.market_vol_limit: Optional[np.ndarray] = None
def get_backtest_data(self) -> BaseIntradayBacktestData:
try:
data = load_handler_intraday_processed_data(
data_dir=self.data_dir,
stock_id=self.order.stock_id,
date=pd.Timestamp(self.order.start_time.date()),
feature_columns_today=self.feature_columns_today,
feature_columns_yesterday=self.feature_columns_yesterday,
backtest=True,
index_only=False,
)
return DataframeIntradayBacktestData(data.today)
except (AttributeError, FileNotFoundError):
# TODO: For compatibility with older versions of test scripts (tests/rl/test_saoe_simple.py)
# TODO: In the future, we should modify the data format used by the test script,
# TODO: and then delete this branch.
return load_simple_intraday_backtest_data(
self.data_dir / "backtest",
self.order.stock_id,
pd.Timestamp(self.order.start_time.date()),
"close",
self.order.direction,
)
def step(self, amount: float) -> None:
"""Execute one step or SAOE.

View File

@@ -7,6 +7,7 @@ import collections
from types import GeneratorType
from typing import Any, Callable, cast, Dict, Generator, List, Optional, Tuple, Union
import warnings
import numpy as np
import pandas as pd
import torch
@@ -89,6 +90,7 @@ class SAOEStateAdapter:
exchange: Exchange,
ticks_per_step: int,
backtest_data: IntradayBacktestData,
data_granularity: int = 1,
) -> None:
self.position = order.amount
self.order = order
@@ -106,11 +108,13 @@ class SAOEStateAdapter:
self.cur_time = max(backtest_data.ticks_for_order[0], order.start_time)
self.ticks_per_step = ticks_per_step
self.data_granularity = data_granularity
assert self.ticks_per_step % self.data_granularity == 0
def _next_time(self) -> pd.Timestamp:
current_loc = self.backtest_data.ticks_index.get_loc(self.cur_time)
next_loc = current_loc + self.ticks_per_step
next_loc = next_loc - next_loc % self.ticks_per_step
next_loc = current_loc + (self.ticks_per_step // self.data_granularity)
next_loc = next_loc - next_loc % (self.ticks_per_step // self.data_granularity)
if (
next_loc < len(self.backtest_data.ticks_index)
and self.backtest_data.ticks_index[next_loc] < self.order.end_time
@@ -130,11 +134,16 @@ class SAOEStateAdapter:
exec_vol = np.zeros(last_step_size)
for order, _, __, ___ in execute_result:
idx, _ = get_day_min_idx_range(order.start_time, order.end_time, "1min", REG_CN)
idx, _ = get_day_min_idx_range(order.start_time, order.end_time, f"{self.data_granularity}min", REG_CN)
exec_vol[idx - last_step_range[0]] = order.deal_amount
if exec_vol.sum() > self.position and exec_vol.sum() > 0.0:
assert exec_vol.sum() < self.position + 1, f"{exec_vol} too large"
if exec_vol.sum() > self.position + 1.0:
warnings.warn(
f"Sum of execution volume is {exec_vol.sum()} which is larger than "
f"position + 1.0 = {self.position} + 1.0 = {self.position + 1.0}. "
f"All execution volume is scaled down linearly to ensure that their sum does not position."
)
exec_vol *= self.position / (exec_vol.sum())
market_volume = cast(
@@ -168,7 +177,9 @@ class SAOEStateAdapter:
self.history_exec,
self._collect_multi_order_metric(
order=self.order,
datetime=_get_all_timestamps(start_time, end_time, include_end=True),
datetime=_get_all_timestamps(
start_time, end_time, include_end=True, granularity=ONE_MIN * self.data_granularity
),
market_vol=market_volume,
market_price=market_price,
exec_vol=exec_vol,
@@ -293,9 +304,10 @@ class SAOEStrategy(RLStrategy):
def __init__(
self,
policy: BasePolicy,
outer_trade_decision: BaseTradeDecision = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
outer_trade_decision: BaseTradeDecision | None = None,
level_infra: LevelInfrastructure | None = None,
common_infra: CommonInfrastructure | None = None,
data_granularity: int = 1,
**kwargs: Any,
) -> None:
super(SAOEStrategy, self).__init__(
@@ -306,6 +318,7 @@ class SAOEStrategy(RLStrategy):
**kwargs,
)
self._data_granularity = data_granularity
self.adapter_dict: Dict[tuple, SAOEStateAdapter] = {}
self._last_step_range = (0, 0)
@@ -324,9 +337,10 @@ class SAOEStrategy(RLStrategy):
exchange=self.trade_exchange,
ticks_per_step=int(pd.Timedelta(self.trade_calendar.get_freq()) / ONE_MIN),
backtest_data=backtest_data,
data_granularity=self._data_granularity,
)
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None:
def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None:
super(SAOEStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
self.adapter_dict = {}
@@ -366,7 +380,7 @@ class SAOEStrategy(RLStrategy):
def generate_trade_decision(
self,
execute_result: list = None,
execute_result: list | None = None,
) -> Union[BaseTradeDecision, Generator[Any, Any, BaseTradeDecision]]:
"""
For SAOEStrategy, we need to update the `self._last_step_range` every time a decision is generated.
@@ -385,7 +399,7 @@ class SAOEStrategy(RLStrategy):
def _generate_trade_decision(
self,
execute_result: list = None,
execute_result: list | None = None,
) -> Union[BaseTradeDecision, Generator[Any, Any, BaseTradeDecision]]:
raise NotImplementedError
@@ -399,14 +413,14 @@ class ProxySAOEStrategy(SAOEStrategy):
def __init__(
self,
outer_trade_decision: BaseTradeDecision = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
outer_trade_decision: BaseTradeDecision | None = None,
level_infra: LevelInfrastructure | None = None,
common_infra: CommonInfrastructure | None = None,
**kwargs: Any,
) -> None:
super().__init__(None, outer_trade_decision, level_infra, common_infra, **kwargs)
def _generate_trade_decision(self, execute_result: list = None) -> Generator[Any, Any, BaseTradeDecision]:
def _generate_trade_decision(self, execute_result: list | None = None) -> Generator[Any, Any, BaseTradeDecision]:
# Once the following line is executed, this ProxySAOEStrategy (self) will be yielded to the outside
# of the entire executor, and the execution will be suspended. When the execution is resumed by `send()`,
# the item will be captured by `exec_vol`. The outside policy could communicate with the inner
@@ -418,7 +432,7 @@ class ProxySAOEStrategy(SAOEStrategy):
return TradeDecisionWO([order], self)
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None:
def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None:
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
assert isinstance(outer_trade_decision, TradeDecisionWO)
@@ -437,9 +451,9 @@ class SAOEIntStrategy(SAOEStrategy):
state_interpreter: dict | StateInterpreter,
action_interpreter: dict | ActionInterpreter,
network: dict | torch.nn.Module | None = None,
outer_trade_decision: BaseTradeDecision = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
outer_trade_decision: BaseTradeDecision | None = None,
level_infra: LevelInfrastructure | None = None,
common_infra: CommonInfrastructure | None = None,
**kwargs: Any,
) -> None:
super(SAOEIntStrategy, self).__init__(
@@ -488,7 +502,7 @@ class SAOEIntStrategy(SAOEStrategy):
if self._policy is not None:
self._policy.eval()
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None:
def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None:
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
def _generate_trade_details(self, act: np.ndarray, exec_vols: List[float]) -> pd.DataFrame:
@@ -508,7 +522,7 @@ class SAOEIntStrategy(SAOEStrategy):
trade_details[-1]["rl_action"] = a
return pd.DataFrame.from_records(trade_details)
def _generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision:
def _generate_trade_decision(self, execute_result: list | None = None) -> BaseTradeDecision:
states = []
obs_batch = []
for decision in self.outer_trade_decision.get_decision():

View File

@@ -10,18 +10,7 @@ import pandas as pd
from qlib.backtest.decision import OrderDir
from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor
from qlib.constant import EPS_T, float_or_ndarray
def get_ticks_slice(
ticks_index: pd.DatetimeIndex,
start: pd.Timestamp,
end: pd.Timestamp,
include_end: bool = False,
) -> pd.DatetimeIndex:
if not include_end:
end = end - EPS_T
return ticks_index[ticks_index.slice_indexer(start, end)]
from qlib.constant import float_or_ndarray
def dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame:

View File

@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from qlib.backtest import Order
from qlib.backtest.decision import OrderHelper, TradeDecisionWO, TradeRange
from qlib.strategy.base import BaseStrategy
@@ -12,14 +14,14 @@ class SingleOrderStrategy(BaseStrategy):
def __init__(
self,
order: Order,
trade_range: TradeRange = None,
trade_range: TradeRange | None = None,
) -> None:
super().__init__()
self._order = order
self._trade_range = trade_range
def generate_trade_decision(self, execute_result: list = None) -> TradeDecisionWO:
def generate_trade_decision(self, execute_result: list | None = None) -> TradeDecisionWO:
oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper()
order_list = [
oh.create(

View File

@@ -4,6 +4,7 @@
from __future__ import annotations
import multiprocessing
from multiprocessing.sharedctypes import Synchronized
import os
import threading
import time
@@ -78,7 +79,9 @@ class DataQueue(Generic[T]):
self._activated: bool = False
self._queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=queue_maxsize)
self._done = multiprocessing.Value("i", 0)
# Mypy 0.981 brought '"SynchronizedBase[Any]" has no attribute "value" [attr-defined]' bug.
# Therefore, add this type casting to pass Mypy checking.
self._done = cast(Synchronized, multiprocessing.Value("i", 0))
def __enter__(self) -> DataQueue:
self.activate()
@@ -122,7 +125,7 @@ class DataQueue(Generic[T]):
if self._done.value:
raise StopIteration # pylint: disable=raise-missing-from
def put(self, obj: Any, block: bool = True, timeout: int = None) -> None:
def put(self, obj: Any, block: bool = True, timeout: int | None = None) -> None:
self._queue.put(obj, block=block, timeout=timeout)
def mark_as_done(self) -> None:

View File

@@ -99,9 +99,9 @@ class EnvWrapper(
state_interpreter: StateInterpreter[StateType, ObsType],
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
seed_iterator: Optional[Iterable[InitialStateType]],
reward_fn: Reward = None,
aux_info_collector: AuxiliaryInfoCollector[StateType, Any] = None,
logger: LogCollector = None,
reward_fn: Reward | None = None,
aux_info_collector: AuxiliaryInfoCollector[StateType, Any] | None = None,
logger: LogCollector | None = None,
) -> None:
# Assign weak reference to wrapper.
#

View File

@@ -397,7 +397,7 @@ class ConsoleWriter(LogWriter):
def __init__(
self,
log_every_n_episode: int = 20,
total_episodes: int = None,
total_episodes: int | None = None,
float_format: str = ":.4f",
counter_format: str = ":4d",
loglevel: int | LogLevel = LogLevel.PERIODIC,

Some files were not shown because too many files have changed in this diff Show More