1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 02:21:18 +08:00

Add DataPathManager to QlibConfig && modify inst_processors to supports list only

This commit is contained in:
zhupr
2021-09-02 11:45:37 +08:00
committed by you-n-g
parent 707399a245
commit e84cc23589
10 changed files with 148 additions and 179 deletions

View File

@@ -4,9 +4,8 @@ import pandas as pd
from qlib.data.inst_processor import InstProcessor
class ResampleProcessor(InstProcessor):
def __init__(self, freq: str, hour: int, minute: int):
self.freq = freq
class Resample1minProcessor(InstProcessor):
def __init__(self, hour: int, minute: int):
self.hour = hour
self.minute = minute

View File

@@ -3,6 +3,8 @@ qlib_init:
day: "~/.qlib/qlib_data/cn_data"
1min: "~/.qlib/qlib_data/cn_data_1min"
region: cn
dataset_cache: null
maxtasksperchild: 1
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
@@ -16,10 +18,9 @@ data_handler_config: &data_handler_config
label: day
feature: 1min
# with label as reference
sample_benchmark: label
sample_config:
inst_processor:
feature:
- class: ResampleProcessor
- class: Resample1minProcessor
module_path: features_sample.py
kwargs:
freq: 1d

View File

@@ -33,41 +33,30 @@ def init(default_conf="client", **kwargs):
H.clear()
C.set(default_conf, **kwargs)
provider_uri_map = C.provider_uri if isinstance(C.provider_uri, dict) else {None: C.provider_uri}
if C.mount_path is not None:
# mount nfs
for _freq, provider_uri in provider_uri_map.items():
mount_path = C.mount_path if _freq is None else C["mount_path"][_freq]
# check path if server/local
uri_type = C.get_uri_type(provider_uri)
if uri_type == C.LOCAL_URI:
if not Path(provider_uri).exists():
if C["auto_mount"]:
logger.error(
f"Invalid provider uri: {provider_uri}, please check if a valid provider uri has been set. This path does not exist."
)
else:
logger.warning(f"auto_path is False, please make sure {mount_path} is mounted")
elif uri_type == C.NFS_URI:
mount_path = _mount_nfs_uri(provider_uri, mount_path, C["auto_mount"])
if _freq is None:
C["mount_path"] = mount_path
# mount nfs
for _freq, provider_uri in C.provider_uri.items():
mount_path = C["mount_path"][_freq]
# check path if server/local
uri_type = C.dpm.get_uri_type(provider_uri)
if uri_type == C.LOCAL_URI:
if not Path(provider_uri).exists():
if C["auto_mount"]:
logger.error(
f"Invalid provider uri: {provider_uri}, please check if a valid provider uri has been set. This path does not exist."
)
else:
C["mount_path"][_freq] = mount_path
else:
raise NotImplementedError(f"This type of URI is not supported")
logger.warning(f"auto_path is False, please make sure {mount_path} is mounted")
elif uri_type == C.NFS_URI:
_mount_nfs_uri(provider_uri, mount_path, C["auto_mount"])
else:
raise NotImplementedError(f"This type of URI is not supported")
C.register()
if "flask_server" in C:
logger.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
logger.info("qlib successfully initialized based on %s settings." % default_conf)
data_path = (
C.get_data_path()
if isinstance(C["provider_uri"], str)
else {_freq: C.get_data_path(_freq) for _freq in C["provider_uri"].keys()}
)
logger.info(f"data_path={data_path}")
logger.info(f"data_path={C.dpm.provider_uri}")
def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
@@ -102,8 +91,6 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
else:
raise OSError(f"unknown error: {result}")
# config mount path
mount_path += ":\\"
else:
# system: linux/Unix/Mac
# check mount
@@ -156,7 +143,6 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
LOG.info("Mount finished")
else:
LOG.warning(f"{_remote_uri} on {_mount_path} is already mounted")
return mount_path
def init_from_yaml_conf(conf_path, **kwargs):

View File

