1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 11:00:57 +08:00

Add sample_config to QlibDataLoader, support multi-freq

This commit is contained in:
zhupr
2021-08-26 14:29:32 +08:00
committed by you-n-g
parent e8126b0c39
commit c99494eb76
3 changed files with 115 additions and 16 deletions

View File

@@ -58,6 +58,8 @@ class Alpha360(DataHandlerLP):
fit_start_time=None,
fit_end_time=None,
filter_pipe=None,
sample_config=None,
sample_benchmark=None,
**kwargs,
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
@@ -72,6 +74,8 @@ class Alpha360(DataHandlerLP):
},
"filter_pipe": filter_pipe,
"freq": freq,
"sample_config": sample_config,
"sample_benchmark": sample_benchmark,
},
}
@@ -144,6 +148,8 @@ class Alpha158(DataHandlerLP):
fit_end_time=None,
process_type=DataHandlerLP.PTYPE_A,
filter_pipe=None,
sample_config=None,
sample_benchmark=None,
**kwargs,
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
@@ -158,6 +164,8 @@ class Alpha158(DataHandlerLP):
},
"filter_pipe": filter_pipe,
"freq": freq,
"sample_config": sample_config,
"sample_benchmark": sample_benchmark,
},
}
super().__init__(

View File

@@ -7,12 +7,12 @@ import warnings
import numpy as np
import pandas as pd
from typing import Tuple, Union
from typing import Tuple, Union, List, Type
from qlib.data import D
from qlib.data import filter as filter_module
from qlib.data.filter import BaseDFilter
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point, get_cls_kwargs
from qlib.log import get_module_logger
@@ -62,11 +62,11 @@ class DLWParser(DataLoader):
Extracting this class so that QlibDataLoader and other dataloaders(such as QdbDataLoader) can share the fields.
"""
def __init__(self, config: Tuple[list, tuple, dict]):
def __init__(self, config: Union[list, tuple, dict]):
"""
Parameters
----------
config : Tuple[list, tuple, dict]
config : Union[list, tuple, dict]
Config will be used to describe the fields and column names
.. code-block::
@@ -88,7 +88,7 @@ class DLWParser(DataLoader):
else:
self.fields = self._parse_fields_info(config)
def _parse_fields_info(self, fields_info: Tuple[list, tuple]) -> Tuple[list, list]:
def _parse_fields_info(self, fields_info: Union[list, tuple]) -> Tuple[list, list]:
if len(fields_info) == 0:
raise ValueError("The size of fields must be greater than 0")
@@ -104,7 +104,15 @@ class DLWParser(DataLoader):
return exprs, names
@abc.abstractmethod
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
def load_group_df(
self,
instruments,
exprs: list,
names: list,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
gp_name: str = None,
) -> pd.DataFrame:
"""
load the dataframe for specific group
@@ -128,7 +136,7 @@ class DLWParser(DataLoader):
if self.is_group:
df = pd.concat(
{
grp: self.load_group_df(instruments, exprs, names, start_time, end_time)
grp: self.load_group_df(instruments, exprs, names, start_time, end_time, grp)
for grp, (exprs, names) in self.fields.items()
},
axis=1,
@@ -142,7 +150,15 @@ class DLWParser(DataLoader):
class QlibDataLoader(DLWParser):
"""Same as QlibDataLoader. The fields can be define by config"""
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_level=True, freq="day"):
def __init__(
self,
config: Tuple[list, tuple, dict],
filter_pipe: List = None,
swap_level: bool = True,
freq: Union[str, dict] = "day",
sample_benchmark: str = None,
sample_config: dict = None,
):
"""
Parameters
----------
@@ -163,9 +179,53 @@ class QlibDataLoader(DLWParser):
self.filter_pipe = filter_pipe
self.swap_level = swap_level
self.freq = freq
# sample
self.sample_config = sample_config
self.sample_benchmark = sample_benchmark
self.can_sample = False
super().__init__(config)
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
if self.is_group:
# check sample config
if isinstance(freq, dict):
for _gp in config.keys():
if _gp not in freq:
raise ValueError(f"freq(={freq}) missing group(={_gp})")
if len(set(freq.values())) == 1:
self.freq = list(freq.values())[0]
else:
assert self.sample_config, f"freq(={self.freq}), sample_config cannot be None/empty"
assert isinstance(self.sample_config, dict), f"sample_config(={self.sample_config}) must be dict"
assert (
self.sample_benchmark and self.sample_benchmark in self.fields
), f"sample_benchmark not to specification"
self.can_sample = True
def _get_sample_method(self, gp_name: str) -> Union[str, Type]:
_method = self.sample_config.get(gp_name, None)
if _method is None:
return _method
if isinstance(_method, str):
# pandas.DataFrame.resample
if not _method.startswith("resample"):
raise ValueError(f"sample method error, only pandas.DataFrame.resample is supported")
elif isinstance(_method, dict):
# module_path && func name
_method, _ = get_cls_kwargs(_method, obj_type="func")
else:
raise TypeError(f"sample_method only supports [str, dict], currently it is {_method}")
return _method
def load_group_df(
self,
instruments,
exprs: list,
names: list,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
gp_name: str = None,
) -> pd.DataFrame:
if instruments is None:
warnings.warn("`instruments` is not set, will load all stocks")
instruments = "all"
@@ -174,12 +234,39 @@ class QlibDataLoader(DLWParser):
elif self.filter_pipe is not None:
warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")
df = D.features(instruments, exprs, start_time, end_time, self.freq)
freq = self.freq[gp_name] if self.can_sample else self.freq
df = D.features(instruments, exprs, start_time, end_time, freq)
df.columns = names
if self.can_sample and self.sample_benchmark != gp_name:
sample_method = self._get_sample_method(gp_name)
if sample_method is None:
warnings.warn(f"{gp_name} sample_method is None")
if isinstance(sample_method, str):
df = eval(f"df.groupby(level='instrument').{sample_method}")
else:
df = df.groupby(level="instrument").apply(sample_method)
if self.swap_level:
df = df.swaplevel().sort_index() # NOTE: if swaplevel, return <datetime, instrument>
return df
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
if self.is_group:
group = {
grp: self.load_group_df(instruments, exprs, names, start_time, end_time, grp)
for grp, (exprs, names) in self.fields.items()
}
for grp, _df in group.items():
if grp == self.sample_benchmark:
continue
else:
group[grp] = _df.reindex(group[self.sample_benchmark].index)
df = pd.concat(group, axis=1)
else:
exprs, names = self.fields
df = self.load_group_df(instruments, exprs, names, start_time, end_time)
return df
class StaticDataLoader(DataLoader):
"""

View File

@@ -189,9 +189,11 @@ def get_module_by_module_path(module_path: Union[str, ModuleType]):
return module
def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict):
def get_cls_kwargs(
config: Union[dict, str], default_module: Union[str, ModuleType] = None, obj_type: str = "class"
) -> (type, dict):
"""
extract class and kwargs from config info
extract class/func and kwargs from config info
Parameters
----------
@@ -203,25 +205,27 @@ def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleTy
This function will load class from the config['module_path'] first.
If config['module_path'] doesn't exists, it will load the class from default_module.
obj_type: str
"class" or "func"
Returns
-------
(type, dict):
the class object and it's arguments.
the class/func object and it's arguments.
"""
if isinstance(config, dict):
module = get_module_by_module_path(config.get("module_path", default_module))
# raise AttributeError
klass = getattr(module, config["class"])
_obj = getattr(module, config[obj_type])
kwargs = config.get("kwargs", {})
elif isinstance(config, str):
module = get_module_by_module_path(default_module)
klass = getattr(module, config)
_obj = getattr(module, config)
kwargs = {}
else:
raise NotImplementedError(f"This type of input is not supported")
return klass, kwargs
return _obj, kwargs
def init_instance_by_config(