1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-29 09:01:18 +08:00

Compare commits

..

13 Commits

Author SHA1 Message Date
you-n-g
8388c8be8c Update handler.py 2023-04-06 16:10:01 +08:00
saurabh dave
e6f9a94fc5 fix: removed extra blank link between sections (#1451) 2023-04-03 17:32:01 +08:00
Fivele-Li
73937863f1 Merge pull request #1475 from qianyun210603/bugfix
[BUGFIX] potential file// url parsing error
2023-03-24 11:22:57 +08:00
BookSword
d010219ba6 Merge branch 'main' into bugfix 2023-03-23 16:11:19 +08:00
BookSword
4fc8a5f25f merge 2023-03-23 16:05:09 +08:00
Linlang
0e8bfcb5d3 fix_pylint_w0719 (#1463)
* fix_pylint_w0719

* remove_fixme
2023-03-17 19:25:49 +08:00
you-n-g
e457ca8511 Improve annotation & documentation for handler (#1312)
* Improve annotation & documentation for handler

* Add type
2023-03-15 21:15:40 +08:00
Huoran Li
4dbb8ecb86 Remove (#1464) 2023-03-15 15:26:44 +08:00
Huoran Li
653c082e7a Order execution open source (#1447)
* Waiting for bin data

* Complete readme

* CI

* Add inst filter by time

* Update qlib/data/dataset/processor.py

* typo

* Fix time filter bug

* Add Filter and set Universe

* Complete data pipeline

* Fix Provider Logger Info Args

* Add DQN; a minor bugfix in ppo reward.

* update readme. modify assertion logic in strategy check.

* Fix Doc issues and fix black

* Fix pylint Error

---------

Co-authored-by: Young <afe.young@gmail.com>
Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
2023-03-13 12:06:28 +08:00
you-n-g
f98e04ca9d Fix Field Name Error 2023-03-03 16:28:47 +08:00
Cadenza-Li
76f2fb1a1a Add ipynb format check (#1439)
* Update test_qlib_from_source.yml

* add ipynb format check to workflow

* test ipynb CI

* modify nbqa check path

* add pylint flake8 mypy check to ipynb

* check ipynb with black and pylint

* reformat .ipynb files

* format line length

nbqa black . -l 120

* update nbqa .ipynb format CI

* format old ipynb files

* add nbconvert check to CI

* adjust CI order to avoid repeating download data
2023-02-21 09:23:22 +08:00
Huoran Li
5eb5ac1f1f RL backtest pipeline on 5-min data (#1417)
* Workflow runnable

* CI

* Slight changes to make the workflow runnable. The changes of handler/provider should be reverted before merging.

* Train experiment successful

* Refine handler & provider

* test passed

* Ready to test on server

* Minor

* Test passed

* TWAP training

* Add PPOReward

* Add a FIXME

* Refine PPO reward according to PR comments

* Minor

* Resolve PR comments

* CI issues

* CI issues

* CI issues
2023-02-13 12:43:22 +08:00
Young
6295939346 Update to Dev Version 2023-01-29 18:55:23 +08:00
61 changed files with 1156 additions and 654 deletions

View File

@@ -120,6 +120,11 @@ jobs:
run: |
mypy qlib --install-types --non-interactive || true
mypy qlib --verbose
- name: Check Qlib ipynb with nbqa
run: |
nbqa black . -l 120 --check --diff
nbqa pylint . --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719,W0104,W0404,C0412,W0611,C0410 --const-rgx='[a-z_][a-z0-9_]{2,30}$'
- name: Test data downloads
run: |
@@ -138,6 +143,12 @@ jobs:
brew unlink libomp
brew install libomp.rb
# Run after data downloads
- name: Check Qlib ipynb with nbconvert
run: |
# add more ipynb files in future
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
- name: Test workflow by config (install from source)
run: |
python -m pip install numba

3
.gitignore vendored
View File

@@ -10,7 +10,6 @@ _build
build/
dist/
*.pkl
*.hd5
*.csv
@@ -27,6 +26,8 @@ examples/estimator/estimator_example/
examples/rl/data/
examples/rl/checkpoints/
examples/rl/outputs/
examples/rl_order_execution/data/
examples/rl_order_execution/outputs/
*.egg-info/

View File

@@ -29,13 +29,13 @@ class Avg15minHandler(DataHandlerLP):
fit_end_time=None,
process_type=DataHandlerLP.PTYPE_A,
filter_pipe=None,
inst_processor=None,
inst_processors=None,
**kwargs,
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
data_loader = Avg15minLoader(
config=self.loader_config(), filter_pipe=filter_pipe, freq=freq, inst_processor=inst_processor
config=self.loader_config(), filter_pipe=filter_pipe, freq=freq, inst_processors=inst_processors
)
super().__init__(
instruments=instruments,

View File

@@ -18,7 +18,7 @@ data_handler_config: &data_handler_config
label: day
feature: 1min
# with label as reference
inst_processor:
inst_processors:
feature:
- class: Resample1minProcessor
module_path: features_sample.py

View File

@@ -19,7 +19,7 @@ data_handler_config: &data_handler_config
feature_15min: 1min
feature_day: day
# with label as reference
inst_processor:
inst_processors:
feature_15min:
- class: ResampleNProcessor
module_path: features_resample_N.py

View File

@@ -25,59 +25,65 @@
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib\n",
"sns.set(style='white')\n",
"matplotlib.rcParams['pdf.fonttype'] = 42\n",
"matplotlib.rcParams['ps.fonttype'] = 42\n",
"\n",
"sns.set(style=\"white\")\n",
"matplotlib.rcParams[\"pdf.fonttype\"] = 42\n",
"matplotlib.rcParams[\"ps.fonttype\"] = 42\n",
"\n",
"from tqdm.auto import tqdm\n",
"from joblib import Parallel, delayed\n",
"\n",
"\n",
"def func(x, N=80):\n",
" ret = x.ret.copy()\n",
" x = x.rank(pct=True)\n",
" x['ret'] = ret\n",
" x[\"ret\"] = ret\n",
" diff = x.score.sub(x.label)\n",
" r = x.nlargest(N, columns='score').ret.mean()\n",
" r -= x.nsmallest(N, columns='score').ret.mean()\n",
" return pd.Series({\n",
" 'MSE': diff.pow(2).mean(), \n",
" 'MAE': diff.abs().mean(), \n",
" 'IC': x.score.corr(x.label),\n",
" 'R': r\n",
" })\n",
" \n",
" r = x.nlargest(N, columns=\"score\").ret.mean()\n",
" r -= x.nsmallest(N, columns=\"score\").ret.mean()\n",
" return pd.Series(\n",
" {\n",
" \"MSE\": diff.pow(2).mean(),\n",
" \"MAE\": diff.abs().mean(),\n",
" \"IC\": x.score.corr(x.label),\n",
" \"R\": r,\n",
" }\n",
" )\n",
"\n",
"\n",
"ret = pd.read_pickle(\"data/ret.pkl\").clip(-0.1, 0.1)\n",
"\n",
"\n",
"def backtest(fname, **kwargs):\n",
" pred = pd.read_pickle(fname).loc['2018-09-21':'2020-06-30'] # test period\n",
" pred['ret'] = ret\n",
" pred = pd.read_pickle(fname).loc[\"2018-09-21\":\"2020-06-30\"] # test period\n",
" pred[\"ret\"] = ret\n",
" dates = pred.index.unique(level=0)\n",
" res = Parallel(n_jobs=-1)(delayed(func)(pred.loc[d], **kwargs) for d in dates)\n",
" res = {\n",
" dates[i]: res[i]\n",
" for i in range(len(dates))\n",
" }\n",
" res = {dates[i]: res[i] for i in range(len(dates))}\n",
" res = pd.DataFrame(res).T\n",
" r = res['R'].copy()\n",
" r = res[\"R\"].copy()\n",
" r.index = pd.to_datetime(r.index)\n",
" r = r.reindex(pd.date_range(r.index[0], r.index[-1])).fillna(0) # paper use 365 days\n",
" return {\n",
" 'MSE': res['MSE'].mean(),\n",
" 'MAE': res['MAE'].mean(),\n",
" 'IC': res['IC'].mean(),\n",
" 'ICIR': res['IC'].mean()/res['IC'].std(),\n",
" 'AR': r.mean()*365,\n",
" 'AV': r.std()*365**0.5,\n",
" 'SR': r.mean()/r.std()*365**0.5,\n",
" 'MDD': (r.cumsum().cummax() - r.cumsum()).max()\n",
" \"MSE\": res[\"MSE\"].mean(),\n",
" \"MAE\": res[\"MAE\"].mean(),\n",
" \"IC\": res[\"IC\"].mean(),\n",
" \"ICIR\": res[\"IC\"].mean() / res[\"IC\"].std(),\n",
" \"AR\": r.mean() * 365,\n",
" \"AV\": r.std() * 365**0.5,\n",
" \"SR\": r.mean() / r.std() * 365**0.5,\n",
" \"MDD\": (r.cumsum().cummax() - r.cumsum()).max(),\n",
" }, r\n",
"\n",
"\n",
"def fmt(x, p=3, scale=1, std=False):\n",
" _fmt = '{:.%df}'%p\n",
" _fmt = \"{:.%df}\" % p\n",
" string = _fmt.format((x.mean() if not isinstance(x, (float, np.floating)) else x) * scale)\n",
" if std and len(x) > 1:\n",
" string += ' ('+_fmt.format(x.std()*scale)+')'\n",
" string += \" (\" + _fmt.format(x.std() * scale) + \")\"\n",
" return string\n",
"\n",
"\n",
"def backtest_multi(files, **kwargs):\n",
" res = []\n",
" pnl = []\n",
@@ -88,14 +94,14 @@
" res = pd.DataFrame(res)\n",
" pnl = pd.concat(pnl, axis=1)\n",
" return {\n",
" 'MSE': fmt(res['MSE'], std=True),\n",
" 'MAE': fmt(res['MAE'], std=True),\n",
" 'IC': fmt(res['IC']),\n",
" 'ICIR': fmt(res['ICIR']),\n",
" 'AR': fmt(res['AR'], scale=100, p=1)+'%',\n",
" 'VR': fmt(res['AV'], scale=100, p=1)+'%',\n",
" 'SR': fmt(res['SR']),\n",
" 'MDD': fmt(res['MDD'], scale=100, p=1)+'%'\n",
" \"MSE\": fmt(res[\"MSE\"], std=True),\n",
" \"MAE\": fmt(res[\"MAE\"], std=True),\n",
" \"IC\": fmt(res[\"IC\"]),\n",
" \"ICIR\": fmt(res[\"ICIR\"]),\n",
" \"AR\": fmt(res[\"AR\"], scale=100, p=1) + \"%\",\n",
" \"VR\": fmt(res[\"AV\"], scale=100, p=1) + \"%\",\n",
" \"SR\": fmt(res[\"SR\"]),\n",
" \"MDD\": fmt(res[\"MDD\"], scale=100, p=1) + \"%\",\n",
" }, pnl"
]
},
@@ -124,16 +130,20 @@
"outputs": [],
"source": [
"exps = {\n",
" 'Linear': ['output/Linear/pred.pkl'],\n",
" 'LightGBM': ['output/GBDT/lr0.05_leaves128/pred.pkl'],\n",
" 'MLP': glob.glob('output/search/MLP/hs128_bs512_do0.3_lr0.001_seed*/pred.pkl'),\n",
" 'SFM': glob.glob('output/search/SFM/hs32_bs512_do0.5_lr0.001_seed*/pred.pkl'),\n",
" 'ALSTM': glob.glob('output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n",
" 'Trans.': glob.glob('output/search/Transformer/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n",
" 'ALSTM+TS':glob.glob('output/LSTM_Attn_TS/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n",
" 'Trans.+TS':glob.glob('output/Transformer_TS/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n",
" 'ALSTM+TRA(Ours)': glob.glob('output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n",
" 'Trans.+TRA(Ours)': glob.glob('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb1.0_head4_hs64_bs512_do0.1_lr0.0005_seed*/pred.pkl')\n",
" \"Linear\": [\"output/Linear/pred.pkl\"],\n",
" \"LightGBM\": [\"output/GBDT/lr0.05_leaves128/pred.pkl\"],\n",
" \"MLP\": glob.glob(\"output/search/MLP/hs128_bs512_do0.3_lr0.001_seed*/pred.pkl\"),\n",
" \"SFM\": glob.glob(\"output/search/SFM/hs32_bs512_do0.5_lr0.001_seed*/pred.pkl\"),\n",
" \"ALSTM\": glob.glob(\"output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl\"),\n",
" \"Trans.\": glob.glob(\"output/search/Transformer/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl\"),\n",
" \"ALSTM+TS\": glob.glob(\"output/LSTM_Attn_TS/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl\"),\n",
" \"Trans.+TS\": glob.glob(\"output/Transformer_TS/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl\"),\n",
" \"ALSTM+TRA(Ours)\": glob.glob(\n",
" \"output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n",
" ),\n",
" \"Trans.+TRA(Ours)\": glob.glob(\n",
" \"output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb1.0_head4_hs64_bs512_do0.1_lr0.0005_seed*/pred.pkl\"\n",
" ),\n",
"}"
]
},
@@ -160,14 +170,8 @@
}
],
"source": [
"res = {\n",
" name: backtest_multi(exps[name])\n",
" for name in tqdm(exps)\n",
"}\n",
"report = pd.DataFrame({\n",
" k: v[0]\n",
" for k, v in res.items()\n",
"}).T"
"res = {name: backtest_multi(exps[name]) for name in tqdm(exps)}\n",
"report = pd.DataFrame({k: v[0] for k, v in res.items()}).T"
]
},
{
@@ -385,24 +389,40 @@
}
],
"source": [
"df = pd.read_pickle('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed1000/pred.pkl')\n",
"code = 'SH600157'\n",
"date = '2018-09-28'\n",
"df = pd.read_pickle(\n",
" \"output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed1000/pred.pkl\"\n",
")\n",
"code = \"SH600157\"\n",
"date = \"2018-09-28\"\n",
"lookbackperiod = 50\n",
"\n",
"prob = df.iloc[:, -3:].loc(axis=0)[:, code].reset_index(level=1, drop=True).loc[date:].iloc[:lookbackperiod]\n",
"pred = df.loc[:,[\"score_0\",\"score_1\",\"score_2\",\"label\"]].loc(axis=0)[:, code].reset_index(level=1, drop=True).loc[date:].iloc[:lookbackperiod]\n",
"e_all = pred.iloc[:,:-1].sub(pred.iloc[:,-1], axis=0).pow(2)\n",
"pred = (\n",
" df.loc[:, [\"score_0\", \"score_1\", \"score_2\", \"label\"]]\n",
" .loc(axis=0)[:, code]\n",
" .reset_index(level=1, drop=True)\n",
" .loc[date:]\n",
" .iloc[:lookbackperiod]\n",
")\n",
"e_all = pred.iloc[:, :-1].sub(pred.iloc[:, -1], axis=0).pow(2)\n",
"e_all = e_all.sub(e_all.min(axis=1), axis=0)\n",
"e_all.columns = [r'$\\theta_%d$'%d for d in range(1, 4)]\n",
"e_all.columns = [r\"$\\theta_%d$\" % d for d in range(1, 4)]\n",
"prob = pd.Series(np.argmax(prob.values, axis=1), index=prob.index).rolling(7).mean().round()\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(7, 3))\n",
"e_all.plot(ax=axes[0], xlabel='', rot=30)\n",
"prob.plot(ax=axes[1], xlabel='', rot=30, color='red', linestyle='None', marker='^', markersize=5)\n",
"e_all.plot(ax=axes[0], xlabel=\"\", rot=30)\n",
"prob.plot(\n",
" ax=axes[1],\n",
" xlabel=\"\",\n",
" rot=30,\n",
" color=\"red\",\n",
" linestyle=\"None\",\n",
" marker=\"^\",\n",
" markersize=5,\n",
")\n",
"plt.yticks(np.array([0, 1, 2]), e_all.columns.values)\n",
"axes[0].set_ylabel('Predictor Loss')\n",
"axes[1].set_ylabel('Router Selection')\n",
"axes[0].set_ylabel(\"Predictor Loss\")\n",
"axes[1].set_ylabel(\"Router Selection\")\n",
"plt.tight_layout()\n",
"# plt.savefig('select.pdf', bbox_inches='tight')\n",
"plt.show()"
@@ -428,10 +448,18 @@
"outputs": [],
"source": [
"exps = {\n",
" 'Random': glob.glob('output/search/LSTM_Attn_tra/K10_traHs16_traSrcNONE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n",
" 'LR': glob.glob('output/search/LSTM_Attn_tra/K10_traHs16_traSrcLR_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n",
" 'TPE': glob.glob('output/search/LSTM_Attn_tra/K10_traHs16_traSrcTPE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n",
" 'LR+TPE': glob.glob('output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl')\n",
" \"Random\": glob.glob(\n",
" \"output/search/LSTM_Attn_tra/K10_traHs16_traSrcNONE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n",
" ),\n",
" \"LR\": glob.glob(\n",
" \"output/search/LSTM_Attn_tra/K10_traHs16_traSrcLR_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n",
" ),\n",
" \"TPE\": glob.glob(\n",
" \"output/search/LSTM_Attn_tra/K10_traHs16_traSrcTPE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n",
" ),\n",
" \"LR+TPE\": glob.glob(\n",
" \"output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n",
" ),\n",
"}"
]
},
@@ -456,14 +484,8 @@
}
],
"source": [
"res = {\n",
" name: backtest_multi(exps[name])\n",
" for name in tqdm(exps)\n",
"}\n",
"report = pd.DataFrame({\n",
" k: v[0]\n",
" for k, v in res.items()\n",
"}).T"
"res = {name: backtest_multi(exps[name]) for name in tqdm(exps)}\n",
"report = pd.DataFrame({k: v[0] for k, v in res.items()}).T"
]
},
{
@@ -597,18 +619,22 @@
}
],
"source": [
"a = pd.read_pickle('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl')\n",
"b = pd.read_pickle('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl')\n",
"a = pd.read_pickle(\n",
" \"output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl\"\n",
")\n",
"b = pd.read_pickle(\n",
" \"output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl\"\n",
")\n",
"a = a.iloc[:, -3:]\n",
"b = b.iloc[:, -3:]\n",
"b = np.eye(3)[b.values.argmax(axis=1)]\n",
"a = np.eye(3)[a.values.argmax(axis=1)]\n",
"\n",
"res = pd.DataFrame({\n",
" 'with OT': b.sum(axis=0) / b.sum(),\n",
" 'without OT': a.sum(axis=0)/ a.sum() \n",
"},index=[r'$\\theta_1$',r'$\\theta_2$',r'$\\theta_3$'])\n",
"res.plot.bar(rot=30, figsize=(5, 4), color=['b', 'g'])\n",
"res = pd.DataFrame(\n",
" {\"with OT\": b.sum(axis=0) / b.sum(), \"without OT\": a.sum(axis=0) / a.sum()},\n",
" index=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"],\n",
")\n",
"res.plot.bar(rot=30, figsize=(5, 4), color=[\"b\", \"g\"])\n",
"del a, b"
]
},
@@ -633,11 +659,19 @@
"outputs": [],
"source": [
"exps = {\n",
" 'K=1': glob.glob('output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/info.json'),\n",
" 'K=3': glob.glob('output/search/finetune/LSTM_Attn_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json'),\n",
" 'K=5': glob.glob('output/search/finetune/LSTM_Attn_tra/K5_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json'),\n",
" 'K=10': glob.glob('output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json'),\n",
" 'K=20': glob.glob('output/search/finetune/LSTM_Attn_tra/K20_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json')\n",
" \"K=1\": glob.glob(\"output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/info.json\"),\n",
" \"K=3\": glob.glob(\n",
" \"output/search/finetune/LSTM_Attn_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json\"\n",
" ),\n",
" \"K=5\": glob.glob(\n",
" \"output/search/finetune/LSTM_Attn_tra/K5_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json\"\n",
" ),\n",
" \"K=10\": glob.glob(\n",
" \"output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json\"\n",
" ),\n",
" \"K=20\": glob.glob(\n",
" \"output/search/finetune/LSTM_Attn_tra/K20_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json\"\n",
" ),\n",
"}"
]
},
@@ -649,16 +683,11 @@
"source": [
"report = dict()\n",
"for k, v in exps.items():\n",
" \n",
" tmp = dict()\n",
" for fname in v:\n",
" with open(fname) as f:\n",
" info = json.load(f)\n",
" tmp[fname] = (\n",
" {\n",
" \"IC\":info[\"metric\"][\"IC\"],\n",
" \"MSE\":info[\"metric\"][\"MSE\"]\n",
" })\n",
" tmp[fname] = {\"IC\": info[\"metric\"][\"IC\"], \"MSE\": info[\"metric\"][\"MSE\"]}\n",
" tmp = pd.DataFrame(tmp).T\n",
" report[k] = tmp.mean()\n",
"report = pd.DataFrame(report).T"
@@ -681,13 +710,14 @@
}
],
"source": [
"fig, axes = plt.subplots(1, 2, figsize=(6,3)); axes = axes.flatten()\n",
"report['IC'].plot.bar(rot=30, ax=axes[0])\n",
"fig, axes = plt.subplots(1, 2, figsize=(6, 3))\n",
"axes = axes.flatten()\n",
"report[\"IC\"].plot.bar(rot=30, ax=axes[0])\n",
"axes[0].set_ylim(0.045, 0.062)\n",
"axes[0].set_title('IC performance')\n",
"report['MSE'].astype(float).plot.bar(rot=30, ax=axes[1], color='green')\n",
"axes[0].set_title(\"IC performance\")\n",
"report[\"MSE\"].astype(float).plot.bar(rot=30, ax=axes[1], color=\"green\")\n",
"axes[1].set_ylim(0.155, 0.1585)\n",
"axes[1].set_title('MSE performance')\n",
"axes[1].set_title(\"MSE performance\")\n",
"plt.tight_layout()\n",
"# plt.savefig('sensitivity.pdf')"
]

