1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

add register ops config

This commit is contained in:
bxdd
2021-01-20 18:44:53 +09:00
parent 3dda2cb379
commit 6daaa79519
4 changed files with 119 additions and 21 deletions

View File

@@ -187,6 +187,8 @@ MODE_CONF = {
"timeout": 100,
"logging_level": "INFO",
"region": REG_CN,
## Custom Operator
"custom_ops": [],
},
}

View File

@@ -1396,27 +1396,6 @@ class Cov(PairRolling):
super(Cov, self).__init__(feature_left, feature_right, N, "cov")
class OpsWrapper(object):
"""Ops Wrapper"""
def __init__(self):
self._ops = {}
def register(self, ops_list):
for operator in ops_list:
if operator.__name__ in self._ops:
get_module_logger(self.__class__.__name__).warning(
"The custom operator [{}] will override the qlib default definition".format(operator.__name__)
)
self._ops[operator.__name__] = operator
def __getattr__(self, key):
if key not in self._ops:
raise AttributeError("The operator [{0}] is not registered, and all dict is {1}".format(key, self._ops))
return self._ops[key]
Operators = OpsWrapper()
OpsList = [
Ref,
Max,
@@ -1465,4 +1444,37 @@ OpsList = [
If,
]
class OpsWrapper(object):
"""Ops Wrapper"""
def __init__(self):
self._ops = {}
def register(self, ops_list):
for operator in ops_list:
if not issubclass(operator, ExpressionOps):
raise TypeError("operator must be subclass of ExpressionOps, not {}".format(operator))
if operator.__name__ in self._ops:
get_module_logger(self.__class__.__name__).warning(
"The custom operator [{}] will override the qlib default definition".format(operator.__name__)
)
self._ops[operator.__name__] = operator
def __getattr__(self, key):
if key not in self._ops:
raise AttributeError("The operator [{0}] is not registered".format(key))
return self._ops[key]
Operators = OpsWrapper()
Operators.register(OpsList)
def register_custom_ops(C):
"""register custom operator"""
logger = get_module_logger("ops")
if getattr(C, "custom_ops", None) is not None:
Operators.register(C.custom_ops)
logger.debug("register custom operator {}".format(C.custom_ops))

View File

@@ -783,10 +783,12 @@ def set_config(config_c, default_conf="client", **kwargs):
def config_based_on_c(config_c):
from ..data.ops import register_custom_ops
from ..data.data import register_all_wrappers
from ..workflow import R, QlibRecorder
from ..workflow.utils import experiment_exit_handler
register_custom_ops(config_c)
register_all_wrappers(config_c)
# set up QlibRecorder
exp_manager = init_instance_by_config(config_c["exp_manager"])

View File

@@ -0,0 +1,82 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import unittest
import numpy as np
import qlib
from qlib.data import D
from qlib.data.ops import ElemOperator, PairOperator
from qlib.config import REG_CN
from qlib.utils import exists_qlib_data
from qlib.tests import TestAutoData
from qlib.tests.data import GetData
class Diff(ElemOperator):
"""Feature First Difference
Parameters
----------
feature : Expression
feature instance
Returns
----------
Expression
a feature instance with first difference
"""
def __init__(self, feature):
super(Diff, self).__init__(feature, "diff")
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.diff()
def get_extended_window_size(self):
lft_etd, rght_etd = self.feature.get_extended_window_size()
return lft_etd + 1, rght_etd
class Distance(PairOperator):
"""Feature Distance
Parameters
----------
feature : Expression
feature instance
Returns
----------
Expression
a feature instance with distance
"""
def __init__(self, feature_left, feature_right):
super(Distance, self).__init__(feature_left, feature_right, "distance")
def _load_internal(self, instrument, start_index, end_index, freq):
series_left = self.feature_left.load(instrument, start_index, end_index, freq)
series_right = self.feature_right.load(instrument, start_index, end_index, freq)
return np.abs(series_left - series_right)
class TestRegiterCustomOps(TestAutoData):
@classmethod
def setUpClass(cls) -> None:
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(
name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri
)
qlib.init(provider_uri=provider_uri, custom_ops=[Diff, Distance], region=REG_CN)
def test_regiter_custom_ops(self):
instruments = ["SH600000"]
fields = ["Diff($close)", "Distance($close, Ref($close, 1))"]
print(D.features(instruments, fields, start_time="2010-01-01", end_time="2017-12-31", freq="day"))
if __name__ == "__main__":
unittest.main()