mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 02:50:58 +08:00
fix comments
This commit is contained in:
@@ -14,8 +14,11 @@ This example uses a DropoutTopkStrategy (a strategy based on the daily frequency
|
||||
|
||||
Start backtesting by running the following command:
|
||||
```bash
|
||||
python workflow.py
|
||||
python workflow.py backtest
|
||||
```
|
||||
|
||||
Also, reports is shown in workflow.ipynb
|
||||
Start collecting data by running the following command:
|
||||
```bash
|
||||
python workflow.py collect_data
|
||||
```
|
||||
|
||||
|
||||
@@ -1,305 +0,0 @@
|
||||
{
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.8"
|
||||
},
|
||||
"orig_nbformat": 2,
|
||||
"kernelspec": {
|
||||
"name": "pythonjvsc74a57bd0fcc004278713aaede7c629a6a43738a929cb09abb52817d4f72eb70db44cd87b",
|
||||
"display_name": "Python 3.8.8 ('qlib_backtest': conda)"
|
||||
},
|
||||
"metadata": {
|
||||
"interpreter": {
|
||||
"hash": "fcc004278713aaede7c629a6a43738a929cb09abb52817d4f72eb70db44cd87b"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2,
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Copyright (c) Microsoft Corporation.\n",
|
||||
"# Licensed under the MIT License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys, site\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"################################# NOTE #################################\n",
|
||||
"# Please be aware that if colab installs the latest numpy and pyqlib #\n",
|
||||
"# in this cell, users should RESTART the runtime in order to run the #\n",
|
||||
"# following cells successfully. #\n",
|
||||
"########################################################################\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" import qlib\n",
|
||||
"except ImportError:\n",
|
||||
" # install qlib\n",
|
||||
" ! pip install --upgrade numpy\n",
|
||||
" ! pip install pyqlib\n",
|
||||
" # reload\n",
|
||||
" site.main()\n",
|
||||
"\n",
|
||||
"scripts_dir = Path.cwd().parent.joinpath(\"scripts\")\n",
|
||||
"if not scripts_dir.joinpath(\"get_data.py\").exists():\n",
|
||||
" # download get_data.py script\n",
|
||||
" scripts_dir = Path(\"~/tmp/qlib_code/scripts\").expanduser().resolve()\n",
|
||||
" scripts_dir.mkdir(parents=True, exist_ok=True)\n",
|
||||
" import requests\n",
|
||||
" with requests.get(\"https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py\") as resp:\n",
|
||||
" with open(scripts_dir.joinpath(\"get_data.py\"), \"wb\") as fp:\n",
|
||||
" fp.write(resp.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"from qlib.config import REG_CN\n",
|
||||
"from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict\n",
|
||||
"from qlib.workflow import R\n",
|
||||
"from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n",
|
||||
"from qlib.tests.data import GetData"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# use default data\n",
|
||||
"provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
|
||||
"if not exists_qlib_data(provider_uri):\n",
|
||||
" print(f\"Qlib data is not found in {provider_uri}\")\n",
|
||||
" GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n",
|
||||
"\n",
|
||||
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"market = \"csi300\"\n",
|
||||
"benchmark = \"SH000300\"\n",
|
||||
"\n",
|
||||
"###################################\n",
|
||||
"# train model\n",
|
||||
"###################################\n",
|
||||
"\n",
|
||||
"data_handler_config = {\n",
|
||||
" \"start_time\": \"2008-01-01\",\n",
|
||||
" \"end_time\": \"2020-08-01\",\n",
|
||||
" \"fit_start_time\": \"2008-01-01\",\n",
|
||||
" \"fit_end_time\": \"2014-12-31\",\n",
|
||||
" \"instruments\": market,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"task = {\n",
|
||||
" \"model\": {\n",
|
||||
" \"class\": \"LGBModel\",\n",
|
||||
" \"module_path\": \"qlib.contrib.model.gbdt\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"loss\": \"mse\",\n",
|
||||
" \"colsample_bytree\": 0.8879,\n",
|
||||
" \"learning_rate\": 0.0421,\n",
|
||||
" \"subsample\": 0.8789,\n",
|
||||
" \"lambda_l1\": 205.6999,\n",
|
||||
" \"lambda_l2\": 580.9768,\n",
|
||||
" \"max_depth\": 8,\n",
|
||||
" \"num_leaves\": 210,\n",
|
||||
" \"num_threads\": 20,\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" \"dataset\": {\n",
|
||||
" \"class\": \"DatasetH\",\n",
|
||||
" \"module_path\": \"qlib.data.dataset\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"handler\": {\n",
|
||||
" \"class\": \"Alpha158\",\n",
|
||||
" \"module_path\": \"qlib.contrib.data.handler\",\n",
|
||||
" \"kwargs\": data_handler_config,\n",
|
||||
" },\n",
|
||||
" \"segments\": {\n",
|
||||
" \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
|
||||
" \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
|
||||
" \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
"}\n",
|
||||
"# model initialization\n",
|
||||
"model = init_instance_by_config(task[\"model\"])\n",
|
||||
"dataset = init_instance_by_config(task[\"dataset\"])\n",
|
||||
"\n",
|
||||
"# start exp to train model\n",
|
||||
"with R.start(experiment_name=\"train_model\"):\n",
|
||||
" R.log_params(**flatten_dict(task))\n",
|
||||
" model.fit(dataset)\n",
|
||||
" R.save_objects(trained_model=model)\n",
|
||||
" rid = R.get_recorder().id\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": [
|
||||
"outputPrepend"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trade_start_time = \"2017-01-01\"\n",
|
||||
"trade_end_time = \"2020-08-01\"\n",
|
||||
"\n",
|
||||
"port_analysis_config = {\n",
|
||||
" \"strategy\": {\n",
|
||||
" \"class\": \"TopkDropoutStrategy\",\n",
|
||||
" \"module_path\": \"qlib.contrib.strategy.model_strategy\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"step_bar\": \"week\",\n",
|
||||
" \"model\": model,\n",
|
||||
" \"dataset\": dataset,\n",
|
||||
" \"topk\": 50,\n",
|
||||
" \"n_drop\": 5,\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" \"env\": {\n",
|
||||
" \"class\": \"SplitExecutor\",\n",
|
||||
" \"module_path\": \"qlib.contrib.backtest.executor\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"step_bar\": \"week\",\n",
|
||||
" \"generate_report\": True,\n",
|
||||
" \"sub_env\": {\n",
|
||||
" \"class\": \"SimulatorExecutor\",\n",
|
||||
" \"module_path\": \"qlib.contrib.backtest.executor\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"step_bar\": \"day\",\n",
|
||||
" \"verbose\": True,\n",
|
||||
" \"generate_report\": True,\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" \"sub_strategy\": {\n",
|
||||
" \"class\": \"SBBStrategyEMA\",\n",
|
||||
" \"module_path\": \"qlib.contrib.strategy.rule_strategy\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"step_bar\": \"day\",\n",
|
||||
" \"freq\": \"day\",\n",
|
||||
" \"instruments\": market,\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" \"backtest\": {\n",
|
||||
" \"start_time\": trade_start_time,\n",
|
||||
" \"end_time\": trade_end_time,\n",
|
||||
" \"account\": 100000000,\n",
|
||||
" \"benchmark\": benchmark,\n",
|
||||
" \"exchange_kwargs\": {\n",
|
||||
" \"freq\": \"day\",\n",
|
||||
" \"limit_threshold\": 0.095,\n",
|
||||
" \"deal_price\": \"close\",\n",
|
||||
" \"open_cost\": 0.0005,\n",
|
||||
" \"close_cost\": 0.0015,\n",
|
||||
" \"min_cost\": 5,\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
"}\n",
|
||||
"# backtest and analysis\n",
|
||||
"with R.start(experiment_name=\"backtest_analysis\"):\n",
|
||||
" # prediction\n",
|
||||
" recorder = R.get_recorder()\n",
|
||||
" ba_rid = recorder.id\n",
|
||||
" sr = SignalRecord(model, dataset, recorder)\n",
|
||||
" sr.generate()\n",
|
||||
"\n",
|
||||
" # backtest & analysis\n",
|
||||
" par = PortAnaRecord(recorder, port_analysis_config, \"day\")\n",
|
||||
" par.generate()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.contrib.report import analysis_model, analysis_position\n",
|
||||
"from qlib.data import D\n",
|
||||
"recorder = R.get_recorder(ba_rid, experiment_name=\"backtest_analysis\")\n",
|
||||
"pred_df = recorder.load_object(\"pred.pkl\")\n",
|
||||
"pred_df_dates = pred_df.index.get_level_values(level='datetime')\n",
|
||||
"report_normal_df_1d = recorder.load_object(\"portfolio_analysis/report_normal_1day.pkl\")\n",
|
||||
"positions_1d = recorder.load_object(\"portfolio_analysis/positions_normal_1day.pkl\")\n",
|
||||
"analysis_df_1d = recorder.load_object(\"portfolio_analysis/port_analysis_1day.pkl\")\n",
|
||||
"report_normal_df_1w = recorder.load_object(\"portfolio_analysis/report_normal_1week.pkl\")\n",
|
||||
"positions_1w = recorder.load_object(\"portfolio_analysis/positions_normal_1week.pkl\")\n",
|
||||
"analysis_df_1w = recorder.load_object(\"portfolio_analysis/port_analysis_1week.pkl\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"analysis_position.report_graph(report_normal_df_1d)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"analysis_position.report_graph(report_normal_df_1w)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"analysis_position.risk_analysis_graph(analysis_df_1d, report_normal_df_1d)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"analysis_position.risk_analysis_graph(analysis_df_1w, report_normal_df_1w)"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -3,30 +3,21 @@
|
||||
|
||||
|
||||
import qlib
|
||||
import fire
|
||||
from qlib.config import REG_CN
|
||||
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.contrib.backtest import collect_data
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
class MultiLevelTradingWorkflow:
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
@@ -68,31 +59,17 @@ if __name__ == "__main__":
|
||||
},
|
||||
},
|
||||
}
|
||||
# model initialization
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
|
||||
trade_start_time = "2017-01-01"
|
||||
trade_end_time = "2020-08-01"
|
||||
|
||||
port_analysis_config = {
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"kwargs": {
|
||||
"step_bar": "week",
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
},
|
||||
"env": {
|
||||
"executor": {
|
||||
"class": "SplitExecutor",
|
||||
"module_path": "qlib.contrib.backtest.executor",
|
||||
"kwargs": {
|
||||
"step_bar": "week",
|
||||
"sub_env": {
|
||||
"sub_executor": {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.contrib.backtest.executor",
|
||||
"kwargs": {
|
||||
@@ -105,11 +82,11 @@ if __name__ == "__main__":
|
||||
"class": "SBBStrategyEMA",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
"kwargs": {
|
||||
"step_bar": "day",
|
||||
"freq": "day",
|
||||
"instruments": market,
|
||||
},
|
||||
},
|
||||
"track_data": True,
|
||||
},
|
||||
},
|
||||
"backtest": {
|
||||
@@ -128,17 +105,69 @@ if __name__ == "__main__":
|
||||
},
|
||||
}
|
||||
|
||||
with R.start(experiment_name="highfreq_backtest"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
model.fit(dataset)
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
def _init_qlib(self):
|
||||
"""initialize qlib"""
|
||||
# use yahoo_cn_1min data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
def _train_model(self, model, dataset):
|
||||
with R.start(experiment_name="train"):
|
||||
R.log_params(**flatten_dict(self.task))
|
||||
model.fit(dataset)
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
|
||||
# backtest. If users want to use backtest based on their own prediction,
|
||||
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
|
||||
par = PortAnaRecord(recorder, port_analysis_config, "day")
|
||||
par.generate()
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
|
||||
def backtest(self):
|
||||
self._init_qlib()
|
||||
model = init_instance_by_config(self.task["model"])
|
||||
dataset = init_instance_by_config(self.task["dataset"])
|
||||
self._train_model(model, dataset)
|
||||
strategy_config = {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"kwargs": {
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
}
|
||||
self.port_analysis_config["strategy"] = strategy_config
|
||||
with R.start(experiment_name="backtest"):
|
||||
|
||||
recorder = R.get_recorder()
|
||||
par = PortAnaRecord(recorder, self.port_analysis_config, "day")
|
||||
par.generate()
|
||||
|
||||
def collect_data(self):
|
||||
self._init_qlib()
|
||||
model = init_instance_by_config(self.task["model"])
|
||||
dataset = init_instance_by_config(self.task["dataset"])
|
||||
self._train_model(model, dataset)
|
||||
executor_config = self.port_analysis_config["executor"]
|
||||
backtest_config = self.port_analysis_config["backtest"]
|
||||
strategy_config = {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"kwargs": {
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
}
|
||||
data_generator = collect_data(executor=executor_config, strategy=strategy_config, **backtest_config)
|
||||
for trade_decision in data_generator:
|
||||
print(trade_decision)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(MultiLevelTradingWorkflow)
|
||||
|
||||
@@ -140,6 +140,10 @@ _default_config = {
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
},
|
||||
# Shift minute for highfreq minite data, used in backtest
|
||||
# if min_data_shift == 0, use default market time [9:30, 11:29, 1:30, 2:39]
|
||||
# if min_data_shift != 0, use shifted market time [9:30, 11:29, 1:30, 2:39] - shift*minute
|
||||
"min_data_shift": {0},
|
||||
}
|
||||
|
||||
MODE_CONF = {
|
||||
|
||||
@@ -5,13 +5,12 @@ from .account import Account
|
||||
from .exchange import Exchange
|
||||
from .executor import BaseExecutor
|
||||
from .backtest import backtest as backtest_func
|
||||
|
||||
from .backtest import collect_data as data_generator
|
||||
|
||||
from ...strategy.base import BaseStrategy
|
||||
from ...utils import init_instance_by_config
|
||||
from ...log import get_module_logger
|
||||
from ...config import C
|
||||
from .faculty import common_faculty
|
||||
|
||||
|
||||
logger = get_module_logger("backtest caller")
|
||||
@@ -89,8 +88,9 @@ def get_exchange(
|
||||
return init_instance_by_config(exchange, accept_types=Exchange)
|
||||
|
||||
|
||||
def backtest(start_time, end_time, strategy, env, benchmark="SH000300", account=1e9, exchange_kwargs={}):
|
||||
|
||||
def get_strategy_executor(
|
||||
start_time, end_time, strategy, executor, benchmark="SH000300", account=1e9, exchange_kwargs={}
|
||||
):
|
||||
trade_account = Account(
|
||||
init_cash=account,
|
||||
benchmark_config={
|
||||
@@ -101,14 +101,32 @@ def backtest(start_time, end_time, strategy, env, benchmark="SH000300", account=
|
||||
)
|
||||
trade_exchange = get_exchange(**exchange_kwargs)
|
||||
|
||||
common_faculty.update(
|
||||
trade_account=trade_account,
|
||||
trade_exchange=trade_exchange,
|
||||
common_infra = {
|
||||
"trade_account": trade_account,
|
||||
"trade_exchange": trade_exchange,
|
||||
}
|
||||
|
||||
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy, common_infra=common_infra)
|
||||
trade_executor = init_instance_by_config(executor, accept_types=BaseExecutor, common_infra=common_infra)
|
||||
|
||||
return trade_strategy, trade_executor
|
||||
|
||||
|
||||
def backtest(start_time, end_time, strategy, executor, benchmark="SH000300", account=1e9, exchange_kwargs={}):
|
||||
|
||||
trade_strategy, trade_executor = get_strategy_executor(
|
||||
start_time, end_time, strategy, executor, benchmark, account, exchange_kwargs
|
||||
)
|
||||
|
||||
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy)
|
||||
trade_env = init_instance_by_config(env, accept_types=BaseExecutor)
|
||||
|
||||
report_dict = backtest_func(start_time, end_time, trade_strategy, trade_env)
|
||||
report_dict = backtest_func(start_time, end_time, trade_strategy, trade_executor)
|
||||
|
||||
return report_dict
|
||||
|
||||
|
||||
def collect_data(start_time, end_time, strategy, executor, benchmark="SH000300", account=1e9, exchange_kwargs={}):
|
||||
|
||||
trade_strategy, trade_executor = get_strategy_executor(
|
||||
start_time, end_time, strategy, executor, benchmark, account, exchange_kwargs
|
||||
)
|
||||
report_dict = yield from data_generator(start_time, end_time, trade_strategy, trade_executor)
|
||||
|
||||
return report_dict
|
||||
|
||||
@@ -9,8 +9,6 @@ import pandas as pd
|
||||
from .position import Position
|
||||
from .report import Report
|
||||
from .order import Order
|
||||
from ...data import D
|
||||
from ...utils.sample import parse_freq, sample_feature
|
||||
|
||||
|
||||
"""
|
||||
@@ -34,85 +32,14 @@ class Account:
|
||||
self.init_vars(init_cash, freq, benchmark_config)
|
||||
|
||||
def init_vars(self, init_cash, freq: str, benchmark_config: dict):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
freq : str
|
||||
frequency of trading bar, used for updating hold count of trading bar
|
||||
benchmark_config : dict
|
||||
config of benchmark, may including the following arguments:
|
||||
- benchmark : Union[str, list, pd.Series]
|
||||
- If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
|
||||
example:
|
||||
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
|
||||
2017-01-04 0.011693
|
||||
2017-01-05 0.000721
|
||||
2017-01-06 -0.004322
|
||||
2017-01-09 0.006874
|
||||
2017-01-10 -0.003350
|
||||
- If `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
|
||||
- If `benchmark` is str, will use the daily change as the 'bench'.
|
||||
benchmark code, default is SH000300 CSI300
|
||||
- start_time : Union[str, pd.Timestamp], optional
|
||||
- If `benchmark` is pd.Series, it will be ignored
|
||||
- Else, it represent start time of benchmark, by default None
|
||||
- end_time : Union[str, pd.Timestamp], optional
|
||||
- If `benchmark` is pd.Series, it will be ignored
|
||||
- Else, it represent end time of benchmark, by default None
|
||||
|
||||
"""
|
||||
# init cash
|
||||
self.init_cash = init_cash
|
||||
self.freq = freq
|
||||
self.benchmark_config = benchmark_config
|
||||
self.bench = self._cal_benchmark(benchmark_config, freq)
|
||||
self.current = Position(cash=init_cash)
|
||||
self._reset_report()
|
||||
self.reset(freq=freq, benchmark_config=benchmark_config, init_report=True)
|
||||
|
||||
def _cal_benchmark(self, benchmark_config, freq):
|
||||
benchmark = benchmark_config.get("benchmark", "SH000300")
|
||||
if isinstance(benchmark, pd.Series):
|
||||
return benchmark
|
||||
else:
|
||||
start_time = benchmark_config.get("start_time", None)
|
||||
end_time = benchmark_config.get("end_time", None)
|
||||
|
||||
if freq is None:
|
||||
raise ValueError("benchmark freq can't be None!")
|
||||
_codes = benchmark if isinstance(benchmark, list) else [benchmark]
|
||||
fields = ["$close/Ref($close,1)-1"]
|
||||
try:
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq=freq, disk_cache=1)
|
||||
except ValueError:
|
||||
_, norm_freq = parse_freq(freq)
|
||||
if norm_freq in ["month", "week", "day"]:
|
||||
try:
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq="day", disk_cache=1)
|
||||
except ValueError:
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1)
|
||||
elif norm_freq == "minute":
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1)
|
||||
else:
|
||||
raise ValueError(f"benchmark freq {freq} is not supported")
|
||||
if len(_temp_result) == 0:
|
||||
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
|
||||
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
|
||||
|
||||
def _sample_benchmark(self, bench, trade_start_time, trade_end_time):
|
||||
def cal_change(x):
|
||||
return (x + 1).prod() - 1
|
||||
|
||||
_ret = sample_feature(bench, trade_start_time, trade_end_time, method=cal_change)
|
||||
return 0 if _ret is None else _ret
|
||||
|
||||
def _reset_freq(self, freq):
|
||||
"""reset frequency"""
|
||||
if freq != self.freq:
|
||||
self.freq = freq
|
||||
self.bench = self._cal_benchmark(self.benchmark_config, self.freq)
|
||||
|
||||
def _reset_report(self):
|
||||
self.report = Report()
|
||||
def reset_report(self, freq, benchmark_config):
|
||||
self.report = Report(freq, benchmark_config)
|
||||
self.positions = {}
|
||||
self.rtn = 0
|
||||
self.ct = 0
|
||||
@@ -120,10 +47,25 @@ class Account:
|
||||
self.val = 0
|
||||
self.earning = 0
|
||||
|
||||
def reset(self, freq=None, init_report: bool = False):
|
||||
self._reset_freq(freq)
|
||||
if init_report:
|
||||
self._reset_report()
|
||||
def reset(self, freq=None, benchmark_config=None, init_report=False):
|
||||
"""reset freq and report of account
|
||||
|
||||
Parameters
|
||||
----------
|
||||
freq : str, optional
|
||||
frequency of account & report, by default None
|
||||
benchmark_config : {}, optional
|
||||
benchmark config of report, by default None
|
||||
init_report : bool, optional
|
||||
whether to initialize the report, by default False
|
||||
"""
|
||||
if freq is not None:
|
||||
self.freq = freq
|
||||
if benchmark_config is not None:
|
||||
self.benchmark_config = benchmark_config
|
||||
|
||||
if freq is not None or benchmark_config is not None or init_report:
|
||||
self.reset_report(self.freq, self.benchmark_config)
|
||||
|
||||
def get_positions(self):
|
||||
return self.positions
|
||||
@@ -131,7 +73,7 @@ class Account:
|
||||
def get_cash(self):
|
||||
return self.current.position["cash"]
|
||||
|
||||
def update_state_from_order(self, order, trade_val, cost, trade_price):
|
||||
def _update_state_from_order(self, order, trade_val, cost, trade_price):
|
||||
# update turnover
|
||||
self.to += trade_val
|
||||
# update cost
|
||||
@@ -155,7 +97,7 @@ class Account:
|
||||
# The cost will be substracted from the cash at last. So the trading logic can ignore the cost calculation
|
||||
if order.direction == Order.SELL:
|
||||
# sell stock
|
||||
self.update_state_from_order(order, trade_val, cost, trade_price)
|
||||
self._update_state_from_order(order, trade_val, cost, trade_price)
|
||||
# update current position
|
||||
# for may sell all of stock_id
|
||||
self.current.update_order(order, trade_val, cost, trade_price)
|
||||
@@ -163,15 +105,15 @@ class Account:
|
||||
# buy stock
|
||||
# deal order, then update state
|
||||
self.current.update_order(order, trade_val, cost, trade_price)
|
||||
self.update_state_from_order(order, trade_val, cost, trade_price)
|
||||
self._update_state_from_order(order, trade_val, cost, trade_price)
|
||||
|
||||
def update_bar_count(self):
|
||||
self.current.add_count_all(bar=self.freq)
|
||||
|
||||
def update_bar_report(self, trade_start_time, trade_end_time, trade_exchange):
|
||||
"""
|
||||
start_time: pd.TimeStamp
|
||||
end_time: pd.TimeStamp
|
||||
trade_start_time: pd.TimeStamp
|
||||
trade_end_time: pd.TimeStamp
|
||||
quote: pd.DataFrame (code, date), collumns
|
||||
when the end of trade date
|
||||
- update rtn
|
||||
@@ -211,7 +153,8 @@ class Account:
|
||||
# judge whether the the trading is begin.
|
||||
# and don't add init account state into report, due to we don't have excess return in those days.
|
||||
self.report.update_report_record(
|
||||
trade_time=trade_start_time,
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
account_value=now_account_value,
|
||||
cash=self.current.position["cash"],
|
||||
return_rate=(self.earning + self.ct) / last_account_value,
|
||||
@@ -220,7 +163,6 @@ class Account:
|
||||
turnover_rate=self.to / last_account_value,
|
||||
cost_rate=self.ct / last_account_value,
|
||||
stock_value=now_stock_value,
|
||||
bench_value=self._sample_benchmark(self.bench, trade_start_time, trade_end_time),
|
||||
)
|
||||
# set now_account_value to position
|
||||
self.current.position["now_account_value"] = now_account_value
|
||||
@@ -234,18 +176,3 @@ class Account:
|
||||
self.rtn = 0
|
||||
self.ct = 0
|
||||
self.to = 0
|
||||
|
||||
def load_account(self, account_path):
|
||||
report = Report()
|
||||
position = Position()
|
||||
report.load_report(account_path / "report.csv")
|
||||
position.load_position(account_path / "position.xlsx")
|
||||
|
||||
# assign values
|
||||
self.init_vars(position.init_cash)
|
||||
self.current = position
|
||||
self.report = report
|
||||
|
||||
def save_account(self, account_path):
|
||||
self.current.save_position(account_path / "position.xlsx")
|
||||
self.report.save_report(account_path / "report.csv")
|
||||
|
||||
@@ -2,14 +2,29 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
def backtest(start_time, end_time, trade_strategy, trade_env):
|
||||
def backtest(start_time, end_time, trade_strategy, trade_executor):
|
||||
|
||||
trade_env.reset(start_time=start_time, end_time=end_time)
|
||||
trade_strategy.reset(start_time=start_time, end_time=end_time)
|
||||
trade_executor.reset(start_time=start_time, end_time=end_time)
|
||||
level_infra = trade_executor.get_level_infra()
|
||||
trade_strategy.reset(level_infra=level_infra)
|
||||
|
||||
_execute_state = trade_env.get_init_state()
|
||||
while not trade_env.finished():
|
||||
_order_list = trade_strategy.generate_order_list(_execute_state)
|
||||
_execute_state = trade_env.execute(_order_list)
|
||||
sub_execute_state = trade_executor.get_init_state()
|
||||
while not trade_executor.finished():
|
||||
sub_trade_decision = trade_strategy.generate_trade_decision(sub_execute_state)
|
||||
sub_execute_state = trade_executor.execute(sub_trade_decision)
|
||||
|
||||
return trade_env.get_report()
|
||||
return trade_executor.get_report()
|
||||
|
||||
|
||||
def collect_data(start_time, end_time, trade_strategy, trade_executor):
|
||||
|
||||
trade_executor.reset(start_time=start_time, end_time=end_time)
|
||||
level_infra = trade_executor.get_level_infra()
|
||||
trade_strategy.reset(level_infra=level_infra)
|
||||
|
||||
sub_execute_state = trade_executor.get_init_state()
|
||||
while not trade_executor.finished():
|
||||
sub_trade_decision = trade_strategy.generate_trade_decision(sub_execute_state)
|
||||
sub_execute_state = yield from trade_executor.collect_data(sub_trade_decision)
|
||||
|
||||
return trade_executor.get_report()
|
||||
|
||||
@@ -11,7 +11,7 @@ import pandas as pd
|
||||
from ...data.data import D
|
||||
from ...data.dataset.utils import get_level_index
|
||||
from ...config import C, REG_CN
|
||||
from ...utils.sample import sample_feature
|
||||
from ...utils.resam import resam_ts_data
|
||||
from ...log import get_module_logger
|
||||
from .order import Order
|
||||
|
||||
@@ -34,8 +34,9 @@ class Exchange:
|
||||
):
|
||||
"""__init__
|
||||
|
||||
:param start_time: start time for backtest
|
||||
:param end_time: end time for backtest
|
||||
:param freq: frequency of data
|
||||
:param start_time: closed start time for backtest
|
||||
:param end_time: closed end time for backtest
|
||||
:param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50)
|
||||
:param deal_price: str, 'close', 'open', 'vwap'
|
||||
:param subscribe_fields: list, subscribe fields
|
||||
@@ -91,7 +92,7 @@ class Exchange:
|
||||
# $factor is for rounding to the trading unit
|
||||
# $change is for calculating the limit of the stock
|
||||
|
||||
necessary_fields = {self.deal_price, "$close", "$change", "$factor"}
|
||||
necessary_fields = {self.deal_price, "$close", "$change", "$factor", "$volume"}
|
||||
subscribe_fields = list(necessary_fields | set(subscribe_fields))
|
||||
all_fields = list(necessary_fields | set(subscribe_fields))
|
||||
self.all_fields = all_fields
|
||||
@@ -167,12 +168,12 @@ class Exchange:
|
||||
trade_date
|
||||
is limtited
|
||||
"""
|
||||
return sample_feature(self.quote[stock_id], start_time, end_time, fields="limit", method="all").iloc[0]
|
||||
return resam_ts_data(self.quote[stock_id]["limit"], start_time, end_time, method="all").iloc[0]
|
||||
|
||||
def check_stock_suspended(self, stock_id, start_time, end_time):
|
||||
# is suspended
|
||||
if stock_id in self.quote:
|
||||
return sample_feature(self.quote[stock_id], start_time, end_time, method=None) is None
|
||||
return resam_ts_data(self.quote[stock_id], start_time, end_time, method=None) is None
|
||||
else:
|
||||
return True
|
||||
|
||||
@@ -230,15 +231,16 @@ class Exchange:
|
||||
return trade_val, trade_cost, trade_price
|
||||
|
||||
def get_quote_info(self, stock_id, start_time, end_time):
|
||||
return sample_feature(self.quote[stock_id], start_time, end_time, method="last").iloc[0]
|
||||
return resam_ts_data(self.quote[stock_id], start_time, end_time, method="last").iloc[0]
|
||||
|
||||
def get_close(self, stock_id, start_time, end_time):
|
||||
return sample_feature(self.quote[stock_id], start_time, end_time, fields="$close", method="last").iloc[0]
|
||||
return resam_ts_data(self.quote[stock_id]["$close"], start_time, end_time, method="last").iloc[0]
|
||||
|
||||
def get_volume(self, stock_id, start_time, end_time):
|
||||
return resam_ts_data(self.quote[stock_id]["$volume"], start_time, end_time, method="sum").iloc[0]
|
||||
|
||||
def get_deal_price(self, stock_id, start_time, end_time):
|
||||
deal_price = sample_feature(
|
||||
self.quote[stock_id], start_time, end_time, fields=self.deal_price, method="last"
|
||||
).iloc[0]
|
||||
deal_price = resam_ts_data(self.quote[stock_id][self.deal_price], start_time, end_time, method="last").iloc[0]
|
||||
if np.isclose(deal_price, 0.0) or np.isnan(deal_price):
|
||||
self.logger.warning(
|
||||
f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!"
|
||||
@@ -248,7 +250,7 @@ class Exchange:
|
||||
return deal_price
|
||||
|
||||
def get_factor(self, stock_id, start_time, end_time):
|
||||
return sample_feature(self.quote[stock_id], start_time, end_time, fields="$factor", method="last").iloc[0]
|
||||
return resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method="last").iloc[0]
|
||||
|
||||
def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time):
|
||||
"""
|
||||
|
||||
@@ -2,88 +2,18 @@ import copy
|
||||
import warnings
|
||||
import pandas as pd
|
||||
from typing import Union
|
||||
from ...data.data import Cal
|
||||
|
||||
from ...utils import init_instance_by_config
|
||||
from ...utils.sample import get_sample_freq_calendar, parse_freq
|
||||
from ...utils.resam import parse_freq
|
||||
|
||||
|
||||
from .order import Order
|
||||
from .account import Account
|
||||
from .exchange import Exchange
|
||||
from .faculty import common_faculty
|
||||
from .utils import TradeCalendarManager
|
||||
|
||||
|
||||
class BaseTradeCalendar:
|
||||
"""
|
||||
Base class providing trading calendar
|
||||
- BaseStrategy and BaseExecutor should inherited from this class
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, step_bar: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
step_bar : str
|
||||
frequency of each trading step bar
|
||||
start_time : Union[str, pd.Timestamp], optional
|
||||
start time of trading, by default None
|
||||
If `start_time` is None, it must be reset before trading.
|
||||
end_time : Union[str, pd.Timestamp], optional
|
||||
end time of trading, by default None
|
||||
If `end_time` is None, it must be reset before trading.
|
||||
"""
|
||||
|
||||
self.step_bar = step_bar
|
||||
self.start_time = pd.Timestamp(start_time) if start_time else None
|
||||
self.end_time = pd.Timestamp(end_time) if end_time else None
|
||||
self.reset(start_time=start_time, end_time=end_time)
|
||||
|
||||
def _reset_trade_calendar(self, start_time, end_time):
|
||||
"""reset trade calendar"""
|
||||
if start_time and end_time:
|
||||
_calendar, freq, freq_sam = get_sample_freq_calendar(freq=self.step_bar)
|
||||
self.calendar = _calendar
|
||||
_, _, _start_index, _end_index = Cal.locate_index(
|
||||
self.start_time, self.end_time, freq=freq, freq_sam=freq_sam
|
||||
)
|
||||
self.start_index = _start_index
|
||||
self.end_index = _end_index
|
||||
self.trade_len = _end_index - _start_index + 1
|
||||
self.trade_index = 0
|
||||
else:
|
||||
raise ValueError("failed to reset trade calendar, param `start_time` or `end_time` is None.")
|
||||
|
||||
def reset(self, start_time=None, end_time=None):
|
||||
"""
|
||||
Reset start\end time of trading, and reset trading calendar
|
||||
"""
|
||||
|
||||
if start_time:
|
||||
self.start_time = pd.Timestamp(start_time)
|
||||
if end_time:
|
||||
self.end_time = pd.Timestamp(end_time)
|
||||
if self.start_time and self.end_time and (start_time or end_time):
|
||||
self._reset_trade_calendar(start_time=self.start_time, end_time=self.end_time)
|
||||
|
||||
def _get_calendar_time(self, trade_index=1, shift=0):
|
||||
trade_index = trade_index - shift
|
||||
calendar_index = self.start_index + trade_index
|
||||
return self.calendar[calendar_index - 1], self.calendar[calendar_index] - pd.Timedelta(seconds=1)
|
||||
|
||||
def finished(self):
|
||||
return self.trade_index >= self.trade_len
|
||||
|
||||
def step(self):
|
||||
if self.finished():
|
||||
raise RuntimeError(f"this env has completed its task, please reset it if you want to call it!")
|
||||
# trade count += 1
|
||||
self.trade_index = self.trade_index + 1
|
||||
|
||||
|
||||
class BaseExecutor(BaseTradeCalendar):
|
||||
class BaseExecutor:
|
||||
"""Base executor for trading"""
|
||||
|
||||
def __init__(
|
||||
@@ -91,48 +21,97 @@ class BaseExecutor(BaseTradeCalendar):
|
||||
step_bar: str,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
trade_account: Account = None,
|
||||
generate_report: bool = False,
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
common_infra: dict = {},
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
trade_account : Account, optional
|
||||
trade account for trading, by default None
|
||||
- If `trade_account` is None, self.trade_account will be set with common_faculty
|
||||
generate_report : bool, optional
|
||||
whether to generate report, by default False
|
||||
verbose : bool, optional
|
||||
whether to print trading info, by default False
|
||||
track_data : bool, optional
|
||||
whether to generate order_list, will be used when making data for multi-level training
|
||||
- If `self.track_data` is true, when making data for training, the input `order_list` of `execute` will be generated by `get_data`
|
||||
- Else, `order_list` will not be generated
|
||||
whether to generate trade_decision, will be used when making data for multi-level training
|
||||
- If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will be generated by `collect_data`
|
||||
- Else, `trade_decision` will not be generated
|
||||
common_infra : dict, optional:
|
||||
common infrastructure for backtesting, may including:
|
||||
- trade_account : Account, optional
|
||||
trade account for trading
|
||||
- trade_exchange : Exchange, optional
|
||||
exchange that provides market info
|
||||
|
||||
"""
|
||||
super(BaseExecutor, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, **kwargs)
|
||||
self.trade_account = copy.copy(common_faculty.trade_account if trade_account is None else trade_account)
|
||||
self.trade_account.reset(freq=self.step_bar, init_report=True)
|
||||
self.step_bar = step_bar
|
||||
self.generate_report = generate_report
|
||||
self.verbose = verbose
|
||||
self.track_data = track_data
|
||||
self.reset(start_time=start_time, end_time=end_time, track_data=track_data, common_infra=common_infra)
|
||||
|
||||
def reset(self, track_data: bool = None, **kwargs):
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
Reset `track_data`, will be used when making data for multi-level training
|
||||
reset infrastructure for trading
|
||||
- reset trade_account
|
||||
"""
|
||||
super(BaseExecutor, self).reset(**kwargs)
|
||||
if not hasattr(self, "common_infra"):
|
||||
self.common_infra = common_infra
|
||||
else:
|
||||
self.common_infra.update(common_infra)
|
||||
|
||||
if "trade_account" in common_infra:
|
||||
self.trade_account = copy.copy(common_infra.get("trade_account"))
|
||||
self.trade_account.reset(freq=self.step_bar, init_report=True)
|
||||
|
||||
def reset(self, track_data: bool = None, common_infra: dict = None, **kwargs):
|
||||
"""
|
||||
- reset `start_time` and `end_time`, used in trade calendar
|
||||
- reset `track_data`, used when making data for multi-level training
|
||||
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
|
||||
"""
|
||||
|
||||
if track_data is not None:
|
||||
self.track_data = track_data
|
||||
|
||||
if common_infra is not None:
|
||||
self.reset_common_infra(common_infra)
|
||||
|
||||
if "start_time" in kwargs or "end_time" in kwargs:
|
||||
start_time = kwargs.get("start_time")
|
||||
end_time = kwargs.get("end_time")
|
||||
self.trade_calendar = TradeCalendarManager(step_bar=self.step_bar, start_time=start_time, end_time=end_time)
|
||||
|
||||
def get_level_infra(self):
|
||||
return {"trade_calendar": self.trade_calendar}
|
||||
|
||||
def finished(self):
|
||||
return self.trade_calendar.finished()
|
||||
|
||||
def execute(self, trade_decision):
|
||||
"""execute the trade decision and return the executed result
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_decision : object
|
||||
|
||||
Returns
|
||||
----------
|
||||
executed state : List[Tuple[Order, float, float, float]]
|
||||
- Each element in the list represents (order, trade value, trade cost, trade price)
|
||||
"""
|
||||
raise NotImplementedError("execute is not implemented!")
|
||||
|
||||
def collect_data(self, trade_decision):
|
||||
if self.track_data:
|
||||
yield trade_decision
|
||||
return self.execute(trade_decision)
|
||||
|
||||
def get_init_state(self):
|
||||
raise NotImplementedError("get_init_state in not implemeted!")
|
||||
|
||||
def execute(self, **kwargs):
|
||||
raise NotImplementedError("execute is not implemented!")
|
||||
|
||||
def get_trade_account(self):
|
||||
raise NotImplementedError("get_trade_account is not implemented!")
|
||||
|
||||
@@ -146,56 +125,75 @@ class SplitExecutor(BaseExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
step_bar: str,
|
||||
sub_env: Union[BaseExecutor, dict],
|
||||
sub_executor: Union[BaseExecutor, dict],
|
||||
sub_strategy: Union[BaseStrategy, dict],
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
trade_account: Account = None,
|
||||
trade_exchange: Exchange = None,
|
||||
generate_report: bool = False,
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
common_infra: dict = {},
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
sub_env : BaseExecutor
|
||||
sub_executor : BaseExecutor
|
||||
trading env in each trading bar.
|
||||
sub_strategy : BaseStrategy
|
||||
trading strategy in each trading bar
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_faculty
|
||||
exchange that provides market info, used to generate report
|
||||
- If generate_report is None, trade_exchange will be ignored
|
||||
- Else If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
"""
|
||||
self.sub_executor = init_instance_by_config(sub_executor, common_infra=common_infra, accept_types=BaseExecutor)
|
||||
self.sub_strategy = init_instance_by_config(
|
||||
sub_strategy, common_infra=common_infra, accept_types=self.BaseStrategy
|
||||
)
|
||||
|
||||
super(SplitExecutor, self).__init__(
|
||||
step_bar=step_bar,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
trade_account=trade_account,
|
||||
generate_report=generate_report,
|
||||
verbose=verbose,
|
||||
track_data=track_data,
|
||||
common_infra=common_infra,
|
||||
**kwargs,
|
||||
)
|
||||
if generate_report:
|
||||
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
|
||||
self.sub_env = init_instance_by_config(sub_env, accept_types=BaseExecutor)
|
||||
self.sub_strategy = init_instance_by_config(sub_strategy, accept_types=self.BaseStrategy)
|
||||
|
||||
if generate_report and trade_exchange is not None:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
reset infrastructure for trading
|
||||
- reset trade_exchange
|
||||
- reset substrategy and subexecutor common infra
|
||||
"""
|
||||
super(SplitExecutor, self).reset_common_infra(common_infra)
|
||||
|
||||
if self.generate_report and "trade_exchange" in common_infra:
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
self.sub_executor.reset_common_infra(common_infra)
|
||||
self.sub_strategy.reset_common_infra(common_infra)
|
||||
|
||||
def get_init_state(self):
|
||||
init_state = {"current": self.trade_account.current}
|
||||
return init_state
|
||||
return []
|
||||
|
||||
def _init_sub_trading(self, order_list):
|
||||
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
|
||||
self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time)
|
||||
self.sub_strategy.reset(start_time=trade_start_time, end_time=trade_end_time, trade_order_list=order_list)
|
||||
sub_execute_state = self.sub_env.get_init_state()
|
||||
return sub_execute_state
|
||||
def _init_sub_trading(self, trade_decision):
|
||||
trade_index = self.trade_calendar.get_trade_index()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
|
||||
self.sub_executor.reset(start_time=trade_start_time, end_time=trade_end_time)
|
||||
sub_level_infra = self.sub_executor.get_level_infra()
|
||||
self.sub_strategy.reset(level_infra=sub_level_infra, rely_trade_decision=trade_decision)
|
||||
|
||||
def _update_trade_account(self):
|
||||
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
|
||||
trade_index = self.trade_calendar.get_trade_index()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
|
||||
self.trade_account.update_bar_count()
|
||||
if self.generate_report:
|
||||
self.trade_account.update_bar_report(
|
||||
@@ -204,30 +202,38 @@ class SplitExecutor(BaseExecutor):
|
||||
trade_exchange=self.trade_exchange,
|
||||
)
|
||||
|
||||
def execute(self, order_list):
|
||||
super(SplitExecutor, self).step()
|
||||
self._init_sub_trading(order_list)
|
||||
sub_execute_state = self.sub_env.get_init_state()
|
||||
while not self.sub_env.finished():
|
||||
_order_list = self.sub_strategy.generate_order_list(sub_execute_state)
|
||||
sub_execute_state = self.sub_env.execute(order_list=_order_list)
|
||||
self._update_trade_account()
|
||||
return {"current": self.trade_account.current}
|
||||
def execute(self, trade_decision):
|
||||
self.trade_calendar.step()
|
||||
self._init_sub_trading(trade_decision)
|
||||
execute_state = []
|
||||
sub_execute_state = self.sub_executor.get_init_state()
|
||||
while not self.sub_executor.finished():
|
||||
sub_trade_decison = self.sub_strategy.generate_trade_decision(sub_execute_state)
|
||||
sub_execute_state = self.sub_executor.execute(trade_decision=sub_trade_decison)
|
||||
execute_state.extend(sub_execute_state)
|
||||
if hasattr(self, "trade_account"):
|
||||
self._update_trade_account()
|
||||
|
||||
def get_data(self, order_list):
|
||||
return execute_state
|
||||
|
||||
def collect_data(self, trade_decision):
|
||||
if self.track_data:
|
||||
yield order_list
|
||||
super(SplitExecutor, self).step()
|
||||
self._init_sub_trading(order_list)
|
||||
sub_execute_state = self.sub_env.get_init_state()
|
||||
while not self.sub_env.finished():
|
||||
_order_list = self.sub_strategy.generate_order_list(sub_execute_state)
|
||||
sub_execute_state = yield from self.sub_env.get_data(order_list=_order_list)
|
||||
self._update_trade_account()
|
||||
return {"current": self.trade_account.current}
|
||||
yield trade_decision
|
||||
self.trade_calendar.step()
|
||||
self._init_sub_trading(trade_decision)
|
||||
execute_state = []
|
||||
sub_execute_state = self.sub_executor.get_init_state()
|
||||
while not self.sub_executor.finished():
|
||||
sub_trade_decison = self.sub_strategy.generate_trade_decision(sub_execute_state)
|
||||
sub_execute_state = yield from self.sub_executor.collect_data(trade_decision=sub_trade_decison)
|
||||
execute_state.extend(sub_execute_state)
|
||||
if hasattr(self, "trade_account"):
|
||||
self._update_trade_account()
|
||||
|
||||
return execute_state
|
||||
|
||||
def get_report(self):
|
||||
sub_env_report_dict = self.sub_env.get_report()
|
||||
sub_env_report_dict = self.sub_executor.get_report()
|
||||
if self.generate_report:
|
||||
_report = self.trade_account.report.generate_report_dataframe()
|
||||
_positions = self.trade_account.get_positions()
|
||||
@@ -242,46 +248,57 @@ class SimulatorExecutor(BaseExecutor):
|
||||
step_bar: str,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
trade_account: Account = None,
|
||||
trade_exchange: Exchange = None,
|
||||
generate_report: bool = False,
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
common_infra: dict = {},
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
"""
|
||||
super(SimulatorExecutor, self).__init__(
|
||||
step_bar=step_bar,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
trade_account=trade_account,
|
||||
generate_report=generate_report,
|
||||
verbose=verbose,
|
||||
track_data=track_data,
|
||||
common_infra=common_infra,
|
||||
**kwargs,
|
||||
)
|
||||
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
|
||||
if trade_exchange is not None:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
reset infrastructure for trading
|
||||
- reset trade_exchange
|
||||
"""
|
||||
super(SimulatorExecutor, self).reset_common_infra(common_infra)
|
||||
if "trade_exchange" in common_infra:
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def get_init_state(self):
|
||||
init_state = {"current": self.trade_account.current, "trade_info": []}
|
||||
return init_state
|
||||
return []
|
||||
|
||||
def execute(self, order_list):
|
||||
super(SimulatorExecutor, self).step()
|
||||
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
|
||||
trade_info = []
|
||||
for order in order_list:
|
||||
def execute(self, trade_decision):
|
||||
self.trade_calendar.step()
|
||||
trade_index = self.trade_calendar.get_trade_index()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
|
||||
execute_state = []
|
||||
for order in trade_decision:
|
||||
if self.trade_exchange.check_order(order) is True:
|
||||
# execute the order
|
||||
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(
|
||||
order, trade_account=self.trade_account
|
||||
)
|
||||
trade_info.append((order, trade_val, trade_cost, trade_price))
|
||||
execute_state.append((order, trade_val, trade_cost, trade_price))
|
||||
if self.verbose:
|
||||
if order.direction == Order.SELL: # sell
|
||||
print(
|
||||
@@ -323,7 +340,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
trade_exchange=self.trade_exchange,
|
||||
)
|
||||
|
||||
return {"current": self.trade_account.current, "trade_info": trade_info}
|
||||
return execute_state
|
||||
|
||||
def get_report(self):
|
||||
if self.generate_report:
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
class Faculty:
|
||||
def __init__(self):
|
||||
self.__dict__["_faculty"] = dict()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.__dict__["_faculty"][key]
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.__dict__["_faculty"]:
|
||||
return self.__dict__["_faculty"][attr]
|
||||
|
||||
raise AttributeError(f"No such {attr} in self._faculty")
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.__dict__["_faculty"][key] = value
|
||||
|
||||
def __setattr__(self, attr, value):
|
||||
self.__dict__["_faculty"][attr] = value
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
self.__dict__["_faculty"].update(*args, **kwargs)
|
||||
|
||||
|
||||
common_faculty = Faculty()
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
import pathlib
|
||||
|
||||
@@ -3,16 +3,51 @@
|
||||
|
||||
|
||||
from collections import OrderedDict
|
||||
from logging import warning
|
||||
import pandas as pd
|
||||
import pathlib
|
||||
import warnings
|
||||
|
||||
from pandas.core.frame import DataFrame
|
||||
|
||||
from ...utils.resam import parse_freq, resam_ts_data
|
||||
from ...data import D
|
||||
|
||||
|
||||
class Report:
|
||||
# daily report of the account
|
||||
# contain those followings: returns, costs turnovers, accounts, cash, bench, value
|
||||
# update report
|
||||
def __init__(self):
|
||||
def __init__(self, freq: str = "day", benchmark_config: dict = {}):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
freq : str
|
||||
frequency of trading bar, used for updating hold count of trading bar
|
||||
benchmark_config : dict
|
||||
config of benchmark, may including the following arguments:
|
||||
- benchmark : Union[str, list, pd.Series]
|
||||
- If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
|
||||
example:
|
||||
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
|
||||
2017-01-04 0.011693
|
||||
2017-01-05 0.000721
|
||||
2017-01-06 -0.004322
|
||||
2017-01-09 0.006874
|
||||
2017-01-10 -0.003350
|
||||
- If `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
|
||||
- If `benchmark` is str, will use the daily change as the 'bench'.
|
||||
benchmark code, default is SH000300 CSI300
|
||||
- start_time : Union[str, pd.Timestamp], optional
|
||||
- If `benchmark` is pd.Series, it will be ignored
|
||||
- Else, it represent start time of benchmark, by default None
|
||||
- end_time : Union[str, pd.Timestamp], optional
|
||||
- If `benchmark` is pd.Series, it will be ignored
|
||||
- Else, it represent end time of benchmark, by default None
|
||||
|
||||
"""
|
||||
self.init_vars()
|
||||
self.init_bench(freq=freq, benchmark_config=benchmark_config)
|
||||
|
||||
def init_vars(self):
|
||||
self.accounts = OrderedDict() # account postion value for each trade date
|
||||
@@ -24,6 +59,49 @@ class Report:
|
||||
self.benches = OrderedDict()
|
||||
self.latest_report_time = None # pd.TimeStamp
|
||||
|
||||
def init_bench(self, freq=None, benchmark_config=None):
|
||||
if freq is not None:
|
||||
self.freq = freq
|
||||
if benchmark_config is not None:
|
||||
self.benchmark_config = benchmark_config
|
||||
self.bench = self._cal_benchmark(self.benchmark_config, self.freq)
|
||||
|
||||
def _cal_benchmark(self, benchmark_config, freq):
|
||||
benchmark = benchmark_config.get("benchmark", "SH000300")
|
||||
if isinstance(benchmark, pd.Series):
|
||||
return benchmark
|
||||
else:
|
||||
start_time = benchmark_config.get("start_time", None)
|
||||
end_time = benchmark_config.get("end_time", None)
|
||||
|
||||
if freq is None:
|
||||
raise ValueError("benchmark freq can't be None!")
|
||||
_codes = benchmark if isinstance(benchmark, list) else [benchmark]
|
||||
fields = ["$close/Ref($close,1)-1"]
|
||||
try:
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq=freq, disk_cache=1)
|
||||
except ValueError:
|
||||
_, norm_freq = parse_freq(freq)
|
||||
if norm_freq in ["month", "week", "day"]:
|
||||
try:
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq="day", disk_cache=1)
|
||||
except ValueError:
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1)
|
||||
elif norm_freq == "minute":
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1)
|
||||
else:
|
||||
raise ValueError(f"benchmark freq {freq} is not supported")
|
||||
if len(_temp_result) == 0:
|
||||
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
|
||||
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
|
||||
|
||||
def _sample_benchmark(self, bench, trade_start_time, trade_end_time):
|
||||
def cal_change(x):
|
||||
return (x + 1).prod() - 1
|
||||
|
||||
_ret = resam_ts_data(bench, trade_start_time, trade_end_time, method=cal_change)
|
||||
return 0.0 if _ret is None else _ret
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.accounts) == 0
|
||||
|
||||
@@ -35,30 +113,39 @@ class Report:
|
||||
|
||||
def update_report_record(
|
||||
self,
|
||||
trade_time=None,
|
||||
trade_start_time=None,
|
||||
trade_end_time=None,
|
||||
account_value=None,
|
||||
cash=None,
|
||||
return_rate=None,
|
||||
turnover_rate=None,
|
||||
cost_rate=None,
|
||||
stock_value=None,
|
||||
bench_value=None,
|
||||
):
|
||||
# check data
|
||||
if None in [trade_time, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value, bench_value]:
|
||||
if None in [
|
||||
trade_start_time,
|
||||
trade_end_time,
|
||||
account_value,
|
||||
cash,
|
||||
return_rate,
|
||||
turnover_rate,
|
||||
cost_rate,
|
||||
stock_value,
|
||||
]:
|
||||
raise ValueError(
|
||||
"None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value, bench_value]"
|
||||
"None in [trade_start_time, trade_end_time, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value]"
|
||||
)
|
||||
# update report data
|
||||
self.accounts[trade_time] = account_value
|
||||
self.returns[trade_time] = return_rate
|
||||
self.turnovers[trade_time] = turnover_rate
|
||||
self.costs[trade_time] = cost_rate
|
||||
self.values[trade_time] = stock_value
|
||||
self.cashes[trade_time] = cash
|
||||
self.benches[trade_time] = bench_value
|
||||
self.accounts[trade_start_time] = account_value
|
||||
self.returns[trade_start_time] = return_rate
|
||||
self.turnovers[trade_start_time] = turnover_rate
|
||||
self.costs[trade_start_time] = cost_rate
|
||||
self.values[trade_start_time] = stock_value
|
||||
self.cashes[trade_start_time] = cash
|
||||
self.benches[trade_start_time] = self._sample_benchmark(self.bench, trade_start_time, trade_end_time)
|
||||
# update latest_report_date
|
||||
self.latest_report_time = trade_time
|
||||
self.latest_report_time = trade_start_time
|
||||
# finish daily report update
|
||||
|
||||
def generate_report_dataframe(self):
|
||||
|
||||
67
qlib/contrib/backtest/utils.py
Normal file
67
qlib/contrib/backtest/utils.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
from typing import Union
|
||||
|
||||
from ...utils.resam import get_resam_calendar
|
||||
from ...data.data import Cal
|
||||
|
||||
|
||||
class TradeCalendarManager:
|
||||
"""
|
||||
Manager for trading calendar
|
||||
- BaseStrategy and BaseExecutor will use it
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, step_bar: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
step_bar : str
|
||||
frequency of each trading calendar
|
||||
start_time : Union[str, pd.Timestamp], optional
|
||||
closed start of the trading calendar, by default None
|
||||
If `start_time` is None, it must be reset before trading.
|
||||
end_time : Union[str, pd.Timestamp], optional
|
||||
closed end of the trade time range, by default None
|
||||
If `end_time` is None, it must be reset before trading.
|
||||
"""
|
||||
self.step_bar = step_bar
|
||||
self.start_time = pd.Timestamp(start_time) if start_time else None
|
||||
self.end_time = pd.Timestamp(start_time) if start_time else None
|
||||
self._init_trade_calendar(step_bar=step_bar, start_time=start_time, end_time=end_time)
|
||||
|
||||
def _init_trade_calendar(self, step_bar, start_time, end_time):
|
||||
"""reset trade calendar"""
|
||||
_calendar, freq, freq_sam = get_resam_calendar(freq=step_bar)
|
||||
self.calendar = _calendar
|
||||
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam)
|
||||
self.start_index = _start_index
|
||||
self.end_index = _end_index
|
||||
self.trade_len = _end_index - _start_index + 1
|
||||
self.trade_index = 0
|
||||
|
||||
def finished(self):
|
||||
return self.trade_index >= self.trade_len
|
||||
|
||||
def step(self):
|
||||
if self.finished():
|
||||
raise RuntimeError(f"The calendar is finished, please reset it if you want to call it!")
|
||||
self.trade_index = self.trade_index + 1
|
||||
|
||||
def get_step_bar(self):
|
||||
return self.step_bar
|
||||
|
||||
def get_trade_len(self):
|
||||
return self.trade_len
|
||||
|
||||
def get_trade_index(self):
|
||||
return self.trade_index
|
||||
|
||||
def get_calendar_time(self, trade_index=1, shift=0):
|
||||
trade_index = trade_index - shift
|
||||
calendar_index = self.start_index + trade_index
|
||||
return self.calendar[calendar_index - 1], self.calendar[calendar_index] - pd.Timedelta(seconds=1)
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from logging import warn
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -10,7 +11,7 @@ import warnings
|
||||
from ..log import get_module_logger
|
||||
from .backtest import get_exchange, backtest as backtest_func
|
||||
from ..utils import get_date_range
|
||||
from ..utils.sample import parse_freq
|
||||
from ..utils.resam import parse_freq
|
||||
|
||||
from ..data import D
|
||||
from ..config import C
|
||||
@@ -20,7 +21,7 @@ from ..data.dataset.utils import get_level_index
|
||||
logger = get_module_logger("Evaluate")
|
||||
|
||||
|
||||
def risk_analysis(r, N: int = None, freq: str = None):
|
||||
def risk_analysis(r, N: int = None, freq: str = "day"):
|
||||
"""Risk Analysis
|
||||
|
||||
Parameters
|
||||
@@ -36,8 +37,8 @@ def risk_analysis(r, N: int = None, freq: str = None):
|
||||
def cal_risk_analysis_scaler(freq):
|
||||
_count, _freq = parse_freq(freq)
|
||||
_freq_scaler = {
|
||||
"minute": 240 * 250,
|
||||
"day": 250,
|
||||
"minute": 240 * 252,
|
||||
"day": 252,
|
||||
"week": 50,
|
||||
"month": 12,
|
||||
}
|
||||
@@ -45,6 +46,8 @@ def risk_analysis(r, N: int = None, freq: str = None):
|
||||
|
||||
if N is None and freq is None:
|
||||
raise ValueError("at least one of `N` and `freq` should exist")
|
||||
if N is not None and freq is not None:
|
||||
warnings.warn("risk_analysis freq will be ignored")
|
||||
if N is None:
|
||||
N = cal_risk_analysis_scaler(freq)
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ class Operator:
|
||||
user.strategy.update(score_series, pred_date, trade_date)
|
||||
|
||||
# generate and save order list
|
||||
order_list = user.strategy.generate_order_list(
|
||||
order_list = user.strategy.generate_trade_decision(
|
||||
score_series=score_series,
|
||||
current=user.account.current,
|
||||
trade_exchange=trade_exchange,
|
||||
@@ -208,7 +208,7 @@ class Operator:
|
||||
self.logger.info("Update account state {} for {}".format(trade_date, user_id))
|
||||
|
||||
def simulate(self, id, config, exchange_config, start, end, path, bench="SH000905"):
|
||||
"""Run the ( generate_order_list -> execute_order_list -> update_account) process everyday
|
||||
"""Run the ( generate_trade_decision -> execute_order_list -> update_account) process everyday
|
||||
from start date to end date.
|
||||
|
||||
Parameters
|
||||
@@ -256,7 +256,7 @@ class Operator:
|
||||
user.strategy.update(score_series, pred_date, trade_date)
|
||||
|
||||
# 3. generate and save order list
|
||||
order_list = user.strategy.generate_order_list(
|
||||
order_list = user.strategy.generate_trade_decision(
|
||||
score_series=score_series,
|
||||
current=user.account.current,
|
||||
trade_exchange=trade_exchange,
|
||||
|
||||
@@ -10,17 +10,15 @@ import copy
|
||||
class SoftTopkStrategy(WeightStrategyBase):
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
model,
|
||||
dataset,
|
||||
topk,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
order_generator_cls_or_obj=OrderGenWInteract,
|
||||
trade_exchange=None,
|
||||
max_sold_weight=1.0,
|
||||
risk_degree=0.95,
|
||||
buy_method="first_fill",
|
||||
level_infra={},
|
||||
common_infra={},
|
||||
**kwargs,
|
||||
):
|
||||
"""Parameter
|
||||
@@ -33,7 +31,7 @@ class SoftTopkStrategy(WeightStrategyBase):
|
||||
average_fill: assign the weight to the stocks rank high averagely.
|
||||
"""
|
||||
super(SoftTopkStrategy, self).__init__(
|
||||
step_bar, model, dataset, start_time, end_time, order_generator_cls_or_obj, trade_exchange
|
||||
model, dataset, order_generator_cls_or_obj, level_infra, common_infra, **kwargs
|
||||
)
|
||||
self.topk = topk
|
||||
self.max_sold_weight = max_sold_weight
|
||||
|
||||
@@ -3,29 +3,26 @@ import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ...utils.sample import sample_feature
|
||||
from ...utils.resam import resam_ts_data
|
||||
from ...strategy.base import ModelStrategy
|
||||
from ..backtest.order import Order
|
||||
from ..backtest.faculty import common_faculty
|
||||
from .order_generator import OrderGenWInteract
|
||||
|
||||
|
||||
class TopkDropoutStrategy(ModelStrategy):
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
model,
|
||||
dataset,
|
||||
topk,
|
||||
n_drop,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trade_exchange=None,
|
||||
method_sell="bottom",
|
||||
method_buy="top",
|
||||
risk_degree=0.95,
|
||||
hold_thresh=1,
|
||||
only_tradable=False,
|
||||
level_infra={},
|
||||
common_infra={},
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -51,8 +48,9 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
else:
|
||||
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
|
||||
"""
|
||||
super(TopkDropoutStrategy, self).__init__(step_bar, model, dataset, start_time, end_time, **kwargs)
|
||||
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
|
||||
super(TopkDropoutStrategy, self).__init__(
|
||||
model, dataset, level_infra=level_infra, common_infra=common_infra, **kwargs
|
||||
)
|
||||
self.topk = topk
|
||||
self.n_drop = n_drop
|
||||
self.method_sell = method_sell
|
||||
@@ -61,6 +59,20 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
self.hold_thresh = hold_thresh
|
||||
self.only_tradable = only_tradable
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
common_infra : dict, optional
|
||||
common infrastructure for backtesting, by default None
|
||||
- It should include `trade_account`, used to get position
|
||||
- It should include `trade_exchange`, used to provide market info
|
||||
"""
|
||||
super(TopkDropoutStrategy, self).reset_common_infra(common_infra)
|
||||
|
||||
if "trade_exchange" in common_infra:
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def get_risk_degree(self, trade_index=None):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
@@ -69,11 +81,11 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
# It will use 95% amoutn of your total value by default
|
||||
return self.risk_degree
|
||||
|
||||
def generate_order_list(self, execute_state):
|
||||
super(TopkDropoutStrategy, self).step()
|
||||
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
|
||||
pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1)
|
||||
pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
def generate_trade_decision(self, execute_state):
|
||||
trade_index = self.trade_calendar.get_trade_index()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_calendar_time(trade_index, shift=1)
|
||||
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
if pred_score is None:
|
||||
return []
|
||||
if self.only_tradable:
|
||||
@@ -115,8 +127,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
def filter_stock(l):
|
||||
return l
|
||||
|
||||
current = execute_state.get("current")
|
||||
current_temp = copy.deepcopy(current)
|
||||
current_temp = copy.deepcopy(self.trade_position)
|
||||
# generate order list for this adjust date
|
||||
sell_order_list = []
|
||||
buy_order_list = []
|
||||
@@ -168,7 +179,8 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
continue
|
||||
if code in sell:
|
||||
# check hold limit
|
||||
if current_temp.get_stock_count(code, bar=self.step_bar) < self.hold_thresh:
|
||||
step_bar = self.trade_calendar.get_step_bar()
|
||||
if current_temp.get_stock_count(code, bar=step_bar) < self.hold_thresh:
|
||||
continue
|
||||
# sell order
|
||||
sell_amount = current_temp.get_stock_amount(code=code)
|
||||
@@ -228,22 +240,35 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
class WeightStrategyBase(ModelStrategy):
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
model,
|
||||
dataset,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
order_generator_cls_or_obj=OrderGenWInteract,
|
||||
trade_exchange=None,
|
||||
level_infra={},
|
||||
common_infra={},
|
||||
**kwargs,
|
||||
):
|
||||
super(WeightStrategyBase, self).__init__(step_bar, model, dataset, start_time, end_time, **kwargs)
|
||||
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
|
||||
super(WeightStrategyBase, self).__init__(
|
||||
model, dataset, level_infra=level_infra, common_infra=common_infra, **kwargs
|
||||
)
|
||||
if isinstance(order_generator_cls_or_obj, type):
|
||||
self.order_generator = order_generator_cls_or_obj()
|
||||
else:
|
||||
self.order_generator = order_generator_cls_or_obj
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
common_infra : dict, optional
|
||||
common infrastructure for backtesting, by default None
|
||||
- It should include `trade_account`, used to get position
|
||||
- It should include `trade_exchange`, used to provide market info
|
||||
"""
|
||||
super(WeightStrategyBase, self).reset_common_infra(common_infra)
|
||||
|
||||
if "trade_exchange" in common_infra:
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def get_risk_degree(self, trade_index=None):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
@@ -267,7 +292,7 @@ class WeightStrategyBase(ModelStrategy):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def generate_order_list(self, execute_state):
|
||||
def generate_trade_decision(self, execute_state):
|
||||
"""
|
||||
Parameters
|
||||
-----------
|
||||
@@ -280,23 +305,22 @@ class WeightStrategyBase(ModelStrategy):
|
||||
trade_date : pd.Timestamp
|
||||
date.
|
||||
"""
|
||||
# generate_order_list
|
||||
# generate_trade_decision
|
||||
# generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list
|
||||
super(WeightStrategyBase, self).step()
|
||||
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
|
||||
pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1)
|
||||
pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
trade_index = self.trade_calendar.get_trade_index()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_calendar_time(trade_index, shift=1)
|
||||
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
if pred_score is None:
|
||||
return []
|
||||
current = execute_state.get("current")
|
||||
current_temp = copy.deepcopy(current)
|
||||
current_temp = copy.deepcopy(self.trade_position)
|
||||
target_weight_position = self.generate_target_weight_position(
|
||||
score=pred_score, current=current_temp, trade_start_time=trade_start_time, trade_end_time=trade_end_time
|
||||
)
|
||||
order_list = self.order_generator.generate_order_list_from_target_weight_position(
|
||||
current=current_temp,
|
||||
trade_exchange=self.trade_exchange,
|
||||
risk_degree=self.get_risk_degree(self.trade_index),
|
||||
risk_degree=self.get_risk_degree(trade_index),
|
||||
target_weight_position=target_weight_position,
|
||||
pred_start_time=pred_start_time,
|
||||
pred_end_time=pred_end_time,
|
||||
|
||||
@@ -1,80 +1,77 @@
|
||||
import copy
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Union
|
||||
|
||||
from ...utils.sample import sample_feature
|
||||
from ...utils.resam import resam_ts_data
|
||||
from ...data.data import D
|
||||
from ...data.dataset.utils import convert_index_format
|
||||
from ...strategy.base import RuleStrategy, OrderEnhancement
|
||||
from ...strategy.base import RuleStrategy
|
||||
from ..backtest.order import Order
|
||||
from ..backtest.faculty import common_faculty
|
||||
|
||||
|
||||
class TWAPStrategy(RuleStrategy, OrderEnhancement):
|
||||
class TWAPStrategy(RuleStrategy):
|
||||
"""TWAP Strategy for trading"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trade_exchange=None,
|
||||
trade_order_list=[],
|
||||
**kwargs,
|
||||
):
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
trade_exchange : Exchange, optional
|
||||
exchange that provides market info, by default None
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_faculty
|
||||
trade_order_list : list, optional
|
||||
order list to trade, which the strategy will trade in [start_time , end_time] , by default []
|
||||
common_infra : dict, optional
|
||||
common infrastructure for backtesting, by default None
|
||||
- It should include `trade_account`, used to get position
|
||||
- It should include `trade_exchange`, used to provide market info
|
||||
"""
|
||||
super(TWAPStrategy, self).__init__(step_bar, start_time, end_time, **kwargs)
|
||||
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
|
||||
self.trade_order_list = trade_order_list
|
||||
super(TWAPStrategy, self).reset_common_infra(common_infra)
|
||||
if common_infra is not None:
|
||||
if "trade_exchange" in common_infra:
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def reset(self, trade_order_list: list = None, **kwargs):
|
||||
super(TWAPStrategy, self).reset(**kwargs)
|
||||
OrderEnhancement.reset(self, trade_order_list=trade_order_list)
|
||||
if trade_order_list is not None:
|
||||
def reset(self, rely_trade_decision: object = None, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
rely_trade_decision : object, optional
|
||||
"""
|
||||
|
||||
super(TWAPStrategy, self).reset(rely_trade_decision=rely_trade_decision, common_infra=common_infra, **kwargs)
|
||||
if rely_trade_decision is not None:
|
||||
self.trade_amount = {}
|
||||
for order in self.trade_order_list:
|
||||
for order in rely_trade_decision:
|
||||
self.trade_amount[(order.stock_id, order.direction)] = order.amount
|
||||
|
||||
def generate_order_list(self, execute_state):
|
||||
super(TWAPStrategy, self).step()
|
||||
trade_info = execute_state.get("trade_info")
|
||||
def generate_trade_decision(self, execute_state):
|
||||
|
||||
# update the order amount
|
||||
trade_info = execute_state
|
||||
for order, _, _, _ in trade_info:
|
||||
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
|
||||
|
||||
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
|
||||
trade_index = self.trade_calendar.get_trade_index()
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
|
||||
order_list = []
|
||||
for order in self.trade_order_list:
|
||||
for order in self.rely_trade_decision:
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
|
||||
):
|
||||
continue
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
|
||||
_order_amount = None
|
||||
# consider trade unit
|
||||
if _amount_trade_unit is None:
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (
|
||||
self.trade_len - self.trade_index + 1
|
||||
)
|
||||
if self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
|
||||
# split the order equally
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_index + 1)
|
||||
# without considering trade unit
|
||||
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
|
||||
# split the order equally
|
||||
# floor((trade_unit_cnt + trade_len - trade_index) / (trade_len - trade_index + 1)) == ceil(trade_unit_cnt / (trade_len - trade_index + 1))
|
||||
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
|
||||
_order_amount = (
|
||||
(trade_unit_cnt + self.trade_len - self.trade_index)
|
||||
// (self.trade_len - self.trade_index + 1)
|
||||
* _amount_trade_unit
|
||||
(trade_unit_cnt + trade_len - trade_index) // (trade_len - trade_index + 1) * _amount_trade_unit
|
||||
)
|
||||
|
||||
if order.direction == order.SELL:
|
||||
# sell all amount at last
|
||||
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
|
||||
_order_amount is None or self.trade_index == self.trade_len
|
||||
_order_amount is None or trade_index == trade_len
|
||||
):
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
|
||||
|
||||
@@ -92,7 +89,7 @@ class TWAPStrategy(RuleStrategy, OrderEnhancement):
|
||||
return order_list
|
||||
|
||||
|
||||
class SBBStrategyBase(RuleStrategy, OrderEnhancement):
|
||||
class SBBStrategyBase(RuleStrategy):
|
||||
"""
|
||||
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy.
|
||||
"""
|
||||
@@ -101,81 +98,80 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement):
|
||||
TREND_SHORT = 1
|
||||
TREND_LONG = 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trade_exchange=None,
|
||||
trade_order_list=[],
|
||||
**kwargs,
|
||||
):
|
||||
def reset_common_infra(self, common_infra):
|
||||
super(SBBStrategyBase, self).reset_common_infra(common_infra)
|
||||
if common_infra is not None:
|
||||
if "trade_exchange" in common_infra:
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def reset(self, rely_trade_decision=None, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
trade_exchange : Exchange, optional
|
||||
exchange that provides market info, by default None
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_faculty
|
||||
trade_order_list : list, optional
|
||||
order list to trade, which the strategy will trade in [start_time , end_time] , by default []
|
||||
rely_trade_decision : object, optional
|
||||
common_infra : None, optional
|
||||
common infrastructure for backtesting, by default None
|
||||
- It should include `trade_account`, used to get position
|
||||
- It should include `trade_exchange`, used to provide market info
|
||||
"""
|
||||
super(SBBStrategyBase, self).__init__(step_bar, start_time, end_time, **kwargs)
|
||||
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
|
||||
self.trade_order_list = trade_order_list
|
||||
|
||||
def reset(self, trade_order_list=None, **kwargs):
|
||||
super(SBBStrategyBase, self).reset(**kwargs)
|
||||
OrderEnhancement.reset(self, trade_order_list=trade_order_list)
|
||||
if trade_order_list is not None:
|
||||
super(SBBStrategyBase, self).reset(rely_trade_decision=rely_trade_decision, **kwargs)
|
||||
if rely_trade_decision is not None:
|
||||
self.trade_trend = {}
|
||||
self.trade_amount = {}
|
||||
for order in self.trade_order_list:
|
||||
# init the trade amount of order and predicted trade trend
|
||||
for order in rely_trade_decision:
|
||||
self.trade_trend[(order.stock_id, order.direction)] = self.TREND_MID
|
||||
self.trade_amount[(order.stock_id, order.direction)] = order.amount
|
||||
|
||||
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
|
||||
raise NotImplementedError("pred_price_trend method is not implemented!")
|
||||
|
||||
def generate_order_list(self, execute_state):
|
||||
super(SBBStrategyBase, self).step()
|
||||
def generate_trade_decision(self, execute_state):
|
||||
|
||||
trade_info = execute_state.get("trade_info")
|
||||
# update the order amount
|
||||
trade_info = execute_state
|
||||
for order, _, _, _ in trade_info:
|
||||
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
|
||||
|
||||
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
|
||||
pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1)
|
||||
trade_index = self.trade_calendar.get_trade_index()
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_calendar_time(trade_index, shift=1)
|
||||
order_list = []
|
||||
for order in self.trade_order_list:
|
||||
if self.trade_index % 2 == 1:
|
||||
# for each order in in self.rely_trade_decision
|
||||
for order in self.rely_trade_decision:
|
||||
# predict the price trend
|
||||
if trade_index % 2 == 1:
|
||||
_pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time)
|
||||
else:
|
||||
_pred_trend = self.trade_trend[(order.stock_id, order.direction)]
|
||||
|
||||
# if not tradable, continue
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
|
||||
):
|
||||
if self.trade_index % 2 == 1:
|
||||
if trade_index % 2 == 1:
|
||||
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
|
||||
continue
|
||||
|
||||
# get amount of one trade unit
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
|
||||
if _pred_trend == self.TREND_MID:
|
||||
_order_amount = None
|
||||
# considering trade unit
|
||||
if _amount_trade_unit is None:
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (
|
||||
self.trade_len - self.trade_index + 1
|
||||
)
|
||||
# split the order equally
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_index + 1)
|
||||
# without considering trade unit
|
||||
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
|
||||
# cal how many trade unit
|
||||
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
|
||||
# split the order equally
|
||||
# floor((trade_unit_cnt + trade_len - trade_index) / (trade_len - trade_index + 1)) == ceil(trade_unit_cnt / (trade_len - trade_index + 1))
|
||||
_order_amount = (
|
||||
(trade_unit_cnt + self.trade_len - self.trade_index)
|
||||
// (self.trade_len - self.trade_index + 1)
|
||||
* _amount_trade_unit
|
||||
(trade_unit_cnt + trade_len - trade_index) // (trade_len - trade_index + 1) * _amount_trade_unit
|
||||
)
|
||||
if order.direction == order.SELL:
|
||||
# sell all amount at last
|
||||
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
|
||||
_order_amount is None or self.trade_index == self.trade_len
|
||||
_order_amount is None or trade_index == trade_len
|
||||
):
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
|
||||
|
||||
@@ -185,36 +181,43 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement):
|
||||
amount=_order_amount,
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
direction=order.direction, # 1 for buy
|
||||
direction=order.direction,
|
||||
factor=order.factor,
|
||||
)
|
||||
order_list.append(_order)
|
||||
# print("DEBUG AMOUNT", _order_amount, self.trade_amount[(order.stock_id, order.direction)], _amount_trade_unit)
|
||||
|
||||
else:
|
||||
_order_amount = None
|
||||
# considering trade unit
|
||||
if _amount_trade_unit is None:
|
||||
# N trade day last, split the order into N + 1 parts, and trade 2 parts
|
||||
_order_amount = (
|
||||
2
|
||||
* self.trade_amount[(order.stock_id, order.direction)]
|
||||
/ (self.trade_len - self.trade_index + 2)
|
||||
2 * self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_index + 2)
|
||||
)
|
||||
# without considering trade unit
|
||||
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
|
||||
# cal how many trade unit
|
||||
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
|
||||
# N trade day last, split the order into N + 1 parts, and trade 2 parts
|
||||
_order_amount = (
|
||||
(trade_unit_cnt + self.trade_len - self.trade_index + 1)
|
||||
// (self.trade_len - self.trade_index + 2)
|
||||
(trade_unit_cnt + trade_len - trade_index + 1)
|
||||
// (trade_len - trade_index + 2)
|
||||
* 2
|
||||
* _amount_trade_unit
|
||||
)
|
||||
if order.direction == order.SELL:
|
||||
# sell all amount at last
|
||||
if self.trade_amount[(order.stock_id, order.direction)] >= 1e-5 and (
|
||||
_order_amount is None or self.trade_index == self.trade_len
|
||||
_order_amount is None or trade_index == trade_len
|
||||
):
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
|
||||
|
||||
if _order_amount:
|
||||
_order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
|
||||
if self.trade_index % 2 == 1:
|
||||
if trade_index % 2 == 1:
|
||||
# in the first of two adjacent bar
|
||||
# if look short on the price, sell the stock more
|
||||
# if look long on the price, sell the stock more
|
||||
if (
|
||||
_pred_trend == self.TREND_SHORT
|
||||
and order.direction == order.SELL
|
||||
@@ -231,6 +234,9 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement):
|
||||
)
|
||||
order_list.append(_order)
|
||||
else:
|
||||
# in the second of two adjacent bar
|
||||
# if look short on the price, buy the stock more
|
||||
# if look long on the price, sell the stock more
|
||||
if (
|
||||
_pred_trend == self.TREND_SHORT
|
||||
and order.direction == order.BUY
|
||||
@@ -246,8 +252,8 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement):
|
||||
factor=order.factor,
|
||||
)
|
||||
order_list.append(_order)
|
||||
# print("DEBUG AMOUNT", _order_amount, self.trade_amount[(order.stock_id, order.direction)], _amount_trade_unit)
|
||||
if self.trade_index % 2 == 1:
|
||||
|
||||
if trade_index % 2 == 1:
|
||||
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
|
||||
|
||||
return order_list
|
||||
@@ -260,13 +266,11 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trade_exchange=None,
|
||||
trade_order_list=[],
|
||||
rely_trade_decision=[],
|
||||
instruments="csi300",
|
||||
freq="day",
|
||||
level_infra={},
|
||||
common_infra={},
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -278,47 +282,49 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
freq of EMA signal, by default "day"
|
||||
Note: `freq` may be different from `steb_bar`
|
||||
"""
|
||||
super(SBBStrategyEMA, self).__init__(step_bar, start_time, end_time, trade_exchange, trade_order_list, **kwargs)
|
||||
if instruments is None:
|
||||
warnings.warn("`instruments` is not set, will load all stocks")
|
||||
self.instruments = "all"
|
||||
if isinstance(instruments, str):
|
||||
self.instruments = D.instruments(instruments)
|
||||
self.freq = freq
|
||||
super(SBBStrategyEMA, self).__init__(rely_trade_decision, level_infra, common_infra, **kwargs)
|
||||
|
||||
def reset(self, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None, **kwargs):
|
||||
def _reset_signal(self):
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
fields = ["EMA($close, 10)-EMA($close, 20)"]
|
||||
signal_start_time, _ = self.trade_calendar.get_calendar_time(trade_index=1, shift=1)
|
||||
_, signal_end_time = self.trade_calendar.get_calendar_time(trade_index=trade_len, shift=1)
|
||||
signal_df = D.features(
|
||||
self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
|
||||
)
|
||||
signal_df = convert_index_format(signal_df)
|
||||
signal_df.columns = ["signal"]
|
||||
self.signal = {}
|
||||
for stock_id, stock_val in signal_df.groupby(level="instrument"):
|
||||
self.signal[stock_id] = stock_val
|
||||
|
||||
def reset_level_infra(self, level_infra):
|
||||
"""
|
||||
Reset EMA signal for trading
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time : Union[str, pd.Timestamp], optional
|
||||
start time for trading, also used to calculate the start time of EMA signal, by default None
|
||||
|
||||
end_time : Union[str, pd.Timestamp], optional
|
||||
end time for trading, also used to calculate the end time of EMA signal, by default None
|
||||
reset level-shared infra
|
||||
- After reset the trade_calendar, the signal will be changed
|
||||
"""
|
||||
super(SBBStrategyEMA, self).reset(start_time=start_time, end_time=end_time, **kwargs)
|
||||
if self.start_time and self.end_time and (start_time or end_time):
|
||||
fields = ["EMA($close, 10)-EMA($close, 20)"]
|
||||
signal_start_time, _ = self._get_calendar_time(trade_index=1, shift=1)
|
||||
_, signal_end_time = self._get_calendar_time(trade_index=self.trade_len, shift=1)
|
||||
signal_df = D.features(
|
||||
self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
|
||||
)
|
||||
signal_df = convert_index_format(signal_df)
|
||||
signal_df.columns = ["signal"]
|
||||
self.signal = {}
|
||||
for stock_id, stock_val in signal_df.groupby(level="instrument"):
|
||||
self.signal[stock_id] = stock_val
|
||||
if not hasattr(self, "level_infra"):
|
||||
self.level_infra = level_infra
|
||||
else:
|
||||
self.level_infra.update(level_infra)
|
||||
|
||||
if "trade_calendar" in level_infra:
|
||||
self.trade_calendar = level_infra.get("trade_calendar")
|
||||
self._reset_signal()
|
||||
|
||||
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
|
||||
|
||||
if stock_id not in self.signal:
|
||||
return self.TREND_MID
|
||||
else:
|
||||
_sample_signal = sample_feature(
|
||||
self.signal[stock_id], pred_start_time, pred_end_time, fields="signal", method="last"
|
||||
_sample_signal = resam_ts_data(
|
||||
self.signal[stock_id]["signal"], pred_start_time, pred_end_time, method="last"
|
||||
)
|
||||
if _sample_signal is None or _sample_signal.iloc[0] == 0:
|
||||
return self.TREND_MID
|
||||
|
||||
@@ -26,7 +26,7 @@ from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, co
|
||||
from .base import Feature
|
||||
from .cache import DiskDatasetCache, DiskExpressionCache
|
||||
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
|
||||
from ..utils.sample import sample_calendar
|
||||
from ..utils.resam import resam_calendar
|
||||
|
||||
|
||||
class CalendarProvider(abc.ABC):
|
||||
@@ -133,7 +133,7 @@ class CalendarProvider(abc.ABC):
|
||||
if freq_sam is None:
|
||||
return _calendar, _calendar_index
|
||||
else:
|
||||
_calendar_sam = sample_calendar(_calendar, freq, freq_sam)
|
||||
_calendar_sam = resam_calendar(_calendar, freq, freq_sam)
|
||||
_calendar_sam_index = {x: i for i, x in enumerate(_calendar_sam)}
|
||||
H["c"][flag] = _calendar_sam, _calendar_sam_index
|
||||
return _calendar_sam, _calendar_sam_index
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .interpreter import StateInterpreter, ActionInterpreter
|
||||
from typing import Union
|
||||
|
||||
from .interpreter import StateInterpreter, ActionInterpreter
|
||||
from ..contrib.backtest.executor import BaseExecutor
|
||||
from ..utils import init_instance_by_config
|
||||
|
||||
|
||||
class BaseRLEnv:
|
||||
@@ -52,35 +54,22 @@ class QlibIntRLEnv(QlibRLEnv):
|
||||
def __init__(
|
||||
self,
|
||||
executor: BaseExecutor,
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
state_interpret_kwargs: dict = {},
|
||||
action_interpret_kwargs: dict = {},
|
||||
state_interpreter: Union[dict, StateInterpreter],
|
||||
action_interpreter: Union[dict, ActionInterpreter],
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state_interpreter : StateInterpreter
|
||||
state_interpreter : Union[dict, StateInterpreter]
|
||||
interpretor that interprets the qlib execute result into rl env state.
|
||||
action_interpreter : ActionInterpreter
|
||||
|
||||
action_interpreter : Union[dict, ActionInterpreter]
|
||||
interpretor that interprets the rl agent action into qlib order list
|
||||
state_interpret_kwargs : dict, optional
|
||||
arguments may be used in `state_interpreter.interpret`, by default {}
|
||||
such as the following arguments:
|
||||
- trade exchange : Exchange
|
||||
Exchange that can provide market info
|
||||
action_interpret_kwargs: dict, optional
|
||||
arguments may be used in `action_interpreter.interpret`, by default {}
|
||||
such as the following arguments:
|
||||
- trade_order_list : List[Order]
|
||||
If the strategy is used to split order, it presents the trade order pool.
|
||||
"""
|
||||
super(QlibIntRLEnv, self).__init__(executor=executor)
|
||||
self.state_interpreter = state_interpreter
|
||||
self.action_interpreter = action_interpreter
|
||||
self.state_interpret_kwargs = state_interpret_kwargs
|
||||
self.action_interpret_kwargs = action_interpret_kwargs
|
||||
self.state_interpreter = init_instance_by_config(state_interpreter)
|
||||
self.action_interpreter = init_instance_by_config(action_interpreter)
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
@@ -96,11 +85,9 @@ class QlibIntRLEnv(QlibRLEnv):
|
||||
|
||||
Returns
|
||||
-------
|
||||
env state to rl rl policy
|
||||
env state to rl policy
|
||||
"""
|
||||
_interpret_action = self.action_interpreter.interpret(action=action, **self.state_interpret_kwargs)
|
||||
_interpret_action = self.action_interpreter.interpret(action=action)
|
||||
_execute_result = self.executor.execute(_interpret_action)
|
||||
_interpret_state = self.state_interpreter.interpret(
|
||||
execute_result=_execute_result, **self.action_interpret_kwargs
|
||||
)
|
||||
_interpret_state = self.state_interpreter.interpret(execute_result=_execute_result)
|
||||
return _interpret_state
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
class BaseInterpreter:
|
||||
"""Base Interpreter"""
|
||||
|
||||
@staticmethod
|
||||
def interpret(**kwargs):
|
||||
raise NotImplementedError("interpret is not implemented!")
|
||||
|
||||
@@ -13,7 +12,6 @@ class BaseInterpreter:
|
||||
class ActionInterpreter(BaseInterpreter):
|
||||
"""Action Interpreter that interpret rl agent action into qlib orders"""
|
||||
|
||||
@staticmethod
|
||||
def interpret(action, **kwargs):
|
||||
"""interpret method
|
||||
|
||||
@@ -34,7 +32,6 @@ class ActionInterpreter(BaseInterpreter):
|
||||
class StateInterpreter(BaseInterpreter):
|
||||
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
|
||||
|
||||
@staticmethod
|
||||
def interpret(execute_result, **kwargs):
|
||||
"""interpret method
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import copy
|
||||
import pandas as pd
|
||||
from typing import List, Union
|
||||
|
||||
@@ -9,16 +10,70 @@ from ..model.base import BaseModel
|
||||
from ..data.dataset import DatasetH
|
||||
from ..data.dataset.utils import convert_index_format
|
||||
from ..contrib.backtest.order import Order
|
||||
from ..contrib.backtest.executor import BaseTradeCalendar
|
||||
from ..rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from ..utils import init_instance_by_config
|
||||
|
||||
|
||||
class BaseStrategy(BaseTradeCalendar):
|
||||
class BaseStrategy:
|
||||
"""Base strategy for trading"""
|
||||
|
||||
def generate_order_list(self, execute_state):
|
||||
"""Generate order list in each trading bar"""
|
||||
raise NotImplementedError("generator_order_list is not implemented!")
|
||||
def __init__(
|
||||
self,
|
||||
rely_trade_decision: object = None,
|
||||
level_infra: dict = {},
|
||||
common_infra: dict = {},
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
rely_trade_decision : object, optional
|
||||
the high-level trade decison on which the startegy rely, and it will be traded in [start_time , end_time] , by default None
|
||||
- If the strategy is used to split trade decison, it will be used
|
||||
- If the strategy is used for portfolio management, it can be ignored
|
||||
level_infra : dict, optional
|
||||
level shared infrastructure for backtesting, including trade_calendar
|
||||
common_infra : dict, optional
|
||||
common infrastructure for backtesting, including trade_account, trade_exchange, .etc
|
||||
"""
|
||||
|
||||
self.reset(level_infra=level_infra, common_infra=common_infra, rely_trade_decision=rely_trade_decision)
|
||||
|
||||
def reset_level_infra(self, level_infra):
|
||||
if not hasattr(self, "level_infra"):
|
||||
self.level_infra = level_infra
|
||||
else:
|
||||
self.level_infra.update(level_infra)
|
||||
|
||||
if "trade_calendar" in level_infra:
|
||||
self.trade_calendar = level_infra.get("trade_calendar")
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
if not hasattr(self, "common_infra"):
|
||||
self.common_infra = common_infra
|
||||
else:
|
||||
self.common_infra.update(common_infra)
|
||||
|
||||
if "trade_account" in common_infra:
|
||||
self.trade_position = common_infra.get("trade_account").current
|
||||
|
||||
def reset(self, level_infra: dict = None, common_infra: dict = None, rely_trade_decision=None, **kwargs):
|
||||
"""
|
||||
- reset `level_infra`, used to reset trade_calendar, .etc
|
||||
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
|
||||
- reset `rely_trade_decision`, used to make split decison
|
||||
"""
|
||||
if level_infra is not None:
|
||||
self.reset_level_infra(level_infra)
|
||||
|
||||
if common_infra is not None:
|
||||
self.reset_common_infra(common_infra)
|
||||
|
||||
if rely_trade_decision is not None:
|
||||
self.rely_trade_decision = rely_trade_decision
|
||||
|
||||
def generate_trade_decision(self, execute_state):
|
||||
"""Generate trade decision in each trading bar"""
|
||||
raise NotImplementedError("generate_trade_decision is not implemented!")
|
||||
|
||||
|
||||
class RuleStrategy(BaseStrategy):
|
||||
@@ -32,11 +87,11 @@ class ModelStrategy(BaseStrategy):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_bar: str,
|
||||
model: BaseModel,
|
||||
dataset: DatasetH,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
rely_trade_decision: object = None,
|
||||
level_infra: dict = {},
|
||||
common_infra: dict = {},
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -49,11 +104,10 @@ class ModelStrategy(BaseStrategy):
|
||||
kwargs : dict
|
||||
arguments that will be passed into `reset` method
|
||||
"""
|
||||
super(ModelStrategy, self).__init__(rely_trade_decision, level_infra, common_infra, **kwargs)
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
self.pred_scores = convert_index_format(self.model.predict(dataset), level="datetime")
|
||||
# pred_score_dates = self.pred_scores.index.get_level_values(level="datetime")
|
||||
super(ModelStrategy, self).__init__(step_bar, start_time, end_time, **kwargs)
|
||||
|
||||
def _update_model(self):
|
||||
"""
|
||||
@@ -70,10 +124,10 @@ class RLStrategy(BaseStrategy):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_bar: str,
|
||||
policy,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
rely_trade_decision: object = None,
|
||||
level_infra: dict = {},
|
||||
common_infra: dict = {},
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -82,7 +136,7 @@ class RLStrategy(BaseStrategy):
|
||||
policy :
|
||||
RL policy for generate action
|
||||
"""
|
||||
super(RLStrategy, self).__init__(step_bar, start_time, end_time, **kwargs)
|
||||
super(RLStrategy, self).__init__(rely_trade_decision, level_infra, common_infra, **kwargs)
|
||||
self.policy = policy
|
||||
|
||||
|
||||
@@ -91,14 +145,12 @@ class RLIntStrategy(RLStrategy):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_bar: str,
|
||||
policy,
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
state_interpret_kwargs: dict = {},
|
||||
action_interpret_kwargs: dict = {},
|
||||
rely_trade_decision: object = None,
|
||||
level_infra: dict = {},
|
||||
common_infra: dict = {},
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -112,49 +164,16 @@ class RLIntStrategy(RLStrategy):
|
||||
start time of trading, by default None
|
||||
end_time : Union[str, pd.Timestamp], optional
|
||||
end time of trading, by default None
|
||||
state_interpret_kwargs : dict, optional
|
||||
arguments may be used in `state_interpreter.interpret`, by default {}
|
||||
such as the following arguments:
|
||||
- trade exchange : Exchange
|
||||
Exchange that can provide market info
|
||||
action_interpret_kwargs: dict, optional
|
||||
arguments may be used in `action_interpreter.interpret`, by default {}
|
||||
such as the following arguments:
|
||||
- trade_order_list : List[Order]
|
||||
If the strategy is used to split order, it presents the trade order pool.
|
||||
"""
|
||||
super(RLIntStrategy, self).__init__(step_bar, policy, start_time, end_time, **kwargs)
|
||||
super(RLIntStrategy, self).__init__(policy, rely_trade_decision, level_infra, common_infra, **kwargs)
|
||||
|
||||
self.policy = policy
|
||||
self.action_interpreter = action_interpreter
|
||||
self.state_interpreter = state_interpreter
|
||||
self.state_interpret_kwargs = state_interpret_kwargs
|
||||
self.action_interpret_kwargs = action_interpret_kwargs
|
||||
self.state_interpreter = init_instance_by_config(state_interpreter)
|
||||
self.action_interpreter = init_instance_by_config(action_interpreter)
|
||||
|
||||
def generate_order_list(self, execute_state):
|
||||
def generate_trade_decision(self, execute_state):
|
||||
super(RLStrategy, self).step()
|
||||
_interpret_state = self.state_interpretor.interpret(
|
||||
execute_result=execute_state, **self.action_interpret_kwargs
|
||||
)
|
||||
_interpret_state = self.state_interpretor.interpret(execute_result=execute_state)
|
||||
_policy_action = self.policy.step(_interpret_state)
|
||||
_order_list = self.action_interpreter.interpret(action=_policy_action, **self.state_interpret_kwargs)
|
||||
_order_list = self.action_interpreter.interpret(action=_policy_action)
|
||||
return _order_list
|
||||
|
||||
|
||||
class OrderEnhancement:
|
||||
"""
|
||||
Order enhancement for strategy
|
||||
- If the strategy is used to split orders, the enhancement should be inherited
|
||||
- If the strategy is used for portfolio management, the enhancement can be ignored
|
||||
"""
|
||||
|
||||
def reset(self, trade_order_list: List[Order] = None):
|
||||
"""reset trade orders for split strategy
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_order_list for split strategy: List[Order], optional
|
||||
trading orders , by default None
|
||||
"""
|
||||
if trade_order_list is not None:
|
||||
self.trade_order_list = trade_order_list
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
import re
|
||||
import datetime
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Tuple, List, Union, Optional, Callable
|
||||
|
||||
from . import lazy_sort_index
|
||||
from ..config import C
|
||||
|
||||
|
||||
def parse_freq(freq: str) -> Tuple[int, str]:
|
||||
"""
|
||||
@@ -50,9 +55,10 @@ def parse_freq(freq: str) -> Tuple[int, str]:
|
||||
return _count, _freq_format_dict[_freq]
|
||||
|
||||
|
||||
def sample_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np.ndarray:
|
||||
def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np.ndarray:
|
||||
"""
|
||||
Sample the calendar with frequency freq_raw into the calendar with frequency freq_sam
|
||||
Resample the calendar with frequency freq_raw into the calendar with frequency freq_sam
|
||||
Assumption: The fix length (240) of the calendar in each day.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -72,24 +78,36 @@ def sample_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> n
|
||||
sam_count, freq_sam = parse_freq(freq_sam)
|
||||
if not len(calendar_raw):
|
||||
return calendar_raw
|
||||
|
||||
# if freq_sam is xminute, divide each trading day into several bars evenly
|
||||
if freq_sam == "minute":
|
||||
|
||||
def cal_next_sam_minute(x, sam_minutes):
|
||||
hour = x.hour
|
||||
minute = x.minute
|
||||
if (hour == 9 and minute >= 30) or (9 < hour < 11) or (hour == 11 and minute < 30):
|
||||
minute_index = (hour - 9) * 60 + minute - 30
|
||||
elif 13 <= hour < 15:
|
||||
minute_index = (hour - 13) * 60 + minute + 120
|
||||
def cal_sam_minute(x, sam_minutes):
|
||||
day_time = pd.Timestamp(x.date())
|
||||
shift = C.min_data_shift
|
||||
# shift represents the shift minute the market time
|
||||
# - open time of stock market is [9:30 - shift*pd.Timedelta(minutes=1)]
|
||||
# - mid close time of stock market is [11:29 - shift*pd.Timedelta(minutes=1)]
|
||||
# - mid open time of stock market is [13:30 - shift*pd.Timedelta(minutes=1)]
|
||||
# - close time of stock market is [14:59 - shift*pd.Timedelta(minutes=1)]
|
||||
open_time = day_time + pd.Timedelta(hours=9, minutes=30) - shift * pd.Timedelta(minutes=1)
|
||||
mid_close_time = day_time + pd.Timedelta(hours=11, minutes=29) - shift * pd.Timedelta(minutes=1)
|
||||
mid_open_time = day_time + pd.Timedelta(hours=13, minutes=30) - shift * pd.Timedelta(minutes=1)
|
||||
close_time = day_time + pd.Timedelta(hours=14, minutes=59) - shift * pd.Timedelta(minutes=1)
|
||||
|
||||
if open_time <= x <= mid_close_time:
|
||||
minute_index = (x - open_time).seconds // 60
|
||||
elif mid_open_time <= x <= close_time:
|
||||
minute_index = (x - mid_open_time).seconds // 60 + 120
|
||||
else:
|
||||
raise ValueError("calendar hour must be in [9, 11] or [13, 15]")
|
||||
raise ValueError("datetime of calendar is out of range")
|
||||
|
||||
minute_index = minute_index // sam_minutes * sam_minutes
|
||||
|
||||
if 0 <= minute_index < 120:
|
||||
return 9 + (minute_index + 30) // 60, (minute_index + 30) % 60
|
||||
return open_time + minute_index * pd.Timedelta(minutes=1)
|
||||
elif 120 <= minute_index < 240:
|
||||
return 13 + (minute_index - 120) // 60, (minute_index - 120) % 60
|
||||
return mid_open_time + (minute_index - 120) * pd.Timedelta(minutes=1)
|
||||
else:
|
||||
raise ValueError("calendar minute_index error")
|
||||
|
||||
@@ -98,14 +116,10 @@ def sample_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> n
|
||||
else:
|
||||
if raw_count > sam_count:
|
||||
raise ValueError("raw freq must be higher than sampling freq")
|
||||
_calendar_minute = np.unique(
|
||||
list(
|
||||
map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_count), 0), calendar_raw)
|
||||
)
|
||||
)
|
||||
if calendar_raw[0] > _calendar_minute[0]:
|
||||
_calendar_minute[0] = calendar_raw[0]
|
||||
_calendar_minute = np.unique(list(map(lambda x: cal_sam_minute(x, sam_count), calendar_raw)))
|
||||
return _calendar_minute
|
||||
|
||||
# else, convert the raw calendar into day calendar, and divide the whole calendar into several bars evenly
|
||||
else:
|
||||
_calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw)))
|
||||
if freq_sam == "day":
|
||||
@@ -124,14 +138,14 @@ def sample_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> n
|
||||
raise ValueError("sampling freq must be xmin, xd, xw, xm")
|
||||
|
||||
|
||||
def get_sample_freq_calendar(
|
||||
def get_resam_calendar(
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
freq: str = "day",
|
||||
future: bool = False,
|
||||
) -> Tuple[np.ndarray, str, Optional[str]]:
|
||||
"""
|
||||
Get the calendar with frequency freq.
|
||||
Get the resampled calendar with frequency freq.
|
||||
|
||||
- If the calendar with the raw frequency freq exists, return it directly
|
||||
|
||||
@@ -186,16 +200,15 @@ def get_sample_freq_calendar(
|
||||
return _calendar, freq, freq_sam
|
||||
|
||||
|
||||
def sample_feature(
|
||||
feature: Union[pd.DataFrame, pd.Series],
|
||||
def resam_ts_data(
|
||||
ts_feature: Union[pd.DataFrame, pd.Series],
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
fields: Union[str, List[str]] = None,
|
||||
method: Union[str, Callable] = "last",
|
||||
method_kwargs: dict = {},
|
||||
):
|
||||
"""
|
||||
Sample value from pandas DataFrame or Series for each stock
|
||||
Resample value from time-series data
|
||||
|
||||
- If `feature` has MultiIndex[instrument, datetime], apply the `method` to each instruemnt data with datetime in [start_time, end_time]
|
||||
Example:
|
||||
@@ -217,7 +230,7 @@ def sample_feature(
|
||||
2010-01-12 2788.688232 164587.937500
|
||||
2010-01-13 2790.604004 145460.453125
|
||||
|
||||
print(sample_feature(feature, start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last"))
|
||||
print(resam_ts_data(feature, start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last"))
|
||||
$close $volume
|
||||
instrument
|
||||
SH600000 87.433578 28117442.0
|
||||
@@ -236,25 +249,23 @@ def sample_feature(
|
||||
2010-01-07 83.788803 20813402.0
|
||||
2010-01-08 84.730675 16044853.0
|
||||
|
||||
print(sample_feature(feature, start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last"))
|
||||
print(resam_ts_data(feature, start_time="2010-01-04", end_time="2010-01-05", method="last"))
|
||||
|
||||
$close 87.433578
|
||||
$volume 28117442.0
|
||||
|
||||
print(sample_feature(feature, start_time="2010-01-04", end_time="2010-01-05", fields="$close", method="last"))
|
||||
print(resam_ts_data(feature['$close'], start_time="2010-01-04", end_time="2010-01-05", method="last"))
|
||||
|
||||
87.433578
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Union[pd.DataFrame, pd.Series]
|
||||
Raw feature to be sampled
|
||||
Raw time-series feature to be resampled
|
||||
start_time : Union[str, pd.Timestamp], optional
|
||||
start sampling time, by default None
|
||||
end_time : Union[str, pd.Timestamp], optional
|
||||
end sampling time, by default None
|
||||
fields : Union[str, List[str]], optional
|
||||
column names, it's ignored when sample pd.Series data, by default None(all columns)
|
||||
method : Union[str, Callable], optional
|
||||
sample method, apply method function to each stock series data, by default "last"
|
||||
- If type(method) is str, it should be an attribute of SeriesGroupBy or DataFrameGroupby, and run feature.groupby
|
||||
@@ -264,24 +275,19 @@ def sample_feature(
|
||||
|
||||
Returns
|
||||
-------
|
||||
The Sampled DataFrame/Series/Value
|
||||
The Resampled DataFrame/Series/Value
|
||||
"""
|
||||
|
||||
selector_datetime = slice(start_time, end_time)
|
||||
if fields is None:
|
||||
fields = slice(None)
|
||||
|
||||
from ..data.dataset.utils import get_level_index
|
||||
|
||||
feature = lazy_sort_index(ts_feature)
|
||||
datetime_level = get_level_index(feature, level="datetime") == 0
|
||||
if isinstance(feature, pd.Series):
|
||||
feature = feature.loc[selector_datetime] if datetime_level else feature.loc[(slice(None), selector_datetime)]
|
||||
elif isinstance(feature, pd.DataFrame):
|
||||
feature = (
|
||||
feature.loc[selector_datetime, fields]
|
||||
if datetime_level
|
||||
else feature.loc[(slice(None), selector_datetime), fields]
|
||||
)
|
||||
if datetime_level:
|
||||
feature = feature.loc[selector_datetime]
|
||||
else:
|
||||
feature = feature.loc[(slice(None), selector_datetime)]
|
||||
if feature.empty:
|
||||
return None
|
||||
if isinstance(feature.index, pd.MultiIndex):
|
||||
@@ -296,5 +302,4 @@ def sample_feature(
|
||||
return method_func(feature, **method_kwargs)
|
||||
elif isinstance(method, str):
|
||||
return getattr(feature, method)(**method_kwargs)
|
||||
|
||||
return feature
|
||||
@@ -15,7 +15,7 @@ from ..data.dataset.handler import DataHandlerLP
|
||||
from ..utils import init_instance_by_config, get_module_by_module_path
|
||||
from ..log import get_module_logger
|
||||
from ..utils import flatten_dict
|
||||
from ..utils.sample import parse_freq
|
||||
from ..utils.resam import parse_freq
|
||||
from ..strategy.base import BaseStrategy
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return
|
||||
|
||||
@@ -291,8 +291,8 @@ class PortAnaRecord(RecordTemp):
|
||||
"""
|
||||
config["strategy"] : dict
|
||||
define the strategy class as well as the kwargs.
|
||||
config["env"] : dict
|
||||
define the env class as well as the kwargs.
|
||||
config["executor"] : dict
|
||||
define the executor class as well as the kwargs.
|
||||
config["backtest"] : dict
|
||||
define the backtest kwargs.
|
||||
risk_analysis_freq : int
|
||||
@@ -301,24 +301,26 @@ class PortAnaRecord(RecordTemp):
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
|
||||
self.strategy_config = config["strategy"]
|
||||
self.env_config = config["env"]
|
||||
self.executor_config = config["executor"]
|
||||
self.backtest_config = config["backtest"]
|
||||
_count, _freq = parse_freq(risk_analysis_freq)
|
||||
self.risk_analysis_freq = f"{_count}{_freq}"
|
||||
self.report_freq = self._get_report_freq(self.env_config)
|
||||
self.report_freq = self._get_report_freq(self.executor_config)
|
||||
|
||||
def _get_report_freq(self, env_config):
|
||||
def _get_report_freq(self, executor_config):
|
||||
ret_freq = []
|
||||
if env_config["kwargs"].get("generate_report", False):
|
||||
_count, _freq = parse_freq(env_config["kwargs"]["step_bar"])
|
||||
if executor_config["kwargs"].get("generate_report", False):
|
||||
_count, _freq = parse_freq(executor_config["kwargs"]["step_bar"])
|
||||
ret_freq.append(f"{_count}{_freq}")
|
||||
if "sub_env" in env_config["kwargs"]:
|
||||
ret_freq.extend(self._get_report_freq(env_config["kwargs"]["sub_env"]))
|
||||
if "sub_env" in executor_config["kwargs"]:
|
||||
ret_freq.extend(self._get_report_freq(executor_config["kwargs"]["sub_env"]))
|
||||
return ret_freq
|
||||
|
||||
def generate(self, **kwargs):
|
||||
# custom strategy and get backtest
|
||||
report_dict = normal_backtest(env=self.env_config, strategy=self.strategy_config, **self.backtest_config)
|
||||
report_dict = normal_backtest(
|
||||
executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config
|
||||
)
|
||||
for report_freq, (report_normal, positions_normal) in report_dict.items():
|
||||
self.recorder.save_objects(
|
||||
**{f"report_normal_{report_freq}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()
|
||||
|
||||
Reference in New Issue
Block a user