1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 02:50:58 +08:00

Merge branch 'nested_decision_exe' of https://github.com/microsoft/qlib into rl-dummy

This commit is contained in:
v-mingzhehan
2021-06-30 08:30:33 +00:00
12 changed files with 585 additions and 51 deletions

View File

@@ -263,11 +263,11 @@ class Account:
elif atomic is False and inner_order_indicators is None:
raise ValueError("inner_order_indicators is necessary in unatomic executor")
# TODO: `update_bar_count` and `update_current` should placed in Position and be merged.
self.update_bar_count()
self.update_current(trade_start_time, trade_end_time, trade_exchange)
if generate_report:
# report is portfolio related analysis
# TODO: `update_bar_count` and `update_current` should placed in Position and be merged.
self.update_bar_count()
self.update_current(trade_start_time, trade_end_time, trade_exchange)
self.update_report(trade_start_time, trade_end_time)
# indicator is trading (e.g. high-frequency order execution) related analysis

View File

@@ -14,7 +14,7 @@ from ..data.dataset.utils import get_level_index
from ..config import C, REG_CN
from ..utils.resam import resam_ts_data
from ..log import get_module_logger
from .order import Order
from .order import Order, OrderDir, OrderHelper
class Exchange:
@@ -164,10 +164,6 @@ class Exchange:
assert set(self.extra_quote.columns) == set(quote_df.columns) - {"$change"}
quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0)
# update quote: pd.DataFrame to dict, for search use
if get_level_index(quote_df, level="datetime") == 1:
quote_df = quote_df.swaplevel().sort_index()
quote_dict = {}
for stock_id, stock_val in quote_df.groupby(level="instrument"):
quote_dict[stock_id] = stock_val
@@ -530,3 +526,9 @@ class Exchange:
raise NotImplementedError("order type {} error".format(order.type))
return trade_val, trade_cost
def get_order_helper(self) -> OrderHelper:
if not hasattr(self, "_order_helper"):
# cache to avoid recreate the same instance
self._order_helper = OrderHelper(self)
return self._order_helper

View File

@@ -2,12 +2,14 @@
# Licensed under the MIT License.
# TODO: rename it with decision.py
from __future__ import annotations
from enum import IntEnum
# try to fix circular imports when enabling type hints
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from qlib.strategy.base import BaseStrategy
from qlib.backtest.exchange import Exchange
from qlib.backtest.utils import TradeCalendarManager
import warnings
import pandas as pd
@@ -15,6 +17,12 @@ from dataclasses import dataclass, field
from typing import ClassVar, Optional, Union, List, Set, Tuple
class OrderDir(IntEnum):
# Order direction
SELL = 0
BUY = 1
@dataclass
class Order:
"""
@@ -32,19 +40,98 @@ class Order:
stock_id: str
amount: float
# The interval of the order which belongs to (NOTE: this is not the expected order dealing range time)
start_time: pd.Timestamp
end_time: pd.Timestamp
direction: int
factor: float
deal_amount: Optional[float] = None
SELL: ClassVar[int] = 0
BUY: ClassVar[int] = 1
deal_amount: float = field(init=False)
# FIXME:
# for compatible now.
# Plese remove them in the future
SELL: ClassVar[OrderDir] = OrderDir.SELL
BUY: ClassVar[OrderDir] = OrderDir.BUY
def __post_init__(self):
if self.direction not in {Order.SELL, Order.BUY}:
raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy")
self.deal_amount = 0
@staticmethod
def parse_dir(direction: Union[str, int, OrderDir]) -> OrderDir:
if isinstance(direction, OrderDir):
return direction
elif isinstance(direction, int):
return OrderDir(direction)
elif isinstance(direction, str):
dl = direction.lower()
if dl.strip() == "sell":
return OrderDir.SELL
elif dl.strip() == "buy":
return OrderDir.BUY
else:
raise NotImplementedError(f"This type of input is not supported")
else:
raise NotImplementedError(f"This type of input is not supported")
class OrderHelper:
"""
Motivation
- Make generating order easier
- User may have no knowledge about the adjust-factor information about the system.
- It involves to much interaction with the exchange when generating orders.
"""
def __init__(self, exchange: Exchange):
self.exchange = exchange
def create(
self,
code: str,
amount: float,
direction: OrderDir,
start_time: Union[str, pd.Timestamp],
end_time: Union[str, pd.Timestamp],
) -> Order:
"""
help to create a order
# TODO: create order for unadjusted amount order
Parameters
----------
code : str
the id of the instrument
amount : float
**adjusted trading amount**
direction : OrderDir
trading direction
start_time : Union[str, pd.Timestamp]
The interval of the order which belongs to
end_time : Union[str, pd.Timestamp]
The interval of the order which belongs to
Returns
-------
Order:
The created order
"""
start_time = pd.Timestamp(start_time)
end_time = pd.Timestamp(end_time)
return Order(
stock_id=code,
amount=amount,
start_time=start_time,
end_time=end_time,
direction=direction,
factor=self.exchange.get_factor(code, start_time, end_time),
)
class BaseTradeDecision:
"""

