diff --git a/qlib/config.py b/qlib/config.py index 1737c5b37..eb68a504d 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -187,6 +187,8 @@ MODE_CONF = { "timeout": 100, "logging_level": "INFO", "region": REG_CN, + ## Custom Operator + "custom_ops": [], }, } diff --git a/qlib/data/ops.py b/qlib/data/ops.py index c988db3b5..91f7349d2 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -1396,27 +1396,6 @@ class Cov(PairRolling): super(Cov, self).__init__(feature_left, feature_right, N, "cov") -class OpsWrapper(object): - """Ops Wrapper""" - - def __init__(self): - self._ops = {} - - def register(self, ops_list): - for operator in ops_list: - if operator.__name__ in self._ops: - get_module_logger(self.__class__.__name__).warning( - "The custom operator [{}] will override the qlib default definition".format(operator.__name__) - ) - self._ops[operator.__name__] = operator - - def __getattr__(self, key): - if key not in self._ops: - raise AttributeError("The operator [{0}] is not registered, and all dict is {1}".format(key, self._ops)) - return self._ops[key] - - -Operators = OpsWrapper() OpsList = [ Ref, Max, @@ -1465,4 +1444,37 @@ OpsList = [ If, ] + +class OpsWrapper(object): + """Ops Wrapper""" + + def __init__(self): + self._ops = {} + + def register(self, ops_list): + for operator in ops_list: + if not issubclass(operator, ExpressionOps): + raise TypeError("operator must be subclass of ExpressionOps, not {}".format(operator)) + + if operator.__name__ in self._ops: + get_module_logger(self.__class__.__name__).warning( + "The custom operator [{}] will override the qlib default definition".format(operator.__name__) + ) + self._ops[operator.__name__] = operator + + def __getattr__(self, key): + if key not in self._ops: + raise AttributeError("The operator [{0}] is not registered".format(key)) + return self._ops[key] + + +Operators = OpsWrapper() Operators.register(OpsList) + + +def register_custom_ops(C): + """register custom operator""" + logger = get_module_logger("ops") + if getattr(C, "custom_ops", None) is not None: + Operators.register(C.custom_ops) + logger.debug("register custom operator {}".format(C.custom_ops)) diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 7e029d05e..9e2e92a2b 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -783,10 +783,12 @@ def set_config(config_c, default_conf="client", **kwargs): def config_based_on_c(config_c): + from ..data.ops import register_custom_ops from ..data.data import register_all_wrappers from ..workflow import R, QlibRecorder from ..workflow.utils import experiment_exit_handler + register_custom_ops(config_c) register_all_wrappers(config_c) # set up QlibRecorder exp_manager = init_instance_by_config(config_c["exp_manager"]) diff --git a/tests/test_register_ops.py b/tests/test_register_ops.py new file mode 100644 index 000000000..b147f916f --- /dev/null +++ b/tests/test_register_ops.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +import unittest +import numpy as np + +import qlib +from qlib.data import D +from qlib.data.ops import ElemOperator, PairOperator +from qlib.config import REG_CN +from qlib.utils import exists_qlib_data +from qlib.tests import TestAutoData +from qlib.tests.data import GetData + + +class Diff(ElemOperator): + """Feature First Difference + Parameters + ---------- + feature : Expression + feature instance + Returns + ---------- + Expression + a feature instance with first difference + """ + + def __init__(self, feature): + super(Diff, self).__init__(feature, "diff") + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + return series.diff() + + def get_extended_window_size(self): + lft_etd, rght_etd = self.feature.get_extended_window_size() + return lft_etd + 1, rght_etd + + +class Distance(PairOperator): + """Feature Distance + Parameters + ---------- + feature : Expression + feature instance + Returns + ---------- + Expression + a feature instance with distance + """ + + def __init__(self, feature_left, feature_right): + super(Distance, self).__init__(feature_left, feature_right, "distance") + + def _load_internal(self, instrument, start_index, end_index, freq): + series_left = self.feature_left.load(instrument, start_index, end_index, freq) + series_right = self.feature_right.load(instrument, start_index, end_index, freq) + return np.abs(series_left - series_right) + + +class TestRegiterCustomOps(TestAutoData): + @classmethod + def setUpClass(cls) -> None: + # use default data + provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir + if not exists_qlib_data(provider_uri): + print(f"Qlib data is not found in {provider_uri}") + + GetData().qlib_data( + name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri + ) + qlib.init(provider_uri=provider_uri, custom_ops=[Diff, Distance], region=REG_CN) + + def test_regiter_custom_ops(self): + instruments = ["SH600000"] + fields = ["Diff($close)", "Distance($close, Ref($close, 1))"] + print(D.features(instruments, fields, start_time="2010-01-01", end_time="2017-12-31", freq="day")) + + +if __name__ == "__main__": + unittest.main()