View File

@@ -1,60 +0,0 @@
This folder contains a simple example of how to run Qlib RL. It contains:
```
.
├── experiment_config
│ ├── backtest # Backtest config
│ └── training # Training config
├── README.md # Readme (the current file)
└── scripts # Scripts for data pre-processing
```
## Data preparation
Use [AzCopy](https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-v10) to download data:
```
azcopy copy https://qlibpublic.blob.core.windows.net/data/rl/qlib_rl_example_data ./ --recursive
mv qlib_rl_example_data data
```
The downloaded data will be placed at `./data`. The original data are in `data/csv`. To create all data needed by the case, run:
```
bash scripts/data_pipeline.sh
```
After the execution finishes, the `data/` directory should be like:
```
data
├── backtest_orders.csv
├── bin
├── csv
├── pickle
├── pickle_dataframe
└── training_order_split
```
## Run training
Run:
```
python -m qlib.rl.contrib.train_onpolicy --config_path ./experiment_config/training/config.yml
```
After training, checkpoints will be stored under `checkpoints/`.
## Run backtest
```
python -m qlib.rl.contrib.backtest --config_path ./experiment_config/backtest/config.yml
```
The backtest workflow will use the trained model in `checkpoints/`. The backtest summary can be found in `outputs/`.
## Others
The RL module is designed in a loosely-coupled way. Currently, RL examples are integrated with concrete business logic.
But the core part of RL is much simpler than what you see.
To demonstrate the simple core of RL, [a dedicated notebook](./simple_example.ipynb) for RL without business loss is created.

View File

@@ -1,57 +0,0 @@
order_file: ./data/backtest_orders.csv
start_time: "9:45"
end_time: "14:44"
qlib:
provider_uri_1min: ./data/bin
feature_root_dir: ./data/pickle
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$volume",
]
feature_columns_yesterday: [
"$open_v1", "$high_v1", "$low_v1", "$close_v1", "$vwap_v1", "$volume_v1",
]
exchange:
limit_threshold: ['$close == 0', '$close == 0']
deal_price: ["If($close == 0, $vwap, $close)", "If($close == 0, $vwap, $close)"]
volume_threshold:
all: ["cum", "0.2 * DayCumsum($volume, '9:45', '14:44')"]
buy: ["current", "$close"]
sell: ["current", "$close"]
strategies:
30min:
class: TWAPStrategy
module_path: qlib.contrib.strategy.rule_strategy
kwargs: {}
1day:
class: SAOEIntStrategy
module_path: qlib.rl.order_execution.strategy
kwargs:
state_interpreter:
class: FullHistoryStateInterpreter
module_path: qlib.rl.order_execution.interpreter
kwargs:
max_step: 8
data_ticks: 240
data_dim: 6
processed_data_provider:
class: PickleProcessedDataProvider
module_path: qlib.rl.data.pickle_styled
kwargs:
data_dir: ./data/pickle_dataframe/feature
action_interpreter:
class: CategoricalActionInterpreter
module_path: qlib.rl.order_execution.interpreter
kwargs:
values: 14
max_step: 8
network:
class: Recurrent
module_path: qlib.rl.order_execution.network
kwargs: {}
policy:
class: PPO
module_path: qlib.rl.order_execution.policy
kwargs:
lr: 1.0e-4
weight_file: ./checkpoints/latest.pth
concurrency: 5

View File

@@ -1,14 +0,0 @@
# Generate `bin` format data
set -e
python ../../scripts/dump_bin.py dump_all --csv_path ./data/csv --qlib_dir ./data/bin --include_fields open,close,high,low,vwap,volume --symbol_field_name symbol --date_field_name date --freq 1min
# Generate pickle format data
python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml
if [ -e stat/ ]; then
rm -r stat/
fi
python scripts/collect_pickle_dataframe.py
# Sample orders
python scripts/gen_training_orders.py
python scripts/gen_backtest_orders.py

View File

@@ -1,55 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import os
import pandas as pd
import numpy as np
import pickle
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=20220926)
parser.add_argument("--num_order", type=int, default=10)
args = parser.parse_args()
np.random.seed(args.seed)
path = os.path.join("data", "pickle", "backtesttest.pkl")
df = pickle.load(open(path, "rb")).reset_index()
df["date"] = df["datetime"].dt.date.astype("datetime64")
instruments = sorted(set(df["instrument"]))
# TODO: The example is expected to be able to handle data containing missing values.
# TODO: Currently, we just simply skip dates that contain missing data. We will add
# TODO: this feature in the future.
skip_dates = {}
for instrument in instruments:
csv_df = pd.read_csv(os.path.join("data", "csv", f"{instrument}.csv"))
csv_df = csv_df[csv_df["close"].isna()]
dates = set([str(d).split(" ")[0] for d in csv_df["date"]])
skip_dates[instrument] = dates
df_list = []
for instrument in instruments:
print(instrument)
cur_df = df[df["instrument"] == instrument]
dates = sorted(set([str(d).split(" ")[0] for d in cur_df["date"]]))
dates = [date for date in dates if date not in skip_dates[instrument]]
n = args.num_order
df_list.append(
pd.DataFrame(
{
"date": sorted(np.random.choice(dates, size=n, replace=False)),
"instrument": [instrument] * n,
"amount": np.random.randint(low=3, high=11, size=n) * 100.0,
"order_type": np.random.randint(low=0, high=2, size=n),
}
).set_index(["date", "instrument"]),
)
total_df = pd.concat(df_list)
total_df.to_csv("data/backtest_orders.csv")

View File

@@ -1,39 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import os
import pandas as pd
import numpy as np
import pickle
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=20220926)
parser.add_argument("--stock", type=str, default="AAPL")
parser.add_argument("--train_size", type=int, default=10)
parser.add_argument("--valid_size", type=int, default=2)
parser.add_argument("--test_size", type=int, default=2)
args = parser.parse_args()
np.random.seed(args.seed)
os.makedirs(os.path.join("data", "training_order_split"), exist_ok=True)
for group, n in zip(("train", "valid", "test"), (args.train_size, args.valid_size, args.test_size)):
path = os.path.join("data", "pickle", f"backtest{group}.pkl")
df = pickle.load(open(path, "rb")).reset_index()
df["date"] = df["datetime"].dt.date.astype("datetime64")
dates = sorted(set([str(d).split(" ")[0] for d in df["date"]]))
data_df = pd.DataFrame(
{
"date": sorted(np.random.choice(dates, size=n, replace=False)),
"instrument": [args.stock] * n,
"amount": np.random.randint(low=3, high=11, size=n) * 100.0,
"order_type": [0] * n,
}
).set_index(["date", "instrument"])
os.makedirs(os.path.join("data", "training_order_split", group), exist_ok=True)
pickle.dump(data_df, open(os.path.join("data", "training_order_split", group, f"{args.stock}.pkl"), "wb"))

View File

