mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 11:00:57 +08:00
for IDE auto-complete with global Wrapper
R, D, Cal, Inst, FeatureD, ExpressionD, DatasetD, D
This commit is contained in:
@@ -25,7 +25,12 @@ from ..log import get_module_logger
|
||||
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields
|
||||
from .base import Feature
|
||||
from .cache import DiskDatasetCache, DiskExpressionCache
|
||||
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
|
||||
from ..utils import (
|
||||
Wrapper,
|
||||
init_instance_by_config,
|
||||
register_wrapper,
|
||||
get_module_by_module_path,
|
||||
)
|
||||
|
||||
|
||||
class CalendarProvider(abc.ABC):
|
||||
@@ -54,7 +59,9 @@ class CalendarProvider(abc.ABC):
|
||||
list
|
||||
calendar list
|
||||
"""
|
||||
raise NotImplementedError("Subclass of CalendarProvider must implement `calendar` method")
|
||||
raise NotImplementedError(
|
||||
"Subclass of CalendarProvider must implement `calendar` method"
|
||||
)
|
||||
|
||||
def locate_index(self, start_time, end_time, freq, future):
|
||||
"""Locate the start time index and end time index in a calendar under certain frequency.
|
||||
@@ -176,7 +183,9 @@ class InstrumentProvider(abc.ABC):
|
||||
return config
|
||||
|
||||
@abc.abstractmethod
|
||||
def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
|
||||
def list_instruments(
|
||||
self, instruments, start_time=None, end_time=None, freq="day", as_list=False
|
||||
):
|
||||
"""List the instruments based on a certain stockpool config.
|
||||
|
||||
Parameters
|
||||
@@ -195,9 +204,13 @@ class InstrumentProvider(abc.ABC):
|
||||
dict or list
|
||||
instruments list or dictionary with time spans
|
||||
"""
|
||||
raise NotImplementedError("Subclass of InstrumentProvider must implement `list_instruments` method")
|
||||
raise NotImplementedError(
|
||||
"Subclass of InstrumentProvider must implement `list_instruments` method"
|
||||
)
|
||||
|
||||
def _uri(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
|
||||
def _uri(
|
||||
self, instruments, start_time=None, end_time=None, freq="day", as_list=False
|
||||
):
|
||||
return hash_args(instruments, start_time, end_time, freq, as_list)
|
||||
|
||||
# instruments type
|
||||
@@ -221,10 +234,16 @@ class InstrumentProvider(abc.ABC):
|
||||
_df_list = []
|
||||
# FIXME: each process will read these files
|
||||
for _path in Path(C.get_data_path()).joinpath("instruments").glob("*.txt"):
|
||||
_df = pd.read_csv(_path, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"])
|
||||
_df = pd.read_csv(
|
||||
_path,
|
||||
sep="\t",
|
||||
names=["inst", "start_datetime", "end_datetime", "save_inst"],
|
||||
)
|
||||
_df_list.append(_df.iloc[:, [0, -1]])
|
||||
df = pd.concat(_df_list, sort=False).sort_values("save_inst")
|
||||
df = df.drop_duplicates(subset=["save_inst"], keep="first").fillna(axis=1, method="ffill")
|
||||
df = df.drop_duplicates(subset=["save_inst"], keep="first").fillna(
|
||||
axis=1, method="ffill"
|
||||
)
|
||||
_instruments_map = df.set_index("inst").iloc[:, 0].to_dict()
|
||||
setattr(self, "_instruments_map", _instruments_map)
|
||||
return _instruments_map.get(instrument, instrument)
|
||||
@@ -258,7 +277,9 @@ class FeatureProvider(abc.ABC):
|
||||
pd.Series
|
||||
data of a certain feature
|
||||
"""
|
||||
raise NotImplementedError("Subclass of FeatureProvider must implement `feature` method")
|
||||
raise NotImplementedError(
|
||||
"Subclass of FeatureProvider must implement `feature` method"
|
||||
)
|
||||
|
||||
|
||||
class ExpressionProvider(abc.ABC):
|
||||
@@ -279,11 +300,14 @@ class ExpressionProvider(abc.ABC):
|
||||
self.expression_instance_cache[field] = expression
|
||||
except NameError as e:
|
||||
get_module_logger("data").exception(
|
||||
"ERROR: field [%s] contains invalid operator/variable [%s]" % (str(field), str(e).split()[1])
|
||||
"ERROR: field [%s] contains invalid operator/variable [%s]"
|
||||
% (str(field), str(e).split()[1])
|
||||
)
|
||||
raise
|
||||
except SyntaxError:
|
||||
get_module_logger("data").exception("ERROR: field [%s] contains invalid syntax" % str(field))
|
||||
get_module_logger("data").exception(
|
||||
"ERROR: field [%s] contains invalid syntax" % str(field)
|
||||
)
|
||||
raise
|
||||
return expression
|
||||
|
||||
@@ -309,7 +333,9 @@ class ExpressionProvider(abc.ABC):
|
||||
pd.Series
|
||||
data of a certain expression
|
||||
"""
|
||||
raise NotImplementedError("Subclass of ExpressionProvider must implement `Expression` method")
|
||||
raise NotImplementedError(
|
||||
"Subclass of ExpressionProvider must implement `Expression` method"
|
||||
)
|
||||
|
||||
|
||||
class DatasetProvider(abc.ABC):
|
||||
@@ -340,7 +366,9 @@ class DatasetProvider(abc.ABC):
|
||||
pd.DataFrame
|
||||
a pandas dataframe with <instrument, datetime> index.
|
||||
"""
|
||||
raise NotImplementedError("Subclass of DatasetProvider must implement `Dataset` method")
|
||||
raise NotImplementedError(
|
||||
"Subclass of DatasetProvider must implement `Dataset` method"
|
||||
)
|
||||
|
||||
def _uri(
|
||||
self,
|
||||
@@ -370,7 +398,9 @@ class DatasetProvider(abc.ABC):
|
||||
whether to skip(0)/use(1)/replace(2) disk_cache.
|
||||
|
||||
"""
|
||||
return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache)
|
||||
return DiskDatasetCache._uri(
|
||||
instruments, fields, start_time, end_time, freq, disk_cache
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_instruments_d(instruments, freq):
|
||||
@@ -382,7 +412,9 @@ class DatasetProvider(abc.ABC):
|
||||
if isinstance(instruments, dict):
|
||||
if "market" in instruments:
|
||||
# dict of stockpool config
|
||||
instruments_d = Inst.list_instruments(instruments=instruments, freq=freq, as_list=False)
|
||||
instruments_d = Inst.list_instruments(
|
||||
instruments=instruments, freq=freq, as_list=False
|
||||
)
|
||||
else:
|
||||
# dict of instruments and timestamp
|
||||
instruments_d = instruments
|
||||
@@ -472,7 +504,9 @@ class DatasetProvider(abc.ABC):
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def expression_calculator(inst, start_time, end_time, freq, column_names, spans=None, g_config=None):
|
||||
def expression_calculator(
|
||||
inst, start_time, end_time, freq, column_names, spans=None, g_config=None
|
||||
):
|
||||
"""
|
||||
Calculate the expressions for one instrument, return a df result.
|
||||
If the expression has been calculated before, load from cache.
|
||||
@@ -537,7 +571,9 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
fname = self._uri_cal.format(freq + "_future")
|
||||
# if future calendar not exists, return current calendar
|
||||
if not os.path.exists(fname):
|
||||
get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!")
|
||||
get_module_logger("data").warning(
|
||||
f"{freq}_future.txt not exists, return current calendar!"
|
||||
)
|
||||
fname = self._uri_cal.format(freq)
|
||||
else:
|
||||
fname = self._uri_cal.format(freq)
|
||||
@@ -588,14 +624,20 @@ class LocalInstrumentProvider(InstrumentProvider):
|
||||
if not os.path.exists(fname):
|
||||
raise ValueError("instruments not exists for market " + market)
|
||||
_instruments = dict()
|
||||
df = pd.read_csv(fname, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"])
|
||||
df = pd.read_csv(
|
||||
fname,
|
||||
sep="\t",
|
||||
names=["inst", "start_datetime", "end_datetime", "save_inst"],
|
||||
)
|
||||
df["start_datetime"] = pd.to_datetime(df["start_datetime"])
|
||||
df["end_datetime"] = pd.to_datetime(df["end_datetime"])
|
||||
for row in df.itertuples(index=False):
|
||||
_instruments.setdefault(row[0], []).append((row[1], row[2]))
|
||||
return _instruments
|
||||
|
||||
def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
|
||||
def list_instruments(
|
||||
self, instruments, start_time=None, end_time=None, freq="day", as_list=False
|
||||
):
|
||||
market = instruments["market"]
|
||||
if market in H["i"]:
|
||||
_instruments = H["i"][market]
|
||||
@@ -616,14 +658,20 @@ class LocalInstrumentProvider(InstrumentProvider):
|
||||
)
|
||||
for inst, spans in _instruments.items()
|
||||
}
|
||||
_instruments_filtered = {key: value for key, value in _instruments_filtered.items() if value}
|
||||
_instruments_filtered = {
|
||||
key: value for key, value in _instruments_filtered.items() if value
|
||||
}
|
||||
# filter
|
||||
filter_pipe = instruments["filter_pipe"]
|
||||
for filter_config in filter_pipe:
|
||||
from . import filter as F
|
||||
|
||||
filter_t = getattr(F, filter_config["filter_type"]).from_config(filter_config)
|
||||
_instruments_filtered = filter_t(_instruments_filtered, start_time, end_time, freq)
|
||||
filter_t = getattr(F, filter_config["filter_type"]).from_config(
|
||||
filter_config
|
||||
)
|
||||
_instruments_filtered = filter_t(
|
||||
_instruments_filtered, start_time, end_time, freq
|
||||
)
|
||||
# as list
|
||||
if as_list:
|
||||
return list(_instruments_filtered)
|
||||
@@ -650,7 +698,9 @@ class LocalFeatureProvider(FeatureProvider):
|
||||
instrument = Inst.convert_instruments(instrument)
|
||||
uri_data = self._uri_data.format(instrument.lower(), field, freq)
|
||||
if not os.path.exists(uri_data):
|
||||
get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field))
|
||||
get_module_logger("data").warning(
|
||||
"WARN: data not found for %s.%s" % (instrument, field)
|
||||
)
|
||||
return pd.Series(dtype=np.float32)
|
||||
# raise ValueError('uri_data not found: ' + uri_data)
|
||||
# load
|
||||
@@ -671,9 +721,13 @@ class LocalExpressionProvider(ExpressionProvider):
|
||||
expression = self.get_expression_instance(field)
|
||||
start_time = pd.Timestamp(start_time)
|
||||
end_time = pd.Timestamp(end_time)
|
||||
_, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq, future=False)
|
||||
_, _, start_index, end_index = Cal.locate_index(
|
||||
start_time, end_time, freq, future=False
|
||||
)
|
||||
lft_etd, rght_etd = expression.get_extended_window_size()
|
||||
series = expression.load(instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq)
|
||||
series = expression.load(
|
||||
instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq
|
||||
)
|
||||
# Ensure that each column type is consistent
|
||||
# FIXME:
|
||||
# 1) The stock data is currently float. If there is other types of data, this part needs to be re-implemented.
|
||||
@@ -705,12 +759,16 @@ class LocalDatasetProvider(DatasetProvider):
|
||||
start_time = cal[0]
|
||||
end_time = cal[-1]
|
||||
|
||||
data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq)
|
||||
data = self.dataset_processor(
|
||||
instruments_d, column_names, start_time, end_time, freq
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def multi_cache_walker(instruments, fields, start_time=None, end_time=None, freq="day"):
|
||||
def multi_cache_walker(
|
||||
instruments, fields, start_time=None, end_time=None, freq="day"
|
||||
):
|
||||
"""
|
||||
This method is used to prepare the expression cache for the client.
|
||||
Then the client will load the data from expression cache by itself.
|
||||
@@ -778,7 +836,9 @@ class ClientCalendarProvider(CalendarProvider):
|
||||
"future": future,
|
||||
},
|
||||
msg_queue=self.queue,
|
||||
msg_proc_func=lambda response_content: [pd.Timestamp(c) for c in response_content],
|
||||
msg_proc_func=lambda response_content: [
|
||||
pd.Timestamp(c) for c in response_content
|
||||
],
|
||||
)
|
||||
result = self.queue.get(timeout=C["timeout"])
|
||||
return result
|
||||
@@ -797,11 +857,14 @@ class ClientInstrumentProvider(InstrumentProvider):
|
||||
def set_conn(self, conn):
|
||||
self.conn = conn
|
||||
|
||||
def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
|
||||
def list_instruments(
|
||||
self, instruments, start_time=None, end_time=None, freq="day", as_list=False
|
||||
):
|
||||
def inst_msg_proc_func(response_content):
|
||||
if isinstance(response_content, dict):
|
||||
instrument = {
|
||||
i: [(pd.Timestamp(s), pd.Timestamp(e)) for s, e in t] for i, t in response_content.items()
|
||||
i: [(pd.Timestamp(s), pd.Timestamp(e)) for s, e in t]
|
||||
for i, t in response_content.items()
|
||||
}
|
||||
else:
|
||||
instrument = response_content
|
||||
@@ -887,7 +950,9 @@ class ClientDatasetProvider(DatasetProvider):
|
||||
start_time = cal[0]
|
||||
end_time = cal[-1]
|
||||
|
||||
data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq)
|
||||
data = self.dataset_processor(
|
||||
instruments_d, column_names, start_time, end_time, freq
|
||||
)
|
||||
if return_uri:
|
||||
return data, feature_uri
|
||||
else:
|
||||
@@ -919,8 +984,12 @@ class ClientDatasetProvider(DatasetProvider):
|
||||
get_module_logger("data").debug("get result")
|
||||
try:
|
||||
# pre-mound nfs, used for demo
|
||||
mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri)
|
||||
df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields)
|
||||
mnt_feature_uri = os.path.join(
|
||||
C.get_data_path(), C.dataset_cache_dir_name, feature_uri
|
||||
)
|
||||
df = DiskDatasetCache.read_data_from_cache(
|
||||
mnt_feature_uri, start_time, end_time, fields
|
||||
)
|
||||
get_module_logger("data").debug("finish slicing data")
|
||||
if return_uri:
|
||||
return df, feature_uri
|
||||
@@ -938,7 +1007,9 @@ class BaseProvider:
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
|
||||
return Cal.calendar(start_time, end_time, freq, future=future)
|
||||
|
||||
def instruments(self, market="all", filter_pipe=None, start_time=None, end_time=None):
|
||||
def instruments(
|
||||
self, market="all", filter_pipe=None, start_time=None, end_time=None
|
||||
):
|
||||
if start_time is not None or end_time is not None:
|
||||
get_module_logger("Provider").warning(
|
||||
"The instruments corresponds to a stock pool. "
|
||||
@@ -946,7 +1017,9 @@ class BaseProvider:
|
||||
)
|
||||
return InstrumentProvider.instruments(market, filter_pipe)
|
||||
|
||||
def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
|
||||
def list_instruments(
|
||||
self, instruments, start_time=None, end_time=None, freq="day", as_list=False
|
||||
):
|
||||
return Inst.list_instruments(instruments, start_time, end_time, freq, as_list)
|
||||
|
||||
def features(
|
||||
@@ -972,7 +1045,9 @@ class BaseProvider:
|
||||
if C.disable_disk_cache:
|
||||
disk_cache = False
|
||||
try:
|
||||
return DatasetD.dataset(instruments, fields, start_time, end_time, freq, disk_cache)
|
||||
return DatasetD.dataset(
|
||||
instruments, fields, start_time, end_time, freq, disk_cache
|
||||
)
|
||||
except TypeError:
|
||||
return DatasetD.dataset(instruments, fields, start_time, end_time, freq)
|
||||
|
||||
@@ -993,7 +1068,9 @@ class LocalProvider(BaseProvider):
|
||||
elif type == "feature":
|
||||
return DatasetD._uri(**kwargs)
|
||||
|
||||
def features_uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1):
|
||||
def features_uri(
|
||||
self, instruments, fields, start_time, end_time, freq, disk_cache=1
|
||||
):
|
||||
"""features_uri
|
||||
|
||||
Return the uri of the generated cache of features/dataset
|
||||
@@ -1005,7 +1082,9 @@ class LocalProvider(BaseProvider):
|
||||
:param end_time:
|
||||
:param freq:
|
||||
"""
|
||||
return DatasetD._dataset_uri(instruments, fields, start_time, end_time, freq, disk_cache)
|
||||
return DatasetD._dataset_uri(
|
||||
instruments, fields, start_time, end_time, freq, disk_cache
|
||||
)
|
||||
|
||||
|
||||
class ClientProvider(BaseProvider):
|
||||
@@ -1035,12 +1114,31 @@ class ClientProvider(BaseProvider):
|
||||
DatasetD.set_conn(self.client)
|
||||
|
||||
|
||||
Cal = Wrapper()
|
||||
Inst = Wrapper()
|
||||
FeatureD = Wrapper()
|
||||
ExpressionD = Wrapper()
|
||||
DatasetD = Wrapper()
|
||||
D = Wrapper()
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
from typing import Annotated
|
||||
|
||||
CalendarProviderWrapper = Annotated[CalendarProvider, Wrapper]
|
||||
InstrumentProviderWrapper = Annotated[InstrumentProvider, Wrapper]
|
||||
FeatureProviderWrapper = Annotated[FeatureProvider, Wrapper]
|
||||
ExpressionProviderWrapper = Annotated[ExpressionProvider, Wrapper]
|
||||
DatasetProviderWrapper = Annotated[DatasetProvider, Wrapper]
|
||||
BaseProviderWrapper = Annotated[BaseProvider, Wrapper]
|
||||
else:
|
||||
CalendarProviderWrapper = CalendarProvider
|
||||
InstrumentProviderWrapper = InstrumentProvider
|
||||
FeatureProviderWrapper = FeatureProvider
|
||||
ExpressionProviderWrapper = ExpressionProvider
|
||||
DatasetProviderWrapper = DatasetProvider
|
||||
BaseProviderWrapper = BaseProvider
|
||||
|
||||
Cal: CalendarProviderWrapper = Wrapper()
|
||||
Inst: InstrumentProviderWrapper = Wrapper()
|
||||
FeatureD: FeatureProviderWrapper = Wrapper()
|
||||
ExpressionD: ExpressionProviderWrapper = Wrapper()
|
||||
DatasetD: DatasetProviderWrapper = Wrapper()
|
||||
D: BaseProviderWrapper = Wrapper()
|
||||
|
||||
|
||||
def register_all_wrappers():
|
||||
@@ -1050,7 +1148,9 @@ def register_all_wrappers():
|
||||
|
||||
_calendar_provider = init_instance_by_config(C.calendar_provider, module)
|
||||
if getattr(C, "calendar_cache", None) is not None:
|
||||
_calendar_provider = init_instance_by_config(C.calendar_cache, module, provide=_calendar_provider)
|
||||
_calendar_provider = init_instance_by_config(
|
||||
C.calendar_cache, module, provide=_calendar_provider
|
||||
)
|
||||
register_wrapper(Cal, _calendar_provider, "qlib.data")
|
||||
logger.debug(f"registering Cal {C.calendar_provider}-{C.calendar_cache}")
|
||||
|
||||
@@ -1066,13 +1166,19 @@ def register_all_wrappers():
|
||||
# This provider is unnecessary in client provider
|
||||
_eprovider = init_instance_by_config(C.expression_provider, module)
|
||||
if getattr(C, "expression_cache", None) is not None:
|
||||
_eprovider = init_instance_by_config(C.expression_cache, module, provider=_eprovider)
|
||||
_eprovider = init_instance_by_config(
|
||||
C.expression_cache, module, provider=_eprovider
|
||||
)
|
||||
register_wrapper(ExpressionD, _eprovider, "qlib.data")
|
||||
logger.debug(f"registering ExpressioneD {C.expression_provider}-{C.expression_cache}")
|
||||
logger.debug(
|
||||
f"registering ExpressioneD {C.expression_provider}-{C.expression_cache}"
|
||||
)
|
||||
|
||||
_dprovider = init_instance_by_config(C.dataset_provider, module)
|
||||
if getattr(C, "dataset_cache", None) is not None:
|
||||
_dprovider = init_instance_by_config(C.dataset_cache, module, provider=_dprovider)
|
||||
_dprovider = init_instance_by_config(
|
||||
C.dataset_cache, module, provider=_dprovider
|
||||
)
|
||||
register_wrapper(DatasetD, _dprovider, "qlib.data")
|
||||
logger.debug(f"registering DataseteD {C.dataset_provider}-{C.dataset_cache}")
|
||||
|
||||
|
||||
@@ -38,7 +38,9 @@ class QlibRecorder:
|
||||
try:
|
||||
yield run
|
||||
except Exception as e:
|
||||
self.end_exp(Recorder.STATUS_FA) # end the experiment if something went wrong
|
||||
self.end_exp(
|
||||
Recorder.STATUS_FA
|
||||
) # end the experiment if something went wrong
|
||||
raise e
|
||||
self.end_exp(Recorder.STATUS_FI)
|
||||
|
||||
@@ -461,5 +463,14 @@ class QlibRecorder:
|
||||
self.get_exp().get_recorder().set_tags(**kwargs)
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
from typing import Annotated
|
||||
|
||||
QlibRecorderWrapper = Annotated[QlibRecorder, Wrapper]
|
||||
else:
|
||||
QlibRecorderWrapper = QlibRecorder
|
||||
|
||||
# global record
|
||||
R = Wrapper()
|
||||
R: QlibRecorderWrapper = Wrapper()
|
||||
|
||||
Reference in New Issue
Block a user