1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 03:21:00 +08:00

support multi-freq uri

This commit is contained in:
zhupr
2021-08-29 16:21:37 +08:00
committed by you-n-g
parent 6011a21308
commit d1cbf4c3d9
6 changed files with 213 additions and 200 deletions

View File

@@ -33,48 +33,63 @@ def init(default_conf="client", **kwargs):
H.clear()
C.set(default_conf, **kwargs)
# check path if server/local
if C.get_uri_type() == C.LOCAL_URI:
if not os.path.exists(C["provider_uri"]):
if C["auto_mount"]:
logger.error(
f"Invalid provider uri: {C['provider_uri']}, please check if a valid provider uri has been set. This path does not exist."
)
provider_uri_map = C.provider_uri if isinstance(C.provider_uri, dict) else {None: C.provider_uri}
if C.mount_path is not None:
# mount nfs
for _freq, provider_uri in provider_uri_map.items():
mount_path = C.mount_path if _freq is None else C["mount_path"][_freq]
# check path if server/local
uri_type = C.get_uri_type(provider_uri)
if uri_type == C.LOCAL_URI:
if not Path(provider_uri).exists():
if C["auto_mount"]:
logger.error(
f"Invalid provider uri: {provider_uri}, please check if a valid provider uri has been set. This path does not exist."
)
else:
logger.warning(f"auto_path is False, please make sure {mount_path} is mounted")
elif uri_type == C.NFS_URI:
mount_path = _mount_nfs_uri(provider_uri, mount_path, C["auto_mount"])
if _freq is None:
C["mount_path"] = mount_path
else:
C["mount_path"][_freq] = mount_path
else:
logger.warning(f"auto_path is False, please make sure {C['mount_path']} is mounted")
elif C.get_uri_type() == C.NFS_URI:
_mount_nfs_uri(C)
else:
raise NotImplementedError(f"This type of URI is not supported")
raise NotImplementedError(f"This type of URI is not supported")
C.register()
if "flask_server" in C:
logger.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
logger.info("qlib successfully initialized based on %s settings." % default_conf)
logger.info(f"data_path={C.get_data_path()}")
data_path = (
C.get_data_path()
if isinstance(C["provider_uri"], str)
else {_freq: C.get_data_path(_freq) for _freq in C["provider_uri"].keys()}
)
logger.info(f"data_path={data_path}")
def _mount_nfs_uri(C):
def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
LOG = get_module_logger("mount nfs", level=logging.INFO)
# FIXME: the C["provider_uri"] is modified in this function
# If it is not modified, we can pass only provider_uri or mount_path instead of C
mount_command = "sudo mount.nfs %s %s" % (C["provider_uri"], C["mount_path"])
mount_command = "sudo mount.nfs %s %s" % (provider_uri, mount_path)
# If the provider uri looks like this 172.23.233.89//data/csdesign'
# It will be a nfs path. The client provider will be used
if not C["auto_mount"]:
if not os.path.exists(C["mount_path"]):
if not auto_mount:
if not Path(mount_path).exists():
raise FileNotFoundError(
f"Invalid mount path: {C['mount_path']}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
f"Invalid mount path: {mount_path}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
)
else:
# Judging system type
sys_type = platform.system()
if "win" in sys_type.lower():
# system: window
exec_result = os.popen("mount -o anon %s %s" % (C["provider_uri"], C["mount_path"] + ":"))
exec_result = os.popen("mount -o anon %s %s" % (provider_uri, mount_path + ":"))
result = exec_result.read()
if "85" in result:
LOG.warning("already mounted or window mount path already exists")
@@ -82,20 +97,18 @@ def _mount_nfs_uri(C):
raise OSError("not find network path")
elif "error" in result or "错误" in result:
raise OSError("Invalid mount path")
elif C["provider_uri"] in result:
elif provider_uri in result:
LOG.info("window success mount..")
else:
raise OSError(f"unknown error: {result}")
# config mount path
C["mount_path"] = C["mount_path"] + ":\\"
mount_path += ":\\"
else:
# system: linux/Unix/Mac
# check mount
_remote_uri = C["provider_uri"]
_remote_uri = _remote_uri[:-1] if _remote_uri.endswith("/") else _remote_uri
_mount_path = C["mount_path"]
_mount_path = _mount_path[:-1] if _mount_path.endswith("/") else _mount_path
_remote_uri = provider_uri[:-1] if provider_uri.endswith("/") else provider_uri
_mount_path = mount_path[:-1] if mount_path.endswith("/") else mount_path
_check_level_num = 2
_is_mount = False
while _check_level_num:
@@ -121,11 +134,9 @@ def _mount_nfs_uri(C):
if not _is_mount:
try:
os.makedirs(C["mount_path"], exist_ok=True)
Path(mount_path).mkdir(parents=True, exist_ok=True)
except Exception:
raise OSError(
f"Failed to create directory {C['mount_path']}, please create {C['mount_path']} manually!"
)
raise OSError(f"Failed to create directory {mount_path}, please create {mount_path} manually!")
# check nfs-common
command_res = os.popen("dpkg -l | grep nfs-common")
@@ -136,15 +147,16 @@ def _mount_nfs_uri(C):
command_status = os.system(mount_command)
if command_status == 256:
raise OSError(
f"mount {C['provider_uri']} on {C['mount_path']} error! Needs SUDO! Please mount manually: {mount_command}"
f"mount {provider_uri} on {mount_path} error! Needs SUDO! Please mount manually: {mount_command}"
)
elif command_status == 32512:
# LOG.error("Command error")
raise OSError(f"mount {C['provider_uri']} on {C['mount_path']} error! Command error")
raise OSError(f"mount {provider_uri} on {mount_path} error! Command error")
elif command_status == 0:
LOG.info("Mount finished")
else:
LOG.warning(f"{_remote_uri} on {_mount_path} is already mounted")
return mount_path
def init_from_yaml_conf(conf_path, **kwargs):

