mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
add register ops config
This commit is contained in:
@@ -187,6 +187,8 @@ MODE_CONF = {
|
||||
"timeout": 100,
|
||||
"logging_level": "INFO",
|
||||
"region": REG_CN,
|
||||
## Custom Operator
|
||||
"custom_ops": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -1396,27 +1396,6 @@ class Cov(PairRolling):
|
||||
super(Cov, self).__init__(feature_left, feature_right, N, "cov")
|
||||
|
||||
|
||||
class OpsWrapper(object):
|
||||
"""Ops Wrapper"""
|
||||
|
||||
def __init__(self):
|
||||
self._ops = {}
|
||||
|
||||
def register(self, ops_list):
|
||||
for operator in ops_list:
|
||||
if operator.__name__ in self._ops:
|
||||
get_module_logger(self.__class__.__name__).warning(
|
||||
"The custom operator [{}] will override the qlib default definition".format(operator.__name__)
|
||||
)
|
||||
self._ops[operator.__name__] = operator
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key not in self._ops:
|
||||
raise AttributeError("The operator [{0}] is not registered, and all dict is {1}".format(key, self._ops))
|
||||
return self._ops[key]
|
||||
|
||||
|
||||
Operators = OpsWrapper()
|
||||
OpsList = [
|
||||
Ref,
|
||||
Max,
|
||||
@@ -1465,4 +1444,37 @@ OpsList = [
|
||||
If,
|
||||
]
|
||||
|
||||
|
||||
class OpsWrapper(object):
|
||||
"""Ops Wrapper"""
|
||||
|
||||
def __init__(self):
|
||||
self._ops = {}
|
||||
|
||||
def register(self, ops_list):
|
||||
for operator in ops_list:
|
||||
if not issubclass(operator, ExpressionOps):
|
||||
raise TypeError("operator must be subclass of ExpressionOps, not {}".format(operator))
|
||||
|
||||
if operator.__name__ in self._ops:
|
||||
get_module_logger(self.__class__.__name__).warning(
|
||||
"The custom operator [{}] will override the qlib default definition".format(operator.__name__)
|
||||
)
|
||||
self._ops[operator.__name__] = operator
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key not in self._ops:
|
||||
raise AttributeError("The operator [{0}] is not registered".format(key))
|
||||
return self._ops[key]
|
||||
|
||||
|
||||
Operators = OpsWrapper()
|
||||
Operators.register(OpsList)
|
||||
|
||||
|
||||
def register_custom_ops(C):
|
||||
"""register custom operator"""
|
||||
logger = get_module_logger("ops")
|
||||
if getattr(C, "custom_ops", None) is not None:
|
||||
Operators.register(C.custom_ops)
|
||||
logger.debug("register custom operator {}".format(C.custom_ops))
|
||||
|
||||
@@ -783,10 +783,12 @@ def set_config(config_c, default_conf="client", **kwargs):
|
||||
|
||||
|
||||
def config_based_on_c(config_c):
|
||||
from ..data.ops import register_custom_ops
|
||||
from ..data.data import register_all_wrappers
|
||||
from ..workflow import R, QlibRecorder
|
||||
from ..workflow.utils import experiment_exit_handler
|
||||
|
||||
register_custom_ops(config_c)
|
||||
register_all_wrappers(config_c)
|
||||
# set up QlibRecorder
|
||||
exp_manager = init_instance_by_config(config_c["exp_manager"])
|
||||
|
||||
82
tests/test_register_ops.py
Normal file
82
tests/test_register_ops.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
from qlib.data.ops import ElemOperator, PairOperator
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import exists_qlib_data
|
||||
from qlib.tests import TestAutoData
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
class Diff(ElemOperator):
|
||||
"""Feature First Difference
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
Returns
|
||||
----------
|
||||
Expression
|
||||
a feature instance with first difference
|
||||
"""
|
||||
|
||||
def __init__(self, feature):
|
||||
super(Diff, self).__init__(feature, "diff")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.diff()
|
||||
|
||||
def get_extended_window_size(self):
|
||||
lft_etd, rght_etd = self.feature.get_extended_window_size()
|
||||
return lft_etd + 1, rght_etd
|
||||
|
||||
|
||||
class Distance(PairOperator):
|
||||
"""Feature Distance
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
Returns
|
||||
----------
|
||||
Expression
|
||||
a feature instance with distance
|
||||
"""
|
||||
|
||||
def __init__(self, feature_left, feature_right):
|
||||
super(Distance, self).__init__(feature_left, feature_right, "distance")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series_left = self.feature_left.load(instrument, start_index, end_index, freq)
|
||||
series_right = self.feature_right.load(instrument, start_index, end_index, freq)
|
||||
return np.abs(series_left - series_right)
|
||||
|
||||
|
||||
class TestRegiterCustomOps(TestAutoData):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
|
||||
GetData().qlib_data(
|
||||
name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri
|
||||
)
|
||||
qlib.init(provider_uri=provider_uri, custom_ops=[Diff, Distance], region=REG_CN)
|
||||
|
||||
def test_regiter_custom_ops(self):
|
||||
instruments = ["SH600000"]
|
||||
fields = ["Diff($close)", "Distance($close, Ref($close, 1))"]
|
||||
print(D.features(instruments, fields, start_time="2010-01-01", end_time="2017-12-31", freq="day"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user