From f67b99a30e6890519bd532a96452687b3f7d795c Mon Sep 17 00:00:00 2001 From: "wangwenxi.handsome" Date: Sun, 15 Aug 2021 12:45:29 +0000 Subject: [PATCH] update exchange --- qlib/backtest/exchange.py | 10 +- qlib/backtest/high_performance_ds.py | 135 ++++++++++++++++++++++++++- qlib/backtest/report.py | 3 + tests/backtest/test_file_strategy.py | 15 +-- 4 files changed, 150 insertions(+), 13 deletions(-) diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index d36675b01..9327e6f15 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -21,7 +21,7 @@ from ..config import C, REG_CN from ..utils.resam import resam_ts_data, ts_data_last from ..log import get_module_logger from .order import Order, OrderDir, OrderHelper -from .high_performance_ds import PandasQuote +from .high_performance_ds import PandasQuote, NumpyQuote class Exchange: @@ -39,7 +39,7 @@ class Exchange: close_cost=0.0025, min_cost=5, extra_quote=None, - quote_cls=PandasQuote, + quote_cls=NumpyQuote, **kwargs, ): """__init__ @@ -725,9 +725,9 @@ class Exchange: """ max_trade_amount = 0 if cash >= self.min_cost: - # critical_amount means the stock transaction amount when the service fee is equal to min_cost. - critical_amount = self.min_cost / self.open_cost + self.min_cost - if cash >= critical_amount: + # critical_price means the stock transaction price when the service fee is equal to min_cost. + critical_price = self.min_cost / self.open_cost + self.min_cost + if cash >= critical_price: # the service fee is equal to open_cost * trade_amount max_trade_amount = cash / (1 + self.open_cost) / trade_price else: diff --git a/qlib/backtest/high_performance_ds.py b/qlib/backtest/high_performance_ds.py index eabe84a0a..9bf2ca2b8 100644 --- a/qlib/backtest/high_performance_ds.py +++ b/qlib/backtest/high_performance_ds.py @@ -3,13 +3,16 @@ import logging +from qlib.data.base import Feature from typing import List, Text, Tuple, Union, Callable, Iterable, Dict from collections import OrderedDict import inspect +import bisect import pandas as pd +import numpy as np -from ..utils.resam import resam_ts_data +from ..utils.resam import resam_ts_data, ts_data_last from ..log import get_module_logger @@ -112,6 +115,136 @@ class PandasQuote(BaseQuote): else: raise ValueError(f"fields must be None, str or list") + def _if_single_data(self, start_time, end_time): + if end_time - start_time < np.timedelta64(1, 'm'): + return True + if start_time.hour == 11 and start_time.minute == 29 and start_time.second == 0: + return True + if start_time.hour == 14 and start_time.minute == 59 and start_time.second == 0: + return True + return False + + +class NumpyQuote(BaseQuote): + def __init__(self, quote_df: pd.DataFrame): + """NumpyQuote + + Parameters + ---------- + quote_df : pd.DataFrame + the init dataframe from qlib. + + Variables + self.data: Dict[stock_id, np.array] + each stock has one two-dimensional np.array to represent data. + self.columns: Dict[str, int] + map column name to column id in self.data. + self.dates: Dict[stock_id, Dict[pd.Timestap, int]] + map timestap to row id in self.data. + self.dates_list: Dict[stock_id, List[pd.Timestap]] + the dates of each stock for searching. + """ + super().__init__(quote_df=quote_df) + # init data + columns = quote_df.columns.values + self.columns = dict(zip(columns, range(len(columns)))) + self.data, self.dates, self.dates_list = self._to_numpy(quote_df) + + # lru + self.muti_lru = {} + + def _to_numpy(self, quote_df): + """convert dataframe to numpy. + """ + quote_dict = {} + date_dict = {} + date_list = {} + for stock_id, stock_val in quote_df.groupby(level="instrument"): + quote_dict[stock_id] = stock_val.values + date_dict[stock_id] = stock_val.index.get_level_values("datetime") + date_list[stock_id] = list(date_dict[stock_id]) + for stock_id in date_dict: + date_dict[stock_id] = dict(zip(date_dict[stock_id], range(len(date_dict[stock_id])))) + return quote_dict, date_dict, date_list + + def get_all_stock(self): + return self.data.keys() + + def get_data(self, stock_id, start_time, end_time, fields=None, method=None): + # check stock id + if stock_id not in self.get_all_stock(): + return None + + # get single data + if self._if_single_data(start_time, end_time): + if start_time not in self.dates[stock_id]: + return None + if fields is None: + # it used for check if data is None + return self.data[stock_id][self.dates[stock_id][start_time]] + else: + return self.data[stock_id][self.dates[stock_id][start_time]][self.columns[fields]] + # get muti row data + else: + # check lru + if (start_time, end_time, fields, method) in self.muti_lru: + return self.muti_lru[(start_time, end_time, fields, method)] + + start_id = bisect.bisect_left(self.dates_list[stock_id], start_time) + end_id = bisect.bisect_right(self.dates_list[stock_id], end_time) + if start_id == end_id: + return None + # it used for check if data is None + if fields is None: + return self.data[stock_id][start_id: end_id] + agg_stock_data = self._agg_data(self.data[stock_id][start_id: end_id, self.columns[fields]], method) + + # result lru + self.muti_lru[(start_time, end_time, fields, method)] = agg_stock_data + return agg_stock_data + + def _agg_data(self, data, method): + """Agg data by specific method. + """ + if method == "sum": + return data.sum() + if method == "mean": + return data.mean() + if method == "last": + return data[-1] + if method == "all": + return data.all() + if method == "any": + return data.any() + if method == ts_data_last: + valid_data = data[data != np.NaN] + if len(valid_data) == 0: + return None + else: + return valid_data[0] + + def _if_single_data(self, start_time, end_time): + """Is there only one piece of data to obtaine. + + Parameters + ---------- + start_time : Union[pd.Timestamp, str] + closed start time for data. + end_time : Union[pd.Timestamp, str] + closed end time for data. + Returns + ------- + bool + True means one piece of data to obtaine. + """ + if end_time - start_time < np.timedelta64(1, 'm'): + return True + if start_time.hour == 11 and start_time.minute == 29 and start_time.second == 0: + return True + if start_time.hour == 14 and start_time.minute == 59 and start_time.second == 0: + return True + return False + class BaseSingleMetric: """ diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index 2d188dd18..9f957c0ac 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -389,6 +389,9 @@ class Indicator: if price_s is None: return None, None + if isinstance(price_s, (int, float)): + price_s = pd.Series(price_s, index=[trade_start_time]) + # NOTE: there are some zeros in the trading price. These cases are known meaningless # for aligning the previous logic, remove it. price_s = price_s[~(price_s < 1e-08)] # remove zero and negative values. diff --git a/tests/backtest/test_file_strategy.py b/tests/backtest/test_file_strategy.py index 9229581ac..945f142c6 100644 --- a/tests/backtest/test_file_strategy.py +++ b/tests/backtest/test_file_strategy.py @@ -29,13 +29,13 @@ class FileStrTest(TestAutoData): # test cash limit for buying ["20200103", self.TEST_INST, "1000", "buy"], # test min_cost for buying - ["20200103", self.TEST_INST, "1", "buy"], + ["20200106", self.TEST_INST, "1", "buy"], # test held stock limit for selling - ["20200106", self.TEST_INST, "1000", "sell"], + ["20200107", self.TEST_INST, "1000", "sell"], # test cash limit for buying - ["20200107", self.TEST_INST, "1000", "buy"], + ["20200108", self.TEST_INST, "1000", "buy"], # test min_cost for selling - ["20200108", self.TEST_INST, "1", "sell"], + ["20200109", self.TEST_INST, "1", "sell"], # test selling all stocks ["20200110", self.TEST_INST, str(self.DEAL_NUM_FOR_1000), "sell"], ] @@ -94,10 +94,11 @@ class FileStrTest(TestAutoData): # ffr valid ffr_dict = indicator_dict["1day"]["ffr"].to_dict() ffr_dict = {str(date).split()[0]: ffr_dict[date] for date in ffr_dict} - assert ffr_dict["2020-01-03"] == 0 - assert ffr_dict["2020-01-06"] == self.DEAL_NUM_FOR_1000 / 1000 + assert ffr_dict["2020-01-03"] == self.DEAL_NUM_FOR_1000 / 1000 + assert ffr_dict["2020-01-06"] == 0 assert ffr_dict["2020-01-07"] == self.DEAL_NUM_FOR_1000 / 1000 - assert ffr_dict["2020-01-08"] == 0 + assert ffr_dict["2020-01-08"] == self.DEAL_NUM_FOR_1000 / 1000 + assert ffr_dict["2020-01-09"] == 0 assert ffr_dict["2020-01-10"] == 1 self.EXAMPLE_FILE.unlink()