mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 11:30:57 +08:00
adjust data and model interface
This commit is contained in:
@@ -1,22 +1,26 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import six
|
||||
from ..utils.serial import Serializable
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Model(object):
|
||||
"""Model base class"""
|
||||
class BaseModel(Serializable, metaclass=abc.ABCMeta):
|
||||
'''Modeling things'''
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return type(self).__name__
|
||||
@abc.abstractmethod
|
||||
def predict(self, *args, **kwargs) -> object:
|
||||
""" Make predictions after modeling things """
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs) -> object:
|
||||
""" levarge Python syntactic sugar to make the models' behaviors like functions """
|
||||
return self.predict(*args, **kwargs)
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
'''Learnable Models'''
|
||||
|
||||
# TODO: Make the model easier.
|
||||
def fit(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs):
|
||||
"""fix train with cross-validation
|
||||
Fit model when ex_config.finetune is False
|
||||
@@ -43,25 +47,7 @@ class Model(object):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def score(self, x_test, y_test, w_test=None, **kwargs):
|
||||
"""evaluate model with test data/label
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x_test : pd.dataframe
|
||||
test data
|
||||
y_test : pd.dataframe
|
||||
test label
|
||||
w_test : pd.dataframe
|
||||
test weight
|
||||
|
||||
Returns
|
||||
----------
|
||||
float
|
||||
evaluation score
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def predict(self, x_test, **kwargs):
|
||||
"""predict given test data
|
||||
|
||||
@@ -76,80 +62,3 @@ class Model(object):
|
||||
test predict label
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def save(self, fname, **kwargs):
|
||||
"""save model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname : str
|
||||
model filename
|
||||
"""
|
||||
# TODO: Currently need to save the model as a single file, otherwise the estimator may not be compatible
|
||||
raise NotImplementedError()
|
||||
|
||||
def load(self, buffer, **kwargs):
|
||||
"""load model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
buffer : bytes
|
||||
binary data of model parameters
|
||||
|
||||
Returns
|
||||
----------
|
||||
Model
|
||||
loaded model
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_data_with_date(self, date, **kwargs):
|
||||
"""
|
||||
Will be called in online module
|
||||
need to return the data that used to predict the label (score) of stocks at date.
|
||||
|
||||
:param
|
||||
date: pd.Timestamp
|
||||
predict date
|
||||
:return:
|
||||
data: the input data that used to predict the label (score) of stocks at predict date.
|
||||
"""
|
||||
raise NotImplementedError("get_data_with_date for this model is not implemented.")
|
||||
|
||||
def finetune(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs):
|
||||
"""Finetune model
|
||||
In `RollingTrainer`:
|
||||
if loader.model_index is None:
|
||||
If provide 'Static Model', based on the provided 'Static' model update.
|
||||
If provide 'Rolling Model', skip the model of load, based on the last 'provided model' update.
|
||||
|
||||
if loader.model_index is not None:
|
||||
Based on the provided model(loader.model_index) update.
|
||||
|
||||
In `StaticTrainer`:
|
||||
If the load is 'static model':
|
||||
Based on the 'static model' update
|
||||
If the load is 'rolling model':
|
||||
Based on the provided model(`loader.model_index`) update. If `loader.model_index` is None, use the last model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x_train : pd.dataframe
|
||||
train data
|
||||
y_train : pd.dataframe
|
||||
train label
|
||||
x_valid : pd.dataframe
|
||||
valid data
|
||||
y_valid : pd.dataframe
|
||||
valid label
|
||||
w_train : pd.dataframe
|
||||
train weight
|
||||
w_valid : pd.dataframe
|
||||
valid weight
|
||||
|
||||
Returns
|
||||
----------
|
||||
Model
|
||||
finetune model
|
||||
"""
|
||||
raise NotImplementedError("Finetune for this model is not implemented.")
|
||||
|
||||
Reference in New Issue
Block a user