diff --git a/qlib/workflow/cli.py b/qlib/workflow/cli.py index f660a8098..7cce25809 100644 --- a/qlib/workflow/cli.py +++ b/qlib/workflow/cli.py @@ -14,7 +14,7 @@ from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord # worflow handler function -def workflow(config_path): +def workflow(config_path, experiment_name="workflow"): with open(config_path) as fp: config = yaml.load(fp, Loader=yaml.Loader) @@ -26,12 +26,13 @@ def workflow(config_path): dataset = init_instance_by_config(config.get("task")["dataset"]) # start exp - with R.start(experiment_name="workflow"): - R.log_paramters(**flatten_dict(task)) + with R.start(experiment_name=experiment_name): + # train model + R.log_params(**flatten_dict(config.get("task"))) model.fit(dataset) recorder = R.get_recorder() - # generate records + # generate records: prediction, backtest, and analysis for record in config.get("task")["record"]: if record["class"] == SignalRecord.__name__: srconf = {"model": model, "dataset": dataset, "recorder": recorder} diff --git a/scripts/check_dump_bin.py b/scripts/check_dump_bin.py index 7c2ceccda..7c2e837af 100644 --- a/scripts/check_dump_bin.py +++ b/scripts/check_dump_bin.py @@ -108,9 +108,7 @@ class CheckBin: return self.COMPARE_ERROR def check(self): - """Check whether the bin file after ``dump_bin.py`` is executed is consistent with the original csv file data - - """ + """Check whether the bin file after ``dump_bin.py`` is executed is consistent with the original csv file data""" logger.info("start check......") error_list = [] diff --git a/scripts/get_data.py b/scripts/get_data.py index 661e31c5f..4c0595238 100644 --- a/scripts/get_data.py +++ b/scripts/get_data.py @@ -55,7 +55,9 @@ class GetData: for _file in tqdm(zp.namelist()): zp.extract(_file, str(target_dir.resolve())) - def qlib_data(self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn"): + def qlib_data( + self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn" + ): """download cn qlib data from remote Parameters diff --git a/setup.py b/setup.py index 8ad124750..d08e378cb 100644 --- a/setup.py +++ b/setup.py @@ -57,6 +57,7 @@ REQUIRED = [ "tornado", "joblib>=0.17.0", "fire>=0.3.1", + "ruamel.yaml>=0.16.12", ] # Numpy include diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index 2930489a2..04c399342 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -149,7 +149,9 @@ class TestAllFlow(unittest.TestCase): sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) from get_data import GetData - GetData().qlib_data(name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri) + GetData().qlib_data( + name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri + ) qlib.init(provider_uri=provider_uri, region=REG_CN) def test_0_train(self): diff --git a/tests/test_dump_data.py b/tests/test_dump_data.py index 01e6a3758..dfa7f8556 100644 --- a/tests/test_dump_data.py +++ b/tests/test_dump_data.py @@ -75,7 +75,9 @@ class TestDumpData(unittest.TestCase): def test_4_dump_features_simple(self): stock = self.STOCK_NAMES[0] - dump_data = DumpDataFix(csv_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR, include_fields=self.FIELDS) + dump_data = DumpDataFix( + csv_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR, include_fields=self.FIELDS + ) dump_data.dump() df = D.features([stock], self.QLIB_FIELDS)