diff --git a/docs/component/data.rst b/docs/component/data.rst index 1218d0c1b..854ab1c27 100644 --- a/docs/component/data.rst +++ b/docs/component/data.rst @@ -195,6 +195,7 @@ Feature - `ExpressionOps` `ExpressionOps` will use operator for feature construction. To know more about ``Operator``, please refer to `Operator API <../reference/api.html#module-qlib.data.ops>`_. + Also, ``Qlib`` supports users to define their own custom ``Operator``, an example has been given in ``tests/test_register_ops.py``. To know more about ``Feature``, please refer to `Feature API <../reference/api.html#module-qlib.data.base>`_. diff --git a/qlib/__init__.py b/qlib/__init__.py index c7a310562..23dbd7038 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -6,90 +6,45 @@ __version__ = "0.6.1.99" import os -import re -import sys -import copy import yaml import logging import platform import subprocess -from pathlib import Path -from .utils import can_use_cache, init_instance_by_config, check_qlib_data -from .workflow.utils import experiment_exit_handler # init qlib def init(default_conf="client", **kwargs): - from .config import C, REG_CN, REG_US, QlibConfig - from .data.data import register_all_wrappers - from .log import get_module_logger, set_log_with_config + from .config import C + from .log import get_module_logger from .data.cache import H - from .workflow import R, QlibRecorder - C.reset() H.clear() - _logging_config = C.logging_config - if "logging_config" in kwargs: - _logging_config = kwargs["logging_config"] - - # set global config - if _logging_config: - set_log_with_config(_logging_config) - # FIXME: this logger ignored the level in config - LOG = get_module_logger("Initialization", level=logging.INFO) - LOG.info(f"default_conf: {default_conf}.") + logger = get_module_logger("Initialization", level=logging.INFO) - C.set_mode(default_conf) - C.set_region(kwargs.get("region", C["region"] if "region" in C else REG_CN)) - - for k, v in kwargs.items(): - if k not in C: - LOG.warning("Unrecognized config %s" % k) - else: - C[k] = v - - C.resolve_path() - - if not (C["expression_cache"] is None and C["dataset_cache"] is None): - # check redis - if not can_use_cache(): - LOG.warning( - f"redis connection failed(host={C['redis_host']} port={C['redis_port']}), cache will not be used!" - ) - C["expression_cache"] = None - C["dataset_cache"] = None + C.set(default_conf, **kwargs) # check path if server/local - if C.get_uri_type() == QlibConfig.LOCAL_URI: + if C.get_uri_type() == C.LOCAL_URI: if not os.path.exists(C["provider_uri"]): if C["auto_mount"]: - LOG.error( + logger.error( f"Invalid provider uri: {C['provider_uri']}, please check if a valid provider uri has been set. This path does not exist." ) else: - LOG.warning(f"auto_path is False, please make sure {C['mount_path']} is mounted") - elif C.get_uri_type() == QlibConfig.NFS_URI: + logger.warning(f"auto_path is False, please make sure {C['mount_path']} is mounted") + elif C.get_uri_type() == C.NFS_URI: _mount_nfs_uri(C) else: raise NotImplementedError(f"This type of URI is not supported") - LOG.info("qlib successfully initialized based on %s settings." % default_conf) - register_all_wrappers() - - LOG.info(f"data_path={C.get_data_path()}") + C.register() if "flask_server" in C: - LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}") - - # set up QlibRecorder - exp_manager = init_instance_by_config(C["exp_manager"]) - qr = QlibRecorder(exp_manager) - R.register(qr) - # clean up experiment when python program ends - experiment_exit_handler() - check_qlib_data(C) + logger.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}") + logger.info("qlib successfully initialized based on %s settings." % default_conf) + logger.info(f"data_path={C.get_data_path()}") def _mount_nfs_uri(C): diff --git a/qlib/config.py b/qlib/config.py index 31b34bacd..a65d41041 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -11,11 +11,12 @@ Two modes are supported """ -import copy -from pathlib import Path -import re import os +import re +import copy +import logging import multiprocessing +from pathlib import Path class Config: @@ -59,6 +60,9 @@ class Config: def update(self, *args, **kwargs): self.__dict__["_config"].update(*args, **kwargs) + def set_conf_from_C(self, config_c): + self.update(**config_c.__dict__["_config"]) + # REGION CONST REG_CN = "cn" @@ -184,6 +188,8 @@ MODE_CONF = { "timeout": 100, "logging_level": "INFO", "region": REG_CN, + ## Custom Operator + "custom_ops": [], }, } @@ -207,6 +213,10 @@ class QlibConfig(Config): LOCAL_URI = "local" NFS_URI = "nfs" + def __init__(self, default_conf): + super().__init__(default_conf) + self._registered = False + def set_mode(self, mode): # raise KeyError self.update(MODE_CONF[mode]) @@ -243,6 +253,64 @@ class QlibConfig(Config): else: raise NotImplementedError(f"This type of uri is not supported") + def set(self, default_conf="client", **kwargs): + from .utils import set_log_with_config, get_module_logger, can_use_cache + + self.reset() + + _logging_config = self.logging_config + if "logging_config" in kwargs: + _logging_config = kwargs["logging_config"] + + # set global config + if _logging_config: + set_log_with_config(_logging_config) + + # FIXME: this logger ignored the level in config + logger = get_module_logger("Initialization", level=logging.INFO) + logger.info(f"default_conf: {default_conf}.") + + self.set_mode(default_conf) + self.set_region(kwargs.get("region", self["region"] if "region" in self else REG_CN)) + + for k, v in kwargs.items(): + if k not in self: + logger.warning("Unrecognized config %s" % k) + self[k] = v + + self.resolve_path() + + if not (self["expression_cache"] is None and self["dataset_cache"] is None): + # check redis + if not can_use_cache(): + logger.warning( + f"redis connection failed(host={self['redis_host']} port={self['redis_port']}), cache will not be used!" + ) + self["expression_cache"] = None + self["dataset_cache"] = None + + def register(self): + from .utils import init_instance_by_config + from .data.ops import register_custom_ops + from .data.data import register_all_wrappers + from .workflow import R, QlibRecorder + from .workflow.utils import experiment_exit_handler + + register_custom_ops(self) + register_all_wrappers(self) + # set up QlibRecorder + exp_manager = init_instance_by_config(self["exp_manager"]) + qr = QlibRecorder(exp_manager) + R.register(qr) + # clean up experiment when python program ends + experiment_exit_handler() + + self._registered = True + + @property + def registered(self): + return self._registered + # global config C = QlibConfig(_default_config) diff --git a/qlib/data/cache.py b/qlib/data/cache.py index 6433127f4..243736ddc 100644 --- a/qlib/data/cache.py +++ b/qlib/data/cache.py @@ -33,7 +33,7 @@ from ..utils import ( from ..log import get_module_logger from .base import Feature -from .ops import * +from .ops import Operators class QlibCacheException(RuntimeError): diff --git a/qlib/data/data.py b/qlib/data/data.py index f16e14b7a..ece3c3641 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -19,17 +19,12 @@ from multiprocessing import Pool from .cache import H from ..config import C -from .ops import * +from .ops import Operators from ..log import get_module_logger from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, code_to_fname from .base import Feature from .cache import DiskDatasetCache, DiskExpressionCache -from ..utils import ( - Wrapper, - init_instance_by_config, - register_wrapper, - get_module_by_module_path, -) +from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path class CalendarProvider(abc.ABC): @@ -471,11 +466,10 @@ class DatasetProvider(abc.ABC): """ # FIXME: Windows OS or MacOS using spawn: https://docs.python.org/3.8/library/multiprocessing.html?highlight=spawn#contexts-and-start-methods - global C - C = g_config # NOTE: This place is compatible with windows, windows multi-process is spawn - if getattr(ExpressionD, "_provider", None) is None: - register_all_wrappers() + if not C.registered: + C.set_conf_from_C(g_config) + C.register() obj = dict() for field in column_names: @@ -1058,7 +1052,7 @@ DatasetD: DatasetProviderWrapper = Wrapper() D: BaseProviderWrapper = Wrapper() -def register_all_wrappers(): +def register_all_wrappers(C): """register_all_wrappers""" logger = get_module_logger("data") module = get_module_by_module_path("qlib.data") diff --git a/qlib/data/ops.py b/qlib/data/ops.py index 7c13d345f..91f7349d2 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -23,53 +23,6 @@ except ImportError as err: ) raise -__all__ = ( - "Ref", - "Max", - "Min", - "Sum", - "Mean", - "Std", - "Var", - "Skew", - "Kurt", - "Med", - "Mad", - "Slope", - "Rsquare", - "Resi", - "Rank", - "Quantile", - "Count", - "EMA", - "WMA", - "Corr", - "Cov", - "Delta", - "Abs", - "Sign", - "Log", - "Power", - "Add", - "Sub", - "Mul", - "Div", - "Greater", - "Less", - "And", - "Or", - "Not", - "Gt", - "Ge", - "Lt", - "Le", - "Eq", - "Ne", - "Mask", - "IdxMax", - "IdxMin", - "If", -) np.seterr(invalid="ignore") @@ -1441,3 +1394,87 @@ class Cov(PairRolling): def __init__(self, feature_left, feature_right, N): super(Cov, self).__init__(feature_left, feature_right, N, "cov") + + +OpsList = [ + Ref, + Max, + Min, + Sum, + Mean, + Std, + Var, + Skew, + Kurt, + Med, + Mad, + Slope, + Rsquare, + Resi, + Rank, + Quantile, + Count, + EMA, + WMA, + Corr, + Cov, + Delta, + Abs, + Sign, + Log, + Power, + Add, + Sub, + Mul, + Div, + Greater, + Less, + And, + Or, + Not, + Gt, + Ge, + Lt, + Le, + Eq, + Ne, + Mask, + IdxMax, + IdxMin, + If, +] + + +class OpsWrapper(object): + """Ops Wrapper""" + + def __init__(self): + self._ops = {} + + def register(self, ops_list): + for operator in ops_list: + if not issubclass(operator, ExpressionOps): + raise TypeError("operator must be subclass of ExpressionOps, not {}".format(operator)) + + if operator.__name__ in self._ops: + get_module_logger(self.__class__.__name__).warning( + "The custom operator [{}] will override the qlib default definition".format(operator.__name__) + ) + self._ops[operator.__name__] = operator + + def __getattr__(self, key): + if key not in self._ops: + raise AttributeError("The operator [{0}] is not registered".format(key)) + return self._ops[key] + + +Operators = OpsWrapper() +Operators.register(OpsList) + + +def register_custom_ops(C): + """register custom operator""" + logger = get_module_logger("ops") + if getattr(C, "custom_ops", None) is not None: + Operators.register(C.custom_ops) + logger.debug("register custom operator {}".format(C.custom_ops)) diff --git a/qlib/tests/__init__.py b/qlib/tests/__init__.py index af8dc6c1a..f92e72787 100644 --- a/qlib/tests/__init__.py +++ b/qlib/tests/__init__.py @@ -7,6 +7,9 @@ from ..config import REG_CN class TestAutoData(unittest.TestCase): + + _setup_kwargs = {} + @classmethod def setUpClass(cls) -> None: # use default data @@ -21,4 +24,4 @@ class TestAutoData(unittest.TestCase): target_dir=provider_uri, delete_old=False, ) - init(provider_uri=provider_uri, region=REG_CN) + init(provider_uri=provider_uri, region=REG_CN, **cls._setup_kwargs) diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index c75d6db96..799ab377a 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -26,9 +26,8 @@ import pandas as pd from pathlib import Path from typing import Union, Tuple -from .. import __version__ as qlib_version from ..config import C -from ..log import get_module_logger +from ..log import get_module_logger, set_log_with_config log = get_module_logger("utils") @@ -163,7 +162,7 @@ def parse_field(field): # - $open+$close -> Feature("open")+Feature("close") if not isinstance(field, str): field = str(field) - return re.sub(r"\$(\w+)", r'Feature("\1")', field) + return re.sub(r"\$(\w+)", r'Feature("\1")', re.sub(r"(\w+\s*)\(", r"Operators.\1(", field)) def get_module_by_module_path(module_path): diff --git a/tests/test_register_ops.py b/tests/test_register_ops.py new file mode 100644 index 000000000..cb172b2bb --- /dev/null +++ b/tests/test_register_ops.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +import unittest +import numpy as np + +import qlib +from qlib.data import D +from qlib.data.ops import ElemOperator, PairOperator +from qlib.config import REG_CN +from qlib.utils import exists_qlib_data +from qlib.tests import TestAutoData +from qlib.tests.data import GetData + + +class Diff(ElemOperator): + """Feature First Difference + Parameters + ---------- + feature : Expression + feature instance + Returns + ---------- + Expression + a feature instance with first difference + """ + + def __init__(self, feature): + super(Diff, self).__init__(feature, "diff") + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + return series.diff() + + def get_extended_window_size(self): + lft_etd, rght_etd = self.feature.get_extended_window_size() + return lft_etd + 1, rght_etd + + +class Distance(PairOperator): + """Feature Distance + Parameters + ---------- + feature : Expression + feature instance + Returns + ---------- + Expression + a feature instance with distance + """ + + def __init__(self, feature_left, feature_right): + super(Distance, self).__init__(feature_left, feature_right, "distance") + + def _load_internal(self, instrument, start_index, end_index, freq): + series_left = self.feature_left.load(instrument, start_index, end_index, freq) + series_right = self.feature_right.load(instrument, start_index, end_index, freq) + return np.abs(series_left - series_right) + + +class TestRegiterCustomOps(TestAutoData): + @classmethod + def setUpClass(cls) -> None: + cls._setup_kwargs.update({"custom_ops": [Diff, Distance]}) + super().setUpClass() + + def test_regiter_custom_ops(self): + instruments = ["SH600000"] + fields = ["Diff($close)", "Distance($close, Ref($close, 1))"] + print(D.features(instruments, fields, start_time="2010-01-01", end_time="2017-12-31", freq="day")) + + +if __name__ == "__main__": + unittest.main()