1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 11:00:57 +08:00

black formatting

This commit is contained in:
王雪
2021-01-20 20:29:59 +08:00
committed by you-n-g
parent 5ad1b4cc33
commit 784e73bceb
2 changed files with 39 additions and 115 deletions

View File

@@ -59,9 +59,7 @@ class CalendarProvider(abc.ABC):
list
calendar list
"""
raise NotImplementedError(
"Subclass of CalendarProvider must implement `calendar` method"
)
raise NotImplementedError("Subclass of CalendarProvider must implement `calendar` method")
def locate_index(self, start_time, end_time, freq, future):
"""Locate the start time index and end time index in a calendar under certain frequency.
@@ -183,9 +181,7 @@ class InstrumentProvider(abc.ABC):
return config
@abc.abstractmethod
def list_instruments(
self, instruments, start_time=None, end_time=None, freq="day", as_list=False
):
def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
"""List the instruments based on a certain stockpool config.
Parameters
@@ -204,13 +200,9 @@ class InstrumentProvider(abc.ABC):
dict or list
instruments list or dictionary with time spans
"""
raise NotImplementedError(
"Subclass of InstrumentProvider must implement `list_instruments` method"
)
raise NotImplementedError("Subclass of InstrumentProvider must implement `list_instruments` method")
def _uri(
self, instruments, start_time=None, end_time=None, freq="day", as_list=False
):
def _uri(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
return hash_args(instruments, start_time, end_time, freq, as_list)
# instruments type
@@ -241,9 +233,7 @@ class InstrumentProvider(abc.ABC):
)
_df_list.append(_df.iloc[:, [0, -1]])
df = pd.concat(_df_list, sort=False).sort_values("save_inst")
df = df.drop_duplicates(subset=["save_inst"], keep="first").fillna(
axis=1, method="ffill"
)
df = df.drop_duplicates(subset=["save_inst"], keep="first").fillna(axis=1, method="ffill")
_instruments_map = df.set_index("inst").iloc[:, 0].to_dict()
setattr(self, "_instruments_map", _instruments_map)
return _instruments_map.get(instrument, instrument)
@@ -277,9 +267,7 @@ class FeatureProvider(abc.ABC):
pd.Series
data of a certain feature
"""
raise NotImplementedError(
"Subclass of FeatureProvider must implement `feature` method"
)
raise NotImplementedError("Subclass of FeatureProvider must implement `feature` method")
class ExpressionProvider(abc.ABC):
@@ -300,14 +288,11 @@ class ExpressionProvider(abc.ABC):
self.expression_instance_cache[field] = expression
except NameError as e:
get_module_logger("data").exception(
"ERROR: field [%s] contains invalid operator/variable [%s]"
% (str(field), str(e).split()[1])
"ERROR: field [%s] contains invalid operator/variable [%s]" % (str(field), str(e).split()[1])
)
raise
except SyntaxError:
get_module_logger("data").exception(
"ERROR: field [%s] contains invalid syntax" % str(field)
)
get_module_logger("data").exception("ERROR: field [%s] contains invalid syntax" % str(field))
raise
return expression
@@ -333,9 +318,7 @@ class ExpressionProvider(abc.ABC):
pd.Series
data of a certain expression
"""
raise NotImplementedError(
"Subclass of ExpressionProvider must implement `Expression` method"
)
raise NotImplementedError("Subclass of ExpressionProvider must implement `Expression` method")
class DatasetProvider(abc.ABC):
@@ -366,9 +349,7 @@ class DatasetProvider(abc.ABC):
pd.DataFrame
a pandas dataframe with <instrument, datetime> index.
"""
raise NotImplementedError(
"Subclass of DatasetProvider must implement `Dataset` method"
)
raise NotImplementedError("Subclass of DatasetProvider must implement `Dataset` method")
def _uri(
self,
@@ -398,9 +379,7 @@ class DatasetProvider(abc.ABC):
whether to skip(0)/use(1)/replace(2) disk_cache.
"""
return DiskDatasetCache._uri(
instruments, fields, start_time, end_time, freq, disk_cache
)
return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache)
@staticmethod
def get_instruments_d(instruments, freq):
@@ -412,9 +391,7 @@ class DatasetProvider(abc.ABC):
if isinstance(instruments, dict):
if "market" in instruments:
# dict of stockpool config
instruments_d = Inst.list_instruments(
instruments=instruments, freq=freq, as_list=False
)
instruments_d = Inst.list_instruments(instruments=instruments, freq=freq, as_list=False)
else:
# dict of instruments and timestamp
instruments_d = instruments
@@ -504,9 +481,7 @@ class DatasetProvider(abc.ABC):
return data
@staticmethod
def expression_calculator(
inst, start_time, end_time, freq, column_names, spans=None, g_config=None
):
def expression_calculator(inst, start_time, end_time, freq, column_names, spans=None, g_config=None):
"""
Calculate the expressions for one instrument, return a df result.
If the expression has been calculated before, load from cache.
@@ -571,9 +546,7 @@ class LocalCalendarProvider(CalendarProvider):
fname = self._uri_cal.format(freq + "_future")
# if future calendar not exists, return current calendar
if not os.path.exists(fname):
get_module_logger("data").warning(
f"{freq}_future.txt not exists, return current calendar!"
)
get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!")
fname = self._uri_cal.format(freq)
else:
fname = self._uri_cal.format(freq)
@@ -635,9 +608,7 @@ class LocalInstrumentProvider(InstrumentProvider):
_instruments.setdefault(row[0], []).append((row[1], row[2]))
return _instruments
def list_instruments(
self, instruments, start_time=None, end_time=None, freq="day", as_list=False
):
def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
market = instruments["market"]
if market in H["i"]:
_instruments = H["i"][market]
@@ -658,20 +629,14 @@ class LocalInstrumentProvider(InstrumentProvider):
)
for inst, spans in _instruments.items()
}
_instruments_filtered = {
key: value for key, value in _instruments_filtered.items() if value
}
_instruments_filtered = {key: value for key, value in _instruments_filtered.items() if value}
# filter
filter_pipe = instruments["filter_pipe"]
for filter_config in filter_pipe:
from . import filter as F
filter_t = getattr(F, filter_config["filter_type"]).from_config(
filter_config
)
_instruments_filtered = filter_t(
_instruments_filtered, start_time, end_time, freq
)
filter_t = getattr(F, filter_config["filter_type"]).from_config(filter_config)
_instruments_filtered = filter_t(_instruments_filtered, start_time, end_time, freq)
# as list
if as_list:
return list(_instruments_filtered)
@@ -698,9 +663,7 @@ class LocalFeatureProvider(FeatureProvider):
instrument = Inst.convert_instruments(instrument)
uri_data = self._uri_data.format(instrument.lower(), field, freq)
if not os.path.exists(uri_data):
get_module_logger("data").warning(
"WARN: data not found for %s.%s" % (instrument, field)
)
get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field))
return pd.Series(dtype=np.float32)
# raise ValueError('uri_data not found: ' + uri_data)
# load
@@ -721,13 +684,9 @@ class LocalExpressionProvider(ExpressionProvider):
expression = self.get_expression_instance(field)
start_time = pd.Timestamp(start_time)
end_time = pd.Timestamp(end_time)
_, _, start_index, end_index = Cal.locate_index(
start_time, end_time, freq, future=False
)
_, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq, future=False)
lft_etd, rght_etd = expression.get_extended_window_size()
series = expression.load(
instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq
)
series = expression.load(instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq)
# Ensure that each column type is consistent
# FIXME:
# 1) The stock data is currently float. If there is other types of data, this part needs to be re-implemented.
@@ -759,16 +718,12 @@ class LocalDatasetProvider(DatasetProvider):
start_time = cal[0]
end_time = cal[-1]
data = self.dataset_processor(
instruments_d, column_names, start_time, end_time, freq
)
data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq)
return data
@staticmethod
def multi_cache_walker(
instruments, fields, start_time=None, end_time=None, freq="day"
):
def multi_cache_walker(instruments, fields, start_time=None, end_time=None, freq="day"):
"""
This method is used to prepare the expression cache for the client.
Then the client will load the data from expression cache by itself.
@@ -836,9 +791,7 @@ class ClientCalendarProvider(CalendarProvider):
"future": future,
},
msg_queue=self.queue,
msg_proc_func=lambda response_content: [
pd.Timestamp(c) for c in response_content
],
msg_proc_func=lambda response_content: [pd.Timestamp(c) for c in response_content],
)
result = self.queue.get(timeout=C["timeout"])
return result
@@ -857,14 +810,11 @@ class ClientInstrumentProvider(InstrumentProvider):
def set_conn(self, conn):
self.conn = conn
def list_instruments(
self, instruments, start_time=None, end_time=None, freq="day", as_list=False
):
def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
def inst_msg_proc_func(response_content):
if isinstance(response_content, dict):
instrument = {
i: [(pd.Timestamp(s), pd.Timestamp(e)) for s, e in t]
for i, t in response_content.items()
i: [(pd.Timestamp(s), pd.Timestamp(e)) for s, e in t] for i, t in response_content.items()
}
else:
instrument = response_content
@@ -950,9 +900,7 @@ class ClientDatasetProvider(DatasetProvider):
start_time = cal[0]
end_time = cal[-1]
data = self.dataset_processor(
instruments_d, column_names, start_time, end_time, freq
)
data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq)
if return_uri:
return data, feature_uri
else:
@@ -984,12 +932,8 @@ class ClientDatasetProvider(DatasetProvider):
get_module_logger("data").debug("get result")
try:
# pre-mound nfs, used for demo
mnt_feature_uri = os.path.join(
C.get_data_path(), C.dataset_cache_dir_name, feature_uri
)
df = DiskDatasetCache.read_data_from_cache(
mnt_feature_uri, start_time, end_time, fields
)
mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri)
df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields)
get_module_logger("data").debug("finish slicing data")
if return_uri:
return df, feature_uri
@@ -1007,9 +951,7 @@ class BaseProvider:
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
return Cal.calendar(start_time, end_time, freq, future=future)
def instruments(
self, market="all", filter_pipe=None, start_time=None, end_time=None
):
def instruments(self, market="all", filter_pipe=None, start_time=None, end_time=None):
if start_time is not None or end_time is not None:
get_module_logger("Provider").warning(
"The instruments corresponds to a stock pool. "
@@ -1017,9 +959,7 @@ class BaseProvider:
)
return InstrumentProvider.instruments(market, filter_pipe)
def list_instruments(
self, instruments, start_time=None, end_time=None, freq="day", as_list=False
):
def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
return Inst.list_instruments(instruments, start_time, end_time, freq, as_list)
def features(
@@ -1045,9 +985,7 @@ class BaseProvider:
if C.disable_disk_cache:
disk_cache = False
try:
return DatasetD.dataset(
instruments, fields, start_time, end_time, freq, disk_cache
)
return DatasetD.dataset(instruments, fields, start_time, end_time, freq, disk_cache)
except TypeError:
return DatasetD.dataset(instruments, fields, start_time, end_time, freq)
@@ -1068,9 +1006,7 @@ class LocalProvider(BaseProvider):
elif type == "feature":
return DatasetD._uri(**kwargs)
def features_uri(
self, instruments, fields, start_time, end_time, freq, disk_cache=1
):
def features_uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1):
"""features_uri
Return the uri of the generated cache of features/dataset
@@ -1082,9 +1018,7 @@ class LocalProvider(BaseProvider):
:param end_time:
:param freq:
"""
return DatasetD._dataset_uri(
instruments, fields, start_time, end_time, freq, disk_cache
)
return DatasetD._dataset_uri(instruments, fields, start_time, end_time, freq, disk_cache)
class ClientProvider(BaseProvider):
@@ -1148,9 +1082,7 @@ def register_all_wrappers():
_calendar_provider = init_instance_by_config(C.calendar_provider, module)
if getattr(C, "calendar_cache", None) is not None:
_calendar_provider = init_instance_by_config(
C.calendar_cache, module, provide=_calendar_provider
)
_calendar_provider = init_instance_by_config(C.calendar_cache, module, provide=_calendar_provider)
register_wrapper(Cal, _calendar_provider, "qlib.data")
logger.debug(f"registering Cal {C.calendar_provider}-{C.calendar_cache}")
@@ -1166,19 +1098,13 @@ def register_all_wrappers():
# This provider is unnecessary in client provider
_eprovider = init_instance_by_config(C.expression_provider, module)
if getattr(C, "expression_cache", None) is not None:
_eprovider = init_instance_by_config(
C.expression_cache, module, provider=_eprovider
)
_eprovider = init_instance_by_config(C.expression_cache, module, provider=_eprovider)
register_wrapper(ExpressionD, _eprovider, "qlib.data")
logger.debug(
f"registering ExpressioneD {C.expression_provider}-{C.expression_cache}"
)
logger.debug(f"registering ExpressioneD {C.expression_provider}-{C.expression_cache}")
_dprovider = init_instance_by_config(C.dataset_provider, module)
if getattr(C, "dataset_cache", None) is not None:
_dprovider = init_instance_by_config(
C.dataset_cache, module, provider=_dprovider
)
_dprovider = init_instance_by_config(C.dataset_cache, module, provider=_dprovider)
register_wrapper(DatasetD, _dprovider, "qlib.data")
logger.debug(f"registering DataseteD {C.dataset_provider}-{C.dataset_cache}")

View File

@@ -38,9 +38,7 @@ class QlibRecorder:
try:
yield run
except Exception as e:
self.end_exp(
Recorder.STATUS_FA
) # end the experiment if something went wrong
self.end_exp(Recorder.STATUS_FA) # end the experiment if something went wrong
raise e
self.end_exp(Recorder.STATUS_FI)