mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 20:11:08 +08:00
update ops register
This commit is contained in:
82
tests/test_register_custom_ops.py
Normal file
82
tests/test_register_custom_ops.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
from qlib.data.ops import Operators, ElemOperator, PairOperator
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import exists_qlib_data
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
|
||||
class Diff(ElemOperator):
|
||||
"""Feature First Difference
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
Expression
|
||||
a feature instance with first difference
|
||||
"""
|
||||
|
||||
def __init__(self, feature):
|
||||
super(Diff, self).__init__(feature, "diff")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.diff()
|
||||
|
||||
def get_extended_window_size(self):
|
||||
lft_etd, rght_etd = self.feature.get_extended_window_size()
|
||||
return lft_etd + 1, rght_etd
|
||||
|
||||
|
||||
class Distance(PairOperator):
|
||||
"""Feature Distance
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
Expression
|
||||
a feature instance with distance
|
||||
"""
|
||||
|
||||
def __init__(self, feature_left, feature_right):
|
||||
super(Distance, self).__init__(feature_left, feature_right, "distance")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series_left = self.feature_left.load(instrument, start_index, end_index, freq)
|
||||
series_right = self.feature_right.load(instrument, start_index, end_index, freq)
|
||||
return np.abs(series_left - series_right)
|
||||
|
||||
class TestRegiterCustomOps(TestAutoData):
|
||||
|
||||
def test_regiter_custom_ops(self):
|
||||
OpsList = [Diff, Distance]
|
||||
Operators.register(OpsList)
|
||||
instruments = ["SH600000"]
|
||||
fields = ["Diff($close)", "Distance($close, Ref($close, 1))"]
|
||||
print(D.features(instruments, fields, start_time="2010-01-01", end_time="2017-12-31", freq="day"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=10)
|
||||
|
||||
# User could use following code to run test when using line_profiler
|
||||
# td = TestDataset()
|
||||
# td.setUpClass()
|
||||
# td.testTSDataset()
|
||||
Reference in New Issue
Block a user