mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 02:21:18 +08:00
Add DataPathManager to QlibConfig && modify inst_processors to supports list only
This commit is contained in:
@@ -4,9 +4,8 @@ import pandas as pd
|
||||
from qlib.data.inst_processor import InstProcessor
|
||||
|
||||
|
||||
class ResampleProcessor(InstProcessor):
|
||||
def __init__(self, freq: str, hour: int, minute: int):
|
||||
self.freq = freq
|
||||
class Resample1minProcessor(InstProcessor):
|
||||
def __init__(self, hour: int, minute: int):
|
||||
self.hour = hour
|
||||
self.minute = minute
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ qlib_init:
|
||||
day: "~/.qlib/qlib_data/cn_data"
|
||||
1min: "~/.qlib/qlib_data/cn_data_1min"
|
||||
region: cn
|
||||
dataset_cache: null
|
||||
maxtasksperchild: 1
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
@@ -16,10 +18,9 @@ data_handler_config: &data_handler_config
|
||||
label: day
|
||||
feature: 1min
|
||||
# with label as reference
|
||||
sample_benchmark: label
|
||||
sample_config:
|
||||
inst_processor:
|
||||
feature:
|
||||
- class: ResampleProcessor
|
||||
- class: Resample1minProcessor
|
||||
module_path: features_sample.py
|
||||
kwargs:
|
||||
freq: 1d
|
||||
|
||||
@@ -33,41 +33,30 @@ def init(default_conf="client", **kwargs):
|
||||
H.clear()
|
||||
C.set(default_conf, **kwargs)
|
||||
|
||||
provider_uri_map = C.provider_uri if isinstance(C.provider_uri, dict) else {None: C.provider_uri}
|
||||
if C.mount_path is not None:
|
||||
# mount nfs
|
||||
for _freq, provider_uri in provider_uri_map.items():
|
||||
mount_path = C.mount_path if _freq is None else C["mount_path"][_freq]
|
||||
# check path if server/local
|
||||
uri_type = C.get_uri_type(provider_uri)
|
||||
if uri_type == C.LOCAL_URI:
|
||||
if not Path(provider_uri).exists():
|
||||
if C["auto_mount"]:
|
||||
logger.error(
|
||||
f"Invalid provider uri: {provider_uri}, please check if a valid provider uri has been set. This path does not exist."
|
||||
)
|
||||
else:
|
||||
logger.warning(f"auto_path is False, please make sure {mount_path} is mounted")
|
||||
elif uri_type == C.NFS_URI:
|
||||
mount_path = _mount_nfs_uri(provider_uri, mount_path, C["auto_mount"])
|
||||
if _freq is None:
|
||||
C["mount_path"] = mount_path
|
||||
# mount nfs
|
||||
for _freq, provider_uri in C.provider_uri.items():
|
||||
mount_path = C["mount_path"][_freq]
|
||||
# check path if server/local
|
||||
uri_type = C.dpm.get_uri_type(provider_uri)
|
||||
if uri_type == C.LOCAL_URI:
|
||||
if not Path(provider_uri).exists():
|
||||
if C["auto_mount"]:
|
||||
logger.error(
|
||||
f"Invalid provider uri: {provider_uri}, please check if a valid provider uri has been set. This path does not exist."
|
||||
)
|
||||
else:
|
||||
C["mount_path"][_freq] = mount_path
|
||||
else:
|
||||
raise NotImplementedError(f"This type of URI is not supported")
|
||||
logger.warning(f"auto_path is False, please make sure {mount_path} is mounted")
|
||||
elif uri_type == C.NFS_URI:
|
||||
_mount_nfs_uri(provider_uri, mount_path, C["auto_mount"])
|
||||
else:
|
||||
raise NotImplementedError(f"This type of URI is not supported")
|
||||
|
||||
C.register()
|
||||
|
||||
if "flask_server" in C:
|
||||
logger.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
|
||||
logger.info("qlib successfully initialized based on %s settings." % default_conf)
|
||||
data_path = (
|
||||
C.get_data_path()
|
||||
if isinstance(C["provider_uri"], str)
|
||||
else {_freq: C.get_data_path(_freq) for _freq in C["provider_uri"].keys()}
|
||||
)
|
||||
logger.info(f"data_path={data_path}")
|
||||
logger.info(f"data_path={C.dpm.provider_uri}")
|
||||
|
||||
|
||||
def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
@@ -102,8 +91,6 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
else:
|
||||
raise OSError(f"unknown error: {result}")
|
||||
|
||||
# config mount path
|
||||
mount_path += ":\\"
|
||||
else:
|
||||
# system: linux/Unix/Mac
|
||||
# check mount
|
||||
@@ -156,7 +143,6 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
LOG.info("Mount finished")
|
||||
else:
|
||||
LOG.warning(f"{_remote_uri} on {_mount_path} is already mounted")
|
||||
return mount_path
|
||||
|
||||
|
||||
def init_from_yaml_conf(conf_path, **kwargs):
|
||||
|
||||
105
qlib/config.py
105
qlib/config.py
@@ -15,6 +15,7 @@ import os
|
||||
import re
|
||||
import copy
|
||||
import logging
|
||||
import platform
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
@@ -238,11 +239,43 @@ class QlibConfig(Config):
|
||||
# URI_TYPE
|
||||
LOCAL_URI = "local"
|
||||
NFS_URI = "nfs"
|
||||
DEFAULT_FREQ = "__DEFAULT_FREQ"
|
||||
|
||||
def __init__(self, default_conf):
|
||||
super().__init__(default_conf)
|
||||
self._registered = False
|
||||
|
||||
class DataPathManager:
|
||||
def __init__(self, provider_uri: Union[str, Path, dict], mount_path: Union[str, Path, dict]):
|
||||
self.provider_uri = provider_uri
|
||||
self.mount_path = mount_path
|
||||
|
||||
@staticmethod
|
||||
def get_uri_type(uri: Union[str, Path]):
|
||||
uri = uri if isinstance(uri, str) else str(uri.expanduser().resolve())
|
||||
is_win = re.match("^[a-zA-Z]:.*", uri) is not None # such as 'C:\\data', 'D:'
|
||||
# such as 'host:/data/' (User may define short hostname by themselves or use localhost)
|
||||
is_nfs_or_win = re.match("^[^/]+:.+", uri) is not None
|
||||
|
||||
if is_nfs_or_win and not is_win:
|
||||
return QlibConfig.NFS_URI
|
||||
else:
|
||||
return QlibConfig.LOCAL_URI
|
||||
|
||||
def get_data_path(self, freq: str = None) -> Path:
|
||||
if freq is None or freq not in self.provider_uri:
|
||||
freq = QlibConfig.DEFAULT_FREQ
|
||||
_provider_uri = self.provider_uri[freq]
|
||||
if self.get_uri_type(_provider_uri) == QlibConfig.LOCAL_URI:
|
||||
return Path(_provider_uri)
|
||||
elif self.get_uri_type(_provider_uri) == QlibConfig.NFS_URI:
|
||||
if "win" in platform.system().lower():
|
||||
# windows, mount_path is the drive
|
||||
return Path(f"{self.mount_path[freq]}:\\")
|
||||
return Path(self.mount_path[freq])
|
||||
else:
|
||||
raise NotImplementedError(f"This type of uri is not supported")
|
||||
|
||||
def set_mode(self, mode):
|
||||
# raise KeyError
|
||||
self.update(MODE_CONF[mode])
|
||||
@@ -252,59 +285,39 @@ class QlibConfig(Config):
|
||||
# raise KeyError
|
||||
self.update(_default_region_config[region])
|
||||
|
||||
@property
|
||||
def dpm(self):
|
||||
return self.DataPathManager(self["provider_uri"], self["mount_path"])
|
||||
|
||||
def resolve_path(self):
|
||||
# resolve path
|
||||
_mount_path = self["mount_path"]
|
||||
_provider_uri = self["provider_uri"]
|
||||
if isinstance(_provider_uri, dict):
|
||||
if _mount_path is not None:
|
||||
# check provider_uri and mount_path
|
||||
assert isinstance(_mount_path, dict), f"type(provider_uri) != type(mount_path); {_mount_path}"
|
||||
_miss_freq = set(_provider_uri.keys()) - set(_mount_path.keys())
|
||||
assert len(_miss_freq) == 0, f"mount_path is missing freq: {_miss_freq}"
|
||||
self["mount_path"] = {_freq: str(Path(_path).expanduser().resolve()) for _freq, _path in _mount_path}
|
||||
for _freq, _uri in _provider_uri.items():
|
||||
if self.get_uri_type(_uri) == QlibConfig.LOCAL_URI:
|
||||
self["provider_uri"][_freq] = str(Path(_uri).expanduser().resolve())
|
||||
elif isinstance(_provider_uri, str):
|
||||
if _mount_path is not None:
|
||||
self["mount_path"] = str(Path(_mount_path).expanduser().resolve())
|
||||
if _provider_uri is None:
|
||||
raise ValueError("provider_uri cannot be None")
|
||||
if not isinstance(_provider_uri, dict):
|
||||
_provider_uri = {self.DEFAULT_FREQ: _provider_uri}
|
||||
if not isinstance(_mount_path, dict):
|
||||
_mount_path = {_freq: _mount_path for _freq in _provider_uri.keys()}
|
||||
|
||||
if self.get_uri_type(_provider_uri) == QlibConfig.LOCAL_URI:
|
||||
self["provider_uri"] = str(Path(_provider_uri).expanduser().resolve())
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The types supported by provider_uri are [str, dict], " f"not {type(_provider_uri)}: {_provider_uri}"
|
||||
# check provider_uri and mount_path
|
||||
_miss_freq = set(_provider_uri.keys()) - set(_mount_path.keys())
|
||||
assert len(_miss_freq) == 0, f"mount_path is missing freq: {_miss_freq}"
|
||||
|
||||
# resolve
|
||||
for _freq, _uri in _provider_uri.items():
|
||||
# provider_uri
|
||||
if self.DataPathManager.get_uri_type(_uri) == QlibConfig.LOCAL_URI:
|
||||
_provider_uri[_freq] = str(Path(_uri).expanduser().resolve())
|
||||
# mount_path
|
||||
_mount_path[_freq] = (
|
||||
_mount_path[_freq]
|
||||
if _mount_path[_freq] is None
|
||||
else str(Path(_mount_path[_freq]).expanduser().resolve())
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_uri_type(uri: Union[str, Path]):
|
||||
uri = uri if isinstance(uri, str) else str(uri.expanduser().resolve())
|
||||
is_win = re.match("^[a-zA-Z]:.*", uri) is not None # such as 'C:\\data', 'D:'
|
||||
# such as 'host:/data/' (User may define short hostname by themselves or use localhost)
|
||||
is_nfs_or_win = re.match("^[^/]+:.+", uri) is not None
|
||||
|
||||
if is_nfs_or_win and not is_win:
|
||||
return QlibConfig.NFS_URI
|
||||
else:
|
||||
return QlibConfig.LOCAL_URI
|
||||
|
||||
def get_data_path(self, freq: str = None) -> Path:
|
||||
if freq is None and not isinstance(self["provider_uri"], str):
|
||||
raise ValueError(f"type(provider_uri) == dict, freq cannot be None; provider_uri: {self['provider_uri']}")
|
||||
_provider_uri = self["provider_uri"] if isinstance(self["provider_uri"], str) else self["provider_uri"][freq]
|
||||
_mount_path = (
|
||||
self["mount_path"]
|
||||
if self["mount_path"] is None or isinstance(self["mount_path"], str)
|
||||
else self["mount_path"][freq]
|
||||
)
|
||||
|
||||
if self.get_uri_type(_provider_uri) == QlibConfig.LOCAL_URI:
|
||||
return Path(_provider_uri)
|
||||
elif self.get_uri_type(_provider_uri) == QlibConfig.NFS_URI:
|
||||
return Path(_mount_path)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of uri is not supported")
|
||||
self["provider_uri"] = _provider_uri
|
||||
self["mount_path"] = _mount_path
|
||||
|
||||
def set(self, default_conf="client", **kwargs):
|
||||
from .utils import set_log_with_config, get_module_logger, can_use_cache
|
||||
|
||||
@@ -35,7 +35,7 @@ def get_benchmark_weight(
|
||||
|
||||
"""
|
||||
if not path:
|
||||
path = Path(C.get_data_path(freq)).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
|
||||
path = Path(C.dpm.get_data_path(freq)).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
|
||||
# TODO: the storage of weights should be implemented in a more elegent way
|
||||
# TODO: The benchmark is not consistant with the filename in instruments.
|
||||
bench_weight_df = pd.read_csv(path, usecols=["code", "date", "index", "weight"])
|
||||
|
||||
@@ -58,8 +58,7 @@ class Alpha360(DataHandlerLP):
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
filter_pipe=None,
|
||||
sample_config=None,
|
||||
sample_benchmark=None,
|
||||
inst_processor=None,
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
@@ -74,8 +73,7 @@ class Alpha360(DataHandlerLP):
|
||||
},
|
||||
"filter_pipe": filter_pipe,
|
||||
"freq": freq,
|
||||
"sample_config": sample_config,
|
||||
"sample_benchmark": sample_benchmark,
|
||||
"inst_processor": inst_processor,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -148,8 +146,7 @@ class Alpha158(DataHandlerLP):
|
||||
fit_end_time=None,
|
||||
process_type=DataHandlerLP.PTYPE_A,
|
||||
filter_pipe=None,
|
||||
sample_config=None,
|
||||
sample_benchmark=None,
|
||||
inst_processor=None,
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
@@ -164,8 +161,7 @@ class Alpha158(DataHandlerLP):
|
||||
},
|
||||
"filter_pipe": filter_pipe,
|
||||
"freq": freq,
|
||||
"sample_config": sample_config,
|
||||
"sample_benchmark": sample_benchmark,
|
||||
"inst_processor": inst_processor,
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
|
||||
@@ -319,7 +319,7 @@ class BaseProviderCache:
|
||||
|
||||
@staticmethod
|
||||
def get_cache_dir(dir_name: str, freq: str = None) -> Path:
|
||||
cache_dir = Path(C.get_data_path(freq)).joinpath(dir_name)
|
||||
cache_dir = Path(C.dpm.get_data_path(freq)).joinpath(dir_name)
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
return cache_dir
|
||||
|
||||
@@ -356,7 +356,7 @@ class ExpressionCache(BaseProviderCache):
|
||||
"""
|
||||
raise NotImplementedError("Implement this method if you want to use expression cache")
|
||||
|
||||
def update(self, cache_uri: Union[str, Path]):
|
||||
def update(self, cache_uri: Union[str, Path], freq: str = "day"):
|
||||
"""Update expression cache to latest calendar.
|
||||
|
||||
Overide this method to define how to update expression cache corresponding to users' own cache mechanism.
|
||||
@@ -365,6 +365,7 @@ class ExpressionCache(BaseProviderCache):
|
||||
----------
|
||||
cache_uri : str or Path
|
||||
the complete uri of expression cache file (include dir path).
|
||||
freq : str
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -385,7 +386,7 @@ class DatasetCache(BaseProviderCache):
|
||||
HDF_KEY = "df"
|
||||
|
||||
def dataset(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
|
||||
):
|
||||
"""Get feature dataset.
|
||||
|
||||
@@ -419,7 +420,7 @@ class DatasetCache(BaseProviderCache):
|
||||
raise NotImplementedError("Implement this function to match your own cache mechanism")
|
||||
|
||||
def _dataset(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
|
||||
):
|
||||
"""Get feature dataset using cache.
|
||||
|
||||
@@ -428,7 +429,7 @@ class DatasetCache(BaseProviderCache):
|
||||
raise NotImplementedError("Implement this method if you want to use dataset feature cache")
|
||||
|
||||
def _dataset_uri(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
|
||||
):
|
||||
"""Get a uri of feature dataset using cache.
|
||||
specially:
|
||||
@@ -441,7 +442,7 @@ class DatasetCache(BaseProviderCache):
|
||||
"Implement this method if you want to use dataset feature cache as a cache file for client"
|
||||
)
|
||||
|
||||
def update(self, cache_uri: Union[str, Path]):
|
||||
def update(self, cache_uri: Union[str, Path], freq: str = "day"):
|
||||
"""Update dataset cache to latest calendar.
|
||||
|
||||
Overide this method to define how to update dataset cache corresponding to users' own cache mechanism.
|
||||
@@ -450,6 +451,7 @@ class DatasetCache(BaseProviderCache):
|
||||
----------
|
||||
cache_uri : str or Path
|
||||
the complete uri of dataset cache file (include dir path).
|
||||
freq : str
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -542,7 +544,7 @@ class DiskExpressionCache(ExpressionCache):
|
||||
series = self.provider.expression(instrument, field, _calendar[0], _calendar[-1], freq)
|
||||
if not series.empty:
|
||||
# This expresion is empty, we don't generate any cache for it.
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path(freq))}:expression-{_cache_uri}"):
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:expression-{_cache_uri}"):
|
||||
self.gen_expression_cache(
|
||||
expression_data=series,
|
||||
cache_path=cache_path,
|
||||
@@ -578,18 +580,16 @@ class DiskExpressionCache(ExpressionCache):
|
||||
r = np.hstack([df.index[0], expression_data]).astype("<f")
|
||||
r.tofile(str(cache_path))
|
||||
|
||||
def update(self, sid, cache_uri):
|
||||
def update(self, sid, cache_uri, freq: str = "day"):
|
||||
|
||||
# FIXME: when updating the cache, the type of C.provider_uri must be str
|
||||
assert isinstance(C.provider_uri, str), "when updating the cache, the type of C.provider_uri must be str"
|
||||
cp_cache_uri = self.get_cache_dir().joinpath(sid).joinpath(cache_uri)
|
||||
cp_cache_uri = self.get_cache_dir(freq).joinpath(sid).joinpath(cache_uri)
|
||||
meta_path = cp_cache_uri.with_suffix(".meta")
|
||||
if not self.check_cache_exists(cp_cache_uri, suffix_list=[".meta"]):
|
||||
self.logger.info(f"The cache {cp_cache_uri} has corrupted. It will be removed")
|
||||
self.clear_cache(cp_cache_uri)
|
||||
return 2
|
||||
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path())}:expression-{cache_uri}"):
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path())}:expression-{cache_uri}"):
|
||||
with meta_path.open("rb") as f:
|
||||
d = pickle.load(f)
|
||||
instrument = d["info"]["instrument"]
|
||||
@@ -651,7 +651,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
self.remote = kwargs.get("remote", False)
|
||||
|
||||
@staticmethod
|
||||
def _uri(instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=None, **kwargs):
|
||||
def _uri(instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):
|
||||
return hash_args(*DatasetCache.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors)
|
||||
|
||||
def get_cache_dir(self, freq: str = None) -> Path:
|
||||
@@ -692,7 +692,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
return df
|
||||
|
||||
def _dataset(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=None
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[]
|
||||
):
|
||||
|
||||
if disk_cache == 0:
|
||||
@@ -701,7 +701,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
|
||||
)
|
||||
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
|
||||
if not inst_processors:
|
||||
if inst_processors:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support inst_processor. "
|
||||
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
|
||||
@@ -724,7 +724,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
if self.check_cache_exists(cache_path):
|
||||
if disk_cache == 1:
|
||||
# use cache
|
||||
with CacheUtils.reader_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"):
|
||||
with CacheUtils.reader_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
|
||||
CacheUtils.visit(cache_path)
|
||||
features = self.read_data_from_cache(cache_path, start_time, end_time, fields)
|
||||
elif disk_cache == 2:
|
||||
@@ -734,7 +734,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
|
||||
if gen_flag:
|
||||
# cache unavailable, generate the cache
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"):
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
|
||||
features = self.gen_dataset_cache(
|
||||
cache_path=cache_path,
|
||||
instruments=instruments,
|
||||
@@ -747,7 +747,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
return features
|
||||
|
||||
def _dataset_uri(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=None
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[]
|
||||
):
|
||||
if disk_cache == 0:
|
||||
# In this case, server only checks the expression cache.
|
||||
@@ -775,12 +775,12 @@ class DiskDatasetCache(DatasetCache):
|
||||
|
||||
if self.check_cache_exists(cache_path):
|
||||
self.logger.debug(f"The cache dataset has already existed {cache_path}. Return the uri directly")
|
||||
with CacheUtils.reader_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"):
|
||||
with CacheUtils.reader_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
|
||||
CacheUtils.visit(cache_path)
|
||||
return _cache_uri
|
||||
else:
|
||||
# cache unavailable, generate the cache
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"):
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
|
||||
self.gen_dataset_cache(
|
||||
cache_path=cache_path,
|
||||
instruments=instruments,
|
||||
@@ -854,7 +854,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
index_data += start_index
|
||||
return index_data
|
||||
|
||||
def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, freq, inst_processors=None):
|
||||
def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, freq, inst_processors=[]):
|
||||
"""gen_dataset_cache
|
||||
|
||||
.. note:: This function does not consider the cache read write lock. Please
|
||||
@@ -949,10 +949,8 @@ class DiskDatasetCache(DatasetCache):
|
||||
# the fields of the cached features are converted to the original fields
|
||||
return features.swaplevel("datetime", "instrument")
|
||||
|
||||
def update(self, cache_uri):
|
||||
# FIXME: when updating the cache, the type of C.provider_uri must be str
|
||||
assert isinstance(C.provider_uri, str), "when updating the cache, the type of C.provider_uri must be str"
|
||||
cp_cache_uri = self.get_cache_dir().joinpath(cache_uri)
|
||||
def update(self, cache_uri, freq: str = "day"):
|
||||
cp_cache_uri = self.get_cache_dir(freq).joinpath(cache_uri)
|
||||
meta_path = cp_cache_uri.with_suffix(".meta")
|
||||
if not self.check_cache_exists(cp_cache_uri):
|
||||
self.logger.info(f"The cache {cp_cache_uri} has corrupted. It will be removed")
|
||||
@@ -960,7 +958,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
return 2
|
||||
|
||||
im = DiskDatasetCache.IndexManager(cp_cache_uri)
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path())}:dataset-{cache_uri}"):
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path())}:dataset-{cache_uri}"):
|
||||
with meta_path.open("rb") as f:
|
||||
d = pickle.load(f)
|
||||
instruments = d["info"]["instruments"]
|
||||
@@ -1073,14 +1071,14 @@ class SimpleDatasetCache(DatasetCache):
|
||||
except KeyError as e:
|
||||
self.logger.error("Assign a local_cache_path in config if you want to use this cache mechanism")
|
||||
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=None, **kwargs):
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):
|
||||
instruments, fields, freq = self.normalize_uri_args(instruments, fields, freq)
|
||||
return hash_args(
|
||||
instruments, fields, start_time, end_time, freq, disk_cache, str(self.local_cache_path), inst_processors
|
||||
)
|
||||
|
||||
def _dataset(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
|
||||
):
|
||||
if disk_cache == 0:
|
||||
# In this case, data_set cache is configured but will not be used.
|
||||
@@ -1115,11 +1113,11 @@ class SimpleDatasetCache(DatasetCache):
|
||||
class DatasetURICache(DatasetCache):
|
||||
"""Prepared cache mechanism for server."""
|
||||
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=None, **kwargs):
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):
|
||||
return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors)
|
||||
|
||||
def dataset(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=None
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[]
|
||||
):
|
||||
|
||||
if "local" in C.dataset_provider.lower():
|
||||
@@ -1151,7 +1149,7 @@ class DatasetURICache(DatasetCache):
|
||||
instruments, fields, None, None, freq, disk_cache=disk_cache, inst_processors=inst_processors
|
||||
)
|
||||
value, expire = MemCacheExpire.get_cache(H["f"], feature_uri)
|
||||
mnt_feature_uri = C.get_data_path(freq).joinpath(C.dataset_cache_dir_name).joinpath(feature_uri)
|
||||
mnt_feature_uri = C.dpm.get_data_path(freq).joinpath(C.dataset_cache_dir_name).joinpath(feature_uri)
|
||||
if value is None or expire or not mnt_feature_uri.exists():
|
||||
df, uri = self.provider.dataset(
|
||||
instruments,
|
||||
|
||||
@@ -61,7 +61,7 @@ class ProviderBackendMixin:
|
||||
provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {})
|
||||
freq = kwargs.get("freq", "day")
|
||||
if freq not in provider_uri_map:
|
||||
provider_uri_map[freq] = C.get_data_path(freq)
|
||||
provider_uri_map[freq] = C.dpm.get_data_path(freq)
|
||||
backend_kwargs["provider_uri"] = provider_uri_map[freq]
|
||||
backend.setdefault("kwargs", {}).update(**kwargs)
|
||||
return init_instance_by_config(backend)
|
||||
@@ -353,7 +353,7 @@ class DatasetProvider(abc.ABC):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", inst_processors=None):
|
||||
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", inst_processors=[]):
|
||||
"""Get dataset data.
|
||||
|
||||
Parameters
|
||||
@@ -386,7 +386,7 @@ class DatasetProvider(abc.ABC):
|
||||
end_time=None,
|
||||
freq="day",
|
||||
disk_cache=1,
|
||||
inst_processors=None,
|
||||
inst_processors=[],
|
||||
**kwargs,
|
||||
):
|
||||
"""Get task uri, used when generating rabbitmq task in qlib_server
|
||||
@@ -449,7 +449,7 @@ class DatasetProvider(abc.ABC):
|
||||
return [ExpressionD.get_expression_instance(f) for f in fields]
|
||||
|
||||
@staticmethod
|
||||
def dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors=None):
|
||||
def dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors=[]):
|
||||
"""
|
||||
Load and process the data, return the data set.
|
||||
- default using multi-kernel method.
|
||||
@@ -513,7 +513,7 @@ class DatasetProvider(abc.ABC):
|
||||
|
||||
@staticmethod
|
||||
def expression_calculator(
|
||||
inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=None
|
||||
inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=[]
|
||||
):
|
||||
"""
|
||||
Calculate the expressions for one instrument, return a df result.
|
||||
@@ -544,10 +544,10 @@ class DatasetProvider(abc.ABC):
|
||||
mask |= (data.index >= begin) & (data.index <= end)
|
||||
data = data[mask]
|
||||
|
||||
if inst_processors is not None:
|
||||
for _processor in inst_processors if isinstance(inst_processors, (list, tuple, set)) else [inst_processors]:
|
||||
_processor = init_instance_by_config(_processor, accept_types=InstProcessor)
|
||||
data = _processor(data)
|
||||
for _processor in inst_processors:
|
||||
if _processor:
|
||||
_processor_obj = init_instance_by_config(_processor, accept_types=InstProcessor)
|
||||
data = _processor_obj(data)
|
||||
return data
|
||||
|
||||
|
||||
@@ -719,7 +719,7 @@ class LocalDatasetProvider(DatasetProvider):
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq="day",
|
||||
inst_processors=None,
|
||||
inst_processors=[],
|
||||
):
|
||||
instruments_d = self.get_instruments_d(instruments, freq)
|
||||
column_names = self.get_column_names(fields)
|
||||
@@ -874,7 +874,7 @@ class ClientDatasetProvider(DatasetProvider):
|
||||
freq="day",
|
||||
disk_cache=0,
|
||||
return_uri=False,
|
||||
inst_processors=None,
|
||||
inst_processors=[],
|
||||
):
|
||||
if Inst.get_inst_type(instruments) == Inst.DICT:
|
||||
get_module_logger("data").warning(
|
||||
@@ -914,7 +914,7 @@ class ClientDatasetProvider(DatasetProvider):
|
||||
start_time = cal[0]
|
||||
end_time = cal[-1]
|
||||
|
||||
data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq)
|
||||
data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors)
|
||||
if return_uri:
|
||||
return data, feature_uri
|
||||
else:
|
||||
@@ -953,7 +953,7 @@ class ClientDatasetProvider(DatasetProvider):
|
||||
get_module_logger("data").debug("get result")
|
||||
try:
|
||||
# pre-mound nfs, used for demo
|
||||
mnt_feature_uri = C.get_data_path(freq).joinpath(C.dataset_cache_dir_name, feature_uri)
|
||||
mnt_feature_uri = C.dpm.get_data_path(freq).joinpath(C.dataset_cache_dir_name, feature_uri)
|
||||
df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields)
|
||||
get_module_logger("data").debug("finish slicing data")
|
||||
if return_uri:
|
||||
@@ -991,7 +991,7 @@ class BaseProvider:
|
||||
end_time=None,
|
||||
freq="day",
|
||||
disk_cache=None,
|
||||
inst_processors=None,
|
||||
inst_processors=[],
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
|
||||
@@ -156,8 +156,7 @@ class QlibDataLoader(DLWParser):
|
||||
filter_pipe: List = None,
|
||||
swap_level: bool = True,
|
||||
freq: Union[str, dict] = "day",
|
||||
sample_benchmark: str = None,
|
||||
sample_config: dict = None,
|
||||
inst_processor: dict = None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
@@ -168,6 +167,11 @@ class QlibDataLoader(DLWParser):
|
||||
Filter pipe for the instruments
|
||||
swap_level :
|
||||
Whether to swap level of MultiIndex
|
||||
freq: dict or str
|
||||
If type(config) == dict and type(freq) == str, load config data using freq.
|
||||
If type(config) == dict and type(freq) == dict, load config[<group_name>] data using freq[<group_name>]
|
||||
inst_processor: dict
|
||||
If inst_processor is not None and type(config) == dict; load config[<group_name>] data using inst_processor[<group_name>]
|
||||
"""
|
||||
if filter_pipe is not None:
|
||||
assert isinstance(filter_pipe, list), "The type of `filter_pipe` must be list."
|
||||
@@ -181,9 +185,9 @@ class QlibDataLoader(DLWParser):
|
||||
self.freq = freq
|
||||
|
||||
# sample
|
||||
self.sample_config = sample_config if sample_config is not None else {}
|
||||
self.sample_benchmark = sample_benchmark
|
||||
self.can_sample = False
|
||||
self.inst_processor = inst_processor if inst_processor is not None else {}
|
||||
assert isinstance(self.inst_processor, dict), f"inst_processor(={self.inst_processor}) must be dict"
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
if self.is_group:
|
||||
@@ -192,12 +196,9 @@ class QlibDataLoader(DLWParser):
|
||||
for _gp in config.keys():
|
||||
if _gp not in freq:
|
||||
raise ValueError(f"freq(={freq}) missing group(={_gp})")
|
||||
assert sample_config, f"freq(={self.freq}), sample_config(={sample_config}) cannot be None/empty"
|
||||
assert isinstance(sample_config, dict), f"sample_config(={sample_config}) must be dict"
|
||||
assert (
|
||||
self.sample_benchmark and self.sample_benchmark in self.fields
|
||||
), f"sample_benchmark not to specification"
|
||||
self.can_sample = True
|
||||
self.inst_processor
|
||||
), f"freq(={self.freq}), inst_processor(={self.inst_processor}) cannot be None/empty"
|
||||
|
||||
def load_group_df(
|
||||
self,
|
||||
@@ -216,33 +217,15 @@ class QlibDataLoader(DLWParser):
|
||||
elif self.filter_pipe is not None:
|
||||
warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")
|
||||
|
||||
freq = self.freq[gp_name] if self.can_sample else self.freq
|
||||
inst_processor = self.sample_config.get(gp_name, None) if self.can_sample else None
|
||||
df = D.features(instruments, exprs, start_time, end_time, freq=freq, inst_processors=inst_processor)
|
||||
freq = self.freq[gp_name] if isinstance(self.freq, dict) else self.freq
|
||||
df = D.features(
|
||||
instruments, exprs, start_time, end_time, freq=freq, inst_processors=self.inst_processor.get(gp_name, [])
|
||||
)
|
||||
df.columns = names
|
||||
if self.swap_level:
|
||||
df = df.swaplevel().sort_index() # NOTE: if swaplevel, return <datetime, instrument>
|
||||
return df
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
if self.is_group:
|
||||
group = {
|
||||
grp: self.load_group_df(instruments, exprs, names, start_time, end_time, grp)
|
||||
for grp, (exprs, names) in self.fields.items()
|
||||
}
|
||||
if self.can_sample:
|
||||
# reindex: alignment to index of sample_benchmark
|
||||
for grp, _df in group.items():
|
||||
if grp == self.sample_benchmark:
|
||||
continue
|
||||
else:
|
||||
group[grp] = _df.reindex(group[self.sample_benchmark].index)
|
||||
df = pd.concat(group, axis=1)
|
||||
else:
|
||||
exprs, names = self.fields
|
||||
df = self.load_group_df(instruments, exprs, names, start_time, end_time)
|
||||
return df
|
||||
|
||||
|
||||
class StaticDataLoader(DataLoader):
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
import json
|
||||
import pandas as pd
|
||||
|
||||
|
||||
@@ -18,13 +19,5 @@ class InstProcessor:
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ResampleProcessor(InstProcessor):
|
||||
"""resample data"""
|
||||
|
||||
def __init__(self, freq: str, func: str, *args, **kwargs):
|
||||
self.freq = freq
|
||||
self.func = func
|
||||
|
||||
def __call__(self, df: pd.DataFrame, *args, **kwargs):
|
||||
return getattr(df.resample(self.freq, level="datetime"), self.func)().dropna(how="all")
|
||||
def __str__(self):
|
||||
return f"{self.__class__.__name__}:{json.dumps(self.__dict__, sort_keys=True, default=str)}"
|
||||
|
||||
Reference in New Issue
Block a user