From 27b573c7d690c2c628c4d8ad3ae7a7bd53791004 Mon Sep 17 00:00:00 2001 From: Jactus Date: Mon, 23 Nov 2020 15:10:14 +0800 Subject: [PATCH] Update run_all_model script --- .github/workflows/test.yml | 2 +- .../{GBDT => LightGBM}/requirements.txt | 0 .../workflow_config_lightgbm.yaml} | 0 examples/run_all_model.py | 66 ++++++++++++++----- examples/workflow_by_code.ipynb | 2 +- examples/workflow_by_code.py | 2 +- examples/workflow_by_code_finetune.py | 2 +- examples/workflow_by_code_gats.py | 2 +- examples/workflow_by_code_gru.py | 2 +- examples/workflow_by_code_lstm.py | 2 +- qlib/workflow/cli.py | 1 - 11 files changed, 57 insertions(+), 24 deletions(-) rename examples/benchmarks/{GBDT => LightGBM}/requirements.txt (100%) rename examples/benchmarks/{GBDT/workflow_config_gbdt.yaml => LightGBM/workflow_config_lightgbm.yaml} (100%) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 935d03116..033d31536 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -56,4 +56,4 @@ jobs: - name: Test workflow by config run: | - qrun examples/benchmarks/GBDT/workflow_config_gbdt.yaml + qrun examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml diff --git a/examples/benchmarks/GBDT/requirements.txt b/examples/benchmarks/LightGBM/requirements.txt similarity index 100% rename from examples/benchmarks/GBDT/requirements.txt rename to examples/benchmarks/LightGBM/requirements.txt diff --git a/examples/benchmarks/GBDT/workflow_config_gbdt.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml similarity index 100% rename from examples/benchmarks/GBDT/workflow_config_gbdt.yaml rename to examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml diff --git a/examples/run_all_model.py b/examples/run_all_model.py index 0b7e1dbbe..f8894afd3 100644 --- a/examples/run_all_model.py +++ b/examples/run_all_model.py @@ -3,10 +3,12 @@ import os import sys +import fire import venv import glob import shutil import tempfile +import statistics from pathlib import Path from subprocess import Popen, PIPE from threading import Thread @@ -18,9 +20,16 @@ import qlib from qlib.config import REG_CN from qlib.workflow import R from qlib.workflow.cli import workflow +from qlib.utils import exists_qlib_data # init qlib provider_uri = "~/.qlib/qlib_data/cn_data" +if not exists_qlib_data(provider_uri): + print(f"Qlib data is not found in {provider_uri}") + sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) + from get_data import GetData + + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) qlib.init(provider_uri=provider_uri, region=REG_CN) @@ -152,6 +161,18 @@ class ExtendedEnvBuilder(venv.EnvBuilder): self.install_script(context, "pip", url) +# function to calculate the mean and std of a list in the results dictionary +def cal_mean_std(results) -> dict: + mean_std = dict() + for fn in results: + mean_std[fn] = dict() + for metric in results[fn]: + mean = statistics.mean(results[fn][metric]) if len(results[fn][metric]) > 1 else results[fn][metric][0] + std = statistics.stdev(results[fn][metric]) if len(results[fn][metric]) > 1 else 0 + mean_std[fn][metric] = [mean, std] + return mean_std + + # function to get all the folders benchmark folder def get_all_folders() -> dict: folders = dict() @@ -175,21 +196,29 @@ def get_all_results(folders) -> dict: for fn in folders: exp = R.get_exp(experiment_name=fn, create=False) recorders = exp.list_recorders() - recorder = R.get_recorder(recorder_id=next(iter(recorders)), experiment_name=fn) - metrics = recorder.list_metrics() - results[fn] = {key: metrics[key] for key in metrics if "with_cost" in key} + result = dict() + result["annualized_return_with_cost"] = list() + result["information_ratio_with_cost"] = list() + result["max_drawdown_with_cost"] = list() + for recorder_id in recorders: + recorder = R.get_recorder(recorder_id=recorder_id, experiment_name=fn) + metrics = recorder.list_metrics() + result["annualized_return_with_cost"].append(metrics["excess_return_with_cost.annualized_return"]) + result["information_ratio_with_cost"].append(metrics["excess_return_with_cost.information_ratio"]) + result["max_drawdown_with_cost"].append(metrics["excess_return_with_cost.max_drawdown"]) + results[fn] = result return results -# function to generate and save markdown tables -def gen_and_save_md_table(results): +# function to generate and save markdown table +def gen_and_save_md_table(metrics): table = "| Model Name | Annualized Return | Information Ratio | Max Drawdown |\n" table += "|---|---|---|---|\n" - for fn in results: - ar = metrics[fn]["excess_return_with_cost.annualized_return"] - ir = metrics[fn]["excess_return_with_cost.information_ratio"] - md = metrics[fn]["excess_return_with_cost.max_drawdown"] - table += f"| {fn} | {ar:9.5f} | {ir:9.5f} | {md:9.5f} |\n" + for fn in metrics: + ar = metrics[fn]["annualized_return_with_cost"] + ir = metrics[fn]["information_ratio_with_cost"] + md = metrics[fn]["max_drawdown_with_cost"] + table += f"| {fn} | {ar[0]:9.4f}±{ar[1]:9.2f} | {ir[0]:9.4f}±{ir[1]:9.2f}| {md[0]:9.4f}±{md[1]:9.2f} |\n" pprint(table) with open("table.md", "w") as f: f.write(table) @@ -197,7 +226,7 @@ def gen_and_save_md_table(results): # function to run the all the models -def run(): +def run(times=1): """ Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future. Any PR to enhance this method is highly welcomed. @@ -225,6 +254,7 @@ def run(): nopip=False, verbose=False, ) + # run all the model for iterations for fn in folders: # create env temp_dir = tempfile.mkdtemp() @@ -246,16 +276,20 @@ def run(): os.system(f"{python_path} -m pip install --upgrade cython") # TODO: FIX ME! os.system(f"{python_path} -m pip install -e git+https://github.com/you-n-g/qlib#egg=pyqlib") # TODO: FIX ME! sys.stderr.write("\n") - # run workflow_by_config - sys.stderr.write(f"Running the model: {fn}...\n") - os.system(f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn}") - sys.stderr.write("\n") + # run workflow_by_config for multiple times + for i in range(times): + sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n") + os.system(f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn}") + sys.stderr.write("\n") # remove env sys.stderr.write(f"Deleting the environment: {env_path}...\n") shutil.rmtree(env_path) # getting all results sys.stderr.write(f"Retrieving results...\n") results = get_all_results(folders) + # calculating the mean and std + sys.stderr.write(f"Calculating the mean and std of results...\n") + results = cal_mean_std(results) # generating md table sys.stderr.write(f"Generating markdown table...\n") gen_and_save_md_table(results) @@ -264,7 +298,7 @@ def run(): if __name__ == "__main__": rc = 1 try: - run() # run all the model + fire.Fire(run) # run all the model rc = 0 except Exception as e: print("Error: %s" % e, file=sys.stderr) diff --git a/examples/workflow_by_code.ipynb b/examples/workflow_by_code.ipynb index 1ac9ab17c..1b4183b29 100644 --- a/examples/workflow_by_code.ipynb +++ b/examples/workflow_by_code.ipynb @@ -49,7 +49,7 @@ " print(f\"Qlib data is not found in {provider_uri}\")\n", " sys.path.append(str(Path.cwd().parent.joinpath(\"scripts\")))\n", " from get_data import GetData\n", - " GetData().qlib_data(target_dir=provider_uri, region=\"cn\")\n", + " GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n", "qlib.init(provider_uri=provider_uri, region=REG_CN)" ] }, diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py index 8d495e05e..8fdb4332f 100644 --- a/examples/workflow_by_code.py +++ b/examples/workflow_by_code.py @@ -28,7 +28,7 @@ if __name__ == "__main__": sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) from get_data import GetData - GetData().qlib_data(target_dir=provider_uri, region="cn") + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) qlib.init(provider_uri=provider_uri, region=REG_CN) diff --git a/examples/workflow_by_code_finetune.py b/examples/workflow_by_code_finetune.py index 209cb4a1e..5e7c179ae 100644 --- a/examples/workflow_by_code_finetune.py +++ b/examples/workflow_by_code_finetune.py @@ -28,7 +28,7 @@ if __name__ == "__main__": sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) from get_data import GetData - GetData().qlib_data(target_dir=provider_uri, region="cn") + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) qlib.init(provider_uri=provider_uri, region=REG_CN) diff --git a/examples/workflow_by_code_gats.py b/examples/workflow_by_code_gats.py index 6d44bd1b6..6b15b77b4 100644 --- a/examples/workflow_by_code_gats.py +++ b/examples/workflow_by_code_gats.py @@ -30,7 +30,7 @@ if __name__ == "__main__": sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) from get_data import GetData - GetData().qlib_data(target_dir=provider_uri, region="cn") + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) qlib.init(provider_uri=provider_uri, region=REG_CN) diff --git a/examples/workflow_by_code_gru.py b/examples/workflow_by_code_gru.py index 96e461ba8..fdd0d9220 100644 --- a/examples/workflow_by_code_gru.py +++ b/examples/workflow_by_code_gru.py @@ -30,7 +30,7 @@ if __name__ == "__main__": sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) from get_data import GetData - GetData().qlib_data(target_dir=provider_uri, region="cn") + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) qlib.init(provider_uri=provider_uri, region=REG_CN) diff --git a/examples/workflow_by_code_lstm.py b/examples/workflow_by_code_lstm.py index 2b07f6925..ee50c9aff 100644 --- a/examples/workflow_by_code_lstm.py +++ b/examples/workflow_by_code_lstm.py @@ -30,7 +30,7 @@ if __name__ == "__main__": sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) from get_data import GetData - GetData().qlib_data(target_dir=provider_uri, region="cn") + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) qlib.init(provider_uri=provider_uri, region=REG_CN) diff --git a/qlib/workflow/cli.py b/qlib/workflow/cli.py index 307a466a1..a946af9a7 100644 --- a/qlib/workflow/cli.py +++ b/qlib/workflow/cli.py @@ -8,7 +8,6 @@ import qlib import fire import pandas as pd import ruamel.yaml as yaml -from qlib.config import REG_CN from qlib.utils import init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord