1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-06 04:20:57 +08:00

add finetune example & fix serial bug

This commit is contained in:
Young
2020-11-16 13:11:39 +00:00
parent 3e04ded750
commit 90d41e4022
5 changed files with 203 additions and 37 deletions

View File

@@ -0,0 +1,131 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
from pathlib import Path
import qlib
import pandas as pd
from qlib.config import REG_CN
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from qlib.utils import exists_qlib_data, init_instance_by_config
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
if __name__ == "__main__":
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data_cn(target_dir=provider_uri)
qlib.init(provider_uri=provider_uri, region=REG_CN)
MARKET = "csi300"
BENCHMARK = "SH000300"
###################################
# train model
###################################
DATA_HANDLER_CONFIG = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": MARKET,
}
task = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": 0.8879,
"learning_rate": 0.0421,
"subsample": 0.8789,
"lambda_l1": 205.6999,
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
},
},
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": DATA_HANDLER_CONFIG,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
# You shoud record the data in specific sequence
"record": ["SignalRecord", "PortAnaRecord"],
}
port_analysis_config = {
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.strategy",
"kwargs": {
"topk": 50,
"n_drop": 5,
}
},
"backtest": {
"verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": BENCHMARK,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
},
}
# model initiaiton
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
# start exp to train init model
with R.start(experiment_name="init models"):
model.fit(dataset)
R.save_objects(init_model=model)
rid = R.get_recorder().id
# Finetune model based on previous trained model
with R.start(experiment_name="finetune model"):
recorder = R.get_recorder(rid, experiment_name="init models")
model = recorder.load_object("init_model")
model.finetune(dataset, num_boost_round=10)
R.save_objects(model=model)
# prediction
recorder = R.get_recorder()
sr = SignalRecord(model, dataset, recorder)
sr.generate()
# backtest
par = PortAnaRecord(recorder, port_analysis_config)
par.generate()

View File

@@ -5,56 +5,54 @@ import numpy as np
import pandas as pd
import lightgbm as lgb
from ...model.base import Model
from ...model.base import ModelFT
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
class LGBModel(Model):
class LGBModel(ModelFT):
"""LightGBM Model"""
def __init__(self, loss="mse", **kwargs):
if loss not in {"mse", "binary"}:
raise NotImplementedError
self._params = {"objective": loss}
self._params.update(kwargs)
self.params = {"objective": loss}
self.params.update(kwargs)
self.model = None
def fit(
self,
dataset: DatasetH,
num_boost_round=1000,
early_stopping_rounds=50,
verbose_eval=20,
evals_result=dict(),
**kwargs
):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
def _prepare_data(self, dataset: DatasetH):
df_train, df_valid = dataset.prepare(["train", "valid"],
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
# Lightgbm need 1D array as its label
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze(y_valid.values)
y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)
else:
raise ValueError("LightGBM doesn't support multi-label training")
dtrain = lgb.Dataset(x_train.values, label=y_train_1d)
dvalid = lgb.Dataset(x_valid.values, label=y_valid_1d)
self.model = lgb.train(
self._params,
dtrain,
num_boost_round=num_boost_round,
valid_sets=[dtrain, dvalid],
valid_names=["train", "valid"],
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
evals_result=evals_result,
**kwargs
)
dtrain = lgb.Dataset(x_train.values, label=y_train)
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
return dtrain, dvalid
def fit(self,
dataset: DatasetH,
num_boost_round=1000,
early_stopping_rounds=50,
verbose_eval=20,
evals_result=dict(),
**kwargs):
dtrain, dvalid = self._prepare_data(dataset)
self.model = lgb.train(self.params,
dtrain,
num_boost_round=num_boost_round,
valid_sets=[dtrain, dvalid],
valid_names=["train", "valid"],
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
evals_result=evals_result,
**kwargs)
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]
@@ -63,3 +61,25 @@ class LGBModel(Model):
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
return pd.Series(self.model.predict(np.squeeze(x_test.values)), index=x_test.index)
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
"""
finetune model
Parameters
----------
dataset : DatasetH
dataset for finetuning
num_boost_round : int
number of round to finetune model
verbose_eval : int
verbose level
"""
dtrain, _ = self._prepare_data(dataset)
self.model = lgb.train(self.params,
dtrain,
num_boost_round=num_boost_round,
init_model=self.model,
valid_sets=[dtrain],
valid_names=["train"],
verbose_eval=verbose_eval)

View File

@@ -45,3 +45,18 @@ class Model(BaseModel):
dataset will generate the processed dataset from model training
"""
raise NotImplementedError()
class ModelFT(Model):
'''Model (F)ine(t)unable'''
@abc.abstractmethod
def finetune(self, dataset: Dataset):
"""finetune model based given dataset
Parameters
----------
dataset : Dataset
dataset will generate the processed dataset from model training
"""
raise NotImplementedError()

View File

@@ -8,11 +8,11 @@ import pickle
class Serializable:
"""
Serializable behaves like pickle.
But it only save the state whose name starts with `_`
But it only saves the state whose name **does not** start with `_`
"""
def __getstate__(self) -> dict:
return {k: v for k, v in self.__dict__.items() if k.startswith("_")}
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
def __setstate__(self, state: dict):
self.__dict__.update(state)

View File

@@ -226,7 +226,7 @@ class MLflowExperiment(Experiment):
return self.active_recorder
else:
raise Exception(
"Something went wrong when retrieving recorders. Please check if QlibRecorder is running or the name/id of the recorder is correct."
"Something went wrong when retrieving recorders. Please check if QlibRecorder is running."
)
else:
if recorder_id is not None:
@@ -235,7 +235,7 @@ class MLflowExperiment(Experiment):
else:
# mlflow does not support create a run with given id
raise Exception(
"Something went wrong when retrieving recorders. Please check if QlibRecorder is running or the name/id of the recorder is correct."
"Something went wrong when retrieving recorders. Please check if id of the recorder is correct."
)
else:
for rid in recorders:
@@ -250,7 +250,7 @@ class MLflowExperiment(Experiment):
return recorder
else:
raise Exception(
"Something went wrong when retrieving experiments. Please check if QlibRecorder is running or the name/id of the experiment is correct."
"Something went wrong when retrieving experiments. Please check if the name of the experiment is correct."
)
def list_recorders(self):