mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 20:11:08 +08:00
Update run_all_model script
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -56,4 +56,4 @@ jobs:
|
||||
|
||||
- name: Test workflow by config
|
||||
run: |
|
||||
qrun examples/benchmarks/GBDT/workflow_config_gbdt.yaml
|
||||
qrun examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml
|
||||
|
||||
@@ -3,10 +3,12 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import fire
|
||||
import venv
|
||||
import glob
|
||||
import shutil
|
||||
import tempfile
|
||||
import statistics
|
||||
from pathlib import Path
|
||||
from subprocess import Popen, PIPE
|
||||
from threading import Thread
|
||||
@@ -18,9 +20,16 @@ import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.cli import workflow
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
# init qlib
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
|
||||
@@ -152,6 +161,18 @@ class ExtendedEnvBuilder(venv.EnvBuilder):
|
||||
self.install_script(context, "pip", url)
|
||||
|
||||
|
||||
# function to calculate the mean and std of a list in the results dictionary
|
||||
def cal_mean_std(results) -> dict:
|
||||
mean_std = dict()
|
||||
for fn in results:
|
||||
mean_std[fn] = dict()
|
||||
for metric in results[fn]:
|
||||
mean = statistics.mean(results[fn][metric]) if len(results[fn][metric]) > 1 else results[fn][metric][0]
|
||||
std = statistics.stdev(results[fn][metric]) if len(results[fn][metric]) > 1 else 0
|
||||
mean_std[fn][metric] = [mean, std]
|
||||
return mean_std
|
||||
|
||||
|
||||
# function to get all the folders benchmark folder
|
||||
def get_all_folders() -> dict:
|
||||
folders = dict()
|
||||
@@ -175,21 +196,29 @@ def get_all_results(folders) -> dict:
|
||||
for fn in folders:
|
||||
exp = R.get_exp(experiment_name=fn, create=False)
|
||||
recorders = exp.list_recorders()
|
||||
recorder = R.get_recorder(recorder_id=next(iter(recorders)), experiment_name=fn)
|
||||
metrics = recorder.list_metrics()
|
||||
results[fn] = {key: metrics[key] for key in metrics if "with_cost" in key}
|
||||
result = dict()
|
||||
result["annualized_return_with_cost"] = list()
|
||||
result["information_ratio_with_cost"] = list()
|
||||
result["max_drawdown_with_cost"] = list()
|
||||
for recorder_id in recorders:
|
||||
recorder = R.get_recorder(recorder_id=recorder_id, experiment_name=fn)
|
||||
metrics = recorder.list_metrics()
|
||||
result["annualized_return_with_cost"].append(metrics["excess_return_with_cost.annualized_return"])
|
||||
result["information_ratio_with_cost"].append(metrics["excess_return_with_cost.information_ratio"])
|
||||
result["max_drawdown_with_cost"].append(metrics["excess_return_with_cost.max_drawdown"])
|
||||
results[fn] = result
|
||||
return results
|
||||
|
||||
|
||||
# function to generate and save markdown tables
|
||||
def gen_and_save_md_table(results):
|
||||
# function to generate and save markdown table
|
||||
def gen_and_save_md_table(metrics):
|
||||
table = "| Model Name | Annualized Return | Information Ratio | Max Drawdown |\n"
|
||||
table += "|---|---|---|---|\n"
|
||||
for fn in results:
|
||||
ar = metrics[fn]["excess_return_with_cost.annualized_return"]
|
||||
ir = metrics[fn]["excess_return_with_cost.information_ratio"]
|
||||
md = metrics[fn]["excess_return_with_cost.max_drawdown"]
|
||||
table += f"| {fn} | {ar:9.5f} | {ir:9.5f} | {md:9.5f} |\n"
|
||||
for fn in metrics:
|
||||
ar = metrics[fn]["annualized_return_with_cost"]
|
||||
ir = metrics[fn]["information_ratio_with_cost"]
|
||||
md = metrics[fn]["max_drawdown_with_cost"]
|
||||
table += f"| {fn} | {ar[0]:9.4f}±{ar[1]:9.2f} | {ir[0]:9.4f}±{ir[1]:9.2f}| {md[0]:9.4f}±{md[1]:9.2f} |\n"
|
||||
pprint(table)
|
||||
with open("table.md", "w") as f:
|
||||
f.write(table)
|
||||
@@ -197,7 +226,7 @@ def gen_and_save_md_table(results):
|
||||
|
||||
|
||||
# function to run the all the models
|
||||
def run():
|
||||
def run(times=1):
|
||||
"""
|
||||
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
|
||||
Any PR to enhance this method is highly welcomed.
|
||||
@@ -225,6 +254,7 @@ def run():
|
||||
nopip=False,
|
||||
verbose=False,
|
||||
)
|
||||
# run all the model for iterations
|
||||
for fn in folders:
|
||||
# create env
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
@@ -246,16 +276,20 @@ def run():
|
||||
os.system(f"{python_path} -m pip install --upgrade cython") # TODO: FIX ME!
|
||||
os.system(f"{python_path} -m pip install -e git+https://github.com/you-n-g/qlib#egg=pyqlib") # TODO: FIX ME!
|
||||
sys.stderr.write("\n")
|
||||
# run workflow_by_config
|
||||
sys.stderr.write(f"Running the model: {fn}...\n")
|
||||
os.system(f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn}")
|
||||
sys.stderr.write("\n")
|
||||
# run workflow_by_config for multiple times
|
||||
for i in range(times):
|
||||
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
|
||||
os.system(f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn}")
|
||||
sys.stderr.write("\n")
|
||||
# remove env
|
||||
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
|
||||
shutil.rmtree(env_path)
|
||||
# getting all results
|
||||
sys.stderr.write(f"Retrieving results...\n")
|
||||
results = get_all_results(folders)
|
||||
# calculating the mean and std
|
||||
sys.stderr.write(f"Calculating the mean and std of results...\n")
|
||||
results = cal_mean_std(results)
|
||||
# generating md table
|
||||
sys.stderr.write(f"Generating markdown table...\n")
|
||||
gen_and_save_md_table(results)
|
||||
@@ -264,7 +298,7 @@ def run():
|
||||
if __name__ == "__main__":
|
||||
rc = 1
|
||||
try:
|
||||
run() # run all the model
|
||||
fire.Fire(run) # run all the model
|
||||
rc = 0
|
||||
except Exception as e:
|
||||
print("Error: %s" % e, file=sys.stderr)
|
||||
|
||||
@@ -49,7 +49,7 @@
|
||||
" print(f\"Qlib data is not found in {provider_uri}\")\n",
|
||||
" sys.path.append(str(Path.cwd().parent.joinpath(\"scripts\")))\n",
|
||||
" from get_data import GetData\n",
|
||||
" GetData().qlib_data(target_dir=provider_uri, region=\"cn\")\n",
|
||||
" GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n",
|
||||
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -28,7 +28,7 @@ if __name__ == "__main__":
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region="cn")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ if __name__ == "__main__":
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region="cn")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ if __name__ == "__main__":
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region="cn")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ if __name__ == "__main__":
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region="cn")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ if __name__ == "__main__":
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region="cn")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import qlib
|
||||
import fire
|
||||
import pandas as pd
|
||||
import ruamel.yaml as yaml
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
|
||||
Reference in New Issue
Block a user