View File

@@ -408,7 +408,7 @@ class InfPosition(BasePosition):
"""
def skip_update(self) -> bool:
""" Updating state is meaningless for InfPosition """
"""Updating state is meaningless for InfPosition"""
return True
def check_stock(self, stock_id: str) -> bool:

View File

@@ -1,15 +1,19 @@
from pathlib import Path
import warnings
import numpy as np
import pandas as pd
from typing import List, Tuple, Union
from typing import IO, List, Tuple, Union
from qlib.data.dataset.utils import convert_index_format
from qlib.utils import lazy_sort_index
from ...utils.resam import resam_ts_data
from ...data.data import D
from ...data.dataset.utils import convert_index_format
from ...strategy.base import BaseStrategy
from ...backtest.order import BaseTradeDecision, Order, TradeDecisionWO
from ...backtest.exchange import Exchange
from ...backtest.exchange import Exchange, OrderHelper
from ...backtest.utils import CommonInfrastructure, LevelInfrastructure
from qlib.utils.file import get_io_object
def get_start_end_idx(strategy: BaseStrategy, outer_trade_decision: BaseTradeDecision) -> Union[int, int]:
@@ -423,7 +427,6 @@ class SBBStrategyEMA(SBBStrategyBase):
signal_df = D.features(
self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
)
signal_df = convert_index_format(signal_df)
signal_df.columns = ["signal"]
self.signal = {}
@@ -515,7 +518,6 @@ class ACStrategy(BaseStrategy):
signal_df = D.features(
self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
)
signal_df = convert_index_format(signal_df)
signal_df.columns = ["volatility"]
self.signal = {}
@@ -656,6 +658,9 @@ class RandomOrderStrategy(BaseStrategy):
index_range : Tuple
the intra day time index range of the orders
the left and right is closed.
If you want to get the index_range in intra-day
- `qlib/utils/time.py:def get_day_min_idx_range` can help you create the index range easier
# TODO: this is a index_range level limitation. We'll implement a more detailed limitation later.
sample_ratio : float
the ratio of all orders are sampled
@@ -687,7 +692,9 @@ class RandomOrderStrategy(BaseStrategy):
if step_time_start in self.volume_df:
for stock_id, volume in self.volume_df[step_time_start].dropna().sample(frac=self.sample_ratio).items():
order_list.append(
self.common_infra.get("trade_exchange").create_order(
self.common_infra.get("trade_exchange")
.get_order_helper()
.create(
code=stock_id,
amount=volume * self.volume_ratio,
start_time=step_time_start,
@@ -696,3 +703,53 @@ class RandomOrderStrategy(BaseStrategy):
)
)
return TradeDecisionWO(order_list, self, self.index_range)
class FileOrderStrategy(BaseStrategy):
"""
Motivtaion:
- This class provides an interface for user to read orders from csv files.
- It is supposed to be used in
"""
def __init__(self, file: Union[IO, str, Path], index_range: Tuple[int, int] = None, *args, **kwargs):
super().__init__(*args, **kwargs)
with get_io_object(file) as f:
self.order_df = pd.read_csv(f, dtype={"datetime": np.str})
self.order_df["datetime"] = self.order_df["datetime"].apply(pd.Timestamp)
self.order_df = self.order_df.set_index(["datetime", "instrument"])
# make sure the datetime is the first level for fast indexing
self.order_df = lazy_sort_index(convert_index_format(self.order_df, level="datetime"))
self.index_range = index_range
def generate_trade_decision(self, execute_result=None) -> TradeDecisionWO:
"""
Parameters
----------
execute_result :
execute_result will be ignored in FileOrderStrategy
"""
oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper()
tc = self.trade_calendar
step = tc.get_trade_step()
start, end = tc.get_step_time(step)
# CONVERSION: the bar is indexed by the time
try:
df = self.order_df.loc(axis=0)[start]
except KeyError:
return TradeDecisionWO([], self)
else:
order_list = []
for idx, row in df.iterrows():
order_list.append(
oh.create(
code=idx,
amount=row["amount"],
direction=Order.parse_dir(row["direction"]),
start_time=start,
end_time=end,
)
)
return TradeDecisionWO(order_list, self, self.index_range)

View File

@@ -17,7 +17,7 @@ from ...data import D
from ...config import C
from ...utils import parse_config, transform_end_date, init_instance_by_config
from ...utils.serial import Serializable
from .utils import fetch_df_by_index
from .utils import fetch_df_by_index, fetch_df_by_col
from pathlib import Path
from .loader import DataLoader
@@ -152,14 +152,6 @@ class DataHandler(Serializable):
CS_ALL = "__all" # return all columns with single-level index column
CS_RAW = "__raw" # return raw data with multi-level index column
def _fetch_df_by_col(self, df: pd.DataFrame, col_set: str) -> pd.DataFrame:
if not isinstance(df.columns, pd.MultiIndex) or col_set == self.CS_RAW:
return df
elif col_set == self.CS_ALL:
return df.droplevel(axis=1, level=0)
else:
return df.loc(axis=1)[col_set]
def fetch(
self,
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
@@ -183,7 +175,7 @@ class DataHandler(Serializable):
select a set of meaningful columns.(e.g. features, columns)
if cal_set == CS_RAW:
if col_set == CS_RAW:
the raw dataset will be returned.
- if isinstance(col_set, List[str]):
@@ -205,23 +197,34 @@ class DataHandler(Serializable):
-------
pd.DataFrame.
"""
if proc_func is None:
df = self._data
else:
# FIXME: fetching by time first will be more friendly to `proc_func`
# Copy in case of `proc_func` changing the data inplace....
df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy())
from .storage import HasingStockStorage
data_storage = self._data
if isinstance(data_storage, pd.DataFrame):
data_df = data_storage
if proc_func is not None:
# FIXME: fetching by time first will be more friendly to `proc_func`
# Copy in case of `proc_func` changing the data inplace....
data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy())
data_df = fetch_df_by_col(data_df, col_set)
else:
# Fetch column first will be more friendly to SepDataFrame
data_df = fetch_df_by_col(data_df, col_set)
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
elif isinstance(data_storage, HasingStockStorage):
if proc_func is not None:
raise ValueError("proc_func is not supported by the HasingStockStorage")
data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig)
else:
raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}")
# Fetch column first will be more friendly to SepDataFrame
df = self._fetch_df_by_col(df, col_set)
df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
if squeeze:
# squeeze columns
df = df.squeeze()
data_df = data_df.squeeze()
# squeeze index
if isinstance(selector, (str, pd.Timestamp)):
df = df.reset_index(level=level, drop=True)
return df
data_df = data_df.reset_index(level=level, drop=True)
return data_df
def get_cols(self, col_set=CS_ALL) -> list:
"""
@@ -238,7 +241,7 @@ class DataHandler(Serializable):
list of column names
"""
df = self._data.head()
df = self._fetch_df_by_col(df, col_set)
df = fetch_df_by_col(df, col_set)
return df.columns.to_list()
def get_range_selector(self, cur_date: Union[pd.Timestamp, str], periods: int) -> slice:
@@ -519,14 +522,29 @@ class DataHandlerLP(DataHandler):
-------
pd.DataFrame:
"""
df = self._get_df_by_key(data_key)
if proc_func is not None:
# FIXME: fetch by time first will be more friendly to proc_func
# Copy incase of `proc_func` changing the data inplace....
df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy())
# Fetch column first will be more friendly to SepDataFrame
df = self._fetch_df_by_col(df, col_set)
return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
from .storage import HasingStockStorage
data_storage = self._get_df_by_key(data_key)
if isinstance(data_storage, pd.DataFrame):
data_df = data_storage
if proc_func is not None:
# FIXME: fetch by time first will be more friendly to proc_func
# Copy incase of `proc_func` changing the data inplace....
data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy())
data_df = fetch_df_by_col(data_df, col_set)
else:
# Fetch column first will be more friendly to SepDataFrame
data_df = fetch_df_by_col(data_df, col_set)
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
elif isinstance(data_storage, HasingStockStorage):
if proc_func is not None:
raise ValueError("proc_func is not supported by the HasingStockStorage")
data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig)
else:
raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}")
return data_df
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list:
"""
@@ -545,5 +563,5 @@ class DataHandlerLP(DataHandler):
list of column names
"""
df = self._get_df_by_key(data_key).head()
df = self._fetch_df_by_col(df, col_set)
df = fetch_df_by_col(df, col_set)
return df.columns.to_list()

View File

@@ -310,3 +310,12 @@ class CSZFillna(Processor):
cols = get_group_columns(df, self.fields_group)
df[cols] = df[cols].groupby("datetime").apply(lambda x: x.fillna(x.mean()))
return df
class HashStockFormat(Processor):
"""Process the storage of from df into hasing stock format"""
def __call__(self, df: pd.DataFrame):
from .storage import HasingStockStorage
return HasingStockStorage.from_df(df)

View File

@@ -0,0 +1,107 @@
import pandas as pd
import numpy as np
from .handler import DataHandler
from typing import Tuple, Union, List, Callable
from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col
class BaseHandlerStorage:
def fetch(
self,
selector: Union[pd.Timestamp, slice, str, list] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
fetch_orig: bool = True,
**kwargs,
) -> pd.DataFrame:
"""fetch data from the data storage
Parameters
----------
selector : Union[pd.Timestamp, slice, str]
describe how to select data by index
level : Union[str, int]
which index level to select the data
col_set : Union[str, List[str]]
- if isinstance(col_set, str):
select a set of meaningful columns.(e.g. features, columns)
if col_set == DataHandler.CS_RAW:
the raw dataset will be returned.
- if isinstance(col_set, List[str]):
select several sets of meaningful columns, the returned data has multiple level
fetch_orig : bool
Return the original data instead of copy if possible.
"""
raise NotImplementedError("fetch is method not implemented!")
@staticmethod
def from_df(df: pd.DataFrame):
raise NotImplementedError("from_df method is not implemented!")
class HasingStockStorage(BaseHandlerStorage):
def __init__(self, df):
self.hash_df = dict()
self.stock_level = get_level_index(df, "instrument")
for k, v in df.groupby(level="instrument"):
self.hash_df[k] = v
self.columns = df.columns
@staticmethod
def from_df(df):
return HasingStockStorage(df)
def _fetch_hash_df_by_stock(self, selector, level):
stock_selector = slice(None)
if level is None:
if isinstance(selector, tuple) and self.stock_level < len(selector):
stock_selector = selector[self.stock_level]
elif isinstance(selector, (list, str)) and self.stock_level == 0:
stock_selector = selector
elif level == "instrument" or level == self.stock_level:
if isinstance(selector, tuple):
stock_selector = selector[0]
elif isinstance(selector, (list, str)):
stock_selector = selector
if not isinstance(stock_selector, (list, str)) and stock_selector != slice(None):
raise TypeError(f"stock selector must be type str|list, or slice(None), rather than {stock_selector}")
if stock_selector == slice(None):
return self.hash_df
if isinstance(stock_selector, str):
stock_selector = [stock_selector]
select_dict = dict()
for each_stock in sorted(stock_selector):
if each_stock in self.hash_df:
select_dict[each_stock] = self.hash_df[each_stock]
return select_dict
def fetch(
self,
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
fetch_orig: bool = True,
) -> pd.DataFrame:
fetch_stock_df_list = list(self._fetch_hash_df_by_stock(selector=selector, level=level).values())
for _index, stock_df in enumerate(fetch_stock_df_list):
fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set)
fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level, fetch_orig=fetch_orig)
fetch_stock_df_list[_index] = fetch_index_df
if len(fetch_stock_df_list) == 0:
index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument")
return pd.DataFrame(
index=pd.MultiIndex.from_arrays([[], []], names=index_names), columns=self.columns, dtype=np.float32
)
elif len(fetch_stock_df_list) == 1:
return fetch_stock_df_list[0]
else:
return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig)

View File

@@ -1,5 +1,8 @@
from typing import Union
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
from typing import Union, List
def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
@@ -72,12 +75,26 @@ def fetch_df_by_index(
]
def fetch_df_by_col(df: pd.DataFrame, col_set: Union[str, List[str]]) -> pd.DataFrame:
from .handler import DataHandler
if not isinstance(df.columns, pd.MultiIndex) or col_set == DataHandler.CS_RAW:
return df
elif col_set == DataHandler.CS_ALL:
return df.droplevel(axis=1, level=0)
else:
return df.loc(axis=1)[col_set]
def convert_index_format(df: Union[pd.DataFrame, pd.Series], level: str = "datetime") -> Union[pd.DataFrame, pd.Series]:
"""
Convert the format of df.MultiIndex according to the following rules:
- If `level` is the first level of df.MultiIndex, do nothing
- If `level` is the second level of df.MultiIndex, swap the level of index.
NOTE:
the number of levels of df.MultiIndex should be 2
Parameters
----------
df : Union[pd.DataFrame, pd.Series]

