diff --git a/qlib/__init__.py b/qlib/__init__.py index 6f76bbcaa..476a851d9 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -33,48 +33,63 @@ def init(default_conf="client", **kwargs): H.clear() C.set(default_conf, **kwargs) - # check path if server/local - if C.get_uri_type() == C.LOCAL_URI: - if not os.path.exists(C["provider_uri"]): - if C["auto_mount"]: - logger.error( - f"Invalid provider uri: {C['provider_uri']}, please check if a valid provider uri has been set. This path does not exist." - ) + provider_uri_map = C.provider_uri if isinstance(C.provider_uri, dict) else {None: C.provider_uri} + if C.mount_path is not None: + # mount nfs + for _freq, provider_uri in provider_uri_map.items(): + mount_path = C.mount_path if _freq is None else C["mount_path"][_freq] + # check path if server/local + uri_type = C.get_uri_type(provider_uri) + if uri_type == C.LOCAL_URI: + if not Path(provider_uri).exists(): + if C["auto_mount"]: + logger.error( + f"Invalid provider uri: {provider_uri}, please check if a valid provider uri has been set. This path does not exist." + ) + else: + logger.warning(f"auto_path is False, please make sure {mount_path} is mounted") + elif uri_type == C.NFS_URI: + mount_path = _mount_nfs_uri(provider_uri, mount_path, C["auto_mount"]) + if _freq is None: + C["mount_path"] = mount_path + else: + C["mount_path"][_freq] = mount_path else: - logger.warning(f"auto_path is False, please make sure {C['mount_path']} is mounted") - elif C.get_uri_type() == C.NFS_URI: - _mount_nfs_uri(C) - else: - raise NotImplementedError(f"This type of URI is not supported") + raise NotImplementedError(f"This type of URI is not supported") C.register() if "flask_server" in C: logger.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}") logger.info("qlib successfully initialized based on %s settings." % default_conf) - logger.info(f"data_path={C.get_data_path()}") + data_path = ( + C.get_data_path() + if isinstance(C["provider_uri"], str) + else {_freq: C.get_data_path(_freq) for _freq in C["provider_uri"].keys()} + ) + logger.info(f"data_path={data_path}") -def _mount_nfs_uri(C): +def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False): LOG = get_module_logger("mount nfs", level=logging.INFO) # FIXME: the C["provider_uri"] is modified in this function # If it is not modified, we can pass only provider_uri or mount_path instead of C - mount_command = "sudo mount.nfs %s %s" % (C["provider_uri"], C["mount_path"]) + mount_command = "sudo mount.nfs %s %s" % (provider_uri, mount_path) # If the provider uri looks like this 172.23.233.89//data/csdesign' # It will be a nfs path. The client provider will be used - if not C["auto_mount"]: - if not os.path.exists(C["mount_path"]): + if not auto_mount: + if not Path(mount_path).exists(): raise FileNotFoundError( - f"Invalid mount path: {C['mount_path']}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`" + f"Invalid mount path: {mount_path}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`" ) else: # Judging system type sys_type = platform.system() if "win" in sys_type.lower(): # system: window - exec_result = os.popen("mount -o anon %s %s" % (C["provider_uri"], C["mount_path"] + ":")) + exec_result = os.popen("mount -o anon %s %s" % (provider_uri, mount_path + ":")) result = exec_result.read() if "85" in result: LOG.warning("already mounted or window mount path already exists") @@ -82,20 +97,18 @@ def _mount_nfs_uri(C): raise OSError("not find network path") elif "error" in result or "错误" in result: raise OSError("Invalid mount path") - elif C["provider_uri"] in result: + elif provider_uri in result: LOG.info("window success mount..") else: raise OSError(f"unknown error: {result}") # config mount path - C["mount_path"] = C["mount_path"] + ":\\" + mount_path += ":\\" else: # system: linux/Unix/Mac # check mount - _remote_uri = C["provider_uri"] - _remote_uri = _remote_uri[:-1] if _remote_uri.endswith("/") else _remote_uri - _mount_path = C["mount_path"] - _mount_path = _mount_path[:-1] if _mount_path.endswith("/") else _mount_path + _remote_uri = provider_uri[:-1] if provider_uri.endswith("/") else provider_uri + _mount_path = mount_path[:-1] if mount_path.endswith("/") else mount_path _check_level_num = 2 _is_mount = False while _check_level_num: @@ -121,11 +134,9 @@ def _mount_nfs_uri(C): if not _is_mount: try: - os.makedirs(C["mount_path"], exist_ok=True) + Path(mount_path).mkdir(parents=True, exist_ok=True) except Exception: - raise OSError( - f"Failed to create directory {C['mount_path']}, please create {C['mount_path']} manually!" - ) + raise OSError(f"Failed to create directory {mount_path}, please create {mount_path} manually!") # check nfs-common command_res = os.popen("dpkg -l | grep nfs-common") @@ -136,15 +147,16 @@ def _mount_nfs_uri(C): command_status = os.system(mount_command) if command_status == 256: raise OSError( - f"mount {C['provider_uri']} on {C['mount_path']} error! Needs SUDO! Please mount manually: {mount_command}" + f"mount {provider_uri} on {mount_path} error! Needs SUDO! Please mount manually: {mount_command}" ) elif command_status == 32512: # LOG.error("Command error") - raise OSError(f"mount {C['provider_uri']} on {C['mount_path']} error! Command error") + raise OSError(f"mount {provider_uri} on {mount_path} error! Command error") elif command_status == 0: LOG.info("Mount finished") else: LOG.warning(f"{_remote_uri} on {_mount_path} is already mounted") + return mount_path def init_from_yaml_conf(conf_path, **kwargs): diff --git a/qlib/config.py b/qlib/config.py index be29180f8..4e2b27f8c 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -17,6 +17,7 @@ import copy import logging import multiprocessing from pathlib import Path +from typing import Union class Config: @@ -82,25 +83,16 @@ _default_config = { "dataset_provider": "LocalDatasetProvider", "provider": "LocalProvider", # config it in qlib.init() + # "provider_uri" str or dict: + # # str + # "~/.qlib/stock_data/cn_data" + # # dict + # {"day": "~/.qlib/stock_data/cn_data", "1min": "~/.qlib/stock_data/cn_data_1min"} + # NOTE: provider_uri priority: + # 1. backend_config: backend_obj["kwargs"]["provider_uri"] + # 2. backend_config: backend_obj["kwargs"]["provider_uri_map"] + # 3. qlib.init: provider_uri "provider_uri": "", - # backend_freq_config is dict: {"day": "1d dataset uri", "1min": "1min dataset uri"} - # If `backend_freq_config` is not None && is not empty && "freq" in `backend_freq_config.keys()` - # use `backend_freq_config` as backend-uri - # else: - # use `provider_uri` as backend-uri - # Examples: - # qlib.init(provider_uri="qlib_data/1d", backend_freq_config={"1min": "qlib_data/1min"}) - # # using provider_uri - # D.features(D.instruments("all"), ["$close"], freq="day") - # # using backend_freq_config["1min"] - # D.features(D.instruments("all"), ["$close"], freq="1min") - # ######################## - # qlib.init(provider_uri="qlib_data/1d", backend_freq_config={"1min": "qlib_data/1min", "day": "qlib_data/day"}) - # # using backend_freq_config["day"] - # D.features(D.instruments("all"), ["$close"], freq="day") - # # raise ValueError - # D.features(D.instruments("all"), ["$close"], freq="week") - "backend_freq_config": None, # cache "expression_cache": None, "dataset_cache": None, @@ -262,28 +254,55 @@ class QlibConfig(Config): def resolve_path(self): # resolve path - if self["mount_path"] is not None: - self["mount_path"] = str(Path(self["mount_path"]).expanduser().resolve()) + _mount_path = self["mount_path"] + _provider_uri = self["provider_uri"] + if isinstance(_provider_uri, dict): + if _mount_path is not None: + # check provider_uri and mount_path + assert isinstance(_mount_path, dict), f"type(provider_uri) != type(mount_path); {_mount_path}" + _miss_freq = set(_provider_uri.keys()) - set(_mount_path.keys()) + assert len(_miss_freq) == 0, f"mount_path is missing freq: {_miss_freq}" + self["mount_path"] = {_freq: str(Path(_path).expanduser().resolve()) for _freq, _path in _mount_path} + for _freq, _uri in _provider_uri.items(): + if self.get_uri_type(_uri) == QlibConfig.LOCAL_URI: + self["provider_uri"][_freq] = str(Path(_uri).expanduser().resolve()) + elif isinstance(_provider_uri, str): + if _mount_path is not None: + self["mount_path"] = str(Path(_mount_path).expanduser().resolve()) - if self.get_uri_type() == QlibConfig.LOCAL_URI: - self["provider_uri"] = str(Path(self["provider_uri"]).expanduser().resolve()) + if self.get_uri_type(_provider_uri) == QlibConfig.LOCAL_URI: + self["provider_uri"] = str(Path(_provider_uri).expanduser().resolve()) + else: + raise TypeError( + f"The types supported by provider_uri are [str, dict], " f"not {type(_provider_uri)}: {_provider_uri}" + ) - def get_uri_type(self): - is_win = re.match("^[a-zA-Z]:.*", self["provider_uri"]) is not None # such as 'C:\\data', 'D:' - is_nfs_or_win = ( - re.match("^[^/]+:.+", self["provider_uri"]) is not None - ) # such as 'host:/data/' (User may define short hostname by themselves or use localhost) + @staticmethod + def get_uri_type(uri: Union[str, Path]): + uri = uri if isinstance(uri, str) else str(uri.expanduser().resolve()) + is_win = re.match("^[a-zA-Z]:.*", uri) is not None # such as 'C:\\data', 'D:' + # such as 'host:/data/' (User may define short hostname by themselves or use localhost) + is_nfs_or_win = re.match("^[^/]+:.+", uri) is not None if is_nfs_or_win and not is_win: return QlibConfig.NFS_URI else: return QlibConfig.LOCAL_URI - def get_data_path(self): - if self.get_uri_type() == QlibConfig.LOCAL_URI: - return self["provider_uri"] - elif self.get_uri_type() == QlibConfig.NFS_URI: - return self["mount_path"] + def get_data_path(self, freq: str = None) -> Path: + if freq is None and not isinstance(self["provider_uri"], str): + raise ValueError(f"type(provider_uri) == dict, freq cannot be None; provider_uri: {self['provider_uri']}") + _provider_uri = self["provider_uri"] if isinstance(self["provider_uri"], str) else self["provider_uri"][freq] + _mount_path = ( + self["mount_path"] + if self["mount_path"] is None or isinstance(self["mount_path"], str) + else self["mount_path"][freq] + ) + + if self.get_uri_type(_provider_uri) == QlibConfig.LOCAL_URI: + return Path(_provider_uri) + elif self.get_uri_type(_provider_uri) == QlibConfig.NFS_URI: + return Path(_mount_path) else: raise NotImplementedError(f"This type of uri is not supported") @@ -318,7 +337,8 @@ class QlibConfig(Config): # check redis if not can_use_cache(): logger.warning( - f"redis connection failed(host={self['redis_host']} port={self['redis_port']}), cache will not be used!" + f"redis connection failed(host={self['redis_host']} port={self['redis_port']}), " + f"cache will not be used!" ) self["expression_cache"] = None self["dataset_cache"] = None diff --git a/qlib/contrib/backtest/profit_attribution.py b/qlib/contrib/backtest/profit_attribution.py index 20c6f638f..2fdebc591 100644 --- a/qlib/contrib/backtest/profit_attribution.py +++ b/qlib/contrib/backtest/profit_attribution.py @@ -16,6 +16,7 @@ def get_benchmark_weight( start_date=None, end_date=None, path=None, + freq="day", ): """get_benchmark_weight @@ -25,6 +26,7 @@ def get_benchmark_weight( :param start_date: :param end_date: :param path: + :param freq: :return: The weight distribution of the the benchmark described by a pandas dataframe Every row corresponds to a trading day. @@ -33,7 +35,7 @@ def get_benchmark_weight( """ if not path: - path = Path(C.get_data_path()).expanduser() / "raw" / "AIndexMembers" / "weights.csv" + path = Path(C.get_data_path(freq)).expanduser() / "raw" / "AIndexMembers" / "weights.csv" # TODO: the storage of weights should be implemented in a more elegent way # TODO: The benchmark is not consistant with the filename in instruments. bench_weight_df = pd.read_csv(path, usecols=["code", "date", "index", "weight"]) @@ -222,6 +224,7 @@ def brinson_pa( group_method="category", group_n=None, deal_price="vwap", + freq="day", ): """brinson profit attribution @@ -243,7 +246,7 @@ def brinson_pa( start_date, end_date = min(dates), max(dates) - bench_stock_weight = get_benchmark_weight(bench, start_date, end_date) + bench_stock_weight = get_benchmark_weight(bench, start_date, end_date, freq) # The attributes for allocation will not if not group_field.startswith("$"): @@ -259,13 +262,14 @@ def brinson_pa( start_time=shift_start_date, end_time=end_date, as_list=True, + freq=freq, ) stock_df = D.features( instruments, [group_field, deal_price], start_time=shift_start_date, end_time=end_date, - freq="day", + freq=freq, ) stock_df.columns = [group_field, "deal_price"] diff --git a/qlib/data/cache.py b/qlib/data/cache.py index 3e0680e74..4475a16f3 100644 --- a/qlib/data/cache.py +++ b/qlib/data/cache.py @@ -17,6 +17,7 @@ import abc from pathlib import Path import numpy as np import pandas as pd +from typing import Union, Iterable from collections import OrderedDict from ..config import C @@ -216,12 +217,14 @@ class CacheUtils: redis_lock.reset_all(r) @staticmethod - def visit(cache_path): + def visit(cache_path: Union[str, Path]): # FIXME: Because read_lock was canceled when reading the cache, multiple processes may have read and write exceptions here try: - with open(cache_path + ".meta", "rb") as f: + cache_path = Path(cache_path) + meta_path = cache_path.with_suffix(".meta") + with meta_path.open("rb") as f: d = pickle.load(f) - with open(cache_path + ".meta", "wb") as f: + with meta_path.open("wb") as f: try: d["meta"]["last_visit"] = str(time.time()) d["meta"]["visits"] = d["meta"]["visits"] + 1 @@ -249,17 +252,17 @@ class CacheUtils: @staticmethod @contextlib.contextmanager - def reader_lock(redis_t, lock_name): - lock_name = f"{C.provider_uri}:{lock_name}" - current_cache_rlock = redis_lock.Lock(redis_t, "%s-rlock" % lock_name) - current_cache_wlock = redis_lock.Lock(redis_t, "%s-wlock" % lock_name) + def reader_lock(redis_t, lock_name: str): + current_cache_rlock = redis_lock.Lock(redis_t, f"{lock_name}-rlock") + current_cache_wlock = redis_lock.Lock(redis_t, f"{lock_name}-wlock") + lock_reader = f"{lock_name}-reader" # make sure only one reader is entering current_cache_rlock.acquire(timeout=60) try: - current_cache_readers = redis_t.get("%s-reader" % lock_name) + current_cache_readers = redis_t.get(lock_reader) if current_cache_readers is None or int(current_cache_readers) == 0: CacheUtils.acquire(current_cache_wlock, lock_name) - redis_t.incr("%s-reader" % lock_name) + redis_t.incr(lock_reader) finally: current_cache_rlock.release() try: @@ -268,9 +271,9 @@ class CacheUtils: # make sure only one reader is leaving current_cache_rlock.acquire(timeout=60) try: - redis_t.decr("%s-reader" % lock_name) - if int(redis_t.get("%s-reader" % lock_name)) == 0: - redis_t.delete("%s-reader" % lock_name) + redis_t.decr(lock_reader) + if int(redis_t.get(lock_reader)) == 0: + redis_t.delete(lock_reader) current_cache_wlock.reset() finally: current_cache_rlock.release() @@ -278,8 +281,7 @@ class CacheUtils: @staticmethod @contextlib.contextmanager def writer_lock(redis_t, lock_name): - lock_name = f"{C.provider_uri}:{lock_name}" - current_cache_wlock = redis_lock.Lock(redis_t, "%s-wlock" % lock_name, id=CacheUtils.LOCK_ID) + current_cache_wlock = redis_lock.Lock(redis_t, f"{lock_name}-wlock", id=CacheUtils.LOCK_ID) CacheUtils.acquire(current_cache_wlock, lock_name) try: yield @@ -297,6 +299,30 @@ class BaseProviderCache: def __getattr__(self, attr): return getattr(self.provider, attr) + @staticmethod + def check_cache_exists(cache_path: Union[str, Path], suffix_list: Iterable = (".index", ".meta")) -> bool: + cache_path = Path(cache_path) + for p in [cache_path] + [cache_path.with_suffix(_s) for _s in suffix_list]: + if not p.exists(): + return False + return True + + @staticmethod + def clear_cache(cache_path: Union[str, Path]): + for p in [ + cache_path, + cache_path.with_suffix(".meta"), + cache_path.with_suffix(".index"), + ]: + if p.exists(): + p.unlink() + + @staticmethod + def get_cache_dir(dir_name: str, freq: str = None) -> Path: + cache_dir = Path(C.get_data_path(freq)).joinpath(dir_name) + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir + class ExpressionCache(BaseProviderCache): """Expression cache mechanism base class. @@ -330,14 +356,14 @@ class ExpressionCache(BaseProviderCache): """ raise NotImplementedError("Implement this method if you want to use expression cache") - def update(self, cache_uri): + def update(self, cache_uri: Union[str, Path]): """Update expression cache to latest calendar. Overide this method to define how to update expression cache corresponding to users' own cache mechanism. Parameters ---------- - cache_uri : str + cache_uri : str or Path the complete uri of expression cache file (include dir path). Returns @@ -403,14 +429,14 @@ class DatasetCache(BaseProviderCache): "Implement this method if you want to use dataset feature cache as a cache file for client" ) - def update(self, cache_uri): + def update(self, cache_uri: Union[str, Path]): """Update dataset cache to latest calendar. Overide this method to define how to update dataset cache corresponding to users' own cache mechanism. Parameters ---------- - cache_uri : str + cache_uri : str or Path the complete uri of dataset cache file (include dir path). Returns @@ -452,25 +478,19 @@ class DiskExpressionCache(ExpressionCache): self.r = get_redis_connection() # remote==True means client is using this module, writing behaviour will not be allowed. self.remote = kwargs.get("remote", False) - self.expr_cache_path = os.path.join(C.get_data_path(), C.features_cache_dir_name) - os.makedirs(self.expr_cache_path, exist_ok=True) + + def get_cache_dir(self, freq: str = None) -> Path: + return super(DiskExpressionCache, self).get_cache_dir(C.features_cache_dir_name, freq) def _uri(self, instrument, field, start_time, end_time, freq): field = remove_fields_space(field) instrument = str(instrument).lower() return hash_args(instrument, field, freq) - @staticmethod - def check_cache_exists(cache_path): - for p in [cache_path, cache_path + ".meta"]: - if not Path(p).exists(): - return False - return True - def _expression(self, instrument, field, start_time=None, end_time=None, freq="day"): _cache_uri = self._uri(instrument=instrument, field=field, start_time=None, end_time=None, freq=freq) - _instrument_dir = os.path.join(self.expr_cache_path, instrument.lower()) - cache_path = os.path.join(_instrument_dir, _cache_uri) + _instrument_dir = self.get_cache_dir(freq).joinpath(instrument.lower()) + cache_path = _instrument_dir.joinpath(_cache_uri) # get calendar from .data import Cal @@ -478,7 +498,7 @@ class DiskExpressionCache(ExpressionCache): _, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq, future=False) - if self.check_cache_exists(cache_path): + if self.check_cache_exists(cache_path, suffix_list=[".meta"]): """ In most cases, we do not need reader_lock. Because updating data is a small probability event compare to reading data. @@ -502,8 +522,7 @@ class DiskExpressionCache(ExpressionCache): # normalize field field = remove_fields_space(field) # cache unavailable, generate the cache - if not os.path.exists(_instrument_dir): - os.makedirs(_instrument_dir, exist_ok=True) + _instrument_dir.mkdir(parents=True, exist_ok=True) if not isinstance(eval(parse_field(field)), Feature): # When the expression is not a raw feature # generate expression cache if the feature is not a Feature @@ -511,7 +530,7 @@ class DiskExpressionCache(ExpressionCache): series = self.provider.expression(instrument, field, _calendar[0], _calendar[-1], freq) if not series.empty: # This expresion is empty, we don't generate any cache for it. - with CacheUtils.writer_lock(self.r, "expression-%s" % _cache_uri): + with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path(freq))}:expression-{_cache_uri}"): self.gen_expression_cache( expression_data=series, cache_path=cache_path, @@ -527,14 +546,6 @@ class DiskExpressionCache(ExpressionCache): # If the expression is a raw feature(such as $close, $open) return self.provider.expression(instrument, field, start_time, end_time, freq) - @staticmethod - def clear_cache(cache_path): - meta_path = cache_path + ".meta" - for p in [cache_path, meta_path]: - p = Path(p) - if p.exists(): - p.unlink() - def gen_expression_cache(self, expression_data, cache_path, instrument, field, freq, last_update): """use bin file to save like feature-data.""" # Make sure the cache runs right when the directory is deleted @@ -544,27 +555,30 @@ class DiskExpressionCache(ExpressionCache): "meta": {"last_visit": time.time(), "visits": 1}, } self.logger.debug(f"generating expression cache: {meta}") - os.makedirs(self.expr_cache_path, exist_ok=True) self.clear_cache(cache_path) - meta_path = cache_path + ".meta" + meta_path = cache_path.with_suffix(".meta") - with open(meta_path, "wb") as f: + with meta_path.open("wb") as f: pickle.dump(meta, f) - os.chmod(meta_path, stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH) + meta_path.chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH) df = expression_data.to_frame() r = np.hstack([df.index[0], expression_data]).astype(" Path: + return super(DiskDatasetCache, self).get_cache_dir(C.dataset_cache_dir_name, freq) @classmethod - def read_data_from_cache(cls, cache_path, start_time, end_time, fields): + def read_data_from_cache(cls, cache_path: Union[str, Path], start_time, end_time, fields): """read_cache_from This function can read data from the disk cache dataset @@ -681,7 +689,7 @@ class DiskDatasetCache(DatasetCache): instruments=instruments, fields=fields, start_time=None, end_time=None, freq=freq, disk_cache=disk_cache ) - cache_path = os.path.join(self.dtst_cache_path, _cache_uri) + cache_path = self.get_cache_dir(freq).joinpath(_cache_uri) features = pd.DataFrame() gen_flag = False @@ -689,7 +697,7 @@ class DiskDatasetCache(DatasetCache): if self.check_cache_exists(cache_path): if disk_cache == 1: # use cache - with CacheUtils.reader_lock(self.r, "dataset-%s" % _cache_uri): + with CacheUtils.reader_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"): CacheUtils.visit(cache_path) features = self.read_data_from_cache(cache_path, start_time, end_time, fields) elif disk_cache == 2: @@ -699,7 +707,7 @@ class DiskDatasetCache(DatasetCache): if gen_flag: # cache unavailable, generate the cache - with CacheUtils.writer_lock(self.r, "dataset-%s" % _cache_uri): + with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"): features = self.gen_dataset_cache( cache_path=cache_path, instruments=instruments, fields=fields, freq=freq ) @@ -719,16 +727,16 @@ class DiskDatasetCache(DatasetCache): _cache_uri = self._uri( instruments=instruments, fields=fields, start_time=None, end_time=None, freq=freq, disk_cache=disk_cache ) - cache_path = os.path.join(self.dtst_cache_path, _cache_uri) + cache_path = self.get_cache_dir(freq).joinpath(_cache_uri) if self.check_cache_exists(cache_path): self.logger.debug(f"The cache dataset has already existed {cache_path}. Return the uri directly") - with CacheUtils.reader_lock(self.r, "dataset-%s" % _cache_uri): + with CacheUtils.reader_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"): CacheUtils.visit(cache_path) return _cache_uri else: # cache unavailable, generate the cache - with CacheUtils.writer_lock(self.r, "dataset-%s" % _cache_uri): + with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"): self.gen_dataset_cache(cache_path=cache_path, instruments=instruments, fields=fields, freq=freq) return _cache_uri @@ -740,8 +748,9 @@ class DiskDatasetCache(DatasetCache): KEY = "df" - def __init__(self, cache_path): - self.index_path = cache_path + ".index" + def __init__(self, cache_path: Union[str, Path]): + + self.index_path = cache_path.with_suffix(".index") self._data = None self.logger = get_module_logger(self.__class__.__name__) @@ -757,7 +766,7 @@ class DiskDatasetCache(DatasetCache): self._data.sort_index(inplace=True) self._data.to_hdf(self.index_path, key=self.KEY, mode="w", format="table") # The index should be readable for all users - os.chmod(self.index_path, stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH) + self.index_path.chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH) def sync_from_disk(self): # The file will not be closed directly if we read_hdf from the disk directly @@ -795,15 +804,7 @@ class DiskDatasetCache(DatasetCache): index_data += start_index return index_data - @staticmethod - def clear_cache(cache_path): - meta_path = cache_path + ".meta" - for p in [cache_path, meta_path, cache_path + ".index", cache_path + ".data"]: - p = Path(p) - if p.exists(): - p.unlink() - - def gen_dataset_cache(self, cache_path, instruments, fields, freq): + def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, freq): """gen_dataset_cache .. note:: This function does not consider the cache read write lock. Please @@ -844,11 +845,11 @@ class DiskDatasetCache(DatasetCache): # get calendar from .data import Cal + cache_path = Path(cache_path) _calendar = Cal.calendar(freq=freq) self.logger.debug(f"Generating dataset cache {cache_path}") # Make sure the cache runs right when the directory is deleted # while running - os.makedirs(self.dtst_cache_path, exist_ok=True) self.clear_cache(cache_path) features = self.provider.dataset(instruments, fields, _calendar[0], _calendar[-1], freq) @@ -860,7 +861,7 @@ class DiskDatasetCache(DatasetCache): features = features.swaplevel("instrument", "datetime").sort_index() # write cache data - with pd.HDFStore(cache_path + ".data") as store: + with pd.HDFStore(str(cache_path.with_suffix(".data"))) as store: cache_to_orig_map = dict(zip(remove_fields_space(features.columns), features.columns)) orig_to_cache_map = dict(zip(features.columns, remove_fields_space(features.columns))) cache_features = features[list(cache_to_orig_map.values())].rename(columns=orig_to_cache_map) @@ -879,9 +880,9 @@ class DiskDatasetCache(DatasetCache): }, "meta": {"last_visit": time.time(), "visits": 1}, } - with open(cache_path + ".meta", "wb") as f: + with cache_path.with_suffix(".meta").open("wb") as f: pickle.dump(meta, f) - os.chmod(cache_path + ".meta", stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH) + cache_path.with_suffix(".meta").chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH) # write index file im = DiskDatasetCache.IndexManager(cache_path) index_data = im.build_index_from_data(features) @@ -890,21 +891,23 @@ class DiskDatasetCache(DatasetCache): # rename the file after the cache has been generated # this doesn't work well on windows, but our server won't use windows # temporarily - os.replace(cache_path + ".data", cache_path) + cache_path.with_suffix(".data").rename(cache_path) # the fields of the cached features are converted to the original fields return features.swaplevel("datetime", "instrument") def update(self, cache_uri): - cp_cache_uri = os.path.join(self.dtst_cache_path, cache_uri) - + # FIXME: when updating the cache, the type of C.provider_uri must be str + assert isinstance(C.provider_uri, str), "when updating the cache, the type of C.provider_uri must be str" + cp_cache_uri = self.get_cache_dir().joinpath(cache_uri) + meta_path = cp_cache_uri.with_suffix(".meta") if not self.check_cache_exists(cp_cache_uri): self.logger.info(f"The cache {cp_cache_uri} has corrupted. It will be removed") self.clear_cache(cp_cache_uri) return 2 im = DiskDatasetCache.IndexManager(cp_cache_uri) - with CacheUtils.writer_lock(self.r, "dataset-%s" % cache_uri): - with open(cp_cache_uri + ".meta", "rb") as f: + with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path())}:dataset-{cache_uri}"): + with meta_path.open("rb") as f: d = pickle.load(f) instruments = d["info"]["instruments"] fields = d["info"]["fields"] @@ -995,7 +998,7 @@ class DiskDatasetCache(DatasetCache): # update meta file d["info"]["last_update"] = str(new_calendar[-1]) - with open(cp_cache_uri + ".meta", "wb") as f: + with meta_path.open("wb") as f: pickle.dump(d, f) return 0 @@ -1006,26 +1009,25 @@ class SimpleDatasetCache(DatasetCache): def __init__(self, provider): super(SimpleDatasetCache, self).__init__(provider) try: - self.local_cache_path = C["local_cache_path"] + self.local_cache_path: Path = Path(C["local_cache_path"]).expanduser().resolve() except KeyError as e: self.logger.error("Assign a local_cache_path in config if you want to use this cache mechanism") def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs): instruments, fields, freq = self.normalize_uri_args(instruments, fields, freq) - local_cache_path = str(Path(self.local_cache_path).expanduser().resolve()) - return hash_args(instruments, fields, start_time, end_time, freq, disk_cache, local_cache_path) + return hash_args(instruments, fields, start_time, end_time, freq, disk_cache, str(self.local_cache_path)) def _dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1): if disk_cache == 0: # In this case, data_set cache is configured but will not be used. return self.provider.dataset(instruments, fields, start_time, end_time, freq) - os.makedirs(os.path.expanduser(self.local_cache_path), exist_ok=True) - cache_file = os.path.join( - self.local_cache_path, self._uri(instruments, fields, start_time, end_time, freq, disk_cache=disk_cache) + self.local_cache_path.mkdir(exist_ok=True, parents=True) + cache_file = self.local_cache_path.joinpath( + self._uri(instruments, fields, start_time, end_time, freq, disk_cache=disk_cache) ) gen_flag = False - if os.path.exists(cache_file): + if cache_file.exists(): if disk_cache == 1: # use cache df = pd.read_pickle(cache_file) @@ -1061,8 +1063,8 @@ class DatasetURICache(DatasetCache): # use ClientDatasetProvider feature_uri = self._uri(instruments, fields, None, None, freq, disk_cache=disk_cache) value, expire = MemCacheExpire.get_cache(H["f"], feature_uri) - mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri) - if value is None or expire or not os.path.exists(mnt_feature_uri): + mnt_feature_uri = C.get_data_path(freq).joinpath(C.dataset_cache_dir_name).joinpath(feature_uri) + if value is None or expire or not mnt_feature_uri.exists(): df, uri = self.provider.dataset( instruments, fields, start_time, end_time, freq, disk_cache, return_uri=True ) @@ -1072,7 +1074,6 @@ class DatasetURICache(DatasetCache): # HZ['f'][uri] = df.copy() get_module_logger("cache").debug(f"get feature from {C.dataset_provider}") else: - mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri) df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields) get_module_logger("cache").debug("get feature from uri cache") diff --git a/qlib/data/data.py b/qlib/data/data.py index 8ec4355c6..82f8c3745 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -5,16 +5,11 @@ from __future__ import division from __future__ import print_function -import os import re import abc import copy -import time import queue import bisect -import logging -import importlib -import traceback import numpy as np import pandas as pd from multiprocessing import Pool @@ -23,7 +18,7 @@ from .cache import H from ..config import C from .ops import Operators from ..log import get_module_logger -from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, code_to_fname +from ..utils import parse_field, hash_args, normalize_cache_fields, code_to_fname from .base import Feature from .cache import DiskDatasetCache, DiskExpressionCache from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path @@ -48,19 +43,14 @@ class ProviderBackendMixin: # default provider_uri map if "provider_uri" not in backend_kwargs: # if the user has no uri configured, use: uri = uri_map[freq] - # NOTE: uri priority - # 1. backend_obj.kwargs["provider_uri"] - # 2. backend_obj.kwargs["backend_freq_config"] - # 3. C.backend_freq_config, or qlib.init(backend_freq_config={}) - # 4. C.provider_uri, or qlib.init(provider_uri="") - provider_uri_map = backend_kwargs.setdefault("backend_freq_config", {}) + # NOTE: provider_uri priority: + # 1. backend_config: backend_obj["kwargs"]["provider_uri"] + # 2. backend_config: backend_obj["kwargs"]["provider_uri_map"] + # 3. qlib.init: provider_uri + provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {}) freq = kwargs.get("freq", "day") - if C.backend_freq_config is not None: - if freq not in provider_uri_map: - provider_uri_map[freq] = C.backend_freq_config.get(freq, C.get_data_path()) - else: - if freq not in provider_uri_map: - provider_uri_map[freq] = C.get_data_path() + if freq not in provider_uri_map: + provider_uri_map[freq] = C.get_data_path(freq) backend_kwargs["provider_uri"] = provider_uri_map[freq] backend.setdefault("kwargs", {}).update(**kwargs) return init_instance_by_config(backend) @@ -548,11 +538,6 @@ class LocalCalendarProvider(CalendarProvider): super(LocalCalendarProvider, self).__init__(**kwargs) self.remote = kwargs.get("remote", False) - @property - def _uri_cal(self): - """Calendar file uri.""" - return os.path.join(C.get_data_path(), "calendars", "{}.txt") - def load_calendar(self, freq, future): """Load original calendar timestamp from file. @@ -612,11 +597,6 @@ class LocalInstrumentProvider(InstrumentProvider): Provide instrument data from local data source. """ - @property - def _uri_inst(self): - """Instrument file uri.""" - return os.path.join(C.get_data_path(), "instruments", "{}.txt") - def _load_instruments(self, market, freq): return self.backend_obj(market=market, freq=freq).data @@ -665,11 +645,6 @@ class LocalFeatureProvider(FeatureProvider): super(LocalFeatureProvider, self).__init__(**kwargs) self.remote = kwargs.get("remote", False) - @property - def _uri_data(self): - """Static feature file uri.""" - return os.path.join(C.get_data_path(), "features", "{}", "{}.{}.bin") - def feature(self, instrument, field, start_index, end_index, freq): # validate field = str(field).lower()[1:] @@ -937,7 +912,7 @@ class ClientDatasetProvider(DatasetProvider): get_module_logger("data").debug("get result") try: # pre-mound nfs, used for demo - mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri) + mnt_feature_uri = C.get_data_path(freq).joinpath(C.dataset_cache_dir_name, feature_uri) df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields) get_module_logger("data").debug("finish slicing data") if return_uri: diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 3934a1992..7504f8d61 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -43,8 +43,9 @@ def get_redis_connection(): #################### Data #################### -def read_bin(file_path, start_index, end_index): - with open(file_path, "rb") as f: +def read_bin(file_path: Union[str, Path], start_index, end_index): + file_path = Path(file_path.expanduser().resolve()) + with file_path.open("rb") as f: # read start_index ref_start_index = int(np.frombuffer(f.read(4), dtype="