1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

dataloader support static file or dataframe

This commit is contained in:
Dong Zhou
2020-11-24 10:08:52 +08:00
parent 6ded0d50c7
commit 5729b2242e
3 changed files with 80 additions and 3 deletions

View File

@@ -51,7 +51,7 @@ class DataHandler(Serializable):
def __init__(
self,
instruments,
instruments=None,
start_time=None,
end_time=None,
data_loader: Tuple[dict, str, DataLoader] = None,
@@ -242,7 +242,7 @@ class DataHandlerLP(DataHandler):
def __init__(
self,
instruments,
instruments=None,
start_time=None,
end_time=None,
data_loader: Tuple[dict, str, DataLoader] = None,

View File

@@ -1,13 +1,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import abc
import warnings
import numpy as np
import pandas as pd
from typing import Tuple
from typing import Tuple, Union
from qlib.data import D
from qlib.utils import load_dataset
class DataLoader(abc.ABC):
@@ -139,6 +142,9 @@ class QlibDataLoader(DLWParser):
super().__init__(config)
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
if instruments is None:
warnings.warn('`instruments` is not set, will load all stocks')
instruments = 'all'
if isinstance(instruments, str):
instruments = D.instruments(instruments, filter_pipe=self.filter_pipe)
elif self.filter_pipe is not None:
@@ -148,3 +154,60 @@ class QlibDataLoader(DLWParser):
df.columns = names
df = df.swaplevel().sort_index() # NOTE: always return <datetime, instrument>
return df
class StaticDataLoader(DataLoader):
"""
DataLoader that supports loading data from file or as provided.
"""
def __init__(self, feature_path_or_obj: Union[str, pd.DataFrame], label_path_or_obj: Union[str, pd.DataFrame] = None):
"""
Parameters
----------
feature_path_or_obj : str or pd.DataFrame
file path or pandas object for feature
label_path_or_obj : str or pd.DataFrame
file path or pandas object for label
"""
if isinstance(feature_path_or_obj, str):
assert os.path.exists(feature_path_or_obj), f"cannot find feature `{feature_path_or_obj}"
else:
assert isinstance(feature_path_or_obj, pd.DataFrame), f"need to be dataframe"
self._feature_path_or_obj = feature_path_or_obj
if isinstance(label_path_or_obj, str):
assert os.path.exists(label_path_or_obj), f"cannot find label `{label_path_or_obj}"
elif label_path_or_obj is not None:
assert isinstance(label_path_or_obj, pd.DataFrame), f"need to be dataframe"
self._label_path_or_obj = label_path_or_obj
self._data = None
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
self._maybe_load_raw_data()
if instruments is None:
df = self._data
else:
df = self._data.loc(axis=0)[:, instruments]
if start_time is None and end_time is None:
return df # NOTE: avoid copy by loc
return df.loc[pd.Timestamp(start_time):pd.Timestamp(end_time)]
def _maybe_load_raw_data(self):
if self._data is not None:
return
self._data = load_dataset(self._feature_path_or_obj)
if self._label_path_or_obj is not None:
self._data = pd.concat(
{"feature": self._data, "label": load_dataset(self._label_path_or_obj)}, axis=1
)
if not isinstance(self._data.columns, pd.MultiIndex):
self._data.columns = pd.MultiIndex.from_arrays(
[
np.array(["feature", "label"])[
self._data.columns.str.contains("^LABEL").astype(int)
],
self._data.columns,
]
)

View File

@@ -695,3 +695,17 @@ def register_wrapper(wrapper, cls_or_obj, module_path=None):
cls_or_obj = getattr(module, cls_or_obj)
obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj
wrapper.register(obj)
def load_dataset(path_or_obj):
"""load dataset from multiple file formats"""
if isinstance(path_or_obj, pd.DataFrame):
return path_or_obj
_, extension = os.path.splitext(path_or_obj)
if extension == '.h5':
return pd.read_hdf(path_or_obj)
elif extension == '.pkl':
return pd.read_pickle(path_or_obj)
elif extension == '.csv':
return pd.read_csv(path_or_obj, parse_dates=True, index_col=[0, 1])
raise ValueError(f'unsupported file type `{extension}`')