1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-30 17:41:18 +08:00

update ops register

This commit is contained in:
bxdd
2020-12-21 12:06:42 +00:00
parent 0cdc5e125a
commit 7d97fd39ce
3 changed files with 40 additions and 43 deletions

View File

@@ -13,7 +13,6 @@ from scipy.stats import percentileofscore
from .base import Expression, ExpressionOps
from ..log import get_module_logger
from ..utils import OpsWrapper
try:
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
@@ -1386,8 +1385,27 @@ class Cov(PairRolling):
super(Cov, self).__init__(feature_left, feature_right, N, "cov")
Operators = OpsWrapper()
class OpsWrapper(object):
"""Ops Wrapper"""
def __init__(self):
self._ops = {}
def register(self, ops_list):
for operator in ops_list:
if operator.__name__ in self._ops:
get_module_logger(self.__class__.__name__).warning(
"The custom operator [{}] will override the qlib default definition".format(operator.__name__)
)
self._ops[operator.__name__] = operator
def __getattr__(self, key):
if key not in self._ops:
raise AttributeError("The operator [{}] is not registered".format(key))
return self._ops[key]
Operators = OpsWrapper()
OpsList = [
Ref,
Max,

View File

@@ -728,23 +728,3 @@ def load_dataset(path_or_obj):
elif extension == ".csv":
return pd.read_csv(path_or_obj, parse_dates=True, index_col=[0, 1])
raise ValueError(f"unsupported file type `{extension}`")
#################### Operator Wrapper #####################
class OpsWrapper(object):
"""Ops Wrapper"""
def __init__(self):
self._ops = {}
def register(self, ops_list):
for operator in ops_list:
self._ops[operator.__name__] = operator
def __getattr__(self, key):
if self._ops is {}:
raise AttributeError("Please run qlib.init() first using qlib to register ops")
return self._ops[key]

View File

@@ -1,16 +1,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import unittest
import numpy as np
import pandas as pd
from pathlib import Path
import qlib
import numpy as np
import pandas as pd
from qlib.data import D
from qlib.data.ops import Operators, ElemOperator, PairOperator
from qlib.config import REG_CN
from qlib.utils import exists_qlib_data
from qlib.tests import TestAutoData
class Diff(ElemOperator):
@@ -61,23 +63,20 @@ class Distance(PairOperator):
series_right = self.feature_right.load(instrument, start_index, end_index, freq)
return np.abs(series_left - series_right)
class TestRegiterCustomOps(TestAutoData):
def test_regiter_custom_ops(self):
OpsList = [Diff, Distance]
Operators.register(OpsList)
instruments = ["SH600000"]
fields = ["Diff($close)", "Distance($close, Ref($close, 1))"]
print(D.features(instruments, fields, start_time="2010-01-01", end_time="2017-12-31", freq="day"))
if __name__ == "__main__":
unittest.main(verbosity=10)
# register custom operators
OpsList = [Diff, Distance]
Operators.register(OpsList)
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)
instruments = ["SH600000"]
fields = ["Diff($close)", "Distance($close, Ref($close, 1))"]
print(D.features(instruments, fields, start_time="2010-01-01", end_time="2017-12-31", freq="day"))
# User could use following code to run test when using line_profiler
# td = TestDataset()
# td.setUpClass()
# td.testTSDataset()