@@ -41,6 +41,7 @@
"\n",
"State = namedtuple(\"State\", [\"value\", \"last_action\"])\n",
"\n",
"\n",
"class SimpleSimulator(Simulator[float, State, float]):\n",
" def __init__(self, initial: float, nsteps: int, **kwargs: Any) -> None:\n",
" super().__init__(initial)\n",
@@ -92,6 +93,7 @@
"from gym import spaces\n",
"from qlib.rl.interpreter import StateInterpreter\n",
"\n",
"\n",
"class SimpleStateInterpreter(StateInterpreter[Tuple[float, float], np.ndarray]):\n",
" def interpret(self, state: State) -> np.ndarray:\n",
" # Convert state.value to a 1D Numpy array\n",
@@ -101,7 +103,8 @@
" @property\n",
" def observation_space(self) -> spaces.Box:\n",
" return spaces.Box(0, np.inf, shape=(1,), dtype=np.float32)\n",
" \n",
"\n",
"\n",
"state_interpreter = SimpleStateInterpreter()"
]
},
@@ -120,6 +123,7 @@
"source": [
"from qlib.rl.interpreter import ActionInterpreter\n",
"\n",
"\n",
"class SimpleActionInterpreter(ActionInterpreter[State, int, float]):\n",
" def __init__(self, n_value: int) -> None:\n",
" self.n_value = n_value\n",
@@ -132,7 +136,8 @@
" assert 0 <= action <= self.n_value\n",
" # simulator_state.value is used as the denominator\n",
" return simulator_state.value * (action / self.n_value)\n",
" \n",
"\n",
"\n",
"action_interpreter = SimpleActionInterpreter(n_value=10)"
]
},
@@ -151,12 +156,14 @@
"source": [
"from qlib.rl.reward import Reward\n",
"\n",
"\n",
"class SimpleReward(Reward[State]):\n",
" def reward(self, simulator_state: State) -> float:\n",
" # Use last_action to calculate reward. This is why it should be in the state.\n",
" rew = simulator_state.last_action / simulator_state.value\n",
" return rew\n",
" \n",
"\n",
"\n",
"reward = SimpleReward()"
]
},
@@ -180,6 +187,7 @@
"from torch import nn\n",
"from qlib.rl.order_execution import PPO\n",
"\n",
"\n",
"class SimpleFullyConnect(nn.Module):\n",
" def __init__(self, dims: List[int]) -> None:\n",
" super().__init__()\n",
@@ -195,7 +203,8 @@
"\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" return self.fc(x)\n",
" \n",
"\n",
"\n",
"policy = PPO(\n",
" network=SimpleFullyConnect(dims=[16, 8]),\n",
" obs_space=state_interpreter.observation_space,\n",
@@ -221,6 +230,7 @@
"source": [
"from torch.utils.data import Dataset\n",
"\n",
"\n",
"class SimpleDataset(Dataset):\n",
" def __init__(self, positions: List[float]) -> None:\n",
" self.positions = positions\n",
@@ -230,7 +240,8 @@
"\n",
" def __getitem__(self, index: int) -> float:\n",
" return self.positions[index]\n",
" \n",
"\n",
"\n",
"dataset = SimpleDataset(positions=[10.0, 50.0, 100.0])"
]
},
@@ -265,11 +276,13 @@
"trainer_kwargs = {\n",
" \"max_iters\": 10,\n",
" \"finite_env_type\": \"dummy\",\n",
" \"callbacks\": [Checkpoint(\n",
" dirpath=Path(\"./checkpoints\"),\n",
" every_n_iters=1,\n",
" save_latest=\"copy\",\n",
" )],\n",
" \"callbacks\": [\n",
" Checkpoint(\n",
" dirpath=Path(\"./checkpoints\"),\n",
" every_n_iters=1,\n",
" save_latest=\"copy\",\n",
" )\n",
" ],\n",
"}\n",
"vessel_kwargs = {\n",
" \"update_kwargs\": {\"batch_size\": 16, \"repeat\": 5},\n",

View File

@@ -0,0 +1,100 @@
# RL Example for Order Execution
This folder comprises an example of Reinforcement Learning (RL) workflows for order execution scenario, including both training workflows and backtest workflows.
## Data Processing
### Get Data
```
python -m qlib.run.get_data qlib_data qlib_data --target_dir ./data/bin --region hs300 --interval 5min
```
### Generate Pickle-Style Data
To run codes in this example, we need data in pickle format. To achieve this, run following commands (might need a few minutes to finish):
```
python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml
python scripts/collect_pickle_dataframe.py
python scripts/gen_training_orders.py
python scripts/merge_orders.py
```
When finished, the structure under `data/` should be:
```
data
├── bin
├── orders
├── pickle
└── pickle_dataframe
```
## Training
Each training task is specified by a config file. The config file for task `TASKNAME` is `exp_configs/train_TASKNAME.yml`. This example provides two training tasks:
- **PPO**: Method proposed by IJCAL 2020 paper "[An End-to-End Optimal Trade Execution Framework based on Proximal Policy Optimization](https://www.ijcai.org/proceedings/2020/0627.pdf)".
- **OPDS**: Method proposed by AAAI 2021 paper "[Universal Trading for Order Execution with Oracle Policy Distillation](https://arxiv.org/abs/2103.10860)".
The main differece between these two methods is their reward functions. Please see their config files for details.
Take OPDS as an example, to run the training workflow, run:
```
python -m qlib.rl.contrib.train_onpolicy --config_path exp_configs/train_opds.yml --run_backtest
```
Metrics, logs, and checkpoints will be stored under `outputs/opds` (configured by `exp_configs/train_opds.yml`).
## Backtest
Once the training workflow has completed, the trained model can be used for the backtesting workflow. Still taking OPDS as an example, once training is finished, the latest checkpoint of the model can be found at `outputs/opds/checkpoints/latest.pth`. To run backtest workflow:
1. Uncomment the `weight_file` parameter in `exp_configs/train_opds.yml` (it is commented by default). While it is possible to run the backtesting workflow without setting a checkpoint, this will lead to randomly initialized model results, thus making them meaningless.
2. Run `python -m qlib.rl.contrib.backtest --config_path exp_configs/backtest_opds.yml`.
The backtest result is stored in `outputs/checkpoints/backtest_result.csv`.
In addition to OPDS and PPO, we also provide TWAP ([Time-weighted average price](https://en.wikipedia.org/wiki/Time-weighted_average_price)) as a weak baseline. The config file for TWAP is `exp_configs/backtest_twap.yml`.
### Gap between backtest and training pipeline's testing
It is worthy to notice that the results of the backtesting process may differ from the results of the testing process used during training.
This is because different simulators are used to simulate market conditions during training and backtesting.
In training pipeline, the simplified simulator called `SingleAssetOrderExecutionSimple` is used for efficiency reasons.
`SingleAssetOrderExecutionSimple` makes no restriction to trading amounts.
No matter what the amount of the order is, it can be completely executed.
However, during backtesting, a more realistic simulator called `SingleAssetOrderExecution` is used.
It takes into account practical constraints in more real-world scenarios (for example, the trading volume must be a multiple of the smallest trading unit).
As a result, the amount of an order that is actually executed during backtesting may differ from the amount expected to be executed.
If you would like to obtain results that are exactly the same as those obtained during testing in the training pipeline, you could run training pipeline with only backtest phrase.
In order to do this:
- Modify the training config. Add the path of the checkpoint you want to use (see following for an example).
- Run `python -m qlib.rl.contrib.train_onpolicy --config_path PATH/TO/CONFIG --run_backtest --no_training`
```yaml
...
policy:
class: PPO # PPO, DQN
kwargs:
lr: 0.0001
weight_file: PATH/TO/CHECKPOINT
module_path: qlib.rl.order_execution.policy
...
```
## Benchmarks (TBD)
To accurately evaluate the performance of models using Reinforcement Learning algorithms, it's best to run experiments multiple times and compute the average performance across all trials. However, given the time-consuming nature of model training, this is not always feasible. An alternative approach is to run each training task only once, selecting the 10 checkpoints with the highest validation performance to simulate multiple trials. In this example, we use "Price Advantage (PA)" as the metric for selecting these checkpoints. The average performance of these 10 checkpoints on the testing set is as follows:
| **Model** | **PA mean with std.** |
|-----------------------------|-----------------------|
| OPDS (with PPO policy) | 0.4785 ± 0.7815 |
| OPDS (with DQN policy) | -0.0114 ± 0.5780 |
| PPO | -1.0935 ± 0.0922 |
| TWAP | ≈ 0.0 ± 0.0 |
The table above also includes TWAP as a rule-based baseline. The ideal PA of TWAP should be 0.0, however, in this example, the order execution is divided into two steps: first, the order is split equally among each half hour, and then each five minutes within each half hour. Since trading is forbidden during the last five minutes of the day, this approach may slightly differ from traditional TWAP over the course of a full day (as there are 5 minutes missing in the last "half hour"). Therefore, the PA of TWAP can be considered as a number that is close to 0.0. To verify this, you may run a TWAP backtest and check the results.

View File

@@ -0,0 +1,59 @@
order_file: ./data/orders/test_orders.pkl
start_time: "9:30"
end_time: "14:54"
qlib:
provider_uri_5min: ./data/bin/
feature_root_dir: ./data/pickle/
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
]
feature_columns_yesterday: [
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
]
exchange:
limit_threshold: null
deal_price: ["$close", "$close"]
volume_threshold: null
strategies:
1day:
class: SAOEIntStrategy
kwargs:
data_granularity: 5
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
max_step: 8
values: 4
module_path: qlib.rl.order_execution.interpreter
network:
class: Recurrent
kwargs: {}
module_path: qlib.rl.order_execution.network
policy:
class: PPO # PPO, DQN
kwargs:
lr: 0.0001
# Restore `weight_file` once the training workflow finishes. You can change the checkpoint file you want to use.
# weight_file: outputs/opds/checkpoints/latest.pth
module_path: qlib.rl.order_execution.policy
state_interpreter:
class: FullHistoryStateInterpreter
kwargs:
data_dim: 5
data_ticks: 48
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
kwargs:
data_dir: ./data/pickle_dataframe/feature
module_path: qlib.rl.data.pickle_styled
module_path: qlib.rl.order_execution.interpreter
module_path: qlib.rl.order_execution.strategy
30min:
class: TWAPStrategy
kwargs: {}
module_path: qlib.contrib.strategy.rule_strategy
concurrency: 16
output_dir: outputs/opds/

View File

@@ -0,0 +1,59 @@
order_file: ./data/orders/test_orders.pkl
start_time: "9:30"
end_time: "14:54"
qlib:
provider_uri_5min: ./data/bin/
feature_root_dir: ./data/pickle/
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
]
feature_columns_yesterday: [
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
]
exchange:
limit_threshold: null
deal_price: ["$close", "$close"]
volume_threshold: null
strategies:
1day:
class: SAOEIntStrategy
kwargs:
data_granularity: 5
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
max_step: 8
values: 4
module_path: qlib.rl.order_execution.interpreter
network:
class: Recurrent
kwargs: {}
module_path: qlib.rl.order_execution.network
policy:
class: PPO # PPO, DQN
kwargs:
lr: 0.0001
# Restore `weight_file` once the training workflow finishes. You can change the checkpoint file you want to use.
# weight_file: outputs/ppo/checkpoints/latest.pth
module_path: qlib.rl.order_execution.policy
state_interpreter:
class: FullHistoryStateInterpreter
kwargs:
data_dim: 5
data_ticks: 48
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
kwargs:
data_dir: ./data/pickle_dataframe/feature
module_path: qlib.rl.data.pickle_styled
module_path: qlib.rl.order_execution.interpreter
module_path: qlib.rl.order_execution.strategy
30min:
class: TWAPStrategy
kwargs: {}
module_path: qlib.contrib.strategy.rule_strategy
concurrency: 16
output_dir: outputs/ppo/

View File

@@ -0,0 +1,29 @@
order_file: ./data/orders/test_orders.pkl
start_time: "9:30"
end_time: "14:54"
qlib:
provider_uri_5min: ./data/bin/
feature_root_dir: ./data/pickle/
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
]
feature_columns_yesterday: [
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
]
exchange:
limit_threshold: null
deal_price: ["$close", "$close"]
volume_threshold: null
strategies:
1day:
class: TWAPStrategy
kwargs: {}
module_path: qlib.contrib.strategy.rule_strategy
30min:
class: TWAPStrategy
kwargs: {}
module_path: qlib.contrib.strategy.rule_strategy
concurrency: 16
output_dir: outputs/twap/

View File

@@ -1,20 +1,21 @@
simulator:
data_granularity: 5
time_per_step: 30
vol_limit: null
env:
concurrency: 1
parallel_mode: dummy
concurrency: 48
parallel_mode: shmem
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
values: 14
values: 4
max_step: 8
module_path: qlib.rl.order_execution.interpreter
state_interpreter:
class: FullHistoryStateInterpreter
kwargs:
data_dim: 6
data_ticks: 240
data_dim: 5
data_ticks: 48 # 48 = 240 min / 5 min
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
@@ -25,23 +26,24 @@ state_interpreter:
reward:
class: PAPenaltyReward
kwargs:
penalty: 100.0
penalty: 4.0
scale: 0.01
module_path: qlib.rl.order_execution.reward
data:
source:
order_dir: ./data/training_order_split
order_dir: ./data/orders
data_dir: ./data/pickle_dataframe/backtest
total_time: 240
default_start_time: 0
default_end_time: 240
proc_data_dim: 6
default_start_time_index: 0
default_end_time_index: 235
proc_data_dim: 5
num_workers: 0
queue_size: 20
network:
class: Recurrent
module_path: qlib.rl.order_execution.network
policy:
class: PPO
class: PPO # PPO, DQN
kwargs:
lr: 0.0001
module_path: qlib.rl.order_execution.policy
@@ -49,11 +51,11 @@ runtime:
seed: 42
use_cuda: false
trainer:
max_epoch: 2
repeat_per_collect: 5
earlystop_patience: 2
episode_per_collect: 20
batch_size: 16
val_every_n_epoch: 1
checkpoint_path: ./checkpoints
max_epoch: 500
repeat_per_collect: 25
earlystop_patience: 50
episode_per_collect: 10000
batch_size: 1024
val_every_n_epoch: 4
checkpoint_path: ./outputs/opds
checkpoint_every_n_iters: 1

View File

@@ -0,0 +1,62 @@
simulator:
data_granularity: 5
time_per_step: 30
vol_limit: null
env:
concurrency: 48
parallel_mode: shmem
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
values: 4
max_step: 8
module_path: qlib.rl.order_execution.interpreter
state_interpreter:
class: FullHistoryStateInterpreter
kwargs:
data_dim: 5
data_ticks: 48 # 48 = 240 min / 5 min
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
module_path: qlib.rl.data.pickle_styled
kwargs:
data_dir: ./data/pickle_dataframe/feature
module_path: qlib.rl.order_execution.interpreter
reward:
class: PPOReward
kwargs:
max_step: 8
start_time_index: 0
end_time_index: 46 # 46 = (240 - 5) min / 5 min - 1
module_path: qlib.rl.order_execution.reward
data:
source:
order_dir: ./data/orders
data_dir: ./data/pickle_dataframe/backtest
total_time: 240
default_start_time_index: 0
default_end_time_index: 235
proc_data_dim: 5
num_workers: 0
queue_size: 20
network:
class: Recurrent
module_path: qlib.rl.order_execution.network
policy:
class: PPO # PPO, DQN
kwargs:
lr: 0.0001
module_path: qlib.rl.order_execution.policy
runtime:
seed: 42
use_cuda: false
trainer:
max_epoch: 500
repeat_per_collect: 25
earlystop_patience: 50
episode_per_collect: 10000
batch_size: 1024
val_every_n_epoch: 4
checkpoint_path: ./outputs/ppo
checkpoint_every_n_iters: 1

View File

