mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 20:11:08 +08:00
Merge branch 'main' of github.com:you-n-g/qlib into main
This commit is contained in:
@@ -5,8 +5,10 @@
|
||||
**GitHub**: https://github.com/google-research/google-research/tree/master/tft
|
||||
|
||||
## Run the Workflow
|
||||
Users can follow the ``workflow_by_code_tft.py`` to run the benchmark. Please be **aware** that this script can only support Python 3.5 - 3.8.
|
||||
Users can follow the ``workflow_by_code_tft.py`` to run the benchmark.
|
||||
|
||||
### Notes
|
||||
1. The model must run in GPU, or an error will be raised.
|
||||
2. New datasets should be registered in ``data_formatters``, for detail please visit the source.
|
||||
1. Please be **aware** that this script can only support `Python 3.5 - 3.8`.
|
||||
2. If the CUDA version on your machine is not 10.0, please remember to run the following commands `conda install anaconda cudatoolkit=10.0` and `conda install cudnn` on your machine.
|
||||
3. The model must run in GPU, or an error will be raised.
|
||||
4. New datasets should be registered in ``data_formatters``, for detail please visit the source.
|
||||
|
||||
@@ -10,6 +10,7 @@ import shutil
|
||||
import tempfile
|
||||
import statistics
|
||||
from pathlib import Path
|
||||
from operator import xor
|
||||
from subprocess import Popen, PIPE
|
||||
from threading import Thread
|
||||
from pprint import pprint
|
||||
@@ -174,11 +175,22 @@ def cal_mean_std(results) -> dict:
|
||||
|
||||
|
||||
# function to get all the folders benchmark folder
|
||||
def get_all_folders() -> dict:
|
||||
def get_all_folders(models, exclude) -> dict:
|
||||
folders = dict()
|
||||
if isinstance(models, str):
|
||||
model_list = models.split(",")
|
||||
models = [m.lower().strip("[ ]") for m in model_list]
|
||||
elif isinstance(models, list):
|
||||
models = [m.lower() for m in models]
|
||||
elif models is None:
|
||||
models = [f.name.lower() for f in os.scandir("benchmarks")]
|
||||
else:
|
||||
raise ValueError("Input models type is not supported. Please provide str or list without space.")
|
||||
for f in os.scandir("benchmarks"):
|
||||
path = Path("benchmarks") / f.name
|
||||
folders[f.name] = str(path.resolve())
|
||||
add = xor(bool(f.name.lower() in models), bool(exclude))
|
||||
if add:
|
||||
path = Path("benchmarks") / f.name
|
||||
folders[f.name] = str(path.resolve())
|
||||
return folders
|
||||
|
||||
|
||||
@@ -225,13 +237,44 @@ def gen_and_save_md_table(metrics):
|
||||
|
||||
|
||||
# function to run the all the models
|
||||
def run(times=1):
|
||||
def run(times=1, models=None, exclude=False):
|
||||
"""
|
||||
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
|
||||
Any PR to enhance this method is highly welcomed.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
times : int
|
||||
determines how many times the model should be running.
|
||||
models : str or list
|
||||
determines the specific model or list of models to run or exclude.
|
||||
exclude : boolean
|
||||
determines whether the model being used is excluded or included.
|
||||
|
||||
Usage:
|
||||
-------
|
||||
Here are some use cases of the function in the bash:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# Case 1 - run all models multiple times
|
||||
python run_all_model.py 3
|
||||
|
||||
# Case 2 - run specific models multiple times
|
||||
python run_all_model.py 3 dnn
|
||||
|
||||
# Case 3 - run other models except those are given as arguments for multiple times
|
||||
python run_all_model.py 3 [dnn,tft,lstm] True
|
||||
|
||||
# Case 4 - run specific models for one time
|
||||
python run_all_model.py --models=[dnn,lightgbm]
|
||||
|
||||
# Case 5 - run other models except those are given as aruments for one time
|
||||
python run_all_model.py --models=[dnn,tft,sfm] --exclude=True
|
||||
|
||||
"""
|
||||
# get all folders
|
||||
folders = get_all_folders()
|
||||
folders = get_all_folders(models, exclude)
|
||||
# set up
|
||||
compatible = True
|
||||
if sys.version_info < (3, 3):
|
||||
|
||||
Reference in New Issue
Block a user