diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 04715c892..7f1dbd179 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -5,7 +5,7 @@ import abc import bisect import logging -from typing import Union, Tuple, List +from typing import Union, Tuple, List, Iterator, Optional import pandas as pd import numpy as np @@ -113,8 +113,7 @@ class DataHandler(Serializable): CS_ALL = "__all" def _fetch_df_by_col(self, df: pd.DataFrame, col_set: str) -> pd.DataFrame: - cln = len(df.columns.levels) - if cln == 1: + if not isinstance(df.columns, pd.MultiIndex): return df elif col_set == self.CS_ALL: return df.droplevel(axis=1, level=0) @@ -126,6 +125,7 @@ class DataHandler(Serializable): selector: Union[pd.Timestamp, slice, str], level: Union[str, int] = "datetime", col_set: Union[str, List[str]] = CS_ALL, + squeeze: bool = False ) -> pd.DataFrame: """ fetch data from underlying data source @@ -141,13 +141,22 @@ class DataHandler(Serializable): select a set of meaningful columns.(e.g. features, columns) if isinstance(col_set, List[str]): select several sets of meaningful columns, the returned data has multiple levels + squeeze : bool + whether squeeze columns and index Returns ------- pd.DataFrame: """ df = self._fetch_df_by_index(self._data, selector, level) - return self._fetch_df_by_col(df, col_set) + df = self._fetch_df_by_col(df, col_set) + if squeeze: + # squeeze columns + df = df.squeeze() + # squeeze index + if isinstance(selector, (str, pd.Timestamp)): + df = df.reset_index(level=level, drop=True) + return df def get_cols(self, col_set=CS_ALL) -> list: """ @@ -167,6 +176,51 @@ class DataHandler(Serializable): df = self._fetch_df_by_col(df, col_set) return df.columns.to_list() + def get_range_selector(self, cur_date: Union[pd.Timestamp, str], periods: int) -> slice: + """ + get range selector by number of periods + + Args: + cur_date (pd.Timestamp or str): current date + periods (int): number of periods + """ + trading_dates = self.get_unique_index('datetime') + cur_loc = trading_dates.get_loc(cur_date) + pre_loc = cur_loc - periods + 1 + if pre_loc < 0: + warnings.warn('`periods` is too large. the first date will be returned.') + pre_loc = 0 + ref_date = trading_dates[pre_loc] + return slice(ref_date, cur_date) + + def get_range_iterator(self, periods: int, min_periods: Optional[int] = None, + **kwargs) -> Iterator[Tuple[pd.Timestamp, pd.DataFrame]]: + """ + get a iterator of sliced data with given periods + + Args: + periods (int): number of periods + min_periods (int): minimum periods for sliced dataframe + kwargs (dict): will be passed to `self.fetch` + """ + trading_dates = self.get_unique_index('datetime') + if min_periods is None: + min_periods = periods + for cur_date in trading_dates[min_periods:]: + selector = self.get_range_selector(cur_date, periods) + yield cur_date, self.fetch(selector, **kwargs) + + def get_unique_index(self, level: Union[str, int] = 'datetime') -> pd.Index: + """ + get unique index by level id (int) or name (str) + + Args: + level (str or int): index level + """ + if self._data is None: + raise ValueError('data is not loaded!') + return self._data.index.unique(level=level) + class DataHandlerLP(DataHandler): """ diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index e4f2f8619..abfc695d9 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -1,78 +1,75 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from abc import ABC, abstractmethod +import abc +import warnings import pandas as pd -from qlib.data import D + from typing import Tuple +from qlib.data import D -class DataLoader(ABC): - """ +class DataLoader(abc.ABC): + ''' DataLoader is designed for loading raw data from original data source. - """ - - @abstractmethod + ''' + @abc.abstractmethod def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame: """ - load the data as pd.DataFrame + load the data as pd.DataFrame - Parameters - ---------- - self : [TODO:type] - [TODO:description] - instruments : [TODO:type] - [TODO:description] - start_time : [TODO:type] - [TODO:description] - end_time : [TODO:type] - [TODO:description] + Parameters + ---------- + self : [TODO:type] + [TODO:description] + instruments : [TODO:type] + [TODO:description] + start_time : [TODO:type] + [TODO:description] + end_time : [TODO:type] + [TODO:description] - Returns - ------- - pd.DataFrame: - data load from the under layer source + Returns + ------- + pd.DataFrame: + data load from the under layer source - Example of the data: - The multi-index of the columns is optional. - feature label - $close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0 - datetime instrument - 2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032 - SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042 - SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289 + Example of the data: + (The multi-index of the columns is optional.) + feature label + $close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0 + datetime instrument + 2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032 + SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042 + SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289 """ pass class QlibDataLoader(DataLoader): - """Same as QlibDataLoader. The fields can be define by config""" - + '''Same as QlibDataLoader. The fields can be define by config''' def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None): """ Parameters ---------- - config : Tuple[list ,tuple, dict] + config : Tuple[list, tuple, dict] Config will be used to describe the fields and column names := { "group_name1": "group_name2": } - + or := := ["expr", ...] | (["expr", ...], ["col_name", ...]) - - Here is a few examples to describe the fields - TODO: """ - self.is_group = isinstance(config, dict) + self.is_group = isinstance(config, dict) if self.is_group: self.fields = {grp: self._parse_fields_info(fields_info) for grp, fields_info in config.items()} else: - self.fields = self._parse_fields_info(fields_info) + self.fields = self._parse_fields_info(config) self.filter_pipe = filter_pipe @@ -86,14 +83,18 @@ class QlibDataLoader(DataLoader): return exprs, names def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame: + if isinstance(instruments, str): + instruments = D.instruments(instruments, filter_pipe=self.filter_pipe) + elif self.filter_pipe is not None: + warnings.warn('`filter_pipe` is not None, but it will not be used with `instruments` as list') def _get_df(exprs, names): - df = D.features(D.instruments(instruments, filter_pipe=self.filter_pipe), exprs, start_time, end_time) + df = D.features(instruments, exprs, start_time, end_time) df.columns = names return df - if self.is_group: df = pd.concat({grp: _get_df(exprs, names) for grp, (exprs, names) in self.fields.items()}, axis=1) else: + exprs, names = self.fields df = _get_df(exprs, names) - df = df.swaplevel().sort_index() + df = df.swaplevel().sort_index() # NOTE: always return return df