@@ -4,10 +4,17 @@
import os
import pickle
import pandas as pd
from tqdm import tqdm
from joblib import Parallel, delayed
os.makedirs(os.path.join("data", "pickle_dataframe"), exist_ok=True)
def _collect(df: pd.DataFrame, instrument: str, tag: str) -> None:
cur = df[df["instrument"] == instrument].sort_values(by=["datetime"])
cur = cur.set_index(["instrument", "datetime", "date"])
pickle.dump(cur, open(os.path.join("data", "pickle_dataframe", tag, f"{instrument}.pkl"), "wb"))
for tag in ("backtest", "feature"):
df = pickle.load(open(os.path.join("data", "pickle", f"{tag}.pkl"), "rb"))
df = pd.concat(list(df.values())).reset_index()
@@ -15,7 +22,5 @@ for tag in ("backtest", "feature"):
instruments = sorted(set(df["instrument"]))
os.makedirs(os.path.join("data", "pickle_dataframe", tag), exist_ok=True)
for instrument in tqdm(instruments):
cur = df[df["instrument"] == instrument].sort_values(by=["datetime"])
cur = cur.set_index(["instrument", "datetime", "date"])
pickle.dump(cur, open(os.path.join("data", "pickle_dataframe", tag, f"{instrument}.pkl"), "wb"))
Parallel(n_jobs=-1, verbose=10)(delayed(_collect)(df, instrument, tag) for instrument in instruments)

View File

@@ -4,6 +4,7 @@
import yaml
import argparse
import os
import shutil
from copy import deepcopy
from qlib.contrib.data.highfreq_provider import HighFreqProvider
@@ -41,3 +42,5 @@ if __name__ == "__main__":
if args.split == "stock" or args.split == "both":
provider._gen_stock_dataset(deepcopy(provider.feature_conf), "feature")
provider._gen_stock_dataset(deepcopy(provider.backtest_conf), "backtest")
shutil.rmtree("stat/", ignore_errors=True)

View File

@@ -0,0 +1,42 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path
DATA_PATH = Path(os.path.join("data", "pickle_dataframe", "backtest"))
OUTPUT_PATH = Path(os.path.join("data", "orders"))
def generate_order(stock: str, start_idx: int, end_idx: int) -> None:
df = pd.read_pickle(DATA_PATH / f"{stock}.pkl")
df = df.groupby("date").take(range(start_idx, end_idx)).droplevel(level=0)
div = df["$volume0"].rolling((end_idx - start_idx) * 60).mean().shift(1).groupby(level="date").transform("first")
order_all = pd.DataFrame(df.groupby(level=(2, 0)).mean().dropna())
order_all["amount"] = np.random.lognormal(-3.28, 1.14) * order_all["$volume0"]
order_all = order_all[order_all["amount"] > 0.0]
order_all["order_type"] = 0
order_all = order_all.drop(columns=["$volume0"])
order_train = order_all[order_all.index.get_level_values(0) <= pd.Timestamp("2021-06-30")]
order_test = order_all[order_all.index.get_level_values(0) > pd.Timestamp("2021-06-30")]
order_valid = order_test[order_test.index.get_level_values(0) <= pd.Timestamp("2021-09-30")]
order_test = order_test[order_test.index.get_level_values(0) > pd.Timestamp("2021-09-30")]
for order, tag in zip((order_train, order_valid, order_test, order_all), ("train", "valid", "test", "all")):
path = OUTPUT_PATH / tag
os.makedirs(path, exist_ok=True)
if len(order) > 0:
order.to_pickle(path / f"{stock}.pkl.target")
np.random.seed(1234)
file_list = sorted(os.listdir(DATA_PATH))
stocks = [f.replace(".pkl", "") for f in file_list]
stocks = sorted(np.random.choice(stocks, size=100, replace=False))
for stock in tqdm(stocks):
generate_order(stock, 0, 240 // 5 - 1)

View File

@@ -0,0 +1,15 @@
import pickle
import os
import pandas as pd
from tqdm import tqdm
for tag in ["test", "valid"]:
files = os.listdir(os.path.join("data/orders/", tag))
dfs = []
for f in tqdm(files):
df = pickle.load(open(os.path.join("data/orders/", tag, f), "rb"))
df = df.drop(["$close0"], axis=1)
dfs.append(df)
total_df = pd.concat(dfs)
pickle.dump(total_df, open(os.path.join("data", "orders", f"{tag}_orders.pkl"), "wb"))

View File

@@ -1,15 +1,16 @@
# start & end time for training/validation/test datasets
start_time: !!str &start 2020-01-01
end_time: !!str &end 2020-07-31
train_end_time: !!str &tend 2020-03-31
valid_start_time: !!str &vstart 2020-04-01
valid_end_time: !!str &vend 2020-05-31
test_start_time: !!str &tstart 2020-06-01
end_time: !!str &end 2021-12-31
train_end_time: !!str &tend 2021-06-30
valid_start_time: !!str &vstart 2021-07-01
valid_end_time: !!str &vend 2021-09-30
test_start_time: !!str &tstart 2021-10-01
# the instrument set
instruments: &ins all
instruments: &ins csi300s19_22
# qlib related configuration
qlib_conf:
provider_uri: ./data/bin # path to generated qlib bin
provider_uri:
5min: ./data/bin # path to generated qlib bin
redis_port: 233
feature_conf:
path: ./data/pickle/feature.pkl # output path of feature
@@ -26,14 +27,23 @@ feature_conf:
fit_end_time: *tend
instruments: *ins
day_length: 240 # how many minutes in one trading day
freq: 5min
columns: ["$open", "$high", "$low", "$close"]
infer_processors:
- class: HighFreqNorm
module_path: qlib.contrib.data.highfreq_processor
kwargs:
feature_save_dir: ./stat/ # output path of statistics of features (for feature normalization)
norm_groups:
price: 10
price: 8
volume: 2
inst_processors:
- class: TimeRangeFlt
module_path: qlib.data.dataset.processor
kwargs:
start_time: "2020-01-01"
end_time: "2021-12-31"
freq: 5min
segments:
train: !!python/tuple [*start, *tend]
valid: !!python/tuple [*vstart, *vend]
@@ -51,7 +61,17 @@ backtest_conf:
end_time: *end
instruments: *ins
day_length: 240
freq: 5min
columns: ["$close", "$volume"]
inst_processors:
- class: TimeRangeFlt
module_path: qlib.data.dataset.processor
kwargs:
start_time: "2020-01-01"
end_time: "2021-12-31"
freq: 5min
segments:
train: !!python/tuple [*start, *tend]
valid: !!python/tuple [*vstart, *vend]
test: !!python/tuple [*tstart, *end]
freq: 5min

View File

@@ -88,6 +88,7 @@
"outputs": [],
"source": [
"from qlib.tests.data import GetData\n",
"\n",
"GetData().qlib_data(exists_skip=True)"
]
},
@@ -99,6 +100,7 @@
"outputs": [],
"source": [
"import qlib\n",
"\n",
"qlib.init()"
]
},
@@ -134,7 +136,8 @@
"outputs": [],
"source": [
"from qlib.data import D\n",
"D.calendar(start_time='2010-01-01', end_time='2017-12-31', freq='day')[:2] # calendar data"
"\n",
"print(D.calendar(start_time=\"2010-01-01\", end_time=\"2017-12-31\", freq=\"day\")[:2]) # calendar data"
]
},
{
@@ -152,7 +155,12 @@
"metadata": {},
"outputs": [],
"source": [
"df = D.features(['SH601216'], ['$open', '$high', '$low', '$close', '$factor'], start_time='2020-05-01', end_time='2020-05-31') "
"df = D.features(\n",
" [\"SH601216\"],\n",
" [\"$open\", \"$high\", \"$low\", \"$close\", \"$factor\"],\n",
" start_time=\"2020-05-01\",\n",
" end_time=\"2020-05-31\",\n",
")"
]
},
{
@@ -163,11 +171,18 @@
"outputs": [],
"source": [
"import plotly.graph_objects as go\n",
"fig = go.Figure(data=[go.Candlestick(x=df.index.get_level_values(\"datetime\"),\n",
" open=df['$open'],\n",
" high=df['$high'],\n",
" low=df['$low'],\n",
" close=df['$close'])])\n",
"\n",
"fig = go.Figure(\n",
" data=[\n",
" go.Candlestick(\n",
" x=df.index.get_level_values(\"datetime\"),\n",
" open=df[\"$open\"],\n",
" high=df[\"$high\"],\n",
" low=df[\"$low\"],\n",
" close=df[\"$close\"],\n",
" )\n",
" ]\n",
")\n",
"fig.show()"
]
},
@@ -197,11 +212,18 @@
"outputs": [],
"source": [
"import plotly.graph_objects as go\n",
"fig = go.Figure(data=[go.Candlestick(x=df.index.get_level_values(\"datetime\"),\n",
" open=df['$open'] / df['$factor'],\n",
" high=df['$high'] / df['$factor'],\n",
" low=df['$low'] / df['$factor'],\n",
" close=df['$close'] / df['$factor'])])\n",
"\n",
"fig = go.Figure(\n",
" data=[\n",
" go.Candlestick(\n",
" x=df.index.get_level_values(\"datetime\"),\n",
" open=df[\"$open\"] / df[\"$factor\"],\n",
" high=df[\"$high\"] / df[\"$factor\"],\n",
" low=df[\"$low\"] / df[\"$factor\"],\n",
" close=df[\"$close\"] / df[\"$factor\"],\n",
" )\n",
" ]\n",
")\n",
"fig.show()"
]
},
@@ -240,7 +262,7 @@
"outputs": [],
"source": [
"# dynamic universe\n",
"universe = D.list_instruments(D.instruments('csi100'), start_time='2010-01-01', end_time='2020-12-31')\n",
"universe = D.list_instruments(D.instruments(\"csi100\"), start_time=\"2010-01-01\", end_time=\"2020-12-31\")\n",
"pprint(universe)"
]
},
@@ -271,8 +293,8 @@
"metadata": {},
"outputs": [],
"source": [
"df = D.features(D.instruments('csi100'), ['$close'], start_time='2010-01-01', end_time='2020-12-31') \n",
"df.groupby('datetime').size().plot()"
"df = D.features(D.instruments(\"csi100\"), [\"$close\"], start_time=\"2010-01-01\", end_time=\"2020-12-31\")\n",
"df.groupby(\"datetime\").size().plot()"
]
},
{
@@ -313,8 +335,7 @@
" !cd ../../scripts/data_collector/pit/ && pip install -r requirements.txt\n",
" !cd ../../scripts/data_collector/pit/ && python collector.py download_data --source_dir ~/.qlib/stock_data/source/pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_regex \"^(600519|000725).*\"\n",
" !cd ../../scripts/data_collector/pit/ && python collector.py normalize_data --interval quarterly --source_dir ~/.qlib/stock_data/source/pit --normalize_dir ~/.qlib/stock_data/source/pit_normalized\n",
" !cd ../../scripts/ && python dump_pit.py dump --csv_path ~/.qlib/stock_data/source/pit_normalized --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly\n",
" pass"
" !cd ../../scripts/ && python dump_pit.py dump --csv_path ~/.qlib/stock_data/source/pit_normalized --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly"
]
},
{
@@ -338,7 +359,13 @@
"outputs": [],
"source": [
"instruments = [\"sh600519\"]\n",
"data = D.features(instruments, ['P($$roewa_q)'], start_time=\"2019-01-01\", end_time=\"2019-07-19\", freq=\"day\")"
"data = D.features(\n",
" instruments,\n",
" [\"P($$roewa_q)\"],\n",
" start_time=\"2019-01-01\",\n",
" end_time=\"2019-07-19\",\n",
" freq=\"day\",\n",
")"
]
},
{
@@ -366,7 +393,10 @@
"metadata": {},
"outputs": [],
"source": [
"D.features([\"sh600519\"], ['(EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9)/$close'])"
"D.features(\n",
" [\"sh600519\"],\n",
" [\"(EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9)/$close\"],\n",
")"
]
},
{
@@ -418,7 +448,7 @@
"metadata": {},
"outputs": [],
"source": [
"qdl = QlibDataLoader(config=(['$close / Ref($close, 10)'], ['RET10']))"
"qdl = QlibDataLoader(config=([\"$close / Ref($close, 10)\"], [\"RET10\"]))"
]
},
{
@@ -428,7 +458,7 @@
"metadata": {},
"outputs": [],
"source": [
"qdl.load(instruments=['sh600519'], start_time='20190101', end_time='20191231')"
"qdl.load(instruments=[\"sh600519\"], start_time=\"20190101\", end_time=\"20191231\")"
]
},
{
@@ -456,7 +486,7 @@
"metadata": {},
"outputs": [],
"source": [
"df = qdl.load(instruments=['sh600519'], start_time='20190101', end_time='20191231')"
"df = qdl.load(instruments=[\"sh600519\"], start_time=\"20190101\", end_time=\"20191231\")"
]
},
{
@@ -476,7 +506,7 @@
"metadata": {},
"outputs": [],
"source": [
"df.plot(kind='hist')"
"df.plot(kind=\"hist\")"
]
},
{
@@ -508,9 +538,16 @@
"source": [
"# NOTE: normally, the training & validation time range will be `fit_start_time` `fit_end_time`\n",
"# howeverall the components are decomposed, so the training & validation time range is unknown when preprocessing.\n",
"dh = DataHandlerLP(instruments=['sh600519'], start_time='20170101', end_time='20191231',\n",
" infer_processors=[ZScoreNorm(fit_start_time='20170101', fit_end_time='20181231'), Fillna()],\n",
" data_loader=qdl)"
"dh = DataHandlerLP(\n",
" instruments=[\"sh600519\"],\n",
" start_time=\"20170101\",\n",
" end_time=\"20191231\",\n",
" infer_processors=[\n",
" ZScoreNorm(fit_start_time=\"20170101\", fit_end_time=\"20181231\"),\n",
" Fillna(),\n",
" ],\n",
" data_loader=qdl,\n",
")"
]
},
{
@@ -550,7 +587,7 @@
"metadata": {},
"outputs": [],
"source": [
"df.plot(kind='hist')"
"df.plot(kind=\"hist\")"
]
},
{
@@ -586,7 +623,7 @@
"metadata": {},
"outputs": [],
"source": [
"ds = DatasetH(dh, segments={\"train\": ('20180101', '20181231'), \"valid\": ('20190101', '20191231')})"
"ds = DatasetH(dh, segments={\"train\": (\"20180101\", \"20181231\"), \"valid\": (\"20190101\", \"20191231\")})"
]
},
{
@@ -596,7 +633,7 @@
"metadata": {},
"outputs": [],
"source": [
"ds.prepare('train')"
"ds.prepare(\"train\")"
]
},
{
@@ -606,7 +643,7 @@
"metadata": {},
"outputs": [],
"source": [
"ds.prepare('valid')"
"ds.prepare(\"valid\")"
]
},
{
@@ -628,8 +665,12 @@
"metadata": {},
"outputs": [],
"source": [
"ds = TSDatasetH(step_len=10, handler=dh, segments={\"train\": ('20180101', '20181231'), \"valid\": ('20190101', '20191231')})\n",
"train_sampler = ds.prepare('train')"
"ds = TSDatasetH(\n",
" step_len=10,\n",
" handler=dh,\n",
" segments={\"train\": (\"20180101\", \"20181231\"), \"valid\": (\"20190101\", \"20191231\")},\n",
")\n",
"train_sampler = ds.prepare(\"train\")"
]
},
{
@@ -649,7 +690,7 @@
"metadata": {},
"outputs": [],
"source": [
"train_sampler[0] # Retrieving the first example"
"train_sampler[0] # Retrieving the first example"
]
},
{
@@ -659,7 +700,7 @@
"metadata": {},
"outputs": [],
"source": [
"train_sampler['2018-01-08', 'sh600519'] # get the time series by <'timestamp', 'instrument_id'> index"
"train_sampler[\"2018-01-08\", \"sh600519\"] # get the time series by <'timestamp', 'instrument_id'> index"
]
},
{
@@ -682,11 +723,11 @@
"outputs": [],
"source": [
"handler_kwargs = {\n",
" \"start_time\": \"2008-01-01\",\n",
" \"end_time\": \"2020-08-01\",\n",
" \"fit_start_time\": \"2008-01-01\",\n",
" \"fit_end_time\": \"2014-12-31\",\n",
" \"instruments\": MARKET,\n",
" \"start_time\": \"2008-01-01\",\n",
" \"end_time\": \"2020-08-01\",\n",
" \"fit_start_time\": \"2008-01-01\",\n",
" \"fit_end_time\": \"2014-12-31\",\n",
" \"instruments\": MARKET,\n",
"}\n",
"handler_conf = {\n",
" \"class\": \"Alpha158\",\n",
@@ -735,6 +776,7 @@
"outputs": [],
"source": [
"from qlib.contrib.data.handler import Alpha158\n",
"\n",
"hd = Alpha158(**handler_kwargs)"
]
},
@@ -826,7 +868,7 @@
"metadata": {},
"outputs": [],
"source": [
"hd.process_type # appending type"
"hd.process_type # appending type"
]
},
{
@@ -857,16 +899,16 @@
"outputs": [],
"source": [
"dataset_conf = {\n",
" \"class\": \"DatasetH\",\n",
" \"module_path\": \"qlib.data.dataset\",\n",
" \"kwargs\": {\n",
" \"handler\": hd,\n",
" \"segments\": {\n",
" \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
" \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
" \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
" },\n",
" \"class\": \"DatasetH\",\n",
" \"module_path\": \"qlib.data.dataset\",\n",
" \"kwargs\": {\n",
" \"handler\": hd,\n",
" \"segments\": {\n",
" \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
" \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
" \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
" },\n",
" },\n",
"}"
]
},
@@ -908,7 +950,8 @@
"metadata": {},
"outputs": [],
"source": [
"model = init_instance_by_config({\n",
"model = init_instance_by_config(\n",
" {\n",
" \"class\": \"LGBModel\",\n",
" \"module_path\": \"qlib.contrib.model.gbdt\",\n",
" \"kwargs\": {\n",
@@ -922,7 +965,8 @@
" \"num_leaves\": 210,\n",
" \"num_threads\": 20,\n",
" },\n",
"})"
" }\n",
")"
]
},
{
@@ -938,7 +982,7 @@
" R.save_objects(trained_model=model)\n",
"\n",
" rec = R.get_recorder()\n",
" rid = rec.id # save the record id\n",
" rid = rec.id # save the record id\n",
"\n",
" # Inference and saving signal\n",
" sr = SignalRecord(model, dataset, rec)\n",
@@ -1001,12 +1045,11 @@
"\n",
"# backtest and analysis\n",
"with R.start(experiment_name=EXP_NAME, recorder_id=rid, resume=True):\n",
"\n",
" # signal-based analysis\n",
" rec = R.get_recorder()\n",
" sar = SigAnaRecord(rec)\n",
" sar.generate()\n",
" \n",
"\n",
" # portfolio-based analysis: backtest\n",
" par = PortAnaRecord(rec, port_analysis_config, \"day\")\n",
" par.generate()"
@@ -1137,7 +1180,7 @@
"outputs": [],
"source": [
"label_df = dataset.prepare(\"test\", col_set=\"label\")\n",
"label_df.columns = ['label']"
"label_df.columns = [\"label\"]"
]
},
{

View File

@@ -38,7 +38,7 @@
" # install qlib\n",
" ! pip install --upgrade numpy\n",
" ! pip install pyqlib\n",
" if 'google.colab' in sys.modules:\n",
" if \"google.colab\" in sys.modules:\n",
" # The Google colab environment is a little outdated. We have to downgrade the pyyaml to make it compatible with other packages\n",
" ! pip install pyyaml==5.4.1\n",
" # reload\n",
@@ -50,7 +50,8 @@
" scripts_dir = Path(\"~/tmp/qlib_code/scripts\").expanduser().resolve()\n",
" scripts_dir.mkdir(parents=True, exist_ok=True)\n",
" import requests\n",
" with requests.get(\"https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py\") as resp:\n",
"\n",
" with requests.get(\"https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py\", timeout=10) as resp:\n",
" with open(scripts_dir.joinpath(\"get_data.py\"), \"wb\") as fp:\n",
" fp.write(resp.content)"
]
@@ -61,14 +62,13 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"import qlib\n",
"import pandas as pd\n",
"from qlib.constant import REG_CN\n",
"from qlib.utils import exists_qlib_data, init_instance_by_config\n",
"from qlib.workflow import R\n",
"from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n",
"from qlib.utils import flatten_dict\n"
"from qlib.utils import flatten_dict"
]
},
{
@@ -86,6 +86,7 @@
" print(f\"Qlib data is not found in {provider_uri}\")\n",
" sys.path.append(str(scripts_dir))\n",
" from get_data import GetData\n",
"\n",
" GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n",
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
]
@@ -169,7 +170,7 @@
" R.log_params(**flatten_dict(task))\n",
" model.fit(dataset)\n",
" R.save_objects(trained_model=model)\n",
" rid = R.get_recorder().id\n"
" rid = R.get_recorder().id"
]
},
{
@@ -238,7 +239,7 @@
"\n",
" # backtest & analysis\n",
" par = PortAnaRecord(recorder, port_analysis_config, \"day\")\n",
" par.generate()\n"
" par.generate()"
]
},
{
@@ -256,6 +257,7 @@
"source": [
"from qlib.contrib.report import analysis_model, analysis_position\n",
"from qlib.data import D\n",
"\n",
"recorder = R.get_recorder(recorder_id=ba_rid, experiment_name=\"backtest_analysis\")\n",
"print(recorder)\n",
"pred_df = recorder.load_object(\"pred.pkl\")\n",
@@ -317,7 +319,7 @@
"outputs": [],
"source": [
"label_df = dataset.prepare(\"test\", col_set=\"label\")\n",
"label_df.columns = ['label']"
"label_df.columns = [\"label\"]"
]
},
{

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
from pathlib import Path
__version__ = "0.9.1"
__version__ = "0.9.1.99"
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
import os
from typing import Union

View File

@@ -40,8 +40,8 @@ def get_exchange(
open_cost: float = 0.0015,
close_cost: float = 0.0025,
min_cost: float = 5.0,
limit_threshold: Union[Tuple[str, str], float, None] = None,
deal_price: Union[str, Tuple[str, str], List[str]] = None,
limit_threshold: Union[Tuple[str, str], float, None] | None = None,
deal_price: Union[str, Tuple[str, str], List[str]] | None = None,
**kwargs: Any,
) -> Exchange:
"""get_exchange
@@ -284,7 +284,7 @@ def collect_data(
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
return_value: dict = None,
return_value: dict | None = None,
) -> Generator[object, None, None]:
"""initialize the strategy and executor, then collect the trade decision data for rl training

View File

@@ -152,7 +152,9 @@ class Account:
# trading related metrics(e.g. high-frequency trading)
self.indicator = Indicator()
def reset(self, freq: str = None, benchmark_config: dict = None, port_metr_enabled: bool = None) -> None:
def reset(
self, freq: str | None = None, benchmark_config: dict | None = None, port_metr_enabled: bool | None = None
) -> None:
"""reset freq and report of account
Parameters

View File

@@ -55,7 +55,7 @@ def collect_data_loop(
end_time: Union[pd.Timestamp, str],
trade_strategy: BaseStrategy,
trade_executor: BaseExecutor,
return_value: dict = None,
return_value: dict | None = None,
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], None]:
"""Generator for collecting the trade decision data for rl training

View File

@@ -254,7 +254,7 @@ class IdxTradeRange(TradeRange):
self._start_idx = start_idx
self._end_idx = end_idx
def __call__(self, trade_calendar: TradeCalendarManager = None) -> Tuple[int, int]:
def __call__(self, trade_calendar: TradeCalendarManager | None = None) -> Tuple[int, int]:
return self._start_idx, self._end_idx
def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:
@@ -315,7 +315,7 @@ class BaseTradeDecision(Generic[DecisionType]):
2. Same as `case 1.3`
"""
def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None) -> None:
def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange, None] = None) -> None:
"""
Parameters
----------
@@ -554,7 +554,7 @@ class TradeDecisionWO(BaseTradeDecision[Order]):
self,
order_list: List[Order],
strategy: BaseStrategy,
trade_range: Union[Tuple[int, int], TradeRange] = None,
trade_range: Union[Tuple[int, int], TradeRange, None] = None,
) -> None:
super().__init__(strategy, trade_range=trade_range)
self.order_list = cast(List[Order], order_list)

View File

@@ -41,10 +41,10 @@ class Exchange:
start_time: Union[pd.Timestamp, str] = None,
end_time: Union[pd.Timestamp, str] = None,
codes: Union[list, str] = "all",
deal_price: Union[str, Tuple[str, str], List[str]] = None,
deal_price: Union[str, Tuple[str, str], List[str], None] = None,
subscribe_fields: list = [],
limit_threshold: Union[Tuple[str, str], float, None] = None,
volume_threshold: Union[tuple, dict] = None,
volume_threshold: Union[tuple, dict, None] = None,
open_cost: float = 0.0015,
close_cost: float = 0.0025,
min_cost: float = 5.0,
@@ -340,7 +340,7 @@ class Exchange:
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
direction: int = None,
direction: int | None = None,
) -> bool:
"""
Parameters
@@ -406,7 +406,7 @@ class Exchange:
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
direction: int = None,
direction: int | None = None,
) -> bool:
# check if stock can be traded
return not (
@@ -421,8 +421,8 @@ class Exchange:
def deal_order(
self,
order: Order,
trade_account: Account = None,
position: BasePosition = None,
trade_account: Account | None = None,
position: BasePosition | None = None,
dealt_order_amount: Dict[str, float] = defaultdict(float),
) -> Tuple[float, float, float]:
"""
@@ -586,7 +586,7 @@ class Exchange:
)
return amount_dict
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float = None) -> float:
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float | None = None) -> float:
"""
Calculate the real adjust deal amount when considering the trading unit
:param current_amount:
@@ -712,8 +712,8 @@ class Exchange:
def _get_factor_or_raise_error(
self,
factor: float = None,
stock_id: str = None,
factor: float | None = None,
stock_id: str | None = None,
start_time: pd.Timestamp = None,
end_time: pd.Timestamp = None,
) -> float:
@@ -728,8 +728,8 @@ class Exchange:
def get_amount_of_trade_unit(
self,
factor: float = None,
stock_id: str = None,
factor: float | None = None,
stock_id: str | None = None,
start_time: pd.Timestamp = None,
end_time: pd.Timestamp = None,
) -> Optional[float]:
@@ -762,8 +762,8 @@ class Exchange:
def round_amount_by_trade_unit(
self,
deal_amount: float,
factor: float = None,
stock_id: str = None,
factor: float | None = None,
stock_id: str | None = None,
start_time: pd.Timestamp = None,
end_time: pd.Timestamp = None,
) -> float:

View File

@@ -31,8 +31,8 @@ class BaseExecutor:
generate_portfolio_metrics: bool = False,
verbose: bool = False,
track_data: bool = False,
trade_exchange: Exchange = None,
common_infra: CommonInfrastructure = None,
trade_exchange: Exchange | None = None,
common_infra: CommonInfrastructure | None = None,
settle_type: str = BasePosition.ST_NO,
**kwargs: Any,
) -> None:
@@ -161,7 +161,7 @@ class BaseExecutor:
"""
return self.level_infra.get("trade_calendar")
def reset(self, common_infra: CommonInfrastructure = None, **kwargs: Any) -> None:
def reset(self, common_infra: CommonInfrastructure | None = None, **kwargs: Any) -> None:
"""
- reset `start_time` and `end_time`, used in trade calendar
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
@@ -227,7 +227,7 @@ class BaseExecutor:
def collect_data(
self,
trade_decision: BaseTradeDecision,
return_value: dict = None,
return_value: dict | None = None,
level: int = 0,
) -> Generator[Any, Any, List[object]]:
"""Generator for collecting the trade decision data for rl training
@@ -327,7 +327,7 @@ class NestedExecutor(BaseExecutor):
track_data: bool = False,
skip_empty_decision: bool = True,
align_range_limit: bool = True,
common_infra: CommonInfrastructure = None,
common_infra: CommonInfrastructure | None = None,
**kwargs: Any,
) -> None:
"""
@@ -534,7 +534,7 @@ class SimulatorExecutor(BaseExecutor):
generate_portfolio_metrics: bool = False,
verbose: bool = False,
track_data: bool = False,
common_infra: CommonInfrastructure = None,
common_infra: CommonInfrastructure | None = None,
trade_type: str = TT_SERIAL,
**kwargs: Any,
) -> None:

View File

@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from datetime import timedelta
from typing import Any, Dict, List, Union
@@ -320,7 +321,7 @@ class Position(BasePosition):
self.position[stock]["price"] = price_dict[stock]
self.position["now_account_value"] = self.calculate_value()
def _init_stock(self, stock_id: str, amount: float, price: float = None) -> None:
def _init_stock(self, stock_id: str, amount: float, price: float | None = None) -> None:
"""
initialization the stock in current position

View File

