mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 11:30:57 +08:00
Add inst_processors to D.features
This commit is contained in:
@@ -1,9 +1,16 @@
|
||||
import datetime
|
||||
import pandas as pd
|
||||
|
||||
from qlib.data.inst_processor import InstProcessor
|
||||
|
||||
def resample_feature(df: pd.DataFrame) -> pd.DataFrame:
|
||||
df = df.droplevel(level="instrument")
|
||||
df = df.loc[df.index.time == datetime.time(13, 1)]
|
||||
df.index = df.index.normalize()
|
||||
return df
|
||||
|
||||
class ResampleProcessor(InstProcessor):
|
||||
def __init__(self, freq: str, hour: int, minute: int):
|
||||
self.freq = freq
|
||||
self.hour = hour
|
||||
self.minute = minute
|
||||
|
||||
def __call__(self, df: pd.DataFrame, *args, **kwargs):
|
||||
df = df.loc[df.index.time == datetime.time(self.hour, self.minute)]
|
||||
df.index = df.index.normalize()
|
||||
return df
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
backend_freq_config:
|
||||
provider_uri:
|
||||
day: "~/.qlib/qlib_data/cn_data"
|
||||
1min: "~/.qlib/qlib_data/cn_data_1min"
|
||||
region: cn
|
||||
@@ -19,13 +18,14 @@ data_handler_config: &data_handler_config
|
||||
# with label as reference
|
||||
sample_benchmark: label
|
||||
sample_config:
|
||||
# using pandas.DataFrame.resample
|
||||
feature: resample("1d", level="datetime").last()
|
||||
# or
|
||||
# using custom function, df.groupby(level="instrument").apply(<user func>)
|
||||
# feature:
|
||||
# module_path: features_sample.py
|
||||
# func: resample_feature
|
||||
feature:
|
||||
- class: ResampleProcessor
|
||||
moudle_path: features_sample.py
|
||||
kwargs:
|
||||
freq: 1d
|
||||
hour: 13
|
||||
minute: 1
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
|
||||
@@ -384,7 +384,9 @@ class DatasetCache(BaseProviderCache):
|
||||
|
||||
HDF_KEY = "df"
|
||||
|
||||
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
|
||||
def dataset(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None
|
||||
):
|
||||
"""Get feature dataset.
|
||||
|
||||
.. note:: Same interface as `dataset` method in dataset provider
|
||||
@@ -395,13 +397,19 @@ class DatasetCache(BaseProviderCache):
|
||||
"""
|
||||
if disk_cache == 0:
|
||||
# skip cache
|
||||
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
|
||||
return self.provider.dataset(
|
||||
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
|
||||
)
|
||||
else:
|
||||
# use and replace cache
|
||||
try:
|
||||
return self._dataset(instruments, fields, start_time, end_time, freq, disk_cache)
|
||||
return self._dataset(
|
||||
instruments, fields, start_time, end_time, freq, disk_cache, inst_processors=inst_processors
|
||||
)
|
||||
except NotImplementedError:
|
||||
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
|
||||
return self.provider.dataset(
|
||||
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
|
||||
)
|
||||
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, **kwargs):
|
||||
"""Get dataset cache file uri.
|
||||
@@ -410,14 +418,18 @@ class DatasetCache(BaseProviderCache):
|
||||
"""
|
||||
raise NotImplementedError("Implement this function to match your own cache mechanism")
|
||||
|
||||
def _dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
|
||||
def _dataset(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None
|
||||
):
|
||||
"""Get feature dataset using cache.
|
||||
|
||||
Override this method to define how to get feature dataset corresponding to users' own cache mechanism.
|
||||
"""
|
||||
raise NotImplementedError("Implement this method if you want to use dataset feature cache")
|
||||
|
||||
def _dataset_uri(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
|
||||
def _dataset_uri(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None
|
||||
):
|
||||
"""Get a uri of feature dataset using cache.
|
||||
specially:
|
||||
disk_cache=1 means using data set cache and return the uri of cache file.
|
||||
@@ -639,8 +651,8 @@ class DiskDatasetCache(DatasetCache):
|
||||
self.remote = kwargs.get("remote", False)
|
||||
|
||||
@staticmethod
|
||||
def _uri(instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs):
|
||||
return hash_args(*DatasetCache.normalize_uri_args(instruments, fields, freq), disk_cache)
|
||||
def _uri(instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=None, **kwargs):
|
||||
return hash_args(*DatasetCache.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors)
|
||||
|
||||
def get_cache_dir(self, freq: str = None) -> Path:
|
||||
return super(DiskDatasetCache, self).get_cache_dir(C.dataset_cache_dir_name, freq)
|
||||
@@ -679,14 +691,29 @@ class DiskDatasetCache(DatasetCache):
|
||||
df = pd.DataFrame(columns=fields)
|
||||
return df
|
||||
|
||||
def _dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0):
|
||||
def _dataset(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=None
|
||||
):
|
||||
|
||||
if disk_cache == 0:
|
||||
# In this case, data_set cache is configured but will not be used.
|
||||
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
|
||||
|
||||
return self.provider.dataset(
|
||||
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
|
||||
)
|
||||
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
|
||||
if not inst_processors:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support inst_processor. "
|
||||
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
|
||||
)
|
||||
_cache_uri = self._uri(
|
||||
instruments=instruments, fields=fields, start_time=None, end_time=None, freq=freq, disk_cache=disk_cache
|
||||
instruments=instruments,
|
||||
fields=fields,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq=freq,
|
||||
disk_cache=disk_cache,
|
||||
inst_processors=inst_processors,
|
||||
)
|
||||
|
||||
cache_path = self.get_cache_dir(freq).joinpath(_cache_uri)
|
||||
@@ -709,13 +736,19 @@ class DiskDatasetCache(DatasetCache):
|
||||
# cache unavailable, generate the cache
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"):
|
||||
features = self.gen_dataset_cache(
|
||||
cache_path=cache_path, instruments=instruments, fields=fields, freq=freq
|
||||
cache_path=cache_path,
|
||||
instruments=instruments,
|
||||
fields=fields,
|
||||
freq=freq,
|
||||
inst_processors=inst_processors,
|
||||
)
|
||||
if not features.empty:
|
||||
features = features.sort_index().loc(axis=0)[:, start_time:end_time]
|
||||
return features
|
||||
|
||||
def _dataset_uri(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0):
|
||||
def _dataset_uri(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=None
|
||||
):
|
||||
if disk_cache == 0:
|
||||
# In this case, server only checks the expression cache.
|
||||
# The client will load the cache data by itself.
|
||||
@@ -723,9 +756,20 @@ class DiskDatasetCache(DatasetCache):
|
||||
|
||||
LocalDatasetProvider.multi_cache_walker(instruments, fields, start_time, end_time, freq)
|
||||
return ""
|
||||
|
||||
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
|
||||
if not inst_processors:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support inst_processor. "
|
||||
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
|
||||
)
|
||||
_cache_uri = self._uri(
|
||||
instruments=instruments, fields=fields, start_time=None, end_time=None, freq=freq, disk_cache=disk_cache
|
||||
instruments=instruments,
|
||||
fields=fields,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq=freq,
|
||||
disk_cache=disk_cache,
|
||||
inst_processors=inst_processors,
|
||||
)
|
||||
cache_path = self.get_cache_dir(freq).joinpath(_cache_uri)
|
||||
|
||||
@@ -737,7 +781,13 @@ class DiskDatasetCache(DatasetCache):
|
||||
else:
|
||||
# cache unavailable, generate the cache
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"):
|
||||
self.gen_dataset_cache(cache_path=cache_path, instruments=instruments, fields=fields, freq=freq)
|
||||
self.gen_dataset_cache(
|
||||
cache_path=cache_path,
|
||||
instruments=instruments,
|
||||
fields=fields,
|
||||
freq=freq,
|
||||
inst_processors=inst_processors,
|
||||
)
|
||||
return _cache_uri
|
||||
|
||||
class IndexManager:
|
||||
@@ -804,7 +854,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
index_data += start_index
|
||||
return index_data
|
||||
|
||||
def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, freq):
|
||||
def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, freq, inst_processors=None):
|
||||
"""gen_dataset_cache
|
||||
|
||||
.. note:: This function does not consider the cache read write lock. Please
|
||||
@@ -839,6 +889,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
:param instruments: The instruments to store the cache.
|
||||
:param fields: The fields to store the cache.
|
||||
:param freq: The freq to store the cache.
|
||||
:param inst_processors: Instrument processors.
|
||||
|
||||
:return type pd.DataFrame; The fields of the returned DataFrame are consistent with the parameters of the function.
|
||||
"""
|
||||
@@ -852,7 +903,9 @@ class DiskDatasetCache(DatasetCache):
|
||||
# while running
|
||||
self.clear_cache(cache_path)
|
||||
|
||||
features = self.provider.dataset(instruments, fields, _calendar[0], _calendar[-1], freq)
|
||||
features = self.provider.dataset(
|
||||
instruments, fields, _calendar[0], _calendar[-1], freq, inst_processors=inst_processors
|
||||
)
|
||||
|
||||
if features.empty:
|
||||
return features
|
||||
@@ -877,6 +930,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
"fields": cache_columns,
|
||||
"freq": freq,
|
||||
"last_update": str(_calendar[-1]), # The last_update to store the cache
|
||||
"inst_processors": inst_processors, # The last_update to store the cache
|
||||
},
|
||||
"meta": {"last_visit": time.time(), "visits": 1},
|
||||
}
|
||||
@@ -913,6 +967,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
fields = d["info"]["fields"]
|
||||
freq = d["info"]["freq"]
|
||||
last_update_time = d["info"]["last_update"]
|
||||
inst_processors = d["info"]["inst_processors"]
|
||||
index_data = im.get_index()
|
||||
|
||||
self.logger.debug("Updating dataset: {}".format(d))
|
||||
@@ -963,7 +1018,12 @@ class DiskDatasetCache(DatasetCache):
|
||||
)
|
||||
|
||||
data = self.provider.dataset(
|
||||
instruments, fields, whole_calendar[current_index - rm_n_period], new_calendar[-1], freq
|
||||
instruments,
|
||||
fields,
|
||||
whole_calendar[current_index - rm_n_period],
|
||||
new_calendar[-1],
|
||||
freq,
|
||||
inst_processors=inst_processors,
|
||||
)
|
||||
|
||||
if not data.empty:
|
||||
@@ -1013,17 +1073,23 @@ class SimpleDatasetCache(DatasetCache):
|
||||
except KeyError as e:
|
||||
self.logger.error("Assign a local_cache_path in config if you want to use this cache mechanism")
|
||||
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs):
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=None, **kwargs):
|
||||
instruments, fields, freq = self.normalize_uri_args(instruments, fields, freq)
|
||||
return hash_args(instruments, fields, start_time, end_time, freq, disk_cache, str(self.local_cache_path))
|
||||
return hash_args(
|
||||
instruments, fields, start_time, end_time, freq, disk_cache, str(self.local_cache_path), inst_processors
|
||||
)
|
||||
|
||||
def _dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
|
||||
def _dataset(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None
|
||||
):
|
||||
if disk_cache == 0:
|
||||
# In this case, data_set cache is configured but will not be used.
|
||||
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
|
||||
self.local_cache_path.mkdir(exist_ok=True, parents=True)
|
||||
cache_file = self.local_cache_path.joinpath(
|
||||
self._uri(instruments, fields, start_time, end_time, freq, disk_cache=disk_cache)
|
||||
self._uri(
|
||||
instruments, fields, start_time, end_time, freq, disk_cache=disk_cache, inst_processors=inst_processors
|
||||
)
|
||||
)
|
||||
gen_flag = False
|
||||
|
||||
@@ -1039,7 +1105,9 @@ class SimpleDatasetCache(DatasetCache):
|
||||
gen_flag = True
|
||||
|
||||
if gen_flag:
|
||||
data = self.provider.dataset(instruments, normalize_cache_fields(fields), start_time, end_time, freq)
|
||||
data = self.provider.dataset(
|
||||
instruments, normalize_cache_fields(fields), start_time, end_time, freq, inst_processors=inst_processors
|
||||
)
|
||||
data.to_pickle(cache_file)
|
||||
return self.cache_to_origin_data(data, fields)
|
||||
|
||||
@@ -1047,26 +1115,53 @@ class SimpleDatasetCache(DatasetCache):
|
||||
class DatasetURICache(DatasetCache):
|
||||
"""Prepared cache mechanism for server."""
|
||||
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs):
|
||||
return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache)
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=None, **kwargs):
|
||||
return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors)
|
||||
|
||||
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0):
|
||||
def dataset(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=None
|
||||
):
|
||||
|
||||
if "local" in C.dataset_provider.lower():
|
||||
# use LocalDatasetProvider
|
||||
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
|
||||
return self.provider.dataset(
|
||||
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
|
||||
)
|
||||
|
||||
if disk_cache == 0:
|
||||
# do not use data_set cache, load data from remote expression cache directly
|
||||
return self.provider.dataset(instruments, fields, start_time, end_time, freq, disk_cache, return_uri=False)
|
||||
|
||||
return self.provider.dataset(
|
||||
instruments,
|
||||
fields,
|
||||
start_time,
|
||||
end_time,
|
||||
freq,
|
||||
disk_cache,
|
||||
return_uri=False,
|
||||
inst_processors=inst_processors,
|
||||
)
|
||||
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
|
||||
if not inst_processors:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support inst_processor. "
|
||||
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
|
||||
)
|
||||
# use ClientDatasetProvider
|
||||
feature_uri = self._uri(instruments, fields, None, None, freq, disk_cache=disk_cache)
|
||||
feature_uri = self._uri(
|
||||
instruments, fields, None, None, freq, disk_cache=disk_cache, inst_processors=inst_processors
|
||||
)
|
||||
value, expire = MemCacheExpire.get_cache(H["f"], feature_uri)
|
||||
mnt_feature_uri = C.get_data_path(freq).joinpath(C.dataset_cache_dir_name).joinpath(feature_uri)
|
||||
if value is None or expire or not mnt_feature_uri.exists():
|
||||
df, uri = self.provider.dataset(
|
||||
instruments, fields, start_time, end_time, freq, disk_cache, return_uri=True
|
||||
instruments,
|
||||
fields,
|
||||
start_time,
|
||||
end_time,
|
||||
freq,
|
||||
disk_cache,
|
||||
return_uri=True,
|
||||
inst_processors=inst_processors,
|
||||
)
|
||||
# cache uri
|
||||
MemCacheExpire.set_cache(H["f"], uri, uri)
|
||||
|
||||
@@ -13,15 +13,26 @@ import bisect
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from multiprocessing import Pool
|
||||
from typing import Iterable, Union
|
||||
|
||||
from .cache import H
|
||||
from ..config import C
|
||||
from .ops import Operators
|
||||
from ..log import get_module_logger
|
||||
from ..utils import parse_field, hash_args, normalize_cache_fields, code_to_fname
|
||||
from .base import Feature
|
||||
from .ops import Operators
|
||||
from .inst_processor import InstProcessor
|
||||
|
||||
from ..log import get_module_logger
|
||||
from .cache import DiskDatasetCache, DiskExpressionCache
|
||||
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
|
||||
from ..utils import (
|
||||
Wrapper,
|
||||
init_instance_by_config,
|
||||
register_wrapper,
|
||||
get_module_by_module_path,
|
||||
parse_field,
|
||||
hash_args,
|
||||
normalize_cache_fields,
|
||||
code_to_fname,
|
||||
)
|
||||
|
||||
|
||||
class ProviderBackendMixin:
|
||||
@@ -342,7 +353,7 @@ class DatasetProvider(abc.ABC):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day"):
|
||||
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", inst_processors=None):
|
||||
"""Get dataset data.
|
||||
|
||||
Parameters
|
||||
@@ -357,6 +368,8 @@ class DatasetProvider(abc.ABC):
|
||||
end of the time range.
|
||||
freq : str
|
||||
time frequency.
|
||||
inst_processors: Iterable[Union[dict, InstProcessor]]
|
||||
the operations performed on each instrument
|
||||
|
||||
Returns
|
||||
----------
|
||||
@@ -373,6 +386,7 @@ class DatasetProvider(abc.ABC):
|
||||
end_time=None,
|
||||
freq="day",
|
||||
disk_cache=1,
|
||||
inst_processors=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Get task uri, used when generating rabbitmq task in qlib_server
|
||||
@@ -393,7 +407,8 @@ class DatasetProvider(abc.ABC):
|
||||
whether to skip(0)/use(1)/replace(2) disk_cache.
|
||||
|
||||
"""
|
||||
return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache)
|
||||
# TODO: qlib-server support inst_processors
|
||||
return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache, inst_processors)
|
||||
|
||||
@staticmethod
|
||||
def get_instruments_d(instruments, freq):
|
||||
@@ -434,7 +449,7 @@ class DatasetProvider(abc.ABC):
|
||||
return [ExpressionD.get_expression_instance(f) for f in fields]
|
||||
|
||||
@staticmethod
|
||||
def dataset_processor(instruments_d, column_names, start_time, end_time, freq):
|
||||
def dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors=None):
|
||||
"""
|
||||
Load and process the data, return the data set.
|
||||
- default using multi-kernel method.
|
||||
@@ -460,6 +475,7 @@ class DatasetProvider(abc.ABC):
|
||||
normalize_column_names,
|
||||
spans,
|
||||
C,
|
||||
inst_processors,
|
||||
),
|
||||
)
|
||||
else:
|
||||
@@ -474,6 +490,7 @@ class DatasetProvider(abc.ABC):
|
||||
normalize_column_names,
|
||||
None,
|
||||
C,
|
||||
inst_processors,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -495,7 +512,9 @@ class DatasetProvider(abc.ABC):
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def expression_calculator(inst, start_time, end_time, freq, column_names, spans=None, g_config=None):
|
||||
def expression_calculator(
|
||||
inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=None
|
||||
):
|
||||
"""
|
||||
Calculate the expressions for one instrument, return a df result.
|
||||
If the expression has been calculated before, load from cache.
|
||||
@@ -519,13 +538,17 @@ class DatasetProvider(abc.ABC):
|
||||
data.index = _calendar[data.index.values.astype(int)]
|
||||
data.index.names = ["datetime"]
|
||||
|
||||
if spans is None:
|
||||
return data
|
||||
else:
|
||||
if spans is not None:
|
||||
mask = np.zeros(len(data), dtype=bool)
|
||||
for begin, end in spans:
|
||||
mask |= (data.index >= begin) & (data.index <= end)
|
||||
return data[mask]
|
||||
data = data[mask]
|
||||
|
||||
if inst_processors is not None:
|
||||
for _processor in inst_processors if isinstance(inst_processors, (list, tuple, set)) else [inst_processors]:
|
||||
_processor = init_instance_by_config(_processor, accept_types=InstProcessor)
|
||||
data = _processor(data)
|
||||
return data
|
||||
|
||||
|
||||
class LocalCalendarProvider(CalendarProvider):
|
||||
@@ -689,7 +712,15 @@ class LocalDatasetProvider(DatasetProvider):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day"):
|
||||
def dataset(
|
||||
self,
|
||||
instruments,
|
||||
fields,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq="day",
|
||||
inst_processors=None,
|
||||
):
|
||||
instruments_d = self.get_instruments_d(instruments, freq)
|
||||
column_names = self.get_column_names(fields)
|
||||
cal = Cal.calendar(start_time, end_time, freq)
|
||||
@@ -698,7 +729,9 @@ class LocalDatasetProvider(DatasetProvider):
|
||||
start_time = cal[0]
|
||||
end_time = cal[-1]
|
||||
|
||||
data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq)
|
||||
data = self.dataset_processor(
|
||||
instruments_d, column_names, start_time, end_time, freq, inst_processors=inst_processors
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@@ -841,6 +874,7 @@ class ClientDatasetProvider(DatasetProvider):
|
||||
freq="day",
|
||||
disk_cache=0,
|
||||
return_uri=False,
|
||||
inst_processors=None,
|
||||
):
|
||||
if Inst.get_inst_type(instruments) == Inst.DICT:
|
||||
get_module_logger("data").warning(
|
||||
@@ -893,6 +927,13 @@ class ClientDatasetProvider(DatasetProvider):
|
||||
- using single-process implementation.
|
||||
|
||||
"""
|
||||
# TODO: support inst_processors, need to change the code of qlib-server at the same time
|
||||
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
|
||||
if not inst_processors:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support inst_processor. "
|
||||
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
|
||||
)
|
||||
self.conn.send_request(
|
||||
request_type="feature",
|
||||
request_content={
|
||||
@@ -950,6 +991,7 @@ class BaseProvider:
|
||||
end_time=None,
|
||||
freq="day",
|
||||
disk_cache=None,
|
||||
inst_processors=None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
@@ -964,9 +1006,11 @@ class BaseProvider:
|
||||
disk_cache = C.default_disk_cache if disk_cache is None else disk_cache
|
||||
fields = list(fields) # In case of tuple.
|
||||
try:
|
||||
return DatasetD.dataset(instruments, fields, start_time, end_time, freq, disk_cache)
|
||||
return DatasetD.dataset(
|
||||
instruments, fields, start_time, end_time, freq, disk_cache, inst_processors=inst_processors
|
||||
)
|
||||
except TypeError:
|
||||
return DatasetD.dataset(instruments, fields, start_time, end_time, freq)
|
||||
return DatasetD.dataset(instruments, fields, start_time, end_time, freq, inst_processors=inst_processors)
|
||||
|
||||
|
||||
class LocalProvider(BaseProvider):
|
||||
|
||||
@@ -181,7 +181,7 @@ class QlibDataLoader(DLWParser):
|
||||
self.freq = freq
|
||||
|
||||
# sample
|
||||
self.sample_config = sample_config
|
||||
self.sample_config = sample_config if sample_config is not None else {}
|
||||
self.sample_benchmark = sample_benchmark
|
||||
self.can_sample = False
|
||||
super().__init__(config)
|
||||
@@ -192,30 +192,12 @@ class QlibDataLoader(DLWParser):
|
||||
for _gp in config.keys():
|
||||
if _gp not in freq:
|
||||
raise ValueError(f"freq(={freq}) missing group(={_gp})")
|
||||
if len(set(freq.values())) == 1:
|
||||
self.freq = list(freq.values())[0]
|
||||
else:
|
||||
assert self.sample_config, f"freq(={self.freq}), sample_config cannot be None/empty"
|
||||
assert isinstance(self.sample_config, dict), f"sample_config(={self.sample_config}) must be dict"
|
||||
assert (
|
||||
self.sample_benchmark and self.sample_benchmark in self.fields
|
||||
), f"sample_benchmark not to specification"
|
||||
self.can_sample = True
|
||||
|
||||
def _get_sample_method(self, gp_name: str) -> Union[str, Type]:
|
||||
_method = self.sample_config.get(gp_name, None)
|
||||
if _method is None:
|
||||
return _method
|
||||
if isinstance(_method, str):
|
||||
# pandas.DataFrame.resample
|
||||
if not _method.startswith("resample"):
|
||||
raise ValueError(f"sample method error, only pandas.DataFrame.resample is supported")
|
||||
elif isinstance(_method, dict):
|
||||
# module_path && func name
|
||||
_method, _ = get_callable_kwargs(_method)
|
||||
else:
|
||||
raise TypeError(f"sample_method only supports [str, dict], currently it is {_method}")
|
||||
return _method
|
||||
assert sample_config, f"freq(={self.freq}), sample_config(={sample_config}) cannot be None/empty"
|
||||
assert isinstance(sample_config, dict), f"sample_config(={sample_config}) must be dict"
|
||||
assert (
|
||||
self.sample_benchmark and self.sample_benchmark in self.fields
|
||||
), f"sample_benchmark not to specification"
|
||||
self.can_sample = True
|
||||
|
||||
def load_group_df(
|
||||
self,
|
||||
@@ -235,17 +217,10 @@ class QlibDataLoader(DLWParser):
|
||||
warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")
|
||||
|
||||
freq = self.freq[gp_name] if self.can_sample else self.freq
|
||||
df = D.features(instruments, exprs, start_time, end_time, freq)
|
||||
df = D.features(
|
||||
instruments, exprs, start_time, end_time, freq=freq, inst_processors=self.sample_config.get(freq, None)
|
||||
)
|
||||
df.columns = names
|
||||
|
||||
if self.can_sample and self.sample_benchmark != gp_name:
|
||||
sample_method = self._get_sample_method(gp_name)
|
||||
if sample_method is None:
|
||||
warnings.warn(f"{gp_name} sample_method is None")
|
||||
if isinstance(sample_method, str):
|
||||
df = eval(f"df.groupby(level='instrument').{sample_method}")
|
||||
else:
|
||||
df = df.groupby(level="instrument").apply(sample_method)
|
||||
if self.swap_level:
|
||||
df = df.swaplevel().sort_index() # NOTE: if swaplevel, return <datetime, instrument>
|
||||
return df
|
||||
@@ -256,11 +231,13 @@ class QlibDataLoader(DLWParser):
|
||||
grp: self.load_group_df(instruments, exprs, names, start_time, end_time, grp)
|
||||
for grp, (exprs, names) in self.fields.items()
|
||||
}
|
||||
for grp, _df in group.items():
|
||||
if grp == self.sample_benchmark:
|
||||
continue
|
||||
else:
|
||||
group[grp] = _df.reindex(group[self.sample_benchmark].index)
|
||||
if self.can_sample:
|
||||
# reindex: alignment to index of sample_benchmark
|
||||
for grp, _df in group.items():
|
||||
if grp == self.sample_benchmark:
|
||||
continue
|
||||
else:
|
||||
group[grp] = _df.reindex(group[self.sample_benchmark].index)
|
||||
df = pd.concat(group, axis=1)
|
||||
else:
|
||||
exprs, names = self.fields
|
||||
|
||||
30
qlib/data/inst_processor.py
Normal file
30
qlib/data/inst_processor.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import abc
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class InstProcessor:
|
||||
@abc.abstractmethod
|
||||
def __call__(self, df: pd.DataFrame, *args, **kwargs):
|
||||
"""
|
||||
process the data
|
||||
|
||||
NOTE: **The processor could change the content of `df` inplace !!!!! **
|
||||
User should keep a copy of data outside
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
The raw_df of handler or result from previous processor.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ResampleProcessor(InstProcessor):
|
||||
"""resample data"""
|
||||
|
||||
def __init__(self, freq: str, func: str, *args, **kwargs):
|
||||
self.freq = freq
|
||||
self.func = func
|
||||
|
||||
def __call__(self, df: pd.DataFrame, *args, **kwargs):
|
||||
return getattr(df.resample(self.freq, level="datetime"), self.func)().dropna(how="all")
|
||||
@@ -1,10 +1,9 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import typing
|
||||
import dill
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
|
||||
@@ -18,6 +17,7 @@ class Serializable:
|
||||
|
||||
pickle_backend = "pickle" # another optional value is "dill" which can pickle more things of python.
|
||||
default_dump_all = False # if dump all things
|
||||
FLAG_KEY = "_qlib_serial_flag"
|
||||
|
||||
def __init__(self):
|
||||
self._dump_all = self.default_dump_all
|
||||
@@ -45,8 +45,6 @@ class Serializable:
|
||||
"""
|
||||
return getattr(self, "_exclude", [])
|
||||
|
||||
FLAG_KEY = "_qlib_serial_flag"
|
||||
|
||||
def config(self, dump_all: bool = None, exclude: list = None, recursive=False):
|
||||
"""
|
||||
configure the serializable object
|
||||
|
||||
Reference in New Issue
Block a user