View File

@@ -17,6 +17,7 @@ import copy
import logging
import multiprocessing
from pathlib import Path
from typing import Union
class Config:
@@ -82,25 +83,16 @@ _default_config = {
"dataset_provider": "LocalDatasetProvider",
"provider": "LocalProvider",
# config it in qlib.init()
# "provider_uri" str or dict:
# # str
# "~/.qlib/stock_data/cn_data"
# # dict
# {"day": "~/.qlib/stock_data/cn_data", "1min": "~/.qlib/stock_data/cn_data_1min"}
# NOTE: provider_uri priority
# 1. backend_config: backend_obj["kwargs"]["provider_uri"]
# 2. backend_config: backend_obj["kwargs"]["provider_uri_map"]
# 3. qlib.init: provider_uri
"provider_uri": "",
# backend_freq_config is dict: {"day": "1d dataset uri", "1min": "1min dataset uri"}
# If `backend_freq_config` is not None && is not empty && "freq" in `backend_freq_config.keys()`
# use `backend_freq_config` as backend-uri
# else:
# use `provider_uri` as backend-uri
# Examples:
# qlib.init(provider_uri="qlib_data/1d", backend_freq_config={"1min": "qlib_data/1min"})
# # using provider_uri
# D.features(D.instruments("all"), ["$close"], freq="day")
# # using backend_freq_config["1min"]
# D.features(D.instruments("all"), ["$close"], freq="1min")
# ########################
# qlib.init(provider_uri="qlib_data/1d", backend_freq_config={"1min": "qlib_data/1min", "day": "qlib_data/day"})
# # using backend_freq_config["day"]
# D.features(D.instruments("all"), ["$close"], freq="day")
# # raise ValueError
# D.features(D.instruments("all"), ["$close"], freq="week")
"backend_freq_config": None,
# cache
"expression_cache": None,
"dataset_cache": None,
@@ -262,28 +254,55 @@ class QlibConfig(Config):
def resolve_path(self):
# resolve path
if self["mount_path"] is not None:
self["mount_path"] = str(Path(self["mount_path"]).expanduser().resolve())
_mount_path = self["mount_path"]
_provider_uri = self["provider_uri"]
if isinstance(_provider_uri, dict):
if _mount_path is not None:
# check provider_uri and mount_path
assert isinstance(_mount_path, dict), f"type(provider_uri) != type(mount_path); {_mount_path}"
_miss_freq = set(_provider_uri.keys()) - set(_mount_path.keys())
assert len(_miss_freq) == 0, f"mount_path is missing freq: {_miss_freq}"
self["mount_path"] = {_freq: str(Path(_path).expanduser().resolve()) for _freq, _path in _mount_path}
for _freq, _uri in _provider_uri.items():
if self.get_uri_type(_uri) == QlibConfig.LOCAL_URI:
self["provider_uri"][_freq] = str(Path(_uri).expanduser().resolve())
elif isinstance(_provider_uri, str):
if _mount_path is not None:
self["mount_path"] = str(Path(_mount_path).expanduser().resolve())
if self.get_uri_type() == QlibConfig.LOCAL_URI:
self["provider_uri"] = str(Path(self["provider_uri"]).expanduser().resolve())
if self.get_uri_type(_provider_uri) == QlibConfig.LOCAL_URI:
self["provider_uri"] = str(Path(_provider_uri).expanduser().resolve())
else:
raise TypeError(
f"The types supported by provider_uri are [str, dict], " f"not {type(_provider_uri)}: {_provider_uri}"
)
def get_uri_type(self):
is_win = re.match("^[a-zA-Z]:.*", self["provider_uri"]) is not None # such as 'C:\\data', 'D:'
is_nfs_or_win = (
re.match("^[^/]+:.+", self["provider_uri"]) is not None
) # such as 'host:/data/' (User may define short hostname by themselves or use localhost)
@staticmethod
def get_uri_type(uri: Union[str, Path]):
uri = uri if isinstance(uri, str) else str(uri.expanduser().resolve())
is_win = re.match("^[a-zA-Z]:.*", uri) is not None # such as 'C:\\data', 'D:'
# such as 'host:/data/' (User may define short hostname by themselves or use localhost)
is_nfs_or_win = re.match("^[^/]+:.+", uri) is not None
if is_nfs_or_win and not is_win:
return QlibConfig.NFS_URI
else:
return QlibConfig.LOCAL_URI
def get_data_path(self):
if self.get_uri_type() == QlibConfig.LOCAL_URI:
return self["provider_uri"]
elif self.get_uri_type() == QlibConfig.NFS_URI:
return self["mount_path"]
def get_data_path(self, freq: str = None) -> Path:
if freq is None and not isinstance(self["provider_uri"], str):
raise ValueError(f"type(provider_uri) == dict, freq cannot be None; provider_uri: {self['provider_uri']}")
_provider_uri = self["provider_uri"] if isinstance(self["provider_uri"], str) else self["provider_uri"][freq]
_mount_path = (
self["mount_path"]
if self["mount_path"] is None or isinstance(self["mount_path"], str)
else self["mount_path"][freq]
)
if self.get_uri_type(_provider_uri) == QlibConfig.LOCAL_URI:
return Path(_provider_uri)
elif self.get_uri_type(_provider_uri) == QlibConfig.NFS_URI:
return Path(_mount_path)
else:
raise NotImplementedError(f"This type of uri is not supported")
@@ -318,7 +337,8 @@ class QlibConfig(Config):
# check redis
if not can_use_cache():
logger.warning(
f"redis connection failed(host={self['redis_host']} port={self['redis_port']}), cache will not be used!"
f"redis connection failed(host={self['redis_host']} port={self['redis_port']}), "
f"cache will not be used!"
)
self["expression_cache"] = None
self["dataset_cache"] = None

View File

@@ -16,6 +16,7 @@ def get_benchmark_weight(
start_date=None,
end_date=None,
path=None,
freq="day",
):
"""get_benchmark_weight
@@ -25,6 +26,7 @@ def get_benchmark_weight(
:param start_date:
:param end_date:
:param path:
:param freq:
:return: The weight distribution of the the benchmark described by a pandas dataframe
Every row corresponds to a trading day.
@@ -33,7 +35,7 @@ def get_benchmark_weight(
"""
if not path:
path = Path(C.get_data_path()).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
path = Path(C.get_data_path(freq)).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
# TODO: the storage of weights should be implemented in a more elegent way
# TODO: The benchmark is not consistant with the filename in instruments.
bench_weight_df = pd.read_csv(path, usecols=["code", "date", "index", "weight"])
@@ -222,6 +224,7 @@ def brinson_pa(
group_method="category",
group_n=None,
deal_price="vwap",
freq="day",
):
"""brinson profit attribution
@@ -243,7 +246,7 @@ def brinson_pa(
start_date, end_date = min(dates), max(dates)
bench_stock_weight = get_benchmark_weight(bench, start_date, end_date)
bench_stock_weight = get_benchmark_weight(bench, start_date, end_date, freq)
# The attributes for allocation will not
if not group_field.startswith("$"):
@@ -259,13 +262,14 @@ def brinson_pa(
start_time=shift_start_date,
end_time=end_date,
as_list=True,
freq=freq,
)
stock_df = D.features(
instruments,
[group_field, deal_price],
start_time=shift_start_date,
end_time=end_date,
freq="day",
freq=freq,
)
stock_df.columns = [group_field, "deal_price"]

View File

@@ -17,6 +17,7 @@ import abc
from pathlib import Path
import numpy as np
import pandas as pd
from typing import Union, Iterable
from collections import OrderedDict
from ..config import C
@@ -216,12 +217,14 @@ class CacheUtils:
redis_lock.reset_all(r)
@staticmethod
def visit(cache_path):
def visit(cache_path: Union[str, Path]):
# FIXME: Because read_lock was canceled when reading the cache, multiple processes may have read and write exceptions here
try:
with open(cache_path + ".meta", "rb") as f:
cache_path = Path(cache_path)
meta_path = cache_path.with_suffix(".meta")
with meta_path.open("rb") as f:
d = pickle.load(f)
with open(cache_path + ".meta", "wb") as f:
with meta_path.open("wb") as f:
try:
d["meta"]["last_visit"] = str(time.time())
d["meta"]["visits"] = d["meta"]["visits"] + 1
@@ -249,17 +252,17 @@ class CacheUtils:
@staticmethod
@contextlib.contextmanager
def reader_lock(redis_t, lock_name):
lock_name = f"{C.provider_uri}:{lock_name}"
current_cache_rlock = redis_lock.Lock(redis_t, "%s-rlock" % lock_name)
current_cache_wlock = redis_lock.Lock(redis_t, "%s-wlock" % lock_name)
def reader_lock(redis_t, lock_name: str):
current_cache_rlock = redis_lock.Lock(redis_t, f"{lock_name}-rlock")
current_cache_wlock = redis_lock.Lock(redis_t, f"{lock_name}-wlock")
lock_reader = f"{lock_name}-reader"
# make sure only one reader is entering
current_cache_rlock.acquire(timeout=60)
try:
current_cache_readers = redis_t.get("%s-reader" % lock_name)
current_cache_readers = redis_t.get(lock_reader)
if current_cache_readers is None or int(current_cache_readers) == 0:
CacheUtils.acquire(current_cache_wlock, lock_name)
redis_t.incr("%s-reader" % lock_name)
redis_t.incr(lock_reader)
finally:
current_cache_rlock.release()
try:
@@ -268,9 +271,9 @@ class CacheUtils:
# make sure only one reader is leaving
current_cache_rlock.acquire(timeout=60)
try:
redis_t.decr("%s-reader" % lock_name)
if int(redis_t.get("%s-reader" % lock_name)) == 0:
redis_t.delete("%s-reader" % lock_name)
redis_t.decr(lock_reader)
if int(redis_t.get(lock_reader)) == 0:
redis_t.delete(lock_reader)
current_cache_wlock.reset()
finally:
current_cache_rlock.release()
@@ -278,8 +281,7 @@ class CacheUtils:
@staticmethod
@contextlib.contextmanager
def writer_lock(redis_t, lock_name):
lock_name = f"{C.provider_uri}:{lock_name}"
current_cache_wlock = redis_lock.Lock(redis_t, "%s-wlock" % lock_name, id=CacheUtils.LOCK_ID)
current_cache_wlock = redis_lock.Lock(redis_t, f"{lock_name}-wlock", id=CacheUtils.LOCK_ID)
CacheUtils.acquire(current_cache_wlock, lock_name)
try:
yield
@@ -297,6 +299,30 @@ class BaseProviderCache:
def __getattr__(self, attr):
return getattr(self.provider, attr)
@staticmethod
def check_cache_exists(cache_path: Union[str, Path], suffix_list: Iterable = (".index", ".meta")) -> bool:
cache_path = Path(cache_path)
for p in [cache_path] + [cache_path.with_suffix(_s) for _s in suffix_list]:
if not p.exists():
return False
return True
@staticmethod
def clear_cache(cache_path: Union[str, Path]):
for p in [
cache_path,
cache_path.with_suffix(".meta"),
cache_path.with_suffix(".index"),
]:
if p.exists():
p.unlink()
@staticmethod
def get_cache_dir(dir_name: str, freq: str = None) -> Path:
cache_dir = Path(C.get_data_path(freq)).joinpath(dir_name)
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir
class ExpressionCache(BaseProviderCache):
"""Expression cache mechanism base class.
@@ -330,14 +356,14 @@ class ExpressionCache(BaseProviderCache):
"""
raise NotImplementedError("Implement this method if you want to use expression cache")
def update(self, cache_uri):
def update(self, cache_uri: Union[str, Path]):
"""Update expression cache to latest calendar.
Overide this method to define how to update expression cache corresponding to users' own cache mechanism.
Parameters
----------
cache_uri : str
cache_uri : str or Path
the complete uri of expression cache file (include dir path).
Returns
@@ -403,14 +429,14 @@ class DatasetCache(BaseProviderCache):
"Implement this method if you want to use dataset feature cache as a cache file for client"
)
def update(self, cache_uri):
def update(self, cache_uri: Union[str, Path]):
"""Update dataset cache to latest calendar.
Overide this method to define how to update dataset cache corresponding to users' own cache mechanism.
Parameters
----------
cache_uri : str
cache_uri : str or Path
the complete uri of dataset cache file (include dir path).
Returns
@@ -452,25 +478,19 @@ class DiskExpressionCache(ExpressionCache):
self.r = get_redis_connection()
# remote==True means client is using this module, writing behaviour will not be allowed.
self.remote = kwargs.get("remote", False)
self.expr_cache_path = os.path.join(C.get_data_path(), C.features_cache_dir_name)
os.makedirs(self.expr_cache_path, exist_ok=True)
def get_cache_dir(self, freq: str = None) -> Path:
return super(DiskExpressionCache, self).get_cache_dir(C.features_cache_dir_name, freq)
def _uri(self, instrument, field, start_time, end_time, freq):
field = remove_fields_space(field)
instrument = str(instrument).lower()
return hash_args(instrument, field, freq)
@staticmethod
def check_cache_exists(cache_path):
for p in [cache_path, cache_path + ".meta"]:
if not Path(p).exists():
return False
return True
def _expression(self, instrument, field, start_time=None, end_time=None, freq="day"):
_cache_uri = self._uri(instrument=instrument, field=field, start_time=None, end_time=None, freq=freq)
_instrument_dir = os.path.join(self.expr_cache_path, instrument.lower())
cache_path = os.path.join(_instrument_dir, _cache_uri)
_instrument_dir = self.get_cache_dir(freq).joinpath(instrument.lower())
cache_path = _instrument_dir.joinpath(_cache_uri)
# get calendar
from .data import Cal
@@ -478,7 +498,7 @@ class DiskExpressionCache(ExpressionCache):
_, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq, future=False)
if self.check_cache_exists(cache_path):
if self.check_cache_exists(cache_path, suffix_list=[".meta"]):
"""
In most cases, we do not need reader_lock.
Because updating data is a small probability event compare to reading data.
@@ -502,8 +522,7 @@ class DiskExpressionCache(ExpressionCache):
# normalize field
field = remove_fields_space(field)
# cache unavailable, generate the cache
if not os.path.exists(_instrument_dir):
os.makedirs(_instrument_dir, exist_ok=True)
_instrument_dir.mkdir(parents=True, exist_ok=True)
if not isinstance(eval(parse_field(field)), Feature):
# When the expression is not a raw feature
# generate expression cache if the feature is not a Feature
@@ -511,7 +530,7 @@ class DiskExpressionCache(ExpressionCache):
series = self.provider.expression(instrument, field, _calendar[0], _calendar[-1], freq)
if not series.empty:
# This expresion is empty, we don't generate any cache for it.
with CacheUtils.writer_lock(self.r, "expression-%s" % _cache_uri):
with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path(freq))}:expression-{_cache_uri}"):
self.gen_expression_cache(
expression_data=series,
cache_path=cache_path,
@@ -527,14 +546,6 @@ class DiskExpressionCache(ExpressionCache):
# If the expression is a raw feature(such as $close, $open)
return self.provider.expression(instrument, field, start_time, end_time, freq)
@staticmethod
def clear_cache(cache_path):
meta_path = cache_path + ".meta"
for p in [cache_path, meta_path]:
p = Path(p)
if p.exists():
p.unlink()
def gen_expression_cache(self, expression_data, cache_path, instrument, field, freq, last_update):
"""use bin file to save like feature-data."""
# Make sure the cache runs right when the directory is deleted
@@ -544,27 +555,30 @@ class DiskExpressionCache(ExpressionCache):
"meta": {"last_visit": time.time(), "visits": 1},
}
self.logger.debug(f"generating expression cache: {meta}")
os.makedirs(self.expr_cache_path, exist_ok=True)
self.clear_cache(cache_path)
meta_path = cache_path + ".meta"
meta_path = cache_path.with_suffix(".meta")
with open(meta_path, "wb") as f:
with meta_path.open("wb") as f:
pickle.dump(meta, f)
os.chmod(meta_path, stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
meta_path.chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
df = expression_data.to_frame()
r = np.hstack([df.index[0], expression_data]).astype("<f")
r.tofile(str(cache_path))
def update(self, sid, cache_uri):
cp_cache_uri = os.path.join(self.expr_cache_path, sid, cache_uri)
if not self.check_cache_exists(cp_cache_uri):
# FIXME: when updating the cache, the type of C.provider_uri must be str
assert isinstance(C.provider_uri, str), "when updating the cache, the type of C.provider_uri must be str"
cp_cache_uri = self.get_cache_dir().joinpath(sid).joinpath(cache_uri)
meta_path = cp_cache_uri.with_suffix(".meta")
if not self.check_cache_exists(cp_cache_uri, suffix_list=[".meta"]):
self.logger.info(f"The cache {cp_cache_uri} has corrupted. It will be removed")
self.clear_cache(cp_cache_uri)
return 2
with CacheUtils.writer_lock(self.r, "expression-%s" % cache_uri):
with open(cp_cache_uri + ".meta", "rb") as f:
with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path())}:expression-{cache_uri}"):
with meta_path.open("rb") as f:
d = pickle.load(f)
instrument = d["info"]["instrument"]
field = d["info"]["field"]
@@ -611,7 +625,7 @@ class DiskExpressionCache(ExpressionCache):
f.write(data)
# update meta file
d["info"]["last_update"] = str(new_calendar[-1])
with open(cp_cache_uri + ".meta", "wb") as f:
with meta_path.open("wb") as f:
pickle.dump(d, f)
return 0
@@ -623,22 +637,16 @@ class DiskDatasetCache(DatasetCache):
super(DiskDatasetCache, self).__init__(provider)
self.r = get_redis_connection()
self.remote = kwargs.get("remote", False)
self.dtst_cache_path = os.path.join(C.get_data_path(), C.dataset_cache_dir_name)
os.makedirs(self.dtst_cache_path, exist_ok=True)
@staticmethod
def _uri(instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs):
return hash_args(*DatasetCache.normalize_uri_args(instruments, fields, freq), disk_cache)
@staticmethod
def check_cache_exists(cache_path):
for p in [cache_path, cache_path + ".index", cache_path + ".meta"]:
if not Path(p).exists():
return False
return True
def get_cache_dir(self, freq: str = None) -> Path:
return super(DiskDatasetCache, self).get_cache_dir(C.dataset_cache_dir_name, freq)
@classmethod
def read_data_from_cache(cls, cache_path, start_time, end_time, fields):
def read_data_from_cache(cls, cache_path: Union[str, Path], start_time, end_time, fields):
"""read_cache_from
This function can read data from the disk cache dataset
@@ -681,7 +689,7 @@ class DiskDatasetCache(DatasetCache):
instruments=instruments, fields=fields, start_time=None, end_time=None, freq=freq, disk_cache=disk_cache
)
cache_path = os.path.join(self.dtst_cache_path, _cache_uri)
cache_path = self.get_cache_dir(freq).joinpath(_cache_uri)
features = pd.DataFrame()
gen_flag = False
@@ -689,7 +697,7 @@ class DiskDatasetCache(DatasetCache):
if self.check_cache_exists(cache_path):
if disk_cache == 1:
# use cache
with CacheUtils.reader_lock(self.r, "dataset-%s" % _cache_uri):
with CacheUtils.reader_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"):
CacheUtils.visit(cache_path)
features = self.read_data_from_cache(cache_path, start_time, end_time, fields)
elif disk_cache == 2:
@@ -699,7 +707,7 @@ class DiskDatasetCache(DatasetCache):
if gen_flag:
# cache unavailable, generate the cache
with CacheUtils.writer_lock(self.r, "dataset-%s" % _cache_uri):
with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"):
features = self.gen_dataset_cache(
cache_path=cache_path, instruments=instruments, fields=fields, freq=freq
)
@@ -719,16 +727,16 @@ class DiskDatasetCache(DatasetCache):
_cache_uri = self._uri(
instruments=instruments, fields=fields, start_time=None, end_time=None, freq=freq, disk_cache=disk_cache
)
cache_path = os.path.join(self.dtst_cache_path, _cache_uri)
cache_path = self.get_cache_dir(freq).joinpath(_cache_uri)
if self.check_cache_exists(cache_path):
self.logger.debug(f"The cache dataset has already existed {cache_path}. Return the uri directly")
with CacheUtils.reader_lock(self.r, "dataset-%s" % _cache_uri):
with CacheUtils.reader_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"):
CacheUtils.visit(cache_path)
return _cache_uri
else:
# cache unavailable, generate the cache
with CacheUtils.writer_lock(self.r, "dataset-%s" % _cache_uri):
with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"):
self.gen_dataset_cache(cache_path=cache_path, instruments=instruments, fields=fields, freq=freq)
return _cache_uri
@@ -740,8 +748,9 @@ class DiskDatasetCache(DatasetCache):
KEY = "df"
def __init__(self, cache_path):
self.index_path = cache_path + ".index"
def __init__(self, cache_path: Union[str, Path]):
self.index_path = cache_path.with_suffix(".index")
self._data = None
self.logger = get_module_logger(self.__class__.__name__)
@@ -757,7 +766,7 @@ class DiskDatasetCache(DatasetCache):
self._data.sort_index(inplace=True)
self._data.to_hdf(self.index_path, key=self.KEY, mode="w", format="table")
# The index should be readable for all users
os.chmod(self.index_path, stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
self.index_path.chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
def sync_from_disk(self):
# The file will not be closed directly if we read_hdf from the disk directly
@@ -795,15 +804,7 @@ class DiskDatasetCache(DatasetCache):
index_data += start_index
return index_data
@staticmethod
def clear_cache(cache_path):
meta_path = cache_path + ".meta"
for p in [cache_path, meta_path, cache_path + ".index", cache_path + ".data"]:
p = Path(p)
if p.exists():
p.unlink()
def gen_dataset_cache(self, cache_path, instruments, fields, freq):
def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, freq):
"""gen_dataset_cache
.. note:: This function does not consider the cache read write lock. Please
@@ -844,11 +845,11 @@ class DiskDatasetCache(DatasetCache):
# get calendar
from .data import Cal
cache_path = Path(cache_path)
_calendar = Cal.calendar(freq=freq)
self.logger.debug(f"Generating dataset cache {cache_path}")
# Make sure the cache runs right when the directory is deleted
# while running
os.makedirs(self.dtst_cache_path, exist_ok=True)
self.clear_cache(cache_path)
features = self.provider.dataset(instruments, fields, _calendar[0], _calendar[-1], freq)
@@ -860,7 +861,7 @@ class DiskDatasetCache(DatasetCache):
features = features.swaplevel("instrument", "datetime").sort_index()
# write cache data
with pd.HDFStore(cache_path + ".data") as store:
with pd.HDFStore(str(cache_path.with_suffix(".data"))) as store:
cache_to_orig_map = dict(zip(remove_fields_space(features.columns), features.columns))
orig_to_cache_map = dict(zip(features.columns, remove_fields_space(features.columns)))
cache_features = features[list(cache_to_orig_map.values())].rename(columns=orig_to_cache_map)
@@ -879,9 +880,9 @@ class DiskDatasetCache(DatasetCache):
},
"meta": {"last_visit": time.time(), "visits": 1},
}
with open(cache_path + ".meta", "wb") as f:
with cache_path.with_suffix(".meta").open("wb") as f:
pickle.dump(meta, f)
os.chmod(cache_path + ".meta", stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
cache_path.with_suffix(".meta").chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
# write index file
im = DiskDatasetCache.IndexManager(cache_path)
index_data = im.build_index_from_data(features)
@@ -890,21 +891,23 @@ class DiskDatasetCache(DatasetCache):
# rename the file after the cache has been generated
# this doesn't work well on windows, but our server won't use windows
# temporarily
os.replace(cache_path + ".data", cache_path)
cache_path.with_suffix(".data").rename(cache_path)
# the fields of the cached features are converted to the original fields
return features.swaplevel("datetime", "instrument")
def update(self, cache_uri):
cp_cache_uri = os.path.join(self.dtst_cache_path, cache_uri)
# FIXME: when updating the cache, the type of C.provider_uri must be str
assert isinstance(C.provider_uri, str), "when updating the cache, the type of C.provider_uri must be str"
cp_cache_uri = self.get_cache_dir().joinpath(cache_uri)
meta_path = cp_cache_uri.with_suffix(".meta")
if not self.check_cache_exists(cp_cache_uri):
self.logger.info(f"The cache {cp_cache_uri} has corrupted. It will be removed")
self.clear_cache(cp_cache_uri)
return 2
im = DiskDatasetCache.IndexManager(cp_cache_uri)
with CacheUtils.writer_lock(self.r, "dataset-%s" % cache_uri):
with open(cp_cache_uri + ".meta", "rb") as f:
with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path())}:dataset-{cache_uri}"):
with meta_path.open("rb") as f:
d = pickle.load(f)
instruments = d["info"]["instruments"]
fields = d["info"]["fields"]
@@ -995,7 +998,7 @@ class DiskDatasetCache(DatasetCache):
# update meta file
d["info"]["last_update"] = str(new_calendar[-1])
with open(cp_cache_uri + ".meta", "wb") as f:
with meta_path.open("wb") as f:
pickle.dump(d, f)
return 0
@@ -1006,26 +1009,25 @@ class SimpleDatasetCache(DatasetCache):
def __init__(self, provider):
super(SimpleDatasetCache, self).__init__(provider)
try:
self.local_cache_path = C["local_cache_path"]
self.local_cache_path: Path = Path(C["local_cache_path"]).expanduser().resolve()
except KeyError as e:
self.logger.error("Assign a local_cache_path in config if you want to use this cache mechanism")
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs):
instruments, fields, freq = self.normalize_uri_args(instruments, fields, freq)
local_cache_path = str(Path(self.local_cache_path).expanduser().resolve())
return hash_args(instruments, fields, start_time, end_time, freq, disk_cache, local_cache_path)
return hash_args(instruments, fields, start_time, end_time, freq, disk_cache, str(self.local_cache_path))
def _dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
if disk_cache == 0:
# In this case, data_set cache is configured but will not be used.
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
os.makedirs(os.path.expanduser(self.local_cache_path), exist_ok=True)
cache_file = os.path.join(
self.local_cache_path, self._uri(instruments, fields, start_time, end_time, freq, disk_cache=disk_cache)
self.local_cache_path.mkdir(exist_ok=True, parents=True)
cache_file = self.local_cache_path.joinpath(
self._uri(instruments, fields, start_time, end_time, freq, disk_cache=disk_cache)
)
gen_flag = False
if os.path.exists(cache_file):
if cache_file.exists():
if disk_cache == 1:
# use cache
df = pd.read_pickle(cache_file)
@@ -1061,8 +1063,8 @@ class DatasetURICache(DatasetCache):
# use ClientDatasetProvider
feature_uri = self._uri(instruments, fields, None, None, freq, disk_cache=disk_cache)
value, expire = MemCacheExpire.get_cache(H["f"], feature_uri)
mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri)
if value is None or expire or not os.path.exists(mnt_feature_uri):
mnt_feature_uri = C.get_data_path(freq).joinpath(C.dataset_cache_dir_name).joinpath(feature_uri)
if value is None or expire or not mnt_feature_uri.exists():
df, uri = self.provider.dataset(
instruments, fields, start_time, end_time, freq, disk_cache, return_uri=True
)
@@ -1072,7 +1074,6 @@ class DatasetURICache(DatasetCache):
# HZ['f'][uri] = df.copy()
get_module_logger("cache").debug(f"get feature from {C.dataset_provider}")
else:
mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri)
df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields)
get_module_logger("cache").debug("get feature from uri cache")

