1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 02:21:18 +08:00

fix bug in recorder

This commit is contained in:
bxdd
2021-04-30 01:06:05 +08:00
parent f404a031f3
commit a109df3f46
8 changed files with 63 additions and 83 deletions

View File

@@ -10,6 +10,8 @@ from qlib.config import REG_CN
from qlib.contrib.strategy import TopkDropoutStrategy
from qlib.contrib.backtest import backtest
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import PortAnaRecord
from qlib.tests.data import GetData
if __name__ == "__main__":
@@ -78,7 +80,7 @@ if __name__ == "__main__":
trade_start_time = "2017-01-31"
trade_end_time = "2018-01-31"
backtest_config = {
port_analysis_config = {
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
@@ -101,6 +103,7 @@ if __name__ == "__main__":
"kwargs": {
"step_bar": "day",
"verbose": True,
"generate_report": True,
},
},
"sub_strategy": {
@@ -128,11 +131,19 @@ if __name__ == "__main__":
},
}
report_dict = backtest(
start_time=trade_start_time,
end_time=trade_end_time,
**backtest_config,
account=1e8,
deal_price="$close",
verbose=False,
)
#report_dict = backtest(
# start_time=trade_start_time,
# end_time=trade_end_time,
# **backtest_config,
# account=1e8,
# benchmark=benchmark,
# deal_price="$close",
# verbose=False,
#)
with R.start(experiment_name="highfreq_backtest"):
# backtest. If users want to use backtest based on their own prediction,
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
recorder = R.get_recorder()
par = PortAnaRecord(recorder, port_analysis_config, 1)
par.generate()

View File

@@ -118,7 +118,7 @@ def setup_exchange(root_instance, trade_exchange=None, force=False):
setup_exchange(root_instance.sub_strategy, trade_exchange)
def backtest(start_time, end_time, strategy, env, benchmark=None, account=1e9, **kwargs):
def backtest(start_time, end_time, strategy, env, benchmark="SH000905", account=1e9, **kwargs):
trade_strategy = init_instance_by_config(strategy)
trade_env = init_env_instance_by_config(env)

View File

@@ -8,6 +8,7 @@ import pandas as pd
from .position import Position
from .report import Report
from .order import Order
from ...data import D
from ...utils import parse_freq, sample_feature
@@ -95,7 +96,8 @@ class Account:
def cal_change(x):
return x.prod() - 1
return sample_feature(bench, trade_start_time, trade_end_time, method=cal_change)
_ret = sample_feature(bench, trade_start_time, trade_end_time, method=cal_change)
return 0 if _ret is None else _ret
def reset(self, benchmark=None, freq=None, **kwargs):
if benchmark:
@@ -105,9 +107,9 @@ class Account:
if self.freq and self.benchmark and (freq or benchmark):
self.bench = self._cal_benchmark(self.benchmark, self.start_time, self.end_time, self.freq)
for k, v in kwargs:
if hasattr(k):
setattr(k, v)
for k, v in kwargs.items():
if hasattr(self, k):
setattr(self, k, v)
def get_positions(self):
return self.positions
@@ -150,7 +152,7 @@ class Account:
self.current.update_order(order, trade_val, cost, trade_price)
self.update_state_from_order(order, trade_val, cost, trade_price)
def update_report(self, trade_start_time, trade_end_time, trade_exchange):
def update_bar_end(self, trade_start_time, trade_end_time, trade_exchange, update_report):
"""
start_time: pd.TimeStamp
end_time: pd.TimeStamp
@@ -166,6 +168,9 @@ class Account:
:return: None
"""
# update price for stock in the position and the profit from changed_price
self.current.add_count_all(bar=self.freq)
if update_report is None:
return
stock_list = self.current.get_stock_list()
for code in stock_list:
# if suspend, no new price to be updated, profit is 0
@@ -174,7 +179,7 @@ class Account:
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
self.current.update_stock_price(stock_id=code, price=bar_close)
# update holding day count
self.current.add_count_all(bar=self.freq)
# update value
self.val = self.current.calculate_value()
# update earning
@@ -212,7 +217,7 @@ class Account:
self.positions[trade_start_time] = copy.deepcopy(self.current)
# finish today's updation
# reset the daily variables
# reset the bar variables
self.rtn = 0
self.ct = 0
self.to = 0

View File

