1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

add Handler Storage

This commit is contained in:
bxdd
2021-06-28 20:06:15 +00:00
parent 27f0db669f
commit e1b6f310c9
7 changed files with 115 additions and 22 deletions

View File

@@ -164,10 +164,6 @@ class Exchange:
assert set(self.extra_quote.columns) == set(quote_df.columns) - {"$change"}
quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0)
# update quote: pd.DataFrame to dict, for search use
if get_level_index(quote_df, level="datetime") == 1:
quote_df = quote_df.swaplevel().sort_index()
quote_dict = {}
for stock_id, stock_val in quote_df.groupby(level="instrument"):
quote_dict[stock_id] = stock_val

View File

@@ -408,7 +408,7 @@ class InfPosition(BasePosition):
"""
def skip_update(self) -> bool:
""" Updating state is meaningless for InfPosition """
"""Updating state is meaningless for InfPosition"""
return True
def check_stock(self, stock_id: str) -> bool:

View File

@@ -5,7 +5,6 @@ from typing import List, Tuple, Union
from ...utils.resam import resam_ts_data
from ...data.data import D
from ...data.dataset.utils import convert_index_format
from ...strategy.base import BaseStrategy
from ...backtest.order import BaseTradeDecision, Order, TradeDecisionWO
from ...backtest.exchange import Exchange
@@ -423,7 +422,6 @@ class SBBStrategyEMA(SBBStrategyBase):
signal_df = D.features(
self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
)
signal_df = convert_index_format(signal_df)
signal_df.columns = ["signal"]
self.signal = {}
@@ -515,7 +513,6 @@ class ACStrategy(BaseStrategy):
signal_df = D.features(
self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
)
signal_df = convert_index_format(signal_df)
signal_df.columns = ["volatility"]
self.signal = {}

View File

@@ -17,7 +17,7 @@ from ...data import D
from ...config import C
from ...utils import parse_config, transform_end_date, init_instance_by_config
from ...utils.serial import Serializable
from .utils import fetch_df_by_index
from .utils import fetch_df_by_index, fetch_df_by_col
from pathlib import Path
from .loader import DataLoader
@@ -152,14 +152,6 @@ class DataHandler(Serializable):
CS_ALL = "__all" # return all columns with single-level index column
CS_RAW = "__raw" # return raw data with multi-level index column
def _fetch_df_by_col(self, df: pd.DataFrame, col_set: str) -> pd.DataFrame:
if not isinstance(df.columns, pd.MultiIndex) or col_set == self.CS_RAW:
return df
elif col_set == self.CS_ALL:
return df.droplevel(axis=1, level=0)
else:
return df.loc(axis=1)[col_set]
def fetch(
self,
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
@@ -213,7 +205,7 @@ class DataHandler(Serializable):
df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy())
# Fetch column first will be more friendly to SepDataFrame
df = self._fetch_df_by_col(df, col_set)
df = fetch_df_by_col(df, col_set)
df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
if squeeze:
# squeeze columns
@@ -238,7 +230,7 @@ class DataHandler(Serializable):
list of column names
"""
df = self._data.head()
df = self._fetch_df_by_col(df, col_set)
df = fetch_df_by_col(df, col_set)
return df.columns.to_list()
def get_range_selector(self, cur_date: Union[pd.Timestamp, str], periods: int) -> slice:
@@ -525,7 +517,7 @@ class DataHandlerLP(DataHandler):
# Copy incase of `proc_func` changing the data inplace....
df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy())
# Fetch column first will be more friendly to SepDataFrame
df = self._fetch_df_by_col(df, col_set)
df = fetch_df_by_col(df, col_set)
return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list:
@@ -545,5 +537,5 @@ class DataHandlerLP(DataHandler):
list of column names
"""
df = self._get_df_by_key(data_key).head()
df = self._fetch_df_by_col(df, col_set)
df = fetch_df_by_col(df, col_set)
return df.columns.to_list()

View File

@@ -310,3 +310,12 @@ class CSZFillna(Processor):
cols = get_group_columns(df, self.fields_group)
df[cols] = df[cols].groupby("datetime").apply(lambda x: x.fillna(x.mean()))
return df
class HashingStock(Processor):
"""Process the df into hasing stock storage"""
def __call__(self, df: pd.DataFrame):
from .storage import HasingStockStorage
return HasingStockStorage.from_df(df)

View File

@@ -0,0 +1,85 @@
import pandas as pd
import numpy as np
from .handler import DataHandler
from typing import Tuple, Union, List
from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col
class BaseHandlerStorage:
def fetch(
self,
selector: Union[pd.Timestamp, slice, str, list] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
**kwargs,
) -> pd.DataFrame:
raise NotImplementedError("fetch is method not implemented!")
@staticmethod
def from_df(df: pd.DataFrame):
raise NotImplementedError("from_df method is not implemented!")
class HasingStockStorage(BaseHandlerStorage):
def __init__(self, df):
self.hash_df = dict()
self.stock_level = get_level_index(df, "instrument")
for k, v in df.groupby(level="instrument"):
self.hash_df[k] = v
self.columns = df.columns
@staticmethod
def from_df(df):
return HasingStockStorage(df)
def _fetch_hash_df_by_stock(self, selector, level):
stock_selector = slice(None)
if level is None:
if isinstance(selector, tuple) and self.stock_level < len(selector):
stock_selector = selector[self.stock_level]
elif isinstance(selector, (list, str)) and self.stock_level == 0:
stock_selector = selector
elif level == "instrument" or level == self.stock_level:
if isinstance(selector, tuple):
stock_selector = selector[0]
elif isinstance(selector, (list, str)):
stock_selector = selector
if not isinstance(stock_selector, (list, str)) and stock_selector != slice(None):
raise TypeError(f"stock selector must be type str|list, or slice(None), rather than {stock_selector}")
print(stock_selector)
if stock_selector == slice(None):
return self.hash_df
if isinstance(stock_selector, str):
stock_selector = [stock_selector]
select_dict = dict()
for each_stock in sorted(stock_selector):
if each_stock in self.hash_df:
select_dict[each_stock] = self.hash_df[each_stock]
return select_dict
def fetch(
self,
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
) -> pd.DataFrame:
fetch_stock_df_list = list(self._fetch_hash_df_by_stock(selector=selector, level=level).values())
for _index, stock_df in enumerate(fetch_stock_df_list):
fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set)
fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level)
fetch_stock_df_list[_index] = fetch_index_df
if len(fetch_stock_df_list) == 0:
index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument")
return pd.DataFrame(
index=pd.MultiIndex.from_arrays([[], []], names=index_names), columns=self.columns, dtype=np.float32
)
elif len(fetch_stock_df_list) == 1:
return fetch_stock_df_list[0]
else:
return pd.concat(fetch_stock_df_list, axis=0, sort=False)

View File

@@ -1,5 +1,8 @@
from typing import Union
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
from typing import Union, List
def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
@@ -72,6 +75,17 @@ def fetch_df_by_index(
]
def fetch_df_by_col(df: pd.DataFrame, col_set: Union[str, List[str]]) -> pd.DataFrame:
from .handler import DataHandler
if not isinstance(df.columns, pd.MultiIndex) or col_set == DataHandler.CS_RAW:
return df
elif col_set == DataHandler.CS_ALL:
return df.droplevel(axis=1, level=0)
else:
return df.loc(axis=1)[col_set]
def convert_index_format(df: Union[pd.DataFrame, pd.Series], level: str = "datetime") -> Union[pd.DataFrame, pd.Series]:
"""
Convert the format of df.MultiIndex according to the following rules: