From af0053eb17ee932d9cd9d3e4625c35258a0c0dc9 Mon Sep 17 00:00:00 2001 From: bxdd Date: Sat, 24 Apr 2021 22:37:36 +0800 Subject: [PATCH] fix bug --- examples/highfreq/backtest/workflow.py | 35 +- qlib/backtest/__init__.py | 130 ------- qlib/backtest/account.py | 170 --------- qlib/backtest/backtest.py | 26 -- qlib/backtest/exchange.py | 429 ----------------------- qlib/backtest/import numpy as np | 90 ----- qlib/backtest/init.py | 132 ------- qlib/backtest/order.py | 30 -- qlib/backtest/position.py | 217 ------------ qlib/backtest/profit_attribution.py | 324 ----------------- qlib/backtest/report.py | 106 ------ qlib/contrib/backtest/__init__.py | 284 +++------------ qlib/contrib/backtest/account.py | 49 +-- qlib/contrib/backtest/backtest.py | 138 +------- qlib/{ => contrib}/backtest/env.py | 41 +-- qlib/contrib/backtest/exchange.py | 100 +++--- qlib/contrib/backtest/interpreter.py | 15 + qlib/contrib/backtest/order.py | 5 +- qlib/contrib/backtest/position.py | 32 +- qlib/contrib/backtest/report.py | 42 +-- qlib/contrib/backtest_new/backtest.py | 0 qlib/contrib/strategy/__init__.py | 4 +- qlib/contrib/strategy/cost_control.py | 2 +- qlib/contrib/strategy/dl_strategy.py | 23 +- qlib/contrib/strategy/order_generator.py | 4 +- qlib/contrib/strategy/rule_strategy.py | 56 +-- qlib/data/data.py | 8 +- qlib/strategy/base.py | 41 +-- qlib/utils/__init__.py | 28 +- 29 files changed, 314 insertions(+), 2247 deletions(-) delete mode 100644 qlib/backtest/__init__.py delete mode 100644 qlib/backtest/account.py delete mode 100644 qlib/backtest/backtest.py delete mode 100644 qlib/backtest/exchange.py delete mode 100644 qlib/backtest/import numpy as np delete mode 100644 qlib/backtest/init.py delete mode 100644 qlib/backtest/order.py delete mode 100644 qlib/backtest/position.py delete mode 100644 qlib/backtest/profit_attribution.py delete mode 100644 qlib/backtest/report.py rename qlib/{ => contrib}/backtest/env.py (89%) create mode 100644 qlib/contrib/backtest/interpreter.py delete mode 100644 qlib/contrib/backtest_new/backtest.py diff --git a/examples/highfreq/backtest/workflow.py b/examples/highfreq/backtest/workflow.py index df01e31de..3e0e1524b 100644 --- a/examples/highfreq/backtest/workflow.py +++ b/examples/highfreq/backtest/workflow.py @@ -7,13 +7,9 @@ from pathlib import Path import qlib import pandas as pd from qlib.config import REG_CN -from qlib.contrib.model.gbdt import LGBModel -from qlib.contrib.data.handler import Alpha158 -from qlib.contrib.strategy.strategy import TopkDropoutStrategy -from qlib.backtest import backtest +from qlib.contrib.strategy import TopkDropoutStrategy +from qlib.contrib.backtest import backtest from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict -from qlib.workflow import R -from qlib.workflow.record_temp import SignalRecord, PortAnaRecord from qlib.tests.data import GetData if __name__ == "__main__": @@ -67,9 +63,9 @@ if __name__ == "__main__": "kwargs": data_handler_config, }, "segments": { - "train": ("2008-01-01", "2014-12-31"), + "train": ("2012-01-01", "2014-12-31"), "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), + "test": ("2017-01-01", "2018-01-31"), }, }, }, @@ -79,41 +75,40 @@ if __name__ == "__main__": dataset = init_instance_by_config(task["dataset"]) model.fit(dataset) - trade_start_time = "2017-01-01" - trade_end_time = "2020-08-01" - trade_exchange = get_exchange(start_time=trade_start_time, end_time=trade_end_time) + trade_start_time = "2017-01-31" + trade_end_time = "2018-01-31" backtest_config={ "strategy": { "class": "TopkDropoutStrategy", "module_path": "qlib.contrib.strategy.dl_strategy", "kwargs": { - "step_bar": "day", + "step_bar": "week", "model": model, "dataset": dataset, - "trade_exchange": trade_exchange, "topk": 50, "n_drop": 5, }, }, "env":{ "class": "SplitEnv", - "module_path": "qlib.backtest.env", + "module_path": "qlib.contrib.backtest.env", "kwargs": { - "step_bar": "day", + "step_bar": "week", "sub_env": { "class": "SimulatorEnv", - "module_path": "qlib.backtest.env", + "module_path": "qlib.contrib.backtest.env", "kwargs": { - "step_bar": "1min", - "trade_exchange": trade_exchange, + "step_bar": "day", } }, "sub_strategy": { "class": "SBBStrategyEMA", "module_path": "qlib.contrib.strategy.rule_strategy", "kwargs": { - "step_bar": "1min", + "step_bar": "day", + "freq": "day", + "instruments": "csi300", } } } @@ -121,4 +116,4 @@ if __name__ == "__main__": } - backtest(**backtest_config, ) \ No newline at end of file + report_dict = backtest(start_time=trade_start_time, end_time=trade_end_time, **backtest_config, account=1e8, deal_price="$close", verbose=False) \ No newline at end of file diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py deleted file mode 100644 index 70bc03363..000000000 --- a/qlib/backtest/__init__.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from .order import Order -from .position import Position -from .exchange import Exchange -from .report import Report -from .backtest import backtest as backtest_func - -import copy -import numpy as np -import inspect -from ..utils import init_instance_by_config -from ..log import get_module_logger -from ..config import C - -logger = get_module_logger("backtest caller") - - -def get_exchange( - pred, - exchange=None, - start_time=None, - end_time=None, - codes = "all", - subscribe_fields=[], - open_cost=0.0015, - close_cost=0.0025, - min_cost=5.0, - trade_unit=None, - limit_threshold=None, - deal_price=None, - shift=1, -): - """get_exchange - - Parameters - ---------- - - # exchange related arguments - exchange: Exchange(). - subscribe_fields: list - subscribe fields. - open_cost : float - open transaction cost. - close_cost : float - close transaction cost. - min_cost : float - min transaction cost. - trade_unit : int - 100 for China A. - deal_price: str - dealing price type: 'close', 'open', 'vwap'. - limit_threshold : float - limit move 0.1 (10%) for example, long and short with same limit. - - Returns - ------- - :class: Exchange - an initialized Exchange object - """ - - if trade_unit is None: - trade_unit = C.trade_unit - if limit_threshold is None: - limit_threshold = C.limit_threshold - if deal_price is None: - deal_price = C.deal_price - if exchange is None: - logger.info("Create new exchange") - # handle exception for deal_price - if deal_price[0] != "$": - deal_price = "$" + deal_price - - exchange = Exchange( - start_time=start_time, - end_time=end_time, - codes=codes, - deal_price=deal_price, - subscribe_fields=subscribe_fields, - limit_threshold=limit_threshold, - open_cost=open_cost, - close_cost=close_cost, - trade_unit=trade_unit, - min_cost=min_cost, - ) - else: - return init_instance_by_config(exchange, accept_types=Exchange) - -def init_env_instance_by_config(env): - if isinstance(env, dict): - env_config = copy.copy(env) - if "kwargs" in env_config: - env_kwargs = copy.copy(env_config["kwargs"]): - if "sub_env" in env_kwargs: - env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"]) - if "sub_strategy" in env_kwargs: - env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"]) - env_config["kwargs"] = env_kwargs - return init_instance_by_config(env_config) - else: - return env - -def setup_exchange(root_instance, trade_exchange=None, force=False): - if "trade_exchange" in inspect.getfullargspec(root_instance.__class__).args: - if force: - root_instance.reset(trade_exchange=trade_exchange) - else: - if not hasattr(root_instance, "trade_exchange") or root_instance.trade_exchange is None: - root_instance.reset(trade_exchange=trade_exchange) - if hasattr(root_instance, "sub_env"): - setup_exchange(root_instance.sub_env, trade_exchange) - if hasattr(root_instance, "sub_strategy"): - setup_exchange(root_instance.sub_strategy, trade_exchange) - - -def backtest(start_time, end_time, strategy, env, benchmark, account=1e9, **kwargs): - trade_strategy = init_instance_by_config(strategy) - trade_env = init_env_instance_by_config(env) - - spec = inspect.getfullargspec(get_exchange) - ex_args = {k: v for k, v in kwargs.items() if k in spec.args} - trade_exchange = get_exchange(pred, **ex_args) - - setup_exchange(trade_env, trade_exchange) - setup_exchange(trade_strategy, trade_exchange) - - report_dict = backtest_func(start_time, end_time, trade_strategy, trade_env, benchmark, account) - - return report_dict diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py deleted file mode 100644 index c44d26d7b..000000000 --- a/qlib/backtest/account.py +++ /dev/null @@ -1,170 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - - -import copy - -from .position import Position -from .report import Report -from .order import Order - - -""" -rtn & earning in the Account - rtn: - from order's view - 1.change if any order is executed, sell order or buy order - 2.change at the end of today, (today_clse - stock_price) * amount - earning - from value of current position - earning will be updated at the end of trade date - earning = today_value - pre_value - **is consider cost** - while earning is the difference of two position value, so it considers cost, it is the true return rate - in the specific accomplishment for rtn, it does not consider cost, in other words, rtn - cost = earning -""" - - -class Account: - def __init__(self, init_cash, last_trade_time=None): - self.init_vars(init_cash, last_trade_time) - - def init_vars(self, init_cash, last_trade_time=None): - # init cash - self.init_cash = init_cash - self.current = Position(cash=init_cash) - self.positions = {} - self.rtn = 0 - self.ct = 0 - self.to = 0 - self.val = 0 - self.report = Report() - self.earning = 0 - self.last_trade_time = last_trade_time - - def get_positions(self): - return self.positions - - def get_cash(self): - return self.current.position["cash"] - - def update_state_from_order(self, order, trade_val, cost, trade_price): - # update turnover - self.to += trade_val - # update cost - self.ct += cost - # update return - # update self.rtn from order - trade_amount = trade_val / trade_price - if order.direction == Order.SELL: # 0 for sell - # when sell stock, get profit from price change - profit = trade_val - self.current.get_stock_price(order.stock_id) * trade_amount - self.rtn += profit # note here do not consider cost - elif order.direction == Order.BUY: # 1 for buy - # when buy stock, we get return for the rtn computing method - # profit in buy order is to make self.rtn is consistent with self.earning at the end of date - profit = self.current.get_stock_price(order.stock_id) * trade_amount - trade_val - self.rtn += profit - - def update_order(self, order, trade_val, cost, trade_price): - # if stock is sold out, no stock price information in Position, then we should update account first, then update current position - # if stock is bought, there is no stock in current position, update current, then update account - # The cost will be substracted from the cash at last. So the trading logic can ignore the cost calculation - trade_amount = trade_val / trade_price - if order.direction == Order.SELL: - # sell stock - self.update_state_from_order(order, trade_val, cost, trade_price) - # update current position - # for may sell all of stock_id - self.current.update_order(order, trade_val, cost, trade_price) - else: - # buy stock - # deal order, then update state - self.current.update_order(order, trade_val, cost, trade_price) - self.update_state_from_order(order, trade_val, cost, trade_price) - - def update_bar_end(self, trade_start_time, trade_end_time, trade_exchange): - """ - start_time: pd.TimeStamp - end_time: pd.TimeStamp - quote: pd.DataFrame (code, date), collumns - when the end of trade date - - update rtn - - update price for each asset - - update value for this account - - update earning (2nd view of return ) - - update holding day, count of stock - - update position hitory - - update report - :return: None - """ - # update price for stock in the position and the profit from changed_price - stock_list = self.current.get_stock_list() - profit = 0 - for code in stock_list: - # if suspend, no new price to be updated, profit is 0 - if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time): - continue - bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time) - profit += (bar_close - self.current.position[code]["price"]) * self.current.position[code]["amount"] - self.current.update_stock_price(stock_id=code, price=bar_close) - self.rtn += profit - # update holding day count - self.current.add_count_all() - # update value - self.val = self.current.calculate_value() - # update earning (2nd view of return) - # account_value - last_account_value - # for the first trade date, account_value - init_cash - # self.report.is_empty() to judge is_first_trade_date - # get last_account_value, now_account_value, now_stock_value - if self.report.is_empty(): - last_account_value = self.init_cash - else: - last_account_value = self.report.get_latest_account_value() - now_account_value = self.current.calculate_value() - now_stock_value = self.current.calculate_stock_value() - self.earning = now_account_value - last_account_value - # update report for today - # judge whether the the trading is begin. - # and don't add init account state into report, due to we don't have excess return in those days. - self.report.update_report_record( - trade_time=trade_start_time, - account_value=now_account_value, - cash=self.current.position["cash"], - return_rate=(self.earning + self.ct) / last_account_value, - # here use earning to calculate return, position's view, earning consider cost, true return - # in order to make same definition with original backtest in evaluate.py - turnover_rate=self.to / last_account_value, - cost_rate=self.ct / last_account_value, - stock_value=now_stock_value, - ) - # set now_account_value to position - self.current.position["now_account_value"] = now_account_value - self.current.update_weight_all() - # update positions - # note use deepcopy - self.positions[trade_start_time] = copy.deepcopy(self.current) - - # finish today's updation - # reset the daily variables - self.rtn = 0 - self.ct = 0 - self.to = 0 - self.last_trade_time = (trade_start_time, trade_end_time) - - def load_account(self, account_path): - report = Report() - position = Position() - last_trade_time = position.load_position(account_path / "position.xlsx") - report.load_report(account_path / "report.csv") - - # assign values - self.init_vars(position.init_cash) - self.current = position - self.report = report - self.last_trade_time = last_trade_time - - def save_account(self, account_path): - self.current.save_position(account_path / "position.xlsx", self.last_trade_time) - self.report.save_report(account_path / "report.csv") diff --git a/qlib/backtest/backtest.py b/qlib/backtest/backtest.py deleted file mode 100644 index cd9539725..000000000 --- a/qlib/backtest/backtest.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - - -import numpy as np -import pandas as pd - -from .account import Account - -def backtest(trade_strategy, trade_env, benchmark, account): - - trade_account = Account(init_cash=account) - trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account) - trade_strategy.reset(start_time=start_time, end_time=end_time) - - trade_state = self.sub_env.get_init_state() - while not trade_env.finished(): - _order_list = self.sub_strategy.generate_order(**trade_state) - trade_state, trade_info = self.sub_env.execute(sub_order_list) - - report_df = trade_account.report.generate_report_dataframe() - positions = trade_account.get_positions() - report_dict = {"report_df": report_df, "positions": positions} - - return report_dict - diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py deleted file mode 100644 index 985cf92e8..000000000 --- a/qlib/backtest/exchange.py +++ /dev/null @@ -1,429 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - - -import random -import logging - -import numpy as np -import pandas as pd - -from ..data import D -from ..utils import sample_feature -from .order import Order -from ..config import C, REG_CN -from ..log import get_module_logger - - -class Exchange: - def __init__( - self, - start_time=None, - end_time=None, - codes="all", - deal_price=None, - subscribe_fields=[], - limit_threshold=None, - open_cost=0.0015, - close_cost=0.0025, - trade_unit=None, - min_cost=5, - extra_quote=None, - ): - """__init__ - - :param start_time: start time for backtest - :param end_time: end time for backtest - :param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50) - :param deal_price: str, 'close', 'open', 'vwap' - :param subscribe_fields: list, subscribe fields - :param limit_threshold: float, 0.1 for example, default None - :param open_cost: cost rate for open, default 0.0015 - :param close_cost: cost rate for close, default 0.0025 - :param trade_unit: trade unit, 100 for China A market - :param min_cost: min cost, default 5 - :param extra_quote: pandas, dataframe consists of - columns: like ['$vwap', '$close', '$factor', 'limit']. - The limit indicates that the etf is tradable on a specific day. - Necessary fields: - $close is for calculating the total value at end of each day. - Optional fields: - $vwap is only necessary when we use the $vwap price as the deal price - $factor is for rounding to the trading unit - limit will be set to False by default(False indicates we can buy this - target on this day). - index: MultipleIndex(instrument, pd.Datetime) - """ - self.start_time = start_time - self.end_time = end_time - if trade_unit is None: - trade_unit = C.trade_unit - if limit_threshold is None: - limit_threshold = C.limit_threshold - if deal_price is None: - deal_price = C.deal_price - - self.logger = get_module_logger("online operator", level=logging.INFO) - - self.trade_unit = trade_unit - - # TODO: the quote, trade_dates, codes are not necessray. - # It is just for performance consideration. - if limit_threshold is None: - if C.region == REG_CN: - self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold") - elif abs(limit_threshold) > 0.1: - if C.region == REG_CN: - self.logger.warning(f"limit_threshold may not be set to a reasonable value") - - if deal_price[0] != "$": - self.deal_price = "$" + deal_price - else: - self.deal_price = deal_price - if isinstance(codes, str): - codes = D.instruments(codes) - self.codes = codes - # Necessary fields - # $close is for calculating the total value at end of each day. - # $factor is for rounding to the trading unit - # $change is for calculating the limit of the stock - - necessary_fields = {self.deal_price, "$close", "$change", "$factor"} - subscribe_fields = list(necessary_fields | set(subscribe_fields)) - all_fields = list(necessary_fields | set(subscribe_fields)) - self.all_fields = all_fields - self.open_cost = open_cost - self.close_cost = close_cost - self.min_cost = min_cost - self.limit_threshold = limit_threshold - - - self.extra_quote = extra_quote - self.set_quote(codes, start_time, end_time) - - def set_quote(self, codes, start_time, end_time): - if len(codes) == 0: - codes = D.instruments() - self.quote = D.features(codes, self.all_fields, start_time, end_time, disk_cache=True).dropna(subset=["$close"]) - self.quote.columns = self.all_fields - - if self.quote[self.deal_price].isna().any(): - self.logger.warning("{} field data contains nan.".format(self.deal_price)) - - if self.quote["$factor"].isna().any(): - # The 'factor.day.bin' file not exists, and `factor` field contains `nan` - # Use adjusted price - self.trade_w_adj_price = True - self.logger.warning("factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price.") - else: - # The `factor.day.bin` file exists and all data `close` and `factor` are not `nan` - # Use normal price - self.trade_w_adj_price = False - # update limit - # check limit_threshold - if self.limit_threshold is None: - self.quote["limit"] = False - else: - # set limit - self._update_limit(buy_limit=self.limit_threshold, sell_limit=self.limit_threshold) - - quote_df = self.quote - if self.extra_quote is not None: - # process extra_quote - if "$close" not in self.extra_quote: - raise ValueError("$close is necessray in extra_quote") - if self.deal_price not in self.extra_quote.columns: - self.extra_quote[self.deal_price] = self.extra_quote["$close"] - self.logger.warning("No deal_price set for extra_quote. Use $close as deal_price.") - if "$factor" not in self.extra_quote.columns: - self.extra_quote["$factor"] = 1.0 - self.logger.warning("No $factor set for extra_quote. Use 1.0 as $factor.") - if "limit" not in self.extra_quote.columns: - self.extra_quote["limit"] = False - self.logger.warning("No limit set for extra_quote. All stock will be tradable.") - assert set(self.extra_quote.columns) == set(quote_df.columns) - {"$change"} - quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0) - - # update quote: pd.DataFrame to dict, for search use - self.quote = quote_df - - def _update_limit(self, buy_limit, sell_limit): - self.quote["limit"] = ~self.quote["$change"].between(-sell_limit, buy_limit, inclusive=False) - - def check_stock_limit(self, stock_id, start_time, end_time): - """Parameter - stock_id - trade_date - is limtited - """ - return sample_feature(self.quote, stock_id, start_time, end_time, fields="limit", method="any").iloc[0, 0] - - - def check_stock_suspended(self, stock_id, start_time, end_time): - # is suspended - return sample_feature(self.quote, stock_id, start_time, end_time).empty is False - - - def is_stock_tradable(self, stock_id, start_time, end_time): - # check if stock can be traded - # same as check in check_order - if self.check_stock_suspended(stock_id, start_time, end_time) or self.check_stock_limit(stock_id, start_time, end_time): - return False - else: - return True - - def check_order(self, order): - # check limit and suspended - if self.check_stock_suspended(order.stock_id, order.start_time, order.end_time) or self.check_stock_limit( - order.stock_id, order.start_time, order.end_time - ): - return False - else: - return True - - def deal_order(self, order, trade_account=None, position=None): - """ - Deal order when the actual transaction - - :param order: Deal the order. - :param trade_account: Trade account to be updated after dealing the order. - :param position: position to be updated after dealing the order. - :return: trade_val, trade_cost, trade_price - """ - # need to check order first - # TODO: check the order unit limit in the exchange!!!! - # The order limit is related to the adj factor and the cur_amount. - # factor = self.quote[(order.stock_id, order.trade_date)]['$factor'] - # cur_amount = trade_account.current.get_stock_amount(order.stock_id) - if self.check_order(order) is False: - raise AttributeError("need to check order first") - if trade_account is not None and position is not None: - raise ValueError("trade_account and position can only choose one") - - trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time) - trade_val, trade_cost = self._calc_trade_info_by_order( - order, trade_account.current if trade_account else position - ) - # update account - if trade_val > 0: - # If the order can only be deal 0 trade_val. Nothing to be updated - # Otherwise, it will result some stock with 0 amount in the position - if trade_account: - trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price) - elif position: - position.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price) - - return trade_val, trade_cost, trade_price - - def get_quote_info(self, stock_id, start_time, end_time): - return sample_feature(self.quote, stock_id, start_time, end_time) - - def get_close(self, stock_id, start_time, end_time): - return sample_feature(self.quote, stock_id, start_time, end_time, fields="$close", method="last") - - def get_deal_price(self, stock_id, start_time, end_time): - deal_price = sample_feature(self.quote, stock_id, start_time, end_time, fields=self.deal_price, method="last") - deal_price = self.quote[(stock_id, trade_date)][self.deal_price] - if np.isclose(deal_price, 0.0) or np.isnan(deal_price): - self.logger.warning(f"(stock_id:{stock_id}, trade_date:{trade_date}, {self.deal_price}): {deal_price}!!!") - self.logger.warning(f"setting deal_price to close price") - deal_price = self.get_close(stock_id, start_time, end_time) - return deal_price - - def get_factor(self, stock_id, start_time, end_time): - return sample_feature(self.quote, stock_id, start_time, end_time, fields="$factor", method="last") - - def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time): - """ - The generate the target position according to the weight and the cash. - NOTE: All the cash will assigned to the tadable stock. - - Parameter: - weight_position : dict {stock_id : weight}; allocate cash by weight_position - among then, weight must be in this range: 0 < weight < 1 - cash : cash - trade_date : trade date - """ - - # calculate the total weight of tradable value - tradable_weight = 0.0 - for stock_id in weight_position: - if self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time): - # weight_position must be greater than 0 and less than 1 - if weight_position[stock_id] < 0 or weight_position[stock_id] > 1: - raise ValueError( - "weight_position is {}, " - "weight_position is not in the range of (0, 1).".format(weight_position[stock_id]) - ) - tradable_weight += weight_position[stock_id] - - if tradable_weight - 1.0 >= 1e-5: - raise ValueError("tradable_weight is {}, can not greater than 1.".format(tradable_weight)) - - amount_dict = {} - for stock_id in weight_position: - if weight_position[stock_id] > 0.0 and self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time): - amount_dict[stock_id] = ( - cash - * weight_position[stock_id] - / tradable_weight - // self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time) - ) - return amount_dict - - def get_real_deal_amount(self, current_amount, target_amount, factor): - """ - Calculate the real adjust deal amount when considering the trading unit - - :param current_amount: - :param target_amount: - :param factor: - :return real_deal_amount; Positive deal_amount indicates buying more stock. - """ - if current_amount == target_amount: - return 0 - elif current_amount < target_amount: - deal_amount = target_amount - current_amount - deal_amount = self.round_amount_by_trade_unit(deal_amount, factor) - return deal_amount - else: - if target_amount == 0: - return -current_amount - else: - deal_amount = current_amount - target_amount - deal_amount = self.round_amount_by_trade_unit(deal_amount, factor) - return -deal_amount - - def generate_order_for_target_amount_position(self, target_position, current_position, start_time, end_time): - """Parameter: - target_position : dict { stock_id : amount } - current_postion : dict { stock_id : amount} - trade_unit : trade_unit - down sample : for amount 321 and trade_unit 100, deal_amount is 300 - deal order on trade_date - """ - # split buy and sell for further use - buy_order_list = [] - sell_order_list = [] - # three parts: kept stock_id, dropped stock_id, new stock_id - # handle kept stock_id - - # because the order of the set is not fixed, the trading order of the stock is different, so that the backtest results of the same parameter are different; - # so here we sort stock_id, and then randomly shuffle the order of stock_id - # because the same random seed is used, the final stock_id order is fixed - sorted_ids = sorted(set(list(current_position.keys()) + list(target_position.keys()))) - random.seed(0) - random.shuffle(sorted_ids) - for stock_id in sorted_ids: - - # Do not generate order for the nontradable stocks - if not self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time): - continue - - target_amount = target_position.get(stock_id, 0) - current_amount = current_position.get(stock_id, 0) - factor = self.get_factor(stock_id, start_time=start_time, end_time=end_time) - - deal_amount = self.get_real_deal_amount(current_amount, target_amount, factor) - if deal_amount == 0: - continue - elif deal_amount > 0: - # buy stock - buy_order_list.append( - Order( - stock_id=stock_id, - amount=deal_amount, - direction=Order.BUY, - start_time=start_time, - end_time=end_time, - factor=factor, - ) - ) - else: - # sell stock - sell_order_list.append( - Order( - stock_id=stock_id, - amount=abs(deal_amount), - direction=Order.SELL, - start_time=start_time, - end_time=end_time, - factor=factor, - ) - ) - # return order_list : buy + sell - return sell_order_list + buy_order_list - - def calculate_amount_position_value(self, amount_dict, start_time, end_time, only_tradable=False): - """Parameter - position : Position() - amount_dict : {stock_id : amount} - """ - value = 0 - for stock_id in amount_dict: - if ( - self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time) is False - and self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) is False - ): - value += self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time) * amount_dict[stock_id] - return value - - def round_amount_by_trade_unit(self, deal_amount, factor): - """Parameter - deal_amount : float, adjusted amount - factor : float, adjusted factor - return : float, real amount - """ - if not self.trade_w_adj_price: - # the minimal amount is 1. Add 0.1 for solving precision problem. - return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor - return deal_amount - - def _calc_trade_info_by_order(self, order, position): - """ - Calculation of trade info - - :param order: - :param position: Position - :return: trade_val, trade_cost - """ - - trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time) - if order.direction == Order.SELL: - # sell - if position is not None: - if np.isclose(order.amount, position.get_stock_amount(order.stock_id)): - # when selling last stock. The amount don't need rounding - order.deal_amount = order.amount - else: - order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor) - else: - # TODO: We don't know current position. - # We choose to sell all - order.deal_amount = order.amount - - trade_val = order.deal_amount * trade_price - trade_cost = max(trade_val * self.close_cost, self.min_cost) - elif order.direction == Order.BUY: - # buy - if position is not None: - cash = position.get_cash() - trade_val = order.amount * trade_price - if cash < trade_val * (1 + self.open_cost): - # The money is not enough - order.deal_amount = self.round_amount_by_trade_unit( - cash / (1 + self.open_cost) / trade_price, order.factor - ) - else: - # THe money is enough - order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor) - else: - # Unknown amount of money. Just round the amount - order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor) - - trade_val = order.deal_amount * trade_price - trade_cost = trade_val * self.open_cost - else: - raise NotImplementedError("order type {} error".format(order.type)) - - return trade_val, trade_cost diff --git a/qlib/backtest/import numpy as np b/qlib/backtest/import numpy as np deleted file mode 100644 index f558d3649..000000000 --- a/qlib/backtest/import numpy as np +++ /dev/null @@ -1,90 +0,0 @@ -class HighFreqOrderNorm(Processor): - def __init__(self, fit_start_time, fit_end_time, feature_save_dir, price_dim=5, order_price_dim=2, volume_dim=1, order_volume_dim=8, day_length=240): - self.fit_start_time = fit_start_time - self.fit_end_time = fit_end_time - self.price_dim = price_dim - self.volume_dim = volume_dim - self.order_price_dim = order_price_dim - self.order_volume_dim = order_volume_dim - self.feature_save_dir = feature_save_dir - self.day_length = day_length - self.names = dict() - column_dim = self.price_dim + self.order_price_dim + self.volume_dim + self.order_volume_dim - fields = [("price", self.price_dim), ("order_price", self.order_price_dim), ("volume", self.volume_dim), ("order_volume", self.order_volume_dim)] - last_dim = 0 - for field, field_dim in fields: - self.names[field] = list(range(last_dim, last_dim + field_dim)) + list((range(column_dim + last_dim, column_dim + last_dim + field_dim))) - last_dim += field_dim - - @profile - def fit(self, df_features): - # fetch_df = fetch_df_by_index(df_features, slice(self.fit_start_time, self.fit_end_time), level="datetime") - - - print("end") - if not os.path.exists(self.feature_save_dir): - os.makedirs(self.feature_save_dir) - for name, name_val in self.names.items(): - print(name) - df_values = df_features.iloc(axis=1)[name_val].values - if name == "volume" or name == "order_volume": - df_values = np.log1p(df_values) - self.feature_med = np.nanmedian(df_values) - np.save(self.feature_save_dir + name + "_med.npy", self.feature_med) - df_values = df_values - self.feature_med - self.feature_std = np.nanmedian(np.absolute(df_values)) * 1.4826 + 1e-12 - np.save(self.feature_save_dir + name + "_std.npy", self.feature_std) - df_values = df_values / self.feature_std - np.save(self.feature_save_dir + name + "_vmax.npy", np.nanmax(df_values)) - np.save(self.feature_save_dir + name + "_vmin.npy", np.nanmin(df_values)) - - - def __call__(self, df_features): - df_features.set_index("date", append=True, drop=True, inplace=True) - df_values = df_features.values - df_values_dict = dict() - for name, name_val in self.names.items(): - self.feature_med = np.load(self.feature_save_dir + name + "_med.npy") - self.feature_std = np.load(self.feature_save_dir + name + "_std.npy") - self.feature_vmax = np.load(self.feature_save_dir + name + "_vmax.npy") - self.feature_vmin = np.load(self.feature_save_dir + name + "_vmin.npy") - - df_values = df_features.iloc(axis=1)[name_val].values - if name == "volume" or name == "order_volume": - df_values[:] = np.log1p(df_values) - df_values[:] -= self.feature_med - df_values[:] /= self.feature_std - slice0 = df_values > 3.0 - slice1 = df_values > 3.5 - slice2 = df_values < -3.0 - slice3 = df_values < -3.5 - - df_values[slice0] = ( - 3.0 + (df_values[slice0] - 3.0) / (self.feature_vmax - 3) * 0.5 - ) - df_values[slice1] = 3.5 - df_values[slice2] = ( - -3.0 - (df_values[slice2] + 3.0) / (self.feature_vmin + 3) * 0.5 - ) - df_values[slice3] = -3.5 - df_values_dict[name] = df_values - - idx = df_features.index.droplevel("datetime").drop_duplicates() - idx.set_names(["instrument", "datetime"], inplace=True) - - # Reshape is specifically for adapting to RL high-freq executor - feat = df_values[:, list(range(self.price_dim)) + list(range(self.price_dim * 2, self.price_dim * 2 + self.order_price_dim)) - + list(range((self.price_dim + self.order_price_dim) * 2, (self.price_dim + self.order_price_dim) * 2 + self.volume_dim)) - + list(range((self.price_dim + self.order_price_dim + self.volume_dim) * 2, (self.price_dim + self.order_price_dim + self.volume_dim) * 2 + self.order_volume_dim)) - ].reshape(-1, (self.price_dim + self.order_price_dim + self.volume_dim + self.order_volume_dim) * self.day_length) - - feat_1 = df_values[:, list(np.arange(self.price_dim) + self.price_dim) + list(np.arange(self.price_dim * 2, self.price_dim * 2 + self.order_price_dim) + self.order_price_dim) - + list(np.arange((self.price_dim + self.order_price_dim) * 2, (self.price_dim + self.order_price_dim) * 2 + self.volume_dim) + self.volume_dim) - + list(np.arange((self.price_dim + self.order_price_dim + self.volume_dim) * 2, (self.price_dim + self.order_price_dim + self.volume_dim) * 2 + self.order_volume_dim) + self.order_volume_dim) - ].reshape(-1, (self.price_dim + self.order_price_dim + self.volume_dim + self.order_volume_dim) * self.day_length) - df_new_features = pd.DataFrame( - data=np.concatenate((feat, feat_1), axis=1), - index=idx, - columns=range(2 * (self.price_dim + self.order_price_dim + self.volume_dim + self.order_volume_dim) * self.day_length), - ).sort_index() - return df_new_features \ No newline at end of file diff --git a/qlib/backtest/init.py b/qlib/backtest/init.py deleted file mode 100644 index 06dd437db..000000000 --- a/qlib/backtest/init.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from .order import Order -from .account import Account -from .position import Position -from .exchange import Exchange -from .report import Report -from .backtest import backtest as backtest_func, get_date_range - -import copy -import numpy as np -import inspect -from ..utils import init_instance_by_config -from ..log import get_module_logger -from ..config import C - -logger = get_module_logger("backtest caller") - - -def init_env_instance_by_config(env): - if isinstance(env, dict): - env_config = copy.copy(env) - if "kwargs" in env_config: - env_kwargs = copy.copy(env_config["kwargs"]): - if "sub_env" in env_kwargs: - env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"]) - if "sub_strategy" in env_kwargs: - env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"]) - env_config["kwargs"] = env_kwargs - return init_instance_by_config(env_config) - else: - return env - - -def get_exchange( - exchange=None, - start_time=None, - end_time=None, - codes = "all", - subscribe_fields=[], - open_cost=0.0015, - close_cost=0.0025, - min_cost=5.0, - trade_unit=None, - limit_threshold=None, - deal_price=None, - shift=1, -): - """get_exchange - - Parameters - ---------- - - # exchange related arguments - exchange: Exchange(). - subscribe_fields: list - subscribe fields. - open_cost : float - open transaction cost. - close_cost : float - close transaction cost. - min_cost : float - min transaction cost. - trade_unit : int - 100 for China A. - deal_price: str - dealing price type: 'close', 'open', 'vwap'. - limit_threshold : float - limit move 0.1 (10%) for example, long and short with same limit. - - Returns - ------- - :class: Exchange - an initialized Exchange object - """ - - if trade_unit is None: - trade_unit = C.trade_unit - if limit_threshold is None: - limit_threshold = C.limit_threshold - if deal_price is None: - deal_price = C.deal_price - if exchange is None: - logger.info("Create new exchange") - # handle exception for deal_price - if deal_price[0] != "$": - deal_price = "$" + deal_price - - exchange = Exchange( - start_time=start_time, - end_time=end_time, - codes=codes, - deal_price=deal_price, - subscribe_fields=subscribe_fields, - limit_threshold=limit_threshold, - open_cost=open_cost, - close_cost=close_cost, - trade_unit=trade_unit, - min_cost=min_cost, - ) - else: - return init_instance_by_config(exchange, accept_types=Exchange) - -def backtest(start_time, end_time, strategy, env, account=1e9, **kwargs): - trade_strategy = init_instance_by_config(strategy) - trade_env = init_env_instance_by_config(env) - trade_account = Account(init_cash=account) - - spec = inspect.getfullargspec(get_exchange) - ex_args = {k: v for k, v in kwargs.items() if k in spec.args} - trade_exchange = get_exchange(pred, **ex_args) - -# temp_env = trade_env -# while True: -# if hasattr(temp_env, "trade_exchange"): -# temp_env.reset(trade_exchange=trade_exchange) -# if hasattr(temp_env, "sub_env"): -# temp_env = temp_env.sub_env -# else: -# break - - trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account) - trade_state, _reset_info = self.sub_env.get_first_state() - trade_strategy.reset(**_reset_info) - - - while not trade_env.finished(): - _order_list = self.sub_strategy.generate_order(**trade_state) - trade_state, trade_info = self.sub_env.execute(sub_order_list) - - return diff --git a/qlib/backtest/order.py b/qlib/backtest/order.py deleted file mode 100644 index 0d637d9db..000000000 --- a/qlib/backtest/order.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - - -class Order: - - SELL = 0 - BUY = 1 - - def __init__(self, stock_id, amount, start_time, end_time, direction, factor): - """Parameter - direction : Order.SELL for sell; Order.BUY for buy - stock_id : str - amount : float - trade_date : pd.Timestamp - factor : float - presents the weight factor assigned in Exchange() - """ - # check direction - if direction not in {Order.SELL, Order.BUY}: - raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy") - self.stock_id = stock_id - # amount of generated orders - self.amount = amount - # amount of successfully completed orders - self.deal_amount = 0 - self.start_time = start_time - self.end_time = end_time - self.direction = direction - self.factor = factor diff --git a/qlib/backtest/position.py b/qlib/backtest/position.py deleted file mode 100644 index 9945a7e8f..000000000 --- a/qlib/backtest/position.py +++ /dev/null @@ -1,217 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - - -import pandas as pd -import copy -import pathlib -from .order import Order - -""" -Position module -""" - -""" -current state of position -a typical example is :{ - : { - 'count': , - 'amount': , - 'price': , - 'weight': , - }, -} - -""" - - -class Position: - """Position""" - - def __init__(self, cash=0, position_dict={}, today_account_value=0): - # NOTE: The position dict must be copied!!! - # Otherwise the initial value - self.init_cash = cash - self.position = position_dict.copy() - self.position["cash"] = cash - self.position["today_account_value"] = today_account_value - - def init_stock(self, stock_id, amount, price=None): - self.position[stock_id] = {} - self.position[stock_id]["count"] = 0 # update count in the end of this date - self.position[stock_id]["amount"] = amount - self.position[stock_id]["price"] = price - self.position[stock_id]["weight"] = 0 # update the weight in the end of the trade date - - def buy_stock(self, stock_id, trade_val, cost, trade_price): - trade_amount = trade_val / trade_price - if stock_id not in self.position: - self.init_stock(stock_id=stock_id, amount=trade_amount, price=trade_price) - else: - # exist, add amount - self.position[stock_id]["amount"] += trade_amount - - self.position["cash"] -= trade_val + cost - - def sell_stock(self, stock_id, trade_val, cost, trade_price): - trade_amount = trade_val / trade_price - if stock_id not in self.position: - raise KeyError("{} not in current position".format(stock_id)) - else: - # decrease the amount of stock - self.position[stock_id]["amount"] -= trade_amount - # check if to delete - if self.position[stock_id]["amount"] < -1e-5: - raise ValueError( - "only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, trade_amount) - ) - elif abs(self.position[stock_id]["amount"]) <= 1e-5: - self.del_stock(stock_id) - - self.position["cash"] += trade_val - cost - - def del_stock(self, stock_id): - del self.position[stock_id] - - def update_order(self, order, trade_val, cost, trade_price): - # handle order, order is a order class, defined in exchange.py - if order.direction == Order.BUY: - # BUY - self.buy_stock(order.stock_id, trade_val, cost, trade_price) - elif order.direction == Order.SELL: - # SELL - self.sell_stock(order.stock_id, trade_val, cost, trade_price) - else: - raise NotImplementedError("do not support order direction {}".format(order.direction)) - - def update_stock_price(self, stock_id, price): - self.position[stock_id]["price"] = price - - def update_stock_count(self, stock_id, count): - self.position[stock_id]["count"] = count - - def update_stock_weight(self, stock_id, weight): - self.position[stock_id]["weight"] = weight - - def update_cash(self, cash): - self.position["cash"] = cash - - def calculate_stock_value(self): - stock_list = self.get_stock_list() - value = 0 - for stock_id in stock_list: - value += self.position[stock_id]["amount"] * self.position[stock_id]["price"] - return value - - def calculate_value(self): - value = self.calculate_stock_value() - value += self.position["cash"] - return value - - def get_stock_list(self): - stock_list = list(set(self.position.keys()) - {"cash", "today_account_value"}) - return stock_list - - def get_stock_price(self, code): - return self.position[code]["price"] - - def get_stock_amount(self, code): - return self.position[code]["amount"] - - def get_stock_count(self, code): - return self.position[code]["count"] - - def get_stock_weight(self, code): - return self.position[code]["weight"] - - def get_cash(self): - return self.position["cash"] - - def get_stock_amount_dict(self): - """generate stock amount dict {stock_id : amount of stock} """ - d = {} - stock_list = self.get_stock_list() - for stock_code in stock_list: - d[stock_code] = self.get_stock_amount(code=stock_code) - return d - - def get_stock_weight_dict(self, only_stock=False): - """get_stock_weight_dict - generate stock weight fict {stock_id : value weight of stock in the position} - it is meaningful in the beginning or the end of each trade date - - :param only_stock: If only_stock=True, the weight of each stock in total stock will be returned - If only_stock=False, the weight of each stock in total assets(stock + cash) will be returned - """ - if only_stock: - position_value = self.calculate_stock_value() - else: - position_value = self.calculate_value() - d = {} - stock_list = self.get_stock_list() - for stock_code in stock_list: - d[stock_code] = self.position[stock_code]["amount"] * self.position[stock_code]["price"] / position_value - return d - - def add_count_all(self): - stock_list = self.get_stock_list() - for code in stock_list: - self.position[code]["count"] += 1 - - def update_weight_all(self): - weight_dict = self.get_stock_weight_dict() - for stock_code, weight in weight_dict.items(): - self.update_stock_weight(stock_code, weight) - - def save_position(self, path, last_trade_time): - path = pathlib.Path(path) - p = copy.deepcopy(self.position) - cash = pd.Series(dtype=np.float) - cash["init_cash"] = self.init_cash - cash["cash"] = p["cash"] - cash["today_account_value"] = p["today_account_value"] - cash["last_trade_start_time"] = str(last_trade_time[0]) if last_trade_time else None - cash["last_trade_end_time"] = str(last_trade_time[1]) if last_trade_time else None - del p["cash"] - del p["today_account_value"] - positions = pd.DataFrame.from_dict(p, orient="index") - with pd.ExcelWriter(path) as writer: - positions.to_excel(writer, sheet_name="position") - cash.to_excel(writer, sheet_name="info") - - def load_position(self, path): - """load position information from a file - should have format below - sheet "position" - columns: ['stock', 'count', 'amount', 'price', 'weight'] - 'count': , - 'amount': , - 'price': , - 'weight': , - - sheet "cash" - index: ['init_cash', 'cash', 'today_account_value'] - 'init_cash': , - 'cash': , - 'today_account_value': - """ - path = pathlib.Path(path) - positions = pd.read_excel(open(path, "rb"), sheet_name="position", index_col=0) - cash_record = pd.read_excel(open(path, "rb"), sheet_name="info", index_col=0) - positions = positions.to_dict(orient="index") - init_cash = cash_record.loc["init_cash"].values[0] - cash = cash_record.loc["cash"].values[0] - today_account_value = cash_record.loc["today_account_value"].values[0] - last_trade_start_time = cash_record.loc["last_trade_start_time"].values[0] - last_trade_end_time = cash_record.loc["last_trade_end_time"].values[0] - - # assign values - self.position = {} - self.init_cash = init_cash - self.position = positions - self.position["cash"] = cash - self.position["today_account_value"] = today_account_value - - last_trade_start_time = None is pd.isna(last_trade_start_time) else pd.Timestamp(last_trade_start_time) - last_trade_end_time = None is pd.isna(last_trade_end_time) else pd.Timestamp(last_trade_end_time) - return last_trade_start_time, last_trade_end_time diff --git a/qlib/backtest/profit_attribution.py b/qlib/backtest/profit_attribution.py deleted file mode 100644 index 20c6f638f..000000000 --- a/qlib/backtest/profit_attribution.py +++ /dev/null @@ -1,324 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - - -import numpy as np -import pandas as pd -from .position import Position -from ...data import D -from ...config import C -import datetime -from pathlib import Path - - -def get_benchmark_weight( - bench, - start_date=None, - end_date=None, - path=None, -): - """get_benchmark_weight - - get the stock weight distribution of the benchmark - - :param bench: - :param start_date: - :param end_date: - :param path: - - :return: The weight distribution of the the benchmark described by a pandas dataframe - Every row corresponds to a trading day. - Every column corresponds to a stock. - Every cell represents the strategy. - - """ - if not path: - path = Path(C.get_data_path()).expanduser() / "raw" / "AIndexMembers" / "weights.csv" - # TODO: the storage of weights should be implemented in a more elegent way - # TODO: The benchmark is not consistant with the filename in instruments. - bench_weight_df = pd.read_csv(path, usecols=["code", "date", "index", "weight"]) - bench_weight_df = bench_weight_df[bench_weight_df["index"] == bench] - bench_weight_df["date"] = pd.to_datetime(bench_weight_df["date"]) - if start_date is not None: - bench_weight_df = bench_weight_df[bench_weight_df.date >= start_date] - if end_date is not None: - bench_weight_df = bench_weight_df[bench_weight_df.date <= end_date] - bench_stock_weight = bench_weight_df.pivot_table(index="date", columns="code", values="weight") / 100.0 - return bench_stock_weight - - -def get_stock_weight_df(positions): - """get_stock_weight_df - :param positions: Given a positions from backtest result. - :return: A weight distribution for the position - """ - stock_weight = [] - index = [] - for date in sorted(positions.keys()): - pos = positions[date] - if isinstance(pos, dict): - pos = Position(position_dict=pos) - index.append(date) - stock_weight.append(pos.get_stock_weight_dict(only_stock=True)) - return pd.DataFrame(stock_weight, index=index) - - -def decompose_portofolio_weight(stock_weight_df, stock_group_df): - """decompose_portofolio_weight - - ''' - :param stock_weight_df: a pandas dataframe to describe the portofolio by weight. - every row corresponds to a day - every column corresponds to a stock. - Here is an example below. - code SH600004 SH600006 SH600017 SH600022 SH600026 SH600037 \ - date - 2016-01-05 0.001543 0.001570 0.002732 0.001320 0.003000 NaN - 2016-01-06 0.001538 0.001569 0.002770 0.001417 0.002945 NaN - .... - :param stock_group_df: a pandas dataframe to describe the stock group. - every row corresponds to a day - every column corresponds to a stock. - the value in the cell repreponds the group id. - Here is a example by for stock_group_df for industry. The value is the industry code - instrument SH600000 SH600004 SH600005 SH600006 SH600007 SH600008 \ - datetime - 2016-01-05 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0 - 2016-01-06 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0 - ... - :return: Two dict will be returned. The group_weight and the stock_weight_in_group. - The key is the group. The value is a Series or Dataframe to describe the weight of group or weight of stock - """ - all_group = np.unique(stock_group_df.values.flatten()) - all_group = all_group[~np.isnan(all_group)] - - group_weight = {} - stock_weight_in_group = {} - for group_key in all_group: - group_mask = stock_group_df == group_key - group_weight[group_key] = stock_weight_df[group_mask].sum(axis=1) - stock_weight_in_group[group_key] = stock_weight_df[group_mask].divide(group_weight[group_key], axis=0) - return group_weight, stock_weight_in_group - - -def decompose_portofolio(stock_weight_df, stock_group_df, stock_ret_df): - """ - :param stock_weight_df: a pandas dataframe to describe the portofolio by weight. - every row corresponds to a day - every column corresponds to a stock. - Here is an example below. - code SH600004 SH600006 SH600017 SH600022 SH600026 SH600037 \ - date - 2016-01-05 0.001543 0.001570 0.002732 0.001320 0.003000 NaN - 2016-01-06 0.001538 0.001569 0.002770 0.001417 0.002945 NaN - 2016-01-07 0.001555 0.001546 0.002772 0.001393 0.002904 NaN - 2016-01-08 0.001564 0.001527 0.002791 0.001506 0.002948 NaN - 2016-01-11 0.001597 0.001476 0.002738 0.001493 0.003043 NaN - .... - - :param stock_group_df: a pandas dataframe to describe the stock group. - every row corresponds to a day - every column corresponds to a stock. - the value in the cell repreponds the group id. - Here is a example by for stock_group_df for industry. The value is the industry code - instrument SH600000 SH600004 SH600005 SH600006 SH600007 SH600008 \ - datetime - 2016-01-05 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0 - 2016-01-06 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0 - 2016-01-07 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0 - 2016-01-08 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0 - 2016-01-11 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0 - ... - - :param stock_ret_df: a pandas dataframe to describe the stock return. - every row corresponds to a day - every column corresponds to a stock. - the value in the cell repreponds the return of the group. - Here is a example by for stock_ret_df. - instrument SH600000 SH600004 SH600005 SH600006 SH600007 SH600008 \ - datetime - 2016-01-05 0.007795 0.022070 0.099099 0.024707 0.009473 0.016216 - 2016-01-06 -0.032597 -0.075205 -0.098361 -0.098985 -0.099707 -0.098936 - 2016-01-07 -0.001142 0.022544 0.100000 0.004225 0.000651 0.047226 - 2016-01-08 -0.025157 -0.047244 -0.038567 -0.098177 -0.099609 -0.074408 - 2016-01-11 0.023460 0.004959 -0.034384 0.018663 0.014461 0.010962 - ... - - :return: It will decompose the portofolio to the group weight and group return. - """ - all_group = np.unique(stock_group_df.values.flatten()) - all_group = all_group[~np.isnan(all_group)] - - group_weight, stock_weight_in_group = decompose_portofolio_weight(stock_weight_df, stock_group_df) - - group_ret = {} - for group_key in stock_weight_in_group: - stock_weight_in_group_start_date = min(stock_weight_in_group[group_key].index) - stock_weight_in_group_end_date = max(stock_weight_in_group[group_key].index) - - temp_stock_ret_df = stock_ret_df[ - (stock_ret_df.index >= stock_weight_in_group_start_date) - & (stock_ret_df.index <= stock_weight_in_group_end_date) - ] - - group_ret[group_key] = (temp_stock_ret_df * stock_weight_in_group[group_key]).sum(axis=1) - # If no weight is assigned, then the return of group will be np.nan - group_ret[group_key][group_weight[group_key] == 0.0] = np.nan - - group_weight_df = pd.DataFrame(group_weight) - group_ret_df = pd.DataFrame(group_ret) - return group_weight_df, group_ret_df - - -def get_daily_bin_group(bench_values, stock_values, group_n): - """get_daily_bin_group - Group the values of the stocks of benchmark into several bins in a day. - Put the stocks into these bins. - - :param bench_values: A series contains the value of stocks in benchmark. - The index is the stock code. - :param stock_values: A series contains the value of stocks of your portofolio - The index is the stock code. - :param group_n: Bins will be produced - - :return: A series with the same size and index as the stock_value. - The value in the series is the group id of the bins. - The No.1 bin contains the biggest values. - """ - stock_group = stock_values.copy() - - # get the bin split points based on the daily proportion of benchmark - split_points = np.percentile(bench_values[~bench_values.isna()], np.linspace(0, 100, group_n + 1)) - # Modify the biggest uppper bound and smallest lowerbound - split_points[0], split_points[-1] = -np.inf, np.inf - for i, (lb, up) in enumerate(zip(split_points, split_points[1:])): - stock_group.loc[stock_values[(stock_values >= lb) & (stock_values < up)].index] = group_n - i - return stock_group - - -def get_stock_group(stock_group_field_df, bench_stock_weight_df, group_method, group_n=None): - if group_method == "category": - # use the value of the benchmark as the category - return stock_group_field_df - elif group_method == "bins": - assert group_n is not None - # place the values into `group_n` fields. - # Each bin corresponds to a category. - new_stock_group_df = stock_group_field_df.copy().loc[ - bench_stock_weight_df.index.min() : bench_stock_weight_df.index.max() - ] - for idx, row in (~bench_stock_weight_df.isna()).iterrows(): - bench_values = stock_group_field_df.loc[idx, row[row].index] - new_stock_group_df.loc[idx] = get_daily_bin_group( - bench_values, stock_group_field_df.loc[idx], group_n=group_n - ) - return new_stock_group_df - - -def brinson_pa( - positions, - bench="SH000905", - group_field="industry", - group_method="category", - group_n=None, - deal_price="vwap", -): - """brinson profit attribution - - :param positions: The position produced by the backtest class - :param bench: The benchmark for comparing. TODO: if no benchmark is set, the equal-weighted is used. - :param group_field: The field used to set the group for assets allocation. - `industry` and `market_value` is often used. - :param group_method: 'category' or 'bins'. The method used to set the group for asstes allocation - `bin` will split the value into `group_n` bins and each bins represents a group - :param group_n: . Only used when group_method == 'bins'. - - :return: - A dataframe with three columns: RAA(excess Return of Assets Allocation), RSS(excess Return of Stock Selectino), RTotal(Total excess Return) - Every row corresponds to a trading day, the value corresponds to the next return for this trading day - The middle info of brinson profit attribution - """ - # group_method will decide how to group the group_field. - dates = sorted(positions.keys()) - - start_date, end_date = min(dates), max(dates) - - bench_stock_weight = get_benchmark_weight(bench, start_date, end_date) - - # The attributes for allocation will not - if not group_field.startswith("$"): - group_field = "$" + group_field - if not deal_price.startswith("$"): - deal_price = "$" + deal_price - - # FIXME: In current version. Some attributes(such as market_value) of some - # suspend stock is NAN. So we have to get more date to forward fill the NAN - shift_start_date = start_date - datetime.timedelta(days=250) - instruments = D.list_instruments( - D.instruments(market="all"), - start_time=shift_start_date, - end_time=end_date, - as_list=True, - ) - stock_df = D.features( - instruments, - [group_field, deal_price], - start_time=shift_start_date, - end_time=end_date, - freq="day", - ) - stock_df.columns = [group_field, "deal_price"] - - stock_group_field = stock_df[group_field].unstack().T - # FIXME: some attributes of some suspend stock is NAN. - stock_group_field = stock_group_field.fillna(method="ffill") - stock_group_field = stock_group_field.loc[start_date:end_date] - - stock_group = get_stock_group(stock_group_field, bench_stock_weight, group_method, group_n) - - deal_price_df = stock_df["deal_price"].unstack().T - deal_price_df = deal_price_df.fillna(method="ffill") - - # NOTE: - # The return will be slightly different from the of the return in the report. - # Here the position are adjusted at the end of the trading day with close - stock_ret = (deal_price_df - deal_price_df.shift(1)) / deal_price_df.shift(1) - stock_ret = stock_ret.shift(-1).loc[start_date:end_date] - - port_stock_weight_df = get_stock_weight_df(positions) - - # decomposing the portofolio - port_group_weight_df, port_group_ret_df = decompose_portofolio(port_stock_weight_df, stock_group, stock_ret) - bench_group_weight_df, bench_group_ret_df = decompose_portofolio(bench_stock_weight, stock_group, stock_ret) - - # if the group return of the portofolio is NaN, replace it with the market - # value - mod_port_group_ret_df = port_group_ret_df.copy() - mod_port_group_ret_df[mod_port_group_ret_df.isna()] = bench_group_ret_df - - Q1 = (bench_group_weight_df * bench_group_ret_df).sum(axis=1) - Q2 = (port_group_weight_df * bench_group_ret_df).sum(axis=1) - Q3 = (bench_group_weight_df * mod_port_group_ret_df).sum(axis=1) - Q4 = (port_group_weight_df * mod_port_group_ret_df).sum(axis=1) - - return ( - pd.DataFrame( - { - "RAA": Q2 - Q1, # The excess profit from the assets allocation - "RSS": Q3 - Q1, # The excess profit from the stocks selection - # The excess profit from the interaction of assets allocation and stocks selection - "RIN": Q4 - Q3 - Q2 + Q1, - "RTotal": Q4 - Q1, # The totoal excess profit - } - ), - { - "port_group_ret": port_group_ret_df, - "port_group_weight": port_group_weight_df, - "bench_group_ret": bench_group_ret_df, - "bench_group_weight": bench_group_weight_df, - "stock_group": stock_group, - "bench_stock_weight": bench_stock_weight, - "port_stock_weight": port_stock_weight_df, - "stock_ret": stock_ret, - }, - ) diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py deleted file mode 100644 index 9a57156f2..000000000 --- a/qlib/backtest/report.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - - -from collections import OrderedDict -import pandas as pd -import pathlib - - -class Report: - # daily report of the account - # contain those followings: returns, costs turnovers, accounts, cash, bench, value - # update report - def __init__(self): - self.init_vars() - - def init_vars(self): - self.accounts = OrderedDict() # account postion value for each trade date - self.returns = OrderedDict() # daily return rate for each trade date - self.turnovers = OrderedDict() # turnover for each trade date - self.costs = OrderedDict() # trade cost for each trade date - self.values = OrderedDict() # value for each trade date - self.cashes = OrderedDict() - self.latest_report_time = None # pd.TimeStamp - - def is_empty(self): - return len(self.accounts) == 0 - - def get_latest_date(self): - return self.latest_report_time - - def get_latest_account_value(self): - return self.accounts[self.latest_report_time] - - def update_report_record( - self, - trade_time=None, - account_value=None, - cash=None, - return_rate=None, - turnover_rate=None, - cost_rate=None, - stock_value=None, - ): - # check data - if None in [ - trade_time, - account_value, - cash, - return_rate, - turnover_rate, - cost_rate, - stock_value, - ]: - raise ValueError( - "None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value]" - ) - # update report data - self.accounts[trade_time] = account_value - self.returns[trade_time] = return_rate - self.turnovers[trade_time] = turnover_rate - self.costs[trade_time] = cost_rate - self.values[trade_time] = stock_value - self.cashes[trade_time] = cash - # update latest_report_date - self.latest_report_time = trade_time - # finish daily report update - - def generate_report_dataframe(self): - report = pd.DataFrame() - report["account"] = pd.Series(self.accounts) - report["return"] = pd.Series(self.returns) - report["turnover"] = pd.Series(self.turnovers) - report["cost"] = pd.Series(self.costs) - report["value"] = pd.Series(self.values) - report["cash"] = pd.Series(self.cashes) - report.index.name = "trade_time" - return report - - def save_report(self, path): - r = self.generate_report_dataframe() - r.to_csv(path) - - def load_report(self, path): - """load report from a file - should have format like - columns = ['account', 'return', 'turnover', 'cost', 'value', 'cash'] - :param - path: str/ pathlib.Path() - """ - path = pathlib.Path(path) - r = pd.read_csv(open(path, "rb"), index_col=0) - r.index = pd.DatetimeIndex(r.index) - - index = r.index - self.init_vars() - for trade_time in index: - self.update_report_record( - trade_time=trade_time, - account_value=r.loc[trade_time]["account"], - cash=r.loc[trade_time]["cash"], - return_rate=r.loc[trade_time]["return"], - turnover_rate=r.loc[trade_time]["turnover"], - cost_rate=r.loc[trade_time]["cost"], - stock_value=r.loc[trade_time]["value"], - ) diff --git a/qlib/contrib/backtest/__init__.py b/qlib/contrib/backtest/__init__.py index aa24ffb0c..8796d0057 100644 --- a/qlib/contrib/backtest/__init__.py +++ b/qlib/contrib/backtest/__init__.py @@ -2,12 +2,12 @@ # Licensed under the MIT License. from .order import Order -from .account import Account from .position import Position from .exchange import Exchange from .report import Report -from .backtest import backtest as backtest_func, get_date_range +from .backtest import backtest as backtest_func +import copy import numpy as np import inspect from ...utils import init_instance_by_config @@ -17,86 +17,11 @@ from ...config import C logger = get_module_logger("backtest caller") -def get_strategy( - strategy=None, - topk=50, - margin=0.5, - n_drop=5, - risk_degree=0.95, - str_type="dropout", - adjust_dates=None, -): - """get_strategy - - There will be 3 ways to return a stratgy. Please follow the code. - - - Parameters - ---------- - - strategy : Strategy() - strategy used in backtest. - topk : int (Default value: 50) - top-N stocks to buy. - margin : int or float(Default value: 0.5) - - if isinstance(margin, int): - - sell_limit = margin - - - else: - - sell_limit = pred_in_a_day.count() * margin - - buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit). - sell_limit should be no less than topk. - n_drop : int - number of stocks to be replaced in each trading date. - risk_degree: float - 0-1, 0.95 for example, use 95% money to trade. - str_type: 'amount', 'weight' or 'dropout' - strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy. - - Returns - ------- - :class: Strategy - an initialized strategy object - """ - - # There will be 3 ways to return a strategy. - if strategy is None: - # 1) create strategy with param `strategy` - str_cls_dict = { - "amount": "TopkAmountStrategy", - "weight": "TopkWeightStrategy", - "dropout": "TopkDropoutStrategy", - } - logger.info("Create new strategy ") - from .. import strategy as strategy_pool - - str_cls = getattr(strategy_pool, str_cls_dict.get(str_type)) - strategy = str_cls( - topk=topk, - buffer_margin=margin, - n_drop=n_drop, - risk_degree=risk_degree, - adjust_dates=adjust_dates, - ) - elif isinstance(strategy, (dict, str)): - # 2) create strategy with init_instance_by_config - logger.info("Create new strategy ") - strategy = init_instance_by_config(strategy) - - from ..strategy.strategy import BaseStrategy - - # else: nothing happens. 3) Use the strategy directly - if not isinstance(strategy, BaseStrategy): - raise TypeError("Strategy not supported") - return strategy - - def get_exchange( - pred, exchange=None, + start_time=None, + end_time=None, + codes = "all", subscribe_fields=[], open_cost=0.0015, close_cost=0.0025, @@ -104,7 +29,6 @@ def get_exchange( trade_unit=None, limit_threshold=None, deal_price=None, - extract_codes=False, shift=1, ): """get_exchange @@ -128,9 +52,6 @@ def get_exchange( dealing price type: 'close', 'open', 'vwap'. limit_threshold : float limit move 0.1 (10%) for example, long and short with same limit. - extract_codes: bool - will we pass the codes extracted from the pred to the exchange. - NOTE: This will be faster with offline qlib. Returns ------- @@ -149,176 +70,61 @@ def get_exchange( # handle exception for deal_price if deal_price[0] != "$": deal_price = "$" + deal_price - if extract_codes: - codes = sorted(pred.index.get_level_values("instrument").unique()) - else: - codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks - - dates = sorted(pred.index.get_level_values("datetime").unique()) - dates = np.append(dates, get_date_range(dates[-1], left_shift=1, right_shift=shift)) exchange = Exchange( - trade_dates=dates, + start_time=start_time, + end_time=end_time, codes=codes, deal_price=deal_price, subscribe_fields=subscribe_fields, limit_threshold=limit_threshold, open_cost=open_cost, close_cost=close_cost, - min_cost=min_cost, trade_unit=trade_unit, + min_cost=min_cost, ) - return exchange + return exchange + else: + return init_instance_by_config(exchange, accept_types=Exchange) +def init_env_instance_by_config(env): + if isinstance(env, dict): + env_config = copy.copy(env) + if "kwargs" in env_config: + env_kwargs = copy.copy(env_config["kwargs"]) + if "sub_env" in env_kwargs: + env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"]) + if "sub_strategy" in env_kwargs: + env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"]) + env_config["kwargs"] = env_kwargs + return init_instance_by_config(env_config) + else: + return env -def get_executor( - executor=None, - trade_exchange=None, - verbose=True, -): - """get_executor +def setup_exchange(root_instance, trade_exchange=None, force=False): + if "trade_exchange" in inspect.getfullargspec(root_instance.__class__).args: + if force: + root_instance.reset(trade_exchange=trade_exchange) + else: + if not hasattr(root_instance, "trade_exchange") or root_instance.trade_exchange is None: + root_instance.reset(trade_exchange=trade_exchange) + if hasattr(root_instance, "sub_env"): + setup_exchange(root_instance.sub_env, trade_exchange) + if hasattr(root_instance, "sub_strategy"): + setup_exchange(root_instance.sub_strategy, trade_exchange) + + +def backtest(start_time, end_time, strategy, env, benchmark=None, account=1e9, **kwargs): + trade_strategy = init_instance_by_config(strategy) + trade_env = init_env_instance_by_config(env) - There will be 3 ways to return a executor. Please follow the code. - - Parameters - ---------- - - executor : BaseExecutor - executor used in backtest. - trade_exchange : Exchange - exchange used in executor - verbose : bool - whether to print log. - - Returns - ------- - :class: BaseExecutor - an initialized BaseExecutor object - """ - - # There will be 3 ways to return a executor. - if executor is None: - # 1) create executor with param `executor` - logger.info("Create new executor ") - from ..online.executor import SimulatorExecutor - - executor = SimulatorExecutor(trade_exchange=trade_exchange, verbose=verbose) - elif isinstance(executor, (dict, str)): - # 2) create executor with config - logger.info("Create new executor ") - executor = init_instance_by_config(executor) - - from ..online.executor import BaseExecutor - - # 3) Use the executor directly - if not isinstance(executor, BaseExecutor): - raise TypeError("Executor not supported") - return executor - - -# This is the API for compatibility for legacy code -def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, return_order=False, **kwargs): - """This function will help you set a reasonable Exchange and provide default value for strategy - Parameters - ---------- - - - **backtest workflow related or commmon arguments** - - pred : pandas.DataFrame - predict should has index and one `score` column. - account : float - init account value. - shift : int - whether to shift prediction by one day. - benchmark : str - benchmark code, default is SH000905 CSI 500. - verbose : bool - whether to print log. - return_order : bool - whether to return order list - - - **strategy related arguments** - - strategy : Strategy() - strategy used in backtest. - topk : int (Default value: 50) - top-N stocks to buy. - margin : int or float(Default value: 0.5) - - if isinstance(margin, int): - - sell_limit = margin - - - else: - - sell_limit = pred_in_a_day.count() * margin - - buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit). - sell_limit should be no less than topk. - n_drop : int - number of stocks to be replaced in each trading date. - risk_degree: float - 0-1, 0.95 for example, use 95% money to trade. - str_type: 'amount', 'weight' or 'dropout' - strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy. - - - **exchange related arguments** - - exchange: Exchange() - pass the exchange for speeding up. - subscribe_fields: list - subscribe fields. - open_cost : float - open transaction cost. The default value is 0.002(0.2%). - close_cost : float - close transaction cost. The default value is 0.002(0.2%). - min_cost : float - min transaction cost. - trade_unit : int - 100 for China A. - deal_price: str - dealing price type: 'close', 'open', 'vwap'. - limit_threshold : float - limit move 0.1 (10%) for example, long and short with same limit. - extract_codes: bool - will we pass the codes extracted from the pred to the exchange. - - .. note:: This will be faster with offline qlib. - - - **executor related arguments** - - executor : BaseExecutor() - executor used in backtest. - verbose : bool - whether to print log. - - """ - # check strategy: - spec = inspect.getfullargspec(get_strategy) - str_args = {k: v for k, v in kwargs.items() if k in spec.args} - strategy = get_strategy(**str_args) - - # init exchange: spec = inspect.getfullargspec(get_exchange) - ex_args = {k: v for k, v in kwargs.items() if k in spec.args} - trade_exchange = get_exchange(pred, **ex_args) + exchange_args = {k: v for k, v in kwargs.items() if k in spec.args} + trade_exchange = get_exchange(**exchange_args) - # init executor: - executor = get_executor(executor=kwargs.get("executor"), trade_exchange=trade_exchange, verbose=verbose) + setup_exchange(trade_env, trade_exchange) + setup_exchange(trade_strategy, trade_exchange) - # run backtest - report_dict = backtest_func( - pred=pred, - strategy=strategy, - executor=executor, - trade_exchange=trade_exchange, - shift=shift, - verbose=verbose, - account=account, - benchmark=benchmark, - return_order=return_order, - ) - # for compatibility of the old API. return the dict positions + report_dict = backtest_func(start_time, end_time, trade_strategy, trade_env, benchmark, account) - positions = report_dict.get("positions") - report_dict.update({"positions": {k: p.position for k, p in positions.items()}}) return report_dict diff --git a/qlib/contrib/backtest/account.py b/qlib/contrib/backtest/account.py index a614f08b6..c44d26d7b 100644 --- a/qlib/contrib/backtest/account.py +++ b/qlib/contrib/backtest/account.py @@ -26,10 +26,10 @@ rtn & earning in the Account class Account: - def __init__(self, init_cash, last_trade_date=None): - self.init_vars(init_cash, last_trade_date) + def __init__(self, init_cash, last_trade_time=None): + self.init_vars(init_cash, last_trade_time) - def init_vars(self, init_cash, last_trade_date=None): + def init_vars(self, init_cash, last_trade_time=None): # init cash self.init_cash = init_cash self.current = Position(cash=init_cash) @@ -40,7 +40,7 @@ class Account: self.val = 0 self.report = Report() self.earning = 0 - self.last_trade_date = last_trade_date + self.last_trade_time = last_trade_time def get_positions(self): return self.positions @@ -83,9 +83,10 @@ class Account: self.current.update_order(order, trade_val, cost, trade_price) self.update_state_from_order(order, trade_val, cost, trade_price) - def update_daily_end(self, today, trader): + def update_bar_end(self, trade_start_time, trade_end_time, trade_exchange): """ - today: pd.TimeStamp + start_time: pd.TimeStamp + end_time: pd.TimeStamp quote: pd.DataFrame (code, date), collumns when the end of trade date - update rtn @@ -102,11 +103,11 @@ class Account: profit = 0 for code in stock_list: # if suspend, no new price to be updated, profit is 0 - if trader.check_stock_suspended(code, today): + if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time): continue - today_close = trader.get_close(code, today) - profit += (today_close - self.current.position[code]["price"]) * self.current.position[code]["amount"] - self.current.update_stock_price(stock_id=code, price=today_close) + bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time) + profit += (bar_close - self.current.position[code]["price"]) * self.current.position[code]["amount"] + self.current.update_stock_price(stock_id=code, price=bar_close) self.rtn += profit # update holding day count self.current.add_count_all() @@ -116,54 +117,54 @@ class Account: # account_value - last_account_value # for the first trade date, account_value - init_cash # self.report.is_empty() to judge is_first_trade_date - # get last_account_value, today_account_value, today_stock_value + # get last_account_value, now_account_value, now_stock_value if self.report.is_empty(): last_account_value = self.init_cash else: last_account_value = self.report.get_latest_account_value() - today_account_value = self.current.calculate_value() - today_stock_value = self.current.calculate_stock_value() - self.earning = today_account_value - last_account_value + now_account_value = self.current.calculate_value() + now_stock_value = self.current.calculate_stock_value() + self.earning = now_account_value - last_account_value # update report for today # judge whether the the trading is begin. # and don't add init account state into report, due to we don't have excess return in those days. self.report.update_report_record( - trade_date=today, - account_value=today_account_value, + trade_time=trade_start_time, + account_value=now_account_value, cash=self.current.position["cash"], return_rate=(self.earning + self.ct) / last_account_value, # here use earning to calculate return, position's view, earning consider cost, true return # in order to make same definition with original backtest in evaluate.py turnover_rate=self.to / last_account_value, cost_rate=self.ct / last_account_value, - stock_value=today_stock_value, + stock_value=now_stock_value, ) - # set today_account_value to position - self.current.position["today_account_value"] = today_account_value + # set now_account_value to position + self.current.position["now_account_value"] = now_account_value self.current.update_weight_all() # update positions # note use deepcopy - self.positions[today] = copy.deepcopy(self.current) + self.positions[trade_start_time] = copy.deepcopy(self.current) # finish today's updation # reset the daily variables self.rtn = 0 self.ct = 0 self.to = 0 - self.last_trade_date = today + self.last_trade_time = (trade_start_time, trade_end_time) def load_account(self, account_path): report = Report() position = Position() - last_trade_date = position.load_position(account_path / "position.xlsx") + last_trade_time = position.load_position(account_path / "position.xlsx") report.load_report(account_path / "report.csv") # assign values self.init_vars(position.init_cash) self.current = position self.report = report - self.last_trade_date = last_trade_date if last_trade_date else None + self.last_trade_time = last_trade_time def save_account(self, account_path): - self.current.save_position(account_path / "position.xlsx", self.last_trade_date) + self.current.save_position(account_path / "position.xlsx", self.last_trade_time) self.report.save_report(account_path / "report.csv") diff --git a/qlib/contrib/backtest/backtest.py b/qlib/contrib/backtest/backtest.py index b87d6afe3..8e157a361 100644 --- a/qlib/contrib/backtest/backtest.py +++ b/qlib/contrib/backtest/backtest.py @@ -4,140 +4,24 @@ import numpy as np import pandas as pd -from ...utils import get_date_by_shift, get_date_range -from ...data import D + from .account import Account -from ...config import C -from ...log import get_module_logger -from ...data.dataset.utils import get_level_index -LOG = get_module_logger("backtest") - - -def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, benchmark, return_order): - """Parameters - ---------- - pred : pandas.DataFrame - predict should has index and one `score` column - Qlib want to support multi-singal strategy in the future. So pd.Series is not used. - strategy : Strategy() - strategy part for backtest - trade_exchange : Exchange() - exchage for backtest - shift : int - whether to shift prediction by one day - verbose : bool - whether to print log - account : float - init account value - benchmark : str/list/pd.Series - `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T. - example: - print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head()) - 2017-01-04 0.011693 - 2017-01-05 0.000721 - 2017-01-06 -0.004322 - 2017-01-09 0.006874 - 2017-01-10 -0.003350 - - `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'. - `benchmark` is str, will use the daily change as the 'bench'. - benchmark code, default is SH000905 CSI500 - """ - # Convert format if the input format is not expected - if get_level_index(pred, level="datetime") == 1: - pred = pred.swaplevel().sort_index() - if isinstance(pred, pd.Series): - pred = pred.to_frame("score") +def backtest(start_time, end_time, trade_strategy, trade_env, benchmark, account): trade_account = Account(init_cash=account) - _pred_dates = pred.index.get_level_values(level="datetime") - predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max()) - if isinstance(benchmark, pd.Series): - bench = benchmark - else: - _codes = benchmark if isinstance(benchmark, list) else [benchmark] - _temp_result = D.features( - _codes, - ["$close/Ref($close,1)-1"], - predict_dates[0], - get_date_by_shift(predict_dates[-1], shift=shift), - disk_cache=1, - ) - if len(_temp_result) == 0: - raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark") - bench = _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean() + trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account) + trade_strategy.reset(start_time=start_time, end_time=end_time) - trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift)) - if return_order: - multi_order_list = [] - # trading apart - for pred_date, trade_date in zip(predict_dates, trade_dates): - # for loop predict date and trading date - # print - if verbose: - LOG.info("[I {:%Y-%m-%d}]: trade begin.".format(trade_date)) - - # 1. Load the score_series at pred_date - try: - score = pred.loc(axis=0)[pred_date, :] # (trade_date, stock_id) multi_index, score in pdate - score_series = score.reset_index(level="datetime", drop=True)[ - "score" - ] # pd.Series(index:stock_id, data: score) - except KeyError: - LOG.warning("No score found on predict date[{:%Y-%m-%d}]".format(trade_date)) - score_series = None - - if score_series is not None and score_series.count() > 0: # in case of the scores are all None - # 2. Update your strategy (and model) - strategy.update(score_series, pred_date, trade_date) - - # 3. Generate order list - order_list = strategy.generate_order_list( - score_series=score_series, - current=trade_account.current, - trade_exchange=trade_exchange, - pred_date=pred_date, - trade_date=trade_date, - ) - else: - order_list = [] - if return_order: - multi_order_list.append((trade_account, order_list, trade_date)) - # 4. Get result after executing order list - # NOTE: The following operation will modify order.amount. - # NOTE: If it is buy and the cash is insufficient, the tradable amount will be recalculated - trade_info = executor.execute(trade_account, order_list, trade_date) - - # 5. Update account information according to transaction - update_account(trade_account, trade_info, trade_exchange, trade_date) - - # generate backtest report + trade_state = trade_env.get_init_state() + while not trade_env.finished(): + _order_list = trade_strategy.generate_order_list(**trade_state) + print("_order_list", _order_list) + trade_state, trade_info = trade_env.execute(_order_list) + report_df = trade_account.report.generate_report_dataframe() - report_df["bench"] = bench positions = trade_account.get_positions() - report_dict = {"report_df": report_df, "positions": positions} - if return_order: - report_dict.update({"order_list": multi_order_list}) + return report_dict - -def update_account(trade_account, trade_info, trade_exchange, trade_date): - """Update the account and strategy - Parameters - ---------- - trade_account : Account() - trade_info : list of [Order(), float, float, float] - (order, trade_val, trade_cost, trade_price), trade_info with out factor - trade_exchange : Exchange() - used to get the $close_price at trade_date to update account - trade_date : pd.Timestamp - """ - # update account - for [order, trade_val, trade_cost, trade_price] in trade_info: - if order.deal_amount == 0: - continue - trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price) - # at the end of trade date, update the account based the $close_price of stocks. - trade_account.update_daily_end(today=trade_date, trader=trade_exchange) diff --git a/qlib/backtest/env.py b/qlib/contrib/backtest/env.py similarity index 89% rename from qlib/backtest/env.py rename to qlib/contrib/backtest/env.py index 571f33b7e..85a6c1ec3 100644 --- a/qlib/backtest/env.py +++ b/qlib/contrib/backtest/env.py @@ -5,13 +5,13 @@ import json import copy import warnings import pathlib +import numpy as np import pandas as pd -from loguru import Logger -from ...data import D, Cal -from ...utils import get_date_in_file_name -from ...utils import get_pre_trading_date -from ..backtest.order import Order -from ..utils import init_instance_by_config +from ...data.data import Cal +from ...utils import get_sample_freq_calendar +from .order import Order + + class TradeCalendarBase: def _reset_trade_calendar(self, start_time, end_time): @@ -20,10 +20,10 @@ class TradeCalendarBase: if end_time: self.end_time = pd.Timestamp(end_time) if self.start_time and self.end_time: - _calendar, freq, freq_sam = get_sample_freq_calendar(freq=step_bar) + _calendar, freq, freq_sam = get_sample_freq_calendar(freq=self.step_bar) self.calendar = _calendar _start_time, _end_time, _start_index, _end_index = Cal.locate_index(self.start_time, self.end_time, freq=freq, freq_sam=freq_sam) - _trade_calendar = self.calendar[_start_index, _end_index + 1] + _trade_calendar = self.calendar[_start_index: _end_index + 1] if _start_time != self.start_time: self.trade_calendar = np.hstack((self.start_time, _trade_calendar, self.end_time)) self.start_index = _start_index - 1 @@ -40,7 +40,7 @@ class TradeCalendarBase: trade_index = trade_index - shift if 0 < trade_index < self.trade_len - 1: trade_start_time = self.trade_calendar[trade_index - 1] - trade_end_time = self.trade_calendar[trade_index] - pd.Timestamp(second=1) + trade_end_time = self.trade_calendar[trade_index] - pd.Timedelta(seconds=1) return trade_start_time, trade_end_time elif trade_index == self.trade_len - 1: trade_start_time = self.trade_calendar[trade_index - 1] @@ -68,7 +68,7 @@ class BaseEnv(TradeCalendarBase): end_time=None, trade_account=None, verbose=False, - **kwargs + **kwargs, ): self.step_bar = step_bar self.verbose = verbose @@ -76,24 +76,24 @@ class BaseEnv(TradeCalendarBase): def _get_position(self): return self.trade_account.current - def reset(self, start_time=None, end_time=None, trade_account=None, **kwargs): if start_time or end_time: self._reset_trade_calendar(start_time=start_time, end_time=end_time) - self.trade_account = trade_account + if trade_account: + self.trade_account = trade_account for k, v in kwargs: if hasattr(self, k): setattr(self, k, v) - def get_first_state(self): + def get_init_state(self): init_state = {"current": self._get_position()} return init_state - def execute(self, order_list, **kwargs): + def execute(self, order_list=None, **kwargs): self.trade_index = self.trade_index + 1 def finished(self): @@ -122,13 +122,13 @@ class SplitEnv(BaseEnv): #if self.track: # yield action #episode_reward = 0 - super(SimulatorEnv, self).execute(**kwargs) + super(SplitEnv, self).execute(**kwargs) trade_start_time, trade_end_time = self._get_trade_time(trade_index=self.trade_index) self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time, trade_account=self.trade_account) self.sub_strategy.reset(start_time=trade_start_time, end_time=trade_end_time, trade_order_list=order_list) trade_state = self.sub_env.get_init_state() while not self.sub_env.finished(): - _order_list = self.sub_strategy.generate_order(**trade_state) + _order_list = self.sub_strategy.generate_order_list(**trade_state) trade_state, trade_info = self.sub_env.execute(order_list=_order_list) #episode_reward += sub_reward _obs = {"current": self._get_position()} @@ -149,11 +149,12 @@ class SimulatorEnv(BaseEnv): verbose=False, **kwargs, ): - super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, trade_exchange=trade_exchange, verbose=verbose) + super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, trade_exchange=trade_exchange, verbose=verbose, **kwargs) - def reset(trade_exchange=None, **kwargs): + def reset(self, trade_exchange=None, **kwargs): super(SimulatorEnv, self).reset(**kwargs) - self.trade_exchange=trade_exchange + if trade_exchange: + self.trade_exchange=trade_exchange def execute(self, order_list, **kwargs): """ @@ -162,7 +163,7 @@ class SimulatorEnv(BaseEnv): if self.finished(): raise StopIteration(f"this env has completed its task, please reset it if you want to call it!") super(SimulatorEnv, self).execute(**kwargs) - ttrade_start_time, trade_end_time = self._get_trade_time(trade_index=self.trade_index) + trade_start_time, trade_end_time = self._get_trade_time(trade_index=self.trade_index) trade_info = [] for order in order_list: if self.trade_exchange.check_order(order) is True: diff --git a/qlib/contrib/backtest/exchange.py b/qlib/contrib/backtest/exchange.py index 178950eeb..62f6c63bd 100644 --- a/qlib/contrib/backtest/exchange.py +++ b/qlib/contrib/backtest/exchange.py @@ -8,16 +8,19 @@ import logging import numpy as np import pandas as pd -from ...data import D -from .order import Order +from ...data.data import D from ...config import C, REG_CN +from ...utils import sample_feature from ...log import get_module_logger +from .order import Order + class Exchange: def __init__( self, - trade_dates=None, + start_time=None, + end_time=None, codes="all", deal_price=None, subscribe_fields=[], @@ -30,7 +33,8 @@ class Exchange: ): """__init__ - :param trade_dates: list of pd.Timestamp + :param start_time: start time for backtest + :param end_time: end time for backtest :param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50) :param deal_price: str, 'close', 'open', 'vwap' :param subscribe_fields: list, subscribe fields @@ -51,6 +55,8 @@ class Exchange: target on this day). index: MultipleIndex(instrument, pd.Datetime) """ + self.start_time = start_time + self.end_time = end_time if trade_unit is None: trade_unit = C.trade_unit if limit_threshold is None: @@ -91,21 +97,15 @@ class Exchange: self.close_cost = close_cost self.min_cost = min_cost self.limit_threshold = limit_threshold - # TODO: the quote, trade_dates, codes are not necessray. - # It is just for performance consideration. - if trade_dates is not None and len(trade_dates): - start_date, end_date = trade_dates[0], trade_dates[-1] - else: - self.logger.warning("trade_dates have not been assigned, all dates will be loaded") - start_date, end_date = None, None + self.extra_quote = extra_quote - self.set_quote(codes, start_date, end_date) + self.set_quote(codes, start_time, end_time) - def set_quote(self, codes, start_date, end_date): + def set_quote(self, codes, start_time, end_time): if len(codes) == 0: codes = D.instruments() - self.quote = D.features(codes, self.all_fields, start_date, end_date, disk_cache=True).dropna(subset=["$close"]) + self.quote = D.features(codes, self.all_fields, start_time, end_time, disk_cache=True).dropna(subset=["$close"]) self.quote.columns = self.all_fields if self.quote[self.deal_price].isna().any(): @@ -146,35 +146,37 @@ class Exchange: quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0) # update quote: pd.DataFrame to dict, for search use - self.quote = quote_df.to_dict("index") + self.quote = quote_df def _update_limit(self, buy_limit, sell_limit): self.quote["limit"] = ~self.quote["$change"].between(-sell_limit, buy_limit, inclusive=False) - def check_stock_limit(self, stock_id, trade_date): + def check_stock_limit(self, stock_id, start_time, end_time): """Parameter stock_id trade_date is limtited """ - return self.quote[(stock_id, trade_date)]["limit"] + return sample_feature(self.quote, stock_id, start_time, end_time, fields="limit", method="any").iloc[0] + - def check_stock_suspended(self, stock_id, trade_date): + def check_stock_suspended(self, stock_id, start_time, end_time): # is suspended - return (stock_id, trade_date) not in self.quote + return sample_feature(self.quote, stock_id, start_time, end_time).empty - def is_stock_tradable(self, stock_id, trade_date): + + def is_stock_tradable(self, stock_id, start_time, end_time): # check if stock can be traded # same as check in check_order - if self.check_stock_suspended(stock_id, trade_date) or self.check_stock_limit(stock_id, trade_date): + if self.check_stock_suspended(stock_id, start_time, end_time) or self.check_stock_limit(stock_id, start_time, end_time): return False else: return True def check_order(self, order): # check limit and suspended - if self.check_stock_suspended(order.stock_id, order.trade_date) or self.check_stock_limit( - order.stock_id, order.trade_date + if self.check_stock_suspended(order.stock_id, order.start_time, order.end_time) or self.check_stock_limit( + order.stock_id, order.start_time, order.end_time ): return False else: @@ -199,7 +201,7 @@ class Exchange: if trade_account is not None and position is not None: raise ValueError("trade_account and position can only choose one") - trade_price = self.get_deal_price(order.stock_id, order.trade_date) + trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time) trade_val, trade_cost = self._calc_trade_info_by_order( order, trade_account.current if trade_account else position ) @@ -214,24 +216,24 @@ class Exchange: return trade_val, trade_cost, trade_price - def get_quote_info(self, stock_id, trade_date): - return self.quote[(stock_id, trade_date)] + def get_quote_info(self, stock_id, start_time, end_time): + return sample_feature(self.quote, stock_id, start_time, end_time) - def get_close(self, stock_id, trade_date): - return self.quote[(stock_id, trade_date)]["$close"] + def get_close(self, stock_id, start_time, end_time): + return sample_feature(self.quote, stock_id, start_time, end_time, fields="$close", method="last").iloc[0] - def get_deal_price(self, stock_id, trade_date): - deal_price = self.quote[(stock_id, trade_date)][self.deal_price] + def get_deal_price(self, stock_id, start_time, end_time): + deal_price = sample_feature(self.quote, stock_id, start_time, end_time, fields=self.deal_price, method="last").iloc[0] if np.isclose(deal_price, 0.0) or np.isnan(deal_price): - self.logger.warning(f"(stock_id:{stock_id}, trade_date:{trade_date}, {self.deal_price}): {deal_price}!!!") + self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!") self.logger.warning(f"setting deal_price to close price") - deal_price = self.get_close(stock_id, trade_date) + deal_price = self.get_close(stock_id, start_time, end_time) return deal_price - def get_factor(self, stock_id, trade_date): - return self.quote[(stock_id, trade_date)]["$factor"] + def get_factor(self, stock_id, start_time, end_time): + return sample_feature(self.quote, stock_id, start_time, end_time, fields="$factor", method="last").iloc[0] - def generate_amount_position_from_weight_position(self, weight_position, cash, trade_date): + def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time): """ The generate the target position according to the weight and the cash. NOTE: All the cash will assigned to the tadable stock. @@ -246,7 +248,7 @@ class Exchange: # calculate the total weight of tradable value tradable_weight = 0.0 for stock_id in weight_position: - if self.is_stock_tradable(stock_id=stock_id, trade_date=trade_date): + if self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time): # weight_position must be greater than 0 and less than 1 if weight_position[stock_id] < 0 or weight_position[stock_id] > 1: raise ValueError( @@ -260,12 +262,12 @@ class Exchange: amount_dict = {} for stock_id in weight_position: - if weight_position[stock_id] > 0.0 and self.is_stock_tradable(stock_id=stock_id, trade_date=trade_date): + if weight_position[stock_id] > 0.0 and self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time): amount_dict[stock_id] = ( cash * weight_position[stock_id] / tradable_weight - // self.get_deal_price(stock_id=stock_id, trade_date=trade_date) + // self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time) ) return amount_dict @@ -292,7 +294,7 @@ class Exchange: deal_amount = self.round_amount_by_trade_unit(deal_amount, factor) return -deal_amount - def generate_order_for_target_amount_position(self, target_position, current_position, trade_date): + def generate_order_for_target_amount_position(self, target_position, current_position, start_time, end_time): """Parameter: target_position : dict { stock_id : amount } current_postion : dict { stock_id : amount} @@ -315,12 +317,12 @@ class Exchange: for stock_id in sorted_ids: # Do not generate order for the nontradable stocks - if not self.is_stock_tradable(stock_id=stock_id, trade_date=trade_date): + if not self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time): continue target_amount = target_position.get(stock_id, 0) current_amount = current_position.get(stock_id, 0) - factor = self.quote[(stock_id, trade_date)]["$factor"] + factor = self.get_factor(stock_id, start_time=start_time, end_time=end_time) deal_amount = self.get_real_deal_amount(current_amount, target_amount, factor) if deal_amount == 0: @@ -332,7 +334,8 @@ class Exchange: stock_id=stock_id, amount=deal_amount, direction=Order.BUY, - trade_date=trade_date, + start_time=start_time, + end_time=end_time, factor=factor, ) ) @@ -343,14 +346,15 @@ class Exchange: stock_id=stock_id, amount=abs(deal_amount), direction=Order.SELL, - trade_date=trade_date, + start_time=start_time, + end_time=end_time, factor=factor, ) ) # return order_list : buy + sell return sell_order_list + buy_order_list - def calculate_amount_position_value(self, amount_dict, trade_date, only_tradable=False): + def calculate_amount_position_value(self, amount_dict, start_time, end_time, only_tradable=False): """Parameter position : Position() amount_dict : {stock_id : amount} @@ -358,10 +362,10 @@ class Exchange: value = 0 for stock_id in amount_dict: if ( - self.check_stock_suspended(stock_id=stock_id, trade_date=trade_date) is False - and self.check_stock_limit(stock_id=stock_id, trade_date=trade_date) is False + self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time) is False + and self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) is False ): - value += self.get_deal_price(stock_id=stock_id, trade_date=trade_date) * amount_dict[stock_id] + value += self.get_deal_price(stock_id=stock_id, start_time=start_time, end_time=end_time) * amount_dict[stock_id] return value def round_amount_by_trade_unit(self, deal_amount, factor): @@ -384,7 +388,7 @@ class Exchange: :return: trade_val, trade_cost """ - trade_price = self.get_deal_price(order.stock_id, order.trade_date) + trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time) if order.direction == Order.SELL: # sell if position is not None: diff --git a/qlib/contrib/backtest/interpreter.py b/qlib/contrib/backtest/interpreter.py new file mode 100644 index 000000000..94d6f9ec2 --- /dev/null +++ b/qlib/contrib/backtest/interpreter.py @@ -0,0 +1,15 @@ + +class BaseInterpreter: + @staticmethod + def interpret(**kwargs): + raise NotImplementedError("interpret is not implemented!") + +class ActionInterpreter: + @staticmethod + def interpret(action, **kwargs): + return action + +class StateInterpreter: + @staticmethod + def interpret(state, **kwargs): + return state \ No newline at end of file diff --git a/qlib/contrib/backtest/order.py b/qlib/contrib/backtest/order.py index 740773b2f..0d637d9db 100644 --- a/qlib/contrib/backtest/order.py +++ b/qlib/contrib/backtest/order.py @@ -7,7 +7,7 @@ class Order: SELL = 0 BUY = 1 - def __init__(self, stock_id, amount, trade_date, direction, factor): + def __init__(self, stock_id, amount, start_time, end_time, direction, factor): """Parameter direction : Order.SELL for sell; Order.BUY for buy stock_id : str @@ -24,6 +24,7 @@ class Order: self.amount = amount # amount of successfully completed orders self.deal_amount = 0 - self.trade_date = trade_date + self.start_time = start_time + self.end_time = end_time self.direction = direction self.factor = factor diff --git a/qlib/contrib/backtest/position.py b/qlib/contrib/backtest/position.py index 6c269d505..ac1a471f8 100644 --- a/qlib/contrib/backtest/position.py +++ b/qlib/contrib/backtest/position.py @@ -28,13 +28,13 @@ a typical example is :{ class Position: """Position""" - def __init__(self, cash=0, position_dict={}, today_account_value=0): + def __init__(self, cash=0, position_dict={}, now_account_value=0): # NOTE: The position dict must be copied!!! # Otherwise the initial value self.init_cash = cash self.position = position_dict.copy() self.position["cash"] = cash - self.position["today_account_value"] = today_account_value + self.position["now_account_value"] = now_account_value def init_stock(self, stock_id, amount, price=None): self.position[stock_id] = {} @@ -82,7 +82,7 @@ class Position: # SELL self.sell_stock(order.stock_id, trade_val, cost, trade_price) else: - raise NotImplementedError("do not suppotr order direction {}".format(order.direction)) + raise NotImplementedError("do not support order direction {}".format(order.direction)) def update_stock_price(self, stock_id, price): self.position[stock_id]["price"] = price @@ -109,7 +109,7 @@ class Position: return value def get_stock_list(self): - stock_list = list(set(self.position.keys()) - {"cash", "today_account_value"}) + stock_list = list(set(self.position.keys()) - {"cash", "now_account_value"}) return stock_list def get_stock_price(self, code): @@ -163,16 +163,17 @@ class Position: for stock_code, weight in weight_dict.items(): self.update_stock_weight(stock_code, weight) - def save_position(self, path, last_trade_date): + def save_position(self, path, last_trade_time): path = pathlib.Path(path) p = copy.deepcopy(self.position) cash = pd.Series(dtype=np.float) cash["init_cash"] = self.init_cash cash["cash"] = p["cash"] - cash["today_account_value"] = p["today_account_value"] - cash["last_trade_date"] = str(last_trade_date.date()) if last_trade_date else None + cash["now_account_value"] = p["now_account_value"] + cash["last_trade_start_time"] = str(last_trade_time[0]) if last_trade_time else None + cash["last_trade_end_time"] = str(last_trade_time[1]) if last_trade_time else None del p["cash"] - del p["today_account_value"] + del p["now_account_value"] positions = pd.DataFrame.from_dict(p, orient="index") with pd.ExcelWriter(path) as writer: positions.to_excel(writer, sheet_name="position") @@ -189,10 +190,10 @@ class Position: 'weight': , sheet "cash" - index: ['init_cash', 'cash', 'today_account_value'] + index: ['init_cash', 'cash', 'now_account_value'] 'init_cash': , 'cash': , - 'today_account_value': + 'now_account_value': """ path = pathlib.Path(path) positions = pd.read_excel(open(path, "rb"), sheet_name="position", index_col=0) @@ -200,14 +201,17 @@ class Position: positions = positions.to_dict(orient="index") init_cash = cash_record.loc["init_cash"].values[0] cash = cash_record.loc["cash"].values[0] - today_account_value = cash_record.loc["today_account_value"].values[0] - last_trade_date = cash_record.loc["last_trade_date"].values[0] + now_account_value = cash_record.loc["now_account_value"].values[0] + last_trade_start_time = cash_record.loc["last_trade_start_time"].values[0] + last_trade_end_time = cash_record.loc["last_trade_end_time"].values[0] # assign values self.position = {} self.init_cash = init_cash self.position = positions self.position["cash"] = cash - self.position["today_account_value"] = today_account_value + self.position["now_account_value"] = now_account_value - return None if pd.isna(last_trade_date) else pd.Timestamp(last_trade_date) + last_trade_start_time = None if pd.isna(last_trade_start_time) else pd.Timestamp(last_trade_start_time) + last_trade_end_time = None if pd.isna(last_trade_end_time) else pd.Timestamp(last_trade_end_time) + return last_trade_start_time, last_trade_end_time diff --git a/qlib/contrib/backtest/report.py b/qlib/contrib/backtest/report.py index beb9759d0..9a57156f2 100644 --- a/qlib/contrib/backtest/report.py +++ b/qlib/contrib/backtest/report.py @@ -21,20 +21,20 @@ class Report: self.costs = OrderedDict() # trade cost for each trade date self.values = OrderedDict() # value for each trade date self.cashes = OrderedDict() - self.latest_report_date = None # pd.TimeStamp + self.latest_report_time = None # pd.TimeStamp def is_empty(self): return len(self.accounts) == 0 def get_latest_date(self): - return self.latest_report_date + return self.latest_report_time def get_latest_account_value(self): - return self.accounts[self.latest_report_date] + return self.accounts[self.latest_report_time] def update_report_record( self, - trade_date=None, + trade_time=None, account_value=None, cash=None, return_rate=None, @@ -44,7 +44,7 @@ class Report: ): # check data if None in [ - trade_date, + trade_time, account_value, cash, return_rate, @@ -56,14 +56,14 @@ class Report: "None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value]" ) # update report data - self.accounts[trade_date] = account_value - self.returns[trade_date] = return_rate - self.turnovers[trade_date] = turnover_rate - self.costs[trade_date] = cost_rate - self.values[trade_date] = stock_value - self.cashes[trade_date] = cash + self.accounts[trade_time] = account_value + self.returns[trade_time] = return_rate + self.turnovers[trade_time] = turnover_rate + self.costs[trade_time] = cost_rate + self.values[trade_time] = stock_value + self.cashes[trade_time] = cash # update latest_report_date - self.latest_report_date = trade_date + self.latest_report_time = trade_time # finish daily report update def generate_report_dataframe(self): @@ -74,7 +74,7 @@ class Report: report["cost"] = pd.Series(self.costs) report["value"] = pd.Series(self.values) report["cash"] = pd.Series(self.cashes) - report.index.name = "date" + report.index.name = "trade_time" return report def save_report(self, path): @@ -94,13 +94,13 @@ class Report: index = r.index self.init_vars() - for date in index: + for trade_time in index: self.update_report_record( - trade_date=date, - account_value=r.loc[date]["account"], - cash=r.loc[date]["cash"], - return_rate=r.loc[date]["return"], - turnover_rate=r.loc[date]["turnover"], - cost_rate=r.loc[date]["cost"], - stock_value=r.loc[date]["value"], + trade_time=trade_time, + account_value=r.loc[trade_time]["account"], + cash=r.loc[trade_time]["cash"], + return_rate=r.loc[trade_time]["return"], + turnover_rate=r.loc[trade_time]["turnover"], + cost_rate=r.loc[trade_time]["cost"], + stock_value=r.loc[trade_time]["value"], ) diff --git a/qlib/contrib/backtest_new/backtest.py b/qlib/contrib/backtest_new/backtest.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/qlib/contrib/strategy/__init__.py b/qlib/contrib/strategy/__init__.py index f0ec0a5d0..678b048c2 100644 --- a/qlib/contrib/strategy/__init__.py +++ b/qlib/contrib/strategy/__init__.py @@ -4,13 +4,13 @@ from .dl_strategy import ( TopkDropoutStrategy, - BaseStrategy, WeightStrategyBase, ) from .rule_strategy import( TWAPStrategy, - SBBEMAStrategy + SBBStrategyBase, + SBBStrategyEMA, ) from .cost_control import ( diff --git a/qlib/contrib/strategy/cost_control.py b/qlib/contrib/strategy/cost_control.py index 001630a95..962936f9f 100644 --- a/qlib/contrib/strategy/cost_control.py +++ b/qlib/contrib/strategy/cost_control.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. -from .strategy import WeightStrategyBase +from .dl_strategy import WeightStrategyBase import copy diff --git a/qlib/contrib/strategy/dl_strategy.py b/qlib/contrib/strategy/dl_strategy.py index 5f702fe0b..4c7d16eea 100644 --- a/qlib/contrib/strategy/dl_strategy.py +++ b/qlib/contrib/strategy/dl_strategy.py @@ -4,12 +4,12 @@ import numpy as np import pandas as pd from ...utils import sample_feature -from ...strategy.base import DLStrategy -from ...backtest.order import Order +from ...strategy.base import ModelStrategy +from ..backtest.order import Order from .order_generator import OrderGenWInteract -class TopkDropoutStrategy(DLStrategy): +class TopkDropoutStrategy(ModelStrategy): def __init__( self, step_bar, @@ -53,7 +53,7 @@ class TopkDropoutStrategy(DLStrategy): else: strategy will make decision with the tradable state of the stock info and avoid buy and sell them. """ - super(TopkDropoutStrategy, self).__init__(step_bar, model, dataset, start_time, end_time) + super(TopkDropoutStrategy, self).__init__(step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange) self.topk = topk self.n_drop = n_drop self.method_sell = method_sell @@ -67,9 +67,10 @@ class TopkDropoutStrategy(DLStrategy): self.only_tradable = only_tradable - def reset(trade_exchange=None, **kwargs): + def reset(self, trade_exchange=None, **kwargs): super(TopkDropoutStrategy, self).reset(**kwargs) - self.trade_exchange = trade_exchange + if trade_exchange: + self.trade_exchange = trade_exchange def get_risk_degree(self, trade_index): """get_risk_degree @@ -189,7 +190,7 @@ class TopkDropoutStrategy(DLStrategy): # update cash cash += trade_val - trade_cost # sold - del self.stock_count[code] + self.stock_count[code] = 0 else: # no buy signal, but the stock is kept self.stock_count[code] += 1 @@ -210,10 +211,10 @@ class TopkDropoutStrategy(DLStrategy): # value = value / (1+self.trade_exchange.open_cost) # set open_cost limit for code in buy: # check is stock suspended - if not self.trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date): + if not self.trade_exchange.is_stock_tradable(stock_id=code, start_time=trade_start_time, end_time=trade_end_time): continue # buy order - buy_price = self.trade_exchange.get_deal_price(stock_id=code, trade_date=trade_date) + buy_price = self.trade_exchange.get_deal_price(stock_id=code, start_time=trade_start_time, end_time=trade_end_time) buy_amount = value / buy_price factor = self.trade_exchange.get_factor(stock_id=code, start_time=trade_start_time, end_time=trade_end_time) buy_amount = self.trade_exchange.round_amount_by_trade_unit(buy_amount, factor) @@ -229,8 +230,8 @@ class TopkDropoutStrategy(DLStrategy): self.stock_count[code] = 1 return sell_order_list + buy_order_list -class WeightStrategyBase(DLStrategy): - def __init__(self, trade_exchange, order_generator_cls_or_obj=OrderGenWInteract, start_time=None, end_time=None, **kwargs): +class WeightStrategyBase(ModelStrategy): + def __init__(self, step_bar, start_time=None, end_time=None, order_generator_cls_or_obj=OrderGenWInteract, trade_exchange=None, **kwargs): super(WeightStrategyBase, self).__init__(step_bar, start_time, end_time) self.trade_exchange = trade_exchange if isinstance(order_generator_cls_or_obj, type): diff --git a/qlib/contrib/strategy/order_generator.py b/qlib/contrib/strategy/order_generator.py index cdbd30c1f..d263f658d 100644 --- a/qlib/contrib/strategy/order_generator.py +++ b/qlib/contrib/strategy/order_generator.py @@ -4,8 +4,8 @@ """ This order generator is for strategies based on WeightStrategyBase """ -from ...backtest.position import Position -from ...backtest.exchange import Exchange +from ..backtest.position import Position +from ..backtest.exchange import Exchange import pandas as pd import copy diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index dd2e17c54..b51ec9aca 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -4,18 +4,20 @@ import numpy as np import pandas as pd from ...utils import sample_feature +from ...data.data import D from ...strategy.base import RuleStrategy, TradingEnhancement -from ...backtest.order import Order +from ..backtest.order import Order class TWAPStrategy(RuleStrategy, TradingEnhancement): def reset(self, trade_order_list=None, **kwargs): super(TWAPStrategy, self).reset(**kwargs) - TradingEnhancement.reset(trade_order_list=trade_order_list) - self.trade_amount = {} - for order in self.trade_order_list: - self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len + TradingEnhancement.reset(self, trade_order_list=trade_order_list) + if trade_order_list: + self.trade_amount = {} + for order in self.trade_order_list: + self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len def generate_order_list(self, **kwargs): @@ -43,13 +45,15 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement): TREND_LONG = 2 def reset(self, trade_order_list=None, **kwargs): - TradingEnhancement.reset(trade_order_list=trade_order_list) - self.trade_amount = {} - self.trade_delay = {} - for order in self.trade_order_list: - self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len - self.trade_trend[(order.stock_id, order.direction)] = TREND_MID super(SBBStrategyBase, self).reset(**kwargs) + TradingEnhancement.reset(self, trade_order_list=trade_order_list) + if trade_order_list: + self.trade_amount = {} + self.trade_trend = {} + for order in self.trade_order_list: + self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len + self.trade_trend[(order.stock_id, order.direction)] = self.TREND_MID + def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None): raise NotImplementedError("pred_price_trend method is not implemented!") @@ -64,7 +68,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement): _pred_trend = self._pred_price_trend(order.stock_id) else: _pred_trend = self.trade_trend[(order.stock_id, order.direction)] - if _pred_trend == TREND_MID: + if _pred_trend == self.TREND_MID: _order = Order( stock_id=order.stock_id, amount=self.trade_amount[(order.stock_id, order.direction)], @@ -97,7 +101,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement): factor=order.factor, ) order_list.append(_order) - if self.trade_index % 2 == 1 + if self.trade_index % 2 == 1: self.trade_trend[(order.stock_id, order.direction)] = _pred_trend return order_list @@ -110,8 +114,8 @@ class SBBStrategyEMA(SBBStrategyBase): def __init__( self, step_bar, - start_time, - end_time, + start_time=None, + end_time=None, instruments="csi300", freq="day", **kwargs, @@ -121,21 +125,23 @@ class SBBStrategyEMA(SBBStrategyBase): warnings.warn("`instruments` is not set, will load all stocks") self.instruments = "all" if isinstance(instruments, str): - self.instruments = D.instruments(instruments, filter_pipe=self.filter_pipe) + self.instruments = D.instruments(instruments) self.freq = freq - def _reset_trade_calendar(self, start_time=None, end_time=None, _calendar=None): - super(SBBStrategyEMA, self)._reset_trade_calendar(start_time=start_time, end_time=end_time, _calendar=_calendar) - fields = [("EMA($close, 10) - EMA($close, 20)", "signal")] - signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1) - self.signal = D.features(instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq) + def _reset_trade_calendar(self, start_time=None, end_time=None): + super(SBBStrategyEMA, self)._reset_trade_calendar(start_time=start_time, end_time=end_time) + if self.start_time and self.end_time: + fields = ["EMA($close, 10)-EMA($close, 20)"] + signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1) + self.signal = D.features(self.instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq) + self.signal.columns = ["signal"] def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None): _sample_signal = sample_feature(self.signal, stock_id, start_time=pred_start_time, end_time=pred_end_time, fields="signal", method="last") if _sample_signal.empty: - return SBBStrategy.TREND_MID - elif _sample_signal.iloc[0, 0] > 0: - return SBBStrategy.TREND_LONG + return self.TREND_MID + elif _sample_signal.iloc[0] > 0: + return self.TREND_LONG else: - return SBBStrategy.TREND_SHORT \ No newline at end of file + return self.TREND_SHORT \ No newline at end of file diff --git a/qlib/data/data.py b/qlib/data/data.py index 98427637a..a8d5a42ab 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -117,6 +117,7 @@ class CalendarProvider(abc.ABC): flag = f"{freq}_sam_{freq_sam}_future_{future}" if flag in H["c"]: _calendar, _calendar_index = H["c"][flag] + return _calendar, _calendar_index else: flag_raw = f"{freq}_sam_{None}_future_{future}" if flag_raw in H["c"]: @@ -125,6 +126,7 @@ class CalendarProvider(abc.ABC): _calendar = np.array(self.load_calendar(freq, future)) _calendar_index = {x: i for i, x in enumerate(_calendar)} # for fast search H["c"][flag_raw] = _calendar, _calendar_index + if freq_sam is None: return _calendar, _calendar_index else: @@ -132,6 +134,7 @@ class CalendarProvider(abc.ABC): _calendar_sam_index = {x: i for i, x in enumerate(_calendar_sam)} H["c"][flag] = _calendar_sam, _calendar_sam_index return _calendar_sam, _calendar_sam_index + def _uri(self, start_time, end_time, freq, future=False): """Get the uri of calendar generation task.""" @@ -541,8 +544,8 @@ class LocalCalendarProvider(CalendarProvider): with open(fname) as f: return [pd.Timestamp(x.strip()) for x in f] - def calendar(self, start_time=None, end_time=None, freq="day", future=False, freq_sam=None): - _calendar, _ = self._get_calendar(freq=freq, future=future) + def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, future=False): + _calendar, _ = self._get_calendar(freq=freq, freq_sam=freq_sam, future=future) # strip if start_time: start_time = pd.Timestamp(start_time) @@ -764,6 +767,7 @@ class ClientCalendarProvider(CalendarProvider): self.conn = conn def calendar(self, start_time=None, end_time=None, freq="day", future=False): + self.conn.send_request( request_type="calendar", request_content={ diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index cad093af2..193906dcd 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -10,8 +10,9 @@ import pandas as pd from ..utils import get_sample_freq_calendar from ..data.dataset import DatasetH -from ..backtest.order import Order -from ..backtest.env import TradeCalendarBase +from ..data.dataset.utils import get_level_index +from ..contrib.backtest.order import Order +from ..contrib.backtest.env import TradeCalendarBase """ 1. BaseStrategy 的粒度一定是数据粒度的整数倍 @@ -24,26 +25,14 @@ class BaseStrategy(TradeCalendarBase): self.step_bar = step_bar self.reset(start_time=start_time, end_time=end_time, **kwargs) - def reset(self, start_time=None, end_time=None, _calendar=None, **kwargs): + def reset(self, start_time=None, end_time=None, **kwargs): if start_time or end_time : - self._reset_trade_calendar(start_time=start_time, end_time=end_time, calendar=calendar) + self._reset_trade_calendar(start_time=start_time, end_time=end_time) for k, v in kwargs: if hasattr(self, k): setattr(self, k, v) - - def _get_trade_time(self): - if 0 < self.trade_index < self.trade_len - 1: - trade_start_time = self.trade_calendar[self.trade_index - 1] - trade_end_time = self.trade_calendar[self.trade_index] - pd.Timestamp(second=1) - return trade_start_time, trade_end_time - elif self.trade_index == self.trade_len - 1: - trade_start_time = self.trade_calendar[self.trade_index - 1] - trade_end_time = self.trade_calendar[self.trade_index] - return trade_start_time, trade_end_time - else: - raise RuntimeError("trade_index out of range") def generate_order_list(self, **kwargs): self.trade_index = self.trade_index + 1 @@ -52,20 +41,26 @@ class BaseStrategy(TradeCalendarBase): class RuleStrategy(BaseStrategy): pass -class DLStrategy(BaseStrategy): - def __init__(self, step_bar, model, dataset:DatasetH, start_time=None, end_time=None): +class ModelStrategy(BaseStrategy): + def __init__(self, step_bar, model, dataset:DatasetH, start_time=None, end_time=None, **kwargs): self.model = model self.dataset = dataset - self.pred_scores = self.model.predict(dataset) + self.pred_scores = self._convert_index_format(self.model.predict(dataset)) #pred_score_dates = self.pred_scores.index.get_level_values(level="datetime") - super(DLStrategy, self).__init__(step_bar, start_time, end_time) + super(ModelStrategy, self).__init__(step_bar, start_time, end_time, **kwargs) - def _update_model(self): + def _convert_index_format(self, df): + if get_level_index(df, level="datetime") == 0: + df = df.swaplevel().sort_index() + return df + + def _update_model(self): """update pred score """ pass class TradingEnhancement: - def reset(self, trade_order_list): - self.trade_order_list = trade_order_list + def reset(self, trade_order_list=None): + if trade_order_list: + self.trade_order_list = trade_order_list diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 028e60cc6..0f365956d 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -15,6 +15,7 @@ import bisect import shutil import difflib import hashlib +import warnings import datetime import requests import tempfile @@ -918,37 +919,40 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam): else: raise ValueError("sample freq must be xmin, xd, xw, xm") -def get_sample_freq_calendar(start_time=None, end_time=None, freq, **kwargs): +def get_sample_freq_calendar(start_time=None, end_time=None, freq="day", **kwargs): + from ..data.data import Cal + try: - _calendar = D.calendar(start_time=start_time, end_time=end_time, freq=freq, **kwargs) + _calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq=freq, **kwargs) freq, freq_sam = freq, None except ValueError: freq_sam = freq if freq.endswith(("m", "month", "w", "week", "d", "day")): try: - _calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="min", freq_sam=freq, **kwargs) + _calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, **kwargs) freq = "min" except ValueError: - _calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="day", freq_sam=freq, **kwargs) + _calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="day", freq_sam=freq, **kwargs) freq = "day" elif freq.endswith(("min", "minute")): - _calendar = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="min", freq_sam=freq, **kwargs) + _calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, **kwargs) freq = "min" else: raise ValueError(f"freq {freq} is not supported") return _calendar, freq, freq_sam def sample_feature(feature, instruments=None, start_time=None, end_time=None, fields=None, method=None, method_kwargs={}): - if instruments and type(instruments) is not list: + if instruments and not isinstance(instruments, list): instruments = [instruments] - if fields and type(fields) is not list: - fields = [fields] selector_inst = slice(None) if instruments is None else instruments selector_datetime = slice(start_time, end_time) - if fields is not None and type(fields) is not list: - fields = [fields] - selector_fields = slice(None) if fields is None else fields - feature = feature.loc[(selector_inst, selector_datetime), selector_fields] + if isinstance(feature, pd.Series): + feature = feature.loc[(selector_inst, selector_datetime)] + if fields: + warnings.warn(f"sample series feature, {fields} is ignored!") + elif isinstance(feature, pd.DataFrame): + selector_fields = slice(None) if fields is None else fields + feature = feature.loc[(selector_inst, selector_datetime), selector_fields] if method: return getattr(feature.groupby(level="instrument"), method)(**method_kwargs) else: