diff --git a/qlib/config.py b/qlib/config.py index 4dedf75d0..15c13a015 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -195,7 +195,10 @@ MODE_CONF = { "timeout": 100, "logging_level": logging.INFO, "region": REG_CN, - ## Custom Operator + # custom operator + # each element of custom_ops should be Type[ExpressionOps] or dict + # if element of custom_ops is Type[ExpressionOps], it represents the custom operator class + # if element of custom_ops is dict, it represents the config of custom operator and should include `class` and `module_path` keys. "custom_ops": [], }, } diff --git a/qlib/data/ops.py b/qlib/data/ops.py index cbc101f47..e044533d3 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -10,10 +10,12 @@ import abc import numpy as np import pandas as pd +from typing import Union, List, Type from scipy.stats import percentileofscore from .base import Expression, ExpressionOps from ..log import get_module_logger +from ..utils import get_cls_kwargs try: from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi @@ -1495,16 +1497,34 @@ class OpsWrapper: def reset(self): self._ops = {} - def register(self, ops_list): - for operator in ops_list: - if not issubclass(operator, ExpressionOps): - raise TypeError("operator must be subclass of ExpressionOps, not {}".format(operator)) + def register(self, ops_list: List[Union[Type[ExpressionOps], dict]]): + """register operator - if operator.__name__ in self._ops: + Parameters + ---------- + ops_list : List[Union[Type[ExpressionOps], dict]] + - if type(ops_list) is List[Type[ExpressionOps]], each element of ops_list represents the operator class, which should be the subclass of `ExpressionOps`. + - if type(ops_list) is List[dict], each element of ops_list represents the config of operator, which has the following format: + { + "class": class_name, + "module_path": path, + } + Note: `class` should be the class name of operator, `module_path` should be a python module or path of file. + """ + for _operator in ops_list: + if isinstance(_operator, dict): + _ops_class, _ = get_cls_kwargs(_operator) + else: + _ops_class = _operator + + if not issubclass(_ops_class, ExpressionOps): + raise TypeError("operator must be subclass of ExpressionOps, not {}".format(_ops_class)) + + if _ops_class.__name__ in self._ops: get_module_logger(self.__class__.__name__).warning( - "The custom operator [{}] will override the qlib default definition".format(operator.__name__) + "The custom operator [{}] will override the qlib default definition".format(_ops_class.__name__) ) - self._ops[operator.__name__] = operator + self._ops[_ops_class.__name__] = _ops_class def __getattr__(self, key): if key not in self._ops: