1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 20:11:08 +08:00

Update run_all_model script

This commit is contained in:
Jactus
2020-11-23 15:10:14 +08:00
parent 0c3f50e426
commit 27b573c7d6
11 changed files with 57 additions and 24 deletions

View File

@@ -56,4 +56,4 @@ jobs:
- name: Test workflow by config
run: |
qrun examples/benchmarks/GBDT/workflow_config_gbdt.yaml
qrun examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml

View File

@@ -3,10 +3,12 @@
import os
import sys
import fire
import venv
import glob
import shutil
import tempfile
import statistics
from pathlib import Path
from subprocess import Popen, PIPE
from threading import Thread
@@ -18,9 +20,16 @@ import qlib
from qlib.config import REG_CN
from qlib.workflow import R
from qlib.workflow.cli import workflow
from qlib.utils import exists_qlib_data
# init qlib
provider_uri = "~/.qlib/qlib_data/cn_data"
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)
@@ -152,6 +161,18 @@ class ExtendedEnvBuilder(venv.EnvBuilder):
self.install_script(context, "pip", url)
# function to calculate the mean and std of a list in the results dictionary
def cal_mean_std(results) -> dict:
mean_std = dict()
for fn in results:
mean_std[fn] = dict()
for metric in results[fn]:
mean = statistics.mean(results[fn][metric]) if len(results[fn][metric]) > 1 else results[fn][metric][0]
std = statistics.stdev(results[fn][metric]) if len(results[fn][metric]) > 1 else 0
mean_std[fn][metric] = [mean, std]
return mean_std
# function to get all the folders benchmark folder
def get_all_folders() -> dict:
folders = dict()
@@ -175,21 +196,29 @@ def get_all_results(folders) -> dict:
for fn in folders:
exp = R.get_exp(experiment_name=fn, create=False)
recorders = exp.list_recorders()
recorder = R.get_recorder(recorder_id=next(iter(recorders)), experiment_name=fn)
metrics = recorder.list_metrics()
results[fn] = {key: metrics[key] for key in metrics if "with_cost" in key}
result = dict()
result["annualized_return_with_cost"] = list()
result["information_ratio_with_cost"] = list()
result["max_drawdown_with_cost"] = list()
for recorder_id in recorders:
recorder = R.get_recorder(recorder_id=recorder_id, experiment_name=fn)
metrics = recorder.list_metrics()
result["annualized_return_with_cost"].append(metrics["excess_return_with_cost.annualized_return"])
result["information_ratio_with_cost"].append(metrics["excess_return_with_cost.information_ratio"])
result["max_drawdown_with_cost"].append(metrics["excess_return_with_cost.max_drawdown"])
results[fn] = result
return results
# function to generate and save markdown tables
def gen_and_save_md_table(results):
# function to generate and save markdown table
def gen_and_save_md_table(metrics):
table = "| Model Name | Annualized Return | Information Ratio | Max Drawdown |\n"
table += "|---|---|---|---|\n"
for fn in results:
ar = metrics[fn]["excess_return_with_cost.annualized_return"]
ir = metrics[fn]["excess_return_with_cost.information_ratio"]
md = metrics[fn]["excess_return_with_cost.max_drawdown"]
table += f"| {fn} | {ar:9.5f} | {ir:9.5f} | {md:9.5f} |\n"
for fn in metrics:
ar = metrics[fn]["annualized_return_with_cost"]
ir = metrics[fn]["information_ratio_with_cost"]
md = metrics[fn]["max_drawdown_with_cost"]
table += f"| {fn} | {ar[0]:9.4f}±{ar[1]:9.2f} | {ir[0]:9.4f}±{ir[1]:9.2f}| {md[0]:9.4f}±{md[1]:9.2f} |\n"
pprint(table)
with open("table.md", "w") as f:
f.write(table)
@@ -197,7 +226,7 @@ def gen_and_save_md_table(results):
# function to run the all the models
def run():
def run(times=1):
"""
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
Any PR to enhance this method is highly welcomed.
@@ -225,6 +254,7 @@ def run():
nopip=False,
verbose=False,
)
# run all the model for iterations
for fn in folders:
# create env
temp_dir = tempfile.mkdtemp()
@@ -246,16 +276,20 @@ def run():
os.system(f"{python_path} -m pip install --upgrade cython") # TODO: FIX ME!
os.system(f"{python_path} -m pip install -e git+https://github.com/you-n-g/qlib#egg=pyqlib") # TODO: FIX ME!
sys.stderr.write("\n")
# run workflow_by_config
sys.stderr.write(f"Running the model: {fn}...\n")
os.system(f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn}")
sys.stderr.write("\n")
# run workflow_by_config for multiple times
for i in range(times):
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
os.system(f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn}")
sys.stderr.write("\n")
# remove env
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
shutil.rmtree(env_path)
# getting all results
sys.stderr.write(f"Retrieving results...\n")
results = get_all_results(folders)
# calculating the mean and std
sys.stderr.write(f"Calculating the mean and std of results...\n")
results = cal_mean_std(results)
# generating md table
sys.stderr.write(f"Generating markdown table...\n")
gen_and_save_md_table(results)
@@ -264,7 +298,7 @@ def run():
if __name__ == "__main__":
rc = 1
try:
run() # run all the model
fire.Fire(run) # run all the model
rc = 0
except Exception as e:
print("Error: %s" % e, file=sys.stderr)

View File

@@ -49,7 +49,7 @@
" print(f\"Qlib data is not found in {provider_uri}\")\n",
" sys.path.append(str(Path.cwd().parent.joinpath(\"scripts\")))\n",
" from get_data import GetData\n",
" GetData().qlib_data(target_dir=provider_uri, region=\"cn\")\n",
" GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n",
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
]
},

View File

@@ -28,7 +28,7 @@ if __name__ == "__main__":
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, region="cn")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)

View File

@@ -28,7 +28,7 @@ if __name__ == "__main__":
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, region="cn")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)

View File

@@ -30,7 +30,7 @@ if __name__ == "__main__":
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, region="cn")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)

View File

@@ -30,7 +30,7 @@ if __name__ == "__main__":
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, region="cn")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)

View File

@@ -30,7 +30,7 @@ if __name__ == "__main__":
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, region="cn")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)

View File

@@ -8,7 +8,6 @@ import qlib
import fire
import pandas as pd
import ruamel.yaml as yaml
from qlib.config import REG_CN
from qlib.utils import init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord