1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 11:30:57 +08:00

update trade calendar & backtest workflow

This commit is contained in:
bxdd
2021-04-24 02:29:42 +08:00
parent 39deb7d27f
commit b14efa1129
10 changed files with 263 additions and 254 deletions

View File

@@ -10,10 +10,7 @@ from qlib.config import REG_CN
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from qlib.backtest import backtest
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
@@ -124,12 +121,4 @@ if __name__ == "__main__":
}
# prediction
recorder = R.get_recorder()
sr = SignalRecord(model, dataset, recorder)
sr.generate()
# backtest. If users want to use backtest based on their own prediction,
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
par = PortAnaRecord(recorder, port_analysis_config)
par.generate()
backtest(**backtest_config, )

View File

@@ -2,11 +2,10 @@
# Licensed under the MIT License.
from .order import Order
from .account import Account
from .position import Position
from .exchange import Exchange
from .report import Report
from .backtest import backtest as backtest_func, get_date_range
from .backtest import backtest as backtest_func
import copy
import numpy as np
@@ -18,21 +17,6 @@ from ..config import C
logger = get_module_logger("backtest caller")
def init_env_instance_by_config(env):
if isinstance(env, dict):
env_config = copy.copy(env)
if "kwargs" in env_config:
env_kwargs = copy.copy(env_config["kwargs"]):
if "sub_env" in env_kwargs:
env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"])
if "sub_strategy" in env_kwargs:
env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"])
env_config["kwargs"] = env_kwargs
return init_instance_by_config(env_config)
else:
return env
def get_exchange(
pred,
exchange=None,
@@ -103,36 +87,44 @@ def get_exchange(
else:
return init_instance_by_config(exchange, accept_types=Exchange)
def backtest(start_time, end_time, strategy, env, account=1e9, benchmark, **kwargs):
def init_env_instance_by_config(env):
if isinstance(env, dict):
env_config = copy.copy(env)
if "kwargs" in env_config:
env_kwargs = copy.copy(env_config["kwargs"]):
if "sub_env" in env_kwargs:
env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"])
if "sub_strategy" in env_kwargs:
env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"])
env_config["kwargs"] = env_kwargs
return init_instance_by_config(env_config)
else:
return env
def setup_exchange(root_instance, trade_exchange=None, force=False):
if "trade_exchange" in inspect.getfullargspec(root_instance.__class__).args:
if force:
root_instance.reset(trade_exchange=trade_exchange)
else:
if not hasattr(root_instance, "trade_exchange") or root_instance.trade_exchange is None:
root_instance.reset(trade_exchange=trade_exchange)
if hasattr(root_instance, "sub_env"):
setup_exchange(root_instance.sub_env, trade_exchange)
if hasattr(root_instance, "sub_strategy"):
setup_exchange(root_instance.sub_strategy, trade_exchange)
def backtest(start_time, end_time, strategy, env, benchmark, account=1e9, **kwargs):
trade_strategy = init_instance_by_config(strategy)
trade_env = init_env_instance_by_config(env)
trade_account = Account(init_cash=account)
spec = inspect.getfullargspec(get_exchange)
ex_args = {k: v for k, v in kwargs.items() if k in spec.args}
trade_exchange = get_exchange(pred, **ex_args)
temp_env = trade_env
while True:
if hasattr(temp_env, "trade_exchange"):
temp_env.reset(trade_exchange=trade_exchange)
if hasattr(temp_env, "sub_env"):
temp_env = temp_env.sub_env
else:
break
trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account)
trade_strategy.reset(start_time=start_time, end_time=end_time)
trade_state = self.sub_env.get_first_state()
while not trade_env.finished():
_order_list = self.sub_strategy.generate_order(**trade_state)
trade_state, trade_info = self.sub_env.execute(sub_order_list)
report_df = trade_account.report.generate_report_dataframe()
positions = trade_account.get_positions()
setup_exchange(trade_env, trade_exchange)
setup_exchange(trade_strategy, trade_exchange)
report_dict = {"report_df": report_df, "positions": positions}
report_dict = backtest_func(start_time, end_time, trade_strategy, trade_env, benchmark, account)
return
return report_dict

View File

@@ -4,140 +4,23 @@
import numpy as np
import pandas as pd
from ...utils import get_date_by_shift, get_date_range
from ...data import D
from .account import Account
from ...config import C
from ...log import get_module_logger
from ...data.dataset.utils import get_level_index
LOG = get_module_logger("backtest")
def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, benchmark, return_order):
"""Parameters
----------
pred : pandas.DataFrame
predict should has <datetime, instrument> index and one `score` column
Qlib want to support multi-singal strategy in the future. So pd.Series is not used.
strategy : Strategy()
strategy part for backtest
trade_exchange : Exchange()
exchage for backtest
shift : int
whether to shift prediction by one day
verbose : bool
whether to print log
account : float
init account value
benchmark : str/list/pd.Series
`benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
example:
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
2017-01-04 0.011693
2017-01-05 0.000721
2017-01-06 -0.004322
2017-01-09 0.006874
2017-01-10 -0.003350
`benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
`benchmark` is str, will use the daily change as the 'bench'.
benchmark code, default is SH000905 CSI500
"""
# Convert format if the input format is not expected
if get_level_index(pred, level="datetime") == 1:
pred = pred.swaplevel().sort_index()
if isinstance(pred, pd.Series):
pred = pred.to_frame("score")
def backtest(trade_strategy, trade_env, benchmark, account):
trade_account = Account(init_cash=account)
_pred_dates = pred.index.get_level_values(level="datetime")
predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max())
if isinstance(benchmark, pd.Series):
bench = benchmark
else:
_codes = benchmark if isinstance(benchmark, list) else [benchmark]
_temp_result = D.features(
_codes,
["$close/Ref($close,1)-1"],
predict_dates[0],
get_date_by_shift(predict_dates[-1], shift=shift),
disk_cache=1,
)
if len(_temp_result) == 0:
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
bench = _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean()
trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account)
trade_strategy.reset(start_time=start_time, end_time=end_time)
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift))
if return_order:
multi_order_list = []
# trading apart
for pred_date, trade_date in zip(predict_dates, trade_dates):
# for loop predict date and trading date
# print
if verbose:
LOG.info("[I {:%Y-%m-%d}]: trade begin.".format(trade_date))
# 1. Load the score_series at pred_date
try:
score = pred.loc(axis=0)[pred_date, :] # (trade_date, stock_id) multi_index, score in pdate
score_series = score.reset_index(level="datetime", drop=True)[
"score"
] # pd.Series(index:stock_id, data: score)
except KeyError:
LOG.warning("No score found on predict date[{:%Y-%m-%d}]".format(trade_date))
score_series = None
if score_series is not None and score_series.count() > 0: # in case of the scores are all None
# 2. Update your strategy (and model)
strategy.update(score_series, pred_date, trade_date)
# 3. Generate order list
order_list = strategy.generate_order_list(
score_series=score_series,
current=trade_account.current,
trade_exchange=trade_exchange,
pred_date=pred_date,
trade_date=trade_date,
)
else:
order_list = []
if return_order:
multi_order_list.append((trade_account, order_list, trade_date))
# 4. Get result after executing order list
# NOTE: The following operation will modify order.amount.
# NOTE: If it is buy and the cash is insufficient, the tradable amount will be recalculated
trade_info = executor.execute(trade_account, order_list, trade_date)
# 5. Update account information according to transaction
update_account(trade_account, trade_info, trade_exchange, trade_date)
# generate backtest report
trade_state = self.sub_env.get_init_state()
while not trade_env.finished():
_order_list = self.sub_strategy.generate_order(**trade_state)
trade_state, trade_info = self.sub_env.execute(sub_order_list)
report_df = trade_account.report.generate_report_dataframe()
report_df["bench"] = bench
positions = trade_account.get_positions()
report_dict = {"report_df": report_df, "positions": positions}
if return_order:
report_dict.update({"order_list": multi_order_list})
return report_dict
def update_account(trade_account, trade_info, trade_exchange, trade_date):
"""Update the account and strategy
Parameters
----------
trade_account : Account()
trade_info : list of [Order(), float, float, float]
(order, trade_val, trade_cost, trade_price), trade_info with out factor
trade_exchange : Exchange()
used to get the $close_price at trade_date to update account
trade_date : pd.Timestamp
"""
# update account
for [order, trade_val, trade_cost, trade_price] in trade_info:
if order.deal_amount == 0:
continue
trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)
# at the end of trade date, update the account based the $close_price of stocks.
trade_account.update_daily_end(today=trade_date, trader=trade_exchange)

