diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py index a6ef2f6b8..6167ee407 100644 --- a/qlib/backtest/account.py +++ b/qlib/backtest/account.py @@ -263,11 +263,11 @@ class Account: elif atomic is False and inner_order_indicators is None: raise ValueError("inner_order_indicators is necessary in unatomic executor") + # TODO: `update_bar_count` and `update_current` should placed in Position and be merged. + self.update_bar_count() + self.update_current(trade_start_time, trade_end_time, trade_exchange) if generate_report: # report is portfolio related analysis - # TODO: `update_bar_count` and `update_current` should placed in Position and be merged. - self.update_bar_count() - self.update_current(trade_start_time, trade_end_time, trade_exchange) self.update_report(trade_start_time, trade_end_time) # indicator is trading (e.g. high-frequency order execution) related analysis diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index cffa98ba6..8177d53ee 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -14,7 +14,7 @@ from ..data.dataset.utils import get_level_index from ..config import C, REG_CN from ..utils.resam import resam_ts_data from ..log import get_module_logger -from .order import Order +from .order import Order, OrderDir, OrderHelper class Exchange: @@ -164,10 +164,6 @@ class Exchange: assert set(self.extra_quote.columns) == set(quote_df.columns) - {"$change"} quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0) - # update quote: pd.DataFrame to dict, for search use - if get_level_index(quote_df, level="datetime") == 1: - quote_df = quote_df.swaplevel().sort_index() - quote_dict = {} for stock_id, stock_val in quote_df.groupby(level="instrument"): quote_dict[stock_id] = stock_val @@ -530,3 +526,9 @@ class Exchange: raise NotImplementedError("order type {} error".format(order.type)) return trade_val, trade_cost + + def get_order_helper(self) -> OrderHelper: + if not hasattr(self, "_order_helper"): + # cache to avoid recreate the same instance + self._order_helper = OrderHelper(self) + return self._order_helper diff --git a/qlib/backtest/order.py b/qlib/backtest/order.py index 58c208080..b013d8723 100644 --- a/qlib/backtest/order.py +++ b/qlib/backtest/order.py @@ -2,12 +2,14 @@ # Licensed under the MIT License. # TODO: rename it with decision.py from __future__ import annotations +from enum import IntEnum # try to fix circular imports when enabling type hints from typing import TYPE_CHECKING if TYPE_CHECKING: from qlib.strategy.base import BaseStrategy + from qlib.backtest.exchange import Exchange from qlib.backtest.utils import TradeCalendarManager import warnings import pandas as pd @@ -15,6 +17,12 @@ from dataclasses import dataclass, field from typing import ClassVar, Optional, Union, List, Set, Tuple +class OrderDir(IntEnum): + # Order direction + SELL = 0 + BUY = 1 + + @dataclass class Order: """ @@ -32,19 +40,98 @@ class Order: stock_id: str amount: float + + # The interval of the order which belongs to (NOTE: this is not the expected order dealing range time) start_time: pd.Timestamp end_time: pd.Timestamp + direction: int factor: float - deal_amount: Optional[float] = None - SELL: ClassVar[int] = 0 - BUY: ClassVar[int] = 1 + deal_amount: float = field(init=False) + + # FIXME: + # for compatible now. + # Plese remove them in the future + SELL: ClassVar[OrderDir] = OrderDir.SELL + BUY: ClassVar[OrderDir] = OrderDir.BUY + def __post_init__(self): if self.direction not in {Order.SELL, Order.BUY}: raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy") self.deal_amount = 0 + @staticmethod + def parse_dir(direction: Union[str, int, OrderDir]) -> OrderDir: + if isinstance(direction, OrderDir): + return direction + elif isinstance(direction, int): + return OrderDir(direction) + elif isinstance(direction, str): + dl = direction.lower() + if dl.strip() == "sell": + return OrderDir.SELL + elif dl.strip() == "buy": + return OrderDir.BUY + else: + raise NotImplementedError(f"This type of input is not supported") + else: + raise NotImplementedError(f"This type of input is not supported") + + +class OrderHelper: + """ + Motivation + - Make generating order easier + - User may have no knowledge about the adjust-factor information about the system. + - It involves to much interaction with the exchange when generating orders. + """ + + def __init__(self, exchange: Exchange): + self.exchange = exchange + + def create( + self, + code: str, + amount: float, + direction: OrderDir, + start_time: Union[str, pd.Timestamp], + end_time: Union[str, pd.Timestamp], + ) -> Order: + """ + help to create a order + + # TODO: create order for unadjusted amount order + + Parameters + ---------- + code : str + the id of the instrument + amount : float + **adjusted trading amount** + direction : OrderDir + trading direction + start_time : Union[str, pd.Timestamp] + The interval of the order which belongs to + end_time : Union[str, pd.Timestamp] + The interval of the order which belongs to + + Returns + ------- + Order: + The created order + """ + start_time = pd.Timestamp(start_time) + end_time = pd.Timestamp(end_time) + return Order( + stock_id=code, + amount=amount, + start_time=start_time, + end_time=end_time, + direction=direction, + factor=self.exchange.get_factor(code, start_time, end_time), + ) + class BaseTradeDecision: """ diff --git a/qlib/backtest/position.py b/qlib/backtest/position.py index 0f36e4959..b1037d460 100644 --- a/qlib/backtest/position.py +++ b/qlib/backtest/position.py @@ -408,7 +408,7 @@ class InfPosition(BasePosition): """ def skip_update(self) -> bool: - """ Updating state is meaningless for InfPosition """ + """Updating state is meaningless for InfPosition""" return True def check_stock(self, stock_id: str) -> bool: diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 0fb98e8ac..22483a79c 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -1,15 +1,19 @@ +from pathlib import Path import warnings import numpy as np import pandas as pd -from typing import List, Tuple, Union +from typing import IO, List, Tuple, Union +from qlib.data.dataset.utils import convert_index_format + +from qlib.utils import lazy_sort_index from ...utils.resam import resam_ts_data from ...data.data import D -from ...data.dataset.utils import convert_index_format from ...strategy.base import BaseStrategy from ...backtest.order import BaseTradeDecision, Order, TradeDecisionWO -from ...backtest.exchange import Exchange +from ...backtest.exchange import Exchange, OrderHelper from ...backtest.utils import CommonInfrastructure, LevelInfrastructure +from qlib.utils.file import get_io_object def get_start_end_idx(strategy: BaseStrategy, outer_trade_decision: BaseTradeDecision) -> Union[int, int]: @@ -423,7 +427,6 @@ class SBBStrategyEMA(SBBStrategyBase): signal_df = D.features( self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq ) - signal_df = convert_index_format(signal_df) signal_df.columns = ["signal"] self.signal = {} @@ -515,7 +518,6 @@ class ACStrategy(BaseStrategy): signal_df = D.features( self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq ) - signal_df = convert_index_format(signal_df) signal_df.columns = ["volatility"] self.signal = {} @@ -656,6 +658,9 @@ class RandomOrderStrategy(BaseStrategy): index_range : Tuple the intra day time index range of the orders the left and right is closed. + + If you want to get the index_range in intra-day + - `qlib/utils/time.py:def get_day_min_idx_range` can help you create the index range easier # TODO: this is a index_range level limitation. We'll implement a more detailed limitation later. sample_ratio : float the ratio of all orders are sampled @@ -687,7 +692,9 @@ class RandomOrderStrategy(BaseStrategy): if step_time_start in self.volume_df: for stock_id, volume in self.volume_df[step_time_start].dropna().sample(frac=self.sample_ratio).items(): order_list.append( - self.common_infra.get("trade_exchange").create_order( + self.common_infra.get("trade_exchange") + .get_order_helper() + .create( code=stock_id, amount=volume * self.volume_ratio, start_time=step_time_start, @@ -696,3 +703,53 @@ class RandomOrderStrategy(BaseStrategy): ) ) return TradeDecisionWO(order_list, self, self.index_range) + + +class FileOrderStrategy(BaseStrategy): + """ + Motivtaion: + - This class provides an interface for user to read orders from csv files. + - It is supposed to be used in + """ + + def __init__(self, file: Union[IO, str, Path], index_range: Tuple[int, int] = None, *args, **kwargs): + super().__init__(*args, **kwargs) + with get_io_object(file) as f: + self.order_df = pd.read_csv(f, dtype={"datetime": np.str}) + + self.order_df["datetime"] = self.order_df["datetime"].apply(pd.Timestamp) + self.order_df = self.order_df.set_index(["datetime", "instrument"]) + + # make sure the datetime is the first level for fast indexing + self.order_df = lazy_sort_index(convert_index_format(self.order_df, level="datetime")) + self.index_range = index_range + + def generate_trade_decision(self, execute_result=None) -> TradeDecisionWO: + """ + Parameters + ---------- + execute_result : + execute_result will be ignored in FileOrderStrategy + """ + oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper() + tc = self.trade_calendar + step = tc.get_trade_step() + start, end = tc.get_step_time(step) + # CONVERSION: the bar is indexed by the time + try: + df = self.order_df.loc(axis=0)[start] + except KeyError: + return TradeDecisionWO([], self) + else: + order_list = [] + for idx, row in df.iterrows(): + order_list.append( + oh.create( + code=idx, + amount=row["amount"], + direction=Order.parse_dir(row["direction"]), + start_time=start, + end_time=end, + ) + ) + return TradeDecisionWO(order_list, self, self.index_range) diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index c6338832a..edcc1ede2 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -17,7 +17,7 @@ from ...data import D from ...config import C from ...utils import parse_config, transform_end_date, init_instance_by_config from ...utils.serial import Serializable -from .utils import fetch_df_by_index +from .utils import fetch_df_by_index, fetch_df_by_col from pathlib import Path from .loader import DataLoader @@ -152,14 +152,6 @@ class DataHandler(Serializable): CS_ALL = "__all" # return all columns with single-level index column CS_RAW = "__raw" # return raw data with multi-level index column - def _fetch_df_by_col(self, df: pd.DataFrame, col_set: str) -> pd.DataFrame: - if not isinstance(df.columns, pd.MultiIndex) or col_set == self.CS_RAW: - return df - elif col_set == self.CS_ALL: - return df.droplevel(axis=1, level=0) - else: - return df.loc(axis=1)[col_set] - def fetch( self, selector: Union[pd.Timestamp, slice, str] = slice(None, None), @@ -183,7 +175,7 @@ class DataHandler(Serializable): select a set of meaningful columns.(e.g. features, columns) - if cal_set == CS_RAW: + if col_set == CS_RAW: the raw dataset will be returned. - if isinstance(col_set, List[str]): @@ -205,23 +197,34 @@ class DataHandler(Serializable): ------- pd.DataFrame. """ - if proc_func is None: - df = self._data - else: - # FIXME: fetching by time first will be more friendly to `proc_func` - # Copy in case of `proc_func` changing the data inplace.... - df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy()) + from .storage import HasingStockStorage + + data_storage = self._data + if isinstance(data_storage, pd.DataFrame): + data_df = data_storage + if proc_func is not None: + # FIXME: fetching by time first will be more friendly to `proc_func` + # Copy in case of `proc_func` changing the data inplace.... + data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy()) + data_df = fetch_df_by_col(data_df, col_set) + else: + # Fetch column first will be more friendly to SepDataFrame + data_df = fetch_df_by_col(data_df, col_set) + data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig) + elif isinstance(data_storage, HasingStockStorage): + if proc_func is not None: + raise ValueError("proc_func is not supported by the HasingStockStorage") + data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig) + else: + raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}") - # Fetch column first will be more friendly to SepDataFrame - df = self._fetch_df_by_col(df, col_set) - df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig) if squeeze: # squeeze columns - df = df.squeeze() + data_df = data_df.squeeze() # squeeze index if isinstance(selector, (str, pd.Timestamp)): - df = df.reset_index(level=level, drop=True) - return df + data_df = data_df.reset_index(level=level, drop=True) + return data_df def get_cols(self, col_set=CS_ALL) -> list: """ @@ -238,7 +241,7 @@ class DataHandler(Serializable): list of column names """ df = self._data.head() - df = self._fetch_df_by_col(df, col_set) + df = fetch_df_by_col(df, col_set) return df.columns.to_list() def get_range_selector(self, cur_date: Union[pd.Timestamp, str], periods: int) -> slice: @@ -519,14 +522,29 @@ class DataHandlerLP(DataHandler): ------- pd.DataFrame: """ - df = self._get_df_by_key(data_key) - if proc_func is not None: - # FIXME: fetch by time first will be more friendly to proc_func - # Copy incase of `proc_func` changing the data inplace.... - df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy()) - # Fetch column first will be more friendly to SepDataFrame - df = self._fetch_df_by_col(df, col_set) - return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig) + from .storage import HasingStockStorage + + data_storage = self._get_df_by_key(data_key) + if isinstance(data_storage, pd.DataFrame): + data_df = data_storage + if proc_func is not None: + # FIXME: fetch by time first will be more friendly to proc_func + # Copy incase of `proc_func` changing the data inplace.... + data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy()) + data_df = fetch_df_by_col(data_df, col_set) + else: + # Fetch column first will be more friendly to SepDataFrame + data_df = fetch_df_by_col(data_df, col_set) + data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig) + + elif isinstance(data_storage, HasingStockStorage): + if proc_func is not None: + raise ValueError("proc_func is not supported by the HasingStockStorage") + data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig) + else: + raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}") + + return data_df def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list: """ @@ -545,5 +563,5 @@ class DataHandlerLP(DataHandler): list of column names """ df = self._get_df_by_key(data_key).head() - df = self._fetch_df_by_col(df, col_set) + df = fetch_df_by_col(df, col_set) return df.columns.to_list() diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index fce22ddfc..cc6dcdfd3 100644 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -310,3 +310,12 @@ class CSZFillna(Processor): cols = get_group_columns(df, self.fields_group) df[cols] = df[cols].groupby("datetime").apply(lambda x: x.fillna(x.mean())) return df + + +class HashStockFormat(Processor): + """Process the storage of from df into hasing stock format""" + + def __call__(self, df: pd.DataFrame): + from .storage import HasingStockStorage + + return HasingStockStorage.from_df(df) diff --git a/qlib/data/dataset/storage.py b/qlib/data/dataset/storage.py new file mode 100644 index 000000000..247970481 --- /dev/null +++ b/qlib/data/dataset/storage.py @@ -0,0 +1,107 @@ +import pandas as pd +import numpy as np + +from .handler import DataHandler +from typing import Tuple, Union, List, Callable + +from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col + + +class BaseHandlerStorage: + def fetch( + self, + selector: Union[pd.Timestamp, slice, str, list] = slice(None, None), + level: Union[str, int] = "datetime", + col_set: Union[str, List[str]] = DataHandler.CS_ALL, + fetch_orig: bool = True, + **kwargs, + ) -> pd.DataFrame: + """fetch data from the data storage + + Parameters + ---------- + selector : Union[pd.Timestamp, slice, str] + describe how to select data by index + level : Union[str, int] + which index level to select the data + col_set : Union[str, List[str]] + - if isinstance(col_set, str): + select a set of meaningful columns.(e.g. features, columns) + if col_set == DataHandler.CS_RAW: + the raw dataset will be returned. + - if isinstance(col_set, List[str]): + select several sets of meaningful columns, the returned data has multiple level + fetch_orig : bool + Return the original data instead of copy if possible. + + """ + + raise NotImplementedError("fetch is method not implemented!") + + @staticmethod + def from_df(df: pd.DataFrame): + raise NotImplementedError("from_df method is not implemented!") + + +class HasingStockStorage(BaseHandlerStorage): + def __init__(self, df): + self.hash_df = dict() + self.stock_level = get_level_index(df, "instrument") + for k, v in df.groupby(level="instrument"): + self.hash_df[k] = v + self.columns = df.columns + + @staticmethod + def from_df(df): + return HasingStockStorage(df) + + def _fetch_hash_df_by_stock(self, selector, level): + stock_selector = slice(None) + + if level is None: + if isinstance(selector, tuple) and self.stock_level < len(selector): + stock_selector = selector[self.stock_level] + elif isinstance(selector, (list, str)) and self.stock_level == 0: + stock_selector = selector + elif level == "instrument" or level == self.stock_level: + if isinstance(selector, tuple): + stock_selector = selector[0] + elif isinstance(selector, (list, str)): + stock_selector = selector + + if not isinstance(stock_selector, (list, str)) and stock_selector != slice(None): + raise TypeError(f"stock selector must be type str|list, or slice(None), rather than {stock_selector}") + + if stock_selector == slice(None): + return self.hash_df + + if isinstance(stock_selector, str): + stock_selector = [stock_selector] + + select_dict = dict() + for each_stock in sorted(stock_selector): + if each_stock in self.hash_df: + select_dict[each_stock] = self.hash_df[each_stock] + return select_dict + + def fetch( + self, + selector: Union[pd.Timestamp, slice, str] = slice(None, None), + level: Union[str, int] = "datetime", + col_set: Union[str, List[str]] = DataHandler.CS_ALL, + fetch_orig: bool = True, + ) -> pd.DataFrame: + fetch_stock_df_list = list(self._fetch_hash_df_by_stock(selector=selector, level=level).values()) + for _index, stock_df in enumerate(fetch_stock_df_list): + fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set) + fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level, fetch_orig=fetch_orig) + fetch_stock_df_list[_index] = fetch_index_df + if len(fetch_stock_df_list) == 0: + index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument") + return pd.DataFrame( + index=pd.MultiIndex.from_arrays([[], []], names=index_names), columns=self.columns, dtype=np.float32 + ) + elif len(fetch_stock_df_list) == 1: + return fetch_stock_df_list[0] + else: + return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig) diff --git a/qlib/data/dataset/utils.py b/qlib/data/dataset/utils.py index f7b07d563..c6b3d97b6 100644 --- a/qlib/data/dataset/utils.py +++ b/qlib/data/dataset/utils.py @@ -1,5 +1,8 @@ -from typing import Union +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import pandas as pd +from typing import Union, List def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int: @@ -72,12 +75,26 @@ def fetch_df_by_index( ] +def fetch_df_by_col(df: pd.DataFrame, col_set: Union[str, List[str]]) -> pd.DataFrame: + from .handler import DataHandler + + if not isinstance(df.columns, pd.MultiIndex) or col_set == DataHandler.CS_RAW: + return df + elif col_set == DataHandler.CS_ALL: + return df.droplevel(axis=1, level=0) + else: + return df.loc(axis=1)[col_set] + + def convert_index_format(df: Union[pd.DataFrame, pd.Series], level: str = "datetime") -> Union[pd.DataFrame, pd.Series]: """ Convert the format of df.MultiIndex according to the following rules: - If `level` is the first level of df.MultiIndex, do nothing - If `level` is the second level of df.MultiIndex, swap the level of index. + NOTE: + the number of levels of df.MultiIndex should be 2 + Parameters ---------- df : Union[pd.DataFrame, pd.Series] diff --git a/qlib/utils/file.py b/qlib/utils/file.py new file mode 100644 index 000000000..611260c86 --- /dev/null +++ b/qlib/utils/file.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# TODO: move file related utils into this module +import contextlib +from typing import IO, Union +from pathlib import Path + + +@contextlib.contextmanager +def get_io_object(file: Union[IO, str, Path], *args, **kwargs) -> IO: + """ + providing a easy interface to get an IO object + + Parameters + ---------- + file : Union[IO, str, Path] + a object representing the file + + Returns + ------- + IO: + a IO-like object + + Raises + ------ + NotImplementedError: + """ + if isinstance(file, IO): + yield file + else: + if isinstance(file, str): + file = Path(file) + if not isinstance(file, Path): + raise NotImplementedError(f"This type[{type(file)}] of input is not supported") + with file.open(*args, **kwargs) as f: + yield f diff --git a/tests/backtest/test_file_strategy.py b/tests/backtest/test_file_strategy.py new file mode 100644 index 000000000..da52b0d53 --- /dev/null +++ b/tests/backtest/test_file_strategy.py @@ -0,0 +1,86 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest +from qlib.backtest import backtest, order +from qlib.tests import TestAutoData +import pandas as pd +from pathlib import Path + +DIRNAME = Path(__file__).absolute().resolve().parent + + +class FileStrTest(TestAutoData): + + TEST_INST = "SH600519" + + EXAMPLE_FILE = DIRNAME / "order_example.csv" + + def _gen_orders(self) -> pd.DataFrame: + headers = [ + "datetime", + "instrument", + "amount", + "direction", + ] + orders = [ + ["20200102", self.TEST_INST, "1000", "sell"], + ["20200103", self.TEST_INST, "1000", "buy"], + ["20200106", self.TEST_INST, "1000", "sell"], + ] + return pd.DataFrame(orders, columns=headers).set_index(["datetime", "instrument"]) + + def test_file_str(self): + + orders = self._gen_orders() + print(orders) + orders.to_csv(self.EXAMPLE_FILE) + + orders = pd.read_csv(self.EXAMPLE_FILE, index_col=["datetime", "instrument"]) + + strategy_config = { + "class": "FileOrderStrategy", + "module_path": "qlib.contrib.strategy.rule_strategy", + "kwargs": {"file": self.EXAMPLE_FILE}, + } + + freq = "day" + start_time = "2020-01-01" + end_time = "2020-01-16" + codes = [self.TEST_INST] + + backtest_config = { + "start_time": start_time, + "end_time": end_time, + "account": 100000000, + "benchmark": None, # benchmark is not required here for trading + "exchange_kwargs": { + "freq": freq, + "limit_threshold": 0.095, + "deal_price": "close", + "open_cost": 0.0005, + "close_cost": 0.0015, + "min_cost": 5, + "codes": codes, + }, + # "pos_type": "InfPosition" # Position with infinitive position + } + executor_config = { + "class": "SimulatorExecutor", + "module_path": "qlib.backtest.executor", + "kwargs": { + "time_per_step": freq, + "generate_report": False, + "verbose": True, + "indicator_config": { + "show_indicator": False, + }, + }, + } + backtest(executor=executor_config, strategy=strategy_config, **backtest_config) + + self.EXAMPLE_FILE.unlink() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_storage.py b/tests/test_handler_storage.py new file mode 100644 index 000000000..056595063 --- /dev/null +++ b/tests/test_handler_storage.py @@ -0,0 +1,114 @@ +import unittest +import time +import numpy as np +from qlib.data import D +from qlib.tests import TestAutoData + +from qlib.data.dataset.handler import DataHandlerLP +from qlib.contrib.data.handler import check_transform_proc +from qlib.log import TimeInspector + + +class TestHandler(DataHandlerLP): + def __init__( + self, + instruments="csi300", + start_time=None, + end_time=None, + infer_processors=[], + learn_processors=[], + fit_start_time=None, + fit_end_time=None, + drop_raw=True, + ): + + infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) + learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) + + data_loader = { + "class": "QlibDataLoader", + "kwargs": { + "freq": "day", + "config": self.get_feature_config(), + "swap_level": False, + }, + } + + super().__init__( + instruments=instruments, + start_time=start_time, + end_time=end_time, + data_loader=data_loader, + infer_processors=infer_processors, + learn_processors=learn_processors, + drop_raw=drop_raw, + ) + + def get_feature_config(self): + fields = ["Ref($open, 1)", "Ref($close, 1)", "Ref($volume, 1)", "$open", "$close", "$volume"] + names = ["open_0", "close_0", "volume_0", "open_1", "close_1", "volume_1"] + return fields, names + + +class TestHandlerStorage(TestAutoData): + + market = "all" + + start_time = "2010-01-01" + end_time = "2020-12-31" + train_end_time = "2015-12-31" + test_start_time = "2016-01-01" + + data_handler_kwargs = { + "start_time": start_time, + "end_time": end_time, + "fit_start_time": start_time, + "fit_end_time": train_end_time, + "instruments": market, + } + + def test_handler_storage(self): + # init data handler + data_handler = TestHandler(**self.data_handler_kwargs) + + # init data handler with hasing storage + data_handler_hs = TestHandler(**self.data_handler_kwargs, infer_processors=["HashStockFormat"]) + + fetch_start_time = "2019-01-01" + fetch_end_time = "2019-12-31" + instruments = D.instruments(market=self.market) + instruments = D.list_instruments( + instruments=instruments, start_time=fetch_start_time, end_time=fetch_end_time, as_list=True + ) + + with TimeInspector.logt("random fetch with DataFrame Storage"): + + # single stock + for i in range(100): + random_index = np.random.randint(len(instruments), size=1)[0] + fetch_stock = instruments[random_index] + data_handler.fetch(selector=(fetch_stock, slice(fetch_start_time, fetch_end_time)), level=None) + + # multi stocks + for i in range(100): + random_indexs = np.random.randint(len(instruments), size=5) + fetch_stocks = [instruments[_index] for _index in random_indexs] + data_handler.fetch(selector=(fetch_stocks, slice(fetch_start_time, fetch_end_time)), level=None) + + with TimeInspector.logt("random fetch with HasingStock Storage"): + + # single stock + for i in range(100): + random_index = np.random.randint(len(instruments), size=1)[0] + fetch_stock = instruments[random_index] + data_handler_hs.fetch(selector=(fetch_stock, slice(fetch_start_time, fetch_end_time)), level=None) + + # multi stocks + for i in range(100): + random_indexs = np.random.randint(len(instruments), size=5) + fetch_stocks = [instruments[_index] for _index in random_indexs] + data_handler_hs.fetch(selector=(fetch_stocks, slice(fetch_start_time, fetch_end_time)), level=None) + + +if __name__ == "__main__": + unittest.main()