From 0eee4a0f2e28c74a5bb48f4b21f7223e9988de86 Mon Sep 17 00:00:00 2001 From: bxdd Date: Wed, 23 Jun 2021 15:56:36 +0800 Subject: [PATCH 1/2] support config custom_ops --- qlib/data/ops.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/qlib/data/ops.py b/qlib/data/ops.py index cbc101f47..0edc5e03f 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -14,6 +14,7 @@ from scipy.stats import percentileofscore from .base import Expression, ExpressionOps from ..log import get_module_logger +from ..utils import get_cls_kwargs try: from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi @@ -1496,15 +1497,20 @@ class OpsWrapper: self._ops = {} def register(self, ops_list): - for operator in ops_list: - if not issubclass(operator, ExpressionOps): - raise TypeError("operator must be subclass of ExpressionOps, not {}".format(operator)) + for _operator in ops_list: + if isinstance(_operator, dict): + _ops_class, _ = get_cls_kwargs(_operator) + else: + _ops_class = _operator - if operator.__name__ in self._ops: + if not issubclass(_ops_class, ExpressionOps): + raise TypeError("operator must be subclass of ExpressionOps, not {}".format(_ops_class)) + + if _ops_class.__name__ in self._ops: get_module_logger(self.__class__.__name__).warning( - "The custom operator [{}] will override the qlib default definition".format(operator.__name__) + "The custom operator [{}] will override the qlib default definition".format(_ops_class.__name__) ) - self._ops[operator.__name__] = operator + self._ops[_ops_class.__name__] = _ops_class def __getattr__(self, key): if key not in self._ops: From 8d0b67334174943456f287f6ff12e608061f7961 Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 24 Jun 2021 15:00:45 +0800 Subject: [PATCH 2/2] add custom_ops docstring --- qlib/config.py | 5 ++++- qlib/data/ops.py | 16 +++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/qlib/config.py b/qlib/config.py index 4dedf75d0..15c13a015 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -195,7 +195,10 @@ MODE_CONF = { "timeout": 100, "logging_level": logging.INFO, "region": REG_CN, - ## Custom Operator + # custom operator + # each element of custom_ops should be Type[ExpressionOps] or dict + # if element of custom_ops is Type[ExpressionOps], it represents the custom operator class + # if element of custom_ops is dict, it represents the config of custom operator and should include `class` and `module_path` keys. "custom_ops": [], }, } diff --git a/qlib/data/ops.py b/qlib/data/ops.py index 0edc5e03f..e044533d3 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -10,6 +10,7 @@ import abc import numpy as np import pandas as pd +from typing import Union, List, Type from scipy.stats import percentileofscore from .base import Expression, ExpressionOps @@ -1496,7 +1497,20 @@ class OpsWrapper: def reset(self): self._ops = {} - def register(self, ops_list): + def register(self, ops_list: List[Union[Type[ExpressionOps], dict]]): + """register operator + + Parameters + ---------- + ops_list : List[Union[Type[ExpressionOps], dict]] + - if type(ops_list) is List[Type[ExpressionOps]], each element of ops_list represents the operator class, which should be the subclass of `ExpressionOps`. + - if type(ops_list) is List[dict], each element of ops_list represents the config of operator, which has the following format: + { + "class": class_name, + "module_path": path, + } + Note: `class` should be the class name of operator, `module_path` should be a python module or path of file. + """ for _operator in ops_list: if isinstance(_operator, dict): _ops_class, _ = get_cls_kwargs(_operator)