diff --git a/examples/benchmarks/TFT/workflow_config_tft.yaml b/examples/benchmarks/TFT/workflow_config_tft.yaml index 1396400cb..d8ee14e71 100644 --- a/examples/benchmarks/TFT/workflow_config_tft.yaml +++ b/examples/benchmarks/TFT/workflow_config_tft.yaml @@ -1,4 +1,5 @@ - +sys: + rel_path: . provider_uri: "~/.qlib/qlib_data/cn_data" region: cn market: &market csi300 @@ -28,7 +29,7 @@ port_analysis_config: &port_analysis_config task: model: class: TFTModel - module_path: qlib.examples.benchmarks.TFT + module_path: tft dataset: class: DatasetH module_path: qlib.data.dataset diff --git a/qlib/workflow/cli.py b/qlib/workflow/cli.py index a946af9a7..2e087877b 100644 --- a/qlib/workflow/cli.py +++ b/qlib/workflow/cli.py @@ -13,11 +13,43 @@ from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord +def get_path_list(path): + if isinstance(path, str): + return [path] + else: + return [p for p in path] + + +def sys_config(config, config_path): + """ + Configure the `sys` section + + Parameters + ---------- + config : dict + configuration of the workflow + config_path : str + configuration of the path + """ + sys_config = config.get("sys", {}) + + # abspath + for p in get_path_list(sys_config.get("path", [])): + sys.path.append(p) + + # relative path to config path + for p in get_path_list(sys_config.get("rel_path", [])): + sys.path.append(str(Path(config_path).parent.resolve().absolute() / p)) + + # worflow handler function def workflow(config_path, experiment_name="workflow"): with open(config_path) as fp: config = yaml.load(fp, Loader=yaml.Loader) + # config the `sys` section + sys_config(config, config_path) + provider_uri = config.get("provider_uri") region = config.get("region") qlib.init(provider_uri=provider_uri, region=region)