@@ -19,8 +19,4 @@ def backtest(start_time, end_time, trade_strategy, trade_env, benchmark, account
_order_list = trade_strategy.generate_order_list(**trade_state)
trade_state, trade_info = trade_env.execute(_order_list)
report_df = trade_account.report.generate_report_dataframe()
positions = trade_account.get_positions()
report_dict = {"report_df": report_df, "positions": positions}
return report_dict
return trade_env.get_report()

View File

@@ -42,7 +42,7 @@ class BaseTradeCalendar:
if start_time or end_time:
self._reset_trade_calendar(start_time=start_time, end_time=end_time)
for k, v in kwargs:
for k, v in kwargs.items():
if hasattr(self, k):
setattr(self, k, v)
@@ -52,7 +52,7 @@ class BaseTradeCalendar:
return self.calendar[calendar_index - 1], self.calendar[calendar_index]
def finished(self):
return self.trade_index >= self.trade_len
return self.trade_index >= self.trade_len - 1
def step(self):
self.trade_index = self.trade_index + 1
@@ -71,11 +71,11 @@ class BaseEnv(BaseTradeCalendar):
start_time=None,
end_time=None,
trade_account=None,
update_report=False,
generate_report=False,
verbose=False,
**kwargs,
):
self.generate_report = update_report
self.generate_report = generate_report
self.verbose = verbose
super(BaseEnv, self).__init__(
step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs
@@ -110,7 +110,8 @@ class SplitEnv(BaseEnv):
start_time=None,
end_time=None,
trade_account=None,
update_report=False,
trade_exchange=None,
generate_report=False,
verbose=False,
**kwargs,
):
@@ -121,15 +122,18 @@ class SplitEnv(BaseEnv):
start_time=start_time,
end_time=end_time,
trade_account=trade_account,
update_report=update_report,
trade_exchange=trade_exchange,
generate_report=generate_report,
verbose=verbose,
**kwargs,
)
def reset(self, trade_account=None, **kwargs):
def reset(self, trade_account=None, trade_exchange=None, **kwargs):
super(SplitEnv, self).reset(trade_account=trade_account, **kwargs)
if trade_account:
self.sub_env.reset(trade_account=copy.copy(trade_account))
if trade_exchange:
self.trade_exchange = trade_exchange
def execute(self, order_list, **kwargs):
if self.finished():
@@ -146,10 +150,9 @@ class SplitEnv(BaseEnv):
_order_list = self.sub_strategy.generate_order_list(**trade_state)
trade_state, trade_info = self.sub_env.execute(order_list=_order_list)
if self.generate_report:
self.trade_account.update_report(
trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange
)
self.trade_account.update_bar_end(
trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange, update_report=self.generate_report
)
_obs = {"current": self.trade_account.current}
_info = {}
return _obs, _info
@@ -157,7 +160,7 @@ class SplitEnv(BaseEnv):
def get_report(self):
_report = self.trade_account.report.generate_report_dataframe() if self.generate_report else None
_positions = self.trade_account.get_positions() if self.generate_report else None
return [(_report, _positions), *sub_env.get_report()]
return [(_report, _positions), *self.sub_env.get_report()]
class SimulatorEnv(BaseEnv):
@@ -168,7 +171,7 @@ class SimulatorEnv(BaseEnv):
end_time=None,
trade_account=None,
trade_exchange=None,
update_report=False,
generate_report=False,
verbose=False,
**kwargs,
):
@@ -178,7 +181,7 @@ class SimulatorEnv(BaseEnv):
end_time=end_time,
trade_account=trade_account,
trade_exchange=trade_exchange,
update_report=update_report,
generate_report=generate_report,
verbose=verbose,
**kwargs,
)
@@ -231,10 +234,9 @@ class SimulatorEnv(BaseEnv):
print("[W {:%Y-%m-%d}]: {} wrong.".format(trade_start_time, order.stock_id))
# do nothing
pass
if self.generate_report:
self.trade_account.update_report(
trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange
)
self.trade_account.update_bar_end(
trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange, update_report=self.generate_report
)
_obs = {"current": self.trade_account.current}
_info = {"trade_info": trade_info}
return _obs, _info
@@ -242,4 +244,4 @@ class SimulatorEnv(BaseEnv):
def get_report(self):
_report = self.trade_account.report.generate_report_dataframe() if self.generate_report else None
_positions = self.trade_account.get_positions() if self.generate_report else None
return [{"report": _report, "positions": _positions}]
return [(_report, _positions)]

View File

@@ -9,7 +9,7 @@ import pandas as pd
import warnings
from ..log import get_module_logger
from .backtest import get_exchange, backtest as backtest_func
from .backtest.backtest import get_date_range
from ..utils import get_date_range
from ..data import D
from ..config import C

View File

