diff --git a/qlib/tests/__init__.py b/qlib/tests/__init__.py index f736dbd7d..a23092a2e 100644 --- a/qlib/tests/__init__.py +++ b/qlib/tests/__init__.py @@ -2,6 +2,9 @@ import unittest from .data import GetData from .. import init from ..constant import REG_CN +from qlib.data.filter import NameDFilter +from qlib.data import D +from qlib.data.data import Cal, DatasetD class TestAutoData(unittest.TestCase): @@ -51,3 +54,24 @@ class TestAutoData(unittest.TestCase): dataset_cache=None, **cls._setup_kwargs, ) + + +class TestOperatorData(TestAutoData): + @classmethod + def setUpClass(cls, enable_1d_type="simple", enable_1min=False) -> None: + # use default data + super().setUpClass(enable_1d_type, enable_1min) + nameDFilter = NameDFilter(name_rule_re="SH600110") + instruments = D.instruments("csi300", filter_pipe=[nameDFilter]) + start_time = "2005-01-04" + end_time = "2005-12-31" + freq = "day" + + instruments_d = DatasetD.get_instruments_d(instruments, freq) + cls.instruments_d = instruments_d + cal = Cal.calendar(start_time, end_time, freq) + cls.cal = cal + cls.start_time = cal[0] + cls.end_time = cal[-1] + cls.inst = list(instruments_d.keys())[0] + cls.spans = list(instruments_d.values())[0] diff --git a/tests/ops/test_elem_operator.py b/tests/ops/test_elem_operator.py new file mode 100644 index 000000000..0e21e5354 --- /dev/null +++ b/tests/ops/test_elem_operator.py @@ -0,0 +1,33 @@ +import unittest + +from qlib.data import DatasetProvider +from qlib.tests import TestOperatorData +from qlib.config import C + + +class TestOperatorDataSetting(TestOperatorData): + def test_setting(self): + self.assertEqual(len(self.instruments_d), 1) + self.assertGreater(len(self.cal), 0) + + +class TestElementOperator(TestOperatorData): + def setUp(self) -> None: + freq = "day" + expressions = [ + "$change", + "Abs($change)", + ] + columns = ["change", "abs"] + self.data = DatasetProvider.expression_calculator( + self.inst, self.start_time, self.end_time, freq, expressions, self.spans, C, [] + ) + self.data.columns = columns + + def test_abs(self): + abs_values = self.data["abs"] + self.assertGreater(abs_values[2], 0) + + +if __name__ == "__main__": + unittest.main()