mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Fix test_all_pipeline.py
This commit is contained in:
@@ -53,23 +53,25 @@ class GetData:
|
||||
for _file in tqdm(zp.namelist()):
|
||||
zp.extract(_file, str(target_dir.resolve()))
|
||||
|
||||
def qlib_data_cn(self, target_dir="~/.qlib/qlib_data/cn_data", version="latest"):
|
||||
def qlib_data_cn(self, name="qlib_data_cn", target_dir="~/.qlib/qlib_data/cn_data", version="latest"):
|
||||
"""download cn qlib data from remote
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target_dir: str
|
||||
data save directory
|
||||
name: str
|
||||
dataset name, value from [qlib_data_cn, qlib_data_cn_simple], by default qlib_data_cn
|
||||
version: str
|
||||
data version, value from [v0, v1, ..., latest], by default latest
|
||||
|
||||
Examples
|
||||
---------
|
||||
python get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data --version v1
|
||||
python get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data --version latest
|
||||
-------
|
||||
|
||||
"""
|
||||
file_name = f"qlib_data_cn_{version}.zip"
|
||||
file_name = f"{name}_{version}.zip"
|
||||
self._download_data(file_name, target_dir)
|
||||
|
||||
def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"):
|
||||
|
||||
@@ -141,14 +141,14 @@ class TestAllFlow(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
# use default data
|
||||
provier_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provier_uri):
|
||||
print(f"Qlib data is not found in {provier_uri}")
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data_cn(provier_uri)
|
||||
qlib.init(provier_uri=provier_uri, region=REG_CN)
|
||||
GetData().qlib_data_cn(provider_uri)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
def test_0_train(self):
|
||||
TestAllFlow.PRED_SCORE, model_pearsonr = train()
|
||||
|
||||
Reference in New Issue
Block a user