mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 19:10:58 +08:00
get_cls_kwargs renamed get_callable_kwargs
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...data.dataset.processor import Processor
|
||||
from ...utils import get_cls_kwargs
|
||||
from ...utils import get_callable_kwargs
|
||||
from ...data.dataset import processor as processor_module
|
||||
from ...log import TimeInspector
|
||||
from inspect import getfullargspec
|
||||
@@ -14,7 +14,7 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
if not isinstance(p, Processor):
|
||||
klass, pkwargs = get_cls_kwargs(p, processor_module)
|
||||
klass, pkwargs = get_callable_kwargs(p, processor_module)
|
||||
args = getfullargspec(klass).args
|
||||
if "fit_start_time" in args and "fit_end_time" in args:
|
||||
assert (
|
||||
|
||||
@@ -12,7 +12,7 @@ from typing import Tuple, Union, List, Type
|
||||
from qlib.data import D
|
||||
from qlib.data import filter as filter_module
|
||||
from qlib.data.filter import BaseDFilter
|
||||
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point, get_cls_kwargs
|
||||
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point, get_callable_kwargs
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
|
||||
@@ -212,7 +212,7 @@ class QlibDataLoader(DLWParser):
|
||||
raise ValueError(f"sample method error, only pandas.DataFrame.resample is supported")
|
||||
elif isinstance(_method, dict):
|
||||
# module_path && func name
|
||||
_method, _ = get_cls_kwargs(_method, obj_type="func")
|
||||
_method, _ = get_callable_kwargs(_method)
|
||||
else:
|
||||
raise TypeError(f"sample_method only supports [str, dict], currently it is {_method}")
|
||||
return _method
|
||||
|
||||
@@ -15,7 +15,7 @@ from scipy.stats import percentileofscore
|
||||
|
||||
from .base import Expression, ExpressionOps
|
||||
from ..log import get_module_logger
|
||||
from ..utils import get_cls_kwargs
|
||||
from ..utils import get_callable_kwargs
|
||||
|
||||
try:
|
||||
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
|
||||
@@ -1513,7 +1513,7 @@ class OpsWrapper:
|
||||
"""
|
||||
for _operator in ops_list:
|
||||
if isinstance(_operator, dict):
|
||||
_ops_class, _ = get_cls_kwargs(_operator)
|
||||
_ops_class, _ = get_callable_kwargs(_operator)
|
||||
else:
|
||||
_ops_class = _operator
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Callable, List
|
||||
from qlib.data.dataset import Dataset
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config
|
||||
from qlib.utils import flatten_dict, get_callable_kwargs, init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
from qlib.workflow.recorder import Recorder
|
||||
@@ -71,7 +71,7 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
|
||||
if isinstance(records, dict): # prevent only one dict
|
||||
records = [records]
|
||||
for record in records:
|
||||
cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp")
|
||||
cls, kwargs = get_callable_kwargs(record, default_module="qlib.workflow.record_temp")
|
||||
if cls is SignalRecord:
|
||||
rconf = {"model": model, "dataset": dataset, "recorder": rec}
|
||||
else:
|
||||
|
||||
@@ -189,9 +189,7 @@ def get_module_by_module_path(module_path: Union[str, ModuleType]):
|
||||
return module
|
||||
|
||||
|
||||
def get_cls_kwargs(
|
||||
config: Union[dict, str], default_module: Union[str, ModuleType] = None, obj_type: str = "class"
|
||||
) -> (type, dict):
|
||||
def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict):
|
||||
"""
|
||||
extract class/func and kwargs from config info
|
||||
|
||||
@@ -205,8 +203,6 @@ def get_cls_kwargs(
|
||||
This function will load class from the config['module_path'] first.
|
||||
If config['module_path'] doesn't exists, it will load the class from default_module.
|
||||
|
||||
obj_type: str
|
||||
"class" or "func"
|
||||
Returns
|
||||
-------
|
||||
(type, dict):
|
||||
@@ -216,16 +212,16 @@ def get_cls_kwargs(
|
||||
module = get_module_by_module_path(config.get("module_path", default_module))
|
||||
|
||||
# raise AttributeError
|
||||
_obj = getattr(module, config[obj_type])
|
||||
_callable = getattr(module, config["class" if "class" in config else "func"])
|
||||
kwargs = config.get("kwargs", {})
|
||||
elif isinstance(config, str):
|
||||
module = get_module_by_module_path(default_module)
|
||||
|
||||
_obj = getattr(module, config)
|
||||
_callable = getattr(module, config)
|
||||
kwargs = {}
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return _obj, kwargs
|
||||
return _callable, kwargs
|
||||
|
||||
|
||||
def init_instance_by_config(
|
||||
@@ -276,7 +272,7 @@ def init_instance_by_config(
|
||||
with open(os.path.join(pr.netloc, pr.path), "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
klass, cls_kwargs = get_cls_kwargs(config, default_module=default_module)
|
||||
klass, cls_kwargs = get_callable_kwargs(config, default_module=default_module)
|
||||
return klass(**cls_kwargs, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from typing import List, Union
|
||||
from qlib.data.dataset import TSDatasetH
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils import get_cls_kwargs
|
||||
from qlib.utils import get_callable_kwargs
|
||||
from qlib.utils.exceptions import LoadObjectError
|
||||
from qlib.workflow.online.update import PredUpdater
|
||||
from qlib.workflow.recorder import Recorder
|
||||
@@ -172,7 +172,7 @@ class OnlineToolR(OnlineTool):
|
||||
hist_ref = 0
|
||||
task = rec.load_object("task")
|
||||
# Special treatment of historical dependencies
|
||||
cls, kwargs = get_cls_kwargs(task["dataset"], default_module="qlib.data.dataset")
|
||||
cls, kwargs = get_callable_kwargs(task["dataset"], default_module="qlib.data.dataset")
|
||||
if issubclass(cls, TSDatasetH):
|
||||
hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN)
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user