View File

@@ -7,13 +7,54 @@ import warnings
import pathlib
import pandas as pd
from loguru import Logger
from ...data import D
from ...data import D, Cal
from ...utils import get_date_in_file_name
from ...utils import get_pre_trading_date
from ..backtest.order import Order
from ..utils import init_instance_by_config
class TradeCalendarBase:
class BaseEnv:
def _reset_trade_calendar(self, start_time, end_time):
if start_time:
self.start_time = pd.Timestamp(start_time)
if end_time:
self.end_time = pd.Timestamp(end_time)
if self.start_time and self.end_time:
_calendar, freq, freq_sam = get_sample_freq_calendar(freq=step_bar)
self.calendar = _calendar
_start_time, _end_time, _start_index, _end_index = Cal.locate_index(self.start_time, self.end_time, freq=freq, freq_sam=freq_sam)
_trade_calendar = self.calendar[_start_index, _end_index + 1]
if _start_time != self.start_time:
self.trade_calendar = np.hstack((self.start_time, _trade_calendar, self.end_time))
self.start_index = _start_index - 1
else:
self.trade_calendar = np.hstack((_trade_calendar, self.end_time))
self.start_index = _start_index
self.end_index = _end_index
self.trade_index = 0
self.trade_len = len(self.trade_calendar)
else:
raise ValueError("failed to reset trade calendar, params `start_time` or `end_time` is None.")
def _get_trade_time(self, trade_index=1, shift=0):
trade_index = trade_index - shift
if 0 < trade_index < self.trade_len - 1:
trade_start_time = self.trade_calendar[trade_index - 1]
trade_end_time = self.trade_calendar[trade_index] - pd.Timestamp(second=1)
return trade_start_time, trade_end_time
elif trade_index == self.trade_len - 1:
trade_start_time = self.trade_calendar[trade_index - 1]
trade_end_time = self.trade_calendar[trade_index]
return trade_start_time, trade_end_time
else:
raise RuntimeError("trade_index out of range")
def _get_calendar_time(self, trade_index=1, shift=1):
trade_index = trade_index - shift
calendar_index = self.start_index + trade_index
return self.calendar[calendar_index - 1], self.calendar[calendar_index]
class BaseEnv(TradeCalendarBase):
"""
# Strategy framework document
@@ -33,38 +74,19 @@ class BaseEnv:
self.verbose = verbose
self.reset(start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs)
def _reset_trade_calendar(self, start_time, end_time):
if start_time:
self.start_time = start_time
if end_time:
self.end_time = end_time
if self.start_time and self.end_time:
_calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar)
self.trade_calendar = np.hstack(_calendar, pd.Timestamp(self.end_time))
self.trade_len = len(self.trade_calendar)
self.trade_index = 0
else:
raise ValueError("failed to reset trade calendar, params `start_time` or `end_time` is None.")
def _get_position(self):
return self.trade_account.current
def _get_trade_time(self):
if 0 < self.trade_index < self.trade_len - 1:
trade_start_time = self.trade_calendar[self.trade_index - 1]
trade_end_time = self.trade_calendar[self.trade_index] - pd.Timestamp(second=1)
return trade_start_time, trade_end_time
elif self.trade_index == self.trade_len - 1:
trade_start_time = self.trade_calendar[self.trade_index - 1]
trade_end_time = self.trade_calendar[self.trade_index]
return trade_start_time, trade_end_time
else:
raise RuntimeError("trade_index out of range")
def reset(self, start_time=None, end_time=None, trade_account=None, **kwargs):
if start_time or end_time:
self._reset_trade_calendar(start_time=start_time, end_time=end_time)
self.trade_account = trade_account
for k, v in kwargs:
if hasattr(self, k):
setattr(self, k, v)
def get_first_state(self):
init_state = {"current": self._get_position()}
@@ -101,10 +123,10 @@ class SplitEnv(BaseEnv):
# yield action
#episode_reward = 0
super(SimulatorEnv, self).execute(**kwargs)
trade_start_time, trade_end_time = self._get_trade_time()
trade_start_time, trade_end_time = self._get_trade_time(trade_index=self.trade_index)
self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time, trade_account=self.trade_account)
self.sub_strategy.reset(start_time=trade_start_time, end_time=trade_end_time, trade_order_list=order_list)
trade_state = self.sub_env.get_first_state()
trade_state = self.sub_env.get_init_state()
while not self.sub_env.finished():
_order_list = self.sub_strategy.generate_order(**trade_state)
trade_state, trade_info = self.sub_env.execute(order_list=_order_list)
@@ -140,7 +162,7 @@ class SimulatorEnv(BaseEnv):
if self.finished():
raise StopIteration(f"this env has completed its task, please reset it if you want to call it!")
super(SimulatorEnv, self).execute(**kwargs)
ttrade_start_time, trade_end_time = self._get_trade_time()
ttrade_start_time, trade_end_time = self._get_trade_time(trade_index=self.trade_index)
trade_info = []
for order in order_list:
if self.trade_exchange.check_order(order) is True:

132
qlib/backtest/init.py Normal file
View File

@@ -0,0 +1,132 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .order import Order
from .account import Account
from .position import Position
from .exchange import Exchange
from .report import Report
from .backtest import backtest as backtest_func, get_date_range
import copy
import numpy as np
import inspect
from ..utils import init_instance_by_config
from ..log import get_module_logger
from ..config import C
logger = get_module_logger("backtest caller")
def init_env_instance_by_config(env):
if isinstance(env, dict):
env_config = copy.copy(env)
if "kwargs" in env_config:
env_kwargs = copy.copy(env_config["kwargs"]):
if "sub_env" in env_kwargs:
env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"])
if "sub_strategy" in env_kwargs:
env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"])
env_config["kwargs"] = env_kwargs
return init_instance_by_config(env_config)
else:
return env
def get_exchange(
exchange=None,
start_time=None,
end_time=None,
codes = "all",
subscribe_fields=[],
open_cost=0.0015,
close_cost=0.0025,
min_cost=5.0,
trade_unit=None,
limit_threshold=None,
deal_price=None,
shift=1,
):
"""get_exchange
Parameters
----------
# exchange related arguments
exchange: Exchange().
subscribe_fields: list
subscribe fields.
open_cost : float
open transaction cost.
close_cost : float
close transaction cost.
min_cost : float
min transaction cost.
trade_unit : int
100 for China A.
deal_price: str
dealing price type: 'close', 'open', 'vwap'.
limit_threshold : float
limit move 0.1 (10%) for example, long and short with same limit.
Returns
-------
:class: Exchange
an initialized Exchange object
"""
if trade_unit is None:
trade_unit = C.trade_unit
if limit_threshold is None:
limit_threshold = C.limit_threshold
if deal_price is None:
deal_price = C.deal_price
if exchange is None:
logger.info("Create new exchange")
# handle exception for deal_price
if deal_price[0] != "$":
deal_price = "$" + deal_price
exchange = Exchange(
start_time=start_time,
end_time=end_time,
codes=codes,
deal_price=deal_price,
subscribe_fields=subscribe_fields,
limit_threshold=limit_threshold,
open_cost=open_cost,
close_cost=close_cost,
trade_unit=trade_unit,
min_cost=min_cost,
)
else:
return init_instance_by_config(exchange, accept_types=Exchange)
def backtest(start_time, end_time, strategy, env, account=1e9, **kwargs):
trade_strategy = init_instance_by_config(strategy)
trade_env = init_env_instance_by_config(env)
trade_account = Account(init_cash=account)
spec = inspect.getfullargspec(get_exchange)
ex_args = {k: v for k, v in kwargs.items() if k in spec.args}
trade_exchange = get_exchange(pred, **ex_args)
# temp_env = trade_env
# while True:
# if hasattr(temp_env, "trade_exchange"):
# temp_env.reset(trade_exchange=trade_exchange)
# if hasattr(temp_env, "sub_env"):
# temp_env = temp_env.sub_env
# else:
# break
trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account)
trade_state, _reset_info = self.sub_env.get_first_state()
trade_strategy.reset(**_reset_info)
while not trade_env.finished():
_order_list = self.sub_strategy.generate_order(**trade_state)
trade_state, trade_info = self.sub_env.execute(sub_order_list)
return

View File

@@ -15,11 +15,11 @@ class TopkDropoutStrategy(DLStrategy):
step_bar,
model,
dataset,
trade_exchange,
topk,
n_drop,
start_time=None,
end_time=None,
trade_exchange=None,
method_sell="bottom",
method_buy="top",
risk_degree=0.95,
@@ -54,7 +54,6 @@ class TopkDropoutStrategy(DLStrategy):
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
"""
super(TopkDropoutStrategy, self).__init__(step_bar, model, dataset, start_time, end_time)
self.trade_exchange = trade_exchange
self.topk = topk
self.n_drop = n_drop
self.method_sell = method_sell
@@ -68,6 +67,10 @@ class TopkDropoutStrategy(DLStrategy):
self.only_tradable = only_tradable
def reset(trade_exchange=None, **kwargs):
super(TopkDropoutStrategy, self).reset(**kwargs)
self.trade_exchange = trade_exchange
def get_risk_degree(self, trade_index):
"""get_risk_degree
Return the proportion of your total value you will used in investment.
@@ -78,8 +81,8 @@ class TopkDropoutStrategy(DLStrategy):
def generate_order_list(self, current, **kwargs):
super(TopkDropoutStrategy, self).generate_order_list()
trade_start_time, trade_end_time = self._get_trade_time()
pred_start_time, pred_end_time = self._get_last_trade_time()
trade_start_time, trade_end_time = self._get_trade_time(self.trade_index)
pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1)
pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
if self.only_tradable:
# If The strategy only consider tradable stock when make decision
@@ -268,7 +271,7 @@ class WeightStrategyBase(DLStrategy):
# generate_order_list
# generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list
super(WeightStrategyBase, self).generate_order_list()
trade_start_time, trade_end_time = self._get_trade_time()
trade_start_time, trade_end_time = self._get_trade_time(self.trade_index)
pred_start_time, pred_end_time = self._get_pred_time()
pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
current_temp = copy.deepcopy(trade_account.current)

View File

@@ -57,7 +57,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
def generate_order_list(self, **kwargs):
super(SBBStrategyBase, self).generate_order_list()
trade_start_time, trade_end_time = self._get_trade_time()
pred_start_time, pred_end_time = self._get_last_trade_time()
pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1)
order_list = []
for order in self.trade_order_list:
if self.trade_index % 2 == 1:
@@ -127,8 +127,9 @@ class SBBStrategyEMA(SBBStrategyBase):
def _reset_trade_calendar(self, start_time=None, end_time=None, _calendar=None):
super(SBBStrategyEMA, self)._reset_trade_calendar(start_time=start_time, end_time=end_time, _calendar=_calendar)
fields = [("EMA...", "signal")]
self.signal = D.features(instruments, fields, start_time=self.start_time, end_time=self.end_time, freq=self.freq)
fields = [("EMA($close, 10) - EMA($close, 20)", "signal")]
signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1)
self.signal = D.features(instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq)
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
_sample_signal = sample_feature(self.signal, stock_id, start_time=pred_start_time, end_time=pred_end_time, fields="signal", method="last")

View File

@@ -114,11 +114,11 @@ class CalendarProvider(abc.ABC):
dict
dict composed by timestamp as key and index as value for fast search.
"""
flag = f"{freq}_future_{future}_sam_{freq_sam}"
flag = f"{freq}_sam_{freq_sam}_future_{future}"
if flag in H["c"]:
_calendar, _calendar_index = H["c"][flag]
else:
flag_raw = f"{freq}_future_{future}_sam_{None}"
flag_raw = f"{freq}_sam_{None}_future_{future}"
if flag_raw in H["c"]:
_calendar, _calendar_index = H["c"][flag_raw]
else:

