diff --git a/qlib/data/ops.py b/qlib/data/ops.py index cbc101f47..0edc5e03f 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -14,6 +14,7 @@ from scipy.stats import percentileofscore from .base import Expression, ExpressionOps from ..log import get_module_logger +from ..utils import get_cls_kwargs try: from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi @@ -1496,15 +1497,20 @@ class OpsWrapper: self._ops = {} def register(self, ops_list): - for operator in ops_list: - if not issubclass(operator, ExpressionOps): - raise TypeError("operator must be subclass of ExpressionOps, not {}".format(operator)) + for _operator in ops_list: + if isinstance(_operator, dict): + _ops_class, _ = get_cls_kwargs(_operator) + else: + _ops_class = _operator - if operator.__name__ in self._ops: + if not issubclass(_ops_class, ExpressionOps): + raise TypeError("operator must be subclass of ExpressionOps, not {}".format(_ops_class)) + + if _ops_class.__name__ in self._ops: get_module_logger(self.__class__.__name__).warning( - "The custom operator [{}] will override the qlib default definition".format(operator.__name__) + "The custom operator [{}] will override the qlib default definition".format(_ops_class.__name__) ) - self._ops[operator.__name__] = operator + self._ops[_ops_class.__name__] = _ops_class def __getattr__(self, key): if key not in self._ops: