From 5c2f218cfb142d2011293f14262d6e0bb0af80cc Mon Sep 17 00:00:00 2001 From: Young Date: Tue, 29 Sep 2020 11:38:48 +0000 Subject: [PATCH] refactor qlib conf&init. Fix test bug --- qlib/__init__.py | 254 ++++++++++---------- qlib/config.py | 201 ++++++++++------ qlib/contrib/backtest/profit_attribution.py | 2 +- qlib/data/cache.py | 14 +- qlib/data/data.py | 14 +- 5 files changed, 267 insertions(+), 218 deletions(-) diff --git a/qlib/__init__.py b/qlib/__init__.py index df00ecfc1..d865ff3f2 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -10,6 +10,7 @@ import logging import re import subprocess import platform +import yaml from pathlib import Path from .utils import can_use_cache @@ -17,15 +18,13 @@ from .utils import can_use_cache # init qlib def init(default_conf="client", **kwargs): - from .config import ( - C, - _default_client_config, - _default_server_config, - _default_region_config, - REG_CN, - ) + from .config import C, REG_CN, REG_US, QlibConfig from .data.data import register_all_wrappers from .log import get_module_logger, set_log_with_config + from .data.cache import H + + C.reset() + H.clear() _logging_config = C.logging_config if "logging_config" in kwargs: @@ -37,36 +36,28 @@ def init(default_conf="client", **kwargs): LOG = get_module_logger("Initialization", level=logging.INFO) LOG.info(f"default_conf: {default_conf}.") - if default_conf == "server": - base_config = copy.deepcopy(_default_server_config) - elif default_conf == "client": - base_config = copy.deepcopy(_default_client_config) - else: - raise ValueError("Unknown system type") - if base_config: - base_config.update(_default_region_config[kwargs.get("region", REG_CN)]) - for k, v in base_config.items(): - C[k] = v + + C.set_mode(default_conf) for k, v in kwargs.items(): C[k] = v if k not in C: LOG.warning("Unrecognized config %s" % k) - if default_conf == "client": - C["mount_path"] = str(Path(C["mount_path"]).expanduser().resolve()) - if not (C["expression_cache"] is None and C["dataset_cache"] is None): - # check redis - if not can_use_cache(): - LOG.warning( - f"redis connection failed(host={C['redis_host']} port={C['redis_port']}), cache will not be used!" - ) - C["expression_cache"] = None - C["dataset_cache"] = None + C.set_region(kwargs.get('region', REG_CN)) + C.resolve_path() + + if not (C["expression_cache"] is None and C["dataset_cache"] is None): + # check redis + if not can_use_cache(): + LOG.warning( + f"redis connection failed(host={C['redis_host']} port={C['redis_port']}), cache will not be used!" + ) + C["expression_cache"] = None + C["dataset_cache"] = None # check path if server/local - if re.match("^[^/ ]+:.+", C["provider_uri"]) is None: - C["provider_uri"] = str(Path(C["provider_uri"]).expanduser().resolve()) + if C.get_uri_type() == QlibConfig.LOCAL_URI: if not os.path.exists(C["provider_uri"]): if C["auto_mount"]: LOG.error( @@ -76,120 +67,125 @@ def init(default_conf="client", **kwargs): ) else: LOG.warning("auto_path is False, please make sure {} is mounted".format(C["mount_path"])) + elif C.get_uri_type() == QlibConfig.NFS_URI: + _mount_nfs_uri(C) else: - mount_command = "sudo mount.nfs %s %s" % (C["provider_uri"], C["mount_path"]) - # If the provider uri looks like this 172.23.233.89//data/csdesign' - # It will be a nfs path. The client provider will be used - if not C["auto_mount"]: - if not os.path.exists(C["mount_path"]): - raise FileNotFoundError( - "Invalid mount path: {}! Please mount manually: {} or Set init parameter `auto_mount=True`".format( - C["mount_path"], mount_command - ) - ) - else: - # Judging system type - sys_type = platform.system() - if "win" in sys_type.lower(): - # system: window - exec_result = os.popen("mount -o anon %s %s" % (C["provider_uri"], C["mount_path"] + ":")) - result = exec_result.read() - if "85" in result: - LOG.warning("already mounted or window mount path already exists") - elif "53" in result: - raise OSError("not find network path") - elif "error" in result or "错误" in result: - raise OSError("Invalid mount path") - elif C["provider_uri"] in result: - LOG.info("window success mount..") - else: - raise OSError(f"unknown error: {result}") - - # config mount path - C["mount_path"] = C["mount_path"] + ":\\" - else: - # system: linux/Unix/Mac - # check mount - _remote_uri = C["provider_uri"] - _remote_uri = _remote_uri[:-1] if _remote_uri.endswith("/") else _remote_uri - _mount_path = C["mount_path"] - _mount_path = _mount_path[:-1] if _mount_path.endswith("/") else _mount_path - _check_level_num = 2 - _is_mount = False - while _check_level_num: - with subprocess.Popen( - 'mount | grep "{}"'.format(_remote_uri), - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - ) as shell_r: - _command_log = shell_r.stdout.readlines() - if len(_command_log) > 0: - for _c in _command_log: - _temp_mount = _c.decode("utf-8").split(" ")[2] - _temp_mount = _temp_mount[:-1] if _temp_mount.endswith("/") else _temp_mount - if _temp_mount == _mount_path: - _is_mount = True - break - if _is_mount: - break - _remote_uri = "/".join(_remote_uri.split("/")[:-1]) - _mount_path = "/".join(_mount_path.split("/")[:-1]) - _check_level_num -= 1 - - if not _is_mount: - try: - os.makedirs(C["mount_path"], exist_ok=True) - except Exception: - raise OSError( - "Failed to create directory {}, please create {} manually!".format( - C["mount_path"], C["mount_path"] - ) - ) - - # check nfs-common - command_res = os.popen("dpkg -l | grep nfs-common") - command_res = command_res.readlines() - if not command_res: - raise OSError( - "nfs-common is not found, please install it by execute: sudo apt install nfs-common" - ) - # manually mount - command_status = os.system(mount_command) - if command_status == 256: - raise OSError( - "mount {} on {} error! Needs SUDO! Please mount manually: {}".format( - C["provider_uri"], C["mount_path"], mount_command - ) - ) - elif command_status == 32512: - # LOG.error("Command error") - raise OSError("mount {} on {} error! Command error".format(C["provider_uri"], C["mount_path"])) - elif command_status == 0: - LOG.info("Mount finished") - else: - LOG.warning("{} on {} is already mounted".format(_remote_uri, _mount_path)) + raise NotImplementedError(f"This type of URI is not supported") LOG.info("qlib successfully initialized based on %s settings." % default_conf) register_all_wrappers() - try: - if C["auto_mount"]: - LOG.info(f"provider_uri={C['provider_uri']}") - else: - LOG.info(f"mount_path={C['mount_path']}") - except KeyError: - LOG.info(f"provider_uri={C['provider_uri']}") + + LOG.info(f"data_path={C.get_data_path()}") if "flask_server" in C: LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}") +def _mount_nfs_uri(C): + from .log import get_module_logger + LOG = get_module_logger("mount nfs", level=logging.INFO) + + # FIXME: the C["provider_uri"] is modified in this function + # If it is not modified, we can pass only provider_uri or mount_path instead of C + mount_command = "sudo mount.nfs %s %s" % (C["provider_uri"], C["mount_path"]) + # If the provider uri looks like this 172.23.233.89//data/csdesign' + # It will be a nfs path. The client provider will be used + if not C["auto_mount"]: + if not os.path.exists(C["mount_path"]): + raise FileNotFoundError( + "Invalid mount path: {}! Please mount manually: {} or Set init parameter `auto_mount=True`".format( + C["mount_path"], mount_command + ) + ) + else: + # Judging system type + sys_type = platform.system() + if "win" in sys_type.lower(): + # system: window + exec_result = os.popen("mount -o anon %s %s" % (C["provider_uri"], C["mount_path"] + ":")) + result = exec_result.read() + if "85" in result: + LOG.warning("already mounted or window mount path already exists") + elif "53" in result: + raise OSError("not find network path") + elif "error" in result or "错误" in result: + raise OSError("Invalid mount path") + elif C["provider_uri"] in result: + LOG.info("window success mount..") + else: + raise OSError(f"unknown error: {result}") + + # config mount path + C["mount_path"] = C["mount_path"] + ":\\" + else: + # system: linux/Unix/Mac + # check mount + _remote_uri = C["provider_uri"] + _remote_uri = _remote_uri[:-1] if _remote_uri.endswith("/") else _remote_uri + _mount_path = C["mount_path"] + _mount_path = _mount_path[:-1] if _mount_path.endswith("/") else _mount_path + _check_level_num = 2 + _is_mount = False + while _check_level_num: + with subprocess.Popen( + 'mount | grep "{}"'.format(_remote_uri), + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) as shell_r: + _command_log = shell_r.stdout.readlines() + if len(_command_log) > 0: + for _c in _command_log: + _temp_mount = _c.decode("utf-8").split(" ")[2] + _temp_mount = _temp_mount[:-1] if _temp_mount.endswith("/") else _temp_mount + if _temp_mount == _mount_path: + _is_mount = True + break + if _is_mount: + break + _remote_uri = "/".join(_remote_uri.split("/")[:-1]) + _mount_path = "/".join(_mount_path.split("/")[:-1]) + _check_level_num -= 1 + + if not _is_mount: + try: + os.makedirs(C["mount_path"], exist_ok=True) + except Exception: + raise OSError( + "Failed to create directory {}, please create {} manually!".format( + C["mount_path"], C["mount_path"] + ) + ) + + # check nfs-common + command_res = os.popen("dpkg -l | grep nfs-common") + command_res = command_res.readlines() + if not command_res: + raise OSError( + "nfs-common is not found, please install it by execute: sudo apt install nfs-common" + ) + # manually mount + command_status = os.system(mount_command) + if command_status == 256: + raise OSError( + "mount {} on {} error! Needs SUDO! Please mount manually: {}".format( + C["provider_uri"], C["mount_path"], mount_command + ) + ) + elif command_status == 32512: + # LOG.error("Command error") + raise OSError("mount {} on {} error! Command error".format(C["provider_uri"], C["mount_path"])) + elif command_status == 0: + LOG.info("Mount finished") + else: + LOG.warning("{} on {} is already mounted".format(_remote_uri, _mount_path)) + + def init_from_yaml_conf(conf_path): """init_from_yaml_conf :param conf_path: A path to the qlib config in yml format """ - import yaml with open(conf_path) as f: config = yaml.load(f, Loader=yaml.FullLoader) diff --git a/qlib/config.py b/qlib/config.py index c599ced79..977f8f29c 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -1,5 +1,62 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +""" +About the configs +================= + +The config will based on _default_config. +Two modes are supported +- client +- server + +""" + +import copy +from pathlib import Path +import re + + +class Config: + + def __init__(self, default_conf): + self.__dict__["_default_config"] = default_conf # avoiding conflictions with __getattr__ + self.reset() + + def __getitem__(self, key): + return self.__dict__["_config"][key] + + def __getattr__(self, attr): + try: + return self.__dict__["_config"][attr] + except KeyError: + return AttributeError(f"No such {attr} in self._config") + + def __setitem__(self, key, value): + self.__dict__["_config"][key] = value + + def __setattr__(self, attr, value): + self.__dict__["_config"][attr] = value + + def __contains__(self, item): + return item in self.__dict__["_config"] + + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__.update(state) + + def __str__(self): + return str(self.__dict__["_config"]) + + def __repr__(self): + return str(self.__dict__["_config"]) + + def reset(self): + self.__dict__["_config"] = copy.deepcopy(self._default_config) + + def update(self, *args, **kwargs): + self.__dict__["_config"].update(*args, **kwargs) # REGION CONST @@ -70,50 +127,52 @@ _default_config = { }, } -_default_server_config = { - # data provider config - "calendar_provider": "LocalCalendarProvider", - "instrument_provider": "LocalInstrumentProvider", - "feature_provider": "LocalFeatureProvider", - "expression_provider": "LocalExpressionProvider", - "dataset_provider": "LocalDatasetProvider", - "provider": "LocalProvider", - # config it in qlib.init() - "provider_uri": "", - # redis - "redis_host": "127.0.0.1", - "redis_port": 6379, - "redis_task_db": 1, - "kernels": 64, - # cache - "expression_cache": "DiskExpressionCache", - "dataset_cache": "DiskDatasetCache", -} +MODE_CONF = { + 'server': { + # data provider config + "calendar_provider": "LocalCalendarProvider", + "instrument_provider": "LocalInstrumentProvider", + "feature_provider": "LocalFeatureProvider", + "expression_provider": "LocalExpressionProvider", + "dataset_provider": "LocalDatasetProvider", + "provider": "LocalProvider", + # config it in qlib.init() + "provider_uri": "", + # redis + "redis_host": "127.0.0.1", + "redis_port": 6379, + "redis_task_db": 1, + "kernels": 64, + # cache + "expression_cache": "DiskExpressionCache", + "dataset_cache": "DiskDatasetCache", + }, -_default_client_config = { - # data provider config - "calendar_provider": "LocalCalendarProvider", - "instrument_provider": "LocalInstrumentProvider", - "feature_provider": "LocalFeatureProvider", - "expression_provider": "LocalExpressionProvider", - "dataset_provider": "LocalDatasetProvider", - "provider": "LocalProvider", - # config it in user's own code - "provider_uri": "~/.qlib/qlib_data/cn_data", - # cache - # Using parameter 'remote' to announce the client is using server_cache, and the writing access will be disabled. - "expression_cache": "DiskExpressionCache", - "dataset_cache": "DiskDatasetCache", - "calendar_cache": None, - # client config - "kernels": 16, - "mount_path": "~/.qlib/qlib_data/cn_data", - "auto_mount": False, # The nfs is already mounted on our server[auto_mount: False]. - # The nfs should be auto-mounted by qlib on other - # serversS(such as PAI) [auto_mount:True] - "timeout": 100, - "logging_level": "INFO", - "region": REG_CN, + 'client': { + # data provider config + "calendar_provider": "LocalCalendarProvider", + "instrument_provider": "LocalInstrumentProvider", + "feature_provider": "LocalFeatureProvider", + "expression_provider": "LocalExpressionProvider", + "dataset_provider": "LocalDatasetProvider", + "provider": "LocalProvider", + # config it in user's own code + "provider_uri": "~/.qlib/qlib_data/cn_data", + # cache + # Using parameter 'remote' to announce the client is using server_cache, and the writing access will be disabled. + "expression_cache": "DiskExpressionCache", + "dataset_cache": "DiskDatasetCache", + "calendar_cache": None, + # client config + "kernels": 16, + "mount_path": None, + "auto_mount": False, # The nfs is already mounted on our server[auto_mount: False]. + # The nfs should be auto-mounted by qlib on other + # serversS(such as PAI) [auto_mount:True] + "timeout": 100, + "logging_level": "INFO", + "region": REG_CN, + } } @@ -131,37 +190,43 @@ _default_region_config = { } -class Config: - def __getitem__(self, key): - return _default_config[key] +class QlibConfig(Config): + # URI_TYPE + LOCAL_URI = 'local' + NFS_URI = 'nfs' - def __getattr__(self, attr): - try: - return _default_config[attr] - except KeyError: - return AttributeError(f"No such {attr} in _default_config") + def set_mode(self, mode): + # raise KeyError + self.update(MODE_CONF[mode]) + # TODO: update region based on kwargs - def __setitem__(self, key, value): - _default_config[key] = value + def set_region(self, region): + # raise KeyError + self.update(_default_region_config[region]) - def __setattr__(self, attr, value): - _default_config[attr] = value + def resolve_path(self): + # resolve path + if self["mount_path"] is not None: + self["mount_path"]= str(Path(self["mount_path"]).expanduser().resolve()) - def __contains__(self, item): - return item in _default_config + if self.get_uri_type() == QlibConfig.LOCAL_URI: + self["provider_uri"] = str(Path(self["provider_uri"]).expanduser().resolve()) - def __getstate__(self): - return _default_config + def get_uri_type(self): + rm = re.match("^[^/ ]+:.+", self["provider_uri"]) + if rm is None: + return QlibConfig.LOCAL_URI + else: + return QlibConfig.NFS_URI - def __setstate__(self, state): - _default_config.update(state) - - def __str__(self): - return str(_default_config) - - def __repr__(self): - return str(_default_config) + def get_data_path(self): + if self.get_uri_type() == QlibConfig.LOCAL_URI: + return self['provider_uri'] + elif self.get_uri_type() == QlibConfig.NFS_URI: + return self['mount_path'] + else: + raise NotImplementedError(f"This type of uri is not supported") # global config -C = Config() +C = QlibConfig(_default_config) diff --git a/qlib/contrib/backtest/profit_attribution.py b/qlib/contrib/backtest/profit_attribution.py index d51fc450e..20c6f638f 100644 --- a/qlib/contrib/backtest/profit_attribution.py +++ b/qlib/contrib/backtest/profit_attribution.py @@ -33,7 +33,7 @@ def get_benchmark_weight( """ if not path: - path = Path(C.mount_path).expanduser() / "raw" / "AIndexMembers" / "weights.csv" + path = Path(C.get_data_path()).expanduser() / "raw" / "AIndexMembers" / "weights.csv" # TODO: the storage of weights should be implemented in a more elegent way # TODO: The benchmark is not consistant with the filename in instruments. bench_weight_df = pd.read_csv(path, usecols=["code", "date", "index", "weight"]) diff --git a/qlib/data/cache.py b/qlib/data/cache.py index 481996609..3cfb8dae9 100644 --- a/qlib/data/cache.py +++ b/qlib/data/cache.py @@ -388,10 +388,7 @@ class DiskExpressionCache(ExpressionCache): self.r = get_redis_connection() # remote==True means client is using this module, writing behaviour will not be allowed. self.remote = kwargs.get("remote", False) - if self.remote: - self.expr_cache_path = os.path.join(C.mount_path, C.features_cache_dir_name) - else: - self.expr_cache_path = os.path.join(C.provider_uri, C.features_cache_dir_name) + self.expr_cache_path = os.path.join(C.get_data_path(), C.features_cache_dir_name) os.makedirs(self.expr_cache_path, exist_ok=True) def _uri(self, instrument, field, start_time, end_time, freq): @@ -562,10 +559,7 @@ class DiskDatasetCache(DatasetCache): super(DiskDatasetCache, self).__init__(provider) self.r = get_redis_connection() self.remote = kwargs.get("remote", False) - if self.remote: - self.dtst_cache_path = os.path.join(C.mount_path, C.dataset_cache_dir_name) - else: - self.dtst_cache_path = os.path.join(C.provider_uri, C.dataset_cache_dir_name) + self.dtst_cache_path = os.path.join(C.get_data_path(), C.dataset_cache_dir_name) os.makedirs(self.dtst_cache_path, exist_ok=True) @staticmethod @@ -1003,7 +997,7 @@ class DatasetURICache(DatasetCache): # use ClientDatasetProvider feature_uri = self._uri(instruments, fields, None, None, freq, disk_cache=disk_cache) value, expire = MemCacheExpire.get_cache(H["f"], feature_uri) - mnt_feature_uri = os.path.join(C.mount_path, C.dataset_cache_dir_name, feature_uri) + mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri) if value is None or expire or not os.path.exists(mnt_feature_uri): df, uri = self.provider.dataset( instruments, fields, start_time, end_time, freq, disk_cache, return_uri=True @@ -1014,7 +1008,7 @@ class DatasetURICache(DatasetCache): # HZ['f'][uri] = df.copy() get_module_logger("cache").debug(f"get feature from {C.dataset_provider}") else: - mnt_feature_uri = os.path.join(C.mount_path, C.dataset_cache_dir_name, feature_uri) + mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri) df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields) get_module_logger("cache").debug("get feature from uri cache") diff --git a/qlib/data/data.py b/qlib/data/data.py index e2e0f6662..9b72d1b1c 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -502,10 +502,7 @@ class LocalCalendarProvider(CalendarProvider): @property def _uri_cal(self): """Calendar file uri.""" - if self.remote: - return os.path.join(C.mount_path, "calendars", "{}.txt") - else: - return os.path.join(C.provider_uri, "calendars", "{}.txt") + return os.path.join(C.get_data_path(), "calendars", "{}.txt") def _load_calendar(self, freq, future): """Load original calendar timestamp from file. @@ -568,7 +565,7 @@ class LocalInstrumentProvider(InstrumentProvider): @property def _uri_inst(self): """Instrument file uri.""" - return os.path.join(C.provider_uri, "instruments", "{}.txt") + return os.path.join(C.get_data_path(), "instruments", "{}.txt") def _load_instruments(self, market): fname = self._uri_inst.format(market) @@ -637,10 +634,7 @@ class LocalFeatureProvider(FeatureProvider): @property def _uri_data(self): """Static feature file uri.""" - if self.remote: - return os.path.join(C.mount_path, "features", "{}", "{}.{}.bin") - else: - return os.path.join(C.provider_uri, "features", "{}", "{}.{}.bin") + return os.path.join(C.get_data_path(), "features", "{}", "{}.{}.bin") def feature(self, instrument, field, start_index, end_index, freq): # validate @@ -914,7 +908,7 @@ class ClientDatasetProvider(DatasetProvider): get_module_logger("data").debug("get result") try: # pre-mound nfs, used for demo - mnt_feature_uri = os.path.join(C.mount_path, C.dataset_cache_dir_name, feature_uri) + mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri) df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields) get_module_logger("data").debug("finish slicing data") if return_uri: