From aee507d5ddd2eade265ebe8bd62be868faa98bd5 Mon Sep 17 00:00:00 2001 From: Young Date: Mon, 26 Oct 2020 13:26:01 +0000 Subject: [PATCH] adjust data and model interface --- examples/workflow_by_code.py | 44 ++++++ qlib/contrib/data/handler.py | 185 ++++++++++++++++++++++- qlib/contrib/data/processor.py | 117 +++++++++++++++ qlib/contrib/online/__init__.py | 18 +++ qlib/data/base.py | 5 +- qlib/data/data.py | 16 +- qlib/data/dataset/__init__.py | 8 + qlib/data/dataset/loader.py | 256 +++++--------------------------- qlib/data/dataset/processor.py | 110 -------------- qlib/data/filter.py | 5 +- qlib/model/base.py | 125 +++------------- qlib/workflow/__init__.py | 0 12 files changed, 431 insertions(+), 458 deletions(-) create mode 100644 qlib/contrib/data/processor.py create mode 100644 qlib/workflow/__init__.py diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py index 9f0d5b02f..326e9dee0 100644 --- a/examples/workflow_by_code.py +++ b/examples/workflow_by_code.py @@ -16,6 +16,8 @@ from qlib.contrib.evaluate import ( ) from qlib.utils import exists_qlib_data +from qlib.model.learner import train_model + if __name__ == "__main__": @@ -62,6 +64,48 @@ if __name__ == "__main__": data = handler.fetch(slice('2008-01-01', '2014-12-31'), data_key=handler.DK_I) print(data) + task = { + "model": { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + "kwargs": { + "loss": "mse", + "colsample_bytree": 0.8879, + "learning_rate": 0.0421, + "subsample": 0.8789, + "lambda_l1": 205.6999, + "lambda_l2": 580.9768, + "max_depth": 8, + "num_leaves": 210, + "num_threads": 20, + } + }, + "data": { + "dataset": { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + 'handler': { + "class": "Alpha158", + "kwargs": DATA_HANDLER_CONFIG + }, + "train_start_time": "2008-01-01", + "train_end_time": "2014-12-31", + "validate_start_time": "2015-01-01", + "validate_end_time": "2016-12-31", + "test_start_time": "2017-01-01", + "test_end_time": "2020-08-01", + } + } + }, + # You shoud record the data in specific sequence + # "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'], + } + + model = train_model(task) + + + sys.exit(0) # I have tested the code above --------------------------------------------- x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158(**DATA_HANDLER_CONFIG).get_split_data( diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py index c9959535a..45e4855c1 100644 --- a/qlib/contrib/data/handler.py +++ b/qlib/contrib/data/handler.py @@ -25,10 +25,12 @@ class ALPHA360(DataHandlerLP): }, "label": self.get_label_config() }, - "group_fields": True, } } - infer_processors = ["ConfigSectionProcessor"] # ConfigSectionProcessor will normalize LABEL0 + infer_processors = [{ + "class": "ConfigSectionProcessor", + "module_path": "qlib.contrib.data.processor" + }] # ConfigSectionProcessor will normalize LABEL0 super().__init__(instruments, start_time, end_time, data_loader=data_loader, infer_processors=infer_processors) def get_label_config(self): @@ -83,7 +85,6 @@ class Alpha158(DataHandlerLP): "feature": self.get_feature_config(), "label": self.get_label_config() }, - "group_fields": True, } } super().__init__(instruments, @@ -94,7 +95,7 @@ class Alpha158(DataHandlerLP): learn_processors=learn_processors) def get_feature_config(self): - return { + conf = { "kbar": {}, "price": { "windows": [0], @@ -102,10 +103,186 @@ class Alpha158(DataHandlerLP): }, "rolling": {}, } + return self.parse_config_to_fields(conf) def get_label_config(self): return (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]) + @staticmethod + def parse_config_to_fields(config): + """create factors from config + + config = { + 'kbar': {}, # whether to use some hard-code kbar features + 'price': { # whether to use raw price features + 'windows': [0, 1, 2, 3, 4], # use price at n days ago + 'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use + }, + 'volume': { # whether to use raw volume features + 'windows': [0, 1, 2, 3, 4], # use volume at n days ago + }, + 'rolling': { # whether to use rolling operator based features + 'windows': [5, 10, 20, 30, 60], # rolling windows size + 'include': ['ROC', 'MA', 'STD'], # rolling operator to use + #if include is None we will use default operators + 'exclude': ['RANK'], # rolling operator not to use + } + } + """ + fields = [] + names = [] + if "kbar" in config: + fields += [ + "($close-$open)/$open", + "($high-$low)/$open", + "($close-$open)/($high-$low+1e-12)", + "($high-Greater($open, $close))/$open", + "($high-Greater($open, $close))/($high-$low+1e-12)", + "(Less($open, $close)-$low)/$open", + "(Less($open, $close)-$low)/($high-$low+1e-12)", + "(2*$close-$high-$low)/$open", + "(2*$close-$high-$low)/($high-$low+1e-12)", + ] + names += [ + "KMID", + "KLEN", + "KMID2", + "KUP", + "KUP2", + "KLOW", + "KLOW2", + "KSFT", + "KSFT2", + ] + if "price" in config: + windows = config["price"].get("windows", range(5)) + feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"]) + for field in feature: + field = field.lower() + fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows] + names += [field.upper() + str(d) for d in windows] + if "volume" in config: + windows = config["volume"].get("windows", range(5)) + fields += ["Ref($volume, %d)/$volume" % d if d != 0 else "$volume/$volume" for d in windows] + names += ["VOLUME" + str(d) for d in windows] + if "rolling" in config: + windows = config["rolling"].get("windows", [5, 10, 20, 30, 60]) + include = config["rolling"].get("include", None) + exclude = config["rolling"].get("exclude", []) + # `exclude` in dataset config unnecessary filed + # `include` in dataset config necessary field + use = lambda x: x not in exclude and (include is None or x in include) + if use("ROC"): + fields += ["Ref($close, %d)/$close" % d for d in windows] + names += ["ROC%d" % d for d in windows] + if use("MA"): + fields += ["Mean($close, %d)/$close" % d for d in windows] + names += ["MA%d" % d for d in windows] + if use("STD"): + fields += ["Std($close, %d)/$close" % d for d in windows] + names += ["STD%d" % d for d in windows] + if use("BETA"): + fields += ["Slope($close, %d)/$close" % d for d in windows] + names += ["BETA%d" % d for d in windows] + if use("RSQR"): + fields += ["Rsquare($close, %d)" % d for d in windows] + names += ["RSQR%d" % d for d in windows] + if use("RESI"): + fields += ["Resi($close, %d)/$close" % d for d in windows] + names += ["RESI%d" % d for d in windows] + if use("MAX"): + fields += ["Max($high, %d)/$close" % d for d in windows] + names += ["MAX%d" % d for d in windows] + if use("LOW"): + fields += ["Min($low, %d)/$close" % d for d in windows] + names += ["MIN%d" % d for d in windows] + if use("QTLU"): + fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows] + names += ["QTLU%d" % d for d in windows] + if use("QTLD"): + fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows] + names += ["QTLD%d" % d for d in windows] + if use("RANK"): + fields += ["Rank($close, %d)" % d for d in windows] + names += ["RANK%d" % d for d in windows] + if use("RSV"): + fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows] + names += ["RSV%d" % d for d in windows] + if use("IMAX"): + fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows] + names += ["IMAX%d" % d for d in windows] + if use("IMIN"): + fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows] + names += ["IMIN%d" % d for d in windows] + if use("IMXD"): + fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows] + names += ["IMXD%d" % d for d in windows] + if use("CORR"): + fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows] + names += ["CORR%d" % d for d in windows] + if use("CORD"): + fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows] + names += ["CORD%d" % d for d in windows] + if use("CNTP"): + fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows] + names += ["CNTP%d" % d for d in windows] + if use("CNTN"): + fields += ["Mean($closeRef($close, 1), %d)-Mean($close= -3, -3 - (x + 3).div(x.min() + 3) * 0.5, inplace=True) + if self.fillna_feature: + x.fillna(0, inplace=True) + return x + + TimeInspector.set_time_mark() + + # Copy the focus part and change it to single level + selected_cols = get_group_columns(df, self.fields_group) + df_focus = df[selected_cols].copy() + if len(df_focus.columns.levels) > 1: + df_focus = df_focus.droplevel(level=0) + + # Label + cols = df_focus.columns[df_focus.columns.str.contains("^LABEL")] + df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_label_norm) + + # Features + cols = df_focus.columns[df_focus.columns.str.contains("^KLEN|^KLOW|^KUP")] + df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.25).groupby(level="datetime").apply(_feature_norm) + + cols = df_focus.columns[df_focus.columns.str.contains("^KLOW2|^KUP2")] + df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.5).groupby(level="datetime").apply(_feature_norm) + + _cols = [ + "KMID", + "KSFT", + "OPEN", + "HIGH", + "LOW", + "CLOSE", + "VWAP", + "ROC", + "MA", + "BETA", + "RESI", + "QTLU", + "QTLD", + "RSV", + "SUMP", + "SUMN", + "SUMD", + "VSUMP", + "VSUMN", + "VSUMD", + ] + pat = "|".join(["^" + x for x in _cols]) + cols = df_focus.columns[df_focus.columns.str.contains(pat) & (~df_focus.columns.isin(["HIGH0", "LOW0"]))] + df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_feature_norm) + + cols = df_focus.columns[df_focus.columns.str.contains("^STD|^VOLUME|^VMA|^VSTD")] + df_focus[cols] = df_focus[cols].apply(np.log).groupby(level="datetime").apply(_feature_norm) + + cols = df_focus.columns[df_focus.columns.str.contains("^RSQR")] + df_focus[cols] = df_focus[cols].fillna(0).groupby(level="datetime").apply(_feature_norm) + + cols = df_focus.columns[df_focus.columns.str.contains("^MAX|^HIGH0")] + df_focus[cols] = df_focus[cols].apply(lambda x: (x - 1) ** 0.5).groupby(level="datetime").apply(_feature_norm) + + cols = df_focus.columns[df_focus.columns.str.contains("^MIN|^LOW0")] + df_focus[cols] = df_focus[cols].apply(lambda x: (1 - x) ** 0.5).groupby(level="datetime").apply(_feature_norm) + + cols = df_focus.columns[df_focus.columns.str.contains("^CORR|^CORD")] + df_focus[cols] = df_focus[cols].apply(np.exp).groupby(level="datetime").apply(_feature_norm) + + cols = df_focus.columns[df_focus.columns.str.contains("^WVMA")] + df_focus[cols] = df_focus[cols].apply(np.log1p).groupby(level="datetime").apply(_feature_norm) + + df[selected_cols] = df_focus.values + + TimeInspector.log_cost_time("Finished preprocessing data.") + + return df diff --git a/qlib/contrib/online/__init__.py b/qlib/contrib/online/__init__.py index e69de29bb..71389882e 100644 --- a/qlib/contrib/online/__init__.py +++ b/qlib/contrib/online/__init__.py @@ -0,0 +1,18 @@ +''' +TODO: + +- Online needs that the model have such method + def get_data_with_date(self, date, **kwargs): + """ + Will be called in online module + need to return the data that used to predict the label (score) of stocks at date. + + :param + date: pd.Timestamp + predict date + :return: + data: the input data that used to predict the label (score) of stocks at predict date. + """ + raise NotImplementedError("get_data_with_date for this model is not implemented.") + +''' diff --git a/qlib/data/base.py b/qlib/data/base.py index c357700c0..433b6585a 100644 --- a/qlib/data/base.py +++ b/qlib/data/base.py @@ -6,12 +6,10 @@ from __future__ import division from __future__ import print_function import abc -import six import pandas as pd -@six.add_metaclass(abc.ABCMeta) -class Expression(object): +class Expression(abc.ABC): """Expression base class""" def __str__(self): @@ -218,7 +216,6 @@ class Feature(Expression): return 0, 0 -@six.add_metaclass(abc.ABCMeta) class ExpressionOps(Expression): """Operator Expression diff --git a/qlib/data/data.py b/qlib/data/data.py index dc2c5886c..c41d32f6e 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -7,7 +7,6 @@ from __future__ import print_function import os import abc -import six import time import queue import bisect @@ -27,8 +26,7 @@ from .base import Feature from .cache import DiskDatasetCache, DiskExpressionCache -@six.add_metaclass(abc.ABCMeta) -class CalendarProvider(object): +class CalendarProvider(abc.ABC): """Calendar provider base class Provide calendar data. @@ -128,8 +126,7 @@ class CalendarProvider(object): return hash_args(start_time, end_time, freq, future) -@six.add_metaclass(abc.ABCMeta) -class InstrumentProvider(object): +class InstrumentProvider(abc.ABC): """Instrument provider base class Provide instrument data. @@ -214,8 +211,7 @@ class InstrumentProvider(object): raise ValueError(f"Unknown instrument type {inst}") -@six.add_metaclass(abc.ABCMeta) -class FeatureProvider(object): +class FeatureProvider(abc.ABC): """Feature provider class Provide feature data. @@ -246,8 +242,7 @@ class FeatureProvider(object): raise NotImplementedError("Subclass of FeatureProvider must implement `feature` method") -@six.add_metaclass(abc.ABCMeta) -class ExpressionProvider(object): +class ExpressionProvider(abc.ABC): """Expression provider class Provide Expression data. @@ -298,8 +293,7 @@ class ExpressionProvider(object): raise NotImplementedError("Subclass of ExpressionProvider must implement `Expression` method") -@six.add_metaclass(abc.ABCMeta) -class DatasetProvider(object): +class DatasetProvider(abc.ABC): """Dataset provider class Provide Dataset data. diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index e69de29bb..ec6cb2c4b 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -0,0 +1,8 @@ + +class Dataset: + ''' + Preparing data for model training. + The type of dataset depends on the model. (It could be pd.DataFrame, pytorch.DataLoader, etc.) + ''' + def generate(self): + pass diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 7da042b7c..b94280a83 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -16,6 +16,17 @@ class DataLoader(ABC): """ load the data as pd.DataFrame + Parameters + ---------- + self : [TODO:type] + [TODO:description] + instruments : [TODO:type] + [TODO:description] + start_time : [TODO:type] + [TODO:description] + end_time : [TODO:type] + [TODO:description] + Returns ------- pd.DataFrame: @@ -35,240 +46,51 @@ class DataLoader(ABC): class QlibDataLoader(DataLoader): '''Same as QlibDataLoader. The fields can be define by config''' - def __init__(self, config: Tuple[list, tuple, dict], group_fields: bool = False, filter_pipe=None): + def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None): """ Parameters ---------- config : Tuple[list ,tuple, dict] Config will be used to describe the fields and column names - if `group_fields`: - := { - "group_name1": - "group_name2": - } - else: - := + := { + "group_name1": + "group_name2": + } - := ["expr", ...] | (["expr", ...], ["col_name", ...]) | + := - is a config with dict type which could be parsed by `parse_config_to_fields` + := ["expr", ...] | (["expr", ...], ["col_name", ...]) - Here is a few examples to describe the fields + Here is a few examples to describe the fields TODO: - - group_fields : bool - Will the fields be grouped. Multi-index will be used for the group """ - if group_fields: - fields_all = [] - name_grp_info = [] - for grp, fields_info in config.items(): - fields, names = self._parse_fields_info(fields_info) - fields_all.extend(fields) - name_grp_info.extend([(grp, n) for n in names]) - self.fields, self.names = fields_all, name_grp_info - else: - self.fields, self.names = self._parse_fields_info(fields_info) + self.is_group = isinstance(config, dict) + + if self.is_group: + self.fields = {grp: self._parse_fields_info(fields_info) for grp, fields_info in config.items()} + else: + self.fields = self._parse_fields_info(fields_info) - self.group_fields = group_fields self.filter_pipe = filter_pipe - def _parse_fields_info(self, fields_info: Tuple[list, tuple, dict]) -> Tuple[list, list]: - if isinstance(fields_info, dict): - fields, names = parse_config_to_fields(fields_info) - elif isinstance(fields_info, list): - fields = fields_info - names = fields + def _parse_fields_info(self, fields_info: Tuple[list, tuple]) -> Tuple[list, list]: + if isinstance(fields_info, list): + exprs = names = fields_info elif isinstance(fields_info, tuple): - fields, names = fields_info + exprs, names = fields_info else: raise NotImplementedError(f"This type of input is not supported") - return fields, names + return exprs, names - def load(self, - instruments, - config: Tuple[list, tuple, dict], - group_fields=False, - start_time=None, - end_time=None) -> Tuple[pd.DataFrame, dict]: - df = D.features(D.instruments(instruments, filter_pipe=self.filter_pipe), self.fields, start_time, end_time) - df.columns = pd.MultiIndex.from_tuples(self.names) if self.group_fields else self.names + def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame: + def _get_df(exprs, names): + df = D.features(D.instruments(instruments, filter_pipe=self.filter_pipe), exprs, start_time, end_time) + df.columns = names + return df + if self.is_group: + df = pd.concat({grp: _get_df(exprs, names) for grp, (exprs, names) in self.fields.items()}, axis=1) + else: + df = _get_df(exprs, names) df = df.swaplevel().sort_index() return df - - -# TODO: make it easier to understand the config language -def parse_config_to_fields(config): - """create factors from config - - config = { - 'kbar': {}, # whether to use some hard-code kbar features - 'price': { # whether to use raw price features - 'windows': [0, 1, 2, 3, 4], # use price at n days ago - 'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use - }, - 'volume': { # whether to use raw volume features - 'windows': [0, 1, 2, 3, 4], # use volume at n days ago - }, - 'rolling': { # whether to use rolling operator based features - 'windows': [5, 10, 20, 30, 60], # rolling windows size - 'include': ['ROC', 'MA', 'STD'], # rolling operator to use - #if include is None we will use default operators - 'exclude': ['RANK'], # rolling operator not to use - } - } - """ - fields = [] - names = [] - if "kbar" in config: - fields += [ - "($close-$open)/$open", - "($high-$low)/$open", - "($close-$open)/($high-$low+1e-12)", - "($high-Greater($open, $close))/$open", - "($high-Greater($open, $close))/($high-$low+1e-12)", - "(Less($open, $close)-$low)/$open", - "(Less($open, $close)-$low)/($high-$low+1e-12)", - "(2*$close-$high-$low)/$open", - "(2*$close-$high-$low)/($high-$low+1e-12)", - ] - names += [ - "KMID", - "KLEN", - "KMID2", - "KUP", - "KUP2", - "KLOW", - "KLOW2", - "KSFT", - "KSFT2", - ] - if "price" in config: - windows = config["price"].get("windows", range(5)) - feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"]) - for field in feature: - field = field.lower() - fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows] - names += [field.upper() + str(d) for d in windows] - if "volume" in config: - windows = config["volume"].get("windows", range(5)) - fields += ["Ref($volume, %d)/$volume" % d if d != 0 else "$volume/$volume" for d in windows] - names += ["VOLUME" + str(d) for d in windows] - if "rolling" in config: - windows = config["rolling"].get("windows", [5, 10, 20, 30, 60]) - include = config["rolling"].get("include", None) - exclude = config["rolling"].get("exclude", []) - # `exclude` in dataset config unnecessary filed - # `include` in dataset config necessary field - use = lambda x: x not in exclude and (include is None or x in include) - if use("ROC"): - fields += ["Ref($close, %d)/$close" % d for d in windows] - names += ["ROC%d" % d for d in windows] - if use("MA"): - fields += ["Mean($close, %d)/$close" % d for d in windows] - names += ["MA%d" % d for d in windows] - if use("STD"): - fields += ["Std($close, %d)/$close" % d for d in windows] - names += ["STD%d" % d for d in windows] - if use("BETA"): - fields += ["Slope($close, %d)/$close" % d for d in windows] - names += ["BETA%d" % d for d in windows] - if use("RSQR"): - fields += ["Rsquare($close, %d)" % d for d in windows] - names += ["RSQR%d" % d for d in windows] - if use("RESI"): - fields += ["Resi($close, %d)/$close" % d for d in windows] - names += ["RESI%d" % d for d in windows] - if use("MAX"): - fields += ["Max($high, %d)/$close" % d for d in windows] - names += ["MAX%d" % d for d in windows] - if use("LOW"): - fields += ["Min($low, %d)/$close" % d for d in windows] - names += ["MIN%d" % d for d in windows] - if use("QTLU"): - fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows] - names += ["QTLU%d" % d for d in windows] - if use("QTLD"): - fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows] - names += ["QTLD%d" % d for d in windows] - if use("RANK"): - fields += ["Rank($close, %d)" % d for d in windows] - names += ["RANK%d" % d for d in windows] - if use("RSV"): - fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows] - names += ["RSV%d" % d for d in windows] - if use("IMAX"): - fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows] - names += ["IMAX%d" % d for d in windows] - if use("IMIN"): - fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows] - names += ["IMIN%d" % d for d in windows] - if use("IMXD"): - fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows] - names += ["IMXD%d" % d for d in windows] - if use("CORR"): - fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows] - names += ["CORR%d" % d for d in windows] - if use("CORD"): - fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows] - names += ["CORD%d" % d for d in windows] - if use("CNTP"): - fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows] - names += ["CNTP%d" % d for d in windows] - if use("CNTN"): - fields += ["Mean($closeRef($close, 1), %d)-Mean($close= -3, -3 - (x + 3).div(x.min() + 3) * 0.5, inplace=True) - if self.fillna_feature: - x.fillna(0, inplace=True) - return x - - TimeInspector.set_time_mark() - - # Copy the focus part and change it to single level - selected_cols = get_group_columns(df, self.fields_group) - df_focus = df[selected_cols].copy() - if len(df_focus.columns.levels) > 1: - df_focus = df_focus.droplevel(level=0) - - # Label - cols = df_focus.columns[df_focus.columns.str.contains("^LABEL")] - df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_label_norm) - - # Features - cols = df_focus.columns[df_focus.columns.str.contains("^KLEN|^KLOW|^KUP")] - df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.25).groupby(level="datetime").apply(_feature_norm) - - cols = df_focus.columns[df_focus.columns.str.contains("^KLOW2|^KUP2")] - df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.5).groupby(level="datetime").apply(_feature_norm) - - _cols = [ - "KMID", - "KSFT", - "OPEN", - "HIGH", - "LOW", - "CLOSE", - "VWAP", - "ROC", - "MA", - "BETA", - "RESI", - "QTLU", - "QTLD", - "RSV", - "SUMP", - "SUMN", - "SUMD", - "VSUMP", - "VSUMN", - "VSUMD", - ] - pat = "|".join(["^" + x for x in _cols]) - cols = df_focus.columns[df_focus.columns.str.contains(pat) & (~df_focus.columns.isin(["HIGH0", "LOW0"]))] - df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_feature_norm) - - cols = df_focus.columns[df_focus.columns.str.contains("^STD|^VOLUME|^VMA|^VSTD")] - df_focus[cols] = df_focus[cols].apply(np.log).groupby(level="datetime").apply(_feature_norm) - - cols = df_focus.columns[df_focus.columns.str.contains("^RSQR")] - df_focus[cols] = df_focus[cols].fillna(0).groupby(level="datetime").apply(_feature_norm) - - cols = df_focus.columns[df_focus.columns.str.contains("^MAX|^HIGH0")] - df_focus[cols] = df_focus[cols].apply(lambda x: (x - 1) ** 0.5).groupby(level="datetime").apply(_feature_norm) - - cols = df_focus.columns[df_focus.columns.str.contains("^MIN|^LOW0")] - df_focus[cols] = df_focus[cols].apply(lambda x: (1 - x) ** 0.5).groupby(level="datetime").apply(_feature_norm) - - cols = df_focus.columns[df_focus.columns.str.contains("^CORR|^CORD")] - df_focus[cols] = df_focus[cols].apply(np.exp).groupby(level="datetime").apply(_feature_norm) - - cols = df_focus.columns[df_focus.columns.str.contains("^WVMA")] - df_focus[cols] = df_focus[cols].apply(np.log1p).groupby(level="datetime").apply(_feature_norm) - - df[selected_cols] = df_focus.values - - TimeInspector.log_cost_time("Finished preprocessing data.") - - return df diff --git a/qlib/data/filter.py b/qlib/data/filter.py index 3a36b1678..47b093b67 100644 --- a/qlib/data/filter.py +++ b/qlib/data/filter.py @@ -7,14 +7,12 @@ from abc import abstractmethod import re import pandas as pd import numpy as np -import six import abc from .data import Cal, DatasetD -@six.add_metaclass(abc.ABCMeta) -class BaseDFilter(object): +class BaseDFilter(abc.ABC): """Dynamic Instruments Filter Abstract class Users can override this class to construct their own filter @@ -50,7 +48,6 @@ class BaseDFilter(object): raise NotImplementedError("Subclass of BaseDFilter must reimplement `to_config` method") -@six.add_metaclass(abc.ABCMeta) class SeriesDFilter(BaseDFilter): """Dynamic Instruments Filter Abstract class to filter a series of certain features diff --git a/qlib/model/base.py b/qlib/model/base.py index b3ea917a5..66b54705a 100644 --- a/qlib/model/base.py +++ b/qlib/model/base.py @@ -1,22 +1,26 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - - -from __future__ import division -from __future__ import print_function - import abc -import six +from ..utils.serial import Serializable -@six.add_metaclass(abc.ABCMeta) -class Model(object): - """Model base class""" +class BaseModel(Serializable, metaclass=abc.ABCMeta): + '''Modeling things''' - @property - def name(self): - return type(self).__name__ + @abc.abstractmethod + def predict(self, *args, **kwargs) -> object: + """ Make predictions after modeling things """ + pass + def __call__(self, *args, **kwargs) -> object: + """ levarge Python syntactic sugar to make the models' behaviors like functions """ + return self.predict(*args, **kwargs) + + +class Model(BaseModel): + '''Learnable Models''' + + # TODO: Make the model easier. def fit(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs): """fix train with cross-validation Fit model when ex_config.finetune is False @@ -43,25 +47,7 @@ class Model(object): """ raise NotImplementedError() - def score(self, x_test, y_test, w_test=None, **kwargs): - """evaluate model with test data/label - - Parameters - ---------- - x_test : pd.dataframe - test data - y_test : pd.dataframe - test label - w_test : pd.dataframe - test weight - - Returns - ---------- - float - evaluation score - """ - raise NotImplementedError() - + @abc.abstractmethod def predict(self, x_test, **kwargs): """predict given test data @@ -76,80 +62,3 @@ class Model(object): test predict label """ raise NotImplementedError() - - def save(self, fname, **kwargs): - """save model - - Parameters - ---------- - fname : str - model filename - """ - # TODO: Currently need to save the model as a single file, otherwise the estimator may not be compatible - raise NotImplementedError() - - def load(self, buffer, **kwargs): - """load model - - Parameters - ---------- - buffer : bytes - binary data of model parameters - - Returns - ---------- - Model - loaded model - """ - raise NotImplementedError() - - def get_data_with_date(self, date, **kwargs): - """ - Will be called in online module - need to return the data that used to predict the label (score) of stocks at date. - - :param - date: pd.Timestamp - predict date - :return: - data: the input data that used to predict the label (score) of stocks at predict date. - """ - raise NotImplementedError("get_data_with_date for this model is not implemented.") - - def finetune(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs): - """Finetune model - In `RollingTrainer`: - if loader.model_index is None: - If provide 'Static Model', based on the provided 'Static' model update. - If provide 'Rolling Model', skip the model of load, based on the last 'provided model' update. - - if loader.model_index is not None: - Based on the provided model(loader.model_index) update. - - In `StaticTrainer`: - If the load is 'static model': - Based on the 'static model' update - If the load is 'rolling model': - Based on the provided model(`loader.model_index`) update. If `loader.model_index` is None, use the last model. - - Parameters - ---------- - x_train : pd.dataframe - train data - y_train : pd.dataframe - train label - x_valid : pd.dataframe - valid data - y_valid : pd.dataframe - valid label - w_train : pd.dataframe - train weight - w_valid : pd.dataframe - valid weight - - Returns - ---------- - Model - finetune model - """ - raise NotImplementedError("Finetune for this model is not implemented.") diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py new file mode 100644 index 000000000..e69de29bb