1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 12:00:58 +08:00

refactor qlib conf&init. Fix test bug

This commit is contained in:
Young
2020-09-29 11:38:48 +00:00
committed by you-n-g
parent 34ce3ad9bf
commit 5c2f218cfb
5 changed files with 267 additions and 218 deletions

View File

@@ -10,6 +10,7 @@ import logging
import re
import subprocess
import platform
import yaml
from pathlib import Path
from .utils import can_use_cache
@@ -17,15 +18,13 @@ from .utils import can_use_cache
# init qlib
def init(default_conf="client", **kwargs):
from .config import (
C,
_default_client_config,
_default_server_config,
_default_region_config,
REG_CN,
)
from .config import C, REG_CN, REG_US, QlibConfig
from .data.data import register_all_wrappers
from .log import get_module_logger, set_log_with_config
from .data.cache import H
C.reset()
H.clear()
_logging_config = C.logging_config
if "logging_config" in kwargs:
@@ -37,36 +36,28 @@ def init(default_conf="client", **kwargs):
LOG = get_module_logger("Initialization", level=logging.INFO)
LOG.info(f"default_conf: {default_conf}.")
if default_conf == "server":
base_config = copy.deepcopy(_default_server_config)
elif default_conf == "client":
base_config = copy.deepcopy(_default_client_config)
else:
raise ValueError("Unknown system type")
if base_config:
base_config.update(_default_region_config[kwargs.get("region", REG_CN)])
for k, v in base_config.items():
C[k] = v
C.set_mode(default_conf)
for k, v in kwargs.items():
C[k] = v
if k not in C:
LOG.warning("Unrecognized config %s" % k)
if default_conf == "client":
C["mount_path"] = str(Path(C["mount_path"]).expanduser().resolve())
if not (C["expression_cache"] is None and C["dataset_cache"] is None):
# check redis
if not can_use_cache():
LOG.warning(
f"redis connection failed(host={C['redis_host']} port={C['redis_port']}), cache will not be used!"
)
C["expression_cache"] = None
C["dataset_cache"] = None
C.set_region(kwargs.get('region', REG_CN))
C.resolve_path()
if not (C["expression_cache"] is None and C["dataset_cache"] is None):
# check redis
if not can_use_cache():
LOG.warning(
f"redis connection failed(host={C['redis_host']} port={C['redis_port']}), cache will not be used!"
)
C["expression_cache"] = None
C["dataset_cache"] = None
# check path if server/local
if re.match("^[^/ ]+:.+", C["provider_uri"]) is None:
C["provider_uri"] = str(Path(C["provider_uri"]).expanduser().resolve())
if C.get_uri_type() == QlibConfig.LOCAL_URI:
if not os.path.exists(C["provider_uri"]):
if C["auto_mount"]:
LOG.error(
@@ -76,120 +67,125 @@ def init(default_conf="client", **kwargs):
)
else:
LOG.warning("auto_path is False, please make sure {} is mounted".format(C["mount_path"]))
elif C.get_uri_type() == QlibConfig.NFS_URI:
_mount_nfs_uri(C)
else:
mount_command = "sudo mount.nfs %s %s" % (C["provider_uri"], C["mount_path"])
# If the provider uri looks like this 172.23.233.89//data/csdesign'
# It will be a nfs path. The client provider will be used
if not C["auto_mount"]:
if not os.path.exists(C["mount_path"]):
raise FileNotFoundError(
"Invalid mount path: {}! Please mount manually: {} or Set init parameter `auto_mount=True`".format(
C["mount_path"], mount_command
)
)
else:
# Judging system type
sys_type = platform.system()
if "win" in sys_type.lower():
# system: window
exec_result = os.popen("mount -o anon %s %s" % (C["provider_uri"], C["mount_path"] + ":"))
result = exec_result.read()
if "85" in result:
LOG.warning("already mounted or window mount path already exists")
elif "53" in result:
raise OSError("not find network path")
elif "error" in result or "错误" in result:
raise OSError("Invalid mount path")
elif C["provider_uri"] in result:
LOG.info("window success mount..")
else:
raise OSError(f"unknown error: {result}")
# config mount path
C["mount_path"] = C["mount_path"] + ":\\"
else:
# system: linux/Unix/Mac
# check mount
_remote_uri = C["provider_uri"]
_remote_uri = _remote_uri[:-1] if _remote_uri.endswith("/") else _remote_uri
_mount_path = C["mount_path"]
_mount_path = _mount_path[:-1] if _mount_path.endswith("/") else _mount_path
_check_level_num = 2
_is_mount = False
while _check_level_num:
with subprocess.Popen(
'mount | grep "{}"'.format(_remote_uri),
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
) as shell_r:
_command_log = shell_r.stdout.readlines()
if len(_command_log) > 0:
for _c in _command_log:
_temp_mount = _c.decode("utf-8").split(" ")[2]
_temp_mount = _temp_mount[:-1] if _temp_mount.endswith("/") else _temp_mount
if _temp_mount == _mount_path:
_is_mount = True
break
if _is_mount:
break
_remote_uri = "/".join(_remote_uri.split("/")[:-1])
_mount_path = "/".join(_mount_path.split("/")[:-1])
_check_level_num -= 1
if not _is_mount:
try:
os.makedirs(C["mount_path"], exist_ok=True)
except Exception:
raise OSError(
"Failed to create directory {}, please create {} manually!".format(
C["mount_path"], C["mount_path"]
)
)
# check nfs-common
command_res = os.popen("dpkg -l | grep nfs-common")
command_res = command_res.readlines()
if not command_res:
raise OSError(
"nfs-common is not found, please install it by execute: sudo apt install nfs-common"
)
# manually mount
command_status = os.system(mount_command)
if command_status == 256:
raise OSError(
"mount {} on {} error! Needs SUDO! Please mount manually: {}".format(
C["provider_uri"], C["mount_path"], mount_command
)
)
elif command_status == 32512:
# LOG.error("Command error")
raise OSError("mount {} on {} error! Command error".format(C["provider_uri"], C["mount_path"]))
elif command_status == 0:
LOG.info("Mount finished")
else:
LOG.warning("{} on {} is already mounted".format(_remote_uri, _mount_path))
raise NotImplementedError(f"This type of URI is not supported")
LOG.info("qlib successfully initialized based on %s settings." % default_conf)
register_all_wrappers()
try:
if C["auto_mount"]:
LOG.info(f"provider_uri={C['provider_uri']}")
else:
LOG.info(f"mount_path={C['mount_path']}")
except KeyError:
LOG.info(f"provider_uri={C['provider_uri']}")
LOG.info(f"data_path={C.get_data_path()}")
if "flask_server" in C:
LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
def _mount_nfs_uri(C):
from .log import get_module_logger
LOG = get_module_logger("mount nfs", level=logging.INFO)
# FIXME: the C["provider_uri"] is modified in this function
# If it is not modified, we can pass only provider_uri or mount_path instead of C
mount_command = "sudo mount.nfs %s %s" % (C["provider_uri"], C["mount_path"])
# If the provider uri looks like this 172.23.233.89//data/csdesign'
# It will be a nfs path. The client provider will be used
if not C["auto_mount"]:
if not os.path.exists(C["mount_path"]):
raise FileNotFoundError(
"Invalid mount path: {}! Please mount manually: {} or Set init parameter `auto_mount=True`".format(
C["mount_path"], mount_command
)
)
else:
# Judging system type
sys_type = platform.system()
if "win" in sys_type.lower():
# system: window
exec_result = os.popen("mount -o anon %s %s" % (C["provider_uri"], C["mount_path"] + ":"))
result = exec_result.read()
if "85" in result:
LOG.warning("already mounted or window mount path already exists")
elif "53" in result:
raise OSError("not find network path")
elif "error" in result or "错误" in result:
raise OSError("Invalid mount path")
elif C["provider_uri"] in result:
LOG.info("window success mount..")
else:
raise OSError(f"unknown error: {result}")
# config mount path
C["mount_path"] = C["mount_path"] + ":\\"
else:
# system: linux/Unix/Mac
# check mount
_remote_uri = C["provider_uri"]
_remote_uri = _remote_uri[:-1] if _remote_uri.endswith("/") else _remote_uri
_mount_path = C["mount_path"]
_mount_path = _mount_path[:-1] if _mount_path.endswith("/") else _mount_path
_check_level_num = 2
_is_mount = False
while _check_level_num:
with subprocess.Popen(
'mount | grep "{}"'.format(_remote_uri),
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
) as shell_r:
_command_log = shell_r.stdout.readlines()
if len(_command_log) > 0:
for _c in _command_log:
_temp_mount = _c.decode("utf-8").split(" ")[2]
_temp_mount = _temp_mount[:-1] if _temp_mount.endswith("/") else _temp_mount
if _temp_mount == _mount_path:
_is_mount = True
break
if _is_mount:
break
_remote_uri = "/".join(_remote_uri.split("/")[:-1])
_mount_path = "/".join(_mount_path.split("/")[:-1])
_check_level_num -= 1
if not _is_mount:
try:
os.makedirs(C["mount_path"], exist_ok=True)
except Exception:
raise OSError(
"Failed to create directory {}, please create {} manually!".format(
C["mount_path"], C["mount_path"]
)
)
# check nfs-common
command_res = os.popen("dpkg -l | grep nfs-common")
command_res = command_res.readlines()
if not command_res:
raise OSError(
"nfs-common is not found, please install it by execute: sudo apt install nfs-common"
)
# manually mount
command_status = os.system(mount_command)
if command_status == 256:
raise OSError(
"mount {} on {} error! Needs SUDO! Please mount manually: {}".format(
C["provider_uri"], C["mount_path"], mount_command
)
)
elif command_status == 32512:
# LOG.error("Command error")
raise OSError("mount {} on {} error! Command error".format(C["provider_uri"], C["mount_path"]))
elif command_status == 0:
LOG.info("Mount finished")
else:
LOG.warning("{} on {} is already mounted".format(_remote_uri, _mount_path))
def init_from_yaml_conf(conf_path):
"""init_from_yaml_conf
:param conf_path: A path to the qlib config in yml format
"""
import yaml
with open(conf_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)

