1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 03:21:00 +08:00

static data loader supports fields_group

This commit is contained in:
Dong Zhou
2020-11-24 22:43:34 +08:00
parent dfa8bc10a5
commit db9758575b
2 changed files with 13 additions and 29 deletions

View File

@@ -161,29 +161,17 @@ class StaticDataLoader(DataLoader):
DataLoader that supports loading data from file or as provided.
"""
def __init__(
self, feature_path_or_obj: Union[str, pd.DataFrame], label_path_or_obj: Union[str, pd.DataFrame] = None
):
def __init__(self, config: dict, join='outer'):
"""
Parameters
----------
feature_path_or_obj : str or pd.DataFrame
file path or pandas object for feature
label_path_or_obj : str or pd.DataFrame
file path or pandas object for label
config : dict
{fields_group: <path or object>}
join : str
How to align different dataframes
"""
if isinstance(feature_path_or_obj, str):
assert os.path.exists(feature_path_or_obj), f"cannot find feature `{feature_path_or_obj}"
else:
assert isinstance(feature_path_or_obj, pd.DataFrame), f"need to be dataframe"
self._feature_path_or_obj = feature_path_or_obj
if isinstance(label_path_or_obj, str):
assert os.path.exists(label_path_or_obj), f"cannot find label `{label_path_or_obj}"
elif label_path_or_obj is not None:
assert isinstance(label_path_or_obj, pd.DataFrame), f"need to be dataframe"
self._label_path_or_obj = label_path_or_obj
self.config = config
self.join = join
self._data = None
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
@@ -199,13 +187,7 @@ class StaticDataLoader(DataLoader):
def _maybe_load_raw_data(self):
if self._data is not None:
return
self._data = load_dataset(self._feature_path_or_obj)
if self._label_path_or_obj is not None:
self._data = pd.concat({"feature": self._data, "label": load_dataset(self._label_path_or_obj)}, axis=1)
if not isinstance(self._data.columns, pd.MultiIndex):
self._data.columns = pd.MultiIndex.from_arrays(
[
np.array(["feature", "label"])[self._data.columns.str.contains("^LABEL").astype(int)],
self._data.columns,
]
)
self._data = pd.concat({
fields_group: load_dataset(path_or_obj)
for fields_group, path_or_obj in self.config.items()
}, axis=1, join=self.join)

View File

@@ -701,6 +701,8 @@ def load_dataset(path_or_obj):
"""load dataset from multiple file formats"""
if isinstance(path_or_obj, pd.DataFrame):
return path_or_obj
if not os.path.exists(path_or_obj):
raise ValueError(f'file {path_or_obj} doesn\'t exist')
_, extension = os.path.splitext(path_or_obj)
if extension == ".h5":
return pd.read_hdf(path_or_obj)