1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 19:10:58 +08:00

fix dataloader & add interface to datahandler

This commit is contained in:
Dong Zhou
2020-10-30 11:20:15 +08:00
committed by you-n-g
parent 9dc357bc81
commit c59058b47d
2 changed files with 101 additions and 46 deletions

View File

@@ -5,7 +5,7 @@
import abc
import bisect
import logging
from typing import Union, Tuple, List
from typing import Union, Tuple, List, Iterator, Optional
import pandas as pd
import numpy as np
@@ -113,8 +113,7 @@ class DataHandler(Serializable):
CS_ALL = "__all"
def _fetch_df_by_col(self, df: pd.DataFrame, col_set: str) -> pd.DataFrame:
cln = len(df.columns.levels)
if cln == 1:
if not isinstance(df.columns, pd.MultiIndex):
return df
elif col_set == self.CS_ALL:
return df.droplevel(axis=1, level=0)
@@ -126,6 +125,7 @@ class DataHandler(Serializable):
selector: Union[pd.Timestamp, slice, str],
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = CS_ALL,
squeeze: bool = False
) -> pd.DataFrame:
"""
fetch data from underlying data source
@@ -141,13 +141,22 @@ class DataHandler(Serializable):
select a set of meaningful columns.(e.g. features, columns)
if isinstance(col_set, List[str]):
select several sets of meaningful columns, the returned data has multiple levels
squeeze : bool
whether squeeze columns and index
Returns
-------
pd.DataFrame:
"""
df = self._fetch_df_by_index(self._data, selector, level)
return self._fetch_df_by_col(df, col_set)
df = self._fetch_df_by_col(df, col_set)
if squeeze:
# squeeze columns
df = df.squeeze()
# squeeze index
if isinstance(selector, (str, pd.Timestamp)):
df = df.reset_index(level=level, drop=True)
return df
def get_cols(self, col_set=CS_ALL) -> list:
"""
@@ -167,6 +176,51 @@ class DataHandler(Serializable):
df = self._fetch_df_by_col(df, col_set)
return df.columns.to_list()
def get_range_selector(self, cur_date: Union[pd.Timestamp, str], periods: int) -> slice:
"""
get range selector by number of periods
Args:
cur_date (pd.Timestamp or str): current date
periods (int): number of periods
"""
trading_dates = self.get_unique_index('datetime')
cur_loc = trading_dates.get_loc(cur_date)
pre_loc = cur_loc - periods + 1
if pre_loc < 0:
warnings.warn('`periods` is too large. the first date will be returned.')
pre_loc = 0
ref_date = trading_dates[pre_loc]
return slice(ref_date, cur_date)
def get_range_iterator(self, periods: int, min_periods: Optional[int] = None,
**kwargs) -> Iterator[Tuple[pd.Timestamp, pd.DataFrame]]:
"""
get a iterator of sliced data with given periods
Args:
periods (int): number of periods
min_periods (int): minimum periods for sliced dataframe
kwargs (dict): will be passed to `self.fetch`
"""
trading_dates = self.get_unique_index('datetime')
if min_periods is None:
min_periods = periods
for cur_date in trading_dates[min_periods:]:
selector = self.get_range_selector(cur_date, periods)
yield cur_date, self.fetch(selector, **kwargs)
def get_unique_index(self, level: Union[str, int] = 'datetime') -> pd.Index:
"""
get unique index by level id (int) or name (str)
Args:
level (str or int): index level
"""
if self._data is None:
raise ValueError('data is not loaded!')
return self._data.index.unique(level=level)
class DataHandlerLP(DataHandler):
"""

View File

@@ -1,78 +1,75 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from abc import ABC, abstractmethod
import abc
import warnings
import pandas as pd
from qlib.data import D
from typing import Tuple
from qlib.data import D
class DataLoader(ABC):
"""
class DataLoader(abc.ABC):
'''
DataLoader is designed for loading raw data from original data source.
"""
@abstractmethod
'''
@abc.abstractmethod
def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame:
"""
load the data as pd.DataFrame
load the data as pd.DataFrame
Parameters
----------
self : [TODO:type]
[TODO:description]
instruments : [TODO:type]
[TODO:description]
start_time : [TODO:type]
[TODO:description]
end_time : [TODO:type]
[TODO:description]
Parameters
----------
self : [TODO:type]
[TODO:description]
instruments : [TODO:type]
[TODO:description]
start_time : [TODO:type]
[TODO:description]
end_time : [TODO:type]
[TODO:description]
Returns
-------
pd.DataFrame:
data load from the under layer source
Returns
-------
pd.DataFrame:
data load from the under layer source
Example of the data:
The multi-index of the columns is optional.
feature label
$close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
datetime instrument
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
Example of the data:
(The multi-index of the columns is optional.)
feature label
$close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
datetime instrument
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
"""
pass
class QlibDataLoader(DataLoader):
"""Same as QlibDataLoader. The fields can be define by config"""
'''Same as QlibDataLoader. The fields can be define by config'''
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None):
"""
Parameters
----------
config : Tuple[list ,tuple, dict]
config : Tuple[list, tuple, dict]
Config will be used to describe the fields and column names
<config> := {
"group_name1": <fields_info1>
"group_name2": <fields_info2>
}
or
<config> := <fields_info>
<fields_info> := ["expr", ...] | (["expr", ...], ["col_name", ...])
Here is a few examples to describe the fields
TODO:
"""
self.is_group = isinstance(config, dict)
self.is_group = isinstance(config, dict)
if self.is_group:
self.fields = {grp: self._parse_fields_info(fields_info) for grp, fields_info in config.items()}
else:
self.fields = self._parse_fields_info(fields_info)
self.fields = self._parse_fields_info(config)
self.filter_pipe = filter_pipe
@@ -86,14 +83,18 @@ class QlibDataLoader(DataLoader):
return exprs, names
def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame:
if isinstance(instruments, str):
instruments = D.instruments(instruments, filter_pipe=self.filter_pipe)
elif self.filter_pipe is not None:
warnings.warn('`filter_pipe` is not None, but it will not be used with `instruments` as list')
def _get_df(exprs, names):
df = D.features(D.instruments(instruments, filter_pipe=self.filter_pipe), exprs, start_time, end_time)
df = D.features(instruments, exprs, start_time, end_time)
df.columns = names
return df
if self.is_group:
df = pd.concat({grp: _get_df(exprs, names) for grp, (exprs, names) in self.fields.items()}, axis=1)
else:
exprs, names = self.fields
df = _get_df(exprs, names)
df = df.swaplevel().sort_index()
df = df.swaplevel().sort_index() # NOTE: always return <datetime, instrument>
return df