diff --git a/examples/multi_level_trading/README.md b/examples/multi_level_trading/README.md new file mode 100644 index 000000000..f69afb13b --- /dev/null +++ b/examples/multi_level_trading/README.md @@ -0,0 +1,21 @@ +# Multi-level Trading + +This worflow is an example for multi-level trading. + +## Introduction + +Qlib supports backtesting of various strategies, including portfolio management strategies, order split strategies, model-based strategies (such as deep learning models), rule-based strategies, and RL-based strategies. + +And, Qlib also supports multi-level trading and backtesting. It means that users can use different strategies to trade at different frequencies. + +This example uses a DropoutTopkStrategy (a strategy based on the daily frequency Lightgbm model) in weekly frequency for portfolio generation. And, at the daily frequency level, this example uses SBBStrategyEMA (a rule-based strategy that uses EMA for decision-making) to split orders. + +## Usage + +Start backtesting by running the following command: +```bash + python workflow.py +``` + +Also, reports is shown in workflow.ipynb + diff --git a/examples/multi_level_trading/workflow.ipynb b/examples/multi_level_trading/workflow.ipynb new file mode 100644 index 000000000..a122a39fc --- /dev/null +++ b/examples/multi_level_trading/workflow.ipynb @@ -0,0 +1,305 @@ +{ + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + }, + "orig_nbformat": 2, + "kernelspec": { + "name": "pythonjvsc74a57bd0fcc004278713aaede7c629a6a43738a929cb09abb52817d4f72eb70db44cd87b", + "display_name": "Python 3.8.8 ('qlib_backtest': conda)" + }, + "metadata": { + "interpreter": { + "hash": "fcc004278713aaede7c629a6a43738a929cb09abb52817d4f72eb70db44cd87b" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2, + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) Microsoft Corporation.\n", + "# Licensed under the MIT License." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys, site\n", + "from pathlib import Path\n", + "\n", + "################################# NOTE #################################\n", + "# Please be aware that if colab installs the latest numpy and pyqlib #\n", + "# in this cell, users should RESTART the runtime in order to run the #\n", + "# following cells successfully. #\n", + "########################################################################\n", + "\n", + "try:\n", + " import qlib\n", + "except ImportError:\n", + " # install qlib\n", + " ! pip install --upgrade numpy\n", + " ! pip install pyqlib\n", + " # reload\n", + " site.main()\n", + "\n", + "scripts_dir = Path.cwd().parent.joinpath(\"scripts\")\n", + "if not scripts_dir.joinpath(\"get_data.py\").exists():\n", + " # download get_data.py script\n", + " scripts_dir = Path(\"~/tmp/qlib_code/scripts\").expanduser().resolve()\n", + " scripts_dir.mkdir(parents=True, exist_ok=True)\n", + " import requests\n", + " with requests.get(\"https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py\") as resp:\n", + " with open(scripts_dir.joinpath(\"get_data.py\"), \"wb\") as fp:\n", + " fp.write(resp.content)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import pandas as pd\n", + "from qlib.config import REG_CN\n", + "from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict\n", + "from qlib.workflow import R\n", + "from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n", + "from qlib.tests.data import GetData" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# use default data\n", + "provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n", + "if not exists_qlib_data(provider_uri):\n", + " print(f\"Qlib data is not found in {provider_uri}\")\n", + " GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n", + "\n", + "qlib.init(provider_uri=provider_uri, region=REG_CN)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "market = \"csi300\"\n", + "benchmark = \"SH000300\"\n", + "\n", + "###################################\n", + "# train model\n", + "###################################\n", + "\n", + "data_handler_config = {\n", + " \"start_time\": \"2008-01-01\",\n", + " \"end_time\": \"2020-08-01\",\n", + " \"fit_start_time\": \"2008-01-01\",\n", + " \"fit_end_time\": \"2014-12-31\",\n", + " \"instruments\": market,\n", + "}\n", + "\n", + "task = {\n", + " \"model\": {\n", + " \"class\": \"LGBModel\",\n", + " \"module_path\": \"qlib.contrib.model.gbdt\",\n", + " \"kwargs\": {\n", + " \"loss\": \"mse\",\n", + " \"colsample_bytree\": 0.8879,\n", + " \"learning_rate\": 0.0421,\n", + " \"subsample\": 0.8789,\n", + " \"lambda_l1\": 205.6999,\n", + " \"lambda_l2\": 580.9768,\n", + " \"max_depth\": 8,\n", + " \"num_leaves\": 210,\n", + " \"num_threads\": 20,\n", + " },\n", + " },\n", + " \"dataset\": {\n", + " \"class\": \"DatasetH\",\n", + " \"module_path\": \"qlib.data.dataset\",\n", + " \"kwargs\": {\n", + " \"handler\": {\n", + " \"class\": \"Alpha158\",\n", + " \"module_path\": \"qlib.contrib.data.handler\",\n", + " \"kwargs\": data_handler_config,\n", + " },\n", + " \"segments\": {\n", + " \"train\": (\"2008-01-01\", \"2014-12-31\"),\n", + " \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n", + " \"test\": (\"2017-01-01\", \"2020-08-01\"),\n", + " },\n", + " },\n", + " },\n", + "}\n", + "# model initialization\n", + "model = init_instance_by_config(task[\"model\"])\n", + "dataset = init_instance_by_config(task[\"dataset\"])\n", + "\n", + "# start exp to train model\n", + "with R.start(experiment_name=\"train_model\"):\n", + " R.log_params(**flatten_dict(task))\n", + " model.fit(dataset)\n", + " R.save_objects(trained_model=model)\n", + " rid = R.get_recorder().id\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "outputPrepend" + ] + }, + "outputs": [], + "source": [ + "trade_start_time = \"2017-01-01\"\n", + "trade_end_time = \"2020-08-01\"\n", + "\n", + "port_analysis_config = {\n", + " \"strategy\": {\n", + " \"class\": \"TopkDropoutStrategy\",\n", + " \"module_path\": \"qlib.contrib.strategy.model_strategy\",\n", + " \"kwargs\": {\n", + " \"step_bar\": \"week\",\n", + " \"model\": model,\n", + " \"dataset\": dataset,\n", + " \"topk\": 50,\n", + " \"n_drop\": 5,\n", + " },\n", + " },\n", + " \"env\": {\n", + " \"class\": \"SplitExecutor\",\n", + " \"module_path\": \"qlib.contrib.backtest.executor\",\n", + " \"kwargs\": {\n", + " \"step_bar\": \"week\",\n", + " \"generate_report\": True,\n", + " \"sub_env\": {\n", + " \"class\": \"SimulatorExecutor\",\n", + " \"module_path\": \"qlib.contrib.backtest.executor\",\n", + " \"kwargs\": {\n", + " \"step_bar\": \"day\",\n", + " \"verbose\": True,\n", + " \"generate_report\": True,\n", + " },\n", + " },\n", + " \"sub_strategy\": {\n", + " \"class\": \"SBBStrategyEMA\",\n", + " \"module_path\": \"qlib.contrib.strategy.rule_strategy\",\n", + " \"kwargs\": {\n", + " \"step_bar\": \"day\",\n", + " \"freq\": \"day\",\n", + " \"instruments\": market,\n", + " },\n", + " },\n", + " },\n", + " },\n", + " \"backtest\": {\n", + " \"start_time\": trade_start_time,\n", + " \"end_time\": trade_end_time,\n", + " \"account\": 100000000,\n", + " \"benchmark\": benchmark,\n", + " \"exchange_kwargs\": {\n", + " \"freq\": \"day\",\n", + " \"limit_threshold\": 0.095,\n", + " \"deal_price\": \"close\",\n", + " \"open_cost\": 0.0005,\n", + " \"close_cost\": 0.0015,\n", + " \"min_cost\": 5,\n", + " },\n", + " },\n", + "}\n", + "# backtest and analysis\n", + "with R.start(experiment_name=\"backtest_analysis\"):\n", + " # prediction\n", + " recorder = R.get_recorder()\n", + " ba_rid = recorder.id\n", + " sr = SignalRecord(model, dataset, recorder)\n", + " sr.generate()\n", + "\n", + " # backtest & analysis\n", + " par = PortAnaRecord(recorder, port_analysis_config, \"day\")\n", + " par.generate()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from qlib.contrib.report import analysis_model, analysis_position\n", + "from qlib.data import D\n", + "recorder = R.get_recorder(ba_rid, experiment_name=\"backtest_analysis\")\n", + "pred_df = recorder.load_object(\"pred.pkl\")\n", + "pred_df_dates = pred_df.index.get_level_values(level='datetime')\n", + "report_normal_df_1d = recorder.load_object(\"portfolio_analysis/report_normal_1day.pkl\")\n", + "positions_1d = recorder.load_object(\"portfolio_analysis/positions_normal_1day.pkl\")\n", + "analysis_df_1d = recorder.load_object(\"portfolio_analysis/port_analysis_1day.pkl\")\n", + "report_normal_df_1w = recorder.load_object(\"portfolio_analysis/report_normal_1week.pkl\")\n", + "positions_1w = recorder.load_object(\"portfolio_analysis/positions_normal_1week.pkl\")\n", + "analysis_df_1w = recorder.load_object(\"portfolio_analysis/port_analysis_1week.pkl\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "analysis_position.report_graph(report_normal_df_1d)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "analysis_position.report_graph(report_normal_df_1w)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "analysis_position.risk_analysis_graph(analysis_df_1d, report_normal_df_1d)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "analysis_position.risk_analysis_graph(analysis_df_1w, report_normal_df_1w)" + ] + } + ] +} \ No newline at end of file diff --git a/examples/multi_level_trading/workflow.py b/examples/multi_level_trading/workflow.py index 77689b3f7..8bfb4f3ec 100644 --- a/examples/multi_level_trading/workflow.py +++ b/examples/multi_level_trading/workflow.py @@ -1,11 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys -from pathlib import Path import qlib -import pandas as pd from qlib.config import REG_CN from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict diff --git a/qlib/contrib/backtest/executor.py b/qlib/contrib/backtest/executor.py index 96be0778f..ef0f205ce 100644 --- a/qlib/contrib/backtest/executor.py +++ b/qlib/contrib/backtest/executor.py @@ -127,8 +127,7 @@ class BaseExecutor(BaseTradeCalendar): self.track_data = track_data def get_init_state(self): - init_state = {"current": self.trade_account.current} - return init_state + raise NotImplementedError("get_init_state in not implemeted!") def execute(self, **kwargs): raise NotImplementedError("execute is not implemented!") @@ -180,9 +179,12 @@ class SplitExecutor(BaseExecutor): if generate_report: self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange self.sub_env = init_instance_by_config(sub_env, accept_types=BaseExecutor) - self.sub_strategy = init_instance_by_config(sub_strategy, accept_types=self.BaseStrategy) + def get_init_state(self): + init_state = {"current": self.trade_account.current} + return init_state + def _init_sub_trading(self, order_list): trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time) @@ -263,6 +265,10 @@ class SimulatorExecutor(BaseExecutor): ) self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange + def get_init_state(self): + init_state = {"current": self.trade_account.current, "trade_info": []} + return init_state + def execute(self, order_list): super(SimulatorExecutor, self).step() trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 222c56568..3a37d71d3 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -36,6 +36,10 @@ class TWAPStrategy(RuleStrategy, OrderEnhancement): def generate_order_list(self, execute_state): super(TWAPStrategy, self).step() + trade_info = execute_state.get("trade_info") + for order, _, _, _ in trade_info: + self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount + trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) order_list = [] for order in self.trade_order_list: @@ -56,7 +60,15 @@ class TWAPStrategy(RuleStrategy, OrderEnhancement): // (self.trade_len - self.trade_index) * _amount_trade_unit ) + + if order.direction == order.SELL: + if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and ( + _order_amount is None or self.trade_index == self.trade_len - 1 + ): + _order_amount = self.trade_amount[(order.stock_id, order.direction)] + if _order_amount: + _order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)]) _order = Order( stock_id=order.stock_id, amount=_order_amount, @@ -106,8 +118,11 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement): def generate_order_list(self, execute_state): super(SBBStrategyBase, self).step() - if not self.trade_order_list: - return [] + + trade_info = execute_state.get("trade_info") + for order, _, _, _ in trade_info: + self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount + trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1) order_list = [] @@ -139,11 +154,12 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement): * _amount_trade_unit ) if order.direction == order.SELL: - if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and _order_amount is None: + if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and ( + _order_amount is None or self.trade_index == self.trade_len - 1 + ): _order_amount = self.trade_amount[(order.stock_id, order.direction)] if _order_amount: - self.trade_amount[(order.stock_id, order.direction)] -= _order_amount _order = Order( stock_id=order.stock_id, amount=_order_amount, @@ -171,12 +187,13 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement): * _amount_trade_unit ) if order.direction == order.SELL: - if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and _order_amount is None: + if self.trade_amount[(order.stock_id, order.direction)] >= 1e-5 and ( + _order_amount is None or self.trade_index == self.trade_len - 1 + ): _order_amount = self.trade_amount[(order.stock_id, order.direction)] if _order_amount: _order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)]) - self.trade_amount[(order.stock_id, order.direction)] -= _order_amount if self.trade_index % 2 == 1: if ( _pred_trend == self.TREND_SHORT