diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index d1dbe1777..0a2fed637 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -51,7 +51,7 @@ class DataHandler(Serializable): def __init__( self, - instruments, + instruments=None, start_time=None, end_time=None, data_loader: Tuple[dict, str, DataLoader] = None, @@ -242,7 +242,7 @@ class DataHandlerLP(DataHandler): def __init__( self, - instruments, + instruments=None, start_time=None, end_time=None, data_loader: Tuple[dict, str, DataLoader] = None, diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 564a7e5d5..7e8dd507c 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -1,13 +1,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import os import abc import warnings +import numpy as np import pandas as pd -from typing import Tuple +from typing import Tuple, Union from qlib.data import D +from qlib.utils import load_dataset class DataLoader(abc.ABC): @@ -139,6 +142,9 @@ class QlibDataLoader(DLWParser): super().__init__(config) def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame: + if instruments is None: + warnings.warn('`instruments` is not set, will load all stocks') + instruments = 'all' if isinstance(instruments, str): instruments = D.instruments(instruments, filter_pipe=self.filter_pipe) elif self.filter_pipe is not None: @@ -148,3 +154,60 @@ class QlibDataLoader(DLWParser): df.columns = names df = df.swaplevel().sort_index() # NOTE: always return return df + + +class StaticDataLoader(DataLoader): + """ + DataLoader that supports loading data from file or as provided. + """ + + def __init__(self, feature_path_or_obj: Union[str, pd.DataFrame], label_path_or_obj: Union[str, pd.DataFrame] = None): + """ + Parameters + ---------- + feature_path_or_obj : str or pd.DataFrame + file path or pandas object for feature + label_path_or_obj : str or pd.DataFrame + file path or pandas object for label + """ + if isinstance(feature_path_or_obj, str): + assert os.path.exists(feature_path_or_obj), f"cannot find feature `{feature_path_or_obj}" + else: + assert isinstance(feature_path_or_obj, pd.DataFrame), f"need to be dataframe" + self._feature_path_or_obj = feature_path_or_obj + + if isinstance(label_path_or_obj, str): + assert os.path.exists(label_path_or_obj), f"cannot find label `{label_path_or_obj}" + elif label_path_or_obj is not None: + assert isinstance(label_path_or_obj, pd.DataFrame), f"need to be dataframe" + self._label_path_or_obj = label_path_or_obj + + self._data = None + + def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: + self._maybe_load_raw_data() + if instruments is None: + df = self._data + else: + df = self._data.loc(axis=0)[:, instruments] + if start_time is None and end_time is None: + return df # NOTE: avoid copy by loc + return df.loc[pd.Timestamp(start_time):pd.Timestamp(end_time)] + + def _maybe_load_raw_data(self): + if self._data is not None: + return + self._data = load_dataset(self._feature_path_or_obj) + if self._label_path_or_obj is not None: + self._data = pd.concat( + {"feature": self._data, "label": load_dataset(self._label_path_or_obj)}, axis=1 + ) + if not isinstance(self._data.columns, pd.MultiIndex): + self._data.columns = pd.MultiIndex.from_arrays( + [ + np.array(["feature", "label"])[ + self._data.columns.str.contains("^LABEL").astype(int) + ], + self._data.columns, + ] + ) diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 79fd6fe5c..c77c67fa2 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -695,3 +695,17 @@ def register_wrapper(wrapper, cls_or_obj, module_path=None): cls_or_obj = getattr(module, cls_or_obj) obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj wrapper.register(obj) + + +def load_dataset(path_or_obj): + """load dataset from multiple file formats""" + if isinstance(path_or_obj, pd.DataFrame): + return path_or_obj + _, extension = os.path.splitext(path_or_obj) + if extension == '.h5': + return pd.read_hdf(path_or_obj) + elif extension == '.pkl': + return pd.read_pickle(path_or_obj) + elif extension == '.csv': + return pd.read_csv(path_or_obj, parse_dates=True, index_col=[0, 1]) + raise ValueError(f'unsupported file type `{extension}`')