1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 02:50:58 +08:00

fix comments

This commit is contained in:
bxdd
2021-05-25 02:38:34 +08:00
parent eaa719df17
commit 0c6e505455
24 changed files with 855 additions and 978 deletions

View File

@@ -14,8 +14,11 @@ This example uses a DropoutTopkStrategy (a strategy based on the daily frequency
Start backtesting by running the following command:
```bash
python workflow.py
python workflow.py backtest
```
Also, reports is shown in workflow.ipynb
Start collecting data by running the following command:
```bash
python workflow.py collect_data
```

View File

@@ -1,305 +0,0 @@
{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
},
"orig_nbformat": 2,
"kernelspec": {
"name": "pythonjvsc74a57bd0fcc004278713aaede7c629a6a43738a929cb09abb52817d4f72eb70db44cd87b",
"display_name": "Python 3.8.8 ('qlib_backtest': conda)"
},
"metadata": {
"interpreter": {
"hash": "fcc004278713aaede7c629a6a43738a929cb09abb52817d4f72eb70db44cd87b"
}
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Copyright (c) Microsoft Corporation.\n",
"# Licensed under the MIT License."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys, site\n",
"from pathlib import Path\n",
"\n",
"################################# NOTE #################################\n",
"# Please be aware that if colab installs the latest numpy and pyqlib #\n",
"# in this cell, users should RESTART the runtime in order to run the #\n",
"# following cells successfully. #\n",
"########################################################################\n",
"\n",
"try:\n",
" import qlib\n",
"except ImportError:\n",
" # install qlib\n",
" ! pip install --upgrade numpy\n",
" ! pip install pyqlib\n",
" # reload\n",
" site.main()\n",
"\n",
"scripts_dir = Path.cwd().parent.joinpath(\"scripts\")\n",
"if not scripts_dir.joinpath(\"get_data.py\").exists():\n",
" # download get_data.py script\n",
" scripts_dir = Path(\"~/tmp/qlib_code/scripts\").expanduser().resolve()\n",
" scripts_dir.mkdir(parents=True, exist_ok=True)\n",
" import requests\n",
" with requests.get(\"https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py\") as resp:\n",
" with open(scripts_dir.joinpath(\"get_data.py\"), \"wb\") as fp:\n",
" fp.write(resp.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"import pandas as pd\n",
"from qlib.config import REG_CN\n",
"from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict\n",
"from qlib.workflow import R\n",
"from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n",
"from qlib.tests.data import GetData"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# use default data\n",
"provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
"if not exists_qlib_data(provider_uri):\n",
" print(f\"Qlib data is not found in {provider_uri}\")\n",
" GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n",
"\n",
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"market = \"csi300\"\n",
"benchmark = \"SH000300\"\n",
"\n",
"###################################\n",
"# train model\n",
"###################################\n",
"\n",
"data_handler_config = {\n",
" \"start_time\": \"2008-01-01\",\n",
" \"end_time\": \"2020-08-01\",\n",
" \"fit_start_time\": \"2008-01-01\",\n",
" \"fit_end_time\": \"2014-12-31\",\n",
" \"instruments\": market,\n",
"}\n",
"\n",
"task = {\n",
" \"model\": {\n",
" \"class\": \"LGBModel\",\n",
" \"module_path\": \"qlib.contrib.model.gbdt\",\n",
" \"kwargs\": {\n",
" \"loss\": \"mse\",\n",
" \"colsample_bytree\": 0.8879,\n",
" \"learning_rate\": 0.0421,\n",
" \"subsample\": 0.8789,\n",
" \"lambda_l1\": 205.6999,\n",
" \"lambda_l2\": 580.9768,\n",
" \"max_depth\": 8,\n",
" \"num_leaves\": 210,\n",
" \"num_threads\": 20,\n",
" },\n",
" },\n",
" \"dataset\": {\n",
" \"class\": \"DatasetH\",\n",
" \"module_path\": \"qlib.data.dataset\",\n",
" \"kwargs\": {\n",
" \"handler\": {\n",
" \"class\": \"Alpha158\",\n",
" \"module_path\": \"qlib.contrib.data.handler\",\n",
" \"kwargs\": data_handler_config,\n",
" },\n",
" \"segments\": {\n",
" \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
" \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
" \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
" },\n",
" },\n",
" },\n",
"}\n",
"# model initialization\n",
"model = init_instance_by_config(task[\"model\"])\n",
"dataset = init_instance_by_config(task[\"dataset\"])\n",
"\n",
"# start exp to train model\n",
"with R.start(experiment_name=\"train_model\"):\n",
" R.log_params(**flatten_dict(task))\n",
" model.fit(dataset)\n",
" R.save_objects(trained_model=model)\n",
" rid = R.get_recorder().id\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"outputPrepend"
]
},
"outputs": [],
"source": [
"trade_start_time = \"2017-01-01\"\n",
"trade_end_time = \"2020-08-01\"\n",
"\n",
"port_analysis_config = {\n",
" \"strategy\": {\n",
" \"class\": \"TopkDropoutStrategy\",\n",
" \"module_path\": \"qlib.contrib.strategy.model_strategy\",\n",
" \"kwargs\": {\n",
" \"step_bar\": \"week\",\n",
" \"model\": model,\n",
" \"dataset\": dataset,\n",
" \"topk\": 50,\n",
" \"n_drop\": 5,\n",
" },\n",
" },\n",
" \"env\": {\n",
" \"class\": \"SplitExecutor\",\n",
" \"module_path\": \"qlib.contrib.backtest.executor\",\n",
" \"kwargs\": {\n",
" \"step_bar\": \"week\",\n",
" \"generate_report\": True,\n",
" \"sub_env\": {\n",
" \"class\": \"SimulatorExecutor\",\n",
" \"module_path\": \"qlib.contrib.backtest.executor\",\n",
" \"kwargs\": {\n",
" \"step_bar\": \"day\",\n",
" \"verbose\": True,\n",
" \"generate_report\": True,\n",
" },\n",
" },\n",
" \"sub_strategy\": {\n",
" \"class\": \"SBBStrategyEMA\",\n",
" \"module_path\": \"qlib.contrib.strategy.rule_strategy\",\n",
" \"kwargs\": {\n",
" \"step_bar\": \"day\",\n",
" \"freq\": \"day\",\n",
" \"instruments\": market,\n",
" },\n",
" },\n",
" },\n",
" },\n",
" \"backtest\": {\n",
" \"start_time\": trade_start_time,\n",
" \"end_time\": trade_end_time,\n",
" \"account\": 100000000,\n",
" \"benchmark\": benchmark,\n",
" \"exchange_kwargs\": {\n",
" \"freq\": \"day\",\n",
" \"limit_threshold\": 0.095,\n",
" \"deal_price\": \"close\",\n",
" \"open_cost\": 0.0005,\n",
" \"close_cost\": 0.0015,\n",
" \"min_cost\": 5,\n",
" },\n",
" },\n",
"}\n",
"# backtest and analysis\n",
"with R.start(experiment_name=\"backtest_analysis\"):\n",
" # prediction\n",
" recorder = R.get_recorder()\n",
" ba_rid = recorder.id\n",
" sr = SignalRecord(model, dataset, recorder)\n",
" sr.generate()\n",
"\n",
" # backtest & analysis\n",
" par = PortAnaRecord(recorder, port_analysis_config, \"day\")\n",
" par.generate()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from qlib.contrib.report import analysis_model, analysis_position\n",
"from qlib.data import D\n",
"recorder = R.get_recorder(ba_rid, experiment_name=\"backtest_analysis\")\n",
"pred_df = recorder.load_object(\"pred.pkl\")\n",
"pred_df_dates = pred_df.index.get_level_values(level='datetime')\n",
"report_normal_df_1d = recorder.load_object(\"portfolio_analysis/report_normal_1day.pkl\")\n",
"positions_1d = recorder.load_object(\"portfolio_analysis/positions_normal_1day.pkl\")\n",
"analysis_df_1d = recorder.load_object(\"portfolio_analysis/port_analysis_1day.pkl\")\n",
"report_normal_df_1w = recorder.load_object(\"portfolio_analysis/report_normal_1week.pkl\")\n",
"positions_1w = recorder.load_object(\"portfolio_analysis/positions_normal_1week.pkl\")\n",
"analysis_df_1w = recorder.load_object(\"portfolio_analysis/port_analysis_1week.pkl\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis_position.report_graph(report_normal_df_1d)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis_position.report_graph(report_normal_df_1w)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis_position.risk_analysis_graph(analysis_df_1d, report_normal_df_1d)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis_position.risk_analysis_graph(analysis_df_1w, report_normal_df_1w)"
]
}
]
}

View File

