mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Fix
This commit is contained in:
@@ -164,13 +164,14 @@ class ExtendedEnvBuilder(venv.EnvBuilder):
|
||||
|
||||
# function to check cuda version on the machine, this case is for the model TFT
|
||||
def check_cuda(folders):
|
||||
path = "/usr/local/cuda/version.txt"
|
||||
path = "/usr/local/cuda/version.txt" # TODO: FIX ME, this will not work on other os systems.
|
||||
exclude_tft = True
|
||||
if os.path.exists(path):
|
||||
with open(path, "w") as f:
|
||||
with open(path, "r") as f:
|
||||
if "10.1" in str(f.read()) or "10.0" in str(f.read()):
|
||||
exclude_tft = False
|
||||
if exclude_tft and "TFT" in folders:
|
||||
sys.stderr.write("Compatible CUDA version not found! Removing TFT from the workflow...\n")
|
||||
del folders["TFT"]
|
||||
return folders
|
||||
|
||||
|
||||
Reference in New Issue
Block a user