From 7d97fd39ce70c75369ac2324e710c16697ee4c60 Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 21 Dec 2020 12:06:42 +0000 Subject: [PATCH] update ops register --- qlib/data/ops.py | 22 +++++++++- qlib/utils/__init__.py | 20 --------- .../test_register_custom_ops.py | 41 +++++++++---------- 3 files changed, 40 insertions(+), 43 deletions(-) rename examples/workflow_with_custom_ops.py => tests/test_register_custom_ops.py (66%) diff --git a/qlib/data/ops.py b/qlib/data/ops.py index 6836b49b4..861a69bf9 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -13,7 +13,6 @@ from scipy.stats import percentileofscore from .base import Expression, ExpressionOps from ..log import get_module_logger -from ..utils import OpsWrapper try: from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi @@ -1386,8 +1385,27 @@ class Cov(PairRolling): super(Cov, self).__init__(feature_left, feature_right, N, "cov") -Operators = OpsWrapper() +class OpsWrapper(object): + """Ops Wrapper""" + def __init__(self): + self._ops = {} + + def register(self, ops_list): + for operator in ops_list: + if operator.__name__ in self._ops: + get_module_logger(self.__class__.__name__).warning( + "The custom operator [{}] will override the qlib default definition".format(operator.__name__) + ) + self._ops[operator.__name__] = operator + + def __getattr__(self, key): + if key not in self._ops: + raise AttributeError("The operator [{}] is not registered".format(key)) + return self._ops[key] + + +Operators = OpsWrapper() OpsList = [ Ref, Max, diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index f7b406f58..b08f9426d 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -728,23 +728,3 @@ def load_dataset(path_or_obj): elif extension == ".csv": return pd.read_csv(path_or_obj, parse_dates=True, index_col=[0, 1]) raise ValueError(f"unsupported file type `{extension}`") - - -#################### Operator Wrapper ##################### - - -class OpsWrapper(object): - """Ops Wrapper""" - - def __init__(self): - self._ops = {} - - def register(self, ops_list): - - for operator in ops_list: - self._ops[operator.__name__] = operator - - def __getattr__(self, key): - if self._ops is {}: - raise AttributeError("Please run qlib.init() first using qlib to register ops") - return self._ops[key] diff --git a/examples/workflow_with_custom_ops.py b/tests/test_register_custom_ops.py similarity index 66% rename from examples/workflow_with_custom_ops.py rename to tests/test_register_custom_ops.py index f75df6166..e109fffae 100644 --- a/examples/workflow_with_custom_ops.py +++ b/tests/test_register_custom_ops.py @@ -1,16 +1,18 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import sys +import unittest +import numpy as np +import pandas as pd from pathlib import Path import qlib -import numpy as np -import pandas as pd from qlib.data import D from qlib.data.ops import Operators, ElemOperator, PairOperator from qlib.config import REG_CN from qlib.utils import exists_qlib_data +from qlib.tests import TestAutoData class Diff(ElemOperator): @@ -61,23 +63,20 @@ class Distance(PairOperator): series_right = self.feature_right.load(instrument, start_index, end_index, freq) return np.abs(series_left - series_right) +class TestRegiterCustomOps(TestAutoData): + + def test_regiter_custom_ops(self): + OpsList = [Diff, Distance] + Operators.register(OpsList) + instruments = ["SH600000"] + fields = ["Diff($close)", "Distance($close, Ref($close, 1))"] + print(D.features(instruments, fields, start_time="2010-01-01", end_time="2017-12-31", freq="day")) + if __name__ == "__main__": + unittest.main(verbosity=10) - # register custom operators - OpsList = [Diff, Distance] - Operators.register(OpsList) - # use default data - provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) - from get_data import GetData - - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) - - qlib.init(provider_uri=provider_uri, region=REG_CN) - - instruments = ["SH600000"] - fields = ["Diff($close)", "Distance($close, Ref($close, 1))"] - print(D.features(instruments, fields, start_time="2010-01-01", end_time="2017-12-31", freq="day")) + # User could use following code to run test when using line_profiler + # td = TestDataset() + # td.setUpClass() + # td.testTSDataset()