diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py index 51c056522..1254c0d26 100644 --- a/qlib/contrib/data/handler.py +++ b/qlib/contrib/data/handler.py @@ -3,7 +3,7 @@ from ...data.dataset.handler import DataHandlerLP from ...data.dataset.processor import Processor -from ...utils import get_cls_kwargs +from ...utils import get_callable_kwargs from ...data.dataset import processor as processor_module from ...log import TimeInspector from inspect import getfullargspec @@ -14,7 +14,7 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time): new_l = [] for p in proc_l: if not isinstance(p, Processor): - klass, pkwargs = get_cls_kwargs(p, processor_module) + klass, pkwargs = get_callable_kwargs(p, processor_module) args = getfullargspec(klass).args if "fit_start_time" in args and "fit_end_time" in args: assert ( diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 70be66d13..beef5c9fb 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -12,7 +12,7 @@ from typing import Tuple, Union, List, Type from qlib.data import D from qlib.data import filter as filter_module from qlib.data.filter import BaseDFilter -from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point, get_cls_kwargs +from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point, get_callable_kwargs from qlib.log import get_module_logger @@ -212,7 +212,7 @@ class QlibDataLoader(DLWParser): raise ValueError(f"sample method error, only pandas.DataFrame.resample is supported") elif isinstance(_method, dict): # module_path && func name - _method, _ = get_cls_kwargs(_method, obj_type="func") + _method, _ = get_callable_kwargs(_method) else: raise TypeError(f"sample_method only supports [str, dict], currently it is {_method}") return _method diff --git a/qlib/data/ops.py b/qlib/data/ops.py index e044533d3..a34b2ed35 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -15,7 +15,7 @@ from scipy.stats import percentileofscore from .base import Expression, ExpressionOps from ..log import get_module_logger -from ..utils import get_cls_kwargs +from ..utils import get_callable_kwargs try: from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi @@ -1513,7 +1513,7 @@ class OpsWrapper: """ for _operator in ops_list: if isinstance(_operator, dict): - _ops_class, _ = get_cls_kwargs(_operator) + _ops_class, _ = get_callable_kwargs(_operator) else: _ops_class = _operator diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 3f1ae8a96..fb731196d 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -17,7 +17,7 @@ from typing import Callable, List from qlib.data.dataset import Dataset from qlib.log import get_module_logger from qlib.model.base import Model -from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config +from qlib.utils import flatten_dict, get_callable_kwargs, init_instance_by_config from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord from qlib.workflow.recorder import Recorder @@ -71,7 +71,7 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder: if isinstance(records, dict): # prevent only one dict records = [records] for record in records: - cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp") + cls, kwargs = get_callable_kwargs(record, default_module="qlib.workflow.record_temp") if cls is SignalRecord: rconf = {"model": model, "dataset": dataset, "recorder": rec} else: diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 007cafbce..3934a1992 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -189,9 +189,7 @@ def get_module_by_module_path(module_path: Union[str, ModuleType]): return module -def get_cls_kwargs( - config: Union[dict, str], default_module: Union[str, ModuleType] = None, obj_type: str = "class" -) -> (type, dict): +def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict): """ extract class/func and kwargs from config info @@ -205,8 +203,6 @@ def get_cls_kwargs( This function will load class from the config['module_path'] first. If config['module_path'] doesn't exists, it will load the class from default_module. - obj_type: str - "class" or "func" Returns ------- (type, dict): @@ -216,16 +212,16 @@ def get_cls_kwargs( module = get_module_by_module_path(config.get("module_path", default_module)) # raise AttributeError - _obj = getattr(module, config[obj_type]) + _callable = getattr(module, config["class" if "class" in config else "func"]) kwargs = config.get("kwargs", {}) elif isinstance(config, str): module = get_module_by_module_path(default_module) - _obj = getattr(module, config) + _callable = getattr(module, config) kwargs = {} else: raise NotImplementedError(f"This type of input is not supported") - return _obj, kwargs + return _callable, kwargs def init_instance_by_config( @@ -276,7 +272,7 @@ def init_instance_by_config( with open(os.path.join(pr.netloc, pr.path), "rb") as f: return pickle.load(f) - klass, cls_kwargs = get_cls_kwargs(config, default_module=default_module) + klass, cls_kwargs = get_callable_kwargs(config, default_module=default_module) return klass(**cls_kwargs, **kwargs) diff --git a/qlib/workflow/online/utils.py b/qlib/workflow/online/utils.py index 2cd972494..0fdec7b34 100644 --- a/qlib/workflow/online/utils.py +++ b/qlib/workflow/online/utils.py @@ -11,7 +11,7 @@ from typing import List, Union from qlib.data.dataset import TSDatasetH from qlib.log import get_module_logger -from qlib.utils import get_cls_kwargs +from qlib.utils import get_callable_kwargs from qlib.utils.exceptions import LoadObjectError from qlib.workflow.online.update import PredUpdater from qlib.workflow.recorder import Recorder @@ -172,7 +172,7 @@ class OnlineToolR(OnlineTool): hist_ref = 0 task = rec.load_object("task") # Special treatment of historical dependencies - cls, kwargs = get_cls_kwargs(task["dataset"], default_module="qlib.data.dataset") + cls, kwargs = get_callable_kwargs(task["dataset"], default_module="qlib.data.dataset") if issubclass(cls, TSDatasetH): hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN) try: