1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Add catboost config and notebook

This commit is contained in:
Jactus
2020-11-19 17:18:18 +08:00
parent df406d58a5
commit c91698287a
9 changed files with 415 additions and 87 deletions

View File

@@ -0,0 +1,53 @@
provider_uri: "~/.qlib/qlib_data/cn_data"
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: CatBoostModel
module_path: qlib.contrib.model.catboost_model
kwargs:
loss: RMSE
iterations: 5
learning_rate: 0.03
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: ALPHA360_Denoise
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -1,64 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import pandas as pd
import xgboost as xgb
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
class XGBModel(Model):
"""XGBModel Model"""
def __init__(self, obj="mse", **kwargs):
if obj not in {"mse", "binary"}:
raise NotImplementedError
self._params = {"obj": obj}
self._params.update(kwargs)
self.model = None
def fit(
self,
dataset: DatasetH,
num_boost_round=1000,
early_stopping_rounds=50,
verbose_eval=20,
evals_result=dict(),
**kwargs
):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
# Lightgbm need 1D array as its label
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze(y_valid.values)
else:
raise ValueError("XGBoost doesn't support multi-label training")
dtrain = xgb.DMatrix(x_train.values, label=y_train_1d)
dvalid = xgb.DMatrix(x_valid.values, label=y_valid_1d)
self.model = xgb.train(
self._params,
dtrain=dtrain,
num_boost_round=num_boost_round,
evals=[(dtrain, "train"), (dvalid, "valid")],
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
evals_result=evals_result,
**kwargs
)
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]
def predict(self, dataset):
if self.model is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
return pd.Series(self.model.predict(xgb.DMatrix(np.squeeze(x_test.values))), index=x_test.index)

View File

