From 567841e1c663964b41e6d4bcfb0689540c43d2b5 Mon Sep 17 00:00:00 2001 From: "wangwenxi.handsome" Date: Fri, 16 Jul 2021 12:56:49 +0000 Subject: [PATCH] get qlib data in exchange --- qlib/backtest/exchange.py | 310 +++++++++++++++++--------------------- 1 file changed, 139 insertions(+), 171 deletions(-) diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 2e865d591..82f57462e 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -102,11 +102,11 @@ class Exchange: # TODO: the quote, trade_dates, codes are not necessray. # It is just for performance consideration. - self.limit_type = BaseQuote._get_limit_type(limit_threshold) + self.limit_type = self._get_limit_type(limit_threshold) if limit_threshold is None: if C.region == REG_CN: self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold") - elif self.limit_type == BaseQuote.LT_FLT and abs(limit_threshold) > 0.1: + elif self.limit_type == self.LT_FLT and abs(limit_threshold) > 0.1: if C.region == REG_CN: self.logger.warning(f"limit_threshold may not be set to a reasonable value") @@ -128,7 +128,7 @@ class Exchange: # $change is for calculating the limit of the stock necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"} - if self.limit_type == BaseQuote.LT_TP_EXP: + if self.limit_type == self.LT_TP_EXP: for exp in limit_threshold: necessary_fields.add(exp) all_fields = list(necessary_fields | set(subscribe_fields)) @@ -140,22 +140,98 @@ class Exchange: self.limit_threshold: Union[Tuple[str, str], float, None] = limit_threshold self.volume_threshold = volume_threshold self.extra_quote = extra_quote + self.get_quote_from_qlib() - # init quote - self.quote = PandasQuote( - start_time = self.start_time, - end_time = self.end_time, - freq = self.freq, - codes = self.codes, - all_fields = self.all_fields, - limit_threshold = self.limit_threshold, - buy_price = self.buy_price, - sell_price = self.sell_price, - extra_quote = self.extra_quote, - ) - self.trade_w_adj_price = self.quote.get_trade_w_adj_price() - if(self.trade_w_adj_price and (self.trade_unit is not None)): - self.logger.warning(f"trade unit {self.trade_unit} is not supported in adjusted_price mode.") + # init quote by quote_df + self.quote = PandasQuote(self.quote_df) + + def get_quote_from_qlib(self): + # get stock data from qlib + if len(self.codes) == 0: + self.codes = D.instruments() + self.quote_df = D.features( + self.codes, + self.all_fields, + self.start_time, + self.end_time, + freq=self.freq, + disk_cache=True + ).dropna(subset=["$close"]) + self.quote_df.columns = self.all_fields + + # check buy_price data and sell_price data + for attr in "buy_price", "sell_price": + pstr = getattr(self, attr) # price string + if self.quote_df[pstr].isna().any(): + self.logger.warning("{} field data contains nan.".format(pstr)) + + # update trade_w_adj_price + if self.quote_df["$factor"].isna().any(): + # The 'factor.day.bin' file not exists, and `factor` field contains `nan` + # Use adjusted price + self.trade_w_adj_price = True + self.logger.warning("factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price.") + if self.trade_unit is not None: + self.logger.warning(f"trade unit {self.trade_unit} is not supported in adjusted_price mode.") + else: + # The `factor.day.bin` file exists and all data `close` and `factor` are not `nan` + # Use normal price + self.trade_w_adj_price = False + + # update limit + self._update_limit(self.limit_threshold) + + # concat extra_quote + if self.extra_quote is not None: + # process extra_quote + if "$close" not in self.extra_quote: + raise ValueError("$close is necessray in extra_quote") + for attr in "buy_price", "sell_price": + pstr = getattr(self, attr) # price string + if pstr not in self.extra_quote.columns: + self.extra_quote[pstr] = self.extra_quote["$close"] + self.logger.warning(f"No {pstr} set for extra_quote. Use $close as {pstr}.") + if "$factor" not in self.extra_quote.columns: + self.extra_quote["$factor"] = 1.0 + self.logger.warning("No $factor set for extra_quote. Use 1.0 as $factor.") + if "limit_sell" not in self.extra_quote.columns: + self.extra_quote["limit_sell"] = False + self.logger.warning("No limit_sell set for extra_quote. All stock will be able to be sold.") + if "limit_buy" not in self.extra_quote.columns: + self.extra_quote["limit_buy"] = False + self.logger.warning("No limit_buy set for extra_quote. All stock will be able to be bought.") + assert set(self.extra_quote.columns) == set(self.quote_df.columns) - {"$change"} + self.quote_df = pd.concat([self.quote_df, extra_quote], sort=False, axis=0) + + LT_TP_EXP = "(exp)" # Tuple[str, str] + LT_FLT = "float" # float + LT_NONE = "none" # none + + def _get_limit_type(self, limit_threshold): + """get limit type + """ + if isinstance(limit_threshold, Tuple): + return self.LT_TP_EXP + elif isinstance(limit_threshold, float): + return self.LT_FLT + elif limit_threshold is None: + return self.LT_NONE + else: + raise NotImplementedError(f"This type of `limit_threshold` is not supported") + + def _update_limit(self, limit_threshold): + # check limit_threshold + limit_type = self._get_limit_type(limit_threshold) + if limit_type == self.LT_NONE: + self.quote_df["limit_buy"] = False + self.quote_df["limit_sell"] = False + elif limit_type == self.LT_TP_EXP: + # set limit + self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]] + self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]] + elif limit_type == self.LT_FLT: + self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold) + self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130 def check_stock_limit(self, stock_id, start_time, end_time, direction=None): """ @@ -528,184 +604,79 @@ class Exchange: class BaseQuote: - def __init__(self): + def __init__(self, quote_df: pd.DataFrame): self.logger = get_module_logger("online operator", level=logging.INFO) - def _update_limit(self, limit_threshold): - """add limitation information to data based on limit_threshold - """ - raise NotImplementedError(f"Please implement the `_update_limit` method") - - def get_trade_w_adj_price(self): - """return whether use the trade price with adjusted weight - """ - raise NotImplementedError(f"Please implement the `get_trade_w_adj_price` method") - def get_all_stock(self): """return all stock codes + + Return + ------ + Union[list, Dict.keys(), set, tuple] + all stock codes """ raise NotImplementedError(f"Please implement the `get_all_stock` method") - def get_data(self, stock_id, start_time, end_time, fields=None, method=None): + def get_data(self, stock_id: str, start_time, end_time, fields: Union[str, list]=None, method=None): """get the specific fields of stock data during start time and end_time, - and apply method to the data, please refer to resam_ts_data - """ - raise NotImplementedError(f"Please implement the `get_data` method") + and apply method to the data. + + Example: + .. code-block:: + $close $volume + instrument datetime + SH600000 2010-01-04 86.778313 16162960.0 + 2010-01-05 87.433578 28117442.0 + 2010-01-06 85.713585 23632884.0 + 2010-01-07 83.788803 20813402.0 + 2010-01-08 84.730675 16044853.0 - LT_TP_EXP = "(exp)" # Tuple[str, str] - LT_FLT = "float" # float - LT_NONE = "none" # none + SH600655 2010-01-04 2699.567383 158193.328125 + 2010-01-08 2612.359619 77501.406250 + 2010-01-11 2712.982422 160852.390625 + 2010-01-12 2788.688232 164587.937500 + 2010-01-13 2790.604004 145460.453125 - @staticmethod - def _get_limit_type(limit_threshold): - """get limit type - """ - if isinstance(limit_threshold, Tuple): - return BaseQuote.LT_TP_EXP - elif isinstance(limit_threshold, float): - return BaseQuote.LT_FLT - elif limit_threshold is None: - return BaseQuote.LT_NONE - else: - raise NotImplementedError(f"This type of `limit_threshold` is not supported") + print(get_data(stock_id="SH600000", start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last")) + + $close 87.433578 + $volume 28117442.0 + print(get_data(stock_id="SH600000", start_time="2010-01-04", end_time="2010-01-05", fields="$close", method="last")) -class PandasQuote(BaseQuote): - """ - """ - - def __init__( - self, - start_time, - end_time, - freq, - codes, - all_fields: List[str], - limit_threshold: Union[Tuple[str, str], float, None], - buy_price: str, - sell_price: str, - extra_quote: pd.DataFrame, - ): - """init stock data based on pandas + 87.433578 Parameters ---------- + stock_id: Union[str, list] start_time : pd.Timestamp|str closed start time for backtest end_time : pd.Timestamp|str closed end time for backtest - freq : str - frequency of data - codes : [type] - all stock code - all_fields : List[str] - all subscribe fields in qlib - limit_threshold : Union[Tuple[str, str], float, None] - 1) `None`: no limitation - 2) float, 0.1 for example, default None - 3) Tuple[str, str]: (, - ) - `False` value indicates the stock is tradable - `True` value indicates the stock is limited and not tradable - buy_price : str - the data field for buying stock - sell_price : str - the data field for selling stock - extra_quote : pd.DataFrame - columns: like ['$vwap', '$close', '$volume', '$factor', 'limit_sell', 'limit_buy']. - The limit indicates that the etf is tradable on a specific day. - Necessary fields: - $close is for calculating the total value at end of each day. - Optional fields: - $volume is only necessary when we limit the trade amount or caculate PA(vwap) indicator - $vwap is only necessary when we use the $vwap price as the deal price - $factor is for rounding to the trading unit - limit_sell will be set to False by default(False indicates we can sell this - target on this day). - limit_buy will be set to False by default(False indicates we can buy this - target on this day). - index: MultipleIndex(instrument, pd.Datetime) + fields : Union[str, List] + the columns of data to fetch + method : Union[str, Callable] + the method apply to data. + e.g ["None", "last", "all", "sum", "mean", qlib/utils/resam.py/ts_data_last] + + Return + ---------- + Union[None, float, pd.Series] + The resampled Series/value, return None when the resampled data is empty. """ - super().__init__() + raise NotImplementedError(f"Please implement the `get_data` method") - # get stock data from qlib - if len(codes) == 0: - codes = D.instruments() - self.data = D.features( - codes, - all_fields, - start_time, - end_time, - freq=freq, - disk_cache=True - ).dropna(subset=["$close"]) - self.data.columns = all_fields - # check buy_price data and sell_price data - self.buy_price = buy_price - self.sell_price = sell_price - for attr in "buy_price", "sell_price": - pstr = getattr(self, attr) # price string - if self.data[pstr].isna().any(): - self.logger.warning("{} field data contains nan.".format(pstr)) - - # update trade_w_adj_price - if self.data["$factor"].isna().any(): - # The 'factor.day.bin' file not exists, and `factor` field contains `nan` - # Use adjusted price - self.logger.warning("factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price.") - self.trade_w_adj_price = True - else: - # The `factor.day.bin` file exists and all data `close` and `factor` are not `nan` - # Use normal price - self.trade_w_adj_price = False - - # update limit - self._update_limit(limit_threshold) - - # concat extra_quote - quote_df = self.data - if extra_quote is not None: - # process extra_quote - if "$close" not in extra_quote: - raise ValueError("$close is necessray in extra_quote") - for attr in "buy_price", "sell_price": - pstr = getattr(self, attr) # price string - if pstr not in extra_quote.columns: - extra_quote[pstr] = extra_quote["$close"] - self.logger.warning(f"No {pstr} set for extra_quote. Use $close as {pstr}.") - if "$factor" not in extra_quote.columns: - extra_quote["$factor"] = 1.0 - self.logger.warning("No $factor set for extra_quote. Use 1.0 as $factor.") - if "limit_sell" not in extra_quote.columns: - extra_quote["limit_sell"] = False - self.logger.warning("No limit_sell set for extra_quote. All stock will be able to be sold.") - if "limit_buy" not in extra_quote.columns: - extra_quote["limit_buy"] = False - self.logger.warning("No limit_buy set for extra_quote. All stock will be able to be bought.") - assert set(extra_quote.columns) == set(quote_df.columns) - {"$change"} - quote_df = pd.concat([quote_df, extra_quote], sort=False, axis=0) +class PandasQuote(BaseQuote): + def __init__(self, quote_df: pd.DataFrame): + super().__init__(quote_df=quote_df) quote_dict = {} for stock_id, stock_val in quote_df.groupby(level="instrument"): quote_dict[stock_id] = stock_val.droplevel(level="instrument") self.data = quote_dict - def _update_limit(self, limit_threshold): - # check limit_threshold - limit_type = self._get_limit_type(limit_threshold) - if limit_type == self.LT_NONE: - self.data["limit_buy"] = False - self.data["limit_sell"] = False - elif limit_type == self.LT_TP_EXP: - # set limit - self.data["limit_buy"] = self.data[limit_threshold[0]] - self.data["limit_sell"] = self.data[limit_threshold[1]] - elif limit_type == self.LT_FLT: - self.data["limit_buy"] = self.data["$change"].ge(limit_threshold) - self.data["limit_sell"] = self.data["$change"].le(-limit_threshold) # pylint: disable=E1130 - def get_all_stock(self): return self.data.keys() @@ -715,7 +686,4 @@ class PandasQuote(BaseQuote): elif(isinstance(fields, (str, list))): return resam_ts_data(self.data[stock_id][fields], start_time, end_time, method=method) else: - raise ValueError(f"fields must be None, str or list") - - def get_trade_w_adj_price(self): - return self.trade_w_adj_price \ No newline at end of file + raise ValueError(f"fields must be None, str or list") \ No newline at end of file