From 10747a3219cc59474613184cb4bafdd5d202ed7d Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 27 Nov 2020 21:19:27 +0800 Subject: [PATCH] Fix --- examples/run_all_model.py | 5 +++-- qlib/workflow/cli.py | 12 ++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/run_all_model.py b/examples/run_all_model.py index ee98177c2..f40f11444 100644 --- a/examples/run_all_model.py +++ b/examples/run_all_model.py @@ -28,7 +28,8 @@ from qlib.utils import exists_qlib_data # init qlib provider_uri = "~/.qlib/qlib_data/cn_data" -exp_path = str(Path(os.getcwd()).resolve() / "run_all_model_records") +exp_folder_name = "run_all_model_records" +exp_path = str(Path(os.getcwd()).resolve() / exp_folder_name) exp_manager = { "class": "MLflowExpManager", "module_path": "qlib.workflow.expm", @@ -253,7 +254,7 @@ def run(times=1, models=None, exclude=False): for i in range(times): sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n") errs = execute( - f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn} '{exp_manager}'" + f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn} {exp_folder_name}" ) if errs is not None: _errs = errors.get(fn, {}) diff --git a/qlib/workflow/cli.py b/qlib/workflow/cli.py index 451337343..e0c957f60 100644 --- a/qlib/workflow/cli.py +++ b/qlib/workflow/cli.py @@ -1,13 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys +import sys, os from pathlib import Path import qlib import fire import pandas as pd import ruamel.yaml as yaml +from qlib.config import C from qlib.model.trainer import task_train @@ -41,7 +42,7 @@ def sys_config(config, config_path): # worflow handler function -def workflow(config_path, experiment_name="workflow", exp_manager=None): +def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"): with open(config_path) as fp: config = yaml.load(fp, Loader=yaml.Loader) @@ -50,10 +51,9 @@ def workflow(config_path, experiment_name="workflow", exp_manager=None): provider_uri = config.get("provider_uri") region = config.get("region") - if exp_manager: - qlib.init(provider_uri=provider_uri, region=region, exp_manager=exp_manager) - else: - qlib.init(provider_uri=provider_uri, region=region) + exp_manager = C["exp_manager"] + exp_manager["kwargs"]["uri"] = "file:" + str(Path(os.getcwd()).resolve() / uri_folder + qlib.init(provider_uri=provider_uri, region=region, exp_manager=exp_manager) task_train(config, experiment_name=experiment_name)