@@ -0,0 +1,330 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"from pathlib import Path\n",
"\n",
"import qlib\n",
"import pandas as pd\n",
"from qlib.config import REG_CN\n",
"from qlib.contrib.model.gbdt import LGBModel\n",
"from qlib.contrib.estimator.handler import Alpha158\n",
"from qlib.contrib.strategy.strategy import TopkDropoutStrategy\n",
"from qlib.contrib.evaluate import (\n",
" backtest as normal_backtest,\n",
" risk_analysis,\n",
")\n",
"from qlib.utils import exists_qlib_data, init_instance_by_config\n",
"from qlib.workflow import R\n",
"from qlib.workflow.record_temp import SignalRecord, PortAnaRecord"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# use default data\n",
"# NOTE: need to download data from remote: python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data\n",
"provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
"if not exists_qlib_data(provider_uri):\n",
" print(f\"Qlib data is not found in {provider_uri}\")\n",
" sys.path.append(str(Path.cwd().parent.joinpath(\"scripts\")))\n",
" from get_data import GetData\n",
" GetData().qlib_data_cn(target_dir=provider_uri)\n",
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"market = \"csi300\"\n",
"benchmark = \"SH000300\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# train model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"###################################\n",
"# train model\n",
"###################################\n",
"data_handler_config = {\n",
" \"start_time\": \"2008-01-01\",\n",
" \"end_time\": \"2020-08-01\",\n",
" \"fit_start_time\": \"2008-01-01\",\n",
" \"fit_end_time\": \"2014-12-31\",\n",
" \"instruments\": market,\n",
"}\n",
"\n",
"task = {\n",
" \"model\": {\n",
" \"class\": \"LGBModel\",\n",
" \"module_path\": \"qlib.contrib.model.gbdt\",\n",
" \"kwargs\": {\n",
" \"loss\": \"mse\",\n",
" \"colsample_bytree\": 0.8879,\n",
" \"learning_rate\": 0.0421,\n",
" \"subsample\": 0.8789,\n",
" \"lambda_l1\": 205.6999,\n",
" \"lambda_l2\": 580.9768,\n",
" \"max_depth\": 8,\n",
" \"num_leaves\": 210,\n",
" \"num_threads\": 20,\n",
" },\n",
" },\n",
" \"dataset\": {\n",
" \"class\": \"DatasetH\",\n",
" \"module_path\": \"qlib.data.dataset\",\n",
" \"kwargs\": {\n",
" \"handler\": {\n",
" \"class\": \"Alpha158\",\n",
" \"module_path\": \"qlib.contrib.data.handler\",\n",
" \"kwargs\": data_handler_config,\n",
" },\n",
" \"segments\": {\n",
" \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
" \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
" \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
" },\n",
" },\n",
" },\n",
"}\n",
"\n",
"# model initiaiton\n",
"model = init_instance_by_config(task[\"model\"])\n",
"dataset = init_instance_by_config(task[\"dataset\"])\n",
"\n",
"# start exp to train model\n",
"with R.start(experiment_name=\"train_model\"):\n",
" R.log_paramters(**flatten_dict(task))\n",
" model.fit(dataset)\n",
" R.save_objects(trained_model=model)\n",
" rid = R.get_recorder().id\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# prediction, backtest & analysis"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"###################################\n",
"# prediction, backtest & analysis\n",
"###################################\n",
"port_analysis_config = {\n",
" \"strategy\": {\n",
" \"class\": \"TopkDropoutStrategy\",\n",
" \"module_path\": \"qlib.contrib.strategy.strategy\",\n",
" \"kwargs\": {\n",
" \"topk\": 50,\n",
" \"n_drop\": 5,\n",
" },\n",
" },\n",
" \"backtest\": {\n",
" \"verbose\": False,\n",
" \"limit_threshold\": 0.095,\n",
" \"account\": 100000000,\n",
" \"benchmark\": benchmark,\n",
" \"deal_price\": \"close\",\n",
" \"open_cost\": 0.0005,\n",
" \"close_cost\": 0.0015,\n",
" \"min_cost\": 5,\n",
" },\n",
"}\n",
"\n",
"\n",
"# backtest and analysis\n",
"with R.start(experiment_name=\"backtest_analysis\"):\n",
" recorder = R.get_recorder(rid, experiment_name=\"train_model\")\n",
" model = recorder.load_object(\"trained_model\")\n",
"\n",
" # prediction\n",
" recorder = R.get_recorder()\n",
" ba_rid = recorder.id\n",
" sr = SignalRecord(model, dataset, recorder)\n",
" sr.generate()\n",
"\n",
" # backtest & analysis\n",
" par = PortAnaRecord(recorder, port_analysis_config)\n",
" par.generate()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# analyze graphs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from qlib.contrib.report import analysis_model, analysis_position\n",
"from qlib.data import D\n",
"recorder = R.get_recorder(ba_rid, experiment_name=\"backtest_analysis\")\n",
"pred_df = recorder.load_object(\"pred.pkl\")\n",
"pred_df_dates = pred_df.index.get_level_values(level='datetime')\n",
"report_normal_df = recorder.load_object(\"portfolio_analysis/report_normal.pkl\")\n",
"positions = recorder.load_object(\"portfolio_analysis/positions_normal.pkl\")\n",
"analysis_df = recorder.load_object(\"portfolio_analysis/port_analysis.pkl\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## analysis position"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### report"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis_position.report_graph(report_normal_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### risk analysis"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis_position.risk_analysis_graph(analysis_df, report_normal_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## analysis model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"label_df = dataset.prepare(\"test\", col_set=\"label\")\n",
"label_df.columns = ['label']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### score IC"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pred_label = pd.concat([label_df, pred_df], axis=1, sort=True).reindex(label_df.index)\n",
"analysis_position.score_ic_graph(pred_label)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### model performance"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis_model.model_performance_graph(pred_label)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@@ -32,18 +32,18 @@ if __name__ == "__main__":
qlib.init(provider_uri=provider_uri, region=REG_CN)
MARKET = "csi300"
BENCHMARK = "SH000300"
market = "csi300"
benchmark = "SH000300"
###################################
# train model
###################################
DATA_HANDLER_CONFIG = {
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": MARKET,
"instruments": market,
}
task = {
@@ -69,7 +69,7 @@ if __name__ == "__main__":
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": DATA_HANDLER_CONFIG,
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
@@ -78,8 +78,6 @@ if __name__ == "__main__":
},
},
},
# You shoud record the data in specific sequence
"record": ["SignalRecord", "PortAnaRecord"],
}
port_analysis_config = {
@@ -95,7 +93,7 @@ if __name__ == "__main__":
"verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": BENCHMARK,
"benchmark": benchmark,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
@@ -108,7 +106,8 @@ if __name__ == "__main__":
dataset = init_instance_by_config(task["dataset"])
# start exp
with R.start("workflow"):
with R.start(experiment_name="workflow"):
R.log_paramters(**flatten_dict(task))
model.fit(dataset)
# prediction

View File

@@ -32,18 +32,18 @@ if __name__ == "__main__":
qlib.init(provider_uri=provider_uri, region=REG_CN)
MARKET = "csi300"
BENCHMARK = "SH000300"
market = "csi300"
benchmark = "SH000300"
###################################
# train model
###################################
DATA_HANDLER_CONFIG = {
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": MARKET,
"instruments": market,
}
task = {
@@ -69,7 +69,7 @@ if __name__ == "__main__":
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": DATA_HANDLER_CONFIG,
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
@@ -78,8 +78,6 @@ if __name__ == "__main__":
},
},
},
# You shoud record the data in specific sequence
"record": ["SignalRecord", "PortAnaRecord"],
}
port_analysis_config = {
@@ -95,7 +93,7 @@ if __name__ == "__main__":
"verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": BENCHMARK,
"benchmark": benchmark,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,

View File

@@ -130,7 +130,7 @@ _default_config = {
"class": "MLflowExpManager",
"module_path": "qlib.workflow.expm",
"kwargs": {
"uri": 'file:' + str(Path(os.getcwd()).resolve() / "mlruns"),
"uri": "file:" + str(Path(os.getcwd()).resolve() / "mlruns"),
"default_exp_name": "Experiment",
},
},

View File

@@ -289,8 +289,12 @@ class DataHandlerLP(DataHandler):
getattr(self, pname).append(
init_instance_by_config(
proc,
None if (isinstance(data_loader, dict) and "module_path" in data_loader) else data_loader_module,
accept_types=processor_module.Processor))
None
if (isinstance(data_loader, dict) and "module_path" in data_loader)
else data_loader_module,
accept_types=processor_module.Processor,
)
)
self.process_type = process_type
super().__init__(instruments, start_time, end_time, data_loader, **kwargs)

View File

@@ -32,7 +32,10 @@ def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
def fetch_df_by_index(
df: pd.DataFrame, selector: Union[pd.Timestamp, slice, str, list], level: Union[str, int], fetch_orig=True,
df: pd.DataFrame,
selector: Union[pd.Timestamp, slice, str, list],
level: Union[str, int],
fetch_orig=True,
) -> pd.DataFrame:
"""
fetch data from `data` with `selector` and `level`
@@ -55,8 +58,12 @@ def fetch_df_by_index(
if fetch_orig:
for slc in idx_slc:
if slc != slice(None, None):
return df.loc[pd.IndexSlice[idx_slc],]
return df.loc[
pd.IndexSlice[idx_slc],
]
else:
return df
else:
return df.loc[pd.IndexSlice[idx_slc],]
return df.loc[
pd.IndexSlice[idx_slc],
]

View File

@@ -5,6 +5,7 @@ import sys, traceback, signal, atexit
from . import R
from .recorder import Recorder
from ..log import get_module_logger
logger = get_module_logger("workflow", "INFO")