1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 10:01:19 +08:00
Files
qlib/qlib/model/base.py
SunsetWolf 144e1e2459 Fix pylint (#888)
* add_pylint_to_workflow

* fix-pylint

* fix_pylinterror

* fix-issue
2022-01-26 19:27:24 +08:00

111 lines
3.7 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import abc
from typing import Text, Union
from ..utils.serial import Serializable
from ..data.dataset import Dataset
from ..data.dataset.weight import Reweighter
class BaseModel(Serializable, metaclass=abc.ABCMeta):
"""Modeling things"""
@abc.abstractmethod
def predict(self, *args, **kwargs) -> object:
"""Make predictions after modeling things"""
def __call__(self, *args, **kwargs) -> object:
"""leverage Python syntactic sugar to make the models' behaviors like functions"""
return self.predict(*args, **kwargs)
class Model(BaseModel):
"""Learnable Models"""
def fit(self, dataset: Dataset, reweighter: Reweighter):
"""
Learn model from the base model
.. note::
The attribute names of learned model should `not` start with '_'. So that the model could be
dumped to disk.
The following code example shows how to retrieve `x_train`, `y_train` and `w_train` from the `dataset`:
.. code-block:: Python
# get features and labels
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
# get weights
try:
wdf_train, wdf_valid = dataset.prepare(["train", "valid"], col_set=["weight"],
data_key=DataHandlerLP.DK_L)
w_train, w_valid = wdf_train["weight"], wdf_valid["weight"]
except KeyError as e:
w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)
w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)
Parameters
----------
dataset : Dataset
dataset will generate the processed data from model training.
"""
raise NotImplementedError()
@abc.abstractmethod
def predict(self, dataset: Dataset, segment: Union[Text, slice] = "test") -> object:
"""give prediction given Dataset
Parameters
----------
dataset : Dataset
dataset will generate the processed dataset from model training.
segment : Text or slice
dataset will use this segment to prepare data. (default=test)
Returns
-------
Prediction results with certain type such as `pandas.Series`.
"""
raise NotImplementedError()
class ModelFT(Model):
"""Model (F)ine(t)unable"""
@abc.abstractmethod
def finetune(self, dataset: Dataset):
"""finetune model based given dataset
A typical use case of finetuning model with qlib.workflow.R
.. code-block:: python
# start exp to train init model
with R.start(experiment_name="init models"):
model.fit(dataset)
R.save_objects(init_model=model)
rid = R.get_recorder().id
# Finetune model based on previous trained model
with R.start(experiment_name="finetune model"):
recorder = R.get_recorder(recorder_id=rid, experiment_name="init models")
model = recorder.load_object("init_model")
model.finetune(dataset, num_boost_round=10)
Parameters
----------
dataset : Dataset
dataset will generate the processed dataset from model training.
"""
raise NotImplementedError()