@@ -15,6 +15,7 @@ import os
import re
import copy
import logging
import platform
import multiprocessing
from pathlib import Path
from typing import Union
@@ -238,11 +239,43 @@ class QlibConfig(Config):
# URI_TYPE
LOCAL_URI = "local"
NFS_URI = "nfs"
DEFAULT_FREQ = "__DEFAULT_FREQ"
def __init__(self, default_conf):
super().__init__(default_conf)
self._registered = False
class DataPathManager:
def __init__(self, provider_uri: Union[str, Path, dict], mount_path: Union[str, Path, dict]):
self.provider_uri = provider_uri
self.mount_path = mount_path
@staticmethod
def get_uri_type(uri: Union[str, Path]):
uri = uri if isinstance(uri, str) else str(uri.expanduser().resolve())
is_win = re.match("^[a-zA-Z]:.*", uri) is not None # such as 'C:\\data', 'D:'
# such as 'host:/data/' (User may define short hostname by themselves or use localhost)
is_nfs_or_win = re.match("^[^/]+:.+", uri) is not None
if is_nfs_or_win and not is_win:
return QlibConfig.NFS_URI
else:
return QlibConfig.LOCAL_URI
def get_data_path(self, freq: str = None) -> Path:
if freq is None or freq not in self.provider_uri:
freq = QlibConfig.DEFAULT_FREQ
_provider_uri = self.provider_uri[freq]
if self.get_uri_type(_provider_uri) == QlibConfig.LOCAL_URI:
return Path(_provider_uri)
elif self.get_uri_type(_provider_uri) == QlibConfig.NFS_URI:
if "win" in platform.system().lower():
# windows, mount_path is the drive
return Path(f"{self.mount_path[freq]}:\\")
return Path(self.mount_path[freq])
else:
raise NotImplementedError(f"This type of uri is not supported")
def set_mode(self, mode):
# raise KeyError
self.update(MODE_CONF[mode])
@@ -252,59 +285,39 @@ class QlibConfig(Config):
# raise KeyError
self.update(_default_region_config[region])
@property
def dpm(self):
return self.DataPathManager(self["provider_uri"], self["mount_path"])
def resolve_path(self):
# resolve path
_mount_path = self["mount_path"]
_provider_uri = self["provider_uri"]
if isinstance(_provider_uri, dict):
if _mount_path is not None:
# check provider_uri and mount_path
assert isinstance(_mount_path, dict), f"type(provider_uri) != type(mount_path); {_mount_path}"
_miss_freq = set(_provider_uri.keys()) - set(_mount_path.keys())
assert len(_miss_freq) == 0, f"mount_path is missing freq: {_miss_freq}"
self["mount_path"] = {_freq: str(Path(_path).expanduser().resolve()) for _freq, _path in _mount_path}
for _freq, _uri in _provider_uri.items():
if self.get_uri_type(_uri) == QlibConfig.LOCAL_URI:
self["provider_uri"][_freq] = str(Path(_uri).expanduser().resolve())
elif isinstance(_provider_uri, str):
if _mount_path is not None:
self["mount_path"] = str(Path(_mount_path).expanduser().resolve())
if _provider_uri is None:
raise ValueError("provider_uri cannot be None")
if not isinstance(_provider_uri, dict):
_provider_uri = {self.DEFAULT_FREQ: _provider_uri}
if not isinstance(_mount_path, dict):
_mount_path = {_freq: _mount_path for _freq in _provider_uri.keys()}
if self.get_uri_type(_provider_uri) == QlibConfig.LOCAL_URI:
self["provider_uri"] = str(Path(_provider_uri).expanduser().resolve())
else:
raise TypeError(
f"The types supported by provider_uri are [str, dict], " f"not {type(_provider_uri)}: {_provider_uri}"
# check provider_uri and mount_path
_miss_freq = set(_provider_uri.keys()) - set(_mount_path.keys())
assert len(_miss_freq) == 0, f"mount_path is missing freq: {_miss_freq}"
# resolve
for _freq, _uri in _provider_uri.items():
# provider_uri
if self.DataPathManager.get_uri_type(_uri) == QlibConfig.LOCAL_URI:
_provider_uri[_freq] = str(Path(_uri).expanduser().resolve())
# mount_path
_mount_path[_freq] = (
_mount_path[_freq]
if _mount_path[_freq] is None
else str(Path(_mount_path[_freq]).expanduser().resolve())
)
@staticmethod
def get_uri_type(uri: Union[str, Path]):
uri = uri if isinstance(uri, str) else str(uri.expanduser().resolve())
is_win = re.match("^[a-zA-Z]:.*", uri) is not None # such as 'C:\\data', 'D:'
# such as 'host:/data/' (User may define short hostname by themselves or use localhost)
is_nfs_or_win = re.match("^[^/]+:.+", uri) is not None
if is_nfs_or_win and not is_win:
return QlibConfig.NFS_URI
else:
return QlibConfig.LOCAL_URI
def get_data_path(self, freq: str = None) -> Path:
if freq is None and not isinstance(self["provider_uri"], str):
raise ValueError(f"type(provider_uri) == dict, freq cannot be None; provider_uri: {self['provider_uri']}")
_provider_uri = self["provider_uri"] if isinstance(self["provider_uri"], str) else self["provider_uri"][freq]
_mount_path = (
self["mount_path"]
if self["mount_path"] is None or isinstance(self["mount_path"], str)
else self["mount_path"][freq]
)
if self.get_uri_type(_provider_uri) == QlibConfig.LOCAL_URI:
return Path(_provider_uri)
elif self.get_uri_type(_provider_uri) == QlibConfig.NFS_URI:
return Path(_mount_path)
else:
raise NotImplementedError(f"This type of uri is not supported")
self["provider_uri"] = _provider_uri
self["mount_path"] = _mount_path
def set(self, default_conf="client", **kwargs):
from .utils import set_log_with_config, get_module_logger, can_use_cache

View File

@@ -35,7 +35,7 @@ def get_benchmark_weight(
"""
if not path:
path = Path(C.get_data_path(freq)).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
path = Path(C.dpm.get_data_path(freq)).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
# TODO: the storage of weights should be implemented in a more elegent way
# TODO: The benchmark is not consistant with the filename in instruments.
bench_weight_df = pd.read_csv(path, usecols=["code", "date", "index", "weight"])

View File

@@ -58,8 +58,7 @@ class Alpha360(DataHandlerLP):
fit_start_time=None,
fit_end_time=None,
filter_pipe=None,
sample_config=None,
sample_benchmark=None,
inst_processor=None,
**kwargs,
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
@@ -74,8 +73,7 @@ class Alpha360(DataHandlerLP):
},
"filter_pipe": filter_pipe,
"freq": freq,
"sample_config": sample_config,
"sample_benchmark": sample_benchmark,
"inst_processor": inst_processor,
},
}
@@ -148,8 +146,7 @@ class Alpha158(DataHandlerLP):
fit_end_time=None,
process_type=DataHandlerLP.PTYPE_A,
filter_pipe=None,
sample_config=None,
sample_benchmark=None,
inst_processor=None,
**kwargs,
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
@@ -164,8 +161,7 @@ class Alpha158(DataHandlerLP):
},
"filter_pipe": filter_pipe,
"freq": freq,
"sample_config": sample_config,
"sample_benchmark": sample_benchmark,
"inst_processor": inst_processor,
},
}
super().__init__(

View File

@@ -319,7 +319,7 @@ class BaseProviderCache:
@staticmethod
def get_cache_dir(dir_name: str, freq: str = None) -> Path:
cache_dir = Path(C.get_data_path(freq)).joinpath(dir_name)
cache_dir = Path(C.dpm.get_data_path(freq)).joinpath(dir_name)
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir
@@ -356,7 +356,7 @@ class ExpressionCache(BaseProviderCache):
"""
raise NotImplementedError("Implement this method if you want to use expression cache")
def update(self, cache_uri: Union[str, Path]):
def update(self, cache_uri: Union[str, Path], freq: str = "day"):
"""Update expression cache to latest calendar.
Overide this method to define how to update expression cache corresponding to users' own cache mechanism.
@@ -365,6 +365,7 @@ class ExpressionCache(BaseProviderCache):
----------
cache_uri : str or Path
the complete uri of expression cache file (include dir path).
freq : str
Returns
-------
@@ -385,7 +386,7 @@ class DatasetCache(BaseProviderCache):
HDF_KEY = "df"
def dataset(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
):
"""Get feature dataset.
@@ -419,7 +420,7 @@ class DatasetCache(BaseProviderCache):
raise NotImplementedError("Implement this function to match your own cache mechanism")
def _dataset(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
):
"""Get feature dataset using cache.
@@ -428,7 +429,7 @@ class DatasetCache(BaseProviderCache):
raise NotImplementedError("Implement this method if you want to use dataset feature cache")
def _dataset_uri(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
):
"""Get a uri of feature dataset using cache.
specially:
@@ -441,7 +442,7 @@ class DatasetCache(BaseProviderCache):
"Implement this method if you want to use dataset feature cache as a cache file for client"
)
def update(self, cache_uri: Union[str, Path]):
def update(self, cache_uri: Union[str, Path], freq: str = "day"):
"""Update dataset cache to latest calendar.
Overide this method to define how to update dataset cache corresponding to users' own cache mechanism.
@@ -450,6 +451,7 @@ class DatasetCache(BaseProviderCache):
----------
cache_uri : str or Path
the complete uri of dataset cache file (include dir path).
freq : str
Returns
-------
@@ -542,7 +544,7 @@ class DiskExpressionCache(ExpressionCache):
series = self.provider.expression(instrument, field, _calendar[0], _calendar[-1], freq)
if not series.empty:
# This expresion is empty, we don't generate any cache for it.
with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path(freq))}:expression-{_cache_uri}"):
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:expression-{_cache_uri}"):
self.gen_expression_cache(
expression_data=series,
cache_path=cache_path,
@@ -578,18 +580,16 @@ class DiskExpressionCache(ExpressionCache):
r = np.hstack([df.index[0], expression_data]).astype("<f")
r.tofile(str(cache_path))
def update(self, sid, cache_uri):
def update(self, sid, cache_uri, freq: str = "day"):
# FIXME: when updating the cache, the type of C.provider_uri must be str
assert isinstance(C.provider_uri, str), "when updating the cache, the type of C.provider_uri must be str"
cp_cache_uri = self.get_cache_dir().joinpath(sid).joinpath(cache_uri)
cp_cache_uri = self.get_cache_dir(freq).joinpath(sid).joinpath(cache_uri)
meta_path = cp_cache_uri.with_suffix(".meta")
if not self.check_cache_exists(cp_cache_uri, suffix_list=[".meta"]):
self.logger.info(f"The cache {cp_cache_uri} has corrupted. It will be removed")
self.clear_cache(cp_cache_uri)
return 2
with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path())}:expression-{cache_uri}"):
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path())}:expression-{cache_uri}"):
with meta_path.open("rb") as f:
d = pickle.load(f)
instrument = d["info"]["instrument"]
@@ -651,7 +651,7 @@ class DiskDatasetCache(DatasetCache):
self.remote = kwargs.get("remote", False)
@staticmethod
def _uri(instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=None, **kwargs):
def _uri(instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):
return hash_args(*DatasetCache.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors)
def get_cache_dir(self, freq: str = None) -> Path:
@@ -692,7 +692,7 @@ class DiskDatasetCache(DatasetCache):
return df
def _dataset(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=None
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[]
):
if disk_cache == 0:
@@ -701,7 +701,7 @@ class DiskDatasetCache(DatasetCache):
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
)
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
if not inst_processors:
if inst_processors:
raise ValueError(
f"{self.__class__.__name__} does not support inst_processor. "
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
@@ -724,7 +724,7 @@ class DiskDatasetCache(DatasetCache):
if self.check_cache_exists(cache_path):
if disk_cache == 1:
# use cache
with CacheUtils.reader_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"):
with CacheUtils.reader_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
CacheUtils.visit(cache_path)
features = self.read_data_from_cache(cache_path, start_time, end_time, fields)
elif disk_cache == 2:
@@ -734,7 +734,7 @@ class DiskDatasetCache(DatasetCache):
if gen_flag:
# cache unavailable, generate the cache
with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"):
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
features = self.gen_dataset_cache(
cache_path=cache_path,
instruments=instruments,
@@ -747,7 +747,7 @@ class DiskDatasetCache(DatasetCache):
return features
def _dataset_uri(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=None
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[]
):
if disk_cache == 0:
# In this case, server only checks the expression cache.
@@ -775,12 +775,12 @@ class DiskDatasetCache(DatasetCache):
if self.check_cache_exists(cache_path):
self.logger.debug(f"The cache dataset has already existed {cache_path}. Return the uri directly")
with CacheUtils.reader_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"):
with CacheUtils.reader_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
CacheUtils.visit(cache_path)
return _cache_uri
else:
# cache unavailable, generate the cache
with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"):
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
self.gen_dataset_cache(
cache_path=cache_path,
instruments=instruments,
@@ -854,7 +854,7 @@ class DiskDatasetCache(DatasetCache):
index_data += start_index
return index_data
def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, freq, inst_processors=None):
def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, freq, inst_processors=[]):
"""gen_dataset_cache
.. note:: This function does not consider the cache read write lock. Please
@@ -949,10 +949,8 @@ class DiskDatasetCache(DatasetCache):
# the fields of the cached features are converted to the original fields
return features.swaplevel("datetime", "instrument")
def update(self, cache_uri):
# FIXME: when updating the cache, the type of C.provider_uri must be str
assert isinstance(C.provider_uri, str), "when updating the cache, the type of C.provider_uri must be str"
cp_cache_uri = self.get_cache_dir().joinpath(cache_uri)
def update(self, cache_uri, freq: str = "day"):
cp_cache_uri = self.get_cache_dir(freq).joinpath(cache_uri)
meta_path = cp_cache_uri.with_suffix(".meta")
if not self.check_cache_exists(cp_cache_uri):
self.logger.info(f"The cache {cp_cache_uri} has corrupted. It will be removed")
@@ -960,7 +958,7 @@ class DiskDatasetCache(DatasetCache):
return 2
im = DiskDatasetCache.IndexManager(cp_cache_uri)
with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path())}:dataset-{cache_uri}"):
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path())}:dataset-{cache_uri}"):
with meta_path.open("rb") as f:
d = pickle.load(f)
instruments = d["info"]["instruments"]
@@ -1073,14 +1071,14 @@ class SimpleDatasetCache(DatasetCache):
except KeyError as e:
self.logger.error("Assign a local_cache_path in config if you want to use this cache mechanism")
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=None, **kwargs):
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):
instruments, fields, freq = self.normalize_uri_args(instruments, fields, freq)
return hash_args(
instruments, fields, start_time, end_time, freq, disk_cache, str(self.local_cache_path), inst_processors
)
def _dataset(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
):
if disk_cache == 0:
# In this case, data_set cache is configured but will not be used.
@@ -1115,11 +1113,11 @@ class SimpleDatasetCache(DatasetCache):
class DatasetURICache(DatasetCache):
"""Prepared cache mechanism for server."""
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=None, **kwargs):
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):
return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors)
def dataset(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=None
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[]
):
if "local" in C.dataset_provider.lower():
@@ -1151,7 +1149,7 @@ class DatasetURICache(DatasetCache):
instruments, fields, None, None, freq, disk_cache=disk_cache, inst_processors=inst_processors
)
value, expire = MemCacheExpire.get_cache(H["f"], feature_uri)
mnt_feature_uri = C.get_data_path(freq).joinpath(C.dataset_cache_dir_name).joinpath(feature_uri)
mnt_feature_uri = C.dpm.get_data_path(freq).joinpath(C.dataset_cache_dir_name).joinpath(feature_uri)
if value is None or expire or not mnt_feature_uri.exists():
df, uri = self.provider.dataset(
instruments,

View File

@@ -61,7 +61,7 @@ class ProviderBackendMixin:
provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {})
freq = kwargs.get("freq", "day")
if freq not in provider_uri_map:
provider_uri_map[freq] = C.get_data_path(freq)
provider_uri_map[freq] = C.dpm.get_data_path(freq)
backend_kwargs["provider_uri"] = provider_uri_map[freq]
backend.setdefault("kwargs", {}).update(**kwargs)
return init_instance_by_config(backend)
@@ -353,7 +353,7 @@ class DatasetProvider(abc.ABC):
"""
@abc.abstractmethod
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", inst_processors=None):
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", inst_processors=[]):
"""Get dataset data.
Parameters
@@ -386,7 +386,7 @@ class DatasetProvider(abc.ABC):
end_time=None,
freq="day",
disk_cache=1,
inst_processors=None,
inst_processors=[],
**kwargs,
):
"""Get task uri, used when generating rabbitmq task in qlib_server
@@ -449,7 +449,7 @@ class DatasetProvider(abc.ABC):
return [ExpressionD.get_expression_instance(f) for f in fields]
@staticmethod
def dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors=None):
def dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors=[]):
"""
Load and process the data, return the data set.
- default using multi-kernel method.
@@ -513,7 +513,7 @@ class DatasetProvider(abc.ABC):
@staticmethod
def expression_calculator(
inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=None
inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=[]
):
"""
Calculate the expressions for one instrument, return a df result.
@@ -544,10 +544,10 @@ class DatasetProvider(abc.ABC):
mask |= (data.index >= begin) & (data.index <= end)
data = data[mask]
if inst_processors is not None:
for _processor in inst_processors if isinstance(inst_processors, (list, tuple, set)) else [inst_processors]:
_processor = init_instance_by_config(_processor, accept_types=InstProcessor)
data = _processor(data)
for _processor in inst_processors:
if _processor:
_processor_obj = init_instance_by_config(_processor, accept_types=InstProcessor)
data = _processor_obj(data)
return data
@@ -719,7 +719,7 @@ class LocalDatasetProvider(DatasetProvider):
start_time=None,
end_time=None,
freq="day",
inst_processors=None,
inst_processors=[],
):
instruments_d = self.get_instruments_d(instruments, freq)
column_names = self.get_column_names(fields)
@@ -874,7 +874,7 @@ class ClientDatasetProvider(DatasetProvider):
freq="day",
disk_cache=0,
return_uri=False,
inst_processors=None,
inst_processors=[],
):
if Inst.get_inst_type(instruments) == Inst.DICT:
get_module_logger("data").warning(
@@ -914,7 +914,7 @@ class ClientDatasetProvider(DatasetProvider):
start_time = cal[0]
end_time = cal[-1]
data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq)
data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors)
if return_uri:
return data, feature_uri
else:
@@ -953,7 +953,7 @@ class ClientDatasetProvider(DatasetProvider):
get_module_logger("data").debug("get result")
try:
# pre-mound nfs, used for demo
mnt_feature_uri = C.get_data_path(freq).joinpath(C.dataset_cache_dir_name, feature_uri)
mnt_feature_uri = C.dpm.get_data_path(freq).joinpath(C.dataset_cache_dir_name, feature_uri)
df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields)
get_module_logger("data").debug("finish slicing data")
if return_uri:
@@ -991,7 +991,7 @@ class BaseProvider:
end_time=None,
freq="day",
disk_cache=None,
inst_processors=None,
inst_processors=[],
):
"""
Parameters:

