diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 82f57462e..7733891fe 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -150,12 +150,7 @@ class Exchange: if len(self.codes) == 0: self.codes = D.instruments() self.quote_df = D.features( - self.codes, - self.all_fields, - self.start_time, - self.end_time, - freq=self.freq, - disk_cache=True + self.codes, self.all_fields, self.start_time, self.end_time, freq=self.freq, disk_cache=True ).dropna(subset=["$close"]) self.quote_df.columns = self.all_fields @@ -177,10 +172,9 @@ class Exchange: # The `factor.day.bin` file exists and all data `close` and `factor` are not `nan` # Use normal price self.trade_w_adj_price = False - # update limit self._update_limit(self.limit_threshold) - + # concat extra_quote if self.extra_quote is not None: # process extra_quote @@ -199,7 +193,7 @@ class Exchange: self.logger.warning("No limit_sell set for extra_quote. All stock will be able to be sold.") if "limit_buy" not in self.extra_quote.columns: self.extra_quote["limit_buy"] = False - self.logger.warning("No limit_buy set for extra_quote. All stock will be able to be bought.") + self.logger.warning("No limit_buy set for extra_quote. All stock will be able to be bought.") assert set(self.extra_quote.columns) == set(self.quote_df.columns) - {"$change"} self.quote_df = pd.concat([self.quote_df, extra_quote], sort=False, axis=0) @@ -208,8 +202,7 @@ class Exchange: LT_NONE = "none" # none def _get_limit_type(self, limit_threshold): - """get limit type - """ + """get limit type""" if isinstance(limit_threshold, Tuple): return self.LT_TP_EXP elif isinstance(limit_threshold, float): @@ -603,7 +596,6 @@ class Exchange: class BaseQuote: - def __init__(self, quote_df: pd.DataFrame): self.logger = get_module_logger("online operator", level=logging.INFO) @@ -617,10 +609,17 @@ class BaseQuote: """ raise NotImplementedError(f"Please implement the `get_all_stock` method") - def get_data(self, stock_id: str, start_time, end_time, fields: Union[str, list]=None, method=None): + def get_data( + self, + stock_id: Union[str, list], + start_time: Union[pd.Timestamp, str], + end_time: Union[pd.Timestamp, str], + fields: Union[str, list] = None, + method: Union[str, Callable] = None, + ): """get the specific fields of stock data during start time and end_time, and apply method to the data. - + Example: .. code-block:: $close $volume @@ -637,8 +636,15 @@ class BaseQuote: 2010-01-12 2788.688232 164587.937500 2010-01-13 2790.604004 145460.453125 + print(get_data(stock_id=["SH600000", "SH600655"], start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last")) + + $close $volume + instrument + SH600000 87.433578 28117442.0 + SH600655 2699.567383 158193.328125 + print(get_data(stock_id="SH600000", start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last")) - + $close 87.433578 $volume 28117442.0 @@ -649,27 +655,26 @@ class BaseQuote: Parameters ---------- stock_id: Union[str, list] - start_time : pd.Timestamp|str + start_time : Union[pd.Timestamp, str] closed start time for backtest - end_time : pd.Timestamp|str + end_time : Union[pd.Timestamp, str] closed end time for backtest fields : Union[str, List] the columns of data to fetch method : Union[str, Callable] - the method apply to data. - e.g ["None", "last", "all", "sum", "mean", qlib/utils/resam.py/ts_data_last] + the method apply to data. + e.g ["None", "last", "all", "sum", "mean", "any", qlib/utils/resam.py/ts_data_last] Return ---------- - Union[None, float, pd.Series] - The resampled Series/value, return None when the resampled data is empty. + Union[None, float, pd.Series, pd.DataFrame] + The resampled DataFrame/Series/value, return None when the resampled data is empty. """ - raise NotImplementedError(f"Please implement the `get_data` method") + raise NotImplementedError(f"Please implement the `get_data` method") class PandasQuote(BaseQuote): - def __init__(self, quote_df: pd.DataFrame): super().__init__(quote_df=quote_df) quote_dict = {} @@ -680,10 +685,10 @@ class PandasQuote(BaseQuote): def get_all_stock(self): return self.data.keys() - def get_data(self, stock_id, start_time, end_time, fields = None, method = None): - if(fields is None): + def get_data(self, stock_id, start_time, end_time, fields=None, method=None): + if fields is None: return resam_ts_data(self.data[stock_id], start_time, end_time, method=method) - elif(isinstance(fields, (str, list))): + elif isinstance(fields, (str, list)): return resam_ts_data(self.data[stock_id][fields], start_time, end_time, method=method) else: - raise ValueError(f"fields must be None, str or list") \ No newline at end of file + raise ValueError(f"fields must be None, str or list") diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 56884cd48..970734df5 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -687,7 +687,9 @@ class FileOrderStrategy(BaseStrategy): - This class provides an interface for user to read orders from csv files. """ - def __init__(self, file: Union[IO, str, Path], trade_range: Union[Tuple[int, int], TradeRange]= None, *args, **kwargs): + def __init__( + self, file: Union[IO, str, Path], trade_range: Union[Tuple[int, int], TradeRange] = None, *args, **kwargs + ): """ Parameters