diff --git a/scripts/get_data.py b/scripts/get_data.py index d20a251ed..8345afed7 100644 --- a/scripts/get_data.py +++ b/scripts/get_data.py @@ -53,23 +53,25 @@ class GetData: for _file in tqdm(zp.namelist()): zp.extract(_file, str(target_dir.resolve())) - def qlib_data_cn(self, target_dir="~/.qlib/qlib_data/cn_data", version="latest"): + def qlib_data_cn(self, name="qlib_data_cn", target_dir="~/.qlib/qlib_data/cn_data", version="latest"): """download cn qlib data from remote Parameters ---------- target_dir: str data save directory + name: str + dataset name, value from [qlib_data_cn, qlib_data_cn_simple], by default qlib_data_cn version: str data version, value from [v0, v1, ..., latest], by default latest Examples --------- - python get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data --version v1 + python get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data --version latest ------- """ - file_name = f"qlib_data_cn_{version}.zip" + file_name = f"{name}_{version}.zip" self._download_data(file_name, target_dir) def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"): diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index b2f95bc26..815f76745 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -141,14 +141,14 @@ class TestAllFlow(unittest.TestCase): @classmethod def setUpClass(cls) -> None: # use default data - provier_uri = "~/.qlib/qlib_data/cn_data" # target_dir - if not exists_qlib_data(provier_uri): - print(f"Qlib data is not found in {provier_uri}") + provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir + if not exists_qlib_data(provider_uri): + print(f"Qlib data is not found in {provider_uri}") sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) from get_data import GetData - GetData().qlib_data_cn(provier_uri) - qlib.init(provier_uri=provier_uri, region=REG_CN) + GetData().qlib_data_cn(provider_uri) + qlib.init(provider_uri=provider_uri, region=REG_CN) def test_0_train(self): TestAllFlow.PRED_SCORE, model_pearsonr = train()