mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 19:10:58 +08:00
support config custom_ops
This commit is contained in:
@@ -14,6 +14,7 @@ from scipy.stats import percentileofscore
|
||||
|
||||
from .base import Expression, ExpressionOps
|
||||
from ..log import get_module_logger
|
||||
from ..utils import get_cls_kwargs
|
||||
|
||||
try:
|
||||
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
|
||||
@@ -1496,15 +1497,20 @@ class OpsWrapper:
|
||||
self._ops = {}
|
||||
|
||||
def register(self, ops_list):
|
||||
for operator in ops_list:
|
||||
if not issubclass(operator, ExpressionOps):
|
||||
raise TypeError("operator must be subclass of ExpressionOps, not {}".format(operator))
|
||||
for _operator in ops_list:
|
||||
if isinstance(_operator, dict):
|
||||
_ops_class, _ = get_cls_kwargs(_operator)
|
||||
else:
|
||||
_ops_class = _operator
|
||||
|
||||
if operator.__name__ in self._ops:
|
||||
if not issubclass(_ops_class, ExpressionOps):
|
||||
raise TypeError("operator must be subclass of ExpressionOps, not {}".format(_ops_class))
|
||||
|
||||
if _ops_class.__name__ in self._ops:
|
||||
get_module_logger(self.__class__.__name__).warning(
|
||||
"The custom operator [{}] will override the qlib default definition".format(operator.__name__)
|
||||
"The custom operator [{}] will override the qlib default definition".format(_ops_class.__name__)
|
||||
)
|
||||
self._ops[operator.__name__] = operator
|
||||
self._ops[_ops_class.__name__] = _ops_class
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key not in self._ops:
|
||||
|
||||
Reference in New Issue
Block a user