@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import pathlib
from collections import OrderedDict
@@ -86,7 +87,7 @@ class PortfolioMetrics:
self.benches: dict = OrderedDict()
self.latest_pm_time: Optional[pd.TimeStamp] = None
def init_bench(self, freq: str = None, benchmark_config: dict = None) -> None:
def init_bench(self, freq: str | None = None, benchmark_config: dict | None = None) -> None:
if freq is not None:
self.freq = freq
self.benchmark_config = benchmark_config
@@ -149,15 +150,15 @@ class PortfolioMetrics:
self,
trade_start_time: Union[str, pd.Timestamp] = None,
trade_end_time: Union[str, pd.Timestamp] = None,
account_value: float = None,
cash: float = None,
return_rate: float = None,
total_turnover: float = None,
turnover_rate: float = None,
total_cost: float = None,
cost_rate: float = None,
stock_value: float = None,
bench_value: float = None,
account_value: float | None = None,
cash: float | None = None,
return_rate: float | None = None,
total_turnover: float | None = None,
turnover_rate: float | None = None,
total_cost: float | None = None,
cost_rate: float | None = None,
stock_value: float | None = None,
bench_value: float | None = None,
) -> None:
# check data
if None in [

View File

@@ -31,7 +31,7 @@ class TradeCalendarManager:
freq: str,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
level_infra: LevelInfrastructure = None,
level_infra: LevelInfrastructure | None = None,
) -> None:
"""
Parameters
@@ -99,7 +99,7 @@ class TradeCalendarManager:
def get_trade_step(self) -> int:
return self.trade_step
def get_step_time(self, trade_step: int = None, shift: int = 0) -> Tuple[pd.Timestamp, pd.Timestamp]:
def get_step_time(self, trade_step: int | None = None, shift: int = 0) -> Tuple[pd.Timestamp, pd.Timestamp]:
"""
Get the left and right endpoints of the trade_step'th trading interval

View File

@@ -56,7 +56,7 @@ class Alpha360(DataHandlerLP):
fit_start_time=None,
fit_end_time=None,
filter_pipe=None,
inst_processor=None,
inst_processors=None,
**kwargs
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
@@ -71,7 +71,7 @@ class Alpha360(DataHandlerLP):
},
"filter_pipe": filter_pipe,
"freq": freq,
"inst_processor": inst_processor,
"inst_processors": inst_processors,
},
}
@@ -152,7 +152,7 @@ class Alpha158(DataHandlerLP):
fit_end_time=None,
process_type=DataHandlerLP.PTYPE_A,
filter_pipe=None,
inst_processor=None,
inst_processors=None,
**kwargs
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
@@ -167,7 +167,7 @@ class Alpha158(DataHandlerLP):
},
"filter_pipe": filter_pipe,
"freq": freq,
"inst_processor": inst_processor,
"inst_processors": inst_processors,
},
}
super().__init__(

View File

@@ -44,7 +44,7 @@ class HighFreqHandler(DataHandlerLP):
names = []
template_if = "If(IsNull({1}), {0}, {1})"
template_paused = "Select(Gt($hx_paused_num, 1.001), {0})"
template_paused = "Select(Gt($paused_num, 1.001), {0})"
def get_normalized_price_feature(price_field, shift=0):
# norm with the close price of 237th minute of yesterday.
@@ -115,6 +115,7 @@ class HighFreqGeneralHandler(DataHandlerLP):
day_length=240,
freq="1min",
columns=["$open", "$high", "$low", "$close", "$vwap"],
inst_processors=None,
):
self.day_length = day_length
self.columns = columns
@@ -128,6 +129,7 @@ class HighFreqGeneralHandler(DataHandlerLP):
"config": self.get_feature_config(),
"swap_level": False,
"freq": freq,
"inst_processors": inst_processors,
},
}
super().__init__(
@@ -257,6 +259,7 @@ class HighFreqGeneralBacktestHandler(DataHandler):
day_length=240,
freq="1min",
columns=["$close", "$vwap", "$volume"],
inst_processors=None,
):
self.day_length = day_length
self.columns = set(columns)
@@ -266,6 +269,7 @@ class HighFreqGeneralBacktestHandler(DataHandler):
"config": self.get_feature_config(),
"swap_level": False,
"freq": freq,
"inst_processors": inst_processors,
},
}
super().__init__(
@@ -311,6 +315,7 @@ class HighFreqOrderHandler(DataHandlerLP):
learn_processors=[],
fit_start_time=None,
fit_end_time=None,
inst_processors=None,
drop_raw=True,
):
@@ -323,6 +328,7 @@ class HighFreqOrderHandler(DataHandlerLP):
"config": self.get_feature_config(),
"swap_level": False,
"freq": "1min",
"inst_processors": inst_processors,
},
}
super().__init__(
@@ -482,7 +488,7 @@ class HighFreqBacktestOrderHandler(DataHandler):
names = []
template_if = "If(IsNull({1}), {0}, {1})"
template_paused = "Select(Gt($hx_paused_num, 1.001), {0})"
template_paused = "Select(Gt($paused_num, 1.001), {0})"
template_fillnan = "FFillNan({0})"
fields += [
template_fillnan.format(template_paused.format("$close")),

View File

@@ -128,7 +128,7 @@ class HighFreqProvider:
raise ValueError("Must specify the path to save the dataset.") from e
if os.path.isfile(path):
start = time.time()
self.logger.info("Dataset exists, load from disk.", __name__)
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
# res = dataset.prepare(['train', 'valid', 'test'])
with open(path, "rb") as f:
@@ -137,11 +137,11 @@ class HighFreqProvider:
res = [data[i] for i in datasets]
else:
res = data.prepare(datasets)
self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}")
else:
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
self.logger.info("Generating dataset", __name__)
self.logger.info(f"[{__name__}]Generating dataset")
start_time = time.time()
self._prepare_calender_cache()
dataset = init_instance_by_config(config)
@@ -160,7 +160,7 @@ class HighFreqProvider:
with open(path[:-4] + "test.pkl", "wb") as f:
pkl.dump(testset, f)
res = [data[i] for i in datasets]
self.logger.info(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
self.logger.info(f"[{__name__}]Data generated, time cost: {(time.time() - start_time):.2f}")
return res
def _gen_data(self, config, datasets=["train", "valid", "test"]):
@@ -170,7 +170,7 @@ class HighFreqProvider:
raise ValueError("Must specify the path to save the dataset.") from e
if os.path.isfile(path):
start = time.time()
self.logger.info("Dataset exists, load from disk.", __name__)
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
# res = dataset.prepare(['train', 'valid', 'test'])
with open(path, "rb") as f:
@@ -179,18 +179,18 @@ class HighFreqProvider:
res = [data[i] for i in datasets]
else:
res = data.prepare(datasets)
self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}")
else:
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
self.logger.info("Generating dataset", __name__)
self.logger.info(f"[{__name__}]Generating dataset")
start_time = time.time()
self._prepare_calender_cache()
dataset = init_instance_by_config(config)
dataset.config(dump_all=True, recursive=True)
dataset.to_pickle(path)
res = dataset.prepare(datasets)
self.logger.info(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
self.logger.info(f"[{__name__}]Data generated, time cost: {(time.time() - start_time):.2f}")
return res
def _gen_dataset(self, config):
@@ -200,21 +200,21 @@ class HighFreqProvider:
raise ValueError("Must specify the path to save the dataset.") from e
if os.path.isfile(path):
start = time.time()
self.logger.info("Dataset exists, load from disk.", __name__)
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
with open(path, "rb") as f:
dataset = pkl.load(f)
self.logger.info(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}")
else:
start = time.time()
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
self.logger.info("Generating dataset", __name__)
self.logger.info(f"[{__name__}]Generating dataset")
self._prepare_calender_cache()
dataset = init_instance_by_config(config)
self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}")
dataset.prepare(["train", "valid", "test"])
self.logger.info(f"Dataset prepared, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"[{__name__}]Dataset prepared, time cost: {time.time() - start:.2f}")
dataset.config(dump_all=True, recursive=True)
dataset.to_pickle(path)
return dataset
@@ -227,15 +227,15 @@ class HighFreqProvider:
if os.path.isfile(path + "tmp_dataset.pkl"):
start = time.time()
self.logger.info("Dataset exists, load from disk.", __name__)
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
else:
start = time.time()
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
self.logger.info("Generating dataset", __name__)
self.logger.info(f"[{__name__}]Generating dataset")
self._prepare_calender_cache()
dataset = init_instance_by_config(config)
self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}")
dataset.config(dump_all=False, recursive=True)
dataset.to_pickle(path + "tmp_dataset.pkl")
@@ -268,15 +268,15 @@ class HighFreqProvider:
if os.path.isfile(path + "tmp_dataset.pkl"):
start = time.time()
self.logger.info("Dataset exists, load from disk.", __name__)
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
else:
start = time.time()
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
self.logger.info("Generating dataset", __name__)
self.logger.info(f"[{__name__}]Generating dataset")
self._prepare_calender_cache()
dataset = init_instance_by_config(config)
self.logger.info(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
self.logger.info(f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}")
dataset.config(dump_all=False, recursive=True)
dataset.to_pickle(path + "tmp_dataset.pkl")

View File

@@ -70,7 +70,7 @@ class DayCumsum(ElemOperator):
Otherwise, the value is zero.
"""
def __init__(self, feature, start: str = "9:30", end: str = "14:59"):
def __init__(self, feature, start: str = "9:30", end: str = "14:59", data_granularity: int = 1):
self.feature = feature
self.start = datetime.strptime(start, "%H:%M")
self.end = datetime.strptime(end, "%H:%M")
@@ -80,15 +80,17 @@ class DayCumsum(ElemOperator):
self.noon_open = datetime.strptime("13:00", "%H:%M")
self.noon_close = datetime.strptime("15:00", "%H:%M")
self.start_id = time_to_day_index(self.start)
self.end_id = time_to_day_index(self.end)
self.data_granularity = data_granularity
self.start_id = time_to_day_index(self.start) // self.data_granularity
self.end_id = time_to_day_index(self.end) // self.data_granularity
assert 240 % self.data_granularity == 0
def period_cusum(self, df):
df = df.copy()
assert len(df) == 240
assert len(df) == 240 // self.data_granularity
df.iloc[0 : self.start_id] = 0
df = df.cumsum()
df.iloc[self.end_id + 1 : 240] = 0
df.iloc[self.end_id + 1 : 240 // self.data_granularity] = 0
return df
def _load_internal(self, instrument, start_index, end_index, freq):

View File

@@ -7,6 +7,7 @@ from typing import Callable, Union, Tuple, List, Iterator, Optional
import pandas as pd
from qlib.typehint import Literal
from ...log import get_module_logger, TimeInspector
from ...utils import init_instance_by_config
from ...utils.serial import Serializable
@@ -49,6 +50,8 @@ class DataHandler(Serializable):
- Fetching data with `col_set=CS_RAW` will return the raw data and may avoid pandas from copying the data when calling `loc`
"""
_data: pd.DataFrame # underlying data.
def __init__(
self,
instruments=None,
@@ -155,6 +158,11 @@ class DataHandler(Serializable):
"""
fetch data from underlying data source
Design motivation:
- providing a unified interface for underlying data.
- Potential to make the interface more friendly.
- User can improve performance when fetching data in this extra layer
Parameters
----------
selector : Union[pd.Timestamp, slice, str]
@@ -328,6 +336,9 @@ class DataHandler(Serializable):
yield cur_date, self.fetch(selector, **kwargs)
DATA_KEY_TYPE = Literal["raw", "infer", "learn"]
class DataHandlerLP(DataHandler):
"""
DataHandler with **(L)earnable (P)rocessor**
@@ -346,17 +357,28 @@ class DataHandlerLP(DataHandler):
- These processors only apply to the learning phase.
Tips to improve the performance of data handler
Tips for data handler
- To reduce the memory cost
- `drop_raw=True`: this will modify the data inplace on raw data;
- Please note processed data like `self._infer` or `self._learn` are concepts different from `segments` in Qlib's `Dataset` like "train" and "test"
- Processed data like `self._infer` or `self._learn` are underlying data processed with different processors
- `segments` in Qlib's `Dataset` like "train" and "test" are simply the time segmentations when querying data("train" are often before "test" in time-series).
- For example, you can query `data._infer` processed by `infer_processors` in the "train" time segmentation.
"""
# based on `self._data`, _infer and _learn are genrated after processors
_infer: pd.DataFrame # data for inference
_learn: pd.DataFrame # data for learning models
# data key
DK_R = "raw"
DK_I = "infer"
DK_L = "learn"
DK_R: DATA_KEY_TYPE = "raw"
DK_I: DATA_KEY_TYPE = "infer"
DK_L: DATA_KEY_TYPE = "learn"
# map data_key to attribute name
ATTR_MAP = {DK_R: "_data", DK_I: "_infer", DK_L: "_learn"}
# process type
@@ -600,7 +622,7 @@ class DataHandlerLP(DataHandler):
# TODO: Be able to cache handler data. Save the memory for data processing
def _get_df_by_key(self, data_key: str = DK_I) -> pd.DataFrame:
def _get_df_by_key(self, data_key: DATA_KEY_TYPE = DK_I) -> pd.DataFrame:
if data_key == self.DK_R and self.drop_raw:
raise AttributeError(
"DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data"
@@ -613,7 +635,7 @@ class DataHandlerLP(DataHandler):
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
level: Union[str, int] = "datetime",
col_set=DataHandler.CS_ALL,
data_key: str = DK_I,
data_key: DATA_KEY_TYPE = DK_I,
squeeze: bool = False,
proc_func: Callable = None,
) -> pd.DataFrame:
@@ -647,7 +669,7 @@ class DataHandlerLP(DataHandler):
proc_func=proc_func,
)
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list:
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: DATA_KEY_TYPE = DK_I) -> list:
"""
get the column names
@@ -655,7 +677,7 @@ class DataHandlerLP(DataHandler):
----------
col_set : str
select a set of meaningful columns.(e.g. features, columns).
data_key : str
data_key : DATA_KEY_TYPE
the data to fetch: DK_*.
Returns

View File

@@ -153,7 +153,7 @@ class QlibDataLoader(DLWParser):
filter_pipe: List = None,
swap_level: bool = True,
freq: Union[str, dict] = "day",
inst_processor: dict = None,
inst_processors: Union[dict, list] = None,
):
"""
Parameters
@@ -167,16 +167,19 @@ class QlibDataLoader(DLWParser):
freq: dict or str
If type(config) == dict and type(freq) == str, load config data using freq.
If type(config) == dict and type(freq) == dict, load config[<group_name>] data using freq[<group_name>]
inst_processor: dict
If inst_processor is not None and type(config) == dict; load config[<group_name>] data using inst_processor[<group_name>]
inst_processors: dict | list
If inst_processors is not None and type(config) == dict; load config[<group_name>] data using inst_processors[<group_name>]
If inst_processors is a list, then it will be applied to all groups.
"""
self.filter_pipe = filter_pipe
self.swap_level = swap_level
self.freq = freq
# sample
self.inst_processor = inst_processor if inst_processor is not None else {}
assert isinstance(self.inst_processor, dict), f"inst_processor(={self.inst_processor}) must be dict"
self.inst_processors = inst_processors if inst_processors is not None else {}
assert isinstance(
self.inst_processors, (dict, list)
), f"inst_processors(={self.inst_processors}) must be dict or list"
super().__init__(config)
@@ -187,8 +190,8 @@ class QlibDataLoader(DLWParser):
if _gp not in freq:
raise ValueError(f"freq(={freq}) missing group(={_gp})")
assert (
self.inst_processor
), f"freq(={self.freq}), inst_processor(={self.inst_processor}) cannot be None/empty"
self.inst_processors
), f"freq(={self.freq}), inst_processors(={self.inst_processors}) cannot be None/empty"
def load_group_df(
self,
@@ -208,9 +211,10 @@ class QlibDataLoader(DLWParser):
warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")
freq = self.freq[gp_name] if isinstance(self.freq, dict) else self.freq
df = D.features(
instruments, exprs, start_time, end_time, freq=freq, inst_processors=self.inst_processor.get(gp_name, [])
inst_processors = (
self.inst_processors if isinstance(self.inst_processors, list) else self.inst_processors.get(gp_name, [])
)
df = D.features(instruments, exprs, start_time, end_time, freq=freq, inst_processors=inst_processors)
df.columns = names
if self.swap_level:
df = df.swaplevel().sort_index() # NOTE: if swaplevel, return <datetime, instrument>

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
import abc
from typing import Union, Text
from typing import Union, Text, Optional
import numpy as np
import pandas as pd
@@ -11,6 +11,8 @@ from ...constant import EPS
from .utils import fetch_df_by_index
from ...utils.serial import Serializable
from ...utils.paral import datetime_groupby_apply
from qlib.data.inst_processor import InstProcessor
from qlib.data import D
def get_group_columns(df: pd.DataFrame, group: Union[Text, None]):
@@ -378,3 +380,42 @@ class HashStockFormat(Processor):
from .storage import HashingStockStorage # pylint: disable=C0415
return HashingStockStorage.from_df(df)
class TimeRangeFlt(InstProcessor):
"""
This is a filter to filter stock.
Only keep the data that exist from start_time to end_time (the existence in the middle is not checked.)
WARNING: It may induce leakage!!!
"""
def __init__(
self,
start_time: Optional[Union[pd.Timestamp, str]] = None,
end_time: Optional[Union[pd.Timestamp, str]] = None,
freq: str = "day",
):
"""
Parameters
----------
start_time : Optional[Union[pd.Timestamp, str]]
The data must start earlier (or equal) than `start_time`
None indicates data will not be filtered based on `start_time`
end_time : Optional[Union[pd.Timestamp, str]]
similar to start_time
freq : str
The frequency of the calendar
"""
# Align to calendar before filtering
cal = D.calendar(start_time=start_time, end_time=end_time, freq=freq)
self.start_time = None if start_time is None else cal[0]
self.end_time = None if end_time is None else cal[-1]
def __call__(self, df: pd.DataFrame, instrument, *args, **kwargs):
if (
df.empty
or (self.start_time is None or df.index.min() <= self.start_time)
and (self.end_time is None or df.index.max() >= self.end_time)
):
return df
return df.head(0)

View File

@@ -28,14 +28,14 @@ from qlib.typehint import Literal
def _get_multi_level_executor_config(
strategy_config: dict,
cash_limit: float = None,
cash_limit: float | None = None,
generate_report: bool = False,
) -> dict:
executor_config = {
"class": "SimulatorExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": "1min",
"time_per_step": "5min", # FIXME: move this into config
"verbose": False,
"trade_type": SimulatorExecutor.TT_PARAL if cash_limit is not None else SimulatorExecutor.TT_SERIAL,
"generate_report": generate_report,
@@ -127,7 +127,7 @@ def single_with_simulator(
backtest_config: dict,
orders: pd.DataFrame,
split: Literal["stock", "day"] = "stock",
cash_limit: float = None,
cash_limit: float | None = None,
generate_report: bool = False,
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
"""Run backtest in a single thread with SingleAssetOrderExecution simulator. The orders will be executed day by day.
@@ -187,7 +187,7 @@ def single_with_simulator(
exchange_config.update(
{
"codes": stocks,
"freq": "1min",
"freq": "5min", # FIXME: move this into config
}
)
@@ -226,7 +226,7 @@ def single_with_collect_data_loop(
backtest_config: dict,
orders: pd.DataFrame,
split: Literal["stock", "day"] = "stock",
cash_limit: float = None,
cash_limit: float | None = None,
generate_report: bool = False,
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
"""Run backtest in a single thread with collect_data_loop.
@@ -286,7 +286,7 @@ def single_with_collect_data_loop(
exchange_config.update(
{
"codes": stocks,
"freq": "1min",
"freq": "5min", # FIXME: move this into config
}
)
@@ -357,7 +357,10 @@ def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFram
if not output_path.exists():
os.makedirs(output_path)
res.to_csv(output_path / "summary.csv")
if "pa" in res.columns:
res["pa"] = res["pa"] * 10000.0 # align with training metrics
res.to_csv(output_path / "backtest_result.csv")
return res

View File

@@ -98,7 +98,7 @@ def get_backtest_config_fromfile(path: str) -> dict:
"debug_single_day": None,
"concurrency": -1,
"multiplier": 1.0,
"output_dir": "outputs/",
"output_dir": "outputs_backtest/",
"generate_report": False,
}
backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)

View File

@@ -3,6 +3,7 @@
import argparse
import os
import random
import warnings
from pathlib import Path
from typing import cast, List, Optional
@@ -23,7 +24,6 @@ from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
from qlib.rl.utils.log import CsvWriter
from qlib.utils import init_instance_by_config
from tianshou.policy import BasePolicy
from torch import nn
from torch.utils.data import Dataset
@@ -101,6 +101,7 @@ def train_and_test(
action_interpreter: ActionInterpreter,
policy: BasePolicy,
reward: Reward,
run_training: bool,
run_backtest: bool,
) -> None:
qlib.init()
@@ -122,62 +123,67 @@ def train_and_test(
assert data_config["source"]["default_start_time_index"] % data_granularity == 0
assert data_config["source"]["default_end_time_index"] % data_granularity == 0
train_dataset, valid_dataset, test_dataset = [
LazyLoadDataset(
order_file_path=order_root_path / tag,
if run_training:
train_dataset, valid_dataset = [
LazyLoadDataset(
order_file_path=order_root_path / tag,
data_dir=Path(data_config["source"]["data_dir"]),
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
)
for tag in ("train", "valid")
]
callbacks: List[Callback] = []
if "checkpoint_path" in trainer_config:
callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"])))
callbacks.append(
Checkpoint(
dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints",
every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1),
save_latest="copy",
),
)
if "earlystop_patience" in trainer_config:
callbacks.append(
EarlyStopping(
patience=trainer_config["earlystop_patience"],
monitor="val/pa",
)
)
train(
simulator_fn=_simulator_factory_simple,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
policy=policy,
reward=reward,
initial_states=cast(List[Order], train_dataset),
trainer_kwargs={
"max_iters": trainer_config["max_epoch"],
"finite_env_type": env_config["parallel_mode"],
"concurrency": env_config["concurrency"],
"val_every_n_iters": trainer_config.get("val_every_n_epoch", None),
"callbacks": callbacks,
},
vessel_kwargs={
"episode_per_iter": trainer_config["episode_per_collect"],
"update_kwargs": {
"batch_size": trainer_config["batch_size"],
"repeat": trainer_config["repeat_per_collect"],
},
"val_initial_states": valid_dataset,
},
)
if run_backtest:
test_dataset = LazyLoadDataset(
order_file_path=order_root_path / "test",
data_dir=Path(data_config["source"]["data_dir"]),
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
)
for tag in ("train", "valid", "test")
]
if "checkpoint_path" in trainer_config:
callbacks: List[Callback] = []
callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"])))
callbacks.append(
Checkpoint(
dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints",
every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1),
save_latest="copy",
),
)
if "earlystop_patience" in trainer_config:
callbacks.append(
EarlyStopping(
patience=trainer_config["earlystop_patience"],
monitor="val/pa",
)
)
trainer_kwargs = {
"max_iters": trainer_config["max_epoch"],
"finite_env_type": env_config["parallel_mode"],
"concurrency": env_config["concurrency"],
"val_every_n_iters": trainer_config.get("val_every_n_epoch", None),
"callbacks": callbacks,
}
vessel_kwargs = {
"episode_per_iter": trainer_config["episode_per_collect"],
"update_kwargs": {
"batch_size": trainer_config["batch_size"],
"repeat": trainer_config["repeat_per_collect"],
},
"val_initial_states": valid_dataset,
}
train(
simulator_fn=_simulator_factory_simple,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
policy=policy,
reward=reward,
initial_states=cast(List[Order], train_dataset),
trainer_kwargs=trainer_kwargs,
vessel_kwargs=vessel_kwargs,
)
if run_backtest:
backtest(
simulator_fn=_simulator_factory_simple,
state_interpreter=state_interpreter,
@@ -186,35 +192,39 @@ def train_and_test(
policy=policy,
logger=CsvWriter(Path(trainer_config["checkpoint_path"])),
reward=reward,
finite_env_type=trainer_kwargs["finite_env_type"],
concurrency=trainer_kwargs["concurrency"],
finite_env_type=env_config["parallel_mode"],
concurrency=env_config["concurrency"],
)
def main(config: dict, run_backtest: bool) -> None:
def main(config: dict, run_training: bool, run_backtest: bool) -> None:
if not run_training and not run_backtest:
warnings.warn("Skip the entire job since training and backtest are both skipped.")
return
if "seed" in config["runtime"]:
seed_everything(config["runtime"]["seed"])
state_config = config["state_interpreter"]
state_interpreter: StateInterpreter = init_instance_by_config(state_config)
state_interpreter: StateInterpreter = init_instance_by_config(config["state_interpreter"])
action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"])
reward: Reward = init_instance_by_config(config["reward"])
additional_policy_kwargs = {
"obs_space": state_interpreter.observation_space,
"action_space": action_interpreter.action_space,
}
# Create torch network
if "kwargs" not in config["network"]:
config["network"]["kwargs"] = {}
config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space})
network: nn.Module = init_instance_by_config(config["network"])
if "network" in config:
if "kwargs" not in config["network"]:
config["network"]["kwargs"] = {}
config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space})
additional_policy_kwargs["network"] = init_instance_by_config(config["network"])
# Create policy
config["policy"]["kwargs"].update(
{
"network": network,
"obs_space": state_interpreter.observation_space,
"action_space": action_interpreter.action_space,
}
)
if "kwargs" not in config["policy"]:
config["policy"]["kwargs"] = {}
config["policy"]["kwargs"].update(additional_policy_kwargs)
policy: BasePolicy = init_instance_by_config(config["policy"])
use_cuda = config["runtime"].get("use_cuda", False)
@@ -230,22 +240,22 @@ def main(config: dict, run_backtest: bool) -> None:
state_interpreter=state_interpreter,
policy=policy,
reward=reward,
run_training=run_training,
run_backtest=run_backtest,
)
if __name__ == "__main__":
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow after training is finished")
parser.add_argument("--no_training", action="store_true", help="Skip training workflow.")
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.")
args = parser.parse_args()
with open(args.config_path, "r") as input_stream:
config = yaml.safe_load(input_stream)
main(config, run_backtest=args.run_backtest)
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)