View File

@@ -1,5 +1,62 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
About the configs
=================
The config will based on _default_config.
Two modes are supported
- client
- server
"""
import copy
from pathlib import Path
import re
class Config:
def __init__(self, default_conf):
self.__dict__["_default_config"] = default_conf # avoiding conflictions with __getattr__
self.reset()
def __getitem__(self, key):
return self.__dict__["_config"][key]
def __getattr__(self, attr):
try:
return self.__dict__["_config"][attr]
except KeyError:
return AttributeError(f"No such {attr} in self._config")
def __setitem__(self, key, value):
self.__dict__["_config"][key] = value
def __setattr__(self, attr, value):
self.__dict__["_config"][attr] = value
def __contains__(self, item):
return item in self.__dict__["_config"]
def __getstate__(self):
return self.__dict__
def __setstate__(self, state):
self.__dict__.update(state)
def __str__(self):
return str(self.__dict__["_config"])
def __repr__(self):
return str(self.__dict__["_config"])
def reset(self):
self.__dict__["_config"] = copy.deepcopy(self._default_config)
def update(self, *args, **kwargs):
self.__dict__["_config"].update(*args, **kwargs)
# REGION CONST
@@ -70,50 +127,52 @@ _default_config = {
},
}
_default_server_config = {
# data provider config
"calendar_provider": "LocalCalendarProvider",
"instrument_provider": "LocalInstrumentProvider",
"feature_provider": "LocalFeatureProvider",
"expression_provider": "LocalExpressionProvider",
"dataset_provider": "LocalDatasetProvider",
"provider": "LocalProvider",
# config it in qlib.init()
"provider_uri": "",
# redis
"redis_host": "127.0.0.1",
"redis_port": 6379,
"redis_task_db": 1,
"kernels": 64,
# cache
"expression_cache": "DiskExpressionCache",
"dataset_cache": "DiskDatasetCache",
}
MODE_CONF = {
'server': {
# data provider config
"calendar_provider": "LocalCalendarProvider",
"instrument_provider": "LocalInstrumentProvider",
"feature_provider": "LocalFeatureProvider",
"expression_provider": "LocalExpressionProvider",
"dataset_provider": "LocalDatasetProvider",
"provider": "LocalProvider",
# config it in qlib.init()
"provider_uri": "",
# redis
"redis_host": "127.0.0.1",
"redis_port": 6379,
"redis_task_db": 1,
"kernels": 64,
# cache
"expression_cache": "DiskExpressionCache",
"dataset_cache": "DiskDatasetCache",
},
_default_client_config = {
# data provider config
"calendar_provider": "LocalCalendarProvider",
"instrument_provider": "LocalInstrumentProvider",
"feature_provider": "LocalFeatureProvider",
"expression_provider": "LocalExpressionProvider",
"dataset_provider": "LocalDatasetProvider",
"provider": "LocalProvider",
# config it in user's own code
"provider_uri": "~/.qlib/qlib_data/cn_data",
# cache
# Using parameter 'remote' to announce the client is using server_cache, and the writing access will be disabled.
"expression_cache": "DiskExpressionCache",
"dataset_cache": "DiskDatasetCache",
"calendar_cache": None,
# client config
"kernels": 16,
"mount_path": "~/.qlib/qlib_data/cn_data",
"auto_mount": False, # The nfs is already mounted on our server[auto_mount: False].
# The nfs should be auto-mounted by qlib on other
# serversS(such as PAI) [auto_mount:True]
"timeout": 100,
"logging_level": "INFO",
"region": REG_CN,
'client': {
# data provider config
"calendar_provider": "LocalCalendarProvider",
"instrument_provider": "LocalInstrumentProvider",
"feature_provider": "LocalFeatureProvider",
"expression_provider": "LocalExpressionProvider",
"dataset_provider": "LocalDatasetProvider",
"provider": "LocalProvider",
# config it in user's own code
"provider_uri": "~/.qlib/qlib_data/cn_data",
# cache
# Using parameter 'remote' to announce the client is using server_cache, and the writing access will be disabled.
"expression_cache": "DiskExpressionCache",
"dataset_cache": "DiskDatasetCache",
"calendar_cache": None,
# client config
"kernels": 16,
"mount_path": None,
"auto_mount": False, # The nfs is already mounted on our server[auto_mount: False].
# The nfs should be auto-mounted by qlib on other
# serversS(such as PAI) [auto_mount:True]
"timeout": 100,
"logging_level": "INFO",
"region": REG_CN,
}
}
@@ -131,37 +190,43 @@ _default_region_config = {
}
class Config:
def __getitem__(self, key):
return _default_config[key]
class QlibConfig(Config):
# URI_TYPE
LOCAL_URI = 'local'
NFS_URI = 'nfs'
def __getattr__(self, attr):
try:
return _default_config[attr]
except KeyError:
return AttributeError(f"No such {attr} in _default_config")
def set_mode(self, mode):
# raise KeyError
self.update(MODE_CONF[mode])
# TODO: update region based on kwargs
def __setitem__(self, key, value):
_default_config[key] = value
def set_region(self, region):
# raise KeyError
self.update(_default_region_config[region])
def __setattr__(self, attr, value):
_default_config[attr] = value
def resolve_path(self):
# resolve path
if self["mount_path"] is not None:
self["mount_path"]= str(Path(self["mount_path"]).expanduser().resolve())
def __contains__(self, item):
return item in _default_config
if self.get_uri_type() == QlibConfig.LOCAL_URI:
self["provider_uri"] = str(Path(self["provider_uri"]).expanduser().resolve())
def __getstate__(self):
return _default_config
def get_uri_type(self):
rm = re.match("^[^/ ]+:.+", self["provider_uri"])
if rm is None:
return QlibConfig.LOCAL_URI
else:
return QlibConfig.NFS_URI
def __setstate__(self, state):
_default_config.update(state)
def __str__(self):
return str(_default_config)
def __repr__(self):
return str(_default_config)
def get_data_path(self):
if self.get_uri_type() == QlibConfig.LOCAL_URI:
return self['provider_uri']
elif self.get_uri_type() == QlibConfig.NFS_URI:
return self['mount_path']
else:
raise NotImplementedError(f"This type of uri is not supported")
# global config
C = Config()
C = QlibConfig(_default_config)

View File

@@ -33,7 +33,7 @@ def get_benchmark_weight(
"""
if not path:
path = Path(C.mount_path).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
path = Path(C.get_data_path()).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
# TODO: the storage of weights should be implemented in a more elegent way
# TODO: The benchmark is not consistant with the filename in instruments.
bench_weight_df = pd.read_csv(path, usecols=["code", "date", "index", "weight"])

