mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 10:01:19 +08:00
Fix some API(for lb nn)
This commit is contained in:
@@ -26,6 +26,7 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time):
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
# FIXME: the `module_path` parameter is missed.
|
||||
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
|
||||
else:
|
||||
new_l.append(p)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
from typing import Union, Text
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
@@ -14,7 +15,7 @@ from ...utils.paral import datetime_groupby_apply
|
||||
EPS = 1e-12
|
||||
|
||||
|
||||
def get_group_columns(df: pd.DataFrame, group: str):
|
||||
def get_group_columns(df: pd.DataFrame, group: Union[Text, None]):
|
||||
"""
|
||||
get a group of columns from multi-index columns DataFrame
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.utils import init_instance_by_config, flatten_dict, get_cls_kwargs
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
|
||||
@@ -41,16 +41,11 @@ def task_train(task_config: dict, experiment_name: str) -> str:
|
||||
if isinstance(records, dict): # prevent only one dict
|
||||
records = [records]
|
||||
for record in records:
|
||||
if record["class"] == SignalRecord.__name__:
|
||||
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
|
||||
record.setdefault("kwargs", {})
|
||||
record["kwargs"].update(srconf)
|
||||
sr = init_instance_by_config(record)
|
||||
sr.generate()
|
||||
cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp")
|
||||
if cls is SignalRecord:
|
||||
rconf = {"model": model, "dataset": dataset, "recorder": recorder}
|
||||
else:
|
||||
rconf = {"recorder": recorder}
|
||||
record.setdefault("kwargs", {})
|
||||
record["kwargs"].update(rconf)
|
||||
ar = init_instance_by_config(record)
|
||||
ar.generate()
|
||||
r = cls(**kwargs, **rconf)
|
||||
r.generate()
|
||||
return recorder.info["id"]
|
||||
|
||||
@@ -24,7 +24,8 @@ import collections
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Union, Tuple
|
||||
from typing import Union, Tuple, Any
|
||||
from types import ModuleType
|
||||
|
||||
from ..config import C
|
||||
from ..log import get_module_logger, set_log_with_config
|
||||
@@ -165,24 +166,25 @@ def parse_field(field):
|
||||
return re.sub(r"\$(\w+)", r'Feature("\1")', re.sub(r"(\w+\s*)\(", r"Operators.\1(", field))
|
||||
|
||||
|
||||
def get_module_by_module_path(module_path):
|
||||
def get_module_by_module_path(module_path: Union[str, ModuleType]):
|
||||
"""Load module path
|
||||
|
||||
:param module_path:
|
||||
:return:
|
||||
"""
|
||||
|
||||
if module_path.endswith(".py"):
|
||||
module_spec = importlib.util.spec_from_file_location("", module_path)
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
module_spec.loader.exec_module(module)
|
||||
if isinstance(module_path, ModuleType):
|
||||
module = module_path
|
||||
else:
|
||||
module = importlib.import_module(module_path)
|
||||
|
||||
if module_path.endswith(".py"):
|
||||
module_spec = importlib.util.spec_from_file_location("", module_path)
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
module_spec.loader.exec_module(module)
|
||||
else:
|
||||
module = importlib.import_module(module_path)
|
||||
return module
|
||||
|
||||
|
||||
def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
||||
def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType]=None) -> (type, dict):
|
||||
"""
|
||||
extract class and kwargs from config info
|
||||
|
||||
@@ -191,8 +193,10 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
||||
config : [dict, str]
|
||||
similar to config
|
||||
|
||||
module : Python module
|
||||
default_module : Python module or str
|
||||
It should be a python module to load the class type
|
||||
This function will load class from the config['module_path'] first.
|
||||
If config['module_path'] doesn't exists, it will load the class from default_module.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -200,10 +204,14 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
||||
the class object and it's arguments.
|
||||
"""
|
||||
if isinstance(config, dict):
|
||||
module = get_module_by_module_path(config.get("module_path", default_module))
|
||||
|
||||
# raise AttributeError
|
||||
klass = getattr(module, config["class"])
|
||||
kwargs = config.get("kwargs", {})
|
||||
elif isinstance(config, str):
|
||||
module = get_module_by_module_path(default_module)
|
||||
|
||||
klass = getattr(module, config)
|
||||
kwargs = {}
|
||||
else:
|
||||
@@ -212,8 +220,8 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
||||
|
||||
|
||||
def init_instance_by_config(
|
||||
config: Union[str, dict, object], module=None, accept_types: Union[type, Tuple[type]] = tuple([]), **kwargs
|
||||
) -> object:
|
||||
config: Union[str, dict, object], default_module=None, accept_types: Union[type, Tuple[type]] = tuple([]), **kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
get initialized instance with config
|
||||
|
||||
@@ -230,10 +238,13 @@ def init_instance_by_config(
|
||||
"ClassName": getattr(module, config)() will be used.
|
||||
object example:
|
||||
instance of accept_types
|
||||
module : Python module
|
||||
default_module : Python module
|
||||
Optional. It should be a python module.
|
||||
NOTE: the "module_path" will be override by `module` arguments
|
||||
|
||||
This function will load class from the config['module_path'] first.
|
||||
If config['module_path'] doesn't exists, it will load the class from default_module.
|
||||
|
||||
accept_types: Union[type, Tuple[type]]
|
||||
Optional. If the config is a instance of specific type, return the config directly.
|
||||
This will be passed into the second parameter of isinstance.
|
||||
@@ -246,10 +257,7 @@ def init_instance_by_config(
|
||||
if isinstance(config, accept_types):
|
||||
return config
|
||||
|
||||
if module is None:
|
||||
module = get_module_by_module_path(config["module_path"])
|
||||
|
||||
klass, cls_kwargs = get_cls_kwargs(config, module)
|
||||
klass, cls_kwargs = get_cls_kwargs(config, default_module=default_module)
|
||||
return klass(**cls_kwargs, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -145,6 +145,10 @@ class SignalRecord(RecordTemp):
|
||||
del params["data_key"]
|
||||
# The backend handler should be DataHandler
|
||||
raw_label = self.dataset.prepare(**params)
|
||||
except AttributeError:
|
||||
# The data handler is initialize with `drop_raw=True`...
|
||||
# So raw_label is not available
|
||||
raw_label = None
|
||||
|
||||
self.recorder.save_objects(**{"label.pkl": raw_label})
|
||||
self.dataset.__class__ = orig_cls
|
||||
|
||||
Reference in New Issue
Block a user