1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 10:31:00 +08:00
Files
qlib/tests/test_all_pipeline.py
you-n-g de9e13b171 release-0.5.0 (#1)
* init commit

* change the version number

* rich the docs&fix cache docs

* update index readme

* Modify cache class name

* Modify sharpe to information_ratio

* Modify Group- to Group

* add the description of graphical results & fix the backtest docs

* fix docs in details

* update docs

* Update introduction.rst

* Update README.md

* Update introduction.rst

* Update introduction.rst

* Update introduction.rst

* Update installation.rst

* Update installation.rst

* Update initialization.rst

* Update getdata.rst

* Update integration.rst

* Update initialization.rst

* Update getdata.rst

* Update estimator.rst

Modify some typos.

* Update README.md

Modify the typos.

* Update initialization.rst

* Update data.rst

* Update report.rst

* Update estimator.rst

* Update cumulative_return.py

* Update model.rst

* Update rank_label.py

* Update cumulative_return.py

* Update strategy.rst

* Update getdata.rst

* Update backtest.rst

* Update integration.rst

* Update getdata.rst

* Update introduction.rst

* Update introduction.rst

* Update README.md

* Update report.rst

* Update integration.rst

Fix typos

* Update installation.rst

Fix typos

* Update getdata.rst

* Update initialization.rst

Fix typos.

* add quick start docs&fix detials

* fix estimator docs & fix strategy docs

* fix the cahce in data.rst

* update documents

* Fix Corr && Rsquare

* fix data retrival example to csi300 & fix a data bug

* fix filter bug

* Fix data collector

* Modift model args

* add the log & fix README.md\quick.rst

* add enviroment depend & add intoduction of qlib-server online mode

* fix image center fomat & set log_only of docs is True

* fix README.md format

* update data preparation & readme logo image

* get_data support version

* Modify analysis names

* Modify analysis graph

* update report.rst & data.rst

* commmit estimator for merge

* minimal requirements

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update READEME.md

* Update READEME.md

* update estimator

* Fix doc urls

* fix get_data.py docstring

* update test_get_data.py

* Upate docs

* Upate docs

* Upate docs

Co-authored-by: bxdd <bxddream@gmail.com>
Co-authored-by: zhupr <zhu.pengrong@foxmail.com>
Co-authored-by: Wendi Li <wendili.academic@qq.com>
Co-authored-by: Dingsu Wang <dingsu.wang@gmail.com>
Co-authored-by: bxdd <45119470+bxdd@users.noreply.github.com>
Co-authored-by: cslwqxx <cslwqxx@users.noreply.github.com>
2020-09-24 12:01:39 +08:00

174 lines
4.7 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import unittest
from pathlib import Path
import numpy as np
import pandas as pd
from scipy.stats import pearsonr
import qlib
from qlib.config import REG_CN
from qlib.utils import drop_nan_by_y_index
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.estimator.handler import QLibDataHandlerClose
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from qlib.utils import exists_qlib_data
DATA_HANDLER_CONFIG = {
"dropna_label": True,
"start_date": "2008-01-01",
"end_date": "2020-08-01",
"market": "CSI300",
}
MODEL_CONFIG = {
"loss": "mse",
"colsample_bytree": 0.8879,
"learning_rate": 0.0421,
"subsample": 0.8789,
"lambda_l1": 205.6999,
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
}
TRAINER_CONFIG = {
"train_start_date": "2008-01-01",
"train_end_date": "2014-12-31",
"validate_start_date": "2015-01-01",
"validate_end_date": "2016-12-31",
"test_start_date": "2017-01-01",
"test_end_date": "2020-08-01",
}
STRATEGY_CONFIG = {
"topk": 50,
"n_drop": 5,
}
BACKTEST_CONFIG = {
"verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": "SH000300",
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
}
# train
def train():
"""train model
Returns
-------
pred_score: pandas.DataFrame
predict scores
performance: dict
model performance
"""
# get data
x_train, y_train, x_validate, y_validate, x_test, y_test = QLibDataHandlerClose(
**DATA_HANDLER_CONFIG
).get_split_data(**TRAINER_CONFIG)
# train
model = LGBModel(**MODEL_CONFIG)
model.fit(x_train, y_train, x_validate, y_validate)
_pred = model.predict(x_test)
_pred = pd.DataFrame(_pred, index=x_test.index, columns=y_test.columns)
pred_score = pd.DataFrame(index=_pred.index)
pred_score["score"] = _pred.iloc(axis=1)[0]
# get performance
model_score = model.score(x_test, y_test)
# Remove rows from x, y and w, which contain Nan in any columns in y_test.
x_test, y_test, __ = drop_nan_by_y_index(x_test, y_test)
pred_test = model.predict(x_test)
model_pearsonr = pearsonr(np.ravel(pred_test), np.ravel(y_test.values))[0]
return pred_score, {"model_score": model_score, "model_pearsonr": model_pearsonr}
def backtest(pred):
"""backtest
Parameters
----------
pred: pandas.DataFrame
predict scores
Returns
-------
report_normal: pandas.DataFrame
positions_normal: dict
"""
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
_report_normal, _positions_normal = normal_backtest(pred, strategy=strategy, **BACKTEST_CONFIG)
return _report_normal, _positions_normal
def analyze(report_normal):
_analysis = dict()
_analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
_analysis["excess_return_with_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"])
analysis_df = pd.concat(_analysis) # type: pd.DataFrame
print(analysis_df)
return analysis_df
class TestAllFlow(unittest.TestCase):
PRED_SCORE = None
REPORT_NORMAL = None
POSITIONS = None
@classmethod
def setUpClass(cls) -> None:
# use default data
provier_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provier_uri):
print(f"Qlib data is not found in {provier_uri}")
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data_cn(provier_uri)
qlib.init(provier_uri=provier_uri, region=REG_CN)
def test_0_train(self):
TestAllFlow.PRED_SCORE, model_pearsonr = train()
self.assertGreaterEqual(model_pearsonr["model_pearsonr"], 0, "train failed")
def test_1_backtest(self):
TestAllFlow.REPORT_NORMAL, TestAllFlow.POSITIONS = backtest(
TestAllFlow.PRED_SCORE
)
analyze_df = analyze(TestAllFlow.REPORT_NORMAL)
self.assertGreaterEqual(
analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0], 0.10, "backtest failed",
)
def suite():
_suite = unittest.TestSuite()
_suite.addTest(TestAllFlow("test_0_train"))
_suite.addTest(TestAllFlow("test_1_backtest"))
return _suite
if __name__ == "__main__":
runner = unittest.TextTestRunner()
runner.run(suite())