1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 02:21:18 +08:00

fix trade time bug

This commit is contained in:
bxdd
2021-05-06 21:33:33 +08:00
parent ae339506b3
commit 7540ecde11
5 changed files with 56 additions and 36 deletions

View File

@@ -10,7 +10,7 @@ from qlib.config import REG_CN
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import PortAnaRecord
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.tests.data import GetData
if __name__ == "__main__":
@@ -64,9 +64,9 @@ if __name__ == "__main__":
"kwargs": data_handler_config,
},
"segments": {
"train": ("2012-01-01", "2014-12-31"),
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2018-01-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
@@ -74,17 +74,16 @@ if __name__ == "__main__":
# model initialization
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
model.fit(dataset)
trade_start_time = "2017-01-31"
trade_end_time = "2018-01-31"
trade_start_time = "2017-01-01"
trade_end_time = "2020-08-01"
port_analysis_config = {
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"kwargs": {
"step_bar": "week",
"step_bar": "day",
"model": model,
"dataset": dataset,
"topk": 50,
@@ -92,28 +91,12 @@ if __name__ == "__main__":
},
},
"env": {
"class": "SplitEnv",
"class": "SimulatorEnv",
"module_path": "qlib.contrib.backtest.env",
"kwargs": {
"step_bar": "week",
"sub_env": {
"class": "SimulatorEnv",
"module_path": "qlib.contrib.backtest.env",
"kwargs": {
"step_bar": "day",
"verbose": True,
"generate_report": True,
},
},
"sub_strategy": {
"class": "SBBStrategyEMA",
"module_path": "qlib.contrib.strategy.rule_strategy",
"kwargs": {
"step_bar": "day",
"freq": "day",
"instruments": "csi300",
},
},
"step_bar": "day",
"verbose": True,
"generate_report": True,
},
},
"backtest": {
@@ -129,9 +112,18 @@ if __name__ == "__main__":
"min_cost": 5,
},
}
with R.start(experiment_name="highfreq_backtest"):
R.log_params(**flatten_dict(task))
model.fit(dataset)
R.save_objects(**{"params.pkl": model})
# prediction
recorder = R.get_recorder()
sr = SignalRecord(model, dataset, recorder)
sr.generate()
# backtest. If users want to use backtest based on their own prediction,
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
recorder = R.get_recorder()
par = PortAnaRecord(recorder, port_analysis_config, "day")
par.generate()
par.generate()

View File

@@ -94,7 +94,7 @@ class Account:
def _sample_benchmark(self, bench, trade_start_time, trade_end_time):
def cal_change(x):
return x.prod() - 1
return (x + 1).prod() - 1
_ret = sample_feature(bench, trade_start_time, trade_end_time, method=cal_change)
return 0 if _ret is None else _ret

View File

@@ -49,7 +49,7 @@ class BaseTradeCalendar:
def _get_calendar_time(self, trade_index=1, shift=0):
trade_index = trade_index - shift
calendar_index = self.start_index + trade_index
return self.calendar[calendar_index - 1], self.calendar[calendar_index]
return self.calendar[calendar_index - 1], self.calendar[calendar_index] - pd.Timedelta(seconds=1)
def finished(self):
return self.trade_index >= self.trade_len - 1

View File

@@ -51,7 +51,7 @@ class TopkDropoutStrategy(ModelStrategy):
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
"""
super(TopkDropoutStrategy, self).__init__(
step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange
step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange, **kwargs
)
self.topk = topk
self.n_drop = n_drop

View File

@@ -11,16 +11,30 @@ from ..backtest.order import Order
class TWAPStrategy(RuleStrategy, TradingEnhancement):
def reset(self, trade_order_list=None, **kwargs):
def __init__(
self,
step_bar,
start_time=None,
end_time=None,
trade_exchange=None,
**kwargs,
):
super(TWAPStrategy, self).__init__(
step_bar, start_time, end_time, trade_exchange=trade_exchange, **kwargs
)
def reset(self, trade_order_list=None, trade_exchange=None, **kwargs):
super(TWAPStrategy, self).reset(**kwargs)
TradingEnhancement.reset(self, trade_order_list=trade_order_list)
if trade_order_list:
self.trade_amount = {}
for order in self.trade_order_list:
self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len
if trade_exchange:
self.trade_exchange = trade_exchange
def generate_order_list(self, **kwargs):
super(TopkDropoutStrategy, self).step()
super(TWAPStrategy, self).step()
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
order_list = []
for order in self.trade_order_list:
@@ -44,8 +58,19 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
TREND_MID = 0
TREND_SHORT = 1
TREND_LONG = 2
def __init__(
self,
step_bar,
start_time=None,
end_time=None,
trade_exchange=None,
**kwargs,
):
super(SBBStrategyBase, self).__init__(
step_bar, start_time, end_time, trade_exchange=trade_exchange, **kwargs
)
def reset(self, trade_order_list=None, **kwargs):
def reset(self, trade_order_list=None, trade_exchange=None, **kwargs):
super(SBBStrategyBase, self).reset(**kwargs)
TradingEnhancement.reset(self, trade_order_list=trade_order_list)
if trade_order_list:
@@ -54,6 +79,8 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
for order in self.trade_order_list:
self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len
self.trade_trend[(order.stock_id, order.direction)] = self.TREND_MID
if trade_exchange:
self.trade_exchange = trade_exchange
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
raise NotImplementedError("pred_price_trend method is not implemented!")
@@ -127,11 +154,12 @@ class SBBStrategyEMA(SBBStrategyBase):
step_bar,
start_time=None,
end_time=None,
trade_exchange=None,
instruments="csi300",
freq="day",
**kwargs,
):
super(SBBStrategyEMA, self).__init__(step_bar, start_time, end_time, **kwargs)
super(SBBStrategyEMA, self).__init__(step_bar, start_time, end_time, trade_exchange=trade_exchange, **kwargs)
if instruments is None:
warnings.warn("`instruments` is not set, will load all stocks")
self.instruments = "all"