@@ -3,30 +3,21 @@
import qlib
import fire
from qlib.config import REG_CN
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.tests.data import GetData
from qlib.contrib.backtest import collect_data
if __name__ == "__main__":
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)
class MultiLevelTradingWorkflow:
market = "csi300"
benchmark = "SH000300"
###################################
# train model
###################################
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
@@ -68,31 +59,17 @@ if __name__ == "__main__":
},
},
}
# model initialization
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
trade_start_time = "2017-01-01"
trade_end_time = "2020-08-01"
port_analysis_config = {
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"kwargs": {
"step_bar": "week",
"model": model,
"dataset": dataset,
"topk": 50,
"n_drop": 5,
},
},
"env": {
"executor": {
"class": "SplitExecutor",
"module_path": "qlib.contrib.backtest.executor",
"kwargs": {
"step_bar": "week",
"sub_env": {
"sub_executor": {
"class": "SimulatorExecutor",
"module_path": "qlib.contrib.backtest.executor",
"kwargs": {
@@ -105,11 +82,11 @@ if __name__ == "__main__":
"class": "SBBStrategyEMA",
"module_path": "qlib.contrib.strategy.rule_strategy",
"kwargs": {
"step_bar": "day",
"freq": "day",
"instruments": market,
},
},
"track_data": True,
},
},
"backtest": {
@@ -128,17 +105,69 @@ if __name__ == "__main__":
},
}
with R.start(experiment_name="highfreq_backtest"):
R.log_params(**flatten_dict(task))
model.fit(dataset)
R.save_objects(**{"params.pkl": model})
def _init_qlib(self):
"""initialize qlib"""
# use yahoo_cn_1min data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)
# prediction
recorder = R.get_recorder()
sr = SignalRecord(model, dataset, recorder)
sr.generate()
def _train_model(self, model, dataset):
with R.start(experiment_name="train"):
R.log_params(**flatten_dict(self.task))
model.fit(dataset)
R.save_objects(**{"params.pkl": model})
# backtest. If users want to use backtest based on their own prediction,
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
par = PortAnaRecord(recorder, port_analysis_config, "day")
par.generate()
# prediction
recorder = R.get_recorder()
sr = SignalRecord(model, dataset, recorder)
sr.generate()
def backtest(self):
self._init_qlib()
model = init_instance_by_config(self.task["model"])
dataset = init_instance_by_config(self.task["dataset"])
self._train_model(model, dataset)
strategy_config = {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"kwargs": {
"model": model,
"dataset": dataset,
"topk": 50,
"n_drop": 5,
},
}
self.port_analysis_config["strategy"] = strategy_config
with R.start(experiment_name="backtest"):
recorder = R.get_recorder()
par = PortAnaRecord(recorder, self.port_analysis_config, "day")
par.generate()
def collect_data(self):
self._init_qlib()
model = init_instance_by_config(self.task["model"])
dataset = init_instance_by_config(self.task["dataset"])
self._train_model(model, dataset)
executor_config = self.port_analysis_config["executor"]
backtest_config = self.port_analysis_config["backtest"]
strategy_config = {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"kwargs": {
"model": model,
"dataset": dataset,
"topk": 50,
"n_drop": 5,
},
}
data_generator = collect_data(executor=executor_config, strategy=strategy_config, **backtest_config)
for trade_decision in data_generator:
print(trade_decision)
if __name__ == "__main__":
fire.Fire(MultiLevelTradingWorkflow)

View File

@@ -140,6 +140,10 @@ _default_config = {
"default_exp_name": "Experiment",
},
},
# Shift minute for highfreq minite data, used in backtest
# if min_data_shift == 0, use default market time [9:30, 11:29, 1:30, 2:39]
# if min_data_shift != 0, use shifted market time [9:30, 11:29, 1:30, 2:39] - shift*minute
"min_data_shift": {0},
}
MODE_CONF = {

View File

@@ -5,13 +5,12 @@ from .account import Account
from .exchange import Exchange
from .executor import BaseExecutor
from .backtest import backtest as backtest_func
from .backtest import collect_data as data_generator
from ...strategy.base import BaseStrategy
from ...utils import init_instance_by_config
from ...log import get_module_logger
from ...config import C
from .faculty import common_faculty
logger = get_module_logger("backtest caller")
@@ -89,8 +88,9 @@ def get_exchange(
return init_instance_by_config(exchange, accept_types=Exchange)
def backtest(start_time, end_time, strategy, env, benchmark="SH000300", account=1e9, exchange_kwargs={}):
def get_strategy_executor(
start_time, end_time, strategy, executor, benchmark="SH000300", account=1e9, exchange_kwargs={}
):
trade_account = Account(
init_cash=account,
benchmark_config={
@@ -101,14 +101,32 @@ def backtest(start_time, end_time, strategy, env, benchmark="SH000300", account=
)
trade_exchange = get_exchange(**exchange_kwargs)
common_faculty.update(
trade_account=trade_account,
trade_exchange=trade_exchange,
common_infra = {
"trade_account": trade_account,
"trade_exchange": trade_exchange,
}
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy, common_infra=common_infra)
trade_executor = init_instance_by_config(executor, accept_types=BaseExecutor, common_infra=common_infra)
return trade_strategy, trade_executor
def backtest(start_time, end_time, strategy, executor, benchmark="SH000300", account=1e9, exchange_kwargs={}):
trade_strategy, trade_executor = get_strategy_executor(
start_time, end_time, strategy, executor, benchmark, account, exchange_kwargs
)
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy)
trade_env = init_instance_by_config(env, accept_types=BaseExecutor)
report_dict = backtest_func(start_time, end_time, trade_strategy, trade_env)
report_dict = backtest_func(start_time, end_time, trade_strategy, trade_executor)
return report_dict
def collect_data(start_time, end_time, strategy, executor, benchmark="SH000300", account=1e9, exchange_kwargs={}):
trade_strategy, trade_executor = get_strategy_executor(
start_time, end_time, strategy, executor, benchmark, account, exchange_kwargs
)
report_dict = yield from data_generator(start_time, end_time, trade_strategy, trade_executor)
return report_dict

View File

@@ -9,8 +9,6 @@ import pandas as pd
from .position import Position
from .report import Report
from .order import Order
from ...data import D
from ...utils.sample import parse_freq, sample_feature
"""
@@ -34,85 +32,14 @@ class Account:
self.init_vars(init_cash, freq, benchmark_config)
def init_vars(self, init_cash, freq: str, benchmark_config: dict):
"""
Parameters
----------
freq : str
frequency of trading bar, used for updating hold count of trading bar
benchmark_config : dict
config of benchmark, may including the following arguments:
- benchmark : Union[str, list, pd.Series]
- If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
example:
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
2017-01-04 0.011693
2017-01-05 0.000721
2017-01-06 -0.004322
2017-01-09 0.006874
2017-01-10 -0.003350
- If `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
- If `benchmark` is str, will use the daily change as the 'bench'.
benchmark code, default is SH000300 CSI300
- start_time : Union[str, pd.Timestamp], optional
- If `benchmark` is pd.Series, it will be ignored
- Else, it represent start time of benchmark, by default None
- end_time : Union[str, pd.Timestamp], optional
- If `benchmark` is pd.Series, it will be ignored
- Else, it represent end time of benchmark, by default None
"""
# init cash
self.init_cash = init_cash
self.freq = freq
self.benchmark_config = benchmark_config
self.bench = self._cal_benchmark(benchmark_config, freq)
self.current = Position(cash=init_cash)
self._reset_report()
self.reset(freq=freq, benchmark_config=benchmark_config, init_report=True)
def _cal_benchmark(self, benchmark_config, freq):
benchmark = benchmark_config.get("benchmark", "SH000300")
if isinstance(benchmark, pd.Series):
return benchmark
else:
start_time = benchmark_config.get("start_time", None)
end_time = benchmark_config.get("end_time", None)
if freq is None:
raise ValueError("benchmark freq can't be None!")
_codes = benchmark if isinstance(benchmark, list) else [benchmark]
fields = ["$close/Ref($close,1)-1"]
try:
_temp_result = D.features(_codes, fields, start_time, end_time, freq=freq, disk_cache=1)
except ValueError:
_, norm_freq = parse_freq(freq)
if norm_freq in ["month", "week", "day"]:
try:
_temp_result = D.features(_codes, fields, start_time, end_time, freq="day", disk_cache=1)
except ValueError:
_temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1)
elif norm_freq == "minute":
_temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1)
else:
raise ValueError(f"benchmark freq {freq} is not supported")
if len(_temp_result) == 0:
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
def _sample_benchmark(self, bench, trade_start_time, trade_end_time):
def cal_change(x):
return (x + 1).prod() - 1
_ret = sample_feature(bench, trade_start_time, trade_end_time, method=cal_change)
return 0 if _ret is None else _ret
def _reset_freq(self, freq):
"""reset frequency"""
if freq != self.freq:
self.freq = freq
self.bench = self._cal_benchmark(self.benchmark_config, self.freq)
def _reset_report(self):
self.report = Report()
def reset_report(self, freq, benchmark_config):
self.report = Report(freq, benchmark_config)
self.positions = {}
self.rtn = 0
self.ct = 0
@@ -120,10 +47,25 @@ class Account:
self.val = 0
self.earning = 0
def reset(self, freq=None, init_report: bool = False):
self._reset_freq(freq)
if init_report:
self._reset_report()
def reset(self, freq=None, benchmark_config=None, init_report=False):
"""reset freq and report of account
Parameters
----------
freq : str, optional
frequency of account & report, by default None
benchmark_config : {}, optional
benchmark config of report, by default None
init_report : bool, optional
whether to initialize the report, by default False
"""
if freq is not None:
self.freq = freq
if benchmark_config is not None:
self.benchmark_config = benchmark_config
if freq is not None or benchmark_config is not None or init_report:
self.reset_report(self.freq, self.benchmark_config)
def get_positions(self):
return self.positions
@@ -131,7 +73,7 @@ class Account:
def get_cash(self):
return self.current.position["cash"]
def update_state_from_order(self, order, trade_val, cost, trade_price):
def _update_state_from_order(self, order, trade_val, cost, trade_price):
# update turnover
self.to += trade_val
# update cost
@@ -155,7 +97,7 @@ class Account:
# The cost will be substracted from the cash at last. So the trading logic can ignore the cost calculation
if order.direction == Order.SELL:
# sell stock
self.update_state_from_order(order, trade_val, cost, trade_price)
self._update_state_from_order(order, trade_val, cost, trade_price)
# update current position
# for may sell all of stock_id
self.current.update_order(order, trade_val, cost, trade_price)
@@ -163,15 +105,15 @@ class Account:
# buy stock
# deal order, then update state
self.current.update_order(order, trade_val, cost, trade_price)
self.update_state_from_order(order, trade_val, cost, trade_price)
self._update_state_from_order(order, trade_val, cost, trade_price)
def update_bar_count(self):
self.current.add_count_all(bar=self.freq)
def update_bar_report(self, trade_start_time, trade_end_time, trade_exchange):
"""
start_time: pd.TimeStamp
end_time: pd.TimeStamp
trade_start_time: pd.TimeStamp
trade_end_time: pd.TimeStamp
quote: pd.DataFrame (code, date), collumns
when the end of trade date
- update rtn
@@ -211,7 +153,8 @@ class Account:
# judge whether the the trading is begin.
# and don't add init account state into report, due to we don't have excess return in those days.
self.report.update_report_record(
trade_time=trade_start_time,
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
account_value=now_account_value,
cash=self.current.position["cash"],
return_rate=(self.earning + self.ct) / last_account_value,
@@ -220,7 +163,6 @@ class Account:
turnover_rate=self.to / last_account_value,
cost_rate=self.ct / last_account_value,
stock_value=now_stock_value,
bench_value=self._sample_benchmark(self.bench, trade_start_time, trade_end_time),
)
# set now_account_value to position
self.current.position["now_account_value"] = now_account_value
@@ -234,18 +176,3 @@ class Account:
self.rtn = 0
self.ct = 0
self.to = 0
def load_account(self, account_path):
report = Report()
position = Position()
report.load_report(account_path / "report.csv")
position.load_position(account_path / "position.xlsx")
# assign values
self.init_vars(position.init_cash)
self.current = position
self.report = report
def save_account(self, account_path):
self.current.save_position(account_path / "position.xlsx")
self.report.save_report(account_path / "report.csv")

View File

@@ -2,14 +2,29 @@
# Licensed under the MIT License.
def backtest(start_time, end_time, trade_strategy, trade_env):
def backtest(start_time, end_time, trade_strategy, trade_executor):
trade_env.reset(start_time=start_time, end_time=end_time)
trade_strategy.reset(start_time=start_time, end_time=end_time)
trade_executor.reset(start_time=start_time, end_time=end_time)
level_infra = trade_executor.get_level_infra()
trade_strategy.reset(level_infra=level_infra)
_execute_state = trade_env.get_init_state()
while not trade_env.finished():
_order_list = trade_strategy.generate_order_list(_execute_state)
_execute_state = trade_env.execute(_order_list)
sub_execute_state = trade_executor.get_init_state()
while not trade_executor.finished():
sub_trade_decision = trade_strategy.generate_trade_decision(sub_execute_state)
sub_execute_state = trade_executor.execute(sub_trade_decision)
return trade_env.get_report()
return trade_executor.get_report()
def collect_data(start_time, end_time, trade_strategy, trade_executor):
trade_executor.reset(start_time=start_time, end_time=end_time)
level_infra = trade_executor.get_level_infra()
trade_strategy.reset(level_infra=level_infra)
sub_execute_state = trade_executor.get_init_state()
while not trade_executor.finished():
sub_trade_decision = trade_strategy.generate_trade_decision(sub_execute_state)
sub_execute_state = yield from trade_executor.collect_data(sub_trade_decision)
return trade_executor.get_report()

View File

@@ -11,7 +11,7 @@ import pandas as pd
from ...data.data import D
from ...data.dataset.utils import get_level_index
from ...config import C, REG_CN
from ...utils.sample import sample_feature
from ...utils.resam import resam_ts_data
from ...log import get_module_logger
from .order import Order
@@ -34,8 +34,9 @@ class Exchange:
):
"""__init__
:param start_time: start time for backtest
:param end_time: end time for backtest
:param freq: frequency of data
:param start_time: closed start time for backtest
:param end_time: closed end time for backtest
:param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50)
:param deal_price: str, 'close', 'open', 'vwap'
:param subscribe_fields: list, subscribe fields
@@ -91,7 +92,7 @@ class Exchange:
# $factor is for rounding to the trading unit
# $change is for calculating the limit of the stock
necessary_fields = {self.deal_price, "$close", "$change", "$factor"}
necessary_fields = {self.deal_price, "$close", "$change", "$factor", "$volume"}
subscribe_fields = list(necessary_fields | set(subscribe_fields))
all_fields = list(necessary_fields | set(subscribe_fields))
self.all_fields = all_fields
@@ -167,12 +168,12 @@ class Exchange:
trade_date
is limtited
"""
return sample_feature(self.quote[stock_id], start_time, end_time, fields="limit", method="all").iloc[0]
return resam_ts_data(self.quote[stock_id]["limit"], start_time, end_time, method="all").iloc[0]
def check_stock_suspended(self, stock_id, start_time, end_time):
# is suspended
if stock_id in self.quote:
return sample_feature(self.quote[stock_id], start_time, end_time, method=None) is None
return resam_ts_data(self.quote[stock_id], start_time, end_time, method=None) is None
else:
return True
@@ -230,15 +231,16 @@ class Exchange:
return trade_val, trade_cost, trade_price
def get_quote_info(self, stock_id, start_time, end_time):
return sample_feature(self.quote[stock_id], start_time, end_time, method="last").iloc[0]
return resam_ts_data(self.quote[stock_id], start_time, end_time, method="last").iloc[0]
def get_close(self, stock_id, start_time, end_time):
return sample_feature(self.quote[stock_id], start_time, end_time, fields="$close", method="last").iloc[0]
return resam_ts_data(self.quote[stock_id]["$close"], start_time, end_time, method="last").iloc[0]
def get_volume(self, stock_id, start_time, end_time):
return resam_ts_data(self.quote[stock_id]["$volume"], start_time, end_time, method="sum").iloc[0]
def get_deal_price(self, stock_id, start_time, end_time):
deal_price = sample_feature(
self.quote[stock_id], start_time, end_time, fields=self.deal_price, method="last"
).iloc[0]
deal_price = resam_ts_data(self.quote[stock_id][self.deal_price], start_time, end_time, method="last").iloc[0]
if np.isclose(deal_price, 0.0) or np.isnan(deal_price):
self.logger.warning(
f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!"
@@ -248,7 +250,7 @@ class Exchange:
return deal_price
def get_factor(self, stock_id, start_time, end_time):
return sample_feature(self.quote[stock_id], start_time, end_time, fields="$factor", method="last").iloc[0]
return resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method="last").iloc[0]
def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time):
"""

View File

@@ -2,88 +2,18 @@ import copy
import warnings
import pandas as pd
from typing import Union
from ...data.data import Cal
from ...utils import init_instance_by_config
from ...utils.sample import get_sample_freq_calendar, parse_freq
from ...utils.resam import parse_freq
from .order import Order
from .account import Account
from .exchange import Exchange
from .faculty import common_faculty
from .utils import TradeCalendarManager
class BaseTradeCalendar:
"""
Base class providing trading calendar
- BaseStrategy and BaseExecutor should inherited from this class
"""
def __init__(
self, step_bar: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None
):
"""
Parameters
----------
step_bar : str
frequency of each trading step bar
start_time : Union[str, pd.Timestamp], optional
start time of trading, by default None
If `start_time` is None, it must be reset before trading.
end_time : Union[str, pd.Timestamp], optional
end time of trading, by default None
If `end_time` is None, it must be reset before trading.
"""
self.step_bar = step_bar
self.start_time = pd.Timestamp(start_time) if start_time else None
self.end_time = pd.Timestamp(end_time) if end_time else None
self.reset(start_time=start_time, end_time=end_time)
def _reset_trade_calendar(self, start_time, end_time):
"""reset trade calendar"""
if start_time and end_time:
_calendar, freq, freq_sam = get_sample_freq_calendar(freq=self.step_bar)
self.calendar = _calendar
_, _, _start_index, _end_index = Cal.locate_index(
self.start_time, self.end_time, freq=freq, freq_sam=freq_sam
)
self.start_index = _start_index
self.end_index = _end_index
self.trade_len = _end_index - _start_index + 1
self.trade_index = 0
else:
raise ValueError("failed to reset trade calendar, param `start_time` or `end_time` is None.")
def reset(self, start_time=None, end_time=None):
"""
Reset start\end time of trading, and reset trading calendar
"""
if start_time:
self.start_time = pd.Timestamp(start_time)
if end_time:
self.end_time = pd.Timestamp(end_time)
if self.start_time and self.end_time and (start_time or end_time):
self._reset_trade_calendar(start_time=self.start_time, end_time=self.end_time)
def _get_calendar_time(self, trade_index=1, shift=0):
trade_index = trade_index - shift
calendar_index = self.start_index + trade_index
return self.calendar[calendar_index - 1], self.calendar[calendar_index] - pd.Timedelta(seconds=1)
def finished(self):
return self.trade_index >= self.trade_len
def step(self):
if self.finished():
raise RuntimeError(f"this env has completed its task, please reset it if you want to call it!")
# trade count += 1
self.trade_index = self.trade_index + 1
class BaseExecutor(BaseTradeCalendar):
class BaseExecutor:
"""Base executor for trading"""
def __init__(
@@ -91,48 +21,97 @@ class BaseExecutor(BaseTradeCalendar):
step_bar: str,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
trade_account: Account = None,
generate_report: bool = False,
verbose: bool = False,
track_data: bool = False,
common_infra: dict = {},
**kwargs,
):
"""
Parameters
----------
trade_account : Account, optional
trade account for trading, by default None
- If `trade_account` is None, self.trade_account will be set with common_faculty
generate_report : bool, optional
whether to generate report, by default False
verbose : bool, optional
whether to print trading info, by default False
track_data : bool, optional
whether to generate order_list, will be used when making data for multi-level training
- If `self.track_data` is true, when making data for training, the input `order_list` of `execute` will be generated by `get_data`
- Else, `order_list` will not be generated
whether to generate trade_decision, will be used when making data for multi-level training
- If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will be generated by `collect_data`
- Else, `trade_decision` will not be generated
common_infra : dict, optional:
common infrastructure for backtesting, may including:
- trade_account : Account, optional
trade account for trading
- trade_exchange : Exchange, optional
exchange that provides market info
"""
super(BaseExecutor, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, **kwargs)
self.trade_account = copy.copy(common_faculty.trade_account if trade_account is None else trade_account)
self.trade_account.reset(freq=self.step_bar, init_report=True)
self.step_bar = step_bar
self.generate_report = generate_report
self.verbose = verbose
self.track_data = track_data
self.reset(start_time=start_time, end_time=end_time, track_data=track_data, common_infra=common_infra)
def reset(self, track_data: bool = None, **kwargs):
def reset_common_infra(self, common_infra):
"""
Reset `track_data`, will be used when making data for multi-level training
reset infrastructure for trading
- reset trade_account
"""
super(BaseExecutor, self).reset(**kwargs)
if not hasattr(self, "common_infra"):
self.common_infra = common_infra
else:
self.common_infra.update(common_infra)
if "trade_account" in common_infra:
self.trade_account = copy.copy(common_infra.get("trade_account"))
self.trade_account.reset(freq=self.step_bar, init_report=True)
def reset(self, track_data: bool = None, common_infra: dict = None, **kwargs):
"""
- reset `start_time` and `end_time`, used in trade calendar
- reset `track_data`, used when making data for multi-level training
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
"""
if track_data is not None:
self.track_data = track_data
if common_infra is not None:
self.reset_common_infra(common_infra)
if "start_time" in kwargs or "end_time" in kwargs:
start_time = kwargs.get("start_time")
end_time = kwargs.get("end_time")
self.trade_calendar = TradeCalendarManager(step_bar=self.step_bar, start_time=start_time, end_time=end_time)
def get_level_infra(self):
return {"trade_calendar": self.trade_calendar}
def finished(self):
return self.trade_calendar.finished()
def execute(self, trade_decision):
"""execute the trade decision and return the executed result
Parameters
----------
trade_decision : object
Returns
----------
executed state : List[Tuple[Order, float, float, float]]
- Each element in the list represents (order, trade value, trade cost, trade price)
"""
raise NotImplementedError("execute is not implemented!")
def collect_data(self, trade_decision):
if self.track_data:
yield trade_decision
return self.execute(trade_decision)
def get_init_state(self):
raise NotImplementedError("get_init_state in not implemeted!")
def execute(self, **kwargs):
raise NotImplementedError("execute is not implemented!")
def get_trade_account(self):
raise NotImplementedError("get_trade_account is not implemented!")
@@ -146,56 +125,75 @@ class SplitExecutor(BaseExecutor):
def __init__(
self,
step_bar: str,
sub_env: Union[BaseExecutor, dict],
sub_executor: Union[BaseExecutor, dict],
sub_strategy: Union[BaseStrategy, dict],
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
trade_account: Account = None,
trade_exchange: Exchange = None,
generate_report: bool = False,
verbose: bool = False,
track_data: bool = False,
common_infra: dict = {},
**kwargs,
):
"""
Parameters
----------
sub_env : BaseExecutor
sub_executor : BaseExecutor
trading env in each trading bar.
sub_strategy : BaseStrategy
trading strategy in each trading bar
trade_exchange : Exchange
exchange that provides market info
- If `trade_exchange` is None, self.trade_exchange will be set with common_faculty
exchange that provides market info, used to generate report
- If generate_report is None, trade_exchange will be ignored
- Else If `trade_exchange` is None, self.trade_exchange will be set with common_infra
"""
self.sub_executor = init_instance_by_config(sub_executor, common_infra=common_infra, accept_types=BaseExecutor)
self.sub_strategy = init_instance_by_config(
sub_strategy, common_infra=common_infra, accept_types=self.BaseStrategy
)
super(SplitExecutor, self).__init__(
step_bar=step_bar,
start_time=start_time,
end_time=end_time,
trade_account=trade_account,
generate_report=generate_report,
verbose=verbose,
track_data=track_data,
common_infra=common_infra,
**kwargs,
)
if generate_report:
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
self.sub_env = init_instance_by_config(sub_env, accept_types=BaseExecutor)
self.sub_strategy = init_instance_by_config(sub_strategy, accept_types=self.BaseStrategy)
if generate_report and trade_exchange is not None:
self.trade_exchange = trade_exchange
def reset_common_infra(self, common_infra):
"""
reset infrastructure for trading
- reset trade_exchange
- reset substrategy and subexecutor common infra
"""
super(SplitExecutor, self).reset_common_infra(common_infra)
if self.generate_report and "trade_exchange" in common_infra:
self.trade_exchange = common_infra.get("trade_exchange")
self.sub_executor.reset_common_infra(common_infra)
self.sub_strategy.reset_common_infra(common_infra)
def get_init_state(self):
init_state = {"current": self.trade_account.current}
return init_state
return []
def _init_sub_trading(self, order_list):
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time)
self.sub_strategy.reset(start_time=trade_start_time, end_time=trade_end_time, trade_order_list=order_list)
sub_execute_state = self.sub_env.get_init_state()
return sub_execute_state
def _init_sub_trading(self, trade_decision):
trade_index = self.trade_calendar.get_trade_index()
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
self.sub_executor.reset(start_time=trade_start_time, end_time=trade_end_time)
sub_level_infra = self.sub_executor.get_level_infra()
self.sub_strategy.reset(level_infra=sub_level_infra, rely_trade_decision=trade_decision)
def _update_trade_account(self):
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
trade_index = self.trade_calendar.get_trade_index()
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
self.trade_account.update_bar_count()
if self.generate_report:
self.trade_account.update_bar_report(
@@ -204,30 +202,38 @@ class SplitExecutor(BaseExecutor):
trade_exchange=self.trade_exchange,
)
def execute(self, order_list):
super(SplitExecutor, self).step()
self._init_sub_trading(order_list)
sub_execute_state = self.sub_env.get_init_state()
while not self.sub_env.finished():
_order_list = self.sub_strategy.generate_order_list(sub_execute_state)
sub_execute_state = self.sub_env.execute(order_list=_order_list)
self._update_trade_account()
return {"current": self.trade_account.current}
def execute(self, trade_decision):
self.trade_calendar.step()
self._init_sub_trading(trade_decision)
execute_state = []
sub_execute_state = self.sub_executor.get_init_state()
while not self.sub_executor.finished():
sub_trade_decison = self.sub_strategy.generate_trade_decision(sub_execute_state)
sub_execute_state = self.sub_executor.execute(trade_decision=sub_trade_decison)
execute_state.extend(sub_execute_state)
if hasattr(self, "trade_account"):
self._update_trade_account()
def get_data(self, order_list):
return execute_state
def collect_data(self, trade_decision):
if self.track_data:
yield order_list
super(SplitExecutor, self).step()
self._init_sub_trading(order_list)
sub_execute_state = self.sub_env.get_init_state()
while not self.sub_env.finished():
_order_list = self.sub_strategy.generate_order_list(sub_execute_state)
sub_execute_state = yield from self.sub_env.get_data(order_list=_order_list)
self._update_trade_account()
return {"current": self.trade_account.current}
yield trade_decision
self.trade_calendar.step()
self._init_sub_trading(trade_decision)
execute_state = []
sub_execute_state = self.sub_executor.get_init_state()
while not self.sub_executor.finished():
sub_trade_decison = self.sub_strategy.generate_trade_decision(sub_execute_state)
sub_execute_state = yield from self.sub_executor.collect_data(trade_decision=sub_trade_decison)
execute_state.extend(sub_execute_state)
if hasattr(self, "trade_account"):
self._update_trade_account()
return execute_state
def get_report(self):
sub_env_report_dict = self.sub_env.get_report()
sub_env_report_dict = self.sub_executor.get_report()
if self.generate_report:
_report = self.trade_account.report.generate_report_dataframe()
_positions = self.trade_account.get_positions()
@@ -242,46 +248,57 @@ class SimulatorExecutor(BaseExecutor):
step_bar: str,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
trade_account: Account = None,
trade_exchange: Exchange = None,
generate_report: bool = False,
verbose: bool = False,
track_data: bool = False,
common_infra: dict = {},
**kwargs,
):
"""
Parameters
----------
trade_exchange : Exchange
exchange that provides market info
exchange that provides market info, used to deal order and generate report
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
"""
super(SimulatorExecutor, self).__init__(
step_bar=step_bar,
start_time=start_time,
end_time=end_time,
trade_account=trade_account,
generate_report=generate_report,
verbose=verbose,
track_data=track_data,
common_infra=common_infra,
**kwargs,
)
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
if trade_exchange is not None:
self.trade_exchange = trade_exchange
def reset_common_infra(self, common_infra):
"""
reset infrastructure for trading
- reset trade_exchange
"""
super(SimulatorExecutor, self).reset_common_infra(common_infra)
if "trade_exchange" in common_infra:
self.trade_exchange = common_infra.get("trade_exchange")
def get_init_state(self):
init_state = {"current": self.trade_account.current, "trade_info": []}
return init_state
return []
def execute(self, order_list):
super(SimulatorExecutor, self).step()
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
trade_info = []
for order in order_list:
def execute(self, trade_decision):
self.trade_calendar.step()
trade_index = self.trade_calendar.get_trade_index()
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
execute_state = []
for order in trade_decision:
if self.trade_exchange.check_order(order) is True:
# execute the order
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(
order, trade_account=self.trade_account
)
trade_info.append((order, trade_val, trade_cost, trade_price))
execute_state.append((order, trade_val, trade_cost, trade_price))
if self.verbose:
if order.direction == Order.SELL: # sell
print(
@@ -323,7 +340,7 @@ class SimulatorExecutor(BaseExecutor):
trade_exchange=self.trade_exchange,
)
return {"current": self.trade_account.current, "trade_info": trade_info}
return execute_state
def get_report(self):
if self.generate_report:

View File

@@ -1,28 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
class Faculty:
def __init__(self):
self.__dict__["_faculty"] = dict()
def __getitem__(self, key):
return self.__dict__["_faculty"][key]
def __getattr__(self, attr):
if attr in self.__dict__["_faculty"]:
return self.__dict__["_faculty"][attr]
raise AttributeError(f"No such {attr} in self._faculty")
def __setitem__(self, key, value):
self.__dict__["_faculty"][key] = value
def __setattr__(self, attr, value):
self.__dict__["_faculty"][attr] = value
def update(self, *args, **kwargs):
self.__dict__["_faculty"].update(*args, **kwargs)
common_faculty = Faculty()

View File

@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import pandas as pd
import copy
import pathlib

View File

@@ -3,16 +3,51 @@
from collections import OrderedDict
from logging import warning
import pandas as pd
import pathlib
import warnings
from pandas.core.frame import DataFrame
from ...utils.resam import parse_freq, resam_ts_data
from ...data import D
class Report:
# daily report of the account
# contain those followings: returns, costs turnovers, accounts, cash, bench, value
# update report
def __init__(self):
def __init__(self, freq: str = "day", benchmark_config: dict = {}):
"""
Parameters
----------
freq : str
frequency of trading bar, used for updating hold count of trading bar
benchmark_config : dict
config of benchmark, may including the following arguments:
- benchmark : Union[str, list, pd.Series]
- If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
example:
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
2017-01-04 0.011693
2017-01-05 0.000721
2017-01-06 -0.004322
2017-01-09 0.006874
2017-01-10 -0.003350
- If `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
- If `benchmark` is str, will use the daily change as the 'bench'.
benchmark code, default is SH000300 CSI300
- start_time : Union[str, pd.Timestamp], optional
- If `benchmark` is pd.Series, it will be ignored
- Else, it represent start time of benchmark, by default None
- end_time : Union[str, pd.Timestamp], optional
- If `benchmark` is pd.Series, it will be ignored
- Else, it represent end time of benchmark, by default None
"""
self.init_vars()
self.init_bench(freq=freq, benchmark_config=benchmark_config)
def init_vars(self):
self.accounts = OrderedDict() # account postion value for each trade date
@@ -24,6 +59,49 @@ class Report:
self.benches = OrderedDict()
self.latest_report_time = None # pd.TimeStamp
def init_bench(self, freq=None, benchmark_config=None):
if freq is not None:
self.freq = freq
if benchmark_config is not None:
self.benchmark_config = benchmark_config
self.bench = self._cal_benchmark(self.benchmark_config, self.freq)
def _cal_benchmark(self, benchmark_config, freq):
benchmark = benchmark_config.get("benchmark", "SH000300")
if isinstance(benchmark, pd.Series):
return benchmark
else:
start_time = benchmark_config.get("start_time", None)
end_time = benchmark_config.get("end_time", None)
if freq is None:
raise ValueError("benchmark freq can't be None!")
_codes = benchmark if isinstance(benchmark, list) else [benchmark]
fields = ["$close/Ref($close,1)-1"]
try:
_temp_result = D.features(_codes, fields, start_time, end_time, freq=freq, disk_cache=1)
except ValueError:
_, norm_freq = parse_freq(freq)
if norm_freq in ["month", "week", "day"]:
try:
_temp_result = D.features(_codes, fields, start_time, end_time, freq="day", disk_cache=1)
except ValueError:
_temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1)
elif norm_freq == "minute":
_temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1)
else:
raise ValueError(f"benchmark freq {freq} is not supported")
if len(_temp_result) == 0:
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
def _sample_benchmark(self, bench, trade_start_time, trade_end_time):
def cal_change(x):
return (x + 1).prod() - 1
_ret = resam_ts_data(bench, trade_start_time, trade_end_time, method=cal_change)
return 0.0 if _ret is None else _ret
def is_empty(self):
return len(self.accounts) == 0
@@ -35,30 +113,39 @@ class Report:
def update_report_record(
self,
trade_time=None,
trade_start_time=None,
trade_end_time=None,
account_value=None,
cash=None,
return_rate=None,
turnover_rate=None,
cost_rate=None,
stock_value=None,
bench_value=None,
):
# check data
if None in [trade_time, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value, bench_value]:
if None in [
trade_start_time,
trade_end_time,
account_value,
cash,
return_rate,
turnover_rate,
cost_rate,
stock_value,
]:
raise ValueError(
"None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value, bench_value]"
"None in [trade_start_time, trade_end_time, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value]"
)
# update report data
self.accounts[trade_time] = account_value
self.returns[trade_time] = return_rate
self.turnovers[trade_time] = turnover_rate
self.costs[trade_time] = cost_rate
self.values[trade_time] = stock_value
self.cashes[trade_time] = cash
self.benches[trade_time] = bench_value
self.accounts[trade_start_time] = account_value
self.returns[trade_start_time] = return_rate
self.turnovers[trade_start_time] = turnover_rate
self.costs[trade_start_time] = cost_rate
self.values[trade_start_time] = stock_value
self.cashes[trade_start_time] = cash
self.benches[trade_start_time] = self._sample_benchmark(self.bench, trade_start_time, trade_end_time)
# update latest_report_date
self.latest_report_time = trade_time
self.latest_report_time = trade_start_time
# finish daily report update
def generate_report_dataframe(self):

View File

@@ -0,0 +1,67 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
from typing import Union
from ...utils.resam import get_resam_calendar
from ...data.data import Cal
class TradeCalendarManager:
"""
Manager for trading calendar
- BaseStrategy and BaseExecutor will use it
"""
def __init__(
self, step_bar: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None
):
"""
Parameters
----------
step_bar : str
frequency of each trading calendar
start_time : Union[str, pd.Timestamp], optional
closed start of the trading calendar, by default None
If `start_time` is None, it must be reset before trading.
end_time : Union[str, pd.Timestamp], optional
closed end of the trade time range, by default None
If `end_time` is None, it must be reset before trading.
"""
self.step_bar = step_bar
self.start_time = pd.Timestamp(start_time) if start_time else None
self.end_time = pd.Timestamp(start_time) if start_time else None
self._init_trade_calendar(step_bar=step_bar, start_time=start_time, end_time=end_time)
def _init_trade_calendar(self, step_bar, start_time, end_time):
"""reset trade calendar"""
_calendar, freq, freq_sam = get_resam_calendar(freq=step_bar)
self.calendar = _calendar
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam)
self.start_index = _start_index
self.end_index = _end_index
self.trade_len = _end_index - _start_index + 1
self.trade_index = 0
def finished(self):
return self.trade_index >= self.trade_len
def step(self):
if self.finished():
raise RuntimeError(f"The calendar is finished, please reset it if you want to call it!")
self.trade_index = self.trade_index + 1
def get_step_bar(self):
return self.step_bar
def get_trade_len(self):
return self.trade_len
def get_trade_index(self):
return self.trade_index
def get_calendar_time(self, trade_index=1, shift=0):
trade_index = trade_index - shift
calendar_index = self.start_index + trade_index
return self.calendar[calendar_index - 1], self.calendar[calendar_index] - pd.Timedelta(seconds=1)

View File

@@ -3,6 +3,7 @@
from __future__ import division
from __future__ import print_function
from logging import warn
import numpy as np
import pandas as pd
@@ -10,7 +11,7 @@ import warnings
from ..log import get_module_logger
from .backtest import get_exchange, backtest as backtest_func
from ..utils import get_date_range
from ..utils.sample import parse_freq
from ..utils.resam import parse_freq
from ..data import D
from ..config import C
@@ -20,7 +21,7 @@ from ..data.dataset.utils import get_level_index
logger = get_module_logger("Evaluate")
def risk_analysis(r, N: int = None, freq: str = None):
def risk_analysis(r, N: int = None, freq: str = "day"):
"""Risk Analysis
Parameters
@@ -36,8 +37,8 @@ def risk_analysis(r, N: int = None, freq: str = None):
def cal_risk_analysis_scaler(freq):
_count, _freq = parse_freq(freq)
_freq_scaler = {
"minute": 240 * 250,
"day": 250,
"minute": 240 * 252,
"day": 252,
"week": 50,
"month": 12,
}
@@ -45,6 +46,8 @@ def risk_analysis(r, N: int = None, freq: str = None):
if N is None and freq is None:
raise ValueError("at least one of `N` and `freq` should exist")
if N is not None and freq is not None:
warnings.warn("risk_analysis freq will be ignored")
if N is None:
N = cal_risk_analysis_scaler(freq)

View File

@@ -118,7 +118,7 @@ class Operator:
user.strategy.update(score_series, pred_date, trade_date)
# generate and save order list
order_list = user.strategy.generate_order_list(
order_list = user.strategy.generate_trade_decision(
score_series=score_series,
current=user.account.current,
trade_exchange=trade_exchange,
@@ -208,7 +208,7 @@ class Operator:
self.logger.info("Update account state {} for {}".format(trade_date, user_id))
def simulate(self, id, config, exchange_config, start, end, path, bench="SH000905"):
"""Run the ( generate_order_list -> execute_order_list -> update_account) process everyday
"""Run the ( generate_trade_decision -> execute_order_list -> update_account) process everyday
from start date to end date.
Parameters
@@ -256,7 +256,7 @@ class Operator:
user.strategy.update(score_series, pred_date, trade_date)
# 3. generate and save order list
order_list = user.strategy.generate_order_list(
order_list = user.strategy.generate_trade_decision(
score_series=score_series,
current=user.account.current,
trade_exchange=trade_exchange,

View File

@@ -10,17 +10,15 @@ import copy
class SoftTopkStrategy(WeightStrategyBase):
def __init__(
self,
step_bar,
model,
dataset,
topk,
start_time=None,
end_time=None,
order_generator_cls_or_obj=OrderGenWInteract,
trade_exchange=None,
max_sold_weight=1.0,
risk_degree=0.95,
buy_method="first_fill",
level_infra={},
common_infra={},
**kwargs,
):
"""Parameter
@@ -33,7 +31,7 @@ class SoftTopkStrategy(WeightStrategyBase):
average_fill: assign the weight to the stocks rank high averagely.
"""
super(SoftTopkStrategy, self).__init__(
step_bar, model, dataset, start_time, end_time, order_generator_cls_or_obj, trade_exchange
model, dataset, order_generator_cls_or_obj, level_infra, common_infra, **kwargs
)
self.topk = topk
self.max_sold_weight = max_sold_weight

View File

@@ -3,29 +3,26 @@ import warnings
import numpy as np
import pandas as pd
from ...utils.sample import sample_feature
from ...utils.resam import resam_ts_data
from ...strategy.base import ModelStrategy
from ..backtest.order import Order
from ..backtest.faculty import common_faculty
from .order_generator import OrderGenWInteract
class TopkDropoutStrategy(ModelStrategy):
def __init__(
self,
step_bar,
model,
dataset,
topk,
n_drop,
start_time=None,
end_time=None,
trade_exchange=None,
method_sell="bottom",
method_buy="top",
risk_degree=0.95,
hold_thresh=1,
only_tradable=False,
level_infra={},
common_infra={},
**kwargs,
):
"""
@@ -51,8 +48,9 @@ class TopkDropoutStrategy(ModelStrategy):
else:
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
"""
super(TopkDropoutStrategy, self).__init__(step_bar, model, dataset, start_time, end_time, **kwargs)
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
super(TopkDropoutStrategy, self).__init__(
model, dataset, level_infra=level_infra, common_infra=common_infra, **kwargs
)
self.topk = topk
self.n_drop = n_drop
self.method_sell = method_sell
@@ -61,6 +59,20 @@ class TopkDropoutStrategy(ModelStrategy):
self.hold_thresh = hold_thresh
self.only_tradable = only_tradable
def reset_common_infra(self, common_infra):
"""
Parameters
----------
common_infra : dict, optional
common infrastructure for backtesting, by default None
- It should include `trade_account`, used to get position
- It should include `trade_exchange`, used to provide market info
"""
super(TopkDropoutStrategy, self).reset_common_infra(common_infra)
if "trade_exchange" in common_infra:
self.trade_exchange = common_infra.get("trade_exchange")
def get_risk_degree(self, trade_index=None):
"""get_risk_degree
Return the proportion of your total value you will used in investment.
@@ -69,11 +81,11 @@ class TopkDropoutStrategy(ModelStrategy):
# It will use 95% amoutn of your total value by default
return self.risk_degree
def generate_order_list(self, execute_state):
super(TopkDropoutStrategy, self).step()
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1)
pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
def generate_trade_decision(self, execute_state):
trade_index = self.trade_calendar.get_trade_index()
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
pred_start_time, pred_end_time = self.trade_calendar.get_calendar_time(trade_index, shift=1)
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
if pred_score is None:
return []
if self.only_tradable:
@@ -115,8 +127,7 @@ class TopkDropoutStrategy(ModelStrategy):
def filter_stock(l):
return l
current = execute_state.get("current")
current_temp = copy.deepcopy(current)
current_temp = copy.deepcopy(self.trade_position)
# generate order list for this adjust date
sell_order_list = []
buy_order_list = []
@@ -168,7 +179,8 @@ class TopkDropoutStrategy(ModelStrategy):
continue
if code in sell:
# check hold limit
if current_temp.get_stock_count(code, bar=self.step_bar) < self.hold_thresh:
step_bar = self.trade_calendar.get_step_bar()
if current_temp.get_stock_count(code, bar=step_bar) < self.hold_thresh:
continue
# sell order
sell_amount = current_temp.get_stock_amount(code=code)
@@ -228,22 +240,35 @@ class TopkDropoutStrategy(ModelStrategy):
class WeightStrategyBase(ModelStrategy):
def __init__(
self,
step_bar,
model,
dataset,
start_time=None,
end_time=None,
order_generator_cls_or_obj=OrderGenWInteract,
trade_exchange=None,
level_infra={},
common_infra={},
**kwargs,
):
super(WeightStrategyBase, self).__init__(step_bar, model, dataset, start_time, end_time, **kwargs)
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
super(WeightStrategyBase, self).__init__(
model, dataset, level_infra=level_infra, common_infra=common_infra, **kwargs
)
if isinstance(order_generator_cls_or_obj, type):
self.order_generator = order_generator_cls_or_obj()
else:
self.order_generator = order_generator_cls_or_obj
def reset_common_infra(self, common_infra):
"""
Parameters
----------
common_infra : dict, optional
common infrastructure for backtesting, by default None
- It should include `trade_account`, used to get position
- It should include `trade_exchange`, used to provide market info
"""
super(WeightStrategyBase, self).reset_common_infra(common_infra)
if "trade_exchange" in common_infra:
self.trade_exchange = common_infra.get("trade_exchange")
def get_risk_degree(self, trade_index=None):
"""get_risk_degree
Return the proportion of your total value you will used in investment.
@@ -267,7 +292,7 @@ class WeightStrategyBase(ModelStrategy):
"""
raise NotImplementedError()
def generate_order_list(self, execute_state):
def generate_trade_decision(self, execute_state):
"""
Parameters
-----------
@@ -280,23 +305,22 @@ class WeightStrategyBase(ModelStrategy):
trade_date : pd.Timestamp
date.
"""
# generate_order_list
# generate_trade_decision
# generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list
super(WeightStrategyBase, self).step()
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1)
pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
trade_index = self.trade_calendar.get_trade_index()
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
pred_start_time, pred_end_time = self.trade_calendar.get_calendar_time(trade_index, shift=1)
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
if pred_score is None:
return []
current = execute_state.get("current")
current_temp = copy.deepcopy(current)
current_temp = copy.deepcopy(self.trade_position)
target_weight_position = self.generate_target_weight_position(
score=pred_score, current=current_temp, trade_start_time=trade_start_time, trade_end_time=trade_end_time
)
order_list = self.order_generator.generate_order_list_from_target_weight_position(
current=current_temp,
trade_exchange=self.trade_exchange,
risk_degree=self.get_risk_degree(self.trade_index),
risk_degree=self.get_risk_degree(trade_index),
target_weight_position=target_weight_position,
pred_start_time=pred_start_time,
pred_end_time=pred_end_time,

View File

@@ -1,80 +1,77 @@
import copy
import warnings
import numpy as np
import pandas as pd
from typing import Union
from ...utils.sample import sample_feature
from ...utils.resam import resam_ts_data
from ...data.data import D
from ...data.dataset.utils import convert_index_format
from ...strategy.base import RuleStrategy, OrderEnhancement
from ...strategy.base import RuleStrategy
from ..backtest.order import Order
from ..backtest.faculty import common_faculty
class TWAPStrategy(RuleStrategy, OrderEnhancement):
class TWAPStrategy(RuleStrategy):
"""TWAP Strategy for trading"""
def __init__(
self,
step_bar,
start_time=None,
end_time=None,
trade_exchange=None,
trade_order_list=[],
**kwargs,
):
def reset_common_infra(self, common_infra):
"""
Parameters
----------
trade_exchange : Exchange, optional
exchange that provides market info, by default None
- If `trade_exchange` is None, self.trade_exchange will be set with common_faculty
trade_order_list : list, optional
order list to trade, which the strategy will trade in [start_time , end_time] , by default []
common_infra : dict, optional
common infrastructure for backtesting, by default None
- It should include `trade_account`, used to get position
- It should include `trade_exchange`, used to provide market info
"""
super(TWAPStrategy, self).__init__(step_bar, start_time, end_time, **kwargs)
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
self.trade_order_list = trade_order_list
super(TWAPStrategy, self).reset_common_infra(common_infra)
if common_infra is not None:
if "trade_exchange" in common_infra:
self.trade_exchange = common_infra.get("trade_exchange")
def reset(self, trade_order_list: list = None, **kwargs):
super(TWAPStrategy, self).reset(**kwargs)
OrderEnhancement.reset(self, trade_order_list=trade_order_list)
if trade_order_list is not None:
def reset(self, rely_trade_decision: object = None, **kwargs):
"""
Parameters
----------
rely_trade_decision : object, optional
"""
super(TWAPStrategy, self).reset(rely_trade_decision=rely_trade_decision, common_infra=common_infra, **kwargs)
if rely_trade_decision is not None:
self.trade_amount = {}
for order in self.trade_order_list:
for order in rely_trade_decision:
self.trade_amount[(order.stock_id, order.direction)] = order.amount
def generate_order_list(self, execute_state):
super(TWAPStrategy, self).step()
trade_info = execute_state.get("trade_info")
def generate_trade_decision(self, execute_state):
# update the order amount
trade_info = execute_state
for order, _, _, _ in trade_info:
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
trade_index = self.trade_calendar.get_trade_index()
trade_len = self.trade_calendar.get_trade_len()
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
order_list = []
for order in self.trade_order_list:
for order in self.rely_trade_decision:
if not self.trade_exchange.is_stock_tradable(
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
):
continue
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
_order_amount = None
# consider trade unit
if _amount_trade_unit is None:
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (
self.trade_len - self.trade_index + 1
)
if self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
# split the order equally
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_index + 1)
# without considering trade unit
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
# split the order equally
# floor((trade_unit_cnt + trade_len - trade_index) / (trade_len - trade_index + 1)) == ceil(trade_unit_cnt / (trade_len - trade_index + 1))
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
_order_amount = (
(trade_unit_cnt + self.trade_len - self.trade_index)
// (self.trade_len - self.trade_index + 1)
* _amount_trade_unit
(trade_unit_cnt + trade_len - trade_index) // (trade_len - trade_index + 1) * _amount_trade_unit
)
if order.direction == order.SELL:
# sell all amount at last
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
_order_amount is None or self.trade_index == self.trade_len
_order_amount is None or trade_index == trade_len
):
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
@@ -92,7 +89,7 @@ class TWAPStrategy(RuleStrategy, OrderEnhancement):
return order_list
class SBBStrategyBase(RuleStrategy, OrderEnhancement):
class SBBStrategyBase(RuleStrategy):
"""
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy.
"""
@@ -101,81 +98,80 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement):
TREND_SHORT = 1
TREND_LONG = 2
def __init__(
self,
step_bar,
start_time=None,
end_time=None,
trade_exchange=None,
trade_order_list=[],
**kwargs,
):
def reset_common_infra(self, common_infra):
super(SBBStrategyBase, self).reset_common_infra(common_infra)
if common_infra is not None:
if "trade_exchange" in common_infra:
self.trade_exchange = common_infra.get("trade_exchange")
def reset(self, rely_trade_decision=None, **kwargs):
"""
Parameters
----------
trade_exchange : Exchange, optional
exchange that provides market info, by default None
- If `trade_exchange` is None, self.trade_exchange will be set with common_faculty
trade_order_list : list, optional
order list to trade, which the strategy will trade in [start_time , end_time] , by default []
rely_trade_decision : object, optional
common_infra : None, optional
common infrastructure for backtesting, by default None
- It should include `trade_account`, used to get position
- It should include `trade_exchange`, used to provide market info
"""
super(SBBStrategyBase, self).__init__(step_bar, start_time, end_time, **kwargs)
self.trade_exchange = common_faculty.trade_exchange if trade_exchange is None else trade_exchange
self.trade_order_list = trade_order_list
def reset(self, trade_order_list=None, **kwargs):
super(SBBStrategyBase, self).reset(**kwargs)
OrderEnhancement.reset(self, trade_order_list=trade_order_list)
if trade_order_list is not None:
super(SBBStrategyBase, self).reset(rely_trade_decision=rely_trade_decision, **kwargs)
if rely_trade_decision is not None:
self.trade_trend = {}
self.trade_amount = {}
for order in self.trade_order_list:
# init the trade amount of order and predicted trade trend
for order in rely_trade_decision:
self.trade_trend[(order.stock_id, order.direction)] = self.TREND_MID
self.trade_amount[(order.stock_id, order.direction)] = order.amount
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
raise NotImplementedError("pred_price_trend method is not implemented!")
def generate_order_list(self, execute_state):
super(SBBStrategyBase, self).step()
def generate_trade_decision(self, execute_state):
trade_info = execute_state.get("trade_info")
# update the order amount
trade_info = execute_state
for order, _, _, _ in trade_info:
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1)
trade_index = self.trade_calendar.get_trade_index()
trade_len = self.trade_calendar.get_trade_len()
trade_start_time, trade_end_time = self.trade_calendar.get_calendar_time(trade_index)
pred_start_time, pred_end_time = self.trade_calendar.get_calendar_time(trade_index, shift=1)
order_list = []
for order in self.trade_order_list:
if self.trade_index % 2 == 1:
# for each order in in self.rely_trade_decision
for order in self.rely_trade_decision:
# predict the price trend
if trade_index % 2 == 1:
_pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time)
else:
_pred_trend = self.trade_trend[(order.stock_id, order.direction)]
# if not tradable, continue
if not self.trade_exchange.is_stock_tradable(
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
):
if self.trade_index % 2 == 1:
if trade_index % 2 == 1:
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
continue
# get amount of one trade unit
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
if _pred_trend == self.TREND_MID:
_order_amount = None
# considering trade unit
if _amount_trade_unit is None:
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (
self.trade_len - self.trade_index + 1
)
# split the order equally
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_index + 1)
# without considering trade unit
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
# cal how many trade unit
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
# split the order equally
# floor((trade_unit_cnt + trade_len - trade_index) / (trade_len - trade_index + 1)) == ceil(trade_unit_cnt / (trade_len - trade_index + 1))
_order_amount = (
(trade_unit_cnt + self.trade_len - self.trade_index)
// (self.trade_len - self.trade_index + 1)
* _amount_trade_unit
(trade_unit_cnt + trade_len - trade_index) // (trade_len - trade_index + 1) * _amount_trade_unit
)
if order.direction == order.SELL:
# sell all amount at last
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
_order_amount is None or self.trade_index == self.trade_len
_order_amount is None or trade_index == trade_len
):
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
@@ -185,36 +181,43 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement):
amount=_order_amount,
start_time=trade_start_time,
end_time=trade_end_time,
direction=order.direction, # 1 for buy
direction=order.direction,
factor=order.factor,
)
order_list.append(_order)
# print("DEBUG AMOUNT", _order_amount, self.trade_amount[(order.stock_id, order.direction)], _amount_trade_unit)
else:
_order_amount = None
# considering trade unit
if _amount_trade_unit is None:
# N trade day last, split the order into N + 1 parts, and trade 2 parts
_order_amount = (
2
* self.trade_amount[(order.stock_id, order.direction)]
/ (self.trade_len - self.trade_index + 2)
2 * self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_index + 2)
)
# without considering trade unit
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
# cal how many trade unit
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
# N trade day last, split the order into N + 1 parts, and trade 2 parts
_order_amount = (
(trade_unit_cnt + self.trade_len - self.trade_index + 1)
// (self.trade_len - self.trade_index + 2)
(trade_unit_cnt + trade_len - trade_index + 1)
// (trade_len - trade_index + 2)
* 2
* _amount_trade_unit
)
if order.direction == order.SELL:
# sell all amount at last
if self.trade_amount[(order.stock_id, order.direction)] >= 1e-5 and (
_order_amount is None or self.trade_index == self.trade_len
_order_amount is None or trade_index == trade_len
):
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
if _order_amount:
_order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
if self.trade_index % 2 == 1:
if trade_index % 2 == 1:
# in the first of two adjacent bar
# if look short on the price, sell the stock more
# if look long on the price, sell the stock more
if (
_pred_trend == self.TREND_SHORT
and order.direction == order.SELL
@@ -231,6 +234,9 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement):
)
order_list.append(_order)
else:
# in the second of two adjacent bar
# if look short on the price, buy the stock more
# if look long on the price, sell the stock more
if (
_pred_trend == self.TREND_SHORT
and order.direction == order.BUY
@@ -246,8 +252,8 @@ class SBBStrategyBase(RuleStrategy, OrderEnhancement):
factor=order.factor,
)
order_list.append(_order)
# print("DEBUG AMOUNT", _order_amount, self.trade_amount[(order.stock_id, order.direction)], _amount_trade_unit)
if self.trade_index % 2 == 1:
if trade_index % 2 == 1:
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
return order_list
@@ -260,13 +266,11 @@ class SBBStrategyEMA(SBBStrategyBase):
def __init__(
self,
step_bar,
start_time=None,
end_time=None,
trade_exchange=None,
trade_order_list=[],
rely_trade_decision=[],
instruments="csi300",
freq="day",
level_infra={},
common_infra={},
**kwargs,
):
"""
@@ -278,47 +282,49 @@ class SBBStrategyEMA(SBBStrategyBase):
freq of EMA signal, by default "day"
Note: `freq` may be different from `steb_bar`
"""
super(SBBStrategyEMA, self).__init__(step_bar, start_time, end_time, trade_exchange, trade_order_list, **kwargs)
if instruments is None:
warnings.warn("`instruments` is not set, will load all stocks")
self.instruments = "all"
if isinstance(instruments, str):
self.instruments = D.instruments(instruments)
self.freq = freq
super(SBBStrategyEMA, self).__init__(rely_trade_decision, level_infra, common_infra, **kwargs)
def reset(self, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None, **kwargs):
def _reset_signal(self):
trade_len = self.trade_calendar.get_trade_len()
fields = ["EMA($close, 10)-EMA($close, 20)"]
signal_start_time, _ = self.trade_calendar.get_calendar_time(trade_index=1, shift=1)
_, signal_end_time = self.trade_calendar.get_calendar_time(trade_index=trade_len, shift=1)
signal_df = D.features(
self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
)
signal_df = convert_index_format(signal_df)
signal_df.columns = ["signal"]
self.signal = {}
for stock_id, stock_val in signal_df.groupby(level="instrument"):
self.signal[stock_id] = stock_val
def reset_level_infra(self, level_infra):
"""
Reset EMA signal for trading
Parameters
----------
start_time : Union[str, pd.Timestamp], optional
start time for trading, also used to calculate the start time of EMA signal, by default None
end_time : Union[str, pd.Timestamp], optional
end time for trading, also used to calculate the end time of EMA signal, by default None
reset level-shared infra
- After reset the trade_calendar, the signal will be changed
"""
super(SBBStrategyEMA, self).reset(start_time=start_time, end_time=end_time, **kwargs)
if self.start_time and self.end_time and (start_time or end_time):
fields = ["EMA($close, 10)-EMA($close, 20)"]
signal_start_time, _ = self._get_calendar_time(trade_index=1, shift=1)
_, signal_end_time = self._get_calendar_time(trade_index=self.trade_len, shift=1)
signal_df = D.features(
self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
)
signal_df = convert_index_format(signal_df)
signal_df.columns = ["signal"]
self.signal = {}
for stock_id, stock_val in signal_df.groupby(level="instrument"):
self.signal[stock_id] = stock_val
if not hasattr(self, "level_infra"):
self.level_infra = level_infra
else:
self.level_infra.update(level_infra)
if "trade_calendar" in level_infra:
self.trade_calendar = level_infra.get("trade_calendar")
self._reset_signal()
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
if stock_id not in self.signal:
return self.TREND_MID
else:
_sample_signal = sample_feature(
self.signal[stock_id], pred_start_time, pred_end_time, fields="signal", method="last"
_sample_signal = resam_ts_data(
self.signal[stock_id]["signal"], pred_start_time, pred_end_time, method="last"
)
if _sample_signal is None or _sample_signal.iloc[0] == 0:
return self.TREND_MID

View File

@@ -26,7 +26,7 @@ from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, co
from .base import Feature
from .cache import DiskDatasetCache, DiskExpressionCache
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
from ..utils.sample import sample_calendar
from ..utils.resam import resam_calendar
class CalendarProvider(abc.ABC):
@@ -133,7 +133,7 @@ class CalendarProvider(abc.ABC):
if freq_sam is None:
return _calendar, _calendar_index
else:
_calendar_sam = sample_calendar(_calendar, freq, freq_sam)
_calendar_sam = resam_calendar(_calendar, freq, freq_sam)
_calendar_sam_index = {x: i for i, x in enumerate(_calendar_sam)}
H["c"][flag] = _calendar_sam, _calendar_sam_index
return _calendar_sam, _calendar_sam_index

View File

@@ -1,9 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .interpreter import StateInterpreter, ActionInterpreter
from typing import Union
from .interpreter import StateInterpreter, ActionInterpreter
from ..contrib.backtest.executor import BaseExecutor
from ..utils import init_instance_by_config
class BaseRLEnv:
@@ -52,35 +54,22 @@ class QlibIntRLEnv(QlibRLEnv):
def __init__(
self,
executor: BaseExecutor,
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
state_interpret_kwargs: dict = {},
action_interpret_kwargs: dict = {},
state_interpreter: Union[dict, StateInterpreter],
action_interpreter: Union[dict, ActionInterpreter],
):
"""
Parameters
----------
state_interpreter : StateInterpreter
state_interpreter : Union[dict, StateInterpreter]
interpretor that interprets the qlib execute result into rl env state.
action_interpreter : ActionInterpreter
action_interpreter : Union[dict, ActionInterpreter]
interpretor that interprets the rl agent action into qlib order list
state_interpret_kwargs : dict, optional
arguments may be used in `state_interpreter.interpret`, by default {}
such as the following arguments:
- trade exchange : Exchange
Exchange that can provide market info
action_interpret_kwargs: dict, optional
arguments may be used in `action_interpreter.interpret`, by default {}
such as the following arguments:
- trade_order_list : List[Order]
If the strategy is used to split order, it presents the trade order pool.
"""
super(QlibIntRLEnv, self).__init__(executor=executor)
self.state_interpreter = state_interpreter
self.action_interpreter = action_interpreter
self.state_interpret_kwargs = state_interpret_kwargs
self.action_interpret_kwargs = action_interpret_kwargs
self.state_interpreter = init_instance_by_config(state_interpreter)
self.action_interpreter = init_instance_by_config(action_interpreter)
def step(self, action):
"""
@@ -96,11 +85,9 @@ class QlibIntRLEnv(QlibRLEnv):
Returns
-------
env state to rl rl policy
env state to rl policy
"""
_interpret_action = self.action_interpreter.interpret(action=action, **self.state_interpret_kwargs)
_interpret_action = self.action_interpreter.interpret(action=action)
_execute_result = self.executor.execute(_interpret_action)
_interpret_state = self.state_interpreter.interpret(
execute_result=_execute_result, **self.action_interpret_kwargs
)
_interpret_state = self.state_interpreter.interpret(execute_result=_execute_result)
return _interpret_state

View File

@@ -5,7 +5,6 @@
class BaseInterpreter:
"""Base Interpreter"""
@staticmethod
def interpret(**kwargs):
raise NotImplementedError("interpret is not implemented!")
@@ -13,7 +12,6 @@ class BaseInterpreter:
class ActionInterpreter(BaseInterpreter):
"""Action Interpreter that interpret rl agent action into qlib orders"""
@staticmethod
def interpret(action, **kwargs):
"""interpret method
@@ -34,7 +32,6 @@ class ActionInterpreter(BaseInterpreter):
class StateInterpreter(BaseInterpreter):
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
@staticmethod
def interpret(execute_result, **kwargs):
"""interpret method

View File

@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import copy
import pandas as pd
from typing import List, Union
@@ -9,16 +10,70 @@ from ..model.base import BaseModel
from ..data.dataset import DatasetH
from ..data.dataset.utils import convert_index_format
from ..contrib.backtest.order import Order
from ..contrib.backtest.executor import BaseTradeCalendar
from ..rl.interpreter import ActionInterpreter, StateInterpreter
from ..utils import init_instance_by_config
class BaseStrategy(BaseTradeCalendar):
class BaseStrategy:
"""Base strategy for trading"""
def generate_order_list(self, execute_state):
"""Generate order list in each trading bar"""
raise NotImplementedError("generator_order_list is not implemented!")
def __init__(
self,
rely_trade_decision: object = None,
level_infra: dict = {},
common_infra: dict = {},
):
"""
Parameters
----------
rely_trade_decision : object, optional
the high-level trade decison on which the startegy rely, and it will be traded in [start_time , end_time] , by default None
- If the strategy is used to split trade decison, it will be used
- If the strategy is used for portfolio management, it can be ignored
level_infra : dict, optional
level shared infrastructure for backtesting, including trade_calendar
common_infra : dict, optional
common infrastructure for backtesting, including trade_account, trade_exchange, .etc
"""
self.reset(level_infra=level_infra, common_infra=common_infra, rely_trade_decision=rely_trade_decision)
def reset_level_infra(self, level_infra):
if not hasattr(self, "level_infra"):
self.level_infra = level_infra
else:
self.level_infra.update(level_infra)
if "trade_calendar" in level_infra:
self.trade_calendar = level_infra.get("trade_calendar")
def reset_common_infra(self, common_infra):
if not hasattr(self, "common_infra"):
self.common_infra = common_infra
else:
self.common_infra.update(common_infra)
if "trade_account" in common_infra:
self.trade_position = common_infra.get("trade_account").current
def reset(self, level_infra: dict = None, common_infra: dict = None, rely_trade_decision=None, **kwargs):
"""
- reset `level_infra`, used to reset trade_calendar, .etc
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
- reset `rely_trade_decision`, used to make split decison
"""
if level_infra is not None:
self.reset_level_infra(level_infra)
if common_infra is not None:
self.reset_common_infra(common_infra)
if rely_trade_decision is not None:
self.rely_trade_decision = rely_trade_decision
def generate_trade_decision(self, execute_state):
"""Generate trade decision in each trading bar"""
raise NotImplementedError("generate_trade_decision is not implemented!")
class RuleStrategy(BaseStrategy):
@@ -32,11 +87,11 @@ class ModelStrategy(BaseStrategy):
def __init__(
self,
step_bar: str,
model: BaseModel,
dataset: DatasetH,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
rely_trade_decision: object = None,
level_infra: dict = {},
common_infra: dict = {},
**kwargs,
):
"""
@@ -49,11 +104,10 @@ class ModelStrategy(BaseStrategy):
kwargs : dict
arguments that will be passed into `reset` method
"""
super(ModelStrategy, self).__init__(rely_trade_decision, level_infra, common_infra, **kwargs)
self.model = model
self.dataset = dataset
self.pred_scores = convert_index_format(self.model.predict(dataset), level="datetime")
# pred_score_dates = self.pred_scores.index.get_level_values(level="datetime")
super(ModelStrategy, self).__init__(step_bar, start_time, end_time, **kwargs)
def _update_model(self):
"""
@@ -70,10 +124,10 @@ class RLStrategy(BaseStrategy):
def __init__(
self,
step_bar: str,
policy,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
rely_trade_decision: object = None,
level_infra: dict = {},
common_infra: dict = {},
**kwargs,
):
"""
@@ -82,7 +136,7 @@ class RLStrategy(BaseStrategy):
policy :
RL policy for generate action
"""
super(RLStrategy, self).__init__(step_bar, start_time, end_time, **kwargs)
super(RLStrategy, self).__init__(rely_trade_decision, level_infra, common_infra, **kwargs)
self.policy = policy
@@ -91,14 +145,12 @@ class RLIntStrategy(RLStrategy):
def __init__(
self,
step_bar: str,
policy,
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
state_interpret_kwargs: dict = {},
action_interpret_kwargs: dict = {},
rely_trade_decision: object = None,
level_infra: dict = {},
common_infra: dict = {},
**kwargs,
):
"""
@@ -112,49 +164,16 @@ class RLIntStrategy(RLStrategy):
start time of trading, by default None
end_time : Union[str, pd.Timestamp], optional
end time of trading, by default None
state_interpret_kwargs : dict, optional
arguments may be used in `state_interpreter.interpret`, by default {}
such as the following arguments:
- trade exchange : Exchange
Exchange that can provide market info
action_interpret_kwargs: dict, optional
arguments may be used in `action_interpreter.interpret`, by default {}
such as the following arguments:
- trade_order_list : List[Order]
If the strategy is used to split order, it presents the trade order pool.
"""
super(RLIntStrategy, self).__init__(step_bar, policy, start_time, end_time, **kwargs)
super(RLIntStrategy, self).__init__(policy, rely_trade_decision, level_infra, common_infra, **kwargs)
self.policy = policy
self.action_interpreter = action_interpreter
self.state_interpreter = state_interpreter
self.state_interpret_kwargs = state_interpret_kwargs
self.action_interpret_kwargs = action_interpret_kwargs
self.state_interpreter = init_instance_by_config(state_interpreter)
self.action_interpreter = init_instance_by_config(action_interpreter)
def generate_order_list(self, execute_state):
def generate_trade_decision(self, execute_state):
super(RLStrategy, self).step()
_interpret_state = self.state_interpretor.interpret(
execute_result=execute_state, **self.action_interpret_kwargs
)
_interpret_state = self.state_interpretor.interpret(execute_result=execute_state)
_policy_action = self.policy.step(_interpret_state)
_order_list = self.action_interpreter.interpret(action=_policy_action, **self.state_interpret_kwargs)
_order_list = self.action_interpreter.interpret(action=_policy_action)
return _order_list
class OrderEnhancement:
"""
Order enhancement for strategy
- If the strategy is used to split orders, the enhancement should be inherited
- If the strategy is used for portfolio management, the enhancement can be ignored
"""
def reset(self, trade_order_list: List[Order] = None):
"""reset trade orders for split strategy
Parameters
----------
trade_order_list for split strategy: List[Order], optional
trading orders , by default None
"""
if trade_order_list is not None:
self.trade_order_list = trade_order_list

View File

@@ -1,8 +1,13 @@
import re
import datetime
import numpy as np
import pandas as pd
from typing import Tuple, List, Union, Optional, Callable
from . import lazy_sort_index
from ..config import C
def parse_freq(freq: str) -> Tuple[int, str]:
"""
@@ -50,9 +55,10 @@ def parse_freq(freq: str) -> Tuple[int, str]:
return _count, _freq_format_dict[_freq]
def sample_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np.ndarray:
def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np.ndarray:
"""
Sample the calendar with frequency freq_raw into the calendar with frequency freq_sam
Resample the calendar with frequency freq_raw into the calendar with frequency freq_sam
Assumption: The fix length (240) of the calendar in each day.
Parameters
----------
@@ -72,24 +78,36 @@ def sample_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> n
sam_count, freq_sam = parse_freq(freq_sam)
if not len(calendar_raw):
return calendar_raw
# if freq_sam is xminute, divide each trading day into several bars evenly
if freq_sam == "minute":
def cal_next_sam_minute(x, sam_minutes):
hour = x.hour
minute = x.minute
if (hour == 9 and minute >= 30) or (9 < hour < 11) or (hour == 11 and minute < 30):
minute_index = (hour - 9) * 60 + minute - 30
elif 13 <= hour < 15:
minute_index = (hour - 13) * 60 + minute + 120
def cal_sam_minute(x, sam_minutes):
day_time = pd.Timestamp(x.date())
shift = C.min_data_shift
# shift represents the shift minute the market time
# - open time of stock market is [9:30 - shift*pd.Timedelta(minutes=1)]
# - mid close time of stock market is [11:29 - shift*pd.Timedelta(minutes=1)]
# - mid open time of stock market is [13:30 - shift*pd.Timedelta(minutes=1)]
# - close time of stock market is [14:59 - shift*pd.Timedelta(minutes=1)]
open_time = day_time + pd.Timedelta(hours=9, minutes=30) - shift * pd.Timedelta(minutes=1)
mid_close_time = day_time + pd.Timedelta(hours=11, minutes=29) - shift * pd.Timedelta(minutes=1)
mid_open_time = day_time + pd.Timedelta(hours=13, minutes=30) - shift * pd.Timedelta(minutes=1)
close_time = day_time + pd.Timedelta(hours=14, minutes=59) - shift * pd.Timedelta(minutes=1)
if open_time <= x <= mid_close_time:
minute_index = (x - open_time).seconds // 60
elif mid_open_time <= x <= close_time:
minute_index = (x - mid_open_time).seconds // 60 + 120
else:
raise ValueError("calendar hour must be in [9, 11] or [13, 15]")
raise ValueError("datetime of calendar is out of range")
minute_index = minute_index // sam_minutes * sam_minutes
if 0 <= minute_index < 120:
return 9 + (minute_index + 30) // 60, (minute_index + 30) % 60
return open_time + minute_index * pd.Timedelta(minutes=1)
elif 120 <= minute_index < 240:
return 13 + (minute_index - 120) // 60, (minute_index - 120) % 60
return mid_open_time + (minute_index - 120) * pd.Timedelta(minutes=1)
else:
raise ValueError("calendar minute_index error")
@@ -98,14 +116,10 @@ def sample_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> n
else:
if raw_count > sam_count:
raise ValueError("raw freq must be higher than sampling freq")
_calendar_minute = np.unique(
list(
map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_count), 0), calendar_raw)
)
)
if calendar_raw[0] > _calendar_minute[0]:
_calendar_minute[0] = calendar_raw[0]
_calendar_minute = np.unique(list(map(lambda x: cal_sam_minute(x, sam_count), calendar_raw)))
return _calendar_minute
# else, convert the raw calendar into day calendar, and divide the whole calendar into several bars evenly
else:
_calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw)))
if freq_sam == "day":
@@ -124,14 +138,14 @@ def sample_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> n
raise ValueError("sampling freq must be xmin, xd, xw, xm")
def get_sample_freq_calendar(
def get_resam_calendar(
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
freq: str = "day",
future: bool = False,
) -> Tuple[np.ndarray, str, Optional[str]]:
"""
Get the calendar with frequency freq.
Get the resampled calendar with frequency freq.
- If the calendar with the raw frequency freq exists, return it directly
@@ -186,16 +200,15 @@ def get_sample_freq_calendar(
return _calendar, freq, freq_sam
def sample_feature(
feature: Union[pd.DataFrame, pd.Series],
def resam_ts_data(
ts_feature: Union[pd.DataFrame, pd.Series],
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
fields: Union[str, List[str]] = None,
method: Union[str, Callable] = "last",
method_kwargs: dict = {},
):
"""
Sample value from pandas DataFrame or Series for each stock
Resample value from time-series data
- If `feature` has MultiIndex[instrument, datetime], apply the `method` to each instruemnt data with datetime in [start_time, end_time]
Example:
@@ -217,7 +230,7 @@ def sample_feature(
2010-01-12 2788.688232 164587.937500
2010-01-13 2790.604004 145460.453125
print(sample_feature(feature, start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last"))
print(resam_ts_data(feature, start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last"))
$close $volume
instrument
SH600000 87.433578 28117442.0
@@ -236,25 +249,23 @@ def sample_feature(
2010-01-07 83.788803 20813402.0
2010-01-08 84.730675 16044853.0
print(sample_feature(feature, start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last"))
print(resam_ts_data(feature, start_time="2010-01-04", end_time="2010-01-05", method="last"))
$close 87.433578
$volume 28117442.0
print(sample_feature(feature, start_time="2010-01-04", end_time="2010-01-05", fields="$close", method="last"))
print(resam_ts_data(feature['$close'], start_time="2010-01-04", end_time="2010-01-05", method="last"))
87.433578
Parameters
----------
feature : Union[pd.DataFrame, pd.Series]
Raw feature to be sampled
Raw time-series feature to be resampled
start_time : Union[str, pd.Timestamp], optional
start sampling time, by default None
end_time : Union[str, pd.Timestamp], optional
end sampling time, by default None
fields : Union[str, List[str]], optional
column names, it's ignored when sample pd.Series data, by default None(all columns)
method : Union[str, Callable], optional
sample method, apply method function to each stock series data, by default "last"
- If type(method) is str, it should be an attribute of SeriesGroupBy or DataFrameGroupby, and run feature.groupby
@@ -264,24 +275,19 @@ def sample_feature(
Returns
-------
The Sampled DataFrame/Series/Value
The Resampled DataFrame/Series/Value
"""
selector_datetime = slice(start_time, end_time)
if fields is None:
fields = slice(None)
from ..data.dataset.utils import get_level_index
feature = lazy_sort_index(ts_feature)
datetime_level = get_level_index(feature, level="datetime") == 0
if isinstance(feature, pd.Series):
feature = feature.loc[selector_datetime] if datetime_level else feature.loc[(slice(None), selector_datetime)]
elif isinstance(feature, pd.DataFrame):
feature = (
feature.loc[selector_datetime, fields]
if datetime_level
else feature.loc[(slice(None), selector_datetime), fields]
)
if datetime_level:
feature = feature.loc[selector_datetime]
else:
feature = feature.loc[(slice(None), selector_datetime)]
if feature.empty:
return None
if isinstance(feature.index, pd.MultiIndex):
@@ -296,5 +302,4 @@ def sample_feature(
return method_func(feature, **method_kwargs)
elif isinstance(method, str):
return getattr(feature, method)(**method_kwargs)
return feature

View File

@@ -15,7 +15,7 @@ from ..data.dataset.handler import DataHandlerLP
from ..utils import init_instance_by_config, get_module_by_module_path
from ..log import get_module_logger
from ..utils import flatten_dict
from ..utils.sample import parse_freq
from ..utils.resam import parse_freq
from ..strategy.base import BaseStrategy
from ..contrib.eva.alpha import calc_ic, calc_long_short_return
@@ -291,8 +291,8 @@ class PortAnaRecord(RecordTemp):
"""
config["strategy"] : dict
define the strategy class as well as the kwargs.
config["env"] : dict
define the env class as well as the kwargs.
config["executor"] : dict
define the executor class as well as the kwargs.
config["backtest"] : dict
define the backtest kwargs.
risk_analysis_freq : int
@@ -301,24 +301,26 @@ class PortAnaRecord(RecordTemp):
super().__init__(recorder=recorder, **kwargs)
self.strategy_config = config["strategy"]
self.env_config = config["env"]
self.executor_config = config["executor"]
self.backtest_config = config["backtest"]
_count, _freq = parse_freq(risk_analysis_freq)
self.risk_analysis_freq = f"{_count}{_freq}"
self.report_freq = self._get_report_freq(self.env_config)
self.report_freq = self._get_report_freq(self.executor_config)
def _get_report_freq(self, env_config):
def _get_report_freq(self, executor_config):
ret_freq = []
if env_config["kwargs"].get("generate_report", False):
_count, _freq = parse_freq(env_config["kwargs"]["step_bar"])
if executor_config["kwargs"].get("generate_report", False):
_count, _freq = parse_freq(executor_config["kwargs"]["step_bar"])
ret_freq.append(f"{_count}{_freq}")
if "sub_env" in env_config["kwargs"]:
ret_freq.extend(self._get_report_freq(env_config["kwargs"]["sub_env"]))
if "sub_env" in executor_config["kwargs"]:
ret_freq.extend(self._get_report_freq(executor_config["kwargs"]["sub_env"]))
return ret_freq
def generate(self, **kwargs):
# custom strategy and get backtest
report_dict = normal_backtest(env=self.env_config, strategy=self.strategy_config, **self.backtest_config)
report_dict = normal_backtest(
executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config
)
for report_freq, (report_normal, positions_normal) in report_dict.items():
self.recorder.save_objects(
**{f"report_normal_{report_freq}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()