View File

@@ -156,8 +156,7 @@ class QlibDataLoader(DLWParser):
filter_pipe: List = None,
swap_level: bool = True,
freq: Union[str, dict] = "day",
sample_benchmark: str = None,
sample_config: dict = None,
inst_processor: dict = None,
):
"""
Parameters
@@ -168,6 +167,11 @@ class QlibDataLoader(DLWParser):
Filter pipe for the instruments
swap_level :
Whether to swap level of MultiIndex
freq: dict or str
If type(config) == dict and type(freq) == str, load config data using freq.
If type(config) == dict and type(freq) == dict, load config[<group_name>] data using freq[<group_name>]
inst_processor: dict
If inst_processor is not None and type(config) == dict; load config[<group_name>] data using inst_processor[<group_name>]
"""
if filter_pipe is not None:
assert isinstance(filter_pipe, list), "The type of `filter_pipe` must be list."
@@ -181,9 +185,9 @@ class QlibDataLoader(DLWParser):
self.freq = freq
# sample
self.sample_config = sample_config if sample_config is not None else {}
self.sample_benchmark = sample_benchmark
self.can_sample = False
self.inst_processor = inst_processor if inst_processor is not None else {}
assert isinstance(self.inst_processor, dict), f"inst_processor(={self.inst_processor}) must be dict"
super().__init__(config)
if self.is_group:
@@ -192,12 +196,9 @@ class QlibDataLoader(DLWParser):
for _gp in config.keys():
if _gp not in freq:
raise ValueError(f"freq(={freq}) missing group(={_gp})")
assert sample_config, f"freq(={self.freq}), sample_config(={sample_config}) cannot be None/empty"
assert isinstance(sample_config, dict), f"sample_config(={sample_config}) must be dict"
assert (
self.sample_benchmark and self.sample_benchmark in self.fields
), f"sample_benchmark not to specification"
self.can_sample = True
self.inst_processor
), f"freq(={self.freq}), inst_processor(={self.inst_processor}) cannot be None/empty"
def load_group_df(
self,
@@ -216,33 +217,15 @@ class QlibDataLoader(DLWParser):
elif self.filter_pipe is not None:
warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")
freq = self.freq[gp_name] if self.can_sample else self.freq
inst_processor = self.sample_config.get(gp_name, None) if self.can_sample else None
df = D.features(instruments, exprs, start_time, end_time, freq=freq, inst_processors=inst_processor)
freq = self.freq[gp_name] if isinstance(self.freq, dict) else self.freq
df = D.features(
instruments, exprs, start_time, end_time, freq=freq, inst_processors=self.inst_processor.get(gp_name, [])
)
df.columns = names
if self.swap_level:
df = df.swaplevel().sort_index() # NOTE: if swaplevel, return <datetime, instrument>
return df
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
if self.is_group:
group = {
grp: self.load_group_df(instruments, exprs, names, start_time, end_time, grp)
for grp, (exprs, names) in self.fields.items()
}
if self.can_sample:
# reindex: alignment to index of sample_benchmark
for grp, _df in group.items():
if grp == self.sample_benchmark:
continue
else:
group[grp] = _df.reindex(group[self.sample_benchmark].index)
df = pd.concat(group, axis=1)
else:
exprs, names = self.fields
df = self.load_group_df(instruments, exprs, names, start_time, end_time)
return df
class StaticDataLoader(DataLoader):
"""

View File

@@ -1,4 +1,5 @@
import abc
import json
import pandas as pd
@@ -18,13 +19,5 @@ class InstProcessor:
"""
pass
class ResampleProcessor(InstProcessor):
"""resample data"""
def __init__(self, freq: str, func: str, *args, **kwargs):
self.freq = freq
self.func = func
def __call__(self, df: pd.DataFrame, *args, **kwargs):
return getattr(df.resample(self.freq, level="datetime"), self.func)().dropna(how="all")
def __str__(self):
return f"{self.__class__.__name__}:{json.dumps(self.__dict__, sort_keys=True, default=str)}"