View File

@@ -49,7 +49,7 @@ class DataWrapper:
return dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
def init_qlib(qlib_config: dict, part: str = None) -> None:
def init_qlib(qlib_config: dict, part: str | None = None) -> None:
"""Initialize necessary resource to launch the workflow, including data direction, feature columns, etc..
Parameters
@@ -82,10 +82,9 @@ def init_qlib(qlib_config: dict, part: str = None) -> None:
return path if isinstance(path, Path) else Path(path)
provider_uri_map = {}
if "provider_uri_day" in qlib_config:
provider_uri_map["day"] = _convert_to_path(qlib_config["provider_uri_day"]).as_posix()
if "provider_uri_1min" in qlib_config:
provider_uri_map["1min"] = _convert_to_path(qlib_config["provider_uri_1min"]).as_posix()
for granularity in ["1min", "5min", "day"]:
if f"provider_uri_{granularity}" in qlib_config:
provider_uri_map[f"{granularity}"] = _convert_to_path(qlib_config[f"provider_uri_{granularity}"]).as_posix()
qlib.init(
region=REG_CN,

View File

@@ -104,7 +104,7 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData):
stock_id: str,
date: pd.Timestamp,
deal_price: DealPriceType = "close",
order_dir: int = None,
order_dir: int | None = None,
) -> None:
super(SimpleIntradayBacktestData, self).__init__()
@@ -208,7 +208,7 @@ def load_simple_intraday_backtest_data(
stock_id: str,
date: pd.Timestamp,
deal_price: DealPriceType = "close",
order_dir: int = None,
order_dir: int | None = None,
) -> SimpleIntradayBacktestData:
return SimpleIntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)

View File

@@ -53,6 +53,18 @@ class FullHistoryObs(TypedDict):
position_history: Any
class DummyStateInterpreter(StateInterpreter[SAOEState, dict]):
"""Dummy interpreter for policies that do not need inputs (for example, AllOne)."""
def interpret(self, state: SAOEState) -> dict:
# TODO: A fake state, used to pass `check_nan_observation`. Find a better way in the future.
return {"DUMMY": _to_int32(1)}
@property
def observation_space(self) -> spaces.Dict:
return spaces.Dict({"DUMMY": spaces.Box(-np.inf, np.inf, shape=(), dtype=np.int32)})
class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
"""The observation of all the history, including today (until this moment), and yesterday.

View File

