1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 19:10:58 +08:00

get_cls_kwargs renamed get_callable_kwargs

This commit is contained in:
zhupr
2021-08-28 10:16:45 +08:00
committed by you-n-g
parent 76a05f37a9
commit 6011a21308
6 changed files with 15 additions and 19 deletions

View File

@@ -3,7 +3,7 @@
from ...data.dataset.handler import DataHandlerLP
from ...data.dataset.processor import Processor
from ...utils import get_cls_kwargs
from ...utils import get_callable_kwargs
from ...data.dataset import processor as processor_module
from ...log import TimeInspector
from inspect import getfullargspec
@@ -14,7 +14,7 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time):
new_l = []
for p in proc_l:
if not isinstance(p, Processor):
klass, pkwargs = get_cls_kwargs(p, processor_module)
klass, pkwargs = get_callable_kwargs(p, processor_module)
args = getfullargspec(klass).args
if "fit_start_time" in args and "fit_end_time" in args:
assert (

View File

@@ -12,7 +12,7 @@ from typing import Tuple, Union, List, Type
from qlib.data import D
from qlib.data import filter as filter_module
from qlib.data.filter import BaseDFilter
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point, get_cls_kwargs
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point, get_callable_kwargs
from qlib.log import get_module_logger
@@ -212,7 +212,7 @@ class QlibDataLoader(DLWParser):
raise ValueError(f"sample method error, only pandas.DataFrame.resample is supported")
elif isinstance(_method, dict):
# module_path && func name
_method, _ = get_cls_kwargs(_method, obj_type="func")
_method, _ = get_callable_kwargs(_method)
else:
raise TypeError(f"sample_method only supports [str, dict], currently it is {_method}")
return _method

View File

@@ -15,7 +15,7 @@ from scipy.stats import percentileofscore
from .base import Expression, ExpressionOps
from ..log import get_module_logger
from ..utils import get_cls_kwargs
from ..utils import get_callable_kwargs
try:
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
@@ -1513,7 +1513,7 @@ class OpsWrapper:
"""
for _operator in ops_list:
if isinstance(_operator, dict):
_ops_class, _ = get_cls_kwargs(_operator)
_ops_class, _ = get_callable_kwargs(_operator)
else:
_ops_class = _operator

View File

@@ -17,7 +17,7 @@ from typing import Callable, List
from qlib.data.dataset import Dataset
from qlib.log import get_module_logger
from qlib.model.base import Model
from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config
from qlib.utils import flatten_dict, get_callable_kwargs, init_instance_by_config
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord
from qlib.workflow.recorder import Recorder
@@ -71,7 +71,7 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
if isinstance(records, dict): # prevent only one dict
records = [records]
for record in records:
cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp")
cls, kwargs = get_callable_kwargs(record, default_module="qlib.workflow.record_temp")
if cls is SignalRecord:
rconf = {"model": model, "dataset": dataset, "recorder": rec}
else:

View File

@@ -189,9 +189,7 @@ def get_module_by_module_path(module_path: Union[str, ModuleType]):
return module
def get_cls_kwargs(
config: Union[dict, str], default_module: Union[str, ModuleType] = None, obj_type: str = "class"
) -> (type, dict):
def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict):
"""
extract class/func and kwargs from config info
@@ -205,8 +203,6 @@ def get_cls_kwargs(
This function will load class from the config['module_path'] first.
If config['module_path'] doesn't exists, it will load the class from default_module.
obj_type: str
"class" or "func"
Returns
-------
(type, dict):
@@ -216,16 +212,16 @@ def get_cls_kwargs(
module = get_module_by_module_path(config.get("module_path", default_module))
# raise AttributeError
_obj = getattr(module, config[obj_type])
_callable = getattr(module, config["class" if "class" in config else "func"])
kwargs = config.get("kwargs", {})
elif isinstance(config, str):
module = get_module_by_module_path(default_module)
_obj = getattr(module, config)
_callable = getattr(module, config)
kwargs = {}
else:
raise NotImplementedError(f"This type of input is not supported")
return _obj, kwargs
return _callable, kwargs
def init_instance_by_config(
@@ -276,7 +272,7 @@ def init_instance_by_config(
with open(os.path.join(pr.netloc, pr.path), "rb") as f:
return pickle.load(f)
klass, cls_kwargs = get_cls_kwargs(config, default_module=default_module)
klass, cls_kwargs = get_callable_kwargs(config, default_module=default_module)
return klass(**cls_kwargs, **kwargs)

View File

@@ -11,7 +11,7 @@ from typing import List, Union
from qlib.data.dataset import TSDatasetH
from qlib.log import get_module_logger
from qlib.utils import get_cls_kwargs
from qlib.utils import get_callable_kwargs
from qlib.utils.exceptions import LoadObjectError
from qlib.workflow.online.update import PredUpdater
from qlib.workflow.recorder import Recorder
@@ -172,7 +172,7 @@ class OnlineToolR(OnlineTool):
hist_ref = 0
task = rec.load_object("task")
# Special treatment of historical dependencies
cls, kwargs = get_cls_kwargs(task["dataset"], default_module="qlib.data.dataset")
cls, kwargs = get_callable_kwargs(task["dataset"], default_module="qlib.data.dataset")
if issubclass(cls, TSDatasetH):
hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN)
try: