1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-06 04:20:57 +08:00

add write method to FeatureStorage && remove extend

This commit is contained in:
zhupr
2021-05-21 08:43:36 +08:00
parent 317357b50d
commit 4ba4512619
5 changed files with 321 additions and 169 deletions

View File

@@ -8,6 +8,7 @@ from __future__ import print_function
import os
import re
import abc
import copy
import time
import queue
import bisect
@@ -31,19 +32,27 @@ from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_modu
class ProviderBackendMixin:
def get_default_backend(self):
backend = {}
provider_name = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2] # type: str
provider_name: str = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2]
# set default storage class
backend.setdefault("class", f"File{provider_name}Storage")
# set default storage module
backend.setdefault("module_path", "qlib.data.storage.file_storage")
# set default storage kwargs
backend_kwargs = backend.setdefault("kwargs", {}) # type: dict
backend_kwargs.setdefault("uri", os.path.join(C.get_data_path(), f"{provider_name.lower()}s"))
return backend
@property
def backend_obj(self):
return init_instance_by_config(self.backend)
def backend_obj(self, **kwargs):
backend = self.backend if self.backend else self.get_default_backend()
backend = copy.deepcopy(backend)
# set default storage kwargs
backend_kwargs = backend.setdefault("kwargs", {})
# default uri map
if "uri" not in backend_kwargs:
# if the user has no uri configured, use: uri = uri_map[freq]
freq = kwargs.get("freq", "day")
uri_map = backend_kwargs.setdefault("uri_map", {freq: C.get_data_path()})
backend_kwargs["uri"] = uri_map[freq]
backend.setdefault("kwargs", {}).update(**kwargs)
return init_instance_by_config(backend)
class CalendarProvider(abc.ABC, ProviderBackendMixin):
@@ -54,8 +63,6 @@ class CalendarProvider(abc.ABC, ProviderBackendMixin):
def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
if not self.backend:
self.backend = self.get_default_backend()
@abc.abstractmethod
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
@@ -159,8 +166,6 @@ class InstrumentProvider(abc.ABC, ProviderBackendMixin):
def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
if not self.backend:
self.backend = self.get_default_backend()
@staticmethod
def instruments(market="all", filter_pipe=None):
@@ -252,8 +257,6 @@ class FeatureProvider(abc.ABC, ProviderBackendMixin):
def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
if not self.backend:
self.backend = self.get_default_backend()
@abc.abstractmethod
def feature(self, instrument, field, start_time, end_time, freq):
@@ -552,8 +555,18 @@ class LocalCalendarProvider(CalendarProvider):
list
list of timestamps
"""
self.backend.setdefault("kwargs", {}).update(freq=freq, future=future)
return [pd.Timestamp(x) for x in self.backend_obj.data]
backend_obj = self.backend_obj(freq=freq, future=future)
if future and not backend_obj.check_exists():
get_module_logger("data").warning(
f"load calendar error: freq={freq}, future={future}; return current calendar!"
)
get_module_logger("data").warning(
"You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md"
)
backend_obj = self.backend_obj(freq=freq, future=False)
return [pd.Timestamp(x) for x in backend_obj.data]
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
_calendar, _calendar_index = self._get_calendar(freq, future)
@@ -589,17 +602,15 @@ class LocalInstrumentProvider(InstrumentProvider):
"""Instrument file uri."""
return os.path.join(C.get_data_path(), "instruments", "{}.txt")
def _load_instruments(self, market):
self.backend.setdefault("kwargs", {}).update(market=market)
return self.backend_obj.data
def _load_instruments(self, market, freq):
return self.backend_obj(market=market, freq=freq).data
def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
market = instruments["market"]
if market in H["i"]:
_instruments = H["i"][market]
else:
_instruments = self._load_instruments(market)
_instruments = self._load_instruments(market, freq=freq)
H["i"][market] = _instruments
# strip
# use calendar boundary
@@ -648,9 +659,14 @@ class LocalFeatureProvider(FeatureProvider):
# validate
field = str(field).lower()[1:]
instrument = code_to_fname(instrument)
self.backend.setdefault("kwargs", {}).update(instrument=instrument, field=field, freq=freq)
return self.backend_obj[start_index : end_index + 1]
try:
data = self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]
except Exception as e:
get_module_logger("data").warning(
f"WARN: data not found for {instrument}.{field}\n\tException info: {str(e)}"
)
data = pd.Series(dtype=np.float32)
return data
class LocalExpressionProvider(ExpressionProvider):

View File

@@ -3,25 +3,36 @@
import struct
from pathlib import Path
from typing import Iterator, Iterable, Union, Dict, Mapping, Tuple
from typing import Iterable, Union, Dict, Mapping, Tuple, List
import numpy as np
import pandas as pd
from qlib.log import get_module_logger
from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT
logger = get_module_logger("file_storage")
class FileCalendarStorage(CalendarStorage):
def __init__(self, freq: str, future: bool, uri: str):
super(FileCalendarStorage, self).__init__(freq, future, uri)
class FileStorage:
def check_exists(self):
return self.uri.exists()
class FileCalendarStorage(FileStorage, CalendarStorage):
def __init__(self, freq: str, future: bool, uri: str, **kwargs):
super(FileCalendarStorage, self).__init__(freq, future, uri, **kwargs)
_file_name = f"{freq}_future.txt" if future else f"{freq}.txt"
self.uri = Path(self.uri).expanduser().joinpath(_file_name.lower())
self.uri = Path(self.uri).expanduser().joinpath("calendars", _file_name.lower())
def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> np.ndarray:
if not self.uri.exists():
def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> Iterable[CalVT]:
if not self.check_exists():
self._write_calendar(values=[])
with self.uri.open("rb") as fp:
return np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, encoding="utf-8")
return [
str(x)
for x in np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, delimiter="\n", encoding="utf-8")
]
def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"):
with self.uri.open(mode=mode) as fp:
@@ -65,23 +76,17 @@ class FileCalendarStorage(CalendarStorage):
def __getitem__(self, i: Union[int, slice]) -> Union[CalVT, Iterable[CalVT]]:
return self._read_calendar()[i]
def __len__(self) -> int:
return len(self._read_calendar())
def __iter__(self):
return iter(self._read_calendar())
class FileInstrumentStorage(InstrumentStorage):
class FileInstrumentStorage(FileStorage, InstrumentStorage):
INSTRUMENT_SEP = "\t"
INSTRUMENT_START_FIELD = "start_datetime"
INSTRUMENT_END_FIELD = "end_datetime"
SYMBOL_FIELD_NAME = "instrument"
def __init__(self, market: str, uri: str):
super(FileInstrumentStorage, self).__init__(market, uri)
self.uri = Path(self.uri).expanduser().joinpath(f"{market.lower()}.txt")
def __init__(self, market: str, uri: str, **kwargs):
super(FileInstrumentStorage, self).__init__(market, uri, **kwargs)
self.uri = Path(self.uri).expanduser().joinpath("instruments", f"{market.lower()}.txt")
def _read_instrument(self) -> Dict[InstKT, InstVT]:
if not self.uri.exists():
@@ -138,14 +143,6 @@ class FileInstrumentStorage(InstrumentStorage):
def __getitem__(self, k: InstKT) -> InstVT:
return self._read_instrument()[k]
def __len__(self) -> int:
inst = self._read_instrument()
return len(inst)
def __iter__(self) -> Iterator[InstKT]:
for _inst in self._read_instrument().keys():
yield _inst
def update(self, *args, **kwargs) -> None:
if len(args) > 1:
@@ -168,11 +165,11 @@ class FileInstrumentStorage(InstrumentStorage):
self._write_instrument(inst)
class FileFeatureStorage(FeatureStorage):
def __init__(self, instrument: str, field: str, freq: str, uri: str):
super(FileFeatureStorage, self).__init__(instrument, field, freq, uri)
class FileFeatureStorage(FileStorage, FeatureStorage):
def __init__(self, instrument: str, field: str, freq: str, uri: str, **kwargs):
super(FileFeatureStorage, self).__init__(instrument, field, freq, uri, **kwargs)
self.uri = (
Path(self.uri).expanduser().joinpath(instrument.lower()).joinpath(f"{field.lower()}.{freq.lower()}.bin")
Path(self.uri).expanduser().joinpath("features", instrument.lower(), f"{field.lower()}.{freq.lower()}.bin")
)
def clear(self):
@@ -183,18 +180,45 @@ class FileFeatureStorage(FeatureStorage):
def data(self) -> pd.Series:
return self[:]
def extend(self, series: pd.Series) -> None:
extend_start_index = self[0][0] + len(self) if self.uri.exists() else series.index[0]
series = series.reindex(pd.RangeIndex(extend_start_index, series.index[-1] + 1))
with self.uri.open("ab") as fp:
np.array(series.values).astype("<f").tofile(fp)
def write(self, data_array: Union[List, np.ndarray], index: int = None) -> None:
if len(data_array) == 0:
logger.info(
"len(data_array) == 0, write"
"if you need to clear the FeatureStorage, please execute: FeatureStorage.clear"
)
return
if not self.uri.exists():
# write
index = 0 if index is None else index
with self.uri.open("wb") as fp:
np.hstack([index, data_array]).astype("<f").tofile(fp)
else:
if index is None or index > self.end_index:
# append
index = 0 if index is None else index
with self.uri.open("ab+") as fp:
np.hstack([[np.nan] * (index - self.end_index - 1), data_array]).astype("<f").tofile(fp)
else:
# rewrite
with self.uri.open("rb+") as fp:
_old_data = np.fromfile(fp, dtype="<f")
_old_index = _old_data[0]
_old_df = pd.DataFrame(
_old_data[1:], index=range(_old_index, _old_index + len(_old_data) - 1), columns=["old"]
)
fp.seek(0)
_new_df = pd.DataFrame(data_array, index=range(index, index + len(data_array)), columns=["new"])
_df = pd.concat([_old_df, _new_df], sort=False, axis=1)
_df = _df.reindex(range(_df.index.min(), _df.index.max() + 1))
_df["new"].fillna(_df["old"]).values.astype("<f").tofile(fp)
def rebase(self, series: pd.Series) -> None:
origin_series = self[:]
series = series.append(origin_series.loc[origin_series.index > series.index[-1]])
series = series.reindex(pd.RangeIndex(series.index[0], series.index[-1]))
with self.uri.open("wb") as fp:
np.array(series.values).astype("<f").tofile(fp)
@property
def start_index(self) -> Union[int, None]:
if len(self) == 0:
return None
with open(self.uri, "rb") as fp:
index = int(np.frombuffer(fp.read(4), dtype="<f")[0])
return index
def __getitem__(self, i: Union[int, slice]) -> Union[Tuple[int, float], pd.Series]:
if not self.uri.exists():
@@ -228,18 +252,4 @@ class FileFeatureStorage(FeatureStorage):
raise TypeError(f"type(i) = {type(i)}")
def __len__(self) -> int:
return self.uri.stat().st_size // 4 - 1 if self.uri.exists() else 0
def __iter__(self):
if not self.uri.exists():
return
with open(self.uri, "rb") as fp:
ref_start_index = int(np.frombuffer(fp.read(4), dtype="<f")[0])
fp.seek(4)
while True:
v = fp.read(4)
if v:
yield ref_start_index, struct.unpack("f", v)[0]
ref_start_index += 1
else:
break
return self.uri.stat().st_size // 4 - 1 if self.check_exists() else 0

View File

@@ -1,9 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Iterable, overload, Tuple, List, Text, Iterator, Union, Dict
import re
from typing import Iterable, overload, Tuple, List, Text, Union, Dict
import numpy as np
import pandas as pd
from qlib.log import get_module_logger
# calendar value type
CalVT = str
@@ -13,9 +16,91 @@ InstVT = List[Tuple[CalVT, CalVT]]
# instrument key
InstKT = Text
logger = get_module_logger("storage")
class CalendarStorage:
def __init__(self, freq: str, future: bool, uri: str):
"""
If the user is only using it in `qlib`, you can customize Storage to implement only the following methods:
class UserCalendarStorage(CalendarStorage):
@property
def data(self):
pass
class UserInstrumentStorage(InstrumentStorage):
@property
def data(self):
pass
class UserFeatureStorage(FeatureStorage):
@check_storage
def __getitem__(self, i: slice) -> pd.Series:
pass
"""
class StorageMeta(type):
"""unified management of raise when storage is not exists"""
def __new__(cls, name, bases, dict):
class_obj = type.__new__(cls, name, bases, dict)
# The calls to __iter__ and __getitem__ do not pass through __getattribute__.
# In order to throw an exception before calling __getitem__, use the metaclass
_getitem_func = getattr(class_obj, "__getitem__")
def _getitem(obj, item):
_check_func = getattr(obj, "_check")
if callable(_check_func):
_check_func()
return _getitem_func(obj, item)
setattr(class_obj, "__getitem__", _getitem)
return class_obj
class BaseStorage(metaclass=StorageMeta):
@property
def storage_name(self) -> str:
return re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2]
def check_exists(self) -> bool:
"""check if storage(uri) exists, if not exists: return False"""
raise NotImplementedError("Subclass of BaseStorage must implement `check_exists` method")
def clear(self) -> None:
"""clear storage"""
raise NotImplementedError("Subclass of BaseStorage must implement `clear` method")
def __len__(self) -> 0:
return len(self.data) if self.check_exists() else 0
def __getitem__(self, item: Union[slice, Union[int, InstKT]]):
raise NotImplementedError(
"Subclass of BaseStorage must implement `__getitem__(i: Union[int, InstKT])`/`__getitem__(s: slice)` method"
)
def _check(self):
# check storage(uri)
if not self.check_exists():
parameters_info = [f"{_k}={_v}" for _k, _v in self.__dict__.items()]
raise ValueError(f"{self.storage_name.lower()} not exists, storage parameters: {parameters_info}")
def __getattribute__(self, item):
if item == "data":
self._check()
return super(BaseStorage, self).__getattribute__(item)
class CalendarStorage(BaseStorage):
"""
The behavior of CalendarStorage's methods and List's methods of the same name remain consistent
"""
def __init__(self, freq: str, future: bool, uri: str, **kwargs):
self.freq = freq
self.future = future
self.uri = uri
@@ -28,9 +113,6 @@ class CalendarStorage:
def extend(self, iterable: Iterable[CalVT]) -> None:
raise NotImplementedError("Subclass of CalendarStorage must implement `extend` method")
def clear(self) -> None:
raise NotImplementedError("Subclass of CalendarStorage must implement `clear` method")
def index(self, value: CalVT) -> int:
raise NotImplementedError("Subclass of CalendarStorage must implement `index` method")
@@ -85,16 +167,9 @@ class CalendarStorage:
"Subclass of CalendarStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method"
)
def __len__(self) -> int:
"""x.__len__() <==> len(x)"""
raise NotImplementedError("Subclass of CalendarStorage must implement `__len__` method")
def __iter__(self):
raise NotImplementedError("Subclass of CalendarStorage must implement `__iter__` method")
class InstrumentStorage:
def __init__(self, market: str, uri: str):
class InstrumentStorage(BaseStorage):
def __init__(self, market: str, uri: str, **kwargs):
self.market = market
self.uri = uri
@@ -103,9 +178,6 @@ class InstrumentStorage:
"""get all data"""
raise NotImplementedError("Subclass of InstrumentStorage must implement `data` method")
def clear(self) -> None:
raise NotImplementedError("Subclass of InstrumentStorage must implement `clear` method")
def update(self, *args, **kwargs) -> None:
"""D.update([E, ]**F) -> None. Update D from mapping/iterable E and F.
If E present and has a .keys() method, does: for k in E: D[k] = E[k]
@@ -126,17 +198,9 @@ class InstrumentStorage:
""" x.__getitem__(k) <==> x[k] """
raise NotImplementedError("Subclass of InstrumentStorage must implement `__getitem__` method")
def __len__(self) -> int:
""" Return len(self). """
raise NotImplementedError("Subclass of InstrumentStorage must implement `__len__` method")
def __iter__(self) -> Iterator[InstKT]:
""" Return iter(self). """
raise NotImplementedError("Subclass of InstrumentStorage must implement `__iter__` method")
class FeatureStorage:
def __init__(self, instrument: str, field: str, freq: str, uri: str):
class FeatureStorage(BaseStorage):
def __init__(self, instrument: str, field: str, freq: str, uri: str, **kwargs):
self.instrument = instrument
self.field = field
self.freq = freq
@@ -147,12 +211,25 @@ class FeatureStorage:
"""get all data"""
raise NotImplementedError("Subclass of FeatureStorage must implement `data` method")
def clear(self):
""" Remove all items from FeatureStorage. """
raise NotImplementedError("Subclass of FeatureStorage must implement `clear` method")
@property
def start_index(self) -> Union[int, None]:
"""get FeatureStorage start index
If len(self) == 0; return None
"""
raise NotImplementedError("Subclass of FeatureStorage must implement `data` method")
@property
def end_index(self) -> Union[int, None]:
if len(self) == 0:
return None
return None if len(self) == 0 else self.start_index + len(self) - 1
def write(self, data_array: Union[List, np.ndarray, Tuple], index: int = None):
"""Write data_array to FeatureStorage starting from index.
If index is None, append data_array to feature.
If len(data_array) == 0; return
If (index - self.end_index) >= 1, self[end_index+1: index] will be filled with np.nan
def extend(self, series: pd.Series):
"""Extend feature by appending elements from the series.
Examples:
@@ -161,21 +238,42 @@ class FeatureStorage:
4 5
5 6
>>> self.extend(pd.Series({7: 8, 9:10}))
>>> self.write([6, 7], index=6)
feature:
3 4
4 5
5 6
6 np.nan
7 8
9 10
6 6
7 7
>>> self.write([8], index=9)
feature:
3 4
4 5
5 6
6 6
7 7
8 np.nan
9 8
>>> self.write([1, np.nan], index=3)
feature:
3 1
4 np.nan
5 6
6 6
7 7
8 np.nan
9 8
"""
raise NotImplementedError("Subclass of FeatureStorage must implement `extend` method")
raise NotImplementedError("Subclass of FeatureStorage must implement `write` method")
def rebase(self, series: pd.Series):
"""Rebase feature header from the series.
def rebase(self, start_index: int = None, end_index: int = None):
"""Rebase the start_index and end_index of the FeatureStorage.
Examples:
@@ -184,30 +282,85 @@ class FeatureStorage:
4 5
5 6
>>> self.rebase(pd.Series({1: 2}))
>>> self.rebase(start_index=4)
feature:
1 2
2 np.nan
3 4
4 5
5 6
>>> self.rebase(pd.Series({5: 6, 7: 8, 9: 10}))
>>> self.rebase(start_index=3)
feature:
3 np.nan
4 5
5 6
7 8
9 10
>>> self.rebase(pd.Series({11: 12, 12: 13,}))
>>> self.write([3], index=3)
feature:
11 12
12 13
3 3
4 5
5 6
>>> self.rebase(end_index=4)
feature:
3 3
4 5
>>> self.write([6, 7, 8], index=4)
feature:
3 3
4 6
5 7
6 8
>>> self.rebase(start_index=4, end_index=5)
feature:
4 6
5 7
"""
raise NotImplementedError("Subclass of FeatureStorage must implement `rebase` method")
if start_index is None and end_index is None:
logger.warning("both start_index and end_index are None, rebase is ignored")
return
if start_index < 0 or end_index < 0:
logger.warning("start_index or end_index cannot be less than 0")
return
if start_index > end_index:
logger.warning(
f"start_index({start_index}) > end_index({end_index}), rebase is ignored; "
f"if you need to clear the FeatureStorage, please execute: FeatureStorage.clear"
)
return
start_index = self.start_index if start_index is None else end_index
end_index = self.end_index if end_index is None else end_index
if start_index <= self.start_index:
self.write([np.nan] * (self.start_index - start_index), start_index)
else:
self.rewrite(self[start_index:].values, start_index)
if end_index >= self.end_index:
self.write([np.nan] * (end_index - self.end_index))
else:
self.rewrite(self[: end_index + 1].values, self.start_index)
def rewrite(self, data: Union[List, np.ndarray, Tuple], index: int):
"""overwrite all data in FeatureStorage with data
Parameters
----------
data: Union[List, np.ndarray, Tuple]
data
index: int
data start index
"""
self.clear()
self.write(data, index)
@overload
def __getitem__(self, s: slice) -> pd.Series:
@@ -224,11 +377,3 @@ class FeatureStorage:
raise NotImplementedError(
"Subclass of FeatureStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method"
)
def __len__(self) -> int:
"""len(feature) <==> feature.__len__() """
raise NotImplementedError("Subclass of FeatureStorage must implement `__len__` method")
def __iter__(self) -> Iterable[Tuple[int, float]]:
"""iter(feature)"""
raise NotImplementedError("Subclass of FeatureStorage must implement `__iter__` method")

