mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
adjust data and model interface
This commit is contained in:
@@ -16,6 +16,8 @@ from qlib.contrib.evaluate import (
|
||||
)
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
from qlib.model.learner import train_model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -62,6 +64,48 @@ if __name__ == "__main__":
|
||||
data = handler.fetch(slice('2008-01-01', '2014-12-31'), data_key=handler.DK_I)
|
||||
print(data)
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
'handler': {
|
||||
"class": "Alpha158",
|
||||
"kwargs": DATA_HANDLER_CONFIG
|
||||
},
|
||||
"train_start_time": "2008-01-01",
|
||||
"train_end_time": "2014-12-31",
|
||||
"validate_start_time": "2015-01-01",
|
||||
"validate_end_time": "2016-12-31",
|
||||
"test_start_time": "2017-01-01",
|
||||
"test_end_time": "2020-08-01",
|
||||
}
|
||||
}
|
||||
},
|
||||
# You shoud record the data in specific sequence
|
||||
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
|
||||
}
|
||||
|
||||
model = train_model(task)
|
||||
|
||||
|
||||
|
||||
sys.exit(0) # I have tested the code above ---------------------------------------------
|
||||
|
||||
x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158(**DATA_HANDLER_CONFIG).get_split_data(
|
||||
|
||||
@@ -25,10 +25,12 @@ class ALPHA360(DataHandlerLP):
|
||||
},
|
||||
"label": self.get_label_config()
|
||||
},
|
||||
"group_fields": True,
|
||||
}
|
||||
}
|
||||
infer_processors = ["ConfigSectionProcessor"] # ConfigSectionProcessor will normalize LABEL0
|
||||
infer_processors = [{
|
||||
"class": "ConfigSectionProcessor",
|
||||
"module_path": "qlib.contrib.data.processor"
|
||||
}] # ConfigSectionProcessor will normalize LABEL0
|
||||
super().__init__(instruments, start_time, end_time, data_loader=data_loader, infer_processors=infer_processors)
|
||||
|
||||
def get_label_config(self):
|
||||
@@ -83,7 +85,6 @@ class Alpha158(DataHandlerLP):
|
||||
"feature": self.get_feature_config(),
|
||||
"label": self.get_label_config()
|
||||
},
|
||||
"group_fields": True,
|
||||
}
|
||||
}
|
||||
super().__init__(instruments,
|
||||
@@ -94,7 +95,7 @@ class Alpha158(DataHandlerLP):
|
||||
learn_processors=learn_processors)
|
||||
|
||||
def get_feature_config(self):
|
||||
return {
|
||||
conf = {
|
||||
"kbar": {},
|
||||
"price": {
|
||||
"windows": [0],
|
||||
@@ -102,10 +103,186 @@ class Alpha158(DataHandlerLP):
|
||||
},
|
||||
"rolling": {},
|
||||
}
|
||||
return self.parse_config_to_fields(conf)
|
||||
|
||||
def get_label_config(self):
|
||||
return (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
|
||||
|
||||
@staticmethod
|
||||
def parse_config_to_fields(config):
|
||||
"""create factors from config
|
||||
|
||||
config = {
|
||||
'kbar': {}, # whether to use some hard-code kbar features
|
||||
'price': { # whether to use raw price features
|
||||
'windows': [0, 1, 2, 3, 4], # use price at n days ago
|
||||
'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
|
||||
},
|
||||
'volume': { # whether to use raw volume features
|
||||
'windows': [0, 1, 2, 3, 4], # use volume at n days ago
|
||||
},
|
||||
'rolling': { # whether to use rolling operator based features
|
||||
'windows': [5, 10, 20, 30, 60], # rolling windows size
|
||||
'include': ['ROC', 'MA', 'STD'], # rolling operator to use
|
||||
#if include is None we will use default operators
|
||||
'exclude': ['RANK'], # rolling operator not to use
|
||||
}
|
||||
}
|
||||
"""
|
||||
fields = []
|
||||
names = []
|
||||
if "kbar" in config:
|
||||
fields += [
|
||||
"($close-$open)/$open",
|
||||
"($high-$low)/$open",
|
||||
"($close-$open)/($high-$low+1e-12)",
|
||||
"($high-Greater($open, $close))/$open",
|
||||
"($high-Greater($open, $close))/($high-$low+1e-12)",
|
||||
"(Less($open, $close)-$low)/$open",
|
||||
"(Less($open, $close)-$low)/($high-$low+1e-12)",
|
||||
"(2*$close-$high-$low)/$open",
|
||||
"(2*$close-$high-$low)/($high-$low+1e-12)",
|
||||
]
|
||||
names += [
|
||||
"KMID",
|
||||
"KLEN",
|
||||
"KMID2",
|
||||
"KUP",
|
||||
"KUP2",
|
||||
"KLOW",
|
||||
"KLOW2",
|
||||
"KSFT",
|
||||
"KSFT2",
|
||||
]
|
||||
if "price" in config:
|
||||
windows = config["price"].get("windows", range(5))
|
||||
feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
|
||||
for field in feature:
|
||||
field = field.lower()
|
||||
fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
|
||||
names += [field.upper() + str(d) for d in windows]
|
||||
if "volume" in config:
|
||||
windows = config["volume"].get("windows", range(5))
|
||||
fields += ["Ref($volume, %d)/$volume" % d if d != 0 else "$volume/$volume" for d in windows]
|
||||
names += ["VOLUME" + str(d) for d in windows]
|
||||
if "rolling" in config:
|
||||
windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
|
||||
include = config["rolling"].get("include", None)
|
||||
exclude = config["rolling"].get("exclude", [])
|
||||
# `exclude` in dataset config unnecessary filed
|
||||
# `include` in dataset config necessary field
|
||||
use = lambda x: x not in exclude and (include is None or x in include)
|
||||
if use("ROC"):
|
||||
fields += ["Ref($close, %d)/$close" % d for d in windows]
|
||||
names += ["ROC%d" % d for d in windows]
|
||||
if use("MA"):
|
||||
fields += ["Mean($close, %d)/$close" % d for d in windows]
|
||||
names += ["MA%d" % d for d in windows]
|
||||
if use("STD"):
|
||||
fields += ["Std($close, %d)/$close" % d for d in windows]
|
||||
names += ["STD%d" % d for d in windows]
|
||||
if use("BETA"):
|
||||
fields += ["Slope($close, %d)/$close" % d for d in windows]
|
||||
names += ["BETA%d" % d for d in windows]
|
||||
if use("RSQR"):
|
||||
fields += ["Rsquare($close, %d)" % d for d in windows]
|
||||
names += ["RSQR%d" % d for d in windows]
|
||||
if use("RESI"):
|
||||
fields += ["Resi($close, %d)/$close" % d for d in windows]
|
||||
names += ["RESI%d" % d for d in windows]
|
||||
if use("MAX"):
|
||||
fields += ["Max($high, %d)/$close" % d for d in windows]
|
||||
names += ["MAX%d" % d for d in windows]
|
||||
if use("LOW"):
|
||||
fields += ["Min($low, %d)/$close" % d for d in windows]
|
||||
names += ["MIN%d" % d for d in windows]
|
||||
if use("QTLU"):
|
||||
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
|
||||
names += ["QTLU%d" % d for d in windows]
|
||||
if use("QTLD"):
|
||||
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
|
||||
names += ["QTLD%d" % d for d in windows]
|
||||
if use("RANK"):
|
||||
fields += ["Rank($close, %d)" % d for d in windows]
|
||||
names += ["RANK%d" % d for d in windows]
|
||||
if use("RSV"):
|
||||
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
|
||||
names += ["RSV%d" % d for d in windows]
|
||||
if use("IMAX"):
|
||||
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMAX%d" % d for d in windows]
|
||||
if use("IMIN"):
|
||||
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMIN%d" % d for d in windows]
|
||||
if use("IMXD"):
|
||||
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
|
||||
names += ["IMXD%d" % d for d in windows]
|
||||
if use("CORR"):
|
||||
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
|
||||
names += ["CORR%d" % d for d in windows]
|
||||
if use("CORD"):
|
||||
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
|
||||
names += ["CORD%d" % d for d in windows]
|
||||
if use("CNTP"):
|
||||
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTP%d" % d for d in windows]
|
||||
if use("CNTN"):
|
||||
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTN%d" % d for d in windows]
|
||||
if use("CNTD"):
|
||||
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
|
||||
names += ["CNTD%d" % d for d in windows]
|
||||
if use("SUMP"):
|
||||
fields += [
|
||||
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMP%d" % d for d in windows]
|
||||
if use("SUMN"):
|
||||
fields += [
|
||||
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMN%d" % d for d in windows]
|
||||
if use("SUMD"):
|
||||
fields += [
|
||||
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
|
||||
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d) for d in windows
|
||||
]
|
||||
names += ["SUMD%d" % d for d in windows]
|
||||
if use("VMA"):
|
||||
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VMA%d" % d for d in windows]
|
||||
if use("VSTD"):
|
||||
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VSTD%d" % d for d in windows]
|
||||
if use("WVMA"):
|
||||
fields += [
|
||||
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)" %
|
||||
(d, d) for d in windows
|
||||
]
|
||||
names += ["WVMA%d" % d for d in windows]
|
||||
if use("VSUMP"):
|
||||
fields += [
|
||||
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMP%d" % d for d in windows]
|
||||
if use("VSUMN"):
|
||||
fields += [
|
||||
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMN%d" % d for d in windows]
|
||||
if use("VSUMD"):
|
||||
fields += [
|
||||
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
|
||||
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d) for d in windows
|
||||
]
|
||||
names += ["VSUMD%d" % d for d in windows]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class Alpha158vwap(Alpha158):
|
||||
def get_label_config(self):
|
||||
|
||||
117
qlib/contrib/data/processor.py
Normal file
117
qlib/contrib/data/processor.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
|
||||
from ...log import TimeInspector
|
||||
from ...utils.serial import Serializable
|
||||
from ...data.dataset.processor import Processor, get_group_columns
|
||||
|
||||
|
||||
class ConfigSectionProcessor(Processor):
|
||||
'''
|
||||
This processor is designed for Alpha158. And will be replaced by simple processors in the future
|
||||
'''
|
||||
def __init__(self, fields_group=None, **kwargs):
|
||||
super().__init__()
|
||||
# Options
|
||||
self.fillna_feature = kwargs.get("fillna_feature", True)
|
||||
self.fillna_label = kwargs.get("fillna_label", True)
|
||||
self.clip_feature_outlier = kwargs.get("clip_feature_outlier", False)
|
||||
self.shrink_feature_outlier = kwargs.get("shrink_feature_outlier", True)
|
||||
self.clip_label_outlier = kwargs.get("clip_label_outlier", False)
|
||||
|
||||
self.fields_group = None
|
||||
|
||||
def __call__(self, df):
|
||||
return self._transform(df)
|
||||
|
||||
def _transform(self, df):
|
||||
def _label_norm(x):
|
||||
x = x - x.mean() # copy
|
||||
x /= x.std()
|
||||
if self.clip_label_outlier:
|
||||
x.clip(-3, 3, inplace=True)
|
||||
if self.fillna_label:
|
||||
x.fillna(0, inplace=True)
|
||||
return x
|
||||
|
||||
def _feature_norm(x):
|
||||
x = x - x.median() # copy
|
||||
x /= x.abs().median() * 1.4826
|
||||
if self.clip_feature_outlier:
|
||||
x.clip(-3, 3, inplace=True)
|
||||
if self.shrink_feature_outlier:
|
||||
x.where(x <= 3, 3 + (x - 3).div(x.max() - 3) * 0.5, inplace=True)
|
||||
x.where(x >= -3, -3 - (x + 3).div(x.min() + 3) * 0.5, inplace=True)
|
||||
if self.fillna_feature:
|
||||
x.fillna(0, inplace=True)
|
||||
return x
|
||||
|
||||
TimeInspector.set_time_mark()
|
||||
|
||||
# Copy the focus part and change it to single level
|
||||
selected_cols = get_group_columns(df, self.fields_group)
|
||||
df_focus = df[selected_cols].copy()
|
||||
if len(df_focus.columns.levels) > 1:
|
||||
df_focus = df_focus.droplevel(level=0)
|
||||
|
||||
# Label
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^LABEL")]
|
||||
df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_label_norm)
|
||||
|
||||
# Features
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^KLEN|^KLOW|^KUP")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.25).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^KLOW2|^KUP2")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
_cols = [
|
||||
"KMID",
|
||||
"KSFT",
|
||||
"OPEN",
|
||||
"HIGH",
|
||||
"LOW",
|
||||
"CLOSE",
|
||||
"VWAP",
|
||||
"ROC",
|
||||
"MA",
|
||||
"BETA",
|
||||
"RESI",
|
||||
"QTLU",
|
||||
"QTLD",
|
||||
"RSV",
|
||||
"SUMP",
|
||||
"SUMN",
|
||||
"SUMD",
|
||||
"VSUMP",
|
||||
"VSUMN",
|
||||
"VSUMD",
|
||||
]
|
||||
pat = "|".join(["^" + x for x in _cols])
|
||||
cols = df_focus.columns[df_focus.columns.str.contains(pat) & (~df_focus.columns.isin(["HIGH0", "LOW0"]))]
|
||||
df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^STD|^VOLUME|^VMA|^VSTD")]
|
||||
df_focus[cols] = df_focus[cols].apply(np.log).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^RSQR")]
|
||||
df_focus[cols] = df_focus[cols].fillna(0).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^MAX|^HIGH0")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: (x - 1) ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^MIN|^LOW0")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: (1 - x) ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^CORR|^CORD")]
|
||||
df_focus[cols] = df_focus[cols].apply(np.exp).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^WVMA")]
|
||||
df_focus[cols] = df_focus[cols].apply(np.log1p).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
df[selected_cols] = df_focus.values
|
||||
|
||||
TimeInspector.log_cost_time("Finished preprocessing data.")
|
||||
|
||||
return df
|
||||
@@ -0,0 +1,18 @@
|
||||
'''
|
||||
TODO:
|
||||
|
||||
- Online needs that the model have such method
|
||||
def get_data_with_date(self, date, **kwargs):
|
||||
"""
|
||||
Will be called in online module
|
||||
need to return the data that used to predict the label (score) of stocks at date.
|
||||
|
||||
:param
|
||||
date: pd.Timestamp
|
||||
predict date
|
||||
:return:
|
||||
data: the input data that used to predict the label (score) of stocks at predict date.
|
||||
"""
|
||||
raise NotImplementedError("get_data_with_date for this model is not implemented.")
|
||||
|
||||
'''
|
||||
|
||||
@@ -6,12 +6,10 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import six
|
||||
import pandas as pd
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Expression(object):
|
||||
class Expression(abc.ABC):
|
||||
"""Expression base class"""
|
||||
|
||||
def __str__(self):
|
||||
@@ -218,7 +216,6 @@ class Feature(Expression):
|
||||
return 0, 0
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class ExpressionOps(Expression):
|
||||
"""Operator Expression
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
import abc
|
||||
import six
|
||||
import time
|
||||
import queue
|
||||
import bisect
|
||||
@@ -27,8 +26,7 @@ from .base import Feature
|
||||
from .cache import DiskDatasetCache, DiskExpressionCache
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class CalendarProvider(object):
|
||||
class CalendarProvider(abc.ABC):
|
||||
"""Calendar provider base class
|
||||
|
||||
Provide calendar data.
|
||||
@@ -128,8 +126,7 @@ class CalendarProvider(object):
|
||||
return hash_args(start_time, end_time, freq, future)
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class InstrumentProvider(object):
|
||||
class InstrumentProvider(abc.ABC):
|
||||
"""Instrument provider base class
|
||||
|
||||
Provide instrument data.
|
||||
@@ -214,8 +211,7 @@ class InstrumentProvider(object):
|
||||
raise ValueError(f"Unknown instrument type {inst}")
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class FeatureProvider(object):
|
||||
class FeatureProvider(abc.ABC):
|
||||
"""Feature provider class
|
||||
|
||||
Provide feature data.
|
||||
@@ -246,8 +242,7 @@ class FeatureProvider(object):
|
||||
raise NotImplementedError("Subclass of FeatureProvider must implement `feature` method")
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class ExpressionProvider(object):
|
||||
class ExpressionProvider(abc.ABC):
|
||||
"""Expression provider class
|
||||
|
||||
Provide Expression data.
|
||||
@@ -298,8 +293,7 @@ class ExpressionProvider(object):
|
||||
raise NotImplementedError("Subclass of ExpressionProvider must implement `Expression` method")
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class DatasetProvider(object):
|
||||
class DatasetProvider(abc.ABC):
|
||||
"""Dataset provider class
|
||||
|
||||
Provide Dataset data.
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
|
||||
class Dataset:
|
||||
'''
|
||||
Preparing data for model training.
|
||||
The type of dataset depends on the model. (It could be pd.DataFrame, pytorch.DataLoader, etc.)
|
||||
'''
|
||||
def generate(self):
|
||||
pass
|
||||
|
||||
@@ -16,6 +16,17 @@ class DataLoader(ABC):
|
||||
"""
|
||||
load the data as pd.DataFrame
|
||||
|
||||
Parameters
|
||||
----------
|
||||
self : [TODO:type]
|
||||
[TODO:description]
|
||||
instruments : [TODO:type]
|
||||
[TODO:description]
|
||||
start_time : [TODO:type]
|
||||
[TODO:description]
|
||||
end_time : [TODO:type]
|
||||
[TODO:description]
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
@@ -35,240 +46,51 @@ class DataLoader(ABC):
|
||||
|
||||
class QlibDataLoader(DataLoader):
|
||||
'''Same as QlibDataLoader. The fields can be define by config'''
|
||||
def __init__(self, config: Tuple[list, tuple, dict], group_fields: bool = False, filter_pipe=None):
|
||||
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
config : Tuple[list ,tuple, dict]
|
||||
Config will be used to describe the fields and column names
|
||||
|
||||
if `group_fields`:
|
||||
<config> := {
|
||||
"group_name1": <fields_info1>
|
||||
"group_name2": <fields_info2>
|
||||
}
|
||||
else:
|
||||
<config> := <fields_info>
|
||||
<config> := {
|
||||
"group_name1": <fields_info1>
|
||||
"group_name2": <fields_info2>
|
||||
}
|
||||
|
||||
<fields_info> := ["expr", ...] | (["expr", ...], ["col_name", ...]) | <fields_info_config>
|
||||
<config> := <fields_info>
|
||||
|
||||
<fields_info_config> is a config with dict type which could be parsed by `parse_config_to_fields`
|
||||
<fields_info> := ["expr", ...] | (["expr", ...], ["col_name", ...])
|
||||
|
||||
Here is a few examples to describe the fields
|
||||
Here is a few examples to describe the fields
|
||||
TODO:
|
||||
|
||||
group_fields : bool
|
||||
Will the fields be grouped. Multi-index will be used for the group
|
||||
"""
|
||||
if group_fields:
|
||||
fields_all = []
|
||||
name_grp_info = []
|
||||
for grp, fields_info in config.items():
|
||||
fields, names = self._parse_fields_info(fields_info)
|
||||
fields_all.extend(fields)
|
||||
name_grp_info.extend([(grp, n) for n in names])
|
||||
self.fields, self.names = fields_all, name_grp_info
|
||||
else:
|
||||
self.fields, self.names = self._parse_fields_info(fields_info)
|
||||
self.is_group = isinstance(config, dict)
|
||||
|
||||
if self.is_group:
|
||||
self.fields = {grp: self._parse_fields_info(fields_info) for grp, fields_info in config.items()}
|
||||
else:
|
||||
self.fields = self._parse_fields_info(fields_info)
|
||||
|
||||
self.group_fields = group_fields
|
||||
self.filter_pipe = filter_pipe
|
||||
|
||||
def _parse_fields_info(self, fields_info: Tuple[list, tuple, dict]) -> Tuple[list, list]:
|
||||
if isinstance(fields_info, dict):
|
||||
fields, names = parse_config_to_fields(fields_info)
|
||||
elif isinstance(fields_info, list):
|
||||
fields = fields_info
|
||||
names = fields
|
||||
def _parse_fields_info(self, fields_info: Tuple[list, tuple]) -> Tuple[list, list]:
|
||||
if isinstance(fields_info, list):
|
||||
exprs = names = fields_info
|
||||
elif isinstance(fields_info, tuple):
|
||||
fields, names = fields_info
|
||||
exprs, names = fields_info
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return fields, names
|
||||
return exprs, names
|
||||
|
||||
def load(self,
|
||||
instruments,
|
||||
config: Tuple[list, tuple, dict],
|
||||
group_fields=False,
|
||||
start_time=None,
|
||||
end_time=None) -> Tuple[pd.DataFrame, dict]:
|
||||
df = D.features(D.instruments(instruments, filter_pipe=self.filter_pipe), self.fields, start_time, end_time)
|
||||
df.columns = pd.MultiIndex.from_tuples(self.names) if self.group_fields else self.names
|
||||
def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
def _get_df(exprs, names):
|
||||
df = D.features(D.instruments(instruments, filter_pipe=self.filter_pipe), exprs, start_time, end_time)
|
||||
df.columns = names
|
||||
return df
|
||||
if self.is_group:
|
||||
df = pd.concat({grp: _get_df(exprs, names) for grp, (exprs, names) in self.fields.items()}, axis=1)
|
||||
else:
|
||||
df = _get_df(exprs, names)
|
||||
df = df.swaplevel().sort_index()
|
||||
return df
|
||||
|
||||
|
||||
# TODO: make it easier to understand the config language
|
||||
def parse_config_to_fields(config):
|
||||
"""create factors from config
|
||||
|
||||
config = {
|
||||
'kbar': {}, # whether to use some hard-code kbar features
|
||||
'price': { # whether to use raw price features
|
||||
'windows': [0, 1, 2, 3, 4], # use price at n days ago
|
||||
'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
|
||||
},
|
||||
'volume': { # whether to use raw volume features
|
||||
'windows': [0, 1, 2, 3, 4], # use volume at n days ago
|
||||
},
|
||||
'rolling': { # whether to use rolling operator based features
|
||||
'windows': [5, 10, 20, 30, 60], # rolling windows size
|
||||
'include': ['ROC', 'MA', 'STD'], # rolling operator to use
|
||||
#if include is None we will use default operators
|
||||
'exclude': ['RANK'], # rolling operator not to use
|
||||
}
|
||||
}
|
||||
"""
|
||||
fields = []
|
||||
names = []
|
||||
if "kbar" in config:
|
||||
fields += [
|
||||
"($close-$open)/$open",
|
||||
"($high-$low)/$open",
|
||||
"($close-$open)/($high-$low+1e-12)",
|
||||
"($high-Greater($open, $close))/$open",
|
||||
"($high-Greater($open, $close))/($high-$low+1e-12)",
|
||||
"(Less($open, $close)-$low)/$open",
|
||||
"(Less($open, $close)-$low)/($high-$low+1e-12)",
|
||||
"(2*$close-$high-$low)/$open",
|
||||
"(2*$close-$high-$low)/($high-$low+1e-12)",
|
||||
]
|
||||
names += [
|
||||
"KMID",
|
||||
"KLEN",
|
||||
"KMID2",
|
||||
"KUP",
|
||||
"KUP2",
|
||||
"KLOW",
|
||||
"KLOW2",
|
||||
"KSFT",
|
||||
"KSFT2",
|
||||
]
|
||||
if "price" in config:
|
||||
windows = config["price"].get("windows", range(5))
|
||||
feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
|
||||
for field in feature:
|
||||
field = field.lower()
|
||||
fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
|
||||
names += [field.upper() + str(d) for d in windows]
|
||||
if "volume" in config:
|
||||
windows = config["volume"].get("windows", range(5))
|
||||
fields += ["Ref($volume, %d)/$volume" % d if d != 0 else "$volume/$volume" for d in windows]
|
||||
names += ["VOLUME" + str(d) for d in windows]
|
||||
if "rolling" in config:
|
||||
windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
|
||||
include = config["rolling"].get("include", None)
|
||||
exclude = config["rolling"].get("exclude", [])
|
||||
# `exclude` in dataset config unnecessary filed
|
||||
# `include` in dataset config necessary field
|
||||
use = lambda x: x not in exclude and (include is None or x in include)
|
||||
if use("ROC"):
|
||||
fields += ["Ref($close, %d)/$close" % d for d in windows]
|
||||
names += ["ROC%d" % d for d in windows]
|
||||
if use("MA"):
|
||||
fields += ["Mean($close, %d)/$close" % d for d in windows]
|
||||
names += ["MA%d" % d for d in windows]
|
||||
if use("STD"):
|
||||
fields += ["Std($close, %d)/$close" % d for d in windows]
|
||||
names += ["STD%d" % d for d in windows]
|
||||
if use("BETA"):
|
||||
fields += ["Slope($close, %d)/$close" % d for d in windows]
|
||||
names += ["BETA%d" % d for d in windows]
|
||||
if use("RSQR"):
|
||||
fields += ["Rsquare($close, %d)" % d for d in windows]
|
||||
names += ["RSQR%d" % d for d in windows]
|
||||
if use("RESI"):
|
||||
fields += ["Resi($close, %d)/$close" % d for d in windows]
|
||||
names += ["RESI%d" % d for d in windows]
|
||||
if use("MAX"):
|
||||
fields += ["Max($high, %d)/$close" % d for d in windows]
|
||||
names += ["MAX%d" % d for d in windows]
|
||||
if use("LOW"):
|
||||
fields += ["Min($low, %d)/$close" % d for d in windows]
|
||||
names += ["MIN%d" % d for d in windows]
|
||||
if use("QTLU"):
|
||||
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
|
||||
names += ["QTLU%d" % d for d in windows]
|
||||
if use("QTLD"):
|
||||
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
|
||||
names += ["QTLD%d" % d for d in windows]
|
||||
if use("RANK"):
|
||||
fields += ["Rank($close, %d)" % d for d in windows]
|
||||
names += ["RANK%d" % d for d in windows]
|
||||
if use("RSV"):
|
||||
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
|
||||
names += ["RSV%d" % d for d in windows]
|
||||
if use("IMAX"):
|
||||
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMAX%d" % d for d in windows]
|
||||
if use("IMIN"):
|
||||
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMIN%d" % d for d in windows]
|
||||
if use("IMXD"):
|
||||
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
|
||||
names += ["IMXD%d" % d for d in windows]
|
||||
if use("CORR"):
|
||||
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
|
||||
names += ["CORR%d" % d for d in windows]
|
||||
if use("CORD"):
|
||||
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
|
||||
names += ["CORD%d" % d for d in windows]
|
||||
if use("CNTP"):
|
||||
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTP%d" % d for d in windows]
|
||||
if use("CNTN"):
|
||||
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTN%d" % d for d in windows]
|
||||
if use("CNTD"):
|
||||
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
|
||||
names += ["CNTD%d" % d for d in windows]
|
||||
if use("SUMP"):
|
||||
fields += [
|
||||
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMP%d" % d for d in windows]
|
||||
if use("SUMN"):
|
||||
fields += [
|
||||
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMN%d" % d for d in windows]
|
||||
if use("SUMD"):
|
||||
fields += [
|
||||
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
|
||||
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d) for d in windows
|
||||
]
|
||||
names += ["SUMD%d" % d for d in windows]
|
||||
if use("VMA"):
|
||||
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VMA%d" % d for d in windows]
|
||||
if use("VSTD"):
|
||||
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VSTD%d" % d for d in windows]
|
||||
if use("WVMA"):
|
||||
fields += [
|
||||
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)" %
|
||||
(d, d) for d in windows
|
||||
]
|
||||
names += ["WVMA%d" % d for d in windows]
|
||||
if use("VSUMP"):
|
||||
fields += [
|
||||
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMP%d" % d for d in windows]
|
||||
if use("VSUMN"):
|
||||
fields += [
|
||||
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMN%d" % d for d in windows]
|
||||
if use("VSUMD"):
|
||||
fields += [
|
||||
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
|
||||
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d) for d in windows
|
||||
]
|
||||
names += ["VSUMD%d" % d for d in windows]
|
||||
|
||||
return fields, names
|
||||
|
||||
@@ -165,113 +165,3 @@ class CSZScoreNorm(Processor):
|
||||
cols = get_group_columns(df,self.fields_group)
|
||||
df[cols] = df[cols].groupby('datetime').apply(lambda df: (df - df.mean()).div(df.std()))
|
||||
return df
|
||||
|
||||
|
||||
# TODO: make the config language easier to understand
|
||||
class ConfigSectionProcessor(Processor):
|
||||
# TODO: this class is not well tested
|
||||
# FIXME: this will raise error when multi-index is passed in
|
||||
def __init__(self, fields_group=None, **kwargs):
|
||||
super().__init__()
|
||||
# Options
|
||||
self.fillna_feature = kwargs.get("fillna_feature", True)
|
||||
self.fillna_label = kwargs.get("fillna_label", True)
|
||||
self.clip_feature_outlier = kwargs.get("clip_feature_outlier", False)
|
||||
self.shrink_feature_outlier = kwargs.get("shrink_feature_outlier", True)
|
||||
self.clip_label_outlier = kwargs.get("clip_label_outlier", False)
|
||||
|
||||
self.fields_group = None
|
||||
|
||||
def __call__(self, df):
|
||||
return self._transform(df)
|
||||
|
||||
def _transform(self, df):
|
||||
def _label_norm(x):
|
||||
x = x - x.mean() # copy
|
||||
x /= x.std()
|
||||
if self.clip_label_outlier:
|
||||
x.clip(-3, 3, inplace=True)
|
||||
if self.fillna_label:
|
||||
x.fillna(0, inplace=True)
|
||||
return x
|
||||
|
||||
def _feature_norm(x):
|
||||
x = x - x.median() # copy
|
||||
x /= x.abs().median() * 1.4826
|
||||
if self.clip_feature_outlier:
|
||||
x.clip(-3, 3, inplace=True)
|
||||
if self.shrink_feature_outlier:
|
||||
x.where(x <= 3, 3 + (x - 3).div(x.max() - 3) * 0.5, inplace=True)
|
||||
x.where(x >= -3, -3 - (x + 3).div(x.min() + 3) * 0.5, inplace=True)
|
||||
if self.fillna_feature:
|
||||
x.fillna(0, inplace=True)
|
||||
return x
|
||||
|
||||
TimeInspector.set_time_mark()
|
||||
|
||||
# Copy the focus part and change it to single level
|
||||
selected_cols = get_group_columns(df, self.fields_group)
|
||||
df_focus = df[selected_cols].copy()
|
||||
if len(df_focus.columns.levels) > 1:
|
||||
df_focus = df_focus.droplevel(level=0)
|
||||
|
||||
# Label
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^LABEL")]
|
||||
df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_label_norm)
|
||||
|
||||
# Features
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^KLEN|^KLOW|^KUP")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.25).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^KLOW2|^KUP2")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
_cols = [
|
||||
"KMID",
|
||||
"KSFT",
|
||||
"OPEN",
|
||||
"HIGH",
|
||||
"LOW",
|
||||
"CLOSE",
|
||||
"VWAP",
|
||||
"ROC",
|
||||
"MA",
|
||||
"BETA",
|
||||
"RESI",
|
||||
"QTLU",
|
||||
"QTLD",
|
||||
"RSV",
|
||||
"SUMP",
|
||||
"SUMN",
|
||||
"SUMD",
|
||||
"VSUMP",
|
||||
"VSUMN",
|
||||
"VSUMD",
|
||||
]
|
||||
pat = "|".join(["^" + x for x in _cols])
|
||||
cols = df_focus.columns[df_focus.columns.str.contains(pat) & (~df_focus.columns.isin(["HIGH0", "LOW0"]))]
|
||||
df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^STD|^VOLUME|^VMA|^VSTD")]
|
||||
df_focus[cols] = df_focus[cols].apply(np.log).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^RSQR")]
|
||||
df_focus[cols] = df_focus[cols].fillna(0).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^MAX|^HIGH0")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: (x - 1) ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^MIN|^LOW0")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: (1 - x) ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^CORR|^CORD")]
|
||||
df_focus[cols] = df_focus[cols].apply(np.exp).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^WVMA")]
|
||||
df_focus[cols] = df_focus[cols].apply(np.log1p).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
df[selected_cols] = df_focus.values
|
||||
|
||||
TimeInspector.log_cost_time("Finished preprocessing data.")
|
||||
|
||||
return df
|
||||
|
||||
@@ -7,14 +7,12 @@ from abc import abstractmethod
|
||||
import re
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import six
|
||||
import abc
|
||||
|
||||
from .data import Cal, DatasetD
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class BaseDFilter(object):
|
||||
class BaseDFilter(abc.ABC):
|
||||
"""Dynamic Instruments Filter Abstract class
|
||||
|
||||
Users can override this class to construct their own filter
|
||||
@@ -50,7 +48,6 @@ class BaseDFilter(object):
|
||||
raise NotImplementedError("Subclass of BaseDFilter must reimplement `to_config` method")
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class SeriesDFilter(BaseDFilter):
|
||||
"""Dynamic Instruments Filter Abstract class to filter a series of certain features
|
||||
|
||||
|
||||
@@ -1,22 +1,26 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import six
|
||||
from ..utils.serial import Serializable
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Model(object):
|
||||
"""Model base class"""
|
||||
class BaseModel(Serializable, metaclass=abc.ABCMeta):
|
||||
'''Modeling things'''
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return type(self).__name__
|
||||
@abc.abstractmethod
|
||||
def predict(self, *args, **kwargs) -> object:
|
||||
""" Make predictions after modeling things """
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs) -> object:
|
||||
""" levarge Python syntactic sugar to make the models' behaviors like functions """
|
||||
return self.predict(*args, **kwargs)
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
'''Learnable Models'''
|
||||
|
||||
# TODO: Make the model easier.
|
||||
def fit(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs):
|
||||
"""fix train with cross-validation
|
||||
Fit model when ex_config.finetune is False
|
||||
@@ -43,25 +47,7 @@ class Model(object):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def score(self, x_test, y_test, w_test=None, **kwargs):
|
||||
"""evaluate model with test data/label
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x_test : pd.dataframe
|
||||
test data
|
||||
y_test : pd.dataframe
|
||||
test label
|
||||
w_test : pd.dataframe
|
||||
test weight
|
||||
|
||||
Returns
|
||||
----------
|
||||
float
|
||||
evaluation score
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def predict(self, x_test, **kwargs):
|
||||
"""predict given test data
|
||||
|
||||
@@ -76,80 +62,3 @@ class Model(object):
|
||||
test predict label
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def save(self, fname, **kwargs):
|
||||
"""save model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname : str
|
||||
model filename
|
||||
"""
|
||||
# TODO: Currently need to save the model as a single file, otherwise the estimator may not be compatible
|
||||
raise NotImplementedError()
|
||||
|
||||
def load(self, buffer, **kwargs):
|
||||
"""load model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
buffer : bytes
|
||||
binary data of model parameters
|
||||
|
||||
Returns
|
||||
----------
|
||||
Model
|
||||
loaded model
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_data_with_date(self, date, **kwargs):
|
||||
"""
|
||||
Will be called in online module
|
||||
need to return the data that used to predict the label (score) of stocks at date.
|
||||
|
||||
:param
|
||||
date: pd.Timestamp
|
||||
predict date
|
||||
:return:
|
||||
data: the input data that used to predict the label (score) of stocks at predict date.
|
||||
"""
|
||||
raise NotImplementedError("get_data_with_date for this model is not implemented.")
|
||||
|
||||
def finetune(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs):
|
||||
"""Finetune model
|
||||
In `RollingTrainer`:
|
||||
if loader.model_index is None:
|
||||
If provide 'Static Model', based on the provided 'Static' model update.
|
||||
If provide 'Rolling Model', skip the model of load, based on the last 'provided model' update.
|
||||
|
||||
if loader.model_index is not None:
|
||||
Based on the provided model(loader.model_index) update.
|
||||
|
||||
In `StaticTrainer`:
|
||||
If the load is 'static model':
|
||||
Based on the 'static model' update
|
||||
If the load is 'rolling model':
|
||||
Based on the provided model(`loader.model_index`) update. If `loader.model_index` is None, use the last model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x_train : pd.dataframe
|
||||
train data
|
||||
y_train : pd.dataframe
|
||||
train label
|
||||
x_valid : pd.dataframe
|
||||
valid data
|
||||
y_valid : pd.dataframe
|
||||
valid label
|
||||
w_train : pd.dataframe
|
||||
train weight
|
||||
w_valid : pd.dataframe
|
||||
valid weight
|
||||
|
||||
Returns
|
||||
----------
|
||||
Model
|
||||
finetune model
|
||||
"""
|
||||
raise NotImplementedError("Finetune for this model is not implemented.")
|
||||
|
||||
0
qlib/workflow/__init__.py
Normal file
0
qlib/workflow/__init__.py
Normal file
Reference in New Issue
Block a user