1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 11:30:57 +08:00
Files
qlib/qlib/model/ens/ensemble.py
you-n-g 50409ff17b Add log info for ensemble (#1113)
* Add log info for ensemble

* Update ensemble.py

* Update setup.py
2022-06-14 11:58:57 +08:00

135 lines
4.4 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Ensemble module can merge the objects in an Ensemble. For example, if there are many submodels predictions, we may need to merge them into an ensemble prediction.
"""
from typing import Union
import pandas as pd
from qlib.utils import FLATTEN_TUPLE, flatten_dict
from qlib.log import get_module_logger
class Ensemble:
"""Merge the ensemble_dict into an ensemble object.
For example: {Rollinga_b: object, Rollingb_c: object} -> object
When calling this class:
Args:
ensemble_dict (dict): the ensemble dict like {name: things} waiting for merging
Returns:
object: the ensemble object
"""
def __call__(self, ensemble_dict: dict, *args, **kwargs):
raise NotImplementedError(f"Please implement the `__call__` method.")
class SingleKeyEnsemble(Ensemble):
"""
Extract the object if there is only one key and value in the dict. Make the result more readable.
{Only key: Only value} -> Only value
If there is more than 1 key or less than 1 key, then do nothing.
Even you can run this recursively to make dict more readable.
NOTE: Default runs recursively.
When calling this class:
Args:
ensemble_dict (dict): the dict. The key of the dict will be ignored.
Returns:
dict: the readable dict.
"""
def __call__(self, ensemble_dict: Union[dict, object], recursion: bool = True) -> object:
if not isinstance(ensemble_dict, dict):
return ensemble_dict
if recursion:
tmp_dict = {}
for k, v in ensemble_dict.items():
tmp_dict[k] = self(v, recursion)
ensemble_dict = tmp_dict
keys = list(ensemble_dict.keys())
if len(keys) == 1:
ensemble_dict = ensemble_dict[keys[0]]
return ensemble_dict
class RollingEnsemble(Ensemble):
"""Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime".
When calling this class:
Args:
ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}.
The key of the dict will be ignored.
Returns:
pd.DataFrame: the complete result of rolling.
"""
def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
get_module_logger("RollingEnsemble").info(f"keys in group: {list(ensemble_dict.keys())}")
artifact_list = list(ensemble_dict.values())
artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min())
artifact = pd.concat(artifact_list)
# If there are duplicated predition, use the latest perdiction
artifact = artifact[~artifact.index.duplicated(keep="last")]
artifact = artifact.sort_index()
return artifact
class AverageEnsemble(Ensemble):
"""
Average and standardize a dict of same shape dataframe like `prediction` or `IC` into an ensemble.
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime". If it is a nested dict, then flat it.
When calling this class:
Args:
ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}.
The key of the dict will be ignored.
Returns:
pd.DataFrame: the complete result of averaging and standardizing.
"""
def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
"""using sample:
from qlib.model.ens.ensemble import AverageEnsemble
pred_res['new_key_name'] = AverageEnsemble()(predict_dict)
Parameters
----------
ensemble_dict : dict
Dictionary you want to ensemble
Returns
-------
pd.DataFrame
The dictionary including ensenbling result
"""
# need to flatten the nested dict
ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE)
get_module_logger("AverageEnsemble").info(f"keys in group: {list(ensemble_dict.keys())}")
values = list(ensemble_dict.values())
# NOTE: this may change the style underlying data!!!!
# from pd.DataFrame to pd.Series
results = pd.concat(values, axis=1)
results = results.groupby("datetime").apply(lambda df: (df - df.mean()) / df.std())
results = results.mean(axis=1)
results = results.sort_index()
return results