1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 18:11:18 +08:00
Files
qlib/tests/test_register_custom_ops.py
2020-12-21 12:06:42 +00:00

83 lines
2.2 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import unittest
import numpy as np
import pandas as pd
from pathlib import Path
import qlib
from qlib.data import D
from qlib.data.ops import Operators, ElemOperator, PairOperator
from qlib.config import REG_CN
from qlib.utils import exists_qlib_data
from qlib.tests import TestAutoData
class Diff(ElemOperator):
"""Feature First Difference
Parameters
----------
feature : Expression
feature instance
Returns
----------
Expression
a feature instance with first difference
"""
def __init__(self, feature):
super(Diff, self).__init__(feature, "diff")
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.diff()
def get_extended_window_size(self):
lft_etd, rght_etd = self.feature.get_extended_window_size()
return lft_etd + 1, rght_etd
class Distance(PairOperator):
"""Feature Distance
Parameters
----------
feature : Expression
feature instance
Returns
----------
Expression
a feature instance with distance
"""
def __init__(self, feature_left, feature_right):
super(Distance, self).__init__(feature_left, feature_right, "distance")
def _load_internal(self, instrument, start_index, end_index, freq):
series_left = self.feature_left.load(instrument, start_index, end_index, freq)
series_right = self.feature_right.load(instrument, start_index, end_index, freq)
return np.abs(series_left - series_right)
class TestRegiterCustomOps(TestAutoData):
def test_regiter_custom_ops(self):
OpsList = [Diff, Distance]
Operators.register(OpsList)
instruments = ["SH600000"]
fields = ["Diff($close)", "Distance($close, Ref($close, 1))"]
print(D.features(instruments, fields, start_time="2010-01-01", end_time="2017-12-31", freq="day"))
if __name__ == "__main__":
unittest.main(verbosity=10)
# User could use following code to run test when using line_profiler
# td = TestDataset()
# td.setUpClass()
# td.testTSDataset()