From 00277699c7bd94031f9bfcd694284b16ba7931aa Mon Sep 17 00:00:00 2001 From: zhupr Date: Wed, 18 Nov 2020 21:55:41 +0800 Subject: [PATCH] update tests && fix typo --- .../estimator/analyze_from_estimator.ipynb | 4 ++-- examples/train_and_backtest.py | 2 +- examples/train_backtest_analyze.ipynb | 4 ++-- scripts/data_collector/yahoo/collector.py | 2 +- scripts/dump_bin.py | 1 - scripts/get_data.py | 8 +++---- tests/dataset_tests/test_dataset.py | 2 +- tests/test_all_pipeline.py | 2 +- tests/test_dump_data.py | 22 +++++++++---------- tests/test_get_data.py | 2 +- 10 files changed, 24 insertions(+), 25 deletions(-) diff --git a/examples/estimator/analyze_from_estimator.ipynb b/examples/estimator/analyze_from_estimator.ipynb index 6554eba29..2ed63bf22 100644 --- a/examples/estimator/analyze_from_estimator.ipynb +++ b/examples/estimator/analyze_from_estimator.ipynb @@ -41,7 +41,7 @@ " print(f\"Qlib data is not found in {provider_uri}\")\n", " sys.path.append(str(CUR_DIR.parent.parent.joinpath(\"scripts\")))\n", " from get_data import GetData\n", - " GetData().qlib_data_cn(target_dir=provider_uri)\n", + " GetData().qlib_data(target_dir=provider_uri)\n", "qlib.init(provider_uri=provider_uri, region=REG_CN)" ] }, @@ -219,4 +219,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/examples/train_and_backtest.py b/examples/train_and_backtest.py index 39cae20b1..045587f52 100644 --- a/examples/train_and_backtest.py +++ b/examples/train_and_backtest.py @@ -26,7 +26,7 @@ if __name__ == "__main__": sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) from get_data import GetData - GetData().qlib_data_cn(target_dir=provider_uri) + GetData().qlib_data(target_dir=provider_uri) qlib.init(provider_uri=provider_uri, region=REG_CN) diff --git a/examples/train_backtest_analyze.ipynb b/examples/train_backtest_analyze.ipynb index e70fe17b4..21d3605a6 100644 --- a/examples/train_backtest_analyze.ipynb +++ b/examples/train_backtest_analyze.ipynb @@ -37,7 +37,7 @@ " print(f\"Qlib data is not found in {provider_uri}\")\n", " sys.path.append(str(Path.cwd().parent.joinpath(\"scripts\")))\n", " from get_data import GetData\n", - " GetData().qlib_data_cn(target_dir=provider_uri)\n", + " GetData().qlib_data(target_dir=provider_uri)\n", "qlib.init(provider_uri=provider_uri, region=REG_CN)" ] }, @@ -335,4 +335,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index f5b62ded1..69c7f8f15 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -313,7 +313,7 @@ class YahooCollectorUS(YahooCollector): def get_stock_list(self): logger.info("get US stock symbols......") - symbols = get_us_stock_symbols(qlib_data_path="/data1/data/yahoo_staock_data/backup/us_1d_qlib") + [ + symbols = get_us_stock_symbols() + [ "^GSPC", "^NDX", "^DJI", diff --git a/scripts/dump_bin.py b/scripts/dump_bin.py index 94e970808..abe75d2f0 100644 --- a/scripts/dump_bin.py +++ b/scripts/dump_bin.py @@ -287,7 +287,6 @@ class DumpDataAll(DumpDataBase): logger.info("end of features dump.\n") def dump(self): - print("dump 2") self._get_all_date() self._dump_calendars() self._dump_instruments() diff --git a/scripts/get_data.py b/scripts/get_data.py index f870f405c..661e31c5f 100644 --- a/scripts/get_data.py +++ b/scripts/get_data.py @@ -55,7 +55,7 @@ class GetData: for _file in tqdm(zp.namelist()): zp.extract(_file, str(target_dir.resolve())) - def qlib_data(self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", intervel="1d", region="cn"): + def qlib_data(self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn"): """download cn qlib data from remote Parameters @@ -63,10 +63,10 @@ class GetData: target_dir: str data save directory name: str - dataset name, value from [qlib_data, qlib_data_simple], by default qlib_data_us + dataset name, value from [qlib_data, qlib_data_simple], by default qlib_data version: str data version, value from [v0, v1, ..., latest], by default latest - intervel: str + interval: str data freq, value from [1d], by default 1d region: str data region, value from [cn, us], by default cn @@ -80,7 +80,7 @@ class GetData: # TODO: The US stock code contains "PRN", and the directory cannot be created on Windows system if region.lower() == "us": logger.warning(f"The US stock code contains 'PRN', and the directory cannot be created on Windows system") - file_name = f"{name}_{region.lower()}_{intervel}_{version}.zip" + file_name = f"{name}_{region.lower()}_{interval.lower()}_{version}.zip" self._download_data(file_name.lower(), target_dir) def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"): diff --git a/tests/dataset_tests/test_dataset.py b/tests/dataset_tests/test_dataset.py index 5a70fee49..9d282b167 100644 --- a/tests/dataset_tests/test_dataset.py +++ b/tests/dataset_tests/test_dataset.py @@ -18,7 +18,7 @@ class TestDataset(unittest.TestCase): sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath("scripts"))) from get_data import GetData - GetData().qlib_data_cn(name="qlib_data_cn_simple", target_dir=provider_uri) + GetData().qlib_data(name="qlib_data_simple", target_dir=provider_uri) qlib.init(provider_uri=provider_uri, region=REG_CN) def testCSI300(self): diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index e3ede382b..886fb31f3 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -149,7 +149,7 @@ class TestAllFlow(unittest.TestCase): sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) from get_data import GetData - GetData().qlib_data_cn(name="qlib_data_cn_simple", target_dir=provider_uri) + GetData().qlib_data(name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri) qlib.init(provider_uri=provider_uri, region=REG_CN) def test_0_train(self): diff --git a/tests/test_dump_data.py b/tests/test_dump_data.py index dbf4fb082..01e6a3758 100644 --- a/tests/test_dump_data.py +++ b/tests/test_dump_data.py @@ -14,7 +14,7 @@ from qlib.data import D sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) from get_data import GetData -from dump_bin import DumpData +from dump_bin import DumpDataAll, DumpDataFix DATA_DIR = Path(__file__).parent.joinpath("test_dump_data") @@ -36,7 +36,7 @@ class TestDumpData(unittest.TestCase): @classmethod def setUpClass(cls) -> None: GetData().csv_data_cn(SOURCE_DIR) - TestDumpData.DUMP_DATA = DumpData(csv_path=SOURCE_DIR, qlib_dir=QLIB_DIR) + TestDumpData.DUMP_DATA = DumpDataAll(csv_path=SOURCE_DIR, qlib_dir=QLIB_DIR, include_fields=cls.FIELDS) TestDumpData.STOCK_NAMES = list(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv"))) provider_uri = str(QLIB_DIR.resolve()) qlib.init( @@ -49,8 +49,10 @@ class TestDumpData(unittest.TestCase): def tearDownClass(cls) -> None: shutil.rmtree(str(DATA_DIR.resolve())) - def test_0_dump_calendars(self): - self.DUMP_DATA.dump_calendars() + def test_0_dump_bin(self): + self.DUMP_DATA.dump() + + def test_1_dump_calendars(self): ori_calendars = set( map( pd.Timestamp, @@ -60,23 +62,21 @@ class TestDumpData(unittest.TestCase): res_calendars = set(D.calendar()) assert len(ori_calendars - res_calendars) == len(res_calendars - ori_calendars) == 0, "dump calendars failed" - def test_1_dump_instruments(self): - self.DUMP_DATA.dump_instruments() + def test_2_dump_instruments(self): ori_ins = set(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv"))) res_ins = set(D.list_instruments(D.instruments("all"), as_list=True)) assert len(ori_ins - res_ins) == len(ori_ins - res_ins) == 0, "dump instruments failed" - def test_2_dump_features(self): - self.DUMP_DATA.dump_features(include_fields=self.FIELDS) + def test_3_dump_features(self): df = D.features(self.STOCK_NAMES, self.QLIB_FIELDS) TestDumpData.SIMPLE_DATA = df.loc(axis=0)[self.STOCK_NAMES[0], :] self.assertFalse(df.dropna().empty, "features data failed") self.assertListEqual(list(df.columns), self.QLIB_FIELDS, "features columns failed") - def test_3_dump_features_simple(self): + def test_4_dump_features_simple(self): stock = self.STOCK_NAMES[0] - dump_data = DumpData(csv_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR) - dump_data.dump_features(include_fields=self.FIELDS, calendar_path=QLIB_DIR.joinpath("calendars", "day.txt")) + dump_data = DumpDataFix(csv_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR, include_fields=self.FIELDS) + dump_data.dump() df = D.features([stock], self.QLIB_FIELDS) diff --git a/tests/test_get_data.py b/tests/test_get_data.py index d0f5ca591..732d866dd 100644 --- a/tests/test_get_data.py +++ b/tests/test_get_data.py @@ -37,7 +37,7 @@ class TestGetData(unittest.TestCase): def test_0_qlib_data(self): - GetData().qlib_data_cn(name="qlib_data_cn_simple", target_dir=QLIB_DIR) + GetData().qlib_data(name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", version="latest") df = D.features(D.instruments("csi300"), self.FIELDS) self.assertListEqual(list(df.columns), self.FIELDS, "get qlib data failed") self.assertFalse(df.dropna().empty, "get qlib data failed")