mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 03:21:00 +08:00
static data loader supports fields_group
This commit is contained in:
@@ -161,29 +161,17 @@ class StaticDataLoader(DataLoader):
|
||||
DataLoader that supports loading data from file or as provided.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, feature_path_or_obj: Union[str, pd.DataFrame], label_path_or_obj: Union[str, pd.DataFrame] = None
|
||||
):
|
||||
def __init__(self, config: dict, join='outer'):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
feature_path_or_obj : str or pd.DataFrame
|
||||
file path or pandas object for feature
|
||||
label_path_or_obj : str or pd.DataFrame
|
||||
file path or pandas object for label
|
||||
config : dict
|
||||
{fields_group: <path or object>}
|
||||
join : str
|
||||
How to align different dataframes
|
||||
"""
|
||||
if isinstance(feature_path_or_obj, str):
|
||||
assert os.path.exists(feature_path_or_obj), f"cannot find feature `{feature_path_or_obj}"
|
||||
else:
|
||||
assert isinstance(feature_path_or_obj, pd.DataFrame), f"need to be dataframe"
|
||||
self._feature_path_or_obj = feature_path_or_obj
|
||||
|
||||
if isinstance(label_path_or_obj, str):
|
||||
assert os.path.exists(label_path_or_obj), f"cannot find label `{label_path_or_obj}"
|
||||
elif label_path_or_obj is not None:
|
||||
assert isinstance(label_path_or_obj, pd.DataFrame), f"need to be dataframe"
|
||||
self._label_path_or_obj = label_path_or_obj
|
||||
|
||||
self.config = config
|
||||
self.join = join
|
||||
self._data = None
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
@@ -199,13 +187,7 @@ class StaticDataLoader(DataLoader):
|
||||
def _maybe_load_raw_data(self):
|
||||
if self._data is not None:
|
||||
return
|
||||
self._data = load_dataset(self._feature_path_or_obj)
|
||||
if self._label_path_or_obj is not None:
|
||||
self._data = pd.concat({"feature": self._data, "label": load_dataset(self._label_path_or_obj)}, axis=1)
|
||||
if not isinstance(self._data.columns, pd.MultiIndex):
|
||||
self._data.columns = pd.MultiIndex.from_arrays(
|
||||
[
|
||||
np.array(["feature", "label"])[self._data.columns.str.contains("^LABEL").astype(int)],
|
||||
self._data.columns,
|
||||
]
|
||||
)
|
||||
self._data = pd.concat({
|
||||
fields_group: load_dataset(path_or_obj)
|
||||
for fields_group, path_or_obj in self.config.items()
|
||||
}, axis=1, join=self.join)
|
||||
|
||||
@@ -701,6 +701,8 @@ def load_dataset(path_or_obj):
|
||||
"""load dataset from multiple file formats"""
|
||||
if isinstance(path_or_obj, pd.DataFrame):
|
||||
return path_or_obj
|
||||
if not os.path.exists(path_or_obj):
|
||||
raise ValueError(f'file {path_or_obj} doesn\'t exist')
|
||||
_, extension = os.path.splitext(path_or_obj)
|
||||
if extension == ".h5":
|
||||
return pd.read_hdf(path_or_obj)
|
||||
|
||||
Reference in New Issue
Block a user