From 393584e535e3b9104199cddb20626619ce261cfe Mon Sep 17 00:00:00 2001 From: Young Date: Fri, 23 Oct 2020 03:37:10 +0000 Subject: [PATCH] Update handler interface round2 --- examples/workflow_by_code.py | 22 +- qlib/contrib/data/handler.py | 132 +++++---- qlib/contrib/online/manager.py | 2 +- qlib/contrib/online/utils.py | 17 +- qlib/data/dataset/handler.py | 522 ++++++++++----------------------- qlib/data/dataset/loader.py | 274 +++++++++++++++++ qlib/data/dataset/processor.py | 229 +++++++-------- qlib/log.py | 23 ++ qlib/utils/__init__.py | 66 +++++ 9 files changed, 715 insertions(+), 572 deletions(-) create mode 100644 qlib/data/dataset/loader.py diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py index 3179cbab3..9f0d5b02f 100644 --- a/examples/workflow_by_code.py +++ b/examples/workflow_by_code.py @@ -31,7 +31,7 @@ if __name__ == "__main__": qlib.init(provider_uri=provider_uri, region=REG_CN) - MARKET = "CSI300" + MARKET = "csi300" BENCHMARK = "SH000300" @@ -39,27 +39,27 @@ if __name__ == "__main__": # train model ################################### DATA_HANDLER_CONFIG = { - "start_date": "2008-01-01", - "end_date": "2020-08-01", + "start_time": "2008-01-01", + "end_time": "2020-08-01", "fit_start_time":"2008-01-01", "fit_end_time":"2014-12-31", - "market": MARKET, + "instruments": MARKET, } TRAINER_CONFIG = { - "train_start_date": "2008-01-01", - "train_end_date": "2014-12-31", - "validate_start_date": "2015-01-01", - "validate_end_date": "2016-12-31", - "test_start_date": "2017-01-01", - "test_end_date": "2020-08-01", + "train_start_time": "2008-01-01", + "train_end_time": "2014-12-31", + "validate_start_time": "2015-01-01", + "validate_end_time": "2016-12-31", + "test_start_time": "2017-01-01", + "test_end_time": "2020-08-01", } # use default DataHandler # custom DataHandler, refer to: TODO: DataHandler API url handler = Alpha158(**DATA_HANDLER_CONFIG) - data = handler.fetch(slice('2008-01-01', '2014-12-31'), key=handler.DK_I) + data = handler.fetch(slice('2008-01-01', '2014-12-31'), data_key=handler.DK_I) print(data) sys.exit(0) # I have tested the code above --------------------------------------------- diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py index 6f53670dd..c9959535a 100644 --- a/qlib/contrib/data/handler.py +++ b/qlib/contrib/data/handler.py @@ -1,41 +1,73 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from ...data.dataset.handler import ConfigQLibDataHandler -from ...data.dataset.processor import Processor, MinMaxNorm, ZscoreNorm, get_cls_kwargs +from ...data.dataset.handler import DataHandlerLP +from ...data.dataset.processor import Processor, MinMaxNorm, ZscoreNorm +from ...utils import get_cls_kwargs +from ...data.dataset import processor as processor_module from ...log import TimeInspector import copy -class ALPHA360(ConfigQLibDataHandler): - config_template = { - "price": {"windows": range(60)}, - "volume": {"windows": range(60)}, - } +class ALPHA360(DataHandlerLP): + def __init__(self, instruments="csi500", start_time=None, end_time=None): + data_loader = { + "class": "QlibDataLoader", + "kwargs": { + "config": { + "feature": { + "price": { + "windows": range(60) + }, + "volume": { + "windows": range(60) + }, + }, + "label": self.get_label_config() + }, + "group_fields": True, + } + } + infer_processors = ["ConfigSectionProcessor"] # ConfigSectionProcessor will normalize LABEL0 + super().__init__(instruments, start_time, end_time, data_loader=data_loader, infer_processors=infer_processors) + + def get_label_config(self): + return (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]) -class QLibDataHandlerV1(ConfigQLibDataHandler): - config_template = { - "kbar": {}, - "price": { - "windows": [0], - "feature": ["OPEN", "HIGH", "LOW", "VWAP"], - }, - "rolling": {}, - } +class ALPHA360vwap(ALPHA360): + def get_label_config(self): + return (["Ref($vwap, -2)/Ref($vwap, -1) - 1"], ["LABEL0"]) - def __init__(self, start_date, end_date, infer_processors=[], learn_processors=["DropnaLabel"], fit_start_time=None, fit_end_time=None, **kwargs): + +class Alpha158(DataHandlerLP): + def __init__( + self, + instruments="csi500", + start_time=None, + end_time=None, + infer_processors=[], + learn_processors=["DropnaLabel", { + "class": "CSZScoreNorm", + "kwargs": { + "fields_group": "label" + } + }], + fit_start_time=None, + fit_end_time=None, + ): def check_transform_proc(proc_l): new_l = [] for p in proc_l: if not isinstance(p, Processor): - klass, pkwargs = get_cls_kwargs(p) + klass, pkwargs = get_cls_kwargs(p, processor_module) + # FIXME: It's hard code here!!!!! if isinstance(klass, (MinMaxNorm, ZscoreNorm)): - assert(fit_start_time is not None and fit_end_time is not None) + assert (fit_start_time is not None and fit_end_time is not None) pkwargs.update({ "fit_start_time": fit_start_time, "fit_end_time": fit_end_time, - }) + }) new_l.append({"class": klass.__name__, "kwargs": pkwargs}) else: new_l.append(p) @@ -44,37 +76,37 @@ class QLibDataHandlerV1(ConfigQLibDataHandler): infer_processors = check_transform_proc(infer_processors) learn_processors = check_transform_proc(learn_processors) - super().__init__(start_date, end_date, infer_processors=infer_processors, learn_processors=learn_processors, **kwargs) + data_loader = { + "class": "QlibDataLoader", + "kwargs": { + "config": { + "feature": self.get_feature_config(), + "label": self.get_label_config() + }, + "group_fields": True, + } + } + super().__init__(instruments, + start_time, + end_time, + data_loader=data_loader, + infer_processors=infer_processors, + learn_processors=learn_processors) - def load_label(self): - """ - load the labels df - :return: df_labels - """ - TimeInspector.set_time_mark() + def get_feature_config(self): + return { + "kbar": {}, + "price": { + "windows": [0], + "feature": ["OPEN", "HIGH", "LOW", "VWAP"], + }, + "rolling": {}, + } - df_labels = super().load_label() - - ## calculate new labels - df_labels["LABEL1"] = df_labels["LABEL0"].groupby(level="datetime").apply(lambda x: (x - x.mean()) / x.std()) - - df_labels = df_labels.drop(["LABEL0"], axis=1) - - TimeInspector.log_cost_time("Finished loading labels.") - - return df_labels + def get_label_config(self): + return (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]) -class Alpha158(QLibDataHandlerV1): - config_template = { - "kbar": {}, - "price": { - "windows": [0], - "feature": ["OPEN", "HIGH", "LOW", "CLOSE"], - }, - "rolling": {}, - } - - def __init__(self, *args, **kwargs): - kwargs["labels"] = ["Ref($close, -2)/Ref($close, -1) - 1"] - super().__init__(*args, **kwargs) +class Alpha158vwap(Alpha158): + def get_label_config(self): + return (["Ref($vwap, -2)/Ref($vwap, -1) - 1"], ["LABEL0"]) diff --git a/qlib/contrib/online/manager.py b/qlib/contrib/online/manager.py index 7e9c766e8..cf850b9da 100644 --- a/qlib/contrib/online/manager.py +++ b/qlib/contrib/online/manager.py @@ -11,7 +11,7 @@ from ..backtest.account import Account from ..backtest.exchange import Exchange from .user import User from .utils import load_instance -from .utils import save_instance, init_instance_by_config +from ...utils import save_instance, init_instance_by_config class UserManager: diff --git a/qlib/contrib/online/utils.py b/qlib/contrib/online/utils.py index cf08e4dbe..611af63e4 100644 --- a/qlib/contrib/online/utils.py +++ b/qlib/contrib/online/utils.py @@ -7,7 +7,7 @@ import yaml import pandas as pd from ...data import D from ...log import get_module_logger -from ...utils import get_module_by_module_path +from ...utils import get_module_by_module_path, init_instance_by_config from ...utils import get_next_trading_date from ..backtest.exchange import Exchange @@ -45,21 +45,6 @@ def save_instance(instance, file_path): pickle.dump(instance, fr) -def init_instance_by_config(config): - """ - generate an instance with settings in config - Parameter - config : dict - python dict indicate a init parameters to create an item - :return - An instance - """ - module = get_module_by_module_path(config["module_path"]) - instance_class = getattr(module, config["class"]) - instance = instance_class(**config["args"]) - return instance - - def create_user_folder(path): path = pathlib.Path(path) if path.exists(): diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index e523fbfef..7cc7995ea 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -5,7 +5,7 @@ import abc import bisect import logging -from typing import Union +from typing import Union, Tuple import pandas as pd import numpy as np @@ -13,11 +13,13 @@ import numpy as np from ...log import get_module_logger, TimeInspector from ...data import D from ...config import C -from ...utils import parse_config, transform_end_date +from ...utils import parse_config, transform_end_date, init_instance_by_config from ...utils.serial import Serializable from pathlib import Path +from .loader import DataLoader from . import processor as processor_module +from . import loader as data_loader_module # TODO: A more general handler interface which does not relies on internal pd.DataFrame is needed. @@ -30,44 +32,57 @@ class DataHandler(Serializable): The data handler try to maintain a handler with 2 level. `datetime` & `instruments`. - + Any order of the index level can be suported(The order will implied in the data). The order <`datetime`, `instruments`> will be used when the dataframe index name is missed. Example of the data: - - $close $volume Ref($close, 1) Mean($close, 3) $high-$low + The multi-index of the columns is optional. + feature label + $close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0 datetime instrument - 2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 - SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 - SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 - SH600006 22.672380 7095624.0 22.508326 22.573947 0.557785 + 2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032 + SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042 + SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289 ''' - def __init__(self, init_data=True): + def __init__(self, instruments, start_time=None, end_time=None, data_loader: Tuple[dict, str, DataLoader]=None, init_data=True): # Set logger self.logger = get_module_logger("DataHandler") - # Setup data. - self._data = {} + # Setup data loader + assert(data_loader is not None) # to make start_time end_time could have None default value + self.data_loader = init_instance_by_config(data_loader, data_loader_module, accept_types=DataLoader) + + self.instruments = instruments + self.start_time = start_time + self.end_time = end_time if init_data: self.init() super().__init__() - def init(self, force_reload: bool=True): + def init(self, enable_cache: bool=True): """ initialize the data. In case of running intialization for multiple time, it will do nothing for the second time. + It is responsible for maintaining following variable + 1) self._data + Parameters ---------- - force_reload : bool - force to reload the data even if the data have been initialized + enable_cache : bool + default value is false + if `enable_cache` == True + the processed data will be saved on disk, and handler will load the cached data from the disk directly + when we call `init` next time """ - pass - # if force_reload or hasattr(self, '_initialized', False): + # Setup data. + # _data may be with multiple column index level. The outer level indicates the feature set name + self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time) + # TODO: cache - def get_level_index(self, df: pd.DataFrame, level=Union[str, int]) -> int: + def _get_level_index(self, df: pd.DataFrame, level=Union[str, int]) -> int: """ get the level index of `df` given `level` @@ -88,40 +103,78 @@ class DataHandler(Serializable): try: return df.index.names.index(level) except (AttributeError, ValueError): - # NOTE: If level index is not given in the data, the default level index will be ('datetime', 'instrument') + # NOTE: If level index is not given in the data, the default level index will be ('datetime', 'instrument') return ('datetime', 'instrument').index(level) elif isinstance(level, int): return level else: raise NotImplementedError(f"This type of input is not supported") - def _fetch_df(self, df: pd.DataFrame, selector: Union[pd.Timestamp, slice, str, list], level: Union[str, int]): + def _fetch_df_by_index(self, df: pd.DataFrame, selector: Union[pd.Timestamp, slice, str, list], level: Union[str, int]) -> pd.DataFrame: """ fetch data from `data` with `selector` and `level` Parameters ---------- - df : pd.DataFrame - the data frame to be selected selector : Union[pd.Timestamp, slice, str, list] selector - level : Union[pd.Timestamp, slice, str] + level : Union[int, str] the level to use the selector """ # Try to get the right index idx_slc = (selector, slice(None, None)) - if self.get_level_index(df, level) == 1: - idx_slc = idx_slc[1], idx_slc[0] + if self._get_level_index(df, level) == 1: + idx_slc = idx_slc[1], idx_slc[0] return df.loc(axis=0)[idx_slc] - - def fetch(self, selector: Union[pd.Timestamp, slice, str], level='datetime', key=None) -> Union[pd.DataFrame, dict]: - if key is None: - res = {} - for k, df in self._data.items(): - res[k] = self._fetch_df(df, selector, level) + + CS_ALL = '_all' + + def _fetch_df_by_col(self, df: pd.DataFrame, col_set: str) -> pd.DataFrame: + cln = len(df.columns.levels) + if cln == 1: + return df + elif col_set == self.CS_ALL: + return df.droplevel(axis=1, level=0) else: - res = self._fetch_df(self._data[key], selector, level) - return res + return df.loc(axis=1)[col_set] + + def fetch(self, selector: Union[pd.Timestamp, slice, str], level: Union[str, int]='datetime', col_set=CS_ALL) -> pd.DataFrame: + """ + fetch data from underlying data source + + Parameters + ---------- + selector : Union[pd.Timestamp, slice, str] + describe how to select data by index + level : Union[str, int] + which index level to select the data + col_set : str + select a set of meaningful columns.(e.g. features, columns) + + Returns + ------- + pd.DataFrame: + """ + df = self._fetch_df_by_index(self._data, selector, level) + return self._fetch_df_by_col(df, col_set) + + def get_cols(self, col_set=CS_ALL) -> list: + """ + get the column names + + Parameters + ---------- + col_set : str + select a set of meaningful columns.(e.g. features, columns) + + Returns + ------- + list: + list of column names + """ + df = self._data.head() + df = self._fetch_df_by_col(df, col_set) + return df.columns.to_list() class DataHandlerLP(DataHandler): @@ -142,14 +195,13 @@ class DataHandlerLP(DataHandler): # - _proc_learn_df will be processed by infer_processors + learn_processors # - (e.g. _proc_infer_df processed by learn_processors ) - def __init__(self, infer_processors=[], learn_processors=[], process_type=PTYPE_A, **kwargs): + def __init__(self, instruments, start_time=None, end_time=None, data_loader: Tuple[dict, str, DataLoader]=None, infer_processors=[], learn_processors=[], process_type=PTYPE_A, **kwargs): """ - Parameters ---------- infer_processors : list list of of processors to generate data for inference - example of : + example of : 1) classname & kwargs: { "class": "MinMaxNorm", @@ -180,24 +232,18 @@ class DataHandlerLP(DataHandler): self.learn_processors = [] # for lint for pname in 'infer_processors', 'learn_processors': for proc in locals()[pname]: - getattr(self, pname).append(processor_module.init_proc_obj(proc)) + getattr(self, pname).append(init_instance_by_config(proc, processor_module, + accept_types=(processor_module.Processor,))) self.process_type = process_type - super().__init__(**kwargs) + super().__init__(instruments, start_time, end_time, data_loader, **kwargs) def get_all_processors(self): return self.infer_processors + self.learn_processors - def _init_raw_data(self): - """ - initialize the raw data - the raw data will be saved in to `self._data['raw']` - """ - raise NotImplementedError(f"Please implement the `_init_raw_data` method") - def fit(self): for proc in self.get_all_processors(): - proc.fit(self) + proc.fit(self._data) def fit_process_data(self): """ @@ -206,7 +252,7 @@ class DataHandlerLP(DataHandler): The input of the `fit` will be the output of the previous processor """ self.process_data(with_fit=True) - + def process_data(self, with_fit: bool=False): """ @@ -218,50 +264,56 @@ class DataHandlerLP(DataHandler): The input of the `fit` will be the output of the previous processor """ # data for inference - _infer_df = self._data[DataHandlerLP.DK_R] + _infer_df = self._data + if len(self.infer_processors) > 0: # avoid modifying the original data + _infer_df = _infer_df.copy() + for proc in self.infer_processors: if not proc.is_for_infer(): raise TypeError("Only processors usable for inference can be used in `infer_processors` ") if with_fit: - proc.fit(self, _infer_df) + proc.fit(_infer_df) _infer_df = proc(_infer_df) + self._infer = _infer_df # data for learning if self.process_type == DataHandlerLP.PTYPE_I: - _learn_df = self._data[DataHandlerLP.DK_R] + _learn_df = self._data elif self.process_type == DataHandlerLP.PTYPE_A: # based on `infer_df` and append the processor _learn_df = _infer_df else: raise NotImplementedError(f"This type of input is not supported") + if len(self.learn_processors) > 0: # avoid modifying the original data + _learn_df = _learn_df.copy() for proc in self.learn_processors: if with_fit: - proc.fit(self, _learn_df) + proc.fit(_learn_df) _learn_df = proc(_learn_df) - - self._data.update({ - DataHandlerLP.DK_I: _infer_df, - DataHandlerLP.DK_L: _learn_df, - }) + self._learn = _learn_df # init type IT_FIT_SEQ = 'fit_seq' # the input of `fit` will be the output of the previous processor IT_FIT_IND = 'fit_ind' # the input of `fit` will be the original df IT_LS = 'load_state' # The state of the object has been load by pickle - - def init(self, init_type: str=IT_FIT_SEQ, path: Path=None): + + def init(self, init_type: str=IT_FIT_SEQ, enable_cache: bool=False): """ Initialize the data of Qlib Parameters ---------- init_type : str - 'fit' or 'load_state' - path : path - if `init_type` == 'load_state': `path` will be used to load_state + The type `IT_*` listed above + enable_cache : bool + default value is false + if `enable_cache` == True: + the processed data will be saved on disk, and handler will load the cached data from the disk directly + when we call `init` next time """ - self._init_raw_data() + # init raw data + super().init(enable_cache=enable_cache) if init_type == DataHandlerLP.IT_FIT_IND: self.fit() @@ -275,311 +327,53 @@ class DataHandlerLP(DataHandler): # TODO: Be able to cache handler data. Save the memory for data processing + def _get_df_by_key(self, data_key: str=DK_I) -> pd.DataFrame: + df = getattr(self, {self.DK_R: '_data', self.DK_I: "_infer", self.DK_L: "_learn"}[data_key]) + return df -class DataHandlerLPWL(DataHandlerLP): - ''' - DataHandler with (L)earnable (P)rocessor with (L)abel - ''' - - def _init_raw_data(self): + def fetch(self, + selector: Union[pd.Timestamp, slice, str], + level: Union[str, int] = 'datetime', + col_set=DataHandler.CS_ALL, + data_key: str = DK_I) -> pd.DataFrame: """ - init raw_df, feature_names, label_names of DataHandler - if the index of df_feature and df_label are not same, user need to overload this method to merge (e.g. inner, left, right merge). + fetch data from underlying data source + + Parameters + ---------- + selector : Union[pd.Timestamp, slice, str] + describe how to select data by index + level : Union[str, int] + which index level to select the data + col_set : str + select a set of meaningful columns.(e.g. features, columns) + data_key: str + The data to fetch: DK_* + + Returns + ------- + pd.DataFrame: """ - df_features = self.load_feature() - feature_names = df_features.columns + df = self._get_df_by_key(data_key) + df = self._fetch_df_by_index(df, selector, level) + return self._fetch_df_by_col(df, col_set) - df_labels = self.load_label() - label_names = df_labels.columns - - raw_df = df_features.merge(df_labels, left_index=True, right_index=True, how="left") - self.feature_names = feature_names - self.label_names = label_names - self._data['raw'] = raw_df - - def load_feature(self): + def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str=DK_I) -> list: """ - Implement this method to load raw feature. - the format of the feature is below - return: df_features + get the column names + + Parameters + ---------- + col_set : str + select a set of meaningful columns.(e.g. features, columns) + data_key: str + The data to fetch: DK_* + + Returns + ------- + list: + list of column names """ - raise NotImplementedError(f"Please implement `load_feature`") - - def load_label(self): - """ - Implement this method to load and calculate label. - the format of the label is below - - return: df_label - """ - raise NotImplementedError(f"Please implement `load_label`") - - def get_feature_names(self): - return self.feature_names - - def get_label_names(self): - return self.label_names - - -class QLibDataHandler(DataHandlerLPWL): - def __init__(self, start_date, end_date, *args, **kwargs): - # Dates. - self.start_date = start_date - self.end_date = end_date - - # Instruments - instruments = kwargs.pop("instruments", None) - if instruments is None: - market = kwargs.pop("market", "csi500").lower() - data_filter_list = kwargs.pop("data_filter_list", list()) - self.instruments = D.instruments(market, filter_pipe=data_filter_list) - else: - self.instruments = instruments - - # Config of features and labels - self._fields = kwargs.pop("fields", []) - self._names = kwargs.pop("names", []) - self._labels = kwargs.pop("labels", []) - self._label_names = kwargs.pop("label_names", []) - - # Check arguments - assert len(self._fields) > 0, "features list is empty" - assert len(self._labels) > 0, "labels list is empty" - - # Check end_date - # If test_end_date is -1 or greater than the last date, the last date is used - self.end_date = transform_end_date(self.end_date) - - super().__init__(*args, **kwargs) - - def load_feature(self): - """ - Load the raw data. - return: df_features - """ - TimeInspector.set_time_mark() - - if len(self._names) == 0: - names = ["F%d" % i for i in range(len(self._fields))] - else: - names = self._names - - df_features = D.features(self.instruments, self._fields, self.start_date, self.end_date) - df_features.columns = names - - TimeInspector.log_cost_time("Finished loading features.") - - return df_features - - def load_label(self): - """ - Build up labels in df through users' method - :return: df_labels - """ - TimeInspector.set_time_mark() - - if len(self._label_names) == 0: - label_names = ["LABEL%d" % i for i in range(len(self._labels))] - else: - label_names = self._label_names - - df_labels = D.features(self.instruments, self._labels, self.start_date, self.end_date) - df_labels.columns = label_names - - TimeInspector.log_cost_time("Finished loading labels.") - - return df_labels - - -def parse_config_to_fields(config): - """create factors from config - - config = { - 'kbar': {}, # whether to use some hard-code kbar features - 'price': { # whether to use raw price features - 'windows': [0, 1, 2, 3, 4], # use price at n days ago - 'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use - }, - 'volume': { # whether to use raw volume features - 'windows': [0, 1, 2, 3, 4], # use volume at n days ago - }, - 'rolling': { # whether to use rolling operator based features - 'windows': [5, 10, 20, 30, 60], # rolling windows size - 'include': ['ROC', 'MA', 'STD'], # rolling operator to use - #if include is None we will use default operators - 'exclude': ['RANK'], # rolling operator not to use - } - } - """ - fields = [] - names = [] - if "kbar" in config: - fields += [ - "($close-$open)/$open", - "($high-$low)/$open", - "($close-$open)/($high-$low+1e-12)", - "($high-Greater($open, $close))/$open", - "($high-Greater($open, $close))/($high-$low+1e-12)", - "(Less($open, $close)-$low)/$open", - "(Less($open, $close)-$low)/($high-$low+1e-12)", - "(2*$close-$high-$low)/$open", - "(2*$close-$high-$low)/($high-$low+1e-12)", - ] - names += [ - "KMID", - "KLEN", - "KMID2", - "KUP", - "KUP2", - "KLOW", - "KLOW2", - "KSFT", - "KSFT2", - ] - if "price" in config: - windows = config["price"].get("windows", range(5)) - feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"]) - for field in feature: - field = field.lower() - fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows] - names += [field.upper() + str(d) for d in windows] - if "volume" in config: - windows = config["volume"].get("windows", range(5)) - fields += ["Ref($volume, %d)/$volume" % d if d != 0 else "$volume/$volume" for d in windows] - names += ["VOLUME" + str(d) for d in windows] - if "rolling" in config: - windows = config["rolling"].get("windows", [5, 10, 20, 30, 60]) - include = config["rolling"].get("include", None) - exclude = config["rolling"].get("exclude", []) - # `exclude` in dataset config unnecessary filed - # `include` in dataset config necessary field - use = lambda x: x not in exclude and (include is None or x in include) - if use("ROC"): - fields += ["Ref($close, %d)/$close" % d for d in windows] - names += ["ROC%d" % d for d in windows] - if use("MA"): - fields += ["Mean($close, %d)/$close" % d for d in windows] - names += ["MA%d" % d for d in windows] - if use("STD"): - fields += ["Std($close, %d)/$close" % d for d in windows] - names += ["STD%d" % d for d in windows] - if use("BETA"): - fields += ["Slope($close, %d)/$close" % d for d in windows] - names += ["BETA%d" % d for d in windows] - if use("RSQR"): - fields += ["Rsquare($close, %d)" % d for d in windows] - names += ["RSQR%d" % d for d in windows] - if use("RESI"): - fields += ["Resi($close, %d)/$close" % d for d in windows] - names += ["RESI%d" % d for d in windows] - if use("MAX"): - fields += ["Max($high, %d)/$close" % d for d in windows] - names += ["MAX%d" % d for d in windows] - if use("LOW"): - fields += ["Min($low, %d)/$close" % d for d in windows] - names += ["MIN%d" % d for d in windows] - if use("QTLU"): - fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows] - names += ["QTLU%d" % d for d in windows] - if use("QTLD"): - fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows] - names += ["QTLD%d" % d for d in windows] - if use("RANK"): - fields += ["Rank($close, %d)" % d for d in windows] - names += ["RANK%d" % d for d in windows] - if use("RSV"): - fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows] - names += ["RSV%d" % d for d in windows] - if use("IMAX"): - fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows] - names += ["IMAX%d" % d for d in windows] - if use("IMIN"): - fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows] - names += ["IMIN%d" % d for d in windows] - if use("IMXD"): - fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows] - names += ["IMXD%d" % d for d in windows] - if use("CORR"): - fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows] - names += ["CORR%d" % d for d in windows] - if use("CORD"): - fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows] - names += ["CORD%d" % d for d in windows] - if use("CNTP"): - fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows] - names += ["CNTP%d" % d for d in windows] - if use("CNTN"): - fields += ["Mean($closeRef($close, 1), %d)-Mean($close pd.DataFrame: + """ + load the data as pd.DataFrame + + Returns + ------- + pd.DataFrame: + data load from the under layer source + + Example of the data: + The multi-index of the columns is optional. + feature label + $close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0 + datetime instrument + 2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032 + SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042 + SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289 + """ + pass + + +class QlibDataLoader(DataLoader): + '''Same as QlibDataLoader. The fields can be define by config''' + def __init__(self, config: Tuple[list, tuple, dict], group_fields: bool = False, filter_pipe=None): + """ + Parameters + ---------- + config : Tuple[list ,tuple, dict] + Config will be used to describe the fields and column names + + if `group_fields`: + := { + "group_name1": + "group_name2": + } + else: + := + + := ["expr", ...] | (["expr", ...], ["col_name", ...]) | + + is a config with dict type which could be parsed by `parse_config_to_fields` + + Here is a few examples to describe the fields + TODO: + + group_fields : bool + Will the fields be grouped. Multi-index will be used for the group + """ + if group_fields: + fields_all = [] + name_grp_info = [] + for grp, fields_info in config.items(): + fields, names = self._parse_fields_info(fields_info) + fields_all.extend(fields) + name_grp_info.extend([(grp, n) for n in names]) + self.fields, self.names = fields_all, name_grp_info + else: + self.fields, self.names = self._parse_fields_info(fields_info) + + self.group_fields = group_fields + self.filter_pipe = filter_pipe + + def _parse_fields_info(self, fields_info: Tuple[list, tuple, dict]) -> Tuple[list, list]: + if isinstance(fields_info, dict): + fields, names = parse_config_to_fields(fields_info) + elif isinstance(fields_info, list): + fields = fields_info + names = fields + elif isinstance(fields_info, tuple): + fields, names = fields_info + else: + raise NotImplementedError(f"This type of input is not supported") + return fields, names + + def load(self, + instruments, + config: Tuple[list, tuple, dict], + group_fields=False, + start_time=None, + end_time=None) -> Tuple[pd.DataFrame, dict]: + df = D.features(D.instruments(instruments, filter_pipe=self.filter_pipe), self.fields, start_time, end_time) + df.columns = pd.MultiIndex.from_tuples(self.names) if self.group_fields else self.names + df = df.swaplevel().sort_index() + return df + + +# TODO: make it easier to understand the config language +def parse_config_to_fields(config): + """create factors from config + + config = { + 'kbar': {}, # whether to use some hard-code kbar features + 'price': { # whether to use raw price features + 'windows': [0, 1, 2, 3, 4], # use price at n days ago + 'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use + }, + 'volume': { # whether to use raw volume features + 'windows': [0, 1, 2, 3, 4], # use volume at n days ago + }, + 'rolling': { # whether to use rolling operator based features + 'windows': [5, 10, 20, 30, 60], # rolling windows size + 'include': ['ROC', 'MA', 'STD'], # rolling operator to use + #if include is None we will use default operators + 'exclude': ['RANK'], # rolling operator not to use + } + } + """ + fields = [] + names = [] + if "kbar" in config: + fields += [ + "($close-$open)/$open", + "($high-$low)/$open", + "($close-$open)/($high-$low+1e-12)", + "($high-Greater($open, $close))/$open", + "($high-Greater($open, $close))/($high-$low+1e-12)", + "(Less($open, $close)-$low)/$open", + "(Less($open, $close)-$low)/($high-$low+1e-12)", + "(2*$close-$high-$low)/$open", + "(2*$close-$high-$low)/($high-$low+1e-12)", + ] + names += [ + "KMID", + "KLEN", + "KMID2", + "KUP", + "KUP2", + "KLOW", + "KLOW2", + "KSFT", + "KSFT2", + ] + if "price" in config: + windows = config["price"].get("windows", range(5)) + feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"]) + for field in feature: + field = field.lower() + fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows] + names += [field.upper() + str(d) for d in windows] + if "volume" in config: + windows = config["volume"].get("windows", range(5)) + fields += ["Ref($volume, %d)/$volume" % d if d != 0 else "$volume/$volume" for d in windows] + names += ["VOLUME" + str(d) for d in windows] + if "rolling" in config: + windows = config["rolling"].get("windows", [5, 10, 20, 30, 60]) + include = config["rolling"].get("include", None) + exclude = config["rolling"].get("exclude", []) + # `exclude` in dataset config unnecessary filed + # `include` in dataset config necessary field + use = lambda x: x not in exclude and (include is None or x in include) + if use("ROC"): + fields += ["Ref($close, %d)/$close" % d for d in windows] + names += ["ROC%d" % d for d in windows] + if use("MA"): + fields += ["Mean($close, %d)/$close" % d for d in windows] + names += ["MA%d" % d for d in windows] + if use("STD"): + fields += ["Std($close, %d)/$close" % d for d in windows] + names += ["STD%d" % d for d in windows] + if use("BETA"): + fields += ["Slope($close, %d)/$close" % d for d in windows] + names += ["BETA%d" % d for d in windows] + if use("RSQR"): + fields += ["Rsquare($close, %d)" % d for d in windows] + names += ["RSQR%d" % d for d in windows] + if use("RESI"): + fields += ["Resi($close, %d)/$close" % d for d in windows] + names += ["RESI%d" % d for d in windows] + if use("MAX"): + fields += ["Max($high, %d)/$close" % d for d in windows] + names += ["MAX%d" % d for d in windows] + if use("LOW"): + fields += ["Min($low, %d)/$close" % d for d in windows] + names += ["MIN%d" % d for d in windows] + if use("QTLU"): + fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows] + names += ["QTLU%d" % d for d in windows] + if use("QTLD"): + fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows] + names += ["QTLD%d" % d for d in windows] + if use("RANK"): + fields += ["Rank($close, %d)" % d for d in windows] + names += ["RANK%d" % d for d in windows] + if use("RSV"): + fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows] + names += ["RSV%d" % d for d in windows] + if use("IMAX"): + fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows] + names += ["IMAX%d" % d for d in windows] + if use("IMIN"): + fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows] + names += ["IMIN%d" % d for d in windows] + if use("IMXD"): + fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows] + names += ["IMXD%d" % d for d in windows] + if use("CORR"): + fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows] + names += ["CORR%d" % d for d in windows] + if use("CORD"): + fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows] + names += ["CORD%d" % d for d in windows] + if use("CNTP"): + fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows] + names += ["CNTP%d" % d for d in windows] + if use("CNTN"): + fields += ["Mean($closeRef($close, 1), %d)-Mean($close (type, dict): - """ - extract class and kwargs from processor info - - Parameters - ---------- - processor : [dict, str] - similar to processor - - Returns - ------- - (type, dict): - the class object and it's arguments. - """ - if isinstance(processor, dict): - # raise AttributeError - klass = globals()[processor['class']] - kwargs = processor['kwargs'] - elif isinstance(processor, str): - klass = globals()[processor] - kwargs = {} - else: - raise NotImplementedError(f"This type of input is not supported") - return klass, kwargs - - -# Place the function here to be able to reference the Processor -def init_proc_obj(processor: [dict, str, Processor]) -> Processor: - """ - Initialize Processor Object - - Parameters - ---------- - processor : [dict, str, Processor] - The info to initialize processor - - Returns - ------- - Processor: - initialized Processor - """ - if not isinstance(processor, Processor): - klass, pkwargs = get_cls_kwargs(processor) - processor = klass(**pkwargs) - return processor - - -class InferProcessor(Processor): - '''This processor is usable for inference''' def is_for_infer(self) -> bool: """ Is this processor usable for inference + Some processors are not usable for inference. Returns ------- @@ -105,37 +72,24 @@ class InferProcessor(Processor): return True -class NInferProcessor(Processor): - '''This processor is not usable for inference''' - def is_for_infer(self) -> bool: - """ - Is this processor usable for inference +class DropnaProcessor(Processor): + def __init__(self, group=None): + self.group = group - Returns - ------- - bool: - if it is usable for infenrece - """ + def __call__(self, df): + return df.dropna(subset=get_group_columns(df, self.group)) + + +class DropnaLabel(DropnaProcessor): + def __init__(self, group='label'): + super().__init__(group=group) + + def is_for_infer(self) -> bool: + '''The samples are dropped according to label. So it is not usable for inference''' return False -class DropnaFeature(InferProcessor): - def fit(self, handler, df=None): - self.feature_names = copy.deepcopy(handler.get_feature_names()) - - def __call__(self, df): - return df.dropna(subset=self.feature_names) - - -class DropnaLabel(InferProcessor): - def fit(self, handler, df=None): - self.label_names = copy.deepcopy(handler.get_label_names()) - - def __call__(self, df): - return df.dropna(subset=self.label_names) - - -class ProcessInf(InferProcessor): +class ProcessInf(Processor): '''Process infinity ''' def __call__(self, df): def replace_inf(data): @@ -151,22 +105,20 @@ class ProcessInf(InferProcessor): return replace_inf(df) -class MinMaxNorm(InferProcessor): - def __init__(self, fit_start_time, fit_end_time): +class MinMaxNorm(Processor): + def __init__(self, fit_start_time, fit_end_time, fields_group=None): self.fit_start_time = fit_start_time self.fit_end_time = fit_end_time + self.fields_group = fields_group - def fit(self, handler, df): - # TODO: 看看这里怎么取数据 - self.min_val = np.nanmin(df[handler.get_feature_names()].values, axis=0) - self.max_val = np.nanmax(df[handler.get_feature_names()].values, axis=0) + def fit(self, df): + cols = get_group_columns(df, self.fields_group) + self.min_val = np.nanmin(df[cols].values, axis=0) + self.max_val = np.nanmax(df[cols].values, axis=0) self.ignore = self.min_val == self.max_val - self.feature_names = copy.deepcopy(handler.get_feature_names()) + self.cols = cols def __call__(self, df): - # FIXME: The df will be changed inplace. It's very dangerous - # The code below is ugly - df = df.copy() # currently copy is used def normalize(x, min_val=self.min_val, max_val=self.max_val, ignore=self.ignore): if (~ignore).all(): return (x - min_val) / (max_val - min_val) @@ -174,25 +126,24 @@ class MinMaxNorm(InferProcessor): if not ignore[i]: x[i] = (x[i] - min_val) / (max_val - min_val) return x - df.loc(axis=1)[self.feature_names] = normalize(df[self.feature_names].values) + df.loc(axis=1)[self.cols] = normalize(df[self.cols].values) return df -class ZscoreNorm(InferProcessor): - def __init__(self, fit_start_time, fit_end_time): +class ZscoreNorm(Processor): + def __init__(self, fit_start_time, fit_end_time, fields_group=None): self.fit_start_time = fit_start_time self.fit_end_time = fit_end_time + self.fields_group = fields_group - def fit(self, handler, df): - self.mean_train = np.nanmean(df[handler.get_feature_names()].values, axis=0) - self.std_train = np.nanstd(df[handler.get_feature_names()].values, axis=0) + def fit(self, df): + cols = get_group_columns(df, self.fields_group) + self.mean_train = np.nanmean(df[cols].values, axis=0) + self.std_train = np.nanstd(df[cols].values, axis=0) self.ignore = self.std_train == 0 - self.feature_names = handler.get_feature_names() + self.cols = cols def __call__(self, df): - # FIXME: The df will be changed inplace. It's very dangerous - # The code below is ugly - df = df.copy() # currently copy is used def normalize(x, mean_train=self.mean_train, std_train=self.std_train, ignore=self.ignore): if (~ignore).all(): return (x - mean_train) / std_train @@ -200,12 +151,27 @@ class ZscoreNorm(InferProcessor): if not ignore[i]: x[i] = (x[i] - mean_train) / std_train return x - df.loc(axis=1)[self.feature_names] = normalize(df[self.feature_names].values) + df.loc(axis=1)[self.cols] = normalize(df[self.cols].values) return df -class ConfigSectionProcessor(InferProcessor): - def __init__(self, **kwargs): +class CSZScoreNorm(Processor): + '''Cross Sectional ZScore Normalization''' + def __init__(self, fields_group=None): + self.fields_group = fields_group + + def __call__(self, df): + # try not modify original dataframe + cols = get_group_columns(df,self.fields_group) + df[cols] = df[cols].groupby('datetime').apply(lambda df: (df - df.mean()).div(df.std())) + return df + + +# TODO: make the config language easier to understand +class ConfigSectionProcessor(Processor): + # TODO: this class is not well tested + # FIXME: this will raise error when multi-index is passed in + def __init__(self, fields_group=None, **kwargs): super().__init__() # Options self.fillna_feature = kwargs.get("fillna_feature", True) @@ -214,9 +180,7 @@ class ConfigSectionProcessor(InferProcessor): self.shrink_feature_outlier = kwargs.get("shrink_feature_outlier", True) self.clip_label_outlier = kwargs.get("clip_label_outlier", False) - def fit(self, handler, df=None): - self.feature_names = handler.get_feature_names() - self.label_names = handler.get_label_names() + self.fields_group = None def __call__(self, df): return self._transform(df) @@ -245,19 +209,22 @@ class ConfigSectionProcessor(InferProcessor): TimeInspector.set_time_mark() - # Copy - df_new = df.copy() + # Copy the focus part and change it to single level + selected_cols = get_group_columns(df, self.fields_group) + df_focus = df[selected_cols].copy() + if len(df_focus.columns.levels) > 1: + df_focus = df_focus.droplevel(level=0) # Label - cols = df.columns[df.columns.str.contains("^LABEL")] - df_new[cols] = df[cols].groupby(level="datetime").apply(_label_norm) + cols = df_focus.columns[df_focus.columns.str.contains("^LABEL")] + df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_label_norm) # Features - cols = df.columns[df.columns.str.contains("^KLEN|^KLOW|^KUP")] - df_new[cols] = df[cols].apply(lambda x: x ** 0.25).groupby(level="datetime").apply(_feature_norm) + cols = df_focus.columns[df_focus.columns.str.contains("^KLEN|^KLOW|^KUP")] + df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.25).groupby(level="datetime").apply(_feature_norm) - cols = df.columns[df.columns.str.contains("^KLOW2|^KUP2")] - df_new[cols] = df[cols].apply(lambda x: x ** 0.5).groupby(level="datetime").apply(_feature_norm) + cols = df_focus.columns[df_focus.columns.str.contains("^KLOW2|^KUP2")] + df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.5).groupby(level="datetime").apply(_feature_norm) _cols = [ "KMID", @@ -282,27 +249,29 @@ class ConfigSectionProcessor(InferProcessor): "VSUMD", ] pat = "|".join(["^" + x for x in _cols]) - cols = df.columns[df.columns.str.contains(pat) & (~df.columns.isin(["HIGH0", "LOW0"]))] - df_new[cols] = df[cols].groupby(level="datetime").apply(_feature_norm) + cols = df_focus.columns[df_focus.columns.str.contains(pat) & (~df_focus.columns.isin(["HIGH0", "LOW0"]))] + df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_feature_norm) - cols = df.columns[df.columns.str.contains("^STD|^VOLUME|^VMA|^VSTD")] - df_new[cols] = df[cols].apply(np.log).groupby(level="datetime").apply(_feature_norm) + cols = df_focus.columns[df_focus.columns.str.contains("^STD|^VOLUME|^VMA|^VSTD")] + df_focus[cols] = df_focus[cols].apply(np.log).groupby(level="datetime").apply(_feature_norm) - cols = df.columns[df.columns.str.contains("^RSQR")] - df_new[cols] = df[cols].fillna(0).groupby(level="datetime").apply(_feature_norm) + cols = df_focus.columns[df_focus.columns.str.contains("^RSQR")] + df_focus[cols] = df_focus[cols].fillna(0).groupby(level="datetime").apply(_feature_norm) - cols = df.columns[df.columns.str.contains("^MAX|^HIGH0")] - df_new[cols] = df[cols].apply(lambda x: (x - 1) ** 0.5).groupby(level="datetime").apply(_feature_norm) + cols = df_focus.columns[df_focus.columns.str.contains("^MAX|^HIGH0")] + df_focus[cols] = df_focus[cols].apply(lambda x: (x - 1) ** 0.5).groupby(level="datetime").apply(_feature_norm) - cols = df.columns[df.columns.str.contains("^MIN|^LOW0")] - df_new[cols] = df[cols].apply(lambda x: (1 - x) ** 0.5).groupby(level="datetime").apply(_feature_norm) + cols = df_focus.columns[df_focus.columns.str.contains("^MIN|^LOW0")] + df_focus[cols] = df_focus[cols].apply(lambda x: (1 - x) ** 0.5).groupby(level="datetime").apply(_feature_norm) - cols = df.columns[df.columns.str.contains("^CORR|^CORD")] - df_new[cols] = df[cols].apply(np.exp).groupby(level="datetime").apply(_feature_norm) + cols = df_focus.columns[df_focus.columns.str.contains("^CORR|^CORD")] + df_focus[cols] = df_focus[cols].apply(np.exp).groupby(level="datetime").apply(_feature_norm) - cols = df.columns[df.columns.str.contains("^WVMA")] - df_new[cols] = df[cols].apply(np.log1p).groupby(level="datetime").apply(_feature_norm) + cols = df_focus.columns[df_focus.columns.str.contains("^WVMA")] + df_focus[cols] = df_focus[cols].apply(np.log1p).groupby(level="datetime").apply(_feature_norm) + + df[selected_cols] = df_focus.values TimeInspector.log_cost_time("Finished preprocessing data.") - return df_new + return df diff --git a/qlib/log.py b/qlib/log.py index 7db9ea92d..1f06f87f5 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -8,6 +8,7 @@ import os import re from logging import config as logging_config from time import time +from contextlib import contextmanager from .config import C @@ -79,6 +80,28 @@ class TimeInspector(object): cost_time = time() - cls.time_marks.pop() cls.timer_logger.info("Time cost: {0:.5f} | {1}".format(cost_time, info)) + @contextmanager + @classmethod + def logt(cls, name="", show_start=False): + """logt. + Log the time of the inside code + + Parameters + ---------- + name : + name + show_start : + show_start + """ + if show_start: + cls.timer_logger.info(f"Begin {name}") + cls.set_time_mark() + try: + yield None + finally: + pass + cls.log_cost_time() + def set_log_with_config(log_config: dict): """set log with config diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 0e0b76e1c..b10735868 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -23,6 +23,7 @@ import contextlib import numpy as np import pandas as pd from pathlib import Path +from typing import Union, Tuple from ..config import C from ..log import get_module_logger @@ -164,6 +165,71 @@ def get_module_by_module_path(module_path): return module +def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict): + """ + extract class and kwargs from config info + + Parameters + ---------- + config : [dict, str] + similar to config + + module : Python module + It should be a python module to load the class type + + Returns + ------- + (type, dict): + the class object and it's arguments. + """ + if isinstance(config, dict): + # raise AttributeError + klass = getattr(module, config['class']) + kwargs = config['kwargs'] + elif isinstance(config, str): + klass = getattr(module, config) + kwargs = {} + else: + raise NotImplementedError(f"This type of input is not supported") + return klass, kwargs + + +def init_instance_by_config(config: Union[str, dict], module=None, accept_types: Tuple[type]=tuple([])) -> object: + """ + get initialized instance with config + + Parameters + ---------- + config : Union[str, dict] + dict example. + { + 'class': 'ClassName', + 'kwargs': dict, # It is optional. {} will be used if not given + 'model_path': path, # It is optional if module is given + } + str example. + "ClassName": getattr(module, config)() will be used. + module : Python module + Optional. It should be a python module. + + accept_types: Tuple[type] + Optional. If the config is a instance of specific type, return the config directly. + + Returns + ------- + object: + An initialized object based on the config info + """ + if isinstance(config, accept_types): + return config + + if module is None: + module = get_module_by_module_path(config["module_path"]) + + klass, kwargs = get_cls_kwargs(config, module) + return klass(**kwargs) + + def compare_dict_value(src_data: dict, dst_data: dict): """Compare dict value