mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-29 09:01:18 +08:00
Compare commits
89 Commits
v0.9.0
...
xuyang1/su
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2df211c320 | ||
|
|
effed382e9 | ||
|
|
86ffd1799d | ||
|
|
aef11536e3 | ||
|
|
8b0fdf1623 | ||
|
|
9a36f8da20 | ||
|
|
b7757d5008 | ||
|
|
ee5e5cfdd8 | ||
|
|
6cb87ecfd1 | ||
|
|
9119bcdd3c | ||
|
|
4fccf8112d | ||
|
|
73bd79ca1a | ||
|
|
7e84f3aae2 | ||
|
|
1326ac614d | ||
|
|
f12184cc0f | ||
|
|
a70386ad52 | ||
|
|
74619ed8d8 | ||
|
|
1a523df007 | ||
|
|
f9cc8a5aaa | ||
|
|
7762c5a1fd | ||
|
|
fa7ef29281 | ||
|
|
429c9a7c66 | ||
|
|
80fbc00792 | ||
|
|
01accec24c | ||
|
|
1d88830b0d | ||
|
|
ad7498e287 | ||
|
|
73d51f05b4 | ||
|
|
3b56b8e6c0 | ||
|
|
40e0c329ba | ||
|
|
e376648860 | ||
|
|
5f37f32184 | ||
|
|
d46b4c1ebf | ||
|
|
0515524b51 | ||
|
|
cda32d5703 | ||
|
|
e2332a004b | ||
|
|
08d9dbccc9 | ||
|
|
e7cd93a36d | ||
|
|
3919678028 | ||
|
|
421b1403b2 | ||
|
|
94102fb742 | ||
|
|
74a5d7c8af | ||
|
|
ce39b4b6f8 | ||
|
|
2af35d9c89 | ||
|
|
f37643550b | ||
|
|
55611aa43e | ||
|
|
f24253efd2 | ||
|
|
7c4f3b8a7d | ||
|
|
94268619c4 | ||
|
|
8d60a6a02b | ||
|
|
7234308651 | ||
|
|
acf5df27ce | ||
|
|
37a59f28d3 | ||
|
|
b084c352f5 | ||
|
|
9e22e5168b | ||
|
|
dceff7b471 | ||
|
|
7f1e8c5206 | ||
|
|
46264dfec9 | ||
|
|
754799ab05 | ||
|
|
32c3070b73 | ||
|
|
40de67265a | ||
|
|
e6f9a94fc5 | ||
|
|
73937863f1 | ||
|
|
d010219ba6 | ||
|
|
4fc8a5f25f | ||
|
|
0e8bfcb5d3 | ||
|
|
e457ca8511 | ||
|
|
4dbb8ecb86 | ||
|
|
653c082e7a | ||
|
|
f98e04ca9d | ||
|
|
76f2fb1a1a | ||
|
|
5eb5ac1f1f | ||
|
|
6295939346 | ||
|
|
5f3e322784 | ||
|
|
691b7f1f60 | ||
|
|
d8fc9aea6b | ||
|
|
d8764660dc | ||
|
|
7f08e6c7b3 | ||
|
|
0f3abfed74 | ||
|
|
44ce91ee9d | ||
|
|
ebb8ec34f3 | ||
|
|
4fe3ffccfd | ||
|
|
2f5ce3dc01 | ||
|
|
756bd0f65b | ||
|
|
667fb0e4d9 | ||
|
|
f326f83fae | ||
|
|
cbd69fb0ed | ||
|
|
5e3924d7a6 | ||
|
|
57f9813f85 | ||
|
|
26d24b5b23 |
6
.github/labeler.yml
vendored
Normal file
6
.github/labeler.yml
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
documentation:
|
||||
- 'docs/**/*'
|
||||
- '**/*.md'
|
||||
|
||||
waiting for triage:
|
||||
- any: ['**/*', '!docs/**/*', '!**/*.md']
|
||||
14
.github/workflows/labeler.yml
vendored
Normal file
14
.github/workflows/labeler.yml
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
name: "Add label automatically"
|
||||
on:
|
||||
- pull_request_target
|
||||
|
||||
jobs:
|
||||
triage:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@v4
|
||||
with:
|
||||
repo-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
2
.github/workflows/test_qlib_from_pip.yml
vendored
2
.github/workflows/test_qlib_from_pip.yml
vendored
@@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
|
||||
26
.github/workflows/test_qlib_from_source.yml
vendored
26
.github/workflows/test_qlib_from_source.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
@@ -28,8 +28,10 @@ jobs:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Update pip to the latest version
|
||||
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
|
||||
# The pip version has been temporarily fixed to 23.0.1
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install pip==23.0.1
|
||||
|
||||
- name: Installing pytorch for macos
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
@@ -37,15 +39,13 @@ jobs:
|
||||
python -m pip install torch torchvision torchaudio
|
||||
|
||||
- name: Installing pytorch for ubuntu
|
||||
if: ${{ matrix.os == 'ubuntu-18.04' || matrix.os == 'ubuntu-20.04' }}
|
||||
if: ${{ matrix.os == 'ubuntu-20.04' || matrix.os == 'ubuntu-22.04' }}
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
- name: Installing pytorch for windows
|
||||
if: ${{ matrix.os == 'windows-latest' }}
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install torch torchvision torchaudio
|
||||
|
||||
- name: Set up Python tools
|
||||
@@ -120,6 +120,11 @@ jobs:
|
||||
run: |
|
||||
mypy qlib --install-types --non-interactive || true
|
||||
mypy qlib --verbose
|
||||
|
||||
- name: Check Qlib ipynb with nbqa
|
||||
run: |
|
||||
nbqa black . -l 120 --check --diff
|
||||
nbqa pylint . --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719,W0104,W0404,C0412,W0611,C0410 --const-rgx='[a-z_][a-z0-9_]{2,30}$'
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
@@ -138,12 +143,15 @@ jobs:
|
||||
brew unlink libomp
|
||||
brew install libomp.rb
|
||||
|
||||
# Run after data downloads
|
||||
- name: Check Qlib ipynb with nbconvert
|
||||
run: |
|
||||
# add more ipynb files in future
|
||||
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
|
||||
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
# Version 0.52.0 of numba must be installed manually in CI, otherwise it will cause incompatibility with the latest version of numpy.
|
||||
python -m pip install numba==0.52.0
|
||||
# You must update numpy manually, because when installing python tools, it will try to uninstall numpy and cause CI to fail.
|
||||
python -m pip install --upgrade numpy
|
||||
python -m pip install numba
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
|
||||
@@ -14,7 +14,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
@@ -28,9 +28,10 @@ jobs:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python tools
|
||||
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
|
||||
# The pip version has been temporarily fixed to 23.0.1
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
# python -m pip is necessary to upgrade pip.
|
||||
python -m pip install pip==23.0.1
|
||||
pip install --upgrade cython numpy
|
||||
pip install -e .[dev]
|
||||
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -10,7 +10,6 @@ _build
|
||||
build/
|
||||
dist/
|
||||
|
||||
|
||||
*.pkl
|
||||
*.hd5
|
||||
*.csv
|
||||
@@ -23,10 +22,13 @@ dist/
|
||||
qlib/VERSION.txt
|
||||
qlib/data/_libs/expanding.cpp
|
||||
qlib/data/_libs/rolling.cpp
|
||||
qlib/finco/prompt_cache.json
|
||||
examples/estimator/estimator_example/
|
||||
examples/rl/data/
|
||||
examples/rl/checkpoints/
|
||||
examples/rl/outputs/
|
||||
examples/rl_order_execution/data/
|
||||
examples/rl_order_execution/outputs/
|
||||
|
||||
*.egg-info/
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
Recent released features
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| Release Qlib v0.9.0 | :octocat: [Released](https://github.com/microsoft/qlib/releases/tag/v0.9.0) on Dec 9, 2022 |
|
||||
| RL Learning Framework | :hammer: :chart_with_upwards_trend: Released on Nov 10, 2022. [#1332](https://github.com/microsoft/qlib/pull/1332), [#1322](https://github.com/microsoft/qlib/pull/1322), [#1316](https://github.com/microsoft/qlib/pull/1316),[#1299](https://github.com/microsoft/qlib/pull/1299),[#1263](https://github.com/microsoft/qlib/pull/1263), [#1244](https://github.com/microsoft/qlib/pull/1244), [#1169](https://github.com/microsoft/qlib/pull/1169), [#1125](https://github.com/microsoft/qlib/pull/1125), [#1076](https://github.com/microsoft/qlib/pull/1076)|
|
||||
| HIST and IGMTF models | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/1040) on Apr 10, 2022 |
|
||||
| Qlib [notebook tutorial](https://github.com/microsoft/qlib/tree/main/examples/tutorial) | 📖 [Released](https://github.com/microsoft/qlib/pull/1037) on Apr 7, 2022 |
|
||||
@@ -41,13 +42,11 @@ Features released before 2021 are not listed here.
|
||||
<img src="http://fintech.msra.cn/images_v070/logo/1.png" />
|
||||
</p>
|
||||
|
||||
Qlib is an open-source, AI-oriented quantitative investment platform that aims to realize the potential, empower research, and create value using AI technologies in quantitative investment, from exploring ideas to implementing productions. Qlib supports diverse machine learning modeling paradigms, including supervised learning, market dynamics modeling, and reinforcement learning.
|
||||
|
||||
Qlib is an AI-oriented quantitative investment platform, which aims to realize the potential, empower the research, and create the value of AI technologies in quantitative investment.
|
||||
An increasing number of SOTA Quant research works/papers in diverse paradigms are being released in Qlib to collaboratively solve key challenges in quantitative investment. For example, 1) using supervised learning to mine the market's complex non-linear patterns from rich and heterogeneous financial data, 2) modeling the dynamic nature of the financial market using adaptive concept drift technology, and 3) using reinforcement learning to model continuous investment decisions and assist investors in optimizing their trading strategies.
|
||||
|
||||
It contains the full ML pipeline of data processing, model training, back-testing; and covers the entire chain of quantitative investment: alpha seeking, risk modeling, portfolio optimization, and order execution.
|
||||
|
||||
With Qlib, users can easily try ideas to create better Quant investment strategies.
|
||||
|
||||
For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative Investment Platform"](https://arxiv.org/abs/2009.11189).
|
||||
|
||||
|
||||
|
||||
@@ -42,4 +42,8 @@ As you may have noticed, a training vessel itself holds all the required compone
|
||||
|
||||
With a training vessel, the trainer could finally launch the training pipeline by simple, Scikit-learn-like interfaces (i.e., ``trainer.fit()``).
|
||||
|
||||
The API for Trainer and TrainingVessel and can be found `here <../../reference/api.html#module-qlib.rl.trainer>`__.
|
||||
The API for Trainer and TrainingVessel and can be found `here <../../reference/api.html#module-qlib.rl.trainer>`__.
|
||||
|
||||
The RL module is designed in a loosely-coupled way. Currently, RL examples are integrated with concrete business logic.
|
||||
But the core part of RL is much simpler than what you see.
|
||||
To demonstrate the simple core of RL, `a dedicated notebook <https://github.com/microsoft/qlib/tree/main/examples/rl/simple_example.ipynb>`__ for RL without business loss is created.
|
||||
|
||||
@@ -29,13 +29,13 @@ class Avg15minHandler(DataHandlerLP):
|
||||
fit_end_time=None,
|
||||
process_type=DataHandlerLP.PTYPE_A,
|
||||
filter_pipe=None,
|
||||
inst_processor=None,
|
||||
inst_processors=None,
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
data_loader = Avg15minLoader(
|
||||
config=self.loader_config(), filter_pipe=filter_pipe, freq=freq, inst_processor=inst_processor
|
||||
config=self.loader_config(), filter_pipe=filter_pipe, freq=freq, inst_processors=inst_processors
|
||||
)
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
|
||||
@@ -18,7 +18,7 @@ data_handler_config: &data_handler_config
|
||||
label: day
|
||||
feature: 1min
|
||||
# with label as reference
|
||||
inst_processor:
|
||||
inst_processors:
|
||||
feature:
|
||||
- class: Resample1minProcessor
|
||||
module_path: features_sample.py
|
||||
|
||||
@@ -19,7 +19,7 @@ data_handler_config: &data_handler_config
|
||||
feature_15min: 1min
|
||||
feature_day: day
|
||||
# with label as reference
|
||||
inst_processor:
|
||||
inst_processors:
|
||||
feature_15min:
|
||||
- class: ResampleNProcessor
|
||||
module_path: features_resample_N.py
|
||||
|
||||
@@ -64,8 +64,6 @@ task:
|
||||
kwargs:
|
||||
loss: mse
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 8192
|
||||
|
||||
@@ -64,8 +64,6 @@ task:
|
||||
kwargs:
|
||||
loss: mse
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 8192
|
||||
|
||||
@@ -52,8 +52,6 @@ task:
|
||||
kwargs:
|
||||
loss: mse
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 4096
|
||||
|
||||
@@ -52,8 +52,6 @@ task:
|
||||
kwargs:
|
||||
loss: mse
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 4096
|
||||
|
||||
@@ -25,59 +25,65 @@
|
||||
"import seaborn as sns\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib\n",
|
||||
"sns.set(style='white')\n",
|
||||
"matplotlib.rcParams['pdf.fonttype'] = 42\n",
|
||||
"matplotlib.rcParams['ps.fonttype'] = 42\n",
|
||||
"\n",
|
||||
"sns.set(style=\"white\")\n",
|
||||
"matplotlib.rcParams[\"pdf.fonttype\"] = 42\n",
|
||||
"matplotlib.rcParams[\"ps.fonttype\"] = 42\n",
|
||||
"\n",
|
||||
"from tqdm.auto import tqdm\n",
|
||||
"from joblib import Parallel, delayed\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def func(x, N=80):\n",
|
||||
" ret = x.ret.copy()\n",
|
||||
" x = x.rank(pct=True)\n",
|
||||
" x['ret'] = ret\n",
|
||||
" x[\"ret\"] = ret\n",
|
||||
" diff = x.score.sub(x.label)\n",
|
||||
" r = x.nlargest(N, columns='score').ret.mean()\n",
|
||||
" r -= x.nsmallest(N, columns='score').ret.mean()\n",
|
||||
" return pd.Series({\n",
|
||||
" 'MSE': diff.pow(2).mean(), \n",
|
||||
" 'MAE': diff.abs().mean(), \n",
|
||||
" 'IC': x.score.corr(x.label),\n",
|
||||
" 'R': r\n",
|
||||
" })\n",
|
||||
" \n",
|
||||
" r = x.nlargest(N, columns=\"score\").ret.mean()\n",
|
||||
" r -= x.nsmallest(N, columns=\"score\").ret.mean()\n",
|
||||
" return pd.Series(\n",
|
||||
" {\n",
|
||||
" \"MSE\": diff.pow(2).mean(),\n",
|
||||
" \"MAE\": diff.abs().mean(),\n",
|
||||
" \"IC\": x.score.corr(x.label),\n",
|
||||
" \"R\": r,\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ret = pd.read_pickle(\"data/ret.pkl\").clip(-0.1, 0.1)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def backtest(fname, **kwargs):\n",
|
||||
" pred = pd.read_pickle(fname).loc['2018-09-21':'2020-06-30'] # test period\n",
|
||||
" pred['ret'] = ret\n",
|
||||
" pred = pd.read_pickle(fname).loc[\"2018-09-21\":\"2020-06-30\"] # test period\n",
|
||||
" pred[\"ret\"] = ret\n",
|
||||
" dates = pred.index.unique(level=0)\n",
|
||||
" res = Parallel(n_jobs=-1)(delayed(func)(pred.loc[d], **kwargs) for d in dates)\n",
|
||||
" res = {\n",
|
||||
" dates[i]: res[i]\n",
|
||||
" for i in range(len(dates))\n",
|
||||
" }\n",
|
||||
" res = {dates[i]: res[i] for i in range(len(dates))}\n",
|
||||
" res = pd.DataFrame(res).T\n",
|
||||
" r = res['R'].copy()\n",
|
||||
" r = res[\"R\"].copy()\n",
|
||||
" r.index = pd.to_datetime(r.index)\n",
|
||||
" r = r.reindex(pd.date_range(r.index[0], r.index[-1])).fillna(0) # paper use 365 days\n",
|
||||
" return {\n",
|
||||
" 'MSE': res['MSE'].mean(),\n",
|
||||
" 'MAE': res['MAE'].mean(),\n",
|
||||
" 'IC': res['IC'].mean(),\n",
|
||||
" 'ICIR': res['IC'].mean()/res['IC'].std(),\n",
|
||||
" 'AR': r.mean()*365,\n",
|
||||
" 'AV': r.std()*365**0.5,\n",
|
||||
" 'SR': r.mean()/r.std()*365**0.5,\n",
|
||||
" 'MDD': (r.cumsum().cummax() - r.cumsum()).max()\n",
|
||||
" \"MSE\": res[\"MSE\"].mean(),\n",
|
||||
" \"MAE\": res[\"MAE\"].mean(),\n",
|
||||
" \"IC\": res[\"IC\"].mean(),\n",
|
||||
" \"ICIR\": res[\"IC\"].mean() / res[\"IC\"].std(),\n",
|
||||
" \"AR\": r.mean() * 365,\n",
|
||||
" \"AV\": r.std() * 365**0.5,\n",
|
||||
" \"SR\": r.mean() / r.std() * 365**0.5,\n",
|
||||
" \"MDD\": (r.cumsum().cummax() - r.cumsum()).max(),\n",
|
||||
" }, r\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def fmt(x, p=3, scale=1, std=False):\n",
|
||||
" _fmt = '{:.%df}'%p\n",
|
||||
" _fmt = \"{:.%df}\" % p\n",
|
||||
" string = _fmt.format((x.mean() if not isinstance(x, (float, np.floating)) else x) * scale)\n",
|
||||
" if std and len(x) > 1:\n",
|
||||
" string += ' ('+_fmt.format(x.std()*scale)+')'\n",
|
||||
" string += \" (\" + _fmt.format(x.std() * scale) + \")\"\n",
|
||||
" return string\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def backtest_multi(files, **kwargs):\n",
|
||||
" res = []\n",
|
||||
" pnl = []\n",
|
||||
@@ -88,14 +94,14 @@
|
||||
" res = pd.DataFrame(res)\n",
|
||||
" pnl = pd.concat(pnl, axis=1)\n",
|
||||
" return {\n",
|
||||
" 'MSE': fmt(res['MSE'], std=True),\n",
|
||||
" 'MAE': fmt(res['MAE'], std=True),\n",
|
||||
" 'IC': fmt(res['IC']),\n",
|
||||
" 'ICIR': fmt(res['ICIR']),\n",
|
||||
" 'AR': fmt(res['AR'], scale=100, p=1)+'%',\n",
|
||||
" 'VR': fmt(res['AV'], scale=100, p=1)+'%',\n",
|
||||
" 'SR': fmt(res['SR']),\n",
|
||||
" 'MDD': fmt(res['MDD'], scale=100, p=1)+'%'\n",
|
||||
" \"MSE\": fmt(res[\"MSE\"], std=True),\n",
|
||||
" \"MAE\": fmt(res[\"MAE\"], std=True),\n",
|
||||
" \"IC\": fmt(res[\"IC\"]),\n",
|
||||
" \"ICIR\": fmt(res[\"ICIR\"]),\n",
|
||||
" \"AR\": fmt(res[\"AR\"], scale=100, p=1) + \"%\",\n",
|
||||
" \"VR\": fmt(res[\"AV\"], scale=100, p=1) + \"%\",\n",
|
||||
" \"SR\": fmt(res[\"SR\"]),\n",
|
||||
" \"MDD\": fmt(res[\"MDD\"], scale=100, p=1) + \"%\",\n",
|
||||
" }, pnl"
|
||||
]
|
||||
},
|
||||
@@ -124,16 +130,20 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"exps = {\n",
|
||||
" 'Linear': ['output/Linear/pred.pkl'],\n",
|
||||
" 'LightGBM': ['output/GBDT/lr0.05_leaves128/pred.pkl'],\n",
|
||||
" 'MLP': glob.glob('output/search/MLP/hs128_bs512_do0.3_lr0.001_seed*/pred.pkl'),\n",
|
||||
" 'SFM': glob.glob('output/search/SFM/hs32_bs512_do0.5_lr0.001_seed*/pred.pkl'),\n",
|
||||
" 'ALSTM': glob.glob('output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n",
|
||||
" 'Trans.': glob.glob('output/search/Transformer/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n",
|
||||
" 'ALSTM+TS':glob.glob('output/LSTM_Attn_TS/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n",
|
||||
" 'Trans.+TS':glob.glob('output/Transformer_TS/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n",
|
||||
" 'ALSTM+TRA(Ours)': glob.glob('output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n",
|
||||
" 'Trans.+TRA(Ours)': glob.glob('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb1.0_head4_hs64_bs512_do0.1_lr0.0005_seed*/pred.pkl')\n",
|
||||
" \"Linear\": [\"output/Linear/pred.pkl\"],\n",
|
||||
" \"LightGBM\": [\"output/GBDT/lr0.05_leaves128/pred.pkl\"],\n",
|
||||
" \"MLP\": glob.glob(\"output/search/MLP/hs128_bs512_do0.3_lr0.001_seed*/pred.pkl\"),\n",
|
||||
" \"SFM\": glob.glob(\"output/search/SFM/hs32_bs512_do0.5_lr0.001_seed*/pred.pkl\"),\n",
|
||||
" \"ALSTM\": glob.glob(\"output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl\"),\n",
|
||||
" \"Trans.\": glob.glob(\"output/search/Transformer/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl\"),\n",
|
||||
" \"ALSTM+TS\": glob.glob(\"output/LSTM_Attn_TS/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl\"),\n",
|
||||
" \"Trans.+TS\": glob.glob(\"output/Transformer_TS/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl\"),\n",
|
||||
" \"ALSTM+TRA(Ours)\": glob.glob(\n",
|
||||
" \"output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n",
|
||||
" ),\n",
|
||||
" \"Trans.+TRA(Ours)\": glob.glob(\n",
|
||||
" \"output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb1.0_head4_hs64_bs512_do0.1_lr0.0005_seed*/pred.pkl\"\n",
|
||||
" ),\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
@@ -160,14 +170,8 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"res = {\n",
|
||||
" name: backtest_multi(exps[name])\n",
|
||||
" for name in tqdm(exps)\n",
|
||||
"}\n",
|
||||
"report = pd.DataFrame({\n",
|
||||
" k: v[0]\n",
|
||||
" for k, v in res.items()\n",
|
||||
"}).T"
|
||||
"res = {name: backtest_multi(exps[name]) for name in tqdm(exps)}\n",
|
||||
"report = pd.DataFrame({k: v[0] for k, v in res.items()}).T"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -385,24 +389,40 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"df = pd.read_pickle('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed1000/pred.pkl')\n",
|
||||
"code = 'SH600157'\n",
|
||||
"date = '2018-09-28'\n",
|
||||
"df = pd.read_pickle(\n",
|
||||
" \"output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed1000/pred.pkl\"\n",
|
||||
")\n",
|
||||
"code = \"SH600157\"\n",
|
||||
"date = \"2018-09-28\"\n",
|
||||
"lookbackperiod = 50\n",
|
||||
"\n",
|
||||
"prob = df.iloc[:, -3:].loc(axis=0)[:, code].reset_index(level=1, drop=True).loc[date:].iloc[:lookbackperiod]\n",
|
||||
"pred = df.loc[:,[\"score_0\",\"score_1\",\"score_2\",\"label\"]].loc(axis=0)[:, code].reset_index(level=1, drop=True).loc[date:].iloc[:lookbackperiod]\n",
|
||||
"e_all = pred.iloc[:,:-1].sub(pred.iloc[:,-1], axis=0).pow(2)\n",
|
||||
"pred = (\n",
|
||||
" df.loc[:, [\"score_0\", \"score_1\", \"score_2\", \"label\"]]\n",
|
||||
" .loc(axis=0)[:, code]\n",
|
||||
" .reset_index(level=1, drop=True)\n",
|
||||
" .loc[date:]\n",
|
||||
" .iloc[:lookbackperiod]\n",
|
||||
")\n",
|
||||
"e_all = pred.iloc[:, :-1].sub(pred.iloc[:, -1], axis=0).pow(2)\n",
|
||||
"e_all = e_all.sub(e_all.min(axis=1), axis=0)\n",
|
||||
"e_all.columns = [r'$\\theta_%d$'%d for d in range(1, 4)]\n",
|
||||
"e_all.columns = [r\"$\\theta_%d$\" % d for d in range(1, 4)]\n",
|
||||
"prob = pd.Series(np.argmax(prob.values, axis=1), index=prob.index).rolling(7).mean().round()\n",
|
||||
"\n",
|
||||
"fig, axes = plt.subplots(1, 2, figsize=(7, 3))\n",
|
||||
"e_all.plot(ax=axes[0], xlabel='', rot=30)\n",
|
||||
"prob.plot(ax=axes[1], xlabel='', rot=30, color='red', linestyle='None', marker='^', markersize=5)\n",
|
||||
"e_all.plot(ax=axes[0], xlabel=\"\", rot=30)\n",
|
||||
"prob.plot(\n",
|
||||
" ax=axes[1],\n",
|
||||
" xlabel=\"\",\n",
|
||||
" rot=30,\n",
|
||||
" color=\"red\",\n",
|
||||
" linestyle=\"None\",\n",
|
||||
" marker=\"^\",\n",
|
||||
" markersize=5,\n",
|
||||
")\n",
|
||||
"plt.yticks(np.array([0, 1, 2]), e_all.columns.values)\n",
|
||||
"axes[0].set_ylabel('Predictor Loss')\n",
|
||||
"axes[1].set_ylabel('Router Selection')\n",
|
||||
"axes[0].set_ylabel(\"Predictor Loss\")\n",
|
||||
"axes[1].set_ylabel(\"Router Selection\")\n",
|
||||
"plt.tight_layout()\n",
|
||||
"# plt.savefig('select.pdf', bbox_inches='tight')\n",
|
||||
"plt.show()"
|
||||
@@ -428,10 +448,18 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"exps = {\n",
|
||||
" 'Random': glob.glob('output/search/LSTM_Attn_tra/K10_traHs16_traSrcNONE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n",
|
||||
" 'LR': glob.glob('output/search/LSTM_Attn_tra/K10_traHs16_traSrcLR_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n",
|
||||
" 'TPE': glob.glob('output/search/LSTM_Attn_tra/K10_traHs16_traSrcTPE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n",
|
||||
" 'LR+TPE': glob.glob('output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl')\n",
|
||||
" \"Random\": glob.glob(\n",
|
||||
" \"output/search/LSTM_Attn_tra/K10_traHs16_traSrcNONE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n",
|
||||
" ),\n",
|
||||
" \"LR\": glob.glob(\n",
|
||||
" \"output/search/LSTM_Attn_tra/K10_traHs16_traSrcLR_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n",
|
||||
" ),\n",
|
||||
" \"TPE\": glob.glob(\n",
|
||||
" \"output/search/LSTM_Attn_tra/K10_traHs16_traSrcTPE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n",
|
||||
" ),\n",
|
||||
" \"LR+TPE\": glob.glob(\n",
|
||||
" \"output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n",
|
||||
" ),\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
@@ -456,14 +484,8 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"res = {\n",
|
||||
" name: backtest_multi(exps[name])\n",
|
||||
" for name in tqdm(exps)\n",
|
||||
"}\n",
|
||||
"report = pd.DataFrame({\n",
|
||||
" k: v[0]\n",
|
||||
" for k, v in res.items()\n",
|
||||
"}).T"
|
||||
"res = {name: backtest_multi(exps[name]) for name in tqdm(exps)}\n",
|
||||
"report = pd.DataFrame({k: v[0] for k, v in res.items()}).T"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -597,18 +619,22 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"a = pd.read_pickle('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl')\n",
|
||||
"b = pd.read_pickle('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl')\n",
|
||||
"a = pd.read_pickle(\n",
|
||||
" \"output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl\"\n",
|
||||
")\n",
|
||||
"b = pd.read_pickle(\n",
|
||||
" \"output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl\"\n",
|
||||
")\n",
|
||||
"a = a.iloc[:, -3:]\n",
|
||||
"b = b.iloc[:, -3:]\n",
|
||||
"b = np.eye(3)[b.values.argmax(axis=1)]\n",
|
||||
"a = np.eye(3)[a.values.argmax(axis=1)]\n",
|
||||
"\n",
|
||||
"res = pd.DataFrame({\n",
|
||||
" 'with OT': b.sum(axis=0) / b.sum(),\n",
|
||||
" 'without OT': a.sum(axis=0)/ a.sum() \n",
|
||||
"},index=[r'$\\theta_1$',r'$\\theta_2$',r'$\\theta_3$'])\n",
|
||||
"res.plot.bar(rot=30, figsize=(5, 4), color=['b', 'g'])\n",
|
||||
"res = pd.DataFrame(\n",
|
||||
" {\"with OT\": b.sum(axis=0) / b.sum(), \"without OT\": a.sum(axis=0) / a.sum()},\n",
|
||||
" index=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"],\n",
|
||||
")\n",
|
||||
"res.plot.bar(rot=30, figsize=(5, 4), color=[\"b\", \"g\"])\n",
|
||||
"del a, b"
|
||||
]
|
||||
},
|
||||
@@ -633,11 +659,19 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"exps = {\n",
|
||||
" 'K=1': glob.glob('output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/info.json'),\n",
|
||||
" 'K=3': glob.glob('output/search/finetune/LSTM_Attn_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json'),\n",
|
||||
" 'K=5': glob.glob('output/search/finetune/LSTM_Attn_tra/K5_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json'),\n",
|
||||
" 'K=10': glob.glob('output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json'),\n",
|
||||
" 'K=20': glob.glob('output/search/finetune/LSTM_Attn_tra/K20_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json')\n",
|
||||
" \"K=1\": glob.glob(\"output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/info.json\"),\n",
|
||||
" \"K=3\": glob.glob(\n",
|
||||
" \"output/search/finetune/LSTM_Attn_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json\"\n",
|
||||
" ),\n",
|
||||
" \"K=5\": glob.glob(\n",
|
||||
" \"output/search/finetune/LSTM_Attn_tra/K5_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json\"\n",
|
||||
" ),\n",
|
||||
" \"K=10\": glob.glob(\n",
|
||||
" \"output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json\"\n",
|
||||
" ),\n",
|
||||
" \"K=20\": glob.glob(\n",
|
||||
" \"output/search/finetune/LSTM_Attn_tra/K20_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json\"\n",
|
||||
" ),\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
@@ -649,16 +683,11 @@
|
||||
"source": [
|
||||
"report = dict()\n",
|
||||
"for k, v in exps.items():\n",
|
||||
" \n",
|
||||
" tmp = dict()\n",
|
||||
" for fname in v:\n",
|
||||
" with open(fname) as f:\n",
|
||||
" info = json.load(f)\n",
|
||||
" tmp[fname] = (\n",
|
||||
" {\n",
|
||||
" \"IC\":info[\"metric\"][\"IC\"],\n",
|
||||
" \"MSE\":info[\"metric\"][\"MSE\"]\n",
|
||||
" })\n",
|
||||
" tmp[fname] = {\"IC\": info[\"metric\"][\"IC\"], \"MSE\": info[\"metric\"][\"MSE\"]}\n",
|
||||
" tmp = pd.DataFrame(tmp).T\n",
|
||||
" report[k] = tmp.mean()\n",
|
||||
"report = pd.DataFrame(report).T"
|
||||
@@ -681,13 +710,14 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"fig, axes = plt.subplots(1, 2, figsize=(6,3)); axes = axes.flatten()\n",
|
||||
"report['IC'].plot.bar(rot=30, ax=axes[0])\n",
|
||||
"fig, axes = plt.subplots(1, 2, figsize=(6, 3))\n",
|
||||
"axes = axes.flatten()\n",
|
||||
"report[\"IC\"].plot.bar(rot=30, ax=axes[0])\n",
|
||||
"axes[0].set_ylim(0.045, 0.062)\n",
|
||||
"axes[0].set_title('IC performance')\n",
|
||||
"report['MSE'].astype(float).plot.bar(rot=30, ax=axes[1], color='green')\n",
|
||||
"axes[0].set_title(\"IC performance\")\n",
|
||||
"report[\"MSE\"].astype(float).plot.bar(rot=30, ax=axes[1], color=\"green\")\n",
|
||||
"axes[1].set_ylim(0.155, 0.1585)\n",
|
||||
"axes[1].set_title('MSE performance')\n",
|
||||
"axes[1].set_title(\"MSE performance\")\n",
|
||||
"plt.tight_layout()\n",
|
||||
"# plt.savefig('sensitivity.pdf')"
|
||||
]
|
||||
|
||||
107
examples/benchmarks_dynamic/DDG-DA/vis_data.py
Normal file
107
examples/benchmarks_dynamic/DDG-DA/vis_data.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
sns.set(color_codes=True)
|
||||
plt.rcParams["font.sans-serif"] = "SimHei"
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
# tqdm.pandas() # for progress_apply
|
||||
# %matplotlib inline
|
||||
# %load_ext autoreload
|
||||
|
||||
|
||||
# # Meta Input
|
||||
|
||||
# +
|
||||
with open("./internal_data_s20.pkl", "rb") as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
data.data_ic_df.columns.names = ["start_date", "end_date"]
|
||||
|
||||
data_sim = data.data_ic_df.droplevel(axis=1, level="end_date")
|
||||
|
||||
data_sim.index.name = "test datetime"
|
||||
# -
|
||||
|
||||
plt.figure(figsize=(40, 20))
|
||||
sns.heatmap(data_sim)
|
||||
|
||||
plt.figure(figsize=(40, 20))
|
||||
sns.heatmap(data_sim.rolling(20).mean())
|
||||
|
||||
# # Meta Model
|
||||
|
||||
from qlib import auto_init
|
||||
|
||||
auto_init()
|
||||
from qlib.workflow import R
|
||||
|
||||
exp = R.get_exp(experiment_name="DDG-DA")
|
||||
meta_rec = exp.list_recorders(rtype="list", max_results=1)[0]
|
||||
meta_m = meta_rec.load_object("model")
|
||||
|
||||
pd.DataFrame(meta_m.tn.twm.linear.weight.detach().numpy()).T[0].plot()
|
||||
|
||||
pd.DataFrame(meta_m.tn.twm.linear.weight.detach().numpy()).T[0].rolling(5).mean().plot()
|
||||
|
||||
# # Meta Output
|
||||
|
||||
# +
|
||||
with open("./tasks_s20.pkl", "rb") as f:
|
||||
tasks = pickle.load(f)
|
||||
|
||||
task_df = {}
|
||||
for t in tasks:
|
||||
test_seg = t["dataset"]["kwargs"]["segments"]["test"]
|
||||
if None not in test_seg:
|
||||
# The last rolling is skipped.
|
||||
task_df[test_seg] = t["reweighter"].time_weight
|
||||
task_df = pd.concat(task_df)
|
||||
|
||||
task_df.index.names = ["OS_start", "OS_end", "IS_start", "IS_end"]
|
||||
task_df = task_df.droplevel(["OS_end", "IS_end"])
|
||||
task_df = task_df.unstack("OS_start")
|
||||
# -
|
||||
|
||||
plt.figure(figsize=(40, 20))
|
||||
sns.heatmap(task_df.T)
|
||||
|
||||
plt.figure(figsize=(40, 20))
|
||||
sns.heatmap(task_df.rolling(10).mean().T)
|
||||
|
||||
# # Sub Models
|
||||
#
|
||||
# NOTE:
|
||||
# - this section assumes that the model is Linear model!!
|
||||
# - Other models does not support this analysis
|
||||
|
||||
exp = R.get_exp(experiment_name="rolling_ds")
|
||||
|
||||
|
||||
def show_linear_weight(exp):
|
||||
coef_df = {}
|
||||
for r in exp.list_recorders("list"):
|
||||
t = r.load_object("task")
|
||||
if None in t["dataset"]["kwargs"]["segments"]["test"]:
|
||||
continue
|
||||
m = r.load_object("params.pkl")
|
||||
coef_df[t["dataset"]["kwargs"]["segments"]["test"]] = pd.Series(m.coef_)
|
||||
|
||||
coef_df = pd.concat(coef_df)
|
||||
|
||||
coef_df.index.names = ["test_start", "test_end", "coef_idx"]
|
||||
|
||||
coef_df = coef_df.droplevel("test_end").unstack("coef_idx").T
|
||||
|
||||
plt.figure(figsize=(40, 20))
|
||||
sns.heatmap(coef_df)
|
||||
plt.show()
|
||||
|
||||
|
||||
show_linear_weight(R.get_exp(experiment_name="rolling_ds"))
|
||||
|
||||
show_linear_weight(R.get_exp(experiment_name="rolling_models"))
|
||||
@@ -10,8 +10,10 @@ import pandas as pd
|
||||
import fire
|
||||
import sys
|
||||
import pickle
|
||||
from typing import Optional
|
||||
from qlib import auto_init
|
||||
from qlib.model.trainer import TrainerR
|
||||
from qlib.typehint import Literal
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.tests.data import GetData
|
||||
@@ -30,7 +32,33 @@ class DDGDA:
|
||||
- `rm -r mlruns`
|
||||
"""
|
||||
|
||||
def __init__(self, sim_task_model="linear", forecast_model="linear"):
|
||||
def __init__(
|
||||
self,
|
||||
sim_task_model: Literal["linear", "gbdt"] = "linear",
|
||||
forecast_model: Literal["linear", "gbdt"] = "linear",
|
||||
h_path: Optional[str] = None,
|
||||
test_end: Optional[str] = None,
|
||||
train_start: Optional[str] = None,
|
||||
meta_1st_train_end: Optional[str] = None,
|
||||
task_ext_conf: Optional[dict] = None,
|
||||
alpha: float = 0.0,
|
||||
proxy_hd: str = "handler_proxy.pkl",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
train_start: Optional[str]
|
||||
the start datetime for data. It is used in training start time (for both tasks & meta learing)
|
||||
test_end: Optional[str]
|
||||
the end datetime for data. It is used in test end time
|
||||
meta_1st_train_end: Optional[str]
|
||||
the datetime of training end of the first meta_task
|
||||
alpha: float
|
||||
Setting the L2 regularization for ridge
|
||||
The `alpha` is only passed to MetaModelDS (it is not passed to sim_task_model currently..)
|
||||
"""
|
||||
self.step = 20
|
||||
# NOTE:
|
||||
# the horizon must match the meaning in the base task template
|
||||
@@ -38,10 +66,19 @@ class DDGDA:
|
||||
self.meta_exp_name = "DDG-DA"
|
||||
self.sim_task_model = sim_task_model # The model to capture the distribution of data.
|
||||
self.forecast_model = forecast_model # downstream forecasting models' type
|
||||
self.rb_kwargs = {
|
||||
"h_path": h_path,
|
||||
"test_end": test_end,
|
||||
"train_start": train_start,
|
||||
"task_ext_conf": task_ext_conf,
|
||||
}
|
||||
self.alpha = alpha
|
||||
self.meta_1st_train_end = meta_1st_train_end
|
||||
self.proxy_hd = proxy_hd
|
||||
|
||||
def get_feature_importance(self):
|
||||
# this must be lightGBM, because it needs to get the feature importance
|
||||
rb = RollingBenchmark(model_type="gbdt")
|
||||
rb = RollingBenchmark(model_type="gbdt", **self.rb_kwargs)
|
||||
task = rb.basic_task()
|
||||
|
||||
with R.start(experiment_name="feature_importance"):
|
||||
@@ -69,7 +106,7 @@ class DDGDA:
|
||||
fi = self.get_feature_importance()
|
||||
col_selected = fi.nlargest(topk)
|
||||
|
||||
rb = RollingBenchmark(model_type=self.sim_task_model)
|
||||
rb = RollingBenchmark(model_type=self.sim_task_model, **self.rb_kwargs)
|
||||
task = rb.basic_task()
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
prep_ds = dataset.prepare(slice(None), col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
@@ -96,7 +133,7 @@ class DDGDA:
|
||||
"kwargs": {"config": DIRNAME / "fea_label_df.pkl"},
|
||||
}
|
||||
)
|
||||
handler.to_pickle(DIRNAME / "handler_proxy.pkl", dump_all=True)
|
||||
handler.to_pickle(DIRNAME / self.proxy_hd, dump_all=True)
|
||||
|
||||
@property
|
||||
def _internal_data_path(self):
|
||||
@@ -108,7 +145,7 @@ class DDGDA:
|
||||
This function will dump the input data for meta model
|
||||
"""
|
||||
# According to the experiments, the choice of the model type is very important for achieving good results
|
||||
rb = RollingBenchmark(model_type=self.sim_task_model)
|
||||
rb = RollingBenchmark(model_type=self.sim_task_model, **self.rb_kwargs)
|
||||
sim_task = rb.basic_task()
|
||||
|
||||
if self.sim_task_model == "gbdt":
|
||||
@@ -122,24 +159,27 @@ class DDGDA:
|
||||
with self._internal_data_path.open("wb") as f:
|
||||
pickle.dump(internal_data, f)
|
||||
|
||||
def train_meta_model(self):
|
||||
def train_meta_model(self, fill_method="max"):
|
||||
"""
|
||||
training a meta model based on a simplified linear proxy model;
|
||||
"""
|
||||
|
||||
# 1) leverage the simplified proxy forecasting model to train meta model.
|
||||
# - Only the dataset part is important, in current version of meta model will integrate the
|
||||
rb = RollingBenchmark(model_type=self.sim_task_model)
|
||||
rb = RollingBenchmark(model_type=self.sim_task_model, **self.rb_kwargs)
|
||||
sim_task = rb.basic_task()
|
||||
train_start = self.rb_kwargs.get("train_start", "2008-01-01")
|
||||
train_end = "2010-12-31" if self.meta_1st_train_end is None else self.meta_1st_train_end
|
||||
test_start = (pd.Timestamp(train_end) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
proxy_forecast_model_task = {
|
||||
# "model": "qlib.contrib.model.linear.LinearModel",
|
||||
"dataset": {
|
||||
"class": "qlib.data.dataset.DatasetH",
|
||||
"kwargs": {
|
||||
"handler": f"file://{(DIRNAME / 'handler_proxy.pkl').absolute()}",
|
||||
"handler": f"file://{(DIRNAME / self.proxy_hd).absolute()}",
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2010-12-31"),
|
||||
"test": ("2011-01-01", sim_task["dataset"]["kwargs"]["segments"]["test"][1]),
|
||||
"train": (train_start, train_end),
|
||||
"test": (test_start, sim_task["dataset"]["kwargs"]["segments"]["test"][1]),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -156,7 +196,7 @@ class DDGDA:
|
||||
segments=0.62, # keep test period consistent with the dataset yaml
|
||||
trunc_days=1 + self.horizon,
|
||||
hist_step_n=30,
|
||||
fill_method="max",
|
||||
fill_method=fill_method,
|
||||
rolling_ext_days=0,
|
||||
)
|
||||
# NOTE:
|
||||
@@ -165,12 +205,15 @@ class DDGDA:
|
||||
# So the misalignment will not affect the effectiveness of the method.
|
||||
with self._internal_data_path.open("rb") as f:
|
||||
internal_data = pickle.load(f)
|
||||
|
||||
md = MetaDatasetDS(exp_name=internal_data, **kwargs)
|
||||
|
||||
# 3) train and logging meta model
|
||||
with R.start(experiment_name=self.meta_exp_name):
|
||||
R.log_params(**kwargs)
|
||||
mm = MetaModelDS(step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=200, seed=43)
|
||||
mm = MetaModelDS(
|
||||
step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=100, seed=43, alpha=self.alpha
|
||||
)
|
||||
mm.fit(md)
|
||||
R.save_objects(model=mm)
|
||||
|
||||
@@ -203,7 +246,7 @@ class DDGDA:
|
||||
hist_step_n = int(param["hist_step_n"])
|
||||
fill_method = param.get("fill_method", "max")
|
||||
|
||||
rb = RollingBenchmark(model_type=self.forecast_model)
|
||||
rb = RollingBenchmark(model_type=self.forecast_model, **self.rb_kwargs)
|
||||
task_l = rb.create_rolling_tasks()
|
||||
|
||||
# 2.2) create meta dataset for final dataset
|
||||
@@ -233,13 +276,13 @@ class DDGDA:
|
||||
"""
|
||||
with self._task_path.open("rb") as f:
|
||||
tasks = pickle.load(f)
|
||||
rb = RollingBenchmark(rolling_exp="rolling_ds", model_type=self.forecast_model)
|
||||
rb = RollingBenchmark(rolling_exp="rolling_ds", model_type=self.forecast_model, **self.rb_kwargs)
|
||||
rb.train_rolling_tasks(tasks)
|
||||
rb.ens_rolling()
|
||||
rb.update_rolling_rec()
|
||||
|
||||
def run_all(self):
|
||||
# 1) file: handler_proxy.pkl
|
||||
# 1) file: handler_proxy.pkl (self.proxy_hd)
|
||||
self.dump_data_for_proxy_model()
|
||||
# 2)
|
||||
# file: internal_data_s20.pkl
|
||||
|
||||
@@ -4,15 +4,21 @@ So adapting the forecasting models/strategies to market dynamics is very importa
|
||||
|
||||
The table below shows the performances of different solutions on different forecasting models.
|
||||
|
||||
## Alpha158 dataset
|
||||
## Alpha158 Dataset
|
||||
Here is the [crowd sourced version of qlib data](data_collector/crowd_source/README.md): https://github.com/chenditc/investment_data/releases
|
||||
```bash
|
||||
wget https://github.com/chenditc/investment_data/releases/download/20220720/qlib_bin.tar.gz
|
||||
tar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=2
|
||||
```
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|------------------|---------|----|------|---------|-----------|-------------------|-------------------|--------------|
|
||||
| RR[Linear] |Alpha158 |0.088|0.570|0.102 |0.622 |0.077 |1.175 |-0.086 |
|
||||
| DDG-DA[Linear] |Alpha158 |0.093|0.622|0.106 |0.670 |0.085 |1.213 |-0.093 |
|
||||
| RR[LightGBM] |Alpha158 |0.079|0.566|0.088 |0.592 |0.075 |1.226 |-0.096 |
|
||||
| DDG-DA[LightGBM] |Alpha158 |0.084|0.639|0.093 |0.664 |0.099 |1.442 |-0.071 |
|
||||
| RR[Linear] |Alpha158 |0.089|0.577|0.102 |0.627 |0.093 |1.458 |-0.073 |
|
||||
| DDG-DA[Linear] |Alpha158 |0.096|0.636|0.107 |0.677 |0.067 |0.996 |-0.091 |
|
||||
| RR[LightGBM] |Alpha158 |0.082|0.589|0.091 |0.626 |0.077 |1.320 |-0.091 |
|
||||
| DDG-DA[LightGBM] |Alpha158 |0.085|0.658|0.094 |0.686 |0.115 |1.792 |-0.068 |
|
||||
|
||||
- The label horizon of the `Alpha158` dataset is set to 20.
|
||||
- The rolling time intervals are set to 20 trading days.
|
||||
- The test rolling periods are from January 2017 to August 2020.
|
||||
- The results are based on the crowd-sourced version. The Yahoo version of qlib data does not contain `VWAP`, so all related factors are missing and filled with 0, which leads to a rank-deficient matrix (a matrix does not have full rank) and makes lower-level optimization of DDG-DA can not be solved.
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from typing import Optional
|
||||
from qlib.model.ens.ensemble import RollingEnsemble
|
||||
from qlib.utils import init_instance_by_config
|
||||
import fire
|
||||
import yaml
|
||||
import pandas as pd
|
||||
from qlib import auto_init
|
||||
from pathlib import Path
|
||||
from tqdm.auto import tqdm
|
||||
from qlib.model.trainer import TrainerR
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils.data import update_config
|
||||
from qlib.workflow import R
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
@@ -25,11 +29,40 @@ class RollingBenchmark:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, rolling_exp="rolling_models", model_type="linear") -> None:
|
||||
def __init__(
|
||||
self,
|
||||
rolling_exp: str = "rolling_models",
|
||||
model_type: str = "linear",
|
||||
h_path: Optional[str] = None,
|
||||
train_start: Optional[str] = None,
|
||||
test_end: Optional[str] = None,
|
||||
task_ext_conf: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
rolling_exp : str
|
||||
The name for the experiments for rolling
|
||||
model_type : str
|
||||
The model to be boosted.
|
||||
h_path : Optional[str]
|
||||
the dumped data handler;
|
||||
test_end : Optional[str]
|
||||
the test end for the data. It is typically used together with the handler
|
||||
train_start : Optional[str]
|
||||
the train start for the data. It is typically used together with the handler.
|
||||
task_ext_conf : Optional[dict]
|
||||
some option to update the
|
||||
"""
|
||||
self.step = 20
|
||||
self.horizon = 20
|
||||
self.rolling_exp = rolling_exp
|
||||
self.model_type = model_type
|
||||
self.h_path = h_path
|
||||
self.train_start = train_start
|
||||
self.test_end = test_end
|
||||
self.logger = get_module_logger("RollingBenchmark")
|
||||
self.task_ext_conf = task_ext_conf
|
||||
|
||||
def basic_task(self):
|
||||
"""For fast training rolling"""
|
||||
@@ -42,6 +75,10 @@ class RollingBenchmark:
|
||||
h_path = DIRNAME / "linear_alpha158_handler_horizon{}.pkl".format(self.horizon)
|
||||
else:
|
||||
raise AssertionError("Model type is not supported!")
|
||||
|
||||
if self.h_path is not None:
|
||||
h_path = Path(self.h_path)
|
||||
|
||||
with conf_path.open("r") as f:
|
||||
conf = yaml.safe_load(f)
|
||||
|
||||
@@ -52,6 +89,9 @@ class RollingBenchmark:
|
||||
|
||||
task = conf["task"]
|
||||
|
||||
if self.task_ext_conf is not None:
|
||||
task = update_config(task, self.task_ext_conf)
|
||||
|
||||
if not h_path.exists():
|
||||
h_conf = task["dataset"]["kwargs"]["handler"]
|
||||
h = init_instance_by_config(h_conf)
|
||||
@@ -59,6 +99,15 @@ class RollingBenchmark:
|
||||
|
||||
task["dataset"]["kwargs"]["handler"] = f"file://{h_path}"
|
||||
task["record"] = ["qlib.workflow.record_temp.SignalRecord"]
|
||||
|
||||
if self.train_start is not None:
|
||||
seg = task["dataset"]["kwargs"]["segments"]["train"]
|
||||
task["dataset"]["kwargs"]["segments"]["train"] = pd.Timestamp(self.train_start), seg[1]
|
||||
|
||||
if self.test_end is not None:
|
||||
seg = task["dataset"]["kwargs"]["segments"]["test"]
|
||||
task["dataset"]["kwargs"]["segments"]["test"] = seg[0], pd.Timestamp(self.test_end)
|
||||
self.logger.info(task)
|
||||
return task
|
||||
|
||||
def create_rolling_tasks(self):
|
||||
@@ -93,7 +142,7 @@ class RollingBenchmark:
|
||||
"""
|
||||
Evaluate the combined rolling results
|
||||
"""
|
||||
for rid, rec in R.list_recorders(experiment_name=self.COMB_EXP).items():
|
||||
for _, rec in R.list_recorders(experiment_name=self.COMB_EXP).items():
|
||||
for rt_cls in SigAnaRecord, PortAnaRecord:
|
||||
rt = rt_cls(recorder=rec, skip_existing=True)
|
||||
rt.generate()
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
This folder contains a simple example of how to run Qlib RL. It contains:
|
||||
|
||||
```
|
||||
.
|
||||
├── experiment_config
|
||||
│ ├── backtest # Backtest config
|
||||
│ └── training # Training config
|
||||
├── README.md # Readme (the current file)
|
||||
└── scripts # Scripts for data pre-processing
|
||||
```
|
||||
|
||||
## Data preparation
|
||||
|
||||
Use [AzCopy](https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-v10) to download data:
|
||||
|
||||
```
|
||||
azcopy copy https://qlibpublic.blob.core.windows.net/data/rl/qlib_rl_example_data ./ --recursive
|
||||
mv qlib_rl_example_data data
|
||||
```
|
||||
|
||||
The downloaded data will be placed at `./data`. The original data are in `data/csv`. To create all data needed by the case, run:
|
||||
|
||||
```
|
||||
bash scripts/data_pipeline.sh
|
||||
```
|
||||
|
||||
After the execution finishes, the `data/` directory should be like:
|
||||
|
||||
```
|
||||
data
|
||||
├── backtest_orders.csv
|
||||
├── bin
|
||||
├── csv
|
||||
├── pickle
|
||||
├── pickle_dataframe
|
||||
└── training_order_split
|
||||
```
|
||||
|
||||
## Run training
|
||||
|
||||
Run:
|
||||
|
||||
```
|
||||
python -m qlib.rl.contrib.train_onpolicy --config_path ./experiment_config/training/config.yml
|
||||
```
|
||||
|
||||
After training, checkpoints will be stored under `checkpoints/`.
|
||||
|
||||
## Run backtest
|
||||
|
||||
```
|
||||
python -m qlib.rl.contrib.backtest --config_path ./experiment_config/backtest/config.yml
|
||||
```
|
||||
|
||||
The backtest workflow will use the trained model in `checkpoints/`. The backtest summary can be found in `outputs/`.
|
||||
@@ -1,57 +0,0 @@
|
||||
order_file: ./data/backtest_orders.csv
|
||||
start_time: "9:45"
|
||||
end_time: "14:44"
|
||||
qlib:
|
||||
provider_uri_1min: ./data/bin
|
||||
feature_root_dir: ./data/pickle
|
||||
feature_columns_today: [
|
||||
"$open", "$high", "$low", "$close", "$vwap", "$volume",
|
||||
]
|
||||
feature_columns_yesterday: [
|
||||
"$open_v1", "$high_v1", "$low_v1", "$close_v1", "$vwap_v1", "$volume_v1",
|
||||
]
|
||||
exchange:
|
||||
limit_threshold: ['$close == 0', '$close == 0']
|
||||
deal_price: ["If($close == 0, $vwap, $close)", "If($close == 0, $vwap, $close)"]
|
||||
volume_threshold:
|
||||
all: ["cum", "0.2 * DayCumsum($volume, '9:45', '14:44')"]
|
||||
buy: ["current", "$close"]
|
||||
sell: ["current", "$close"]
|
||||
strategies:
|
||||
30min:
|
||||
class: TWAPStrategy
|
||||
module_path: qlib.contrib.strategy.rule_strategy
|
||||
kwargs: {}
|
||||
1day:
|
||||
class: SAOEIntStrategy
|
||||
module_path: qlib.rl.order_execution.strategy
|
||||
kwargs:
|
||||
state_interpreter:
|
||||
class: FullHistoryStateInterpreter
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
kwargs:
|
||||
max_step: 8
|
||||
data_ticks: 240
|
||||
data_dim: 6
|
||||
processed_data_provider:
|
||||
class: PickleProcessedDataProvider
|
||||
module_path: qlib.rl.data.pickle_styled
|
||||
kwargs:
|
||||
data_dir: ./data/pickle_dataframe/feature
|
||||
action_interpreter:
|
||||
class: CategoricalActionInterpreter
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
kwargs:
|
||||
values: 14
|
||||
max_step: 8
|
||||
network:
|
||||
class: Recurrent
|
||||
module_path: qlib.rl.order_execution.network
|
||||
kwargs: {}
|
||||
policy:
|
||||
class: PPO
|
||||
module_path: qlib.rl.order_execution.policy
|
||||
kwargs:
|
||||
lr: 1.0e-4
|
||||
weight_file: ./checkpoints/latest.pth
|
||||
concurrency: 5
|
||||
@@ -1,59 +0,0 @@
|
||||
simulator:
|
||||
time_per_step: 30
|
||||
vol_limit: null
|
||||
env:
|
||||
concurrency: 1
|
||||
parallel_mode: dummy
|
||||
action_interpreter:
|
||||
class: CategoricalActionInterpreter
|
||||
kwargs:
|
||||
values: 14
|
||||
max_step: 8
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
state_interpreter:
|
||||
class: FullHistoryStateInterpreter
|
||||
kwargs:
|
||||
data_dim: 6
|
||||
data_ticks: 240
|
||||
max_step: 8
|
||||
processed_data_provider:
|
||||
class: PickleProcessedDataProvider
|
||||
module_path: qlib.rl.data.pickle_styled
|
||||
kwargs:
|
||||
data_dir: ./data/pickle_dataframe/feature
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
reward:
|
||||
class: PAPenaltyReward
|
||||
kwargs:
|
||||
penalty: 100.0
|
||||
module_path: qlib.rl.order_execution.reward
|
||||
data:
|
||||
source:
|
||||
order_dir: ./data/training_order_split
|
||||
data_dir: ./data/pickle_dataframe/backtest
|
||||
total_time: 240
|
||||
default_start_time: 0
|
||||
default_end_time: 240
|
||||
proc_data_dim: 6
|
||||
num_workers: 0
|
||||
queue_size: 20
|
||||
network:
|
||||
class: Recurrent
|
||||
module_path: qlib.rl.order_execution.network
|
||||
policy:
|
||||
class: PPO
|
||||
kwargs:
|
||||
lr: 0.0001
|
||||
module_path: qlib.rl.order_execution.policy
|
||||
runtime:
|
||||
seed: 42
|
||||
use_cuda: false
|
||||
trainer:
|
||||
max_epoch: 2
|
||||
repeat_per_collect: 5
|
||||
earlystop_patience: 2
|
||||
episode_per_collect: 20
|
||||
batch_size: 16
|
||||
val_every_n_epoch: 1
|
||||
checkpoint_path: ./checkpoints
|
||||
checkpoint_every_n_iters: 1
|
||||
@@ -1,21 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
os.makedirs(os.path.join("data", "pickle_dataframe"), exist_ok=True)
|
||||
|
||||
for tag in ("backtest", "feature"):
|
||||
df = pickle.load(open(os.path.join("data", "pickle", f"{tag}.pkl"), "rb"))
|
||||
df = pd.concat(list(df.values())).reset_index()
|
||||
df["date"] = df["datetime"].dt.date.astype("datetime64")
|
||||
instruments = sorted(set(df["instrument"]))
|
||||
|
||||
os.makedirs(os.path.join("data", "pickle_dataframe", tag), exist_ok=True)
|
||||
for instrument in tqdm(instruments):
|
||||
cur = df[df["instrument"] == instrument].sort_values(by=["datetime"])
|
||||
cur = cur.set_index(["instrument", "datetime", "date"])
|
||||
pickle.dump(cur, open(os.path.join("data", "pickle_dataframe", tag, f"{instrument}.pkl"), "wb"))
|
||||
@@ -1,14 +0,0 @@
|
||||
# Generate `bin` format data
|
||||
set -e
|
||||
python ../../scripts/dump_bin.py dump_all --csv_path ./data/csv --qlib_dir ./data/bin --include_fields open,close,high,low,vwap,volume --symbol_field_name symbol --date_field_name date --freq 1min
|
||||
|
||||
# Generate pickle format data
|
||||
python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml
|
||||
if [ -e stat/ ]; then
|
||||
rm -r stat/
|
||||
fi
|
||||
python scripts/collect_pickle_dataframe.py
|
||||
|
||||
# Sample orders
|
||||
python scripts/gen_training_orders.py
|
||||
python scripts/gen_backtest_orders.py
|
||||
@@ -1,55 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pickle
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--seed", type=int, default=20220926)
|
||||
parser.add_argument("--num_order", type=int, default=10)
|
||||
args = parser.parse_args()
|
||||
|
||||
np.random.seed(args.seed)
|
||||
|
||||
path = os.path.join("data", "pickle", "backtesttest.pkl")
|
||||
df = pickle.load(open(path, "rb")).reset_index()
|
||||
df["date"] = df["datetime"].dt.date.astype("datetime64")
|
||||
|
||||
instruments = sorted(set(df["instrument"]))
|
||||
|
||||
# TODO: The example is expected to be able to handle data containing missing values.
|
||||
# TODO: Currently, we just simply skip dates that contain missing data. We will add
|
||||
# TODO: this feature in the future.
|
||||
skip_dates = {}
|
||||
for instrument in instruments:
|
||||
csv_df = pd.read_csv(os.path.join("data", "csv", f"{instrument}.csv"))
|
||||
csv_df = csv_df[csv_df["close"].isna()]
|
||||
dates = set([str(d).split(" ")[0] for d in csv_df["date"]])
|
||||
skip_dates[instrument] = dates
|
||||
|
||||
df_list = []
|
||||
for instrument in instruments:
|
||||
print(instrument)
|
||||
|
||||
cur_df = df[df["instrument"] == instrument]
|
||||
|
||||
dates = sorted(set([str(d).split(" ")[0] for d in cur_df["date"]]))
|
||||
dates = [date for date in dates if date not in skip_dates[instrument]]
|
||||
|
||||
n = args.num_order
|
||||
df_list.append(
|
||||
pd.DataFrame(
|
||||
{
|
||||
"date": sorted(np.random.choice(dates, size=n, replace=False)),
|
||||
"instrument": [instrument] * n,
|
||||
"amount": np.random.randint(low=3, high=11, size=n) * 100.0,
|
||||
"order_type": np.random.randint(low=0, high=2, size=n),
|
||||
}
|
||||
).set_index(["date", "instrument"]),
|
||||
)
|
||||
|
||||
total_df = pd.concat(df_list)
|
||||
total_df.to_csv("data/backtest_orders.csv")
|
||||
@@ -1,39 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pickle
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--seed", type=int, default=20220926)
|
||||
parser.add_argument("--stock", type=str, default="AAPL")
|
||||
parser.add_argument("--train_size", type=int, default=10)
|
||||
parser.add_argument("--valid_size", type=int, default=2)
|
||||
parser.add_argument("--test_size", type=int, default=2)
|
||||
args = parser.parse_args()
|
||||
|
||||
np.random.seed(args.seed)
|
||||
|
||||
os.makedirs(os.path.join("data", "training_order_split"), exist_ok=True)
|
||||
|
||||
for group, n in zip(("train", "valid", "test"), (args.train_size, args.valid_size, args.test_size)):
|
||||
path = os.path.join("data", "pickle", f"backtest{group}.pkl")
|
||||
df = pickle.load(open(path, "rb")).reset_index()
|
||||
df["date"] = df["datetime"].dt.date.astype("datetime64")
|
||||
|
||||
dates = sorted(set([str(d).split(" ")[0] for d in df["date"]]))
|
||||
|
||||
data_df = pd.DataFrame(
|
||||
{
|
||||
"date": sorted(np.random.choice(dates, size=n, replace=False)),
|
||||
"instrument": [args.stock] * n,
|
||||
"amount": np.random.randint(low=3, high=11, size=n) * 100.0,
|
||||
"order_type": [0] * n,
|
||||
}
|
||||
).set_index(["date", "instrument"])
|
||||
|
||||
os.makedirs(os.path.join("data", "training_order_split", group), exist_ok=True)
|
||||
pickle.dump(data_df, open(os.path.join("data", "training_order_split", group, f"{args.stock}.pkl"), "wb"))
|
||||
357
examples/rl/simple_example.ipynb
Normal file
357
examples/rl/simple_example.ipynb
Normal file
File diff suppressed because one or more lines are too long
100
examples/rl_order_execution/README.md
Normal file
100
examples/rl_order_execution/README.md
Normal file
@@ -0,0 +1,100 @@
|
||||
# RL Example for Order Execution
|
||||
|
||||
This folder comprises an example of Reinforcement Learning (RL) workflows for order execution scenario, including both training workflows and backtest workflows.
|
||||
|
||||
## Data Processing
|
||||
|
||||
### Get Data
|
||||
|
||||
```
|
||||
python -m qlib.run.get_data qlib_data qlib_data --target_dir ./data/bin --region hs300 --interval 5min
|
||||
```
|
||||
|
||||
### Generate Pickle-Style Data
|
||||
|
||||
To run codes in this example, we need data in pickle format. To achieve this, run following commands (might need a few minutes to finish):
|
||||
|
||||
[//]: # (TODO: Instead of dumping dataframe with different format (like `_gen_dataset` and `_gen_day_dataset` in `qlib/contrib/data/highfreq_provider.py`), we encourage to implement different subclass of `Dataset` and `DataHandler`. This will keep the workflow cleaner and interfaces more consistent, and move all the complexity to the subclass.)
|
||||
|
||||
```
|
||||
python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml
|
||||
python scripts/gen_training_orders.py
|
||||
python scripts/merge_orders.py
|
||||
```
|
||||
|
||||
When finished, the structure under `data/` should be:
|
||||
|
||||
```
|
||||
data
|
||||
├── bin
|
||||
├── orders
|
||||
└── pickle
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
Each training task is specified by a config file. The config file for task `TASKNAME` is `exp_configs/train_TASKNAME.yml`. This example provides two training tasks:
|
||||
|
||||
- **PPO**: Method proposed by IJCAL 2020 paper "[An End-to-End Optimal Trade Execution Framework based on Proximal Policy Optimization](https://www.ijcai.org/proceedings/2020/0627.pdf)".
|
||||
- **OPDS**: Method proposed by AAAI 2021 paper "[Universal Trading for Order Execution with Oracle Policy Distillation](https://arxiv.org/abs/2103.10860)".
|
||||
|
||||
The main differece between these two methods is their reward functions. Please see their config files for details.
|
||||
|
||||
Take OPDS as an example, to run the training workflow, run:
|
||||
|
||||
```
|
||||
python -m qlib.rl.contrib.train_onpolicy --config_path exp_configs/train_opds.yml --run_backtest
|
||||
```
|
||||
|
||||
Metrics, logs, and checkpoints will be stored under `outputs/opds` (configured by `exp_configs/train_opds.yml`).
|
||||
|
||||
## Backtest
|
||||
|
||||
Once the training workflow has completed, the trained model can be used for the backtesting workflow. Still taking OPDS as an example, once training is finished, the latest checkpoint of the model can be found at `outputs/opds/checkpoints/latest.pth`. To run backtest workflow:
|
||||
|
||||
1. Uncomment the `weight_file` parameter in `exp_configs/train_opds.yml` (it is commented by default). While it is possible to run the backtesting workflow without setting a checkpoint, this will lead to randomly initialized model results, thus making them meaningless.
|
||||
2. Run `python -m qlib.rl.contrib.backtest --config_path exp_configs/backtest_opds.yml`.
|
||||
|
||||
The backtest result is stored in `outputs/checkpoints/backtest_result.csv`.
|
||||
|
||||
In addition to OPDS and PPO, we also provide TWAP ([Time-weighted average price](https://en.wikipedia.org/wiki/Time-weighted_average_price)) as a weak baseline. The config file for TWAP is `exp_configs/backtest_twap.yml`.
|
||||
|
||||
### Gap between backtest and training pipeline's testing
|
||||
|
||||
It is worthy to notice that the results of the backtesting process may differ from the results of the testing process used during training.
|
||||
This is because different simulators are used to simulate market conditions during training and backtesting.
|
||||
In training pipeline, the simplified simulator called `SingleAssetOrderExecutionSimple` is used for efficiency reasons.
|
||||
`SingleAssetOrderExecutionSimple` makes no restriction to trading amounts.
|
||||
No matter what the amount of the order is, it can be completely executed.
|
||||
However, during backtesting, a more realistic simulator called `SingleAssetOrderExecution` is used.
|
||||
It takes into account practical constraints in more real-world scenarios (for example, the trading volume must be a multiple of the smallest trading unit).
|
||||
As a result, the amount of an order that is actually executed during backtesting may differ from the amount expected to be executed.
|
||||
|
||||
If you would like to obtain results that are exactly the same as those obtained during testing in the training pipeline, you could run training pipeline with only backtest phrase.
|
||||
In order to do this:
|
||||
- Modify the training config. Add the path of the checkpoint you want to use (see following for an example).
|
||||
- Run `python -m qlib.rl.contrib.train_onpolicy --config_path PATH/TO/CONFIG --run_backtest --no_training`
|
||||
|
||||
```yaml
|
||||
...
|
||||
policy:
|
||||
class: PPO # PPO, DQN
|
||||
kwargs:
|
||||
lr: 0.0001
|
||||
weight_file: PATH/TO/CHECKPOINT
|
||||
module_path: qlib.rl.order_execution.policy
|
||||
...
|
||||
```
|
||||
|
||||
## Benchmarks (TBD)
|
||||
|
||||
To accurately evaluate the performance of models using Reinforcement Learning algorithms, it's best to run experiments multiple times and compute the average performance across all trials. However, given the time-consuming nature of model training, this is not always feasible. An alternative approach is to run each training task only once, selecting the 10 checkpoints with the highest validation performance to simulate multiple trials. In this example, we use "Price Advantage (PA)" as the metric for selecting these checkpoints. The average performance of these 10 checkpoints on the testing set is as follows:
|
||||
|
||||
| **Model** | **PA mean with std.** |
|
||||
|-----------------------------|-----------------------|
|
||||
| OPDS (with PPO policy) | 0.4785 ± 0.7815 |
|
||||
| OPDS (with DQN policy) | -0.0114 ± 0.5780 |
|
||||
| PPO | -1.0935 ± 0.0922 |
|
||||
| TWAP | ≈ 0.0 ± 0.0 |
|
||||
|
||||
The table above also includes TWAP as a rule-based baseline. The ideal PA of TWAP should be 0.0, however, in this example, the order execution is divided into two steps: first, the order is split equally among each half hour, and then each five minutes within each half hour. Since trading is forbidden during the last five minutes of the day, this approach may slightly differ from traditional TWAP over the course of a full day (as there are 5 minutes missing in the last "half hour"). Therefore, the PA of TWAP can be considered as a number that is close to 0.0. To verify this, you may run a TWAP backtest and check the results.
|
||||
53
examples/rl_order_execution/exp_configs/backtest_opds.yml
Executable file
53
examples/rl_order_execution/exp_configs/backtest_opds.yml
Executable file
@@ -0,0 +1,53 @@
|
||||
order_file: ./data/orders/test_orders.pkl
|
||||
start_time: "9:30"
|
||||
end_time: "14:54"
|
||||
data_granularity: "5min"
|
||||
qlib:
|
||||
provider_uri_5min: ./data/bin/
|
||||
exchange:
|
||||
limit_threshold: null
|
||||
deal_price: ["$close", "$close"]
|
||||
volume_threshold: null
|
||||
strategies:
|
||||
1day:
|
||||
class: SAOEIntStrategy
|
||||
kwargs:
|
||||
data_granularity: 5
|
||||
action_interpreter:
|
||||
class: CategoricalActionInterpreter
|
||||
kwargs:
|
||||
max_step: 8
|
||||
values: 4
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
network:
|
||||
class: Recurrent
|
||||
kwargs: {}
|
||||
module_path: qlib.rl.order_execution.network
|
||||
policy:
|
||||
class: PPO # PPO, DQN
|
||||
kwargs:
|
||||
lr: 0.0001
|
||||
# Restore `weight_file` once the training workflow finishes. You can change the checkpoint file you want to use.
|
||||
# weight_file: outputs/opds/checkpoints/latest.pth
|
||||
module_path: qlib.rl.order_execution.policy
|
||||
state_interpreter:
|
||||
class: FullHistoryStateInterpreter
|
||||
kwargs:
|
||||
data_dim: 5
|
||||
data_ticks: 48
|
||||
max_step: 8
|
||||
processed_data_provider:
|
||||
class: HandlerProcessedDataProvider
|
||||
kwargs:
|
||||
data_dir: ./data/pickle/
|
||||
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
|
||||
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
|
||||
module_path: qlib.rl.data.native
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
module_path: qlib.rl.order_execution.strategy
|
||||
30min:
|
||||
class: TWAPStrategy
|
||||
kwargs: {}
|
||||
module_path: qlib.contrib.strategy.rule_strategy
|
||||
concurrency: 16
|
||||
output_dir: outputs/opds/
|
||||
53
examples/rl_order_execution/exp_configs/backtest_ppo.yml
Executable file
53
examples/rl_order_execution/exp_configs/backtest_ppo.yml
Executable file
@@ -0,0 +1,53 @@
|
||||
order_file: ./data/orders/test_orders.pkl
|
||||
start_time: "9:30"
|
||||
end_time: "14:54"
|
||||
data_granularity: "5min"
|
||||
qlib:
|
||||
provider_uri_5min: ./data/bin/
|
||||
exchange:
|
||||
limit_threshold: null
|
||||
deal_price: ["$close", "$close"]
|
||||
volume_threshold: null
|
||||
strategies:
|
||||
1day:
|
||||
class: SAOEIntStrategy
|
||||
kwargs:
|
||||
data_granularity: 5
|
||||
action_interpreter:
|
||||
class: CategoricalActionInterpreter
|
||||
kwargs:
|
||||
max_step: 8
|
||||
values: 4
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
network:
|
||||
class: Recurrent
|
||||
kwargs: {}
|
||||
module_path: qlib.rl.order_execution.network
|
||||
policy:
|
||||
class: PPO # PPO, DQN
|
||||
kwargs:
|
||||
lr: 0.0001
|
||||
# Restore `weight_file` once the training workflow finishes. You can change the checkpoint file you want to use.
|
||||
# weight_file: outputs/ppo/checkpoints/latest.pth
|
||||
module_path: qlib.rl.order_execution.policy
|
||||
state_interpreter:
|
||||
class: FullHistoryStateInterpreter
|
||||
kwargs:
|
||||
data_dim: 5
|
||||
data_ticks: 48
|
||||
max_step: 8
|
||||
processed_data_provider:
|
||||
class: HandlerProcessedDataProvider
|
||||
kwargs:
|
||||
data_dir: ./data/pickle/
|
||||
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
|
||||
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
|
||||
module_path: qlib.rl.data.native
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
module_path: qlib.rl.order_execution.strategy
|
||||
30min:
|
||||
class: TWAPStrategy
|
||||
kwargs: {}
|
||||
module_path: qlib.contrib.strategy.rule_strategy
|
||||
concurrency: 16
|
||||
output_dir: outputs/ppo/
|
||||
21
examples/rl_order_execution/exp_configs/backtest_twap.yml
Executable file
21
examples/rl_order_execution/exp_configs/backtest_twap.yml
Executable file
@@ -0,0 +1,21 @@
|
||||
order_file: ./data/orders/test_orders.pkl
|
||||
start_time: "9:30"
|
||||
end_time: "14:54"
|
||||
data_granularity: "5min"
|
||||
qlib:
|
||||
provider_uri_5min: ./data/bin/
|
||||
exchange:
|
||||
limit_threshold: null
|
||||
deal_price: ["$close", "$close"]
|
||||
volume_threshold: null
|
||||
strategies:
|
||||
1day:
|
||||
class: TWAPStrategy
|
||||
kwargs: {}
|
||||
module_path: qlib.contrib.strategy.rule_strategy
|
||||
30min:
|
||||
class: TWAPStrategy
|
||||
kwargs: {}
|
||||
module_path: qlib.contrib.strategy.rule_strategy
|
||||
concurrency: 16
|
||||
output_dir: outputs/twap/
|
||||
66
examples/rl_order_execution/exp_configs/train_opds.yml
Executable file
66
examples/rl_order_execution/exp_configs/train_opds.yml
Executable file
@@ -0,0 +1,66 @@
|
||||
simulator:
|
||||
data_granularity: 5
|
||||
time_per_step: 30
|
||||
vol_limit: null
|
||||
env:
|
||||
concurrency: 32
|
||||
parallel_mode: dummy
|
||||
action_interpreter:
|
||||
class: CategoricalActionInterpreter
|
||||
kwargs:
|
||||
values: 4
|
||||
max_step: 8
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
state_interpreter:
|
||||
class: FullHistoryStateInterpreter
|
||||
kwargs:
|
||||
data_dim: 5
|
||||
data_ticks: 48 # 48 = 240 min / 5 min
|
||||
max_step: 8
|
||||
processed_data_provider:
|
||||
class: HandlerProcessedDataProvider
|
||||
kwargs:
|
||||
data_dir: ./data/pickle/
|
||||
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
|
||||
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
|
||||
backtest: false
|
||||
module_path: qlib.rl.data.native
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
reward:
|
||||
class: PAPenaltyReward
|
||||
kwargs:
|
||||
penalty: 4.0
|
||||
scale: 0.01
|
||||
module_path: qlib.rl.order_execution.reward
|
||||
data:
|
||||
source:
|
||||
order_dir: ./data/orders
|
||||
feature_root_dir: ./data/pickle/
|
||||
feature_columns_today: ["$close0", "$volume0"]
|
||||
feature_columns_yesterday: []
|
||||
total_time: 240
|
||||
default_start_time_index: 0
|
||||
default_end_time_index: 235
|
||||
proc_data_dim: 5
|
||||
num_workers: 0
|
||||
queue_size: 20
|
||||
network:
|
||||
class: Recurrent
|
||||
module_path: qlib.rl.order_execution.network
|
||||
policy:
|
||||
class: PPO # PPO, DQN
|
||||
kwargs:
|
||||
lr: 0.0001
|
||||
module_path: qlib.rl.order_execution.policy
|
||||
runtime:
|
||||
seed: 42
|
||||
use_cuda: false
|
||||
trainer:
|
||||
max_epoch: 500
|
||||
repeat_per_collect: 25
|
||||
earlystop_patience: 50
|
||||
episode_per_collect: 10000
|
||||
batch_size: 1024
|
||||
val_every_n_epoch: 4
|
||||
checkpoint_path: ./outputs/opds
|
||||
checkpoint_every_n_iters: 1
|
||||
67
examples/rl_order_execution/exp_configs/train_ppo.yml
Executable file
67
examples/rl_order_execution/exp_configs/train_ppo.yml
Executable file
@@ -0,0 +1,67 @@
|
||||
simulator:
|
||||
data_granularity: 5
|
||||
time_per_step: 30
|
||||
vol_limit: null
|
||||
env:
|
||||
concurrency: 32
|
||||
parallel_mode: dummy
|
||||
action_interpreter:
|
||||
class: CategoricalActionInterpreter
|
||||
kwargs:
|
||||
values: 4
|
||||
max_step: 8
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
state_interpreter:
|
||||
class: FullHistoryStateInterpreter
|
||||
kwargs:
|
||||
data_dim: 5
|
||||
data_ticks: 48 # 48 = 240 min / 5 min
|
||||
max_step: 8
|
||||
processed_data_provider:
|
||||
class: HandlerProcessedDataProvider
|
||||
kwargs:
|
||||
data_dir: ./data/pickle/
|
||||
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
|
||||
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
|
||||
backtest: false
|
||||
module_path: qlib.rl.data.native
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
reward:
|
||||
class: PPOReward
|
||||
kwargs:
|
||||
max_step: 8
|
||||
start_time_index: 0
|
||||
end_time_index: 46 # 46 = (240 - 5) min / 5 min - 1
|
||||
module_path: qlib.rl.order_execution.reward
|
||||
data:
|
||||
source:
|
||||
order_dir: ./data/orders
|
||||
feature_root_dir: ./data/pickle/
|
||||
feature_columns_today: ["$close0", "$volume0"]
|
||||
feature_columns_yesterday: []
|
||||
total_time: 240
|
||||
default_start_time_index: 0
|
||||
default_end_time_index: 235
|
||||
proc_data_dim: 5
|
||||
num_workers: 0
|
||||
queue_size: 20
|
||||
network:
|
||||
class: Recurrent
|
||||
module_path: qlib.rl.order_execution.network
|
||||
policy:
|
||||
class: PPO # PPO, DQN
|
||||
kwargs:
|
||||
lr: 0.0001
|
||||
module_path: qlib.rl.order_execution.policy
|
||||
runtime:
|
||||
seed: 42
|
||||
use_cuda: false
|
||||
trainer:
|
||||
max_epoch: 500
|
||||
repeat_per_collect: 25
|
||||
earlystop_patience: 50
|
||||
episode_per_collect: 10000
|
||||
batch_size: 1024
|
||||
val_every_n_epoch: 4
|
||||
checkpoint_path: ./outputs/ppo
|
||||
checkpoint_every_n_iters: 1
|
||||
@@ -4,6 +4,7 @@
|
||||
import yaml
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
from copy import deepcopy
|
||||
|
||||
from qlib.contrib.data.highfreq_provider import HighFreqProvider
|
||||
@@ -41,3 +42,5 @@ if __name__ == "__main__":
|
||||
if args.split == "stock" or args.split == "both":
|
||||
provider._gen_stock_dataset(deepcopy(provider.feature_conf), "feature")
|
||||
provider._gen_stock_dataset(deepcopy(provider.backtest_conf), "backtest")
|
||||
|
||||
shutil.rmtree("stat/", ignore_errors=True)
|
||||
53
examples/rl_order_execution/scripts/gen_training_orders.py
Executable file
53
examples/rl_order_execution/scripts/gen_training_orders.py
Executable file
@@ -0,0 +1,53 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
DATA_PATH = Path(os.path.join("data", "pickle", "backtest"))
|
||||
OUTPUT_PATH = Path(os.path.join("data", "orders"))
|
||||
|
||||
|
||||
def generate_order(stock: str, start_idx: int, end_idx: int) -> bool:
|
||||
dataset = pd.read_pickle(DATA_PATH / f"{stock}.pkl")
|
||||
df = dataset.handler.fetch(level=None).reset_index()
|
||||
if len(df) == 0 or df.isnull().values.any() or min(df["$volume0"]) < 1e-5:
|
||||
return False
|
||||
|
||||
df["date"] = df["datetime"].dt.date.astype("datetime64")
|
||||
df = df.set_index(["instrument", "datetime", "date"])
|
||||
df = df.groupby("date").take(range(start_idx, end_idx)).droplevel(level=0)
|
||||
|
||||
order_all = pd.DataFrame(df.groupby(level=(2, 0)).mean().dropna())
|
||||
order_all["amount"] = np.random.lognormal(-3.28, 1.14) * order_all["$volume0"]
|
||||
order_all = order_all[order_all["amount"] > 0.0]
|
||||
order_all["order_type"] = 0
|
||||
order_all = order_all.drop(columns=["$volume0"])
|
||||
|
||||
order_train = order_all[order_all.index.get_level_values(0) <= pd.Timestamp("2021-06-30")]
|
||||
order_test = order_all[order_all.index.get_level_values(0) > pd.Timestamp("2021-06-30")]
|
||||
order_valid = order_test[order_test.index.get_level_values(0) <= pd.Timestamp("2021-09-30")]
|
||||
order_test = order_test[order_test.index.get_level_values(0) > pd.Timestamp("2021-09-30")]
|
||||
|
||||
for order, tag in zip((order_train, order_valid, order_test, order_all), ("train", "valid", "test", "all")):
|
||||
path = OUTPUT_PATH / tag
|
||||
os.makedirs(path, exist_ok=True)
|
||||
if len(order) > 0:
|
||||
order.to_pickle(path / f"{stock}.pkl.target")
|
||||
return True
|
||||
|
||||
|
||||
np.random.seed(1234)
|
||||
file_list = sorted(os.listdir(DATA_PATH))
|
||||
stocks = [f.replace(".pkl", "") for f in file_list]
|
||||
np.random.shuffle(stocks)
|
||||
|
||||
cnt = 0
|
||||
for stock in stocks:
|
||||
if generate_order(stock, 0, 240 // 5 - 1):
|
||||
cnt += 1
|
||||
if cnt == 100:
|
||||
break
|
||||
15
examples/rl_order_execution/scripts/merge_orders.py
Executable file
15
examples/rl_order_execution/scripts/merge_orders.py
Executable file
@@ -0,0 +1,15 @@
|
||||
import pickle
|
||||
import os
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
for tag in ["test", "valid"]:
|
||||
files = os.listdir(os.path.join("data/orders/", tag))
|
||||
dfs = []
|
||||
for f in tqdm(files):
|
||||
df = pickle.load(open(os.path.join("data/orders/", tag, f), "rb"))
|
||||
df = df.drop(["$close0"], axis=1)
|
||||
dfs.append(df)
|
||||
|
||||
total_df = pd.concat(dfs)
|
||||
pickle.dump(total_df, open(os.path.join("data", "orders", f"{tag}_orders.pkl"), "wb"))
|
||||
@@ -1,15 +1,16 @@
|
||||
# start & end time for training/validation/test datasets
|
||||
start_time: !!str &start 2020-01-01
|
||||
end_time: !!str &end 2020-07-31
|
||||
train_end_time: !!str &tend 2020-03-31
|
||||
valid_start_time: !!str &vstart 2020-04-01
|
||||
valid_end_time: !!str &vend 2020-05-31
|
||||
test_start_time: !!str &tstart 2020-06-01
|
||||
end_time: !!str &end 2021-12-31
|
||||
train_end_time: !!str &tend 2021-06-30
|
||||
valid_start_time: !!str &vstart 2021-07-01
|
||||
valid_end_time: !!str &vend 2021-09-30
|
||||
test_start_time: !!str &tstart 2021-10-01
|
||||
# the instrument set
|
||||
instruments: &ins all
|
||||
instruments: &ins csi300s19_22
|
||||
# qlib related configuration
|
||||
qlib_conf:
|
||||
provider_uri: ./data/bin # path to generated qlib bin
|
||||
provider_uri:
|
||||
5min: ./data/bin # path to generated qlib bin
|
||||
redis_port: 233
|
||||
feature_conf:
|
||||
path: ./data/pickle/feature.pkl # output path of feature
|
||||
@@ -26,14 +27,23 @@ feature_conf:
|
||||
fit_end_time: *tend
|
||||
instruments: *ins
|
||||
day_length: 240 # how many minutes in one trading day
|
||||
freq: 5min
|
||||
columns: ["$open", "$high", "$low", "$close"]
|
||||
infer_processors:
|
||||
- class: HighFreqNorm
|
||||
module_path: qlib.contrib.data.highfreq_processor
|
||||
kwargs:
|
||||
feature_save_dir: ./stat/ # output path of statistics of features (for feature normalization)
|
||||
norm_groups:
|
||||
price: 10
|
||||
price: 8
|
||||
volume: 2
|
||||
inst_processors:
|
||||
- class: TimeRangeFlt
|
||||
module_path: qlib.data.dataset.processor
|
||||
kwargs:
|
||||
start_time: "2020-01-01"
|
||||
end_time: "2021-12-31"
|
||||
freq: 5min
|
||||
segments:
|
||||
train: !!python/tuple [*start, *tend]
|
||||
valid: !!python/tuple [*vstart, *vend]
|
||||
@@ -51,7 +61,17 @@ backtest_conf:
|
||||
end_time: *end
|
||||
instruments: *ins
|
||||
day_length: 240
|
||||
freq: 5min
|
||||
columns: ["$close", "$volume"]
|
||||
inst_processors:
|
||||
- class: TimeRangeFlt
|
||||
module_path: qlib.data.dataset.processor
|
||||
kwargs:
|
||||
start_time: "2020-01-01"
|
||||
end_time: "2021-12-31"
|
||||
freq: 5min
|
||||
segments:
|
||||
train: !!python/tuple [*start, *tend]
|
||||
valid: !!python/tuple [*vstart, *vend]
|
||||
test: !!python/tuple [*tstart, *end]
|
||||
freq: 5min
|
||||
@@ -88,6 +88,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.tests.data import GetData\n",
|
||||
"\n",
|
||||
"GetData().qlib_data(exists_skip=True)"
|
||||
]
|
||||
},
|
||||
@@ -99,6 +100,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import qlib\n",
|
||||
"\n",
|
||||
"qlib.init()"
|
||||
]
|
||||
},
|
||||
@@ -134,7 +136,8 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.data import D\n",
|
||||
"D.calendar(start_time='2010-01-01', end_time='2017-12-31', freq='day')[:2] # calendar data"
|
||||
"\n",
|
||||
"print(D.calendar(start_time=\"2010-01-01\", end_time=\"2017-12-31\", freq=\"day\")[:2]) # calendar data"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -152,7 +155,12 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = D.features(['SH601216'], ['$open', '$high', '$low', '$close', '$factor'], start_time='2020-05-01', end_time='2020-05-31') "
|
||||
"df = D.features(\n",
|
||||
" [\"SH601216\"],\n",
|
||||
" [\"$open\", \"$high\", \"$low\", \"$close\", \"$factor\"],\n",
|
||||
" start_time=\"2020-05-01\",\n",
|
||||
" end_time=\"2020-05-31\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -163,11 +171,18 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import plotly.graph_objects as go\n",
|
||||
"fig = go.Figure(data=[go.Candlestick(x=df.index.get_level_values(\"datetime\"),\n",
|
||||
" open=df['$open'],\n",
|
||||
" high=df['$high'],\n",
|
||||
" low=df['$low'],\n",
|
||||
" close=df['$close'])])\n",
|
||||
"\n",
|
||||
"fig = go.Figure(\n",
|
||||
" data=[\n",
|
||||
" go.Candlestick(\n",
|
||||
" x=df.index.get_level_values(\"datetime\"),\n",
|
||||
" open=df[\"$open\"],\n",
|
||||
" high=df[\"$high\"],\n",
|
||||
" low=df[\"$low\"],\n",
|
||||
" close=df[\"$close\"],\n",
|
||||
" )\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"fig.show()"
|
||||
]
|
||||
},
|
||||
@@ -197,11 +212,18 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import plotly.graph_objects as go\n",
|
||||
"fig = go.Figure(data=[go.Candlestick(x=df.index.get_level_values(\"datetime\"),\n",
|
||||
" open=df['$open'] / df['$factor'],\n",
|
||||
" high=df['$high'] / df['$factor'],\n",
|
||||
" low=df['$low'] / df['$factor'],\n",
|
||||
" close=df['$close'] / df['$factor'])])\n",
|
||||
"\n",
|
||||
"fig = go.Figure(\n",
|
||||
" data=[\n",
|
||||
" go.Candlestick(\n",
|
||||
" x=df.index.get_level_values(\"datetime\"),\n",
|
||||
" open=df[\"$open\"] / df[\"$factor\"],\n",
|
||||
" high=df[\"$high\"] / df[\"$factor\"],\n",
|
||||
" low=df[\"$low\"] / df[\"$factor\"],\n",
|
||||
" close=df[\"$close\"] / df[\"$factor\"],\n",
|
||||
" )\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"fig.show()"
|
||||
]
|
||||
},
|
||||
@@ -240,7 +262,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# dynamic universe\n",
|
||||
"universe = D.list_instruments(D.instruments('csi100'), start_time='2010-01-01', end_time='2020-12-31')\n",
|
||||
"universe = D.list_instruments(D.instruments(\"csi100\"), start_time=\"2010-01-01\", end_time=\"2020-12-31\")\n",
|
||||
"pprint(universe)"
|
||||
]
|
||||
},
|
||||
@@ -271,8 +293,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = D.features(D.instruments('csi100'), ['$close'], start_time='2010-01-01', end_time='2020-12-31') \n",
|
||||
"df.groupby('datetime').size().plot()"
|
||||
"df = D.features(D.instruments(\"csi100\"), [\"$close\"], start_time=\"2010-01-01\", end_time=\"2020-12-31\")\n",
|
||||
"df.groupby(\"datetime\").size().plot()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -313,8 +335,7 @@
|
||||
" !cd ../../scripts/data_collector/pit/ && pip install -r requirements.txt\n",
|
||||
" !cd ../../scripts/data_collector/pit/ && python collector.py download_data --source_dir ~/.qlib/stock_data/source/pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_regex \"^(600519|000725).*\"\n",
|
||||
" !cd ../../scripts/data_collector/pit/ && python collector.py normalize_data --interval quarterly --source_dir ~/.qlib/stock_data/source/pit --normalize_dir ~/.qlib/stock_data/source/pit_normalized\n",
|
||||
" !cd ../../scripts/ && python dump_pit.py dump --csv_path ~/.qlib/stock_data/source/pit_normalized --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly\n",
|
||||
" pass"
|
||||
" !cd ../../scripts/ && python dump_pit.py dump --csv_path ~/.qlib/stock_data/source/pit_normalized --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -338,7 +359,13 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"instruments = [\"sh600519\"]\n",
|
||||
"data = D.features(instruments, ['P($$roewa_q)'], start_time=\"2019-01-01\", end_time=\"2019-07-19\", freq=\"day\")"
|
||||
"data = D.features(\n",
|
||||
" instruments,\n",
|
||||
" [\"P($$roewa_q)\"],\n",
|
||||
" start_time=\"2019-01-01\",\n",
|
||||
" end_time=\"2019-07-19\",\n",
|
||||
" freq=\"day\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -366,7 +393,10 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"D.features([\"sh600519\"], ['(EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9)/$close'])"
|
||||
"D.features(\n",
|
||||
" [\"sh600519\"],\n",
|
||||
" [\"(EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9)/$close\"],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -418,7 +448,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"qdl = QlibDataLoader(config=(['$close / Ref($close, 10)'], ['RET10']))"
|
||||
"qdl = QlibDataLoader(config=([\"$close / Ref($close, 10)\"], [\"RET10\"]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -428,7 +458,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"qdl.load(instruments=['sh600519'], start_time='20190101', end_time='20191231')"
|
||||
"qdl.load(instruments=[\"sh600519\"], start_time=\"20190101\", end_time=\"20191231\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -456,7 +486,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = qdl.load(instruments=['sh600519'], start_time='20190101', end_time='20191231')"
|
||||
"df = qdl.load(instruments=[\"sh600519\"], start_time=\"20190101\", end_time=\"20191231\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -476,7 +506,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df.plot(kind='hist')"
|
||||
"df.plot(kind=\"hist\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -508,9 +538,16 @@
|
||||
"source": [
|
||||
"# NOTE: normally, the training & validation time range will be `fit_start_time` , `fit_end_time`\n",
|
||||
"# however,all the components are decomposed, so the training & validation time range is unknown when preprocessing.\n",
|
||||
"dh = DataHandlerLP(instruments=['sh600519'], start_time='20170101', end_time='20191231',\n",
|
||||
" infer_processors=[ZScoreNorm(fit_start_time='20170101', fit_end_time='20181231'), Fillna()],\n",
|
||||
" data_loader=qdl)"
|
||||
"dh = DataHandlerLP(\n",
|
||||
" instruments=[\"sh600519\"],\n",
|
||||
" start_time=\"20170101\",\n",
|
||||
" end_time=\"20191231\",\n",
|
||||
" infer_processors=[\n",
|
||||
" ZScoreNorm(fit_start_time=\"20170101\", fit_end_time=\"20181231\"),\n",
|
||||
" Fillna(),\n",
|
||||
" ],\n",
|
||||
" data_loader=qdl,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -550,7 +587,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df.plot(kind='hist')"
|
||||
"df.plot(kind=\"hist\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -586,7 +623,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ds = DatasetH(dh, segments={\"train\": ('20180101', '20181231'), \"valid\": ('20190101', '20191231')})"
|
||||
"ds = DatasetH(dh, segments={\"train\": (\"20180101\", \"20181231\"), \"valid\": (\"20190101\", \"20191231\")})"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -596,7 +633,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ds.prepare('train')"
|
||||
"ds.prepare(\"train\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -606,7 +643,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ds.prepare('valid')"
|
||||
"ds.prepare(\"valid\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -628,8 +665,12 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ds = TSDatasetH(step_len=10, handler=dh, segments={\"train\": ('20180101', '20181231'), \"valid\": ('20190101', '20191231')})\n",
|
||||
"train_sampler = ds.prepare('train')"
|
||||
"ds = TSDatasetH(\n",
|
||||
" step_len=10,\n",
|
||||
" handler=dh,\n",
|
||||
" segments={\"train\": (\"20180101\", \"20181231\"), \"valid\": (\"20190101\", \"20191231\")},\n",
|
||||
")\n",
|
||||
"train_sampler = ds.prepare(\"train\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -649,7 +690,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_sampler[0] # Retrieving the first example"
|
||||
"train_sampler[0] # Retrieving the first example"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -659,7 +700,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_sampler['2018-01-08', 'sh600519'] # get the time series by <'timestamp', 'instrument_id'> index"
|
||||
"train_sampler[\"2018-01-08\", \"sh600519\"] # get the time series by <'timestamp', 'instrument_id'> index"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -682,11 +723,11 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"handler_kwargs = {\n",
|
||||
" \"start_time\": \"2008-01-01\",\n",
|
||||
" \"end_time\": \"2020-08-01\",\n",
|
||||
" \"fit_start_time\": \"2008-01-01\",\n",
|
||||
" \"fit_end_time\": \"2014-12-31\",\n",
|
||||
" \"instruments\": MARKET,\n",
|
||||
" \"start_time\": \"2008-01-01\",\n",
|
||||
" \"end_time\": \"2020-08-01\",\n",
|
||||
" \"fit_start_time\": \"2008-01-01\",\n",
|
||||
" \"fit_end_time\": \"2014-12-31\",\n",
|
||||
" \"instruments\": MARKET,\n",
|
||||
"}\n",
|
||||
"handler_conf = {\n",
|
||||
" \"class\": \"Alpha158\",\n",
|
||||
@@ -735,6 +776,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.contrib.data.handler import Alpha158\n",
|
||||
"\n",
|
||||
"hd = Alpha158(**handler_kwargs)"
|
||||
]
|
||||
},
|
||||
@@ -826,7 +868,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"hd.process_type # appending type"
|
||||
"hd.process_type # appending type"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -857,16 +899,16 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset_conf = {\n",
|
||||
" \"class\": \"DatasetH\",\n",
|
||||
" \"module_path\": \"qlib.data.dataset\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"handler\": hd,\n",
|
||||
" \"segments\": {\n",
|
||||
" \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
|
||||
" \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
|
||||
" \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
|
||||
" },\n",
|
||||
" \"class\": \"DatasetH\",\n",
|
||||
" \"module_path\": \"qlib.data.dataset\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"handler\": hd,\n",
|
||||
" \"segments\": {\n",
|
||||
" \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
|
||||
" \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
|
||||
" \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
@@ -908,7 +950,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = init_instance_by_config({\n",
|
||||
"model = init_instance_by_config(\n",
|
||||
" {\n",
|
||||
" \"class\": \"LGBModel\",\n",
|
||||
" \"module_path\": \"qlib.contrib.model.gbdt\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
@@ -922,7 +965,8 @@
|
||||
" \"num_leaves\": 210,\n",
|
||||
" \"num_threads\": 20,\n",
|
||||
" },\n",
|
||||
"})"
|
||||
" }\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -938,7 +982,7 @@
|
||||
" R.save_objects(trained_model=model)\n",
|
||||
"\n",
|
||||
" rec = R.get_recorder()\n",
|
||||
" rid = rec.id # save the record id\n",
|
||||
" rid = rec.id # save the record id\n",
|
||||
"\n",
|
||||
" # Inference and saving signal\n",
|
||||
" sr = SignalRecord(model, dataset, rec)\n",
|
||||
@@ -1001,12 +1045,11 @@
|
||||
"\n",
|
||||
"# backtest and analysis\n",
|
||||
"with R.start(experiment_name=EXP_NAME, recorder_id=rid, resume=True):\n",
|
||||
"\n",
|
||||
" # signal-based analysis\n",
|
||||
" rec = R.get_recorder()\n",
|
||||
" sar = SigAnaRecord(rec)\n",
|
||||
" sar.generate()\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" # portfolio-based analysis: backtest\n",
|
||||
" par = PortAnaRecord(rec, port_analysis_config, \"day\")\n",
|
||||
" par.generate()"
|
||||
@@ -1137,7 +1180,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"label_df = dataset.prepare(\"test\", col_set=\"label\")\n",
|
||||
"label_df.columns = ['label']"
|
||||
"label_df.columns = [\"label\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -38,7 +38,7 @@
|
||||
" # install qlib\n",
|
||||
" ! pip install --upgrade numpy\n",
|
||||
" ! pip install pyqlib\n",
|
||||
" if 'google.colab' in sys.modules:\n",
|
||||
" if \"google.colab\" in sys.modules:\n",
|
||||
" # The Google colab environment is a little outdated. We have to downgrade the pyyaml to make it compatible with other packages\n",
|
||||
" ! pip install pyyaml==5.4.1\n",
|
||||
" # reload\n",
|
||||
@@ -50,7 +50,8 @@
|
||||
" scripts_dir = Path(\"~/tmp/qlib_code/scripts\").expanduser().resolve()\n",
|
||||
" scripts_dir.mkdir(parents=True, exist_ok=True)\n",
|
||||
" import requests\n",
|
||||
" with requests.get(\"https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py\") as resp:\n",
|
||||
"\n",
|
||||
" with requests.get(\"https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py\", timeout=10) as resp:\n",
|
||||
" with open(scripts_dir.joinpath(\"get_data.py\"), \"wb\") as fp:\n",
|
||||
" fp.write(resp.content)"
|
||||
]
|
||||
@@ -61,14 +62,13 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"import qlib\n",
|
||||
"import pandas as pd\n",
|
||||
"from qlib.constant import REG_CN\n",
|
||||
"from qlib.utils import exists_qlib_data, init_instance_by_config\n",
|
||||
"from qlib.workflow import R\n",
|
||||
"from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n",
|
||||
"from qlib.utils import flatten_dict\n"
|
||||
"from qlib.utils import flatten_dict"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -86,6 +86,7 @@
|
||||
" print(f\"Qlib data is not found in {provider_uri}\")\n",
|
||||
" sys.path.append(str(scripts_dir))\n",
|
||||
" from get_data import GetData\n",
|
||||
"\n",
|
||||
" GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n",
|
||||
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
|
||||
]
|
||||
@@ -169,7 +170,7 @@
|
||||
" R.log_params(**flatten_dict(task))\n",
|
||||
" model.fit(dataset)\n",
|
||||
" R.save_objects(trained_model=model)\n",
|
||||
" rid = R.get_recorder().id\n"
|
||||
" rid = R.get_recorder().id"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -238,7 +239,7 @@
|
||||
"\n",
|
||||
" # backtest & analysis\n",
|
||||
" par = PortAnaRecord(recorder, port_analysis_config, \"day\")\n",
|
||||
" par.generate()\n"
|
||||
" par.generate()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -256,6 +257,7 @@
|
||||
"source": [
|
||||
"from qlib.contrib.report import analysis_model, analysis_position\n",
|
||||
"from qlib.data import D\n",
|
||||
"\n",
|
||||
"recorder = R.get_recorder(recorder_id=ba_rid, experiment_name=\"backtest_analysis\")\n",
|
||||
"print(recorder)\n",
|
||||
"pred_df = recorder.load_object(\"pred.pkl\")\n",
|
||||
@@ -317,7 +319,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"label_df = dataset.prepare(\"test\", col_set=\"label\")\n",
|
||||
"label_df.columns = ['label']"
|
||||
"label_df.columns = [\"label\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
__version__ = "0.9.0"
|
||||
__version__ = "0.9.1.99"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
@@ -40,8 +40,8 @@ def get_exchange(
|
||||
open_cost: float = 0.0015,
|
||||
close_cost: float = 0.0025,
|
||||
min_cost: float = 5.0,
|
||||
limit_threshold: Union[Tuple[str, str], float, None] = None,
|
||||
deal_price: Union[str, Tuple[str, str], List[str]] = None,
|
||||
limit_threshold: Union[Tuple[str, str], float, None] | None = None,
|
||||
deal_price: Union[str, Tuple[str, str], List[str]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Exchange:
|
||||
"""get_exchange
|
||||
@@ -284,7 +284,7 @@ def collect_data(
|
||||
account: Union[float, int, dict] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
return_value: dict = None,
|
||||
return_value: dict | None = None,
|
||||
) -> Generator[object, None, None]:
|
||||
"""initialize the strategy and executor, then collect the trade decision data for rl training
|
||||
|
||||
|
||||
@@ -152,7 +152,9 @@ class Account:
|
||||
# trading related metrics(e.g. high-frequency trading)
|
||||
self.indicator = Indicator()
|
||||
|
||||
def reset(self, freq: str = None, benchmark_config: dict = None, port_metr_enabled: bool = None) -> None:
|
||||
def reset(
|
||||
self, freq: str | None = None, benchmark_config: dict | None = None, port_metr_enabled: bool | None = None
|
||||
) -> None:
|
||||
"""reset freq and report of account
|
||||
|
||||
Parameters
|
||||
|
||||
@@ -55,7 +55,7 @@ def collect_data_loop(
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
trade_strategy: BaseStrategy,
|
||||
trade_executor: BaseExecutor,
|
||||
return_value: dict = None,
|
||||
return_value: dict | None = None,
|
||||
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], None]:
|
||||
"""Generator for collecting the trade decision data for rl training
|
||||
|
||||
|
||||
@@ -254,7 +254,7 @@ class IdxTradeRange(TradeRange):
|
||||
self._start_idx = start_idx
|
||||
self._end_idx = end_idx
|
||||
|
||||
def __call__(self, trade_calendar: TradeCalendarManager = None) -> Tuple[int, int]:
|
||||
def __call__(self, trade_calendar: TradeCalendarManager | None = None) -> Tuple[int, int]:
|
||||
return self._start_idx, self._end_idx
|
||||
|
||||
def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:
|
||||
@@ -315,7 +315,7 @@ class BaseTradeDecision(Generic[DecisionType]):
|
||||
2. Same as `case 1.3`
|
||||
"""
|
||||
|
||||
def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None) -> None:
|
||||
def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange, None] = None) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -554,7 +554,7 @@ class TradeDecisionWO(BaseTradeDecision[Order]):
|
||||
self,
|
||||
order_list: List[Order],
|
||||
strategy: BaseStrategy,
|
||||
trade_range: Union[Tuple[int, int], TradeRange] = None,
|
||||
trade_range: Union[Tuple[int, int], TradeRange, None] = None,
|
||||
) -> None:
|
||||
super().__init__(strategy, trade_range=trade_range)
|
||||
self.order_list = cast(List[Order], order_list)
|
||||
|
||||
@@ -18,7 +18,7 @@ import pandas as pd
|
||||
from qlib.backtest.position import BasePosition
|
||||
|
||||
from ..config import C
|
||||
from ..constant import REG_CN
|
||||
from ..constant import REG_CN, REG_TW
|
||||
from ..data.data import D
|
||||
from ..log import get_module_logger
|
||||
from .decision import Order, OrderDir, OrderHelper
|
||||
@@ -41,10 +41,10 @@ class Exchange:
|
||||
start_time: Union[pd.Timestamp, str] = None,
|
||||
end_time: Union[pd.Timestamp, str] = None,
|
||||
codes: Union[list, str] = "all",
|
||||
deal_price: Union[str, Tuple[str, str], List[str]] = None,
|
||||
deal_price: Union[str, Tuple[str, str], List[str], None] = None,
|
||||
subscribe_fields: list = [],
|
||||
limit_threshold: Union[Tuple[str, str], float, None] = None,
|
||||
volume_threshold: Union[tuple, dict] = None,
|
||||
volume_threshold: Union[tuple, dict, None] = None,
|
||||
open_cost: float = 0.0015,
|
||||
close_cost: float = 0.0025,
|
||||
min_cost: float = 5.0,
|
||||
@@ -148,10 +148,10 @@ class Exchange:
|
||||
# It is just for performance consideration.
|
||||
self.limit_type = self._get_limit_type(limit_threshold)
|
||||
if limit_threshold is None:
|
||||
if C.region == REG_CN:
|
||||
if C.region in [REG_CN, REG_TW]:
|
||||
self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold")
|
||||
elif self.limit_type == self.LT_FLT and abs(cast(float, limit_threshold)) > 0.1:
|
||||
if C.region == REG_CN:
|
||||
if C.region in [REG_CN, REG_TW]:
|
||||
self.logger.warning(f"limit_threshold may not be set to a reasonable value")
|
||||
|
||||
if isinstance(deal_price, str):
|
||||
@@ -340,7 +340,7 @@ class Exchange:
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
direction: int = None,
|
||||
direction: int | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Parameters
|
||||
@@ -406,7 +406,7 @@ class Exchange:
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
direction: int = None,
|
||||
direction: int | None = None,
|
||||
) -> bool:
|
||||
# check if stock can be traded
|
||||
return not (
|
||||
@@ -421,8 +421,8 @@ class Exchange:
|
||||
def deal_order(
|
||||
self,
|
||||
order: Order,
|
||||
trade_account: Account = None,
|
||||
position: BasePosition = None,
|
||||
trade_account: Account | None = None,
|
||||
position: BasePosition | None = None,
|
||||
dealt_order_amount: Dict[str, float] = defaultdict(float),
|
||||
) -> Tuple[float, float, float]:
|
||||
"""
|
||||
@@ -586,7 +586,7 @@ class Exchange:
|
||||
)
|
||||
return amount_dict
|
||||
|
||||
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float = None) -> float:
|
||||
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float | None = None) -> float:
|
||||
"""
|
||||
Calculate the real adjust deal amount when considering the trading unit
|
||||
:param current_amount:
|
||||
@@ -712,8 +712,8 @@ class Exchange:
|
||||
|
||||
def _get_factor_or_raise_error(
|
||||
self,
|
||||
factor: float = None,
|
||||
stock_id: str = None,
|
||||
factor: float | None = None,
|
||||
stock_id: str | None = None,
|
||||
start_time: pd.Timestamp = None,
|
||||
end_time: pd.Timestamp = None,
|
||||
) -> float:
|
||||
@@ -728,8 +728,8 @@ class Exchange:
|
||||
|
||||
def get_amount_of_trade_unit(
|
||||
self,
|
||||
factor: float = None,
|
||||
stock_id: str = None,
|
||||
factor: float | None = None,
|
||||
stock_id: str | None = None,
|
||||
start_time: pd.Timestamp = None,
|
||||
end_time: pd.Timestamp = None,
|
||||
) -> Optional[float]:
|
||||
@@ -762,8 +762,8 @@ class Exchange:
|
||||
def round_amount_by_trade_unit(
|
||||
self,
|
||||
deal_amount: float,
|
||||
factor: float = None,
|
||||
stock_id: str = None,
|
||||
factor: float | None = None,
|
||||
stock_id: str | None = None,
|
||||
start_time: pd.Timestamp = None,
|
||||
end_time: pd.Timestamp = None,
|
||||
) -> float:
|
||||
|
||||
@@ -31,8 +31,8 @@ class BaseExecutor:
|
||||
generate_portfolio_metrics: bool = False,
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
trade_exchange: Exchange = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
trade_exchange: Exchange | None = None,
|
||||
common_infra: CommonInfrastructure | None = None,
|
||||
settle_type: str = BasePosition.ST_NO,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@@ -161,7 +161,7 @@ class BaseExecutor:
|
||||
"""
|
||||
return self.level_infra.get("trade_calendar")
|
||||
|
||||
def reset(self, common_infra: CommonInfrastructure = None, **kwargs: Any) -> None:
|
||||
def reset(self, common_infra: CommonInfrastructure | None = None, **kwargs: Any) -> None:
|
||||
"""
|
||||
- reset `start_time` and `end_time`, used in trade calendar
|
||||
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
|
||||
@@ -227,7 +227,7 @@ class BaseExecutor:
|
||||
def collect_data(
|
||||
self,
|
||||
trade_decision: BaseTradeDecision,
|
||||
return_value: dict = None,
|
||||
return_value: dict | None = None,
|
||||
level: int = 0,
|
||||
) -> Generator[Any, Any, List[object]]:
|
||||
"""Generator for collecting the trade decision data for rl training
|
||||
@@ -327,7 +327,7 @@ class NestedExecutor(BaseExecutor):
|
||||
track_data: bool = False,
|
||||
skip_empty_decision: bool = True,
|
||||
align_range_limit: bool = True,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
common_infra: CommonInfrastructure | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -534,7 +534,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
generate_portfolio_metrics: bool = False,
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
common_infra: CommonInfrastructure | None = None,
|
||||
trade_type: str = TT_SERIAL,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Any, Dict, List, Union
|
||||
@@ -320,7 +321,7 @@ class Position(BasePosition):
|
||||
self.position[stock]["price"] = price_dict[stock]
|
||||
self.position["now_account_value"] = self.calculate_value()
|
||||
|
||||
def _init_stock(self, stock_id: str, amount: float, price: float = None) -> None:
|
||||
def _init_stock(self, stock_id: str, amount: float, price: float | None = None) -> None:
|
||||
"""
|
||||
initialization the stock in current position
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pathlib
|
||||
from collections import OrderedDict
|
||||
@@ -86,7 +87,7 @@ class PortfolioMetrics:
|
||||
self.benches: dict = OrderedDict()
|
||||
self.latest_pm_time: Optional[pd.TimeStamp] = None
|
||||
|
||||
def init_bench(self, freq: str = None, benchmark_config: dict = None) -> None:
|
||||
def init_bench(self, freq: str | None = None, benchmark_config: dict | None = None) -> None:
|
||||
if freq is not None:
|
||||
self.freq = freq
|
||||
self.benchmark_config = benchmark_config
|
||||
@@ -149,15 +150,15 @@ class PortfolioMetrics:
|
||||
self,
|
||||
trade_start_time: Union[str, pd.Timestamp] = None,
|
||||
trade_end_time: Union[str, pd.Timestamp] = None,
|
||||
account_value: float = None,
|
||||
cash: float = None,
|
||||
return_rate: float = None,
|
||||
total_turnover: float = None,
|
||||
turnover_rate: float = None,
|
||||
total_cost: float = None,
|
||||
cost_rate: float = None,
|
||||
stock_value: float = None,
|
||||
bench_value: float = None,
|
||||
account_value: float | None = None,
|
||||
cash: float | None = None,
|
||||
return_rate: float | None = None,
|
||||
total_turnover: float | None = None,
|
||||
turnover_rate: float | None = None,
|
||||
total_cost: float | None = None,
|
||||
cost_rate: float | None = None,
|
||||
stock_value: float | None = None,
|
||||
bench_value: float | None = None,
|
||||
) -> None:
|
||||
# check data
|
||||
if None in [
|
||||
|
||||
@@ -31,7 +31,7 @@ class TradeCalendarManager:
|
||||
freq: str,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
level_infra: LevelInfrastructure | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
@@ -99,7 +99,7 @@ class TradeCalendarManager:
|
||||
def get_trade_step(self) -> int:
|
||||
return self.trade_step
|
||||
|
||||
def get_step_time(self, trade_step: int = None, shift: int = 0) -> Tuple[pd.Timestamp, pd.Timestamp]:
|
||||
def get_step_time(self, trade_step: int | None = None, shift: int = 0) -> Tuple[pd.Timestamp, pd.Timestamp]:
|
||||
"""
|
||||
Get the left and right endpoints of the trade_step'th trading interval
|
||||
|
||||
|
||||
@@ -75,7 +75,8 @@ class Config:
|
||||
def set_conf_from_C(self, config_c):
|
||||
self.update(**config_c.__dict__["_config"])
|
||||
|
||||
def register_from_C(self, config, skip_register=True):
|
||||
@staticmethod
|
||||
def register_from_C(config, skip_register=True):
|
||||
from .utils import set_log_with_config # pylint: disable=C0415
|
||||
|
||||
if C.registered and skip_register:
|
||||
@@ -146,6 +147,7 @@ _default_config = {
|
||||
"redis_host": "127.0.0.1",
|
||||
"redis_port": 6379,
|
||||
"redis_task_db": 1,
|
||||
"redis_password": None,
|
||||
# This value can be reset via qlib.init
|
||||
"logging_level": logging.INFO,
|
||||
# Global configuration of qlib log
|
||||
@@ -202,7 +204,7 @@ _default_config = {
|
||||
"task_url": "mongodb://localhost:27017/",
|
||||
"task_db_name": "default_task_db",
|
||||
},
|
||||
# Shift minute for highfreq minite data, used in backtest
|
||||
# Shift minute for highfreq minute data, used in backtest
|
||||
# if min_data_shift == 0, use default market time [9:30, 11:29, 1:00, 2:59]
|
||||
# if min_data_shift != 0, use shifted market time [9:30, 11:29, 1:00, 2:59] - shift*minute
|
||||
"min_data_shift": 0,
|
||||
|
||||
111
qlib/contrib/analyzer.py
Normal file
111
qlib/contrib/analyzer.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import logging
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
from ..log import get_module_logger
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
|
||||
|
||||
logger = get_module_logger("analysis", logging.INFO)
|
||||
|
||||
|
||||
class AnalyzerTemp:
|
||||
def __init__(self, recorder, output_dir=None, **kwargs):
|
||||
self.recorder = recorder
|
||||
self.output_dir = Path(output_dir) if output_dir else "./"
|
||||
|
||||
def load(self, name: str):
|
||||
"""
|
||||
It behaves the same as self.recorder.load_object.
|
||||
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
the name for the file to be load.
|
||||
|
||||
Return
|
||||
------
|
||||
The stored records.
|
||||
"""
|
||||
return self.recorder.load_object(name)
|
||||
|
||||
def analyse(self, **kwargs):
|
||||
"""
|
||||
Analyse data index, distribution .etc
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
|
||||
Return
|
||||
------
|
||||
The handled data.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `analysis` method.")
|
||||
|
||||
|
||||
class HFAnalyzer(AnalyzerTemp):
|
||||
"""
|
||||
This is the Signal Analysis class that generates the analysis results such as IC and IR.
|
||||
|
||||
default output image filename is "HFAnalyzerTable.jpeg"
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def analyse(self):
|
||||
pred = self.load("pred.pkl")
|
||||
label = self.load("label.pkl")
|
||||
|
||||
long_pre, short_pre = calc_long_short_prec(pred.iloc[:, 0], label.iloc[:, 0], is_alpha=True)
|
||||
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
metrics = {
|
||||
"IC": ic.mean(),
|
||||
"ICIR": ic.mean() / ic.std(),
|
||||
"Rank IC": ric.mean(),
|
||||
"Rank ICIR": ric.mean() / ric.std(),
|
||||
"Long precision": long_pre.mean(),
|
||||
"Short precision": short_pre.mean(),
|
||||
}
|
||||
|
||||
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
metrics.update(
|
||||
{
|
||||
"Long-Short Average Return": long_short_r.mean(),
|
||||
"Long-Short Average Sharpe": long_short_r.mean() / long_short_r.std(),
|
||||
}
|
||||
)
|
||||
|
||||
table = [[k, v] for (k, v) in metrics.items()]
|
||||
plt.table(cellText=table, loc="center")
|
||||
plt.axis("off")
|
||||
plt.savefig(self.output_dir.joinpath("HFAnalyzerTable.jpeg"))
|
||||
plt.clf()
|
||||
|
||||
plt.scatter(np.arange(0, len(pred)), pred.iloc[:, 0])
|
||||
plt.scatter(np.arange(0, len(label)), label.iloc[:, 0])
|
||||
plt.title("HFAnalyzer")
|
||||
plt.savefig(self.output_dir.joinpath("HFAnalyzer.jpeg"))
|
||||
return "HFAnalyzer.jpeg"
|
||||
|
||||
|
||||
class SignalAnalyzer(AnalyzerTemp):
|
||||
"""
|
||||
This is the Signal Analysis class that generates the analysis results such as IC and IR.
|
||||
|
||||
default output image filename is "signalAnalysis.jpeg"
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def analyse(self, dataset=None, **kwargs):
|
||||
label = self.load("label.pkl")
|
||||
|
||||
plt.hist(label)
|
||||
plt.title("SignalAnalyzer")
|
||||
plt.savefig(self.output_dir.joinpath("signalAnalysis.jpeg"))
|
||||
|
||||
return "signalAnalysis.jpeg"
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Optional
|
||||
from qlib.utils.data import update_config
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...data.dataset.processor import Processor
|
||||
from ...utils import get_callable_kwargs
|
||||
@@ -56,13 +58,14 @@ class Alpha360(DataHandlerLP):
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
filter_pipe=None,
|
||||
inst_processor=None,
|
||||
inst_processors=None,
|
||||
data_loader: Optional[dict] = None,
|
||||
**kwargs
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
_data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": {
|
||||
@@ -71,15 +74,17 @@ class Alpha360(DataHandlerLP):
|
||||
},
|
||||
"filter_pipe": filter_pipe,
|
||||
"freq": freq,
|
||||
"inst_processor": inst_processor,
|
||||
"inst_processors": inst_processors,
|
||||
},
|
||||
}
|
||||
if data_loader is not None:
|
||||
update_config(_data_loader, data_loader)
|
||||
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
data_loader=_data_loader,
|
||||
learn_processors=learn_processors,
|
||||
infer_processors=infer_processors,
|
||||
**kwargs
|
||||
@@ -152,13 +157,14 @@ class Alpha158(DataHandlerLP):
|
||||
fit_end_time=None,
|
||||
process_type=DataHandlerLP.PTYPE_A,
|
||||
filter_pipe=None,
|
||||
inst_processor=None,
|
||||
inst_processors=None,
|
||||
data_loader: Optional[dict] = None,
|
||||
**kwargs
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
_data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": {
|
||||
@@ -167,14 +173,16 @@ class Alpha158(DataHandlerLP):
|
||||
},
|
||||
"filter_pipe": filter_pipe,
|
||||
"freq": freq,
|
||||
"inst_processor": inst_processor,
|
||||
"inst_processors": inst_processors,
|
||||
},
|
||||
}
|
||||
if data_loader is not None:
|
||||
update_config(_data_loader, data_loader)
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
data_loader=_data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
process_type=process_type,
|
||||
|
||||
@@ -44,7 +44,7 @@ class HighFreqHandler(DataHandlerLP):
|
||||
names = []
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_paused = "Select(Gt($hx_paused_num, 1.001), {0})"
|
||||
template_paused = "Select(Gt($paused_num, 1.001), {0})"
|
||||
|
||||
def get_normalized_price_feature(price_field, shift=0):
|
||||
# norm with the close price of 237th minute of yesterday.
|
||||
@@ -113,8 +113,12 @@ class HighFreqGeneralHandler(DataHandlerLP):
|
||||
fit_end_time=None,
|
||||
drop_raw=True,
|
||||
day_length=240,
|
||||
freq="1min",
|
||||
columns=["$open", "$high", "$low", "$close", "$vwap"],
|
||||
inst_processors=None,
|
||||
):
|
||||
self.day_length = day_length
|
||||
self.columns = columns
|
||||
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
@@ -124,7 +128,8 @@ class HighFreqGeneralHandler(DataHandlerLP):
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
"freq": freq,
|
||||
"inst_processors": inst_processors,
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
@@ -160,19 +165,13 @@ class HighFreqGeneralHandler(DataHandlerLP):
|
||||
)
|
||||
return feature_ops
|
||||
|
||||
fields += [get_normalized_price_feature("$open", 0)]
|
||||
fields += [get_normalized_price_feature("$high", 0)]
|
||||
fields += [get_normalized_price_feature("$low", 0)]
|
||||
fields += [get_normalized_price_feature("$close", 0)]
|
||||
fields += [get_normalized_price_feature("$vwap", 0)]
|
||||
names += ["$open", "$high", "$low", "$close", "$vwap"]
|
||||
for column_name in self.columns:
|
||||
fields.append(get_normalized_price_feature(column_name, 0))
|
||||
names.append(column_name)
|
||||
|
||||
fields += [get_normalized_price_feature("$open", self.day_length)]
|
||||
fields += [get_normalized_price_feature("$high", self.day_length)]
|
||||
fields += [get_normalized_price_feature("$low", self.day_length)]
|
||||
fields += [get_normalized_price_feature("$close", self.day_length)]
|
||||
fields += [get_normalized_price_feature("$vwap", self.day_length)]
|
||||
names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"]
|
||||
for column_name in self.columns:
|
||||
fields.append(get_normalized_price_feature(column_name, self.day_length))
|
||||
names.append(column_name + "_1")
|
||||
|
||||
# calculate and fill nan with 0
|
||||
fields += [
|
||||
@@ -258,14 +257,19 @@ class HighFreqGeneralBacktestHandler(DataHandler):
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
day_length=240,
|
||||
freq="1min",
|
||||
columns=["$close", "$vwap", "$volume"],
|
||||
inst_processors=None,
|
||||
):
|
||||
self.day_length = day_length
|
||||
self.columns = set(columns)
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
"freq": freq,
|
||||
"inst_processors": inst_processors,
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
@@ -279,21 +283,24 @@ class HighFreqGeneralBacktestHandler(DataHandler):
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
template_paused = f"Cut({{0}}, {self.day_length * 2}, None)"
|
||||
template_fillnan = "FFillNan({0})"
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
fields += [
|
||||
template_paused.format(template_fillnan.format("$close")),
|
||||
]
|
||||
names += ["$close0"]
|
||||
if "$close" in self.columns:
|
||||
template_paused = f"Cut({{0}}, {self.day_length * 2}, None)"
|
||||
template_fillnan = "FFillNan({0})"
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
fields += [
|
||||
template_paused.format(template_fillnan.format("$close")),
|
||||
]
|
||||
names += ["$close0"]
|
||||
|
||||
fields += [
|
||||
template_paused.format(template_if.format(template_fillnan.format("$close"), "$vwap")),
|
||||
]
|
||||
names += ["$vwap0"]
|
||||
if "$vwap" in self.columns:
|
||||
fields += [
|
||||
template_paused.format(template_if.format(template_fillnan.format("$close"), "$vwap")),
|
||||
]
|
||||
names += ["$vwap0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$volume"))]
|
||||
names += ["$volume0"]
|
||||
if "$volume" in self.columns:
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$volume"))]
|
||||
names += ["$volume0"]
|
||||
|
||||
return fields, names
|
||||
|
||||
@@ -308,6 +315,7 @@ class HighFreqOrderHandler(DataHandlerLP):
|
||||
learn_processors=[],
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
inst_processors=None,
|
||||
drop_raw=True,
|
||||
):
|
||||
|
||||
@@ -320,6 +328,7 @@ class HighFreqOrderHandler(DataHandlerLP):
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
"inst_processors": inst_processors,
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
@@ -479,7 +488,7 @@ class HighFreqBacktestOrderHandler(DataHandler):
|
||||
names = []
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_paused = "Select(Gt($hx_paused_num, 1.001), {0})"
|
||||
template_paused = "Select(Gt($paused_num, 1.001), {0})"
|
||||
template_fillnan = "FFillNan({0})"
|
||||
fields += [
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
|
||||
@@ -28,6 +28,7 @@ class HighFreqProvider:
|
||||
feature_conf: dict,
|
||||
label_conf: Optional[dict] = None,
|
||||
backtest_conf: dict = None,
|
||||
freq: str = "1min",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.start_time = start_time
|
||||
@@ -42,6 +43,7 @@ class HighFreqProvider:
|
||||
self.backtest_conf = backtest_conf
|
||||
self.qlib_conf = qlib_conf
|
||||
self.logger = get_module_logger("HighFreqProvider")
|
||||
self.freq = freq
|
||||
|
||||
def get_pre_datasets(self):
|
||||
"""Generate the training, validation and test datasets for prediction
|
||||
@@ -116,8 +118,8 @@ class HighFreqProvider:
|
||||
# This code used the copy-on-write feature of Linux
|
||||
# to avoid calculating the calendar multiple times in the subprocess.
|
||||
# This code may accelerate, but may be not useful on Windows and Mac Os
|
||||
Cal.calendar(freq="1min")
|
||||
get_calendar_day(freq="1min")
|
||||
Cal.calendar(freq=self.freq)
|
||||
get_calendar_day(freq=self.freq)
|
||||
|
||||
def _gen_dataframe(self, config, datasets=["train", "valid", "test"]):
|
||||
try:
|
||||
@@ -126,7 +128,7 @@ class HighFreqProvider:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
if os.path.isfile(path):
|
||||
start = time.time()
|
||||
self.logger.info("Dataset exists, load from disk.", __name__)
|
||||
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
|
||||
|
||||
# res = dataset.prepare(['train', 'valid', 'test'])
|
||||
with open(path, "rb") as f:
|
||||
@@ -135,11 +137,11 @@ class HighFreqProvider:
|
||||
res = [data[i] for i in datasets]
|
||||
else:
|
||||
res = data.prepare(datasets)
|
||||
self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
|
||||
self.logger.info(f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}")
|
||||
else:
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
self.logger.info("Generating dataset", __name__)
|
||||
self.logger.info(f"[{__name__}]Generating dataset")
|
||||
start_time = time.time()
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
@@ -158,7 +160,7 @@ class HighFreqProvider:
|
||||
with open(path[:-4] + "test.pkl", "wb") as f:
|
||||
pkl.dump(testset, f)
|
||||
res = [data[i] for i in datasets]
|
||||
self.logger.info(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
|
||||
self.logger.info(f"[{__name__}]Data generated, time cost: {(time.time() - start_time):.2f}")
|
||||
return res
|
||||
|
||||
def _gen_data(self, config, datasets=["train", "valid", "test"]):
|
||||
@@ -168,7 +170,7 @@ class HighFreqProvider:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
if os.path.isfile(path):
|
||||
start = time.time()
|
||||
self.logger.info("Dataset exists, load from disk.", __name__)
|
||||
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
|
||||
|
||||
# res = dataset.prepare(['train', 'valid', 'test'])
|
||||
with open(path, "rb") as f:
|
||||
@@ -177,18 +179,18 @@ class HighFreqProvider:
|
||||
res = [data[i] for i in datasets]
|
||||
else:
|
||||
res = data.prepare(datasets)
|
||||
self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
|
||||
self.logger.info(f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}")
|
||||
else:
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
self.logger.info("Generating dataset", __name__)
|
||||
self.logger.info(f"[{__name__}]Generating dataset")
|
||||
start_time = time.time()
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
dataset.config(dump_all=True, recursive=True)
|
||||
dataset.to_pickle(path)
|
||||
res = dataset.prepare(datasets)
|
||||
self.logger.info(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
|
||||
self.logger.info(f"[{__name__}]Data generated, time cost: {(time.time() - start_time):.2f}")
|
||||
return res
|
||||
|
||||
def _gen_dataset(self, config):
|
||||
@@ -198,21 +200,21 @@ class HighFreqProvider:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
if os.path.isfile(path):
|
||||
start = time.time()
|
||||
self.logger.info("Dataset exists, load from disk.", __name__)
|
||||
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
|
||||
|
||||
with open(path, "rb") as f:
|
||||
dataset = pkl.load(f)
|
||||
self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
|
||||
self.logger.info(f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}")
|
||||
else:
|
||||
start = time.time()
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
self.logger.info("Generating dataset", __name__)
|
||||
self.logger.info(f"[{__name__}]Generating dataset")
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
|
||||
self.logger.info(f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}")
|
||||
dataset.prepare(["train", "valid", "test"])
|
||||
self.logger.info(f"Dataset prepared, time cost: {time.time() - start:.2f}", __name__)
|
||||
self.logger.info(f"[{__name__}]Dataset prepared, time cost: {time.time() - start:.2f}")
|
||||
dataset.config(dump_all=True, recursive=True)
|
||||
dataset.to_pickle(path)
|
||||
return dataset
|
||||
@@ -225,22 +227,22 @@ class HighFreqProvider:
|
||||
|
||||
if os.path.isfile(path + "tmp_dataset.pkl"):
|
||||
start = time.time()
|
||||
self.logger.info("Dataset exists, load from disk.", __name__)
|
||||
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
|
||||
else:
|
||||
start = time.time()
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
self.logger.info("Generating dataset", __name__)
|
||||
self.logger.info(f"[{__name__}]Generating dataset")
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
|
||||
self.logger.info(f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}")
|
||||
dataset.config(dump_all=False, recursive=True)
|
||||
dataset.to_pickle(path + "tmp_dataset.pkl")
|
||||
|
||||
with open(path + "tmp_dataset.pkl", "rb") as f:
|
||||
new_dataset = pkl.load(f)
|
||||
|
||||
time_list = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="1min")[::240]
|
||||
time_list = D.calendar(start_time=self.start_time, end_time=self.end_time, freq=self.freq)[::240]
|
||||
|
||||
def generate_dataset(times):
|
||||
if os.path.isfile(path + times.strftime("%Y-%m-%d") + ".pkl"):
|
||||
@@ -266,15 +268,15 @@ class HighFreqProvider:
|
||||
|
||||
if os.path.isfile(path + "tmp_dataset.pkl"):
|
||||
start = time.time()
|
||||
self.logger.info("Dataset exists, load from disk.", __name__)
|
||||
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
|
||||
else:
|
||||
start = time.time()
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
self.logger.info("Generating dataset", __name__)
|
||||
self.logger.info(f"[{__name__}]Generating dataset")
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
|
||||
self.logger.info(f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}")
|
||||
dataset.config(dump_all=False, recursive=True)
|
||||
dataset.to_pickle(path + "tmp_dataset.pkl")
|
||||
|
||||
@@ -283,7 +285,7 @@ class HighFreqProvider:
|
||||
|
||||
instruments = D.instruments(market="all")
|
||||
stock_list = D.list_instruments(
|
||||
instruments=instruments, start_time=self.start_time, end_time=self.end_time, freq="1min", as_list=True
|
||||
instruments=instruments, start_time=self.start_time, end_time=self.end_time, freq=self.freq, as_list=True
|
||||
)
|
||||
|
||||
def generate_dataset(stock):
|
||||
|
||||
@@ -55,8 +55,10 @@ class InternalData:
|
||||
# The handler is initialized for only once.
|
||||
if not trainer.has_worker():
|
||||
self.dh = init_task_handler(perf_task_tpl)
|
||||
self.dh.config(dump_all=False) # in some cases, the data handler are saved to disk with `dump_all=True`
|
||||
else:
|
||||
self.dh = init_instance_by_config(perf_task_tpl["dataset"]["kwargs"]["handler"])
|
||||
assert self.dh.dump_all is False # otherwise, it will save all the detailed data
|
||||
|
||||
seg = perf_task_tpl["dataset"]["kwargs"]["segments"]
|
||||
|
||||
@@ -77,7 +79,7 @@ class InternalData:
|
||||
get_module_logger("Internal Data").info("the data has been initialized")
|
||||
else:
|
||||
# train new models
|
||||
assert 0 == len(recorders), "An empty experiment is required for setup `InternalData``"
|
||||
assert 0 == len(recorders), "An empty experiment is required for setup `InternalData`"
|
||||
trainer.train(gen_task)
|
||||
|
||||
# 2) extract the similarity matrix
|
||||
@@ -119,6 +121,7 @@ class MetaTaskDS(MetaTask):
|
||||
|
||||
def __init__(self, task: dict, meta_info: pd.DataFrame, mode: str = MetaTask.PROC_MODE_FULL, fill_method="max"):
|
||||
"""
|
||||
|
||||
The description of the processed data
|
||||
|
||||
time_perf: A array with shape <hist_step_n * step, data pieces> -> data piece performance
|
||||
@@ -132,6 +135,10 @@ class MetaTaskDS(MetaTask):
|
||||
[0., 0., 0., ..., 0., 0., 1.],
|
||||
[0., 0., 0., ..., 0., 0., 1.]])
|
||||
|
||||
Parameters
|
||||
----------
|
||||
meta_info: pd.DataFrame
|
||||
please refer to the docs of _prepare_meta_ipt for detailed explanation.
|
||||
"""
|
||||
super().__init__(task, meta_info)
|
||||
self.fill_method = fill_method
|
||||
@@ -180,12 +187,41 @@ class MetaTaskDS(MetaTask):
|
||||
self.processed_meta_input = data_to_tensor(self.processed_meta_input)
|
||||
|
||||
def _get_processed_meta_info(self):
|
||||
meta_info_norm = self.meta_info.sub(self.meta_info.mean(axis=1), axis=0) # .fillna(0.)
|
||||
if self.fill_method == "max":
|
||||
meta_info_norm = meta_info_norm.T.fillna(
|
||||
meta_info_norm.max(axis=1)
|
||||
).T # fill it with row max to align with previous implementation
|
||||
meta_info_norm = self.meta_info.sub(self.meta_info.mean(axis=1), axis=0)
|
||||
if self.fill_method.startswith("max"):
|
||||
suffix = self.fill_method.lstrip("max")
|
||||
if suffix == "seg":
|
||||
fill_value = {}
|
||||
for col in meta_info_norm.columns:
|
||||
fill_value[col] = meta_info_norm.loc[meta_info_norm[col].isna(), :].dropna(axis=1).mean().max()
|
||||
fill_value = pd.Series(fill_value).sort_index()
|
||||
# The NaN Values are filled segment-wise. Below is an exampleof fill_value
|
||||
# 2009-01-05 2009-02-06 0.145809
|
||||
# 2009-02-09 2009-03-06 0.148005
|
||||
# 2009-03-09 2009-04-03 0.090385
|
||||
# 2009-04-07 2009-05-05 0.114318
|
||||
# 2009-05-06 2009-06-04 0.119328
|
||||
# ...
|
||||
meta_info_norm = meta_info_norm.fillna(fill_value)
|
||||
else:
|
||||
if len(suffix) > 0:
|
||||
get_module_logger("MetaTaskDS").warning(
|
||||
f"fill_method={self.fill_method}; the info after can't be correctly parsed. Please check your parameters."
|
||||
)
|
||||
fill_value = meta_info_norm.max(axis=1)
|
||||
# fill it with row max to align with previous implementation
|
||||
# This will magnify the data similarity when data is in daily freq
|
||||
|
||||
# the fill value corresponds to data like this
|
||||
# It get a performance value for each day.
|
||||
# The performance value are get from other models on this day
|
||||
# 2009-01-16 0.276320
|
||||
# 2009-01-19 0.280603
|
||||
# ...
|
||||
# 2011-06-27 0.203773
|
||||
meta_info_norm = meta_info_norm.T.fillna(fill_value).T
|
||||
elif self.fill_method == "zero":
|
||||
# It will fillna(0.0) at the end.
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
@@ -286,7 +322,33 @@ class MetaDatasetDS(MetaTaskDataset):
|
||||
logger.warning(f"ValueError: {e}")
|
||||
assert len(self.meta_task_l) > 0, "No meta tasks found. Please check the data and setting"
|
||||
|
||||
def _prepare_meta_ipt(self, task):
|
||||
def _prepare_meta_ipt(self, task) -> pd.DataFrame:
|
||||
"""
|
||||
Please refer to `self.internal_data.setup` for detailed information about `self.internal_data.data_ic_df`
|
||||
|
||||
Indices with format below can be successfully sliced by `ic_df.loc[:end, pd.IndexSlice[:, :end]]`
|
||||
|
||||
2021-06-21 2021-06-04 .. 2021-03-22 2021-03-08
|
||||
2021-07-02 2021-06-18 .. 2021-04-02 None
|
||||
|
||||
Returns
|
||||
-------
|
||||
a pd.DataFrame with similar content below.
|
||||
- each column corresponds to a trained model named by the training data range
|
||||
- each row corresponds to a day of data tested by the models of the columns
|
||||
- The rows cells that overlaps with the data used by columns are masked
|
||||
|
||||
|
||||
2009-01-05 2009-02-09 ... 2011-04-27 2011-05-26
|
||||
2009-02-06 2009-03-06 ... 2011-05-25 2011-06-23
|
||||
datetime ...
|
||||
2009-01-13 NaN 0.310639 ... -0.169057 0.137792
|
||||
2009-01-14 NaN 0.261086 ... -0.143567 0.082581
|
||||
... ... ... ... ... ...
|
||||
2011-06-30 -0.054907 -0.020219 ... -0.023226 NaN
|
||||
2011-07-01 -0.075762 -0.026626 ... -0.003167 NaN
|
||||
|
||||
"""
|
||||
ic_df = self.internal_data.data_ic_df
|
||||
|
||||
segs = task["dataset"]["kwargs"]["segments"]
|
||||
@@ -294,15 +356,19 @@ class MetaDatasetDS(MetaTaskDataset):
|
||||
ic_df_avail = ic_df.loc[:end, pd.IndexSlice[:, :end]]
|
||||
|
||||
# meta data set focus on the **information** instead of preprocess
|
||||
# 1) filter the future info
|
||||
def mask_future(s):
|
||||
"""mask future information"""
|
||||
# from qlib.utils import get_date_by_shift
|
||||
# 1) filter the overlap info
|
||||
def mask_overlap(s):
|
||||
"""
|
||||
mask overlap information
|
||||
data after self.name[end] with self.trunc_days that contains future info are also considered as overlap info
|
||||
|
||||
Approximately the diagnal + horizon length of data are masked.
|
||||
"""
|
||||
start, end = s.name
|
||||
end = get_date_by_shift(trading_date=end, shift=self.trunc_days - 1, future=True)
|
||||
return s.mask((s.index >= start) & (s.index <= end))
|
||||
|
||||
ic_df_avail = ic_df_avail.apply(mask_future) # apply to each col
|
||||
ic_df_avail = ic_df_avail.apply(mask_overlap) # apply to each col
|
||||
|
||||
# 2) filter the info with too long periods
|
||||
total_len = self.step * self.hist_step_n
|
||||
|
||||
@@ -52,6 +52,7 @@ class MetaModelDS(MetaTaskModel):
|
||||
lr=0.0001,
|
||||
max_epoch=100,
|
||||
seed=43,
|
||||
alpha=0.0,
|
||||
):
|
||||
self.step = step
|
||||
self.hist_step_n = hist_step_n
|
||||
@@ -61,6 +62,7 @@ class MetaModelDS(MetaTaskModel):
|
||||
self.lr = lr
|
||||
self.max_epoch = max_epoch
|
||||
self.fitted = False
|
||||
self.alpha = alpha
|
||||
torch.manual_seed(seed)
|
||||
|
||||
def run_epoch(self, phase, task_list, epoch, opt, loss_l, ignore_weight=False):
|
||||
@@ -144,7 +146,11 @@ class MetaModelDS(MetaTaskModel):
|
||||
) # debug: record when the test phase starts
|
||||
|
||||
self.tn = PredNet(
|
||||
step=self.step, hist_step_n=self.hist_step_n, clip_weight=self.clip_weight, clip_method=self.clip_method
|
||||
step=self.step,
|
||||
hist_step_n=self.hist_step_n,
|
||||
clip_weight=self.clip_weight,
|
||||
clip_method=self.clip_method,
|
||||
alpha=self.alpha,
|
||||
)
|
||||
|
||||
opt = optim.Adam(self.tn.parameters(), lr=self.lr)
|
||||
|
||||
@@ -41,11 +41,18 @@ class TimeWeightMeta(SingleMetaBase):
|
||||
|
||||
|
||||
class PredNet(nn.Module):
|
||||
def __init__(self, step, hist_step_n, clip_weight=None, clip_method="tanh"):
|
||||
def __init__(self, step, hist_step_n, clip_weight=None, clip_method="tanh", alpha: float = 0.0):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
alpha : float
|
||||
the regularization for sub model (useful when align meta model with linear submodel)
|
||||
"""
|
||||
super().__init__()
|
||||
self.step = step
|
||||
self.twm = TimeWeightMeta(hist_step_n=hist_step_n, clip_weight=clip_weight, clip_method=clip_method)
|
||||
self.init_paramters(hist_step_n)
|
||||
self.alpha = alpha
|
||||
|
||||
def get_sample_weights(self, X, time_perf, time_belong, ignore_weight=False):
|
||||
weights = torch.from_numpy(np.ones(X.shape[0])).float().to(X.device)
|
||||
@@ -59,7 +66,7 @@ class PredNet(nn.Module):
|
||||
"""Please refer to the docs of MetaTaskDS for the description of the variables"""
|
||||
weights = self.get_sample_weights(X, time_perf, time_belong, ignore_weight=ignore_weight)
|
||||
X_w = X.T * weights.view(1, -1)
|
||||
theta = torch.inverse(X_w @ X) @ X_w @ y
|
||||
theta = torch.inverse(X_w @ X + self.alpha * torch.eye(X_w.shape[0])) @ X_w @ y
|
||||
return X_test @ theta, weights
|
||||
|
||||
def init_paramters(self, hist_step_n):
|
||||
|
||||
@@ -5,6 +5,9 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from qlib.constant import EPS
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
|
||||
class ICLoss(nn.Module):
|
||||
def forward(self, pred, y, idx, skip_size=50):
|
||||
@@ -24,6 +27,7 @@ class ICLoss(nn.Module):
|
||||
diff_point.append(i)
|
||||
prev = date
|
||||
diff_point.append(None)
|
||||
# The lengths of diff_point will be one more larger then diff_point
|
||||
|
||||
ic_all = 0.0
|
||||
skip_n = 0
|
||||
@@ -34,13 +38,23 @@ class ICLoss(nn.Module):
|
||||
skip_n += 1
|
||||
continue
|
||||
y_focus = y[start_i:end_i]
|
||||
if pred_focus.std() < EPS or y_focus.std() < EPS:
|
||||
# These cases often happend at the end of test data.
|
||||
# Usually caused by fillna(0.)
|
||||
skip_n += 1
|
||||
continue
|
||||
|
||||
ic_day = torch.dot(
|
||||
(pred_focus - pred_focus.mean()) / np.sqrt(pred_focus.shape[0]) / pred_focus.std(),
|
||||
(y_focus - y_focus.mean()) / np.sqrt(y_focus.shape[0]) / y_focus.std(),
|
||||
)
|
||||
ic_all += ic_day
|
||||
if len(diff_point) - 1 - skip_n <= 0:
|
||||
raise ValueError("No enough data for calculating iC")
|
||||
raise ValueError("No enough data for calculating IC")
|
||||
if skip_n > 0:
|
||||
get_module_logger("ICLoss").info(
|
||||
f"{skip_n} days are skipped due to zero std or small scale of valid samples."
|
||||
)
|
||||
ic_mean = ic_all / (len(diff_point) - 1 - skip_n)
|
||||
return -ic_mean # ic loss
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
from scipy.optimize import nnls
|
||||
from sklearn.linear_model import LinearRegression, Ridge, Lasso
|
||||
@@ -29,7 +30,7 @@ class LinearModel(Model):
|
||||
RIDGE = "ridge"
|
||||
LASSO = "lasso"
|
||||
|
||||
def __init__(self, estimator="ols", alpha=0.0, fit_intercept=False):
|
||||
def __init__(self, estimator="ols", alpha=0.0, fit_intercept=False, include_valid: bool = False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -39,6 +40,9 @@ class LinearModel(Model):
|
||||
l1 or l2 regularization parameter
|
||||
fit_intercept : bool
|
||||
whether fit intercept
|
||||
include_valid: bool
|
||||
Should the validation data be included for training?
|
||||
The validation data should be included
|
||||
"""
|
||||
assert estimator in [self.OLS, self.NNLS, self.RIDGE, self.LASSO], f"unsupported estimator `{estimator}`"
|
||||
self.estimator = estimator
|
||||
@@ -49,9 +53,16 @@ class LinearModel(Model):
|
||||
self.fit_intercept = fit_intercept
|
||||
|
||||
self.coef_ = None
|
||||
self.include_valid = include_valid
|
||||
|
||||
def fit(self, dataset: DatasetH, reweighter: Reweighter = None):
|
||||
df_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
if self.include_valid:
|
||||
try:
|
||||
df_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
df_train = pd.concat([df_train, df_valid])
|
||||
except KeyError:
|
||||
get_module_logger("LinearModel").info("include_valid=True, but valid does not exist")
|
||||
if df_train.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
if reweighter is not None:
|
||||
|
||||
@@ -56,7 +56,7 @@ class ADARNN(Model):
|
||||
n_splits=2,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**_
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("ADARNN")
|
||||
@@ -81,7 +81,7 @@ class ADARNN(Model):
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.n_splits = n_splits
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -213,7 +213,8 @@ class ADARNN(Model):
|
||||
weight_mat = self.transform_type(out_weight_list)
|
||||
return weight_mat, None
|
||||
|
||||
def calc_all_metrics(self, pred):
|
||||
@staticmethod
|
||||
def calc_all_metrics(pred):
|
||||
"""pred is a pandas dataframe that has two attributes: score (pred) and label (real)"""
|
||||
res = {}
|
||||
ic = pred.groupby(level="datetime").apply(lambda x: x.label.corr(x.score))
|
||||
@@ -259,8 +260,6 @@ class ADARNN(Model):
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
@@ -400,7 +399,7 @@ class AdaRNN(nn.Module):
|
||||
self.model_type = model_type
|
||||
self.trans_loss = trans_loss
|
||||
self.len_seq = len_seq
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
in_size = self.n_input
|
||||
|
||||
features = nn.ModuleList()
|
||||
@@ -499,7 +498,8 @@ class AdaRNN(nn.Module):
|
||||
res = self.softmax(weight).squeeze()
|
||||
return res
|
||||
|
||||
def get_features(self, output_list):
|
||||
@staticmethod
|
||||
def get_features(output_list):
|
||||
fea_list_src, fea_list_tar = [], []
|
||||
for fea in output_list:
|
||||
fea_list_src.append(fea[0 : fea.size(0) // 2])
|
||||
@@ -561,7 +561,7 @@ class TransferLoss:
|
||||
"""
|
||||
self.loss_type = loss_type
|
||||
self.input_dim = input_dim
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
|
||||
def compute(self, X, Y):
|
||||
"""Compute adaptation loss
|
||||
@@ -676,7 +676,8 @@ class MMD_loss(nn.Module):
|
||||
self.fix_sigma = None
|
||||
self.kernel_type = kernel_type
|
||||
|
||||
def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
|
||||
@staticmethod
|
||||
def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
|
||||
n_samples = int(source.size()[0]) + int(target.size()[0])
|
||||
total = torch.cat([source, target], dim=0)
|
||||
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
|
||||
@@ -691,7 +692,8 @@ class MMD_loss(nn.Module):
|
||||
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
|
||||
return sum(kernel_val)
|
||||
|
||||
def linear_mmd(self, X, Y):
|
||||
@staticmethod
|
||||
def linear_mmd(X, Y):
|
||||
delta = X.mean(axis=0) - Y.mean(axis=0)
|
||||
loss = delta.dot(delta.T)
|
||||
return loss
|
||||
|
||||
@@ -47,10 +47,6 @@ class DNNModelPytorch(Model):
|
||||
layer sizes
|
||||
lr : float
|
||||
learning rate
|
||||
lr_decay : float
|
||||
learning rate decay
|
||||
lr_decay_steps : int
|
||||
learning rate decay steps
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : int
|
||||
@@ -64,8 +60,6 @@ class DNNModelPytorch(Model):
|
||||
batch_size=2000,
|
||||
early_stop_rounds=50,
|
||||
eval_steps=20,
|
||||
lr_decay=0.96,
|
||||
lr_decay_steps=100,
|
||||
optimizer="gd",
|
||||
loss="mse",
|
||||
GPU=0,
|
||||
@@ -93,8 +87,6 @@ class DNNModelPytorch(Model):
|
||||
self.batch_size = batch_size
|
||||
self.early_stop_rounds = early_stop_rounds
|
||||
self.eval_steps = eval_steps
|
||||
self.lr_decay = lr_decay
|
||||
self.lr_decay_steps = lr_decay_steps
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss_type = loss
|
||||
if isinstance(GPU, str):
|
||||
@@ -116,8 +108,6 @@ class DNNModelPytorch(Model):
|
||||
f"\nbatch_size : {batch_size}"
|
||||
f"\nearly_stop_rounds : {early_stop_rounds}"
|
||||
f"\neval_steps : {eval_steps}"
|
||||
f"\nlr_decay : {lr_decay}"
|
||||
f"\nlr_decay_steps : {lr_decay_steps}"
|
||||
f"\noptimizer : {optimizer}"
|
||||
f"\nloss_type : {loss}"
|
||||
f"\nseed : {seed}"
|
||||
|
||||
@@ -70,7 +70,7 @@ class DayCumsum(ElemOperator):
|
||||
Otherwise, the value is zero.
|
||||
"""
|
||||
|
||||
def __init__(self, feature, start: str = "9:30", end: str = "14:59"):
|
||||
def __init__(self, feature, start: str = "9:30", end: str = "14:59", data_granularity: int = 1):
|
||||
self.feature = feature
|
||||
self.start = datetime.strptime(start, "%H:%M")
|
||||
self.end = datetime.strptime(end, "%H:%M")
|
||||
@@ -80,15 +80,17 @@ class DayCumsum(ElemOperator):
|
||||
self.noon_open = datetime.strptime("13:00", "%H:%M")
|
||||
self.noon_close = datetime.strptime("15:00", "%H:%M")
|
||||
|
||||
self.start_id = time_to_day_index(self.start)
|
||||
self.end_id = time_to_day_index(self.end)
|
||||
self.data_granularity = data_granularity
|
||||
self.start_id = time_to_day_index(self.start) // self.data_granularity
|
||||
self.end_id = time_to_day_index(self.end) // self.data_granularity
|
||||
assert 240 % self.data_granularity == 0
|
||||
|
||||
def period_cusum(self, df):
|
||||
df = df.copy()
|
||||
assert len(df) == 240
|
||||
assert len(df) == 240 // self.data_granularity
|
||||
df.iloc[0 : self.start_id] = 0
|
||||
df = df.cumsum()
|
||||
df.iloc[self.end_id + 1 : 240] = 0
|
||||
df.iloc[self.end_id + 1 : 240 // self.data_granularity] = 0
|
||||
return df
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from functools import partial
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@@ -10,7 +11,11 @@ import matplotlib.pyplot as plt
|
||||
|
||||
from scipy import stats
|
||||
|
||||
from typing import Sequence
|
||||
from qlib.typehint import Literal
|
||||
|
||||
from ..graph import ScatterGraph, SubplotsGraph, BarGraph, HeatmapGraph
|
||||
from ..utils import guess_plotly_rangebreaks
|
||||
|
||||
|
||||
def _group_return(pred_label: pd.DataFrame = None, reverse: bool = False, N: int = 5, **kwargs) -> tuple:
|
||||
@@ -48,12 +53,13 @@ def _group_return(pred_label: pd.DataFrame = None, reverse: bool = False, N: int
|
||||
t_df["long-average"] = t_df["Group1"] - pred_label.groupby(level="datetime")["label"].mean()
|
||||
|
||||
t_df = t_df.dropna(how="all") # for days which does not contain label
|
||||
# FIXME: support HIGH-FREQ
|
||||
t_df.index = t_df.index.strftime("%Y-%m-%d")
|
||||
# Cumulative Return By Group
|
||||
group_scatter_figure = ScatterGraph(
|
||||
t_df.cumsum(),
|
||||
layout=dict(title="Cumulative Return", xaxis=dict(type="category", tickangle=45)),
|
||||
layout=dict(
|
||||
title="Cumulative Return",
|
||||
xaxis=dict(tickangle=45, rangebreaks=kwargs.get("rangebreaks", guess_plotly_rangebreaks(t_df.index))),
|
||||
),
|
||||
).figure
|
||||
|
||||
t_df = t_df.loc[:, ["long-short", "long-average"]]
|
||||
@@ -110,22 +116,36 @@ def _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure:
|
||||
return fig
|
||||
|
||||
|
||||
def _pred_ic(pred_label: pd.DataFrame = None, rank: bool = False, **kwargs) -> tuple:
|
||||
def _pred_ic(
|
||||
pred_label: pd.DataFrame = None, methods: Sequence[Literal["IC", "Rank IC"]] = ("IC", "Rank IC"), **kwargs
|
||||
) -> tuple:
|
||||
"""
|
||||
|
||||
:param pred_label:
|
||||
:param rank:
|
||||
:param pred_label: pd.DataFrame
|
||||
must contain one column of realized return with name `label` and one column of predicted score names `score`.
|
||||
:param methods: Sequence[Literal["IC", "Rank IC"]]
|
||||
IC series to plot.
|
||||
IC is sectional pearson correlation between label and score
|
||||
Rank IC is the spearman correlation between label and score
|
||||
For the Monthly IC, IC histogram, IC Q-Q plot. Only the first type of IC will be plotted.
|
||||
:return:
|
||||
"""
|
||||
if rank:
|
||||
ic = pred_label.groupby(level="datetime").apply(
|
||||
lambda x: x["label"].rank(pct=True).corr(x["score"].rank(pct=True))
|
||||
)
|
||||
else:
|
||||
ic = pred_label.groupby(level="datetime").apply(lambda x: x["label"].corr(x["score"]))
|
||||
_methods_mapping = {"IC": "pearson", "Rank IC": "spearman"}
|
||||
|
||||
_index = ic.index.get_level_values(0).astype("str").str.replace("-", "").str.slice(0, 6)
|
||||
_monthly_ic = ic.groupby(_index).mean()
|
||||
def _corr_series(x, method):
|
||||
return x["label"].corr(x["score"], method=method)
|
||||
|
||||
ic_df = pd.concat(
|
||||
[
|
||||
pred_label.groupby(level="datetime").apply(partial(_corr_series, method=_methods_mapping[m])).rename(m)
|
||||
for m in methods
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
_ic = ic_df.iloc(axis=1)[0]
|
||||
|
||||
_index = _ic.index.get_level_values(0).astype("str").str.replace("-", "").str.slice(0, 6)
|
||||
_monthly_ic = _ic.groupby(_index).mean()
|
||||
_monthly_ic.index = pd.MultiIndex.from_arrays(
|
||||
[_monthly_ic.index.str.slice(0, 4), _monthly_ic.index.str.slice(4, 6)],
|
||||
names=["year", "month"],
|
||||
@@ -148,27 +168,27 @@ def _pred_ic(pred_label: pd.DataFrame = None, rank: bool = False, **kwargs) -> t
|
||||
|
||||
_monthly_ic = _monthly_ic.reindex(fill_index)
|
||||
|
||||
_ic_df = ic.to_frame("ic")
|
||||
ic_bar_figure = ic_figure(_ic_df, kwargs.get("show_nature_day", True))
|
||||
ic_bar_figure = ic_figure(ic_df, kwargs.get("show_nature_day", False))
|
||||
|
||||
ic_heatmap_figure = HeatmapGraph(
|
||||
_monthly_ic.unstack(),
|
||||
layout=dict(title="Monthly IC", yaxis=dict(tickformat=",d")),
|
||||
layout=dict(title="Monthly IC", xaxis=dict(dtick=1), yaxis=dict(tickformat="04d", dtick=1)),
|
||||
graph_kwargs=dict(xtype="array", ytype="array"),
|
||||
).figure
|
||||
|
||||
dist = stats.norm
|
||||
_qqplot_fig = _plot_qq(ic, dist)
|
||||
_qqplot_fig = _plot_qq(_ic, dist)
|
||||
|
||||
if isinstance(dist, stats.norm.__class__):
|
||||
dist_name = "Normal"
|
||||
else:
|
||||
dist_name = "Unknown"
|
||||
|
||||
_ic_df = _ic.to_frame("IC")
|
||||
_bin_size = ((_ic_df.max() - _ic_df.min()) / 20).min()
|
||||
_sub_graph_data = [
|
||||
(
|
||||
"ic",
|
||||
"IC",
|
||||
dict(
|
||||
row=1,
|
||||
col=1,
|
||||
@@ -202,12 +222,13 @@ def _pred_autocorr(pred_label: pd.DataFrame, lag=1, **kwargs) -> tuple:
|
||||
pred = pred_label.copy()
|
||||
pred["score_last"] = pred.groupby(level="instrument")["score"].shift(lag)
|
||||
ac = pred.groupby(level="datetime").apply(lambda x: x["score"].rank(pct=True).corr(x["score_last"].rank(pct=True)))
|
||||
# FIXME: support HIGH-FREQ
|
||||
_df = ac.to_frame("value")
|
||||
_df.index = _df.index.strftime("%Y-%m-%d")
|
||||
ac_figure = ScatterGraph(
|
||||
_df,
|
||||
layout=dict(title="Auto Correlation", xaxis=dict(type="category", tickangle=45)),
|
||||
layout=dict(
|
||||
title="Auto Correlation",
|
||||
xaxis=dict(tickangle=45, rangebreaks=kwargs.get("rangebreaks", guess_plotly_rangebreaks(_df.index))),
|
||||
),
|
||||
).figure
|
||||
return (ac_figure,)
|
||||
|
||||
@@ -233,32 +254,33 @@ def _pred_turnover(pred_label: pd.DataFrame, N=5, lag=1, **kwargs) -> tuple:
|
||||
"Bottom": bottom,
|
||||
}
|
||||
)
|
||||
# FIXME: support HIGH-FREQ
|
||||
r_df.index = r_df.index.strftime("%Y-%m-%d")
|
||||
turnover_figure = ScatterGraph(
|
||||
r_df,
|
||||
layout=dict(title="Top-Bottom Turnover", xaxis=dict(type="category", tickangle=45)),
|
||||
layout=dict(
|
||||
title="Top-Bottom Turnover",
|
||||
xaxis=dict(tickangle=45, rangebreaks=kwargs.get("rangebreaks", guess_plotly_rangebreaks(r_df.index))),
|
||||
),
|
||||
).figure
|
||||
return (turnover_figure,)
|
||||
|
||||
|
||||
def ic_figure(ic_df: pd.DataFrame, show_nature_day=True, **kwargs) -> go.Figure:
|
||||
"""IC figure
|
||||
r"""IC figure
|
||||
|
||||
:param ic_df: ic DataFrame
|
||||
:param show_nature_day: whether to display the abscissa of non-trading day
|
||||
:param \*\*kwargs: contains some parameters to control plot style in plotly. Currently, supports
|
||||
- `rangebreaks`: https://plotly.com/python/time-series/#Hiding-Weekends-and-Holidays
|
||||
:return: plotly.graph_objs.Figure
|
||||
"""
|
||||
if show_nature_day:
|
||||
date_index = pd.date_range(ic_df.index.min(), ic_df.index.max())
|
||||
ic_df = ic_df.reindex(date_index)
|
||||
# FIXME: support HIGH-FREQ
|
||||
ic_df.index = ic_df.index.strftime("%Y-%m-%d")
|
||||
ic_bar_figure = BarGraph(
|
||||
ic_df,
|
||||
layout=dict(
|
||||
title="Information Coefficient (IC)",
|
||||
xaxis=dict(type="category", tickangle=45),
|
||||
xaxis=dict(tickangle=45, rangebreaks=kwargs.get("rangebreaks", guess_plotly_rangebreaks(ic_df.index))),
|
||||
),
|
||||
).figure
|
||||
return ic_bar_figure
|
||||
@@ -272,9 +294,10 @@ def model_performance_graph(
|
||||
rank=False,
|
||||
graph_names: list = ["group_return", "pred_ic", "pred_autocorr"],
|
||||
show_notebook: bool = True,
|
||||
show_nature_day=True,
|
||||
show_nature_day: bool = False,
|
||||
**kwargs,
|
||||
) -> [list, tuple]:
|
||||
"""Model performance
|
||||
r"""Model performance
|
||||
|
||||
:param pred_label: index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[score, label]**.
|
||||
It is usually same as the label of model training(e.g. "Ref($close, -2)/Ref($close, -1) - 1").
|
||||
@@ -297,17 +320,14 @@ def model_performance_graph(
|
||||
:param graph_names: graph names; default ['cumulative_return', 'pred_ic', 'pred_autocorr', 'pred_turnover'].
|
||||
:param show_notebook: whether to display graphics in notebook, the default is `True`.
|
||||
:param show_nature_day: whether to display the abscissa of non-trading day.
|
||||
:param \*\*kwargs: contains some parameters to control plot style in plotly. Currently, supports
|
||||
- `rangebreaks`: https://plotly.com/python/time-series/#Hiding-Weekends-and-Holidays
|
||||
:return: if show_notebook is True, display in notebook; else return `plotly.graph_objs.Figure` list.
|
||||
"""
|
||||
figure_list = []
|
||||
for graph_name in graph_names:
|
||||
fun_res = eval(f"_{graph_name}")(
|
||||
pred_label=pred_label,
|
||||
lag=lag,
|
||||
N=N,
|
||||
reverse=reverse,
|
||||
rank=rank,
|
||||
show_nature_day=show_nature_day,
|
||||
pred_label=pred_label, lag=lag, N=N, reverse=reverse, rank=rank, show_nature_day=show_nature_day, **kwargs
|
||||
)
|
||||
figure_list += fun_res
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ def _get_risk_analysis_figure(analysis_df: pd.DataFrame) -> Iterable[py.Figure]:
|
||||
_figure = SubplotsGraph(
|
||||
_get_all_risk_analysis(analysis_df),
|
||||
kind_map=dict(kind="BarGraph", kwargs={}),
|
||||
subplots_kwargs={"rows": 4, "cols": 1},
|
||||
subplots_kwargs={"rows": 1, "cols": 4},
|
||||
).figure
|
||||
return (_figure,)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import pandas as pd
|
||||
|
||||
from ..graph import ScatterGraph
|
||||
from ..utils import guess_plotly_rangebreaks
|
||||
|
||||
|
||||
def _get_score_ic(pred_label: pd.DataFrame):
|
||||
@@ -19,7 +20,7 @@ def _get_score_ic(pred_label: pd.DataFrame):
|
||||
return pd.DataFrame({"ic": _ic, "rank_ic": _rank_ic})
|
||||
|
||||
|
||||
def score_ic_graph(pred_label: pd.DataFrame, show_notebook: bool = True) -> [list, tuple]:
|
||||
def score_ic_graph(pred_label: pd.DataFrame, show_notebook: bool = True, **kwargs) -> [list, tuple]:
|
||||
"""score IC
|
||||
|
||||
Example:
|
||||
@@ -53,11 +54,13 @@ def score_ic_graph(pred_label: pd.DataFrame, show_notebook: bool = True) -> [lis
|
||||
:return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list.
|
||||
"""
|
||||
_ic_df = _get_score_ic(pred_label)
|
||||
# FIXME: support HIGH-FREQ
|
||||
_ic_df.index = _ic_df.index.strftime("%Y-%m-%d")
|
||||
|
||||
_figure = ScatterGraph(
|
||||
_ic_df,
|
||||
layout=dict(title="Score IC", xaxis=dict(type="category", tickangle=45)),
|
||||
layout=dict(
|
||||
title="Score IC",
|
||||
xaxis=dict(tickangle=45, rangebreaks=kwargs.get("rangebreaks", guess_plotly_rangebreaks(_ic_df.index))),
|
||||
),
|
||||
graph_kwargs={"mode": "lines+markers"},
|
||||
).figure
|
||||
if show_notebook:
|
||||
|
||||
@@ -139,8 +139,8 @@ class FeaACAna(FeaAnalyser):
|
||||
|
||||
class FeaSkewTurt(NumFeaAnalyser):
|
||||
def calc_stat_values(self):
|
||||
self._skew = datetime_groupby_apply(self._dataset, "skew", skip_group=True)
|
||||
self._kurt = datetime_groupby_apply(self._dataset, pd.DataFrame.kurt, skip_group=True)
|
||||
self._skew = datetime_groupby_apply(self._dataset, "skew")
|
||||
self._kurt = datetime_groupby_apply(self._dataset, pd.DataFrame.kurt)
|
||||
|
||||
def plot_single(self, col, ax):
|
||||
self._skew[col].plot(ax=ax, label="skew")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None, sharex=False, sharey=False):
|
||||
@@ -43,3 +44,31 @@ def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None
|
||||
res = res.item()
|
||||
yield res
|
||||
plt.show()
|
||||
|
||||
|
||||
def guess_plotly_rangebreaks(dt_index: pd.DatetimeIndex):
|
||||
"""
|
||||
This function `guesses` the rangebreaks required to remove gaps in datetime index.
|
||||
It basically calculates the difference between a `continuous` datetime index and index given.
|
||||
|
||||
For more details on `rangebreaks` params in plotly, see
|
||||
https://plotly.com/python/reference/layout/xaxis/#layout-xaxis-rangebreaks
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dt_index: pd.DatetimeIndex
|
||||
The datetimes of the data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
the `rangebreaks` to be passed into plotly axis.
|
||||
|
||||
"""
|
||||
dt_idx = dt_index.sort_values()
|
||||
gaps = dt_idx[1:] - dt_idx[:-1]
|
||||
min_gap = gaps.min()
|
||||
gaps_to_break = {}
|
||||
for gap, d in zip(gaps, dt_idx[:-1]):
|
||||
if gap > min_gap:
|
||||
gaps_to_break.setdefault(gap - min_gap, []).append(d + min_gap)
|
||||
return [dict(values=v, dvalue=int(k.total_seconds() * 1000)) for k, v in gaps_to_break.items()]
|
||||
|
||||
@@ -635,7 +635,7 @@ class FileOrderStrategy(BaseStrategy):
|
||||
self.order_df = file
|
||||
else:
|
||||
with get_io_object(file) as f:
|
||||
self.order_df = pd.read_csv(f, dtype={"datetime": np.str})
|
||||
self.order_df = pd.read_csv(f, dtype={"datetime": str})
|
||||
|
||||
self.order_df["datetime"] = self.order_df["datetime"].apply(pd.Timestamp)
|
||||
self.order_df = self.order_df.set_index(["datetime", "instrument"])
|
||||
|
||||
@@ -7,6 +7,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from typing import Dict, List, Text, Tuple, Union
|
||||
from abc import ABC
|
||||
|
||||
from qlib.data import D
|
||||
from qlib.data.dataset import Dataset
|
||||
@@ -17,11 +18,11 @@ from qlib.backtest.signal import Signal, create_signal_from
|
||||
from qlib.backtest.decision import Order, OrderDir, TradeDecisionWO
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils import get_pre_trading_date, load_dataset
|
||||
from qlib.contrib.strategy.order_generator import OrderGenWOInteract
|
||||
from qlib.contrib.strategy.order_generator import OrderGenerator, OrderGenWOInteract
|
||||
from qlib.contrib.strategy.optimizer import EnhancedIndexingOptimizer
|
||||
|
||||
|
||||
class BaseSignalStrategy(BaseStrategy):
|
||||
class BaseSignalStrategy(BaseStrategy, ABC):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -47,7 +48,7 @@ class BaseSignalStrategy(BaseStrategy):
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
- It allowes different trade_exchanges is used in different executions.
|
||||
- For example:
|
||||
- In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
|
||||
- In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it runs faster.
|
||||
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
|
||||
|
||||
"""
|
||||
@@ -64,7 +65,7 @@ class BaseSignalStrategy(BaseStrategy):
|
||||
|
||||
def get_risk_degree(self, trade_step=None):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
Return the proportion of your total value you will use in investment.
|
||||
Dynamically risk_degree will result in Market timing.
|
||||
"""
|
||||
# It will use 95% amount of your total value by default
|
||||
@@ -76,6 +77,7 @@ class TopkDropoutStrategy(BaseSignalStrategy):
|
||||
# 1. Supporting leverage the get_range_limit result from the decision
|
||||
# 2. Supporting alter_outer_trade_decision
|
||||
# 3. Supporting checking the availability of trade decision
|
||||
# 4. Regenerate results with forbid_all_trade_at_limit set to false and flip the default to false, as it is consistent with reality.
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -85,6 +87,7 @@ class TopkDropoutStrategy(BaseSignalStrategy):
|
||||
method_buy="top",
|
||||
hold_thresh=1,
|
||||
only_tradable=False,
|
||||
forbid_all_trade_at_limit=True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -111,6 +114,17 @@ class TopkDropoutStrategy(BaseSignalStrategy):
|
||||
else:
|
||||
|
||||
strategy will make buy sell decision without checking the tradable state of the stock.
|
||||
forbid_all_trade_at_limit : bool
|
||||
if forbid all trades when limit_up or limit_down reached.
|
||||
|
||||
if forbid_all_trade_at_limit:
|
||||
|
||||
strategy will not do any trade when price reaches limit up/down, even not sell at limit up nor buy at
|
||||
limit down, though allowed in reality.
|
||||
|
||||
else:
|
||||
|
||||
strategy will sell at limit up and buy ad limit down.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.topk = topk
|
||||
@@ -119,6 +133,7 @@ class TopkDropoutStrategy(BaseSignalStrategy):
|
||||
self.method_buy = method_buy
|
||||
self.hold_thresh = hold_thresh
|
||||
self.only_tradable = only_tradable
|
||||
self.forbid_all_trade_at_limit = forbid_all_trade_at_limit
|
||||
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
|
||||
@@ -161,7 +176,7 @@ class TopkDropoutStrategy(BaseSignalStrategy):
|
||||
]
|
||||
|
||||
else:
|
||||
# Otherwise, the stock will make decision with out the stock tradable info
|
||||
# Otherwise, the stock will make decision without the stock tradable info
|
||||
def get_first_n(li, n):
|
||||
return list(li)[:n]
|
||||
|
||||
@@ -171,7 +186,7 @@ class TopkDropoutStrategy(BaseSignalStrategy):
|
||||
def filter_stock(li):
|
||||
return li
|
||||
|
||||
current_temp = copy.deepcopy(self.trade_position)
|
||||
current_temp: Position = copy.deepcopy(self.trade_position)
|
||||
# generate order list for this adjust date
|
||||
sell_order_list = []
|
||||
buy_order_list = []
|
||||
@@ -216,7 +231,10 @@ class TopkDropoutStrategy(BaseSignalStrategy):
|
||||
buy = today[: len(sell) + self.topk - len(last)]
|
||||
for code in current_stock_list:
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
stock_id=code, start_time=trade_start_time, end_time=trade_end_time
|
||||
stock_id=code,
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
direction=None if self.forbid_all_trade_at_limit else OrderDir.SELL,
|
||||
):
|
||||
continue
|
||||
if code in sell:
|
||||
@@ -244,7 +262,7 @@ class TopkDropoutStrategy(BaseSignalStrategy):
|
||||
cash += trade_val - trade_cost
|
||||
# buy new stock
|
||||
# note the current has been changed
|
||||
current_stock_list = current_temp.get_stock_list()
|
||||
# current_stock_list = current_temp.get_stock_list()
|
||||
value = cash * self.risk_degree / len(buy) if len(buy) > 0 else 0
|
||||
|
||||
# open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not
|
||||
@@ -253,7 +271,10 @@ class TopkDropoutStrategy(BaseSignalStrategy):
|
||||
for code in buy:
|
||||
# check is stock suspended
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
stock_id=code, start_time=trade_start_time, end_time=trade_end_time
|
||||
stock_id=code,
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
direction=None if self.forbid_all_trade_at_limit else OrderDir.BUY,
|
||||
):
|
||||
continue
|
||||
# buy order
|
||||
@@ -296,15 +317,15 @@ class WeightStrategyBase(BaseSignalStrategy):
|
||||
- It allowes different trade_exchanges is used in different executions.
|
||||
- For example:
|
||||
|
||||
- In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
|
||||
- In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it runs faster.
|
||||
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if isinstance(order_generator_cls_or_obj, type):
|
||||
self.order_generator = order_generator_cls_or_obj()
|
||||
self.order_generator: OrderGenerator = order_generator_cls_or_obj()
|
||||
else:
|
||||
self.order_generator = order_generator_cls_or_obj
|
||||
self.order_generator: OrderGenerator = order_generator_cls_or_obj
|
||||
|
||||
def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time):
|
||||
"""
|
||||
@@ -316,9 +337,8 @@ class WeightStrategyBase(BaseSignalStrategy):
|
||||
pred score for this trade date, index is stock_id, contain 'score' column.
|
||||
current : Position()
|
||||
current position.
|
||||
trade_exchange : Exchange()
|
||||
trade_date : pd.Timestamp
|
||||
trade date.
|
||||
trade_start_time: pd.Timestamp
|
||||
trade_end_time: pd.Timestamp
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -428,7 +448,7 @@ class EnhancedIndexingStrategy(WeightStrategyBase):
|
||||
specific_risk = load_dataset(root + "/" + self.specific_risk_path, index_col=[0])
|
||||
|
||||
if not factor_exp.index.equals(specific_risk.index):
|
||||
# NOTE: for stocks missing specific_risk, we always assume it have the highest volatility
|
||||
# NOTE: for stocks missing specific_risk, we always assume it has the highest volatility
|
||||
specific_risk = specific_risk.reindex(factor_exp.index, fill_value=specific_risk.max())
|
||||
|
||||
universe = factor_exp.index.tolist()
|
||||
|
||||
@@ -783,7 +783,7 @@ class LocalPITProvider(PITProvider):
|
||||
index_path = C.dpm.get_data_uri() / "financial" / instrument.lower() / f"{field}.index"
|
||||
data_path = C.dpm.get_data_uri() / "financial" / instrument.lower() / f"{field}.data"
|
||||
if not (index_path.exists() and data_path.exists()):
|
||||
raise FileNotFoundError("No file is found. Raise exception and ")
|
||||
raise FileNotFoundError("No file is found.")
|
||||
# NOTE: The most significant performance loss is here.
|
||||
# Does the acceleration that makes the program complicated really matters?
|
||||
# - It makes parameters of the interface complicate
|
||||
@@ -797,14 +797,14 @@ class LocalPITProvider(PITProvider):
|
||||
cur_time_int = int(cur_time.year) * 10000 + int(cur_time.month) * 100 + int(cur_time.day)
|
||||
loc = np.searchsorted(data["date"], cur_time_int, side="right")
|
||||
if loc <= 0:
|
||||
return pd.Series()
|
||||
return pd.Series(dtype=C.pit_record_type["value"])
|
||||
last_period = data["period"][:loc].max() # return the latest quarter
|
||||
first_period = data["period"][:loc].min()
|
||||
period_list = get_period_list(first_period, last_period, quarterly)
|
||||
if period is not None:
|
||||
# NOTE: `period` has higher priority than `start_index` & `end_index`
|
||||
if period not in period_list:
|
||||
return pd.Series()
|
||||
return pd.Series(dtype=C.pit_record_type["value"])
|
||||
else:
|
||||
period_list = [period]
|
||||
else:
|
||||
@@ -868,7 +868,7 @@ class LocalExpressionProvider(ExpressionProvider):
|
||||
# Ensure that each column type is consistent
|
||||
# FIXME:
|
||||
# 1) The stock data is currently float. If there is other types of data, this part needs to be re-implemented.
|
||||
# 2) The the precision should be configurable
|
||||
# 2) The precision should be configurable
|
||||
try:
|
||||
series = series.astype(np.float32)
|
||||
except ValueError:
|
||||
|
||||
@@ -417,7 +417,7 @@ class TSDataSampler:
|
||||
# NOTE: bool(np.nan) is True !!!!!!!!
|
||||
# make sure reindex comes first. Otherwise extra NaN may appear.
|
||||
flt_data = flt_data.swaplevel()
|
||||
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
|
||||
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(bool)
|
||||
self.flt_data = flt_data.values
|
||||
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
|
||||
self.data_index = self.data_index[np.where(self.flt_data)[0]]
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Callable, Union, Tuple, List, Iterator, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.typehint import Literal
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...utils import init_instance_by_config
|
||||
from ...utils.serial import Serializable
|
||||
@@ -49,6 +50,8 @@ class DataHandler(Serializable):
|
||||
- Fetching data with `col_set=CS_RAW` will return the raw data and may avoid pandas from copying the data when calling `loc`
|
||||
"""
|
||||
|
||||
_data: pd.DataFrame # underlying data.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instruments=None,
|
||||
@@ -155,6 +158,11 @@ class DataHandler(Serializable):
|
||||
"""
|
||||
fetch data from underlying data source
|
||||
|
||||
Design motivation:
|
||||
- providing a unified interface for underlying data.
|
||||
- Potential to make the interface more friendly.
|
||||
- User can improve performance when fetching data in this extra layer
|
||||
|
||||
Parameters
|
||||
----------
|
||||
selector : Union[pd.Timestamp, slice, str]
|
||||
@@ -328,6 +336,9 @@ class DataHandler(Serializable):
|
||||
yield cur_date, self.fetch(selector, **kwargs)
|
||||
|
||||
|
||||
DATA_KEY_TYPE = Literal["raw", "infer", "learn"]
|
||||
|
||||
|
||||
class DataHandlerLP(DataHandler):
|
||||
"""
|
||||
DataHandler with **(L)earnable (P)rocessor**
|
||||
@@ -346,17 +357,28 @@ class DataHandlerLP(DataHandler):
|
||||
|
||||
- These processors only apply to the learning phase.
|
||||
|
||||
Tips to improve the performance of data handler
|
||||
Tips for data handler
|
||||
|
||||
- To reduce the memory cost
|
||||
|
||||
- `drop_raw=True`: this will modify the data inplace on raw data;
|
||||
|
||||
- Please note processed data like `self._infer` or `self._learn` are concepts different from `segments` in Qlib's `Dataset` like "train" and "test"
|
||||
|
||||
- Processed data like `self._infer` or `self._learn` are underlying data processed with different processors
|
||||
- `segments` in Qlib's `Dataset` like "train" and "test" are simply the time segmentations when querying data("train" are often before "test" in time-series).
|
||||
- For example, you can query `data._infer` processed by `infer_processors` in the "train" time segmentation.
|
||||
"""
|
||||
|
||||
# based on `self._data`, _infer and _learn are genrated after processors
|
||||
_infer: pd.DataFrame # data for inference
|
||||
_learn: pd.DataFrame # data for learning models
|
||||
|
||||
# data key
|
||||
DK_R = "raw"
|
||||
DK_I = "infer"
|
||||
DK_L = "learn"
|
||||
DK_R: DATA_KEY_TYPE = "raw"
|
||||
DK_I: DATA_KEY_TYPE = "infer"
|
||||
DK_L: DATA_KEY_TYPE = "learn"
|
||||
# map data_key to attribute name
|
||||
ATTR_MAP = {DK_R: "_data", DK_I: "_infer", DK_L: "_learn"}
|
||||
|
||||
# process type
|
||||
@@ -600,7 +622,7 @@ class DataHandlerLP(DataHandler):
|
||||
|
||||
# TODO: Be able to cache handler data. Save the memory for data processing
|
||||
|
||||
def _get_df_by_key(self, data_key: str = DK_I) -> pd.DataFrame:
|
||||
def _get_df_by_key(self, data_key: DATA_KEY_TYPE = DK_I) -> pd.DataFrame:
|
||||
if data_key == self.DK_R and self.drop_raw:
|
||||
raise AttributeError(
|
||||
"DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data"
|
||||
@@ -613,7 +635,7 @@ class DataHandlerLP(DataHandler):
|
||||
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key: str = DK_I,
|
||||
data_key: DATA_KEY_TYPE = DK_I,
|
||||
squeeze: bool = False,
|
||||
proc_func: Callable = None,
|
||||
) -> pd.DataFrame:
|
||||
@@ -647,7 +669,7 @@ class DataHandlerLP(DataHandler):
|
||||
proc_func=proc_func,
|
||||
)
|
||||
|
||||
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list:
|
||||
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: DATA_KEY_TYPE = DK_I) -> list:
|
||||
"""
|
||||
get the column names
|
||||
|
||||
@@ -655,7 +677,7 @@ class DataHandlerLP(DataHandler):
|
||||
----------
|
||||
col_set : str
|
||||
select a set of meaningful columns.(e.g. features, columns).
|
||||
data_key : str
|
||||
data_key : DATA_KEY_TYPE
|
||||
the data to fetch: DK_*.
|
||||
|
||||
Returns
|
||||
@@ -698,3 +720,26 @@ class DataHandlerLP(DataHandler):
|
||||
]:
|
||||
setattr(new_hd, key, getattr(handler, key, None))
|
||||
return new_hd
|
||||
|
||||
@classmethod
|
||||
def from_df(cls, df: pd.DataFrame) -> "DataHandlerLP":
|
||||
"""
|
||||
Motivation:
|
||||
- When user want to get a quick data handler.
|
||||
|
||||
The created data handler will have only one shared Dataframe without processors.
|
||||
After creating the handler, user may often want to dump the handler for reuse
|
||||
Here is a typical use case
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from qlib.data.dataset import DataHandlerLP
|
||||
dh = DataHandlerLP.from_df(df)
|
||||
dh.to_pickle(fname, dump_all=True)
|
||||
|
||||
TODO:
|
||||
- The StaticDataLoader is quite slow. It don't have to copy the data again...
|
||||
|
||||
"""
|
||||
loader = data_loader_module.StaticDataLoader(df)
|
||||
return cls(data_loader=loader)
|
||||
|
||||
@@ -153,7 +153,7 @@ class QlibDataLoader(DLWParser):
|
||||
filter_pipe: List = None,
|
||||
swap_level: bool = True,
|
||||
freq: Union[str, dict] = "day",
|
||||
inst_processor: dict = None,
|
||||
inst_processors: Union[dict, list] = None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
@@ -167,16 +167,19 @@ class QlibDataLoader(DLWParser):
|
||||
freq: dict or str
|
||||
If type(config) == dict and type(freq) == str, load config data using freq.
|
||||
If type(config) == dict and type(freq) == dict, load config[<group_name>] data using freq[<group_name>]
|
||||
inst_processor: dict
|
||||
If inst_processor is not None and type(config) == dict; load config[<group_name>] data using inst_processor[<group_name>]
|
||||
inst_processors: dict | list
|
||||
If inst_processors is not None and type(config) == dict; load config[<group_name>] data using inst_processors[<group_name>]
|
||||
If inst_processors is a list, then it will be applied to all groups.
|
||||
"""
|
||||
self.filter_pipe = filter_pipe
|
||||
self.swap_level = swap_level
|
||||
self.freq = freq
|
||||
|
||||
# sample
|
||||
self.inst_processor = inst_processor if inst_processor is not None else {}
|
||||
assert isinstance(self.inst_processor, dict), f"inst_processor(={self.inst_processor}) must be dict"
|
||||
self.inst_processors = inst_processors if inst_processors is not None else {}
|
||||
assert isinstance(
|
||||
self.inst_processors, (dict, list)
|
||||
), f"inst_processors(={self.inst_processors}) must be dict or list"
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
@@ -187,8 +190,8 @@ class QlibDataLoader(DLWParser):
|
||||
if _gp not in freq:
|
||||
raise ValueError(f"freq(={freq}) missing group(={_gp})")
|
||||
assert (
|
||||
self.inst_processor
|
||||
), f"freq(={self.freq}), inst_processor(={self.inst_processor}) cannot be None/empty"
|
||||
self.inst_processors
|
||||
), f"freq(={self.freq}), inst_processors(={self.inst_processors}) cannot be None/empty"
|
||||
|
||||
def load_group_df(
|
||||
self,
|
||||
@@ -208,9 +211,10 @@ class QlibDataLoader(DLWParser):
|
||||
warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")
|
||||
|
||||
freq = self.freq[gp_name] if isinstance(self.freq, dict) else self.freq
|
||||
df = D.features(
|
||||
instruments, exprs, start_time, end_time, freq=freq, inst_processors=self.inst_processor.get(gp_name, [])
|
||||
inst_processors = (
|
||||
self.inst_processors if isinstance(self.inst_processors, list) else self.inst_processors.get(gp_name, [])
|
||||
)
|
||||
df = D.features(instruments, exprs, start_time, end_time, freq=freq, inst_processors=inst_processors)
|
||||
df.columns = names
|
||||
if self.swap_level:
|
||||
df = df.swaplevel().sort_index() # NOTE: if swaplevel, return <datetime, instrument>
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
from typing import Union, Text
|
||||
from typing import Union, Text, Optional
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
@@ -11,6 +11,8 @@ from ...constant import EPS
|
||||
from .utils import fetch_df_by_index
|
||||
from ...utils.serial import Serializable
|
||||
from ...utils.paral import datetime_groupby_apply
|
||||
from qlib.data.inst_processor import InstProcessor
|
||||
from qlib.data import D
|
||||
|
||||
|
||||
def get_group_columns(df: pd.DataFrame, group: Union[Text, None]):
|
||||
@@ -211,16 +213,19 @@ class MinMaxNorm(Processor):
|
||||
self.min_val = np.nanmin(df[cols].values, axis=0)
|
||||
self.max_val = np.nanmax(df[cols].values, axis=0)
|
||||
self.ignore = self.min_val == self.max_val
|
||||
# To improve the speed, we set the value of `min_val` to `0` for the columns that do not need to be processed,
|
||||
# and the value of `max_val` to `1`, when using `(x - min_val) / (max_val - min_val)` for uniform calculation,
|
||||
# the columns that do not need to be processed will be calculated by `(x - 0) / (1 - 0)`,
|
||||
# as you can see, the columns that do not need to be processed, will not be affected.
|
||||
for _i, _con in enumerate(self.ignore):
|
||||
if _con:
|
||||
self.min_val[_i] = 0
|
||||
self.max_val[_i] = 1
|
||||
self.cols = cols
|
||||
|
||||
def __call__(self, df):
|
||||
def normalize(x, min_val=self.min_val, max_val=self.max_val, ignore=self.ignore):
|
||||
if (~ignore).all():
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
for i in range(ignore.size):
|
||||
if not ignore[i]:
|
||||
x[i] = (x[i] - min_val) / (max_val - min_val)
|
||||
return x
|
||||
def normalize(x, min_val=self.min_val, max_val=self.max_val):
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
|
||||
df.loc(axis=1)[self.cols] = normalize(df[self.cols].values)
|
||||
return df
|
||||
@@ -242,16 +247,19 @@ class ZScoreNorm(Processor):
|
||||
self.mean_train = np.nanmean(df[cols].values, axis=0)
|
||||
self.std_train = np.nanstd(df[cols].values, axis=0)
|
||||
self.ignore = self.std_train == 0
|
||||
# To improve the speed, we set the value of `std_train` to `1` for the columns that do not need to be processed,
|
||||
# and the value of `mean_train` to `0`, when using `(x - mean_train) / std_train` for uniform calculation,
|
||||
# the columns that do not need to be processed will be calculated by `(x - 0) / 1`,
|
||||
# as you can see, the columns that do not need to be processed, will not be affected.
|
||||
for _i, _con in enumerate(self.ignore):
|
||||
if _con:
|
||||
self.std_train[_i] = 1
|
||||
self.mean_train[_i] = 0
|
||||
self.cols = cols
|
||||
|
||||
def __call__(self, df):
|
||||
def normalize(x, mean_train=self.mean_train, std_train=self.std_train, ignore=self.ignore):
|
||||
if (~ignore).all():
|
||||
return (x - mean_train) / std_train
|
||||
for i in range(ignore.size):
|
||||
if not ignore[i]:
|
||||
x[i] = (x[i] - mean_train) / std_train
|
||||
return x
|
||||
def normalize(x, mean_train=self.mean_train, std_train=self.std_train):
|
||||
return (x - mean_train) / std_train
|
||||
|
||||
df.loc(axis=1)[self.cols] = normalize(df[self.cols].values)
|
||||
return df
|
||||
@@ -361,7 +369,7 @@ class CSZFillna(Processor):
|
||||
|
||||
def __call__(self, df):
|
||||
cols = get_group_columns(df, self.fields_group)
|
||||
df[cols] = df[cols].groupby("datetime").apply(lambda x: x.fillna(x.mean()))
|
||||
df[cols] = df[cols].groupby("datetime", group_keys=False).apply(lambda x: x.fillna(x.mean()))
|
||||
return df
|
||||
|
||||
|
||||
@@ -372,3 +380,42 @@ class HashStockFormat(Processor):
|
||||
from .storage import HashingStockStorage # pylint: disable=C0415
|
||||
|
||||
return HashingStockStorage.from_df(df)
|
||||
|
||||
|
||||
class TimeRangeFlt(InstProcessor):
|
||||
"""
|
||||
This is a filter to filter stock.
|
||||
Only keep the data that exist from start_time to end_time (the existence in the middle is not checked.)
|
||||
WARNING: It may induce leakage!!!
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_time: Optional[Union[pd.Timestamp, str]] = None,
|
||||
end_time: Optional[Union[pd.Timestamp, str]] = None,
|
||||
freq: str = "day",
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
start_time : Optional[Union[pd.Timestamp, str]]
|
||||
The data must start earlier (or equal) than `start_time`
|
||||
None indicates data will not be filtered based on `start_time`
|
||||
end_time : Optional[Union[pd.Timestamp, str]]
|
||||
similar to start_time
|
||||
freq : str
|
||||
The frequency of the calendar
|
||||
"""
|
||||
# Align to calendar before filtering
|
||||
cal = D.calendar(start_time=start_time, end_time=end_time, freq=freq)
|
||||
self.start_time = None if start_time is None else cal[0]
|
||||
self.end_time = None if end_time is None else cal[-1]
|
||||
|
||||
def __call__(self, df: pd.DataFrame, instrument, *args, **kwargs):
|
||||
if (
|
||||
df.empty
|
||||
or (self.start_time is None or df.index.min() <= self.start_time)
|
||||
and (self.end_time is None or df.index.max() >= self.end_time)
|
||||
):
|
||||
return df
|
||||
return df.head(0)
|
||||
|
||||
@@ -2,9 +2,8 @@
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
import pandas as pd
|
||||
from typing import Union, List
|
||||
from typing import Union, List, TYPE_CHECKING
|
||||
from qlib.utils import init_instance_by_config
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.data.dataset import DataHandler
|
||||
@@ -121,7 +120,7 @@ def convert_index_format(df: Union[pd.DataFrame, pd.Series], level: str = "datet
|
||||
return df
|
||||
|
||||
|
||||
def init_task_handler(task: dict) -> Union[DataHandler, None]:
|
||||
def init_task_handler(task: dict) -> DataHandler:
|
||||
"""
|
||||
initialize the handler part of the task **inplace**
|
||||
|
||||
@@ -142,5 +141,6 @@ def init_task_handler(task: dict) -> Union[DataHandler, None]:
|
||||
if h_conf is not None:
|
||||
handler = init_instance_by_config(h_conf, accept_types=DataHandler)
|
||||
task["dataset"]["kwargs"]["handler"] = handler
|
||||
|
||||
return handler
|
||||
else:
|
||||
raise ValueError("The task does not contains a handler part.")
|
||||
|
||||
18
qlib/finco/.env.example
Normal file
18
qlib/finco/.env.example
Normal file
@@ -0,0 +1,18 @@
|
||||
|
||||
OPENAI_API_KEY=your_api_key
|
||||
|
||||
# USE_AZURE=True
|
||||
# AZURE_API_BASE=your_api_base
|
||||
# AZURE_API_VERSION=your_api_version
|
||||
|
||||
# use gpt-4 means more token but more wait time
|
||||
# MODEL=gpt-4
|
||||
# MAX_TOKENS=1600
|
||||
# MAX_RETRY=1000
|
||||
|
||||
|
||||
MAX_TOKENS=1600
|
||||
MAX_RETRY=120
|
||||
|
||||
CONTINOUS_MODE=True
|
||||
DEBUG_MODE=True
|
||||
22
qlib/finco/README.md
Normal file
22
qlib/finco/README.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# This is an experimental branch of "`FI`nancial `CO`pilot of `Qlib`"
|
||||
|
||||
## Installation
|
||||
|
||||
- To run this module, you need to first install Qlib following the instruction in [install-from-source](/README.md#install-from-source) or follow:
|
||||
|
||||
```python
|
||||
python -m pip install git+https://github.com/microsoft/qlib.git@finco
|
||||
```
|
||||
|
||||
- then you need to install other dependencies of finco:
|
||||
```python
|
||||
python -m pip install pydantic openai python-dotenv
|
||||
```
|
||||
|
||||
## Quick run
|
||||
|
||||
To run this module, you can start the workflow easily with one command:
|
||||
|
||||
```sh
|
||||
cd qlib/finco; python cli.py "your prompt"
|
||||
```
|
||||
13
qlib/finco/__init__.py
Normal file
13
qlib/finco/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
DIRNAME = Path(__file__).absolute().resolve().parent
|
||||
|
||||
|
||||
def get_finco_path() -> Path:
|
||||
"""
|
||||
return the template path
|
||||
Because the template path is located in the folder. We don't know where it is located. So __file__ for this module will be used.
|
||||
"""
|
||||
return DIRNAME
|
||||
15
qlib/finco/cli.py
Normal file
15
qlib/finco/cli.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import fire
|
||||
from qlib.finco.workflow import WorkflowManager
|
||||
from dotenv import load_dotenv
|
||||
from qlib import auto_init
|
||||
|
||||
|
||||
def main(prompt=None):
|
||||
load_dotenv(verbose=True, override=True)
|
||||
wm = WorkflowManager()
|
||||
wm.run(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
auto_init()
|
||||
fire.Fire(main)
|
||||
15
qlib/finco/cli_learn.py
Normal file
15
qlib/finco/cli_learn.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import fire
|
||||
from qlib.finco.workflow import LearnManager
|
||||
from dotenv import load_dotenv
|
||||
from qlib import auto_init
|
||||
|
||||
|
||||
def main(prompt=None):
|
||||
load_dotenv(verbose=True, override=True)
|
||||
lm = LearnManager()
|
||||
lm.run(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
auto_init()
|
||||
fire.Fire(main)
|
||||
32
qlib/finco/conf.py
Normal file
32
qlib/finco/conf.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# TODO: use pydantic for other modules in Qlib
|
||||
from pydantic import BaseSettings
|
||||
from qlib.finco.utils import SingletonBaseClass
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class Config(SingletonBaseClass):
|
||||
"""
|
||||
This config is for fast demo purpose.
|
||||
Please use BaseSettings insetead in the future
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.use_azure = os.getenv("USE_AZURE") == "True"
|
||||
self.temperature = 0.5 if os.getenv("TEMPERATURE") is None else float(os.getenv("TEMPERATURE"))
|
||||
self.max_tokens = 800 if os.getenv("MAX_TOKENS") is None else int(os.getenv("MAX_TOKENS"))
|
||||
|
||||
self.openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
self.use_azure = os.getenv("USE_AZURE") == "True"
|
||||
self.azure_api_base = os.getenv("AZURE_API_BASE")
|
||||
self.azure_api_version = os.getenv("AZURE_API_VERSION")
|
||||
self.model = os.getenv("MODEL") or ("gpt-35-turbo" if self.use_azure else "gpt-3.5-turbo")
|
||||
|
||||
self.max_retry = int(os.getenv("MAX_RETRY")) if os.getenv("MAX_RETRY") is not None else None
|
||||
|
||||
self.continuous_mode = (
|
||||
os.getenv("CONTINOUS_MODE") == "True" if os.getenv("CONTINOUS_MODE") is not None else False
|
||||
)
|
||||
self.debug_mode = os.getenv("DEBUG_MODE") == "True" if os.getenv("DEBUG_MODE") is not None else False
|
||||
self.workspace = os.getenv("WORKSPACE") if os.getenv("WORKSPACE") is not None else "./finco_workspace"
|
||||
self.max_past_message_include = int(os.getenv("MAX_PAST_MESSAGE_INCLUDE") or 6) // 2 * 2
|
||||
156
qlib/finco/knowledge.py
Normal file
156
qlib/finco/knowledge.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from pathlib import Path
|
||||
from jinja2 import Template
|
||||
from typing import List
|
||||
|
||||
from qlib.workflow import R
|
||||
from qlib.finco.log import FinCoLog
|
||||
from qlib.finco.llm import APIBackend
|
||||
|
||||
|
||||
class Knowledge:
|
||||
"""
|
||||
Use to handle knowledge in finCo such as experiment and outside domain information
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = FinCoLog()
|
||||
|
||||
def load(self, **kwargs):
|
||||
"""
|
||||
Load knowledge in memory
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Return
|
||||
------
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `load` method.")
|
||||
|
||||
def brief(self, **kwargs):
|
||||
"""
|
||||
Return a brief summary of knowledge
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Return
|
||||
------
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `load` method.")
|
||||
|
||||
|
||||
class KnowledgeExperiment(Knowledge):
|
||||
"""
|
||||
Handle knowledge from experiments
|
||||
"""
|
||||
|
||||
def __init__(self, exp_name, rec_id=None):
|
||||
super().__init__()
|
||||
self.exp_name = exp_name
|
||||
self.exp = None
|
||||
self.recs = []
|
||||
|
||||
self.load(exp_name=exp_name, rec_id=rec_id)
|
||||
|
||||
def load(self, exp_name, rec_id=None):
|
||||
recs = []
|
||||
self.exp = R.get_exp(experiment_name=exp_name)
|
||||
for r in self.exp.list_recorders(rtype=self.exp.RT_L):
|
||||
if rec_id is not None and r.id != rec_id:
|
||||
continue
|
||||
recs.append(r)
|
||||
self.recs.extend(recs)
|
||||
|
||||
def brief(self):
|
||||
docs = []
|
||||
for recorder in self.recs:
|
||||
docs.append({"exp_name": self.exp.name, "record_info": recorder.info,
|
||||
"config": recorder.load_object("config"),
|
||||
"context_summary": recorder.load_object("context_summary")})
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
class Topic:
|
||||
|
||||
def __init__(self, name: str, describe: Template):
|
||||
self.name = name
|
||||
self.describe = describe
|
||||
self.docs = []
|
||||
self.knowledge = None
|
||||
self.logger = FinCoLog()
|
||||
|
||||
def summarize(self, docs: list):
|
||||
self.logger.info(f"Summarize topic: \nname: {self.name}\ndescribe: {self.describe.module}")
|
||||
prompt_workflow_selection = self.describe.render(docs=docs)
|
||||
response = APIBackend().build_messages_and_create_chat_completion(
|
||||
user_prompt=prompt_workflow_selection
|
||||
)
|
||||
|
||||
self.knowledge = response
|
||||
self.docs = docs
|
||||
|
||||
|
||||
class KnowledgeBase:
|
||||
"""
|
||||
Load knowledge, offer brief information of knowledge and common handle interfaces
|
||||
"""
|
||||
|
||||
def __init__(self, init_path=None, topics: List[Topic] = None):
|
||||
self.logger = FinCoLog()
|
||||
init_path = init_path if init_path else Path.cwd()
|
||||
|
||||
if not init_path.exists():
|
||||
self.logger.warning(f"{init_path} not exist, create empty directory.")
|
||||
Path.mkdir(init_path)
|
||||
|
||||
self.knowledge = self.load(path=init_path)
|
||||
|
||||
# todo: replace list with persistent storage strategy such as ES/pinecone to enable
|
||||
# literal search/semantic search
|
||||
self.docs = self.brief(knowledge=self.knowledge)
|
||||
|
||||
self.topics = topics if topics else []
|
||||
|
||||
def load(self, path) -> List:
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
|
||||
knowledge = []
|
||||
path = path if path.name == "mlruns" else path.joinpath("mlruns")
|
||||
R.set_uri(path.as_uri())
|
||||
for exp_name in R.list_experiments():
|
||||
knowledge.append(KnowledgeExperiment(exp_name=exp_name))
|
||||
|
||||
self.logger.plain_info(f"Load knowledge from: {path} finished.")
|
||||
return knowledge
|
||||
|
||||
def update(self, path):
|
||||
# note: only update new knowledge in future
|
||||
knowledge = self.load(path)
|
||||
self.knowledge = knowledge
|
||||
self.docs = self.brief(self.knowledge)
|
||||
self.logger.plain_info(f"Update knowledge finished.")
|
||||
|
||||
def brief(self, knowledge: List[Knowledge]) -> List:
|
||||
docs = []
|
||||
for k in knowledge:
|
||||
docs.extend(k.brief())
|
||||
|
||||
self.logger.plain_info(f"Generate brief knowledge summary finished.")
|
||||
return docs
|
||||
|
||||
def query(self, content: str = None):
|
||||
# todo: query by DSL
|
||||
return self.docs
|
||||
|
||||
def query_topics(self):
|
||||
knowledge_of_topics = []
|
||||
for topic in self.topics:
|
||||
knowledge_of_topics.append({topic.name: topic.knowledge})
|
||||
return knowledge_of_topics
|
||||
|
||||
def summarize_by_topic(self):
|
||||
for topic in self.topics:
|
||||
topic.summarize(self.docs)
|
||||
111
qlib/finco/llm.py
Normal file
111
qlib/finco/llm.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import os
|
||||
import time
|
||||
import openai
|
||||
import json
|
||||
from typing import Optional
|
||||
from qlib.finco.conf import Config
|
||||
from qlib.finco.utils import SingletonBaseClass
|
||||
from qlib.finco.log import FinCoLog
|
||||
|
||||
|
||||
class APIBackend(SingletonBaseClass):
|
||||
def __init__(self):
|
||||
self.cfg = Config()
|
||||
openai.api_key = self.cfg.openai_api_key
|
||||
if self.cfg.use_azure:
|
||||
openai.api_type = "azure"
|
||||
openai.api_base = self.cfg.azure_api_base
|
||||
openai.api_version = self.cfg.azure_api_version
|
||||
self.use_azure = self.cfg.use_azure
|
||||
|
||||
self.debug_mode = False
|
||||
if self.cfg.debug_mode:
|
||||
self.debug_mode = True
|
||||
cwd = os.getcwd()
|
||||
self.cache_file_location = os.path.join(cwd, "prompt_cache.json")
|
||||
self.cache = (
|
||||
json.load(open(self.cache_file_location, "r")) if os.path.exists(self.cache_file_location) else {}
|
||||
)
|
||||
|
||||
def build_messages_and_create_chat_completion(self, user_prompt, system_prompt=None, former_messages=[], **kwargs):
|
||||
"""build the messages to avoid implementing several redundant lines of code"""
|
||||
cfg = Config()
|
||||
# TODO: system prompt should always be provided. In development stage we can use default value
|
||||
if system_prompt is None:
|
||||
try:
|
||||
system_prompt = cfg.system_prompt
|
||||
except AttributeError:
|
||||
FinCoLog().warning("system_prompt is not set, using default value.")
|
||||
system_prompt = "You are an AI assistant who helps to answer user's questions about finance."
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
}
|
||||
]
|
||||
messages.extend(former_messages[-1*cfg.max_past_message_include:])
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
)
|
||||
fcl = FinCoLog()
|
||||
response = self.try_create_chat_completion(messages=messages, **kwargs)
|
||||
fcl.log_message(messages)
|
||||
fcl.log_response(response)
|
||||
return response
|
||||
|
||||
def try_create_chat_completion(self, max_retry=10, **kwargs):
|
||||
max_retry = self.cfg.max_retry if self.cfg.max_retry is not None else max_retry
|
||||
for i in range(max_retry):
|
||||
try:
|
||||
response = self.create_chat_completion(**kwargs)
|
||||
return response
|
||||
except (openai.error.RateLimitError, openai.error.Timeout, openai.error.APIError) as e:
|
||||
print(e)
|
||||
print(f"Retrying {i+1}th time...")
|
||||
time.sleep(1)
|
||||
continue
|
||||
except openai.InvalidRequestError as e:
|
||||
print("Invalid request, will try to reduce the messages length and retry...")
|
||||
if len(kwargs["messages"]) > 2:
|
||||
kwargs["messages"] = kwargs["messages"][[0]] + kwargs["messages"][3:]
|
||||
continue
|
||||
raise e
|
||||
raise Exception(f"Failed to create chat completion after {max_retry} retries.")
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages,
|
||||
model=None,
|
||||
temperature: float = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
|
||||
if self.debug_mode:
|
||||
key = json.dumps(messages)
|
||||
if key in self.cache:
|
||||
return self.cache[key]
|
||||
|
||||
if temperature is None:
|
||||
temperature = self.cfg.temperature
|
||||
if max_tokens is None:
|
||||
max_tokens = self.cfg.max_tokens
|
||||
|
||||
if self.cfg.use_azure:
|
||||
response = openai.ChatCompletion.create(
|
||||
engine=self.cfg.model,
|
||||
messages=messages,
|
||||
max_tokens=self.cfg.max_tokens,
|
||||
)
|
||||
else:
|
||||
response = openai.ChatCompletion.create(
|
||||
model=self.cfg.model,
|
||||
messages=messages,
|
||||
)
|
||||
resp = response.choices[0].message["content"]
|
||||
if self.debug_mode:
|
||||
self.cache[key] = resp
|
||||
json.dump(self.cache, open(self.cache_file_location, "w"))
|
||||
return resp
|
||||
131
qlib/finco/log.py
Normal file
131
qlib/finco/log.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
This module will base on Qlib's logger module and provides some interactive functions.
|
||||
"""
|
||||
import logging
|
||||
|
||||
from typing import Dict, List
|
||||
from qlib.finco.utils import SingletonBaseClass
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
class LogColors:
|
||||
"""
|
||||
ANSI color codes for use in console output.
|
||||
"""
|
||||
RED = "\033[91m"
|
||||
GREEN = "\033[92m"
|
||||
YELLOW = "\033[93m"
|
||||
BLUE = "\033[94m"
|
||||
MAGENTA = "\033[95m"
|
||||
CYAN = "\033[96m"
|
||||
WHITE = "\033[97m"
|
||||
GRAY = "\033[90m"
|
||||
BLACK = "\033[30m"
|
||||
|
||||
BOLD = "\033[1m"
|
||||
ITALIC = "\033[3m"
|
||||
|
||||
END = "\033[0m"
|
||||
|
||||
@classmethod
|
||||
def get_all_colors(cls):
|
||||
names = dir(cls)
|
||||
names = [name for name in names if not name.startswith("__") and not callable(getattr(cls, name))]
|
||||
var_values = [getattr(cls, name) for name in names]
|
||||
return var_values
|
||||
|
||||
def render(self, text: str, color: str = "", style: str = ""):
|
||||
"""
|
||||
render text by input color and style. It's not recommend that input text is already rendered.
|
||||
"""
|
||||
# This method is called too frequently, which is not good.
|
||||
colors = self.get_all_colors()
|
||||
# Perhaps color and font should be distinguished here.
|
||||
if color:
|
||||
assert color in colors, f"color should be in: {colors} but now is: {color}"
|
||||
if style:
|
||||
assert style in colors, f"style should be in: {colors} but now is: {style}"
|
||||
|
||||
text = f"{color}{text}{self.END}"
|
||||
text = f"{style}{text}{self.END}"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
@contextmanager
|
||||
def formatting_log(logger, title="Info"):
|
||||
"""
|
||||
a context manager, print liens before and after a function
|
||||
"""
|
||||
length = {"Start": 120, "Task": 120, "Info": 60, "Interact": 60, "End": 120}.get(title, 60)
|
||||
color, bold = (LogColors.YELLOW, LogColors.BOLD) \
|
||||
if title in ["Start", "Task", "Info", "Interact", "End"] else (LogColors.CYAN, "")
|
||||
logger.info("")
|
||||
logger.info(f"{color}{bold}{'-'} {title} {'-' * (length - len(title))}{LogColors.END}")
|
||||
yield
|
||||
logger.info("")
|
||||
|
||||
|
||||
class FinCoLog(SingletonBaseClass):
|
||||
# TODO:
|
||||
# - config to file logger and save it into workspace
|
||||
def __init__(self) -> None:
|
||||
self.logger = logging.Logger("interactive")
|
||||
# TODO: merge these with Qlib's default logger.
|
||||
# We can do the same thing by changing the default log dict of Qlib.
|
||||
# Reference: https://github.com/microsoft/qlib/blob/main/qlib/config.py#L155
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
self.logger.addHandler(handler)
|
||||
self.logger.setLevel(logging.INFO)
|
||||
|
||||
def log_message(self, messages: List[Dict[str, str]]):
|
||||
"""
|
||||
messages is some info like this [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
},
|
||||
]
|
||||
"""
|
||||
with formatting_log(self.logger, "GPT Messages"):
|
||||
for m in messages:
|
||||
self.logger.info(
|
||||
f"{LogColors.MAGENTA}{LogColors.BOLD}Role:{LogColors.END} "
|
||||
f"{LogColors.CYAN}{m['role']}{LogColors.END}\n"
|
||||
+ f"{LogColors.MAGENTA}{LogColors.BOLD}Content:{LogColors.END} "
|
||||
f"{LogColors.CYAN}{m['content']}{LogColors.END}\n")
|
||||
|
||||
def log_response(self, response: str):
|
||||
with formatting_log(self.logger, "GPT Response"):
|
||||
self.logger.info(
|
||||
f"{LogColors.CYAN}{response}{LogColors.END}\n")
|
||||
|
||||
# TODO:
|
||||
# It looks wierd if we only have logger
|
||||
def info(self, *args, plain=False, title="Info"):
|
||||
if plain:
|
||||
return self.plain_info(*args)
|
||||
with formatting_log(self.logger, title):
|
||||
for arg in args:
|
||||
self.logger.info(f"{LogColors.WHITE}{arg}{LogColors.END}")
|
||||
|
||||
def plain_info(self, *args):
|
||||
for arg in args:
|
||||
self.logger.info(
|
||||
f"{LogColors.YELLOW}{LogColors.BOLD}Info:{LogColors.END}{LogColors.WHITE}{arg}{LogColors.END}")
|
||||
|
||||
def warning(self, *args):
|
||||
for arg in args:
|
||||
self.logger.warning(
|
||||
f"{LogColors.BLUE}{LogColors.BOLD}Warning:{LogColors.END}{arg}")
|
||||
|
||||
def error(self, *args):
|
||||
for arg in args:
|
||||
self.logger.error(
|
||||
f"{LogColors.RED}{LogColors.BOLD}Error:{LogColors.END}{arg}")
|
||||
32
qlib/finco/prompt_template.py
Normal file
32
qlib/finco/prompt_template.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from typing import Union
|
||||
from pathlib import Path
|
||||
from jinja2 import Template
|
||||
import yaml
|
||||
|
||||
from qlib.finco.utils import SingletonBaseClass
|
||||
from qlib.finco import get_finco_path
|
||||
|
||||
|
||||
class PromptTemplate(SingletonBaseClass):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
_template = yaml.load(open(Path.joinpath(get_finco_path(), "prompt_template.yaml"), "r"),
|
||||
Loader=yaml.FullLoader)
|
||||
for k, v in _template.items():
|
||||
if k == "mods":
|
||||
continue
|
||||
self.__setattr__(k, Template(v))
|
||||
|
||||
def get(self, key: str):
|
||||
return self.__dict__.get(key, Template(""))
|
||||
|
||||
def update(self, key: str, value):
|
||||
self.__setattr__(key, value)
|
||||
|
||||
def save(self, file_path: Union[str, Path]):
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
Path.mkdir(file_path.parent, exist_ok=True)
|
||||
|
||||
with open(file_path, 'w') as f:
|
||||
yaml.dump(self.__dict__, f)
|
||||
1012
qlib/finco/prompt_template.yaml
Normal file
1012
qlib/finco/prompt_template.yaml
Normal file
File diff suppressed because it is too large
Load Diff
1110
qlib/finco/task.py
Normal file
1110
qlib/finco/task.py
Normal file
File diff suppressed because it is too large
Load Diff
12
qlib/finco/tpl/README.md
Normal file
12
qlib/finco/tpl/README.md
Normal file
@@ -0,0 +1,12 @@
|
||||
This is a set of templates that should be copied for a new project.
|
||||
|
||||
Here are the explanations for the templates folder.
|
||||
|
||||
| folder | explanations |
|
||||
|--------|------------------------------------------------------------------|
|
||||
| sl | Default configuration for supervised learning |
|
||||
| sl-cfg | Like configuration in sl. But the dataset is highly configurable |
|
||||
|
||||
|
||||
# TODO
|
||||
- [ ] [Copier](https://copier.readthedocs.io/en/stable/#quick-start) may be useful if the generation process becomes complicated
|
||||
13
qlib/finco/tpl/__init__.py
Normal file
13
qlib/finco/tpl/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
DIRNAME = Path(__file__).absolute().resolve().parent
|
||||
|
||||
|
||||
def get_tpl_path() -> Path:
|
||||
"""
|
||||
return the template path
|
||||
Because the template path is located in the folder. We don't know where it is located. So __file__ for this module will be used.
|
||||
"""
|
||||
return DIRNAME
|
||||
83
qlib/finco/tpl/sl-cfg/workflow_config.yaml
Normal file
83
qlib/finco/tpl/sl-cfg/workflow_config.yaml
Normal file
File diff suppressed because one or more lines are too long
73
qlib/finco/tpl/sl/workflow_config.yaml
Normal file
73
qlib/finco/tpl/sl/workflow_config.yaml
Normal file
@@ -0,0 +1,73 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
experiment_name: finCo
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
38
qlib/finco/utils.py
Normal file
38
qlib/finco/utils.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import json
|
||||
|
||||
from fuzzywuzzy import fuzz
|
||||
|
||||
|
||||
class SingletonMeta(type):
|
||||
_instance = None
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(SingletonMeta, cls).__call__(*args, **kwargs)
|
||||
return cls._instance
|
||||
|
||||
|
||||
class SingletonBaseClass(metaclass=SingletonMeta):
|
||||
"""
|
||||
Because we try to support defining Singleton with `class A(SingletonBaseClass)` instead of `A(metaclass=SingletonMeta)`
|
||||
This class becomes necessary
|
||||
|
||||
"""
|
||||
# TODO: Add move this class to Qlib's general utils.
|
||||
|
||||
|
||||
def parse_json(response):
|
||||
try:
|
||||
return json.loads(response)
|
||||
except json.decoder.JSONDecodeError:
|
||||
pass
|
||||
|
||||
raise Exception(f"Failed to parse response: {response}, please report it or help us to fix it.")
|
||||
|
||||
|
||||
def similarity(text1, text2):
|
||||
text1 = text1 if isinstance(text1, str) else ""
|
||||
text2 = text2 if isinstance(text2, str) else ""
|
||||
|
||||
# Maybe we can use other similarity algorithm such as tfidf
|
||||
return fuzz.ratio(text1, text2)
|
||||
223
qlib/finco/workflow.py
Normal file
223
qlib/finco/workflow.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import sys
|
||||
import copy
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from qlib.finco.task import HighLevelPlanTask, SummarizeTask, TrainTask
|
||||
from qlib.finco.prompt_template import PromptTemplate, Template
|
||||
from qlib.finco.log import FinCoLog, LogColors
|
||||
from qlib.finco.utils import similarity
|
||||
from qlib.finco.llm import APIBackend
|
||||
from qlib.finco.conf import Config
|
||||
from qlib.finco.knowledge import KnowledgeBase, Topic
|
||||
|
||||
|
||||
class WorkflowContextManager:
|
||||
"""Context Manager stores the context of the workflow"""
|
||||
|
||||
"""All context are key value pairs which saves the input, output and status of the whole workflow"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.context = {}
|
||||
self.logger = FinCoLog()
|
||||
|
||||
def set_context(self, key, value):
|
||||
if key in self.context:
|
||||
self.logger.warning("The key already exists in the context, the value will be overwritten")
|
||||
self.context[key] = value
|
||||
|
||||
def get_context(self, key):
|
||||
# NOTE: if the key doesn't exist, return None. In the future, we may raise an error to detect abnormal behavior
|
||||
if key not in self.context:
|
||||
self.logger.warning("The key doesn't exist in the context")
|
||||
return None
|
||||
return self.context[key]
|
||||
|
||||
def update_context(self, key, new_value):
|
||||
# NOTE: if the key doesn't exist, return None. In the future, we may raise an error to detect abnormal behavior
|
||||
if key not in self.context:
|
||||
self.logger.warning("The key doesn't exist in the context")
|
||||
self.context.update({key: new_value})
|
||||
|
||||
def get_all_context(self):
|
||||
"""return a deep copy of the context"""
|
||||
"""TODO: do we need to return a deep copy?"""
|
||||
return copy.deepcopy(self.context)
|
||||
|
||||
def retrieve(self, query: str) -> dict:
|
||||
if query in self.context.keys():
|
||||
return {query: self.context.get(query)}
|
||||
|
||||
# Note: retrieve information from context by string similarity maybe abandon in future
|
||||
scores = {}
|
||||
for k, v in self.context.items():
|
||||
scores.update({k: max(similarity(query, k), similarity(query, v))})
|
||||
max_score_key = max(scores, key=scores.get)
|
||||
return {max_score_key: self.context.get(max_score_key)}
|
||||
|
||||
def clear(self, reserve: list = None):
|
||||
if reserve is None:
|
||||
reserve = []
|
||||
|
||||
_context = {k: self.get_context(k) for k in reserve}
|
||||
self.context = _context
|
||||
|
||||
|
||||
class WorkflowManager:
|
||||
"""This manage the whole task automation workflow including tasks and actions"""
|
||||
|
||||
def __init__(self, workspace=None) -> None:
|
||||
self.logger = FinCoLog()
|
||||
|
||||
if workspace is None:
|
||||
self._workspace = Path.cwd() / "finco_workspace"
|
||||
else:
|
||||
self._workspace = Path(workspace)
|
||||
self.conf = Config()
|
||||
self._confirm_and_rm()
|
||||
|
||||
self.prompt_template = PromptTemplate()
|
||||
self.context = WorkflowContextManager()
|
||||
self.context.set_context("workspace", self._workspace)
|
||||
self.default_user_prompt = "Please help me build a low turnover strategy that focus more on longterm return in China A csi300. Please help to use lightgbm model."
|
||||
|
||||
def _confirm_and_rm(self):
|
||||
# if workspace exists, please confirm and remove it. Otherwise exit.
|
||||
if self._workspace.exists() and not self.conf.continuous_mode:
|
||||
self.logger.info(title="Interact")
|
||||
flag = input(
|
||||
LogColors().render(
|
||||
f"Will be deleted: \n\t{self._workspace}\n"
|
||||
f"If you do not need to delete {self._workspace},"
|
||||
f" please change the workspace dir or rename existing files\n"
|
||||
f"Are you sure you want to delete, yes(Y/y), no (N/n):",
|
||||
color=LogColors.WHITE)
|
||||
)
|
||||
if str(flag) not in ["Y", "y"]:
|
||||
sys.exit()
|
||||
else:
|
||||
# remove self._workspace
|
||||
shutil.rmtree(self._workspace)
|
||||
elif self._workspace.exists() and self.conf.continuous_mode:
|
||||
shutil.rmtree(self._workspace)
|
||||
|
||||
def set_context(self, key, value):
|
||||
"""Direct call set_context method of the context manager"""
|
||||
self.context.set_context(key, value)
|
||||
|
||||
def get_context(self) -> WorkflowContextManager:
|
||||
return self.context
|
||||
|
||||
def run(self, prompt: str) -> Path:
|
||||
"""
|
||||
The workflow manager is supposed to generate a codebase based on the prompt
|
||||
|
||||
Parameters
|
||||
----------
|
||||
prompt: str
|
||||
the prompt user gives
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
The workflow manager is expected to produce output that includes a codebase containing generated code, results, and reports in a designated location.
|
||||
The path is returned
|
||||
|
||||
The output path should follow a specific format:
|
||||
- TODO: design
|
||||
There is a summarized report where user can start from.
|
||||
"""
|
||||
|
||||
# NOTE: The following items are not designed to make the workflow very flexible.
|
||||
# - The generated tasks can't be changed after geting new information from the execution retuls.
|
||||
# - But it is required in some cases, if we want to build a external dataset, it maybe have to plan like autogpt...
|
||||
|
||||
# NOTE: default user prompt might be changed in the future and exposed to the user
|
||||
if prompt is None:
|
||||
self.set_context("user_prompt", self.default_user_prompt)
|
||||
else:
|
||||
self.set_context("user_prompt", prompt)
|
||||
self.logger.info(f"user_prompt: {self.get_context().get_context('user_prompt')}", title="Start")
|
||||
|
||||
# NOTE: list may not be enough for general task list
|
||||
task_list = [HighLevelPlanTask(), SummarizeTask()]
|
||||
task_finished = []
|
||||
while len(task_list):
|
||||
task_list_info = [str(task) for task in task_list]
|
||||
|
||||
# task list is not long, so sort it is not a big problem
|
||||
# TODO: sort the task list based on the priority of the task
|
||||
# task_list = sorted(task_list, key=lambda x: x.task_type)
|
||||
t = task_list.pop(0)
|
||||
self.logger.info(f"Task finished: {[str(task) for task in task_finished]}",
|
||||
f"Task in queue: {task_list_info}",
|
||||
f"Executing task: {str(t)}",
|
||||
title="Task")
|
||||
|
||||
t.assign_context_manager(self.context)
|
||||
res = t.execute()
|
||||
t.summarize()
|
||||
task_finished.append(t)
|
||||
self.context.set_context("task_finished", task_finished)
|
||||
self.logger.plain_info(f"{str(t)} finished.\n\n\n")
|
||||
|
||||
task_list = res + task_list
|
||||
|
||||
return self._workspace
|
||||
|
||||
|
||||
class LearnManager:
|
||||
__DEFAULT_TOPICS = ["IC", "MaxDropDown"]
|
||||
|
||||
def __init__(self):
|
||||
self.epoch = 0
|
||||
self.wm = WorkflowManager()
|
||||
|
||||
topics = [Topic(name=topic, describe=self.wm.prompt_template.get(f"Topic_{topic}")) for topic in
|
||||
self.__DEFAULT_TOPICS]
|
||||
self.knowledge_base = KnowledgeBase(init_path=Path.cwd().joinpath('knowledge'), topics=topics)
|
||||
|
||||
def run(self, prompt):
|
||||
# todo: add early stop condition
|
||||
for i in range(10):
|
||||
self.wm.run(prompt)
|
||||
self.knowledge_base.update(self.wm._workspace)
|
||||
self.knowledge_base.summarize_by_topic()
|
||||
self.learn()
|
||||
self.epoch += 1
|
||||
|
||||
def learn(self):
|
||||
workspace = self.wm.context.get_context("workspace")
|
||||
|
||||
def _drop_duplicate_task(_task: List):
|
||||
unique_task = {}
|
||||
for obj in _task:
|
||||
task_name = obj.__class__.__name__
|
||||
if task_name not in unique_task:
|
||||
unique_task[task_name] = obj
|
||||
return list(unique_task.values())
|
||||
|
||||
# one task maybe run several times in workflow
|
||||
task_finished = _drop_duplicate_task(self.wm.context.get_context("task_finished"))
|
||||
|
||||
user_prompt = self.wm.context.get_context("user_prompt")
|
||||
summary = self.wm.context.get_context("summary")
|
||||
|
||||
for task in task_finished:
|
||||
prompt_workflow_selection = self.wm.prompt_template.get(f"{self.__class__.__name__}_user").render(
|
||||
summary=summary, brief=self.knowledge_base.query_topics(),
|
||||
task_finished=[str(t) for t in task_finished],
|
||||
task=task.__class__.__name__, system=task.system.render(), user_prompt=user_prompt
|
||||
)
|
||||
|
||||
response = APIBackend().build_messages_and_create_chat_completion(
|
||||
user_prompt=prompt_workflow_selection,
|
||||
system_prompt=self.wm.prompt_template.get(f"{self.__class__.__name__}_system").render()
|
||||
)
|
||||
|
||||
# todo: response assertion
|
||||
task.prompt_template.update(key=f"{task.__class__.__name__}_system", value=Template(response))
|
||||
|
||||
self.wm.prompt_template.save(Path.joinpath(workspace, f"prompts/checkpoint_{self.epoch}.yml"))
|
||||
self.wm.context.clear(reserve=["workspace"])
|
||||
@@ -18,7 +18,7 @@ class StructuredCovEstimator(RiskModel):
|
||||
`B` is the regression coefficients matrix for all observations (row) on
|
||||
all factors (columns), and `U` is the residual matrix with shape like `X`.
|
||||
|
||||
Therefore the structured covariance can be estimated by
|
||||
Therefore, the structured covariance can be estimated by
|
||||
cov(X.T) = F @ cov(B.T) @ F.T + diag(var(U))
|
||||
|
||||
In finance domain, there are mainly three methods to design `F` [1][2]:
|
||||
|
||||
@@ -28,14 +28,15 @@ from qlib.typehint import Literal
|
||||
|
||||
def _get_multi_level_executor_config(
|
||||
strategy_config: dict,
|
||||
cash_limit: float = None,
|
||||
cash_limit: float | None = None,
|
||||
generate_report: bool = False,
|
||||
data_granularity: str = "1min",
|
||||
) -> dict:
|
||||
executor_config = {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "1min",
|
||||
"time_per_step": data_granularity,
|
||||
"verbose": False,
|
||||
"trade_type": SimulatorExecutor.TT_PARAL if cash_limit is not None else SimulatorExecutor.TT_SERIAL,
|
||||
"generate_report": generate_report,
|
||||
@@ -127,7 +128,7 @@ def single_with_simulator(
|
||||
backtest_config: dict,
|
||||
orders: pd.DataFrame,
|
||||
split: Literal["stock", "day"] = "stock",
|
||||
cash_limit: float = None,
|
||||
cash_limit: float | None = None,
|
||||
generate_report: bool = False,
|
||||
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
|
||||
"""Run backtest in a single thread with SingleAssetOrderExecution simulator. The orders will be executed day by day.
|
||||
@@ -154,12 +155,7 @@ def single_with_simulator(
|
||||
-------
|
||||
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
|
||||
"""
|
||||
if split == "stock":
|
||||
stock_id = orders.iloc[0].instrument
|
||||
init_qlib(backtest_config["qlib"], part=stock_id)
|
||||
else:
|
||||
day = orders.iloc[0].datetime
|
||||
init_qlib(backtest_config["qlib"], part=day)
|
||||
init_qlib(backtest_config["qlib"])
|
||||
|
||||
stocks = orders.instrument.unique().tolist()
|
||||
|
||||
@@ -181,13 +177,14 @@ def single_with_simulator(
|
||||
strategy_config=backtest_config["strategies"],
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
data_granularity=backtest_config["data_granularity"],
|
||||
)
|
||||
|
||||
exchange_config = copy.deepcopy(backtest_config["exchange"])
|
||||
exchange_config.update(
|
||||
{
|
||||
"codes": stocks,
|
||||
"freq": "1min",
|
||||
"freq": backtest_config["data_granularity"],
|
||||
}
|
||||
)
|
||||
|
||||
@@ -202,7 +199,7 @@ def single_with_simulator(
|
||||
reports.append(simulator.report_dict)
|
||||
decisions += simulator.decisions
|
||||
|
||||
indicator_1day_objs = [report["indicator"]["1day"][1] for report in reports]
|
||||
indicator_1day_objs = [report["indicator_dict"]["1day"][1] for report in reports]
|
||||
indicator_info = {k: v for obj in indicator_1day_objs for k, v in obj.order_indicator_his.items()}
|
||||
records = _convert_indicator_to_dataframe(indicator_info)
|
||||
assert records is None or not np.isnan(records["ffr"]).any()
|
||||
@@ -226,7 +223,7 @@ def single_with_collect_data_loop(
|
||||
backtest_config: dict,
|
||||
orders: pd.DataFrame,
|
||||
split: Literal["stock", "day"] = "stock",
|
||||
cash_limit: float = None,
|
||||
cash_limit: float | None = None,
|
||||
generate_report: bool = False,
|
||||
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
|
||||
"""Run backtest in a single thread with collect_data_loop.
|
||||
@@ -253,12 +250,7 @@ def single_with_collect_data_loop(
|
||||
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
|
||||
"""
|
||||
|
||||
if split == "stock":
|
||||
stock_id = orders.iloc[0].instrument
|
||||
init_qlib(backtest_config["qlib"], part=stock_id)
|
||||
else:
|
||||
day = orders.iloc[0].datetime
|
||||
init_qlib(backtest_config["qlib"], part=day)
|
||||
init_qlib(backtest_config["qlib"])
|
||||
|
||||
trade_start_time = orders["datetime"].min()
|
||||
trade_end_time = orders["datetime"].max()
|
||||
@@ -280,13 +272,14 @@ def single_with_collect_data_loop(
|
||||
strategy_config=backtest_config["strategies"],
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
data_granularity=backtest_config["data_granularity"],
|
||||
)
|
||||
|
||||
exchange_config = copy.deepcopy(backtest_config["exchange"])
|
||||
exchange_config.update(
|
||||
{
|
||||
"codes": stocks,
|
||||
"freq": "1min",
|
||||
"freq": backtest_config["data_granularity"],
|
||||
}
|
||||
)
|
||||
|
||||
@@ -357,7 +350,10 @@ def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFram
|
||||
|
||||
if not output_path.exists():
|
||||
os.makedirs(output_path)
|
||||
res.to_csv(output_path / "summary.csv")
|
||||
|
||||
if "pa" in res.columns:
|
||||
res["pa"] = res["pa"] * 10000.0 # align with training metrics
|
||||
res.to_csv(output_path / "backtest_result.csv")
|
||||
return res
|
||||
|
||||
|
||||
|
||||
@@ -98,8 +98,9 @@ def get_backtest_config_fromfile(path: str) -> dict:
|
||||
"debug_single_day": None,
|
||||
"concurrency": -1,
|
||||
"multiplier": 1.0,
|
||||
"output_dir": "outputs/",
|
||||
"output_dir": "outputs_backtest/",
|
||||
"generate_report": False,
|
||||
"data_granularity": "1min",
|
||||
}
|
||||
backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)
|
||||
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import cast, List, Optional
|
||||
|
||||
@@ -13,14 +17,15 @@ import yaml
|
||||
from qlib.backtest import Order
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.constant import ONE_MIN
|
||||
from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data
|
||||
from qlib.rl.data.native import load_handler_intraday_processed_data
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.trainer import Checkpoint, train
|
||||
from qlib.rl.trainer import Checkpoint, backtest, train
|
||||
from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
|
||||
from qlib.rl.utils.log import CsvWriter
|
||||
from qlib.utils import init_instance_by_config
|
||||
from tianshou.policy import BasePolicy
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
@@ -46,19 +51,17 @@ def _read_orders(order_dir: Path) -> pd.DataFrame:
|
||||
class LazyLoadDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
order_file_path: Path,
|
||||
data_dir: Path,
|
||||
default_start_time_index: int,
|
||||
default_end_time_index: int,
|
||||
) -> None:
|
||||
self._default_start_time_index = default_start_time_index
|
||||
self._default_end_time_index = default_end_time_index
|
||||
|
||||
self._order_file_path = order_file_path
|
||||
self._order_df = _read_orders(order_file_path).reset_index()
|
||||
|
||||
self._data_dir = data_dir
|
||||
self._ticks_index: Optional[pd.DatetimeIndex] = None
|
||||
self._data_dir = Path(data_dir)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._order_df)
|
||||
@@ -71,12 +74,17 @@ class LazyLoadDataset(Dataset):
|
||||
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
|
||||
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
|
||||
# TODO: of all dates.
|
||||
backtest_data = load_simple_intraday_backtest_data(
|
||||
|
||||
data = load_handler_intraday_processed_data(
|
||||
data_dir=self._data_dir,
|
||||
stock_id=row["instrument"],
|
||||
date=date,
|
||||
feature_columns_today=[],
|
||||
feature_columns_yesterday=[],
|
||||
backtest=True,
|
||||
index_only=True,
|
||||
)
|
||||
self._ticks_index = [t - date for t in backtest_data.get_time_index()]
|
||||
self._ticks_index = [t - date for t in data.today.index]
|
||||
|
||||
order = Order(
|
||||
stock_id=row["instrument"],
|
||||
@@ -98,93 +106,132 @@ def train_and_test(
|
||||
action_interpreter: ActionInterpreter,
|
||||
policy: BasePolicy,
|
||||
reward: Reward,
|
||||
run_training: bool,
|
||||
run_backtest: bool,
|
||||
) -> None:
|
||||
order_root_path = Path(data_config["source"]["order_dir"])
|
||||
|
||||
data_granularity = simulator_config.get("data_granularity", 1)
|
||||
|
||||
def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
|
||||
return SingleAssetOrderExecutionSimple(
|
||||
order=order,
|
||||
data_dir=Path(data_config["source"]["data_dir"]),
|
||||
data_dir=data_config["source"]["feature_root_dir"],
|
||||
feature_columns_today=data_config["source"]["feature_columns_today"],
|
||||
feature_columns_yesterday=data_config["source"]["feature_columns_yesterday"],
|
||||
data_granularity=data_granularity,
|
||||
ticks_per_step=simulator_config["time_per_step"],
|
||||
deal_price_type=data_config["source"].get("deal_price_column", "close"),
|
||||
vol_threshold=simulator_config["vol_limit"],
|
||||
)
|
||||
|
||||
train_dataset = LazyLoadDataset(
|
||||
order_file_path=order_root_path / "train",
|
||||
data_dir=Path(data_config["source"]["data_dir"]),
|
||||
default_start_time_index=data_config["source"]["default_start_time"],
|
||||
default_end_time_index=data_config["source"]["default_end_time"],
|
||||
)
|
||||
valid_dataset = LazyLoadDataset(
|
||||
order_file_path=order_root_path / "valid",
|
||||
data_dir=Path(data_config["source"]["data_dir"]),
|
||||
default_start_time_index=data_config["source"]["default_start_time"],
|
||||
default_end_time_index=data_config["source"]["default_end_time"],
|
||||
)
|
||||
assert data_config["source"]["default_start_time_index"] % data_granularity == 0
|
||||
assert data_config["source"]["default_end_time_index"] % data_granularity == 0
|
||||
|
||||
callbacks = []
|
||||
if "checkpoint_path" in trainer_config:
|
||||
callbacks.append(
|
||||
Checkpoint(
|
||||
dirpath=Path(trainer_config["checkpoint_path"]),
|
||||
every_n_iters=trainer_config["checkpoint_every_n_iters"],
|
||||
save_latest="copy",
|
||||
),
|
||||
if run_training:
|
||||
train_dataset, valid_dataset = [
|
||||
LazyLoadDataset(
|
||||
data_dir=data_config["source"]["feature_root_dir"],
|
||||
order_file_path=order_root_path / tag,
|
||||
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
|
||||
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
|
||||
)
|
||||
for tag in ("train", "valid")
|
||||
]
|
||||
|
||||
callbacks: List[Callback] = []
|
||||
if "checkpoint_path" in trainer_config:
|
||||
callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"])))
|
||||
callbacks.append(
|
||||
Checkpoint(
|
||||
dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints",
|
||||
every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1),
|
||||
save_latest="copy",
|
||||
),
|
||||
)
|
||||
if "earlystop_patience" in trainer_config:
|
||||
callbacks.append(
|
||||
EarlyStopping(
|
||||
patience=trainer_config["earlystop_patience"],
|
||||
monitor="val/pa",
|
||||
)
|
||||
)
|
||||
|
||||
train(
|
||||
simulator_fn=_simulator_factory_simple,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
policy=policy,
|
||||
reward=reward,
|
||||
initial_states=cast(List[Order], train_dataset),
|
||||
trainer_kwargs={
|
||||
"max_iters": trainer_config["max_epoch"],
|
||||
"finite_env_type": env_config["parallel_mode"],
|
||||
"concurrency": env_config["concurrency"],
|
||||
"val_every_n_iters": trainer_config.get("val_every_n_epoch", None),
|
||||
"callbacks": callbacks,
|
||||
},
|
||||
vessel_kwargs={
|
||||
"episode_per_iter": trainer_config["episode_per_collect"],
|
||||
"update_kwargs": {
|
||||
"batch_size": trainer_config["batch_size"],
|
||||
"repeat": trainer_config["repeat_per_collect"],
|
||||
},
|
||||
"val_initial_states": valid_dataset,
|
||||
},
|
||||
)
|
||||
|
||||
trainer_kwargs = {
|
||||
"max_iters": trainer_config["max_epoch"],
|
||||
"finite_env_type": env_config["parallel_mode"],
|
||||
"concurrency": env_config["concurrency"],
|
||||
"val_every_n_iters": trainer_config.get("val_every_n_epoch", None),
|
||||
"callbacks": callbacks,
|
||||
}
|
||||
vessel_kwargs = {
|
||||
"episode_per_iter": trainer_config["episode_per_collect"],
|
||||
"update_kwargs": {
|
||||
"batch_size": trainer_config["batch_size"],
|
||||
"repeat": trainer_config["repeat_per_collect"],
|
||||
},
|
||||
"val_initial_states": valid_dataset,
|
||||
}
|
||||
if run_backtest:
|
||||
test_dataset = LazyLoadDataset(
|
||||
data_dir=data_config["source"]["feature_root_dir"],
|
||||
order_file_path=order_root_path / "test",
|
||||
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
|
||||
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
|
||||
)
|
||||
|
||||
train(
|
||||
simulator_fn=_simulator_factory_simple,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
policy=policy,
|
||||
reward=reward,
|
||||
initial_states=cast(List[Order], train_dataset),
|
||||
trainer_kwargs=trainer_kwargs,
|
||||
vessel_kwargs=vessel_kwargs,
|
||||
)
|
||||
backtest(
|
||||
simulator_fn=_simulator_factory_simple,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
initial_states=test_dataset,
|
||||
policy=policy,
|
||||
logger=CsvWriter(Path(trainer_config["checkpoint_path"])),
|
||||
reward=reward,
|
||||
finite_env_type=env_config["parallel_mode"],
|
||||
concurrency=env_config["concurrency"],
|
||||
)
|
||||
|
||||
|
||||
def main(config: dict) -> None:
|
||||
def main(config: dict, run_training: bool, run_backtest: bool) -> None:
|
||||
if not run_training and not run_backtest:
|
||||
warnings.warn("Skip the entire job since training and backtest are both skipped.")
|
||||
return
|
||||
|
||||
if "seed" in config["runtime"]:
|
||||
seed_everything(config["runtime"]["seed"])
|
||||
|
||||
state_config = config["state_interpreter"]
|
||||
state_interpreter: StateInterpreter = init_instance_by_config(state_config)
|
||||
for extra_module_path in config["env"].get("extra_module_paths", []):
|
||||
sys.path.append(extra_module_path)
|
||||
|
||||
state_interpreter: StateInterpreter = init_instance_by_config(config["state_interpreter"])
|
||||
action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"])
|
||||
reward: Reward = init_instance_by_config(config["reward"])
|
||||
|
||||
additional_policy_kwargs = {
|
||||
"obs_space": state_interpreter.observation_space,
|
||||
"action_space": action_interpreter.action_space,
|
||||
}
|
||||
|
||||
# Create torch network
|
||||
if "kwargs" not in config["network"]:
|
||||
config["network"]["kwargs"] = {}
|
||||
config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space})
|
||||
network: nn.Module = init_instance_by_config(config["network"])
|
||||
if "network" in config:
|
||||
if "kwargs" not in config["network"]:
|
||||
config["network"]["kwargs"] = {}
|
||||
config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space})
|
||||
additional_policy_kwargs["network"] = init_instance_by_config(config["network"])
|
||||
|
||||
# Create policy
|
||||
config["policy"]["kwargs"].update(
|
||||
{
|
||||
"network": network,
|
||||
"obs_space": state_interpreter.observation_space,
|
||||
"action_space": action_interpreter.action_space,
|
||||
}
|
||||
)
|
||||
if "kwargs" not in config["policy"]:
|
||||
config["policy"]["kwargs"] = {}
|
||||
config["policy"]["kwargs"].update(additional_policy_kwargs)
|
||||
policy: BasePolicy = init_instance_by_config(config["policy"])
|
||||
|
||||
use_cuda = config["runtime"].get("use_cuda", False)
|
||||
@@ -200,20 +247,22 @@ def main(config: dict) -> None:
|
||||
state_interpreter=state_interpreter,
|
||||
policy=policy,
|
||||
reward=reward,
|
||||
run_training=run_training,
|
||||
run_backtest=run_backtest,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
|
||||
parser.add_argument("--no_training", action="store_true", help="Skip training workflow.")
|
||||
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.")
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config_path, "r") as input_stream:
|
||||
config = yaml.safe_load(input_stream)
|
||||
|
||||
main(config)
|
||||
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)
|
||||
|
||||
@@ -8,48 +8,14 @@ TODO: The implementation here is kind of adhoc. It is better to design a more un
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import cachetools
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import qlib
|
||||
from qlib.constant import REG_CN
|
||||
from qlib.contrib.ops.high_freq import BFillNan, Cut, Date, DayCumsum, DayLast, FFillNan, IsInf, IsNull, Select
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
dataset = None
|
||||
|
||||
|
||||
class DataWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
feature_dataset: DatasetH,
|
||||
backtest_dataset: DatasetH,
|
||||
columns_today: List[str],
|
||||
columns_yesterday: List[str],
|
||||
_internal: bool = False,
|
||||
):
|
||||
assert _internal, "Init function of data wrapper is for internal use only."
|
||||
|
||||
self.feature_dataset = feature_dataset
|
||||
self.backtest_dataset = backtest_dataset
|
||||
self.columns_today = columns_today
|
||||
self.columns_yesterday = columns_yesterday
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100),
|
||||
key=lambda _, stock_id, date, backtest: (stock_id, date.replace(hour=0, minute=0, second=0), backtest),
|
||||
)
|
||||
def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False) -> pd.DataFrame:
|
||||
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
|
||||
dataset = self.backtest_dataset if backtest else self.feature_dataset
|
||||
return dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
|
||||
|
||||
|
||||
def init_qlib(qlib_config: dict, part: str = None) -> None:
|
||||
def init_qlib(qlib_config: dict) -> None:
|
||||
"""Initialize necessary resource to launch the workflow, including data direction, feature columns, etc..
|
||||
|
||||
Parameters
|
||||
@@ -72,20 +38,15 @@ def init_qlib(qlib_config: dict, part: str = None) -> None:
|
||||
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1",
|
||||
],
|
||||
}
|
||||
part
|
||||
Identifying which part (stock / date) to load.
|
||||
"""
|
||||
|
||||
global dataset # pylint: disable=W0603
|
||||
|
||||
def _convert_to_path(path: str | Path) -> Path:
|
||||
return path if isinstance(path, Path) else Path(path)
|
||||
|
||||
provider_uri_map = {}
|
||||
if "provider_uri_day" in qlib_config:
|
||||
provider_uri_map["day"] = _convert_to_path(qlib_config["provider_uri_day"]).as_posix()
|
||||
if "provider_uri_1min" in qlib_config:
|
||||
provider_uri_map["1min"] = _convert_to_path(qlib_config["provider_uri_1min"]).as_posix()
|
||||
for granularity in ["1min", "5min", "day"]:
|
||||
if f"provider_uri_{granularity}" in qlib_config:
|
||||
provider_uri_map[f"{granularity}"] = _convert_to_path(qlib_config[f"provider_uri_{granularity}"]).as_posix()
|
||||
|
||||
qlib.init(
|
||||
region=REG_CN,
|
||||
@@ -119,47 +80,3 @@ def init_qlib(qlib_config: dict, part: str = None) -> None:
|
||||
redis_port=-1,
|
||||
clear_mem_cache=False, # init_qlib will be called for multiple times. Keep the cache for improving performance
|
||||
)
|
||||
|
||||
if part == "skip":
|
||||
return
|
||||
|
||||
# this won't work if it's put outside in case of multiprocessing
|
||||
from qlib.data import D # noqa pylint: disable=C0415,W0611
|
||||
|
||||
if part is None:
|
||||
feature_path = Path(qlib_config["feature_root_dir"]) / "feature.pkl"
|
||||
backtest_path = Path(qlib_config["feature_root_dir"]) / "backtest.pkl"
|
||||
else:
|
||||
feature_path = Path(qlib_config["feature_root_dir"]) / "feature" / (part + ".pkl")
|
||||
backtest_path = Path(qlib_config["feature_root_dir"]) / "backtest" / (part + ".pkl")
|
||||
|
||||
with feature_path.open("rb") as f:
|
||||
feature_dataset = pickle.load(f)
|
||||
with backtest_path.open("rb") as f:
|
||||
backtest_dataset = pickle.load(f)
|
||||
|
||||
dataset = DataWrapper(
|
||||
feature_dataset,
|
||||
backtest_dataset,
|
||||
qlib_config["feature_columns_today"],
|
||||
qlib_config["feature_columns_yesterday"],
|
||||
_internal=True,
|
||||
)
|
||||
|
||||
|
||||
def fetch_features(stock_id: str, date: pd.Timestamp, yesterday: bool = False, backtest: bool = False) -> pd.DataFrame:
|
||||
assert dataset is not None, "You must call init_qlib() before doing this."
|
||||
|
||||
if backtest:
|
||||
fields = ["$close", "$volume"]
|
||||
else:
|
||||
fields = dataset.columns_yesterday if yesterday else dataset.columns_today
|
||||
|
||||
data = dataset.get(stock_id, date, backtest)
|
||||
if data is None or len(data) == 0:
|
||||
# create a fake index, but RL doesn't care about index
|
||||
data = pd.DataFrame(0.0, index=np.arange(240), columns=fields, dtype=np.float32) # FIXME: hardcode here
|
||||
else:
|
||||
data = data.rename(columns={c: c.rstrip("0") for c in data.columns})
|
||||
data = data[fields]
|
||||
return data
|
||||
|
||||
@@ -2,17 +2,29 @@
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
from pathlib import Path
|
||||
from typing import cast, List
|
||||
|
||||
import cachetools
|
||||
import pandas as pd
|
||||
import pickle
|
||||
import os
|
||||
|
||||
from qlib.backtest import Exchange, Order
|
||||
from qlib.backtest.decision import TradeRange, TradeRangeByTime
|
||||
from qlib.rl.order_execution.utils import get_ticks_slice
|
||||
|
||||
from qlib.constant import EPS_T
|
||||
from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
|
||||
from .integration import fetch_features
|
||||
|
||||
|
||||
def get_ticks_slice(
|
||||
ticks_index: pd.DatetimeIndex,
|
||||
start: pd.Timestamp,
|
||||
end: pd.Timestamp,
|
||||
include_end: bool = False,
|
||||
) -> pd.DatetimeIndex:
|
||||
if not include_end:
|
||||
end = end - EPS_T
|
||||
return ticks_index[ticks_index.slice_indexer(start, end)]
|
||||
|
||||
|
||||
class IntradayBacktestData(BaseIntradayBacktestData):
|
||||
@@ -71,6 +83,31 @@ class IntradayBacktestData(BaseIntradayBacktestData):
|
||||
return pd.DatetimeIndex([e[1] for e in list(self._exchange.quote_df.index)])
|
||||
|
||||
|
||||
class DataframeIntradayBacktestData(BaseIntradayBacktestData):
|
||||
"""Backtest data from dataframe"""
|
||||
|
||||
def __init__(self, df: pd.DataFrame, price_column: str = "$close0", volume_column: str = "$volume0") -> None:
|
||||
self.df = df
|
||||
self.price_column = price_column
|
||||
self.volume_column = volume_column
|
||||
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
return f"{self.__class__.__name__}({self.df})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.df)
|
||||
|
||||
def get_deal_price(self) -> pd.Series:
|
||||
return self.df[self.price_column]
|
||||
|
||||
def get_volume(self) -> pd.Series:
|
||||
return self.df[self.volume_column]
|
||||
|
||||
def get_time_index(self) -> pd.DatetimeIndex:
|
||||
return cast(pd.DatetimeIndex, self.df.index)
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100),
|
||||
key=lambda order, _, __: order.key_by_day,
|
||||
@@ -103,13 +140,18 @@ def load_backtest_data(
|
||||
return backtest_data
|
||||
|
||||
|
||||
class NTIntradayProcessedData(BaseIntradayProcessedData):
|
||||
"""Subclass of IntradayProcessedData. Used to handle NT style data."""
|
||||
class HandlerIntradayProcessedData(BaseIntradayProcessedData):
|
||||
"""Subclass of IntradayProcessedData. Used to handle handler (bin format) style data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
index_only: bool = False,
|
||||
) -> None:
|
||||
def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame:
|
||||
df = df.reset_index()
|
||||
@@ -117,8 +159,18 @@ class NTIntradayProcessedData(BaseIntradayProcessedData):
|
||||
df = df.drop(columns=["instrument"])
|
||||
return df.set_index(["datetime"])
|
||||
|
||||
self.today = _drop_stock_id(fetch_features(stock_id, date))
|
||||
self.yesterday = _drop_stock_id(fetch_features(stock_id, date, yesterday=True))
|
||||
path = os.path.join(data_dir, "backtest" if backtest else "feature", f"{stock_id}.pkl")
|
||||
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
|
||||
with open(path, "rb") as fstream:
|
||||
dataset = pickle.load(fstream)
|
||||
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
|
||||
|
||||
if index_only:
|
||||
self.today = _drop_stock_id(data[[]])
|
||||
self.yesterday = _drop_stock_id(data[[]])
|
||||
else:
|
||||
self.today = _drop_stock_id(data[feature_columns_today])
|
||||
self.yesterday = _drop_stock_id(data[feature_columns_yesterday])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
@@ -127,12 +179,42 @@ class NTIntradayProcessedData(BaseIntradayProcessedData):
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
|
||||
key=lambda data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only: (
|
||||
stock_id,
|
||||
date,
|
||||
backtest,
|
||||
index_only,
|
||||
),
|
||||
)
|
||||
def load_nt_intraday_processed_data(stock_id: str, date: pd.Timestamp) -> NTIntradayProcessedData:
|
||||
return NTIntradayProcessedData(stock_id, date)
|
||||
def load_handler_intraday_processed_data(
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
index_only: bool = False,
|
||||
) -> HandlerIntradayProcessedData:
|
||||
return HandlerIntradayProcessedData(
|
||||
data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only
|
||||
)
|
||||
|
||||
|
||||
class NTProcessedDataProvider(ProcessedDataProvider):
|
||||
class HandlerProcessedDataProvider(ProcessedDataProvider):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.data_dir = Path(data_dir)
|
||||
self.feature_columns_today = feature_columns_today
|
||||
self.feature_columns_yesterday = feature_columns_yesterday
|
||||
self.backtest = backtest
|
||||
|
||||
def get_data(
|
||||
self,
|
||||
stock_id: str,
|
||||
@@ -140,4 +222,12 @@ class NTProcessedDataProvider(ProcessedDataProvider):
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
) -> BaseIntradayProcessedData:
|
||||
return load_nt_intraday_processed_data(stock_id, date)
|
||||
return load_handler_intraday_processed_data(
|
||||
self.data_dir,
|
||||
stock_id,
|
||||
date,
|
||||
self.feature_columns_today,
|
||||
self.feature_columns_yesterday,
|
||||
backtest=self.backtest,
|
||||
index_only=False,
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user