From e1b6f310c9ea11d33e89a0bcc50be9b884f79159 Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 28 Jun 2021 20:06:15 +0000 Subject: [PATCH 1/6] add Handler Storage --- qlib/backtest/exchange.py | 4 -- qlib/backtest/position.py | 2 +- qlib/contrib/strategy/rule_strategy.py | 3 - qlib/data/dataset/handler.py | 18 ++---- qlib/data/dataset/processor.py | 9 +++ qlib/data/dataset/storage.py | 85 ++++++++++++++++++++++++++ qlib/data/dataset/utils.py | 16 ++++- 7 files changed, 115 insertions(+), 22 deletions(-) create mode 100644 qlib/data/dataset/storage.py diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index cffa98ba6..a759dbd86 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -164,10 +164,6 @@ class Exchange: assert set(self.extra_quote.columns) == set(quote_df.columns) - {"$change"} quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0) - # update quote: pd.DataFrame to dict, for search use - if get_level_index(quote_df, level="datetime") == 1: - quote_df = quote_df.swaplevel().sort_index() - quote_dict = {} for stock_id, stock_val in quote_df.groupby(level="instrument"): quote_dict[stock_id] = stock_val diff --git a/qlib/backtest/position.py b/qlib/backtest/position.py index 0f36e4959..b1037d460 100644 --- a/qlib/backtest/position.py +++ b/qlib/backtest/position.py @@ -408,7 +408,7 @@ class InfPosition(BasePosition): """ def skip_update(self) -> bool: - """ Updating state is meaningless for InfPosition """ + """Updating state is meaningless for InfPosition""" return True def check_stock(self, stock_id: str) -> bool: diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 0fb98e8ac..20099d4d3 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -5,7 +5,6 @@ from typing import List, Tuple, Union from ...utils.resam import resam_ts_data from ...data.data import D -from ...data.dataset.utils import convert_index_format from ...strategy.base import BaseStrategy from ...backtest.order import BaseTradeDecision, Order, TradeDecisionWO from ...backtest.exchange import Exchange @@ -423,7 +422,6 @@ class SBBStrategyEMA(SBBStrategyBase): signal_df = D.features( self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq ) - signal_df = convert_index_format(signal_df) signal_df.columns = ["signal"] self.signal = {} @@ -515,7 +513,6 @@ class ACStrategy(BaseStrategy): signal_df = D.features( self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq ) - signal_df = convert_index_format(signal_df) signal_df.columns = ["volatility"] self.signal = {} diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index c6338832a..30cfa7732 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -17,7 +17,7 @@ from ...data import D from ...config import C from ...utils import parse_config, transform_end_date, init_instance_by_config from ...utils.serial import Serializable -from .utils import fetch_df_by_index +from .utils import fetch_df_by_index, fetch_df_by_col from pathlib import Path from .loader import DataLoader @@ -152,14 +152,6 @@ class DataHandler(Serializable): CS_ALL = "__all" # return all columns with single-level index column CS_RAW = "__raw" # return raw data with multi-level index column - def _fetch_df_by_col(self, df: pd.DataFrame, col_set: str) -> pd.DataFrame: - if not isinstance(df.columns, pd.MultiIndex) or col_set == self.CS_RAW: - return df - elif col_set == self.CS_ALL: - return df.droplevel(axis=1, level=0) - else: - return df.loc(axis=1)[col_set] - def fetch( self, selector: Union[pd.Timestamp, slice, str] = slice(None, None), @@ -213,7 +205,7 @@ class DataHandler(Serializable): df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy()) # Fetch column first will be more friendly to SepDataFrame - df = self._fetch_df_by_col(df, col_set) + df = fetch_df_by_col(df, col_set) df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig) if squeeze: # squeeze columns @@ -238,7 +230,7 @@ class DataHandler(Serializable): list of column names """ df = self._data.head() - df = self._fetch_df_by_col(df, col_set) + df = fetch_df_by_col(df, col_set) return df.columns.to_list() def get_range_selector(self, cur_date: Union[pd.Timestamp, str], periods: int) -> slice: @@ -525,7 +517,7 @@ class DataHandlerLP(DataHandler): # Copy incase of `proc_func` changing the data inplace.... df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy()) # Fetch column first will be more friendly to SepDataFrame - df = self._fetch_df_by_col(df, col_set) + df = fetch_df_by_col(df, col_set) return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig) def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list: @@ -545,5 +537,5 @@ class DataHandlerLP(DataHandler): list of column names """ df = self._get_df_by_key(data_key).head() - df = self._fetch_df_by_col(df, col_set) + df = fetch_df_by_col(df, col_set) return df.columns.to_list() diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index fce22ddfc..1e1ed8dfb 100644 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -310,3 +310,12 @@ class CSZFillna(Processor): cols = get_group_columns(df, self.fields_group) df[cols] = df[cols].groupby("datetime").apply(lambda x: x.fillna(x.mean())) return df + + +class HashingStock(Processor): + """Process the df into hasing stock storage""" + + def __call__(self, df: pd.DataFrame): + from .storage import HasingStockStorage + + return HasingStockStorage.from_df(df) diff --git a/qlib/data/dataset/storage.py b/qlib/data/dataset/storage.py new file mode 100644 index 000000000..1849b6fcb --- /dev/null +++ b/qlib/data/dataset/storage.py @@ -0,0 +1,85 @@ +import pandas as pd +import numpy as np + +from .handler import DataHandler +from typing import Tuple, Union, List + +from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col + + +class BaseHandlerStorage: + def fetch( + self, + selector: Union[pd.Timestamp, slice, str, list] = slice(None, None), + level: Union[str, int] = "datetime", + col_set: Union[str, List[str]] = DataHandler.CS_ALL, + **kwargs, + ) -> pd.DataFrame: + raise NotImplementedError("fetch is method not implemented!") + + @staticmethod + def from_df(df: pd.DataFrame): + raise NotImplementedError("from_df method is not implemented!") + + +class HasingStockStorage(BaseHandlerStorage): + def __init__(self, df): + self.hash_df = dict() + self.stock_level = get_level_index(df, "instrument") + for k, v in df.groupby(level="instrument"): + self.hash_df[k] = v + self.columns = df.columns + + @staticmethod + def from_df(df): + return HasingStockStorage(df) + + def _fetch_hash_df_by_stock(self, selector, level): + stock_selector = slice(None) + + if level is None: + if isinstance(selector, tuple) and self.stock_level < len(selector): + stock_selector = selector[self.stock_level] + elif isinstance(selector, (list, str)) and self.stock_level == 0: + stock_selector = selector + elif level == "instrument" or level == self.stock_level: + if isinstance(selector, tuple): + stock_selector = selector[0] + elif isinstance(selector, (list, str)): + stock_selector = selector + + if not isinstance(stock_selector, (list, str)) and stock_selector != slice(None): + raise TypeError(f"stock selector must be type str|list, or slice(None), rather than {stock_selector}") + print(stock_selector) + if stock_selector == slice(None): + return self.hash_df + + if isinstance(stock_selector, str): + stock_selector = [stock_selector] + + select_dict = dict() + for each_stock in sorted(stock_selector): + if each_stock in self.hash_df: + select_dict[each_stock] = self.hash_df[each_stock] + return select_dict + + def fetch( + self, + selector: Union[pd.Timestamp, slice, str] = slice(None, None), + level: Union[str, int] = "datetime", + col_set: Union[str, List[str]] = DataHandler.CS_ALL, + ) -> pd.DataFrame: + fetch_stock_df_list = list(self._fetch_hash_df_by_stock(selector=selector, level=level).values()) + for _index, stock_df in enumerate(fetch_stock_df_list): + fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set) + fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level) + fetch_stock_df_list[_index] = fetch_index_df + if len(fetch_stock_df_list) == 0: + index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument") + return pd.DataFrame( + index=pd.MultiIndex.from_arrays([[], []], names=index_names), columns=self.columns, dtype=np.float32 + ) + elif len(fetch_stock_df_list) == 1: + return fetch_stock_df_list[0] + else: + return pd.concat(fetch_stock_df_list, axis=0, sort=False) diff --git a/qlib/data/dataset/utils.py b/qlib/data/dataset/utils.py index f7b07d563..3cb4dd3e2 100644 --- a/qlib/data/dataset/utils.py +++ b/qlib/data/dataset/utils.py @@ -1,5 +1,8 @@ -from typing import Union +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import pandas as pd +from typing import Union, List def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int: @@ -72,6 +75,17 @@ def fetch_df_by_index( ] +def fetch_df_by_col(df: pd.DataFrame, col_set: Union[str, List[str]]) -> pd.DataFrame: + from .handler import DataHandler + + if not isinstance(df.columns, pd.MultiIndex) or col_set == DataHandler.CS_RAW: + return df + elif col_set == DataHandler.CS_ALL: + return df.droplevel(axis=1, level=0) + else: + return df.loc(axis=1)[col_set] + + def convert_index_format(df: Union[pd.DataFrame, pd.Series], level: str = "datetime") -> Union[pd.DataFrame, pd.Series]: """ Convert the format of df.MultiIndex according to the following rules: From 90bbf2b7c6a6456f3dbe8ac237c0f0ae0f33c19b Mon Sep 17 00:00:00 2001 From: you-n-g Date: Wed, 30 Jun 2021 08:29:47 +0800 Subject: [PATCH 2/6] Fix account update bar_count bug --- qlib/backtest/account.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py index a6ef2f6b8..6167ee407 100644 --- a/qlib/backtest/account.py +++ b/qlib/backtest/account.py @@ -263,11 +263,11 @@ class Account: elif atomic is False and inner_order_indicators is None: raise ValueError("inner_order_indicators is necessary in unatomic executor") + # TODO: `update_bar_count` and `update_current` should placed in Position and be merged. + self.update_bar_count() + self.update_current(trade_start_time, trade_end_time, trade_exchange) if generate_report: # report is portfolio related analysis - # TODO: `update_bar_count` and `update_current` should placed in Position and be merged. - self.update_bar_count() - self.update_current(trade_start_time, trade_end_time, trade_exchange) self.update_report(trade_start_time, trade_end_time) # indicator is trading (e.g. high-frequency order execution) related analysis From 9985befe6955c4953e1bb8b57854171b5df24181 Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 29 Jun 2021 12:02:27 +0000 Subject: [PATCH 3/6] update HashingStockStorage --- qlib/data/dataset/handler.py | 65 ++++++++++++++------- qlib/data/dataset/storage.py | 28 ++++++++- tests/test_handler_storage.py | 107 ++++++++++++++++++++++++++++++++++ 3 files changed, 176 insertions(+), 24 deletions(-) create mode 100644 tests/test_handler_storage.py diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 30cfa7732..475601625 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -175,7 +175,7 @@ class DataHandler(Serializable): select a set of meaningful columns.(e.g. features, columns) - if cal_set == CS_RAW: + if col_set == CS_RAW: the raw dataset will be returned. - if isinstance(col_set, List[str]): @@ -197,23 +197,33 @@ class DataHandler(Serializable): ------- pd.DataFrame. """ - if proc_func is None: - df = self._data - else: - # FIXME: fetching by time first will be more friendly to `proc_func` - # Copy in case of `proc_func` changing the data inplace.... - df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy()) + from .storage import HasingStockStorage + + data_storage = self._data + if isinstance(data_storage, pd.DataFrame): + data_df = data_storage + if proc_func is not None: + # FIXME: fetching by time first will be more friendly to `proc_func` + # Copy in case of `proc_func` changing the data inplace.... + data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy()) + + # Fetch column first will be more friendly to SepDataFrame + data_df = fetch_df_by_col(data_df, col_set) + data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig) + elif isinstance(data_storage, HasingStockStorage): + if proc_func is not None: + warnings.warn(f"proc_func is not supported by the HasingStockStorage") + data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig) + else: + raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}") - # Fetch column first will be more friendly to SepDataFrame - df = fetch_df_by_col(df, col_set) - df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig) if squeeze: # squeeze columns - df = df.squeeze() + data_df = data_df.squeeze() # squeeze index if isinstance(selector, (str, pd.Timestamp)): - df = df.reset_index(level=level, drop=True) - return df + data_df = data_df.reset_index(level=level, drop=True) + return data_df def get_cols(self, col_set=CS_ALL) -> list: """ @@ -511,14 +521,27 @@ class DataHandlerLP(DataHandler): ------- pd.DataFrame: """ - df = self._get_df_by_key(data_key) - if proc_func is not None: - # FIXME: fetch by time first will be more friendly to proc_func - # Copy incase of `proc_func` changing the data inplace.... - df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy()) - # Fetch column first will be more friendly to SepDataFrame - df = fetch_df_by_col(df, col_set) - return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig) + from .storage import HasingStockStorage + + data_storage = self._get_df_by_key(data_key) + if isinstance(data_storage, pd.DataFrame): + data_df = data_storage + if proc_func is not None: + # FIXME: fetch by time first will be more friendly to proc_func + # Copy incase of `proc_func` changing the data inplace.... + data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy()) + # Fetch column first will be more friendly to SepDataFrame + data_df = fetch_df_by_col(data_df, col_set) + data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig) + + elif isinstance(data_storage, HasingStockStorage): + if proc_func is not None: + warnings.warn(f"proc_func is not supported by the HasingStockStorage") + data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig) + else: + raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}") + + return data_df def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list: """ diff --git a/qlib/data/dataset/storage.py b/qlib/data/dataset/storage.py index 1849b6fcb..66895cfe7 100644 --- a/qlib/data/dataset/storage.py +++ b/qlib/data/dataset/storage.py @@ -2,7 +2,7 @@ import pandas as pd import numpy as np from .handler import DataHandler -from typing import Tuple, Union, List +from typing import Tuple, Union, List, Callable from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col @@ -13,8 +13,29 @@ class BaseHandlerStorage: selector: Union[pd.Timestamp, slice, str, list] = slice(None, None), level: Union[str, int] = "datetime", col_set: Union[str, List[str]] = DataHandler.CS_ALL, + fetch_orig: bool = True, **kwargs, ) -> pd.DataFrame: + """fetch data from the data storage + + Parameters + ---------- + selector : Union[pd.Timestamp, slice, str] + describe how to select data by index + level : Union[str, int] + which index level to select the data + col_set : Union[str, List[str]] + - if isinstance(col_set, str): + select a set of meaningful columns.(e.g. features, columns) + if col_set == DataHandler.CS_RAW: + the raw dataset will be returned. + - if isinstance(col_set, List[str]): + select several sets of meaningful columns, the returned data has multiple level + fetch_orig : bool + Return the original data instead of copy if possible. + + """ + raise NotImplementedError("fetch is method not implemented!") @staticmethod @@ -68,11 +89,12 @@ class HasingStockStorage(BaseHandlerStorage): selector: Union[pd.Timestamp, slice, str] = slice(None, None), level: Union[str, int] = "datetime", col_set: Union[str, List[str]] = DataHandler.CS_ALL, + fetch_orig: bool = True, ) -> pd.DataFrame: fetch_stock_df_list = list(self._fetch_hash_df_by_stock(selector=selector, level=level).values()) for _index, stock_df in enumerate(fetch_stock_df_list): fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set) - fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level) + fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level, fetch_orig=fetch_orig) fetch_stock_df_list[_index] = fetch_index_df if len(fetch_stock_df_list) == 0: index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument") @@ -82,4 +104,4 @@ class HasingStockStorage(BaseHandlerStorage): elif len(fetch_stock_df_list) == 1: return fetch_stock_df_list[0] else: - return pd.concat(fetch_stock_df_list, axis=0, sort=False) + return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig) diff --git a/tests/test_handler_storage.py b/tests/test_handler_storage.py new file mode 100644 index 000000000..be36788bd --- /dev/null +++ b/tests/test_handler_storage.py @@ -0,0 +1,107 @@ +import unittest +import qlib +import time +import pandas as pd + +from qlib.data import D +from qlib.tests import TestAutoData + +from qlib.data.dataset.handler import DataHandlerLP +from qlib.data.dataset.processor import Processor +from qlib.contrib.data.handler import check_transform_proc +from qlib.utils import init_instance_by_config +from qlib.log import TimeInspector + + +class TestHandler(DataHandlerLP): + def __init__( + self, + instruments="csi300", + start_time=None, + end_time=None, + infer_processors=[], + learn_processors=[], + fit_start_time=None, + fit_end_time=None, + drop_raw=True, + ): + + infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) + learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) + + data_loader = { + "class": "QlibDataLoader", + "kwargs": { + "freq": "day", + "config": self.get_feature_config(), + "swap_level": False, + }, + } + + super().__init__( + instruments=instruments, + start_time=start_time, + end_time=end_time, + data_loader=data_loader, + infer_processors=infer_processors, + learn_processors=learn_processors, + drop_raw=drop_raw, + ) + + def get_feature_config(self): + fields = ["Ref($open, 1)", "Ref($close, 1)", "Ref($volume, 1)", "$open", "$close", "$volume"] + names = ["open_0", "close_0", "volume_0", "open_1", "close_1", "volume_1"] + return fields, names + + +class MiniTimer: + def __init__(self, name): + self.name = name + + def __enter__(self): + self.start = time.time() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end = time.time() + print(f"[MyTimer Info] <{self.name}> process costs {self.end - self.start} seconds") + + +class TestHandlerStorage(TestAutoData): + + market = "all" + + start_time = "2020-01-01" + end_time = "2020-12-31" + train_end_time = "2020-05-31" + test_start_time = "2020-06-01" + + data_handler_kwargs = { + "start_time": start_time, + "end_time": end_time, + "fit_start_time": start_time, + "fit_end_time": train_end_time, + "instruments": market, + "infer_processors": ["HashingStock"], + } + + def test_handler_storage(self): + with MiniTimer("init data hanlder"): + data_handler = TestHandler(**self.data_handler_kwargs) + + with MiniTimer("random fetch"): + print(data_handler.fetch(selector=("SH600170", slice(None)), level=None)) + print( + data_handler.fetch( + selector=("SH600170", slice(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01"))), level=None + ) + ) + print( + data_handler.fetch( + selector=(["SH600170", "SH600383"], slice(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01"))), + level=None, + ) + ) + + +if __name__ == "__main__": + unittest.main() From 8d1b1979d9f69211e65adf62486b539d8f1284d4 Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 29 Jun 2021 15:51:41 +0000 Subject: [PATCH 4/6] update handler_storage test --- qlib/data/dataset/handler.py | 21 ++++++----- qlib/data/dataset/processor.py | 4 +- qlib/data/dataset/storage.py | 2 +- tests/test_handler_storage.py | 69 ++++++++++++++++++++++------------ 4 files changed, 59 insertions(+), 37 deletions(-) diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 475601625..edcc1ede2 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -206,13 +206,14 @@ class DataHandler(Serializable): # FIXME: fetching by time first will be more friendly to `proc_func` # Copy in case of `proc_func` changing the data inplace.... data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy()) - - # Fetch column first will be more friendly to SepDataFrame - data_df = fetch_df_by_col(data_df, col_set) - data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig) + data_df = fetch_df_by_col(data_df, col_set) + else: + # Fetch column first will be more friendly to SepDataFrame + data_df = fetch_df_by_col(data_df, col_set) + data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig) elif isinstance(data_storage, HasingStockStorage): if proc_func is not None: - warnings.warn(f"proc_func is not supported by the HasingStockStorage") + raise ValueError("proc_func is not supported by the HasingStockStorage") data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig) else: raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}") @@ -530,13 +531,15 @@ class DataHandlerLP(DataHandler): # FIXME: fetch by time first will be more friendly to proc_func # Copy incase of `proc_func` changing the data inplace.... data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy()) - # Fetch column first will be more friendly to SepDataFrame - data_df = fetch_df_by_col(data_df, col_set) - data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig) + data_df = fetch_df_by_col(data_df, col_set) + else: + # Fetch column first will be more friendly to SepDataFrame + data_df = fetch_df_by_col(data_df, col_set) + data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig) elif isinstance(data_storage, HasingStockStorage): if proc_func is not None: - warnings.warn(f"proc_func is not supported by the HasingStockStorage") + raise ValueError("proc_func is not supported by the HasingStockStorage") data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig) else: raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}") diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index 1e1ed8dfb..cc6dcdfd3 100644 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -312,8 +312,8 @@ class CSZFillna(Processor): return df -class HashingStock(Processor): - """Process the df into hasing stock storage""" +class HashStockFormat(Processor): + """Process the storage of from df into hasing stock format""" def __call__(self, df: pd.DataFrame): from .storage import HasingStockStorage diff --git a/qlib/data/dataset/storage.py b/qlib/data/dataset/storage.py index 66895cfe7..247970481 100644 --- a/qlib/data/dataset/storage.py +++ b/qlib/data/dataset/storage.py @@ -71,7 +71,7 @@ class HasingStockStorage(BaseHandlerStorage): if not isinstance(stock_selector, (list, str)) and stock_selector != slice(None): raise TypeError(f"stock selector must be type str|list, or slice(None), rather than {stock_selector}") - print(stock_selector) + if stock_selector == slice(None): return self.hash_df diff --git a/tests/test_handler_storage.py b/tests/test_handler_storage.py index be36788bd..e41286cb2 100644 --- a/tests/test_handler_storage.py +++ b/tests/test_handler_storage.py @@ -1,15 +1,11 @@ import unittest -import qlib import time -import pandas as pd - +import numpy as np from qlib.data import D from qlib.tests import TestAutoData from qlib.data.dataset.handler import DataHandlerLP -from qlib.data.dataset.processor import Processor from qlib.contrib.data.handler import check_transform_proc -from qlib.utils import init_instance_by_config from qlib.log import TimeInspector @@ -63,17 +59,17 @@ class MiniTimer: def __exit__(self, exc_type, exc_val, exc_tb): self.end = time.time() - print(f"[MyTimer Info] <{self.name}> process costs {self.end - self.start} seconds") + print(f"[Timer Info] <{self.name}> process costs {self.end - self.start} seconds") class TestHandlerStorage(TestAutoData): market = "all" - start_time = "2020-01-01" + start_time = "2010-01-01" end_time = "2020-12-31" - train_end_time = "2020-05-31" - test_start_time = "2020-06-01" + train_end_time = "2015-12-31" + test_start_time = "2016-01-01" data_handler_kwargs = { "start_time": start_time, @@ -81,26 +77,49 @@ class TestHandlerStorage(TestAutoData): "fit_start_time": start_time, "fit_end_time": train_end_time, "instruments": market, - "infer_processors": ["HashingStock"], } def test_handler_storage(self): - with MiniTimer("init data hanlder"): - data_handler = TestHandler(**self.data_handler_kwargs) + # init data handler + data_handler = TestHandler(**self.data_handler_kwargs) - with MiniTimer("random fetch"): - print(data_handler.fetch(selector=("SH600170", slice(None)), level=None)) - print( - data_handler.fetch( - selector=("SH600170", slice(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01"))), level=None - ) - ) - print( - data_handler.fetch( - selector=(["SH600170", "SH600383"], slice(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01"))), - level=None, - ) - ) + # init data handler with hasing storage + data_handler_hs = TestHandler(**self.data_handler_kwargs, infer_processors=["HashStockFormat"]) + + fetch_start_time = "2019-01-01" + fetch_end_time = "2019-12-31" + instruments = D.instruments(market=self.market) + instruments = D.list_instruments( + instruments=instruments, start_time=fetch_start_time, end_time=fetch_end_time, as_list=True + ) + + with TimeInspector.logt("random fetch with DataFrame Storage"): + + # single stock + for i in range(100): + random_index = np.random.randint(len(instruments), size=1)[0] + fetch_stock = instruments[random_index] + data_handler.fetch(selector=(fetch_stock, slice(fetch_start_time, fetch_end_time)), level=None) + + # multi stocks + for i in range(100): + random_indexs = np.random.randint(len(instruments), size=5) + fetch_stocks = [instruments[_index] for _index in random_indexs] + data_handler.fetch(selector=(fetch_stocks, slice(fetch_start_time, fetch_end_time)), level=None) + + with TimeInspector.logt("random fetch with HasingStock Storage"): + + # single stock + for i in range(100): + random_index = np.random.randint(len(instruments), size=1)[0] + fetch_stock = instruments[random_index] + data_handler_hs.fetch(selector=(fetch_stock, slice(fetch_start_time, fetch_end_time)), level=None) + + # multi stocks + for i in range(100): + random_indexs = np.random.randint(len(instruments), size=5) + fetch_stocks = [instruments[_index] for _index in random_indexs] + data_handler_hs.fetch(selector=(fetch_stocks, slice(fetch_start_time, fetch_end_time)), level=None) if __name__ == "__main__": From b242d6e1e1f9bfb063b7e2ccf2e3d1df6f8079bc Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 29 Jun 2021 15:54:20 +0000 Subject: [PATCH 5/6] delMiniTimer in haandler storage test --- tests/test_handler_storage.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/test_handler_storage.py b/tests/test_handler_storage.py index e41286cb2..056595063 100644 --- a/tests/test_handler_storage.py +++ b/tests/test_handler_storage.py @@ -50,18 +50,6 @@ class TestHandler(DataHandlerLP): return fields, names -class MiniTimer: - def __init__(self, name): - self.name = name - - def __enter__(self): - self.start = time.time() - - def __exit__(self, exc_type, exc_val, exc_tb): - self.end = time.time() - print(f"[Timer Info] <{self.name}> process costs {self.end - self.start} seconds") - - class TestHandlerStorage(TestAutoData): market = "all" From bbf5d1bbbb9d7337a19d6c2bd2f69e23f2898781 Mon Sep 17 00:00:00 2001 From: Young Date: Wed, 30 Jun 2021 07:34:23 +0000 Subject: [PATCH 6/6] add file order strategy --- qlib/backtest/exchange.py | 8 ++- qlib/backtest/order.py | 90 +++++++++++++++++++++++++- qlib/contrib/strategy/rule_strategy.py | 66 ++++++++++++++++++- qlib/data/dataset/utils.py | 3 + qlib/utils/file.py | 37 +++++++++++ tests/backtest/test_file_strategy.py | 86 ++++++++++++++++++++++++ 6 files changed, 284 insertions(+), 6 deletions(-) create mode 100644 qlib/utils/file.py create mode 100644 tests/backtest/test_file_strategy.py diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index a759dbd86..8177d53ee 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -14,7 +14,7 @@ from ..data.dataset.utils import get_level_index from ..config import C, REG_CN from ..utils.resam import resam_ts_data from ..log import get_module_logger -from .order import Order +from .order import Order, OrderDir, OrderHelper class Exchange: @@ -526,3 +526,9 @@ class Exchange: raise NotImplementedError("order type {} error".format(order.type)) return trade_val, trade_cost + + def get_order_helper(self) -> OrderHelper: + if not hasattr(self, "_order_helper"): + # cache to avoid recreate the same instance + self._order_helper = OrderHelper(self) + return self._order_helper diff --git a/qlib/backtest/order.py b/qlib/backtest/order.py index 19ea807c1..9df162263 100644 --- a/qlib/backtest/order.py +++ b/qlib/backtest/order.py @@ -2,12 +2,14 @@ # Licensed under the MIT License. # TODO: rename it with decision.py from __future__ import annotations +from enum import IntEnum # try to fix circular imports when enabling type hints from typing import TYPE_CHECKING if TYPE_CHECKING: from qlib.strategy.base import BaseStrategy + from qlib.backtest.exchange import Exchange from qlib.backtest.utils import TradeCalendarManager import warnings import pandas as pd @@ -15,6 +17,12 @@ from dataclasses import dataclass, field from typing import ClassVar, Union, List, Set, Tuple +class OrderDir(IntEnum): + # Order direction + SELL = 0 + BUY = 1 + + @dataclass class Order: """ @@ -32,19 +40,97 @@ class Order: stock_id: str amount: float + + # The interval of the order which belongs to (NOTE: this is not the expected order dealing range time) start_time: pd.Timestamp end_time: pd.Timestamp + direction: int factor: float deal_amount: float = field(init=False) - SELL: ClassVar[int] = 0 - BUY: ClassVar[int] = 1 + + # FIXME: + # for compatible now. + # Plese remove them in the future + SELL: ClassVar[OrderDir] = OrderDir.SELL + BUY: ClassVar[OrderDir] = OrderDir.BUY def __post_init__(self): if self.direction not in {Order.SELL, Order.BUY}: raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy") self.deal_amount = 0 + @staticmethod + def parse_dir(direction: Union[str, int, OrderDir]) -> OrderDir: + if isinstance(direction, OrderDir): + return direction + elif isinstance(direction, int): + return OrderDir(direction) + elif isinstance(direction, str): + dl = direction.lower() + if dl.strip() == "sell": + return OrderDir.SELL + elif dl.strip() == "buy": + return OrderDir.BUY + else: + raise NotImplementedError(f"This type of input is not supported") + else: + raise NotImplementedError(f"This type of input is not supported") + + +class OrderHelper: + """ + Motivation + - Make generating order easier + - User may have no knowledge about the adjust-factor information about the system. + - It involves to much interaction with the exchange when generating orders. + """ + + def __init__(self, exchange: Exchange): + self.exchange = exchange + + def create( + self, + code: str, + amount: float, + direction: OrderDir, + start_time: Union[str, pd.Timestamp], + end_time: Union[str, pd.Timestamp], + ) -> Order: + """ + help to create a order + + # TODO: create order for unadjusted amount order + + Parameters + ---------- + code : str + the id of the instrument + amount : float + **adjusted trading amount** + direction : OrderDir + trading direction + start_time : Union[str, pd.Timestamp] + The interval of the order which belongs to + end_time : Union[str, pd.Timestamp] + The interval of the order which belongs to + + Returns + ------- + Order: + The created order + """ + start_time = pd.Timestamp(start_time) + end_time = pd.Timestamp(end_time) + return Order( + stock_id=code, + amount=amount, + start_time=start_time, + end_time=end_time, + direction=direction, + factor=self.exchange.get_factor(code, start_time, end_time), + ) + class BaseTradeDecision: """ diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 20099d4d3..22483a79c 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -1,14 +1,19 @@ +from pathlib import Path import warnings import numpy as np import pandas as pd -from typing import List, Tuple, Union +from typing import IO, List, Tuple, Union +from qlib.data.dataset.utils import convert_index_format + +from qlib.utils import lazy_sort_index from ...utils.resam import resam_ts_data from ...data.data import D from ...strategy.base import BaseStrategy from ...backtest.order import BaseTradeDecision, Order, TradeDecisionWO -from ...backtest.exchange import Exchange +from ...backtest.exchange import Exchange, OrderHelper from ...backtest.utils import CommonInfrastructure, LevelInfrastructure +from qlib.utils.file import get_io_object def get_start_end_idx(strategy: BaseStrategy, outer_trade_decision: BaseTradeDecision) -> Union[int, int]: @@ -653,6 +658,9 @@ class RandomOrderStrategy(BaseStrategy): index_range : Tuple the intra day time index range of the orders the left and right is closed. + + If you want to get the index_range in intra-day + - `qlib/utils/time.py:def get_day_min_idx_range` can help you create the index range easier # TODO: this is a index_range level limitation. We'll implement a more detailed limitation later. sample_ratio : float the ratio of all orders are sampled @@ -684,7 +692,9 @@ class RandomOrderStrategy(BaseStrategy): if step_time_start in self.volume_df: for stock_id, volume in self.volume_df[step_time_start].dropna().sample(frac=self.sample_ratio).items(): order_list.append( - self.common_infra.get("trade_exchange").create_order( + self.common_infra.get("trade_exchange") + .get_order_helper() + .create( code=stock_id, amount=volume * self.volume_ratio, start_time=step_time_start, @@ -693,3 +703,53 @@ class RandomOrderStrategy(BaseStrategy): ) ) return TradeDecisionWO(order_list, self, self.index_range) + + +class FileOrderStrategy(BaseStrategy): + """ + Motivtaion: + - This class provides an interface for user to read orders from csv files. + - It is supposed to be used in + """ + + def __init__(self, file: Union[IO, str, Path], index_range: Tuple[int, int] = None, *args, **kwargs): + super().__init__(*args, **kwargs) + with get_io_object(file) as f: + self.order_df = pd.read_csv(f, dtype={"datetime": np.str}) + + self.order_df["datetime"] = self.order_df["datetime"].apply(pd.Timestamp) + self.order_df = self.order_df.set_index(["datetime", "instrument"]) + + # make sure the datetime is the first level for fast indexing + self.order_df = lazy_sort_index(convert_index_format(self.order_df, level="datetime")) + self.index_range = index_range + + def generate_trade_decision(self, execute_result=None) -> TradeDecisionWO: + """ + Parameters + ---------- + execute_result : + execute_result will be ignored in FileOrderStrategy + """ + oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper() + tc = self.trade_calendar + step = tc.get_trade_step() + start, end = tc.get_step_time(step) + # CONVERSION: the bar is indexed by the time + try: + df = self.order_df.loc(axis=0)[start] + except KeyError: + return TradeDecisionWO([], self) + else: + order_list = [] + for idx, row in df.iterrows(): + order_list.append( + oh.create( + code=idx, + amount=row["amount"], + direction=Order.parse_dir(row["direction"]), + start_time=start, + end_time=end, + ) + ) + return TradeDecisionWO(order_list, self, self.index_range) diff --git a/qlib/data/dataset/utils.py b/qlib/data/dataset/utils.py index 3cb4dd3e2..c6b3d97b6 100644 --- a/qlib/data/dataset/utils.py +++ b/qlib/data/dataset/utils.py @@ -92,6 +92,9 @@ def convert_index_format(df: Union[pd.DataFrame, pd.Series], level: str = "datet - If `level` is the first level of df.MultiIndex, do nothing - If `level` is the second level of df.MultiIndex, swap the level of index. + NOTE: + the number of levels of df.MultiIndex should be 2 + Parameters ---------- df : Union[pd.DataFrame, pd.Series] diff --git a/qlib/utils/file.py b/qlib/utils/file.py new file mode 100644 index 000000000..611260c86 --- /dev/null +++ b/qlib/utils/file.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# TODO: move file related utils into this module +import contextlib +from typing import IO, Union +from pathlib import Path + + +@contextlib.contextmanager +def get_io_object(file: Union[IO, str, Path], *args, **kwargs) -> IO: + """ + providing a easy interface to get an IO object + + Parameters + ---------- + file : Union[IO, str, Path] + a object representing the file + + Returns + ------- + IO: + a IO-like object + + Raises + ------ + NotImplementedError: + """ + if isinstance(file, IO): + yield file + else: + if isinstance(file, str): + file = Path(file) + if not isinstance(file, Path): + raise NotImplementedError(f"This type[{type(file)}] of input is not supported") + with file.open(*args, **kwargs) as f: + yield f diff --git a/tests/backtest/test_file_strategy.py b/tests/backtest/test_file_strategy.py new file mode 100644 index 000000000..da52b0d53 --- /dev/null +++ b/tests/backtest/test_file_strategy.py @@ -0,0 +1,86 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest +from qlib.backtest import backtest, order +from qlib.tests import TestAutoData +import pandas as pd +from pathlib import Path + +DIRNAME = Path(__file__).absolute().resolve().parent + + +class FileStrTest(TestAutoData): + + TEST_INST = "SH600519" + + EXAMPLE_FILE = DIRNAME / "order_example.csv" + + def _gen_orders(self) -> pd.DataFrame: + headers = [ + "datetime", + "instrument", + "amount", + "direction", + ] + orders = [ + ["20200102", self.TEST_INST, "1000", "sell"], + ["20200103", self.TEST_INST, "1000", "buy"], + ["20200106", self.TEST_INST, "1000", "sell"], + ] + return pd.DataFrame(orders, columns=headers).set_index(["datetime", "instrument"]) + + def test_file_str(self): + + orders = self._gen_orders() + print(orders) + orders.to_csv(self.EXAMPLE_FILE) + + orders = pd.read_csv(self.EXAMPLE_FILE, index_col=["datetime", "instrument"]) + + strategy_config = { + "class": "FileOrderStrategy", + "module_path": "qlib.contrib.strategy.rule_strategy", + "kwargs": {"file": self.EXAMPLE_FILE}, + } + + freq = "day" + start_time = "2020-01-01" + end_time = "2020-01-16" + codes = [self.TEST_INST] + + backtest_config = { + "start_time": start_time, + "end_time": end_time, + "account": 100000000, + "benchmark": None, # benchmark is not required here for trading + "exchange_kwargs": { + "freq": freq, + "limit_threshold": 0.095, + "deal_price": "close", + "open_cost": 0.0005, + "close_cost": 0.0015, + "min_cost": 5, + "codes": codes, + }, + # "pos_type": "InfPosition" # Position with infinitive position + } + executor_config = { + "class": "SimulatorExecutor", + "module_path": "qlib.backtest.executor", + "kwargs": { + "time_per_step": freq, + "generate_report": False, + "verbose": True, + "indicator_config": { + "show_indicator": False, + }, + }, + } + backtest(executor=executor_config, strategy=strategy_config, **backtest_config) + + self.EXAMPLE_FILE.unlink() + + +if __name__ == "__main__": + unittest.main()