1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 20:11:08 +08:00
Files
qlib/qlib/data/dataset/storage.py
YQ Tsui cc01812c62 Fix typos and grammar errors in docstrings and comments (#1366)
* fix gramma error in doc strings

* fix typos in exchange.py

* fix typos and gramma errors

* fix typo and rename function param to avoid shading python keyword

* remove redundant parathesis; pass kwargs to parent class

* fix pyblack

* further correction

* assign -> be assigned to
2022-11-20 14:15:59 +08:00

159 lines
6.2 KiB
Python

import pandas as pd
import numpy as np
from .handler import DataHandler
from typing import Union, List, Callable
from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col
class BaseHandlerStorage:
"""
Base data storage for datahandler
- pd.DataFrame is the default data storage format in Qlib datahandler
- If users want to use custom data storage, they should define subclass inherited BaseHandlerStorage, and implement the following method
"""
def fetch(
self,
selector: Union[pd.Timestamp, slice, str, list] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
fetch_orig: bool = True,
proc_func: Callable = None,
**kwargs,
) -> pd.DataFrame:
"""fetch data from the data storage
Parameters
----------
selector : Union[pd.Timestamp, slice, str]
describe how to select data by index
level : Union[str, int]
which index level to select the data
- if level is None, apply selector to df directly
col_set : Union[str, List[str]]
- if isinstance(col_set, str):
select a set of meaningful columns.(e.g. features, columns)
if col_set == DataHandler.CS_RAW:
the raw dataset will be returned.
- if isinstance(col_set, List[str]):
select several sets of meaningful columns, the returned data has multiple level
fetch_orig : bool
Return the original data instead of copy if possible.
proc_func: Callable
please refer to the doc of DataHandler.fetch
Returns
-------
pd.DataFrame
the dataframe fetched
"""
raise NotImplementedError("fetch is method not implemented!")
@staticmethod
def from_df(df: pd.DataFrame):
raise NotImplementedError("from_df method is not implemented!")
def is_proc_func_supported(self):
"""whether the arg `proc_func` in `fetch` method is supported."""
raise NotImplementedError("is_proc_func_supported method is not implemented!")
class HashingStockStorage(BaseHandlerStorage):
"""Hashing data storage for datahanlder
- The default data storage pandas.DataFrame is too slow when randomly accessing one stock's data
- HashingStockStorage hashes the multiple stocks' data(pandas.DataFrame) by the key `stock_id`.
- HashingStockStorage hashes the pandas.DataFrame into a dict, whose key is the stock_id(str) and value this stock data(panda.DataFrame), it has the following format:
{
stock1_id: stock1_data,
stock2_id: stock2_data,
...
stockn_id: stockn_data,
}
- By the `fetch` method, users can access any stock data with much lower time cost than default data storage
"""
def __init__(self, df):
self.hash_df = dict()
self.stock_level = get_level_index(df, "instrument")
for k, v in df.groupby(level="instrument"):
self.hash_df[k] = v
self.columns = df.columns
@staticmethod
def from_df(df):
return HashingStockStorage(df)
def _fetch_hash_df_by_stock(self, selector, level):
"""fetch the data with stock selector
Parameters
----------
selector : Union[pd.Timestamp, slice, str]
describe how to select data by index
level : Union[str, int]
which index level to select the data
- if level is None, apply selector to df directly
- the `_fetch_hash_df_by_stock` will parse the stock selector in arg `selector`
Returns
-------
Dict
The dict whose key is stock_id, value is the stock's data
"""
stock_selector = slice(None)
if level is None:
if isinstance(selector, tuple) and self.stock_level < len(selector):
stock_selector = selector[self.stock_level]
elif isinstance(selector, (list, str)) and self.stock_level == 0:
stock_selector = selector
elif level in ("instrument", self.stock_level):
if isinstance(selector, tuple):
stock_selector = selector[0]
elif isinstance(selector, (list, str)):
stock_selector = selector
if not isinstance(stock_selector, (list, str)) and stock_selector != slice(None):
raise TypeError(f"stock selector must be type str|list, or slice(None), rather than {stock_selector}")
if stock_selector == slice(None):
return self.hash_df
if isinstance(stock_selector, str):
stock_selector = [stock_selector]
select_dict = dict()
for each_stock in sorted(stock_selector):
if each_stock in self.hash_df:
select_dict[each_stock] = self.hash_df[each_stock]
return select_dict
def fetch(
self,
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
fetch_orig: bool = True,
) -> pd.DataFrame:
fetch_stock_df_list = list(self._fetch_hash_df_by_stock(selector=selector, level=level).values())
for _index, stock_df in enumerate(fetch_stock_df_list):
fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set)
fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level, fetch_orig=fetch_orig)
fetch_stock_df_list[_index] = fetch_index_df
if len(fetch_stock_df_list) == 0:
index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument")
return pd.DataFrame(
index=pd.MultiIndex.from_arrays([[], []], names=index_names), columns=self.columns, dtype=np.float32
)
elif len(fetch_stock_df_list) == 1:
return fetch_stock_df_list[0]
else:
return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig)
def is_proc_func_supported(self):
"""the arg `proc_func` in `fetch` method is not supported in HashingStockStorage"""
return False