1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 10:31:00 +08:00

black and doc

This commit is contained in:
wangwenxi.handsome
2021-07-16 13:55:49 +00:00
parent 567841e1c6
commit 6ad52e8cf5
2 changed files with 35 additions and 28 deletions

View File

@@ -150,12 +150,7 @@ class Exchange:
if len(self.codes) == 0:
self.codes = D.instruments()
self.quote_df = D.features(
self.codes,
self.all_fields,
self.start_time,
self.end_time,
freq=self.freq,
disk_cache=True
self.codes, self.all_fields, self.start_time, self.end_time, freq=self.freq, disk_cache=True
).dropna(subset=["$close"])
self.quote_df.columns = self.all_fields
@@ -177,10 +172,9 @@ class Exchange:
# The `factor.day.bin` file exists and all data `close` and `factor` are not `nan`
# Use normal price
self.trade_w_adj_price = False
# update limit
self._update_limit(self.limit_threshold)
# concat extra_quote
if self.extra_quote is not None:
# process extra_quote
@@ -199,7 +193,7 @@ class Exchange:
self.logger.warning("No limit_sell set for extra_quote. All stock will be able to be sold.")
if "limit_buy" not in self.extra_quote.columns:
self.extra_quote["limit_buy"] = False
self.logger.warning("No limit_buy set for extra_quote. All stock will be able to be bought.")
self.logger.warning("No limit_buy set for extra_quote. All stock will be able to be bought.")
assert set(self.extra_quote.columns) == set(self.quote_df.columns) - {"$change"}
self.quote_df = pd.concat([self.quote_df, extra_quote], sort=False, axis=0)
@@ -208,8 +202,7 @@ class Exchange:
LT_NONE = "none" # none
def _get_limit_type(self, limit_threshold):
"""get limit type
"""
"""get limit type"""
if isinstance(limit_threshold, Tuple):
return self.LT_TP_EXP
elif isinstance(limit_threshold, float):
@@ -603,7 +596,6 @@ class Exchange:
class BaseQuote:
def __init__(self, quote_df: pd.DataFrame):
self.logger = get_module_logger("online operator", level=logging.INFO)
@@ -617,10 +609,17 @@ class BaseQuote:
"""
raise NotImplementedError(f"Please implement the `get_all_stock` method")
def get_data(self, stock_id: str, start_time, end_time, fields: Union[str, list]=None, method=None):
def get_data(
self,
stock_id: Union[str, list],
start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str],
fields: Union[str, list] = None,
method: Union[str, Callable] = None,
):
"""get the specific fields of stock data during start time and end_time,
and apply method to the data.
Example:
.. code-block::
$close $volume
@@ -637,8 +636,15 @@ class BaseQuote:
2010-01-12 2788.688232 164587.937500
2010-01-13 2790.604004 145460.453125
print(get_data(stock_id=["SH600000", "SH600655"], start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last"))
$close $volume
instrument
SH600000 87.433578 28117442.0
SH600655 2699.567383 158193.328125
print(get_data(stock_id="SH600000", start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last"))
$close 87.433578
$volume 28117442.0
@@ -649,27 +655,26 @@ class BaseQuote:
Parameters
----------
stock_id: Union[str, list]
start_time : pd.Timestamp|str
start_time : Union[pd.Timestamp, str]
closed start time for backtest
end_time : pd.Timestamp|str
end_time : Union[pd.Timestamp, str]
closed end time for backtest
fields : Union[str, List]
the columns of data to fetch
method : Union[str, Callable]
the method apply to data.
e.g ["None", "last", "all", "sum", "mean", qlib/utils/resam.py/ts_data_last]
the method apply to data.
e.g ["None", "last", "all", "sum", "mean", "any", qlib/utils/resam.py/ts_data_last]
Return
----------
Union[None, float, pd.Series]
The resampled Series/value, return None when the resampled data is empty.
Union[None, float, pd.Series, pd.DataFrame]
The resampled DataFrame/Series/value, return None when the resampled data is empty.
"""
raise NotImplementedError(f"Please implement the `get_data` method")
raise NotImplementedError(f"Please implement the `get_data` method")
class PandasQuote(BaseQuote):
def __init__(self, quote_df: pd.DataFrame):
super().__init__(quote_df=quote_df)
quote_dict = {}
@@ -680,10 +685,10 @@ class PandasQuote(BaseQuote):
def get_all_stock(self):
return self.data.keys()
def get_data(self, stock_id, start_time, end_time, fields = None, method = None):
if(fields is None):
def get_data(self, stock_id, start_time, end_time, fields=None, method=None):
if fields is None:
return resam_ts_data(self.data[stock_id], start_time, end_time, method=method)
elif(isinstance(fields, (str, list))):
elif isinstance(fields, (str, list)):
return resam_ts_data(self.data[stock_id][fields], start_time, end_time, method=method)
else:
raise ValueError(f"fields must be None, str or list")
raise ValueError(f"fields must be None, str or list")

View File

@@ -687,7 +687,9 @@ class FileOrderStrategy(BaseStrategy):
- This class provides an interface for user to read orders from csv files.
"""
def __init__(self, file: Union[IO, str, Path], trade_range: Union[Tuple[int, int], TradeRange]= None, *args, **kwargs):
def __init__(
self, file: Union[IO, str, Path], trade_range: Union[Tuple[int, int], TradeRange] = None, *args, **kwargs
):
"""
Parameters