From 21eb86e5cb2df4df95d36b75f8ed8931c953baa3 Mon Sep 17 00:00:00 2001 From: Jactus Date: Thu, 26 Nov 2020 11:54:06 +0800 Subject: [PATCH] Update run_all_model --- examples/benchmarks/TFT/README.md | 8 ++-- examples/run_all_model.py | 67 ++++++++++++++++++++++++++++--- 2 files changed, 67 insertions(+), 8 deletions(-) diff --git a/examples/benchmarks/TFT/README.md b/examples/benchmarks/TFT/README.md index a64ca0129..e9e44db1a 100644 --- a/examples/benchmarks/TFT/README.md +++ b/examples/benchmarks/TFT/README.md @@ -5,8 +5,10 @@ **GitHub**: https://github.com/google-research/google-research/tree/master/tft ## Run the Workflow -Users can follow the ``workflow_by_code_tft.py`` to run the benchmark. Please be **aware** that this script can only support Python 3.5 - 3.8. +Users can follow the ``workflow_by_code_tft.py`` to run the benchmark. ### Notes -1. The model must run in GPU, or an error will be raised. -2. New datasets should be registered in ``data_formatters``, for detail please visit the source. +1. Please be **aware** that this script can only support `Python 3.5 - 3.8`, and `Cuda 10.0 or 10.1`. +2. Please remember to install `cudatoolkit==10.1` and `cudnn==7.6` on your machine. +3. The model must run in GPU, or an error will be raised. +4. New datasets should be registered in ``data_formatters``, for detail please visit the source. diff --git a/examples/run_all_model.py b/examples/run_all_model.py index b448a1857..6f12434da 100644 --- a/examples/run_all_model.py +++ b/examples/run_all_model.py @@ -10,6 +10,7 @@ import shutil import tempfile import statistics from pathlib import Path +from operator import xor from subprocess import Popen, PIPE from threading import Thread from pprint import pprint @@ -161,6 +162,19 @@ class ExtendedEnvBuilder(venv.EnvBuilder): self.install_script(context, "pip", url) +# function to check cuda version on the machine, this case is for the model TFT +def check_cuda(folders): + path = "/usr/local/cuda/version.txt" + exclude_tft = True + if os.path.exists(path): + with open(path, "w") as f: + if "10.1" in str(f.read()) or "10.0" in str(f.read()): + exclude_tft = False + if exclude_tft and "TFT" in folders: + del folders["TFT"] + return folders + + # function to calculate the mean and std of a list in the results dictionary def cal_mean_std(results) -> dict: mean_std = dict() @@ -174,11 +188,23 @@ def cal_mean_std(results) -> dict: # function to get all the folders benchmark folder -def get_all_folders() -> dict: +def get_all_folders(models, exclude) -> dict: folders = dict() + if isinstance(models, str): + model_list = models.split(",") + models = [m.lower().strip("[ ]") for m in model_list] + elif isinstance(models, list): + models = [m.lower() for m in models] + elif models is None: + models = [f.name.lower() for f in os.scandir("benchmarks")] + else: + raise ValueError("Input models type is not supported. Please provide str or list without space.") for f in os.scandir("benchmarks"): - path = Path("benchmarks") / f.name - folders[f.name] = str(path.resolve()) + add = xor(bool(f.name.lower() in models), bool(exclude)) + if add: + path = Path("benchmarks") / f.name + folders[f.name] = str(path.resolve()) + folders = check_cuda(folders) return folders @@ -225,13 +251,44 @@ def gen_and_save_md_table(metrics): # function to run the all the models -def run(times=1): +def run(times=1, models=None, exclude=False): """ Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future. Any PR to enhance this method is highly welcomed. + + Parameters: + ----------- + times : int + determines how many times the model should be running. + models : str or list + determines the specific model or list of models to run or exclude. + exclude : boolean + determines whether the model being used is excluded or included. + + Usage: + ------- + Here are some use cases of the function in the bash: + + .. code-block:: bash + + # Case 1 - run all models multiple times + python run_all_model.py 3 + + # Case 2 - run specific models multiple times + python run_all_model.py 3 dnn + + # Case 3 - run other models except those are given as arguments for multiple times + python run_all_model.py 3 [dnn,tft,lstm] True + + # Case 4 - run specific models for one time + python run_all_model.py --models=[dnn,lightgbm] + + # Case 5 - run other models except those are given as aruments for one time + python run_all_model.py --models=[dnn,tft,sfm] --exclude=True + """ # get all folders - folders = get_all_folders() + folders = get_all_folders(models, exclude) # set up compatible = True if sys.version_info < (3, 3):