37
qlib/utils/file.py Normal file
View File

@@ -0,0 +1,37 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# TODO: move file related utils into this module
import contextlib
from typing import IO, Union
from pathlib import Path
@contextlib.contextmanager
def get_io_object(file: Union[IO, str, Path], *args, **kwargs) -> IO:
"""
providing a easy interface to get an IO object
Parameters
----------
file : Union[IO, str, Path]
a object representing the file
Returns
-------
IO:
a IO-like object
Raises
------
NotImplementedError:
"""
if isinstance(file, IO):
yield file
else:
if isinstance(file, str):
file = Path(file)
if not isinstance(file, Path):
raise NotImplementedError(f"This type[{type(file)}] of input is not supported")
with file.open(*args, **kwargs) as f:
yield f

View File

@@ -0,0 +1,86 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest
from qlib.backtest import backtest, order
from qlib.tests import TestAutoData
import pandas as pd
from pathlib import Path
DIRNAME = Path(__file__).absolute().resolve().parent
class FileStrTest(TestAutoData):
TEST_INST = "SH600519"
EXAMPLE_FILE = DIRNAME / "order_example.csv"
def _gen_orders(self) -> pd.DataFrame:
headers = [
"datetime",
"instrument",
"amount",
"direction",
]
orders = [
["20200102", self.TEST_INST, "1000", "sell"],
["20200103", self.TEST_INST, "1000", "buy"],
["20200106", self.TEST_INST, "1000", "sell"],
]
return pd.DataFrame(orders, columns=headers).set_index(["datetime", "instrument"])
def test_file_str(self):
orders = self._gen_orders()
print(orders)
orders.to_csv(self.EXAMPLE_FILE)
orders = pd.read_csv(self.EXAMPLE_FILE, index_col=["datetime", "instrument"])
strategy_config = {
"class": "FileOrderStrategy",
"module_path": "qlib.contrib.strategy.rule_strategy",
"kwargs": {"file": self.EXAMPLE_FILE},
}
freq = "day"
start_time = "2020-01-01"
end_time = "2020-01-16"
codes = [self.TEST_INST]
backtest_config = {
"start_time": start_time,
"end_time": end_time,
"account": 100000000,
"benchmark": None, # benchmark is not required here for trading
"exchange_kwargs": {
"freq": freq,
"limit_threshold": 0.095,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
"codes": codes,
},
# "pos_type": "InfPosition" # Position with infinitive position
}
executor_config = {
"class": "SimulatorExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": freq,
"generate_report": False,
"verbose": True,
"indicator_config": {
"show_indicator": False,
},
},
}
backtest(executor=executor_config, strategy=strategy_config, **backtest_config)
self.EXAMPLE_FILE.unlink()
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,114 @@
import unittest
import time
import numpy as np
from qlib.data import D
from qlib.tests import TestAutoData
from qlib.data.dataset.handler import DataHandlerLP
from qlib.contrib.data.handler import check_transform_proc
from qlib.log import TimeInspector
class TestHandler(DataHandlerLP):
def __init__(
self,
instruments="csi300",
start_time=None,
end_time=None,
infer_processors=[],
learn_processors=[],
fit_start_time=None,
fit_end_time=None,
drop_raw=True,
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"freq": "day",
"config": self.get_feature_config(),
"swap_level": False,
},
}
super().__init__(
instruments=instruments,
start_time=start_time,
end_time=end_time,
data_loader=data_loader,
infer_processors=infer_processors,
learn_processors=learn_processors,
drop_raw=drop_raw,
)
def get_feature_config(self):
fields = ["Ref($open, 1)", "Ref($close, 1)", "Ref($volume, 1)", "$open", "$close", "$volume"]
names = ["open_0", "close_0", "volume_0", "open_1", "close_1", "volume_1"]
return fields, names
class TestHandlerStorage(TestAutoData):
market = "all"
start_time = "2010-01-01"
end_time = "2020-12-31"
train_end_time = "2015-12-31"
test_start_time = "2016-01-01"
data_handler_kwargs = {
"start_time": start_time,
"end_time": end_time,
"fit_start_time": start_time,
"fit_end_time": train_end_time,
"instruments": market,
}
def test_handler_storage(self):
# init data handler
data_handler = TestHandler(**self.data_handler_kwargs)
# init data handler with hasing storage
data_handler_hs = TestHandler(**self.data_handler_kwargs, infer_processors=["HashStockFormat"])
fetch_start_time = "2019-01-01"
fetch_end_time = "2019-12-31"
instruments = D.instruments(market=self.market)
instruments = D.list_instruments(
instruments=instruments, start_time=fetch_start_time, end_time=fetch_end_time, as_list=True
)
with TimeInspector.logt("random fetch with DataFrame Storage"):
# single stock
for i in range(100):
random_index = np.random.randint(len(instruments), size=1)[0]
fetch_stock = instruments[random_index]
data_handler.fetch(selector=(fetch_stock, slice(fetch_start_time, fetch_end_time)), level=None)
# multi stocks
for i in range(100):
random_indexs = np.random.randint(len(instruments), size=5)
fetch_stocks = [instruments[_index] for _index in random_indexs]
data_handler.fetch(selector=(fetch_stocks, slice(fetch_start_time, fetch_end_time)), level=None)
with TimeInspector.logt("random fetch with HasingStock Storage"):
# single stock
for i in range(100):
random_index = np.random.randint(len(instruments), size=1)[0]
fetch_stock = instruments[random_index]
data_handler_hs.fetch(selector=(fetch_stock, slice(fetch_start_time, fetch_end_time)), level=None)
# multi stocks
for i in range(100):
random_indexs = np.random.randint(len(instruments), size=5)
fetch_stocks = [instruments[_index] for _index in random_indexs]
data_handler_hs.fetch(selector=(fetch_stocks, slice(fetch_start_time, fetch_end_time)), level=None)
if __name__ == "__main__":
unittest.main()