1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 11:00:57 +08:00

update exchange

This commit is contained in:
wangwenxi.handsome
2021-08-15 12:45:29 +00:00
committed by you-n-g
parent 2da6a8c770
commit f67b99a30e
4 changed files with 150 additions and 13 deletions

View File

@@ -21,7 +21,7 @@ from ..config import C, REG_CN
from ..utils.resam import resam_ts_data, ts_data_last
from ..log import get_module_logger
from .order import Order, OrderDir, OrderHelper
from .high_performance_ds import PandasQuote
from .high_performance_ds import PandasQuote, NumpyQuote
class Exchange:
@@ -39,7 +39,7 @@ class Exchange:
close_cost=0.0025,
min_cost=5,
extra_quote=None,
quote_cls=PandasQuote,
quote_cls=NumpyQuote,
**kwargs,
):
"""__init__
@@ -725,9 +725,9 @@ class Exchange:
"""
max_trade_amount = 0
if cash >= self.min_cost:
# critical_amount means the stock transaction amount when the service fee is equal to min_cost.
critical_amount = self.min_cost / self.open_cost + self.min_cost
if cash >= critical_amount:
# critical_price means the stock transaction price when the service fee is equal to min_cost.
critical_price = self.min_cost / self.open_cost + self.min_cost
if cash >= critical_price:
# the service fee is equal to open_cost * trade_amount
max_trade_amount = cash / (1 + self.open_cost) / trade_price
else:

View File

