diff --git a/examples/benchmarks/LightGBM/features_sample.py b/examples/benchmarks/LightGBM/features_sample.py index 72fffb79d..71342ad54 100644 --- a/examples/benchmarks/LightGBM/features_sample.py +++ b/examples/benchmarks/LightGBM/features_sample.py @@ -4,9 +4,8 @@ import pandas as pd from qlib.data.inst_processor import InstProcessor -class ResampleProcessor(InstProcessor): - def __init__(self, freq: str, hour: int, minute: int): - self.freq = freq +class Resample1minProcessor(InstProcessor): + def __init__(self, hour: int, minute: int): self.hour = hour self.minute = minute diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_multi_freq.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_multi_freq.yaml index dde58f293..b37ffed4c 100644 --- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_multi_freq.yaml +++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_multi_freq.yaml @@ -3,6 +3,8 @@ qlib_init: day: "~/.qlib/qlib_data/cn_data" 1min: "~/.qlib/qlib_data/cn_data_1min" region: cn + dataset_cache: null + maxtasksperchild: 1 market: &market csi300 benchmark: &benchmark SH000300 data_handler_config: &data_handler_config @@ -16,10 +18,9 @@ data_handler_config: &data_handler_config label: day feature: 1min # with label as reference - sample_benchmark: label - sample_config: + inst_processor: feature: - - class: ResampleProcessor + - class: Resample1minProcessor module_path: features_sample.py kwargs: freq: 1d diff --git a/qlib/__init__.py b/qlib/__init__.py index 476a851d9..347674a12 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -33,41 +33,30 @@ def init(default_conf="client", **kwargs): H.clear() C.set(default_conf, **kwargs) - provider_uri_map = C.provider_uri if isinstance(C.provider_uri, dict) else {None: C.provider_uri} - if C.mount_path is not None: - # mount nfs - for _freq, provider_uri in provider_uri_map.items(): - mount_path = C.mount_path if _freq is None else C["mount_path"][_freq] - # check path if server/local - uri_type = C.get_uri_type(provider_uri) - if uri_type == C.LOCAL_URI: - if not Path(provider_uri).exists(): - if C["auto_mount"]: - logger.error( - f"Invalid provider uri: {provider_uri}, please check if a valid provider uri has been set. This path does not exist." - ) - else: - logger.warning(f"auto_path is False, please make sure {mount_path} is mounted") - elif uri_type == C.NFS_URI: - mount_path = _mount_nfs_uri(provider_uri, mount_path, C["auto_mount"]) - if _freq is None: - C["mount_path"] = mount_path + # mount nfs + for _freq, provider_uri in C.provider_uri.items(): + mount_path = C["mount_path"][_freq] + # check path if server/local + uri_type = C.dpm.get_uri_type(provider_uri) + if uri_type == C.LOCAL_URI: + if not Path(provider_uri).exists(): + if C["auto_mount"]: + logger.error( + f"Invalid provider uri: {provider_uri}, please check if a valid provider uri has been set. This path does not exist." + ) else: - C["mount_path"][_freq] = mount_path - else: - raise NotImplementedError(f"This type of URI is not supported") + logger.warning(f"auto_path is False, please make sure {mount_path} is mounted") + elif uri_type == C.NFS_URI: + _mount_nfs_uri(provider_uri, mount_path, C["auto_mount"]) + else: + raise NotImplementedError(f"This type of URI is not supported") C.register() if "flask_server" in C: logger.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}") logger.info("qlib successfully initialized based on %s settings." % default_conf) - data_path = ( - C.get_data_path() - if isinstance(C["provider_uri"], str) - else {_freq: C.get_data_path(_freq) for _freq in C["provider_uri"].keys()} - ) - logger.info(f"data_path={data_path}") + logger.info(f"data_path={C.dpm.provider_uri}") def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False): @@ -102,8 +91,6 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False): else: raise OSError(f"unknown error: {result}") - # config mount path - mount_path += ":\\" else: # system: linux/Unix/Mac # check mount @@ -156,7 +143,6 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False): LOG.info("Mount finished") else: LOG.warning(f"{_remote_uri} on {_mount_path} is already mounted") - return mount_path def init_from_yaml_conf(conf_path, **kwargs): diff --git a/qlib/config.py b/qlib/config.py index 4e2b27f8c..c74e43717 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -15,6 +15,7 @@ import os import re import copy import logging +import platform import multiprocessing from pathlib import Path from typing import Union @@ -238,11 +239,43 @@ class QlibConfig(Config): # URI_TYPE LOCAL_URI = "local" NFS_URI = "nfs" + DEFAULT_FREQ = "__DEFAULT_FREQ" def __init__(self, default_conf): super().__init__(default_conf) self._registered = False + class DataPathManager: + def __init__(self, provider_uri: Union[str, Path, dict], mount_path: Union[str, Path, dict]): + self.provider_uri = provider_uri + self.mount_path = mount_path + + @staticmethod + def get_uri_type(uri: Union[str, Path]): + uri = uri if isinstance(uri, str) else str(uri.expanduser().resolve()) + is_win = re.match("^[a-zA-Z]:.*", uri) is not None # such as 'C:\\data', 'D:' + # such as 'host:/data/' (User may define short hostname by themselves or use localhost) + is_nfs_or_win = re.match("^[^/]+:.+", uri) is not None + + if is_nfs_or_win and not is_win: + return QlibConfig.NFS_URI + else: + return QlibConfig.LOCAL_URI + + def get_data_path(self, freq: str = None) -> Path: + if freq is None or freq not in self.provider_uri: + freq = QlibConfig.DEFAULT_FREQ + _provider_uri = self.provider_uri[freq] + if self.get_uri_type(_provider_uri) == QlibConfig.LOCAL_URI: + return Path(_provider_uri) + elif self.get_uri_type(_provider_uri) == QlibConfig.NFS_URI: + if "win" in platform.system().lower(): + # windows, mount_path is the drive + return Path(f"{self.mount_path[freq]}:\\") + return Path(self.mount_path[freq]) + else: + raise NotImplementedError(f"This type of uri is not supported") + def set_mode(self, mode): # raise KeyError self.update(MODE_CONF[mode]) @@ -252,59 +285,39 @@ class QlibConfig(Config): # raise KeyError self.update(_default_region_config[region]) + @property + def dpm(self): + return self.DataPathManager(self["provider_uri"], self["mount_path"]) + def resolve_path(self): # resolve path _mount_path = self["mount_path"] _provider_uri = self["provider_uri"] - if isinstance(_provider_uri, dict): - if _mount_path is not None: - # check provider_uri and mount_path - assert isinstance(_mount_path, dict), f"type(provider_uri) != type(mount_path); {_mount_path}" - _miss_freq = set(_provider_uri.keys()) - set(_mount_path.keys()) - assert len(_miss_freq) == 0, f"mount_path is missing freq: {_miss_freq}" - self["mount_path"] = {_freq: str(Path(_path).expanduser().resolve()) for _freq, _path in _mount_path} - for _freq, _uri in _provider_uri.items(): - if self.get_uri_type(_uri) == QlibConfig.LOCAL_URI: - self["provider_uri"][_freq] = str(Path(_uri).expanduser().resolve()) - elif isinstance(_provider_uri, str): - if _mount_path is not None: - self["mount_path"] = str(Path(_mount_path).expanduser().resolve()) + if _provider_uri is None: + raise ValueError("provider_uri cannot be None") + if not isinstance(_provider_uri, dict): + _provider_uri = {self.DEFAULT_FREQ: _provider_uri} + if not isinstance(_mount_path, dict): + _mount_path = {_freq: _mount_path for _freq in _provider_uri.keys()} - if self.get_uri_type(_provider_uri) == QlibConfig.LOCAL_URI: - self["provider_uri"] = str(Path(_provider_uri).expanduser().resolve()) - else: - raise TypeError( - f"The types supported by provider_uri are [str, dict], " f"not {type(_provider_uri)}: {_provider_uri}" + # check provider_uri and mount_path + _miss_freq = set(_provider_uri.keys()) - set(_mount_path.keys()) + assert len(_miss_freq) == 0, f"mount_path is missing freq: {_miss_freq}" + + # resolve + for _freq, _uri in _provider_uri.items(): + # provider_uri + if self.DataPathManager.get_uri_type(_uri) == QlibConfig.LOCAL_URI: + _provider_uri[_freq] = str(Path(_uri).expanduser().resolve()) + # mount_path + _mount_path[_freq] = ( + _mount_path[_freq] + if _mount_path[_freq] is None + else str(Path(_mount_path[_freq]).expanduser().resolve()) ) - @staticmethod - def get_uri_type(uri: Union[str, Path]): - uri = uri if isinstance(uri, str) else str(uri.expanduser().resolve()) - is_win = re.match("^[a-zA-Z]:.*", uri) is not None # such as 'C:\\data', 'D:' - # such as 'host:/data/' (User may define short hostname by themselves or use localhost) - is_nfs_or_win = re.match("^[^/]+:.+", uri) is not None - - if is_nfs_or_win and not is_win: - return QlibConfig.NFS_URI - else: - return QlibConfig.LOCAL_URI - - def get_data_path(self, freq: str = None) -> Path: - if freq is None and not isinstance(self["provider_uri"], str): - raise ValueError(f"type(provider_uri) == dict, freq cannot be None; provider_uri: {self['provider_uri']}") - _provider_uri = self["provider_uri"] if isinstance(self["provider_uri"], str) else self["provider_uri"][freq] - _mount_path = ( - self["mount_path"] - if self["mount_path"] is None or isinstance(self["mount_path"], str) - else self["mount_path"][freq] - ) - - if self.get_uri_type(_provider_uri) == QlibConfig.LOCAL_URI: - return Path(_provider_uri) - elif self.get_uri_type(_provider_uri) == QlibConfig.NFS_URI: - return Path(_mount_path) - else: - raise NotImplementedError(f"This type of uri is not supported") + self["provider_uri"] = _provider_uri + self["mount_path"] = _mount_path def set(self, default_conf="client", **kwargs): from .utils import set_log_with_config, get_module_logger, can_use_cache diff --git a/qlib/contrib/backtest/profit_attribution.py b/qlib/contrib/backtest/profit_attribution.py index 2fdebc591..df5dd965d 100644 --- a/qlib/contrib/backtest/profit_attribution.py +++ b/qlib/contrib/backtest/profit_attribution.py @@ -35,7 +35,7 @@ def get_benchmark_weight( """ if not path: - path = Path(C.get_data_path(freq)).expanduser() / "raw" / "AIndexMembers" / "weights.csv" + path = Path(C.dpm.get_data_path(freq)).expanduser() / "raw" / "AIndexMembers" / "weights.csv" # TODO: the storage of weights should be implemented in a more elegent way # TODO: The benchmark is not consistant with the filename in instruments. bench_weight_df = pd.read_csv(path, usecols=["code", "date", "index", "weight"]) diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py index 1254c0d26..ccd753006 100644 --- a/qlib/contrib/data/handler.py +++ b/qlib/contrib/data/handler.py @@ -58,8 +58,7 @@ class Alpha360(DataHandlerLP): fit_start_time=None, fit_end_time=None, filter_pipe=None, - sample_config=None, - sample_benchmark=None, + inst_processor=None, **kwargs, ): infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) @@ -74,8 +73,7 @@ class Alpha360(DataHandlerLP): }, "filter_pipe": filter_pipe, "freq": freq, - "sample_config": sample_config, - "sample_benchmark": sample_benchmark, + "inst_processor": inst_processor, }, } @@ -148,8 +146,7 @@ class Alpha158(DataHandlerLP): fit_end_time=None, process_type=DataHandlerLP.PTYPE_A, filter_pipe=None, - sample_config=None, - sample_benchmark=None, + inst_processor=None, **kwargs, ): infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) @@ -164,8 +161,7 @@ class Alpha158(DataHandlerLP): }, "filter_pipe": filter_pipe, "freq": freq, - "sample_config": sample_config, - "sample_benchmark": sample_benchmark, + "inst_processor": inst_processor, }, } super().__init__( diff --git a/qlib/data/cache.py b/qlib/data/cache.py index 271343a01..16506e0ea 100644 --- a/qlib/data/cache.py +++ b/qlib/data/cache.py @@ -319,7 +319,7 @@ class BaseProviderCache: @staticmethod def get_cache_dir(dir_name: str, freq: str = None) -> Path: - cache_dir = Path(C.get_data_path(freq)).joinpath(dir_name) + cache_dir = Path(C.dpm.get_data_path(freq)).joinpath(dir_name) cache_dir.mkdir(parents=True, exist_ok=True) return cache_dir @@ -356,7 +356,7 @@ class ExpressionCache(BaseProviderCache): """ raise NotImplementedError("Implement this method if you want to use expression cache") - def update(self, cache_uri: Union[str, Path]): + def update(self, cache_uri: Union[str, Path], freq: str = "day"): """Update expression cache to latest calendar. Overide this method to define how to update expression cache corresponding to users' own cache mechanism. @@ -365,6 +365,7 @@ class ExpressionCache(BaseProviderCache): ---------- cache_uri : str or Path the complete uri of expression cache file (include dir path). + freq : str Returns ------- @@ -385,7 +386,7 @@ class DatasetCache(BaseProviderCache): HDF_KEY = "df" def dataset( - self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None + self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[] ): """Get feature dataset. @@ -419,7 +420,7 @@ class DatasetCache(BaseProviderCache): raise NotImplementedError("Implement this function to match your own cache mechanism") def _dataset( - self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None + self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[] ): """Get feature dataset using cache. @@ -428,7 +429,7 @@ class DatasetCache(BaseProviderCache): raise NotImplementedError("Implement this method if you want to use dataset feature cache") def _dataset_uri( - self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None + self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[] ): """Get a uri of feature dataset using cache. specially: @@ -441,7 +442,7 @@ class DatasetCache(BaseProviderCache): "Implement this method if you want to use dataset feature cache as a cache file for client" ) - def update(self, cache_uri: Union[str, Path]): + def update(self, cache_uri: Union[str, Path], freq: str = "day"): """Update dataset cache to latest calendar. Overide this method to define how to update dataset cache corresponding to users' own cache mechanism. @@ -450,6 +451,7 @@ class DatasetCache(BaseProviderCache): ---------- cache_uri : str or Path the complete uri of dataset cache file (include dir path). + freq : str Returns ------- @@ -542,7 +544,7 @@ class DiskExpressionCache(ExpressionCache): series = self.provider.expression(instrument, field, _calendar[0], _calendar[-1], freq) if not series.empty: # This expresion is empty, we don't generate any cache for it. - with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path(freq))}:expression-{_cache_uri}"): + with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:expression-{_cache_uri}"): self.gen_expression_cache( expression_data=series, cache_path=cache_path, @@ -578,18 +580,16 @@ class DiskExpressionCache(ExpressionCache): r = np.hstack([df.index[0], expression_data]).astype(" Path: @@ -692,7 +692,7 @@ class DiskDatasetCache(DatasetCache): return df def _dataset( - self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=None + self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[] ): if disk_cache == 0: @@ -701,7 +701,7 @@ class DiskDatasetCache(DatasetCache): instruments, fields, start_time, end_time, freq, inst_processors=inst_processors ) # FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date - if not inst_processors: + if inst_processors: raise ValueError( f"{self.__class__.__name__} does not support inst_processor. " f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`" @@ -724,7 +724,7 @@ class DiskDatasetCache(DatasetCache): if self.check_cache_exists(cache_path): if disk_cache == 1: # use cache - with CacheUtils.reader_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"): + with CacheUtils.reader_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"): CacheUtils.visit(cache_path) features = self.read_data_from_cache(cache_path, start_time, end_time, fields) elif disk_cache == 2: @@ -734,7 +734,7 @@ class DiskDatasetCache(DatasetCache): if gen_flag: # cache unavailable, generate the cache - with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"): + with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"): features = self.gen_dataset_cache( cache_path=cache_path, instruments=instruments, @@ -747,7 +747,7 @@ class DiskDatasetCache(DatasetCache): return features def _dataset_uri( - self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=None + self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[] ): if disk_cache == 0: # In this case, server only checks the expression cache. @@ -775,12 +775,12 @@ class DiskDatasetCache(DatasetCache): if self.check_cache_exists(cache_path): self.logger.debug(f"The cache dataset has already existed {cache_path}. Return the uri directly") - with CacheUtils.reader_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"): + with CacheUtils.reader_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"): CacheUtils.visit(cache_path) return _cache_uri else: # cache unavailable, generate the cache - with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"): + with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"): self.gen_dataset_cache( cache_path=cache_path, instruments=instruments, @@ -854,7 +854,7 @@ class DiskDatasetCache(DatasetCache): index_data += start_index return index_data - def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, freq, inst_processors=None): + def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, freq, inst_processors=[]): """gen_dataset_cache .. note:: This function does not consider the cache read write lock. Please @@ -949,10 +949,8 @@ class DiskDatasetCache(DatasetCache): # the fields of the cached features are converted to the original fields return features.swaplevel("datetime", "instrument") - def update(self, cache_uri): - # FIXME: when updating the cache, the type of C.provider_uri must be str - assert isinstance(C.provider_uri, str), "when updating the cache, the type of C.provider_uri must be str" - cp_cache_uri = self.get_cache_dir().joinpath(cache_uri) + def update(self, cache_uri, freq: str = "day"): + cp_cache_uri = self.get_cache_dir(freq).joinpath(cache_uri) meta_path = cp_cache_uri.with_suffix(".meta") if not self.check_cache_exists(cp_cache_uri): self.logger.info(f"The cache {cp_cache_uri} has corrupted. It will be removed") @@ -960,7 +958,7 @@ class DiskDatasetCache(DatasetCache): return 2 im = DiskDatasetCache.IndexManager(cp_cache_uri) - with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path())}:dataset-{cache_uri}"): + with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path())}:dataset-{cache_uri}"): with meta_path.open("rb") as f: d = pickle.load(f) instruments = d["info"]["instruments"] @@ -1073,14 +1071,14 @@ class SimpleDatasetCache(DatasetCache): except KeyError as e: self.logger.error("Assign a local_cache_path in config if you want to use this cache mechanism") - def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=None, **kwargs): + def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs): instruments, fields, freq = self.normalize_uri_args(instruments, fields, freq) return hash_args( instruments, fields, start_time, end_time, freq, disk_cache, str(self.local_cache_path), inst_processors ) def _dataset( - self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None + self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[] ): if disk_cache == 0: # In this case, data_set cache is configured but will not be used. @@ -1115,11 +1113,11 @@ class SimpleDatasetCache(DatasetCache): class DatasetURICache(DatasetCache): """Prepared cache mechanism for server.""" - def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=None, **kwargs): + def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs): return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors) def dataset( - self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=None + self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[] ): if "local" in C.dataset_provider.lower(): @@ -1151,7 +1149,7 @@ class DatasetURICache(DatasetCache): instruments, fields, None, None, freq, disk_cache=disk_cache, inst_processors=inst_processors ) value, expire = MemCacheExpire.get_cache(H["f"], feature_uri) - mnt_feature_uri = C.get_data_path(freq).joinpath(C.dataset_cache_dir_name).joinpath(feature_uri) + mnt_feature_uri = C.dpm.get_data_path(freq).joinpath(C.dataset_cache_dir_name).joinpath(feature_uri) if value is None or expire or not mnt_feature_uri.exists(): df, uri = self.provider.dataset( instruments, diff --git a/qlib/data/data.py b/qlib/data/data.py index e8fb52849..1830b1073 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -61,7 +61,7 @@ class ProviderBackendMixin: provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {}) freq = kwargs.get("freq", "day") if freq not in provider_uri_map: - provider_uri_map[freq] = C.get_data_path(freq) + provider_uri_map[freq] = C.dpm.get_data_path(freq) backend_kwargs["provider_uri"] = provider_uri_map[freq] backend.setdefault("kwargs", {}).update(**kwargs) return init_instance_by_config(backend) @@ -353,7 +353,7 @@ class DatasetProvider(abc.ABC): """ @abc.abstractmethod - def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", inst_processors=None): + def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", inst_processors=[]): """Get dataset data. Parameters @@ -386,7 +386,7 @@ class DatasetProvider(abc.ABC): end_time=None, freq="day", disk_cache=1, - inst_processors=None, + inst_processors=[], **kwargs, ): """Get task uri, used when generating rabbitmq task in qlib_server @@ -449,7 +449,7 @@ class DatasetProvider(abc.ABC): return [ExpressionD.get_expression_instance(f) for f in fields] @staticmethod - def dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors=None): + def dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors=[]): """ Load and process the data, return the data set. - default using multi-kernel method. @@ -513,7 +513,7 @@ class DatasetProvider(abc.ABC): @staticmethod def expression_calculator( - inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=None + inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=[] ): """ Calculate the expressions for one instrument, return a df result. @@ -544,10 +544,10 @@ class DatasetProvider(abc.ABC): mask |= (data.index >= begin) & (data.index <= end) data = data[mask] - if inst_processors is not None: - for _processor in inst_processors if isinstance(inst_processors, (list, tuple, set)) else [inst_processors]: - _processor = init_instance_by_config(_processor, accept_types=InstProcessor) - data = _processor(data) + for _processor in inst_processors: + if _processor: + _processor_obj = init_instance_by_config(_processor, accept_types=InstProcessor) + data = _processor_obj(data) return data @@ -719,7 +719,7 @@ class LocalDatasetProvider(DatasetProvider): start_time=None, end_time=None, freq="day", - inst_processors=None, + inst_processors=[], ): instruments_d = self.get_instruments_d(instruments, freq) column_names = self.get_column_names(fields) @@ -874,7 +874,7 @@ class ClientDatasetProvider(DatasetProvider): freq="day", disk_cache=0, return_uri=False, - inst_processors=None, + inst_processors=[], ): if Inst.get_inst_type(instruments) == Inst.DICT: get_module_logger("data").warning( @@ -914,7 +914,7 @@ class ClientDatasetProvider(DatasetProvider): start_time = cal[0] end_time = cal[-1] - data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq) + data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors) if return_uri: return data, feature_uri else: @@ -953,7 +953,7 @@ class ClientDatasetProvider(DatasetProvider): get_module_logger("data").debug("get result") try: # pre-mound nfs, used for demo - mnt_feature_uri = C.get_data_path(freq).joinpath(C.dataset_cache_dir_name, feature_uri) + mnt_feature_uri = C.dpm.get_data_path(freq).joinpath(C.dataset_cache_dir_name, feature_uri) df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields) get_module_logger("data").debug("finish slicing data") if return_uri: @@ -991,7 +991,7 @@ class BaseProvider: end_time=None, freq="day", disk_cache=None, - inst_processors=None, + inst_processors=[], ): """ Parameters: diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 4b4df2eb6..080f3dd4d 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -156,8 +156,7 @@ class QlibDataLoader(DLWParser): filter_pipe: List = None, swap_level: bool = True, freq: Union[str, dict] = "day", - sample_benchmark: str = None, - sample_config: dict = None, + inst_processor: dict = None, ): """ Parameters @@ -168,6 +167,11 @@ class QlibDataLoader(DLWParser): Filter pipe for the instruments swap_level : Whether to swap level of MultiIndex + freq: dict or str + If type(config) == dict and type(freq) == str, load config data using freq. + If type(config) == dict and type(freq) == dict, load config[] data using freq[] + inst_processor: dict + If inst_processor is not None and type(config) == dict; load config[] data using inst_processor[] """ if filter_pipe is not None: assert isinstance(filter_pipe, list), "The type of `filter_pipe` must be list." @@ -181,9 +185,9 @@ class QlibDataLoader(DLWParser): self.freq = freq # sample - self.sample_config = sample_config if sample_config is not None else {} - self.sample_benchmark = sample_benchmark - self.can_sample = False + self.inst_processor = inst_processor if inst_processor is not None else {} + assert isinstance(self.inst_processor, dict), f"inst_processor(={self.inst_processor}) must be dict" + super().__init__(config) if self.is_group: @@ -192,12 +196,9 @@ class QlibDataLoader(DLWParser): for _gp in config.keys(): if _gp not in freq: raise ValueError(f"freq(={freq}) missing group(={_gp})") - assert sample_config, f"freq(={self.freq}), sample_config(={sample_config}) cannot be None/empty" - assert isinstance(sample_config, dict), f"sample_config(={sample_config}) must be dict" assert ( - self.sample_benchmark and self.sample_benchmark in self.fields - ), f"sample_benchmark not to specification" - self.can_sample = True + self.inst_processor + ), f"freq(={self.freq}), inst_processor(={self.inst_processor}) cannot be None/empty" def load_group_df( self, @@ -216,33 +217,15 @@ class QlibDataLoader(DLWParser): elif self.filter_pipe is not None: warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list") - freq = self.freq[gp_name] if self.can_sample else self.freq - inst_processor = self.sample_config.get(gp_name, None) if self.can_sample else None - df = D.features(instruments, exprs, start_time, end_time, freq=freq, inst_processors=inst_processor) + freq = self.freq[gp_name] if isinstance(self.freq, dict) else self.freq + df = D.features( + instruments, exprs, start_time, end_time, freq=freq, inst_processors=self.inst_processor.get(gp_name, []) + ) df.columns = names if self.swap_level: df = df.swaplevel().sort_index() # NOTE: if swaplevel, return return df - def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: - if self.is_group: - group = { - grp: self.load_group_df(instruments, exprs, names, start_time, end_time, grp) - for grp, (exprs, names) in self.fields.items() - } - if self.can_sample: - # reindex: alignment to index of sample_benchmark - for grp, _df in group.items(): - if grp == self.sample_benchmark: - continue - else: - group[grp] = _df.reindex(group[self.sample_benchmark].index) - df = pd.concat(group, axis=1) - else: - exprs, names = self.fields - df = self.load_group_df(instruments, exprs, names, start_time, end_time) - return df - class StaticDataLoader(DataLoader): """ diff --git a/qlib/data/inst_processor.py b/qlib/data/inst_processor.py index 34c2d517b..27b356722 100644 --- a/qlib/data/inst_processor.py +++ b/qlib/data/inst_processor.py @@ -1,4 +1,5 @@ import abc +import json import pandas as pd @@ -18,13 +19,5 @@ class InstProcessor: """ pass - -class ResampleProcessor(InstProcessor): - """resample data""" - - def __init__(self, freq: str, func: str, *args, **kwargs): - self.freq = freq - self.func = func - - def __call__(self, df: pd.DataFrame, *args, **kwargs): - return getattr(df.resample(self.freq, level="datetime"), self.func)().dropna(how="all") + def __str__(self): + return f"{self.__class__.__name__}:{json.dumps(self.__dict__, sort_keys=True, default=str)}"