1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 12:00:58 +08:00

update HashingStockStorage

This commit is contained in:
bxdd
2021-06-29 12:02:27 +00:00
committed by you-n-g
parent 90bbf2b7c6
commit 9985befe69
3 changed files with 176 additions and 24 deletions

View File

@@ -175,7 +175,7 @@ class DataHandler(Serializable):
select a set of meaningful columns.(e.g. features, columns)
if cal_set == CS_RAW:
if col_set == CS_RAW:
the raw dataset will be returned.
- if isinstance(col_set, List[str]):
@@ -197,23 +197,33 @@ class DataHandler(Serializable):
-------
pd.DataFrame.
"""
if proc_func is None:
df = self._data
else:
# FIXME: fetching by time first will be more friendly to `proc_func`
# Copy in case of `proc_func` changing the data inplace....
df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy())
from .storage import HasingStockStorage
data_storage = self._data
if isinstance(data_storage, pd.DataFrame):
data_df = data_storage
if proc_func is not None:
# FIXME: fetching by time first will be more friendly to `proc_func`
# Copy in case of `proc_func` changing the data inplace....
data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy())
# Fetch column first will be more friendly to SepDataFrame
data_df = fetch_df_by_col(data_df, col_set)
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
elif isinstance(data_storage, HasingStockStorage):
if proc_func is not None:
warnings.warn(f"proc_func is not supported by the HasingStockStorage")
data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig)
else:
raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}")
# Fetch column first will be more friendly to SepDataFrame
df = fetch_df_by_col(df, col_set)
df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
if squeeze:
# squeeze columns
df = df.squeeze()
data_df = data_df.squeeze()
# squeeze index
if isinstance(selector, (str, pd.Timestamp)):
df = df.reset_index(level=level, drop=True)
return df
data_df = data_df.reset_index(level=level, drop=True)
return data_df
def get_cols(self, col_set=CS_ALL) -> list:
"""
@@ -511,14 +521,27 @@ class DataHandlerLP(DataHandler):
-------
pd.DataFrame:
"""
df = self._get_df_by_key(data_key)
if proc_func is not None:
# FIXME: fetch by time first will be more friendly to proc_func
# Copy incase of `proc_func` changing the data inplace....
df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy())
# Fetch column first will be more friendly to SepDataFrame
df = fetch_df_by_col(df, col_set)
return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
from .storage import HasingStockStorage
data_storage = self._get_df_by_key(data_key)
if isinstance(data_storage, pd.DataFrame):
data_df = data_storage
if proc_func is not None:
# FIXME: fetch by time first will be more friendly to proc_func
# Copy incase of `proc_func` changing the data inplace....
data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy())
# Fetch column first will be more friendly to SepDataFrame
data_df = fetch_df_by_col(data_df, col_set)
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
elif isinstance(data_storage, HasingStockStorage):
if proc_func is not None:
warnings.warn(f"proc_func is not supported by the HasingStockStorage")
data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig)
else:
raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}")
return data_df
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list:
"""

View File

@@ -2,7 +2,7 @@ import pandas as pd
import numpy as np
from .handler import DataHandler
from typing import Tuple, Union, List
from typing import Tuple, Union, List, Callable
from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col
@@ -13,8 +13,29 @@ class BaseHandlerStorage:
selector: Union[pd.Timestamp, slice, str, list] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
fetch_orig: bool = True,
**kwargs,
) -> pd.DataFrame:
"""fetch data from the data storage
Parameters
----------
selector : Union[pd.Timestamp, slice, str]
describe how to select data by index
level : Union[str, int]
which index level to select the data
col_set : Union[str, List[str]]
- if isinstance(col_set, str):
select a set of meaningful columns.(e.g. features, columns)
if col_set == DataHandler.CS_RAW:
the raw dataset will be returned.
- if isinstance(col_set, List[str]):
select several sets of meaningful columns, the returned data has multiple level
fetch_orig : bool
Return the original data instead of copy if possible.
"""
raise NotImplementedError("fetch is method not implemented!")
@staticmethod
@@ -68,11 +89,12 @@ class HasingStockStorage(BaseHandlerStorage):
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
fetch_orig: bool = True,
) -> pd.DataFrame:
fetch_stock_df_list = list(self._fetch_hash_df_by_stock(selector=selector, level=level).values())
for _index, stock_df in enumerate(fetch_stock_df_list):
fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set)
fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level)
fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level, fetch_orig=fetch_orig)
fetch_stock_df_list[_index] = fetch_index_df
if len(fetch_stock_df_list) == 0:
index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument")
@@ -82,4 +104,4 @@ class HasingStockStorage(BaseHandlerStorage):
elif len(fetch_stock_df_list) == 1:
return fetch_stock_df_list[0]
else:
return pd.concat(fetch_stock_df_list, axis=0, sort=False)
return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig)

View File

@@ -0,0 +1,107 @@
import unittest
import qlib
import time
import pandas as pd
from qlib.data import D
from qlib.tests import TestAutoData
from qlib.data.dataset.handler import DataHandlerLP
from qlib.data.dataset.processor import Processor
from qlib.contrib.data.handler import check_transform_proc
from qlib.utils import init_instance_by_config
from qlib.log import TimeInspector
class TestHandler(DataHandlerLP):
def __init__(
self,
instruments="csi300",
start_time=None,
end_time=None,
infer_processors=[],
learn_processors=[],
fit_start_time=None,
fit_end_time=None,
drop_raw=True,
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"freq": "day",
"config": self.get_feature_config(),
"swap_level": False,
},
}
super().__init__(
instruments=instruments,
start_time=start_time,
end_time=end_time,
data_loader=data_loader,
infer_processors=infer_processors,
learn_processors=learn_processors,
drop_raw=drop_raw,
)
def get_feature_config(self):
fields = ["Ref($open, 1)", "Ref($close, 1)", "Ref($volume, 1)", "$open", "$close", "$volume"]
names = ["open_0", "close_0", "volume_0", "open_1", "close_1", "volume_1"]
return fields, names
class MiniTimer:
def __init__(self, name):
self.name = name
def __enter__(self):
self.start = time.time()
def __exit__(self, exc_type, exc_val, exc_tb):
self.end = time.time()
print(f"[MyTimer Info] <{self.name}> process costs {self.end - self.start} seconds")
class TestHandlerStorage(TestAutoData):
market = "all"
start_time = "2020-01-01"
end_time = "2020-12-31"
train_end_time = "2020-05-31"
test_start_time = "2020-06-01"
data_handler_kwargs = {
"start_time": start_time,
"end_time": end_time,
"fit_start_time": start_time,
"fit_end_time": train_end_time,
"instruments": market,
"infer_processors": ["HashingStock"],
}
def test_handler_storage(self):
with MiniTimer("init data hanlder"):
data_handler = TestHandler(**self.data_handler_kwargs)
with MiniTimer("random fetch"):
print(data_handler.fetch(selector=("SH600170", slice(None)), level=None))
print(
data_handler.fetch(
selector=("SH600170", slice(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01"))), level=None
)
)
print(
data_handler.fetch(
selector=(["SH600170", "SH600383"], slice(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01"))),
level=None,
)
)
if __name__ == "__main__":
unittest.main()