From c99494eb76b53717fd59a4b86ef4a60515ca6b6b Mon Sep 17 00:00:00 2001 From: zhupr Date: Thu, 26 Aug 2021 14:29:32 +0800 Subject: [PATCH] Add sample_config to QlibDataLoader, support multi-freq --- qlib/contrib/data/handler.py | 8 +++ qlib/data/dataset/loader.py | 107 +++++++++++++++++++++++++++++++---- qlib/utils/__init__.py | 16 ++++-- 3 files changed, 115 insertions(+), 16 deletions(-) diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py index b22741f4a..51c056522 100644 --- a/qlib/contrib/data/handler.py +++ b/qlib/contrib/data/handler.py @@ -58,6 +58,8 @@ class Alpha360(DataHandlerLP): fit_start_time=None, fit_end_time=None, filter_pipe=None, + sample_config=None, + sample_benchmark=None, **kwargs, ): infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) @@ -72,6 +74,8 @@ class Alpha360(DataHandlerLP): }, "filter_pipe": filter_pipe, "freq": freq, + "sample_config": sample_config, + "sample_benchmark": sample_benchmark, }, } @@ -144,6 +148,8 @@ class Alpha158(DataHandlerLP): fit_end_time=None, process_type=DataHandlerLP.PTYPE_A, filter_pipe=None, + sample_config=None, + sample_benchmark=None, **kwargs, ): infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) @@ -158,6 +164,8 @@ class Alpha158(DataHandlerLP): }, "filter_pipe": filter_pipe, "freq": freq, + "sample_config": sample_config, + "sample_benchmark": sample_benchmark, }, } super().__init__( diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index dd0572660..70be66d13 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -7,12 +7,12 @@ import warnings import numpy as np import pandas as pd -from typing import Tuple, Union +from typing import Tuple, Union, List, Type from qlib.data import D from qlib.data import filter as filter_module from qlib.data.filter import BaseDFilter -from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point +from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point, get_cls_kwargs from qlib.log import get_module_logger @@ -62,11 +62,11 @@ class DLWParser(DataLoader): Extracting this class so that QlibDataLoader and other dataloaders(such as QdbDataLoader) can share the fields. """ - def __init__(self, config: Tuple[list, tuple, dict]): + def __init__(self, config: Union[list, tuple, dict]): """ Parameters ---------- - config : Tuple[list, tuple, dict] + config : Union[list, tuple, dict] Config will be used to describe the fields and column names .. code-block:: @@ -88,7 +88,7 @@ class DLWParser(DataLoader): else: self.fields = self._parse_fields_info(config) - def _parse_fields_info(self, fields_info: Tuple[list, tuple]) -> Tuple[list, list]: + def _parse_fields_info(self, fields_info: Union[list, tuple]) -> Tuple[list, list]: if len(fields_info) == 0: raise ValueError("The size of fields must be greater than 0") @@ -104,7 +104,15 @@ class DLWParser(DataLoader): return exprs, names @abc.abstractmethod - def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame: + def load_group_df( + self, + instruments, + exprs: list, + names: list, + start_time: Union[str, pd.Timestamp] = None, + end_time: Union[str, pd.Timestamp] = None, + gp_name: str = None, + ) -> pd.DataFrame: """ load the dataframe for specific group @@ -128,7 +136,7 @@ class DLWParser(DataLoader): if self.is_group: df = pd.concat( { - grp: self.load_group_df(instruments, exprs, names, start_time, end_time) + grp: self.load_group_df(instruments, exprs, names, start_time, end_time, grp) for grp, (exprs, names) in self.fields.items() }, axis=1, @@ -142,7 +150,15 @@ class DLWParser(DataLoader): class QlibDataLoader(DLWParser): """Same as QlibDataLoader. The fields can be define by config""" - def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_level=True, freq="day"): + def __init__( + self, + config: Tuple[list, tuple, dict], + filter_pipe: List = None, + swap_level: bool = True, + freq: Union[str, dict] = "day", + sample_benchmark: str = None, + sample_config: dict = None, + ): """ Parameters ---------- @@ -163,9 +179,53 @@ class QlibDataLoader(DLWParser): self.filter_pipe = filter_pipe self.swap_level = swap_level self.freq = freq + + # sample + self.sample_config = sample_config + self.sample_benchmark = sample_benchmark + self.can_sample = False super().__init__(config) - def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame: + if self.is_group: + # check sample config + if isinstance(freq, dict): + for _gp in config.keys(): + if _gp not in freq: + raise ValueError(f"freq(={freq}) missing group(={_gp})") + if len(set(freq.values())) == 1: + self.freq = list(freq.values())[0] + else: + assert self.sample_config, f"freq(={self.freq}), sample_config cannot be None/empty" + assert isinstance(self.sample_config, dict), f"sample_config(={self.sample_config}) must be dict" + assert ( + self.sample_benchmark and self.sample_benchmark in self.fields + ), f"sample_benchmark not to specification" + self.can_sample = True + + def _get_sample_method(self, gp_name: str) -> Union[str, Type]: + _method = self.sample_config.get(gp_name, None) + if _method is None: + return _method + if isinstance(_method, str): + # pandas.DataFrame.resample + if not _method.startswith("resample"): + raise ValueError(f"sample method error, only pandas.DataFrame.resample is supported") + elif isinstance(_method, dict): + # module_path && func name + _method, _ = get_cls_kwargs(_method, obj_type="func") + else: + raise TypeError(f"sample_method only supports [str, dict], currently it is {_method}") + return _method + + def load_group_df( + self, + instruments, + exprs: list, + names: list, + start_time: Union[str, pd.Timestamp] = None, + end_time: Union[str, pd.Timestamp] = None, + gp_name: str = None, + ) -> pd.DataFrame: if instruments is None: warnings.warn("`instruments` is not set, will load all stocks") instruments = "all" @@ -174,12 +234,39 @@ class QlibDataLoader(DLWParser): elif self.filter_pipe is not None: warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list") - df = D.features(instruments, exprs, start_time, end_time, self.freq) + freq = self.freq[gp_name] if self.can_sample else self.freq + df = D.features(instruments, exprs, start_time, end_time, freq) df.columns = names + + if self.can_sample and self.sample_benchmark != gp_name: + sample_method = self._get_sample_method(gp_name) + if sample_method is None: + warnings.warn(f"{gp_name} sample_method is None") + if isinstance(sample_method, str): + df = eval(f"df.groupby(level='instrument').{sample_method}") + else: + df = df.groupby(level="instrument").apply(sample_method) if self.swap_level: df = df.swaplevel().sort_index() # NOTE: if swaplevel, return return df + def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: + if self.is_group: + group = { + grp: self.load_group_df(instruments, exprs, names, start_time, end_time, grp) + for grp, (exprs, names) in self.fields.items() + } + for grp, _df in group.items(): + if grp == self.sample_benchmark: + continue + else: + group[grp] = _df.reindex(group[self.sample_benchmark].index) + df = pd.concat(group, axis=1) + else: + exprs, names = self.fields + df = self.load_group_df(instruments, exprs, names, start_time, end_time) + return df + class StaticDataLoader(DataLoader): """ diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 41fdfb748..007cafbce 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -189,9 +189,11 @@ def get_module_by_module_path(module_path: Union[str, ModuleType]): return module -def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict): +def get_cls_kwargs( + config: Union[dict, str], default_module: Union[str, ModuleType] = None, obj_type: str = "class" +) -> (type, dict): """ - extract class and kwargs from config info + extract class/func and kwargs from config info Parameters ---------- @@ -203,25 +205,27 @@ def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleTy This function will load class from the config['module_path'] first. If config['module_path'] doesn't exists, it will load the class from default_module. + obj_type: str + "class" or "func" Returns ------- (type, dict): - the class object and it's arguments. + the class/func object and it's arguments. """ if isinstance(config, dict): module = get_module_by_module_path(config.get("module_path", default_module)) # raise AttributeError - klass = getattr(module, config["class"]) + _obj = getattr(module, config[obj_type]) kwargs = config.get("kwargs", {}) elif isinstance(config, str): module = get_module_by_module_path(default_module) - klass = getattr(module, config) + _obj = getattr(module, config) kwargs = {} else: raise NotImplementedError(f"This type of input is not supported") - return klass, kwargs + return _obj, kwargs def init_instance_by_config(