From 0c6e50545541bdcc14bb62b2b631b7e35f74ba65 Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 25 May 2021 02:38:34 +0800 Subject: [PATCH] fix comments --- examples/multi_level_trading/README.md | 7 +- examples/multi_level_trading/workflow.ipynb | 305 -------------------- examples/multi_level_trading/workflow.py | 111 ++++--- qlib/config.py | 4 + qlib/contrib/backtest/__init__.py | 42 ++- qlib/contrib/backtest/account.py | 131 ++------- qlib/contrib/backtest/backtest.py | 31 +- qlib/contrib/backtest/exchange.py | 26 +- qlib/contrib/backtest/executor.py | 299 ++++++++++--------- qlib/contrib/backtest/faculty.py | 28 -- qlib/contrib/backtest/position.py | 2 +- qlib/contrib/backtest/report.py | 113 +++++++- qlib/contrib/backtest/utils.py | 67 +++++ qlib/contrib/evaluate.py | 11 +- qlib/contrib/online/operator.py | 6 +- qlib/contrib/strategy/cost_control.py | 8 +- qlib/contrib/strategy/model_strategy.py | 86 ++++-- qlib/contrib/strategy/rule_strategy.py | 260 +++++++++-------- qlib/data/data.py | 4 +- qlib/rl/env.py | 39 +-- qlib/rl/interpreter.py | 3 - qlib/strategy/base.py | 135 +++++---- qlib/utils/{sample.py => resam.py} | 91 +++--- qlib/workflow/record_temp.py | 24 +- 24 files changed, 855 insertions(+), 978 deletions(-) delete mode 100644 examples/multi_level_trading/workflow.ipynb delete mode 100644 qlib/contrib/backtest/faculty.py create mode 100644 qlib/contrib/backtest/utils.py rename qlib/utils/{sample.py => resam.py} (75%) diff --git a/examples/multi_level_trading/README.md b/examples/multi_level_trading/README.md index f69afb13b..6761b84ff 100644 --- a/examples/multi_level_trading/README.md +++ b/examples/multi_level_trading/README.md @@ -14,8 +14,11 @@ This example uses a DropoutTopkStrategy (a strategy based on the daily frequency Start backtesting by running the following command: ```bash - python workflow.py + python workflow.py backtest ``` -Also, reports is shown in workflow.ipynb +Start collecting data by running the following command: +```bash + python workflow.py collect_data +``` diff --git a/examples/multi_level_trading/workflow.ipynb b/examples/multi_level_trading/workflow.ipynb deleted file mode 100644 index a122a39fc..000000000 --- a/examples/multi_level_trading/workflow.ipynb +++ /dev/null @@ -1,305 +0,0 @@ -{ - "metadata": { - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.8" - }, - "orig_nbformat": 2, - "kernelspec": { - "name": "pythonjvsc74a57bd0fcc004278713aaede7c629a6a43738a929cb09abb52817d4f72eb70db44cd87b", - "display_name": "Python 3.8.8 ('qlib_backtest': conda)" - }, - "metadata": { - "interpreter": { - "hash": "fcc004278713aaede7c629a6a43738a929cb09abb52817d4f72eb70db44cd87b" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2, - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Copyright (c) Microsoft Corporation.\n", - "# Licensed under the MIT License." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys, site\n", - "from pathlib import Path\n", - "\n", - "################################# NOTE #################################\n", - "# Please be aware that if colab installs the latest numpy and pyqlib #\n", - "# in this cell, users should RESTART the runtime in order to run the #\n", - "# following cells successfully. #\n", - "########################################################################\n", - "\n", - "try:\n", - " import qlib\n", - "except ImportError:\n", - " # install qlib\n", - " ! pip install --upgrade numpy\n", - " ! pip install pyqlib\n", - " # reload\n", - " site.main()\n", - "\n", - "scripts_dir = Path.cwd().parent.joinpath(\"scripts\")\n", - "if not scripts_dir.joinpath(\"get_data.py\").exists():\n", - " # download get_data.py script\n", - " scripts_dir = Path(\"~/tmp/qlib_code/scripts\").expanduser().resolve()\n", - " scripts_dir.mkdir(parents=True, exist_ok=True)\n", - " import requests\n", - " with requests.get(\"https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py\") as resp:\n", - " with open(scripts_dir.joinpath(\"get_data.py\"), \"wb\") as fp:\n", - " fp.write(resp.content)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "import pandas as pd\n", - "from qlib.config import REG_CN\n", - "from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict\n", - "from qlib.workflow import R\n", - "from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n", - "from qlib.tests.data import GetData" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# use default data\n", - "provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n", - "if not exists_qlib_data(provider_uri):\n", - " print(f\"Qlib data is not found in {provider_uri}\")\n", - " GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n", - "\n", - "qlib.init(provider_uri=provider_uri, region=REG_CN)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "market = \"csi300\"\n", - "benchmark = \"SH000300\"\n", - "\n", - "###################################\n", - "# train model\n", - "###################################\n", - "\n", - "data_handler_config = {\n", - " \"start_time\": \"2008-01-01\",\n", - " \"end_time\": \"2020-08-01\",\n", - " \"fit_start_time\": \"2008-01-01\",\n", - " \"fit_end_time\": \"2014-12-31\",\n", - " \"instruments\": market,\n", - "}\n", - "\n", - "task = {\n", - " \"model\": {\n", - " \"class\": \"LGBModel\",\n", - " \"module_path\": \"qlib.contrib.model.gbdt\",\n", - " \"kwargs\": {\n", - " \"loss\": \"mse\",\n", - " \"colsample_bytree\": 0.8879,\n", - " \"learning_rate\": 0.0421,\n", - " \"subsample\": 0.8789,\n", - " \"lambda_l1\": 205.6999,\n", - " \"lambda_l2\": 580.9768,\n", - " \"max_depth\": 8,\n", - " \"num_leaves\": 210,\n", - " \"num_threads\": 20,\n", - " },\n", - " },\n", - " \"dataset\": {\n", - " \"class\": \"DatasetH\",\n", - " \"module_path\": \"qlib.data.dataset\",\n", - " \"kwargs\": {\n", - " \"handler\": {\n", - " \"class\": \"Alpha158\",\n", - " \"module_path\": \"qlib.contrib.data.handler\",\n", - " \"kwargs\": data_handler_config,\n", - " },\n", - " \"segments\": {\n", - " \"train\": (\"2008-01-01\", \"2014-12-31\"),\n", - " \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n", - " \"test\": (\"2017-01-01\", \"2020-08-01\"),\n", - " },\n", - " },\n", - " },\n", - "}\n", - "# model initialization\n", - "model = init_instance_by_config(task[\"model\"])\n", - "dataset = init_instance_by_config(task[\"dataset\"])\n", - "\n", - "# start exp to train model\n", - "with R.start(experiment_name=\"train_model\"):\n", - " R.log_params(**flatten_dict(task))\n", - " model.fit(dataset)\n", - " R.save_objects(trained_model=model)\n", - " rid = R.get_recorder().id\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [ - "outputPrepend" - ] - }, - "outputs": [], - "source": [ - "trade_start_time = \"2017-01-01\"\n", - "trade_end_time = \"2020-08-01\"\n", - "\n", - "port_analysis_config = {\n", - " \"strategy\": {\n", - " \"class\": \"TopkDropoutStrategy\",\n", - " \"module_path\": \"qlib.contrib.strategy.model_strategy\",\n", - " \"kwargs\": {\n", - " \"step_bar\": \"week\",\n", - " \"model\": model,\n", - " \"dataset\": dataset,\n", - " \"topk\": 50,\n", - " \"n_drop\": 5,\n", - " },\n", - " },\n", - " \"env\": {\n", - " \"class\": \"SplitExecutor\",\n", - " \"module_path\": \"qlib.contrib.backtest.executor\",\n", - " \"kwargs\": {\n", - " \"step_bar\": \"week\",\n", - " \"generate_report\": True,\n", - " \"sub_env\": {\n", - " \"class\": \"SimulatorExecutor\",\n", - " \"module_path\": \"qlib.contrib.backtest.executor\",\n", - " \"kwargs\": {\n", - " \"step_bar\": \"day\",\n", - " \"verbose\": True,\n", - " \"generate_report\": True,\n", - " },\n", - " },\n", - " \"sub_strategy\": {\n", - " \"class\": \"SBBStrategyEMA\",\n", - " \"module_path\": \"qlib.contrib.strategy.rule_strategy\",\n", - " \"kwargs\": {\n", - " \"step_bar\": \"day\",\n", - " \"freq\": \"day\",\n", - " \"instruments\": market,\n", - " },\n", - " },\n", - " },\n", - " },\n", - " \"backtest\": {\n", - " \"start_time\": trade_start_time,\n", - " \"end_time\": trade_end_time,\n", - " \"account\": 100000000,\n", - " \"benchmark\": benchmark,\n", - " \"exchange_kwargs\": {\n", - " \"freq\": \"day\",\n", - " \"limit_threshold\": 0.095,\n", - " \"deal_price\": \"close\",\n", - " \"open_cost\": 0.0005,\n", - " \"close_cost\": 0.0015,\n", - " \"min_cost\": 5,\n", - " },\n", - " },\n", - "}\n", - "# backtest and analysis\n", - "with R.start(experiment_name=\"backtest_analysis\"):\n", - " # prediction\n", - " recorder = R.get_recorder()\n", - " ba_rid = recorder.id\n", - " sr = SignalRecord(model, dataset, recorder)\n", - " sr.generate()\n", - "\n", - " # backtest & analysis\n", - " par = PortAnaRecord(recorder, port_analysis_config, \"day\")\n", - " par.generate()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from qlib.contrib.report import analysis_model, analysis_position\n", - "from qlib.data import D\n", - "recorder = R.get_recorder(ba_rid, experiment_name=\"backtest_analysis\")\n", - "pred_df = recorder.load_object(\"pred.pkl\")\n", - "pred_df_dates = pred_df.index.get_level_values(level='datetime')\n", - "report_normal_df_1d = recorder.load_object(\"portfolio_analysis/report_normal_1day.pkl\")\n", - "positions_1d = recorder.load_object(\"portfolio_analysis/positions_normal_1day.pkl\")\n", - "analysis_df_1d = recorder.load_object(\"portfolio_analysis/port_analysis_1day.pkl\")\n", - "report_normal_df_1w = recorder.load_object(\"portfolio_analysis/report_normal_1week.pkl\")\n", - "positions_1w = recorder.load_object(\"portfolio_analysis/positions_normal_1week.pkl\")\n", - "analysis_df_1w = recorder.load_object(\"portfolio_analysis/port_analysis_1week.pkl\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "analysis_position.report_graph(report_normal_df_1d)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "analysis_position.report_graph(report_normal_df_1w)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "analysis_position.risk_analysis_graph(analysis_df_1d, report_normal_df_1d)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "analysis_position.risk_analysis_graph(analysis_df_1w, report_normal_df_1w)" - ] - } - ] -} \ No newline at end of file diff --git a/examples/multi_level_trading/workflow.py b/examples/multi_level_trading/workflow.py index 8bfb4f3ec..390044480 100644 --- a/examples/multi_level_trading/workflow.py +++ b/examples/multi_level_trading/workflow.py @@ -3,30 +3,21 @@ import qlib +import fire from qlib.config import REG_CN from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord, PortAnaRecord from qlib.tests.data import GetData +from qlib.contrib.backtest import collect_data -if __name__ == "__main__": - # use default data - provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) - - qlib.init(provider_uri=provider_uri, region=REG_CN) +class MultiLevelTradingWorkflow: market = "csi300" benchmark = "SH000300" - ################################### - # train model - ################################### - data_handler_config = { "start_time": "2008-01-01", "end_time": "2020-08-01", @@ -68,31 +59,17 @@ if __name__ == "__main__": }, }, } - # model initialization - model = init_instance_by_config(task["model"]) - dataset = init_instance_by_config(task["dataset"]) trade_start_time = "2017-01-01" trade_end_time = "2020-08-01" port_analysis_config = { - "strategy": { - "class": "TopkDropoutStrategy", - "module_path": "qlib.contrib.strategy.model_strategy", - "kwargs": { - "step_bar": "week", - "model": model, - "dataset": dataset, - "topk": 50, - "n_drop": 5, - }, - }, - "env": { + "executor": { "class": "SplitExecutor", "module_path": "qlib.contrib.backtest.executor", "kwargs": { "step_bar": "week", - "sub_env": { + "sub_executor": { "class": "SimulatorExecutor", "module_path": "qlib.contrib.backtest.executor", "kwargs": { @@ -105,11 +82,11 @@ if __name__ == "__main__": "class": "SBBStrategyEMA", "module_path": "qlib.contrib.strategy.rule_strategy", "kwargs": { - "step_bar": "day", "freq": "day", "instruments": market, }, }, + "track_data": True, }, }, "backtest": { @@ -128,17 +105,69 @@ if __name__ == "__main__": }, } - with R.start(experiment_name="highfreq_backtest"): - R.log_params(**flatten_dict(task)) - model.fit(dataset) - R.save_objects(**{"params.pkl": model}) + def _init_qlib(self): + """initialize qlib""" + # use yahoo_cn_1min data + provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir + if not exists_qlib_data(provider_uri): + print(f"Qlib data is not found in {provider_uri}") + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + qlib.init(provider_uri=provider_uri, region=REG_CN) - # prediction - recorder = R.get_recorder() - sr = SignalRecord(model, dataset, recorder) - sr.generate() + def _train_model(self, model, dataset): + with R.start(experiment_name="train"): + R.log_params(**flatten_dict(self.task)) + model.fit(dataset) + R.save_objects(**{"params.pkl": model}) - # backtest. If users want to use backtest based on their own prediction, - # please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template. - par = PortAnaRecord(recorder, port_analysis_config, "day") - par.generate() + # prediction + recorder = R.get_recorder() + sr = SignalRecord(model, dataset, recorder) + sr.generate() + + def backtest(self): + self._init_qlib() + model = init_instance_by_config(self.task["model"]) + dataset = init_instance_by_config(self.task["dataset"]) + self._train_model(model, dataset) + strategy_config = { + "class": "TopkDropoutStrategy", + "module_path": "qlib.contrib.strategy.model_strategy", + "kwargs": { + "model": model, + "dataset": dataset, + "topk": 50, + "n_drop": 5, + }, + } + self.port_analysis_config["strategy"] = strategy_config + with R.start(experiment_name="backtest"): + + recorder = R.get_recorder() + par = PortAnaRecord(recorder, self.port_analysis_config, "day") + par.generate() + + def collect_data(self): + self._init_qlib() + model = init_instance_by_config(self.task["model"]) + dataset = init_instance_by_config(self.task["dataset"]) + self._train_model(model, dataset) + executor_config = self.port_analysis_config["executor"] + backtest_config = self.port_analysis_config["backtest"] + strategy_config = { + "class": "TopkDropoutStrategy", + "module_path": "qlib.contrib.strategy.model_strategy", + "kwargs": { + "model": model, + "dataset": dataset, + "topk": 50, + "n_drop": 5, + }, + } + data_generator = collect_data(executor=executor_config, strategy=strategy_config, **backtest_config) + for trade_decision in data_generator: + print(trade_decision) + + +if __name__ == "__main__": + fire.Fire(MultiLevelTradingWorkflow) diff --git a/qlib/config.py b/qlib/config.py index 75ab0fa3e..3bcf79ddb 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -140,6 +140,10 @@ _default_config = { "default_exp_name": "Experiment", }, }, + # Shift minute for highfreq minite data, used in backtest + # if min_data_shift == 0, use default market time [9:30, 11:29, 1:30, 2:39] + # if min_data_shift != 0, use shifted market time [9:30, 11:29, 1:30, 2:39] - shift*minute + "min_data_shift": {0}, } MODE_CONF = { diff --git a/qlib/contrib/backtest/__init__.py b/qlib/contrib/backtest/__init__.py index 8cfbf9674..effab026b 100644 --- a/qlib/contrib/backtest/__init__.py +++ b/qlib/contrib/backtest/__init__.py @@ -5,13 +5,12 @@ from .account import Account from .exchange import Exchange from .executor import BaseExecutor from .backtest import backtest as backtest_func - +from .backtest import collect_data as data_generator from ...strategy.base import BaseStrategy from ...utils import init_instance_by_config from ...log import get_module_logger from ...config import C -from .faculty import common_faculty logger = get_module_logger("backtest caller") @@ -89,8 +88,9 @@ def get_exchange( return init_instance_by_config(exchange, accept_types=Exchange) -def backtest(start_time, end_time, strategy, env, benchmark="SH000300", account=1e9, exchange_kwargs={}): - +def get_strategy_executor( + start_time, end_time, strategy, executor, benchmark="SH000300", account=1e9, exchange_kwargs={} +): trade_account = Account( init_cash=account, benchmark_config={ @@ -101,14 +101,32 @@ def backtest(start_time, end_time, strategy, env, benchmark="SH000300", account= ) trade_exchange = get_exchange(**exchange_kwargs) - common_faculty.update( - trade_account=trade_account, - trade_exchange=trade_exchange, + common_infra = { + "trade_account": trade_account, + "trade_exchange": trade_exchange, + } + + trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy, common_infra=common_infra) + trade_executor = init_instance_by_config(executor, accept_types=BaseExecutor, common_infra=common_infra) + + return trade_strategy, trade_executor + + +def backtest(start_time, end_time, strategy, executor, benchmark="SH000300", account=1e9, exchange_kwargs={}): + + trade_strategy, trade_executor = get_strategy_executor( + start_time, end_time, strategy, executor, benchmark, account, exchange_kwargs ) - - trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy) - trade_env = init_instance_by_config(env, accept_types=BaseExecutor) - - report_dict = backtest_func(start_time, end_time, trade_strategy, trade_env) + report_dict = backtest_func(start_time, end_time, trade_strategy, trade_executor) + + return report_dict + + +def collect_data(start_time, end_time, strategy, executor, benchmark="SH000300", account=1e9, exchange_kwargs={}): + + trade_strategy, trade_executor = get_strategy_executor( + start_time, end_time, strategy, executor, benchmark, account, exchange_kwargs + ) + report_dict = yield from data_generator(start_time, end_time, trade_strategy, trade_executor) return report_dict diff --git a/qlib/contrib/backtest/account.py b/qlib/contrib/backtest/account.py index df7614979..c7571bc98 100644 --- a/qlib/contrib/backtest/account.py +++ b/qlib/contrib/backtest/account.py @@ -9,8 +9,6 @@ import pandas as pd from .position import Position from .report import Report from .order import Order -from ...data import D -from ...utils.sample import parse_freq, sample_feature """ @@ -34,85 +32,14 @@ class Account: self.init_vars(init_cash, freq, benchmark_config) def init_vars(self, init_cash, freq: str, benchmark_config: dict): - """ - Parameters - ---------- - freq : str - frequency of trading bar, used for updating hold count of trading bar - benchmark_config : dict - config of benchmark, may including the following arguments: - - benchmark : Union[str, list, pd.Series] - - If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T. - example: - print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head()) - 2017-01-04 0.011693 - 2017-01-05 0.000721 - 2017-01-06 -0.004322 - 2017-01-09 0.006874 - 2017-01-10 -0.003350 - - If `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'. - - If `benchmark` is str, will use the daily change as the 'bench'. - benchmark code, default is SH000300 CSI300 - - start_time : Union[str, pd.Timestamp], optional - - If `benchmark` is pd.Series, it will be ignored - - Else, it represent start time of benchmark, by default None - - end_time : Union[str, pd.Timestamp], optional - - If `benchmark` is pd.Series, it will be ignored - - Else, it represent end time of benchmark, by default None - """ # init cash self.init_cash = init_cash - self.freq = freq - self.benchmark_config = benchmark_config - self.bench = self._cal_benchmark(benchmark_config, freq) self.current = Position(cash=init_cash) - self._reset_report() + self.reset(freq=freq, benchmark_config=benchmark_config, init_report=True) - def _cal_benchmark(self, benchmark_config, freq): - benchmark = benchmark_config.get("benchmark", "SH000300") - if isinstance(benchmark, pd.Series): - return benchmark - else: - start_time = benchmark_config.get("start_time", None) - end_time = benchmark_config.get("end_time", None) - - if freq is None: - raise ValueError("benchmark freq can't be None!") - _codes = benchmark if isinstance(benchmark, list) else [benchmark] - fields = ["$close/Ref($close,1)-1"] - try: - _temp_result = D.features(_codes, fields, start_time, end_time, freq=freq, disk_cache=1) - except ValueError: - _, norm_freq = parse_freq(freq) - if norm_freq in ["month", "week", "day"]: - try: - _temp_result = D.features(_codes, fields, start_time, end_time, freq="day", disk_cache=1) - except ValueError: - _temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1) - elif norm_freq == "minute": - _temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1) - else: - raise ValueError(f"benchmark freq {freq} is not supported") - if len(_temp_result) == 0: - raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark") - return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0) - - def _sample_benchmark(self, bench, trade_start_time, trade_end_time): - def cal_change(x): - return (x + 1).prod() - 1 - - _ret = sample_feature(bench, trade_start_time, trade_end_time, method=cal_change) - return 0 if _ret is None else _ret - - def _reset_freq(self, freq): - """reset frequency""" - if freq != self.freq: - self.freq = freq - self.bench = self._cal_benchmark(self.benchmark_config, self.freq) - - def _reset_report(self): - self.report = Report() + def reset_report(self, freq, benchmark_config): + self.report = Report(freq, benchmark_config) self.positions = {} self.rtn = 0 self.ct = 0 @@ -120,10 +47,25 @@ class Account: self.val = 0 self.earning = 0 - def reset(self, freq=None, init_report: bool = False): - self._reset_freq(freq) - if init_report: - self._reset_report() + def reset(self, freq=None, benchmark_config=None, init_report=False): + """reset freq and report of account + + Parameters + ---------- + freq : str, optional + frequency of account & report, by default None + benchmark_config : {}, optional + benchmark config of report, by default None + init_report : bool, optional + whether to initialize the report, by default False + """ + if freq is not None: + self.freq = freq + if benchmark_config is not None: + self.benchmark_config = benchmark_config + + if freq is not None or benchmark_config is not None or init_report: + self.reset_report(self.freq, self.benchmark_config) def get_positions(self): return self.positions @@ -131,7 +73,7 @@ class Account: def get_cash(self): return self.current.position["cash"] - def update_state_from_order(self, order, trade_val, cost, trade_price): + def _update_state_from_order(self, order, trade_val, cost, trade_price): # update turnover self.to += trade_val # update cost @@ -155,7 +97,7 @@ class Account: # The cost will be substracted from the cash at last. So the trading logic can ignore the cost calculation if order.direction == Order.SELL: # sell stock - self.update_state_from_order(order, trade_val, cost, trade_price) + self._update_state_from_order(order, trade_val, cost, trade_price) # update current position # for may sell all of stock_id self.current.update_order(order, trade_val, cost, trade_price) @@ -163,15 +105,15 @@ class Account: # buy stock # deal order, then update state self.current.update_order(order, trade_val, cost, trade_price) - self.update_state_from_order(order, trade_val, cost, trade_price) + self._update_state_from_order(order, trade_val, cost, trade_price) def update_bar_count(self): self.current.add_count_all(bar=self.freq) def update_bar_report(self, trade_start_time, trade_end_time, trade_exchange): """ - start_time: pd.TimeStamp - end_time: pd.TimeStamp + trade_start_time: pd.TimeStamp + trade_end_time: pd.TimeStamp quote: pd.DataFrame (code, date), collumns when the end of trade date - update rtn @@ -211,7 +153,8 @@ class Account: # judge whether the the trading is begin. # and don't add init account state into report, due to we don't have excess return in those days. self.report.update_report_record( - trade_time=trade_start_time, + trade_start_time=trade_start_time, + trade_end_time=trade_end_time, account_value=now_account_value, cash=self.current.position["cash"], return_rate=(self.earning + self.ct) / last_account_value, @@ -220,7 +163,6 @@ class Account: turnover_rate=self.to / last_account_value, cost_rate=self.ct / last_account_value, stock_value=now_stock_value, - bench_value=self._sample_benchmark(self.bench, trade_start_time, trade_end_time), ) # set now_account_value to position self.current.position["now_account_value"] = now_account_value @@ -234,18 +176,3 @@ class Account: self.rtn = 0 self.ct = 0 self.to = 0 - - def load_account(self, account_path): - report = Report() - position = Position() - report.load_report(account_path / "report.csv") - position.load_position(account_path / "position.xlsx") - - # assign values - self.init_vars(position.init_cash) - self.current = position - self.report = report - - def save_account(self, account_path): - self.current.save_position(account_path / "position.xlsx") - self.report.save_report(account_path / "report.csv") diff --git a/qlib/contrib/backtest/backtest.py b/qlib/contrib/backtest/backtest.py index 73785c771..33c73de7a 100644 --- a/qlib/contrib/backtest/backtest.py +++ b/qlib/contrib/backtest/backtest.py @@ -2,14 +2,29 @@ # Licensed under the MIT License. -def backtest(start_time, end_time, trade_strategy, trade_env): +def backtest(start_time, end_time, trade_strategy, trade_executor): - trade_env.reset(start_time=start_time, end_time=end_time) - trade_strategy.reset(start_time=start_time, end_time=end_time) + trade_executor.reset(start_time=start_time, end_time=end_time) + level_infra = trade_executor.get_level_infra() + trade_strategy.reset(level_infra=level_infra) - _execute_state = trade_env.get_init_state() - while not trade_env.finished(): - _order_list = trade_strategy.generate_order_list(_execute_state) - _execute_state = trade_env.execute(_order_list) + sub_execute_state = trade_executor.get_init_state() + while not trade_executor.finished(): + sub_trade_decision = trade_strategy.generate_trade_decision(sub_execute_state) + sub_execute_state = trade_executor.execute(sub_trade_decision) - return trade_env.get_report() + return trade_executor.get_report() + + +def collect_data(start_time, end_time, trade_strategy, trade_executor): + + trade_executor.reset(start_time=start_time, end_time=end_time) + level_infra = trade_executor.get_level_infra() + trade_strategy.reset(level_infra=level_infra) + + sub_execute_state = trade_executor.get_init_state() + while not trade_executor.finished(): + sub_trade_decision = trade_strategy.generate_trade_decision(sub_execute_state) + sub_execute_state = yield from trade_executor.collect_data(sub_trade_decision) + + return trade_executor.get_report() diff --git a/qlib/contrib/backtest/exchange.py b/qlib/contrib/backtest/exchange.py index 86045fd7a..09b7f2a63 100644 --- a/qlib/contrib/backtest/exchange.py +++ b/qlib/contrib/backtest/exchange.py @@ -11,7 +11,7 @@ import pandas as pd from ...data.data import D from ...data.dataset.utils import get_level_index from ...config import C, REG_CN -from ...utils.sample import sample_feature +from ...utils.resam import resam_ts_data from ...log import get_module_logger from .order import Order @@ -34,8 +34,9 @@ class Exchange: ): """__init__ - :param start_time: start time for backtest - :param end_time: end time for backtest + :param freq: frequency of data + :param start_time: closed start time for backtest + :param end_time: closed end time for backtest :param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50) :param deal_price: str, 'close', 'open', 'vwap' :param subscribe_fields: list, subscribe fields @@ -91,7 +92,7 @@ class Exchange: # $factor is for rounding to the trading unit # $change is for calculating the limit of the stock - necessary_fields = {self.deal_price, "$close", "$change", "$factor"} + necessary_fields = {self.deal_price, "$close", "$change", "$factor", "$volume"} subscribe_fields = list(necessary_fields | set(subscribe_fields)) all_fields = list(necessary_fields | set(subscribe_fields)) self.all_fields = all_fields @@ -167,12 +168,12 @@ class Exchange: trade_date is limtited """ - return sample_feature(self.quote[stock_id], start_time, end_time, fields="limit", method="all").iloc[0] + return resam_ts_data(self.quote[stock_id]["limit"], start_time, end_time, method="all").iloc[0] def check_stock_suspended(self, stock_id, start_time, end_time): # is suspended if stock_id in self.quote: - return sample_feature(self.quote[stock_id], start_time, end_time, method=None) is None + return resam_ts_data(self.quote[stock_id], start_time, end_time, method=None) is None else: return True @@ -230,15 +231,16 @@ class Exchange: return trade_val, trade_cost, trade_price def get_quote_info(self, stock_id, start_time, end_time): - return sample_feature(self.quote[stock_id], start_time, end_time, method="last").iloc[0] + return resam_ts_data(self.quote[stock_id], start_time, end_time, method="last").iloc[0] def get_close(self, stock_id, start_time, end_time): - return sample_feature(self.quote[stock_id], start_time, end_time, fields="$close", method="last").iloc[0] + return resam_ts_data(self.quote[stock_id]["$close"], start_time, end_time, method="last").iloc[0] + + def get_volume(self, stock_id, start_time, end_time): + return resam_ts_data(self.quote[stock_id]["$volume"], start_time, end_time, method="sum").iloc[0] def get_deal_price(self, stock_id, start_time, end_time): - deal_price = sample_feature( - self.quote[stock_id], start_time, end_time, fields=self.deal_price, method="last" - ).iloc[0] + deal_price = resam_ts_data(self.quote[stock_id][self.deal_price], start_time, end_time, method="last").iloc[0] if np.isclose(deal_price, 0.0) or np.isnan(deal_price): self.logger.warning( f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!" @@ -248,7 +250,7 @@ class Exchange: return deal_price def get_factor(self, stock_id, start_time, end_time): - return sample_feature(self.quote[stock_id], start_time, end_time, fields="$factor", method="last").iloc[0] + return resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method="last").iloc[0] def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time): """ diff --git a/qlib/contrib/backtest/executor.py b/qlib/contrib/backtest/executor.py index 943b26f9c..8a57d2986 100644 --- a/qlib/contrib/backtest/executor.py +++ b/qlib/contrib/backtest/executor.py @@ -2,88 +2,18 @@ import copy import warnings import pandas as pd from typing import Union -from ...data.data import Cal from ...utils import init_instance_by_config -from ...utils.sample import get_sample_freq_calendar, parse_freq +from ...utils.resam import parse_freq from .order import Order from .account import Account from .exchange import Exchange -from .faculty import common_faculty +from .utils import TradeCalendarManager -class BaseTradeCalendar: - """ - Base class providing trading calendar - - BaseStrategy and BaseExecutor should inherited from this class - """ - - def __init__( - self, step_bar: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None - ): - """ - Parameters - ---------- - step_bar : str - frequency of each trading step bar - start_time : Union[str, pd.Timestamp], optional - start time of trading, by default None - If `start_time` is None, it must be reset before trading. - end_time : Union[str, pd.Timestamp], optional - end time of trading, by default None - If `end_time` is None, it must be reset before trading. - """ - - self.step_bar = step_bar - self.start_time = pd.Timestamp(start_time) if start_time else None - self.end_time = pd.Timestamp(end_time) if end_time else None - self.reset(start_time=start_time, end_time=end_time) - - def _reset_trade_calendar(self, start_time, end_time): - """reset trade calendar""" - if start_time and end_time: - _calendar, freq, freq_sam = get_sample_freq_calendar(freq=self.step_bar) - self.calendar = _calendar - _, _, _start_index, _end_index = Cal.locate_index( - self.start_time, self.end_time, freq=freq, freq_sam=freq_sam - ) - self.start_index = _start_index - self.end_index = _end_index - self.trade_len = _end_index - _start_index + 1 - self.trade_index = 0 - else: - raise ValueError("failed to reset trade calendar, param `start_time` or `end_time` is None.") - - def reset(self, start_time=None, end_time=None): - """ - Reset start\end time of trading, and reset trading calendar - """ - - if start_time: - self.start_time = pd.Timestamp(start_time) - if end_time: - self.end_time = pd.Timestamp(end_time) - if self.start_time and self.end_time and (start_time or end_time): - self._reset_trade_calendar(start_time=self.start_time, end_time=self.end_time) - - def _get_calendar_time(self, trade_index=1, shift=0): - trade_index = trade_index - shift - calendar_index = self.start_index + trade_index - return self.calendar[calendar_index - 1], self.calendar[calendar_index] - pd.Timedelta(seconds=1) - - def finished(self): - return self.trade_index >= self.trade_len - - def step(self): - if self.finished(): - raise RuntimeError(f"this env has completed its task, please reset it if you want to call it!") - # trade count += 1 - self.trade_index = self.trade_index + 1 - - -class BaseExecutor(BaseTradeCalendar): +class BaseExecutor: """Base executor for trading""" def __init__( @@ -91,48 +21,97 @@ class BaseExecutor(BaseTradeCalendar): step_bar: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None, - trade_account: Account = None, generate_report: bool = False, verbose: bool = False, track_data: bool = False, + common_infra: dict = {}, **kwargs, ): """ Parameters ---------- - trade_account : Account, optional - trade account for trading, by default None - - If `trade_account` is None, self.trade_account will be set with common_faculty generate_report : bool, optional whether to generate report, by default False verbose : bool, optional whether to print trading info, by default False track_data : bool, optional - whether to generate order_list, will be used when making data for multi-level training - - If `self.track_data` is true, when making data for training, the input `order_list` of `execute` will be generated by `get_data` - - Else, `order_list` will not be generated + whether to generate trade_decision, will be used when making data for multi-level training + - If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will be generated by `collect_data` + - Else, `trade_decision` will not be generated + common_infra : dict, optional: + common infrastructure for backtesting, may including: + - trade_account : Account, optional + trade account for trading + - trade_exchange : Exchange, optional + exchange that provides market info + """ - super(BaseExecutor, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, **kwargs) - self.trade_account = copy.copy(common_faculty.trade_account if trade_account is None else trade_account) - self.trade_account.reset(freq=self.step_bar, init_report=True) + self.step_bar = step_bar self.generate_report = generate_report self.verbose = verbose self.track_data = track_data + self.reset(start_time=start_time, end_time=end_time, track_data=track_data, common_infra=common_infra) - def reset(self, track_data: bool = None, **kwargs): + def reset_common_infra(self, common_infra): """ - Reset `track_data`, will be used when making data for multi-level training + reset infrastructure for trading + - reset trade_account """ - super(BaseExecutor, self).reset(**kwargs) + if not hasattr(self, "common_infra"): + self.common_infra = common_infra + else: + self.common_infra.update(common_infra) + + if "trade_account" in common_infra: + self.trade_account = copy.copy(common_infra.get("trade_account")) + self.trade_account.reset(freq=self.step_bar, init_report=True) + + def reset(self, track_data: bool = None, common_infra: dict = None, **kwargs): + """ + - reset `start_time` and `end_time`, used in trade calendar + - reset `track_data`, used when making data for multi-level training + - reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc + """ + if track_data is not None: self.track_data = track_data + if common_infra is not None: + self.reset_common_infra(common_infra) + + if "start_time" in kwargs or "end_time" in kwargs: + start_time = kwargs.get("start_time") + end_time = kwargs.get("end_time") + self.trade_calendar = TradeCalendarManager(step_bar=self.step_bar, start_time=start_time, end_time=end_time) + + def get_level_infra(self): + return {"trade_calendar": self.trade_calendar} + + def finished(self): + return self.trade_calendar.finished() + + def execute(self, trade_decision): + """execute the trade decision and return the executed result + + Parameters + ---------- + trade_decision : object + + Returns + ---------- + executed state : List[Tuple[Order, float, float, float]] + - Each element in the list represents (order, trade value, trade cost, trade price) + """ + raise NotImplementedError("execute is not implemented!") + + def collect_data(self, trade_decision): + if self.track_data: + yield trade_decision + return self.execute(trade_decision) + def get_init_state(self): raise NotImplementedError("get_init_state in not implemeted!") - def execute(self, **kwargs): - raise NotImplementedError("execute is not implemented!") - def get_trade_account(self): raise NotImplementedError("get_trade_account is not implemented!") @@ -146,56 +125,75 @@ class SplitExecutor(BaseExecutor): def __init__( self, step_bar: str, - sub_env: Union[BaseExecutor, dict], + sub_executor: Union[BaseExecutor, dict], sub_strategy: Union[BaseStrategy, dict], start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None, - trade_account: Account = None, trade_exchange: Exchange = None, generate_report: bool = False, verbose: bool = False, track_data: bool = False, + common_infra: dict = {}, **kwargs, ): """ Parameters ---------- - sub_env : BaseExecutor + sub_executor : BaseExecutor trading env in each trading bar. sub_strategy : BaseStrategy trading strategy in each trading bar trade_exchange : Exchange - exchange that provides market info - - If `trade_exchange` is None, self.trade_exchange will be set with common_faculty + exchange that provides market info, used to generate report + - If generate_report is None, trade_exchange will be ignored + - Else If `trade_exchange` is None, self.trade_exchange will be set with common_infra """ + self.sub_executor = init_instance_by_config(sub_executor, common_infra=common_infra, accept_types=BaseExecutor) + self.sub_strategy = init_instance_by_config( + sub_strategy, common_infra=common_infra, accept_types=self.BaseStrategy + ) + super(SplitExecutor, self).__init__( step_bar=step_bar, start_time=start_time, end_time=end_time, - trade_account=trade_account, generate_report=generate_report, verbose=verbose, track_data=track_data, + common_infra=common_infra, **kwargs, ) - if generate_report: - self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange - self.sub_env = init_instance_by_config(sub_env, accept_types=BaseExecutor) - self.sub_strategy = init_instance_by_config(sub_strategy, accept_types=self.BaseStrategy) + + if generate_report and trade_exchange is not None: + self.trade_exchange = trade_exchange + + def reset_common_infra(self, common_infra): + """ + reset infrastructure for trading + - reset trade_exchange + - reset substrategy and subexecutor common infra + """ + super(SplitExecutor, self).reset_common_infra(common_infra) + + if self.generate_report and "trade_exchange" in common_infra: + self.trade_exchange = common_infra.get("trade_exchange") + + self.sub_executor.reset_common_infra(common_infra) + self.sub_strategy.reset_common_infra(common_infra) def get_init_state(self): - init_state = {"current": self.trade_account.current} - return init_state + return [] - def _init_sub_trading(self, order_list): - trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) - self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time) - self.sub_strategy.reset(start_time=trade_start_time, end_time=trade_end_time, trade_order_list=order_list) - sub_execute_state = self.sub_env.get_init_state() - return sub_execute_state + def _init_sub_trading(self, trade_decision): + trade_index = self.trade_calendar.get_trade_index() + trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index) + self.sub_executor.reset(start_time=trade_start_time, end_time=trade_end_time) + sub_level_infra = self.sub_executor.get_level_infra() + self.sub_strategy.reset(level_infra=sub_level_infra, rely_trade_decision=trade_decision) def _update_trade_account(self): - trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) + trade_index = self.trade_calendar.get_trade_index() + trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index) self.trade_account.update_bar_count() if self.generate_report: self.trade_account.update_bar_report( @@ -204,30 +202,38 @@ class SplitExecutor(BaseExecutor): trade_exchange=self.trade_exchange, ) - def execute(self, order_list): - super(SplitExecutor, self).step() - self._init_sub_trading(order_list) - sub_execute_state = self.sub_env.get_init_state() - while not self.sub_env.finished(): - _order_list = self.sub_strategy.generate_order_list(sub_execute_state) - sub_execute_state = self.sub_env.execute(order_list=_order_list) - self._update_trade_account() - return {"current": self.trade_account.current} + def execute(self, trade_decision): + self.trade_calendar.step() + self._init_sub_trading(trade_decision) + execute_state = [] + sub_execute_state = self.sub_executor.get_init_state() + while not self.sub_executor.finished(): + sub_trade_decison = self.sub_strategy.generate_trade_decision(sub_execute_state) + sub_execute_state = self.sub_executor.execute(trade_decision=sub_trade_decison) + execute_state.extend(sub_execute_state) + if hasattr(self, "trade_account"): + self._update_trade_account() - def get_data(self, order_list): + return execute_state + + def collect_data(self, trade_decision): if self.track_data: - yield order_list - super(SplitExecutor, self).step() - self._init_sub_trading(order_list) - sub_execute_state = self.sub_env.get_init_state() - while not self.sub_env.finished(): - _order_list = self.sub_strategy.generate_order_list(sub_execute_state) - sub_execute_state = yield from self.sub_env.get_data(order_list=_order_list) - self._update_trade_account() - return {"current": self.trade_account.current} + yield trade_decision + self.trade_calendar.step() + self._init_sub_trading(trade_decision) + execute_state = [] + sub_execute_state = self.sub_executor.get_init_state() + while not self.sub_executor.finished(): + sub_trade_decison = self.sub_strategy.generate_trade_decision(sub_execute_state) + sub_execute_state = yield from self.sub_executor.collect_data(trade_decision=sub_trade_decison) + execute_state.extend(sub_execute_state) + if hasattr(self, "trade_account"): + self._update_trade_account() + + return execute_state def get_report(self): - sub_env_report_dict = self.sub_env.get_report() + sub_env_report_dict = self.sub_executor.get_report() if self.generate_report: _report = self.trade_account.report.generate_report_dataframe() _positions = self.trade_account.get_positions() @@ -242,46 +248,57 @@ class SimulatorExecutor(BaseExecutor): step_bar: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None, - trade_account: Account = None, trade_exchange: Exchange = None, generate_report: bool = False, verbose: bool = False, track_data: bool = False, + common_infra: dict = {}, **kwargs, ): """ Parameters ---------- trade_exchange : Exchange - exchange that provides market info + exchange that provides market info, used to deal order and generate report + - If `trade_exchange` is None, self.trade_exchange will be set with common_infra """ super(SimulatorExecutor, self).__init__( step_bar=step_bar, start_time=start_time, end_time=end_time, - trade_account=trade_account, generate_report=generate_report, verbose=verbose, track_data=track_data, + common_infra=common_infra, **kwargs, ) - self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange + if trade_exchange is not None: + self.trade_exchange = trade_exchange + + def reset_common_infra(self, common_infra): + """ + reset infrastructure for trading + - reset trade_exchange + """ + super(SimulatorExecutor, self).reset_common_infra(common_infra) + if "trade_exchange" in common_infra: + self.trade_exchange = common_infra.get("trade_exchange") def get_init_state(self): - init_state = {"current": self.trade_account.current, "trade_info": []} - return init_state + return [] - def execute(self, order_list): - super(SimulatorExecutor, self).step() - trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) - trade_info = [] - for order in order_list: + def execute(self, trade_decision): + self.trade_calendar.step() + trade_index = self.trade_calendar.get_trade_index() + trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index) + execute_state = [] + for order in trade_decision: if self.trade_exchange.check_order(order) is True: # execute the order trade_val, trade_cost, trade_price = self.trade_exchange.deal_order( order, trade_account=self.trade_account ) - trade_info.append((order, trade_val, trade_cost, trade_price)) + execute_state.append((order, trade_val, trade_cost, trade_price)) if self.verbose: if order.direction == Order.SELL: # sell print( @@ -323,7 +340,7 @@ class SimulatorExecutor(BaseExecutor): trade_exchange=self.trade_exchange, ) - return {"current": self.trade_account.current, "trade_info": trade_info} + return execute_state def get_report(self): if self.generate_report: diff --git a/qlib/contrib/backtest/faculty.py b/qlib/contrib/backtest/faculty.py deleted file mode 100644 index 34ad14cbc..000000000 --- a/qlib/contrib/backtest/faculty.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - - -class Faculty: - def __init__(self): - self.__dict__["_faculty"] = dict() - - def __getitem__(self, key): - return self.__dict__["_faculty"][key] - - def __getattr__(self, attr): - if attr in self.__dict__["_faculty"]: - return self.__dict__["_faculty"][attr] - - raise AttributeError(f"No such {attr} in self._faculty") - - def __setitem__(self, key, value): - self.__dict__["_faculty"][key] = value - - def __setattr__(self, attr, value): - self.__dict__["_faculty"][attr] = value - - def update(self, *args, **kwargs): - self.__dict__["_faculty"].update(*args, **kwargs) - - -common_faculty = Faculty() diff --git a/qlib/contrib/backtest/position.py b/qlib/contrib/backtest/position.py index 0b39990b3..978ea0387 100644 --- a/qlib/contrib/backtest/position.py +++ b/qlib/contrib/backtest/position.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - +import numpy as np import pandas as pd import copy import pathlib diff --git a/qlib/contrib/backtest/report.py b/qlib/contrib/backtest/report.py index 57e56c9a3..3763f5214 100644 --- a/qlib/contrib/backtest/report.py +++ b/qlib/contrib/backtest/report.py @@ -3,16 +3,51 @@ from collections import OrderedDict +from logging import warning import pandas as pd import pathlib +import warnings + +from pandas.core.frame import DataFrame + +from ...utils.resam import parse_freq, resam_ts_data +from ...data import D class Report: # daily report of the account # contain those followings: returns, costs turnovers, accounts, cash, bench, value # update report - def __init__(self): + def __init__(self, freq: str = "day", benchmark_config: dict = {}): + """ + Parameters + ---------- + freq : str + frequency of trading bar, used for updating hold count of trading bar + benchmark_config : dict + config of benchmark, may including the following arguments: + - benchmark : Union[str, list, pd.Series] + - If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T. + example: + print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head()) + 2017-01-04 0.011693 + 2017-01-05 0.000721 + 2017-01-06 -0.004322 + 2017-01-09 0.006874 + 2017-01-10 -0.003350 + - If `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'. + - If `benchmark` is str, will use the daily change as the 'bench'. + benchmark code, default is SH000300 CSI300 + - start_time : Union[str, pd.Timestamp], optional + - If `benchmark` is pd.Series, it will be ignored + - Else, it represent start time of benchmark, by default None + - end_time : Union[str, pd.Timestamp], optional + - If `benchmark` is pd.Series, it will be ignored + - Else, it represent end time of benchmark, by default None + + """ self.init_vars() + self.init_bench(freq=freq, benchmark_config=benchmark_config) def init_vars(self): self.accounts = OrderedDict() # account postion value for each trade date @@ -24,6 +59,49 @@ class Report: self.benches = OrderedDict() self.latest_report_time = None # pd.TimeStamp + def init_bench(self, freq=None, benchmark_config=None): + if freq is not None: + self.freq = freq + if benchmark_config is not None: + self.benchmark_config = benchmark_config + self.bench = self._cal_benchmark(self.benchmark_config, self.freq) + + def _cal_benchmark(self, benchmark_config, freq): + benchmark = benchmark_config.get("benchmark", "SH000300") + if isinstance(benchmark, pd.Series): + return benchmark + else: + start_time = benchmark_config.get("start_time", None) + end_time = benchmark_config.get("end_time", None) + + if freq is None: + raise ValueError("benchmark freq can't be None!") + _codes = benchmark if isinstance(benchmark, list) else [benchmark] + fields = ["$close/Ref($close,1)-1"] + try: + _temp_result = D.features(_codes, fields, start_time, end_time, freq=freq, disk_cache=1) + except ValueError: + _, norm_freq = parse_freq(freq) + if norm_freq in ["month", "week", "day"]: + try: + _temp_result = D.features(_codes, fields, start_time, end_time, freq="day", disk_cache=1) + except ValueError: + _temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1) + elif norm_freq == "minute": + _temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1) + else: + raise ValueError(f"benchmark freq {freq} is not supported") + if len(_temp_result) == 0: + raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark") + return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0) + + def _sample_benchmark(self, bench, trade_start_time, trade_end_time): + def cal_change(x): + return (x + 1).prod() - 1 + + _ret = resam_ts_data(bench, trade_start_time, trade_end_time, method=cal_change) + return 0.0 if _ret is None else _ret + def is_empty(self): return len(self.accounts) == 0 @@ -35,30 +113,39 @@ class Report: def update_report_record( self, - trade_time=None, + trade_start_time=None, + trade_end_time=None, account_value=None, cash=None, return_rate=None, turnover_rate=None, cost_rate=None, stock_value=None, - bench_value=None, ): # check data - if None in [trade_time, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value, bench_value]: + if None in [ + trade_start_time, + trade_end_time, + account_value, + cash, + return_rate, + turnover_rate, + cost_rate, + stock_value, + ]: raise ValueError( - "None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value, bench_value]" + "None in [trade_start_time, trade_end_time, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value]" ) # update report data - self.accounts[trade_time] = account_value - self.returns[trade_time] = return_rate - self.turnovers[trade_time] = turnover_rate - self.costs[trade_time] = cost_rate - self.values[trade_time] = stock_value - self.cashes[trade_time] = cash - self.benches[trade_time] = bench_value + self.accounts[trade_start_time] = account_value + self.returns[trade_start_time] = return_rate + self.turnovers[trade_start_time] = turnover_rate + self.costs[trade_start_time] = cost_rate + self.values[trade_start_time] = stock_value + self.cashes[trade_start_time] = cash + self.benches[trade_start_time] = self._sample_benchmark(self.bench, trade_start_time, trade_end_time) # update latest_report_date - self.latest_report_time = trade_time + self.latest_report_time = trade_start_time # finish daily report update def generate_report_dataframe(self): diff --git a/qlib/contrib/backtest/utils.py b/qlib/contrib/backtest/utils.py new file mode 100644 index 000000000..1a4173887 --- /dev/null +++ b/qlib/contrib/backtest/utils.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pandas as pd +from typing import Union + +from ...utils.resam import get_resam_calendar +from ...data.data import Cal + + +class TradeCalendarManager: + """ + Manager for trading calendar + - BaseStrategy and BaseExecutor will use it + """ + + def __init__( + self, step_bar: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None + ): + """ + Parameters + ---------- + step_bar : str + frequency of each trading calendar + start_time : Union[str, pd.Timestamp], optional + closed start of the trading calendar, by default None + If `start_time` is None, it must be reset before trading. + end_time : Union[str, pd.Timestamp], optional + closed end of the trade time range, by default None + If `end_time` is None, it must be reset before trading. + """ + self.step_bar = step_bar + self.start_time = pd.Timestamp(start_time) if start_time else None + self.end_time = pd.Timestamp(start_time) if start_time else None + self._init_trade_calendar(step_bar=step_bar, start_time=start_time, end_time=end_time) + + def _init_trade_calendar(self, step_bar, start_time, end_time): + """reset trade calendar""" + _calendar, freq, freq_sam = get_resam_calendar(freq=step_bar) + self.calendar = _calendar + _, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam) + self.start_index = _start_index + self.end_index = _end_index + self.trade_len = _end_index - _start_index + 1 + self.trade_index = 0 + + def finished(self): + return self.trade_index >= self.trade_len + + def step(self): + if self.finished(): + raise RuntimeError(f"The calendar is finished, please reset it if you want to call it!") + self.trade_index = self.trade_index + 1 + + def get_step_bar(self): + return self.step_bar + + def get_trade_len(self): + return self.trade_len + + def get_trade_index(self): + return self.trade_index + + def get_calendar_time(self, trade_index=1, shift=0): + trade_index = trade_index - shift + calendar_index = self.start_index + trade_index + return self.calendar[calendar_index - 1], self.calendar[calendar_index] - pd.Timedelta(seconds=1) diff --git a/qlib/contrib/evaluate.py b/qlib/contrib/evaluate.py index 10f80671e..59a831f3e 100644 --- a/qlib/contrib/evaluate.py +++ b/qlib/contrib/evaluate.py @@ -3,6 +3,7 @@ from __future__ import division from __future__ import print_function +from logging import warn import numpy as np import pandas as pd @@ -10,7 +11,7 @@ import warnings from ..log import get_module_logger from .backtest import get_exchange, backtest as backtest_func from ..utils import get_date_range -from ..utils.sample import parse_freq +from ..utils.resam import parse_freq from ..data import D from ..config import C @@ -20,7 +21,7 @@ from ..data.dataset.utils import get_level_index logger = get_module_logger("Evaluate") -def risk_analysis(r, N: int = None, freq: str = None): +def risk_analysis(r, N: int = None, freq: str = "day"): """Risk Analysis Parameters @@ -36,8 +37,8 @@ def risk_analysis(r, N: int = None, freq: str = None): def cal_risk_analysis_scaler(freq): _count, _freq = parse_freq(freq) _freq_scaler = { - "minute": 240 * 250, - "day": 250, + "minute": 240 * 252, + "day": 252, "week": 50, "month": 12, } @@ -45,6 +46,8 @@ def risk_analysis(r, N: int = None, freq: str = None): if N is None and freq is None: raise ValueError("at least one of `N` and `freq` should exist") + if N is not None and freq is not None: + warnings.warn("risk_analysis freq will be ignored") if N is None: N = cal_risk_analysis_scaler(freq) diff --git a/qlib/contrib/online/operator.py b/qlib/contrib/online/operator.py index d2307dad5..8d78f2c50 100644 --- a/qlib/contrib/online/operator.py +++ b/qlib/contrib/online/operator.py @@ -118,7 +118,7 @@ class Operator: user.strategy.update(score_series, pred_date, trade_date) # generate and save order list - order_list = user.strategy.generate_order_list( + order_list = user.strategy.generate_trade_decision( score_series=score_series, current=user.account.current, trade_exchange=trade_exchange, @@ -208,7 +208,7 @@ class Operator: self.logger.info("Update account state {} for {}".format(trade_date, user_id)) def simulate(self, id, config, exchange_config, start, end, path, bench="SH000905"): - """Run the ( generate_order_list -> execute_order_list -> update_account) process everyday + """Run the ( generate_trade_decision -> execute_order_list -> update_account) process everyday from start date to end date. Parameters @@ -256,7 +256,7 @@ class Operator: user.strategy.update(score_series, pred_date, trade_date) # 3. generate and save order list - order_list = user.strategy.generate_order_list( + order_list = user.strategy.generate_trade_decision( score_series=score_series, current=user.account.current, trade_exchange=trade_exchange, diff --git a/qlib/contrib/strategy/cost_control.py b/qlib/contrib/strategy/cost_control.py index 8b3e3db18..58e3fccc4 100644 --- a/qlib/contrib/strategy/cost_control.py +++ b/qlib/contrib/strategy/cost_control.py @@ -10,17 +10,15 @@ import copy class SoftTopkStrategy(WeightStrategyBase): def __init__( self, - step_bar, model, dataset, topk, - start_time=None, - end_time=None, order_generator_cls_or_obj=OrderGenWInteract, - trade_exchange=None, max_sold_weight=1.0, risk_degree=0.95, buy_method="first_fill", + level_infra={}, + common_infra={}, **kwargs, ): """Parameter @@ -33,7 +31,7 @@ class SoftTopkStrategy(WeightStrategyBase): average_fill: assign the weight to the stocks rank high averagely. """ super(SoftTopkStrategy, self).__init__( - step_bar, model, dataset, start_time, end_time, order_generator_cls_or_obj, trade_exchange + model, dataset, order_generator_cls_or_obj, level_infra, common_infra, **kwargs ) self.topk = topk self.max_sold_weight = max_sold_weight diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/model_strategy.py index b3bb33a88..336cfa534 100644 --- a/qlib/contrib/strategy/model_strategy.py +++ b/qlib/contrib/strategy/model_strategy.py @@ -3,29 +3,26 @@ import warnings import numpy as np import pandas as pd -from ...utils.sample import sample_feature +from ...utils.resam import resam_ts_data from ...strategy.base import ModelStrategy from ..backtest.order import Order -from ..backtest.faculty import common_faculty from .order_generator import OrderGenWInteract class TopkDropoutStrategy(ModelStrategy): def __init__( self, - step_bar, model, dataset, topk, n_drop, - start_time=None, - end_time=None, - trade_exchange=None, method_sell="bottom", method_buy="top", risk_degree=0.95, hold_thresh=1, only_tradable=False, + level_infra={}, + common_infra={}, **kwargs, ): """ @@ -51,8 +48,9 @@ class TopkDropoutStrategy(ModelStrategy): else: strategy will make decision with the tradable state of the stock info and avoid buy and sell them. """ - super(TopkDropoutStrategy, self).__init__(step_bar, model, dataset, start_time, end_time, **kwargs) - self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange + super(TopkDropoutStrategy, self).__init__( + model, dataset, level_infra=level_infra, common_infra=common_infra, **kwargs + ) self.topk = topk self.n_drop = n_drop self.method_sell = method_sell @@ -61,6 +59,20 @@ class TopkDropoutStrategy(ModelStrategy): self.hold_thresh = hold_thresh self.only_tradable = only_tradable + def reset_common_infra(self, common_infra): + """ + Parameters + ---------- + common_infra : dict, optional + common infrastructure for backtesting, by default None + - It should include `trade_account`, used to get position + - It should include `trade_exchange`, used to provide market info + """ + super(TopkDropoutStrategy, self).reset_common_infra(common_infra) + + if "trade_exchange" in common_infra: + self.trade_exchange = common_infra.get("trade_exchange") + def get_risk_degree(self, trade_index=None): """get_risk_degree Return the proportion of your total value you will used in investment. @@ -69,11 +81,11 @@ class TopkDropoutStrategy(ModelStrategy): # It will use 95% amoutn of your total value by default return self.risk_degree - def generate_order_list(self, execute_state): - super(TopkDropoutStrategy, self).step() - trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) - pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1) - pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") + def generate_trade_decision(self, execute_state): + trade_index = self.trade_calendar.get_trade_index() + trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index) + pred_start_time, pred_end_time = self.trade_calendar.get_calendar_time(trade_index, shift=1) + pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") if pred_score is None: return [] if self.only_tradable: @@ -115,8 +127,7 @@ class TopkDropoutStrategy(ModelStrategy): def filter_stock(l): return l - current = execute_state.get("current") - current_temp = copy.deepcopy(current) + current_temp = copy.deepcopy(self.trade_position) # generate order list for this adjust date sell_order_list = [] buy_order_list = [] @@ -168,7 +179,8 @@ class TopkDropoutStrategy(ModelStrategy): continue if code in sell: # check hold limit - if current_temp.get_stock_count(code, bar=self.step_bar) < self.hold_thresh: + step_bar = self.trade_calendar.get_step_bar() + if current_temp.get_stock_count(code, bar=step_bar) < self.hold_thresh: continue # sell order sell_amount = current_temp.get_stock_amount(code=code) @@ -228,22 +240,35 @@ class TopkDropoutStrategy(ModelStrategy): class WeightStrategyBase(ModelStrategy): def __init__( self, - step_bar, model, dataset, - start_time=None, - end_time=None, order_generator_cls_or_obj=OrderGenWInteract, - trade_exchange=None, + level_infra={}, + common_infra={}, **kwargs, ): - super(WeightStrategyBase, self).__init__(step_bar, model, dataset, start_time, end_time, **kwargs) - self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange + super(WeightStrategyBase, self).__init__( + model, dataset, level_infra=level_infra, common_infra=common_infra, **kwargs + ) if isinstance(order_generator_cls_or_obj, type): self.order_generator = order_generator_cls_or_obj() else: self.order_generator = order_generator_cls_or_obj + def reset_common_infra(self, common_infra): + """ + Parameters + ---------- + common_infra : dict, optional + common infrastructure for backtesting, by default None + - It should include `trade_account`, used to get position + - It should include `trade_exchange`, used to provide market info + """ + super(WeightStrategyBase, self).reset_common_infra(common_infra) + + if "trade_exchange" in common_infra: + self.trade_exchange = common_infra.get("trade_exchange") + def get_risk_degree(self, trade_index=None): """get_risk_degree Return the proportion of your total value you will used in investment. @@ -267,7 +292,7 @@ class WeightStrategyBase(ModelStrategy): """ raise NotImplementedError() - def generate_order_list(self, execute_state): + def generate_trade_decision(self, execute_state): """ Parameters ----------- @@ -280,23 +305,22 @@ class WeightStrategyBase(ModelStrategy): trade_date : pd.Timestamp date. """ - # generate_order_list + # generate_trade_decision # generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list - super(WeightStrategyBase, self).step() - trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) - pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1) - pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") + trade_index = self.trade_calendar.get_trade_index() + trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index) + pred_start_time, pred_end_time = self.trade_calendar.get_calendar_time(trade_index, shift=1) + pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") if pred_score is None: return [] - current = execute_state.get("current") - current_temp = copy.deepcopy(current) + current_temp = copy.deepcopy(self.trade_position) target_weight_position = self.generate_target_weight_position( score=pred_score, current=current_temp, trade_start_time=trade_start_time, trade_end_time=trade_end_time ) order_list = self.order_generator.generate_order_list_from_target_weight_position( current=current_temp, trade_exchange=self.trade_exchange, - risk_degree=self.get_risk_degree(self.trade_index), + risk_degree=self.get_risk_degree(trade_index), target_weight_position=target_weight_position, pred_start_time=pred_start_time, pred_end_time=pred_end_time, diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 0e0f2b907..2265a9dc5 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -1,80 +1,77 @@ -import copy import warnings -import numpy as np -import pandas as pd -from typing import Union -from ...utils.sample import sample_feature +from ...utils.resam import resam_ts_data from ...data.data import D from ...data.dataset.utils import convert_index_format -from ...strategy.base import RuleStrategy, OrderEnhancement +from ...strategy.base import RuleStrategy from ..backtest.order import Order -from ..backtest.faculty import common_faculty -class TWAPStrategy(RuleStrategy, OrderEnhancement): +class TWAPStrategy(RuleStrategy): """TWAP Strategy for trading""" - def __init__( - self, - step_bar, - start_time=None, - end_time=None, - trade_exchange=None, - trade_order_list=[], - **kwargs, - ): + def reset_common_infra(self, common_infra): """ Parameters ---------- - trade_exchange : Exchange, optional - exchange that provides market info, by default None - - If `trade_exchange` is None, self.trade_exchange will be set with common_faculty - trade_order_list : list, optional - order list to trade, which the strategy will trade in [start_time , end_time] , by default [] + common_infra : dict, optional + common infrastructure for backtesting, by default None + - It should include `trade_account`, used to get position + - It should include `trade_exchange`, used to provide market info """ - super(TWAPStrategy, self).__init__(step_bar, start_time, end_time, **kwargs) - self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange - self.trade_order_list = trade_order_list + super(TWAPStrategy, self).reset_common_infra(common_infra) + if common_infra is not None: + if "trade_exchange" in common_infra: + self.trade_exchange = common_infra.get("trade_exchange") - def reset(self, trade_order_list: list = None, **kwargs): - super(TWAPStrategy, self).reset(**kwargs) - OrderEnhancement.reset(self, trade_order_list=trade_order_list) - if trade_order_list is not None: + def reset(self, rely_trade_decision: object = None, **kwargs): + """ + Parameters + ---------- + rely_trade_decision : object, optional + """ + + super(TWAPStrategy, self).reset(rely_trade_decision=rely_trade_decision, common_infra=common_infra, **kwargs) + if rely_trade_decision is not None: self.trade_amount = {} - for order in self.trade_order_list: + for order in rely_trade_decision: self.trade_amount[(order.stock_id, order.direction)] = order.amount - def generate_order_list(self, execute_state): - super(TWAPStrategy, self).step() - trade_info = execute_state.get("trade_info") + def generate_trade_decision(self, execute_state): + + # update the order amount + trade_info = execute_state for order, _, _, _ in trade_info: self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount - trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) + trade_index = self.trade_calendar.get_trade_index() + trade_len = self.trade_calendar.get_trade_len() + trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index) order_list = [] - for order in self.trade_order_list: + for order in self.rely_trade_decision: if not self.trade_exchange.is_stock_tradable( stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time ): continue _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor) _order_amount = None + # consider trade unit if _amount_trade_unit is None: - _order_amount = self.trade_amount[(order.stock_id, order.direction)] / ( - self.trade_len - self.trade_index + 1 - ) - if self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit: + # split the order equally + _order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_index + 1) + # without considering trade unit + elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit: + # split the order equally + # floor((trade_unit_cnt + trade_len - trade_index) / (trade_len - trade_index + 1)) == ceil(trade_unit_cnt / (trade_len - trade_index + 1)) trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) _order_amount = ( - (trade_unit_cnt + self.trade_len - self.trade_index) - // (self.trade_len - self.trade_index + 1) - * _amount_trade_unit + (trade_unit_cnt + trade_len - trade_index) // (trade_len - trade_index + 1) * _amount_trade_unit ) if order.direction == order.SELL: + # sell all amount at last if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and ( - _order_amount is None or self.trade_index == self.trade_len + _order_amount is None or trade_index == trade_len ): _order_amount = self.trade_amount[(order.stock_id, order.direction)] @@ -92,7 +89,7 @@ class TWAPStrategy(RuleStrategy, OrderEnhancement): return order_list -class SBBStrategyBase(RuleStrategy, OrderEnhancement): +class SBBStrategyBase(RuleStrategy): """ (S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy. """ @@ -101,81 +98,80 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement): TREND_SHORT = 1 TREND_LONG = 2 - def __init__( - self, - step_bar, - start_time=None, - end_time=None, - trade_exchange=None, - trade_order_list=[], - **kwargs, - ): + def reset_common_infra(self, common_infra): + super(SBBStrategyBase, self).reset_common_infra(common_infra) + if common_infra is not None: + if "trade_exchange" in common_infra: + self.trade_exchange = common_infra.get("trade_exchange") + + def reset(self, rely_trade_decision=None, **kwargs): """ Parameters ---------- - trade_exchange : Exchange, optional - exchange that provides market info, by default None - - If `trade_exchange` is None, self.trade_exchange will be set with common_faculty - trade_order_list : list, optional - order list to trade, which the strategy will trade in [start_time , end_time] , by default [] + rely_trade_decision : object, optional + common_infra : None, optional + common infrastructure for backtesting, by default None + - It should include `trade_account`, used to get position + - It should include `trade_exchange`, used to provide market info """ - super(SBBStrategyBase, self).__init__(step_bar, start_time, end_time, **kwargs) - self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange - self.trade_order_list = trade_order_list - - def reset(self, trade_order_list=None, **kwargs): - super(SBBStrategyBase, self).reset(**kwargs) - OrderEnhancement.reset(self, trade_order_list=trade_order_list) - if trade_order_list is not None: + super(SBBStrategyBase, self).reset(rely_trade_decision=rely_trade_decision, **kwargs) + if rely_trade_decision is not None: self.trade_trend = {} self.trade_amount = {} - for order in self.trade_order_list: + # init the trade amount of order and predicted trade trend + for order in rely_trade_decision: self.trade_trend[(order.stock_id, order.direction)] = self.TREND_MID self.trade_amount[(order.stock_id, order.direction)] = order.amount def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None): raise NotImplementedError("pred_price_trend method is not implemented!") - def generate_order_list(self, execute_state): - super(SBBStrategyBase, self).step() + def generate_trade_decision(self, execute_state): - trade_info = execute_state.get("trade_info") + # update the order amount + trade_info = execute_state for order, _, _, _ in trade_info: self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount - - trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) - pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1) + trade_index = self.trade_calendar.get_trade_index() + trade_len = self.trade_calendar.get_trade_len() + trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index) + pred_start_time, pred_end_time = self.trade_calendar.get_calendar_time(trade_index, shift=1) order_list = [] - for order in self.trade_order_list: - if self.trade_index % 2 == 1: + # for each order in in self.rely_trade_decision + for order in self.rely_trade_decision: + # predict the price trend + if trade_index % 2 == 1: _pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time) else: _pred_trend = self.trade_trend[(order.stock_id, order.direction)] - + # if not tradable, continue if not self.trade_exchange.is_stock_tradable( stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time ): - if self.trade_index % 2 == 1: + if trade_index % 2 == 1: self.trade_trend[(order.stock_id, order.direction)] = _pred_trend continue - + # get amount of one trade unit _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor) if _pred_trend == self.TREND_MID: _order_amount = None + # considering trade unit if _amount_trade_unit is None: - _order_amount = self.trade_amount[(order.stock_id, order.direction)] / ( - self.trade_len - self.trade_index + 1 - ) + # split the order equally + _order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_index + 1) + # without considering trade unit elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit: + # cal how many trade unit trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) + # split the order equally + # floor((trade_unit_cnt + trade_len - trade_index) / (trade_len - trade_index + 1)) == ceil(trade_unit_cnt / (trade_len - trade_index + 1)) _order_amount = ( - (trade_unit_cnt + self.trade_len - self.trade_index) - // (self.trade_len - self.trade_index + 1) - * _amount_trade_unit + (trade_unit_cnt + trade_len - trade_index) // (trade_len - trade_index + 1) * _amount_trade_unit ) if order.direction == order.SELL: + # sell all amount at last if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and ( - _order_amount is None or self.trade_index == self.trade_len + _order_amount is None or trade_index == trade_len ): _order_amount = self.trade_amount[(order.stock_id, order.direction)] @@ -185,36 +181,43 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement): amount=_order_amount, start_time=trade_start_time, end_time=trade_end_time, - direction=order.direction, # 1 for buy + direction=order.direction, factor=order.factor, ) order_list.append(_order) - # print("DEBUG AMOUNT", _order_amount, self.trade_amount[(order.stock_id, order.direction)], _amount_trade_unit) + else: _order_amount = None + # considering trade unit if _amount_trade_unit is None: + # N trade day last, split the order into N + 1 parts, and trade 2 parts _order_amount = ( - 2 - * self.trade_amount[(order.stock_id, order.direction)] - / (self.trade_len - self.trade_index + 2) + 2 * self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_index + 2) ) + # without considering trade unit elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit: + # cal how many trade unit trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) + # N trade day last, split the order into N + 1 parts, and trade 2 parts _order_amount = ( - (trade_unit_cnt + self.trade_len - self.trade_index + 1) - // (self.trade_len - self.trade_index + 2) + (trade_unit_cnt + trade_len - trade_index + 1) + // (trade_len - trade_index + 2) * 2 * _amount_trade_unit ) if order.direction == order.SELL: + # sell all amount at last if self.trade_amount[(order.stock_id, order.direction)] >= 1e-5 and ( - _order_amount is None or self.trade_index == self.trade_len + _order_amount is None or trade_index == trade_len ): _order_amount = self.trade_amount[(order.stock_id, order.direction)] if _order_amount: _order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)]) - if self.trade_index % 2 == 1: + if trade_index % 2 == 1: + # in the first of two adjacent bar + # if look short on the price, sell the stock more + # if look long on the price, sell the stock more if ( _pred_trend == self.TREND_SHORT and order.direction == order.SELL @@ -231,6 +234,9 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement): ) order_list.append(_order) else: + # in the second of two adjacent bar + # if look short on the price, buy the stock more + # if look long on the price, sell the stock more if ( _pred_trend == self.TREND_SHORT and order.direction == order.BUY @@ -246,8 +252,8 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement): factor=order.factor, ) order_list.append(_order) - # print("DEBUG AMOUNT", _order_amount, self.trade_amount[(order.stock_id, order.direction)], _amount_trade_unit) - if self.trade_index % 2 == 1: + + if trade_index % 2 == 1: self.trade_trend[(order.stock_id, order.direction)] = _pred_trend return order_list @@ -260,13 +266,11 @@ class SBBStrategyEMA(SBBStrategyBase): def __init__( self, - step_bar, - start_time=None, - end_time=None, - trade_exchange=None, - trade_order_list=[], + rely_trade_decision=[], instruments="csi300", freq="day", + level_infra={}, + common_infra={}, **kwargs, ): """ @@ -278,47 +282,49 @@ class SBBStrategyEMA(SBBStrategyBase): freq of EMA signal, by default "day" Note: `freq` may be different from `steb_bar` """ - super(SBBStrategyEMA, self).__init__(step_bar, start_time, end_time, trade_exchange, trade_order_list, **kwargs) if instruments is None: warnings.warn("`instruments` is not set, will load all stocks") self.instruments = "all" if isinstance(instruments, str): self.instruments = D.instruments(instruments) self.freq = freq + super(SBBStrategyEMA, self).__init__(rely_trade_decision, level_infra, common_infra, **kwargs) - def reset(self, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None, **kwargs): + def _reset_signal(self): + trade_len = self.trade_calendar.get_trade_len() + fields = ["EMA($close, 10)-EMA($close, 20)"] + signal_start_time, _ = self.trade_calendar.get_calendar_time(trade_index=1, shift=1) + _, signal_end_time = self.trade_calendar.get_calendar_time(trade_index=trade_len, shift=1) + signal_df = D.features( + self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq + ) + signal_df = convert_index_format(signal_df) + signal_df.columns = ["signal"] + self.signal = {} + for stock_id, stock_val in signal_df.groupby(level="instrument"): + self.signal[stock_id] = stock_val + + def reset_level_infra(self, level_infra): """ - Reset EMA signal for trading - - Parameters - ---------- - start_time : Union[str, pd.Timestamp], optional - start time for trading, also used to calculate the start time of EMA signal, by default None - - end_time : Union[str, pd.Timestamp], optional - end time for trading, also used to calculate the end time of EMA signal, by default None + reset level-shared infra + - After reset the trade_calendar, the signal will be changed """ - super(SBBStrategyEMA, self).reset(start_time=start_time, end_time=end_time, **kwargs) - if self.start_time and self.end_time and (start_time or end_time): - fields = ["EMA($close, 10)-EMA($close, 20)"] - signal_start_time, _ = self._get_calendar_time(trade_index=1, shift=1) - _, signal_end_time = self._get_calendar_time(trade_index=self.trade_len, shift=1) - signal_df = D.features( - self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq - ) - signal_df = convert_index_format(signal_df) - signal_df.columns = ["signal"] - self.signal = {} - for stock_id, stock_val in signal_df.groupby(level="instrument"): - self.signal[stock_id] = stock_val + if not hasattr(self, "level_infra"): + self.level_infra = level_infra + else: + self.level_infra.update(level_infra) + + if "trade_calendar" in level_infra: + self.trade_calendar = level_infra.get("trade_calendar") + self._reset_signal() def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None): if stock_id not in self.signal: return self.TREND_MID else: - _sample_signal = sample_feature( - self.signal[stock_id], pred_start_time, pred_end_time, fields="signal", method="last" + _sample_signal = resam_ts_data( + self.signal[stock_id]["signal"], pred_start_time, pred_end_time, method="last" ) if _sample_signal is None or _sample_signal.iloc[0] == 0: return self.TREND_MID diff --git a/qlib/data/data.py b/qlib/data/data.py index 91a21da9f..394c3271e 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -26,7 +26,7 @@ from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, co from .base import Feature from .cache import DiskDatasetCache, DiskExpressionCache from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path -from ..utils.sample import sample_calendar +from ..utils.resam import resam_calendar class CalendarProvider(abc.ABC): @@ -133,7 +133,7 @@ class CalendarProvider(abc.ABC): if freq_sam is None: return _calendar, _calendar_index else: - _calendar_sam = sample_calendar(_calendar, freq, freq_sam) + _calendar_sam = resam_calendar(_calendar, freq, freq_sam) _calendar_sam_index = {x: i for i, x in enumerate(_calendar_sam)} H["c"][flag] = _calendar_sam, _calendar_sam_index return _calendar_sam, _calendar_sam_index diff --git a/qlib/rl/env.py b/qlib/rl/env.py index fae17918d..2fef7a659 100644 --- a/qlib/rl/env.py +++ b/qlib/rl/env.py @@ -1,9 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .interpreter import StateInterpreter, ActionInterpreter +from typing import Union +from .interpreter import StateInterpreter, ActionInterpreter from ..contrib.backtest.executor import BaseExecutor +from ..utils import init_instance_by_config class BaseRLEnv: @@ -52,35 +54,22 @@ class QlibIntRLEnv(QlibRLEnv): def __init__( self, executor: BaseExecutor, - state_interpreter: StateInterpreter, - action_interpreter: ActionInterpreter, - state_interpret_kwargs: dict = {}, - action_interpret_kwargs: dict = {}, + state_interpreter: Union[dict, StateInterpreter], + action_interpreter: Union[dict, ActionInterpreter], ): """ Parameters ---------- - state_interpreter : StateInterpreter + state_interpreter : Union[dict, StateInterpreter] interpretor that interprets the qlib execute result into rl env state. - action_interpreter : ActionInterpreter + + action_interpreter : Union[dict, ActionInterpreter] interpretor that interprets the rl agent action into qlib order list - state_interpret_kwargs : dict, optional - arguments may be used in `state_interpreter.interpret`, by default {} - such as the following arguments: - - trade exchange : Exchange - Exchange that can provide market info - action_interpret_kwargs: dict, optional - arguments may be used in `action_interpreter.interpret`, by default {} - such as the following arguments: - - trade_order_list : List[Order] - If the strategy is used to split order, it presents the trade order pool. """ super(QlibIntRLEnv, self).__init__(executor=executor) - self.state_interpreter = state_interpreter - self.action_interpreter = action_interpreter - self.state_interpret_kwargs = state_interpret_kwargs - self.action_interpret_kwargs = action_interpret_kwargs + self.state_interpreter = init_instance_by_config(state_interpreter) + self.action_interpreter = init_instance_by_config(action_interpreter) def step(self, action): """ @@ -96,11 +85,9 @@ class QlibIntRLEnv(QlibRLEnv): Returns ------- - env state to rl rl policy + env state to rl policy """ - _interpret_action = self.action_interpreter.interpret(action=action, **self.state_interpret_kwargs) + _interpret_action = self.action_interpreter.interpret(action=action) _execute_result = self.executor.execute(_interpret_action) - _interpret_state = self.state_interpreter.interpret( - execute_result=_execute_result, **self.action_interpret_kwargs - ) + _interpret_state = self.state_interpreter.interpret(execute_result=_execute_result) return _interpret_state diff --git a/qlib/rl/interpreter.py b/qlib/rl/interpreter.py index 3c94aac09..1e310e8ad 100644 --- a/qlib/rl/interpreter.py +++ b/qlib/rl/interpreter.py @@ -5,7 +5,6 @@ class BaseInterpreter: """Base Interpreter""" - @staticmethod def interpret(**kwargs): raise NotImplementedError("interpret is not implemented!") @@ -13,7 +12,6 @@ class BaseInterpreter: class ActionInterpreter(BaseInterpreter): """Action Interpreter that interpret rl agent action into qlib orders""" - @staticmethod def interpret(action, **kwargs): """interpret method @@ -34,7 +32,6 @@ class ActionInterpreter(BaseInterpreter): class StateInterpreter(BaseInterpreter): """State Interpreter that interpret execution result of qlib executor into rl env state""" - @staticmethod def interpret(execute_result, **kwargs): """interpret method diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index 5534998e9..dad994303 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import copy import pandas as pd from typing import List, Union @@ -9,16 +10,70 @@ from ..model.base import BaseModel from ..data.dataset import DatasetH from ..data.dataset.utils import convert_index_format from ..contrib.backtest.order import Order -from ..contrib.backtest.executor import BaseTradeCalendar from ..rl.interpreter import ActionInterpreter, StateInterpreter +from ..utils import init_instance_by_config -class BaseStrategy(BaseTradeCalendar): +class BaseStrategy: """Base strategy for trading""" - def generate_order_list(self, execute_state): - """Generate order list in each trading bar""" - raise NotImplementedError("generator_order_list is not implemented!") + def __init__( + self, + rely_trade_decision: object = None, + level_infra: dict = {}, + common_infra: dict = {}, + ): + """ + Parameters + ---------- + rely_trade_decision : object, optional + the high-level trade decison on which the startegy rely, and it will be traded in [start_time , end_time] , by default None + - If the strategy is used to split trade decison, it will be used + - If the strategy is used for portfolio management, it can be ignored + level_infra : dict, optional + level shared infrastructure for backtesting, including trade_calendar + common_infra : dict, optional + common infrastructure for backtesting, including trade_account, trade_exchange, .etc + """ + + self.reset(level_infra=level_infra, common_infra=common_infra, rely_trade_decision=rely_trade_decision) + + def reset_level_infra(self, level_infra): + if not hasattr(self, "level_infra"): + self.level_infra = level_infra + else: + self.level_infra.update(level_infra) + + if "trade_calendar" in level_infra: + self.trade_calendar = level_infra.get("trade_calendar") + + def reset_common_infra(self, common_infra): + if not hasattr(self, "common_infra"): + self.common_infra = common_infra + else: + self.common_infra.update(common_infra) + + if "trade_account" in common_infra: + self.trade_position = common_infra.get("trade_account").current + + def reset(self, level_infra: dict = None, common_infra: dict = None, rely_trade_decision=None, **kwargs): + """ + - reset `level_infra`, used to reset trade_calendar, .etc + - reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc + - reset `rely_trade_decision`, used to make split decison + """ + if level_infra is not None: + self.reset_level_infra(level_infra) + + if common_infra is not None: + self.reset_common_infra(common_infra) + + if rely_trade_decision is not None: + self.rely_trade_decision = rely_trade_decision + + def generate_trade_decision(self, execute_state): + """Generate trade decision in each trading bar""" + raise NotImplementedError("generate_trade_decision is not implemented!") class RuleStrategy(BaseStrategy): @@ -32,11 +87,11 @@ class ModelStrategy(BaseStrategy): def __init__( self, - step_bar: str, model: BaseModel, dataset: DatasetH, - start_time: Union[str, pd.Timestamp] = None, - end_time: Union[str, pd.Timestamp] = None, + rely_trade_decision: object = None, + level_infra: dict = {}, + common_infra: dict = {}, **kwargs, ): """ @@ -49,11 +104,10 @@ class ModelStrategy(BaseStrategy): kwargs : dict arguments that will be passed into `reset` method """ + super(ModelStrategy, self).__init__(rely_trade_decision, level_infra, common_infra, **kwargs) self.model = model self.dataset = dataset self.pred_scores = convert_index_format(self.model.predict(dataset), level="datetime") - # pred_score_dates = self.pred_scores.index.get_level_values(level="datetime") - super(ModelStrategy, self).__init__(step_bar, start_time, end_time, **kwargs) def _update_model(self): """ @@ -70,10 +124,10 @@ class RLStrategy(BaseStrategy): def __init__( self, - step_bar: str, policy, - start_time: Union[str, pd.Timestamp] = None, - end_time: Union[str, pd.Timestamp] = None, + rely_trade_decision: object = None, + level_infra: dict = {}, + common_infra: dict = {}, **kwargs, ): """ @@ -82,7 +136,7 @@ class RLStrategy(BaseStrategy): policy : RL policy for generate action """ - super(RLStrategy, self).__init__(step_bar, start_time, end_time, **kwargs) + super(RLStrategy, self).__init__(rely_trade_decision, level_infra, common_infra, **kwargs) self.policy = policy @@ -91,14 +145,12 @@ class RLIntStrategy(RLStrategy): def __init__( self, - step_bar: str, policy, state_interpreter: StateInterpreter, action_interpreter: ActionInterpreter, - start_time: Union[str, pd.Timestamp] = None, - end_time: Union[str, pd.Timestamp] = None, - state_interpret_kwargs: dict = {}, - action_interpret_kwargs: dict = {}, + rely_trade_decision: object = None, + level_infra: dict = {}, + common_infra: dict = {}, **kwargs, ): """ @@ -112,49 +164,16 @@ class RLIntStrategy(RLStrategy): start time of trading, by default None end_time : Union[str, pd.Timestamp], optional end time of trading, by default None - state_interpret_kwargs : dict, optional - arguments may be used in `state_interpreter.interpret`, by default {} - such as the following arguments: - - trade exchange : Exchange - Exchange that can provide market info - action_interpret_kwargs: dict, optional - arguments may be used in `action_interpreter.interpret`, by default {} - such as the following arguments: - - trade_order_list : List[Order] - If the strategy is used to split order, it presents the trade order pool. """ - super(RLIntStrategy, self).__init__(step_bar, policy, start_time, end_time, **kwargs) + super(RLIntStrategy, self).__init__(policy, rely_trade_decision, level_infra, common_infra, **kwargs) self.policy = policy - self.action_interpreter = action_interpreter - self.state_interpreter = state_interpreter - self.state_interpret_kwargs = state_interpret_kwargs - self.action_interpret_kwargs = action_interpret_kwargs + self.state_interpreter = init_instance_by_config(state_interpreter) + self.action_interpreter = init_instance_by_config(action_interpreter) - def generate_order_list(self, execute_state): + def generate_trade_decision(self, execute_state): super(RLStrategy, self).step() - _interpret_state = self.state_interpretor.interpret( - execute_result=execute_state, **self.action_interpret_kwargs - ) + _interpret_state = self.state_interpretor.interpret(execute_result=execute_state) _policy_action = self.policy.step(_interpret_state) - _order_list = self.action_interpreter.interpret(action=_policy_action, **self.state_interpret_kwargs) + _order_list = self.action_interpreter.interpret(action=_policy_action) return _order_list - - -class OrderEnhancement: - """ - Order enhancement for strategy - - If the strategy is used to split orders, the enhancement should be inherited - - If the strategy is used for portfolio management, the enhancement can be ignored - """ - - def reset(self, trade_order_list: List[Order] = None): - """reset trade orders for split strategy - - Parameters - ---------- - trade_order_list for split strategy: List[Order], optional - trading orders , by default None - """ - if trade_order_list is not None: - self.trade_order_list = trade_order_list diff --git a/qlib/utils/sample.py b/qlib/utils/resam.py similarity index 75% rename from qlib/utils/sample.py rename to qlib/utils/resam.py index 9f67d4981..8933b3a82 100644 --- a/qlib/utils/sample.py +++ b/qlib/utils/resam.py @@ -1,8 +1,13 @@ import re +import datetime + import numpy as np import pandas as pd from typing import Tuple, List, Union, Optional, Callable +from . import lazy_sort_index +from ..config import C + def parse_freq(freq: str) -> Tuple[int, str]: """ @@ -50,9 +55,10 @@ def parse_freq(freq: str) -> Tuple[int, str]: return _count, _freq_format_dict[_freq] -def sample_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np.ndarray: +def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np.ndarray: """ - Sample the calendar with frequency freq_raw into the calendar with frequency freq_sam + Resample the calendar with frequency freq_raw into the calendar with frequency freq_sam + Assumption: The fix length (240) of the calendar in each day. Parameters ---------- @@ -72,24 +78,36 @@ def sample_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> n sam_count, freq_sam = parse_freq(freq_sam) if not len(calendar_raw): return calendar_raw + + # if freq_sam is xminute, divide each trading day into several bars evenly if freq_sam == "minute": - def cal_next_sam_minute(x, sam_minutes): - hour = x.hour - minute = x.minute - if (hour == 9 and minute >= 30) or (9 < hour < 11) or (hour == 11 and minute < 30): - minute_index = (hour - 9) * 60 + minute - 30 - elif 13 <= hour < 15: - minute_index = (hour - 13) * 60 + minute + 120 + def cal_sam_minute(x, sam_minutes): + day_time = pd.Timestamp(x.date()) + shift = C.min_data_shift + # shift represents the shift minute the market time + # - open time of stock market is [9:30 - shift*pd.Timedelta(minutes=1)] + # - mid close time of stock market is [11:29 - shift*pd.Timedelta(minutes=1)] + # - mid open time of stock market is [13:30 - shift*pd.Timedelta(minutes=1)] + # - close time of stock market is [14:59 - shift*pd.Timedelta(minutes=1)] + open_time = day_time + pd.Timedelta(hours=9, minutes=30) - shift * pd.Timedelta(minutes=1) + mid_close_time = day_time + pd.Timedelta(hours=11, minutes=29) - shift * pd.Timedelta(minutes=1) + mid_open_time = day_time + pd.Timedelta(hours=13, minutes=30) - shift * pd.Timedelta(minutes=1) + close_time = day_time + pd.Timedelta(hours=14, minutes=59) - shift * pd.Timedelta(minutes=1) + + if open_time <= x <= mid_close_time: + minute_index = (x - open_time).seconds // 60 + elif mid_open_time <= x <= close_time: + minute_index = (x - mid_open_time).seconds // 60 + 120 else: - raise ValueError("calendar hour must be in [9, 11] or [13, 15]") + raise ValueError("datetime of calendar is out of range") minute_index = minute_index // sam_minutes * sam_minutes if 0 <= minute_index < 120: - return 9 + (minute_index + 30) // 60, (minute_index + 30) % 60 + return open_time + minute_index * pd.Timedelta(minutes=1) elif 120 <= minute_index < 240: - return 13 + (minute_index - 120) // 60, (minute_index - 120) % 60 + return mid_open_time + (minute_index - 120) * pd.Timedelta(minutes=1) else: raise ValueError("calendar minute_index error") @@ -98,14 +116,10 @@ def sample_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> n else: if raw_count > sam_count: raise ValueError("raw freq must be higher than sampling freq") - _calendar_minute = np.unique( - list( - map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_count), 0), calendar_raw) - ) - ) - if calendar_raw[0] > _calendar_minute[0]: - _calendar_minute[0] = calendar_raw[0] + _calendar_minute = np.unique(list(map(lambda x: cal_sam_minute(x, sam_count), calendar_raw))) return _calendar_minute + + # else, convert the raw calendar into day calendar, and divide the whole calendar into several bars evenly else: _calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw))) if freq_sam == "day": @@ -124,14 +138,14 @@ def sample_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> n raise ValueError("sampling freq must be xmin, xd, xw, xm") -def get_sample_freq_calendar( +def get_resam_calendar( start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None, freq: str = "day", future: bool = False, ) -> Tuple[np.ndarray, str, Optional[str]]: """ - Get the calendar with frequency freq. + Get the resampled calendar with frequency freq. - If the calendar with the raw frequency freq exists, return it directly @@ -186,16 +200,15 @@ def get_sample_freq_calendar( return _calendar, freq, freq_sam -def sample_feature( - feature: Union[pd.DataFrame, pd.Series], +def resam_ts_data( + ts_feature: Union[pd.DataFrame, pd.Series], start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None, - fields: Union[str, List[str]] = None, method: Union[str, Callable] = "last", method_kwargs: dict = {}, ): """ - Sample value from pandas DataFrame or Series for each stock + Resample value from time-series data - If `feature` has MultiIndex[instrument, datetime], apply the `method` to each instruemnt data with datetime in [start_time, end_time] Example: @@ -217,7 +230,7 @@ def sample_feature( 2010-01-12 2788.688232 164587.937500 2010-01-13 2790.604004 145460.453125 - print(sample_feature(feature, start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last")) + print(resam_ts_data(feature, start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last")) $close $volume instrument SH600000 87.433578 28117442.0 @@ -236,25 +249,23 @@ def sample_feature( 2010-01-07 83.788803 20813402.0 2010-01-08 84.730675 16044853.0 - print(sample_feature(feature, start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last")) + print(resam_ts_data(feature, start_time="2010-01-04", end_time="2010-01-05", method="last")) $close 87.433578 $volume 28117442.0 - print(sample_feature(feature, start_time="2010-01-04", end_time="2010-01-05", fields="$close", method="last")) + print(resam_ts_data(feature['$close'], start_time="2010-01-04", end_time="2010-01-05", method="last")) 87.433578 Parameters ---------- feature : Union[pd.DataFrame, pd.Series] - Raw feature to be sampled + Raw time-series feature to be resampled start_time : Union[str, pd.Timestamp], optional start sampling time, by default None end_time : Union[str, pd.Timestamp], optional end sampling time, by default None - fields : Union[str, List[str]], optional - column names, it's ignored when sample pd.Series data, by default None(all columns) method : Union[str, Callable], optional sample method, apply method function to each stock series data, by default "last" - If type(method) is str, it should be an attribute of SeriesGroupBy or DataFrameGroupby, and run feature.groupby @@ -264,24 +275,19 @@ def sample_feature( Returns ------- - The Sampled DataFrame/Series/Value + The Resampled DataFrame/Series/Value """ selector_datetime = slice(start_time, end_time) - if fields is None: - fields = slice(None) from ..data.dataset.utils import get_level_index + feature = lazy_sort_index(ts_feature) datetime_level = get_level_index(feature, level="datetime") == 0 - if isinstance(feature, pd.Series): - feature = feature.loc[selector_datetime] if datetime_level else feature.loc[(slice(None), selector_datetime)] - elif isinstance(feature, pd.DataFrame): - feature = ( - feature.loc[selector_datetime, fields] - if datetime_level - else feature.loc[(slice(None), selector_datetime), fields] - ) + if datetime_level: + feature = feature.loc[selector_datetime] + else: + feature = feature.loc[(slice(None), selector_datetime)] if feature.empty: return None if isinstance(feature.index, pd.MultiIndex): @@ -296,5 +302,4 @@ def sample_feature( return method_func(feature, **method_kwargs) elif isinstance(method, str): return getattr(feature, method)(**method_kwargs) - return feature diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 6bb6341f0..02a282035 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -15,7 +15,7 @@ from ..data.dataset.handler import DataHandlerLP from ..utils import init_instance_by_config, get_module_by_module_path from ..log import get_module_logger from ..utils import flatten_dict -from ..utils.sample import parse_freq +from ..utils.resam import parse_freq from ..strategy.base import BaseStrategy from ..contrib.eva.alpha import calc_ic, calc_long_short_return @@ -291,8 +291,8 @@ class PortAnaRecord(RecordTemp): """ config["strategy"] : dict define the strategy class as well as the kwargs. - config["env"] : dict - define the env class as well as the kwargs. + config["executor"] : dict + define the executor class as well as the kwargs. config["backtest"] : dict define the backtest kwargs. risk_analysis_freq : int @@ -301,24 +301,26 @@ class PortAnaRecord(RecordTemp): super().__init__(recorder=recorder, **kwargs) self.strategy_config = config["strategy"] - self.env_config = config["env"] + self.executor_config = config["executor"] self.backtest_config = config["backtest"] _count, _freq = parse_freq(risk_analysis_freq) self.risk_analysis_freq = f"{_count}{_freq}" - self.report_freq = self._get_report_freq(self.env_config) + self.report_freq = self._get_report_freq(self.executor_config) - def _get_report_freq(self, env_config): + def _get_report_freq(self, executor_config): ret_freq = [] - if env_config["kwargs"].get("generate_report", False): - _count, _freq = parse_freq(env_config["kwargs"]["step_bar"]) + if executor_config["kwargs"].get("generate_report", False): + _count, _freq = parse_freq(executor_config["kwargs"]["step_bar"]) ret_freq.append(f"{_count}{_freq}") - if "sub_env" in env_config["kwargs"]: - ret_freq.extend(self._get_report_freq(env_config["kwargs"]["sub_env"])) + if "sub_env" in executor_config["kwargs"]: + ret_freq.extend(self._get_report_freq(executor_config["kwargs"]["sub_env"])) return ret_freq def generate(self, **kwargs): # custom strategy and get backtest - report_dict = normal_backtest(env=self.env_config, strategy=self.strategy_config, **self.backtest_config) + report_dict = normal_backtest( + executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config + ) for report_freq, (report_normal, positions_normal) in report_dict.items(): self.recorder.save_objects( **{f"report_normal_{report_freq}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()