From b14efa11291895b1e5e0424b504818d6eda09730 Mon Sep 17 00:00:00 2001 From: bxdd Date: Sat, 24 Apr 2021 02:29:42 +0800 Subject: [PATCH] update trade calendar & backtest workflow --- examples/highfreq/backtest/workflow.py | 15 +-- qlib/backtest/__init__.py | 74 ++++++------- qlib/backtest/backtest.py | 137 ++----------------------- qlib/backtest/env.py | 80 +++++++++------ qlib/backtest/init.py | 132 ++++++++++++++++++++++++ qlib/contrib/strategy/dl_strategy.py | 13 ++- qlib/contrib/strategy/rule_strategy.py | 7 +- qlib/data/data.py | 4 +- qlib/strategy/base.py | 38 ++----- qlib/utils/__init__.py | 17 +-- 10 files changed, 263 insertions(+), 254 deletions(-) create mode 100644 qlib/backtest/init.py diff --git a/examples/highfreq/backtest/workflow.py b/examples/highfreq/backtest/workflow.py index cddc78b92..df01e31de 100644 --- a/examples/highfreq/backtest/workflow.py +++ b/examples/highfreq/backtest/workflow.py @@ -10,10 +10,7 @@ from qlib.config import REG_CN from qlib.contrib.model.gbdt import LGBModel from qlib.contrib.data.handler import Alpha158 from qlib.contrib.strategy.strategy import TopkDropoutStrategy -from qlib.contrib.evaluate import ( - backtest as normal_backtest, - risk_analysis, -) +from qlib.backtest import backtest from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord, PortAnaRecord @@ -124,12 +121,4 @@ if __name__ == "__main__": } - # prediction - recorder = R.get_recorder() - sr = SignalRecord(model, dataset, recorder) - sr.generate() - - # backtest. If users want to use backtest based on their own prediction, - # please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template. - par = PortAnaRecord(recorder, port_analysis_config) - par.generate() + backtest(**backtest_config, ) \ No newline at end of file diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index 0afe03ea4..70bc03363 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -2,11 +2,10 @@ # Licensed under the MIT License. from .order import Order -from .account import Account from .position import Position from .exchange import Exchange from .report import Report -from .backtest import backtest as backtest_func, get_date_range +from .backtest import backtest as backtest_func import copy import numpy as np @@ -18,21 +17,6 @@ from ..config import C logger = get_module_logger("backtest caller") -def init_env_instance_by_config(env): - if isinstance(env, dict): - env_config = copy.copy(env) - if "kwargs" in env_config: - env_kwargs = copy.copy(env_config["kwargs"]): - if "sub_env" in env_kwargs: - env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"]) - if "sub_strategy" in env_kwargs: - env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"]) - env_config["kwargs"] = env_kwargs - return init_instance_by_config(env_config) - else: - return env - - def get_exchange( pred, exchange=None, @@ -103,36 +87,44 @@ def get_exchange( else: return init_instance_by_config(exchange, accept_types=Exchange) -def backtest(start_time, end_time, strategy, env, account=1e9, benchmark, **kwargs): +def init_env_instance_by_config(env): + if isinstance(env, dict): + env_config = copy.copy(env) + if "kwargs" in env_config: + env_kwargs = copy.copy(env_config["kwargs"]): + if "sub_env" in env_kwargs: + env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"]) + if "sub_strategy" in env_kwargs: + env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"]) + env_config["kwargs"] = env_kwargs + return init_instance_by_config(env_config) + else: + return env + +def setup_exchange(root_instance, trade_exchange=None, force=False): + if "trade_exchange" in inspect.getfullargspec(root_instance.__class__).args: + if force: + root_instance.reset(trade_exchange=trade_exchange) + else: + if not hasattr(root_instance, "trade_exchange") or root_instance.trade_exchange is None: + root_instance.reset(trade_exchange=trade_exchange) + if hasattr(root_instance, "sub_env"): + setup_exchange(root_instance.sub_env, trade_exchange) + if hasattr(root_instance, "sub_strategy"): + setup_exchange(root_instance.sub_strategy, trade_exchange) + + +def backtest(start_time, end_time, strategy, env, benchmark, account=1e9, **kwargs): trade_strategy = init_instance_by_config(strategy) trade_env = init_env_instance_by_config(env) - trade_account = Account(init_cash=account) spec = inspect.getfullargspec(get_exchange) ex_args = {k: v for k, v in kwargs.items() if k in spec.args} trade_exchange = get_exchange(pred, **ex_args) - temp_env = trade_env - while True: - if hasattr(temp_env, "trade_exchange"): - temp_env.reset(trade_exchange=trade_exchange) - if hasattr(temp_env, "sub_env"): - temp_env = temp_env.sub_env - else: - break - - trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account) - trade_strategy.reset(start_time=start_time, end_time=end_time) - trade_state = self.sub_env.get_first_state() - - - while not trade_env.finished(): - _order_list = self.sub_strategy.generate_order(**trade_state) - trade_state, trade_info = self.sub_env.execute(sub_order_list) - - report_df = trade_account.report.generate_report_dataframe() - positions = trade_account.get_positions() + setup_exchange(trade_env, trade_exchange) + setup_exchange(trade_strategy, trade_exchange) - report_dict = {"report_df": report_df, "positions": positions} + report_dict = backtest_func(start_time, end_time, trade_strategy, trade_env, benchmark, account) - return + return report_dict diff --git a/qlib/backtest/backtest.py b/qlib/backtest/backtest.py index b87d6afe3..cd9539725 100644 --- a/qlib/backtest/backtest.py +++ b/qlib/backtest/backtest.py @@ -4,140 +4,23 @@ import numpy as np import pandas as pd -from ...utils import get_date_by_shift, get_date_range -from ...data import D + from .account import Account -from ...config import C -from ...log import get_module_logger -from ...data.dataset.utils import get_level_index -LOG = get_module_logger("backtest") - - -def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, benchmark, return_order): - """Parameters - ---------- - pred : pandas.DataFrame - predict should has index and one `score` column - Qlib want to support multi-singal strategy in the future. So pd.Series is not used. - strategy : Strategy() - strategy part for backtest - trade_exchange : Exchange() - exchage for backtest - shift : int - whether to shift prediction by one day - verbose : bool - whether to print log - account : float - init account value - benchmark : str/list/pd.Series - `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T. - example: - print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head()) - 2017-01-04 0.011693 - 2017-01-05 0.000721 - 2017-01-06 -0.004322 - 2017-01-09 0.006874 - 2017-01-10 -0.003350 - - `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'. - `benchmark` is str, will use the daily change as the 'bench'. - benchmark code, default is SH000905 CSI500 - """ - # Convert format if the input format is not expected - if get_level_index(pred, level="datetime") == 1: - pred = pred.swaplevel().sort_index() - if isinstance(pred, pd.Series): - pred = pred.to_frame("score") +def backtest(trade_strategy, trade_env, benchmark, account): trade_account = Account(init_cash=account) - _pred_dates = pred.index.get_level_values(level="datetime") - predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max()) - if isinstance(benchmark, pd.Series): - bench = benchmark - else: - _codes = benchmark if isinstance(benchmark, list) else [benchmark] - _temp_result = D.features( - _codes, - ["$close/Ref($close,1)-1"], - predict_dates[0], - get_date_by_shift(predict_dates[-1], shift=shift), - disk_cache=1, - ) - if len(_temp_result) == 0: - raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark") - bench = _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean() + trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account) + trade_strategy.reset(start_time=start_time, end_time=end_time) - trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift)) - if return_order: - multi_order_list = [] - # trading apart - for pred_date, trade_date in zip(predict_dates, trade_dates): - # for loop predict date and trading date - # print - if verbose: - LOG.info("[I {:%Y-%m-%d}]: trade begin.".format(trade_date)) - - # 1. Load the score_series at pred_date - try: - score = pred.loc(axis=0)[pred_date, :] # (trade_date, stock_id) multi_index, score in pdate - score_series = score.reset_index(level="datetime", drop=True)[ - "score" - ] # pd.Series(index:stock_id, data: score) - except KeyError: - LOG.warning("No score found on predict date[{:%Y-%m-%d}]".format(trade_date)) - score_series = None - - if score_series is not None and score_series.count() > 0: # in case of the scores are all None - # 2. Update your strategy (and model) - strategy.update(score_series, pred_date, trade_date) - - # 3. Generate order list - order_list = strategy.generate_order_list( - score_series=score_series, - current=trade_account.current, - trade_exchange=trade_exchange, - pred_date=pred_date, - trade_date=trade_date, - ) - else: - order_list = [] - if return_order: - multi_order_list.append((trade_account, order_list, trade_date)) - # 4. Get result after executing order list - # NOTE: The following operation will modify order.amount. - # NOTE: If it is buy and the cash is insufficient, the tradable amount will be recalculated - trade_info = executor.execute(trade_account, order_list, trade_date) - - # 5. Update account information according to transaction - update_account(trade_account, trade_info, trade_exchange, trade_date) - - # generate backtest report + trade_state = self.sub_env.get_init_state() + while not trade_env.finished(): + _order_list = self.sub_strategy.generate_order(**trade_state) + trade_state, trade_info = self.sub_env.execute(sub_order_list) + report_df = trade_account.report.generate_report_dataframe() - report_df["bench"] = bench positions = trade_account.get_positions() - report_dict = {"report_df": report_df, "positions": positions} - if return_order: - report_dict.update({"order_list": multi_order_list}) + return report_dict - -def update_account(trade_account, trade_info, trade_exchange, trade_date): - """Update the account and strategy - Parameters - ---------- - trade_account : Account() - trade_info : list of [Order(), float, float, float] - (order, trade_val, trade_cost, trade_price), trade_info with out factor - trade_exchange : Exchange() - used to get the $close_price at trade_date to update account - trade_date : pd.Timestamp - """ - # update account - for [order, trade_val, trade_cost, trade_price] in trade_info: - if order.deal_amount == 0: - continue - trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price) - # at the end of trade date, update the account based the $close_price of stocks. - trade_account.update_daily_end(today=trade_date, trader=trade_exchange) diff --git a/qlib/backtest/env.py b/qlib/backtest/env.py index a4f1eb95e..571f33b7e 100644 --- a/qlib/backtest/env.py +++ b/qlib/backtest/env.py @@ -7,13 +7,54 @@ import warnings import pathlib import pandas as pd from loguru import Logger -from ...data import D +from ...data import D, Cal from ...utils import get_date_in_file_name from ...utils import get_pre_trading_date from ..backtest.order import Order from ..utils import init_instance_by_config +class TradeCalendarBase: -class BaseEnv: + def _reset_trade_calendar(self, start_time, end_time): + if start_time: + self.start_time = pd.Timestamp(start_time) + if end_time: + self.end_time = pd.Timestamp(end_time) + if self.start_time and self.end_time: + _calendar, freq, freq_sam = get_sample_freq_calendar(freq=step_bar) + self.calendar = _calendar + _start_time, _end_time, _start_index, _end_index = Cal.locate_index(self.start_time, self.end_time, freq=freq, freq_sam=freq_sam) + _trade_calendar = self.calendar[_start_index, _end_index + 1] + if _start_time != self.start_time: + self.trade_calendar = np.hstack((self.start_time, _trade_calendar, self.end_time)) + self.start_index = _start_index - 1 + else: + self.trade_calendar = np.hstack((_trade_calendar, self.end_time)) + self.start_index = _start_index + self.end_index = _end_index + self.trade_index = 0 + self.trade_len = len(self.trade_calendar) + else: + raise ValueError("failed to reset trade calendar, params `start_time` or `end_time` is None.") + + def _get_trade_time(self, trade_index=1, shift=0): + trade_index = trade_index - shift + if 0 < trade_index < self.trade_len - 1: + trade_start_time = self.trade_calendar[trade_index - 1] + trade_end_time = self.trade_calendar[trade_index] - pd.Timestamp(second=1) + return trade_start_time, trade_end_time + elif trade_index == self.trade_len - 1: + trade_start_time = self.trade_calendar[trade_index - 1] + trade_end_time = self.trade_calendar[trade_index] + return trade_start_time, trade_end_time + else: + raise RuntimeError("trade_index out of range") + + def _get_calendar_time(self, trade_index=1, shift=1): + trade_index = trade_index - shift + calendar_index = self.start_index + trade_index + return self.calendar[calendar_index - 1], self.calendar[calendar_index] + +class BaseEnv(TradeCalendarBase): """ # Strategy framework document @@ -33,38 +74,19 @@ class BaseEnv: self.verbose = verbose self.reset(start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs) - def _reset_trade_calendar(self, start_time, end_time): - if start_time: - self.start_time = start_time - if end_time: - self.end_time = end_time - if self.start_time and self.end_time: - _calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar) - self.trade_calendar = np.hstack(_calendar, pd.Timestamp(self.end_time)) - self.trade_len = len(self.trade_calendar) - self.trade_index = 0 - else: - raise ValueError("failed to reset trade calendar, params `start_time` or `end_time` is None.") - def _get_position(self): return self.trade_account.current - def _get_trade_time(self): - if 0 < self.trade_index < self.trade_len - 1: - trade_start_time = self.trade_calendar[self.trade_index - 1] - trade_end_time = self.trade_calendar[self.trade_index] - pd.Timestamp(second=1) - return trade_start_time, trade_end_time - elif self.trade_index == self.trade_len - 1: - trade_start_time = self.trade_calendar[self.trade_index - 1] - trade_end_time = self.trade_calendar[self.trade_index] - return trade_start_time, trade_end_time - else: - raise RuntimeError("trade_index out of range") + def reset(self, start_time=None, end_time=None, trade_account=None, **kwargs): if start_time or end_time: self._reset_trade_calendar(start_time=start_time, end_time=end_time) self.trade_account = trade_account + + for k, v in kwargs: + if hasattr(self, k): + setattr(self, k, v) def get_first_state(self): init_state = {"current": self._get_position()} @@ -101,10 +123,10 @@ class SplitEnv(BaseEnv): # yield action #episode_reward = 0 super(SimulatorEnv, self).execute(**kwargs) - trade_start_time, trade_end_time = self._get_trade_time() + trade_start_time, trade_end_time = self._get_trade_time(trade_index=self.trade_index) self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time, trade_account=self.trade_account) self.sub_strategy.reset(start_time=trade_start_time, end_time=trade_end_time, trade_order_list=order_list) - trade_state = self.sub_env.get_first_state() + trade_state = self.sub_env.get_init_state() while not self.sub_env.finished(): _order_list = self.sub_strategy.generate_order(**trade_state) trade_state, trade_info = self.sub_env.execute(order_list=_order_list) @@ -140,7 +162,7 @@ class SimulatorEnv(BaseEnv): if self.finished(): raise StopIteration(f"this env has completed its task, please reset it if you want to call it!") super(SimulatorEnv, self).execute(**kwargs) - ttrade_start_time, trade_end_time = self._get_trade_time() + ttrade_start_time, trade_end_time = self._get_trade_time(trade_index=self.trade_index) trade_info = [] for order in order_list: if self.trade_exchange.check_order(order) is True: diff --git a/qlib/backtest/init.py b/qlib/backtest/init.py new file mode 100644 index 000000000..06dd437db --- /dev/null +++ b/qlib/backtest/init.py @@ -0,0 +1,132 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .order import Order +from .account import Account +from .position import Position +from .exchange import Exchange +from .report import Report +from .backtest import backtest as backtest_func, get_date_range + +import copy +import numpy as np +import inspect +from ..utils import init_instance_by_config +from ..log import get_module_logger +from ..config import C + +logger = get_module_logger("backtest caller") + + +def init_env_instance_by_config(env): + if isinstance(env, dict): + env_config = copy.copy(env) + if "kwargs" in env_config: + env_kwargs = copy.copy(env_config["kwargs"]): + if "sub_env" in env_kwargs: + env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"]) + if "sub_strategy" in env_kwargs: + env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"]) + env_config["kwargs"] = env_kwargs + return init_instance_by_config(env_config) + else: + return env + + +def get_exchange( + exchange=None, + start_time=None, + end_time=None, + codes = "all", + subscribe_fields=[], + open_cost=0.0015, + close_cost=0.0025, + min_cost=5.0, + trade_unit=None, + limit_threshold=None, + deal_price=None, + shift=1, +): + """get_exchange + + Parameters + ---------- + + # exchange related arguments + exchange: Exchange(). + subscribe_fields: list + subscribe fields. + open_cost : float + open transaction cost. + close_cost : float + close transaction cost. + min_cost : float + min transaction cost. + trade_unit : int + 100 for China A. + deal_price: str + dealing price type: 'close', 'open', 'vwap'. + limit_threshold : float + limit move 0.1 (10%) for example, long and short with same limit. + + Returns + ------- + :class: Exchange + an initialized Exchange object + """ + + if trade_unit is None: + trade_unit = C.trade_unit + if limit_threshold is None: + limit_threshold = C.limit_threshold + if deal_price is None: + deal_price = C.deal_price + if exchange is None: + logger.info("Create new exchange") + # handle exception for deal_price + if deal_price[0] != "$": + deal_price = "$" + deal_price + + exchange = Exchange( + start_time=start_time, + end_time=end_time, + codes=codes, + deal_price=deal_price, + subscribe_fields=subscribe_fields, + limit_threshold=limit_threshold, + open_cost=open_cost, + close_cost=close_cost, + trade_unit=trade_unit, + min_cost=min_cost, + ) + else: + return init_instance_by_config(exchange, accept_types=Exchange) + +def backtest(start_time, end_time, strategy, env, account=1e9, **kwargs): + trade_strategy = init_instance_by_config(strategy) + trade_env = init_env_instance_by_config(env) + trade_account = Account(init_cash=account) + + spec = inspect.getfullargspec(get_exchange) + ex_args = {k: v for k, v in kwargs.items() if k in spec.args} + trade_exchange = get_exchange(pred, **ex_args) + +# temp_env = trade_env +# while True: +# if hasattr(temp_env, "trade_exchange"): +# temp_env.reset(trade_exchange=trade_exchange) +# if hasattr(temp_env, "sub_env"): +# temp_env = temp_env.sub_env +# else: +# break + + trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account) + trade_state, _reset_info = self.sub_env.get_first_state() + trade_strategy.reset(**_reset_info) + + + while not trade_env.finished(): + _order_list = self.sub_strategy.generate_order(**trade_state) + trade_state, trade_info = self.sub_env.execute(sub_order_list) + + return diff --git a/qlib/contrib/strategy/dl_strategy.py b/qlib/contrib/strategy/dl_strategy.py index 737fd7a58..5f702fe0b 100644 --- a/qlib/contrib/strategy/dl_strategy.py +++ b/qlib/contrib/strategy/dl_strategy.py @@ -15,11 +15,11 @@ class TopkDropoutStrategy(DLStrategy): step_bar, model, dataset, - trade_exchange, topk, n_drop, start_time=None, end_time=None, + trade_exchange=None, method_sell="bottom", method_buy="top", risk_degree=0.95, @@ -54,7 +54,6 @@ class TopkDropoutStrategy(DLStrategy): strategy will make decision with the tradable state of the stock info and avoid buy and sell them. """ super(TopkDropoutStrategy, self).__init__(step_bar, model, dataset, start_time, end_time) - self.trade_exchange = trade_exchange self.topk = topk self.n_drop = n_drop self.method_sell = method_sell @@ -68,6 +67,10 @@ class TopkDropoutStrategy(DLStrategy): self.only_tradable = only_tradable + def reset(trade_exchange=None, **kwargs): + super(TopkDropoutStrategy, self).reset(**kwargs) + self.trade_exchange = trade_exchange + def get_risk_degree(self, trade_index): """get_risk_degree Return the proportion of your total value you will used in investment. @@ -78,8 +81,8 @@ class TopkDropoutStrategy(DLStrategy): def generate_order_list(self, current, **kwargs): super(TopkDropoutStrategy, self).generate_order_list() - trade_start_time, trade_end_time = self._get_trade_time() - pred_start_time, pred_end_time = self._get_last_trade_time() + trade_start_time, trade_end_time = self._get_trade_time(self.trade_index) + pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1) pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") if self.only_tradable: # If The strategy only consider tradable stock when make decision @@ -268,7 +271,7 @@ class WeightStrategyBase(DLStrategy): # generate_order_list # generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list super(WeightStrategyBase, self).generate_order_list() - trade_start_time, trade_end_time = self._get_trade_time() + trade_start_time, trade_end_time = self._get_trade_time(self.trade_index) pred_start_time, pred_end_time = self._get_pred_time() pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") current_temp = copy.deepcopy(trade_account.current) diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 31968dafa..dd2e17c54 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -57,7 +57,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement): def generate_order_list(self, **kwargs): super(SBBStrategyBase, self).generate_order_list() trade_start_time, trade_end_time = self._get_trade_time() - pred_start_time, pred_end_time = self._get_last_trade_time() + pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1) order_list = [] for order in self.trade_order_list: if self.trade_index % 2 == 1: @@ -127,8 +127,9 @@ class SBBStrategyEMA(SBBStrategyBase): def _reset_trade_calendar(self, start_time=None, end_time=None, _calendar=None): super(SBBStrategyEMA, self)._reset_trade_calendar(start_time=start_time, end_time=end_time, _calendar=_calendar) - fields = [("EMA...", "signal")] - self.signal = D.features(instruments, fields, start_time=self.start_time, end_time=self.end_time, freq=self.freq) + fields = [("EMA($close, 10) - EMA($close, 20)", "signal")] + signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1) + self.signal = D.features(instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq) def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None): _sample_signal = sample_feature(self.signal, stock_id, start_time=pred_start_time, end_time=pred_end_time, fields="signal", method="last") diff --git a/qlib/data/data.py b/qlib/data/data.py index f978f520c..98427637a 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -114,11 +114,11 @@ class CalendarProvider(abc.ABC): dict dict composed by timestamp as key and index as value for fast search. """ - flag = f"{freq}_future_{future}_sam_{freq_sam}" + flag = f"{freq}_sam_{freq_sam}_future_{future}" if flag in H["c"]: _calendar, _calendar_index = H["c"][flag] else: - flag_raw = f"{freq}_future_{future}_sam_{None}" + flag_raw = f"{freq}_sam_{None}_future_{future}" if flag_raw in H["c"]: _calendar, _calendar_index = H["c"][flag_raw] else: diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index 9f9be45cb..cad093af2 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -8,41 +8,30 @@ import numpy as np import pandas as pd -from ..utils import sample_feature, get_sample_freq_calendar +from ..utils import get_sample_freq_calendar from ..data.dataset import DatasetH from ..backtest.order import Order -from .order_generator import OrderGenWInteract -from ..data.data import D +from ..backtest.env import TradeCalendarBase + """ 1. BaseStrategy 的粒度一定是数据粒度的整数倍 - 关于calendar的合并咋整 - adjust_dates这个东西啥用 - label和freq和strategy的bar分离,这个如何决策呢 """ -class BaseStrategy: +class BaseStrategy(TradeCalendarBase): def __init__(self, step_bar, start_time=None, end_time=None, **kwargs): self.step_bar = step_bar self.reset(start_time=start_time, end_time=end_time, **kwargs) - def _reset_trade_calendar(self, start_time, end_time, _calendar=None): - if start_time: - self.start_time = start_time - if end_time: - self.end_time = end_time - if self.start_time and self.end_time: - if not _calendar: - _calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar) - self.trade_calendar = np.hstack(_calendar, pd.Timestamp(self.end_time)) - else: - self.trade_calendar = _calendar - self.trade_len = len(self.trade_calendar) - self.trade_index = 0 - else: - raise ValueError("failed to reset trade calendar, params `start_time` or `end_time` is None.") - - def reset(self, start_time=None, end_time=None, _calendar=None): + def reset(self, start_time=None, end_time=None, _calendar=None, **kwargs): if start_time or end_time : self._reset_trade_calendar(start_time=start_time, end_time=end_time, calendar=calendar) + + for k, v in kwargs: + if hasattr(self, k): + setattr(self, k, v) + def _get_trade_time(self): if 0 < self.trade_index < self.trade_len - 1: @@ -56,13 +45,6 @@ class BaseStrategy: else: raise RuntimeError("trade_index out of range") - def _get_last_trade_time(self, shift=1): - if self.trade_index - shift < 0: - return None, None - elif self.trade_index - shift == 0: - return None, self.trade_index[self.trade_index - shift] - else: - return self.trade_index[self.trade_index - shift - 1], self.trade_index[self.trade_index - shift] def generate_order_list(self, **kwargs): self.trade_index = self.trade_index + 1 diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 2cd2f5d13..028e60cc6 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -918,20 +918,25 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam): else: raise ValueError("sample freq must be xmin, xd, xw, xm") -def get_sample_freq_calendar(start_time, end_time, freq): +def get_sample_freq_calendar(start_time=None, end_time=None, freq, **kwargs): try: - _calendar = D.calendar(start_time=start_time, end_time=end_time, freq=freq) + _calendar = D.calendar(start_time=start_time, end_time=end_time, freq=freq, **kwargs) + freq, freq_sam = freq, None except ValueError: + freq_sam = freq if freq.endswith(("m", "month", "w", "week", "d", "day")): try: - _calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="min", freq_sam=freq) + _calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="min", freq_sam=freq, **kwargs) + freq = "min" except ValueError: - _calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="day", freq_sam=freq) + _calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="day", freq_sam=freq, **kwargs) + freq = "day" elif freq.endswith(("min", "minute")): - _calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="min", freq_sam=freq) + _calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="min", freq_sam=freq, **kwargs) + freq = "min" else: raise ValueError(f"freq {freq} is not supported") - return _calendar + return _calendar, freq, freq_sam def sample_feature(feature, instruments=None, start_time=None, end_time=None, fields=None, method=None, method_kwargs={}): if instruments and type(instruments) is not list: