diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index eddbca044..6d90907f4 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -161,29 +161,17 @@ class StaticDataLoader(DataLoader): DataLoader that supports loading data from file or as provided. """ - def __init__( - self, feature_path_or_obj: Union[str, pd.DataFrame], label_path_or_obj: Union[str, pd.DataFrame] = None - ): + def __init__(self, config: dict, join='outer'): """ Parameters ---------- - feature_path_or_obj : str or pd.DataFrame - file path or pandas object for feature - label_path_or_obj : str or pd.DataFrame - file path or pandas object for label + config : dict + {fields_group: } + join : str + How to align different dataframes """ - if isinstance(feature_path_or_obj, str): - assert os.path.exists(feature_path_or_obj), f"cannot find feature `{feature_path_or_obj}" - else: - assert isinstance(feature_path_or_obj, pd.DataFrame), f"need to be dataframe" - self._feature_path_or_obj = feature_path_or_obj - - if isinstance(label_path_or_obj, str): - assert os.path.exists(label_path_or_obj), f"cannot find label `{label_path_or_obj}" - elif label_path_or_obj is not None: - assert isinstance(label_path_or_obj, pd.DataFrame), f"need to be dataframe" - self._label_path_or_obj = label_path_or_obj - + self.config = config + self.join = join self._data = None def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: @@ -199,13 +187,7 @@ class StaticDataLoader(DataLoader): def _maybe_load_raw_data(self): if self._data is not None: return - self._data = load_dataset(self._feature_path_or_obj) - if self._label_path_or_obj is not None: - self._data = pd.concat({"feature": self._data, "label": load_dataset(self._label_path_or_obj)}, axis=1) - if not isinstance(self._data.columns, pd.MultiIndex): - self._data.columns = pd.MultiIndex.from_arrays( - [ - np.array(["feature", "label"])[self._data.columns.str.contains("^LABEL").astype(int)], - self._data.columns, - ] - ) + self._data = pd.concat({ + fields_group: load_dataset(path_or_obj) + for fields_group, path_or_obj in self.config.items() + }, axis=1, join=self.join) diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 8a1436799..14480b7b5 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -701,6 +701,8 @@ def load_dataset(path_or_obj): """load dataset from multiple file formats""" if isinstance(path_or_obj, pd.DataFrame): return path_or_obj + if not os.path.exists(path_or_obj): + raise ValueError(f'file {path_or_obj} doesn\'t exist') _, extension = os.path.splitext(path_or_obj) if extension == ".h5": return pd.read_hdf(path_or_obj)