View File

@@ -388,10 +388,7 @@ class DiskExpressionCache(ExpressionCache):
self.r = get_redis_connection()
# remote==True means client is using this module, writing behaviour will not be allowed.
self.remote = kwargs.get("remote", False)
if self.remote:
self.expr_cache_path = os.path.join(C.mount_path, C.features_cache_dir_name)
else:
self.expr_cache_path = os.path.join(C.provider_uri, C.features_cache_dir_name)
self.expr_cache_path = os.path.join(C.get_data_path(), C.features_cache_dir_name)
os.makedirs(self.expr_cache_path, exist_ok=True)
def _uri(self, instrument, field, start_time, end_time, freq):
@@ -562,10 +559,7 @@ class DiskDatasetCache(DatasetCache):
super(DiskDatasetCache, self).__init__(provider)
self.r = get_redis_connection()
self.remote = kwargs.get("remote", False)
if self.remote:
self.dtst_cache_path = os.path.join(C.mount_path, C.dataset_cache_dir_name)
else:
self.dtst_cache_path = os.path.join(C.provider_uri, C.dataset_cache_dir_name)
self.dtst_cache_path = os.path.join(C.get_data_path(), C.dataset_cache_dir_name)
os.makedirs(self.dtst_cache_path, exist_ok=True)
@staticmethod
@@ -1003,7 +997,7 @@ class DatasetURICache(DatasetCache):
# use ClientDatasetProvider
feature_uri = self._uri(instruments, fields, None, None, freq, disk_cache=disk_cache)
value, expire = MemCacheExpire.get_cache(H["f"], feature_uri)
mnt_feature_uri = os.path.join(C.mount_path, C.dataset_cache_dir_name, feature_uri)
mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri)
if value is None or expire or not os.path.exists(mnt_feature_uri):
df, uri = self.provider.dataset(
instruments, fields, start_time, end_time, freq, disk_cache, return_uri=True
@@ -1014,7 +1008,7 @@ class DatasetURICache(DatasetCache):
# HZ['f'][uri] = df.copy()
get_module_logger("cache").debug(f"get feature from {C.dataset_provider}")
else:
mnt_feature_uri = os.path.join(C.mount_path, C.dataset_cache_dir_name, feature_uri)
mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri)
df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields)
get_module_logger("cache").debug("get feature from uri cache")

