mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Update run_all_model
This commit is contained in:
@@ -5,8 +5,10 @@
|
||||
**GitHub**: https://github.com/google-research/google-research/tree/master/tft
|
||||
|
||||
## Run the Workflow
|
||||
Users can follow the ``workflow_by_code_tft.py`` to run the benchmark. Please be **aware** that this script can only support Python 3.5 - 3.8.
|
||||
Users can follow the ``workflow_by_code_tft.py`` to run the benchmark.
|
||||
|
||||
### Notes
|
||||
1. The model must run in GPU, or an error will be raised.
|
||||
2. New datasets should be registered in ``data_formatters``, for detail please visit the source.
|
||||
1. Please be **aware** that this script can only support `Python 3.5 - 3.8`, and `Cuda 10.0 or 10.1`.
|
||||
2. Please remember to install `cudatoolkit==10.1` and `cudnn==7.6` on your machine.
|
||||
3. The model must run in GPU, or an error will be raised.
|
||||
4. New datasets should be registered in ``data_formatters``, for detail please visit the source.
|
||||
|
||||
@@ -10,6 +10,7 @@ import shutil
|
||||
import tempfile
|
||||
import statistics
|
||||
from pathlib import Path
|
||||
from operator import xor
|
||||
from subprocess import Popen, PIPE
|
||||
from threading import Thread
|
||||
from pprint import pprint
|
||||
@@ -161,6 +162,19 @@ class ExtendedEnvBuilder(venv.EnvBuilder):
|
||||
self.install_script(context, "pip", url)
|
||||
|
||||
|
||||
# function to check cuda version on the machine, this case is for the model TFT
|
||||
def check_cuda(folders):
|
||||
path = "/usr/local/cuda/version.txt"
|
||||
exclude_tft = True
|
||||
if os.path.exists(path):
|
||||
with open(path, "w") as f:
|
||||
if "10.1" in str(f.read()) or "10.0" in str(f.read()):
|
||||
exclude_tft = False
|
||||
if exclude_tft and "TFT" in folders:
|
||||
del folders["TFT"]
|
||||
return folders
|
||||
|
||||
|
||||
# function to calculate the mean and std of a list in the results dictionary
|
||||
def cal_mean_std(results) -> dict:
|
||||
mean_std = dict()
|
||||
@@ -174,11 +188,23 @@ def cal_mean_std(results) -> dict:
|
||||
|
||||
|
||||
# function to get all the folders benchmark folder
|
||||
def get_all_folders() -> dict:
|
||||
def get_all_folders(models, exclude) -> dict:
|
||||
folders = dict()
|
||||
if isinstance(models, str):
|
||||
model_list = models.split(",")
|
||||
models = [m.lower().strip("[ ]") for m in model_list]
|
||||
elif isinstance(models, list):
|
||||
models = [m.lower() for m in models]
|
||||
elif models is None:
|
||||
models = [f.name.lower() for f in os.scandir("benchmarks")]
|
||||
else:
|
||||
raise ValueError("Input models type is not supported. Please provide str or list without space.")
|
||||
for f in os.scandir("benchmarks"):
|
||||
path = Path("benchmarks") / f.name
|
||||
folders[f.name] = str(path.resolve())
|
||||
add = xor(bool(f.name.lower() in models), bool(exclude))
|
||||
if add:
|
||||
path = Path("benchmarks") / f.name
|
||||
folders[f.name] = str(path.resolve())
|
||||
folders = check_cuda(folders)
|
||||
return folders
|
||||
|
||||
|
||||
@@ -225,13 +251,44 @@ def gen_and_save_md_table(metrics):
|
||||
|
||||
|
||||
# function to run the all the models
|
||||
def run(times=1):
|
||||
def run(times=1, models=None, exclude=False):
|
||||
"""
|
||||
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
|
||||
Any PR to enhance this method is highly welcomed.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
times : int
|
||||
determines how many times the model should be running.
|
||||
models : str or list
|
||||
determines the specific model or list of models to run or exclude.
|
||||
exclude : boolean
|
||||
determines whether the model being used is excluded or included.
|
||||
|
||||
Usage:
|
||||
-------
|
||||
Here are some use cases of the function in the bash:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# Case 1 - run all models multiple times
|
||||
python run_all_model.py 3
|
||||
|
||||
# Case 2 - run specific models multiple times
|
||||
python run_all_model.py 3 dnn
|
||||
|
||||
# Case 3 - run other models except those are given as arguments for multiple times
|
||||
python run_all_model.py 3 [dnn,tft,lstm] True
|
||||
|
||||
# Case 4 - run specific models for one time
|
||||
python run_all_model.py --models=[dnn,lightgbm]
|
||||
|
||||
# Case 5 - run other models except those are given as aruments for one time
|
||||
python run_all_model.py --models=[dnn,tft,sfm] --exclude=True
|
||||
|
||||
"""
|
||||
# get all folders
|
||||
folders = get_all_folders()
|
||||
folders = get_all_folders(models, exclude)
|
||||
# set up
|
||||
compatible = True
|
||||
if sys.version_info < (3, 3):
|
||||
|
||||
Reference in New Issue
Block a user