View File

@@ -8,41 +8,30 @@ import numpy as np
import pandas as pd
from ..utils import sample_feature, get_sample_freq_calendar
from ..utils import get_sample_freq_calendar
from ..data.dataset import DatasetH
from ..backtest.order import Order
from .order_generator import OrderGenWInteract
from ..data.data import D
from ..backtest.env import TradeCalendarBase
"""
1. BaseStrategy 的粒度一定是数据粒度的整数倍
- 关于calendar的合并咋整
- adjust_dates这个东西啥用
- label和freq和strategy的bar分离这个如何决策呢
"""
class BaseStrategy:
class BaseStrategy(TradeCalendarBase):
def __init__(self, step_bar, start_time=None, end_time=None, **kwargs):
self.step_bar = step_bar
self.reset(start_time=start_time, end_time=end_time, **kwargs)
def _reset_trade_calendar(self, start_time, end_time, _calendar=None):
if start_time:
self.start_time = start_time
if end_time:
self.end_time = end_time
if self.start_time and self.end_time:
if not _calendar:
_calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar)
self.trade_calendar = np.hstack(_calendar, pd.Timestamp(self.end_time))
else:
self.trade_calendar = _calendar
self.trade_len = len(self.trade_calendar)
self.trade_index = 0
else:
raise ValueError("failed to reset trade calendar, params `start_time` or `end_time` is None.")
def reset(self, start_time=None, end_time=None, _calendar=None):
def reset(self, start_time=None, end_time=None, _calendar=None, **kwargs):
if start_time or end_time :
self._reset_trade_calendar(start_time=start_time, end_time=end_time, calendar=calendar)
for k, v in kwargs:
if hasattr(self, k):
setattr(self, k, v)
def _get_trade_time(self):
if 0 < self.trade_index < self.trade_len - 1:
@@ -56,13 +45,6 @@ class BaseStrategy:
else:
raise RuntimeError("trade_index out of range")
def _get_last_trade_time(self, shift=1):
if self.trade_index - shift < 0:
return None, None
elif self.trade_index - shift == 0:
return None, self.trade_index[self.trade_index - shift]
else:
return self.trade_index[self.trade_index - shift - 1], self.trade_index[self.trade_index - shift]
def generate_order_list(self, **kwargs):
self.trade_index = self.trade_index + 1

View File

@@ -918,20 +918,25 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam):
else:
raise ValueError("sample freq must be xmin, xd, xw, xm")
def get_sample_freq_calendar(start_time, end_time, freq):
def get_sample_freq_calendar(start_time=None, end_time=None, freq, **kwargs):
try:
_calendar = D.calendar(start_time=start_time, end_time=end_time, freq=freq)
_calendar = D.calendar(start_time=start_time, end_time=end_time, freq=freq, **kwargs)
freq, freq_sam = freq, None
except ValueError:
freq_sam = freq
if freq.endswith(("m", "month", "w", "week", "d", "day")):
try:
_calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="min", freq_sam=freq)
_calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="min", freq_sam=freq, **kwargs)
freq = "min"
except ValueError:
_calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="day", freq_sam=freq)
_calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="day", freq_sam=freq, **kwargs)
freq = "day"
elif freq.endswith(("min", "minute")):
_calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="min", freq_sam=freq)
_calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="min", freq_sam=freq, **kwargs)
freq = "min"
else:
raise ValueError(f"freq {freq} is not supported")
return _calendar
return _calendar, freq, freq_sam
def sample_feature(feature, instruments=None, start_time=None, end_time=None, fields=None, method=None, method_kwargs={}):
if instruments and type(instruments) is not list: