diff --git a/examples/run_all_model.py b/examples/run_all_model.py index 6f12434da..b09750674 100644 --- a/examples/run_all_model.py +++ b/examples/run_all_model.py @@ -164,13 +164,14 @@ class ExtendedEnvBuilder(venv.EnvBuilder): # function to check cuda version on the machine, this case is for the model TFT def check_cuda(folders): - path = "/usr/local/cuda/version.txt" + path = "/usr/local/cuda/version.txt" # TODO: FIX ME, this will not work on other os systems. exclude_tft = True if os.path.exists(path): - with open(path, "w") as f: + with open(path, "r") as f: if "10.1" in str(f.read()) or "10.0" in str(f.read()): exclude_tft = False if exclude_tft and "TFT" in folders: + sys.stderr.write("Compatible CUDA version not found! Removing TFT from the workflow...\n") del folders["TFT"] return folders diff --git a/qlib/workflow/cli.py b/qlib/workflow/cli.py index 2e087877b..ecec8d3d7 100644 --- a/qlib/workflow/cli.py +++ b/qlib/workflow/cli.py @@ -27,9 +27,9 @@ def sys_config(config, config_path): Parameters ---------- config : dict - configuration of the workflow + configuration of the workflow. config_path : str - configuration of the path + configuration of the path. """ sys_config = config.get("sys", {})