@@ -3,13 +3,16 @@
import logging
from qlib.data.base import Feature
from typing import List, Text, Tuple, Union, Callable, Iterable, Dict
from collections import OrderedDict
import inspect
import bisect
import pandas as pd
import numpy as np
from ..utils.resam import resam_ts_data
from ..utils.resam import resam_ts_data, ts_data_last
from ..log import get_module_logger
@@ -112,6 +115,136 @@ class PandasQuote(BaseQuote):
else:
raise ValueError(f"fields must be None, str or list")
def _if_single_data(self, start_time, end_time):
if end_time - start_time < np.timedelta64(1, 'm'):
return True
if start_time.hour == 11 and start_time.minute == 29 and start_time.second == 0:
return True
if start_time.hour == 14 and start_time.minute == 59 and start_time.second == 0:
return True
return False
class NumpyQuote(BaseQuote):
def __init__(self, quote_df: pd.DataFrame):
"""NumpyQuote
Parameters
----------
quote_df : pd.DataFrame
the init dataframe from qlib.
Variables
self.data: Dict[stock_id, np.array]
each stock has one two-dimensional np.array to represent data.
self.columns: Dict[str, int]
map column name to column id in self.data.
self.dates: Dict[stock_id, Dict[pd.Timestap, int]]
map timestap to row id in self.data.
self.dates_list: Dict[stock_id, List[pd.Timestap]]
the dates of each stock for searching.
"""
super().__init__(quote_df=quote_df)
# init data
columns = quote_df.columns.values
self.columns = dict(zip(columns, range(len(columns))))
self.data, self.dates, self.dates_list = self._to_numpy(quote_df)
# lru
self.muti_lru = {}
def _to_numpy(self, quote_df):
"""convert dataframe to numpy.
"""
quote_dict = {}
date_dict = {}
date_list = {}
for stock_id, stock_val in quote_df.groupby(level="instrument"):
quote_dict[stock_id] = stock_val.values
date_dict[stock_id] = stock_val.index.get_level_values("datetime")
date_list[stock_id] = list(date_dict[stock_id])
for stock_id in date_dict:
date_dict[stock_id] = dict(zip(date_dict[stock_id], range(len(date_dict[stock_id]))))
return quote_dict, date_dict, date_list
def get_all_stock(self):
return self.data.keys()
def get_data(self, stock_id, start_time, end_time, fields=None, method=None):
# check stock id
if stock_id not in self.get_all_stock():
return None
# get single data
if self._if_single_data(start_time, end_time):
if start_time not in self.dates[stock_id]:
return None
if fields is None:
# it used for check if data is None
return self.data[stock_id][self.dates[stock_id][start_time]]
else:
return self.data[stock_id][self.dates[stock_id][start_time]][self.columns[fields]]
# get muti row data
else:
# check lru
if (start_time, end_time, fields, method) in self.muti_lru:
return self.muti_lru[(start_time, end_time, fields, method)]
start_id = bisect.bisect_left(self.dates_list[stock_id], start_time)
end_id = bisect.bisect_right(self.dates_list[stock_id], end_time)
if start_id == end_id:
return None
# it used for check if data is None
if fields is None:
return self.data[stock_id][start_id: end_id]
agg_stock_data = self._agg_data(self.data[stock_id][start_id: end_id, self.columns[fields]], method)
# result lru
self.muti_lru[(start_time, end_time, fields, method)] = agg_stock_data
return agg_stock_data
def _agg_data(self, data, method):
"""Agg data by specific method.
"""
if method == "sum":
return data.sum()
if method == "mean":
return data.mean()
if method == "last":
return data[-1]
if method == "all":
return data.all()
if method == "any":
return data.any()
if method == ts_data_last:
valid_data = data[data != np.NaN]
if len(valid_data) == 0:
return None
else:
return valid_data[0]
def _if_single_data(self, start_time, end_time):
"""Is there only one piece of data to obtaine.
Parameters
----------
start_time : Union[pd.Timestamp, str]
closed start time for data.
end_time : Union[pd.Timestamp, str]
closed end time for data.
Returns
-------
bool
True means one piece of data to obtaine.
"""
if end_time - start_time < np.timedelta64(1, 'm'):
return True
if start_time.hour == 11 and start_time.minute == 29 and start_time.second == 0:
return True
if start_time.hour == 14 and start_time.minute == 59 and start_time.second == 0:
return True
return False
class BaseSingleMetric:
"""

View File

@@ -389,6 +389,9 @@ class Indicator:
if price_s is None:
return None, None
if isinstance(price_s, (int, float)):
price_s = pd.Series(price_s, index=[trade_start_time])
# NOTE: there are some zeros in the trading price. These cases are known meaningless
# for aligning the previous logic, remove it.
price_s = price_s[~(price_s < 1e-08)] # remove zero and negative values.

View File

@@ -29,13 +29,13 @@ class FileStrTest(TestAutoData):
# test cash limit for buying
["20200103", self.TEST_INST, "1000", "buy"],
# test min_cost for buying
["20200103", self.TEST_INST, "1", "buy"],
["20200106", self.TEST_INST, "1", "buy"],
# test held stock limit for selling
["20200106", self.TEST_INST, "1000", "sell"],
["20200107", self.TEST_INST, "1000", "sell"],
# test cash limit for buying
["20200107", self.TEST_INST, "1000", "buy"],
["20200108", self.TEST_INST, "1000", "buy"],
# test min_cost for selling
["20200108", self.TEST_INST, "1", "sell"],
["20200109", self.TEST_INST, "1", "sell"],
# test selling all stocks
["20200110", self.TEST_INST, str(self.DEAL_NUM_FOR_1000), "sell"],
]
@@ -94,10 +94,11 @@ class FileStrTest(TestAutoData):
# ffr valid
ffr_dict = indicator_dict["1day"]["ffr"].to_dict()
ffr_dict = {str(date).split()[0]: ffr_dict[date] for date in ffr_dict}
assert ffr_dict["2020-01-03"] == 0
assert ffr_dict["2020-01-06"] == self.DEAL_NUM_FOR_1000 / 1000
assert ffr_dict["2020-01-03"] == self.DEAL_NUM_FOR_1000 / 1000
assert ffr_dict["2020-01-06"] == 0
assert ffr_dict["2020-01-07"] == self.DEAL_NUM_FOR_1000 / 1000
assert ffr_dict["2020-01-08"] == 0
assert ffr_dict["2020-01-08"] == self.DEAL_NUM_FOR_1000 / 1000
assert ffr_dict["2020-01-09"] == 0
assert ffr_dict["2020-01-10"] == 1
self.EXAMPLE_FILE.unlink()