mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 20:11:08 +08:00
Update handler interface round2
This commit is contained in:
@@ -31,7 +31,7 @@ if __name__ == "__main__":
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
MARKET = "CSI300"
|
||||
MARKET = "csi300"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
|
||||
@@ -39,27 +39,27 @@ if __name__ == "__main__":
|
||||
# train model
|
||||
###################################
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"start_date": "2008-01-01",
|
||||
"end_date": "2020-08-01",
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time":"2008-01-01",
|
||||
"fit_end_time":"2014-12-31",
|
||||
"market": MARKET,
|
||||
"instruments": MARKET,
|
||||
}
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
"train_start_date": "2008-01-01",
|
||||
"train_end_date": "2014-12-31",
|
||||
"validate_start_date": "2015-01-01",
|
||||
"validate_end_date": "2016-12-31",
|
||||
"test_start_date": "2017-01-01",
|
||||
"test_end_date": "2020-08-01",
|
||||
"train_start_time": "2008-01-01",
|
||||
"train_end_time": "2014-12-31",
|
||||
"validate_start_time": "2015-01-01",
|
||||
"validate_end_time": "2016-12-31",
|
||||
"test_start_time": "2017-01-01",
|
||||
"test_end_time": "2020-08-01",
|
||||
}
|
||||
|
||||
# use default DataHandler
|
||||
# custom DataHandler, refer to: TODO: DataHandler API url
|
||||
handler = Alpha158(**DATA_HANDLER_CONFIG)
|
||||
|
||||
data = handler.fetch(slice('2008-01-01', '2014-12-31'), key=handler.DK_I)
|
||||
data = handler.fetch(slice('2008-01-01', '2014-12-31'), data_key=handler.DK_I)
|
||||
print(data)
|
||||
|
||||
sys.exit(0) # I have tested the code above ---------------------------------------------
|
||||
|
||||
@@ -1,41 +1,73 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from ...data.dataset.handler import ConfigQLibDataHandler
|
||||
from ...data.dataset.processor import Processor, MinMaxNorm, ZscoreNorm, get_cls_kwargs
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...data.dataset.processor import Processor, MinMaxNorm, ZscoreNorm
|
||||
from ...utils import get_cls_kwargs
|
||||
from ...data.dataset import processor as processor_module
|
||||
from ...log import TimeInspector
|
||||
import copy
|
||||
|
||||
|
||||
class ALPHA360(ConfigQLibDataHandler):
|
||||
config_template = {
|
||||
"price": {"windows": range(60)},
|
||||
"volume": {"windows": range(60)},
|
||||
}
|
||||
class ALPHA360(DataHandlerLP):
|
||||
def __init__(self, instruments="csi500", start_time=None, end_time=None):
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": {
|
||||
"feature": {
|
||||
"price": {
|
||||
"windows": range(60)
|
||||
},
|
||||
"volume": {
|
||||
"windows": range(60)
|
||||
},
|
||||
},
|
||||
"label": self.get_label_config()
|
||||
},
|
||||
"group_fields": True,
|
||||
}
|
||||
}
|
||||
infer_processors = ["ConfigSectionProcessor"] # ConfigSectionProcessor will normalize LABEL0
|
||||
super().__init__(instruments, start_time, end_time, data_loader=data_loader, infer_processors=infer_processors)
|
||||
|
||||
def get_label_config(self):
|
||||
return (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
|
||||
|
||||
|
||||
class QLibDataHandlerV1(ConfigQLibDataHandler):
|
||||
config_template = {
|
||||
"kbar": {},
|
||||
"price": {
|
||||
"windows": [0],
|
||||
"feature": ["OPEN", "HIGH", "LOW", "VWAP"],
|
||||
},
|
||||
"rolling": {},
|
||||
}
|
||||
class ALPHA360vwap(ALPHA360):
|
||||
def get_label_config(self):
|
||||
return (["Ref($vwap, -2)/Ref($vwap, -1) - 1"], ["LABEL0"])
|
||||
|
||||
def __init__(self, start_date, end_date, infer_processors=[], learn_processors=["DropnaLabel"], fit_start_time=None, fit_end_time=None, **kwargs):
|
||||
|
||||
class Alpha158(DataHandlerLP):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi500",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
infer_processors=[],
|
||||
learn_processors=["DropnaLabel", {
|
||||
"class": "CSZScoreNorm",
|
||||
"kwargs": {
|
||||
"fields_group": "label"
|
||||
}
|
||||
}],
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
):
|
||||
def check_transform_proc(proc_l):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
if not isinstance(p, Processor):
|
||||
klass, pkwargs = get_cls_kwargs(p)
|
||||
klass, pkwargs = get_cls_kwargs(p, processor_module)
|
||||
# FIXME: It's hard code here!!!!!
|
||||
if isinstance(klass, (MinMaxNorm, ZscoreNorm)):
|
||||
assert(fit_start_time is not None and fit_end_time is not None)
|
||||
assert (fit_start_time is not None and fit_end_time is not None)
|
||||
pkwargs.update({
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
})
|
||||
})
|
||||
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
|
||||
else:
|
||||
new_l.append(p)
|
||||
@@ -44,37 +76,37 @@ class QLibDataHandlerV1(ConfigQLibDataHandler):
|
||||
infer_processors = check_transform_proc(infer_processors)
|
||||
learn_processors = check_transform_proc(learn_processors)
|
||||
|
||||
super().__init__(start_date, end_date, infer_processors=infer_processors, learn_processors=learn_processors, **kwargs)
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": {
|
||||
"feature": self.get_feature_config(),
|
||||
"label": self.get_label_config()
|
||||
},
|
||||
"group_fields": True,
|
||||
}
|
||||
}
|
||||
super().__init__(instruments,
|
||||
start_time,
|
||||
end_time,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors)
|
||||
|
||||
def load_label(self):
|
||||
"""
|
||||
load the labels df
|
||||
:return: df_labels
|
||||
"""
|
||||
TimeInspector.set_time_mark()
|
||||
def get_feature_config(self):
|
||||
return {
|
||||
"kbar": {},
|
||||
"price": {
|
||||
"windows": [0],
|
||||
"feature": ["OPEN", "HIGH", "LOW", "VWAP"],
|
||||
},
|
||||
"rolling": {},
|
||||
}
|
||||
|
||||
df_labels = super().load_label()
|
||||
|
||||
## calculate new labels
|
||||
df_labels["LABEL1"] = df_labels["LABEL0"].groupby(level="datetime").apply(lambda x: (x - x.mean()) / x.std())
|
||||
|
||||
df_labels = df_labels.drop(["LABEL0"], axis=1)
|
||||
|
||||
TimeInspector.log_cost_time("Finished loading labels.")
|
||||
|
||||
return df_labels
|
||||
def get_label_config(self):
|
||||
return (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
|
||||
|
||||
|
||||
class Alpha158(QLibDataHandlerV1):
|
||||
config_template = {
|
||||
"kbar": {},
|
||||
"price": {
|
||||
"windows": [0],
|
||||
"feature": ["OPEN", "HIGH", "LOW", "CLOSE"],
|
||||
},
|
||||
"rolling": {},
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["labels"] = ["Ref($close, -2)/Ref($close, -1) - 1"]
|
||||
super().__init__(*args, **kwargs)
|
||||
class Alpha158vwap(Alpha158):
|
||||
def get_label_config(self):
|
||||
return (["Ref($vwap, -2)/Ref($vwap, -1) - 1"], ["LABEL0"])
|
||||
|
||||
@@ -11,7 +11,7 @@ from ..backtest.account import Account
|
||||
from ..backtest.exchange import Exchange
|
||||
from .user import User
|
||||
from .utils import load_instance
|
||||
from .utils import save_instance, init_instance_by_config
|
||||
from ...utils import save_instance, init_instance_by_config
|
||||
|
||||
|
||||
class UserManager:
|
||||
|
||||
@@ -7,7 +7,7 @@ import yaml
|
||||
import pandas as pd
|
||||
from ...data import D
|
||||
from ...log import get_module_logger
|
||||
from ...utils import get_module_by_module_path
|
||||
from ...utils import get_module_by_module_path, init_instance_by_config
|
||||
from ...utils import get_next_trading_date
|
||||
from ..backtest.exchange import Exchange
|
||||
|
||||
@@ -45,21 +45,6 @@ def save_instance(instance, file_path):
|
||||
pickle.dump(instance, fr)
|
||||
|
||||
|
||||
def init_instance_by_config(config):
|
||||
"""
|
||||
generate an instance with settings in config
|
||||
Parameter
|
||||
config : dict
|
||||
python dict indicate a init parameters to create an item
|
||||
:return
|
||||
An instance
|
||||
"""
|
||||
module = get_module_by_module_path(config["module_path"])
|
||||
instance_class = getattr(module, config["class"])
|
||||
instance = instance_class(**config["args"])
|
||||
return instance
|
||||
|
||||
|
||||
def create_user_folder(path):
|
||||
path = pathlib.Path(path)
|
||||
if path.exists():
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
import abc
|
||||
import bisect
|
||||
import logging
|
||||
from typing import Union
|
||||
from typing import Union, Tuple
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
@@ -13,11 +13,13 @@ import numpy as np
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...data import D
|
||||
from ...config import C
|
||||
from ...utils import parse_config, transform_end_date
|
||||
from ...utils import parse_config, transform_end_date, init_instance_by_config
|
||||
from ...utils.serial import Serializable
|
||||
from pathlib import Path
|
||||
from .loader import DataLoader
|
||||
|
||||
from . import processor as processor_module
|
||||
from . import loader as data_loader_module
|
||||
|
||||
|
||||
# TODO: A more general handler interface which does not relies on internal pd.DataFrame is needed.
|
||||
@@ -30,44 +32,57 @@ class DataHandler(Serializable):
|
||||
|
||||
The data handler try to maintain a handler with 2 level.
|
||||
`datetime` & `instruments`.
|
||||
|
||||
|
||||
Any order of the index level can be suported(The order will implied in the data).
|
||||
The order <`datetime`, `instruments`> will be used when the dataframe index name is missed.
|
||||
|
||||
Example of the data:
|
||||
|
||||
$close $volume Ref($close, 1) Mean($close, 3) $high-$low
|
||||
The multi-index of the columns is optional.
|
||||
feature label
|
||||
$close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
|
||||
datetime instrument
|
||||
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325
|
||||
SH600006 22.672380 7095624.0 22.508326 22.573947 0.557785
|
||||
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
|
||||
|
||||
'''
|
||||
def __init__(self, init_data=True):
|
||||
def __init__(self, instruments, start_time=None, end_time=None, data_loader: Tuple[dict, str, DataLoader]=None, init_data=True):
|
||||
# Set logger
|
||||
self.logger = get_module_logger("DataHandler")
|
||||
|
||||
# Setup data.
|
||||
self._data = {}
|
||||
# Setup data loader
|
||||
assert(data_loader is not None) # to make start_time end_time could have None default value
|
||||
self.data_loader = init_instance_by_config(data_loader, data_loader_module, accept_types=DataLoader)
|
||||
|
||||
self.instruments = instruments
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
if init_data:
|
||||
self.init()
|
||||
super().__init__()
|
||||
|
||||
def init(self, force_reload: bool=True):
|
||||
def init(self, enable_cache: bool=True):
|
||||
"""
|
||||
initialize the data.
|
||||
In case of running intialization for multiple time, it will do nothing for the second time.
|
||||
|
||||
It is responsible for maintaining following variable
|
||||
1) self._data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
force_reload : bool
|
||||
force to reload the data even if the data have been initialized
|
||||
enable_cache : bool
|
||||
default value is false
|
||||
if `enable_cache` == True
|
||||
the processed data will be saved on disk, and handler will load the cached data from the disk directly
|
||||
when we call `init` next time
|
||||
"""
|
||||
pass
|
||||
# if force_reload or hasattr(self, '_initialized', False):
|
||||
# Setup data.
|
||||
# _data may be with multiple column index level. The outer level indicates the feature set name
|
||||
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time)
|
||||
# TODO: cache
|
||||
|
||||
def get_level_index(self, df: pd.DataFrame, level=Union[str, int]) -> int:
|
||||
def _get_level_index(self, df: pd.DataFrame, level=Union[str, int]) -> int:
|
||||
"""
|
||||
|
||||
get the level index of `df` given `level`
|
||||
@@ -88,40 +103,78 @@ class DataHandler(Serializable):
|
||||
try:
|
||||
return df.index.names.index(level)
|
||||
except (AttributeError, ValueError):
|
||||
# NOTE: If level index is not given in the data, the default level index will be ('datetime', 'instrument')
|
||||
# NOTE: If level index is not given in the data, the default level index will be ('datetime', 'instrument')
|
||||
return ('datetime', 'instrument').index(level)
|
||||
elif isinstance(level, int):
|
||||
return level
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def _fetch_df(self, df: pd.DataFrame, selector: Union[pd.Timestamp, slice, str, list], level: Union[str, int]):
|
||||
def _fetch_df_by_index(self, df: pd.DataFrame, selector: Union[pd.Timestamp, slice, str, list], level: Union[str, int]) -> pd.DataFrame:
|
||||
"""
|
||||
fetch data from `data` with `selector` and `level`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
the data frame to be selected
|
||||
selector : Union[pd.Timestamp, slice, str, list]
|
||||
selector
|
||||
level : Union[pd.Timestamp, slice, str]
|
||||
level : Union[int, str]
|
||||
the level to use the selector
|
||||
"""
|
||||
# Try to get the right index
|
||||
idx_slc = (selector, slice(None, None))
|
||||
if self.get_level_index(df, level) == 1:
|
||||
idx_slc = idx_slc[1], idx_slc[0]
|
||||
if self._get_level_index(df, level) == 1:
|
||||
idx_slc = idx_slc[1], idx_slc[0]
|
||||
return df.loc(axis=0)[idx_slc]
|
||||
|
||||
def fetch(self, selector: Union[pd.Timestamp, slice, str], level='datetime', key=None) -> Union[pd.DataFrame, dict]:
|
||||
if key is None:
|
||||
res = {}
|
||||
for k, df in self._data.items():
|
||||
res[k] = self._fetch_df(df, selector, level)
|
||||
|
||||
CS_ALL = '_all'
|
||||
|
||||
def _fetch_df_by_col(self, df: pd.DataFrame, col_set: str) -> pd.DataFrame:
|
||||
cln = len(df.columns.levels)
|
||||
if cln == 1:
|
||||
return df
|
||||
elif col_set == self.CS_ALL:
|
||||
return df.droplevel(axis=1, level=0)
|
||||
else:
|
||||
res = self._fetch_df(self._data[key], selector, level)
|
||||
return res
|
||||
return df.loc(axis=1)[col_set]
|
||||
|
||||
def fetch(self, selector: Union[pd.Timestamp, slice, str], level: Union[str, int]='datetime', col_set=CS_ALL) -> pd.DataFrame:
|
||||
"""
|
||||
fetch data from underlying data source
|
||||
|
||||
Parameters
|
||||
----------
|
||||
selector : Union[pd.Timestamp, slice, str]
|
||||
describe how to select data by index
|
||||
level : Union[str, int]
|
||||
which index level to select the data
|
||||
col_set : str
|
||||
select a set of meaningful columns.(e.g. features, columns)
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
"""
|
||||
df = self._fetch_df_by_index(self._data, selector, level)
|
||||
return self._fetch_df_by_col(df, col_set)
|
||||
|
||||
def get_cols(self, col_set=CS_ALL) -> list:
|
||||
"""
|
||||
get the column names
|
||||
|
||||
Parameters
|
||||
----------
|
||||
col_set : str
|
||||
select a set of meaningful columns.(e.g. features, columns)
|
||||
|
||||
Returns
|
||||
-------
|
||||
list:
|
||||
list of column names
|
||||
"""
|
||||
df = self._data.head()
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
return df.columns.to_list()
|
||||
|
||||
|
||||
class DataHandlerLP(DataHandler):
|
||||
@@ -142,14 +195,13 @@ class DataHandlerLP(DataHandler):
|
||||
# - _proc_learn_df will be processed by infer_processors + learn_processors
|
||||
# - (e.g. _proc_infer_df processed by learn_processors )
|
||||
|
||||
def __init__(self, infer_processors=[], learn_processors=[], process_type=PTYPE_A, **kwargs):
|
||||
def __init__(self, instruments, start_time=None, end_time=None, data_loader: Tuple[dict, str, DataLoader]=None, infer_processors=[], learn_processors=[], process_type=PTYPE_A, **kwargs):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
infer_processors : list
|
||||
list of <description info> of processors to generate data for inference
|
||||
example of <description info>:
|
||||
example of <description info>:
|
||||
1) classname & kwargs:
|
||||
{
|
||||
"class": "MinMaxNorm",
|
||||
@@ -180,24 +232,18 @@ class DataHandlerLP(DataHandler):
|
||||
self.learn_processors = [] # for lint
|
||||
for pname in 'infer_processors', 'learn_processors':
|
||||
for proc in locals()[pname]:
|
||||
getattr(self, pname).append(processor_module.init_proc_obj(proc))
|
||||
getattr(self, pname).append(init_instance_by_config(proc, processor_module,
|
||||
accept_types=(processor_module.Processor,)))
|
||||
|
||||
self.process_type = process_type
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(instruments, start_time, end_time, data_loader, **kwargs)
|
||||
|
||||
def get_all_processors(self):
|
||||
return self.infer_processors + self.learn_processors
|
||||
|
||||
def _init_raw_data(self):
|
||||
"""
|
||||
initialize the raw data
|
||||
the raw data will be saved in to `self._data['raw']`
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `_init_raw_data` method")
|
||||
|
||||
def fit(self):
|
||||
for proc in self.get_all_processors():
|
||||
proc.fit(self)
|
||||
proc.fit(self._data)
|
||||
|
||||
def fit_process_data(self):
|
||||
"""
|
||||
@@ -206,7 +252,7 @@ class DataHandlerLP(DataHandler):
|
||||
The input of the `fit` will be the output of the previous processor
|
||||
"""
|
||||
self.process_data(with_fit=True)
|
||||
|
||||
|
||||
|
||||
def process_data(self, with_fit: bool=False):
|
||||
"""
|
||||
@@ -218,50 +264,56 @@ class DataHandlerLP(DataHandler):
|
||||
The input of the `fit` will be the output of the previous processor
|
||||
"""
|
||||
# data for inference
|
||||
_infer_df = self._data[DataHandlerLP.DK_R]
|
||||
_infer_df = self._data
|
||||
if len(self.infer_processors) > 0: # avoid modifying the original data
|
||||
_infer_df = _infer_df.copy()
|
||||
|
||||
for proc in self.infer_processors:
|
||||
if not proc.is_for_infer():
|
||||
raise TypeError("Only processors usable for inference can be used in `infer_processors` ")
|
||||
if with_fit:
|
||||
proc.fit(self, _infer_df)
|
||||
proc.fit(_infer_df)
|
||||
_infer_df = proc(_infer_df)
|
||||
self._infer = _infer_df
|
||||
|
||||
# data for learning
|
||||
if self.process_type == DataHandlerLP.PTYPE_I:
|
||||
_learn_df = self._data[DataHandlerLP.DK_R]
|
||||
_learn_df = self._data
|
||||
elif self.process_type == DataHandlerLP.PTYPE_A:
|
||||
# based on `infer_df` and append the processor
|
||||
_learn_df = _infer_df
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
if len(self.learn_processors) > 0: # avoid modifying the original data
|
||||
_learn_df = _learn_df.copy()
|
||||
for proc in self.learn_processors:
|
||||
if with_fit:
|
||||
proc.fit(self, _learn_df)
|
||||
proc.fit(_learn_df)
|
||||
_learn_df = proc(_learn_df)
|
||||
|
||||
self._data.update({
|
||||
DataHandlerLP.DK_I: _infer_df,
|
||||
DataHandlerLP.DK_L: _learn_df,
|
||||
})
|
||||
self._learn = _learn_df
|
||||
|
||||
# init type
|
||||
IT_FIT_SEQ = 'fit_seq' # the input of `fit` will be the output of the previous processor
|
||||
IT_FIT_IND = 'fit_ind' # the input of `fit` will be the original df
|
||||
IT_LS = 'load_state' # The state of the object has been load by pickle
|
||||
|
||||
def init(self, init_type: str=IT_FIT_SEQ, path: Path=None):
|
||||
|
||||
def init(self, init_type: str=IT_FIT_SEQ, enable_cache: bool=False):
|
||||
"""
|
||||
Initialize the data of Qlib
|
||||
|
||||
Parameters
|
||||
----------
|
||||
init_type : str
|
||||
'fit' or 'load_state'
|
||||
path : path
|
||||
if `init_type` == 'load_state': `path` will be used to load_state
|
||||
The type `IT_*` listed above
|
||||
enable_cache : bool
|
||||
default value is false
|
||||
if `enable_cache` == True:
|
||||
the processed data will be saved on disk, and handler will load the cached data from the disk directly
|
||||
when we call `init` next time
|
||||
"""
|
||||
self._init_raw_data()
|
||||
# init raw data
|
||||
super().init(enable_cache=enable_cache)
|
||||
|
||||
if init_type == DataHandlerLP.IT_FIT_IND:
|
||||
self.fit()
|
||||
@@ -275,311 +327,53 @@ class DataHandlerLP(DataHandler):
|
||||
|
||||
# TODO: Be able to cache handler data. Save the memory for data processing
|
||||
|
||||
def _get_df_by_key(self, data_key: str=DK_I) -> pd.DataFrame:
|
||||
df = getattr(self, {self.DK_R: '_data', self.DK_I: "_infer", self.DK_L: "_learn"}[data_key])
|
||||
return df
|
||||
|
||||
class DataHandlerLPWL(DataHandlerLP):
|
||||
'''
|
||||
DataHandler with (L)earnable (P)rocessor with (L)abel
|
||||
'''
|
||||
|
||||
def _init_raw_data(self):
|
||||
def fetch(self,
|
||||
selector: Union[pd.Timestamp, slice, str],
|
||||
level: Union[str, int] = 'datetime',
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key: str = DK_I) -> pd.DataFrame:
|
||||
"""
|
||||
init raw_df, feature_names, label_names of DataHandler
|
||||
if the index of df_feature and df_label are not same, user need to overload this method to merge (e.g. inner, left, right merge).
|
||||
fetch data from underlying data source
|
||||
|
||||
Parameters
|
||||
----------
|
||||
selector : Union[pd.Timestamp, slice, str]
|
||||
describe how to select data by index
|
||||
level : Union[str, int]
|
||||
which index level to select the data
|
||||
col_set : str
|
||||
select a set of meaningful columns.(e.g. features, columns)
|
||||
data_key: str
|
||||
The data to fetch: DK_*
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
"""
|
||||
df_features = self.load_feature()
|
||||
feature_names = df_features.columns
|
||||
df = self._get_df_by_key(data_key)
|
||||
df = self._fetch_df_by_index(df, selector, level)
|
||||
return self._fetch_df_by_col(df, col_set)
|
||||
|
||||
df_labels = self.load_label()
|
||||
label_names = df_labels.columns
|
||||
|
||||
raw_df = df_features.merge(df_labels, left_index=True, right_index=True, how="left")
|
||||
self.feature_names = feature_names
|
||||
self.label_names = label_names
|
||||
self._data['raw'] = raw_df
|
||||
|
||||
def load_feature(self):
|
||||
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str=DK_I) -> list:
|
||||
"""
|
||||
Implement this method to load raw feature.
|
||||
the format of the feature is below
|
||||
return: df_features
|
||||
get the column names
|
||||
|
||||
Parameters
|
||||
----------
|
||||
col_set : str
|
||||
select a set of meaningful columns.(e.g. features, columns)
|
||||
data_key: str
|
||||
The data to fetch: DK_*
|
||||
|
||||
Returns
|
||||
-------
|
||||
list:
|
||||
list of column names
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement `load_feature`")
|
||||
|
||||
def load_label(self):
|
||||
"""
|
||||
Implement this method to load and calculate label.
|
||||
the format of the label is below
|
||||
|
||||
return: df_label
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement `load_label`")
|
||||
|
||||
def get_feature_names(self):
|
||||
return self.feature_names
|
||||
|
||||
def get_label_names(self):
|
||||
return self.label_names
|
||||
|
||||
|
||||
class QLibDataHandler(DataHandlerLPWL):
|
||||
def __init__(self, start_date, end_date, *args, **kwargs):
|
||||
# Dates.
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
|
||||
# Instruments
|
||||
instruments = kwargs.pop("instruments", None)
|
||||
if instruments is None:
|
||||
market = kwargs.pop("market", "csi500").lower()
|
||||
data_filter_list = kwargs.pop("data_filter_list", list())
|
||||
self.instruments = D.instruments(market, filter_pipe=data_filter_list)
|
||||
else:
|
||||
self.instruments = instruments
|
||||
|
||||
# Config of features and labels
|
||||
self._fields = kwargs.pop("fields", [])
|
||||
self._names = kwargs.pop("names", [])
|
||||
self._labels = kwargs.pop("labels", [])
|
||||
self._label_names = kwargs.pop("label_names", [])
|
||||
|
||||
# Check arguments
|
||||
assert len(self._fields) > 0, "features list is empty"
|
||||
assert len(self._labels) > 0, "labels list is empty"
|
||||
|
||||
# Check end_date
|
||||
# If test_end_date is -1 or greater than the last date, the last date is used
|
||||
self.end_date = transform_end_date(self.end_date)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def load_feature(self):
|
||||
"""
|
||||
Load the raw data.
|
||||
return: df_features
|
||||
"""
|
||||
TimeInspector.set_time_mark()
|
||||
|
||||
if len(self._names) == 0:
|
||||
names = ["F%d" % i for i in range(len(self._fields))]
|
||||
else:
|
||||
names = self._names
|
||||
|
||||
df_features = D.features(self.instruments, self._fields, self.start_date, self.end_date)
|
||||
df_features.columns = names
|
||||
|
||||
TimeInspector.log_cost_time("Finished loading features.")
|
||||
|
||||
return df_features
|
||||
|
||||
def load_label(self):
|
||||
"""
|
||||
Build up labels in df through users' method
|
||||
:return: df_labels
|
||||
"""
|
||||
TimeInspector.set_time_mark()
|
||||
|
||||
if len(self._label_names) == 0:
|
||||
label_names = ["LABEL%d" % i for i in range(len(self._labels))]
|
||||
else:
|
||||
label_names = self._label_names
|
||||
|
||||
df_labels = D.features(self.instruments, self._labels, self.start_date, self.end_date)
|
||||
df_labels.columns = label_names
|
||||
|
||||
TimeInspector.log_cost_time("Finished loading labels.")
|
||||
|
||||
return df_labels
|
||||
|
||||
|
||||
def parse_config_to_fields(config):
|
||||
"""create factors from config
|
||||
|
||||
config = {
|
||||
'kbar': {}, # whether to use some hard-code kbar features
|
||||
'price': { # whether to use raw price features
|
||||
'windows': [0, 1, 2, 3, 4], # use price at n days ago
|
||||
'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
|
||||
},
|
||||
'volume': { # whether to use raw volume features
|
||||
'windows': [0, 1, 2, 3, 4], # use volume at n days ago
|
||||
},
|
||||
'rolling': { # whether to use rolling operator based features
|
||||
'windows': [5, 10, 20, 30, 60], # rolling windows size
|
||||
'include': ['ROC', 'MA', 'STD'], # rolling operator to use
|
||||
#if include is None we will use default operators
|
||||
'exclude': ['RANK'], # rolling operator not to use
|
||||
}
|
||||
}
|
||||
"""
|
||||
fields = []
|
||||
names = []
|
||||
if "kbar" in config:
|
||||
fields += [
|
||||
"($close-$open)/$open",
|
||||
"($high-$low)/$open",
|
||||
"($close-$open)/($high-$low+1e-12)",
|
||||
"($high-Greater($open, $close))/$open",
|
||||
"($high-Greater($open, $close))/($high-$low+1e-12)",
|
||||
"(Less($open, $close)-$low)/$open",
|
||||
"(Less($open, $close)-$low)/($high-$low+1e-12)",
|
||||
"(2*$close-$high-$low)/$open",
|
||||
"(2*$close-$high-$low)/($high-$low+1e-12)",
|
||||
]
|
||||
names += [
|
||||
"KMID",
|
||||
"KLEN",
|
||||
"KMID2",
|
||||
"KUP",
|
||||
"KUP2",
|
||||
"KLOW",
|
||||
"KLOW2",
|
||||
"KSFT",
|
||||
"KSFT2",
|
||||
]
|
||||
if "price" in config:
|
||||
windows = config["price"].get("windows", range(5))
|
||||
feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
|
||||
for field in feature:
|
||||
field = field.lower()
|
||||
fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
|
||||
names += [field.upper() + str(d) for d in windows]
|
||||
if "volume" in config:
|
||||
windows = config["volume"].get("windows", range(5))
|
||||
fields += ["Ref($volume, %d)/$volume" % d if d != 0 else "$volume/$volume" for d in windows]
|
||||
names += ["VOLUME" + str(d) for d in windows]
|
||||
if "rolling" in config:
|
||||
windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
|
||||
include = config["rolling"].get("include", None)
|
||||
exclude = config["rolling"].get("exclude", [])
|
||||
# `exclude` in dataset config unnecessary filed
|
||||
# `include` in dataset config necessary field
|
||||
use = lambda x: x not in exclude and (include is None or x in include)
|
||||
if use("ROC"):
|
||||
fields += ["Ref($close, %d)/$close" % d for d in windows]
|
||||
names += ["ROC%d" % d for d in windows]
|
||||
if use("MA"):
|
||||
fields += ["Mean($close, %d)/$close" % d for d in windows]
|
||||
names += ["MA%d" % d for d in windows]
|
||||
if use("STD"):
|
||||
fields += ["Std($close, %d)/$close" % d for d in windows]
|
||||
names += ["STD%d" % d for d in windows]
|
||||
if use("BETA"):
|
||||
fields += ["Slope($close, %d)/$close" % d for d in windows]
|
||||
names += ["BETA%d" % d for d in windows]
|
||||
if use("RSQR"):
|
||||
fields += ["Rsquare($close, %d)" % d for d in windows]
|
||||
names += ["RSQR%d" % d for d in windows]
|
||||
if use("RESI"):
|
||||
fields += ["Resi($close, %d)/$close" % d for d in windows]
|
||||
names += ["RESI%d" % d for d in windows]
|
||||
if use("MAX"):
|
||||
fields += ["Max($high, %d)/$close" % d for d in windows]
|
||||
names += ["MAX%d" % d for d in windows]
|
||||
if use("LOW"):
|
||||
fields += ["Min($low, %d)/$close" % d for d in windows]
|
||||
names += ["MIN%d" % d for d in windows]
|
||||
if use("QTLU"):
|
||||
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
|
||||
names += ["QTLU%d" % d for d in windows]
|
||||
if use("QTLD"):
|
||||
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
|
||||
names += ["QTLD%d" % d for d in windows]
|
||||
if use("RANK"):
|
||||
fields += ["Rank($close, %d)" % d for d in windows]
|
||||
names += ["RANK%d" % d for d in windows]
|
||||
if use("RSV"):
|
||||
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
|
||||
names += ["RSV%d" % d for d in windows]
|
||||
if use("IMAX"):
|
||||
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMAX%d" % d for d in windows]
|
||||
if use("IMIN"):
|
||||
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMIN%d" % d for d in windows]
|
||||
if use("IMXD"):
|
||||
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
|
||||
names += ["IMXD%d" % d for d in windows]
|
||||
if use("CORR"):
|
||||
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
|
||||
names += ["CORR%d" % d for d in windows]
|
||||
if use("CORD"):
|
||||
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
|
||||
names += ["CORD%d" % d for d in windows]
|
||||
if use("CNTP"):
|
||||
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTP%d" % d for d in windows]
|
||||
if use("CNTN"):
|
||||
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTN%d" % d for d in windows]
|
||||
if use("CNTD"):
|
||||
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
|
||||
names += ["CNTD%d" % d for d in windows]
|
||||
if use("SUMP"):
|
||||
fields += [
|
||||
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMP%d" % d for d in windows]
|
||||
if use("SUMN"):
|
||||
fields += [
|
||||
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMN%d" % d for d in windows]
|
||||
if use("SUMD"):
|
||||
fields += [
|
||||
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
|
||||
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMD%d" % d for d in windows]
|
||||
if use("VMA"):
|
||||
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VMA%d" % d for d in windows]
|
||||
if use("VSTD"):
|
||||
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VSTD%d" % d for d in windows]
|
||||
if use("WVMA"):
|
||||
fields += [
|
||||
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["WVMA%d" % d for d in windows]
|
||||
if use("VSUMP"):
|
||||
fields += [
|
||||
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMP%d" % d for d in windows]
|
||||
if use("VSUMN"):
|
||||
fields += [
|
||||
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMN%d" % d for d in windows]
|
||||
if use("VSUMD"):
|
||||
fields += [
|
||||
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
|
||||
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMD%d" % d for d in windows]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class ConfigQLibDataHandler(QLibDataHandler):
|
||||
config_template = {} # template
|
||||
|
||||
def __init__(self, start_date, end_date, infer_processors=["ConfigSectionProcessor"], learn_processors=[], **kwargs):
|
||||
config = self.config_template.copy()
|
||||
if "config_update" in kwargs:
|
||||
config.update(kwargs["config_update"])
|
||||
fields, names = parse_config_to_fields(config)
|
||||
kwargs["fields"] = fields
|
||||
kwargs["names"] = names
|
||||
if "labels" not in kwargs:
|
||||
kwargs["labels"] = ["Ref($vwap, -2)/Ref($vwap, -1) - 1"]
|
||||
|
||||
super().__init__(start_date, end_date, infer_processors=infer_processors, learn_processors=learn_processors, **kwargs)
|
||||
df = self._get_df_by_key(data_key).head()
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
return df.columns.to_list()
|
||||
|
||||
274
qlib/data/dataset/loader.py
Normal file
274
qlib/data/dataset/loader.py
Normal file
@@ -0,0 +1,274 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import pandas as pd
|
||||
from qlib.data import D
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class DataLoader(ABC):
|
||||
'''
|
||||
DataLoader is designed for loading raw data from original data source.
|
||||
'''
|
||||
@abstractmethod
|
||||
def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
"""
|
||||
load the data as pd.DataFrame
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
data load from the under layer source
|
||||
|
||||
Example of the data:
|
||||
The multi-index of the columns is optional.
|
||||
feature label
|
||||
$close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
|
||||
datetime instrument
|
||||
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class QlibDataLoader(DataLoader):
|
||||
'''Same as QlibDataLoader. The fields can be define by config'''
|
||||
def __init__(self, config: Tuple[list, tuple, dict], group_fields: bool = False, filter_pipe=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
config : Tuple[list ,tuple, dict]
|
||||
Config will be used to describe the fields and column names
|
||||
|
||||
if `group_fields`:
|
||||
<config> := {
|
||||
"group_name1": <fields_info1>
|
||||
"group_name2": <fields_info2>
|
||||
}
|
||||
else:
|
||||
<config> := <fields_info>
|
||||
|
||||
<fields_info> := ["expr", ...] | (["expr", ...], ["col_name", ...]) | <fields_info_config>
|
||||
|
||||
<fields_info_config> is a config with dict type which could be parsed by `parse_config_to_fields`
|
||||
|
||||
Here is a few examples to describe the fields
|
||||
TODO:
|
||||
|
||||
group_fields : bool
|
||||
Will the fields be grouped. Multi-index will be used for the group
|
||||
"""
|
||||
if group_fields:
|
||||
fields_all = []
|
||||
name_grp_info = []
|
||||
for grp, fields_info in config.items():
|
||||
fields, names = self._parse_fields_info(fields_info)
|
||||
fields_all.extend(fields)
|
||||
name_grp_info.extend([(grp, n) for n in names])
|
||||
self.fields, self.names = fields_all, name_grp_info
|
||||
else:
|
||||
self.fields, self.names = self._parse_fields_info(fields_info)
|
||||
|
||||
self.group_fields = group_fields
|
||||
self.filter_pipe = filter_pipe
|
||||
|
||||
def _parse_fields_info(self, fields_info: Tuple[list, tuple, dict]) -> Tuple[list, list]:
|
||||
if isinstance(fields_info, dict):
|
||||
fields, names = parse_config_to_fields(fields_info)
|
||||
elif isinstance(fields_info, list):
|
||||
fields = fields_info
|
||||
names = fields
|
||||
elif isinstance(fields_info, tuple):
|
||||
fields, names = fields_info
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return fields, names
|
||||
|
||||
def load(self,
|
||||
instruments,
|
||||
config: Tuple[list, tuple, dict],
|
||||
group_fields=False,
|
||||
start_time=None,
|
||||
end_time=None) -> Tuple[pd.DataFrame, dict]:
|
||||
df = D.features(D.instruments(instruments, filter_pipe=self.filter_pipe), self.fields, start_time, end_time)
|
||||
df.columns = pd.MultiIndex.from_tuples(self.names) if self.group_fields else self.names
|
||||
df = df.swaplevel().sort_index()
|
||||
return df
|
||||
|
||||
|
||||
# TODO: make it easier to understand the config language
|
||||
def parse_config_to_fields(config):
|
||||
"""create factors from config
|
||||
|
||||
config = {
|
||||
'kbar': {}, # whether to use some hard-code kbar features
|
||||
'price': { # whether to use raw price features
|
||||
'windows': [0, 1, 2, 3, 4], # use price at n days ago
|
||||
'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
|
||||
},
|
||||
'volume': { # whether to use raw volume features
|
||||
'windows': [0, 1, 2, 3, 4], # use volume at n days ago
|
||||
},
|
||||
'rolling': { # whether to use rolling operator based features
|
||||
'windows': [5, 10, 20, 30, 60], # rolling windows size
|
||||
'include': ['ROC', 'MA', 'STD'], # rolling operator to use
|
||||
#if include is None we will use default operators
|
||||
'exclude': ['RANK'], # rolling operator not to use
|
||||
}
|
||||
}
|
||||
"""
|
||||
fields = []
|
||||
names = []
|
||||
if "kbar" in config:
|
||||
fields += [
|
||||
"($close-$open)/$open",
|
||||
"($high-$low)/$open",
|
||||
"($close-$open)/($high-$low+1e-12)",
|
||||
"($high-Greater($open, $close))/$open",
|
||||
"($high-Greater($open, $close))/($high-$low+1e-12)",
|
||||
"(Less($open, $close)-$low)/$open",
|
||||
"(Less($open, $close)-$low)/($high-$low+1e-12)",
|
||||
"(2*$close-$high-$low)/$open",
|
||||
"(2*$close-$high-$low)/($high-$low+1e-12)",
|
||||
]
|
||||
names += [
|
||||
"KMID",
|
||||
"KLEN",
|
||||
"KMID2",
|
||||
"KUP",
|
||||
"KUP2",
|
||||
"KLOW",
|
||||
"KLOW2",
|
||||
"KSFT",
|
||||
"KSFT2",
|
||||
]
|
||||
if "price" in config:
|
||||
windows = config["price"].get("windows", range(5))
|
||||
feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
|
||||
for field in feature:
|
||||
field = field.lower()
|
||||
fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
|
||||
names += [field.upper() + str(d) for d in windows]
|
||||
if "volume" in config:
|
||||
windows = config["volume"].get("windows", range(5))
|
||||
fields += ["Ref($volume, %d)/$volume" % d if d != 0 else "$volume/$volume" for d in windows]
|
||||
names += ["VOLUME" + str(d) for d in windows]
|
||||
if "rolling" in config:
|
||||
windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
|
||||
include = config["rolling"].get("include", None)
|
||||
exclude = config["rolling"].get("exclude", [])
|
||||
# `exclude` in dataset config unnecessary filed
|
||||
# `include` in dataset config necessary field
|
||||
use = lambda x: x not in exclude and (include is None or x in include)
|
||||
if use("ROC"):
|
||||
fields += ["Ref($close, %d)/$close" % d for d in windows]
|
||||
names += ["ROC%d" % d for d in windows]
|
||||
if use("MA"):
|
||||
fields += ["Mean($close, %d)/$close" % d for d in windows]
|
||||
names += ["MA%d" % d for d in windows]
|
||||
if use("STD"):
|
||||
fields += ["Std($close, %d)/$close" % d for d in windows]
|
||||
names += ["STD%d" % d for d in windows]
|
||||
if use("BETA"):
|
||||
fields += ["Slope($close, %d)/$close" % d for d in windows]
|
||||
names += ["BETA%d" % d for d in windows]
|
||||
if use("RSQR"):
|
||||
fields += ["Rsquare($close, %d)" % d for d in windows]
|
||||
names += ["RSQR%d" % d for d in windows]
|
||||
if use("RESI"):
|
||||
fields += ["Resi($close, %d)/$close" % d for d in windows]
|
||||
names += ["RESI%d" % d for d in windows]
|
||||
if use("MAX"):
|
||||
fields += ["Max($high, %d)/$close" % d for d in windows]
|
||||
names += ["MAX%d" % d for d in windows]
|
||||
if use("LOW"):
|
||||
fields += ["Min($low, %d)/$close" % d for d in windows]
|
||||
names += ["MIN%d" % d for d in windows]
|
||||
if use("QTLU"):
|
||||
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
|
||||
names += ["QTLU%d" % d for d in windows]
|
||||
if use("QTLD"):
|
||||
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
|
||||
names += ["QTLD%d" % d for d in windows]
|
||||
if use("RANK"):
|
||||
fields += ["Rank($close, %d)" % d for d in windows]
|
||||
names += ["RANK%d" % d for d in windows]
|
||||
if use("RSV"):
|
||||
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
|
||||
names += ["RSV%d" % d for d in windows]
|
||||
if use("IMAX"):
|
||||
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMAX%d" % d for d in windows]
|
||||
if use("IMIN"):
|
||||
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMIN%d" % d for d in windows]
|
||||
if use("IMXD"):
|
||||
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
|
||||
names += ["IMXD%d" % d for d in windows]
|
||||
if use("CORR"):
|
||||
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
|
||||
names += ["CORR%d" % d for d in windows]
|
||||
if use("CORD"):
|
||||
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
|
||||
names += ["CORD%d" % d for d in windows]
|
||||
if use("CNTP"):
|
||||
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTP%d" % d for d in windows]
|
||||
if use("CNTN"):
|
||||
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTN%d" % d for d in windows]
|
||||
if use("CNTD"):
|
||||
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
|
||||
names += ["CNTD%d" % d for d in windows]
|
||||
if use("SUMP"):
|
||||
fields += [
|
||||
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMP%d" % d for d in windows]
|
||||
if use("SUMN"):
|
||||
fields += [
|
||||
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMN%d" % d for d in windows]
|
||||
if use("SUMD"):
|
||||
fields += [
|
||||
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
|
||||
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d) for d in windows
|
||||
]
|
||||
names += ["SUMD%d" % d for d in windows]
|
||||
if use("VMA"):
|
||||
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VMA%d" % d for d in windows]
|
||||
if use("VSTD"):
|
||||
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VSTD%d" % d for d in windows]
|
||||
if use("WVMA"):
|
||||
fields += [
|
||||
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)" %
|
||||
(d, d) for d in windows
|
||||
]
|
||||
names += ["WVMA%d" % d for d in windows]
|
||||
if use("VSUMP"):
|
||||
fields += [
|
||||
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMP%d" % d for d in windows]
|
||||
if use("VSUMN"):
|
||||
fields += [
|
||||
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMN%d" % d for d in windows]
|
||||
if use("VSUMD"):
|
||||
fields += [
|
||||
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
|
||||
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d) for d in windows
|
||||
]
|
||||
names += ["VSUMD%d" % d for d in windows]
|
||||
|
||||
return fields, names
|
||||
@@ -12,16 +12,31 @@ from ...utils.serial import Serializable
|
||||
EPS = 1e-12
|
||||
|
||||
|
||||
def get_group_columns(df: pd.DataFrame, group: str):
|
||||
"""
|
||||
get a group of columns from multi-index columns DataFrame
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
with multi of columns
|
||||
group : str
|
||||
the name of the feature group, i.e. the first level value of the group index.
|
||||
"""
|
||||
if group is None:
|
||||
return df.columns
|
||||
else:
|
||||
return df.columns[df.columns.get_loc(group)]
|
||||
|
||||
|
||||
class Processor(Serializable):
|
||||
|
||||
def fit(self, handler, df: pd.DataFrame=None):
|
||||
def fit(self, df: pd.DataFrame=None):
|
||||
"""
|
||||
learn data processing parameters
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handler : DataHandlerLP
|
||||
The data handler to processing data
|
||||
df : pd.DataFrame
|
||||
When we fit and process data with processor one by one. The fit function reiles on the output of previous
|
||||
processor, i.e. `df`.
|
||||
@@ -34,7 +49,8 @@ class Processor(Serializable):
|
||||
"""
|
||||
process the data
|
||||
|
||||
NOTE: The processor should not change the content of `df`
|
||||
NOTE: **The processor could change the content of `df` inplace !!!!! **
|
||||
User should keep a copy of data outside
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -43,59 +59,10 @@ class Processor(Serializable):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def get_cls_kwargs(processor: [dict, str]) -> (type, dict):
|
||||
"""
|
||||
extract class and kwargs from processor info
|
||||
|
||||
Parameters
|
||||
----------
|
||||
processor : [dict, str]
|
||||
similar to processor
|
||||
|
||||
Returns
|
||||
-------
|
||||
(type, dict):
|
||||
the class object and it's arguments.
|
||||
"""
|
||||
if isinstance(processor, dict):
|
||||
# raise AttributeError
|
||||
klass = globals()[processor['class']]
|
||||
kwargs = processor['kwargs']
|
||||
elif isinstance(processor, str):
|
||||
klass = globals()[processor]
|
||||
kwargs = {}
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return klass, kwargs
|
||||
|
||||
|
||||
# Place the function here to be able to reference the Processor
|
||||
def init_proc_obj(processor: [dict, str, Processor]) -> Processor:
|
||||
"""
|
||||
Initialize Processor Object
|
||||
|
||||
Parameters
|
||||
----------
|
||||
processor : [dict, str, Processor]
|
||||
The info to initialize processor
|
||||
|
||||
Returns
|
||||
-------
|
||||
Processor:
|
||||
initialized Processor
|
||||
"""
|
||||
if not isinstance(processor, Processor):
|
||||
klass, pkwargs = get_cls_kwargs(processor)
|
||||
processor = klass(**pkwargs)
|
||||
return processor
|
||||
|
||||
|
||||
class InferProcessor(Processor):
|
||||
'''This processor is usable for inference'''
|
||||
def is_for_infer(self) -> bool:
|
||||
"""
|
||||
Is this processor usable for inference
|
||||
Some processors are not usable for inference.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -105,37 +72,24 @@ class InferProcessor(Processor):
|
||||
return True
|
||||
|
||||
|
||||
class NInferProcessor(Processor):
|
||||
'''This processor is not usable for inference'''
|
||||
def is_for_infer(self) -> bool:
|
||||
"""
|
||||
Is this processor usable for inference
|
||||
class DropnaProcessor(Processor):
|
||||
def __init__(self, group=None):
|
||||
self.group = group
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool:
|
||||
if it is usable for infenrece
|
||||
"""
|
||||
def __call__(self, df):
|
||||
return df.dropna(subset=get_group_columns(df, self.group))
|
||||
|
||||
|
||||
class DropnaLabel(DropnaProcessor):
|
||||
def __init__(self, group='label'):
|
||||
super().__init__(group=group)
|
||||
|
||||
def is_for_infer(self) -> bool:
|
||||
'''The samples are dropped according to label. So it is not usable for inference'''
|
||||
return False
|
||||
|
||||
|
||||
class DropnaFeature(InferProcessor):
|
||||
def fit(self, handler, df=None):
|
||||
self.feature_names = copy.deepcopy(handler.get_feature_names())
|
||||
|
||||
def __call__(self, df):
|
||||
return df.dropna(subset=self.feature_names)
|
||||
|
||||
|
||||
class DropnaLabel(InferProcessor):
|
||||
def fit(self, handler, df=None):
|
||||
self.label_names = copy.deepcopy(handler.get_label_names())
|
||||
|
||||
def __call__(self, df):
|
||||
return df.dropna(subset=self.label_names)
|
||||
|
||||
|
||||
class ProcessInf(InferProcessor):
|
||||
class ProcessInf(Processor):
|
||||
'''Process infinity '''
|
||||
def __call__(self, df):
|
||||
def replace_inf(data):
|
||||
@@ -151,22 +105,20 @@ class ProcessInf(InferProcessor):
|
||||
return replace_inf(df)
|
||||
|
||||
|
||||
class MinMaxNorm(InferProcessor):
|
||||
def __init__(self, fit_start_time, fit_end_time):
|
||||
class MinMaxNorm(Processor):
|
||||
def __init__(self, fit_start_time, fit_end_time, fields_group=None):
|
||||
self.fit_start_time = fit_start_time
|
||||
self.fit_end_time = fit_end_time
|
||||
self.fields_group = fields_group
|
||||
|
||||
def fit(self, handler, df):
|
||||
# TODO: 看看这里怎么取数据
|
||||
self.min_val = np.nanmin(df[handler.get_feature_names()].values, axis=0)
|
||||
self.max_val = np.nanmax(df[handler.get_feature_names()].values, axis=0)
|
||||
def fit(self, df):
|
||||
cols = get_group_columns(df, self.fields_group)
|
||||
self.min_val = np.nanmin(df[cols].values, axis=0)
|
||||
self.max_val = np.nanmax(df[cols].values, axis=0)
|
||||
self.ignore = self.min_val == self.max_val
|
||||
self.feature_names = copy.deepcopy(handler.get_feature_names())
|
||||
self.cols = cols
|
||||
|
||||
def __call__(self, df):
|
||||
# FIXME: The df will be changed inplace. It's very dangerous
|
||||
# The code below is ugly
|
||||
df = df.copy() # currently copy is used
|
||||
def normalize(x, min_val=self.min_val, max_val=self.max_val, ignore=self.ignore):
|
||||
if (~ignore).all():
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
@@ -174,25 +126,24 @@ class MinMaxNorm(InferProcessor):
|
||||
if not ignore[i]:
|
||||
x[i] = (x[i] - min_val) / (max_val - min_val)
|
||||
return x
|
||||
df.loc(axis=1)[self.feature_names] = normalize(df[self.feature_names].values)
|
||||
df.loc(axis=1)[self.cols] = normalize(df[self.cols].values)
|
||||
return df
|
||||
|
||||
|
||||
class ZscoreNorm(InferProcessor):
|
||||
def __init__(self, fit_start_time, fit_end_time):
|
||||
class ZscoreNorm(Processor):
|
||||
def __init__(self, fit_start_time, fit_end_time, fields_group=None):
|
||||
self.fit_start_time = fit_start_time
|
||||
self.fit_end_time = fit_end_time
|
||||
self.fields_group = fields_group
|
||||
|
||||
def fit(self, handler, df):
|
||||
self.mean_train = np.nanmean(df[handler.get_feature_names()].values, axis=0)
|
||||
self.std_train = np.nanstd(df[handler.get_feature_names()].values, axis=0)
|
||||
def fit(self, df):
|
||||
cols = get_group_columns(df, self.fields_group)
|
||||
self.mean_train = np.nanmean(df[cols].values, axis=0)
|
||||
self.std_train = np.nanstd(df[cols].values, axis=0)
|
||||
self.ignore = self.std_train == 0
|
||||
self.feature_names = handler.get_feature_names()
|
||||
self.cols = cols
|
||||
|
||||
def __call__(self, df):
|
||||
# FIXME: The df will be changed inplace. It's very dangerous
|
||||
# The code below is ugly
|
||||
df = df.copy() # currently copy is used
|
||||
def normalize(x, mean_train=self.mean_train, std_train=self.std_train, ignore=self.ignore):
|
||||
if (~ignore).all():
|
||||
return (x - mean_train) / std_train
|
||||
@@ -200,12 +151,27 @@ class ZscoreNorm(InferProcessor):
|
||||
if not ignore[i]:
|
||||
x[i] = (x[i] - mean_train) / std_train
|
||||
return x
|
||||
df.loc(axis=1)[self.feature_names] = normalize(df[self.feature_names].values)
|
||||
df.loc(axis=1)[self.cols] = normalize(df[self.cols].values)
|
||||
return df
|
||||
|
||||
|
||||
class ConfigSectionProcessor(InferProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
class CSZScoreNorm(Processor):
|
||||
'''Cross Sectional ZScore Normalization'''
|
||||
def __init__(self, fields_group=None):
|
||||
self.fields_group = fields_group
|
||||
|
||||
def __call__(self, df):
|
||||
# try not modify original dataframe
|
||||
cols = get_group_columns(df,self.fields_group)
|
||||
df[cols] = df[cols].groupby('datetime').apply(lambda df: (df - df.mean()).div(df.std()))
|
||||
return df
|
||||
|
||||
|
||||
# TODO: make the config language easier to understand
|
||||
class ConfigSectionProcessor(Processor):
|
||||
# TODO: this class is not well tested
|
||||
# FIXME: this will raise error when multi-index is passed in
|
||||
def __init__(self, fields_group=None, **kwargs):
|
||||
super().__init__()
|
||||
# Options
|
||||
self.fillna_feature = kwargs.get("fillna_feature", True)
|
||||
@@ -214,9 +180,7 @@ class ConfigSectionProcessor(InferProcessor):
|
||||
self.shrink_feature_outlier = kwargs.get("shrink_feature_outlier", True)
|
||||
self.clip_label_outlier = kwargs.get("clip_label_outlier", False)
|
||||
|
||||
def fit(self, handler, df=None):
|
||||
self.feature_names = handler.get_feature_names()
|
||||
self.label_names = handler.get_label_names()
|
||||
self.fields_group = None
|
||||
|
||||
def __call__(self, df):
|
||||
return self._transform(df)
|
||||
@@ -245,19 +209,22 @@ class ConfigSectionProcessor(InferProcessor):
|
||||
|
||||
TimeInspector.set_time_mark()
|
||||
|
||||
# Copy
|
||||
df_new = df.copy()
|
||||
# Copy the focus part and change it to single level
|
||||
selected_cols = get_group_columns(df, self.fields_group)
|
||||
df_focus = df[selected_cols].copy()
|
||||
if len(df_focus.columns.levels) > 1:
|
||||
df_focus = df_focus.droplevel(level=0)
|
||||
|
||||
# Label
|
||||
cols = df.columns[df.columns.str.contains("^LABEL")]
|
||||
df_new[cols] = df[cols].groupby(level="datetime").apply(_label_norm)
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^LABEL")]
|
||||
df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_label_norm)
|
||||
|
||||
# Features
|
||||
cols = df.columns[df.columns.str.contains("^KLEN|^KLOW|^KUP")]
|
||||
df_new[cols] = df[cols].apply(lambda x: x ** 0.25).groupby(level="datetime").apply(_feature_norm)
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^KLEN|^KLOW|^KUP")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.25).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df.columns[df.columns.str.contains("^KLOW2|^KUP2")]
|
||||
df_new[cols] = df[cols].apply(lambda x: x ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^KLOW2|^KUP2")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
_cols = [
|
||||
"KMID",
|
||||
@@ -282,27 +249,29 @@ class ConfigSectionProcessor(InferProcessor):
|
||||
"VSUMD",
|
||||
]
|
||||
pat = "|".join(["^" + x for x in _cols])
|
||||
cols = df.columns[df.columns.str.contains(pat) & (~df.columns.isin(["HIGH0", "LOW0"]))]
|
||||
df_new[cols] = df[cols].groupby(level="datetime").apply(_feature_norm)
|
||||
cols = df_focus.columns[df_focus.columns.str.contains(pat) & (~df_focus.columns.isin(["HIGH0", "LOW0"]))]
|
||||
df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df.columns[df.columns.str.contains("^STD|^VOLUME|^VMA|^VSTD")]
|
||||
df_new[cols] = df[cols].apply(np.log).groupby(level="datetime").apply(_feature_norm)
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^STD|^VOLUME|^VMA|^VSTD")]
|
||||
df_focus[cols] = df_focus[cols].apply(np.log).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df.columns[df.columns.str.contains("^RSQR")]
|
||||
df_new[cols] = df[cols].fillna(0).groupby(level="datetime").apply(_feature_norm)
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^RSQR")]
|
||||
df_focus[cols] = df_focus[cols].fillna(0).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df.columns[df.columns.str.contains("^MAX|^HIGH0")]
|
||||
df_new[cols] = df[cols].apply(lambda x: (x - 1) ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^MAX|^HIGH0")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: (x - 1) ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df.columns[df.columns.str.contains("^MIN|^LOW0")]
|
||||
df_new[cols] = df[cols].apply(lambda x: (1 - x) ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^MIN|^LOW0")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: (1 - x) ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df.columns[df.columns.str.contains("^CORR|^CORD")]
|
||||
df_new[cols] = df[cols].apply(np.exp).groupby(level="datetime").apply(_feature_norm)
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^CORR|^CORD")]
|
||||
df_focus[cols] = df_focus[cols].apply(np.exp).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df.columns[df.columns.str.contains("^WVMA")]
|
||||
df_new[cols] = df[cols].apply(np.log1p).groupby(level="datetime").apply(_feature_norm)
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^WVMA")]
|
||||
df_focus[cols] = df_focus[cols].apply(np.log1p).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
df[selected_cols] = df_focus.values
|
||||
|
||||
TimeInspector.log_cost_time("Finished preprocessing data.")
|
||||
|
||||
return df_new
|
||||
return df
|
||||
|
||||
23
qlib/log.py
23
qlib/log.py
@@ -8,6 +8,7 @@ import os
|
||||
import re
|
||||
from logging import config as logging_config
|
||||
from time import time
|
||||
from contextlib import contextmanager
|
||||
|
||||
from .config import C
|
||||
|
||||
@@ -79,6 +80,28 @@ class TimeInspector(object):
|
||||
cost_time = time() - cls.time_marks.pop()
|
||||
cls.timer_logger.info("Time cost: {0:.5f} | {1}".format(cost_time, info))
|
||||
|
||||
@contextmanager
|
||||
@classmethod
|
||||
def logt(cls, name="", show_start=False):
|
||||
"""logt.
|
||||
Log the time of the inside code
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name :
|
||||
name
|
||||
show_start :
|
||||
show_start
|
||||
"""
|
||||
if show_start:
|
||||
cls.timer_logger.info(f"Begin {name}")
|
||||
cls.set_time_mark()
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
pass
|
||||
cls.log_cost_time()
|
||||
|
||||
|
||||
def set_log_with_config(log_config: dict):
|
||||
"""set log with config
|
||||
|
||||
@@ -23,6 +23,7 @@ import contextlib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Union, Tuple
|
||||
|
||||
from ..config import C
|
||||
from ..log import get_module_logger
|
||||
@@ -164,6 +165,71 @@ def get_module_by_module_path(module_path):
|
||||
return module
|
||||
|
||||
|
||||
def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
||||
"""
|
||||
extract class and kwargs from config info
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : [dict, str]
|
||||
similar to config
|
||||
|
||||
module : Python module
|
||||
It should be a python module to load the class type
|
||||
|
||||
Returns
|
||||
-------
|
||||
(type, dict):
|
||||
the class object and it's arguments.
|
||||
"""
|
||||
if isinstance(config, dict):
|
||||
# raise AttributeError
|
||||
klass = getattr(module, config['class'])
|
||||
kwargs = config['kwargs']
|
||||
elif isinstance(config, str):
|
||||
klass = getattr(module, config)
|
||||
kwargs = {}
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return klass, kwargs
|
||||
|
||||
|
||||
def init_instance_by_config(config: Union[str, dict], module=None, accept_types: Tuple[type]=tuple([])) -> object:
|
||||
"""
|
||||
get initialized instance with config
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : Union[str, dict]
|
||||
dict example.
|
||||
{
|
||||
'class': 'ClassName',
|
||||
'kwargs': dict, # It is optional. {} will be used if not given
|
||||
'model_path': path, # It is optional if module is given
|
||||
}
|
||||
str example.
|
||||
"ClassName": getattr(module, config)() will be used.
|
||||
module : Python module
|
||||
Optional. It should be a python module.
|
||||
|
||||
accept_types: Tuple[type]
|
||||
Optional. If the config is a instance of specific type, return the config directly.
|
||||
|
||||
Returns
|
||||
-------
|
||||
object:
|
||||
An initialized object based on the config info
|
||||
"""
|
||||
if isinstance(config, accept_types):
|
||||
return config
|
||||
|
||||
if module is None:
|
||||
module = get_module_by_module_path(config["module_path"])
|
||||
|
||||
klass, kwargs = get_cls_kwargs(config, module)
|
||||
return klass(**kwargs)
|
||||
|
||||
|
||||
def compare_dict_value(src_data: dict, dst_data: dict):
|
||||
"""Compare dict value
|
||||
|
||||
|
||||
Reference in New Issue
Block a user