mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 11:00:57 +08:00
Add sample_config to QlibDataLoader, support multi-freq
This commit is contained in:
@@ -58,6 +58,8 @@ class Alpha360(DataHandlerLP):
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
filter_pipe=None,
|
||||
sample_config=None,
|
||||
sample_benchmark=None,
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
@@ -72,6 +74,8 @@ class Alpha360(DataHandlerLP):
|
||||
},
|
||||
"filter_pipe": filter_pipe,
|
||||
"freq": freq,
|
||||
"sample_config": sample_config,
|
||||
"sample_benchmark": sample_benchmark,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -144,6 +148,8 @@ class Alpha158(DataHandlerLP):
|
||||
fit_end_time=None,
|
||||
process_type=DataHandlerLP.PTYPE_A,
|
||||
filter_pipe=None,
|
||||
sample_config=None,
|
||||
sample_benchmark=None,
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
@@ -158,6 +164,8 @@ class Alpha158(DataHandlerLP):
|
||||
},
|
||||
"filter_pipe": filter_pipe,
|
||||
"freq": freq,
|
||||
"sample_config": sample_config,
|
||||
"sample_benchmark": sample_benchmark,
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
|
||||
@@ -7,12 +7,12 @@ import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from typing import Tuple, Union
|
||||
from typing import Tuple, Union, List, Type
|
||||
|
||||
from qlib.data import D
|
||||
from qlib.data import filter as filter_module
|
||||
from qlib.data.filter import BaseDFilter
|
||||
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point
|
||||
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point, get_cls_kwargs
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
|
||||
@@ -62,11 +62,11 @@ class DLWParser(DataLoader):
|
||||
Extracting this class so that QlibDataLoader and other dataloaders(such as QdbDataLoader) can share the fields.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Tuple[list, tuple, dict]):
|
||||
def __init__(self, config: Union[list, tuple, dict]):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
config : Tuple[list, tuple, dict]
|
||||
config : Union[list, tuple, dict]
|
||||
Config will be used to describe the fields and column names
|
||||
|
||||
.. code-block::
|
||||
@@ -88,7 +88,7 @@ class DLWParser(DataLoader):
|
||||
else:
|
||||
self.fields = self._parse_fields_info(config)
|
||||
|
||||
def _parse_fields_info(self, fields_info: Tuple[list, tuple]) -> Tuple[list, list]:
|
||||
def _parse_fields_info(self, fields_info: Union[list, tuple]) -> Tuple[list, list]:
|
||||
if len(fields_info) == 0:
|
||||
raise ValueError("The size of fields must be greater than 0")
|
||||
|
||||
@@ -104,7 +104,15 @@ class DLWParser(DataLoader):
|
||||
return exprs, names
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
def load_group_df(
|
||||
self,
|
||||
instruments,
|
||||
exprs: list,
|
||||
names: list,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
gp_name: str = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
load the dataframe for specific group
|
||||
|
||||
@@ -128,7 +136,7 @@ class DLWParser(DataLoader):
|
||||
if self.is_group:
|
||||
df = pd.concat(
|
||||
{
|
||||
grp: self.load_group_df(instruments, exprs, names, start_time, end_time)
|
||||
grp: self.load_group_df(instruments, exprs, names, start_time, end_time, grp)
|
||||
for grp, (exprs, names) in self.fields.items()
|
||||
},
|
||||
axis=1,
|
||||
@@ -142,7 +150,15 @@ class DLWParser(DataLoader):
|
||||
class QlibDataLoader(DLWParser):
|
||||
"""Same as QlibDataLoader. The fields can be define by config"""
|
||||
|
||||
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_level=True, freq="day"):
|
||||
def __init__(
|
||||
self,
|
||||
config: Tuple[list, tuple, dict],
|
||||
filter_pipe: List = None,
|
||||
swap_level: bool = True,
|
||||
freq: Union[str, dict] = "day",
|
||||
sample_benchmark: str = None,
|
||||
sample_config: dict = None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -163,9 +179,53 @@ class QlibDataLoader(DLWParser):
|
||||
self.filter_pipe = filter_pipe
|
||||
self.swap_level = swap_level
|
||||
self.freq = freq
|
||||
|
||||
# sample
|
||||
self.sample_config = sample_config
|
||||
self.sample_benchmark = sample_benchmark
|
||||
self.can_sample = False
|
||||
super().__init__(config)
|
||||
|
||||
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
if self.is_group:
|
||||
# check sample config
|
||||
if isinstance(freq, dict):
|
||||
for _gp in config.keys():
|
||||
if _gp not in freq:
|
||||
raise ValueError(f"freq(={freq}) missing group(={_gp})")
|
||||
if len(set(freq.values())) == 1:
|
||||
self.freq = list(freq.values())[0]
|
||||
else:
|
||||
assert self.sample_config, f"freq(={self.freq}), sample_config cannot be None/empty"
|
||||
assert isinstance(self.sample_config, dict), f"sample_config(={self.sample_config}) must be dict"
|
||||
assert (
|
||||
self.sample_benchmark and self.sample_benchmark in self.fields
|
||||
), f"sample_benchmark not to specification"
|
||||
self.can_sample = True
|
||||
|
||||
def _get_sample_method(self, gp_name: str) -> Union[str, Type]:
|
||||
_method = self.sample_config.get(gp_name, None)
|
||||
if _method is None:
|
||||
return _method
|
||||
if isinstance(_method, str):
|
||||
# pandas.DataFrame.resample
|
||||
if not _method.startswith("resample"):
|
||||
raise ValueError(f"sample method error, only pandas.DataFrame.resample is supported")
|
||||
elif isinstance(_method, dict):
|
||||
# module_path && func name
|
||||
_method, _ = get_cls_kwargs(_method, obj_type="func")
|
||||
else:
|
||||
raise TypeError(f"sample_method only supports [str, dict], currently it is {_method}")
|
||||
return _method
|
||||
|
||||
def load_group_df(
|
||||
self,
|
||||
instruments,
|
||||
exprs: list,
|
||||
names: list,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
gp_name: str = None,
|
||||
) -> pd.DataFrame:
|
||||
if instruments is None:
|
||||
warnings.warn("`instruments` is not set, will load all stocks")
|
||||
instruments = "all"
|
||||
@@ -174,12 +234,39 @@ class QlibDataLoader(DLWParser):
|
||||
elif self.filter_pipe is not None:
|
||||
warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")
|
||||
|
||||
df = D.features(instruments, exprs, start_time, end_time, self.freq)
|
||||
freq = self.freq[gp_name] if self.can_sample else self.freq
|
||||
df = D.features(instruments, exprs, start_time, end_time, freq)
|
||||
df.columns = names
|
||||
|
||||
if self.can_sample and self.sample_benchmark != gp_name:
|
||||
sample_method = self._get_sample_method(gp_name)
|
||||
if sample_method is None:
|
||||
warnings.warn(f"{gp_name} sample_method is None")
|
||||
if isinstance(sample_method, str):
|
||||
df = eval(f"df.groupby(level='instrument').{sample_method}")
|
||||
else:
|
||||
df = df.groupby(level="instrument").apply(sample_method)
|
||||
if self.swap_level:
|
||||
df = df.swaplevel().sort_index() # NOTE: if swaplevel, return <datetime, instrument>
|
||||
return df
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
if self.is_group:
|
||||
group = {
|
||||
grp: self.load_group_df(instruments, exprs, names, start_time, end_time, grp)
|
||||
for grp, (exprs, names) in self.fields.items()
|
||||
}
|
||||
for grp, _df in group.items():
|
||||
if grp == self.sample_benchmark:
|
||||
continue
|
||||
else:
|
||||
group[grp] = _df.reindex(group[self.sample_benchmark].index)
|
||||
df = pd.concat(group, axis=1)
|
||||
else:
|
||||
exprs, names = self.fields
|
||||
df = self.load_group_df(instruments, exprs, names, start_time, end_time)
|
||||
return df
|
||||
|
||||
|
||||
class StaticDataLoader(DataLoader):
|
||||
"""
|
||||
|
||||
@@ -189,9 +189,11 @@ def get_module_by_module_path(module_path: Union[str, ModuleType]):
|
||||
return module
|
||||
|
||||
|
||||
def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict):
|
||||
def get_cls_kwargs(
|
||||
config: Union[dict, str], default_module: Union[str, ModuleType] = None, obj_type: str = "class"
|
||||
) -> (type, dict):
|
||||
"""
|
||||
extract class and kwargs from config info
|
||||
extract class/func and kwargs from config info
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -203,25 +205,27 @@ def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleTy
|
||||
This function will load class from the config['module_path'] first.
|
||||
If config['module_path'] doesn't exists, it will load the class from default_module.
|
||||
|
||||
obj_type: str
|
||||
"class" or "func"
|
||||
Returns
|
||||
-------
|
||||
(type, dict):
|
||||
the class object and it's arguments.
|
||||
the class/func object and it's arguments.
|
||||
"""
|
||||
if isinstance(config, dict):
|
||||
module = get_module_by_module_path(config.get("module_path", default_module))
|
||||
|
||||
# raise AttributeError
|
||||
klass = getattr(module, config["class"])
|
||||
_obj = getattr(module, config[obj_type])
|
||||
kwargs = config.get("kwargs", {})
|
||||
elif isinstance(config, str):
|
||||
module = get_module_by_module_path(default_module)
|
||||
|
||||
klass = getattr(module, config)
|
||||
_obj = getattr(module, config)
|
||||
kwargs = {}
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return klass, kwargs
|
||||
return _obj, kwargs
|
||||
|
||||
|
||||
def init_instance_by_config(
|
||||
|
||||
Reference in New Issue
Block a user