diff --git a/examples/highfreq/backtest/workflow.py b/examples/highfreq/backtest/workflow.py index 38c1eecc8..e5a832927 100644 --- a/examples/highfreq/backtest/workflow.py +++ b/examples/highfreq/backtest/workflow.py @@ -81,7 +81,7 @@ if __name__ == "__main__": backtest_config={ "strategy": { "class": "TopkDropoutStrategy", - "module_path": "qlib.contrib.strategy.dl_strategy", + "module_path": "qlib.contrib.strategy.model_strategy", "kwargs": { "step_bar": "week", "model": model, @@ -113,6 +113,18 @@ if __name__ == "__main__": } } } + }, + "backtest":{ + "start_time": trade_start_time, + "end_time": trade_end_time, + "verbose": False, + "limit_threshold": 0.095, + "account": 100000000, + "benchmark": benchmark, + "deal_price": "close", + "open_cost": 0.0005, + "close_cost": 0.0015, + "min_cost": 5, } } diff --git a/qlib/contrib/backtest/__init__.py b/qlib/contrib/backtest/__init__.py index 8796d0057..4a03bbe47 100644 --- a/qlib/contrib/backtest/__init__.py +++ b/qlib/contrib/backtest/__init__.py @@ -19,6 +19,7 @@ logger = get_module_logger("backtest caller") def get_exchange( exchange=None, + freq="day", start_time=None, end_time=None, codes = "all", @@ -72,6 +73,7 @@ def get_exchange( deal_price = "$" + deal_price exchange = Exchange( + freq=freq, start_time=start_time, end_time=end_time, codes=codes, diff --git a/qlib/contrib/backtest/account.py b/qlib/contrib/backtest/account.py index c44d26d7b..981e3c07a 100644 --- a/qlib/contrib/backtest/account.py +++ b/qlib/contrib/backtest/account.py @@ -3,10 +3,13 @@ import copy +import pandas as pd from .position import Position from .report import Report from .order import Order +from ...utils import parse_freq, sample_feature + """ @@ -26,21 +29,86 @@ rtn & earning in the Account class Account: - def __init__(self, init_cash, last_trade_time=None): - self.init_vars(init_cash, last_trade_time) + def __init__(self, init_cash, benchmark=None, start_time=None, end_time=None, freq=None): + self.init_vars(init_cash, benchmark, start_time, end_time) - def init_vars(self, init_cash, last_trade_time=None): + def init_vars(self, init_cash, benchmark=None, start_time=None, end_time=None, freq=None): + """ + Parameters + ---------- + - benchmark: str/list/pd.Series + `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T. + example: + print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head()) + 2017-01-04 0.011693 + 2017-01-05 0.000721 + 2017-01-06 -0.004322 + 2017-01-09 0.006874 + 2017-01-10 -0.003350 + `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'. + `benchmark` is str, will use the daily change as the 'bench'. + benchmark code, default is SH000905 CSI500 + + """ # init cash self.init_cash = init_cash + self.benchmark = benchmark + self.start_time = start_time + self.end_time = end_time + self.freq = freq self.current = Position(cash=init_cash) self.positions = {} self.rtn = 0 self.ct = 0 self.to = 0 self.val = 0 - self.report = Report() self.earning = 0 - self.last_trade_time = last_trade_time + self.report = Report() + if freq and benchmark: + self.bench = self._cal_benchmark(benchmark, start_time, end_time, freq) + + def _cal_benchmark(self, benchmark, start_time=None, end_time=None, freq=None): + if isinstance(benchmark, pd.Series): + return benchmark + else: + if freq is None: + raise ValueError("benchmark freq can't be None!") + _codes = benchmark if isinstance(benchmark, list) else [benchmark] + fields = ["$close/Ref($close,1)-1"] + try: + _temp_result = D.features(_codes, fields, start_time, end_time, freq=freq, disk_cache=1) + except ValueError: + _, norm_freq = parse_freq(freq) + if norm_freq in ["month", "week", "day"]: + try: + _temp_result = D.features(_codes, fields, start_time, end_time, freq="day", disk_cache=1) + except ValueError: + _temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1) + elif norm_freq == "minute": + _temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1) + else: + raise ValueError(f"benchmark freq {freq} is not supported") + if len(_temp_result) == 0: + raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark") + return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0) + + def _sample_benchmark(self, bench, trade_start_time, trade_end_time): + def cal_change(x): + return x.prod() - 1 + return sample_feature(bench, trade_start_time, trade_end_time, method=cal_change) + + def reset(self, benchmark=None, freq=None,**kwargs): + if benchmark: + self.benchmark = benchmark + if freq: + self.freq = freq + if self.freq and self.benchmark and (freq or benchmark) + self.bench = self._cal_benchmark(self.benchmark, self.start_time, self.end_time, self.freq) + + for k, v in kwargs: + if hasattr(k): + setattr(k, v) + def get_positions(self): return self.positions @@ -83,7 +151,7 @@ class Account: self.current.update_order(order, trade_val, cost, trade_price) self.update_state_from_order(order, trade_val, cost, trade_price) - def update_bar_end(self, trade_start_time, trade_end_time, trade_exchange): + def update_report(self, trade_start_time, trade_end_time, trade_exchange): """ start_time: pd.TimeStamp end_time: pd.TimeStamp @@ -100,20 +168,17 @@ class Account: """ # update price for stock in the position and the profit from changed_price stock_list = self.current.get_stock_list() - profit = 0 for code in stock_list: # if suspend, no new price to be updated, profit is 0 if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time): continue bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time) - profit += (bar_close - self.current.position[code]["price"]) * self.current.position[code]["amount"] self.current.update_stock_price(stock_id=code, price=bar_close) - self.rtn += profit # update holding day count - self.current.add_count_all() + self.current.add_count_all(bar=self.freq) # update value self.val = self.current.calculate_value() - # update earning (2nd view of return) + # update earning # account_value - last_account_value # for the first trade date, account_value - init_cash # self.report.is_empty() to judge is_first_trade_date @@ -138,6 +203,7 @@ class Account: turnover_rate=self.to / last_account_value, cost_rate=self.ct / last_account_value, stock_value=now_stock_value, + bench_value=self._sample_benchmark(self.bench, trade_start_time, trade_end_time) ) # set now_account_value to position self.current.position["now_account_value"] = now_account_value @@ -148,23 +214,20 @@ class Account: # finish today's updation # reset the daily variables - self.rtn = 0 self.ct = 0 self.to = 0 - self.last_trade_time = (trade_start_time, trade_end_time) def load_account(self, account_path): report = Report() position = Position() - last_trade_time = position.load_position(account_path / "position.xlsx") report.load_report(account_path / "report.csv") + position.load_position(account_path / "position.xlsx") # assign values self.init_vars(position.init_cash) self.current = position self.report = report - self.last_trade_time = last_trade_time def save_account(self, account_path): - self.current.save_position(account_path / "position.xlsx", self.last_trade_time) + self.current.save_position(account_path / "position.xlsx") self.report.save_report(account_path / "report.csv") diff --git a/qlib/contrib/backtest/env.py b/qlib/contrib/backtest/env.py index 85a6c1ec3..9fa993e7b 100644 --- a/qlib/contrib/backtest/env.py +++ b/qlib/contrib/backtest/env.py @@ -9,12 +9,26 @@ import numpy as np import pandas as pd from ...data.data import Cal from ...utils import get_sample_freq_calendar +from .position import Position +from .report import Report from .order import Order -class TradeCalendarBase: + +class BaseTradeCalendar: + def __init__( + self, + step_bar, + start_time=None, + end_time=None, + **kwargs + ): + self.step_bar = step_bar + self.reset(start_time=start_time, end_time=end_time) def _reset_trade_calendar(self, start_time, end_time): + if not start_time and not end_time: + return if start_time: self.start_time = pd.Timestamp(start_time) if end_time: @@ -24,37 +38,33 @@ class TradeCalendarBase: self.calendar = _calendar _start_time, _end_time, _start_index, _end_index = Cal.locate_index(self.start_time, self.end_time, freq=freq, freq_sam=freq_sam) _trade_calendar = self.calendar[_start_index: _end_index + 1] - if _start_time != self.start_time: - self.trade_calendar = np.hstack((self.start_time, _trade_calendar, self.end_time)) - self.start_index = _start_index - 1 - else: - self.trade_calendar = np.hstack((_trade_calendar, self.end_time)) - self.start_index = _start_index + self.start_index = _start_index self.end_index = _end_index + self.trade_len = _end_index - _start_index + 1 self.trade_index = 0 - self.trade_len = len(self.trade_calendar) else: raise ValueError("failed to reset trade calendar, params `start_time` or `end_time` is None.") - def _get_trade_time(self, trade_index=1, shift=0): - trade_index = trade_index - shift - if 0 < trade_index < self.trade_len - 1: - trade_start_time = self.trade_calendar[trade_index - 1] - trade_end_time = self.trade_calendar[trade_index] - pd.Timedelta(seconds=1) - return trade_start_time, trade_end_time - elif trade_index == self.trade_len - 1: - trade_start_time = self.trade_calendar[trade_index - 1] - trade_end_time = self.trade_calendar[trade_index] - return trade_start_time, trade_end_time - else: - raise RuntimeError("trade_index out of range") + def reset(self, start_time=None, end_time=None, **kwargs): + if start_time or end_time: + self._reset_trade_calendar(start_time=start_time, end_time=end_time) + + for k, v in kwargs: + if hasattr(self, k): + setattr(self, k, v) - def _get_calendar_time(self, trade_index=1, shift=1): + def _get_calendar_time(self, trade_index=1, shift=0): trade_index = trade_index - shift calendar_index = self.start_index + trade_index return self.calendar[calendar_index - 1], self.calendar[calendar_index] -class BaseEnv(TradeCalendarBase): + def finished(self): + return self.trade_index >= self.trade_len + + def step(self): + self.trade_index = self.trade_index + 1 + +class BaseEnv(BaseTradeCalendar): """ # Strategy framework document @@ -67,38 +77,32 @@ class BaseEnv(TradeCalendarBase): start_time=None, end_time=None, trade_account=None, + update_report=False, verbose=False, **kwargs, ): - self.step_bar = step_bar + self.generate_report = update_report self.verbose = verbose - self.reset(start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs) - - def _get_position(self): - return self.trade_account.current + super(BaseEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs) - - def reset(self, start_time=None, end_time=None, trade_account=None, **kwargs): - if start_time or end_time: - self._reset_trade_calendar(start_time=start_time, end_time=end_time) + def reset(self, trade_account=None, **kwargs): + super(BaseEnv, self).reset(**kwargs) if trade_account: self.trade_account = trade_account - - for k, v in kwargs: - if hasattr(self, k): - setattr(self, k, v) + self.trade_account.reset(freq=self.step_bar, report=Report(), positions={}) def get_init_state(self): - init_state = {"current": self._get_position()} + init_state = {"current": self.trade_account.current} return init_state + def execute(self, **kwargs): + raise NotImplementedError("execute is not implemented!") - def execute(self, order_list=None, **kwargs): - self.trade_index = self.trade_index + 1 - - def finished(self): - return self.trade_index >= self.trade_len - 1 + def get_trade_account(self): + raise NotImplementedError("get_trade_account is not implemented!") + def get_report(self): + raise NotImplementedError("get_report is not implemented!") class SplitEnv(BaseEnv): def __init__( @@ -109,33 +113,44 @@ class SplitEnv(BaseEnv): start_time=None, end_time=None, trade_account=None, + update_report=False, verbose=False, **kwargs ): self.sub_env = sub_env self.sub_strategy = sub_strategy - super(SplitEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, verbose=verbose) + super(SplitEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, update_report=update_report, verbose=verbose, **kwargs) + def reset(self, trade_account=None, **kwargs): + super(SplitEnv, self).reset(trade_account=trade_account, **kwargs) + if trade_account: + self.sub_env.reset(trade_account=copy.copy(trade_account)) + def execute(self, order_list, **kwargs): if self.finished(): raise StopIteration(f"this env has completed its task, please reset it if you want to call it!") #if self.track: # yield action #episode_reward = 0 - super(SplitEnv, self).execute(**kwargs) - trade_start_time, trade_end_time = self._get_trade_time(trade_index=self.trade_index) - self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time, trade_account=self.trade_account) + super(SplitEnv, self).step() + trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) + self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time) self.sub_strategy.reset(start_time=trade_start_time, end_time=trade_end_time, trade_order_list=order_list) trade_state = self.sub_env.get_init_state() while not self.sub_env.finished(): _order_list = self.sub_strategy.generate_order_list(**trade_state) trade_state, trade_info = self.sub_env.execute(order_list=_order_list) - #episode_reward += sub_reward - _obs = {"current": self._get_position()} + + if self.generate_report: + self.trade_account.update_report(trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange) + _obs = {"current": self.trade_account.current} _info = {} return _obs, _info - + def get_report(self): + _report = self.trade_account.report.generate_report_dataframe() if self.generate_report else None + _positions = self.trade_account.get_positions() if self.generate_report else None + return [(_report,_positions), *sub_env.get_report()] class SimulatorEnv(BaseEnv): @@ -146,10 +161,11 @@ class SimulatorEnv(BaseEnv): end_time=None, trade_account=None, trade_exchange=None, + update_report=False, verbose=False, **kwargs, ): - super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, trade_exchange=trade_exchange, verbose=verbose, **kwargs) + super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, trade_exchange=trade_exchange, update_report=update_report, verbose=verbose, **kwargs) def reset(self, trade_exchange=None, **kwargs): super(SimulatorEnv, self).reset(**kwargs) @@ -162,8 +178,8 @@ class SimulatorEnv(BaseEnv): """ if self.finished(): raise StopIteration(f"this env has completed its task, please reset it if you want to call it!") - super(SimulatorEnv, self).execute(**kwargs) - trade_start_time, trade_end_time = self._get_trade_time(trade_index=self.trade_index) + super(SimulatorEnv, self).step() + trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) trade_info = [] for order in order_list: if self.trade_exchange.check_order(order) is True: @@ -197,7 +213,18 @@ class SimulatorEnv(BaseEnv): print("[W {:%Y-%m-%d}]: {} wrong.".format(trade_start_time, order.stock_id)) # do nothing pass - self.trade_account.update_bar_end(trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange) - _obs = {"current": self._get_position()} + if self.generate_report: + self.trade_account.update_report(trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange) + _obs = {"current": self.trade_account.current} _info = {"trade_info": trade_info} - return _obs, _info \ No newline at end of file + return _obs, _info + + def get_report(self): + _report = self.trade_account.report.generate_report_dataframe() if self.generate_report else None + _positions = self.trade_account.get_positions() if self.generate_report else None + return [ + { + "report": _report, + "positions": _positions + } + ] \ No newline at end of file diff --git a/qlib/contrib/backtest/exchange.py b/qlib/contrib/backtest/exchange.py index 62f6c63bd..399f9e151 100644 --- a/qlib/contrib/backtest/exchange.py +++ b/qlib/contrib/backtest/exchange.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd from ...data.data import D +from ...data.dataset.utils import get_level_index from ...config import C, REG_CN from ...utils import sample_feature from ...log import get_module_logger @@ -19,6 +20,7 @@ from .order import Order class Exchange: def __init__( self, + freq="day", start_time=None, end_time=None, codes="all", @@ -55,6 +57,7 @@ class Exchange: target on this day). index: MultipleIndex(instrument, pd.Datetime) """ + self.freq = freq self.start_time = start_time self.end_time = end_time if trade_unit is None: @@ -105,7 +108,7 @@ class Exchange: def set_quote(self, codes, start_time, end_time): if len(codes) == 0: codes = D.instruments() - self.quote = D.features(codes, self.all_fields, start_time, end_time, disk_cache=True).dropna(subset=["$close"]) + self.quote = D.features(codes, self.all_fields, start_time, end_time, freq=self.freq, disk_cache=True).dropna(subset=["$close"]) self.quote.columns = self.all_fields if self.quote[self.deal_price].isna().any(): @@ -146,7 +149,14 @@ class Exchange: quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0) # update quote: pd.DataFrame to dict, for search use - self.quote = quote_df + if get_level_index(quote_df, level="datetime") == 1: + quote_df = quote_df.swaplevel().sort_index() + + quote_dict = {} + for stock_id, stock_val in quote_df.groupby(level="instrument"): + quote_dict[stock_id] = stock_val + + self.quote = quote_dict def _update_limit(self, buy_limit, sell_limit): self.quote["limit"] = ~self.quote["$change"].between(-sell_limit, buy_limit, inclusive=False) @@ -157,13 +167,15 @@ class Exchange: trade_date is limtited """ - return sample_feature(self.quote, stock_id, start_time, end_time, fields="limit", method="any").iloc[0] + return sample_feature(self.quote[stock_id], start_time, end_time, fields="limit", method="all").iloc[0] def check_stock_suspended(self, stock_id, start_time, end_time): # is suspended - return sample_feature(self.quote, stock_id, start_time, end_time).empty - + if stock_id in self.quote: + return sample_feature(self.quote[stock_id], start_time, end_time, method=None) is None + else: + return True def is_stock_tradable(self, stock_id, start_time, end_time): # check if stock can be traded @@ -217,13 +229,13 @@ class Exchange: return trade_val, trade_cost, trade_price def get_quote_info(self, stock_id, start_time, end_time): - return sample_feature(self.quote, stock_id, start_time, end_time) + return sample_feature(self.quote[stock_id], start_time, end_time, method="last").iloc[0] def get_close(self, stock_id, start_time, end_time): - return sample_feature(self.quote, stock_id, start_time, end_time, fields="$close", method="last").iloc[0] + return sample_feature(self.quote[stock_id], start_time, end_time, fields="$close", method="last").iloc[0] def get_deal_price(self, stock_id, start_time, end_time): - deal_price = sample_feature(self.quote, stock_id, start_time, end_time, fields=self.deal_price, method="last").iloc[0] + deal_price = sample_feature(self.quote[stock_id], start_time, end_time, fields=self.deal_price, method="last").iloc[0] if np.isclose(deal_price, 0.0) or np.isnan(deal_price): self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!") self.logger.warning(f"setting deal_price to close price") @@ -231,7 +243,7 @@ class Exchange: return deal_price def get_factor(self, stock_id, start_time, end_time): - return sample_feature(self.quote, stock_id, start_time, end_time, fields="$factor", method="last").iloc[0] + return sample_feature(self.quote[stock_id], start_time, end_time, fields="$factor", method="last").iloc[0] def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time): """ diff --git a/qlib/contrib/backtest/position.py b/qlib/contrib/backtest/position.py index ac1a471f8..6eb2c97b8 100644 --- a/qlib/contrib/backtest/position.py +++ b/qlib/contrib/backtest/position.py @@ -38,7 +38,6 @@ class Position: def init_stock(self, stock_id, amount, price=None): self.position[stock_id] = {} - self.position[stock_id]["count"] = 0 # update count in the end of this date self.position[stock_id]["amount"] = amount self.position[stock_id]["price"] = price self.position[stock_id]["weight"] = 0 # update the weight in the end of the trade date @@ -87,8 +86,8 @@ class Position: def update_stock_price(self, stock_id, price): self.position[stock_id]["price"] = price - def update_stock_count(self, stock_id, count): - self.position[stock_id]["count"] = count + def update_stock_count(self, stock_id, bar, count): + self.position[stock_id][f"count_{bar}"] = count def update_stock_weight(self, stock_id, weight): self.position[stock_id]["weight"] = weight @@ -118,8 +117,11 @@ class Position: def get_stock_amount(self, code): return self.position[code]["amount"] - def get_stock_count(self, code): - return self.position[code]["count"] + def get_stock_count(self, code, bar): + if f"count_{bar}" in self.position[code]: + return self.position[code][f"count_{bar}"] + else: + return 0 def get_stock_weight(self, code): return self.position[code]["weight"] @@ -153,25 +155,26 @@ class Position: d[stock_code] = self.position[stock_code]["amount"] * self.position[stock_code]["price"] / position_value return d - def add_count_all(self): + def add_count_all(self, bar): stock_list = self.get_stock_list() for code in stock_list: - self.position[code]["count"] += 1 + if f"count_{bar}" in self.position[code]: + self.position[code][f"count_{bar}"] += 1 + else: + self.position[code][f"count_{bar}"] = 1 def update_weight_all(self): weight_dict = self.get_stock_weight_dict() for stock_code, weight in weight_dict.items(): self.update_stock_weight(stock_code, weight) - def save_position(self, path, last_trade_time): + def save_position(self, path): path = pathlib.Path(path) p = copy.deepcopy(self.position) cash = pd.Series(dtype=np.float) cash["init_cash"] = self.init_cash cash["cash"] = p["cash"] cash["now_account_value"] = p["now_account_value"] - cash["last_trade_start_time"] = str(last_trade_time[0]) if last_trade_time else None - cash["last_trade_end_time"] = str(last_trade_time[1]) if last_trade_time else None del p["cash"] del p["now_account_value"] positions = pd.DataFrame.from_dict(p, orient="index") @@ -183,8 +186,8 @@ class Position: """load position information from a file should have format below sheet "position" - columns: ['stock', 'count', 'amount', 'price', 'weight'] - 'count': , + columns: ['stock', f'count_{bar}', 'amount', 'price', 'weight'] + f'count_{bar}': , 'amount': , 'price': , 'weight': , @@ -202,16 +205,9 @@ class Position: init_cash = cash_record.loc["init_cash"].values[0] cash = cash_record.loc["cash"].values[0] now_account_value = cash_record.loc["now_account_value"].values[0] - last_trade_start_time = cash_record.loc["last_trade_start_time"].values[0] - last_trade_end_time = cash_record.loc["last_trade_end_time"].values[0] - # assign values self.position = {} self.init_cash = init_cash self.position = positions self.position["cash"] = cash self.position["now_account_value"] = now_account_value - - last_trade_start_time = None if pd.isna(last_trade_start_time) else pd.Timestamp(last_trade_start_time) - last_trade_end_time = None if pd.isna(last_trade_end_time) else pd.Timestamp(last_trade_end_time) - return last_trade_start_time, last_trade_end_time diff --git a/qlib/contrib/backtest/report.py b/qlib/contrib/backtest/report.py index 9a57156f2..3bee440e0 100644 --- a/qlib/contrib/backtest/report.py +++ b/qlib/contrib/backtest/report.py @@ -21,6 +21,7 @@ class Report: self.costs = OrderedDict() # trade cost for each trade date self.values = OrderedDict() # value for each trade date self.cashes = OrderedDict() + self.benches = OrderedDict() self.latest_report_time = None # pd.TimeStamp def is_empty(self): @@ -41,6 +42,7 @@ class Report: turnover_rate=None, cost_rate=None, stock_value=None, + bench_value=None, ): # check data if None in [ @@ -51,9 +53,10 @@ class Report: turnover_rate, cost_rate, stock_value, + bench_value ]: raise ValueError( - "None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value]" + "None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value, bench_value]" ) # update report data self.accounts[trade_time] = account_value @@ -62,6 +65,7 @@ class Report: self.costs[trade_time] = cost_rate self.values[trade_time] = stock_value self.cashes[trade_time] = cash + self.benches[trade_time] = bench_value # update latest_report_date self.latest_report_time = trade_time # finish daily report update @@ -74,7 +78,8 @@ class Report: report["cost"] = pd.Series(self.costs) report["value"] = pd.Series(self.values) report["cash"] = pd.Series(self.cashes) - report.index.name = "trade_time" + report["bench"] = pd.Series(self.benches) + report.index.name = "datetime" return report def save_report(self, path): @@ -84,7 +89,7 @@ class Report: def load_report(self, path): """load report from a file should have format like - columns = ['account', 'return', 'turnover', 'cost', 'value', 'cash'] + columns = ['account', 'return', 'turnover', 'cost', 'value', 'cash', 'bench'] :param path: str/ pathlib.Path() """ @@ -103,4 +108,5 @@ class Report: turnover_rate=r.loc[trade_time]["turnover"], cost_rate=r.loc[trade_time]["cost"], stock_value=r.loc[trade_time]["value"], + bench_value=r.loc[trade_time]["bench"] ) diff --git a/qlib/contrib/report/analysis_position/parse_position.py b/qlib/contrib/report/analysis_position/parse_position.py index fe1d61137..c5d48ff8e 100644 --- a/qlib/contrib/report/analysis_position/parse_position.py +++ b/qlib/contrib/report/analysis_position/parse_position.py @@ -41,7 +41,7 @@ def parse_position(position: dict = None) -> pd.DataFrame: for _trading_date, _value in position.items(): # pd_date type: pd.Timestamp _cash = _value.pop("cash") - for _item in ["today_account_value"]: + for _item in ["now_account_value"]: if _item in _value: _value.pop(_item) diff --git a/qlib/contrib/strategy/__init__.py b/qlib/contrib/strategy/__init__.py index 678b048c2..b138edb23 100644 --- a/qlib/contrib/strategy/__init__.py +++ b/qlib/contrib/strategy/__init__.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. -from .dl_strategy import ( +from .model_strategy import ( TopkDropoutStrategy, WeightStrategyBase, ) diff --git a/qlib/contrib/strategy/cost_control.py b/qlib/contrib/strategy/cost_control.py index 962936f9f..111cc276a 100644 --- a/qlib/contrib/strategy/cost_control.py +++ b/qlib/contrib/strategy/cost_control.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. -from .dl_strategy import WeightStrategyBase +from .model_strategy import WeightStrategyBase import copy diff --git a/qlib/contrib/strategy/dl_strategy.py b/qlib/contrib/strategy/model_strategy.py similarity index 95% rename from qlib/contrib/strategy/dl_strategy.py rename to qlib/contrib/strategy/model_strategy.py index 4c7d16eea..9aab96377 100644 --- a/qlib/contrib/strategy/dl_strategy.py +++ b/qlib/contrib/strategy/model_strategy.py @@ -81,10 +81,12 @@ class TopkDropoutStrategy(ModelStrategy): return self.risk_degree def generate_order_list(self, current, **kwargs): - super(TopkDropoutStrategy, self).generate_order_list() - trade_start_time, trade_end_time = self._get_trade_time(self.trade_index) + super(TopkDropoutStrategy, self).step() + trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1) pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") + if pred_score is None: + return [] if self.only_tradable: # If The strategy only consider tradable stock when make decision # It needs following actions to filter stocks @@ -168,7 +170,7 @@ class TopkDropoutStrategy(ModelStrategy): continue if code in sell: # check hold limit - if self.stock_count[code] < self.thresh or current_temp.get_stock_count(code) < self.hold_thresh: + if self.stock_count[code] < self.thresh or current_temp.get_stock_count(code, bar=self.step_bar) < self.hold_thresh: # can not sell this code # no buy signal, but the stock is kept self.stock_count[code] += 1 @@ -271,10 +273,12 @@ class WeightStrategyBase(ModelStrategy): """ # generate_order_list # generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list - super(WeightStrategyBase, self).generate_order_list() - trade_start_time, trade_end_time = self._get_trade_time(self.trade_index) - pred_start_time, pred_end_time = self._get_pred_time() + super(WeightStrategyBase, self).step() + trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) + pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1) pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") + if pred_score is None: + return [] current_temp = copy.deepcopy(trade_account.current) target_weight_position = self.generate_target_weight_position( score=pred_score, current=current_temp, trade_start_time=trade_start_time, trade_end_time=trade_end_time diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index b51ec9aca..b432ccea2 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -5,6 +5,7 @@ import pandas as pd from ...utils import sample_feature from ...data.data import D +from ...data.dataset.utils import get_level_index from ...strategy.base import RuleStrategy, TradingEnhancement from ..backtest.order import Order @@ -21,8 +22,8 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement): def generate_order_list(self, **kwargs): - super(TopkDropoutStrategy, self).generate_order_list() - trade_start_time, trade_end_time = self._get_trade_time() + super(TopkDropoutStrategy, self).step() + trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) order_list = [] for order in self.trade_order_list: _order = Order( @@ -59,8 +60,8 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement): raise NotImplementedError("pred_price_trend method is not implemented!") def generate_order_list(self, **kwargs): - super(SBBStrategyBase, self).generate_order_list() - trade_start_time, trade_end_time = self._get_trade_time() + super(SBBStrategyBase, self).step() + trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1) order_list = [] for order in self.trade_order_list: @@ -127,21 +128,33 @@ class SBBStrategyEMA(SBBStrategyBase): if isinstance(instruments, str): self.instruments = D.instruments(instruments) self.freq = freq - + + def _convert_index_format(self, df): + if get_level_index(df, level="datetime") == 1: + df = df.swaplevel().sort_index() + return df def _reset_trade_calendar(self, start_time=None, end_time=None): super(SBBStrategyEMA, self)._reset_trade_calendar(start_time=start_time, end_time=end_time) if self.start_time and self.end_time: fields = ["EMA($close, 10)-EMA($close, 20)"] signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1) - self.signal = D.features(self.instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq) - self.signal.columns = ["signal"] - + signal_df = D.features(self.instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq) + signal_df = self._convert_index_format(signal_df) + signal_df.columns = ["signal"] + self.signal = {} + for stock_id, stock_val in signal_df.groupby(level="instrument"): + self.signal[stock_id] = stock_val + def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None): - _sample_signal = sample_feature(self.signal, stock_id, start_time=pred_start_time, end_time=pred_end_time, fields="signal", method="last") - if _sample_signal.empty: + if stock_id not in self.signal: return self.TREND_MID - elif _sample_signal.iloc[0] > 0: - return self.TREND_LONG else: - return self.TREND_SHORT \ No newline at end of file + _sample_signal = sample_feature(self.signal[stock_id], pred_start_time, pred_end_time, fields="signal", method="last") + if _sample_signal is None or _sample_signal.iloc[0] == 0: + return self.TREND_MID + elif _sample_signal.iloc[0] > 0: + return self.TREND_LONG + else: + return self.TREND_SHORT + \ No newline at end of file diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index 193906dcd..fb5b44334 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -12,7 +12,7 @@ from ..utils import get_sample_freq_calendar from ..data.dataset import DatasetH from ..data.dataset.utils import get_level_index from ..contrib.backtest.order import Order -from ..contrib.backtest.env import TradeCalendarBase +from ..contrib.backtest.env import BaseTradeCalendar """ 1. BaseStrategy 的粒度一定是数据粒度的整数倍 @@ -20,22 +20,10 @@ from ..contrib.backtest.env import TradeCalendarBase - adjust_dates这个东西啥用 - label和freq和strategy的bar分离,这个如何决策呢 """ -class BaseStrategy(TradeCalendarBase): - def __init__(self, step_bar, start_time=None, end_time=None, **kwargs): - self.step_bar = step_bar - self.reset(start_time=start_time, end_time=end_time, **kwargs) - - def reset(self, start_time=None, end_time=None, **kwargs): - if start_time or end_time : - self._reset_trade_calendar(start_time=start_time, end_time=end_time) - - for k, v in kwargs: - if hasattr(self, k): - setattr(self, k, v) - +class BaseStrategy(BaseTradeCalendar): def generate_order_list(self, **kwargs): - self.trade_index = self.trade_index + 1 + raise NotImplementedError("generator_order_list is not implemented!") class RuleStrategy(BaseStrategy): @@ -50,14 +38,14 @@ class ModelStrategy(BaseStrategy): super(ModelStrategy, self).__init__(step_bar, start_time, end_time, **kwargs) def _convert_index_format(self, df): - if get_level_index(df, level="datetime") == 0: + if get_level_index(df, level="datetime") == 1: df = df.swaplevel().sort_index() return df def _update_model(self): """update pred score """ - pass + raise NotImplementedError("_update_model is not implemented!") class TradingEnhancement: def reset(self, trade_order_list=None): diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 0f365956d..ea573d819 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -861,15 +861,38 @@ def sample_calendar_bac(calendar_raw, freq_raw, freq_sam): else: raise ValueError("sample freq must be xmin, xd, xw, xm") +def parse_freq(freq): + freq = freq.lower() + search_obj =re.search("^([0-9]*)([a-z]+)", freq) + if search_obj is None: + raise ValueError("freq format is not supported") + _count = int(search_obj.group(1) if search_obj.group(1) else "1") + _freq = search_obj.group(2) + _freq_format_dict = { + "month": "month", + "mon": "month", + "week": "week", + "w": "week", + "day": "day", + "d": "day", + "minute": "minute", + "min": "minute", + } + try: + _freq = _freq_format_dict.get(_freq) + except KeyError: + raise ValueError("freq format is not supported, the supported freq includes (x)month/m, (x)day/d, (x)minute/min") + return _count, _freq + def sample_calendar(calendar_raw, freq_raw, freq_sam): """ freq_raw : "min" or "day" """ - freq_raw = "1" + freq_raw if re.match("^[0-9]", freq_raw) is None else freq_raw - freq_sam = "1" + freq_sam if re.match("^[0-9]", freq_sam) is None else freq_sam + raw_count, freq_raw = parse_freq(freq_raw) + sam_count, freq_sam = parse_freq(freq_sam) if not len(calendar_raw): return calendar_raw - if freq_sam.endswith(("minute", "min")): + if freq_sam == "minute": def cal_next_sam_minute(x, sam_minutes): hour = x.hour minute = x.minute @@ -888,38 +911,36 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam): return 13 + (minute_index - 120) // 60, (minute_index - 120) % 60 else: raise ValueError("calendar minute_index error") - sam_minutes = int(freq_sam[:-3]) if freq_sam.endswith("min") else int(freq_sam[:-6]) - if not freq_raw.endswith(("minute", "min")): + + if req_raw != "minute": raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min") else: - raw_minutes = int(freq_raw[:-3]) if freq_raw.endswith("min") else int(freq_raw[:-6]) - if raw_minutes > sam_minutes: + if raw_count > sam_count: raise ValueError("raw freq must be higher than sample freq") - _calendar_minute = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_minutes), 0), calendar_raw))) + _calendar_minute = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_count), 0), calendar_raw))) if calendar_raw[0] > _calendar_minute[0]: _calendar_minute[0] = calendar_raw[0] return _calendar_minute else: _calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw))) - if freq_sam.endswith(("day", "d")): - sam_days = int(freq_sam[:-1]) if freq_sam.endswith("d") else int(freq_sam[:-3]) - return _calendar_day[::sam_days] + if freq_sam == "day": + return _calendar_day[::sam_count] - elif freq_sam.endswith(("week", "w")): - sam_weeks = int(freq_sam[:-1]) if freq_sam.endswith("w") else int(freq_sam[:-4]) + elif freq_sam == "week": _day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day))) _calendar_week = _calendar_day[np.ediff1d(_day_in_week, to_begin=-1) < 0] - return _calendar_week[::sam_weeks] + return _calendar_week[::sam_count] - elif freq_sam.endswith(("month", "m")): - sam_months = int(freq_sam[:-1]) if freq_sam.endswith("m") else int(freq_sam[:-5]) + elif freq_sam == "month": _day_in_month = np.array(list(map(lambda x: x.day, _calendar_day))) _calendar_month = _calendar_day[np.ediff1d(_day_in_month, to_begin=-1) < 0] - return _calendar_month[::sam_months] + return _calendar_month[::sam_count] else: raise ValueError("sample freq must be xmin, xd, xw, xm") def get_sample_freq_calendar(start_time=None, end_time=None, freq="day", **kwargs): + _, norm_freq = parse_freq(freq) + from ..data.data import Cal try: @@ -927,34 +948,47 @@ def get_sample_freq_calendar(start_time=None, end_time=None, freq="day", **kwarg freq, freq_sam = freq, None except ValueError: freq_sam = freq - if freq.endswith(("m", "month", "w", "week", "d", "day")): + if norm_freq in ["month", "week", "day"]: try: - _calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, **kwargs) - freq = "min" - except ValueError: _calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="day", freq_sam=freq, **kwargs) freq = "day" - elif freq.endswith(("min", "minute")): + except ValueError: + raise + _calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, **kwargs) + freq = "min" + elif norm_freq == "minute": _calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, **kwargs) freq = "min" else: raise ValueError(f"freq {freq} is not supported") return _calendar, freq, freq_sam -def sample_feature(feature, instruments=None, start_time=None, end_time=None, fields=None, method=None, method_kwargs={}): - if instruments and not isinstance(instruments, list): - instruments = [instruments] - selector_inst = slice(None) if instruments is None else instruments +def sample_feature(feature, start_time=None, end_time=None, fields=None, method="last", method_kwargs={}): selector_datetime = slice(start_time, end_time) - if isinstance(feature, pd.Series): - feature = feature.loc[(selector_inst, selector_datetime)] - if fields: - warnings.warn(f"sample series feature, {fields} is ignored!") - elif isinstance(feature, pd.DataFrame): - selector_fields = slice(None) if fields is None else fields - feature = feature.loc[(selector_inst, selector_datetime), selector_fields] - if method: - return getattr(feature.groupby(level="instrument"), method)(**method_kwargs) - else: - return feature + fields = fields if fields else slice(None) + from ..data.dataset.utils import get_level_index + + datetime_level = get_level_index(feature, level="datetime") == 0 + if isinstance(feature, pd.Series): + feature = feature.loc[selector_datetime] if datetime_level else feature.loc[(slice(None), selector_datetime)] + elif isinstance(feature, pd.DataFrame): + feature = feature.loc[selector_datetime, fields] if datetime_level else feature.loc[(slice(None), selector_datetime), fields] + if feature.empty: + return None + if isinstance(feature.index, pd.MultiIndex): + if callable(method): + method_func = method + return feature.groupby(level="instrument").apply(lambda x:method_func(x, **method_kwargs)) + elif isinstance(method, str): + return getattr(feature.groupby(level="instrument"), method)(**method_kwargs) + else: + if callable(method): + method_func = method + return method_func(feature, **method_kwargs) + elif isinstance(method, str): + return getattr(feature, method)(**method_kwargs) + + return feature + + \ No newline at end of file diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 2c1b6fecc..51a9a305c 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -233,8 +233,8 @@ class PortAnaRecord(SignalRecord): super().__init__(recorder=recorder, **kwargs) self.strategy_config = config["strategy"] + self.env_config = config["env"] self.backtest_config = config["backtest"] - self.strategy = init_instance_by_config(self.strategy_config, accept_types=BaseStrategy) def generate(self, **kwargs): # check previously stored prediction results @@ -244,36 +244,32 @@ class PortAnaRecord(SignalRecord): super().generate() # custom strategy and get backtest - pred_score = super().load("pred.pkl") - report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config) - report_normal = report_dict.get("report_df") - positions_normal = report_dict.get("positions") - self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()) - self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path()) - order_normal = report_dict.get("order_list") - if order_normal: - self.recorder.save_objects(**{"order_normal.pkl": order_normal}, artifact_path=PortAnaRecord.get_path()) - - # analysis - analysis = dict() - analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"]) - analysis["excess_return_with_cost"] = risk_analysis( - report_normal["return"] - report_normal["bench"] - report_normal["cost"] - ) - # save portfolio analysis results - analysis_df = pd.concat(analysis) # type: pd.DataFrame - # log metrics - self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict())) - # save results - self.recorder.save_objects(**{"port_analysis.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()) - logger.info( - f"Portfolio analysis record 'port_analysis.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" - ) - # print out results - pprint("The following are analysis results of the excess return without cost.") - pprint(analysis["excess_return_without_cost"]) - pprint("The following are analysis results of the excess return with cost.") - pprint(analysis["excess_return_with_cost"]) + report_list = normal_backtest(env=self.env_config, strategy=self.strategy_config, **self.backtest_config) + for report_id, (report_normal, positions_normal) in enumerate(report_list): + if report_dict is None: + continue + + self.recorder.save_objects(**{f"report_normal_{report_id}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()) + self.recorder.save_objects(**{f"positions_norma_{report_id}l.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path()) + # analysis + analysis = dict() + analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"]) + analysis["excess_return_with_cost"] = risk_analysis( + report_normal["return"] - report_normal["bench"] - report_normal["cost"] + ) + analysis_df = pd.concat(analysis) # type: pd.DataFrame + # log metrics + self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict())) + # save results + self.recorder.save_objects(**{f"port_analysis.pkl_{report_id}": analysis_df}, artifact_path=PortAnaRecord.get_path()) + logger.info( + f"Portfolio analysis record 'port_analysis_{report_id}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" + ) + # print out results + pprint("The following are analysis results of the excess return without cost.") + pprint(analysis["excess_return_without_cost"]) + pprint("The following are analysis results of the excess return with cost.") + pprint(analysis["excess_return_with_cost"]) def list(self): return [