1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 10:01:19 +08:00

Fix some API(for lb nn)

This commit is contained in:
Young
2021-04-07 03:31:50 +00:00
parent 8362780e22
commit 1dbb561744
5 changed files with 39 additions and 30 deletions

View File

@@ -26,6 +26,7 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time):
"fit_end_time": fit_end_time,
}
)
# FIXME: the `module_path` parameter is missed.
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
else:
new_l.append(p)

View File

@@ -2,6 +2,7 @@
# Licensed under the MIT License.
import abc
from typing import Union, Text
import numpy as np
import pandas as pd
import copy
@@ -14,7 +15,7 @@ from ...utils.paral import datetime_groupby_apply
EPS = 1e-12
def get_group_columns(df: pd.DataFrame, group: str):
def get_group_columns(df: pd.DataFrame, group: Union[Text, None]):
"""
get a group of columns from multi-index columns DataFrame

View File

@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from qlib.utils import init_instance_by_config, flatten_dict
from qlib.utils import init_instance_by_config, flatten_dict, get_cls_kwargs
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord
@@ -41,16 +41,11 @@ def task_train(task_config: dict, experiment_name: str) -> str:
if isinstance(records, dict): # prevent only one dict
records = [records]
for record in records:
if record["class"] == SignalRecord.__name__:
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
record.setdefault("kwargs", {})
record["kwargs"].update(srconf)
sr = init_instance_by_config(record)
sr.generate()
cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp")
if cls is SignalRecord:
rconf = {"model": model, "dataset": dataset, "recorder": recorder}
else:
rconf = {"recorder": recorder}
record.setdefault("kwargs", {})
record["kwargs"].update(rconf)
ar = init_instance_by_config(record)
ar.generate()
r = cls(**kwargs, **rconf)
r.generate()
return recorder.info["id"]

View File

@@ -24,7 +24,8 @@ import collections
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Union, Tuple
from typing import Union, Tuple, Any
from types import ModuleType
from ..config import C
from ..log import get_module_logger, set_log_with_config
@@ -165,24 +166,25 @@ def parse_field(field):
return re.sub(r"\$(\w+)", r'Feature("\1")', re.sub(r"(\w+\s*)\(", r"Operators.\1(", field))
def get_module_by_module_path(module_path):
def get_module_by_module_path(module_path: Union[str, ModuleType]):
"""Load module path
:param module_path:
:return:
"""
if module_path.endswith(".py"):
module_spec = importlib.util.spec_from_file_location("", module_path)
module = importlib.util.module_from_spec(module_spec)
module_spec.loader.exec_module(module)
if isinstance(module_path, ModuleType):
module = module_path
else:
module = importlib.import_module(module_path)
if module_path.endswith(".py"):
module_spec = importlib.util.spec_from_file_location("", module_path)
module = importlib.util.module_from_spec(module_spec)
module_spec.loader.exec_module(module)
else:
module = importlib.import_module(module_path)
return module
def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType]=None) -> (type, dict):
"""
extract class and kwargs from config info
@@ -191,8 +193,10 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
config : [dict, str]
similar to config
module : Python module
default_module : Python module or str
It should be a python module to load the class type
This function will load class from the config['module_path'] first.
If config['module_path'] doesn't exists, it will load the class from default_module.
Returns
-------
@@ -200,10 +204,14 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
the class object and it's arguments.
"""
if isinstance(config, dict):
module = get_module_by_module_path(config.get("module_path", default_module))
# raise AttributeError
klass = getattr(module, config["class"])
kwargs = config.get("kwargs", {})
elif isinstance(config, str):
module = get_module_by_module_path(default_module)
klass = getattr(module, config)
kwargs = {}
else:
@@ -212,8 +220,8 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
def init_instance_by_config(
config: Union[str, dict, object], module=None, accept_types: Union[type, Tuple[type]] = tuple([]), **kwargs
) -> object:
config: Union[str, dict, object], default_module=None, accept_types: Union[type, Tuple[type]] = tuple([]), **kwargs
) -> Any:
"""
get initialized instance with config
@@ -230,10 +238,13 @@ def init_instance_by_config(
"ClassName": getattr(module, config)() will be used.
object example:
instance of accept_types
module : Python module
default_module : Python module
Optional. It should be a python module.
NOTE: the "module_path" will be override by `module` arguments
This function will load class from the config['module_path'] first.
If config['module_path'] doesn't exists, it will load the class from default_module.
accept_types: Union[type, Tuple[type]]
Optional. If the config is a instance of specific type, return the config directly.
This will be passed into the second parameter of isinstance.
@@ -246,10 +257,7 @@ def init_instance_by_config(
if isinstance(config, accept_types):
return config
if module is None:
module = get_module_by_module_path(config["module_path"])
klass, cls_kwargs = get_cls_kwargs(config, module)
klass, cls_kwargs = get_cls_kwargs(config, default_module=default_module)
return klass(**cls_kwargs, **kwargs)

View File

@@ -145,6 +145,10 @@ class SignalRecord(RecordTemp):
del params["data_key"]
# The backend handler should be DataHandler
raw_label = self.dataset.prepare(**params)
except AttributeError:
# The data handler is initialize with `drop_raw=True`...
# So raw_label is not available
raw_label = None
self.recorder.save_objects(**{"label.pkl": raw_label})
self.dataset.__class__ = orig_cls