diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index cffa98ba6..a759dbd86 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -164,10 +164,6 @@ class Exchange: assert set(self.extra_quote.columns) == set(quote_df.columns) - {"$change"} quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0) - # update quote: pd.DataFrame to dict, for search use - if get_level_index(quote_df, level="datetime") == 1: - quote_df = quote_df.swaplevel().sort_index() - quote_dict = {} for stock_id, stock_val in quote_df.groupby(level="instrument"): quote_dict[stock_id] = stock_val diff --git a/qlib/backtest/position.py b/qlib/backtest/position.py index 0f36e4959..b1037d460 100644 --- a/qlib/backtest/position.py +++ b/qlib/backtest/position.py @@ -408,7 +408,7 @@ class InfPosition(BasePosition): """ def skip_update(self) -> bool: - """ Updating state is meaningless for InfPosition """ + """Updating state is meaningless for InfPosition""" return True def check_stock(self, stock_id: str) -> bool: diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 0fb98e8ac..20099d4d3 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -5,7 +5,6 @@ from typing import List, Tuple, Union from ...utils.resam import resam_ts_data from ...data.data import D -from ...data.dataset.utils import convert_index_format from ...strategy.base import BaseStrategy from ...backtest.order import BaseTradeDecision, Order, TradeDecisionWO from ...backtest.exchange import Exchange @@ -423,7 +422,6 @@ class SBBStrategyEMA(SBBStrategyBase): signal_df = D.features( self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq ) - signal_df = convert_index_format(signal_df) signal_df.columns = ["signal"] self.signal = {} @@ -515,7 +513,6 @@ class ACStrategy(BaseStrategy): signal_df = D.features( self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq ) - signal_df = convert_index_format(signal_df) signal_df.columns = ["volatility"] self.signal = {} diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index c6338832a..30cfa7732 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -17,7 +17,7 @@ from ...data import D from ...config import C from ...utils import parse_config, transform_end_date, init_instance_by_config from ...utils.serial import Serializable -from .utils import fetch_df_by_index +from .utils import fetch_df_by_index, fetch_df_by_col from pathlib import Path from .loader import DataLoader @@ -152,14 +152,6 @@ class DataHandler(Serializable): CS_ALL = "__all" # return all columns with single-level index column CS_RAW = "__raw" # return raw data with multi-level index column - def _fetch_df_by_col(self, df: pd.DataFrame, col_set: str) -> pd.DataFrame: - if not isinstance(df.columns, pd.MultiIndex) or col_set == self.CS_RAW: - return df - elif col_set == self.CS_ALL: - return df.droplevel(axis=1, level=0) - else: - return df.loc(axis=1)[col_set] - def fetch( self, selector: Union[pd.Timestamp, slice, str] = slice(None, None), @@ -213,7 +205,7 @@ class DataHandler(Serializable): df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy()) # Fetch column first will be more friendly to SepDataFrame - df = self._fetch_df_by_col(df, col_set) + df = fetch_df_by_col(df, col_set) df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig) if squeeze: # squeeze columns @@ -238,7 +230,7 @@ class DataHandler(Serializable): list of column names """ df = self._data.head() - df = self._fetch_df_by_col(df, col_set) + df = fetch_df_by_col(df, col_set) return df.columns.to_list() def get_range_selector(self, cur_date: Union[pd.Timestamp, str], periods: int) -> slice: @@ -525,7 +517,7 @@ class DataHandlerLP(DataHandler): # Copy incase of `proc_func` changing the data inplace.... df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy()) # Fetch column first will be more friendly to SepDataFrame - df = self._fetch_df_by_col(df, col_set) + df = fetch_df_by_col(df, col_set) return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig) def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list: @@ -545,5 +537,5 @@ class DataHandlerLP(DataHandler): list of column names """ df = self._get_df_by_key(data_key).head() - df = self._fetch_df_by_col(df, col_set) + df = fetch_df_by_col(df, col_set) return df.columns.to_list() diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index fce22ddfc..1e1ed8dfb 100644 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -310,3 +310,12 @@ class CSZFillna(Processor): cols = get_group_columns(df, self.fields_group) df[cols] = df[cols].groupby("datetime").apply(lambda x: x.fillna(x.mean())) return df + + +class HashingStock(Processor): + """Process the df into hasing stock storage""" + + def __call__(self, df: pd.DataFrame): + from .storage import HasingStockStorage + + return HasingStockStorage.from_df(df) diff --git a/qlib/data/dataset/storage.py b/qlib/data/dataset/storage.py new file mode 100644 index 000000000..1849b6fcb --- /dev/null +++ b/qlib/data/dataset/storage.py @@ -0,0 +1,85 @@ +import pandas as pd +import numpy as np + +from .handler import DataHandler +from typing import Tuple, Union, List + +from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col + + +class BaseHandlerStorage: + def fetch( + self, + selector: Union[pd.Timestamp, slice, str, list] = slice(None, None), + level: Union[str, int] = "datetime", + col_set: Union[str, List[str]] = DataHandler.CS_ALL, + **kwargs, + ) -> pd.DataFrame: + raise NotImplementedError("fetch is method not implemented!") + + @staticmethod + def from_df(df: pd.DataFrame): + raise NotImplementedError("from_df method is not implemented!") + + +class HasingStockStorage(BaseHandlerStorage): + def __init__(self, df): + self.hash_df = dict() + self.stock_level = get_level_index(df, "instrument") + for k, v in df.groupby(level="instrument"): + self.hash_df[k] = v + self.columns = df.columns + + @staticmethod + def from_df(df): + return HasingStockStorage(df) + + def _fetch_hash_df_by_stock(self, selector, level): + stock_selector = slice(None) + + if level is None: + if isinstance(selector, tuple) and self.stock_level < len(selector): + stock_selector = selector[self.stock_level] + elif isinstance(selector, (list, str)) and self.stock_level == 0: + stock_selector = selector + elif level == "instrument" or level == self.stock_level: + if isinstance(selector, tuple): + stock_selector = selector[0] + elif isinstance(selector, (list, str)): + stock_selector = selector + + if not isinstance(stock_selector, (list, str)) and stock_selector != slice(None): + raise TypeError(f"stock selector must be type str|list, or slice(None), rather than {stock_selector}") + print(stock_selector) + if stock_selector == slice(None): + return self.hash_df + + if isinstance(stock_selector, str): + stock_selector = [stock_selector] + + select_dict = dict() + for each_stock in sorted(stock_selector): + if each_stock in self.hash_df: + select_dict[each_stock] = self.hash_df[each_stock] + return select_dict + + def fetch( + self, + selector: Union[pd.Timestamp, slice, str] = slice(None, None), + level: Union[str, int] = "datetime", + col_set: Union[str, List[str]] = DataHandler.CS_ALL, + ) -> pd.DataFrame: + fetch_stock_df_list = list(self._fetch_hash_df_by_stock(selector=selector, level=level).values()) + for _index, stock_df in enumerate(fetch_stock_df_list): + fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set) + fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level) + fetch_stock_df_list[_index] = fetch_index_df + if len(fetch_stock_df_list) == 0: + index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument") + return pd.DataFrame( + index=pd.MultiIndex.from_arrays([[], []], names=index_names), columns=self.columns, dtype=np.float32 + ) + elif len(fetch_stock_df_list) == 1: + return fetch_stock_df_list[0] + else: + return pd.concat(fetch_stock_df_list, axis=0, sort=False) diff --git a/qlib/data/dataset/utils.py b/qlib/data/dataset/utils.py index f7b07d563..3cb4dd3e2 100644 --- a/qlib/data/dataset/utils.py +++ b/qlib/data/dataset/utils.py @@ -1,5 +1,8 @@ -from typing import Union +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import pandas as pd +from typing import Union, List def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int: @@ -72,6 +75,17 @@ def fetch_df_by_index( ] +def fetch_df_by_col(df: pd.DataFrame, col_set: Union[str, List[str]]) -> pd.DataFrame: + from .handler import DataHandler + + if not isinstance(df.columns, pd.MultiIndex) or col_set == DataHandler.CS_RAW: + return df + elif col_set == DataHandler.CS_ALL: + return df.droplevel(axis=1, level=0) + else: + return df.loc(axis=1)[col_set] + + def convert_index_format(df: Union[pd.DataFrame, pd.Series], level: str = "datetime") -> Union[pd.DataFrame, pd.Series]: """ Convert the format of df.MultiIndex according to the following rules: