mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
add Handler Storage
This commit is contained in:
@@ -164,10 +164,6 @@ class Exchange:
|
||||
assert set(self.extra_quote.columns) == set(quote_df.columns) - {"$change"}
|
||||
quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0)
|
||||
|
||||
# update quote: pd.DataFrame to dict, for search use
|
||||
if get_level_index(quote_df, level="datetime") == 1:
|
||||
quote_df = quote_df.swaplevel().sort_index()
|
||||
|
||||
quote_dict = {}
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument"):
|
||||
quote_dict[stock_id] = stock_val
|
||||
|
||||
@@ -408,7 +408,7 @@ class InfPosition(BasePosition):
|
||||
"""
|
||||
|
||||
def skip_update(self) -> bool:
|
||||
""" Updating state is meaningless for InfPosition """
|
||||
"""Updating state is meaningless for InfPosition"""
|
||||
return True
|
||||
|
||||
def check_stock(self, stock_id: str) -> bool:
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import List, Tuple, Union
|
||||
|
||||
from ...utils.resam import resam_ts_data
|
||||
from ...data.data import D
|
||||
from ...data.dataset.utils import convert_index_format
|
||||
from ...strategy.base import BaseStrategy
|
||||
from ...backtest.order import BaseTradeDecision, Order, TradeDecisionWO
|
||||
from ...backtest.exchange import Exchange
|
||||
@@ -423,7 +422,6 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
signal_df = D.features(
|
||||
self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
|
||||
)
|
||||
signal_df = convert_index_format(signal_df)
|
||||
signal_df.columns = ["signal"]
|
||||
self.signal = {}
|
||||
|
||||
@@ -515,7 +513,6 @@ class ACStrategy(BaseStrategy):
|
||||
signal_df = D.features(
|
||||
self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
|
||||
)
|
||||
signal_df = convert_index_format(signal_df)
|
||||
signal_df.columns = ["volatility"]
|
||||
self.signal = {}
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from ...data import D
|
||||
from ...config import C
|
||||
from ...utils import parse_config, transform_end_date, init_instance_by_config
|
||||
from ...utils.serial import Serializable
|
||||
from .utils import fetch_df_by_index
|
||||
from .utils import fetch_df_by_index, fetch_df_by_col
|
||||
from pathlib import Path
|
||||
from .loader import DataLoader
|
||||
|
||||
@@ -152,14 +152,6 @@ class DataHandler(Serializable):
|
||||
CS_ALL = "__all" # return all columns with single-level index column
|
||||
CS_RAW = "__raw" # return raw data with multi-level index column
|
||||
|
||||
def _fetch_df_by_col(self, df: pd.DataFrame, col_set: str) -> pd.DataFrame:
|
||||
if not isinstance(df.columns, pd.MultiIndex) or col_set == self.CS_RAW:
|
||||
return df
|
||||
elif col_set == self.CS_ALL:
|
||||
return df.droplevel(axis=1, level=0)
|
||||
else:
|
||||
return df.loc(axis=1)[col_set]
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
|
||||
@@ -213,7 +205,7 @@ class DataHandler(Serializable):
|
||||
df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy())
|
||||
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
df = fetch_df_by_col(df, col_set)
|
||||
df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
|
||||
if squeeze:
|
||||
# squeeze columns
|
||||
@@ -238,7 +230,7 @@ class DataHandler(Serializable):
|
||||
list of column names
|
||||
"""
|
||||
df = self._data.head()
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
df = fetch_df_by_col(df, col_set)
|
||||
return df.columns.to_list()
|
||||
|
||||
def get_range_selector(self, cur_date: Union[pd.Timestamp, str], periods: int) -> slice:
|
||||
@@ -525,7 +517,7 @@ class DataHandlerLP(DataHandler):
|
||||
# Copy incase of `proc_func` changing the data inplace....
|
||||
df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy())
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
df = fetch_df_by_col(df, col_set)
|
||||
return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
|
||||
|
||||
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list:
|
||||
@@ -545,5 +537,5 @@ class DataHandlerLP(DataHandler):
|
||||
list of column names
|
||||
"""
|
||||
df = self._get_df_by_key(data_key).head()
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
df = fetch_df_by_col(df, col_set)
|
||||
return df.columns.to_list()
|
||||
|
||||
@@ -310,3 +310,12 @@ class CSZFillna(Processor):
|
||||
cols = get_group_columns(df, self.fields_group)
|
||||
df[cols] = df[cols].groupby("datetime").apply(lambda x: x.fillna(x.mean()))
|
||||
return df
|
||||
|
||||
|
||||
class HashingStock(Processor):
|
||||
"""Process the df into hasing stock storage"""
|
||||
|
||||
def __call__(self, df: pd.DataFrame):
|
||||
from .storage import HasingStockStorage
|
||||
|
||||
return HasingStockStorage.from_df(df)
|
||||
|
||||
85
qlib/data/dataset/storage.py
Normal file
85
qlib/data/dataset/storage.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from .handler import DataHandler
|
||||
from typing import Tuple, Union, List
|
||||
|
||||
from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col
|
||||
|
||||
|
||||
class BaseHandlerStorage:
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str, list] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
|
||||
**kwargs,
|
||||
) -> pd.DataFrame:
|
||||
raise NotImplementedError("fetch is method not implemented!")
|
||||
|
||||
@staticmethod
|
||||
def from_df(df: pd.DataFrame):
|
||||
raise NotImplementedError("from_df method is not implemented!")
|
||||
|
||||
|
||||
class HasingStockStorage(BaseHandlerStorage):
|
||||
def __init__(self, df):
|
||||
self.hash_df = dict()
|
||||
self.stock_level = get_level_index(df, "instrument")
|
||||
for k, v in df.groupby(level="instrument"):
|
||||
self.hash_df[k] = v
|
||||
self.columns = df.columns
|
||||
|
||||
@staticmethod
|
||||
def from_df(df):
|
||||
return HasingStockStorage(df)
|
||||
|
||||
def _fetch_hash_df_by_stock(self, selector, level):
|
||||
stock_selector = slice(None)
|
||||
|
||||
if level is None:
|
||||
if isinstance(selector, tuple) and self.stock_level < len(selector):
|
||||
stock_selector = selector[self.stock_level]
|
||||
elif isinstance(selector, (list, str)) and self.stock_level == 0:
|
||||
stock_selector = selector
|
||||
elif level == "instrument" or level == self.stock_level:
|
||||
if isinstance(selector, tuple):
|
||||
stock_selector = selector[0]
|
||||
elif isinstance(selector, (list, str)):
|
||||
stock_selector = selector
|
||||
|
||||
if not isinstance(stock_selector, (list, str)) and stock_selector != slice(None):
|
||||
raise TypeError(f"stock selector must be type str|list, or slice(None), rather than {stock_selector}")
|
||||
print(stock_selector)
|
||||
if stock_selector == slice(None):
|
||||
return self.hash_df
|
||||
|
||||
if isinstance(stock_selector, str):
|
||||
stock_selector = [stock_selector]
|
||||
|
||||
select_dict = dict()
|
||||
for each_stock in sorted(stock_selector):
|
||||
if each_stock in self.hash_df:
|
||||
select_dict[each_stock] = self.hash_df[each_stock]
|
||||
return select_dict
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
|
||||
) -> pd.DataFrame:
|
||||
fetch_stock_df_list = list(self._fetch_hash_df_by_stock(selector=selector, level=level).values())
|
||||
for _index, stock_df in enumerate(fetch_stock_df_list):
|
||||
fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set)
|
||||
fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level)
|
||||
fetch_stock_df_list[_index] = fetch_index_df
|
||||
if len(fetch_stock_df_list) == 0:
|
||||
index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument")
|
||||
return pd.DataFrame(
|
||||
index=pd.MultiIndex.from_arrays([[], []], names=index_names), columns=self.columns, dtype=np.float32
|
||||
)
|
||||
elif len(fetch_stock_df_list) == 1:
|
||||
return fetch_stock_df_list[0]
|
||||
else:
|
||||
return pd.concat(fetch_stock_df_list, axis=0, sort=False)
|
||||
@@ -1,5 +1,8 @@
|
||||
from typing import Union
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
from typing import Union, List
|
||||
|
||||
|
||||
def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
|
||||
@@ -72,6 +75,17 @@ def fetch_df_by_index(
|
||||
]
|
||||
|
||||
|
||||
def fetch_df_by_col(df: pd.DataFrame, col_set: Union[str, List[str]]) -> pd.DataFrame:
|
||||
from .handler import DataHandler
|
||||
|
||||
if not isinstance(df.columns, pd.MultiIndex) or col_set == DataHandler.CS_RAW:
|
||||
return df
|
||||
elif col_set == DataHandler.CS_ALL:
|
||||
return df.droplevel(axis=1, level=0)
|
||||
else:
|
||||
return df.loc(axis=1)[col_set]
|
||||
|
||||
|
||||
def convert_index_format(df: Union[pd.DataFrame, pd.Series], level: str = "datetime") -> Union[pd.DataFrame, pd.Series]:
|
||||
"""
|
||||
Convert the format of df.MultiIndex according to the following rules:
|
||||
|
||||
Reference in New Issue
Block a user