1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Update run_all_model

This commit is contained in:
Jactus
2020-11-26 11:54:06 +08:00
parent f185f48185
commit 21eb86e5cb
2 changed files with 67 additions and 8 deletions

View File

@@ -5,8 +5,10 @@
**GitHub**: https://github.com/google-research/google-research/tree/master/tft
## Run the Workflow
Users can follow the ``workflow_by_code_tft.py`` to run the benchmark. Please be **aware** that this script can only support Python 3.5 - 3.8.
Users can follow the ``workflow_by_code_tft.py`` to run the benchmark.
### Notes
1. The model must run in GPU, or an error will be raised.
2. New datasets should be registered in ``data_formatters``, for detail please visit the source.
1. Please be **aware** that this script can only support `Python 3.5 - 3.8`, and `Cuda 10.0 or 10.1`.
2. Please remember to install `cudatoolkit==10.1` and `cudnn==7.6` on your machine.
3. The model must run in GPU, or an error will be raised.
4. New datasets should be registered in ``data_formatters``, for detail please visit the source.

View File

@@ -10,6 +10,7 @@ import shutil
import tempfile
import statistics
from pathlib import Path
from operator import xor
from subprocess import Popen, PIPE
from threading import Thread
from pprint import pprint
@@ -161,6 +162,19 @@ class ExtendedEnvBuilder(venv.EnvBuilder):
self.install_script(context, "pip", url)
# function to check cuda version on the machine, this case is for the model TFT
def check_cuda(folders):
path = "/usr/local/cuda/version.txt"
exclude_tft = True
if os.path.exists(path):
with open(path, "w") as f:
if "10.1" in str(f.read()) or "10.0" in str(f.read()):
exclude_tft = False
if exclude_tft and "TFT" in folders:
del folders["TFT"]
return folders
# function to calculate the mean and std of a list in the results dictionary
def cal_mean_std(results) -> dict:
mean_std = dict()
@@ -174,11 +188,23 @@ def cal_mean_std(results) -> dict:
# function to get all the folders benchmark folder
def get_all_folders() -> dict:
def get_all_folders(models, exclude) -> dict:
folders = dict()
if isinstance(models, str):
model_list = models.split(",")
models = [m.lower().strip("[ ]") for m in model_list]
elif isinstance(models, list):
models = [m.lower() for m in models]
elif models is None:
models = [f.name.lower() for f in os.scandir("benchmarks")]
else:
raise ValueError("Input models type is not supported. Please provide str or list without space.")
for f in os.scandir("benchmarks"):
path = Path("benchmarks") / f.name
folders[f.name] = str(path.resolve())
add = xor(bool(f.name.lower() in models), bool(exclude))
if add:
path = Path("benchmarks") / f.name
folders[f.name] = str(path.resolve())
folders = check_cuda(folders)
return folders
@@ -225,13 +251,44 @@ def gen_and_save_md_table(metrics):
# function to run the all the models
def run(times=1):
def run(times=1, models=None, exclude=False):
"""
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
Any PR to enhance this method is highly welcomed.
Parameters:
-----------
times : int
determines how many times the model should be running.
models : str or list
determines the specific model or list of models to run or exclude.
exclude : boolean
determines whether the model being used is excluded or included.
Usage:
-------
Here are some use cases of the function in the bash:
.. code-block:: bash
# Case 1 - run all models multiple times
python run_all_model.py 3
# Case 2 - run specific models multiple times
python run_all_model.py 3 dnn
# Case 3 - run other models except those are given as arguments for multiple times
python run_all_model.py 3 [dnn,tft,lstm] True
# Case 4 - run specific models for one time
python run_all_model.py --models=[dnn,lightgbm]
# Case 5 - run other models except those are given as aruments for one time
python run_all_model.py --models=[dnn,tft,sfm] --exclude=True
"""
# get all folders
folders = get_all_folders()
folders = get_all_folders(models, exclude)
# set up
compatible = True
if sys.version_info < (3, 3):