@@ -12,11 +12,11 @@ import torch
import torch.nn as nn
from gym.spaces import Discrete
from tianshou.data import Batch, ReplayBuffer, to_torch
from tianshou.policy import BasePolicy, PPOPolicy
from tianshou.policy import BasePolicy, PPOPolicy, DQNPolicy
from qlib.rl.trainer.trainer import Trainer
__all__ = ["AllOne", "PPO"]
__all__ = ["AllOne", "PPO", "DQN"]
# baselines #
@@ -32,7 +32,7 @@ class NonLearnablePolicy(BasePolicy):
super().__init__()
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, Any]:
pass
return {}
def process_fn(
self,
@@ -40,7 +40,7 @@ class NonLearnablePolicy(BasePolicy):
buffer: ReplayBuffer,
indices: np.ndarray,
) -> Batch:
pass
return Batch({})
class AllOne(NonLearnablePolicy):
@@ -49,13 +49,18 @@ class AllOne(NonLearnablePolicy):
Useful when implementing some baselines (e.g., TWAP).
"""
def __init__(self, obs_space: gym.Space, action_space: gym.Space, fill_value: float | int = 1.0) -> None:
super().__init__(obs_space, action_space)
self.fill_value = fill_value
def forward(
self,
batch: Batch,
state: dict | Batch | np.ndarray = None,
**kwargs: Any,
) -> Batch:
return Batch(act=np.full(len(batch), 1.0), state=state)
return Batch(act=np.full(len(batch), self.fill_value), state=state)
# ppo #
@@ -153,6 +158,56 @@ class PPO(PPOPolicy):
set_weight(self, Trainer.get_policy_state_dict(weight_file))
DQNModel = PPOActor # Reuse PPOActor.
class DQN(DQNPolicy):
"""A wrapper of tianshou DQNPolicy.
Differences:
- Auto-create model network. Supports discrete action space only.
- Support a ``weight_file`` that supports loading checkpoint.
"""
def __init__(
self,
network: nn.Module,
obs_space: gym.Space,
action_space: gym.Space,
lr: float,
weight_decay: float = 0.0,
discount_factor: float = 0.99,
estimation_step: int = 1,
target_update_freq: int = 0,
reward_normalization: bool = False,
is_double: bool = True,
clip_loss_grad: bool = False,
weight_file: Optional[Path] = None,
) -> None:
assert isinstance(action_space, Discrete)
model = DQNModel(network, action_space.n)
optimizer = torch.optim.Adam(
model.parameters(),
lr=lr,
weight_decay=weight_decay,
)
super().__init__(
model,
optimizer,
discount_factor=discount_factor,
estimation_step=estimation_step,
target_update_freq=target_update_freq,
reward_normalization=reward_normalization,
is_double=is_double,
clip_loss_grad=clip_loss_grad,
)
if weight_file is not None:
set_weight(self, Trainer.get_policy_state_dict(weight_file))
# utilities: these should be put in a separate (common) file. #

View File

@@ -7,6 +7,7 @@ from typing import cast
import numpy as np
from qlib.backtest.decision import OrderDir
from qlib.rl.order_execution.state import SAOEMetrics, SAOEState
from qlib.rl.reward import Reward
@@ -47,3 +48,52 @@ class PAPenaltyReward(Reward[SAOEState]):
self.log("reward/pa", pa)
self.log("reward/penalty", penalty)
return reward * self.scale
class PPOReward(Reward[SAOEState]):
"""Reward proposed by paper "An End-to-End Optimal Trade Execution Framework based on Proximal Policy Optimization".
Parameters
----------
max_step
Maximum number of steps.
start_time_index
First time index that allowed to trade.
end_time_index
Last time index that allowed to trade.
"""
def __init__(self, max_step: int, start_time_index: int = 0, end_time_index: int = 239) -> None:
self.max_step = max_step
self.start_time_index = start_time_index
self.end_time_index = end_time_index
def reward(self, simulator_state: SAOEState) -> float:
if simulator_state.cur_step == self.max_step - 1 or simulator_state.position < 1e-6:
if simulator_state.history_exec["deal_amount"].sum() == 0.0:
vwap_price = cast(
float,
np.average(simulator_state.history_exec["market_price"]),
)
else:
vwap_price = cast(
float,
np.average(
simulator_state.history_exec["market_price"],
weights=simulator_state.history_exec["deal_amount"],
),
)
twap_price = simulator_state.backtest_data.get_deal_price().mean()
if simulator_state.order.direction == OrderDir.SELL:
ratio = vwap_price / twap_price if twap_price != 0 else 1.0
else:
ratio = twap_price / vwap_price if vwap_price != 0 else 1.0
if ratio < 1.0:
return -1.0
elif ratio < 1.1:
return 0.0
else:
return 1.0
else:
return 0.0

View File

@@ -38,8 +38,8 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
order: Order,
executor_config: dict,
exchange_config: dict,
qlib_config: dict = None,
cash_limit: Optional[float] = None,
qlib_config: dict | None = None,
cash_limit: float | None = None,
) -> None:
super().__init__(initial=order)
@@ -63,7 +63,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
strategy_config: dict,
executor_config: dict,
exchange_config: dict,
qlib_config: dict = None,
qlib_config: dict | None = None,
cash_limit: Optional[float] = None,
) -> None:
if qlib_config is not None:

View File

@@ -7,6 +7,7 @@ import collections
from types import GeneratorType
from typing import Any, Callable, cast, Dict, Generator, List, Optional, Tuple, Union
import warnings
import numpy as np
import pandas as pd
import torch
@@ -89,6 +90,7 @@ class SAOEStateAdapter:
exchange: Exchange,
ticks_per_step: int,
backtest_data: IntradayBacktestData,
data_granularity: int = 1,
) -> None:
self.position = order.amount
self.order = order
@@ -106,11 +108,13 @@ class SAOEStateAdapter:
self.cur_time = max(backtest_data.ticks_for_order[0], order.start_time)
self.ticks_per_step = ticks_per_step
self.data_granularity = data_granularity
assert self.ticks_per_step % self.data_granularity == 0
def _next_time(self) -> pd.Timestamp:
current_loc = self.backtest_data.ticks_index.get_loc(self.cur_time)
next_loc = current_loc + self.ticks_per_step
next_loc = next_loc - next_loc % self.ticks_per_step
next_loc = current_loc + (self.ticks_per_step // self.data_granularity)
next_loc = next_loc - next_loc % (self.ticks_per_step // self.data_granularity)
if (
next_loc < len(self.backtest_data.ticks_index)
and self.backtest_data.ticks_index[next_loc] < self.order.end_time
@@ -130,11 +134,16 @@ class SAOEStateAdapter:
exec_vol = np.zeros(last_step_size)
for order, _, __, ___ in execute_result:
idx, _ = get_day_min_idx_range(order.start_time, order.end_time, "1min", REG_CN)
idx, _ = get_day_min_idx_range(order.start_time, order.end_time, f"{self.data_granularity}min", REG_CN)
exec_vol[idx - last_step_range[0]] = order.deal_amount
if exec_vol.sum() > self.position and exec_vol.sum() > 0.0:
assert exec_vol.sum() < self.position + 1, f"{exec_vol} too large"
if exec_vol.sum() > self.position + 1.0:
warnings.warn(
f"Sum of execution volume is {exec_vol.sum()} which is larger than "
f"position + 1.0 = {self.position} + 1.0 = {self.position + 1.0}. "
f"All execution volume is scaled down linearly to ensure that their sum does not position."
)
exec_vol *= self.position / (exec_vol.sum())
market_volume = cast(
@@ -168,7 +177,9 @@ class SAOEStateAdapter:
self.history_exec,
self._collect_multi_order_metric(
order=self.order,
datetime=_get_all_timestamps(start_time, end_time, include_end=True),
datetime=_get_all_timestamps(
start_time, end_time, include_end=True, granularity=ONE_MIN * self.data_granularity
),
market_vol=market_volume,
market_price=market_price,
exec_vol=exec_vol,
@@ -293,9 +304,10 @@ class SAOEStrategy(RLStrategy):
def __init__(
self,
policy: BasePolicy,
outer_trade_decision: BaseTradeDecision = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
outer_trade_decision: BaseTradeDecision | None = None,
level_infra: LevelInfrastructure | None = None,
common_infra: CommonInfrastructure | None = None,
data_granularity: int = 1,
**kwargs: Any,
) -> None:
super(SAOEStrategy, self).__init__(
@@ -306,6 +318,7 @@ class SAOEStrategy(RLStrategy):
**kwargs,
)
self._data_granularity = data_granularity
self.adapter_dict: Dict[tuple, SAOEStateAdapter] = {}
self._last_step_range = (0, 0)
@@ -324,9 +337,10 @@ class SAOEStrategy(RLStrategy):
exchange=self.trade_exchange,
ticks_per_step=int(pd.Timedelta(self.trade_calendar.get_freq()) / ONE_MIN),
backtest_data=backtest_data,
data_granularity=self._data_granularity,
)
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None:
def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None:
super(SAOEStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
self.adapter_dict = {}
@@ -366,7 +380,7 @@ class SAOEStrategy(RLStrategy):
def generate_trade_decision(
self,
execute_result: list = None,
execute_result: list | None = None,
) -> Union[BaseTradeDecision, Generator[Any, Any, BaseTradeDecision]]:
"""
For SAOEStrategy, we need to update the `self._last_step_range` every time a decision is generated.
@@ -385,7 +399,7 @@ class SAOEStrategy(RLStrategy):
def _generate_trade_decision(
self,
execute_result: list = None,
execute_result: list | None = None,
) -> Union[BaseTradeDecision, Generator[Any, Any, BaseTradeDecision]]:
raise NotImplementedError
@@ -399,14 +413,14 @@ class ProxySAOEStrategy(SAOEStrategy):
def __init__(
self,
outer_trade_decision: BaseTradeDecision = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
outer_trade_decision: BaseTradeDecision | None = None,
level_infra: LevelInfrastructure | None = None,
common_infra: CommonInfrastructure | None = None,
**kwargs: Any,
) -> None:
super().__init__(None, outer_trade_decision, level_infra, common_infra, **kwargs)
def _generate_trade_decision(self, execute_result: list = None) -> Generator[Any, Any, BaseTradeDecision]:
def _generate_trade_decision(self, execute_result: list | None = None) -> Generator[Any, Any, BaseTradeDecision]:
# Once the following line is executed, this ProxySAOEStrategy (self) will be yielded to the outside
# of the entire executor, and the execution will be suspended. When the execution is resumed by `send()`,
# the item will be captured by `exec_vol`. The outside policy could communicate with the inner
@@ -418,7 +432,7 @@ class ProxySAOEStrategy(SAOEStrategy):
return TradeDecisionWO([order], self)
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None:
def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None:
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
assert isinstance(outer_trade_decision, TradeDecisionWO)
@@ -437,9 +451,9 @@ class SAOEIntStrategy(SAOEStrategy):
state_interpreter: dict | StateInterpreter,
action_interpreter: dict | ActionInterpreter,
network: dict | torch.nn.Module | None = None,
outer_trade_decision: BaseTradeDecision = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
outer_trade_decision: BaseTradeDecision | None = None,
level_infra: LevelInfrastructure | None = None,
common_infra: CommonInfrastructure | None = None,
**kwargs: Any,
) -> None:
super(SAOEIntStrategy, self).__init__(
@@ -488,7 +502,7 @@ class SAOEIntStrategy(SAOEStrategy):
if self._policy is not None:
self._policy.eval()
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None:
def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None:
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
def _generate_trade_details(self, act: np.ndarray, exec_vols: List[float]) -> pd.DataFrame:
@@ -508,7 +522,7 @@ class SAOEIntStrategy(SAOEStrategy):
trade_details[-1]["rl_action"] = a
return pd.DataFrame.from_records(trade_details)
def _generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision:
def _generate_trade_decision(self, execute_result: list | None = None) -> BaseTradeDecision:
states = []
obs_batch = []
for decision in self.outer_trade_decision.get_decision():

View File

@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from qlib.backtest import Order
from qlib.backtest.decision import OrderHelper, TradeDecisionWO, TradeRange
from qlib.strategy.base import BaseStrategy
@@ -12,14 +14,14 @@ class SingleOrderStrategy(BaseStrategy):
def __init__(
self,
order: Order,
trade_range: TradeRange = None,
trade_range: TradeRange | None = None,
) -> None:
super().__init__()
self._order = order
self._trade_range = trade_range
def generate_trade_decision(self, execute_result: list = None) -> TradeDecisionWO:
def generate_trade_decision(self, execute_result: list | None = None) -> TradeDecisionWO:
oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper()
order_list = [
oh.create(

View File

@@ -4,6 +4,7 @@
from __future__ import annotations
import multiprocessing
from multiprocessing.sharedctypes import Synchronized
import os
import threading
import time
@@ -78,7 +79,9 @@ class DataQueue(Generic[T]):
self._activated: bool = False
self._queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=queue_maxsize)
self._done = multiprocessing.Value("i", 0)
# Mypy 0.981 brought '"SynchronizedBase[Any]" has no attribute "value" [attr-defined]' bug.
# Therefore, add this type casting to pass Mypy checking.
self._done = cast(Synchronized, multiprocessing.Value("i", 0))
def __enter__(self) -> DataQueue:
self.activate()
@@ -122,7 +125,7 @@ class DataQueue(Generic[T]):
if self._done.value:
raise StopIteration # pylint: disable=raise-missing-from
def put(self, obj: Any, block: bool = True, timeout: int = None) -> None:
def put(self, obj: Any, block: bool = True, timeout: int | None = None) -> None:
self._queue.put(obj, block=block, timeout=timeout)
def mark_as_done(self) -> None:

View File

@@ -99,9 +99,9 @@ class EnvWrapper(
state_interpreter: StateInterpreter[StateType, ObsType],
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
seed_iterator: Optional[Iterable[InitialStateType]],
reward_fn: Reward = None,
aux_info_collector: AuxiliaryInfoCollector[StateType, Any] = None,
logger: LogCollector = None,
reward_fn: Reward | None = None,
aux_info_collector: AuxiliaryInfoCollector[StateType, Any] | None = None,
logger: LogCollector | None = None,
) -> None:
# Assign weak reference to wrapper.
#

View File

@@ -397,7 +397,7 @@ class ConsoleWriter(LogWriter):
def __init__(
self,
log_every_n_episode: int = 20,
total_episodes: int = None,
total_episodes: int | None = None,
float_format: str = ":.4f",
counter_format: str = ":4d",
loglevel: int | LogLevel = LogLevel.PERIODIC,

View File

@@ -224,7 +224,7 @@ def requests_with_retry(url, retry=5, **kwargs):
except Exception as e:
log.warning("exception encountered {}".format(e))
continue
raise Exception("ERROR: requests failed!")
raise TimeoutError("ERROR: requests failed!")
#################### Parse ####################
@@ -426,7 +426,8 @@ def init_instance_by_config(
# path like 'file:///<path to pickle file>/obj.pkl'
pr = urlparse(config)
if pr.scheme == "file":
with open(os.path.join(pr.netloc, pr.path), "rb") as f:
pr_path = os.path.join(pr.netloc, pr.path) if bool(pr.path) else pr.netloc
with open(pr_path, "rb") as f:
return pickle.load(f)
else:
with config.open("rb") as f:

View File

@@ -333,7 +333,7 @@ class MLflowExperiment(Experiment):
recorder = self._get_recorder(recorder_name=recorder_name)
self._client.delete_run(recorder.id)
except MlflowException as e:
raise Exception(
raise ValueError(
f"Error: {e}. Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct."
) from e

View File

@@ -415,7 +415,7 @@ class MLflowExpManager(ExpManager):
raise MlflowException("No valid experiment has been found.")
self.client.delete_experiment(experiment.experiment_id)
except MlflowException as e:
raise Exception(
raise ValueError(
f"Error: {e}. Something went wrong when deleting experiment. Please check if the name/id of the experiment is correct."
) from e

View File

@@ -324,7 +324,7 @@ class MLflowRecorder(Recorder):
raise RuntimeError("This recorder is not saved in the local file system.")
else:
raise Exception(
raise ValueError(
"Please make sure the recorder has been created and started properly before getting artifact uri."
)
@@ -464,7 +464,7 @@ class MLflowRecorder(Recorder):
if self.artifact_uri is not None:
return self.artifact_uri
else:
raise Exception(
raise ValueError(
"Please make sure the recorder has been created and started properly before getting artifact uri."
)

View File

@@ -146,6 +146,9 @@ setup(
# References: https://github.com/python/typeshed/issues/8799
"mypy<0.981",
"flake8",
"nbqa",
"jupyter",
"nbconvert",
# The 5.0.0 version of importlib-metadata removed the deprecated endpoint,
# which prevented flake8 from working properly, so we restricted the version of importlib-metadata.
# To help ensure the dependencies of flake8 https://github.com/python/importlib_metadata/issues/406