mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-06 04:20:57 +08:00
add finetune example & fix serial bug
This commit is contained in:
131
examples/workflow_by_code_finetune.py
Normal file
131
examples/workflow_by_code_finetune.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data_cn(target_dir=provider_uri)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
MARKET = "csi300"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": MARKET,
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": DATA_HANDLER_CONFIG,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
# You shoud record the data in specific sequence
|
||||
"record": ["SignalRecord", "PortAnaRecord"],
|
||||
}
|
||||
|
||||
port_analysis_config = {
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.strategy",
|
||||
"kwargs": {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
},
|
||||
"backtest": {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
},
|
||||
}
|
||||
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
|
||||
# start exp to train init model
|
||||
with R.start(experiment_name="init models"):
|
||||
model.fit(dataset)
|
||||
R.save_objects(init_model=model)
|
||||
rid = R.get_recorder().id
|
||||
|
||||
|
||||
# Finetune model based on previous trained model
|
||||
with R.start(experiment_name="finetune model"):
|
||||
recorder = R.get_recorder(rid, experiment_name="init models")
|
||||
model = recorder.load_object("init_model")
|
||||
model.finetune(dataset, num_boost_round=10)
|
||||
R.save_objects(model=model)
|
||||
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
|
||||
# backtest
|
||||
par = PortAnaRecord(recorder, port_analysis_config)
|
||||
par.generate()
|
||||
@@ -5,56 +5,54 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import lightgbm as lgb
|
||||
|
||||
from ...model.base import Model
|
||||
from ...model.base import ModelFT
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class LGBModel(Model):
|
||||
class LGBModel(ModelFT):
|
||||
"""LightGBM Model"""
|
||||
|
||||
def __init__(self, loss="mse", **kwargs):
|
||||
if loss not in {"mse", "binary"}:
|
||||
raise NotImplementedError
|
||||
self._params = {"objective": loss}
|
||||
self._params.update(kwargs)
|
||||
self.params = {"objective": loss}
|
||||
self.params.update(kwargs)
|
||||
self.model = None
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=50,
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
**kwargs
|
||||
):
|
||||
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
def _prepare_data(self, dataset: DatasetH):
|
||||
df_train, df_valid = dataset.prepare(["train", "valid"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
# Lightgbm need 1D array as its label
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze(y_valid.values)
|
||||
y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
dtrain = lgb.Dataset(x_train.values, label=y_train_1d)
|
||||
dvalid = lgb.Dataset(x_valid.values, label=y_valid_1d)
|
||||
self.model = lgb.train(
|
||||
self._params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
)
|
||||
dtrain = lgb.Dataset(x_train.values, label=y_train)
|
||||
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
|
||||
return dtrain, dvalid
|
||||
|
||||
def fit(self,
|
||||
dataset: DatasetH,
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=50,
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
**kwargs):
|
||||
dtrain, dvalid = self._prepare_data(dataset)
|
||||
self.model = lgb.train(self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs)
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
@@ -63,3 +61,25 @@ class LGBModel(Model):
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
return pd.Series(self.model.predict(np.squeeze(x_test.values)), index=x_test.index)
|
||||
|
||||
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
|
||||
"""
|
||||
finetune model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset : DatasetH
|
||||
dataset for finetuning
|
||||
num_boost_round : int
|
||||
number of round to finetune model
|
||||
verbose_eval : int
|
||||
verbose level
|
||||
"""
|
||||
dtrain, _ = self._prepare_data(dataset)
|
||||
self.model = lgb.train(self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
init_model=self.model,
|
||||
valid_sets=[dtrain],
|
||||
valid_names=["train"],
|
||||
verbose_eval=verbose_eval)
|
||||
|
||||
@@ -45,3 +45,18 @@ class Model(BaseModel):
|
||||
dataset will generate the processed dataset from model training
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ModelFT(Model):
|
||||
'''Model (F)ine(t)unable'''
|
||||
|
||||
@abc.abstractmethod
|
||||
def finetune(self, dataset: Dataset):
|
||||
"""finetune model based given dataset
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset : Dataset
|
||||
dataset will generate the processed dataset from model training
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -8,11 +8,11 @@ import pickle
|
||||
class Serializable:
|
||||
"""
|
||||
Serializable behaves like pickle.
|
||||
But it only save the state whose name starts with `_`
|
||||
But it only saves the state whose name **does not** start with `_`
|
||||
"""
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
return {k: v for k, v in self.__dict__.items() if k.startswith("_")}
|
||||
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
|
||||
|
||||
def __setstate__(self, state: dict):
|
||||
self.__dict__.update(state)
|
||||
|
||||
@@ -226,7 +226,7 @@ class MLflowExperiment(Experiment):
|
||||
return self.active_recorder
|
||||
else:
|
||||
raise Exception(
|
||||
"Something went wrong when retrieving recorders. Please check if QlibRecorder is running or the name/id of the recorder is correct."
|
||||
"Something went wrong when retrieving recorders. Please check if QlibRecorder is running."
|
||||
)
|
||||
else:
|
||||
if recorder_id is not None:
|
||||
@@ -235,7 +235,7 @@ class MLflowExperiment(Experiment):
|
||||
else:
|
||||
# mlflow does not support create a run with given id
|
||||
raise Exception(
|
||||
"Something went wrong when retrieving recorders. Please check if QlibRecorder is running or the name/id of the recorder is correct."
|
||||
"Something went wrong when retrieving recorders. Please check if id of the recorder is correct."
|
||||
)
|
||||
else:
|
||||
for rid in recorders:
|
||||
@@ -250,7 +250,7 @@ class MLflowExperiment(Experiment):
|
||||
return recorder
|
||||
else:
|
||||
raise Exception(
|
||||
"Something went wrong when retrieving experiments. Please check if QlibRecorder is running or the name/id of the experiment is correct."
|
||||
"Something went wrong when retrieving experiments. Please check if the name of the experiment is correct."
|
||||
)
|
||||
|
||||
def list_recorders(self):
|
||||
|
||||
Reference in New Issue
Block a user