1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 01:51:18 +08:00
Files
qlib/qlib/data/dataset/utils.py
2021-06-30 07:34:23 +00:00

114 lines
3.1 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
from typing import Union, List
def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
"""
get the level index of `df` given `level`
Parameters
----------
df : pd.DataFrame
data
level : Union[str, int]
index level
Returns
-------
int:
The level index in the multiple index
"""
if isinstance(level, str):
try:
return df.index.names.index(level)
except (AttributeError, ValueError):
# NOTE: If level index is not given in the data, the default level index will be ('datetime', 'instrument')
return ("datetime", "instrument").index(level)
elif isinstance(level, int):
return level
else:
raise NotImplementedError(f"This type of input is not supported")
def fetch_df_by_index(
df: pd.DataFrame,
selector: Union[pd.Timestamp, slice, str, list],
level: Union[str, int],
fetch_orig=True,
) -> pd.DataFrame:
"""
fetch data from `data` with `selector` and `level`
Parameters
----------
selector : Union[pd.Timestamp, slice, str, list]
selector
level : Union[int, str]
the level to use the selector
Returns
-------
Data of the given index.
"""
# level = None -> use selector directly
if level == None:
return df.loc(axis=0)[selector]
# Try to get the right index
idx_slc = (selector, slice(None, None))
if get_level_index(df, level) == 1:
idx_slc = idx_slc[1], idx_slc[0]
if fetch_orig:
for slc in idx_slc:
if slc != slice(None, None):
return df.loc[
pd.IndexSlice[idx_slc],
]
else:
return df
else:
return df.loc[
pd.IndexSlice[idx_slc],
]
def fetch_df_by_col(df: pd.DataFrame, col_set: Union[str, List[str]]) -> pd.DataFrame:
from .handler import DataHandler
if not isinstance(df.columns, pd.MultiIndex) or col_set == DataHandler.CS_RAW:
return df
elif col_set == DataHandler.CS_ALL:
return df.droplevel(axis=1, level=0)
else:
return df.loc(axis=1)[col_set]
def convert_index_format(df: Union[pd.DataFrame, pd.Series], level: str = "datetime") -> Union[pd.DataFrame, pd.Series]:
"""
Convert the format of df.MultiIndex according to the following rules:
- If `level` is the first level of df.MultiIndex, do nothing
- If `level` is the second level of df.MultiIndex, swap the level of index.
NOTE:
the number of levels of df.MultiIndex should be 2
Parameters
----------
df : Union[pd.DataFrame, pd.Series]
raw DataFrame/Series
level : str, optional
the level that will be converted to the first one, by default "datetime"
Returns
-------
Union[pd.DataFrame, pd.Series]
converted DataFrame/Series
"""
if get_level_index(df, level=level) == 1:
df = df.swaplevel().sort_index()
return df