@@ -23,7 +23,6 @@ class TopkDropoutStrategy(ModelStrategy):
method_sell="bottom",
method_buy="top",
risk_degree=0.95,
thresh=1,
hold_thresh=1,
only_tradable=False,
**kwargs,
@@ -41,11 +40,9 @@ class TopkDropoutStrategy(ModelStrategy):
dropout method_buy, random/top.
risk_degree : float
position percentage of total value.
thresh : int
minimun holding days since last buy singal of the stock.
hold_thresh : int
minimum holding days
before sell stock , will check current.get_stock_count(order.stock_id) >= self.thresh.
before sell stock , will check current.get_stock_count(order.stock_id) >= self.hold_thresh.
only_tradable : bool
will the strategy only consider the tradable stock when buying and selling.
if only_tradable:
@@ -61,10 +58,6 @@ class TopkDropoutStrategy(ModelStrategy):
self.method_sell = method_sell
self.method_buy = method_buy
self.risk_degree = risk_degree
self.thresh = thresh
# self.stock_count['code'] will be the days the stock has been hold
# since last buy signal. This is designed for thresh
self.stock_count = {}
self.hold_thresh = hold_thresh
self.only_tradable = only_tradable
@@ -170,10 +163,7 @@ class TopkDropoutStrategy(ModelStrategy):
# Get the stock list we really want to buy
buy = today[: len(sell) + self.topk - len(last)]
# buy singal: if a stock falls into topk, it appear in the buy_sinal
buy_signal = pred_score.sort_values(ascending=False).iloc[: self.topk].index
#print("flag", len(sell), len(buy), self.topk, len(last))
for code in current_stock_list:
if not self.trade_exchange.is_stock_tradable(
stock_id=code, start_time=trade_start_time, end_time=trade_end_time
@@ -181,13 +171,7 @@ class TopkDropoutStrategy(ModelStrategy):
continue
if code in sell:
# check hold limit
if (
self.stock_count[code] < self.thresh
or current_temp.get_stock_count(code, bar=self.step_bar) < self.hold_thresh
):
# can not sell this code
# no buy signal, but the stock is kept
self.stock_count[code] += 1
if current_temp.get_stock_count(code, bar=self.step_bar) < self.hold_thresh:
continue
# sell order
sell_amount = current_temp.get_stock_amount(code=code)
@@ -207,18 +191,6 @@ class TopkDropoutStrategy(ModelStrategy):
)
# update cash
cash += trade_val - trade_cost
# sold
self.stock_count[code] = 0
else:
# no buy signal, but the stock is kept
self.stock_count[code] += 1
elif code in buy_signal:
# NOTE: This is different from the original version
# get new buy signal
# Only the stock fall in to topk will produce buy signal
self.stock_count[code] = 1
else:
self.stock_count[code] += 1
# buy new stock
# note the current has been changed
current_stock_list = current_temp.get_stock_list()
@@ -249,7 +221,6 @@ class TopkDropoutStrategy(ModelStrategy):
factor=factor,
)
buy_order_list.append(buy_order)
self.stock_count[code] = 1
return sell_order_list + buy_order_list

View File

@@ -14,8 +14,9 @@ from ..data.dataset.handler import DataHandlerLP
from ..utils import init_instance_by_config, get_module_by_module_path
from ..log import get_module_logger
from ..utils import flatten_dict
from ..strategy.base import BaseStrategy
from ..contrib.eva.alpha import calc_ic, calc_long_short_return
from ..contrib.strategy.strategy import BaseStrategy
logger = get_module_logger("workflow", "INFO")
@@ -212,7 +213,7 @@ class SigAnaRecord(SignalRecord):
return paths
class PortAnaRecord(SignalRecord):
class PortAnaRecord(RecordTemp):
"""
This is the Portfolio Analysis Record class that generates the analysis results such as those of backtest. This class inherits the ``RecordTemp`` class.
@@ -243,16 +244,10 @@ class PortAnaRecord(SignalRecord):
self.risk_analysis_dep = risk_analysis_dep
def generate(self, **kwargs):
# check previously stored prediction results
try:
self.check(parent=True) # "Make sure the parent process is completed and store the data properly."
except FileExistsError:
super().generate()
# custom strategy and get backtest
report_list = normal_backtest(env=self.env_config, strategy=self.strategy_config, **self.backtest_config)
for report_dep, (report_normal, positions_normal) in enumerate(report_list):
if report_dict is None:
if report_normal is None:
if self.risk_analysis_dep == report_dep:
warnings.warn(
f"the report in dep {risk_analysis_dep} is None, please set the corresponding env with `generate_report==True`"