1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 02:21:18 +08:00

Auto injecting model and dataset for Recorder (#645)

* Auto injecting model and dataset for Recorder

* Support using Feature in expression
This commit is contained in:
you-n-g
2021-10-15 13:50:24 +08:00
committed by GitHub
parent 334b92ace7
commit 7c31012b50
3 changed files with 36 additions and 11 deletions

View File

@@ -13,7 +13,7 @@ import pandas as pd
from typing import Union, List, Type
from scipy.stats import percentileofscore
from .base import Expression, ExpressionOps
from .base import Expression, ExpressionOps, Feature
from ..log import get_module_logger
from ..utils import get_callable_kwargs
@@ -1485,6 +1485,7 @@ OpsList = [
IdxMax,
IdxMin,
If,
Feature,
]
@@ -1517,7 +1518,7 @@ class OpsWrapper:
else:
_ops_class = _operator
if not issubclass(_ops_class, ExpressionOps):
if not issubclass(_ops_class, Expression):
raise TypeError("operator must be subclass of ExpressionOps, not {}".format(_ops_class))
if _ops_class.__name__ in self._ops:

View File

@@ -70,9 +70,9 @@ def fill_placeholder(config: dict, config_extend: dict):
# bfs
top = 0
tail = 1
item_quene = [config]
item_queue = [config]
while top < tail:
now_item = item_quene[top]
now_item = item_queue[top]
top += 1
if isinstance(now_item, list):
item_keys = range(len(now_item))
@@ -80,9 +80,9 @@ def fill_placeholder(config: dict, config_extend: dict):
item_keys = now_item.keys()
for key in item_keys:
if isinstance(now_item[key], list) or isinstance(now_item[key], dict):
item_quene.append(now_item[key])
item_queue.append(now_item[key])
tail += 1
elif now_item[key] in config_extend.keys():
elif isinstance(now_item[key], str) and now_item[key] in config_extend.keys():
now_item[key] = config_extend[now_item[key]]
return config
@@ -114,10 +114,19 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
task_config = fill_placeholder(task_config, placehorder_value)
# generate records: prediction, backtest, and analysis
records = task_config.get("record", [])
if isinstance(records, dict): # prevent only one dict
if isinstance(records, dict): # uniform the data format to list
records = [records]
for record in records:
r = init_instance_by_config(record, recorder=rec)
# Some recorder require the parameter `model` and `dataset`.
# try to automatically pass in them to the initialization function
# to make defining the tasking easier
r = init_instance_by_config(
record,
recorder=rec,
default_module="qlib.workflow.record_temp",
try_kwargs={"model": model, "dataset": dataset},
)
r.generate()
return rec

View File

@@ -27,7 +27,7 @@ import collections
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Union, Tuple, Any, Text, Optional
from typing import Dict, Union, Tuple, Any, Text, Optional
from types import ModuleType
from urllib.parse import urlparse
@@ -232,7 +232,11 @@ get_cls_kwargs = get_callable_kwargs # NOTE: this is for compatibility for the
def init_instance_by_config(
config: Union[str, dict, object], default_module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs
config: Union[str, dict, object],
default_module=None,
accept_types: Union[type, Tuple[type]] = (),
try_kwargs: Dict = {},
**kwargs,
) -> Any:
"""
get initialized instance with config
@@ -270,6 +274,10 @@ def init_instance_by_config(
Optional. If the config is a instance of specific type, return the config directly.
This will be passed into the second parameter of isinstance.
try_kwargs: Dict
Try to pass in kwargs in `try_kwargs` when initialized the instance
If error occurred, it will fail back to initialization without try_kwargs.
Returns
-------
object:
@@ -286,7 +294,14 @@ def init_instance_by_config(
return pickle.load(f)
klass, cls_kwargs = get_callable_kwargs(config, default_module=default_module)
return klass(**cls_kwargs, **kwargs)
try:
return klass(**cls_kwargs, **try_kwargs, **kwargs)
except (TypeError,):
# TypeError for handling errors like
# 1: `XXX() got multiple values for keyword argument 'YYY'`
# 2: `XXX() got an unexpected keyword argument 'YYY'
return klass(**cls_kwargs, **kwargs)
@contextlib.contextmanager