mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 10:01:19 +08:00
add PandasQuote
This commit is contained in:
@@ -102,10 +102,11 @@ class Exchange:
|
||||
|
||||
# TODO: the quote, trade_dates, codes are not necessray.
|
||||
# It is just for performance consideration.
|
||||
self.limit_type = BaseQuote._get_limit_type(limit_threshold)
|
||||
if limit_threshold is None:
|
||||
if C.region == REG_CN:
|
||||
self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold")
|
||||
elif self._get_limit_type(limit_threshold) == self.LT_FLT and abs(limit_threshold) > 0.1:
|
||||
elif self.limit_type == BaseQuote.LT_FLT and abs(limit_threshold) > 0.1:
|
||||
if C.region == REG_CN:
|
||||
self.logger.warning(f"limit_threshold may not be set to a reasonable value")
|
||||
|
||||
@@ -127,10 +128,9 @@ class Exchange:
|
||||
# $change is for calculating the limit of the stock
|
||||
|
||||
necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
|
||||
if self._get_limit_type(limit_threshold) == self.LT_TP_EXP:
|
||||
if self.limit_type == BaseQuote.LT_TP_EXP:
|
||||
for exp in limit_threshold:
|
||||
necessary_fields.add(exp)
|
||||
subscribe_fields = list(necessary_fields | set(subscribe_fields))
|
||||
all_fields = list(necessary_fields | set(subscribe_fields))
|
||||
|
||||
self.all_fields = all_fields
|
||||
@@ -140,94 +140,22 @@ class Exchange:
|
||||
self.limit_threshold: Union[Tuple[str, str], float, None] = limit_threshold
|
||||
self.volume_threshold = volume_threshold
|
||||
self.extra_quote = extra_quote
|
||||
self.set_quote(codes, start_time, end_time)
|
||||
|
||||
def set_quote(self, codes, start_time, end_time):
|
||||
if len(codes) == 0:
|
||||
codes = D.instruments()
|
||||
|
||||
self.quote = D.features(codes, self.all_fields, start_time, end_time, freq=self.freq, disk_cache=True).dropna(
|
||||
subset=["$close"]
|
||||
# init quote
|
||||
self.quote = PandasQuote(
|
||||
start_time = self.start_time,
|
||||
end_time = self.end_time,
|
||||
freq = self.freq,
|
||||
codes = self.codes,
|
||||
all_fields = self.all_fields,
|
||||
limit_threshold = self.limit_threshold,
|
||||
buy_price = self.buy_price,
|
||||
sell_price = self.sell_price,
|
||||
extra_quote = self.extra_quote,
|
||||
)
|
||||
self.quote.columns = self.all_fields
|
||||
|
||||
for attr in "buy_price", "sell_price":
|
||||
pstr = getattr(self, attr) # price string
|
||||
if self.quote[pstr].isna().any():
|
||||
self.logger.warning("{} field data contains nan.".format(pstr))
|
||||
|
||||
if self.quote["$factor"].isna().any():
|
||||
# The 'factor.day.bin' file not exists, and `factor` field contains `nan`
|
||||
# Use adjusted price
|
||||
self.trade_w_adj_price = True
|
||||
self.logger.warning("factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price.")
|
||||
if self.trade_unit is not None:
|
||||
self.logger.warning(f"trade unit {self.trade_unit} is not supported in adjusted_price mode.")
|
||||
|
||||
else:
|
||||
# The `factor.day.bin` file exists and all data `close` and `factor` are not `nan`
|
||||
# Use normal price
|
||||
self.trade_w_adj_price = False
|
||||
|
||||
# update limit
|
||||
self._update_limit()
|
||||
|
||||
quote_df = self.quote
|
||||
if self.extra_quote is not None:
|
||||
# process extra_quote
|
||||
if "$close" not in self.extra_quote:
|
||||
raise ValueError("$close is necessray in extra_quote")
|
||||
for attr in "buy_price", "sell_price":
|
||||
pstr = getattr(self, attr) # price string
|
||||
if pstr not in self.extra_quote.columns:
|
||||
self.extra_quote[pstr] = self.extra_quote["$close"]
|
||||
self.logger.warning(f"No {pstr} set for extra_quote. Use $close as {pstr}.")
|
||||
if "$factor" not in self.extra_quote.columns:
|
||||
self.extra_quote["$factor"] = 1.0
|
||||
self.logger.warning("No $factor set for extra_quote. Use 1.0 as $factor.")
|
||||
if "limit_sell" not in self.extra_quote.columns:
|
||||
self.extra_quote["limit_sell"] = False
|
||||
self.logger.warning("No limit_sell set for extra_quote. All stock will be able to be sold.")
|
||||
if "limit_buy" not in self.extra_quote.columns:
|
||||
self.extra_quote["limit_buy"] = False
|
||||
self.logger.warning("No limit_buy set for extra_quote. All stock will be able to be bought.")
|
||||
|
||||
assert set(self.extra_quote.columns) == set(quote_df.columns) - {"$change"}
|
||||
quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0)
|
||||
|
||||
quote_dict = {}
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument"):
|
||||
quote_dict[stock_id] = stock_val.droplevel(level="instrument")
|
||||
|
||||
self.quote = quote_dict
|
||||
|
||||
LT_TP_EXP = "(exp)" # Tuple[str, str]
|
||||
LT_FLT = "float" # float
|
||||
LT_NONE = "none" # none
|
||||
|
||||
def _get_limit_type(self, limit_threshold):
|
||||
if isinstance(limit_threshold, Tuple):
|
||||
return self.LT_TP_EXP
|
||||
elif isinstance(limit_threshold, float):
|
||||
return self.LT_FLT
|
||||
elif limit_threshold is None:
|
||||
return self.LT_NONE
|
||||
else:
|
||||
raise NotImplementedError(f"This type of `limit_threshold` is not supported")
|
||||
|
||||
def _update_limit(self):
|
||||
# check limit_threshold
|
||||
lt_type = self._get_limit_type(self.limit_threshold)
|
||||
if lt_type == self.LT_NONE:
|
||||
self.quote["limit_buy"] = False
|
||||
self.quote["limit_sell"] = False
|
||||
elif lt_type == self.LT_TP_EXP:
|
||||
# set limit
|
||||
self.quote["limit_buy"] = self.quote[self.limit_threshold[0]]
|
||||
self.quote["limit_sell"] = self.quote[self.limit_threshold[1]]
|
||||
elif lt_type == self.LT_FLT:
|
||||
self.quote["limit_buy"] = self.quote["$change"].ge(self.limit_threshold)
|
||||
self.quote["limit_sell"] = self.quote["$change"].le(-self.limit_threshold) # pylint: disable=E1130
|
||||
self.trade_w_adj_price = self.quote.get_trade_w_adj_price()
|
||||
if(self.trade_w_adj_price and (self.trade_unit is not None)):
|
||||
self.logger.warning(f"trade unit {self.trade_unit} is not supported in adjusted_price mode.")
|
||||
|
||||
def check_stock_limit(self, stock_id, start_time, end_time, direction=None):
|
||||
"""
|
||||
@@ -241,20 +169,20 @@ class Exchange:
|
||||
|
||||
"""
|
||||
if direction is None:
|
||||
buy_limit = resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all")
|
||||
sell_limit = resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all")
|
||||
buy_limit = self.quote.get_data(stock_id, start_time, end_time, fields="limit_buy", method="all")
|
||||
sell_limit = self.quote.get_data(stock_id, start_time, end_time, fields="limit_sell", method="all")
|
||||
return buy_limit or sell_limit
|
||||
elif direction == Order.BUY:
|
||||
return resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all")
|
||||
return self.quote.get_data(stock_id, start_time, end_time, fields="limit_buy", method="all")
|
||||
elif direction == Order.SELL:
|
||||
return resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all")
|
||||
return self.quote.get_data(stock_id, start_time, end_time, fields="limit_sell", method="all")
|
||||
else:
|
||||
raise ValueError(f"direction {direction} is not supported!")
|
||||
|
||||
def check_stock_suspended(self, stock_id, start_time, end_time):
|
||||
# is suspended
|
||||
if stock_id in self.quote:
|
||||
return resam_ts_data(self.quote[stock_id], start_time, end_time, method=None) is None
|
||||
if stock_id in self.quote.get_all_stock():
|
||||
return self.quote.get_data(stock_id, start_time, end_time) is None
|
||||
else:
|
||||
return True
|
||||
|
||||
@@ -313,13 +241,13 @@ class Exchange:
|
||||
return trade_val, trade_cost, trade_price
|
||||
|
||||
def get_quote_info(self, stock_id, start_time, end_time, method=ts_data_last):
|
||||
return resam_ts_data(self.quote[stock_id], start_time, end_time, method=method)
|
||||
return self.quote.get_data(stock_id, start_time, end_time, method=method)
|
||||
|
||||
def get_close(self, stock_id, start_time, end_time, method=ts_data_last):
|
||||
return resam_ts_data(self.quote[stock_id]["$close"], start_time, end_time, method=method)
|
||||
return self.quote.get_data(stock_id, start_time, end_time, fields="$close", method=method)
|
||||
|
||||
def get_volume(self, stock_id, start_time, end_time, method="sum"):
|
||||
return resam_ts_data(self.quote[stock_id]["$volume"], start_time, end_time, method=method)
|
||||
return self.quote.get_data(stock_id, start_time, end_time, fields="$volume", method=method)
|
||||
|
||||
def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir, method=ts_data_last):
|
||||
if direction == OrderDir.SELL:
|
||||
@@ -328,7 +256,7 @@ class Exchange:
|
||||
pstr = self.buy_price
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
deal_price = resam_ts_data(self.quote[stock_id][pstr], start_time, end_time, method=method)
|
||||
deal_price = self.quote.get_data(stock_id, start_time, end_time, fields=pstr, method=method)
|
||||
if method is not None and (np.isclose(deal_price, 0.0) or np.isnan(deal_price)):
|
||||
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!")
|
||||
self.logger.warning(f"setting deal_price to close price")
|
||||
@@ -343,9 +271,9 @@ class Exchange:
|
||||
`None`: if the stock is suspended `None` may be returned
|
||||
`float`: return factor if the factor exists
|
||||
"""
|
||||
if stock_id not in self.quote:
|
||||
if stock_id not in self.quote.get_all_stock():
|
||||
return None
|
||||
return resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method=ts_data_last)
|
||||
return self.quote.get_data(stock_id, start_time, end_time, fields="$factor", method=ts_data_last)
|
||||
|
||||
def generate_amount_position_from_weight_position(
|
||||
self, weight_position, cash, start_time, end_time, direction=OrderDir.BUY
|
||||
@@ -596,3 +524,145 @@ class Exchange:
|
||||
# cache to avoid recreate the same instance
|
||||
self._order_helper = OrderHelper(self)
|
||||
return self._order_helper
|
||||
|
||||
|
||||
class BaseQuote:
|
||||
|
||||
def __init__(self):
|
||||
self.logger = get_module_logger("online operator", level=logging.INFO)
|
||||
|
||||
def _update_limit(self, limit_threshold):
|
||||
raise NotImplementedError(f"Please implement the `_update_limit` method")
|
||||
|
||||
def get_trade_w_adj_price(self):
|
||||
raise NotImplementedError(f"Please implement the `get_trade_w_adj_price` method")
|
||||
|
||||
def get_all_stock(self):
|
||||
raise NotImplementedError(f"Please implement the `get_all_stock` method")
|
||||
|
||||
def get_data(self, stock_id, start_time, end_time, fields, method):
|
||||
raise NotImplementedError(f"Please implement the `get_data` method")
|
||||
|
||||
LT_TP_EXP = "(exp)" # Tuple[str, str]
|
||||
LT_FLT = "float" # float
|
||||
LT_NONE = "none" # none
|
||||
|
||||
@staticmethod
|
||||
def _get_limit_type(limit_threshold):
|
||||
if isinstance(limit_threshold, Tuple):
|
||||
return BaseQuote.LT_TP_EXP
|
||||
elif isinstance(limit_threshold, float):
|
||||
return BaseQuote.LT_FLT
|
||||
elif limit_threshold is None:
|
||||
return BaseQuote.LT_NONE
|
||||
else:
|
||||
raise NotImplementedError(f"This type of `limit_threshold` is not supported")
|
||||
|
||||
|
||||
class PandasQuote(BaseQuote):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_time,
|
||||
end_time,
|
||||
freq,
|
||||
codes,
|
||||
all_fields,
|
||||
limit_threshold,
|
||||
buy_price,
|
||||
sell_price,
|
||||
extra_quote
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
# get stock data from qlib
|
||||
if len(codes) == 0:
|
||||
codes = D.instruments()
|
||||
self.data = D.features(
|
||||
codes,
|
||||
all_fields,
|
||||
start_time,
|
||||
end_time,
|
||||
freq=freq,
|
||||
disk_cache=True
|
||||
).dropna(subset=["$close"])
|
||||
self.data.columns = all_fields
|
||||
|
||||
# check buy_price data and sell_price data
|
||||
self.buy_price = buy_price
|
||||
self.sell_price = sell_price
|
||||
for attr in "buy_price", "sell_price":
|
||||
pstr = getattr(self, attr) # price string
|
||||
if self.data[pstr].isna().any():
|
||||
self.logger.warning("{} field data contains nan.".format(pstr))
|
||||
|
||||
# update trade_w_adj_price
|
||||
if self.data["$factor"].isna().any():
|
||||
# The 'factor.day.bin' file not exists, and `factor` field contains `nan`
|
||||
# Use adjusted price
|
||||
self.logger.warning("factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price.")
|
||||
self.trade_w_adj_price = True
|
||||
else:
|
||||
# The `factor.day.bin` file exists and all data `close` and `factor` are not `nan`
|
||||
# Use normal price
|
||||
self.trade_w_adj_price = False
|
||||
|
||||
# update limit
|
||||
self._update_limit(limit_threshold)
|
||||
|
||||
# concat extra_quote
|
||||
quote_df = self.data
|
||||
if extra_quote is not None:
|
||||
# process extra_quote
|
||||
if "$close" not in extra_quote:
|
||||
raise ValueError("$close is necessray in extra_quote")
|
||||
for attr in "buy_price", "sell_price":
|
||||
pstr = getattr(self, attr) # price string
|
||||
if pstr not in extra_quote.columns:
|
||||
extra_quote[pstr] = extra_quote["$close"]
|
||||
self.logger.warning(f"No {pstr} set for extra_quote. Use $close as {pstr}.")
|
||||
if "$factor" not in extra_quote.columns:
|
||||
extra_quote["$factor"] = 1.0
|
||||
self.logger.warning("No $factor set for extra_quote. Use 1.0 as $factor.")
|
||||
if "limit_sell" not in extra_quote.columns:
|
||||
extra_quote["limit_sell"] = False
|
||||
self.logger.warning("No limit_sell set for extra_quote. All stock will be able to be sold.")
|
||||
if "limit_buy" not in extra_quote.columns:
|
||||
extra_quote["limit_buy"] = False
|
||||
self.logger.warning("No limit_buy set for extra_quote. All stock will be able to be bought.")
|
||||
assert set(extra_quote.columns) == set(quote_df.columns) - {"$change"}
|
||||
quote_df = pd.concat([quote_df, extra_quote], sort=False, axis=0)
|
||||
|
||||
quote_dict = {}
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument"):
|
||||
quote_dict[stock_id] = stock_val.droplevel(level="instrument")
|
||||
self.data = quote_dict
|
||||
|
||||
def _update_limit(self, limit_threshold):
|
||||
# check limit_threshold
|
||||
limit_type = self._get_limit_type(limit_threshold)
|
||||
if limit_type == self.LT_NONE:
|
||||
self.data["limit_buy"] = False
|
||||
self.data["limit_sell"] = False
|
||||
elif limit_type == self.LT_TP_EXP:
|
||||
# set limit
|
||||
self.data["limit_buy"] = self.data[limit_threshold[0]]
|
||||
self.data["limit_sell"] = self.data[limit_threshold[1]]
|
||||
elif limit_type == self.LT_FLT:
|
||||
self.data["limit_buy"] = self.data["$change"].ge(limit_threshold)
|
||||
self.data["limit_sell"] = self.data["$change"].le(-limit_threshold) # pylint: disable=E1130
|
||||
|
||||
def get_all_stock(self):
|
||||
return self.data.keys()
|
||||
|
||||
def get_data(self, stock_id, start_time, end_time, fields = None, method = None):
|
||||
if(fields is None):
|
||||
return resam_ts_data(self.data[stock_id], start_time, end_time, method=method)
|
||||
elif(isinstance(fields, (str, list))):
|
||||
return resam_ts_data(self.data[stock_id][fields], start_time, end_time, method=method)
|
||||
else:
|
||||
raise ValueError(f"fields must be None, str or list")
|
||||
|
||||
def get_trade_w_adj_price(self):
|
||||
return self.trade_w_adj_price
|
||||
Reference in New Issue
Block a user