diff --git a/examples/highfreq/backtest/workflow.py b/examples/highfreq/backtest/workflow.py index e5a832927..d031d40f2 100644 --- a/examples/highfreq/backtest/workflow.py +++ b/examples/highfreq/backtest/workflow.py @@ -28,7 +28,7 @@ if __name__ == "__main__": ################################### # train model ################################### - + data_handler_config = { "start_time": "2008-01-01", "end_time": "2020-08-01", @@ -70,7 +70,7 @@ if __name__ == "__main__": }, }, } - # model initialization + # model initialization model = init_instance_by_config(task["model"]) dataset = init_instance_by_config(task["dataset"]) model.fit(dataset) @@ -78,7 +78,7 @@ if __name__ == "__main__": trade_start_time = "2017-01-31" trade_end_time = "2018-01-31" - backtest_config={ + backtest_config = { "strategy": { "class": "TopkDropoutStrategy", "module_path": "qlib.contrib.strategy.model_strategy", @@ -90,7 +90,7 @@ if __name__ == "__main__": "n_drop": 5, }, }, - "env":{ + "env": { "class": "SplitEnv", "module_path": "qlib.contrib.backtest.env", "kwargs": { @@ -101,7 +101,7 @@ if __name__ == "__main__": "kwargs": { "step_bar": "day", "verbose": True, - } + }, }, "sub_strategy": { "class": "SBBStrategyEMA", @@ -110,11 +110,11 @@ if __name__ == "__main__": "step_bar": "day", "freq": "day", "instruments": "csi300", - } - } - } + }, + }, + }, }, - "backtest":{ + "backtest": { "start_time": trade_start_time, "end_time": trade_end_time, "verbose": False, @@ -125,8 +125,14 @@ if __name__ == "__main__": "open_cost": 0.0005, "close_cost": 0.0015, "min_cost": 5, - } + }, } - - report_dict = backtest(start_time=trade_start_time, end_time=trade_end_time, **backtest_config, account=1e8, deal_price="$close", verbose=False) \ No newline at end of file + report_dict = backtest( + start_time=trade_start_time, + end_time=trade_end_time, + **backtest_config, + account=1e8, + deal_price="$close", + verbose=False, + ) diff --git a/qlib/contrib/backtest/__init__.py b/qlib/contrib/backtest/__init__.py index 4a03bbe47..21d3913e5 100644 --- a/qlib/contrib/backtest/__init__.py +++ b/qlib/contrib/backtest/__init__.py @@ -22,7 +22,7 @@ def get_exchange( freq="day", start_time=None, end_time=None, - codes = "all", + codes="all", subscribe_fields=[], open_cost=0.0015, close_cost=0.0025, @@ -89,6 +89,7 @@ def get_exchange( else: return init_instance_by_config(exchange, accept_types=Exchange) + def init_env_instance_by_config(env): if isinstance(env, dict): env_config = copy.copy(env) @@ -103,6 +104,7 @@ def init_env_instance_by_config(env): else: return env + def setup_exchange(root_instance, trade_exchange=None, force=False): if "trade_exchange" in inspect.getfullargspec(root_instance.__class__).args: if force: @@ -114,8 +116,8 @@ def setup_exchange(root_instance, trade_exchange=None, force=False): setup_exchange(root_instance.sub_env, trade_exchange) if hasattr(root_instance, "sub_strategy"): setup_exchange(root_instance.sub_strategy, trade_exchange) - - + + def backtest(start_time, end_time, strategy, env, benchmark=None, account=1e9, **kwargs): trade_strategy = init_instance_by_config(strategy) trade_env = init_env_instance_by_config(env) diff --git a/qlib/contrib/backtest/account.py b/qlib/contrib/backtest/account.py index 8bf7dedb7..ad88e274a 100644 --- a/qlib/contrib/backtest/account.py +++ b/qlib/contrib/backtest/account.py @@ -11,7 +11,6 @@ from .order import Order from ...utils import parse_freq, sample_feature - """ rtn & earning in the Account rtn: @@ -87,7 +86,7 @@ class Account: elif norm_freq == "minute": _temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1) else: - raise ValueError(f"benchmark freq {freq} is not supported") + raise ValueError(f"benchmark freq {freq} is not supported") if len(_temp_result) == 0: raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark") return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0) @@ -95,20 +94,20 @@ class Account: def _sample_benchmark(self, bench, trade_start_time, trade_end_time): def cal_change(x): return x.prod() - 1 + return sample_feature(bench, trade_start_time, trade_end_time, method=cal_change) - def reset(self, benchmark=None, freq=None,**kwargs): + def reset(self, benchmark=None, freq=None, **kwargs): if benchmark: self.benchmark = benchmark if freq: self.freq = freq - if self.freq and self.benchmark and (freq or benchmark) + if self.freq and self.benchmark and (freq or benchmark): self.bench = self._cal_benchmark(self.benchmark, self.start_time, self.end_time, self.freq) for k, v in kwargs: if hasattr(k): setattr(k, v) - def get_positions(self): return self.positions @@ -203,7 +202,7 @@ class Account: turnover_rate=self.to / last_account_value, cost_rate=self.ct / last_account_value, stock_value=now_stock_value, - bench_value=self._sample_benchmark(self.bench, trade_start_time, trade_end_time) + bench_value=self._sample_benchmark(self.bench, trade_start_time, trade_end_time), ) # set now_account_value to position self.current.position["now_account_value"] = now_account_value diff --git a/qlib/contrib/backtest/backtest.py b/qlib/contrib/backtest/backtest.py index a7e009a9a..d6fcb509d 100644 --- a/qlib/contrib/backtest/backtest.py +++ b/qlib/contrib/backtest/backtest.py @@ -7,6 +7,7 @@ import pandas as pd from .account import Account + def backtest(start_time, end_time, trade_strategy, trade_env, benchmark, account): trade_account = Account(init_cash=account, benchmark=benchmark, start_time=start_time, end_time=end_time) @@ -17,10 +18,9 @@ def backtest(start_time, end_time, trade_strategy, trade_env, benchmark, account while not trade_env.finished(): _order_list = trade_strategy.generate_order_list(**trade_state) trade_state, trade_info = trade_env.execute(_order_list) - + report_df = trade_account.report.generate_report_dataframe() positions = trade_account.get_positions() report_dict = {"report_df": report_df, "positions": positions} return report_dict - diff --git a/qlib/contrib/backtest/env.py b/qlib/contrib/backtest/env.py index 9fa993e7b..ade5caf24 100644 --- a/qlib/contrib/backtest/env.py +++ b/qlib/contrib/backtest/env.py @@ -1,5 +1,3 @@ - - import re import json import copy @@ -14,15 +12,8 @@ from .report import Report from .order import Order - class BaseTradeCalendar: - def __init__( - self, - step_bar, - start_time=None, - end_time=None, - **kwargs - ): + def __init__(self, step_bar, start_time=None, end_time=None, **kwargs): self.step_bar = step_bar self.reset(start_time=start_time, end_time=end_time) @@ -36,8 +27,10 @@ class BaseTradeCalendar: if self.start_time and self.end_time: _calendar, freq, freq_sam = get_sample_freq_calendar(freq=self.step_bar) self.calendar = _calendar - _start_time, _end_time, _start_index, _end_index = Cal.locate_index(self.start_time, self.end_time, freq=freq, freq_sam=freq_sam) - _trade_calendar = self.calendar[_start_index: _end_index + 1] + _start_time, _end_time, _start_index, _end_index = Cal.locate_index( + self.start_time, self.end_time, freq=freq, freq_sam=freq_sam + ) + _trade_calendar = self.calendar[_start_index : _end_index + 1] self.start_index = _start_index self.end_index = _end_index self.trade_len = _end_index - _start_index + 1 @@ -52,7 +45,7 @@ class BaseTradeCalendar: for k, v in kwargs: if hasattr(self, k): setattr(self, k, v) - + def _get_calendar_time(self, trade_index=1, shift=0): trade_index = trade_index - shift calendar_index = self.start_index + trade_index @@ -64,6 +57,7 @@ class BaseTradeCalendar: def step(self): self.trade_index = self.trade_index + 1 + class BaseEnv(BaseTradeCalendar): """ # Strategy framework document @@ -83,8 +77,10 @@ class BaseEnv(BaseTradeCalendar): ): self.generate_report = update_report self.verbose = verbose - super(BaseEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs) - + super(BaseEnv, self).__init__( + step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs + ) + def reset(self, trade_account=None, **kwargs): super(BaseEnv, self).reset(**kwargs) if trade_account: @@ -94,7 +90,7 @@ class BaseEnv(BaseTradeCalendar): def get_init_state(self): init_state = {"current": self.trade_account.current} return init_state - + def execute(self, **kwargs): raise NotImplementedError("execute is not implemented!") @@ -104,23 +100,32 @@ class BaseEnv(BaseTradeCalendar): def get_report(self): raise NotImplementedError("get_report is not implemented!") + class SplitEnv(BaseEnv): def __init__( - self, - step_bar, + self, + step_bar, sub_env, sub_strategy, - start_time=None, - end_time=None, + start_time=None, + end_time=None, trade_account=None, update_report=False, verbose=False, - **kwargs + **kwargs, ): self.sub_env = sub_env self.sub_strategy = sub_strategy - super(SplitEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, update_report=update_report, verbose=verbose, **kwargs) - + super(SplitEnv, self).__init__( + step_bar=step_bar, + start_time=start_time, + end_time=end_time, + trade_account=trade_account, + update_report=update_report, + verbose=verbose, + **kwargs, + ) + def reset(self, trade_account=None, **kwargs): super(SplitEnv, self).reset(trade_account=trade_account, **kwargs) if trade_account: @@ -129,9 +134,9 @@ class SplitEnv(BaseEnv): def execute(self, order_list, **kwargs): if self.finished(): raise StopIteration(f"this env has completed its task, please reset it if you want to call it!") - #if self.track: + # if self.track: # yield action - #episode_reward = 0 + # episode_reward = 0 super(SplitEnv, self).step() trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time) @@ -140,9 +145,11 @@ class SplitEnv(BaseEnv): while not self.sub_env.finished(): _order_list = self.sub_strategy.generate_order_list(**trade_state) trade_state, trade_info = self.sub_env.execute(order_list=_order_list) - + if self.generate_report: - self.trade_account.update_report(trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange) + self.trade_account.update_report( + trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange + ) _obs = {"current": self.trade_account.current} _info = {} return _obs, _info @@ -150,31 +157,40 @@ class SplitEnv(BaseEnv): def get_report(self): _report = self.trade_account.report.generate_report_dataframe() if self.generate_report else None _positions = self.trade_account.get_positions() if self.generate_report else None - return [(_report,_positions), *sub_env.get_report()] - -class SimulatorEnv(BaseEnv): + return [(_report, _positions), *sub_env.get_report()] + +class SimulatorEnv(BaseEnv): def __init__( - self, - step_bar, - start_time=None, - end_time=None, - trade_account=None, + self, + step_bar, + start_time=None, + end_time=None, + trade_account=None, trade_exchange=None, update_report=False, verbose=False, **kwargs, ): - super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, trade_exchange=trade_exchange, update_report=update_report, verbose=verbose, **kwargs) + super(SimulatorEnv, self).__init__( + step_bar=step_bar, + start_time=start_time, + end_time=end_time, + trade_account=trade_account, + trade_exchange=trade_exchange, + update_report=update_report, + verbose=verbose, + **kwargs, + ) def reset(self, trade_exchange=None, **kwargs): super(SimulatorEnv, self).reset(**kwargs) if trade_exchange: - self.trade_exchange=trade_exchange + self.trade_exchange = trade_exchange def execute(self, order_list, **kwargs): """ - Return: obs, done, info + Return: obs, done, info """ if self.finished(): raise StopIteration(f"this env has completed its task, please reset it if you want to call it!") @@ -184,7 +200,9 @@ class SimulatorEnv(BaseEnv): for order in order_list: if self.trade_exchange.check_order(order) is True: # execute the order - trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(order, trade_account=self.trade_account) + trade_val, trade_cost, trade_price = self.trade_exchange.deal_order( + order, trade_account=self.trade_account + ) trade_info.append((order, trade_val, trade_cost, trade_price)) if self.verbose: if order.direction == Order.SELL: # sell @@ -214,7 +232,9 @@ class SimulatorEnv(BaseEnv): # do nothing pass if self.generate_report: - self.trade_account.update_report(trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange) + self.trade_account.update_report( + trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange + ) _obs = {"current": self.trade_account.current} _info = {"trade_info": trade_info} return _obs, _info @@ -222,9 +242,4 @@ class SimulatorEnv(BaseEnv): def get_report(self): _report = self.trade_account.report.generate_report_dataframe() if self.generate_report else None _positions = self.trade_account.get_positions() if self.generate_report else None - return [ - { - "report": _report, - "positions": _positions - } - ] \ No newline at end of file + return [{"report": _report, "positions": _positions}] diff --git a/qlib/contrib/backtest/exchange.py b/qlib/contrib/backtest/exchange.py index 399f9e151..a25b9b4a0 100644 --- a/qlib/contrib/backtest/exchange.py +++ b/qlib/contrib/backtest/exchange.py @@ -16,7 +16,6 @@ from ...log import get_module_logger from .order import Order - class Exchange: def __init__( self, @@ -101,14 +100,15 @@ class Exchange: self.min_cost = min_cost self.limit_threshold = limit_threshold - self.extra_quote = extra_quote self.set_quote(codes, start_time, end_time) def set_quote(self, codes, start_time, end_time): if len(codes) == 0: codes = D.instruments() - self.quote = D.features(codes, self.all_fields, start_time, end_time, freq=self.freq, disk_cache=True).dropna(subset=["$close"]) + self.quote = D.features(codes, self.all_fields, start_time, end_time, freq=self.freq, disk_cache=True).dropna( + subset=["$close"] + ) self.quote.columns = self.all_fields if self.quote[self.deal_price].isna().any(): @@ -168,7 +168,6 @@ class Exchange: is limtited """ return sample_feature(self.quote[stock_id], start_time, end_time, fields="limit", method="all").iloc[0] - def check_stock_suspended(self, stock_id, start_time, end_time): # is suspended @@ -180,7 +179,9 @@ class Exchange: def is_stock_tradable(self, stock_id, start_time, end_time): # check if stock can be traded # same as check in check_order - if self.check_stock_suspended(stock_id, start_time, end_time) or self.check_stock_limit(stock_id, start_time, end_time): + if self.check_stock_suspended(stock_id, start_time, end_time) or self.check_stock_limit( + stock_id, start_time, end_time + ): return False else: return True @@ -235,9 +236,13 @@ class Exchange: return sample_feature(self.quote[stock_id], start_time, end_time, fields="$close", method="last").iloc[0] def get_deal_price(self, stock_id, start_time, end_time): - deal_price = sample_feature(self.quote[stock_id], start_time, end_time, fields=self.deal_price, method="last").iloc[0] + deal_price = sample_feature( + self.quote[stock_id], start_time, end_time, fields=self.deal_price, method="last" + ).iloc[0] if np.isclose(deal_price, 0.0) or np.isnan(deal_price): - self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!") + self.logger.warning( + f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!" + ) self.logger.warning(f"setting deal_price to close price") deal_price = self.get_close(stock_id, start_time, end_time) return deal_price @@ -274,7 +279,9 @@ class Exchange: amount_dict = {} for stock_id in weight_position: - if weight_position[stock_id] > 0.0 and self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time): + if weight_position[stock_id] > 0.0 and self.is_stock_tradable( + stock_id=stock_id, start_time=start_time, end_time=end_time + ): amount_dict[stock_id] = ( cash * weight_position[stock_id] @@ -377,7 +384,10 @@ class Exchange: self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time) is False and self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) is False ): - value += self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time) * amount_dict[stock_id] + value += ( + self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time) + * amount_dict[stock_id] + ) return value def round_amount_by_trade_unit(self, deal_amount, factor): diff --git a/qlib/contrib/backtest/interpreter.py b/qlib/contrib/backtest/interpreter.py index 94d6f9ec2..7f33c809d 100644 --- a/qlib/contrib/backtest/interpreter.py +++ b/qlib/contrib/backtest/interpreter.py @@ -1,15 +1,16 @@ - class BaseInterpreter: @staticmethod def interpret(**kwargs): raise NotImplementedError("interpret is not implemented!") + class ActionInterpreter: @staticmethod def interpret(action, **kwargs): return action + class StateInterpreter: @staticmethod def interpret(state, **kwargs): - return state \ No newline at end of file + return state diff --git a/qlib/contrib/backtest/report.py b/qlib/contrib/backtest/report.py index 3bee440e0..57e56c9a3 100644 --- a/qlib/contrib/backtest/report.py +++ b/qlib/contrib/backtest/report.py @@ -45,16 +45,7 @@ class Report: bench_value=None, ): # check data - if None in [ - trade_time, - account_value, - cash, - return_rate, - turnover_rate, - cost_rate, - stock_value, - bench_value - ]: + if None in [trade_time, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value, bench_value]: raise ValueError( "None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value, bench_value]" ) @@ -108,5 +99,5 @@ class Report: turnover_rate=r.loc[trade_time]["turnover"], cost_rate=r.loc[trade_time]["cost"], stock_value=r.loc[trade_time]["value"], - bench_value=r.loc[trade_time]["bench"] + bench_value=r.loc[trade_time]["bench"], ) diff --git a/qlib/contrib/strategy/__init__.py b/qlib/contrib/strategy/__init__.py index b138edb23..e308c1a05 100644 --- a/qlib/contrib/strategy/__init__.py +++ b/qlib/contrib/strategy/__init__.py @@ -7,12 +7,10 @@ from .model_strategy import ( WeightStrategyBase, ) -from .rule_strategy import( +from .rule_strategy import ( TWAPStrategy, SBBStrategyBase, SBBStrategyEMA, ) -from .cost_control import ( - SoftTopkStrategy -) \ No newline at end of file +from .cost_control import SoftTopkStrategy diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/model_strategy.py index 9aab96377..95280dc2f 100644 --- a/qlib/contrib/strategy/model_strategy.py +++ b/qlib/contrib/strategy/model_strategy.py @@ -53,7 +53,9 @@ class TopkDropoutStrategy(ModelStrategy): else: strategy will make decision with the tradable state of the stock info and avoid buy and sell them. """ - super(TopkDropoutStrategy, self).__init__(step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange) + super(TopkDropoutStrategy, self).__init__( + step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange + ) self.topk = topk self.n_drop = n_drop self.method_sell = method_sell @@ -65,8 +67,7 @@ class TopkDropoutStrategy(ModelStrategy): self.stock_count = {} self.hold_thresh = hold_thresh self.only_tradable = only_tradable - - + def reset(self, trade_exchange=None, **kwargs): super(TopkDropoutStrategy, self).reset(**kwargs) if trade_exchange: @@ -94,7 +95,9 @@ class TopkDropoutStrategy(ModelStrategy): cur_n = 0 res = [] for si in reversed(l) if reverse else l: - if self.trade_exchange.is_stock_tradable(stock_id=si, start_time=trade_start_time, end_time=trade_end_time): + if self.trade_exchange.is_stock_tradable( + stock_id=si, start_time=trade_start_time, end_time=trade_end_time + ): res.append(si) cur_n += 1 if cur_n >= n: @@ -105,7 +108,13 @@ class TopkDropoutStrategy(ModelStrategy): return get_first_n(l, n, reverse=True) def filter_stock(l): - return [si for si in l if self.trade_exchange.is_stock_tradable(stock_id=si, start_time=trade_start_time, end_time=trade_end_time)] + return [ + si + for si in l + if self.trade_exchange.is_stock_tradable( + stock_id=si, start_time=trade_start_time, end_time=trade_end_time + ) + ] else: # Otherwise, the stock will make decision with out the stock tradable info @@ -166,11 +175,16 @@ class TopkDropoutStrategy(ModelStrategy): buy_signal = pred_score.sort_values(ascending=False).iloc[: self.topk].index for code in current_stock_list: - if not self.trade_exchange.is_stock_tradable(stock_id=code, start_time=trade_start_time, end_time=trade_end_time): + if not self.trade_exchange.is_stock_tradable( + stock_id=code, start_time=trade_start_time, end_time=trade_end_time + ): continue if code in sell: # check hold limit - if self.stock_count[code] < self.thresh or current_temp.get_stock_count(code, bar=self.step_bar) < self.hold_thresh: + if ( + self.stock_count[code] < self.thresh + or current_temp.get_stock_count(code, bar=self.step_bar) < self.hold_thresh + ): # can not sell this code # no buy signal, but the stock is kept self.stock_count[code] += 1 @@ -188,7 +202,9 @@ class TopkDropoutStrategy(ModelStrategy): # is order executable if self.trade_exchange.check_order(sell_order): sell_order_list.append(sell_order) - trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(sell_order, position=current_temp) + trade_val, trade_cost, trade_price = self.trade_exchange.deal_order( + sell_order, position=current_temp + ) # update cash cash += trade_val - trade_cost # sold @@ -213,10 +229,14 @@ class TopkDropoutStrategy(ModelStrategy): # value = value / (1+self.trade_exchange.open_cost) # set open_cost limit for code in buy: # check is stock suspended - if not self.trade_exchange.is_stock_tradable(stock_id=code, start_time=trade_start_time, end_time=trade_end_time): + if not self.trade_exchange.is_stock_tradable( + stock_id=code, start_time=trade_start_time, end_time=trade_end_time + ): continue # buy order - buy_price = self.trade_exchange.get_deal_price(stock_id=code, start_time=trade_start_time, end_time=trade_end_time) + buy_price = self.trade_exchange.get_deal_price( + stock_id=code, start_time=trade_start_time, end_time=trade_end_time + ) buy_amount = value / buy_price factor = self.trade_exchange.get_factor(stock_id=code, start_time=trade_start_time, end_time=trade_end_time) buy_amount = self.trade_exchange.round_amount_by_trade_unit(buy_amount, factor) @@ -231,17 +251,24 @@ class TopkDropoutStrategy(ModelStrategy): buy_order_list.append(buy_order) self.stock_count[code] = 1 return sell_order_list + buy_order_list - + + class WeightStrategyBase(ModelStrategy): - def __init__(self, step_bar, start_time=None, end_time=None, order_generator_cls_or_obj=OrderGenWInteract, trade_exchange=None, **kwargs): + def __init__( + self, + step_bar, + start_time=None, + end_time=None, + order_generator_cls_or_obj=OrderGenWInteract, + trade_exchange=None, + **kwargs, + ): super(WeightStrategyBase, self).__init__(step_bar, start_time, end_time) self.trade_exchange = trade_exchange if isinstance(order_generator_cls_or_obj, type): self.order_generator = order_generator_cls_or_obj() else: self.order_generator = order_generator_cls_or_obj - - def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time): """ diff --git a/qlib/contrib/strategy/order_generator.py b/qlib/contrib/strategy/order_generator.py index d263f658d..93bf7b2fe 100644 --- a/qlib/contrib/strategy/order_generator.py +++ b/qlib/contrib/strategy/order_generator.py @@ -81,10 +81,16 @@ class OrderGenWInteract(OrderGenerator): # calculate current_tradable_value current_amount_dict = current.get_stock_amount_dict() current_total_value = trade_exchange.calculate_amount_position_value( - amount_dict=current_amount_dict, trade_start_time=trade_start_time, trade_end_time=trade_end_time, only_tradable=False + amount_dict=current_amount_dict, + trade_start_time=trade_start_time, + trade_end_time=trade_end_time, + only_tradable=False, ) current_tradable_value = trade_exchange.calculate_amount_position_value( - amount_dict=current_amount_dict, trade_start_time=trade_start_time, trade_end_time=trade_end_time, only_tradable=True + amount_dict=current_amount_dict, + trade_start_time=trade_start_time, + trade_end_time=trade_end_time, + only_tradable=True, ) # add cash current_tradable_value += current.get_cash() @@ -97,7 +103,9 @@ class OrderGenWInteract(OrderGenerator): # value. Then just sell all the stocks target_amount_dict = copy.deepcopy(current_amount_dict.copy()) for stock_id in list(target_amount_dict.keys()): - if trade_exchange.is_stock_tradable(stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time): + if trade_exchange.is_stock_tradable( + stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time + ): del target_amount_dict[stock_id] else: # consider cost rate @@ -108,13 +116,13 @@ class OrderGenWInteract(OrderGenerator): target_amount_dict = trade_exchange.generate_amount_position_from_weight_position( weight_position=target_weight_position, cash=current_tradable_value, - trade_start_time=trade_start_time, + trade_start_time=trade_start_time, trade_end_time=trade_end_time, ) order_list = trade_exchange.generate_order_for_target_amount_position( target_position=target_amount_dict, current_position=current_amount_dict, - trade_start_time=trade_start_time, + trade_start_time=trade_start_time, trade_end_time=trade_end_time, ) return order_list @@ -161,7 +169,9 @@ class OrderGenWOInteract(OrderGenerator): amount_dict = {} for stock_id in target_weight_position: # Current rule will ignore the stock that not hold and cannot be traded at predict date - if trade_exchange.is_stock_tradable(stock_id=stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time): + if trade_exchange.is_stock_tradable( + stock_id=stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time + ): amount_dict[stock_id] = ( risk_total_value * target_weight_position[stock_id] / trade_exchange.get_close(stock_id, pred_date) ) diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index b432ccea2..1acf55314 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -11,7 +11,6 @@ from ..backtest.order import Order class TWAPStrategy(RuleStrategy, TradingEnhancement): - def reset(self, trade_order_list=None, **kwargs): super(TWAPStrategy, self).reset(**kwargs) TradingEnhancement.reset(self, trade_order_list=trade_order_list) @@ -19,7 +18,6 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement): self.trade_amount = {} for order in self.trade_order_list: self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len - def generate_order_list(self, **kwargs): super(TopkDropoutStrategy, self).step() @@ -37,10 +35,12 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement): order_list.append(_order) return order_list + class SBBStrategyBase(RuleStrategy, TradingEnhancement): """ - (S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy. + (S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy. """ + TREND_MID = 0 TREND_SHORT = 1 TREND_LONG = 2 @@ -50,11 +50,10 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement): TradingEnhancement.reset(self, trade_order_list=trade_order_list) if trade_order_list: self.trade_amount = {} - self.trade_trend = {} + self.trade_trend = {} for order in self.trade_order_list: self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len self.trade_trend[(order.stock_id, order.direction)] = self.TREND_MID - def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None): raise NotImplementedError("pred_price_trend method is not implemented!") @@ -81,10 +80,15 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement): order_list.append(_order) else: if self.trade_index % 2 == 1: - if _pred_trend == self.TREND_SHORT and order.direction == order.SELL or _pred_trend == self.TREND_LONG and order.direction == order.BUY: + if ( + _pred_trend == self.TREND_SHORT + and order.direction == order.SELL + or _pred_trend == self.TREND_LONG + and order.direction == order.BUY + ): _order = Order( stock_id=order.stock_id, - amount=2*self.trade_amount[(order.stock_id, order.direction)], + amount=2 * self.trade_amount[(order.stock_id, order.direction)], start_time=trade_start_time, end_time=trade_end_time, direction=order.direction, # 1 for buy @@ -92,31 +96,37 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement): ) order_list.append(_order) else: - if _pred_trend == self.TREND_SHORT and order.direction == order.BUY or _pred_trend == self.TREND_LONG and order.direction == order.SELL: + if ( + _pred_trend == self.TREND_SHORT + and order.direction == order.BUY + or _pred_trend == self.TREND_LONG + and order.direction == order.SELL + ): _order = Order( stock_id=order.stock_id, - amount=2*self.trade_amount[(order.stock_id, order.direction)], + amount=2 * self.trade_amount[(order.stock_id, order.direction)], start_time=trade_start_time, end_time=trade_end_time, direction=order.direction, # 1 for buy factor=order.factor, - ) + ) order_list.append(_order) - if self.trade_index % 2 == 1: + if self.trade_index % 2 == 1: self.trade_trend[(order.stock_id, order.direction)] = _pred_trend return order_list - + class SBBStrategyEMA(SBBStrategyBase): """ - (S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy with (EMA). + (S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy with (EMA). """ + def __init__( - self, - step_bar, - start_time=None, - end_time=None, + self, + step_bar, + start_time=None, + end_time=None, instruments="csi300", freq="day", **kwargs, @@ -139,22 +149,25 @@ class SBBStrategyEMA(SBBStrategyBase): if self.start_time and self.end_time: fields = ["EMA($close, 10)-EMA($close, 20)"] signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1) - signal_df = D.features(self.instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq) + signal_df = D.features( + self.instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq + ) signal_df = self._convert_index_format(signal_df) signal_df.columns = ["signal"] self.signal = {} for stock_id, stock_val in signal_df.groupby(level="instrument"): self.signal[stock_id] = stock_val - + def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None): if stock_id not in self.signal: return self.TREND_MID else: - _sample_signal = sample_feature(self.signal[stock_id], pred_start_time, pred_end_time, fields="signal", method="last") + _sample_signal = sample_feature( + self.signal[stock_id], pred_start_time, pred_end_time, fields="signal", method="last" + ) if _sample_signal is None or _sample_signal.iloc[0] == 0: return self.TREND_MID elif _sample_signal.iloc[0] > 0: return self.TREND_LONG else: return self.TREND_SHORT - \ No newline at end of file diff --git a/qlib/data/data.py b/qlib/data/data.py index a8d5a42ab..c34c02236 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -126,7 +126,7 @@ class CalendarProvider(abc.ABC): _calendar = np.array(self.load_calendar(freq, future)) _calendar_index = {x: i for i, x in enumerate(_calendar)} # for fast search H["c"][flag_raw] = _calendar, _calendar_index - + if freq_sam is None: return _calendar, _calendar_index else: @@ -134,7 +134,6 @@ class CalendarProvider(abc.ABC): _calendar_sam_index = {x: i for i, x in enumerate(_calendar_sam)} H["c"][flag] = _calendar_sam, _calendar_sam_index return _calendar_sam, _calendar_sam_index - def _uri(self, start_time, end_time, freq, future=False): """Get the uri of calendar generation task.""" @@ -560,7 +559,8 @@ class LocalCalendarProvider(CalendarProvider): else: end_time = _calendar[-1] st, et, si, ei = self.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam, future=future) - return _calendar[si : ei + 1] + return _calendar[si : ei + 1] + class LocalInstrumentProvider(InstrumentProvider): """Local instrument data provider class @@ -767,7 +767,7 @@ class ClientCalendarProvider(CalendarProvider): self.conn = conn def calendar(self, start_time=None, end_time=None, freq="day", future=False): - + self.conn.send_request( request_type="calendar", request_content={ diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index fb5b44334..e5840d66a 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -20,8 +20,9 @@ from ..contrib.backtest.env import BaseTradeCalendar - adjust_dates这个东西啥用 - label和freq和strategy的bar分离,这个如何决策呢 """ + + class BaseStrategy(BaseTradeCalendar): - def generate_order_list(self, **kwargs): raise NotImplementedError("generator_order_list is not implemented!") @@ -29,12 +30,13 @@ class BaseStrategy(BaseTradeCalendar): class RuleStrategy(BaseStrategy): pass + class ModelStrategy(BaseStrategy): - def __init__(self, step_bar, model, dataset:DatasetH, start_time=None, end_time=None, **kwargs): + def __init__(self, step_bar, model, dataset: DatasetH, start_time=None, end_time=None, **kwargs): self.model = model self.dataset = dataset self.pred_scores = self._convert_index_format(self.model.predict(dataset)) - #pred_score_dates = self.pred_scores.index.get_level_values(level="datetime") + # pred_score_dates = self.pred_scores.index.get_level_values(level="datetime") super(ModelStrategy, self).__init__(step_bar, start_time, end_time, **kwargs) def _convert_index_format(self, df): @@ -43,12 +45,11 @@ class ModelStrategy(BaseStrategy): return df def _update_model(self): - """update pred score - """ + """update pred score""" raise NotImplementedError("_update_model is not implemented!") + class TradingEnhancement: def reset(self, trade_order_list=None): if trade_order_list: self.trade_order_list = trade_order_list - diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index ea573d819..a6bba1f38 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -801,6 +801,7 @@ def fname_to_code(fname: str): fname = fname.lstrip(prefix) return fname + ########################## Sample ############################ def sample_calendar_bac(calendar_raw, freq_raw, freq_sam): """ @@ -810,16 +811,17 @@ def sample_calendar_bac(calendar_raw, freq_raw, freq_sam): freq_sam = "1" + freq_sam if re.match("^[0-9]", freq_sam) is None else freq_sam if freq_sam.endswith(("minute", "min")): + def cal_next_sam_minute(x, sam_minutes): hour = x.hour minute = x.minute if 9 <= hour <= 11: - minute_index = (11 - hour)*60 + 30 - minute + 120 + minute_index = (11 - hour) * 60 + 30 - minute + 120 elif 13 <= hour <= 15: - minute_index = (15 - hour)*60 - minute + minute_index = (15 - hour) * 60 - minute else: raise ValueError("calendar hour must be in [9, 11] or [13, 15]") - + minute_index = minute_index // sam_minutes * sam_minutes if 0 <= minute_index < 120: @@ -838,32 +840,40 @@ def sample_calendar_bac(calendar_raw, freq_raw, freq_sam): if raw_minutes > sam_minutes: raise ValueError("raw freq must be higher than sample freq") - _calendar_minute = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_minutes), 59), calendar_raw))) + _calendar_minute = np.unique( + list( + map( + lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_minutes), 59), + calendar_raw, + ) + ) + ) return _calendar_minute else: _calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 23, 59, 59), calendar_raw))) if freq_sam.endswith(("day", "d")): sam_days = int(freq_sam[:-1]) if freq_sam.endswith("d") else int(freq_sam[:-3]) - return _calendar_day[(len(_calendar_day) + sam_days - 1)%sam_days::sam_days] + return _calendar_day[(len(_calendar_day) + sam_days - 1) % sam_days :: sam_days] elif freq_sam.endswith(("week", "w")): sam_weeks = int(freq_sam[:-1]) if freq_sam.endswith("w") else int(freq_sam[:-4]) _day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day))) _calendar_week = _calendar_day[np.ediff1d(_day_in_week[::-1], to_begin=1)[::-1] > 0] - return _calendar_week[(len(_calendar_week) + sam_weeks - 1)%sam_weeks::sam_weeks] + return _calendar_week[(len(_calendar_week) + sam_weeks - 1) % sam_weeks :: sam_weeks] elif freq_sam.endswith(("month", "m")): sam_months = int(freq_sam[:-1]) if freq_sam.endswith("m") else int(freq_sam[:-5]) _day_in_month = np.array(list(map(lambda x: x.day, _calendar_day))) _calendar_month = _calendar_day[np.ediff1d(_day_in_month[::-1], to_begin=1)[::-1] > 0] - return _calendar_month[(len(_calendar_month) + sam_months - 1)%sam_months::sam_months] + return _calendar_month[(len(_calendar_month) + sam_months - 1) % sam_months :: sam_months] else: raise ValueError("sample freq must be xmin, xd, xw, xm") + def parse_freq(freq): freq = freq.lower() - search_obj =re.search("^([0-9]*)([a-z]+)", freq) + search_obj = re.search("^([0-9]*)([a-z]+)", freq) if search_obj is None: raise ValueError("freq format is not supported") _count = int(search_obj.group(1) if search_obj.group(1) else "1") @@ -881,9 +891,12 @@ def parse_freq(freq): try: _freq = _freq_format_dict.get(_freq) except KeyError: - raise ValueError("freq format is not supported, the supported freq includes (x)month/m, (x)day/d, (x)minute/min") + raise ValueError( + "freq format is not supported, the supported freq includes (x)month/m, (x)day/d, (x)minute/min" + ) return _count, _freq + def sample_calendar(calendar_raw, freq_raw, freq_sam): """ freq_raw : "min" or "day" @@ -893,16 +906,17 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam): if not len(calendar_raw): return calendar_raw if freq_sam == "minute": + def cal_next_sam_minute(x, sam_minutes): hour = x.hour minute = x.minute if (hour == 9 and minute >= 30) or (9 < hour < 11) or (hour == 11 and minute < 30): - minute_index = (hour - 9)*60 + minute - 30 + minute_index = (hour - 9) * 60 + minute - 30 elif 13 <= hour < 15: - minute_index = (hour - 13)*60 + minute + 120 + minute_index = (hour - 13) * 60 + minute + 120 else: raise ValueError("calendar hour must be in [9, 11] or [13, 15]") - + minute_index = minute_index // sam_minutes * sam_minutes if 0 <= minute_index < 120: @@ -917,7 +931,11 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam): else: if raw_count > sam_count: raise ValueError("raw freq must be higher than sample freq") - _calendar_minute = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_count), 0), calendar_raw))) + _calendar_minute = np.unique( + list( + map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_count), 0), calendar_raw) + ) + ) if calendar_raw[0] > _calendar_minute[0]: _calendar_minute[0] = calendar_raw[0] return _calendar_minute @@ -937,7 +955,8 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam): return _calendar_month[::sam_count] else: raise ValueError("sample freq must be xmin, xd, xw, xm") - + + def get_sample_freq_calendar(start_time=None, end_time=None, freq="day", **kwargs): _, norm_freq = parse_freq(freq) @@ -963,23 +982,28 @@ def get_sample_freq_calendar(start_time=None, end_time=None, freq="day", **kwarg raise ValueError(f"freq {freq} is not supported") return _calendar, freq, freq_sam + def sample_feature(feature, start_time=None, end_time=None, fields=None, method="last", method_kwargs={}): selector_datetime = slice(start_time, end_time) fields = fields if fields else slice(None) from ..data.dataset.utils import get_level_index - + datetime_level = get_level_index(feature, level="datetime") == 0 if isinstance(feature, pd.Series): feature = feature.loc[selector_datetime] if datetime_level else feature.loc[(slice(None), selector_datetime)] elif isinstance(feature, pd.DataFrame): - feature = feature.loc[selector_datetime, fields] if datetime_level else feature.loc[(slice(None), selector_datetime), fields] + feature = ( + feature.loc[selector_datetime, fields] + if datetime_level + else feature.loc[(slice(None), selector_datetime), fields] + ) if feature.empty: return None if isinstance(feature.index, pd.MultiIndex): if callable(method): method_func = method - return feature.groupby(level="instrument").apply(lambda x:method_func(x, **method_kwargs)) + return feature.groupby(level="instrument").apply(lambda x: method_func(x, **method_kwargs)) elif isinstance(method, str): return getattr(feature.groupby(level="instrument"), method)(**method_kwargs) else: @@ -988,7 +1012,5 @@ def sample_feature(feature, start_time=None, end_time=None, fields=None, method= return method_func(feature, **method_kwargs) elif isinstance(method, str): return getattr(feature, method)(**method_kwargs) - - return feature - \ No newline at end of file + return feature diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 3d7188bcc..b7935ae08 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -254,13 +254,19 @@ class PortAnaRecord(SignalRecord): for report_dep, (report_normal, positions_normal) in enumerate(report_list): if report_dict is None: if self.risk_analysis_dep == report_dep: - warnings.warn(f"the report in dep {risk_analysis_dep} is None, please set the corresponding env with `generate_report==True`") + warnings.warn( + f"the report in dep {risk_analysis_dep} is None, please set the corresponding env with `generate_report==True`" + ) continue - - self.recorder.save_objects(**{f"report_normal_{report_dep}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()) - self.recorder.save_objects(**{f"positions_norma_{report_dep}l.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path()) + + self.recorder.save_objects( + **{f"report_normal_{report_dep}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path() + ) + self.recorder.save_objects( + **{f"positions_norma_{report_dep}l.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path() + ) # analysis - self.risk_analysis_dep == report_dep: + if self.risk_analysis_dep == report_dep: analysis = dict() analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"]) analysis["excess_return_with_cost"] = risk_analysis( @@ -270,7 +276,9 @@ class PortAnaRecord(SignalRecord): # log metrics self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict())) # save results - self.recorder.save_objects(**{f"port_analysis.pkl_{report_dep}": analysis_df}, artifact_path=PortAnaRecord.get_path()) + self.recorder.save_objects( + **{f"port_analysis.pkl_{report_dep}": analysis_df}, artifact_path=PortAnaRecord.get_path() + ) logger.info( f"Portfolio analysis record 'port_analysis_{report_dep}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" )