1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 18:40:58 +08:00
Files
qlib/qlib/data/dataset/loader.py
Linlang 2c33332dd6 More dataloader example (#1823)
* More dataloader example

* optimize code

* optimeze code

* optimeze code

* optimeze code

* optimeze code

* optimeze code

* fix pylint error

* fix CI error

* fix CI error

* Comments

* fix error type

---------

Co-authored-by: Young <afe.young@gmail.com>
2024-07-10 14:48:44 +08:00

408 lines
14 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import abc
import pickle
from pathlib import Path
import warnings
import pandas as pd
from typing import Tuple, Union, List, Dict
from qlib.data import D
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point
from qlib.log import get_module_logger
from qlib.utils.serial import Serializable
class DataLoader(abc.ABC):
"""
DataLoader is designed for loading raw data from original data source.
"""
@abc.abstractmethod
def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame:
"""
load the data as pd.DataFrame.
Example of the data (The multi-index of the columns is optional.):
.. code-block:: text
feature label
$close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
datetime instrument
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
Parameters
----------
instruments : str or dict
it can either be the market name or the config file of instruments generated by InstrumentProvider.
If the value of instruments is None, it means that no filtering is done.
start_time : str
start of the time range.
end_time : str
end of the time range.
Returns
-------
pd.DataFrame:
data load from the under layer source
Raise
-----
KeyError:
if the instruments filter is not supported, raise KeyError
"""
class DLWParser(DataLoader):
"""
(D)ata(L)oader (W)ith (P)arser for features and names
Extracting this class so that QlibDataLoader and other dataloaders(such as QdbDataLoader) can share the fields.
"""
def __init__(self, config: Union[list, tuple, dict]):
"""
Parameters
----------
config : Union[list, tuple, dict]
Config will be used to describe the fields and column names
.. code-block::
<config> := {
"group_name1": <fields_info1>
"group_name2": <fields_info2>
}
or
<config> := <fields_info>
<fields_info> := ["expr", ...] | (["expr", ...], ["col_name", ...])
# NOTE: list or tuple will be treated as the things when parsing
"""
self.is_group = isinstance(config, dict)
if self.is_group:
self.fields = {grp: self._parse_fields_info(fields_info) for grp, fields_info in config.items()}
else:
self.fields = self._parse_fields_info(config)
def _parse_fields_info(self, fields_info: Union[list, tuple]) -> Tuple[list, list]:
if len(fields_info) == 0:
raise ValueError("The size of fields must be greater than 0")
if not isinstance(fields_info, (list, tuple)):
raise TypeError("Unsupported type")
if isinstance(fields_info[0], str):
exprs = names = fields_info
elif isinstance(fields_info[0], (list, tuple)):
exprs, names = fields_info
else:
raise NotImplementedError(f"This type of input is not supported")
return exprs, names
@abc.abstractmethod
def load_group_df(
self,
instruments,
exprs: list,
names: list,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
gp_name: str = None,
) -> pd.DataFrame:
"""
load the dataframe for specific group
Parameters
----------
instruments :
the instruments.
exprs : list
the expressions to describe the content of the data.
names : list
the name of the data.
Returns
-------
pd.DataFrame:
the queried dataframe.
"""
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
if self.is_group:
df = pd.concat(
{
grp: self.load_group_df(instruments, exprs, names, start_time, end_time, grp)
for grp, (exprs, names) in self.fields.items()
},
axis=1,
)
else:
exprs, names = self.fields
df = self.load_group_df(instruments, exprs, names, start_time, end_time)
return df
class QlibDataLoader(DLWParser):
"""Same as QlibDataLoader. The fields can be define by config"""
def __init__(
self,
config: Tuple[list, tuple, dict],
filter_pipe: List = None,
swap_level: bool = True,
freq: Union[str, dict] = "day",
inst_processors: Union[dict, list] = None,
):
"""
Parameters
----------
config : Tuple[list, tuple, dict]
Please refer to the doc of DLWParser
filter_pipe :
Filter pipe for the instruments
swap_level :
Whether to swap level of MultiIndex
freq: dict or str
If type(config) == dict and type(freq) == str, load config data using freq.
If type(config) == dict and type(freq) == dict, load config[<group_name>] data using freq[<group_name>]
inst_processors: dict | list
If inst_processors is not None and type(config) == dict; load config[<group_name>] data using inst_processors[<group_name>]
If inst_processors is a list, then it will be applied to all groups.
"""
self.filter_pipe = filter_pipe
self.swap_level = swap_level
self.freq = freq
# sample
self.inst_processors = inst_processors if inst_processors is not None else {}
assert isinstance(
self.inst_processors, (dict, list)
), f"inst_processors(={self.inst_processors}) must be dict or list"
super().__init__(config)
if self.is_group:
# check sample config
if isinstance(freq, dict):
for _gp in config.keys():
if _gp not in freq:
raise ValueError(f"freq(={freq}) missing group(={_gp})")
assert (
self.inst_processors
), f"freq(={self.freq}), inst_processors(={self.inst_processors}) cannot be None/empty"
def load_group_df(
self,
instruments,
exprs: list,
names: list,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
gp_name: str = None,
) -> pd.DataFrame:
if instruments is None:
warnings.warn("`instruments` is not set, will load all stocks")
instruments = "all"
if isinstance(instruments, str):
instruments = D.instruments(instruments, filter_pipe=self.filter_pipe)
elif self.filter_pipe is not None:
warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")
freq = self.freq[gp_name] if isinstance(self.freq, dict) else self.freq
inst_processors = (
self.inst_processors if isinstance(self.inst_processors, list) else self.inst_processors.get(gp_name, [])
)
df = D.features(instruments, exprs, start_time, end_time, freq=freq, inst_processors=inst_processors)
df.columns = names
if self.swap_level:
df = df.swaplevel().sort_index() # NOTE: if swaplevel, return <datetime, instrument>
return df
class StaticDataLoader(DataLoader, Serializable):
"""
DataLoader that supports loading data from file or as provided.
"""
include_attr = ["_config"]
def __init__(self, config: Union[dict, str, pd.DataFrame], join="outer"):
"""
Parameters
----------
config : dict
{fields_group: <path or object>}
join : str
How to align different dataframes
"""
self._config = config # using "_" to avoid confliction with the method `config` of Serializable
self.join = join
self._data = None
def __getstate__(self) -> dict:
# avoid pickling `self._data`
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
self._maybe_load_raw_data()
# 1) Filter by instruments
if instruments is None:
df = self._data
else:
df = self._data.loc(axis=0)[:, instruments]
# 2) Filter by Datetime
if start_time is None and end_time is None:
return df # NOTE: avoid copy by loc
# pd.Timestamp(None) == NaT, use NaT as index can not fetch correct thing, so do not change None.
start_time = time_to_slc_point(start_time)
end_time = time_to_slc_point(end_time)
return df.loc[start_time:end_time]
def _maybe_load_raw_data(self):
if self._data is not None:
return
if isinstance(self._config, dict):
self._data = pd.concat(
{fields_group: load_dataset(path_or_obj) for fields_group, path_or_obj in self._config.items()},
axis=1,
join=self.join,
)
self._data.sort_index(inplace=True)
elif isinstance(self._config, (str, Path)):
with Path(self._config).open("rb") as f:
self._data = pickle.load(f)
elif isinstance(self._config, pd.DataFrame):
self._data = self._config
class NestedDataLoader(DataLoader):
"""
We have multiple DataLoader, we can use this class to combine them.
"""
def __init__(self, dataloader_l: List[Dict], join="left") -> None:
"""
Parameters
----------
dataloader_l : list[dict]
A list of dataloader, for exmaple
.. code-block:: python
nd = NestedDataLoader(
dataloader_l=[
{
"class": "qlib.contrib.data.loader.Alpha158DL",
}, {
"class": "qlib.contrib.data.loader.Alpha360DL",
"kwargs": {
"config": {
"label": ( ["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
}
}
}
]
)
join :
it will pass to pd.concat when merging it.
"""
super().__init__()
self.data_loader_l = [
(dl if isinstance(dl, DataLoader) else init_instance_by_config(dl)) for dl in dataloader_l
]
self.join = join
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
df_full = None
for dl in self.data_loader_l:
try:
df_current = dl.load(instruments, start_time, end_time)
except KeyError:
warnings.warn(
"If the value of `instruments` cannot be processed, it will set instruments to None to get all the data."
)
df_current = dl.load(instruments=None, start_time=start_time, end_time=end_time)
if df_full is None:
df_full = df_current
else:
df_full = pd.merge(df_full, df_current, left_index=True, right_index=True, how=self.join)
return df_full.sort_index(axis=1)
class DataLoaderDH(DataLoader):
"""DataLoaderDH
DataLoader based on (D)ata (H)andler
It is designed to load multiple data from data handler
- If you just want to load data from single datahandler, you can write them in single data handler
TODO: What make this module not that easy to use.
- For online scenario
- The underlayer data handler should be configured. But data loader doesn't provide such interface & hook.
"""
def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False):
"""
Parameters
----------
handler_config : dict
handler_config will be used to describe the handlers
.. code-block::
<handler_config> := {
"group_name1": <handler>
"group_name2": <handler>
}
or
<handler_config> := <handler>
<handler> := DataHandler Instance | DataHandler Config
fetch_kwargs : dict
fetch_kwargs will be used to describe the different arguments of fetch method, such as col_set, squeeze, data_key, etc.
is_group: bool
is_group will be used to describe whether the key of handler_config is group
"""
from qlib.data.dataset.handler import DataHandler # pylint: disable=C0415
if is_group:
self.handlers = {
grp: init_instance_by_config(config, accept_types=DataHandler) for grp, config in handler_config.items()
}
else:
self.handlers = init_instance_by_config(handler_config, accept_types=DataHandler)
self.is_group = is_group
self.fetch_kwargs = {"col_set": DataHandler.CS_RAW}
self.fetch_kwargs.update(fetch_kwargs)
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
if instruments is not None:
get_module_logger(self.__class__.__name__).warning(f"instruments[{instruments}] is ignored")
if self.is_group:
df = pd.concat(
{
grp: dh.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs)
for grp, dh in self.handlers.items()
},
axis=1,
)
else:
df = self.handlers.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs)
return df