mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 12:00:58 +08:00
update HashingStockStorage
This commit is contained in:
@@ -175,7 +175,7 @@ class DataHandler(Serializable):
|
||||
|
||||
select a set of meaningful columns.(e.g. features, columns)
|
||||
|
||||
if cal_set == CS_RAW:
|
||||
if col_set == CS_RAW:
|
||||
the raw dataset will be returned.
|
||||
|
||||
- if isinstance(col_set, List[str]):
|
||||
@@ -197,23 +197,33 @@ class DataHandler(Serializable):
|
||||
-------
|
||||
pd.DataFrame.
|
||||
"""
|
||||
if proc_func is None:
|
||||
df = self._data
|
||||
else:
|
||||
# FIXME: fetching by time first will be more friendly to `proc_func`
|
||||
# Copy in case of `proc_func` changing the data inplace....
|
||||
df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy())
|
||||
from .storage import HasingStockStorage
|
||||
|
||||
data_storage = self._data
|
||||
if isinstance(data_storage, pd.DataFrame):
|
||||
data_df = data_storage
|
||||
if proc_func is not None:
|
||||
# FIXME: fetching by time first will be more friendly to `proc_func`
|
||||
# Copy in case of `proc_func` changing the data inplace....
|
||||
data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy())
|
||||
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
data_df = fetch_df_by_col(data_df, col_set)
|
||||
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
|
||||
elif isinstance(data_storage, HasingStockStorage):
|
||||
if proc_func is not None:
|
||||
warnings.warn(f"proc_func is not supported by the HasingStockStorage")
|
||||
data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig)
|
||||
else:
|
||||
raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}")
|
||||
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
df = fetch_df_by_col(df, col_set)
|
||||
df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
|
||||
if squeeze:
|
||||
# squeeze columns
|
||||
df = df.squeeze()
|
||||
data_df = data_df.squeeze()
|
||||
# squeeze index
|
||||
if isinstance(selector, (str, pd.Timestamp)):
|
||||
df = df.reset_index(level=level, drop=True)
|
||||
return df
|
||||
data_df = data_df.reset_index(level=level, drop=True)
|
||||
return data_df
|
||||
|
||||
def get_cols(self, col_set=CS_ALL) -> list:
|
||||
"""
|
||||
@@ -511,14 +521,27 @@ class DataHandlerLP(DataHandler):
|
||||
-------
|
||||
pd.DataFrame:
|
||||
"""
|
||||
df = self._get_df_by_key(data_key)
|
||||
if proc_func is not None:
|
||||
# FIXME: fetch by time first will be more friendly to proc_func
|
||||
# Copy incase of `proc_func` changing the data inplace....
|
||||
df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy())
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
df = fetch_df_by_col(df, col_set)
|
||||
return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
|
||||
from .storage import HasingStockStorage
|
||||
|
||||
data_storage = self._get_df_by_key(data_key)
|
||||
if isinstance(data_storage, pd.DataFrame):
|
||||
data_df = data_storage
|
||||
if proc_func is not None:
|
||||
# FIXME: fetch by time first will be more friendly to proc_func
|
||||
# Copy incase of `proc_func` changing the data inplace....
|
||||
data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy())
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
data_df = fetch_df_by_col(data_df, col_set)
|
||||
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
|
||||
|
||||
elif isinstance(data_storage, HasingStockStorage):
|
||||
if proc_func is not None:
|
||||
warnings.warn(f"proc_func is not supported by the HasingStockStorage")
|
||||
data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig)
|
||||
else:
|
||||
raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}")
|
||||
|
||||
return data_df
|
||||
|
||||
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list:
|
||||
"""
|
||||
|
||||
@@ -2,7 +2,7 @@ import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from .handler import DataHandler
|
||||
from typing import Tuple, Union, List
|
||||
from typing import Tuple, Union, List, Callable
|
||||
|
||||
from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col
|
||||
|
||||
@@ -13,8 +13,29 @@ class BaseHandlerStorage:
|
||||
selector: Union[pd.Timestamp, slice, str, list] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
|
||||
fetch_orig: bool = True,
|
||||
**kwargs,
|
||||
) -> pd.DataFrame:
|
||||
"""fetch data from the data storage
|
||||
|
||||
Parameters
|
||||
----------
|
||||
selector : Union[pd.Timestamp, slice, str]
|
||||
describe how to select data by index
|
||||
level : Union[str, int]
|
||||
which index level to select the data
|
||||
col_set : Union[str, List[str]]
|
||||
- if isinstance(col_set, str):
|
||||
select a set of meaningful columns.(e.g. features, columns)
|
||||
if col_set == DataHandler.CS_RAW:
|
||||
the raw dataset will be returned.
|
||||
- if isinstance(col_set, List[str]):
|
||||
select several sets of meaningful columns, the returned data has multiple level
|
||||
fetch_orig : bool
|
||||
Return the original data instead of copy if possible.
|
||||
|
||||
"""
|
||||
|
||||
raise NotImplementedError("fetch is method not implemented!")
|
||||
|
||||
@staticmethod
|
||||
@@ -68,11 +89,12 @@ class HasingStockStorage(BaseHandlerStorage):
|
||||
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
|
||||
fetch_orig: bool = True,
|
||||
) -> pd.DataFrame:
|
||||
fetch_stock_df_list = list(self._fetch_hash_df_by_stock(selector=selector, level=level).values())
|
||||
for _index, stock_df in enumerate(fetch_stock_df_list):
|
||||
fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set)
|
||||
fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level)
|
||||
fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level, fetch_orig=fetch_orig)
|
||||
fetch_stock_df_list[_index] = fetch_index_df
|
||||
if len(fetch_stock_df_list) == 0:
|
||||
index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument")
|
||||
@@ -82,4 +104,4 @@ class HasingStockStorage(BaseHandlerStorage):
|
||||
elif len(fetch_stock_df_list) == 1:
|
||||
return fetch_stock_df_list[0]
|
||||
else:
|
||||
return pd.concat(fetch_stock_df_list, axis=0, sort=False)
|
||||
return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig)
|
||||
|
||||
107
tests/test_handler_storage.py
Normal file
107
tests/test_handler_storage.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import unittest
|
||||
import qlib
|
||||
import time
|
||||
import pandas as pd
|
||||
|
||||
from qlib.data import D
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.dataset.processor import Processor
|
||||
from qlib.contrib.data.handler import check_transform_proc
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.log import TimeInspector
|
||||
|
||||
|
||||
class TestHandler(DataHandlerLP):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi300",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
drop_raw=True,
|
||||
):
|
||||
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"freq": "day",
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
},
|
||||
}
|
||||
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
drop_raw=drop_raw,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
fields = ["Ref($open, 1)", "Ref($close, 1)", "Ref($volume, 1)", "$open", "$close", "$volume"]
|
||||
names = ["open_0", "close_0", "volume_0", "open_1", "close_1", "volume_1"]
|
||||
return fields, names
|
||||
|
||||
|
||||
class MiniTimer:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __enter__(self):
|
||||
self.start = time.time()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.end = time.time()
|
||||
print(f"[MyTimer Info] <{self.name}> process costs {self.end - self.start} seconds")
|
||||
|
||||
|
||||
class TestHandlerStorage(TestAutoData):
|
||||
|
||||
market = "all"
|
||||
|
||||
start_time = "2020-01-01"
|
||||
end_time = "2020-12-31"
|
||||
train_end_time = "2020-05-31"
|
||||
test_start_time = "2020-06-01"
|
||||
|
||||
data_handler_kwargs = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"fit_start_time": start_time,
|
||||
"fit_end_time": train_end_time,
|
||||
"instruments": market,
|
||||
"infer_processors": ["HashingStock"],
|
||||
}
|
||||
|
||||
def test_handler_storage(self):
|
||||
with MiniTimer("init data hanlder"):
|
||||
data_handler = TestHandler(**self.data_handler_kwargs)
|
||||
|
||||
with MiniTimer("random fetch"):
|
||||
print(data_handler.fetch(selector=("SH600170", slice(None)), level=None))
|
||||
print(
|
||||
data_handler.fetch(
|
||||
selector=("SH600170", slice(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01"))), level=None
|
||||
)
|
||||
)
|
||||
print(
|
||||
data_handler.fetch(
|
||||
selector=(["SH600170", "SH600383"], slice(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01"))),
|
||||
level=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user