View File

@@ -5,16 +5,11 @@
from __future__ import division
from __future__ import print_function
import os
import re
import abc
import copy
import time
import queue
import bisect
import logging
import importlib
import traceback
import numpy as np
import pandas as pd
from multiprocessing import Pool
@@ -23,7 +18,7 @@ from .cache import H
from ..config import C
from .ops import Operators
from ..log import get_module_logger
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, code_to_fname
from ..utils import parse_field, hash_args, normalize_cache_fields, code_to_fname
from .base import Feature
from .cache import DiskDatasetCache, DiskExpressionCache
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
@@ -48,19 +43,14 @@ class ProviderBackendMixin:
# default provider_uri map
if "provider_uri" not in backend_kwargs:
# if the user has no uri configured, use: uri = uri_map[freq]
# NOTE: uri priority
# 1. backend_obj.kwargs["provider_uri"]
# 2. backend_obj.kwargs["backend_freq_config"]
# 3. C.backend_freq_config, or qlib.init(backend_freq_config={})
# 4. C.provider_uri, or qlib.init(provider_uri="")
provider_uri_map = backend_kwargs.setdefault("backend_freq_config", {})
# NOTE: provider_uri priority
# 1. backend_config: backend_obj["kwargs"]["provider_uri"]
# 2. backend_config: backend_obj["kwargs"]["provider_uri_map"]
# 3. qlib.init: provider_uri
provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {})
freq = kwargs.get("freq", "day")
if C.backend_freq_config is not None:
if freq not in provider_uri_map:
provider_uri_map[freq] = C.backend_freq_config.get(freq, C.get_data_path())
else:
if freq not in provider_uri_map:
provider_uri_map[freq] = C.get_data_path()
if freq not in provider_uri_map:
provider_uri_map[freq] = C.get_data_path(freq)
backend_kwargs["provider_uri"] = provider_uri_map[freq]
backend.setdefault("kwargs", {}).update(**kwargs)
return init_instance_by_config(backend)
@@ -548,11 +538,6 @@ class LocalCalendarProvider(CalendarProvider):
super(LocalCalendarProvider, self).__init__(**kwargs)
self.remote = kwargs.get("remote", False)
@property
def _uri_cal(self):
"""Calendar file uri."""
return os.path.join(C.get_data_path(), "calendars", "{}.txt")
def load_calendar(self, freq, future):
"""Load original calendar timestamp from file.
@@ -612,11 +597,6 @@ class LocalInstrumentProvider(InstrumentProvider):
Provide instrument data from local data source.
"""
@property
def _uri_inst(self):
"""Instrument file uri."""
return os.path.join(C.get_data_path(), "instruments", "{}.txt")
def _load_instruments(self, market, freq):
return self.backend_obj(market=market, freq=freq).data
@@ -665,11 +645,6 @@ class LocalFeatureProvider(FeatureProvider):
super(LocalFeatureProvider, self).__init__(**kwargs)
self.remote = kwargs.get("remote", False)
@property
def _uri_data(self):
"""Static feature file uri."""
return os.path.join(C.get_data_path(), "features", "{}", "{}.{}.bin")
def feature(self, instrument, field, start_index, end_index, freq):
# validate
field = str(field).lower()[1:]
@@ -937,7 +912,7 @@ class ClientDatasetProvider(DatasetProvider):
get_module_logger("data").debug("get result")
try:
# pre-mound nfs, used for demo
mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri)
mnt_feature_uri = C.get_data_path(freq).joinpath(C.dataset_cache_dir_name, feature_uri)
df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields)
get_module_logger("data").debug("finish slicing data")
if return_uri:

View File

@@ -43,8 +43,9 @@ def get_redis_connection():
#################### Data ####################
def read_bin(file_path, start_index, end_index):
with open(file_path, "rb") as f:
def read_bin(file_path: Union[str, Path], start_index, end_index):
file_path = Path(file_path.expanduser().resolve())
with file_path.open("rb") as f:
# read start_index
ref_start_index = int(np.frombuffer(f.read(4), dtype="<f")[0])
si = max(ref_start_index, start_index)