mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 19:10:58 +08:00
fix dataloader & add interface to datahandler
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
import abc
|
||||
import bisect
|
||||
import logging
|
||||
from typing import Union, Tuple, List
|
||||
from typing import Union, Tuple, List, Iterator, Optional
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
@@ -113,8 +113,7 @@ class DataHandler(Serializable):
|
||||
CS_ALL = "__all"
|
||||
|
||||
def _fetch_df_by_col(self, df: pd.DataFrame, col_set: str) -> pd.DataFrame:
|
||||
cln = len(df.columns.levels)
|
||||
if cln == 1:
|
||||
if not isinstance(df.columns, pd.MultiIndex):
|
||||
return df
|
||||
elif col_set == self.CS_ALL:
|
||||
return df.droplevel(axis=1, level=0)
|
||||
@@ -126,6 +125,7 @@ class DataHandler(Serializable):
|
||||
selector: Union[pd.Timestamp, slice, str],
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = CS_ALL,
|
||||
squeeze: bool = False
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
fetch data from underlying data source
|
||||
@@ -141,13 +141,22 @@ class DataHandler(Serializable):
|
||||
select a set of meaningful columns.(e.g. features, columns)
|
||||
if isinstance(col_set, List[str]):
|
||||
select several sets of meaningful columns, the returned data has multiple levels
|
||||
squeeze : bool
|
||||
whether squeeze columns and index
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
"""
|
||||
df = self._fetch_df_by_index(self._data, selector, level)
|
||||
return self._fetch_df_by_col(df, col_set)
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
if squeeze:
|
||||
# squeeze columns
|
||||
df = df.squeeze()
|
||||
# squeeze index
|
||||
if isinstance(selector, (str, pd.Timestamp)):
|
||||
df = df.reset_index(level=level, drop=True)
|
||||
return df
|
||||
|
||||
def get_cols(self, col_set=CS_ALL) -> list:
|
||||
"""
|
||||
@@ -167,6 +176,51 @@ class DataHandler(Serializable):
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
return df.columns.to_list()
|
||||
|
||||
def get_range_selector(self, cur_date: Union[pd.Timestamp, str], periods: int) -> slice:
|
||||
"""
|
||||
get range selector by number of periods
|
||||
|
||||
Args:
|
||||
cur_date (pd.Timestamp or str): current date
|
||||
periods (int): number of periods
|
||||
"""
|
||||
trading_dates = self.get_unique_index('datetime')
|
||||
cur_loc = trading_dates.get_loc(cur_date)
|
||||
pre_loc = cur_loc - periods + 1
|
||||
if pre_loc < 0:
|
||||
warnings.warn('`periods` is too large. the first date will be returned.')
|
||||
pre_loc = 0
|
||||
ref_date = trading_dates[pre_loc]
|
||||
return slice(ref_date, cur_date)
|
||||
|
||||
def get_range_iterator(self, periods: int, min_periods: Optional[int] = None,
|
||||
**kwargs) -> Iterator[Tuple[pd.Timestamp, pd.DataFrame]]:
|
||||
"""
|
||||
get a iterator of sliced data with given periods
|
||||
|
||||
Args:
|
||||
periods (int): number of periods
|
||||
min_periods (int): minimum periods for sliced dataframe
|
||||
kwargs (dict): will be passed to `self.fetch`
|
||||
"""
|
||||
trading_dates = self.get_unique_index('datetime')
|
||||
if min_periods is None:
|
||||
min_periods = periods
|
||||
for cur_date in trading_dates[min_periods:]:
|
||||
selector = self.get_range_selector(cur_date, periods)
|
||||
yield cur_date, self.fetch(selector, **kwargs)
|
||||
|
||||
def get_unique_index(self, level: Union[str, int] = 'datetime') -> pd.Index:
|
||||
"""
|
||||
get unique index by level id (int) or name (str)
|
||||
|
||||
Args:
|
||||
level (str or int): index level
|
||||
"""
|
||||
if self._data is None:
|
||||
raise ValueError('data is not loaded!')
|
||||
return self._data.index.unique(level=level)
|
||||
|
||||
|
||||
class DataHandlerLP(DataHandler):
|
||||
"""
|
||||
|
||||
@@ -1,78 +1,75 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import abc
|
||||
import warnings
|
||||
import pandas as pd
|
||||
from qlib.data import D
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from qlib.data import D
|
||||
|
||||
class DataLoader(ABC):
|
||||
"""
|
||||
class DataLoader(abc.ABC):
|
||||
'''
|
||||
DataLoader is designed for loading raw data from original data source.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
'''
|
||||
@abc.abstractmethod
|
||||
def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
"""
|
||||
load the data as pd.DataFrame
|
||||
load the data as pd.DataFrame
|
||||
|
||||
Parameters
|
||||
----------
|
||||
self : [TODO:type]
|
||||
[TODO:description]
|
||||
instruments : [TODO:type]
|
||||
[TODO:description]
|
||||
start_time : [TODO:type]
|
||||
[TODO:description]
|
||||
end_time : [TODO:type]
|
||||
[TODO:description]
|
||||
Parameters
|
||||
----------
|
||||
self : [TODO:type]
|
||||
[TODO:description]
|
||||
instruments : [TODO:type]
|
||||
[TODO:description]
|
||||
start_time : [TODO:type]
|
||||
[TODO:description]
|
||||
end_time : [TODO:type]
|
||||
[TODO:description]
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
data load from the under layer source
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
data load from the under layer source
|
||||
|
||||
Example of the data:
|
||||
The multi-index of the columns is optional.
|
||||
feature label
|
||||
$close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
|
||||
datetime instrument
|
||||
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
|
||||
Example of the data:
|
||||
(The multi-index of the columns is optional.)
|
||||
feature label
|
||||
$close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
|
||||
datetime instrument
|
||||
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class QlibDataLoader(DataLoader):
|
||||
"""Same as QlibDataLoader. The fields can be define by config"""
|
||||
|
||||
'''Same as QlibDataLoader. The fields can be define by config'''
|
||||
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
config : Tuple[list ,tuple, dict]
|
||||
config : Tuple[list, tuple, dict]
|
||||
Config will be used to describe the fields and column names
|
||||
|
||||
<config> := {
|
||||
"group_name1": <fields_info1>
|
||||
"group_name2": <fields_info2>
|
||||
}
|
||||
|
||||
or
|
||||
<config> := <fields_info>
|
||||
|
||||
<fields_info> := ["expr", ...] | (["expr", ...], ["col_name", ...])
|
||||
|
||||
Here is a few examples to describe the fields
|
||||
TODO:
|
||||
"""
|
||||
self.is_group = isinstance(config, dict)
|
||||
self.is_group = isinstance(config, dict)
|
||||
|
||||
if self.is_group:
|
||||
self.fields = {grp: self._parse_fields_info(fields_info) for grp, fields_info in config.items()}
|
||||
else:
|
||||
self.fields = self._parse_fields_info(fields_info)
|
||||
self.fields = self._parse_fields_info(config)
|
||||
|
||||
self.filter_pipe = filter_pipe
|
||||
|
||||
@@ -86,14 +83,18 @@ class QlibDataLoader(DataLoader):
|
||||
return exprs, names
|
||||
|
||||
def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
if isinstance(instruments, str):
|
||||
instruments = D.instruments(instruments, filter_pipe=self.filter_pipe)
|
||||
elif self.filter_pipe is not None:
|
||||
warnings.warn('`filter_pipe` is not None, but it will not be used with `instruments` as list')
|
||||
def _get_df(exprs, names):
|
||||
df = D.features(D.instruments(instruments, filter_pipe=self.filter_pipe), exprs, start_time, end_time)
|
||||
df = D.features(instruments, exprs, start_time, end_time)
|
||||
df.columns = names
|
||||
return df
|
||||
|
||||
if self.is_group:
|
||||
df = pd.concat({grp: _get_df(exprs, names) for grp, (exprs, names) in self.fields.items()}, axis=1)
|
||||
else:
|
||||
exprs, names = self.fields
|
||||
df = _get_df(exprs, names)
|
||||
df = df.swaplevel().sort_index()
|
||||
df = df.swaplevel().sort_index() # NOTE: always return <datetime, instrument>
|
||||
return df
|
||||
|
||||
Reference in New Issue
Block a user