mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 10:31:00 +08:00
Add run_all_model script
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
@@ -28,18 +29,16 @@ task:
|
||||
class: DNNModelPytorch
|
||||
module_path: qlib.contrib.model.pytorch_nn
|
||||
kwargs:
|
||||
loss: mse
|
||||
input_dim: 360
|
||||
output_dim: 1
|
||||
layers: [256, 512, 1024, 512, 256, 128, 64]
|
||||
lr: 0.001
|
||||
max_steps: 300
|
||||
batch_size: 2000
|
||||
early_stop_rounds: 50
|
||||
eval_steps: 20
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: gd
|
||||
loss: mse
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 4096
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
|
||||
267
examples/run_all_model.py
Normal file
267
examples/run_all_model.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import venv
|
||||
import glob
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from subprocess import Popen, PIPE
|
||||
from threading import Thread
|
||||
from pprint import pprint
|
||||
from urllib.parse import urlparse
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.cli import workflow
|
||||
|
||||
# init qlib
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
|
||||
class ExtendedEnvBuilder(venv.EnvBuilder):
|
||||
"""
|
||||
Thie class is modified based on https://docs.python.org/3/library/venv.html.
|
||||
This builder installs setuptools and pip so that you can pip or
|
||||
easy_install other packages into the created virtual environment.
|
||||
|
||||
:param nodist: If true, setuptools and pip are not installed into the
|
||||
created virtual environment.
|
||||
:param nopip: If true, pip is not installed into the created
|
||||
virtual environment.
|
||||
:param progress: If setuptools or pip are installed, the progress of the
|
||||
installation can be monitored by passing a progress
|
||||
callable. If specified, it is called with two
|
||||
arguments: a string indicating some progress, and a
|
||||
context indicating where the string is coming from.
|
||||
The context argument can have one of three values:
|
||||
'main', indicating that it is called from virtualize()
|
||||
itself, and 'stdout' and 'stderr', which are obtained
|
||||
by reading lines from the output streams of a subprocess
|
||||
which is used to install the app.
|
||||
|
||||
If a callable is not specified, default progress
|
||||
information is output to sys.stderr.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.nodist = kwargs.pop("nodist", False)
|
||||
self.nopip = kwargs.pop("nopip", False)
|
||||
self.progress = kwargs.pop("progress", None)
|
||||
self.verbose = kwargs.pop("verbose", False)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def post_setup(self, context):
|
||||
"""
|
||||
Set up any packages which need to be pre-installed into the
|
||||
virtual environment being created.
|
||||
|
||||
:param context: The information for the virtual environment
|
||||
creation request being processed.
|
||||
"""
|
||||
os.environ["VIRTUAL_ENV"] = context.env_dir
|
||||
if not self.nodist:
|
||||
self.install_setuptools(context)
|
||||
# Can't install pip without setuptools
|
||||
if not self.nopip and not self.nodist:
|
||||
self.install_pip(context)
|
||||
|
||||
def reader(self, stream, context):
|
||||
"""
|
||||
Read lines from a subprocess' output stream and either pass to a progress
|
||||
callable (if specified) or write progress information to sys.stderr.
|
||||
"""
|
||||
progress = self.progress
|
||||
while True:
|
||||
s = stream.readline()
|
||||
if not s:
|
||||
break
|
||||
if progress is not None:
|
||||
progress(s, context)
|
||||
else:
|
||||
if not self.verbose:
|
||||
sys.stderr.write(".")
|
||||
else:
|
||||
sys.stderr.write(s.decode("utf-8"))
|
||||
sys.stderr.flush()
|
||||
stream.close()
|
||||
|
||||
def install_script(self, context, name, url):
|
||||
_, _, path, _, _, _ = urlparse(url)
|
||||
fn = os.path.split(path)[-1]
|
||||
binpath = context.bin_path
|
||||
distpath = os.path.join(binpath, fn)
|
||||
# Download script into the virtual environment's binaries folder
|
||||
urlretrieve(url, distpath)
|
||||
progress = self.progress
|
||||
if self.verbose:
|
||||
term = "\n"
|
||||
else:
|
||||
term = ""
|
||||
if progress is not None:
|
||||
progress("Installing %s ...%s" % (name, term), "main")
|
||||
else:
|
||||
sys.stderr.write("Installing %s ...%s" % (name, term))
|
||||
sys.stderr.flush()
|
||||
# Install in the virtual environment
|
||||
args = [context.env_exe, fn]
|
||||
p = Popen(args, stdout=PIPE, stderr=PIPE, cwd=binpath)
|
||||
t1 = Thread(target=self.reader, args=(p.stdout, "stdout"))
|
||||
t1.start()
|
||||
t2 = Thread(target=self.reader, args=(p.stderr, "stderr"))
|
||||
t2.start()
|
||||
p.wait()
|
||||
t1.join()
|
||||
t2.join()
|
||||
if progress is not None:
|
||||
progress("done.", "main")
|
||||
else:
|
||||
sys.stderr.write("done.\n")
|
||||
# Clean up - no longer needed
|
||||
os.unlink(distpath)
|
||||
|
||||
def install_setuptools(self, context):
|
||||
"""
|
||||
Install setuptools in the virtual environment.
|
||||
|
||||
:param context: The information for the virtual environment
|
||||
creation request being processed.
|
||||
"""
|
||||
url = "https://bootstrap.pypa.io/ez_setup.py"
|
||||
self.install_script(context, "setuptools", url)
|
||||
# clear up the setuptools archive which gets downloaded
|
||||
pred = lambda o: o.startswith("setuptools-") and o.endswith(".tar.gz")
|
||||
files = filter(pred, os.listdir(context.bin_path))
|
||||
for f in files:
|
||||
f = os.path.join(context.bin_path, f)
|
||||
os.unlink(f)
|
||||
|
||||
def install_pip(self, context):
|
||||
"""
|
||||
Install pip in the virtual environment.
|
||||
|
||||
:param context: The information for the virtual environment
|
||||
creation request being processed.
|
||||
"""
|
||||
url = "https://bootstrap.pypa.io/get-pip.py"
|
||||
self.install_script(context, "pip", url)
|
||||
|
||||
|
||||
# function to get all the folders benchmark folder
|
||||
def get_all_folders() -> dict:
|
||||
folders = dict()
|
||||
for f in os.scandir("benchmarks"):
|
||||
path = Path("benchmarks") / f.name
|
||||
if f.name != "TFT":
|
||||
folders[f.name] = str(path.resolve())
|
||||
return folders
|
||||
|
||||
|
||||
# function to get all the files under the model folder
|
||||
def get_all_files(folder_path) -> (str, str):
|
||||
yaml_path = str(Path(f"{folder_path}") / "*.yaml")
|
||||
req_path = str(Path(f"{folder_path}") / "*.txt")
|
||||
return glob.glob(yaml_path)[0], glob.glob(req_path)[0]
|
||||
|
||||
|
||||
# function to retrieve all the results
|
||||
def get_all_results(folders) -> dict:
|
||||
results = dict()
|
||||
for fn in folders:
|
||||
exp = R.get_exp(experiment_name=fn, create=False)
|
||||
recorders = exp.list_recorders()
|
||||
recorder = R.get_recorder(recorder_id=next(iter(recorders)), experiment_name=fn)
|
||||
metrics = recorder.list_metrics()
|
||||
results[fn] = {key: metrics[key] for key in metrics if "with_cost" in key}
|
||||
return results
|
||||
|
||||
|
||||
# function to generate and save markdown tables
|
||||
def gen_and_save_md_table(results):
|
||||
table = "| Model Name | Annualized Return | Information Ratio | Max Drawdown |\n"
|
||||
table += "|---|---|---|---|\n"
|
||||
for fn in results:
|
||||
ar = metrics[fn]["excess_return_with_cost.annualized_return"]
|
||||
ir = metrics[fn]["excess_return_with_cost.information_ratio"]
|
||||
md = metrics[fn]["excess_return_with_cost.max_drawdown"]
|
||||
table += f"| {fn} | {ar:9.5f} | {ir:9.5f} | {md:9.5f} |\n"
|
||||
pprint(table)
|
||||
with open("table.md", "w") as f:
|
||||
f.write(table)
|
||||
return table
|
||||
|
||||
|
||||
# function to run the all the models
|
||||
def run():
|
||||
# get all folders
|
||||
folders = get_all_folders()
|
||||
# set up
|
||||
compatible = True
|
||||
if sys.version_info < (3, 3):
|
||||
compatible = False
|
||||
elif not hasattr(sys, "base_prefix"):
|
||||
compatible = False
|
||||
if not compatible:
|
||||
raise ValueError("This script is only for use with " "Python 3.3 or later")
|
||||
if os.name == "nt":
|
||||
use_symlinks = False
|
||||
else:
|
||||
use_symlinks = True
|
||||
builder = ExtendedEnvBuilder(
|
||||
system_site_packages=False,
|
||||
clear=False,
|
||||
symlinks=use_symlinks,
|
||||
upgrade=False,
|
||||
nodist=False,
|
||||
nopip=False,
|
||||
verbose=False,
|
||||
)
|
||||
for fn in folders:
|
||||
# create env
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
env_path = Path(temp_dir).absolute()
|
||||
sys.stderr.write(f"Creating Virtual Environment with path: {env_path}...\n")
|
||||
builder.create(str(env_path))
|
||||
python_path = env_path / "bin" / "python" # TODO: FIX ME!
|
||||
sys.stderr.write("\n")
|
||||
# get all files
|
||||
sys.stderr.write("Retrieving files...\n")
|
||||
yaml_path, req_path = get_all_files(folders[fn])
|
||||
sys.stderr.write("\n")
|
||||
# install requirements.txt
|
||||
sys.stderr.write("Installing requirements.txt...\n")
|
||||
os.system(f"{python_path} -m pip install -r {req_path}")
|
||||
sys.stderr.write("\n")
|
||||
# install qlib
|
||||
sys.stderr.write("Installing qlib...\n")
|
||||
os.system(f"{python_path} -m pip install --upgrade cython") # TODO: FIX ME!
|
||||
os.system(f"{python_path} -m pip install -e git+https://github.com/you-n-g/qlib#egg=pyqlib") # TODO: FIX ME!
|
||||
sys.stderr.write("\n")
|
||||
# run workflow_by_config
|
||||
sys.stderr.write(f"Running the model: {fn}...\n")
|
||||
os.system(f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn}")
|
||||
sys.stderr.write("\n")
|
||||
# remove env
|
||||
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
|
||||
shutil.rmtree(env_path)
|
||||
# getting all results
|
||||
sys.stderr.write(f"Retrieving results...\n")
|
||||
results = get_all_results(folders)
|
||||
# generating md table
|
||||
sys.stderr.write(f"Generating markdown table...\n")
|
||||
gen_and_save_md_table(results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
rc = 1
|
||||
try:
|
||||
run() # run all the model
|
||||
rc = 0
|
||||
except Exception as e:
|
||||
print("Error: %s" % e, file=sys.stderr)
|
||||
sys.exit(rc)
|
||||
@@ -1,5 +1,15 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Copyright (c) Microsoft Corporation.\n",
|
||||
"# Licensed under the MIT License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -13,7 +23,7 @@
|
||||
"import pandas as pd\n",
|
||||
"from qlib.config import REG_CN\n",
|
||||
"from qlib.contrib.model.gbdt import LGBModel\n",
|
||||
"from qlib.contrib.estimator.handler import Alpha158\n",
|
||||
"from qlib.contrib.data.handler import Alpha158\n",
|
||||
"from qlib.contrib.strategy.strategy import TopkDropoutStrategy\n",
|
||||
"from qlib.contrib.evaluate import (\n",
|
||||
" backtest as normal_backtest,\n",
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -290,7 +290,9 @@ class DataHandlerLP(DataHandler):
|
||||
init_instance_by_config(
|
||||
proc,
|
||||
None if (isinstance(proc, dict) and "module_path" in proc) else processor_module,
|
||||
accept_types=processor_module.Processor))
|
||||
accept_types=processor_module.Processor,
|
||||
)
|
||||
)
|
||||
|
||||
self.process_type = process_type
|
||||
super().__init__(instruments, start_time, end_time, data_loader, **kwargs)
|
||||
|
||||
@@ -13,13 +13,15 @@ from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
|
||||
|
||||
# worflow handler function
|
||||
def workflow(config_path, experiment_name="workflow"):
|
||||
with open(config_path) as fp:
|
||||
config = yaml.load(fp, Loader=yaml.Loader)
|
||||
|
||||
provider_uri = config.get("provider_uri")
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
region = config.get("region")
|
||||
qlib.init(provider_uri=provider_uri, region=region)
|
||||
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(config.get("task")["model"])
|
||||
|
||||
@@ -159,6 +159,36 @@ class Recorder:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `list_artifacts` method.")
|
||||
|
||||
def list_metrics(self):
|
||||
"""
|
||||
List all the metrics of a recorder.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary of metrics that being stored.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `list_metrics` method.")
|
||||
|
||||
def list_params(self):
|
||||
"""
|
||||
List all the params of a recorder.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary of params that being stored.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `list_params` method.")
|
||||
|
||||
def list_tags(self):
|
||||
"""
|
||||
List all the tags of a recorder.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary of tags that being stored.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `list_tags` method.")
|
||||
|
||||
|
||||
class MLflowRecorder(Recorder):
|
||||
"""
|
||||
@@ -239,7 +269,7 @@ class MLflowRecorder(Recorder):
|
||||
|
||||
def log_metrics(self, step=None, **kwargs):
|
||||
for name, data in kwargs.items():
|
||||
self.client.log_metric(self.id, name, data)
|
||||
self.client.log_metric(self.id, name, data, step=step)
|
||||
|
||||
def set_tags(self, **kwargs):
|
||||
for name, data in kwargs.items():
|
||||
@@ -261,3 +291,15 @@ class MLflowRecorder(Recorder):
|
||||
assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
artifacts = self.client.list_artifacts(self.id, artifact_path)
|
||||
return artifacts
|
||||
|
||||
def list_metrics(self):
|
||||
run = self.client.get_run(self.id)
|
||||
return run.data.metrics
|
||||
|
||||
def list_params(self):
|
||||
run = self.client.get_run(self.id)
|
||||
return run.data.params
|
||||
|
||||
def list_tags(self):
|
||||
run = self.client.get_run(self.id)
|
||||
return run.data.tags
|
||||
|
||||
Reference in New Issue
Block a user