diff --git a/.github/workflows/test_qlib_from_source.yml b/.github/workflows/test_qlib_from_source.yml index d4a4b075e..220453d60 100644 --- a/.github/workflows/test_qlib_from_source.yml +++ b/.github/workflows/test_qlib_from_source.yml @@ -121,6 +121,11 @@ jobs: run: | mypy qlib --install-types --non-interactive || true mypy qlib --verbose + + - name: Check Qlib ipynb with nbqa + run: | + nbqa black . -l 120 --check --diff + nbqa pylint . --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719,W0104,W0404,C0412,W0611,C0410 --const-rgx='[a-z_][a-z0-9_]{2,30}$' - name: Test data downloads run: | @@ -139,6 +144,12 @@ jobs: brew unlink libomp brew install libomp.rb + # Run after data downloads + - name: Check Qlib ipynb with nbconvert + run: | + # add more ipynb files in future + jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb + - name: Test workflow by config (install from source) run: | python -m pip install numba diff --git a/examples/benchmarks/TRA/Reports.ipynb b/examples/benchmarks/TRA/Reports.ipynb index ee172d97e..bd1534433 100644 --- a/examples/benchmarks/TRA/Reports.ipynb +++ b/examples/benchmarks/TRA/Reports.ipynb @@ -25,59 +25,65 @@ "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "import matplotlib\n", - "sns.set(style='white')\n", - "matplotlib.rcParams['pdf.fonttype'] = 42\n", - "matplotlib.rcParams['ps.fonttype'] = 42\n", + "\n", + "sns.set(style=\"white\")\n", + "matplotlib.rcParams[\"pdf.fonttype\"] = 42\n", + "matplotlib.rcParams[\"ps.fonttype\"] = 42\n", "\n", "from tqdm.auto import tqdm\n", "from joblib import Parallel, delayed\n", "\n", + "\n", "def func(x, N=80):\n", " ret = x.ret.copy()\n", " x = x.rank(pct=True)\n", - " x['ret'] = ret\n", + " x[\"ret\"] = ret\n", " diff = x.score.sub(x.label)\n", - " r = x.nlargest(N, columns='score').ret.mean()\n", - " r -= x.nsmallest(N, columns='score').ret.mean()\n", - " return pd.Series({\n", - " 'MSE': diff.pow(2).mean(), \n", - " 'MAE': diff.abs().mean(), \n", - " 'IC': x.score.corr(x.label),\n", - " 'R': r\n", - " })\n", - " \n", + " r = x.nlargest(N, columns=\"score\").ret.mean()\n", + " r -= x.nsmallest(N, columns=\"score\").ret.mean()\n", + " return pd.Series(\n", + " {\n", + " \"MSE\": diff.pow(2).mean(),\n", + " \"MAE\": diff.abs().mean(),\n", + " \"IC\": x.score.corr(x.label),\n", + " \"R\": r,\n", + " }\n", + " )\n", + "\n", + "\n", "ret = pd.read_pickle(\"data/ret.pkl\").clip(-0.1, 0.1)\n", + "\n", + "\n", "def backtest(fname, **kwargs):\n", - " pred = pd.read_pickle(fname).loc['2018-09-21':'2020-06-30'] # test period\n", - " pred['ret'] = ret\n", + " pred = pd.read_pickle(fname).loc[\"2018-09-21\":\"2020-06-30\"] # test period\n", + " pred[\"ret\"] = ret\n", " dates = pred.index.unique(level=0)\n", " res = Parallel(n_jobs=-1)(delayed(func)(pred.loc[d], **kwargs) for d in dates)\n", - " res = {\n", - " dates[i]: res[i]\n", - " for i in range(len(dates))\n", - " }\n", + " res = {dates[i]: res[i] for i in range(len(dates))}\n", " res = pd.DataFrame(res).T\n", - " r = res['R'].copy()\n", + " r = res[\"R\"].copy()\n", " r.index = pd.to_datetime(r.index)\n", " r = r.reindex(pd.date_range(r.index[0], r.index[-1])).fillna(0) # paper use 365 days\n", " return {\n", - " 'MSE': res['MSE'].mean(),\n", - " 'MAE': res['MAE'].mean(),\n", - " 'IC': res['IC'].mean(),\n", - " 'ICIR': res['IC'].mean()/res['IC'].std(),\n", - " 'AR': r.mean()*365,\n", - " 'AV': r.std()*365**0.5,\n", - " 'SR': r.mean()/r.std()*365**0.5,\n", - " 'MDD': (r.cumsum().cummax() - r.cumsum()).max()\n", + " \"MSE\": res[\"MSE\"].mean(),\n", + " \"MAE\": res[\"MAE\"].mean(),\n", + " \"IC\": res[\"IC\"].mean(),\n", + " \"ICIR\": res[\"IC\"].mean() / res[\"IC\"].std(),\n", + " \"AR\": r.mean() * 365,\n", + " \"AV\": r.std() * 365**0.5,\n", + " \"SR\": r.mean() / r.std() * 365**0.5,\n", + " \"MDD\": (r.cumsum().cummax() - r.cumsum()).max(),\n", " }, r\n", "\n", + "\n", "def fmt(x, p=3, scale=1, std=False):\n", - " _fmt = '{:.%df}'%p\n", + " _fmt = \"{:.%df}\" % p\n", " string = _fmt.format((x.mean() if not isinstance(x, (float, np.floating)) else x) * scale)\n", " if std and len(x) > 1:\n", - " string += ' ('+_fmt.format(x.std()*scale)+')'\n", + " string += \" (\" + _fmt.format(x.std() * scale) + \")\"\n", " return string\n", "\n", + "\n", "def backtest_multi(files, **kwargs):\n", " res = []\n", " pnl = []\n", @@ -88,14 +94,14 @@ " res = pd.DataFrame(res)\n", " pnl = pd.concat(pnl, axis=1)\n", " return {\n", - " 'MSE': fmt(res['MSE'], std=True),\n", - " 'MAE': fmt(res['MAE'], std=True),\n", - " 'IC': fmt(res['IC']),\n", - " 'ICIR': fmt(res['ICIR']),\n", - " 'AR': fmt(res['AR'], scale=100, p=1)+'%',\n", - " 'VR': fmt(res['AV'], scale=100, p=1)+'%',\n", - " 'SR': fmt(res['SR']),\n", - " 'MDD': fmt(res['MDD'], scale=100, p=1)+'%'\n", + " \"MSE\": fmt(res[\"MSE\"], std=True),\n", + " \"MAE\": fmt(res[\"MAE\"], std=True),\n", + " \"IC\": fmt(res[\"IC\"]),\n", + " \"ICIR\": fmt(res[\"ICIR\"]),\n", + " \"AR\": fmt(res[\"AR\"], scale=100, p=1) + \"%\",\n", + " \"VR\": fmt(res[\"AV\"], scale=100, p=1) + \"%\",\n", + " \"SR\": fmt(res[\"SR\"]),\n", + " \"MDD\": fmt(res[\"MDD\"], scale=100, p=1) + \"%\",\n", " }, pnl" ] }, @@ -124,16 +130,20 @@ "outputs": [], "source": [ "exps = {\n", - " 'Linear': ['output/Linear/pred.pkl'],\n", - " 'LightGBM': ['output/GBDT/lr0.05_leaves128/pred.pkl'],\n", - " 'MLP': glob.glob('output/search/MLP/hs128_bs512_do0.3_lr0.001_seed*/pred.pkl'),\n", - " 'SFM': glob.glob('output/search/SFM/hs32_bs512_do0.5_lr0.001_seed*/pred.pkl'),\n", - " 'ALSTM': glob.glob('output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n", - " 'Trans.': glob.glob('output/search/Transformer/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n", - " 'ALSTM+TS':glob.glob('output/LSTM_Attn_TS/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n", - " 'Trans.+TS':glob.glob('output/Transformer_TS/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n", - " 'ALSTM+TRA(Ours)': glob.glob('output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n", - " 'Trans.+TRA(Ours)': glob.glob('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb1.0_head4_hs64_bs512_do0.1_lr0.0005_seed*/pred.pkl')\n", + " \"Linear\": [\"output/Linear/pred.pkl\"],\n", + " \"LightGBM\": [\"output/GBDT/lr0.05_leaves128/pred.pkl\"],\n", + " \"MLP\": glob.glob(\"output/search/MLP/hs128_bs512_do0.3_lr0.001_seed*/pred.pkl\"),\n", + " \"SFM\": glob.glob(\"output/search/SFM/hs32_bs512_do0.5_lr0.001_seed*/pred.pkl\"),\n", + " \"ALSTM\": glob.glob(\"output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl\"),\n", + " \"Trans.\": glob.glob(\"output/search/Transformer/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl\"),\n", + " \"ALSTM+TS\": glob.glob(\"output/LSTM_Attn_TS/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl\"),\n", + " \"Trans.+TS\": glob.glob(\"output/Transformer_TS/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl\"),\n", + " \"ALSTM+TRA(Ours)\": glob.glob(\n", + " \"output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n", + " ),\n", + " \"Trans.+TRA(Ours)\": glob.glob(\n", + " \"output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb1.0_head4_hs64_bs512_do0.1_lr0.0005_seed*/pred.pkl\"\n", + " ),\n", "}" ] }, @@ -160,14 +170,8 @@ } ], "source": [ - "res = {\n", - " name: backtest_multi(exps[name])\n", - " for name in tqdm(exps)\n", - "}\n", - "report = pd.DataFrame({\n", - " k: v[0]\n", - " for k, v in res.items()\n", - "}).T" + "res = {name: backtest_multi(exps[name]) for name in tqdm(exps)}\n", + "report = pd.DataFrame({k: v[0] for k, v in res.items()}).T" ] }, { @@ -385,24 +389,40 @@ } ], "source": [ - "df = pd.read_pickle('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed1000/pred.pkl')\n", - "code = 'SH600157'\n", - "date = '2018-09-28'\n", + "df = pd.read_pickle(\n", + " \"output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed1000/pred.pkl\"\n", + ")\n", + "code = \"SH600157\"\n", + "date = \"2018-09-28\"\n", "lookbackperiod = 50\n", "\n", "prob = df.iloc[:, -3:].loc(axis=0)[:, code].reset_index(level=1, drop=True).loc[date:].iloc[:lookbackperiod]\n", - "pred = df.loc[:,[\"score_0\",\"score_1\",\"score_2\",\"label\"]].loc(axis=0)[:, code].reset_index(level=1, drop=True).loc[date:].iloc[:lookbackperiod]\n", - "e_all = pred.iloc[:,:-1].sub(pred.iloc[:,-1], axis=0).pow(2)\n", + "pred = (\n", + " df.loc[:, [\"score_0\", \"score_1\", \"score_2\", \"label\"]]\n", + " .loc(axis=0)[:, code]\n", + " .reset_index(level=1, drop=True)\n", + " .loc[date:]\n", + " .iloc[:lookbackperiod]\n", + ")\n", + "e_all = pred.iloc[:, :-1].sub(pred.iloc[:, -1], axis=0).pow(2)\n", "e_all = e_all.sub(e_all.min(axis=1), axis=0)\n", - "e_all.columns = [r'$\\theta_%d$'%d for d in range(1, 4)]\n", + "e_all.columns = [r\"$\\theta_%d$\" % d for d in range(1, 4)]\n", "prob = pd.Series(np.argmax(prob.values, axis=1), index=prob.index).rolling(7).mean().round()\n", "\n", "fig, axes = plt.subplots(1, 2, figsize=(7, 3))\n", - "e_all.plot(ax=axes[0], xlabel='', rot=30)\n", - "prob.plot(ax=axes[1], xlabel='', rot=30, color='red', linestyle='None', marker='^', markersize=5)\n", + "e_all.plot(ax=axes[0], xlabel=\"\", rot=30)\n", + "prob.plot(\n", + " ax=axes[1],\n", + " xlabel=\"\",\n", + " rot=30,\n", + " color=\"red\",\n", + " linestyle=\"None\",\n", + " marker=\"^\",\n", + " markersize=5,\n", + ")\n", "plt.yticks(np.array([0, 1, 2]), e_all.columns.values)\n", - "axes[0].set_ylabel('Predictor Loss')\n", - "axes[1].set_ylabel('Router Selection')\n", + "axes[0].set_ylabel(\"Predictor Loss\")\n", + "axes[1].set_ylabel(\"Router Selection\")\n", "plt.tight_layout()\n", "# plt.savefig('select.pdf', bbox_inches='tight')\n", "plt.show()" @@ -428,10 +448,18 @@ "outputs": [], "source": [ "exps = {\n", - " 'Random': glob.glob('output/search/LSTM_Attn_tra/K10_traHs16_traSrcNONE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n", - " 'LR': glob.glob('output/search/LSTM_Attn_tra/K10_traHs16_traSrcLR_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n", - " 'TPE': glob.glob('output/search/LSTM_Attn_tra/K10_traHs16_traSrcTPE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n", - " 'LR+TPE': glob.glob('output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl')\n", + " \"Random\": glob.glob(\n", + " \"output/search/LSTM_Attn_tra/K10_traHs16_traSrcNONE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n", + " ),\n", + " \"LR\": glob.glob(\n", + " \"output/search/LSTM_Attn_tra/K10_traHs16_traSrcLR_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n", + " ),\n", + " \"TPE\": glob.glob(\n", + " \"output/search/LSTM_Attn_tra/K10_traHs16_traSrcTPE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n", + " ),\n", + " \"LR+TPE\": glob.glob(\n", + " \"output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n", + " ),\n", "}" ] }, @@ -456,14 +484,8 @@ } ], "source": [ - "res = {\n", - " name: backtest_multi(exps[name])\n", - " for name in tqdm(exps)\n", - "}\n", - "report = pd.DataFrame({\n", - " k: v[0]\n", - " for k, v in res.items()\n", - "}).T" + "res = {name: backtest_multi(exps[name]) for name in tqdm(exps)}\n", + "report = pd.DataFrame({k: v[0] for k, v in res.items()}).T" ] }, { @@ -597,18 +619,22 @@ } ], "source": [ - "a = pd.read_pickle('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl')\n", - "b = pd.read_pickle('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl')\n", + "a = pd.read_pickle(\n", + " \"output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl\"\n", + ")\n", + "b = pd.read_pickle(\n", + " \"output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl\"\n", + ")\n", "a = a.iloc[:, -3:]\n", "b = b.iloc[:, -3:]\n", "b = np.eye(3)[b.values.argmax(axis=1)]\n", "a = np.eye(3)[a.values.argmax(axis=1)]\n", "\n", - "res = pd.DataFrame({\n", - " 'with OT': b.sum(axis=0) / b.sum(),\n", - " 'without OT': a.sum(axis=0)/ a.sum() \n", - "},index=[r'$\\theta_1$',r'$\\theta_2$',r'$\\theta_3$'])\n", - "res.plot.bar(rot=30, figsize=(5, 4), color=['b', 'g'])\n", + "res = pd.DataFrame(\n", + " {\"with OT\": b.sum(axis=0) / b.sum(), \"without OT\": a.sum(axis=0) / a.sum()},\n", + " index=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"],\n", + ")\n", + "res.plot.bar(rot=30, figsize=(5, 4), color=[\"b\", \"g\"])\n", "del a, b" ] }, @@ -633,11 +659,19 @@ "outputs": [], "source": [ "exps = {\n", - " 'K=1': glob.glob('output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/info.json'),\n", - " 'K=3': glob.glob('output/search/finetune/LSTM_Attn_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json'),\n", - " 'K=5': glob.glob('output/search/finetune/LSTM_Attn_tra/K5_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json'),\n", - " 'K=10': glob.glob('output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json'),\n", - " 'K=20': glob.glob('output/search/finetune/LSTM_Attn_tra/K20_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json')\n", + " \"K=1\": glob.glob(\"output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/info.json\"),\n", + " \"K=3\": glob.glob(\n", + " \"output/search/finetune/LSTM_Attn_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json\"\n", + " ),\n", + " \"K=5\": glob.glob(\n", + " \"output/search/finetune/LSTM_Attn_tra/K5_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json\"\n", + " ),\n", + " \"K=10\": glob.glob(\n", + " \"output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json\"\n", + " ),\n", + " \"K=20\": glob.glob(\n", + " \"output/search/finetune/LSTM_Attn_tra/K20_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json\"\n", + " ),\n", "}" ] }, @@ -649,16 +683,11 @@ "source": [ "report = dict()\n", "for k, v in exps.items():\n", - " \n", " tmp = dict()\n", " for fname in v:\n", " with open(fname) as f:\n", " info = json.load(f)\n", - " tmp[fname] = (\n", - " {\n", - " \"IC\":info[\"metric\"][\"IC\"],\n", - " \"MSE\":info[\"metric\"][\"MSE\"]\n", - " })\n", + " tmp[fname] = {\"IC\": info[\"metric\"][\"IC\"], \"MSE\": info[\"metric\"][\"MSE\"]}\n", " tmp = pd.DataFrame(tmp).T\n", " report[k] = tmp.mean()\n", "report = pd.DataFrame(report).T" @@ -681,13 +710,14 @@ } ], "source": [ - "fig, axes = plt.subplots(1, 2, figsize=(6,3)); axes = axes.flatten()\n", - "report['IC'].plot.bar(rot=30, ax=axes[0])\n", + "fig, axes = plt.subplots(1, 2, figsize=(6, 3))\n", + "axes = axes.flatten()\n", + "report[\"IC\"].plot.bar(rot=30, ax=axes[0])\n", "axes[0].set_ylim(0.045, 0.062)\n", - "axes[0].set_title('IC performance')\n", - "report['MSE'].astype(float).plot.bar(rot=30, ax=axes[1], color='green')\n", + "axes[0].set_title(\"IC performance\")\n", + "report[\"MSE\"].astype(float).plot.bar(rot=30, ax=axes[1], color=\"green\")\n", "axes[1].set_ylim(0.155, 0.1585)\n", - "axes[1].set_title('MSE performance')\n", + "axes[1].set_title(\"MSE performance\")\n", "plt.tight_layout()\n", "# plt.savefig('sensitivity.pdf')" ] diff --git a/examples/rl/simple_example.ipynb b/examples/rl/simple_example.ipynb index c2c771772..1e655ff18 100644 --- a/examples/rl/simple_example.ipynb +++ b/examples/rl/simple_example.ipynb @@ -41,6 +41,7 @@ "\n", "State = namedtuple(\"State\", [\"value\", \"last_action\"])\n", "\n", + "\n", "class SimpleSimulator(Simulator[float, State, float]):\n", " def __init__(self, initial: float, nsteps: int, **kwargs: Any) -> None:\n", " super().__init__(initial)\n", @@ -92,6 +93,7 @@ "from gym import spaces\n", "from qlib.rl.interpreter import StateInterpreter\n", "\n", + "\n", "class SimpleStateInterpreter(StateInterpreter[Tuple[float, float], np.ndarray]):\n", " def interpret(self, state: State) -> np.ndarray:\n", " # Convert state.value to a 1D Numpy array\n", @@ -101,7 +103,8 @@ " @property\n", " def observation_space(self) -> spaces.Box:\n", " return spaces.Box(0, np.inf, shape=(1,), dtype=np.float32)\n", - " \n", + "\n", + "\n", "state_interpreter = SimpleStateInterpreter()" ] }, @@ -120,6 +123,7 @@ "source": [ "from qlib.rl.interpreter import ActionInterpreter\n", "\n", + "\n", "class SimpleActionInterpreter(ActionInterpreter[State, int, float]):\n", " def __init__(self, n_value: int) -> None:\n", " self.n_value = n_value\n", @@ -132,7 +136,8 @@ " assert 0 <= action <= self.n_value\n", " # simulator_state.value is used as the denominator\n", " return simulator_state.value * (action / self.n_value)\n", - " \n", + "\n", + "\n", "action_interpreter = SimpleActionInterpreter(n_value=10)" ] }, @@ -151,12 +156,14 @@ "source": [ "from qlib.rl.reward import Reward\n", "\n", + "\n", "class SimpleReward(Reward[State]):\n", " def reward(self, simulator_state: State) -> float:\n", " # Use last_action to calculate reward. This is why it should be in the state.\n", " rew = simulator_state.last_action / simulator_state.value\n", " return rew\n", - " \n", + "\n", + "\n", "reward = SimpleReward()" ] }, @@ -180,6 +187,7 @@ "from torch import nn\n", "from qlib.rl.order_execution import PPO\n", "\n", + "\n", "class SimpleFullyConnect(nn.Module):\n", " def __init__(self, dims: List[int]) -> None:\n", " super().__init__()\n", @@ -195,7 +203,8 @@ "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " return self.fc(x)\n", - " \n", + "\n", + "\n", "policy = PPO(\n", " network=SimpleFullyConnect(dims=[16, 8]),\n", " obs_space=state_interpreter.observation_space,\n", @@ -221,6 +230,7 @@ "source": [ "from torch.utils.data import Dataset\n", "\n", + "\n", "class SimpleDataset(Dataset):\n", " def __init__(self, positions: List[float]) -> None:\n", " self.positions = positions\n", @@ -230,7 +240,8 @@ "\n", " def __getitem__(self, index: int) -> float:\n", " return self.positions[index]\n", - " \n", + "\n", + "\n", "dataset = SimpleDataset(positions=[10.0, 50.0, 100.0])" ] }, @@ -265,11 +276,13 @@ "trainer_kwargs = {\n", " \"max_iters\": 10,\n", " \"finite_env_type\": \"dummy\",\n", - " \"callbacks\": [Checkpoint(\n", - " dirpath=Path(\"./checkpoints\"),\n", - " every_n_iters=1,\n", - " save_latest=\"copy\",\n", - " )],\n", + " \"callbacks\": [\n", + " Checkpoint(\n", + " dirpath=Path(\"./checkpoints\"),\n", + " every_n_iters=1,\n", + " save_latest=\"copy\",\n", + " )\n", + " ],\n", "}\n", "vessel_kwargs = {\n", " \"update_kwargs\": {\"batch_size\": 16, \"repeat\": 5},\n", diff --git a/examples/tutorial/detailed_workflow.ipynb b/examples/tutorial/detailed_workflow.ipynb index 3d73c37f5..925a1fec6 100644 --- a/examples/tutorial/detailed_workflow.ipynb +++ b/examples/tutorial/detailed_workflow.ipynb @@ -88,6 +88,7 @@ "outputs": [], "source": [ "from qlib.tests.data import GetData\n", + "\n", "GetData().qlib_data(exists_skip=True)" ] }, @@ -99,6 +100,7 @@ "outputs": [], "source": [ "import qlib\n", + "\n", "qlib.init()" ] }, @@ -134,7 +136,8 @@ "outputs": [], "source": [ "from qlib.data import D\n", - "D.calendar(start_time='2010-01-01', end_time='2017-12-31', freq='day')[:2] # calendar data" + "\n", + "print(D.calendar(start_time=\"2010-01-01\", end_time=\"2017-12-31\", freq=\"day\")[:2]) # calendar data" ] }, { @@ -152,7 +155,12 @@ "metadata": {}, "outputs": [], "source": [ - "df = D.features(['SH601216'], ['$open', '$high', '$low', '$close', '$factor'], start_time='2020-05-01', end_time='2020-05-31') " + "df = D.features(\n", + " [\"SH601216\"],\n", + " [\"$open\", \"$high\", \"$low\", \"$close\", \"$factor\"],\n", + " start_time=\"2020-05-01\",\n", + " end_time=\"2020-05-31\",\n", + ")" ] }, { @@ -163,11 +171,18 @@ "outputs": [], "source": [ "import plotly.graph_objects as go\n", - "fig = go.Figure(data=[go.Candlestick(x=df.index.get_level_values(\"datetime\"),\n", - " open=df['$open'],\n", - " high=df['$high'],\n", - " low=df['$low'],\n", - " close=df['$close'])])\n", + "\n", + "fig = go.Figure(\n", + " data=[\n", + " go.Candlestick(\n", + " x=df.index.get_level_values(\"datetime\"),\n", + " open=df[\"$open\"],\n", + " high=df[\"$high\"],\n", + " low=df[\"$low\"],\n", + " close=df[\"$close\"],\n", + " )\n", + " ]\n", + ")\n", "fig.show()" ] }, @@ -197,11 +212,18 @@ "outputs": [], "source": [ "import plotly.graph_objects as go\n", - "fig = go.Figure(data=[go.Candlestick(x=df.index.get_level_values(\"datetime\"),\n", - " open=df['$open'] / df['$factor'],\n", - " high=df['$high'] / df['$factor'],\n", - " low=df['$low'] / df['$factor'],\n", - " close=df['$close'] / df['$factor'])])\n", + "\n", + "fig = go.Figure(\n", + " data=[\n", + " go.Candlestick(\n", + " x=df.index.get_level_values(\"datetime\"),\n", + " open=df[\"$open\"] / df[\"$factor\"],\n", + " high=df[\"$high\"] / df[\"$factor\"],\n", + " low=df[\"$low\"] / df[\"$factor\"],\n", + " close=df[\"$close\"] / df[\"$factor\"],\n", + " )\n", + " ]\n", + ")\n", "fig.show()" ] }, @@ -240,7 +262,7 @@ "outputs": [], "source": [ "# dynamic universe\n", - "universe = D.list_instruments(D.instruments('csi100'), start_time='2010-01-01', end_time='2020-12-31')\n", + "universe = D.list_instruments(D.instruments(\"csi100\"), start_time=\"2010-01-01\", end_time=\"2020-12-31\")\n", "pprint(universe)" ] }, @@ -271,8 +293,8 @@ "metadata": {}, "outputs": [], "source": [ - "df = D.features(D.instruments('csi100'), ['$close'], start_time='2010-01-01', end_time='2020-12-31') \n", - "df.groupby('datetime').size().plot()" + "df = D.features(D.instruments(\"csi100\"), [\"$close\"], start_time=\"2010-01-01\", end_time=\"2020-12-31\")\n", + "df.groupby(\"datetime\").size().plot()" ] }, { @@ -313,8 +335,7 @@ " !cd ../../scripts/data_collector/pit/ && pip install -r requirements.txt\n", " !cd ../../scripts/data_collector/pit/ && python collector.py download_data --source_dir ~/.qlib/stock_data/source/pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_regex \"^(600519|000725).*\"\n", " !cd ../../scripts/data_collector/pit/ && python collector.py normalize_data --interval quarterly --source_dir ~/.qlib/stock_data/source/pit --normalize_dir ~/.qlib/stock_data/source/pit_normalized\n", - " !cd ../../scripts/ && python dump_pit.py dump --csv_path ~/.qlib/stock_data/source/pit_normalized --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly\n", - " pass" + " !cd ../../scripts/ && python dump_pit.py dump --csv_path ~/.qlib/stock_data/source/pit_normalized --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly" ] }, { @@ -338,7 +359,13 @@ "outputs": [], "source": [ "instruments = [\"sh600519\"]\n", - "data = D.features(instruments, ['P($$roewa_q)'], start_time=\"2019-01-01\", end_time=\"2019-07-19\", freq=\"day\")" + "data = D.features(\n", + " instruments,\n", + " [\"P($$roewa_q)\"],\n", + " start_time=\"2019-01-01\",\n", + " end_time=\"2019-07-19\",\n", + " freq=\"day\",\n", + ")" ] }, { @@ -366,7 +393,10 @@ "metadata": {}, "outputs": [], "source": [ - "D.features([\"sh600519\"], ['(EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9)/$close'])" + "D.features(\n", + " [\"sh600519\"],\n", + " [\"(EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9)/$close\"],\n", + ")" ] }, { @@ -418,7 +448,7 @@ "metadata": {}, "outputs": [], "source": [ - "qdl = QlibDataLoader(config=(['$close / Ref($close, 10)'], ['RET10']))" + "qdl = QlibDataLoader(config=([\"$close / Ref($close, 10)\"], [\"RET10\"]))" ] }, { @@ -428,7 +458,7 @@ "metadata": {}, "outputs": [], "source": [ - "qdl.load(instruments=['sh600519'], start_time='20190101', end_time='20191231')" + "qdl.load(instruments=[\"sh600519\"], start_time=\"20190101\", end_time=\"20191231\")" ] }, { @@ -456,7 +486,7 @@ "metadata": {}, "outputs": [], "source": [ - "df = qdl.load(instruments=['sh600519'], start_time='20190101', end_time='20191231')" + "df = qdl.load(instruments=[\"sh600519\"], start_time=\"20190101\", end_time=\"20191231\")" ] }, { @@ -476,7 +506,7 @@ "metadata": {}, "outputs": [], "source": [ - "df.plot(kind='hist')" + "df.plot(kind=\"hist\")" ] }, { @@ -508,9 +538,16 @@ "source": [ "# NOTE: normally, the training & validation time range will be `fit_start_time` , `fit_end_time`\n", "# however,all the components are decomposed, so the training & validation time range is unknown when preprocessing.\n", - "dh = DataHandlerLP(instruments=['sh600519'], start_time='20170101', end_time='20191231',\n", - " infer_processors=[ZScoreNorm(fit_start_time='20170101', fit_end_time='20181231'), Fillna()],\n", - " data_loader=qdl)" + "dh = DataHandlerLP(\n", + " instruments=[\"sh600519\"],\n", + " start_time=\"20170101\",\n", + " end_time=\"20191231\",\n", + " infer_processors=[\n", + " ZScoreNorm(fit_start_time=\"20170101\", fit_end_time=\"20181231\"),\n", + " Fillna(),\n", + " ],\n", + " data_loader=qdl,\n", + ")" ] }, { @@ -550,7 +587,7 @@ "metadata": {}, "outputs": [], "source": [ - "df.plot(kind='hist')" + "df.plot(kind=\"hist\")" ] }, { @@ -586,7 +623,7 @@ "metadata": {}, "outputs": [], "source": [ - "ds = DatasetH(dh, segments={\"train\": ('20180101', '20181231'), \"valid\": ('20190101', '20191231')})" + "ds = DatasetH(dh, segments={\"train\": (\"20180101\", \"20181231\"), \"valid\": (\"20190101\", \"20191231\")})" ] }, { @@ -596,7 +633,7 @@ "metadata": {}, "outputs": [], "source": [ - "ds.prepare('train')" + "ds.prepare(\"train\")" ] }, { @@ -606,7 +643,7 @@ "metadata": {}, "outputs": [], "source": [ - "ds.prepare('valid')" + "ds.prepare(\"valid\")" ] }, { @@ -628,8 +665,12 @@ "metadata": {}, "outputs": [], "source": [ - "ds = TSDatasetH(step_len=10, handler=dh, segments={\"train\": ('20180101', '20181231'), \"valid\": ('20190101', '20191231')})\n", - "train_sampler = ds.prepare('train')" + "ds = TSDatasetH(\n", + " step_len=10,\n", + " handler=dh,\n", + " segments={\"train\": (\"20180101\", \"20181231\"), \"valid\": (\"20190101\", \"20191231\")},\n", + ")\n", + "train_sampler = ds.prepare(\"train\")" ] }, { @@ -649,7 +690,7 @@ "metadata": {}, "outputs": [], "source": [ - "train_sampler[0] # Retrieving the first example" + "train_sampler[0] # Retrieving the first example" ] }, { @@ -659,7 +700,7 @@ "metadata": {}, "outputs": [], "source": [ - "train_sampler['2018-01-08', 'sh600519'] # get the time series by <'timestamp', 'instrument_id'> index" + "train_sampler[\"2018-01-08\", \"sh600519\"] # get the time series by <'timestamp', 'instrument_id'> index" ] }, { @@ -682,11 +723,11 @@ "outputs": [], "source": [ "handler_kwargs = {\n", - " \"start_time\": \"2008-01-01\",\n", - " \"end_time\": \"2020-08-01\",\n", - " \"fit_start_time\": \"2008-01-01\",\n", - " \"fit_end_time\": \"2014-12-31\",\n", - " \"instruments\": MARKET,\n", + " \"start_time\": \"2008-01-01\",\n", + " \"end_time\": \"2020-08-01\",\n", + " \"fit_start_time\": \"2008-01-01\",\n", + " \"fit_end_time\": \"2014-12-31\",\n", + " \"instruments\": MARKET,\n", "}\n", "handler_conf = {\n", " \"class\": \"Alpha158\",\n", @@ -735,6 +776,7 @@ "outputs": [], "source": [ "from qlib.contrib.data.handler import Alpha158\n", + "\n", "hd = Alpha158(**handler_kwargs)" ] }, @@ -826,7 +868,7 @@ "metadata": {}, "outputs": [], "source": [ - "hd.process_type # appending type" + "hd.process_type # appending type" ] }, { @@ -857,16 +899,16 @@ "outputs": [], "source": [ "dataset_conf = {\n", - " \"class\": \"DatasetH\",\n", - " \"module_path\": \"qlib.data.dataset\",\n", - " \"kwargs\": {\n", - " \"handler\": hd,\n", - " \"segments\": {\n", - " \"train\": (\"2008-01-01\", \"2014-12-31\"),\n", - " \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n", - " \"test\": (\"2017-01-01\", \"2020-08-01\"),\n", - " },\n", + " \"class\": \"DatasetH\",\n", + " \"module_path\": \"qlib.data.dataset\",\n", + " \"kwargs\": {\n", + " \"handler\": hd,\n", + " \"segments\": {\n", + " \"train\": (\"2008-01-01\", \"2014-12-31\"),\n", + " \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n", + " \"test\": (\"2017-01-01\", \"2020-08-01\"),\n", " },\n", + " },\n", "}" ] }, @@ -908,7 +950,8 @@ "metadata": {}, "outputs": [], "source": [ - "model = init_instance_by_config({\n", + "model = init_instance_by_config(\n", + " {\n", " \"class\": \"LGBModel\",\n", " \"module_path\": \"qlib.contrib.model.gbdt\",\n", " \"kwargs\": {\n", @@ -922,7 +965,8 @@ " \"num_leaves\": 210,\n", " \"num_threads\": 20,\n", " },\n", - "})" + " }\n", + ")" ] }, { @@ -938,7 +982,7 @@ " R.save_objects(trained_model=model)\n", "\n", " rec = R.get_recorder()\n", - " rid = rec.id # save the record id\n", + " rid = rec.id # save the record id\n", "\n", " # Inference and saving signal\n", " sr = SignalRecord(model, dataset, rec)\n", @@ -1001,12 +1045,11 @@ "\n", "# backtest and analysis\n", "with R.start(experiment_name=EXP_NAME, recorder_id=rid, resume=True):\n", - "\n", " # signal-based analysis\n", " rec = R.get_recorder()\n", " sar = SigAnaRecord(rec)\n", " sar.generate()\n", - " \n", + "\n", " # portfolio-based analysis: backtest\n", " par = PortAnaRecord(rec, port_analysis_config, \"day\")\n", " par.generate()" @@ -1137,7 +1180,7 @@ "outputs": [], "source": [ "label_df = dataset.prepare(\"test\", col_set=\"label\")\n", - "label_df.columns = ['label']" + "label_df.columns = [\"label\"]" ] }, { diff --git a/examples/workflow_by_code.ipynb b/examples/workflow_by_code.ipynb index 5f456c66a..ebdf2d33b 100644 --- a/examples/workflow_by_code.ipynb +++ b/examples/workflow_by_code.ipynb @@ -38,7 +38,7 @@ " # install qlib\n", " ! pip install --upgrade numpy\n", " ! pip install pyqlib\n", - " if 'google.colab' in sys.modules:\n", + " if \"google.colab\" in sys.modules:\n", " # The Google colab environment is a little outdated. We have to downgrade the pyyaml to make it compatible with other packages\n", " ! pip install pyyaml==5.4.1\n", " # reload\n", @@ -50,7 +50,8 @@ " scripts_dir = Path(\"~/tmp/qlib_code/scripts\").expanduser().resolve()\n", " scripts_dir.mkdir(parents=True, exist_ok=True)\n", " import requests\n", - " with requests.get(\"https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py\") as resp:\n", + "\n", + " with requests.get(\"https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py\", timeout=10) as resp:\n", " with open(scripts_dir.joinpath(\"get_data.py\"), \"wb\") as fp:\n", " fp.write(resp.content)" ] @@ -61,14 +62,13 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "import qlib\n", "import pandas as pd\n", "from qlib.constant import REG_CN\n", "from qlib.utils import exists_qlib_data, init_instance_by_config\n", "from qlib.workflow import R\n", "from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n", - "from qlib.utils import flatten_dict\n" + "from qlib.utils import flatten_dict" ] }, { @@ -86,6 +86,7 @@ " print(f\"Qlib data is not found in {provider_uri}\")\n", " sys.path.append(str(scripts_dir))\n", " from get_data import GetData\n", + "\n", " GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n", "qlib.init(provider_uri=provider_uri, region=REG_CN)" ] @@ -169,7 +170,7 @@ " R.log_params(**flatten_dict(task))\n", " model.fit(dataset)\n", " R.save_objects(trained_model=model)\n", - " rid = R.get_recorder().id\n" + " rid = R.get_recorder().id" ] }, { @@ -238,7 +239,7 @@ "\n", " # backtest & analysis\n", " par = PortAnaRecord(recorder, port_analysis_config, \"day\")\n", - " par.generate()\n" + " par.generate()" ] }, { @@ -256,6 +257,7 @@ "source": [ "from qlib.contrib.report import analysis_model, analysis_position\n", "from qlib.data import D\n", + "\n", "recorder = R.get_recorder(recorder_id=ba_rid, experiment_name=\"backtest_analysis\")\n", "print(recorder)\n", "pred_df = recorder.load_object(\"pred.pkl\")\n", @@ -317,7 +319,7 @@ "outputs": [], "source": [ "label_df = dataset.prepare(\"test\", col_set=\"label\")\n", - "label_df.columns = ['label']" + "label_df.columns = [\"label\"]" ] }, { diff --git a/setup.py b/setup.py index 6b945642b..46b6876f2 100644 --- a/setup.py +++ b/setup.py @@ -146,6 +146,9 @@ setup( # References: https://github.com/python/typeshed/issues/8799 "mypy<0.981", "flake8", + "nbqa", + "jupyter", + "nbconvert", # The 5.0.0 version of importlib-metadata removed the deprecated endpoint, # which prevented flake8 from working properly, so we restricted the version of importlib-metadata. # To help ensure the dependencies of flake8 https://github.com/python/importlib_metadata/issues/406