mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 11:30:57 +08:00
update trade calendar & backtest workflow
This commit is contained in:
@@ -10,10 +10,7 @@ from qlib.config import REG_CN
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.backtest import backtest
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
@@ -124,12 +121,4 @@ if __name__ == "__main__":
|
||||
}
|
||||
|
||||
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
|
||||
# backtest. If users want to use backtest based on their own prediction,
|
||||
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
|
||||
par = PortAnaRecord(recorder, port_analysis_config)
|
||||
par.generate()
|
||||
backtest(**backtest_config, )
|
||||
@@ -2,11 +2,10 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .order import Order
|
||||
from .account import Account
|
||||
from .position import Position
|
||||
from .exchange import Exchange
|
||||
from .report import Report
|
||||
from .backtest import backtest as backtest_func, get_date_range
|
||||
from .backtest import backtest as backtest_func
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
@@ -18,21 +17,6 @@ from ..config import C
|
||||
logger = get_module_logger("backtest caller")
|
||||
|
||||
|
||||
def init_env_instance_by_config(env):
|
||||
if isinstance(env, dict):
|
||||
env_config = copy.copy(env)
|
||||
if "kwargs" in env_config:
|
||||
env_kwargs = copy.copy(env_config["kwargs"]):
|
||||
if "sub_env" in env_kwargs:
|
||||
env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"])
|
||||
if "sub_strategy" in env_kwargs:
|
||||
env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"])
|
||||
env_config["kwargs"] = env_kwargs
|
||||
return init_instance_by_config(env_config)
|
||||
else:
|
||||
return env
|
||||
|
||||
|
||||
def get_exchange(
|
||||
pred,
|
||||
exchange=None,
|
||||
@@ -103,36 +87,44 @@ def get_exchange(
|
||||
else:
|
||||
return init_instance_by_config(exchange, accept_types=Exchange)
|
||||
|
||||
def backtest(start_time, end_time, strategy, env, account=1e9, benchmark, **kwargs):
|
||||
def init_env_instance_by_config(env):
|
||||
if isinstance(env, dict):
|
||||
env_config = copy.copy(env)
|
||||
if "kwargs" in env_config:
|
||||
env_kwargs = copy.copy(env_config["kwargs"]):
|
||||
if "sub_env" in env_kwargs:
|
||||
env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"])
|
||||
if "sub_strategy" in env_kwargs:
|
||||
env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"])
|
||||
env_config["kwargs"] = env_kwargs
|
||||
return init_instance_by_config(env_config)
|
||||
else:
|
||||
return env
|
||||
|
||||
def setup_exchange(root_instance, trade_exchange=None, force=False):
|
||||
if "trade_exchange" in inspect.getfullargspec(root_instance.__class__).args:
|
||||
if force:
|
||||
root_instance.reset(trade_exchange=trade_exchange)
|
||||
else:
|
||||
if not hasattr(root_instance, "trade_exchange") or root_instance.trade_exchange is None:
|
||||
root_instance.reset(trade_exchange=trade_exchange)
|
||||
if hasattr(root_instance, "sub_env"):
|
||||
setup_exchange(root_instance.sub_env, trade_exchange)
|
||||
if hasattr(root_instance, "sub_strategy"):
|
||||
setup_exchange(root_instance.sub_strategy, trade_exchange)
|
||||
|
||||
|
||||
def backtest(start_time, end_time, strategy, env, benchmark, account=1e9, **kwargs):
|
||||
trade_strategy = init_instance_by_config(strategy)
|
||||
trade_env = init_env_instance_by_config(env)
|
||||
trade_account = Account(init_cash=account)
|
||||
|
||||
spec = inspect.getfullargspec(get_exchange)
|
||||
ex_args = {k: v for k, v in kwargs.items() if k in spec.args}
|
||||
trade_exchange = get_exchange(pred, **ex_args)
|
||||
|
||||
temp_env = trade_env
|
||||
while True:
|
||||
if hasattr(temp_env, "trade_exchange"):
|
||||
temp_env.reset(trade_exchange=trade_exchange)
|
||||
if hasattr(temp_env, "sub_env"):
|
||||
temp_env = temp_env.sub_env
|
||||
else:
|
||||
break
|
||||
|
||||
trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account)
|
||||
trade_strategy.reset(start_time=start_time, end_time=end_time)
|
||||
trade_state = self.sub_env.get_first_state()
|
||||
|
||||
|
||||
while not trade_env.finished():
|
||||
_order_list = self.sub_strategy.generate_order(**trade_state)
|
||||
trade_state, trade_info = self.sub_env.execute(sub_order_list)
|
||||
|
||||
report_df = trade_account.report.generate_report_dataframe()
|
||||
positions = trade_account.get_positions()
|
||||
setup_exchange(trade_env, trade_exchange)
|
||||
setup_exchange(trade_strategy, trade_exchange)
|
||||
|
||||
report_dict = {"report_df": report_df, "positions": positions}
|
||||
report_dict = backtest_func(start_time, end_time, trade_strategy, trade_env, benchmark, account)
|
||||
|
||||
return
|
||||
return report_dict
|
||||
|
||||
@@ -4,140 +4,23 @@
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from ...utils import get_date_by_shift, get_date_range
|
||||
from ...data import D
|
||||
|
||||
from .account import Account
|
||||
from ...config import C
|
||||
from ...log import get_module_logger
|
||||
from ...data.dataset.utils import get_level_index
|
||||
|
||||
LOG = get_module_logger("backtest")
|
||||
|
||||
|
||||
def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, benchmark, return_order):
|
||||
"""Parameters
|
||||
----------
|
||||
pred : pandas.DataFrame
|
||||
predict should has <datetime, instrument> index and one `score` column
|
||||
Qlib want to support multi-singal strategy in the future. So pd.Series is not used.
|
||||
strategy : Strategy()
|
||||
strategy part for backtest
|
||||
trade_exchange : Exchange()
|
||||
exchage for backtest
|
||||
shift : int
|
||||
whether to shift prediction by one day
|
||||
verbose : bool
|
||||
whether to print log
|
||||
account : float
|
||||
init account value
|
||||
benchmark : str/list/pd.Series
|
||||
`benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
|
||||
example:
|
||||
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
|
||||
2017-01-04 0.011693
|
||||
2017-01-05 0.000721
|
||||
2017-01-06 -0.004322
|
||||
2017-01-09 0.006874
|
||||
2017-01-10 -0.003350
|
||||
|
||||
`benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
|
||||
`benchmark` is str, will use the daily change as the 'bench'.
|
||||
benchmark code, default is SH000905 CSI500
|
||||
"""
|
||||
# Convert format if the input format is not expected
|
||||
if get_level_index(pred, level="datetime") == 1:
|
||||
pred = pred.swaplevel().sort_index()
|
||||
if isinstance(pred, pd.Series):
|
||||
pred = pred.to_frame("score")
|
||||
def backtest(trade_strategy, trade_env, benchmark, account):
|
||||
|
||||
trade_account = Account(init_cash=account)
|
||||
_pred_dates = pred.index.get_level_values(level="datetime")
|
||||
predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max())
|
||||
if isinstance(benchmark, pd.Series):
|
||||
bench = benchmark
|
||||
else:
|
||||
_codes = benchmark if isinstance(benchmark, list) else [benchmark]
|
||||
_temp_result = D.features(
|
||||
_codes,
|
||||
["$close/Ref($close,1)-1"],
|
||||
predict_dates[0],
|
||||
get_date_by_shift(predict_dates[-1], shift=shift),
|
||||
disk_cache=1,
|
||||
)
|
||||
if len(_temp_result) == 0:
|
||||
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
|
||||
bench = _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean()
|
||||
trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account)
|
||||
trade_strategy.reset(start_time=start_time, end_time=end_time)
|
||||
|
||||
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift))
|
||||
if return_order:
|
||||
multi_order_list = []
|
||||
# trading apart
|
||||
for pred_date, trade_date in zip(predict_dates, trade_dates):
|
||||
# for loop predict date and trading date
|
||||
# print
|
||||
if verbose:
|
||||
LOG.info("[I {:%Y-%m-%d}]: trade begin.".format(trade_date))
|
||||
|
||||
# 1. Load the score_series at pred_date
|
||||
try:
|
||||
score = pred.loc(axis=0)[pred_date, :] # (trade_date, stock_id) multi_index, score in pdate
|
||||
score_series = score.reset_index(level="datetime", drop=True)[
|
||||
"score"
|
||||
] # pd.Series(index:stock_id, data: score)
|
||||
except KeyError:
|
||||
LOG.warning("No score found on predict date[{:%Y-%m-%d}]".format(trade_date))
|
||||
score_series = None
|
||||
|
||||
if score_series is not None and score_series.count() > 0: # in case of the scores are all None
|
||||
# 2. Update your strategy (and model)
|
||||
strategy.update(score_series, pred_date, trade_date)
|
||||
|
||||
# 3. Generate order list
|
||||
order_list = strategy.generate_order_list(
|
||||
score_series=score_series,
|
||||
current=trade_account.current,
|
||||
trade_exchange=trade_exchange,
|
||||
pred_date=pred_date,
|
||||
trade_date=trade_date,
|
||||
)
|
||||
else:
|
||||
order_list = []
|
||||
if return_order:
|
||||
multi_order_list.append((trade_account, order_list, trade_date))
|
||||
# 4. Get result after executing order list
|
||||
# NOTE: The following operation will modify order.amount.
|
||||
# NOTE: If it is buy and the cash is insufficient, the tradable amount will be recalculated
|
||||
trade_info = executor.execute(trade_account, order_list, trade_date)
|
||||
|
||||
# 5. Update account information according to transaction
|
||||
update_account(trade_account, trade_info, trade_exchange, trade_date)
|
||||
|
||||
# generate backtest report
|
||||
trade_state = self.sub_env.get_init_state()
|
||||
while not trade_env.finished():
|
||||
_order_list = self.sub_strategy.generate_order(**trade_state)
|
||||
trade_state, trade_info = self.sub_env.execute(sub_order_list)
|
||||
|
||||
report_df = trade_account.report.generate_report_dataframe()
|
||||
report_df["bench"] = bench
|
||||
positions = trade_account.get_positions()
|
||||
|
||||
report_dict = {"report_df": report_df, "positions": positions}
|
||||
if return_order:
|
||||
report_dict.update({"order_list": multi_order_list})
|
||||
|
||||
return report_dict
|
||||
|
||||
|
||||
def update_account(trade_account, trade_info, trade_exchange, trade_date):
|
||||
"""Update the account and strategy
|
||||
Parameters
|
||||
----------
|
||||
trade_account : Account()
|
||||
trade_info : list of [Order(), float, float, float]
|
||||
(order, trade_val, trade_cost, trade_price), trade_info with out factor
|
||||
trade_exchange : Exchange()
|
||||
used to get the $close_price at trade_date to update account
|
||||
trade_date : pd.Timestamp
|
||||
"""
|
||||
# update account
|
||||
for [order, trade_val, trade_cost, trade_price] in trade_info:
|
||||
if order.deal_amount == 0:
|
||||
continue
|
||||
trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)
|
||||
# at the end of trade date, update the account based the $close_price of stocks.
|
||||
trade_account.update_daily_end(today=trade_date, trader=trade_exchange)
|
||||
|
||||
@@ -7,13 +7,54 @@ import warnings
|
||||
import pathlib
|
||||
import pandas as pd
|
||||
from loguru import Logger
|
||||
from ...data import D
|
||||
from ...data import D, Cal
|
||||
from ...utils import get_date_in_file_name
|
||||
from ...utils import get_pre_trading_date
|
||||
from ..backtest.order import Order
|
||||
from ..utils import init_instance_by_config
|
||||
class TradeCalendarBase:
|
||||
|
||||
class BaseEnv:
|
||||
def _reset_trade_calendar(self, start_time, end_time):
|
||||
if start_time:
|
||||
self.start_time = pd.Timestamp(start_time)
|
||||
if end_time:
|
||||
self.end_time = pd.Timestamp(end_time)
|
||||
if self.start_time and self.end_time:
|
||||
_calendar, freq, freq_sam = get_sample_freq_calendar(freq=step_bar)
|
||||
self.calendar = _calendar
|
||||
_start_time, _end_time, _start_index, _end_index = Cal.locate_index(self.start_time, self.end_time, freq=freq, freq_sam=freq_sam)
|
||||
_trade_calendar = self.calendar[_start_index, _end_index + 1]
|
||||
if _start_time != self.start_time:
|
||||
self.trade_calendar = np.hstack((self.start_time, _trade_calendar, self.end_time))
|
||||
self.start_index = _start_index - 1
|
||||
else:
|
||||
self.trade_calendar = np.hstack((_trade_calendar, self.end_time))
|
||||
self.start_index = _start_index
|
||||
self.end_index = _end_index
|
||||
self.trade_index = 0
|
||||
self.trade_len = len(self.trade_calendar)
|
||||
else:
|
||||
raise ValueError("failed to reset trade calendar, params `start_time` or `end_time` is None.")
|
||||
|
||||
def _get_trade_time(self, trade_index=1, shift=0):
|
||||
trade_index = trade_index - shift
|
||||
if 0 < trade_index < self.trade_len - 1:
|
||||
trade_start_time = self.trade_calendar[trade_index - 1]
|
||||
trade_end_time = self.trade_calendar[trade_index] - pd.Timestamp(second=1)
|
||||
return trade_start_time, trade_end_time
|
||||
elif trade_index == self.trade_len - 1:
|
||||
trade_start_time = self.trade_calendar[trade_index - 1]
|
||||
trade_end_time = self.trade_calendar[trade_index]
|
||||
return trade_start_time, trade_end_time
|
||||
else:
|
||||
raise RuntimeError("trade_index out of range")
|
||||
|
||||
def _get_calendar_time(self, trade_index=1, shift=1):
|
||||
trade_index = trade_index - shift
|
||||
calendar_index = self.start_index + trade_index
|
||||
return self.calendar[calendar_index - 1], self.calendar[calendar_index]
|
||||
|
||||
class BaseEnv(TradeCalendarBase):
|
||||
"""
|
||||
# Strategy framework document
|
||||
|
||||
@@ -33,38 +74,19 @@ class BaseEnv:
|
||||
self.verbose = verbose
|
||||
self.reset(start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs)
|
||||
|
||||
def _reset_trade_calendar(self, start_time, end_time):
|
||||
if start_time:
|
||||
self.start_time = start_time
|
||||
if end_time:
|
||||
self.end_time = end_time
|
||||
if self.start_time and self.end_time:
|
||||
_calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar)
|
||||
self.trade_calendar = np.hstack(_calendar, pd.Timestamp(self.end_time))
|
||||
self.trade_len = len(self.trade_calendar)
|
||||
self.trade_index = 0
|
||||
else:
|
||||
raise ValueError("failed to reset trade calendar, params `start_time` or `end_time` is None.")
|
||||
|
||||
def _get_position(self):
|
||||
return self.trade_account.current
|
||||
|
||||
def _get_trade_time(self):
|
||||
if 0 < self.trade_index < self.trade_len - 1:
|
||||
trade_start_time = self.trade_calendar[self.trade_index - 1]
|
||||
trade_end_time = self.trade_calendar[self.trade_index] - pd.Timestamp(second=1)
|
||||
return trade_start_time, trade_end_time
|
||||
elif self.trade_index == self.trade_len - 1:
|
||||
trade_start_time = self.trade_calendar[self.trade_index - 1]
|
||||
trade_end_time = self.trade_calendar[self.trade_index]
|
||||
return trade_start_time, trade_end_time
|
||||
else:
|
||||
raise RuntimeError("trade_index out of range")
|
||||
|
||||
|
||||
def reset(self, start_time=None, end_time=None, trade_account=None, **kwargs):
|
||||
if start_time or end_time:
|
||||
self._reset_trade_calendar(start_time=start_time, end_time=end_time)
|
||||
self.trade_account = trade_account
|
||||
|
||||
for k, v in kwargs:
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
def get_first_state(self):
|
||||
init_state = {"current": self._get_position()}
|
||||
@@ -101,10 +123,10 @@ class SplitEnv(BaseEnv):
|
||||
# yield action
|
||||
#episode_reward = 0
|
||||
super(SimulatorEnv, self).execute(**kwargs)
|
||||
trade_start_time, trade_end_time = self._get_trade_time()
|
||||
trade_start_time, trade_end_time = self._get_trade_time(trade_index=self.trade_index)
|
||||
self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time, trade_account=self.trade_account)
|
||||
self.sub_strategy.reset(start_time=trade_start_time, end_time=trade_end_time, trade_order_list=order_list)
|
||||
trade_state = self.sub_env.get_first_state()
|
||||
trade_state = self.sub_env.get_init_state()
|
||||
while not self.sub_env.finished():
|
||||
_order_list = self.sub_strategy.generate_order(**trade_state)
|
||||
trade_state, trade_info = self.sub_env.execute(order_list=_order_list)
|
||||
@@ -140,7 +162,7 @@ class SimulatorEnv(BaseEnv):
|
||||
if self.finished():
|
||||
raise StopIteration(f"this env has completed its task, please reset it if you want to call it!")
|
||||
super(SimulatorEnv, self).execute(**kwargs)
|
||||
ttrade_start_time, trade_end_time = self._get_trade_time()
|
||||
ttrade_start_time, trade_end_time = self._get_trade_time(trade_index=self.trade_index)
|
||||
trade_info = []
|
||||
for order in order_list:
|
||||
if self.trade_exchange.check_order(order) is True:
|
||||
|
||||
132
qlib/backtest/init.py
Normal file
132
qlib/backtest/init.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .order import Order
|
||||
from .account import Account
|
||||
from .position import Position
|
||||
from .exchange import Exchange
|
||||
from .report import Report
|
||||
from .backtest import backtest as backtest_func, get_date_range
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import inspect
|
||||
from ..utils import init_instance_by_config
|
||||
from ..log import get_module_logger
|
||||
from ..config import C
|
||||
|
||||
logger = get_module_logger("backtest caller")
|
||||
|
||||
|
||||
def init_env_instance_by_config(env):
|
||||
if isinstance(env, dict):
|
||||
env_config = copy.copy(env)
|
||||
if "kwargs" in env_config:
|
||||
env_kwargs = copy.copy(env_config["kwargs"]):
|
||||
if "sub_env" in env_kwargs:
|
||||
env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"])
|
||||
if "sub_strategy" in env_kwargs:
|
||||
env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"])
|
||||
env_config["kwargs"] = env_kwargs
|
||||
return init_instance_by_config(env_config)
|
||||
else:
|
||||
return env
|
||||
|
||||
|
||||
def get_exchange(
|
||||
exchange=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
codes = "all",
|
||||
subscribe_fields=[],
|
||||
open_cost=0.0015,
|
||||
close_cost=0.0025,
|
||||
min_cost=5.0,
|
||||
trade_unit=None,
|
||||
limit_threshold=None,
|
||||
deal_price=None,
|
||||
shift=1,
|
||||
):
|
||||
"""get_exchange
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
# exchange related arguments
|
||||
exchange: Exchange().
|
||||
subscribe_fields: list
|
||||
subscribe fields.
|
||||
open_cost : float
|
||||
open transaction cost.
|
||||
close_cost : float
|
||||
close transaction cost.
|
||||
min_cost : float
|
||||
min transaction cost.
|
||||
trade_unit : int
|
||||
100 for China A.
|
||||
deal_price: str
|
||||
dealing price type: 'close', 'open', 'vwap'.
|
||||
limit_threshold : float
|
||||
limit move 0.1 (10%) for example, long and short with same limit.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: Exchange
|
||||
an initialized Exchange object
|
||||
"""
|
||||
|
||||
if trade_unit is None:
|
||||
trade_unit = C.trade_unit
|
||||
if limit_threshold is None:
|
||||
limit_threshold = C.limit_threshold
|
||||
if deal_price is None:
|
||||
deal_price = C.deal_price
|
||||
if exchange is None:
|
||||
logger.info("Create new exchange")
|
||||
# handle exception for deal_price
|
||||
if deal_price[0] != "$":
|
||||
deal_price = "$" + deal_price
|
||||
|
||||
exchange = Exchange(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
codes=codes,
|
||||
deal_price=deal_price,
|
||||
subscribe_fields=subscribe_fields,
|
||||
limit_threshold=limit_threshold,
|
||||
open_cost=open_cost,
|
||||
close_cost=close_cost,
|
||||
trade_unit=trade_unit,
|
||||
min_cost=min_cost,
|
||||
)
|
||||
else:
|
||||
return init_instance_by_config(exchange, accept_types=Exchange)
|
||||
|
||||
def backtest(start_time, end_time, strategy, env, account=1e9, **kwargs):
|
||||
trade_strategy = init_instance_by_config(strategy)
|
||||
trade_env = init_env_instance_by_config(env)
|
||||
trade_account = Account(init_cash=account)
|
||||
|
||||
spec = inspect.getfullargspec(get_exchange)
|
||||
ex_args = {k: v for k, v in kwargs.items() if k in spec.args}
|
||||
trade_exchange = get_exchange(pred, **ex_args)
|
||||
|
||||
# temp_env = trade_env
|
||||
# while True:
|
||||
# if hasattr(temp_env, "trade_exchange"):
|
||||
# temp_env.reset(trade_exchange=trade_exchange)
|
||||
# if hasattr(temp_env, "sub_env"):
|
||||
# temp_env = temp_env.sub_env
|
||||
# else:
|
||||
# break
|
||||
|
||||
trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account)
|
||||
trade_state, _reset_info = self.sub_env.get_first_state()
|
||||
trade_strategy.reset(**_reset_info)
|
||||
|
||||
|
||||
while not trade_env.finished():
|
||||
_order_list = self.sub_strategy.generate_order(**trade_state)
|
||||
trade_state, trade_info = self.sub_env.execute(sub_order_list)
|
||||
|
||||
return
|
||||
@@ -15,11 +15,11 @@ class TopkDropoutStrategy(DLStrategy):
|
||||
step_bar,
|
||||
model,
|
||||
dataset,
|
||||
trade_exchange,
|
||||
topk,
|
||||
n_drop,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trade_exchange=None,
|
||||
method_sell="bottom",
|
||||
method_buy="top",
|
||||
risk_degree=0.95,
|
||||
@@ -54,7 +54,6 @@ class TopkDropoutStrategy(DLStrategy):
|
||||
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
|
||||
"""
|
||||
super(TopkDropoutStrategy, self).__init__(step_bar, model, dataset, start_time, end_time)
|
||||
self.trade_exchange = trade_exchange
|
||||
self.topk = topk
|
||||
self.n_drop = n_drop
|
||||
self.method_sell = method_sell
|
||||
@@ -68,6 +67,10 @@ class TopkDropoutStrategy(DLStrategy):
|
||||
self.only_tradable = only_tradable
|
||||
|
||||
|
||||
def reset(trade_exchange=None, **kwargs):
|
||||
super(TopkDropoutStrategy, self).reset(**kwargs)
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def get_risk_degree(self, trade_index):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
@@ -78,8 +81,8 @@ class TopkDropoutStrategy(DLStrategy):
|
||||
|
||||
def generate_order_list(self, current, **kwargs):
|
||||
super(TopkDropoutStrategy, self).generate_order_list()
|
||||
trade_start_time, trade_end_time = self._get_trade_time()
|
||||
pred_start_time, pred_end_time = self._get_last_trade_time()
|
||||
trade_start_time, trade_end_time = self._get_trade_time(self.trade_index)
|
||||
pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1)
|
||||
pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
if self.only_tradable:
|
||||
# If The strategy only consider tradable stock when make decision
|
||||
@@ -268,7 +271,7 @@ class WeightStrategyBase(DLStrategy):
|
||||
# generate_order_list
|
||||
# generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list
|
||||
super(WeightStrategyBase, self).generate_order_list()
|
||||
trade_start_time, trade_end_time = self._get_trade_time()
|
||||
trade_start_time, trade_end_time = self._get_trade_time(self.trade_index)
|
||||
pred_start_time, pred_end_time = self._get_pred_time()
|
||||
pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
current_temp = copy.deepcopy(trade_account.current)
|
||||
|
||||
@@ -57,7 +57,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
|
||||
def generate_order_list(self, **kwargs):
|
||||
super(SBBStrategyBase, self).generate_order_list()
|
||||
trade_start_time, trade_end_time = self._get_trade_time()
|
||||
pred_start_time, pred_end_time = self._get_last_trade_time()
|
||||
pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1)
|
||||
order_list = []
|
||||
for order in self.trade_order_list:
|
||||
if self.trade_index % 2 == 1:
|
||||
@@ -127,8 +127,9 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
|
||||
def _reset_trade_calendar(self, start_time=None, end_time=None, _calendar=None):
|
||||
super(SBBStrategyEMA, self)._reset_trade_calendar(start_time=start_time, end_time=end_time, _calendar=_calendar)
|
||||
fields = [("EMA...", "signal")]
|
||||
self.signal = D.features(instruments, fields, start_time=self.start_time, end_time=self.end_time, freq=self.freq)
|
||||
fields = [("EMA($close, 10) - EMA($close, 20)", "signal")]
|
||||
signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1)
|
||||
self.signal = D.features(instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq)
|
||||
|
||||
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
|
||||
_sample_signal = sample_feature(self.signal, stock_id, start_time=pred_start_time, end_time=pred_end_time, fields="signal", method="last")
|
||||
|
||||
@@ -114,11 +114,11 @@ class CalendarProvider(abc.ABC):
|
||||
dict
|
||||
dict composed by timestamp as key and index as value for fast search.
|
||||
"""
|
||||
flag = f"{freq}_future_{future}_sam_{freq_sam}"
|
||||
flag = f"{freq}_sam_{freq_sam}_future_{future}"
|
||||
if flag in H["c"]:
|
||||
_calendar, _calendar_index = H["c"][flag]
|
||||
else:
|
||||
flag_raw = f"{freq}_future_{future}_sam_{None}"
|
||||
flag_raw = f"{freq}_sam_{None}_future_{future}"
|
||||
if flag_raw in H["c"]:
|
||||
_calendar, _calendar_index = H["c"][flag_raw]
|
||||
else:
|
||||
|
||||
@@ -8,41 +8,30 @@ import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
from ..utils import sample_feature, get_sample_freq_calendar
|
||||
from ..utils import get_sample_freq_calendar
|
||||
from ..data.dataset import DatasetH
|
||||
from ..backtest.order import Order
|
||||
from .order_generator import OrderGenWInteract
|
||||
from ..data.data import D
|
||||
from ..backtest.env import TradeCalendarBase
|
||||
|
||||
"""
|
||||
1. BaseStrategy 的粒度一定是数据粒度的整数倍
|
||||
- 关于calendar的合并咋整
|
||||
- adjust_dates这个东西啥用
|
||||
- label和freq和strategy的bar分离,这个如何决策呢
|
||||
"""
|
||||
class BaseStrategy:
|
||||
class BaseStrategy(TradeCalendarBase):
|
||||
def __init__(self, step_bar, start_time=None, end_time=None, **kwargs):
|
||||
self.step_bar = step_bar
|
||||
self.reset(start_time=start_time, end_time=end_time, **kwargs)
|
||||
|
||||
def _reset_trade_calendar(self, start_time, end_time, _calendar=None):
|
||||
if start_time:
|
||||
self.start_time = start_time
|
||||
if end_time:
|
||||
self.end_time = end_time
|
||||
if self.start_time and self.end_time:
|
||||
if not _calendar:
|
||||
_calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar)
|
||||
self.trade_calendar = np.hstack(_calendar, pd.Timestamp(self.end_time))
|
||||
else:
|
||||
self.trade_calendar = _calendar
|
||||
self.trade_len = len(self.trade_calendar)
|
||||
self.trade_index = 0
|
||||
else:
|
||||
raise ValueError("failed to reset trade calendar, params `start_time` or `end_time` is None.")
|
||||
|
||||
def reset(self, start_time=None, end_time=None, _calendar=None):
|
||||
def reset(self, start_time=None, end_time=None, _calendar=None, **kwargs):
|
||||
if start_time or end_time :
|
||||
self._reset_trade_calendar(start_time=start_time, end_time=end_time, calendar=calendar)
|
||||
|
||||
for k, v in kwargs:
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
def _get_trade_time(self):
|
||||
if 0 < self.trade_index < self.trade_len - 1:
|
||||
@@ -56,13 +45,6 @@ class BaseStrategy:
|
||||
else:
|
||||
raise RuntimeError("trade_index out of range")
|
||||
|
||||
def _get_last_trade_time(self, shift=1):
|
||||
if self.trade_index - shift < 0:
|
||||
return None, None
|
||||
elif self.trade_index - shift == 0:
|
||||
return None, self.trade_index[self.trade_index - shift]
|
||||
else:
|
||||
return self.trade_index[self.trade_index - shift - 1], self.trade_index[self.trade_index - shift]
|
||||
def generate_order_list(self, **kwargs):
|
||||
self.trade_index = self.trade_index + 1
|
||||
|
||||
|
||||
@@ -918,20 +918,25 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam):
|
||||
else:
|
||||
raise ValueError("sample freq must be xmin, xd, xw, xm")
|
||||
|
||||
def get_sample_freq_calendar(start_time, end_time, freq):
|
||||
def get_sample_freq_calendar(start_time=None, end_time=None, freq, **kwargs):
|
||||
try:
|
||||
_calendar = D.calendar(start_time=start_time, end_time=end_time, freq=freq)
|
||||
_calendar = D.calendar(start_time=start_time, end_time=end_time, freq=freq, **kwargs)
|
||||
freq, freq_sam = freq, None
|
||||
except ValueError:
|
||||
freq_sam = freq
|
||||
if freq.endswith(("m", "month", "w", "week", "d", "day")):
|
||||
try:
|
||||
_calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="min", freq_sam=freq)
|
||||
_calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="min", freq_sam=freq, **kwargs)
|
||||
freq = "min"
|
||||
except ValueError:
|
||||
_calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="day", freq_sam=freq)
|
||||
_calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="day", freq_sam=freq, **kwargs)
|
||||
freq = "day"
|
||||
elif freq.endswith(("min", "minute")):
|
||||
_calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="min", freq_sam=freq)
|
||||
_calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="min", freq_sam=freq, **kwargs)
|
||||
freq = "min"
|
||||
else:
|
||||
raise ValueError(f"freq {freq} is not supported")
|
||||
return _calendar
|
||||
return _calendar, freq, freq_sam
|
||||
|
||||
def sample_feature(feature, instruments=None, start_time=None, end_time=None, fields=None, method=None, method_kwargs={}):
|
||||
if instruments and type(instruments) is not list:
|
||||
|
||||
Reference in New Issue
Block a user