View File

@@ -502,10 +502,7 @@ class LocalCalendarProvider(CalendarProvider):
@property
def _uri_cal(self):
"""Calendar file uri."""
if self.remote:
return os.path.join(C.mount_path, "calendars", "{}.txt")
else:
return os.path.join(C.provider_uri, "calendars", "{}.txt")
return os.path.join(C.get_data_path(), "calendars", "{}.txt")
def _load_calendar(self, freq, future):
"""Load original calendar timestamp from file.
@@ -568,7 +565,7 @@ class LocalInstrumentProvider(InstrumentProvider):
@property
def _uri_inst(self):
"""Instrument file uri."""
return os.path.join(C.provider_uri, "instruments", "{}.txt")
return os.path.join(C.get_data_path(), "instruments", "{}.txt")
def _load_instruments(self, market):
fname = self._uri_inst.format(market)
@@ -637,10 +634,7 @@ class LocalFeatureProvider(FeatureProvider):
@property
def _uri_data(self):
"""Static feature file uri."""
if self.remote:
return os.path.join(C.mount_path, "features", "{}", "{}.{}.bin")
else:
return os.path.join(C.provider_uri, "features", "{}", "{}.{}.bin")
return os.path.join(C.get_data_path(), "features", "{}", "{}.{}.bin")
def feature(self, instrument, field, start_index, end_index, freq):
# validate
@@ -914,7 +908,7 @@ class ClientDatasetProvider(DatasetProvider):
get_module_logger("data").debug("get result")
try:
# pre-mound nfs, used for demo
mnt_feature_uri = os.path.join(C.mount_path, C.dataset_cache_dir_name, feature_uri)
mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri)
df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields)
get_module_logger("data").debug("finish slicing data")
if return_uri: