mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
dataloader support static file or dataframe
This commit is contained in:
@@ -51,7 +51,7 @@ class DataHandler(Serializable):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instruments,
|
||||
instruments=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
data_loader: Tuple[dict, str, DataLoader] = None,
|
||||
@@ -242,7 +242,7 @@ class DataHandlerLP(DataHandler):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instruments,
|
||||
instruments=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
data_loader: Tuple[dict, str, DataLoader] = None,
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import abc
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Union
|
||||
|
||||
from qlib.data import D
|
||||
from qlib.utils import load_dataset
|
||||
|
||||
|
||||
class DataLoader(abc.ABC):
|
||||
@@ -139,6 +142,9 @@ class QlibDataLoader(DLWParser):
|
||||
super().__init__(config)
|
||||
|
||||
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
if instruments is None:
|
||||
warnings.warn('`instruments` is not set, will load all stocks')
|
||||
instruments = 'all'
|
||||
if isinstance(instruments, str):
|
||||
instruments = D.instruments(instruments, filter_pipe=self.filter_pipe)
|
||||
elif self.filter_pipe is not None:
|
||||
@@ -148,3 +154,60 @@ class QlibDataLoader(DLWParser):
|
||||
df.columns = names
|
||||
df = df.swaplevel().sort_index() # NOTE: always return <datetime, instrument>
|
||||
return df
|
||||
|
||||
|
||||
class StaticDataLoader(DataLoader):
|
||||
"""
|
||||
DataLoader that supports loading data from file or as provided.
|
||||
"""
|
||||
|
||||
def __init__(self, feature_path_or_obj: Union[str, pd.DataFrame], label_path_or_obj: Union[str, pd.DataFrame] = None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
feature_path_or_obj : str or pd.DataFrame
|
||||
file path or pandas object for feature
|
||||
label_path_or_obj : str or pd.DataFrame
|
||||
file path or pandas object for label
|
||||
"""
|
||||
if isinstance(feature_path_or_obj, str):
|
||||
assert os.path.exists(feature_path_or_obj), f"cannot find feature `{feature_path_or_obj}"
|
||||
else:
|
||||
assert isinstance(feature_path_or_obj, pd.DataFrame), f"need to be dataframe"
|
||||
self._feature_path_or_obj = feature_path_or_obj
|
||||
|
||||
if isinstance(label_path_or_obj, str):
|
||||
assert os.path.exists(label_path_or_obj), f"cannot find label `{label_path_or_obj}"
|
||||
elif label_path_or_obj is not None:
|
||||
assert isinstance(label_path_or_obj, pd.DataFrame), f"need to be dataframe"
|
||||
self._label_path_or_obj = label_path_or_obj
|
||||
|
||||
self._data = None
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
self._maybe_load_raw_data()
|
||||
if instruments is None:
|
||||
df = self._data
|
||||
else:
|
||||
df = self._data.loc(axis=0)[:, instruments]
|
||||
if start_time is None and end_time is None:
|
||||
return df # NOTE: avoid copy by loc
|
||||
return df.loc[pd.Timestamp(start_time):pd.Timestamp(end_time)]
|
||||
|
||||
def _maybe_load_raw_data(self):
|
||||
if self._data is not None:
|
||||
return
|
||||
self._data = load_dataset(self._feature_path_or_obj)
|
||||
if self._label_path_or_obj is not None:
|
||||
self._data = pd.concat(
|
||||
{"feature": self._data, "label": load_dataset(self._label_path_or_obj)}, axis=1
|
||||
)
|
||||
if not isinstance(self._data.columns, pd.MultiIndex):
|
||||
self._data.columns = pd.MultiIndex.from_arrays(
|
||||
[
|
||||
np.array(["feature", "label"])[
|
||||
self._data.columns.str.contains("^LABEL").astype(int)
|
||||
],
|
||||
self._data.columns,
|
||||
]
|
||||
)
|
||||
|
||||
@@ -695,3 +695,17 @@ def register_wrapper(wrapper, cls_or_obj, module_path=None):
|
||||
cls_or_obj = getattr(module, cls_or_obj)
|
||||
obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj
|
||||
wrapper.register(obj)
|
||||
|
||||
|
||||
def load_dataset(path_or_obj):
|
||||
"""load dataset from multiple file formats"""
|
||||
if isinstance(path_or_obj, pd.DataFrame):
|
||||
return path_or_obj
|
||||
_, extension = os.path.splitext(path_or_obj)
|
||||
if extension == '.h5':
|
||||
return pd.read_hdf(path_or_obj)
|
||||
elif extension == '.pkl':
|
||||
return pd.read_pickle(path_or_obj)
|
||||
elif extension == '.csv':
|
||||
return pd.read_csv(path_or_obj, parse_dates=True, index_col=[0, 1])
|
||||
raise ValueError(f'unsupported file type `{extension}`')
|
||||
|
||||
Reference in New Issue
Block a user