mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Merge branch 'main' of github.com:you-n-g/qlib into main
This commit is contained in:
9
.github/workflows/test.yml
vendored
9
.github/workflows/test.yml
vendored
@@ -50,9 +50,10 @@ jobs:
|
||||
cd tests
|
||||
pytest . --durations=0
|
||||
|
||||
- name: Test data downloads and examples
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
# cd examples
|
||||
# estimator -c estimator/estimator_config.yaml
|
||||
# jupyter nbconvert --execute estimator/analyze_from_estimator.ipynb --to html
|
||||
|
||||
- name: Test workflow by config
|
||||
run: |
|
||||
workflow_by_config examples/benchmarks/GBDT/workflow_config_gbdt.yaml
|
||||
@@ -1,4 +1,5 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
@@ -28,18 +29,16 @@ task:
|
||||
class: DNNModelPytorch
|
||||
module_path: qlib.contrib.model.pytorch_nn
|
||||
kwargs:
|
||||
loss: mse
|
||||
input_dim: 360
|
||||
output_dim: 1
|
||||
layers: [256, 512, 1024, 512, 256, 128, 64]
|
||||
lr: 0.001
|
||||
max_steps: 300
|
||||
batch_size: 2000
|
||||
early_stop_rounds: 50
|
||||
eval_steps: 20
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: gd
|
||||
loss: mse
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 4096
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
|
||||
267
examples/run_all_model.py
Normal file
267
examples/run_all_model.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import venv
|
||||
import glob
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from subprocess import Popen, PIPE
|
||||
from threading import Thread
|
||||
from pprint import pprint
|
||||
from urllib.parse import urlparse
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.cli import workflow
|
||||
|
||||
# init qlib
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
|
||||
class ExtendedEnvBuilder(venv.EnvBuilder):
|
||||
"""
|
||||
Thie class is modified based on https://docs.python.org/3/library/venv.html.
|
||||
This builder installs setuptools and pip so that you can pip or
|
||||
easy_install other packages into the created virtual environment.
|
||||
|
||||
:param nodist: If true, setuptools and pip are not installed into the
|
||||
created virtual environment.
|
||||
:param nopip: If true, pip is not installed into the created
|
||||
virtual environment.
|
||||
:param progress: If setuptools or pip are installed, the progress of the
|
||||
installation can be monitored by passing a progress
|
||||
callable. If specified, it is called with two
|
||||
arguments: a string indicating some progress, and a
|
||||
context indicating where the string is coming from.
|
||||
The context argument can have one of three values:
|
||||
'main', indicating that it is called from virtualize()
|
||||
itself, and 'stdout' and 'stderr', which are obtained
|
||||
by reading lines from the output streams of a subprocess
|
||||
which is used to install the app.
|
||||
|
||||
If a callable is not specified, default progress
|
||||
information is output to sys.stderr.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.nodist = kwargs.pop("nodist", False)
|
||||
self.nopip = kwargs.pop("nopip", False)
|
||||
self.progress = kwargs.pop("progress", None)
|
||||
self.verbose = kwargs.pop("verbose", False)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def post_setup(self, context):
|
||||
"""
|
||||
Set up any packages which need to be pre-installed into the
|
||||
virtual environment being created.
|
||||
|
||||
:param context: The information for the virtual environment
|
||||
creation request being processed.
|
||||
"""
|
||||
os.environ["VIRTUAL_ENV"] = context.env_dir
|
||||
if not self.nodist:
|
||||
self.install_setuptools(context)
|
||||
# Can't install pip without setuptools
|
||||
if not self.nopip and not self.nodist:
|
||||
self.install_pip(context)
|
||||
|
||||
def reader(self, stream, context):
|
||||
"""
|
||||
Read lines from a subprocess' output stream and either pass to a progress
|
||||
callable (if specified) or write progress information to sys.stderr.
|
||||
"""
|
||||
progress = self.progress
|
||||
while True:
|
||||
s = stream.readline()
|
||||
if not s:
|
||||
break
|
||||
if progress is not None:
|
||||
progress(s, context)
|
||||
else:
|
||||
if not self.verbose:
|
||||
sys.stderr.write(".")
|
||||
else:
|
||||
sys.stderr.write(s.decode("utf-8"))
|
||||
sys.stderr.flush()
|
||||
stream.close()
|
||||
|
||||
def install_script(self, context, name, url):
|
||||
_, _, path, _, _, _ = urlparse(url)
|
||||
fn = os.path.split(path)[-1]
|
||||
binpath = context.bin_path
|
||||
distpath = os.path.join(binpath, fn)
|
||||
# Download script into the virtual environment's binaries folder
|
||||
urlretrieve(url, distpath)
|
||||
progress = self.progress
|
||||
if self.verbose:
|
||||
term = "\n"
|
||||
else:
|
||||
term = ""
|
||||
if progress is not None:
|
||||
progress("Installing %s ...%s" % (name, term), "main")
|
||||
else:
|
||||
sys.stderr.write("Installing %s ...%s" % (name, term))
|
||||
sys.stderr.flush()
|
||||
# Install in the virtual environment
|
||||
args = [context.env_exe, fn]
|
||||
p = Popen(args, stdout=PIPE, stderr=PIPE, cwd=binpath)
|
||||
t1 = Thread(target=self.reader, args=(p.stdout, "stdout"))
|
||||
t1.start()
|
||||
t2 = Thread(target=self.reader, args=(p.stderr, "stderr"))
|
||||
t2.start()
|
||||
p.wait()
|
||||
t1.join()
|
||||
t2.join()
|
||||
if progress is not None:
|
||||
progress("done.", "main")
|
||||
else:
|
||||
sys.stderr.write("done.\n")
|
||||
# Clean up - no longer needed
|
||||
os.unlink(distpath)
|
||||
|
||||
def install_setuptools(self, context):
|
||||
"""
|
||||
Install setuptools in the virtual environment.
|
||||
|
||||
:param context: The information for the virtual environment
|
||||
creation request being processed.
|
||||
"""
|
||||
url = "https://bootstrap.pypa.io/ez_setup.py"
|
||||
self.install_script(context, "setuptools", url)
|
||||
# clear up the setuptools archive which gets downloaded
|
||||
pred = lambda o: o.startswith("setuptools-") and o.endswith(".tar.gz")
|
||||
files = filter(pred, os.listdir(context.bin_path))
|
||||
for f in files:
|
||||
f = os.path.join(context.bin_path, f)
|
||||
os.unlink(f)
|
||||
|
||||
def install_pip(self, context):
|
||||
"""
|
||||
Install pip in the virtual environment.
|
||||
|
||||
:param context: The information for the virtual environment
|
||||
creation request being processed.
|
||||
"""
|
||||
url = "https://bootstrap.pypa.io/get-pip.py"
|
||||
self.install_script(context, "pip", url)
|
||||
|
||||
|
||||
# function to get all the folders benchmark folder
|
||||
def get_all_folders() -> dict:
|
||||
folders = dict()
|
||||
for f in os.scandir("benchmarks"):
|
||||
path = Path("benchmarks") / f.name
|
||||
if f.name != "TFT":
|
||||
folders[f.name] = str(path.resolve())
|
||||
return folders
|
||||
|
||||
|
||||
# function to get all the files under the model folder
|
||||
def get_all_files(folder_path) -> (str, str):
|
||||
yaml_path = str(Path(f"{folder_path}") / "*.yaml")
|
||||
req_path = str(Path(f"{folder_path}") / "*.txt")
|
||||
return glob.glob(yaml_path)[0], glob.glob(req_path)[0]
|
||||
|
||||
|
||||
# function to retrieve all the results
|
||||
def get_all_results(folders) -> dict:
|
||||
results = dict()
|
||||
for fn in folders:
|
||||
exp = R.get_exp(experiment_name=fn, create=False)
|
||||
recorders = exp.list_recorders()
|
||||
recorder = R.get_recorder(recorder_id=next(iter(recorders)), experiment_name=fn)
|
||||
metrics = recorder.list_metrics()
|
||||
results[fn] = {key: metrics[key] for key in metrics if "with_cost" in key}
|
||||
return results
|
||||
|
||||
|
||||
# function to generate and save markdown tables
|
||||
def gen_and_save_md_table(results):
|
||||
table = "| Model Name | Annualized Return | Information Ratio | Max Drawdown |\n"
|
||||
table += "|---|---|---|---|\n"
|
||||
for fn in results:
|
||||
ar = metrics[fn]["excess_return_with_cost.annualized_return"]
|
||||
ir = metrics[fn]["excess_return_with_cost.information_ratio"]
|
||||
md = metrics[fn]["excess_return_with_cost.max_drawdown"]
|
||||
table += f"| {fn} | {ar:9.5f} | {ir:9.5f} | {md:9.5f} |\n"
|
||||
pprint(table)
|
||||
with open("table.md", "w") as f:
|
||||
f.write(table)
|
||||
return table
|
||||
|
||||
|
||||
# function to run the all the models
|
||||
def run():
|
||||
# get all folders
|
||||
folders = get_all_folders()
|
||||
# set up
|
||||
compatible = True
|
||||
if sys.version_info < (3, 3):
|
||||
compatible = False
|
||||
elif not hasattr(sys, "base_prefix"):
|
||||
compatible = False
|
||||
if not compatible:
|
||||
raise ValueError("This script is only for use with " "Python 3.3 or later")
|
||||
if os.name == "nt":
|
||||
use_symlinks = False
|
||||
else:
|
||||
use_symlinks = True
|
||||
builder = ExtendedEnvBuilder(
|
||||
system_site_packages=False,
|
||||
clear=False,
|
||||
symlinks=use_symlinks,
|
||||
upgrade=False,
|
||||
nodist=False,
|
||||
nopip=False,
|
||||
verbose=False,
|
||||
)
|
||||
for fn in folders:
|
||||
# create env
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
env_path = Path(temp_dir).absolute()
|
||||
sys.stderr.write(f"Creating Virtual Environment with path: {env_path}...\n")
|
||||
builder.create(str(env_path))
|
||||
python_path = env_path / "bin" / "python" # TODO: FIX ME!
|
||||
sys.stderr.write("\n")
|
||||
# get all files
|
||||
sys.stderr.write("Retrieving files...\n")
|
||||
yaml_path, req_path = get_all_files(folders[fn])
|
||||
sys.stderr.write("\n")
|
||||
# install requirements.txt
|
||||
sys.stderr.write("Installing requirements.txt...\n")
|
||||
os.system(f"{python_path} -m pip install -r {req_path}")
|
||||
sys.stderr.write("\n")
|
||||
# install qlib
|
||||
sys.stderr.write("Installing qlib...\n")
|
||||
os.system(f"{python_path} -m pip install --upgrade cython") # TODO: FIX ME!
|
||||
os.system(f"{python_path} -m pip install -e git+https://github.com/you-n-g/qlib#egg=pyqlib") # TODO: FIX ME!
|
||||
sys.stderr.write("\n")
|
||||
# run workflow_by_config
|
||||
sys.stderr.write(f"Running the model: {fn}...\n")
|
||||
os.system(f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn}")
|
||||
sys.stderr.write("\n")
|
||||
# remove env
|
||||
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
|
||||
shutil.rmtree(env_path)
|
||||
# getting all results
|
||||
sys.stderr.write(f"Retrieving results...\n")
|
||||
results = get_all_results(folders)
|
||||
# generating md table
|
||||
sys.stderr.write(f"Generating markdown table...\n")
|
||||
gen_and_save_md_table(results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
rc = 1
|
||||
try:
|
||||
run() # run all the model
|
||||
rc = 0
|
||||
except Exception as e:
|
||||
print("Error: %s" % e, file=sys.stderr)
|
||||
sys.exit(rc)
|
||||
@@ -1,5 +1,15 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Copyright (c) Microsoft Corporation.\n",
|
||||
"# Licensed under the MIT License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -13,7 +23,7 @@
|
||||
"import pandas as pd\n",
|
||||
"from qlib.config import REG_CN\n",
|
||||
"from qlib.contrib.model.gbdt import LGBModel\n",
|
||||
"from qlib.contrib.estimator.handler import Alpha158\n",
|
||||
"from qlib.contrib.data.handler import Alpha158\n",
|
||||
"from qlib.contrib.strategy.strategy import TopkDropoutStrategy\n",
|
||||
"from qlib.contrib.evaluate import (\n",
|
||||
" backtest as normal_backtest,\n",
|
||||
|
||||
145
examples/workflow_by_code_gats.py
Normal file
145
examples/workflow_by_code_gats.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.model.pytorch_gats import GAT
|
||||
from qlib.contrib.data.handler import ALPHA360_Denoise
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
# from qlib.model.learner import train_model
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
import pickle
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data_cn(target_dir=provider_uri)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
MARKET = "csi300"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": MARKET,
|
||||
}
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
"train_start_time": "2008-01-01",
|
||||
"train_end_time": "2014-12-31",
|
||||
"validate_start_time": "2015-01-01",
|
||||
"validate_end_time": "2016-12-31",
|
||||
"test_start_time": "2017-01-01",
|
||||
"test_end_time": "2020-08-01",
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "GAT",
|
||||
"module_path": "qlib.contrib.model.pytorch_gats",
|
||||
"kwargs": {
|
||||
"d_feat": 6,
|
||||
"hidden_size": 64,
|
||||
"num_layers": 2,
|
||||
"dropout": 0.0,
|
||||
"n_epochs": 200,
|
||||
"lr": 1e-3,
|
||||
"early_stop": 20,
|
||||
"batch_size": 800,
|
||||
"metric": "IC",
|
||||
"loss": "mse",
|
||||
"base_model": "GRU",
|
||||
"seed": 0,
|
||||
"GPU": 0,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "ALPHA360_Denoise",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": DATA_HANDLER_CONFIG,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
# You shoud record the data in specific sequence
|
||||
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
|
||||
}
|
||||
|
||||
# model = train_model(task)
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
pred_score = model.predict(dataset)
|
||||
|
||||
# save pred_score to file
|
||||
pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
|
||||
pred_score_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
pred_score.to_pickle(pred_score_path)
|
||||
|
||||
###################################
|
||||
# backtest
|
||||
###################################
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
|
||||
# use default strategy
|
||||
# custom Strategy, refer to: TODO: Strategy API url
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
|
||||
|
||||
###################################
|
||||
# analyze
|
||||
# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
|
||||
###################################
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
print(analysis_df)
|
||||
144
examples/workflow_by_code_gru.py
Normal file
144
examples/workflow_by_code_gru.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.model.pytorch_gru import GRU
|
||||
from qlib.contrib.data.handler import ALPHA360_Denoise
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
# from qlib.model.learner import train_model
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
import pickle
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data_cn(target_dir=provider_uri)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
MARKET = "csi300"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": MARKET,
|
||||
}
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
"train_start_time": "2008-01-01",
|
||||
"train_end_time": "2014-12-31",
|
||||
"validate_start_time": "2015-01-01",
|
||||
"validate_end_time": "2016-12-31",
|
||||
"test_start_time": "2017-01-01",
|
||||
"test_end_time": "2020-08-01",
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "GRU",
|
||||
"module_path": "qlib.contrib.model.pytorch_gru",
|
||||
"kwargs": {
|
||||
"d_feat": 6,
|
||||
"hidden_size": 64,
|
||||
"num_layers": 2,
|
||||
"dropout": 0.0,
|
||||
"n_epochs": 200,
|
||||
"lr": 1e-3,
|
||||
"early_stop": 20,
|
||||
"batch_size": 800,
|
||||
"metric": "IC",
|
||||
"loss": "mse",
|
||||
"seed": 0,
|
||||
"GPU": 0,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "ALPHA360_Denoise",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": DATA_HANDLER_CONFIG,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
# You shoud record the data in specific sequence
|
||||
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
|
||||
}
|
||||
|
||||
# model = train_model(task)
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
pred_score = model.predict(dataset)
|
||||
|
||||
# save pred_score to file
|
||||
pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
|
||||
pred_score_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
pred_score.to_pickle(pred_score_path)
|
||||
|
||||
###################################
|
||||
# backtest
|
||||
###################################
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
|
||||
# use default strategy
|
||||
# custom Strategy, refer to: TODO: Strategy API url
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
|
||||
|
||||
###################################
|
||||
# analyze
|
||||
# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
|
||||
###################################
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
print(analysis_df)
|
||||
144
examples/workflow_by_code_lstm.py
Normal file
144
examples/workflow_by_code_lstm.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.model.pytorch_lstm import LSTM
|
||||
from qlib.contrib.data.handler import ALPHA360_Denoise
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
# from qlib.model.learner import train_model
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
import pickle
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data_cn(target_dir=provider_uri)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
MARKET = "csi300"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": MARKET,
|
||||
}
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
"train_start_time": "2008-01-01",
|
||||
"train_end_time": "2014-12-31",
|
||||
"validate_start_time": "2015-01-01",
|
||||
"validate_end_time": "2016-12-31",
|
||||
"test_start_time": "2017-01-01",
|
||||
"test_end_time": "2020-08-01",
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LSTM",
|
||||
"module_path": "qlib.contrib.model.pytorch_lstm",
|
||||
"kwargs": {
|
||||
"d_feat": 6,
|
||||
"hidden_size": 64,
|
||||
"num_layers": 2,
|
||||
"dropout": 0.0,
|
||||
"n_epochs": 200,
|
||||
"lr": 1e-3,
|
||||
"early_stop": 20,
|
||||
"batch_size": 800,
|
||||
"metric": "IC",
|
||||
"loss": "mse",
|
||||
"seed": 0,
|
||||
"GPU": 0,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "ALPHA360_Denoise",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": DATA_HANDLER_CONFIG,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
# You shoud record the data in specific sequence
|
||||
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
|
||||
}
|
||||
|
||||
# model = train_model(task)
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
pred_score = model.predict(dataset)
|
||||
|
||||
# save pred_score to file
|
||||
pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
|
||||
pred_score_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
pred_score.to_pickle(pred_score_path)
|
||||
|
||||
###################################
|
||||
# backtest
|
||||
###################################
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
|
||||
# use default strategy
|
||||
# custom Strategy, refer to: TODO: Strategy API url
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
|
||||
|
||||
###################################
|
||||
# analyze
|
||||
# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
|
||||
###################################
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
print(analysis_df)
|
||||
@@ -1,59 +0,0 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -291,7 +291,9 @@ class DataHandlerLP(DataHandler):
|
||||
init_instance_by_config(
|
||||
proc,
|
||||
None if (isinstance(proc, dict) and "module_path" in proc) else processor_module,
|
||||
accept_types=processor_module.Processor))
|
||||
accept_types=processor_module.Processor,
|
||||
)
|
||||
)
|
||||
|
||||
self.process_type = process_type
|
||||
super().__init__(instruments, start_time, end_time, data_loader, **kwargs)
|
||||
|
||||
@@ -659,7 +659,7 @@ def flatten_dict(d, parent_key="", sep="."):
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
new_key = parent_key + sep + k if parent_key else k
|
||||
if isinstance(v, collections.MutableMapping):
|
||||
if isinstance(v, collections.abc.MutableMapping):
|
||||
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
||||
else:
|
||||
items.append((new_key, v))
|
||||
|
||||
@@ -13,13 +13,15 @@ from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
|
||||
|
||||
# worflow handler function
|
||||
def workflow(config_path, experiment_name="workflow"):
|
||||
with open(config_path) as fp:
|
||||
config = yaml.load(fp, Loader=yaml.Loader)
|
||||
|
||||
provider_uri = config.get("provider_uri")
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
region = config.get("region")
|
||||
qlib.init(provider_uri=provider_uri, region=region)
|
||||
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(config.get("task")["model"])
|
||||
|
||||
@@ -159,6 +159,36 @@ class Recorder:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `list_artifacts` method.")
|
||||
|
||||
def list_metrics(self):
|
||||
"""
|
||||
List all the metrics of a recorder.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary of metrics that being stored.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `list_metrics` method.")
|
||||
|
||||
def list_params(self):
|
||||
"""
|
||||
List all the params of a recorder.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary of params that being stored.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `list_params` method.")
|
||||
|
||||
def list_tags(self):
|
||||
"""
|
||||
List all the tags of a recorder.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary of tags that being stored.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `list_tags` method.")
|
||||
|
||||
|
||||
class MLflowRecorder(Recorder):
|
||||
"""
|
||||
@@ -239,7 +269,7 @@ class MLflowRecorder(Recorder):
|
||||
|
||||
def log_metrics(self, step=None, **kwargs):
|
||||
for name, data in kwargs.items():
|
||||
self.client.log_metric(self.id, name, data)
|
||||
self.client.log_metric(self.id, name, data, step=step)
|
||||
|
||||
def set_tags(self, **kwargs):
|
||||
for name, data in kwargs.items():
|
||||
@@ -261,3 +291,15 @@ class MLflowRecorder(Recorder):
|
||||
assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
artifacts = self.client.list_artifacts(self.id, artifact_path)
|
||||
return artifacts
|
||||
|
||||
def list_metrics(self):
|
||||
run = self.client.get_run(self.id)
|
||||
return run.data.metrics
|
||||
|
||||
def list_params(self):
|
||||
run = self.client.get_run(self.id)
|
||||
return run.data.params
|
||||
|
||||
def list_tags(self):
|
||||
run = self.client.get_run(self.id)
|
||||
return run.data.tags
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import shutil
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
@@ -10,7 +11,7 @@ import pandas as pd
|
||||
from scipy.stats import pearsonr
|
||||
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.config import REG_CN, C
|
||||
from qlib.utils import drop_nan_by_y_index
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
@@ -19,51 +20,78 @@ from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
|
||||
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"dropna_label": True,
|
||||
"start_date": "2008-01-01",
|
||||
"end_date": "2020-08-01",
|
||||
"market": "CSI300",
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
|
||||
MODEL_CONFIG = {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
"train_start_date": "2008-01-01",
|
||||
"train_end_date": "2014-12-31",
|
||||
"validate_start_date": "2015-01-01",
|
||||
"validate_end_date": "2016-12-31",
|
||||
"test_start_date": "2017-01-01",
|
||||
"test_end_date": "2020-08-01",
|
||||
}
|
||||
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": "SH000300",
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
port_analysis_config = {
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.strategy",
|
||||
"kwargs": {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
},
|
||||
"backtest": {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": benchmark,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -78,34 +106,32 @@ def train():
|
||||
performance: dict
|
||||
model performance
|
||||
"""
|
||||
# get data
|
||||
x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158(**DATA_HANDLER_CONFIG).get_split_data(
|
||||
**TRAINER_CONFIG
|
||||
)
|
||||
|
||||
# train
|
||||
model = LGBModel(**MODEL_CONFIG)
|
||||
model.fit(x_train, y_train, x_validate, y_validate)
|
||||
_pred = model.predict(x_test)
|
||||
_pred = pd.DataFrame(_pred, index=x_test.index, columns=y_test.columns)
|
||||
pred_score = pd.DataFrame(index=_pred.index)
|
||||
pred_score["score"] = _pred.iloc(axis=1)[0]
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
|
||||
# get performance
|
||||
try:
|
||||
model_score = model.score(x_test, y_test)
|
||||
except NotImplementedError:
|
||||
model_score = None
|
||||
# Remove rows from x, y and w, which contain Nan in any columns in y_test.
|
||||
x_test, y_test, __ = drop_nan_by_y_index(x_test, y_test)
|
||||
pred_test = model.predict(x_test)
|
||||
model_pearsonr = pearsonr(np.ravel(pred_test), np.ravel(y_test.values))[0]
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
model.fit(dataset)
|
||||
|
||||
return pred_score, {"model_score": model_score, "model_pearsonr": model_pearsonr}
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
rid = recorder.id
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
pred_score = sr.load()
|
||||
|
||||
y_test = dataset.prepare("test", col_set="label")
|
||||
pred_score, y_test, __ = drop_nan_by_y_index(pred_score, y_test)
|
||||
model_pearsonr = pearsonr(np.ravel(pred_score.values), np.ravel(y_test.values))[0]
|
||||
|
||||
return pred_score, {"model_pearsonr": model_pearsonr}, rid
|
||||
|
||||
|
||||
def backtest(pred):
|
||||
"""backtest
|
||||
def backtest_analysis(pred, rid):
|
||||
"""backtest and analysis
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -114,23 +140,14 @@ def backtest(pred):
|
||||
|
||||
Returns
|
||||
-------
|
||||
report_normal: pandas.DataFrame
|
||||
|
||||
positions_normal: dict
|
||||
analysis result : pandas.DataFrame
|
||||
|
||||
"""
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
_report_normal, _positions_normal = normal_backtest(pred, strategy=strategy, **BACKTEST_CONFIG)
|
||||
return _report_normal, _positions_normal
|
||||
|
||||
|
||||
def analyze(report_normal):
|
||||
_analysis = dict()
|
||||
_analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
_analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(_analysis) # type: pd.DataFrame
|
||||
recorder = R.get_recorder(experiment_name="workflow", recorder_id=rid)
|
||||
# backtest
|
||||
par = PortAnaRecord(recorder, port_analysis_config)
|
||||
par.generate()
|
||||
analysis_df = par.load("port_analysis.pkl")
|
||||
print(analysis_df)
|
||||
return analysis_df
|
||||
|
||||
@@ -139,6 +156,7 @@ class TestAllFlow(unittest.TestCase):
|
||||
PRED_SCORE = None
|
||||
REPORT_NORMAL = None
|
||||
POSITIONS = None
|
||||
RID = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
@@ -154,13 +172,16 @@ class TestAllFlow(unittest.TestCase):
|
||||
)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
shutil.rmtree(str(Path(C["exp_manager"]["kwargs"]["uri"].strip("file:")).resolve()))
|
||||
|
||||
def test_0_train(self):
|
||||
TestAllFlow.PRED_SCORE, model_pearsonr = train()
|
||||
TestAllFlow.PRED_SCORE, model_pearsonr, TestAllFlow.RID = train()
|
||||
self.assertGreaterEqual(model_pearsonr["model_pearsonr"], 0, "train failed")
|
||||
|
||||
def test_1_backtest(self):
|
||||
TestAllFlow.REPORT_NORMAL, TestAllFlow.POSITIONS = backtest(TestAllFlow.PRED_SCORE)
|
||||
analyze_df = analyze(TestAllFlow.REPORT_NORMAL)
|
||||
analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID)
|
||||
self.assertGreaterEqual(
|
||||
analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0],
|
||||
0.10,
|
||||
|
||||
Reference in New Issue
Block a user