1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 20:11:08 +08:00

update handler_storage test

This commit is contained in:
bxdd
2021-06-29 15:51:41 +00:00
committed by you-n-g
parent 9985befe69
commit 8d1b1979d9
4 changed files with 59 additions and 37 deletions

View File

@@ -206,13 +206,14 @@ class DataHandler(Serializable):
# FIXME: fetching by time first will be more friendly to `proc_func`
# Copy in case of `proc_func` changing the data inplace....
data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy())
# Fetch column first will be more friendly to SepDataFrame
data_df = fetch_df_by_col(data_df, col_set)
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
data_df = fetch_df_by_col(data_df, col_set)
else:
# Fetch column first will be more friendly to SepDataFrame
data_df = fetch_df_by_col(data_df, col_set)
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
elif isinstance(data_storage, HasingStockStorage):
if proc_func is not None:
warnings.warn(f"proc_func is not supported by the HasingStockStorage")
raise ValueError("proc_func is not supported by the HasingStockStorage")
data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig)
else:
raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}")
@@ -530,13 +531,15 @@ class DataHandlerLP(DataHandler):
# FIXME: fetch by time first will be more friendly to proc_func
# Copy incase of `proc_func` changing the data inplace....
data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy())
# Fetch column first will be more friendly to SepDataFrame
data_df = fetch_df_by_col(data_df, col_set)
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
data_df = fetch_df_by_col(data_df, col_set)
else:
# Fetch column first will be more friendly to SepDataFrame
data_df = fetch_df_by_col(data_df, col_set)
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
elif isinstance(data_storage, HasingStockStorage):
if proc_func is not None:
warnings.warn(f"proc_func is not supported by the HasingStockStorage")
raise ValueError("proc_func is not supported by the HasingStockStorage")
data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig)
else:
raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}")

View File

@@ -312,8 +312,8 @@ class CSZFillna(Processor):
return df
class HashingStock(Processor):
"""Process the df into hasing stock storage"""
class HashStockFormat(Processor):
"""Process the storage of from df into hasing stock format"""
def __call__(self, df: pd.DataFrame):
from .storage import HasingStockStorage

View File

@@ -71,7 +71,7 @@ class HasingStockStorage(BaseHandlerStorage):
if not isinstance(stock_selector, (list, str)) and stock_selector != slice(None):
raise TypeError(f"stock selector must be type str|list, or slice(None), rather than {stock_selector}")
print(stock_selector)
if stock_selector == slice(None):
return self.hash_df

View File

@@ -1,15 +1,11 @@
import unittest
import qlib
import time
import pandas as pd
import numpy as np
from qlib.data import D
from qlib.tests import TestAutoData
from qlib.data.dataset.handler import DataHandlerLP
from qlib.data.dataset.processor import Processor
from qlib.contrib.data.handler import check_transform_proc
from qlib.utils import init_instance_by_config
from qlib.log import TimeInspector
@@ -63,17 +59,17 @@ class MiniTimer:
def __exit__(self, exc_type, exc_val, exc_tb):
self.end = time.time()
print(f"[MyTimer Info] <{self.name}> process costs {self.end - self.start} seconds")
print(f"[Timer Info] <{self.name}> process costs {self.end - self.start} seconds")
class TestHandlerStorage(TestAutoData):
market = "all"
start_time = "2020-01-01"
start_time = "2010-01-01"
end_time = "2020-12-31"
train_end_time = "2020-05-31"
test_start_time = "2020-06-01"
train_end_time = "2015-12-31"
test_start_time = "2016-01-01"
data_handler_kwargs = {
"start_time": start_time,
@@ -81,26 +77,49 @@ class TestHandlerStorage(TestAutoData):
"fit_start_time": start_time,
"fit_end_time": train_end_time,
"instruments": market,
"infer_processors": ["HashingStock"],
}
def test_handler_storage(self):
with MiniTimer("init data hanlder"):
data_handler = TestHandler(**self.data_handler_kwargs)
# init data handler
data_handler = TestHandler(**self.data_handler_kwargs)
with MiniTimer("random fetch"):
print(data_handler.fetch(selector=("SH600170", slice(None)), level=None))
print(
data_handler.fetch(
selector=("SH600170", slice(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01"))), level=None
)
)
print(
data_handler.fetch(
selector=(["SH600170", "SH600383"], slice(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01"))),
level=None,
)
)
# init data handler with hasing storage
data_handler_hs = TestHandler(**self.data_handler_kwargs, infer_processors=["HashStockFormat"])
fetch_start_time = "2019-01-01"
fetch_end_time = "2019-12-31"
instruments = D.instruments(market=self.market)
instruments = D.list_instruments(
instruments=instruments, start_time=fetch_start_time, end_time=fetch_end_time, as_list=True
)
with TimeInspector.logt("random fetch with DataFrame Storage"):
# single stock
for i in range(100):
random_index = np.random.randint(len(instruments), size=1)[0]
fetch_stock = instruments[random_index]
data_handler.fetch(selector=(fetch_stock, slice(fetch_start_time, fetch_end_time)), level=None)
# multi stocks
for i in range(100):
random_indexs = np.random.randint(len(instruments), size=5)
fetch_stocks = [instruments[_index] for _index in random_indexs]
data_handler.fetch(selector=(fetch_stocks, slice(fetch_start_time, fetch_end_time)), level=None)
with TimeInspector.logt("random fetch with HasingStock Storage"):
# single stock
for i in range(100):
random_index = np.random.randint(len(instruments), size=1)[0]
fetch_stock = instruments[random_index]
data_handler_hs.fetch(selector=(fetch_stock, slice(fetch_start_time, fetch_end_time)), level=None)
# multi stocks
for i in range(100):
random_indexs = np.random.randint(len(instruments), size=5)
fetch_stocks = [instruments[_index] for _index in random_indexs]
data_handler_hs.fetch(selector=(fetch_stocks, slice(fetch_start_time, fetch_end_time)), level=None)
if __name__ == "__main__":