diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py index 970b032d6..be2016ea3 100644 --- a/qlib/contrib/data/handler.py +++ b/qlib/contrib/data/handler.py @@ -26,6 +26,7 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time): "fit_end_time": fit_end_time, } ) + # FIXME: the `module_path` parameter is missed. new_l.append({"class": klass.__name__, "kwargs": pkwargs}) else: new_l.append(p) diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index 5a06f66be..972a1294a 100755 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import abc +from typing import Union, Text import numpy as np import pandas as pd import copy @@ -14,7 +15,7 @@ from ...utils.paral import datetime_groupby_apply EPS = 1e-12 -def get_group_columns(df: pd.DataFrame, group: str): +def get_group_columns(df: pd.DataFrame, group: Union[Text, None]): """ get a group of columns from multi-index columns DataFrame diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index a4df92218..57132a33e 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from qlib.utils import init_instance_by_config, flatten_dict +from qlib.utils import init_instance_by_config, flatten_dict, get_cls_kwargs from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord @@ -41,16 +41,11 @@ def task_train(task_config: dict, experiment_name: str) -> str: if isinstance(records, dict): # prevent only one dict records = [records] for record in records: - if record["class"] == SignalRecord.__name__: - srconf = {"model": model, "dataset": dataset, "recorder": recorder} - record.setdefault("kwargs", {}) - record["kwargs"].update(srconf) - sr = init_instance_by_config(record) - sr.generate() + cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp") + if cls is SignalRecord: + rconf = {"model": model, "dataset": dataset, "recorder": recorder} else: rconf = {"recorder": recorder} - record.setdefault("kwargs", {}) - record["kwargs"].update(rconf) - ar = init_instance_by_config(record) - ar.generate() + r = cls(**kwargs, **rconf) + r.generate() return recorder.info["id"] diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index be7969b65..606d007a8 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -24,7 +24,8 @@ import collections import numpy as np import pandas as pd from pathlib import Path -from typing import Union, Tuple +from typing import Union, Tuple, Any +from types import ModuleType from ..config import C from ..log import get_module_logger, set_log_with_config @@ -165,24 +166,25 @@ def parse_field(field): return re.sub(r"\$(\w+)", r'Feature("\1")', re.sub(r"(\w+\s*)\(", r"Operators.\1(", field)) -def get_module_by_module_path(module_path): +def get_module_by_module_path(module_path: Union[str, ModuleType]): """Load module path :param module_path: :return: """ - - if module_path.endswith(".py"): - module_spec = importlib.util.spec_from_file_location("", module_path) - module = importlib.util.module_from_spec(module_spec) - module_spec.loader.exec_module(module) + if isinstance(module_path, ModuleType): + module = module_path else: - module = importlib.import_module(module_path) - + if module_path.endswith(".py"): + module_spec = importlib.util.spec_from_file_location("", module_path) + module = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(module) + else: + module = importlib.import_module(module_path) return module -def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict): +def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType]=None) -> (type, dict): """ extract class and kwargs from config info @@ -191,8 +193,10 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict): config : [dict, str] similar to config - module : Python module + default_module : Python module or str It should be a python module to load the class type + This function will load class from the config['module_path'] first. + If config['module_path'] doesn't exists, it will load the class from default_module. Returns ------- @@ -200,10 +204,14 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict): the class object and it's arguments. """ if isinstance(config, dict): + module = get_module_by_module_path(config.get("module_path", default_module)) + # raise AttributeError klass = getattr(module, config["class"]) kwargs = config.get("kwargs", {}) elif isinstance(config, str): + module = get_module_by_module_path(default_module) + klass = getattr(module, config) kwargs = {} else: @@ -212,8 +220,8 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict): def init_instance_by_config( - config: Union[str, dict, object], module=None, accept_types: Union[type, Tuple[type]] = tuple([]), **kwargs -) -> object: + config: Union[str, dict, object], default_module=None, accept_types: Union[type, Tuple[type]] = tuple([]), **kwargs +) -> Any: """ get initialized instance with config @@ -230,10 +238,13 @@ def init_instance_by_config( "ClassName": getattr(module, config)() will be used. object example: instance of accept_types - module : Python module + default_module : Python module Optional. It should be a python module. NOTE: the "module_path" will be override by `module` arguments + This function will load class from the config['module_path'] first. + If config['module_path'] doesn't exists, it will load the class from default_module. + accept_types: Union[type, Tuple[type]] Optional. If the config is a instance of specific type, return the config directly. This will be passed into the second parameter of isinstance. @@ -246,10 +257,7 @@ def init_instance_by_config( if isinstance(config, accept_types): return config - if module is None: - module = get_module_by_module_path(config["module_path"]) - - klass, cls_kwargs = get_cls_kwargs(config, module) + klass, cls_kwargs = get_cls_kwargs(config, default_module=default_module) return klass(**cls_kwargs, **kwargs) diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index be458a24d..c54a6f700 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -145,6 +145,10 @@ class SignalRecord(RecordTemp): del params["data_key"] # The backend handler should be DataHandler raw_label = self.dataset.prepare(**params) + except AttributeError: + # The data handler is initialize with `drop_raw=True`... + # So raw_label is not available + raw_label = None self.recorder.save_objects(**{"label.pkl": raw_label}) self.dataset.__class__ = orig_cls