View File

@@ -9,19 +9,19 @@ from ..config import REG_CN
class TestAutoData(unittest.TestCase):
_setup_kwargs = {}
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
@classmethod
def setUpClass(cls) -> None:
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
if not exists_qlib_data(cls.provider_uri):
print(f"Qlib data is not found in {cls.provider_uri}")
GetData().qlib_data(
name="qlib_data_simple",
region="cn",
interval="1d",
target_dir=provider_uri,
target_dir=cls.provider_uri,
delete_old=False,
)
init(provider_uri=provider_uri, region=REG_CN, **cls._setup_kwargs)
init(provider_uri=cls.provider_uri, region=REG_CN, **cls._setup_kwargs)

View File

@@ -2,13 +2,12 @@
# Licensed under the MIT License.
import shutil
from pathlib import Path
from collections.abc import Iterable
import pytest
import numpy as np
from qlib.tests.data import GetData
from qlib.tests import TestAutoData
from qlib.data.storage.file_storage import (
FileCalendarStorage as CalendarStorage,
@@ -22,25 +21,10 @@ QLIB_DIR = DATA_DIR.joinpath("qlib")
QLIB_DIR.mkdir(exist_ok=True, parents=True)
# TODO: set value
CALENDAR_URI = QLIB_DIR.joinpath("calendars")
INSTRUMENT_URI = QLIB_DIR.joinpath("instruments")
FEATURE_URI = QLIB_DIR.joinpath("features")
class TestStorage:
@classmethod
def setup_class(cls):
GetData().qlib_data(name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", delete_old=False)
@classmethod
def teardown_class(cls):
shutil.rmtree(str(DATA_DIR.resolve()))
class TestStorage(TestAutoData):
def test_calendar_storage(self):
calendar = CalendarStorage(freq="day", future=False, uri=CALENDAR_URI)
assert isinstance(calendar, Iterable), f"{calendar.__class__.__name__} is not Iterable"
calendar = CalendarStorage(freq="day", future=False, uri=self.provider_uri)
assert isinstance(calendar[:], Iterable), f"{calendar.__class__.__name__}.__getitem__(s: slice) is not Iterable"
assert isinstance(calendar.data, Iterable), f"{calendar.__class__.__name__}.data is not Iterable"
@@ -82,9 +66,7 @@ class TestStorage:
"""
instrument = InstrumentStorage(market="csi300", uri=INSTRUMENT_URI)
assert isinstance(instrument, Iterable), f"{instrument.__class__.__name__} is not Iterable"
instrument = InstrumentStorage(market="csi300", uri=self.provider_uri)
for inst, spans in instrument.data.items():
assert isinstance(inst, str) and isinstance(
@@ -151,13 +133,12 @@ class TestStorage:
"""
feature = FeatureStorage(instrument="SH600004", field="close", freq="day", uri=FEATURE_URI)
feature = FeatureStorage(instrument="SH600004", field="close", freq="day", uri=self.provider_uri)
assert isinstance(feature, Iterable), f"{feature.__class__.__name__} is not Iterable"
with pytest.raises(IndexError):
print(feature[0])
assert isinstance(
feature[815][1], (np.float, np.float32)
feature[815][1], (float, np.float32)
), f"{feature.__class__.__name__}.__getitem__(i: int) error"
assert len(feature[815:818]) == 3, f"{feature.__class__.__name__}.__getitem__(s: slice) error"
print(f"feature[815: 818]: {feature[815: 818]}")
@@ -167,5 +148,5 @@ class TestStorage:
isinstance(_item, tuple) and len(_item) == 2
), f"{feature.__class__.__name__}.__iter__ item type error"
assert isinstance(_item[0], int) and isinstance(
_item[1], (np.float, np.float32)
_item[1], (float, np.float32)
), f"{feature.__class__.__name__}.__iter__ value type error"