mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 12:00:58 +08:00
init commit
This commit is contained in:
173
tests/test_all_pipeline.py
Normal file
173
tests/test_all_pipeline.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from scipy.stats import pearsonr
|
||||
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import drop_nan_by_y_index
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.estimator.handler import QLibDataHandlerClose
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"dropna_label": True,
|
||||
"start_date": "2008-01-01",
|
||||
"end_date": "2020-08-01",
|
||||
"market": "CSI300",
|
||||
}
|
||||
|
||||
MODEL_CONFIG = {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
}
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
"train_start_date": "2008-01-01",
|
||||
"train_end_date": "2014-12-31",
|
||||
"validate_start_date": "2015-01-01",
|
||||
"validate_end_date": "2016-12-31",
|
||||
"test_start_date": "2017-01-01",
|
||||
"test_end_date": "2020-08-01",
|
||||
}
|
||||
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": "SH000300",
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
|
||||
|
||||
# train
|
||||
def train():
|
||||
"""train model
|
||||
|
||||
Returns
|
||||
-------
|
||||
pred_score: pandas.DataFrame
|
||||
predict scores
|
||||
performance: dict
|
||||
model performance
|
||||
"""
|
||||
# get data
|
||||
x_train, y_train, x_validate, y_validate, x_test, y_test = QLibDataHandlerClose(
|
||||
**DATA_HANDLER_CONFIG
|
||||
).get_split_data(**TRAINER_CONFIG)
|
||||
|
||||
# train
|
||||
model = LGBModel(**MODEL_CONFIG)
|
||||
model.fit(x_train, y_train, x_validate, y_validate)
|
||||
_pred = model.predict(x_test)
|
||||
_pred = pd.DataFrame(_pred, index=x_test.index, columns=y_test.columns)
|
||||
pred_score = pd.DataFrame(index=_pred.index)
|
||||
pred_score["score"] = _pred.iloc(axis=1)[0]
|
||||
|
||||
# get performance
|
||||
model_score = model.score(x_test, y_test)
|
||||
# Remove rows from x, y and w, which contain Nan in any columns in y_test.
|
||||
x_test, y_test, __ = drop_nan_by_y_index(x_test, y_test)
|
||||
pred_test = model.predict(x_test)
|
||||
model_pearsonr = pearsonr(np.ravel(pred_test), np.ravel(y_test.values))[0]
|
||||
|
||||
return pred_score, {"model_score": model_score, "model_pearsonr": model_pearsonr}
|
||||
|
||||
|
||||
def backtest(pred):
|
||||
"""backtest
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred: pandas.DataFrame
|
||||
predict scores
|
||||
|
||||
Returns
|
||||
-------
|
||||
report_normal: pandas.DataFrame
|
||||
|
||||
positions_normal: dict
|
||||
|
||||
"""
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
_report_normal, _positions_normal = normal_backtest(pred, strategy=strategy, **BACKTEST_CONFIG)
|
||||
return _report_normal, _positions_normal
|
||||
|
||||
|
||||
def analyze(report_normal):
|
||||
_analysis = dict()
|
||||
_analysis["sub_bench"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
_analysis["sub_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"])
|
||||
analysis_df = pd.concat(_analysis) # type: pd.DataFrame
|
||||
print(analysis_df)
|
||||
return analysis_df
|
||||
|
||||
|
||||
class TestAllFlow(unittest.TestCase):
|
||||
PRED_SCORE = None
|
||||
REPORT_NORMAL = None
|
||||
POSITIONS = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
# use default data
|
||||
provier_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provier_uri):
|
||||
print(f"Qlib data is not found in {provier_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data_cn(provier_uri)
|
||||
qlib.init(provier_uri=provier_uri, region=REG_CN)
|
||||
|
||||
def test_0_train(self):
|
||||
TestAllFlow.PRED_SCORE, model_pearsonr = train()
|
||||
self.assertGreaterEqual(model_pearsonr["model_pearsonr"], 0, "train failed")
|
||||
|
||||
def test_1_backtest(self):
|
||||
TestAllFlow.REPORT_NORMAL, TestAllFlow.POSITIONS = backtest(
|
||||
TestAllFlow.PRED_SCORE
|
||||
)
|
||||
analyze_df = analyze(TestAllFlow.REPORT_NORMAL)
|
||||
self.assertGreaterEqual(
|
||||
analyze_df.loc(axis=0)["sub_cost", "annual"].values[0], 0.10, "backtest failed",
|
||||
)
|
||||
|
||||
|
||||
def suite():
|
||||
_suite = unittest.TestSuite()
|
||||
_suite.addTest(TestAllFlow("test_0_train"))
|
||||
_suite.addTest(TestAllFlow("test_1_backtest"))
|
||||
return _suite
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runner = unittest.TextTestRunner()
|
||||
runner.run(suite())
|
||||
88
tests/test_dump_data.py
Normal file
88
tests/test_dump_data.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import sys
|
||||
import shutil
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.data import D
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
from dump_bin import DumpData
|
||||
|
||||
|
||||
DATA_DIR = Path(__file__).parent.joinpath("test_data")
|
||||
SOURCE_DIR = DATA_DIR.joinpath("source")
|
||||
SOURCE_DIR.mkdir(exist_ok=True, parents=True)
|
||||
QLIB_DIR = DATA_DIR.joinpath("qlib")
|
||||
QLIB_DIR.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
class TestDumpData(unittest.TestCase):
|
||||
FIELDS = "open,close,high,low,volume,factor,change".split(",")
|
||||
QLIB_FIELDS = list(map(lambda x: f"${x}", FIELDS))
|
||||
DUMP_DATA = None
|
||||
STOCK_NAMES = None
|
||||
|
||||
# simpe data
|
||||
SIMPLE_DATA = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
GetData().csv_data_cn(SOURCE_DIR)
|
||||
TestDumpData.DUMP_DATA = DumpData(csv_path=SOURCE_DIR, qlib_dir=QLIB_DIR)
|
||||
TestDumpData.STOCK_NAMES = list(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.iterdir()))
|
||||
provider_uri = str(QLIB_DIR.resolve())
|
||||
qlib.init(
|
||||
provider_uri=provider_uri,
|
||||
expression_cache=None,
|
||||
dataset_cache=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
shutil.rmtree(str(DATA_DIR.resolve()))
|
||||
|
||||
def test_0_dump_calendars(self):
|
||||
self.DUMP_DATA.dump_calendars()
|
||||
ori_calendars = set(
|
||||
map(
|
||||
pd.Timestamp,
|
||||
pd.read_csv(QLIB_DIR.joinpath("calendars", "day.txt"), header=None).loc[:, 0].values,
|
||||
)
|
||||
)
|
||||
res_calendars = set(D.calendar())
|
||||
assert len(ori_calendars - res_calendars) == len(res_calendars - ori_calendars) == 0, "dump calendars failed"
|
||||
|
||||
def test_1_dump_instruments(self):
|
||||
self.DUMP_DATA.dump_instruments()
|
||||
ori_ins = set(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv")))
|
||||
res_ins = set(D.list_instruments(D.instruments("all"), as_list=True))
|
||||
assert len(ori_ins - res_ins) == len(ori_ins - res_ins) == 0, "dump instruments failed"
|
||||
|
||||
def test_2_dump_features(self):
|
||||
self.DUMP_DATA.dump_features(include_fields=self.FIELDS)
|
||||
df = D.features(self.STOCK_NAMES, self.QLIB_FIELDS)
|
||||
TestDumpData.SIMPLE_DATA = df.loc(axis=0)[self.STOCK_NAMES[0], :]
|
||||
self.assertFalse(df.dropna().empty, "features data failed")
|
||||
self.assertListEqual(list(df.columns), self.QLIB_FIELDS, "features columns failed")
|
||||
|
||||
def test_3_dump_features_simple(self):
|
||||
stock = self.STOCK_NAMES[0]
|
||||
dump_data = DumpData(csv_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR)
|
||||
dump_data.dump_features(include_fields=self.FIELDS, calendar_path=QLIB_DIR.joinpath("calendars", "day.txt"))
|
||||
|
||||
df = D.features([stock], self.QLIB_FIELDS)
|
||||
|
||||
self.assertEqual(len(df), len(TestDumpData.SIMPLE_DATA), "dump features simple failed")
|
||||
self.assertTrue(np.isclose(df.dropna(), self.SIMPLE_DATA.dropna()).all(), "dump features simple failed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
52
tests/test_get_data.py
Normal file
52
tests/test_get_data.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import shutil
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
DATA_DIR = Path(__file__).parent.joinpath("test_data")
|
||||
SOURCE_DIR = DATA_DIR.joinpath("source")
|
||||
SOURCE_DIR.mkdir(exist_ok=True, parents=True)
|
||||
QLIB_DIR = DATA_DIR.joinpath("qlib")
|
||||
QLIB_DIR.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
class TestGetData(unittest.TestCase):
|
||||
FIELDS = "$open,$close,$high,$low,$volume,$factor,$change".split(",")
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
provider_uri = str(QLIB_DIR.resolve())
|
||||
qlib.init(
|
||||
provider_uri=provider_uri,
|
||||
expression_cache=None,
|
||||
dataset_cache=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
shutil.rmtree(str(DATA_DIR.resolve()))
|
||||
|
||||
def test_0_qlib_data(self):
|
||||
|
||||
GetData().qlib_data_cn(QLIB_DIR)
|
||||
df = D.features(D.instruments("csi300"), self.FIELDS)
|
||||
self.assertListEqual(list(df.columns), self.FIELDS, "get qlib data failed")
|
||||
self.assertFalse(df.dropna().empty, "get qlib data failed")
|
||||
|
||||
def test_1_csv_data(self):
|
||||
GetData().csv_data_cn(SOURCE_DIR)
|
||||
stock_name = set(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv")))
|
||||
self.assertEqual(len(stock_name), 96, "get csv data failed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user