mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 02:21:18 +08:00
fix bug in recorder
This commit is contained in:
@@ -10,6 +10,8 @@ from qlib.config import REG_CN
|
||||
from qlib.contrib.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.backtest import backtest
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import PortAnaRecord
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -78,7 +80,7 @@ if __name__ == "__main__":
|
||||
trade_start_time = "2017-01-31"
|
||||
trade_end_time = "2018-01-31"
|
||||
|
||||
backtest_config = {
|
||||
port_analysis_config = {
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
@@ -101,6 +103,7 @@ if __name__ == "__main__":
|
||||
"kwargs": {
|
||||
"step_bar": "day",
|
||||
"verbose": True,
|
||||
"generate_report": True,
|
||||
},
|
||||
},
|
||||
"sub_strategy": {
|
||||
@@ -128,11 +131,19 @@ if __name__ == "__main__":
|
||||
},
|
||||
}
|
||||
|
||||
report_dict = backtest(
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
**backtest_config,
|
||||
account=1e8,
|
||||
deal_price="$close",
|
||||
verbose=False,
|
||||
)
|
||||
#report_dict = backtest(
|
||||
# start_time=trade_start_time,
|
||||
# end_time=trade_end_time,
|
||||
# **backtest_config,
|
||||
# account=1e8,
|
||||
# benchmark=benchmark,
|
||||
# deal_price="$close",
|
||||
# verbose=False,
|
||||
#)
|
||||
|
||||
with R.start(experiment_name="highfreq_backtest"):
|
||||
# backtest. If users want to use backtest based on their own prediction,
|
||||
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
|
||||
recorder = R.get_recorder()
|
||||
par = PortAnaRecord(recorder, port_analysis_config, 1)
|
||||
par.generate()
|
||||
@@ -118,7 +118,7 @@ def setup_exchange(root_instance, trade_exchange=None, force=False):
|
||||
setup_exchange(root_instance.sub_strategy, trade_exchange)
|
||||
|
||||
|
||||
def backtest(start_time, end_time, strategy, env, benchmark=None, account=1e9, **kwargs):
|
||||
def backtest(start_time, end_time, strategy, env, benchmark="SH000905", account=1e9, **kwargs):
|
||||
trade_strategy = init_instance_by_config(strategy)
|
||||
trade_env = init_env_instance_by_config(env)
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import pandas as pd
|
||||
from .position import Position
|
||||
from .report import Report
|
||||
from .order import Order
|
||||
from ...data import D
|
||||
from ...utils import parse_freq, sample_feature
|
||||
|
||||
|
||||
@@ -95,7 +96,8 @@ class Account:
|
||||
def cal_change(x):
|
||||
return x.prod() - 1
|
||||
|
||||
return sample_feature(bench, trade_start_time, trade_end_time, method=cal_change)
|
||||
_ret = sample_feature(bench, trade_start_time, trade_end_time, method=cal_change)
|
||||
return 0 if _ret is None else _ret
|
||||
|
||||
def reset(self, benchmark=None, freq=None, **kwargs):
|
||||
if benchmark:
|
||||
@@ -105,9 +107,9 @@ class Account:
|
||||
if self.freq and self.benchmark and (freq or benchmark):
|
||||
self.bench = self._cal_benchmark(self.benchmark, self.start_time, self.end_time, self.freq)
|
||||
|
||||
for k, v in kwargs:
|
||||
if hasattr(k):
|
||||
setattr(k, v)
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
def get_positions(self):
|
||||
return self.positions
|
||||
@@ -150,7 +152,7 @@ class Account:
|
||||
self.current.update_order(order, trade_val, cost, trade_price)
|
||||
self.update_state_from_order(order, trade_val, cost, trade_price)
|
||||
|
||||
def update_report(self, trade_start_time, trade_end_time, trade_exchange):
|
||||
def update_bar_end(self, trade_start_time, trade_end_time, trade_exchange, update_report):
|
||||
"""
|
||||
start_time: pd.TimeStamp
|
||||
end_time: pd.TimeStamp
|
||||
@@ -166,6 +168,9 @@ class Account:
|
||||
:return: None
|
||||
"""
|
||||
# update price for stock in the position and the profit from changed_price
|
||||
self.current.add_count_all(bar=self.freq)
|
||||
if update_report is None:
|
||||
return
|
||||
stock_list = self.current.get_stock_list()
|
||||
for code in stock_list:
|
||||
# if suspend, no new price to be updated, profit is 0
|
||||
@@ -174,7 +179,7 @@ class Account:
|
||||
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
|
||||
self.current.update_stock_price(stock_id=code, price=bar_close)
|
||||
# update holding day count
|
||||
self.current.add_count_all(bar=self.freq)
|
||||
|
||||
# update value
|
||||
self.val = self.current.calculate_value()
|
||||
# update earning
|
||||
@@ -212,7 +217,7 @@ class Account:
|
||||
self.positions[trade_start_time] = copy.deepcopy(self.current)
|
||||
|
||||
# finish today's updation
|
||||
# reset the daily variables
|
||||
# reset the bar variables
|
||||
self.rtn = 0
|
||||
self.ct = 0
|
||||
self.to = 0
|
||||
|
||||
@@ -19,8 +19,4 @@ def backtest(start_time, end_time, trade_strategy, trade_env, benchmark, account
|
||||
_order_list = trade_strategy.generate_order_list(**trade_state)
|
||||
trade_state, trade_info = trade_env.execute(_order_list)
|
||||
|
||||
report_df = trade_account.report.generate_report_dataframe()
|
||||
positions = trade_account.get_positions()
|
||||
report_dict = {"report_df": report_df, "positions": positions}
|
||||
|
||||
return report_dict
|
||||
return trade_env.get_report()
|
||||
|
||||
@@ -42,7 +42,7 @@ class BaseTradeCalendar:
|
||||
if start_time or end_time:
|
||||
self._reset_trade_calendar(start_time=start_time, end_time=end_time)
|
||||
|
||||
for k, v in kwargs:
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
@@ -52,7 +52,7 @@ class BaseTradeCalendar:
|
||||
return self.calendar[calendar_index - 1], self.calendar[calendar_index]
|
||||
|
||||
def finished(self):
|
||||
return self.trade_index >= self.trade_len
|
||||
return self.trade_index >= self.trade_len - 1
|
||||
|
||||
def step(self):
|
||||
self.trade_index = self.trade_index + 1
|
||||
@@ -71,11 +71,11 @@ class BaseEnv(BaseTradeCalendar):
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trade_account=None,
|
||||
update_report=False,
|
||||
generate_report=False,
|
||||
verbose=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.generate_report = update_report
|
||||
self.generate_report = generate_report
|
||||
self.verbose = verbose
|
||||
super(BaseEnv, self).__init__(
|
||||
step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs
|
||||
@@ -110,7 +110,8 @@ class SplitEnv(BaseEnv):
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trade_account=None,
|
||||
update_report=False,
|
||||
trade_exchange=None,
|
||||
generate_report=False,
|
||||
verbose=False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -121,15 +122,18 @@ class SplitEnv(BaseEnv):
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
trade_account=trade_account,
|
||||
update_report=update_report,
|
||||
trade_exchange=trade_exchange,
|
||||
generate_report=generate_report,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def reset(self, trade_account=None, **kwargs):
|
||||
def reset(self, trade_account=None, trade_exchange=None, **kwargs):
|
||||
super(SplitEnv, self).reset(trade_account=trade_account, **kwargs)
|
||||
if trade_account:
|
||||
self.sub_env.reset(trade_account=copy.copy(trade_account))
|
||||
if trade_exchange:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def execute(self, order_list, **kwargs):
|
||||
if self.finished():
|
||||
@@ -146,10 +150,9 @@ class SplitEnv(BaseEnv):
|
||||
_order_list = self.sub_strategy.generate_order_list(**trade_state)
|
||||
trade_state, trade_info = self.sub_env.execute(order_list=_order_list)
|
||||
|
||||
if self.generate_report:
|
||||
self.trade_account.update_report(
|
||||
trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange
|
||||
)
|
||||
self.trade_account.update_bar_end(
|
||||
trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange, update_report=self.generate_report
|
||||
)
|
||||
_obs = {"current": self.trade_account.current}
|
||||
_info = {}
|
||||
return _obs, _info
|
||||
@@ -157,7 +160,7 @@ class SplitEnv(BaseEnv):
|
||||
def get_report(self):
|
||||
_report = self.trade_account.report.generate_report_dataframe() if self.generate_report else None
|
||||
_positions = self.trade_account.get_positions() if self.generate_report else None
|
||||
return [(_report, _positions), *sub_env.get_report()]
|
||||
return [(_report, _positions), *self.sub_env.get_report()]
|
||||
|
||||
|
||||
class SimulatorEnv(BaseEnv):
|
||||
@@ -168,7 +171,7 @@ class SimulatorEnv(BaseEnv):
|
||||
end_time=None,
|
||||
trade_account=None,
|
||||
trade_exchange=None,
|
||||
update_report=False,
|
||||
generate_report=False,
|
||||
verbose=False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -178,7 +181,7 @@ class SimulatorEnv(BaseEnv):
|
||||
end_time=end_time,
|
||||
trade_account=trade_account,
|
||||
trade_exchange=trade_exchange,
|
||||
update_report=update_report,
|
||||
generate_report=generate_report,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -231,10 +234,9 @@ class SimulatorEnv(BaseEnv):
|
||||
print("[W {:%Y-%m-%d}]: {} wrong.".format(trade_start_time, order.stock_id))
|
||||
# do nothing
|
||||
pass
|
||||
if self.generate_report:
|
||||
self.trade_account.update_report(
|
||||
trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange
|
||||
)
|
||||
self.trade_account.update_bar_end(
|
||||
trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange, update_report=self.generate_report
|
||||
)
|
||||
_obs = {"current": self.trade_account.current}
|
||||
_info = {"trade_info": trade_info}
|
||||
return _obs, _info
|
||||
@@ -242,4 +244,4 @@ class SimulatorEnv(BaseEnv):
|
||||
def get_report(self):
|
||||
_report = self.trade_account.report.generate_report_dataframe() if self.generate_report else None
|
||||
_positions = self.trade_account.get_positions() if self.generate_report else None
|
||||
return [{"report": _report, "positions": _positions}]
|
||||
return [(_report, _positions)]
|
||||
|
||||
@@ -9,7 +9,7 @@ import pandas as pd
|
||||
import warnings
|
||||
from ..log import get_module_logger
|
||||
from .backtest import get_exchange, backtest as backtest_func
|
||||
from .backtest.backtest import get_date_range
|
||||
from ..utils import get_date_range
|
||||
|
||||
from ..data import D
|
||||
from ..config import C
|
||||
|
||||
@@ -23,7 +23,6 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
method_sell="bottom",
|
||||
method_buy="top",
|
||||
risk_degree=0.95,
|
||||
thresh=1,
|
||||
hold_thresh=1,
|
||||
only_tradable=False,
|
||||
**kwargs,
|
||||
@@ -41,11 +40,9 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
dropout method_buy, random/top.
|
||||
risk_degree : float
|
||||
position percentage of total value.
|
||||
thresh : int
|
||||
minimun holding days since last buy singal of the stock.
|
||||
hold_thresh : int
|
||||
minimum holding days
|
||||
before sell stock , will check current.get_stock_count(order.stock_id) >= self.thresh.
|
||||
before sell stock , will check current.get_stock_count(order.stock_id) >= self.hold_thresh.
|
||||
only_tradable : bool
|
||||
will the strategy only consider the tradable stock when buying and selling.
|
||||
if only_tradable:
|
||||
@@ -61,10 +58,6 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
self.method_sell = method_sell
|
||||
self.method_buy = method_buy
|
||||
self.risk_degree = risk_degree
|
||||
self.thresh = thresh
|
||||
# self.stock_count['code'] will be the days the stock has been hold
|
||||
# since last buy signal. This is designed for thresh
|
||||
self.stock_count = {}
|
||||
self.hold_thresh = hold_thresh
|
||||
self.only_tradable = only_tradable
|
||||
|
||||
@@ -170,10 +163,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
|
||||
# Get the stock list we really want to buy
|
||||
buy = today[: len(sell) + self.topk - len(last)]
|
||||
|
||||
# buy singal: if a stock falls into topk, it appear in the buy_sinal
|
||||
buy_signal = pred_score.sort_values(ascending=False).iloc[: self.topk].index
|
||||
|
||||
#print("flag", len(sell), len(buy), self.topk, len(last))
|
||||
for code in current_stock_list:
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
stock_id=code, start_time=trade_start_time, end_time=trade_end_time
|
||||
@@ -181,13 +171,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
continue
|
||||
if code in sell:
|
||||
# check hold limit
|
||||
if (
|
||||
self.stock_count[code] < self.thresh
|
||||
or current_temp.get_stock_count(code, bar=self.step_bar) < self.hold_thresh
|
||||
):
|
||||
# can not sell this code
|
||||
# no buy signal, but the stock is kept
|
||||
self.stock_count[code] += 1
|
||||
if current_temp.get_stock_count(code, bar=self.step_bar) < self.hold_thresh:
|
||||
continue
|
||||
# sell order
|
||||
sell_amount = current_temp.get_stock_amount(code=code)
|
||||
@@ -207,18 +191,6 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
)
|
||||
# update cash
|
||||
cash += trade_val - trade_cost
|
||||
# sold
|
||||
self.stock_count[code] = 0
|
||||
else:
|
||||
# no buy signal, but the stock is kept
|
||||
self.stock_count[code] += 1
|
||||
elif code in buy_signal:
|
||||
# NOTE: This is different from the original version
|
||||
# get new buy signal
|
||||
# Only the stock fall in to topk will produce buy signal
|
||||
self.stock_count[code] = 1
|
||||
else:
|
||||
self.stock_count[code] += 1
|
||||
# buy new stock
|
||||
# note the current has been changed
|
||||
current_stock_list = current_temp.get_stock_list()
|
||||
@@ -249,7 +221,6 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
factor=factor,
|
||||
)
|
||||
buy_order_list.append(buy_order)
|
||||
self.stock_count[code] = 1
|
||||
return sell_order_list + buy_order_list
|
||||
|
||||
|
||||
|
||||
@@ -14,8 +14,9 @@ from ..data.dataset.handler import DataHandlerLP
|
||||
from ..utils import init_instance_by_config, get_module_by_module_path
|
||||
from ..log import get_module_logger
|
||||
from ..utils import flatten_dict
|
||||
from ..strategy.base import BaseStrategy
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return
|
||||
from ..contrib.strategy.strategy import BaseStrategy
|
||||
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
|
||||
@@ -212,7 +213,7 @@ class SigAnaRecord(SignalRecord):
|
||||
return paths
|
||||
|
||||
|
||||
class PortAnaRecord(SignalRecord):
|
||||
class PortAnaRecord(RecordTemp):
|
||||
"""
|
||||
This is the Portfolio Analysis Record class that generates the analysis results such as those of backtest. This class inherits the ``RecordTemp`` class.
|
||||
|
||||
@@ -243,16 +244,10 @@ class PortAnaRecord(SignalRecord):
|
||||
self.risk_analysis_dep = risk_analysis_dep
|
||||
|
||||
def generate(self, **kwargs):
|
||||
# check previously stored prediction results
|
||||
try:
|
||||
self.check(parent=True) # "Make sure the parent process is completed and store the data properly."
|
||||
except FileExistsError:
|
||||
super().generate()
|
||||
|
||||
# custom strategy and get backtest
|
||||
report_list = normal_backtest(env=self.env_config, strategy=self.strategy_config, **self.backtest_config)
|
||||
for report_dep, (report_normal, positions_normal) in enumerate(report_list):
|
||||
if report_dict is None:
|
||||
if report_normal is None:
|
||||
if self.risk_analysis_dep == report_dep:
|
||||
warnings.warn(
|
||||
f"the report in dep {risk_analysis_dep} is None, please set the corresponding env with `generate_report==True`"
|
||||
|
||||
Reference in New Issue
Block a user