From 7ce485f2b75a3816a7897b7cd99b7595d48dc672 Mon Sep 17 00:00:00 2001 From: zhupr Date: Wed, 30 Sep 2020 11:14:21 +0800 Subject: [PATCH] Fix tests --- examples/estimator/analyze_from_estimator.ipynb | 2 +- examples/train_and_backtest.py | 2 +- examples/train_backtest_analyze.ipynb | 2 +- tests/test_get_data.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/estimator/analyze_from_estimator.ipynb b/examples/estimator/analyze_from_estimator.ipynb index 3b4e7b703..6554eba29 100644 --- a/examples/estimator/analyze_from_estimator.ipynb +++ b/examples/estimator/analyze_from_estimator.ipynb @@ -41,7 +41,7 @@ " print(f\"Qlib data is not found in {provider_uri}\")\n", " sys.path.append(str(CUR_DIR.parent.parent.joinpath(\"scripts\")))\n", " from get_data import GetData\n", - " GetData().qlib_data_cn(provider_uri)\n", + " GetData().qlib_data_cn(target_dir=provider_uri)\n", "qlib.init(provider_uri=provider_uri, region=REG_CN)" ] }, diff --git a/examples/train_and_backtest.py b/examples/train_and_backtest.py index 216d37592..46d4346f5 100644 --- a/examples/train_and_backtest.py +++ b/examples/train_and_backtest.py @@ -26,7 +26,7 @@ if __name__ == "__main__": sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) from get_data import GetData - GetData().qlib_data_cn(provider_uri) + GetData().qlib_data_cn(target_dir=provider_uri) qlib.init(provider_uri=provider_uri, region=REG_CN) diff --git a/examples/train_backtest_analyze.ipynb b/examples/train_backtest_analyze.ipynb index fed729114..e70fe17b4 100644 --- a/examples/train_backtest_analyze.ipynb +++ b/examples/train_backtest_analyze.ipynb @@ -37,7 +37,7 @@ " print(f\"Qlib data is not found in {provider_uri}\")\n", " sys.path.append(str(Path.cwd().parent.joinpath(\"scripts\")))\n", " from get_data import GetData\n", - " GetData().qlib_data_cn(provider_uri)\n", + " GetData().qlib_data_cn(target_dir=provider_uri)\n", "qlib.init(provider_uri=provider_uri, region=REG_CN)" ] }, diff --git a/tests/test_get_data.py b/tests/test_get_data.py index 935e7982d..4efc8b7d3 100644 --- a/tests/test_get_data.py +++ b/tests/test_get_data.py @@ -37,7 +37,7 @@ class TestGetData(unittest.TestCase): def test_0_qlib_data(self): - GetData().qlib_data_cn(QLIB_DIR) + GetData().qlib_data_cn(name="qlib_data_cn_simple", target_dir=QLIB_DIR) df = D.features(D.instruments("csi300"), self.FIELDS) self.assertListEqual(list(df.columns), self.FIELDS, "get qlib data failed") self.assertFalse(df.dropna().empty, "get qlib data failed")