mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-30 17:41:18 +08:00
update ops register
This commit is contained in:
@@ -13,7 +13,6 @@ from scipy.stats import percentileofscore
|
||||
|
||||
from .base import Expression, ExpressionOps
|
||||
from ..log import get_module_logger
|
||||
from ..utils import OpsWrapper
|
||||
|
||||
try:
|
||||
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
|
||||
@@ -1386,8 +1385,27 @@ class Cov(PairRolling):
|
||||
super(Cov, self).__init__(feature_left, feature_right, N, "cov")
|
||||
|
||||
|
||||
Operators = OpsWrapper()
|
||||
class OpsWrapper(object):
|
||||
"""Ops Wrapper"""
|
||||
|
||||
def __init__(self):
|
||||
self._ops = {}
|
||||
|
||||
def register(self, ops_list):
|
||||
for operator in ops_list:
|
||||
if operator.__name__ in self._ops:
|
||||
get_module_logger(self.__class__.__name__).warning(
|
||||
"The custom operator [{}] will override the qlib default definition".format(operator.__name__)
|
||||
)
|
||||
self._ops[operator.__name__] = operator
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key not in self._ops:
|
||||
raise AttributeError("The operator [{}] is not registered".format(key))
|
||||
return self._ops[key]
|
||||
|
||||
|
||||
Operators = OpsWrapper()
|
||||
OpsList = [
|
||||
Ref,
|
||||
Max,
|
||||
|
||||
@@ -728,23 +728,3 @@ def load_dataset(path_or_obj):
|
||||
elif extension == ".csv":
|
||||
return pd.read_csv(path_or_obj, parse_dates=True, index_col=[0, 1])
|
||||
raise ValueError(f"unsupported file type `{extension}`")
|
||||
|
||||
|
||||
#################### Operator Wrapper #####################
|
||||
|
||||
|
||||
class OpsWrapper(object):
|
||||
"""Ops Wrapper"""
|
||||
|
||||
def __init__(self):
|
||||
self._ops = {}
|
||||
|
||||
def register(self, ops_list):
|
||||
|
||||
for operator in ops_list:
|
||||
self._ops[operator.__name__] = operator
|
||||
|
||||
def __getattr__(self, key):
|
||||
if self._ops is {}:
|
||||
raise AttributeError("Please run qlib.init() first using qlib to register ops")
|
||||
return self._ops[key]
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.data import D
|
||||
from qlib.data.ops import Operators, ElemOperator, PairOperator
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import exists_qlib_data
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
|
||||
class Diff(ElemOperator):
|
||||
@@ -61,23 +63,20 @@ class Distance(PairOperator):
|
||||
series_right = self.feature_right.load(instrument, start_index, end_index, freq)
|
||||
return np.abs(series_left - series_right)
|
||||
|
||||
class TestRegiterCustomOps(TestAutoData):
|
||||
|
||||
def test_regiter_custom_ops(self):
|
||||
OpsList = [Diff, Distance]
|
||||
Operators.register(OpsList)
|
||||
instruments = ["SH600000"]
|
||||
fields = ["Diff($close)", "Distance($close, Ref($close, 1))"]
|
||||
print(D.features(instruments, fields, start_time="2010-01-01", end_time="2017-12-31", freq="day"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=10)
|
||||
|
||||
# register custom operators
|
||||
OpsList = [Diff, Distance]
|
||||
Operators.register(OpsList)
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
instruments = ["SH600000"]
|
||||
fields = ["Diff($close)", "Distance($close, Ref($close, 1))"]
|
||||
print(D.features(instruments, fields, start_time="2010-01-01", end_time="2017-12-31", freq="day"))
|
||||
# User could use following code to run test when using line_profiler
|
||||
# td = TestDataset()
|
||||
# td.setUpClass()
|
||||
# td.testTSDataset()
|
||||
Reference in New Issue
Block a user