mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-06 04:20:57 +08:00
add write method to FeatureStorage && remove extend
This commit is contained in:
@@ -8,6 +8,7 @@ from __future__ import print_function
|
||||
import os
|
||||
import re
|
||||
import abc
|
||||
import copy
|
||||
import time
|
||||
import queue
|
||||
import bisect
|
||||
@@ -31,19 +32,27 @@ from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_modu
|
||||
class ProviderBackendMixin:
|
||||
def get_default_backend(self):
|
||||
backend = {}
|
||||
provider_name = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2] # type: str
|
||||
provider_name: str = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2]
|
||||
# set default storage class
|
||||
backend.setdefault("class", f"File{provider_name}Storage")
|
||||
# set default storage module
|
||||
backend.setdefault("module_path", "qlib.data.storage.file_storage")
|
||||
# set default storage kwargs
|
||||
backend_kwargs = backend.setdefault("kwargs", {}) # type: dict
|
||||
backend_kwargs.setdefault("uri", os.path.join(C.get_data_path(), f"{provider_name.lower()}s"))
|
||||
return backend
|
||||
|
||||
@property
|
||||
def backend_obj(self):
|
||||
return init_instance_by_config(self.backend)
|
||||
def backend_obj(self, **kwargs):
|
||||
backend = self.backend if self.backend else self.get_default_backend()
|
||||
backend = copy.deepcopy(backend)
|
||||
|
||||
# set default storage kwargs
|
||||
backend_kwargs = backend.setdefault("kwargs", {})
|
||||
# default uri map
|
||||
if "uri" not in backend_kwargs:
|
||||
# if the user has no uri configured, use: uri = uri_map[freq]
|
||||
freq = kwargs.get("freq", "day")
|
||||
uri_map = backend_kwargs.setdefault("uri_map", {freq: C.get_data_path()})
|
||||
backend_kwargs["uri"] = uri_map[freq]
|
||||
backend.setdefault("kwargs", {}).update(**kwargs)
|
||||
return init_instance_by_config(backend)
|
||||
|
||||
|
||||
class CalendarProvider(abc.ABC, ProviderBackendMixin):
|
||||
@@ -54,8 +63,6 @@ class CalendarProvider(abc.ABC, ProviderBackendMixin):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.backend = kwargs.get("backend", {})
|
||||
if not self.backend:
|
||||
self.backend = self.get_default_backend()
|
||||
|
||||
@abc.abstractmethod
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
|
||||
@@ -159,8 +166,6 @@ class InstrumentProvider(abc.ABC, ProviderBackendMixin):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.backend = kwargs.get("backend", {})
|
||||
if not self.backend:
|
||||
self.backend = self.get_default_backend()
|
||||
|
||||
@staticmethod
|
||||
def instruments(market="all", filter_pipe=None):
|
||||
@@ -252,8 +257,6 @@ class FeatureProvider(abc.ABC, ProviderBackendMixin):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.backend = kwargs.get("backend", {})
|
||||
if not self.backend:
|
||||
self.backend = self.get_default_backend()
|
||||
|
||||
@abc.abstractmethod
|
||||
def feature(self, instrument, field, start_time, end_time, freq):
|
||||
@@ -552,8 +555,18 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
list
|
||||
list of timestamps
|
||||
"""
|
||||
self.backend.setdefault("kwargs", {}).update(freq=freq, future=future)
|
||||
return [pd.Timestamp(x) for x in self.backend_obj.data]
|
||||
|
||||
backend_obj = self.backend_obj(freq=freq, future=future)
|
||||
if future and not backend_obj.check_exists():
|
||||
get_module_logger("data").warning(
|
||||
f"load calendar error: freq={freq}, future={future}; return current calendar!"
|
||||
)
|
||||
get_module_logger("data").warning(
|
||||
"You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md"
|
||||
)
|
||||
backend_obj = self.backend_obj(freq=freq, future=False)
|
||||
|
||||
return [pd.Timestamp(x) for x in backend_obj.data]
|
||||
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
|
||||
_calendar, _calendar_index = self._get_calendar(freq, future)
|
||||
@@ -589,17 +602,15 @@ class LocalInstrumentProvider(InstrumentProvider):
|
||||
"""Instrument file uri."""
|
||||
return os.path.join(C.get_data_path(), "instruments", "{}.txt")
|
||||
|
||||
def _load_instruments(self, market):
|
||||
|
||||
self.backend.setdefault("kwargs", {}).update(market=market)
|
||||
return self.backend_obj.data
|
||||
def _load_instruments(self, market, freq):
|
||||
return self.backend_obj(market=market, freq=freq).data
|
||||
|
||||
def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
|
||||
market = instruments["market"]
|
||||
if market in H["i"]:
|
||||
_instruments = H["i"][market]
|
||||
else:
|
||||
_instruments = self._load_instruments(market)
|
||||
_instruments = self._load_instruments(market, freq=freq)
|
||||
H["i"][market] = _instruments
|
||||
# strip
|
||||
# use calendar boundary
|
||||
@@ -648,9 +659,14 @@ class LocalFeatureProvider(FeatureProvider):
|
||||
# validate
|
||||
field = str(field).lower()[1:]
|
||||
instrument = code_to_fname(instrument)
|
||||
|
||||
self.backend.setdefault("kwargs", {}).update(instrument=instrument, field=field, freq=freq)
|
||||
return self.backend_obj[start_index : end_index + 1]
|
||||
try:
|
||||
data = self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]
|
||||
except Exception as e:
|
||||
get_module_logger("data").warning(
|
||||
f"WARN: data not found for {instrument}.{field}\n\tException info: {str(e)}"
|
||||
)
|
||||
data = pd.Series(dtype=np.float32)
|
||||
return data
|
||||
|
||||
|
||||
class LocalExpressionProvider(ExpressionProvider):
|
||||
|
||||
@@ -3,25 +3,36 @@
|
||||
|
||||
import struct
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Iterable, Union, Dict, Mapping, Tuple
|
||||
from typing import Iterable, Union, Dict, Mapping, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT
|
||||
|
||||
logger = get_module_logger("file_storage")
|
||||
|
||||
class FileCalendarStorage(CalendarStorage):
|
||||
def __init__(self, freq: str, future: bool, uri: str):
|
||||
super(FileCalendarStorage, self).__init__(freq, future, uri)
|
||||
|
||||
class FileStorage:
|
||||
def check_exists(self):
|
||||
return self.uri.exists()
|
||||
|
||||
|
||||
class FileCalendarStorage(FileStorage, CalendarStorage):
|
||||
def __init__(self, freq: str, future: bool, uri: str, **kwargs):
|
||||
super(FileCalendarStorage, self).__init__(freq, future, uri, **kwargs)
|
||||
_file_name = f"{freq}_future.txt" if future else f"{freq}.txt"
|
||||
self.uri = Path(self.uri).expanduser().joinpath(_file_name.lower())
|
||||
self.uri = Path(self.uri).expanduser().joinpath("calendars", _file_name.lower())
|
||||
|
||||
def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> np.ndarray:
|
||||
if not self.uri.exists():
|
||||
def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> Iterable[CalVT]:
|
||||
if not self.check_exists():
|
||||
self._write_calendar(values=[])
|
||||
with self.uri.open("rb") as fp:
|
||||
return np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, encoding="utf-8")
|
||||
return [
|
||||
str(x)
|
||||
for x in np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, delimiter="\n", encoding="utf-8")
|
||||
]
|
||||
|
||||
def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"):
|
||||
with self.uri.open(mode=mode) as fp:
|
||||
@@ -65,23 +76,17 @@ class FileCalendarStorage(CalendarStorage):
|
||||
def __getitem__(self, i: Union[int, slice]) -> Union[CalVT, Iterable[CalVT]]:
|
||||
return self._read_calendar()[i]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._read_calendar())
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._read_calendar())
|
||||
|
||||
|
||||
class FileInstrumentStorage(InstrumentStorage):
|
||||
class FileInstrumentStorage(FileStorage, InstrumentStorage):
|
||||
|
||||
INSTRUMENT_SEP = "\t"
|
||||
INSTRUMENT_START_FIELD = "start_datetime"
|
||||
INSTRUMENT_END_FIELD = "end_datetime"
|
||||
SYMBOL_FIELD_NAME = "instrument"
|
||||
|
||||
def __init__(self, market: str, uri: str):
|
||||
super(FileInstrumentStorage, self).__init__(market, uri)
|
||||
self.uri = Path(self.uri).expanduser().joinpath(f"{market.lower()}.txt")
|
||||
def __init__(self, market: str, uri: str, **kwargs):
|
||||
super(FileInstrumentStorage, self).__init__(market, uri, **kwargs)
|
||||
self.uri = Path(self.uri).expanduser().joinpath("instruments", f"{market.lower()}.txt")
|
||||
|
||||
def _read_instrument(self) -> Dict[InstKT, InstVT]:
|
||||
if not self.uri.exists():
|
||||
@@ -138,14 +143,6 @@ class FileInstrumentStorage(InstrumentStorage):
|
||||
def __getitem__(self, k: InstKT) -> InstVT:
|
||||
return self._read_instrument()[k]
|
||||
|
||||
def __len__(self) -> int:
|
||||
inst = self._read_instrument()
|
||||
return len(inst)
|
||||
|
||||
def __iter__(self) -> Iterator[InstKT]:
|
||||
for _inst in self._read_instrument().keys():
|
||||
yield _inst
|
||||
|
||||
def update(self, *args, **kwargs) -> None:
|
||||
|
||||
if len(args) > 1:
|
||||
@@ -168,11 +165,11 @@ class FileInstrumentStorage(InstrumentStorage):
|
||||
self._write_instrument(inst)
|
||||
|
||||
|
||||
class FileFeatureStorage(FeatureStorage):
|
||||
def __init__(self, instrument: str, field: str, freq: str, uri: str):
|
||||
super(FileFeatureStorage, self).__init__(instrument, field, freq, uri)
|
||||
class FileFeatureStorage(FileStorage, FeatureStorage):
|
||||
def __init__(self, instrument: str, field: str, freq: str, uri: str, **kwargs):
|
||||
super(FileFeatureStorage, self).__init__(instrument, field, freq, uri, **kwargs)
|
||||
self.uri = (
|
||||
Path(self.uri).expanduser().joinpath(instrument.lower()).joinpath(f"{field.lower()}.{freq.lower()}.bin")
|
||||
Path(self.uri).expanduser().joinpath("features", instrument.lower(), f"{field.lower()}.{freq.lower()}.bin")
|
||||
)
|
||||
|
||||
def clear(self):
|
||||
@@ -183,18 +180,45 @@ class FileFeatureStorage(FeatureStorage):
|
||||
def data(self) -> pd.Series:
|
||||
return self[:]
|
||||
|
||||
def extend(self, series: pd.Series) -> None:
|
||||
extend_start_index = self[0][0] + len(self) if self.uri.exists() else series.index[0]
|
||||
series = series.reindex(pd.RangeIndex(extend_start_index, series.index[-1] + 1))
|
||||
with self.uri.open("ab") as fp:
|
||||
np.array(series.values).astype("<f").tofile(fp)
|
||||
def write(self, data_array: Union[List, np.ndarray], index: int = None) -> None:
|
||||
if len(data_array) == 0:
|
||||
logger.info(
|
||||
"len(data_array) == 0, write"
|
||||
"if you need to clear the FeatureStorage, please execute: FeatureStorage.clear"
|
||||
)
|
||||
return
|
||||
if not self.uri.exists():
|
||||
# write
|
||||
index = 0 if index is None else index
|
||||
with self.uri.open("wb") as fp:
|
||||
np.hstack([index, data_array]).astype("<f").tofile(fp)
|
||||
else:
|
||||
if index is None or index > self.end_index:
|
||||
# append
|
||||
index = 0 if index is None else index
|
||||
with self.uri.open("ab+") as fp:
|
||||
np.hstack([[np.nan] * (index - self.end_index - 1), data_array]).astype("<f").tofile(fp)
|
||||
else:
|
||||
# rewrite
|
||||
with self.uri.open("rb+") as fp:
|
||||
_old_data = np.fromfile(fp, dtype="<f")
|
||||
_old_index = _old_data[0]
|
||||
_old_df = pd.DataFrame(
|
||||
_old_data[1:], index=range(_old_index, _old_index + len(_old_data) - 1), columns=["old"]
|
||||
)
|
||||
fp.seek(0)
|
||||
_new_df = pd.DataFrame(data_array, index=range(index, index + len(data_array)), columns=["new"])
|
||||
_df = pd.concat([_old_df, _new_df], sort=False, axis=1)
|
||||
_df = _df.reindex(range(_df.index.min(), _df.index.max() + 1))
|
||||
_df["new"].fillna(_df["old"]).values.astype("<f").tofile(fp)
|
||||
|
||||
def rebase(self, series: pd.Series) -> None:
|
||||
origin_series = self[:]
|
||||
series = series.append(origin_series.loc[origin_series.index > series.index[-1]])
|
||||
series = series.reindex(pd.RangeIndex(series.index[0], series.index[-1]))
|
||||
with self.uri.open("wb") as fp:
|
||||
np.array(series.values).astype("<f").tofile(fp)
|
||||
@property
|
||||
def start_index(self) -> Union[int, None]:
|
||||
if len(self) == 0:
|
||||
return None
|
||||
with open(self.uri, "rb") as fp:
|
||||
index = int(np.frombuffer(fp.read(4), dtype="<f")[0])
|
||||
return index
|
||||
|
||||
def __getitem__(self, i: Union[int, slice]) -> Union[Tuple[int, float], pd.Series]:
|
||||
if not self.uri.exists():
|
||||
@@ -228,18 +252,4 @@ class FileFeatureStorage(FeatureStorage):
|
||||
raise TypeError(f"type(i) = {type(i)}")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.uri.stat().st_size // 4 - 1 if self.uri.exists() else 0
|
||||
|
||||
def __iter__(self):
|
||||
if not self.uri.exists():
|
||||
return
|
||||
with open(self.uri, "rb") as fp:
|
||||
ref_start_index = int(np.frombuffer(fp.read(4), dtype="<f")[0])
|
||||
fp.seek(4)
|
||||
while True:
|
||||
v = fp.read(4)
|
||||
if v:
|
||||
yield ref_start_index, struct.unpack("f", v)[0]
|
||||
ref_start_index += 1
|
||||
else:
|
||||
break
|
||||
return self.uri.stat().st_size // 4 - 1 if self.check_exists() else 0
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Iterable, overload, Tuple, List, Text, Iterator, Union, Dict
|
||||
import re
|
||||
from typing import Iterable, overload, Tuple, List, Text, Union, Dict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
# calendar value type
|
||||
CalVT = str
|
||||
@@ -13,9 +16,91 @@ InstVT = List[Tuple[CalVT, CalVT]]
|
||||
# instrument key
|
||||
InstKT = Text
|
||||
|
||||
logger = get_module_logger("storage")
|
||||
|
||||
class CalendarStorage:
|
||||
def __init__(self, freq: str, future: bool, uri: str):
|
||||
"""
|
||||
If the user is only using it in `qlib`, you can customize Storage to implement only the following methods:
|
||||
|
||||
class UserCalendarStorage(CalendarStorage):
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
pass
|
||||
|
||||
class UserInstrumentStorage(InstrumentStorage):
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
pass
|
||||
|
||||
class UserFeatureStorage(FeatureStorage):
|
||||
|
||||
@check_storage
|
||||
def __getitem__(self, i: slice) -> pd.Series:
|
||||
pass
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class StorageMeta(type):
|
||||
"""unified management of raise when storage is not exists"""
|
||||
|
||||
def __new__(cls, name, bases, dict):
|
||||
class_obj = type.__new__(cls, name, bases, dict)
|
||||
|
||||
# The calls to __iter__ and __getitem__ do not pass through __getattribute__.
|
||||
# In order to throw an exception before calling __getitem__, use the metaclass
|
||||
_getitem_func = getattr(class_obj, "__getitem__")
|
||||
|
||||
def _getitem(obj, item):
|
||||
_check_func = getattr(obj, "_check")
|
||||
if callable(_check_func):
|
||||
_check_func()
|
||||
return _getitem_func(obj, item)
|
||||
|
||||
setattr(class_obj, "__getitem__", _getitem)
|
||||
return class_obj
|
||||
|
||||
|
||||
class BaseStorage(metaclass=StorageMeta):
|
||||
@property
|
||||
def storage_name(self) -> str:
|
||||
return re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2]
|
||||
|
||||
def check_exists(self) -> bool:
|
||||
"""check if storage(uri) exists, if not exists: return False"""
|
||||
raise NotImplementedError("Subclass of BaseStorage must implement `check_exists` method")
|
||||
|
||||
def clear(self) -> None:
|
||||
"""clear storage"""
|
||||
raise NotImplementedError("Subclass of BaseStorage must implement `clear` method")
|
||||
|
||||
def __len__(self) -> 0:
|
||||
return len(self.data) if self.check_exists() else 0
|
||||
|
||||
def __getitem__(self, item: Union[slice, Union[int, InstKT]]):
|
||||
raise NotImplementedError(
|
||||
"Subclass of BaseStorage must implement `__getitem__(i: Union[int, InstKT])`/`__getitem__(s: slice)` method"
|
||||
)
|
||||
|
||||
def _check(self):
|
||||
# check storage(uri)
|
||||
if not self.check_exists():
|
||||
parameters_info = [f"{_k}={_v}" for _k, _v in self.__dict__.items()]
|
||||
raise ValueError(f"{self.storage_name.lower()} not exists, storage parameters: {parameters_info}")
|
||||
|
||||
def __getattribute__(self, item):
|
||||
if item == "data":
|
||||
self._check()
|
||||
return super(BaseStorage, self).__getattribute__(item)
|
||||
|
||||
|
||||
class CalendarStorage(BaseStorage):
|
||||
"""
|
||||
The behavior of CalendarStorage's methods and List's methods of the same name remain consistent
|
||||
"""
|
||||
|
||||
def __init__(self, freq: str, future: bool, uri: str, **kwargs):
|
||||
self.freq = freq
|
||||
self.future = future
|
||||
self.uri = uri
|
||||
@@ -28,9 +113,6 @@ class CalendarStorage:
|
||||
def extend(self, iterable: Iterable[CalVT]) -> None:
|
||||
raise NotImplementedError("Subclass of CalendarStorage must implement `extend` method")
|
||||
|
||||
def clear(self) -> None:
|
||||
raise NotImplementedError("Subclass of CalendarStorage must implement `clear` method")
|
||||
|
||||
def index(self, value: CalVT) -> int:
|
||||
raise NotImplementedError("Subclass of CalendarStorage must implement `index` method")
|
||||
|
||||
@@ -85,16 +167,9 @@ class CalendarStorage:
|
||||
"Subclass of CalendarStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method"
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""x.__len__() <==> len(x)"""
|
||||
raise NotImplementedError("Subclass of CalendarStorage must implement `__len__` method")
|
||||
|
||||
def __iter__(self):
|
||||
raise NotImplementedError("Subclass of CalendarStorage must implement `__iter__` method")
|
||||
|
||||
|
||||
class InstrumentStorage:
|
||||
def __init__(self, market: str, uri: str):
|
||||
class InstrumentStorage(BaseStorage):
|
||||
def __init__(self, market: str, uri: str, **kwargs):
|
||||
self.market = market
|
||||
self.uri = uri
|
||||
|
||||
@@ -103,9 +178,6 @@ class InstrumentStorage:
|
||||
"""get all data"""
|
||||
raise NotImplementedError("Subclass of InstrumentStorage must implement `data` method")
|
||||
|
||||
def clear(self) -> None:
|
||||
raise NotImplementedError("Subclass of InstrumentStorage must implement `clear` method")
|
||||
|
||||
def update(self, *args, **kwargs) -> None:
|
||||
"""D.update([E, ]**F) -> None. Update D from mapping/iterable E and F.
|
||||
If E present and has a .keys() method, does: for k in E: D[k] = E[k]
|
||||
@@ -126,17 +198,9 @@ class InstrumentStorage:
|
||||
""" x.__getitem__(k) <==> x[k] """
|
||||
raise NotImplementedError("Subclass of InstrumentStorage must implement `__getitem__` method")
|
||||
|
||||
def __len__(self) -> int:
|
||||
""" Return len(self). """
|
||||
raise NotImplementedError("Subclass of InstrumentStorage must implement `__len__` method")
|
||||
|
||||
def __iter__(self) -> Iterator[InstKT]:
|
||||
""" Return iter(self). """
|
||||
raise NotImplementedError("Subclass of InstrumentStorage must implement `__iter__` method")
|
||||
|
||||
|
||||
class FeatureStorage:
|
||||
def __init__(self, instrument: str, field: str, freq: str, uri: str):
|
||||
class FeatureStorage(BaseStorage):
|
||||
def __init__(self, instrument: str, field: str, freq: str, uri: str, **kwargs):
|
||||
self.instrument = instrument
|
||||
self.field = field
|
||||
self.freq = freq
|
||||
@@ -147,12 +211,25 @@ class FeatureStorage:
|
||||
"""get all data"""
|
||||
raise NotImplementedError("Subclass of FeatureStorage must implement `data` method")
|
||||
|
||||
def clear(self):
|
||||
""" Remove all items from FeatureStorage. """
|
||||
raise NotImplementedError("Subclass of FeatureStorage must implement `clear` method")
|
||||
@property
|
||||
def start_index(self) -> Union[int, None]:
|
||||
"""get FeatureStorage start index
|
||||
If len(self) == 0; return None
|
||||
"""
|
||||
raise NotImplementedError("Subclass of FeatureStorage must implement `data` method")
|
||||
|
||||
@property
|
||||
def end_index(self) -> Union[int, None]:
|
||||
if len(self) == 0:
|
||||
return None
|
||||
return None if len(self) == 0 else self.start_index + len(self) - 1
|
||||
|
||||
def write(self, data_array: Union[List, np.ndarray, Tuple], index: int = None):
|
||||
"""Write data_array to FeatureStorage starting from index.
|
||||
If index is None, append data_array to feature.
|
||||
If len(data_array) == 0; return
|
||||
If (index - self.end_index) >= 1, self[end_index+1: index] will be filled with np.nan
|
||||
|
||||
def extend(self, series: pd.Series):
|
||||
"""Extend feature by appending elements from the series.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -161,21 +238,42 @@ class FeatureStorage:
|
||||
4 5
|
||||
5 6
|
||||
|
||||
>>> self.extend(pd.Series({7: 8, 9:10}))
|
||||
>>> self.write([6, 7], index=6)
|
||||
|
||||
feature:
|
||||
3 4
|
||||
4 5
|
||||
5 6
|
||||
6 np.nan
|
||||
7 8
|
||||
9 10
|
||||
6 6
|
||||
7 7
|
||||
|
||||
>>> self.write([8], index=9)
|
||||
|
||||
feature:
|
||||
3 4
|
||||
4 5
|
||||
5 6
|
||||
6 6
|
||||
7 7
|
||||
8 np.nan
|
||||
9 8
|
||||
|
||||
>>> self.write([1, np.nan], index=3)
|
||||
|
||||
feature:
|
||||
3 1
|
||||
4 np.nan
|
||||
5 6
|
||||
6 6
|
||||
7 7
|
||||
8 np.nan
|
||||
9 8
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Subclass of FeatureStorage must implement `extend` method")
|
||||
raise NotImplementedError("Subclass of FeatureStorage must implement `write` method")
|
||||
|
||||
def rebase(self, series: pd.Series):
|
||||
"""Rebase feature header from the series.
|
||||
def rebase(self, start_index: int = None, end_index: int = None):
|
||||
"""Rebase the start_index and end_index of the FeatureStorage.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -184,30 +282,85 @@ class FeatureStorage:
|
||||
4 5
|
||||
5 6
|
||||
|
||||
>>> self.rebase(pd.Series({1: 2}))
|
||||
>>> self.rebase(start_index=4)
|
||||
|
||||
feature:
|
||||
1 2
|
||||
2 np.nan
|
||||
3 4
|
||||
4 5
|
||||
5 6
|
||||
|
||||
>>> self.rebase(pd.Series({5: 6, 7: 8, 9: 10}))
|
||||
>>> self.rebase(start_index=3)
|
||||
|
||||
feature:
|
||||
3 np.nan
|
||||
4 5
|
||||
5 6
|
||||
7 8
|
||||
9 10
|
||||
|
||||
>>> self.rebase(pd.Series({11: 12, 12: 13,}))
|
||||
>>> self.write([3], index=3)
|
||||
|
||||
feature:
|
||||
11 12
|
||||
12 13
|
||||
3 3
|
||||
4 5
|
||||
5 6
|
||||
|
||||
>>> self.rebase(end_index=4)
|
||||
|
||||
feature:
|
||||
3 3
|
||||
4 5
|
||||
|
||||
>>> self.write([6, 7, 8], index=4)
|
||||
|
||||
feature:
|
||||
3 3
|
||||
4 6
|
||||
5 7
|
||||
6 8
|
||||
|
||||
>>> self.rebase(start_index=4, end_index=5)
|
||||
|
||||
feature:
|
||||
4 6
|
||||
5 7
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Subclass of FeatureStorage must implement `rebase` method")
|
||||
if start_index is None and end_index is None:
|
||||
logger.warning("both start_index and end_index are None, rebase is ignored")
|
||||
return
|
||||
|
||||
if start_index < 0 or end_index < 0:
|
||||
logger.warning("start_index or end_index cannot be less than 0")
|
||||
return
|
||||
if start_index > end_index:
|
||||
logger.warning(
|
||||
f"start_index({start_index}) > end_index({end_index}), rebase is ignored; "
|
||||
f"if you need to clear the FeatureStorage, please execute: FeatureStorage.clear"
|
||||
)
|
||||
return
|
||||
|
||||
start_index = self.start_index if start_index is None else end_index
|
||||
end_index = self.end_index if end_index is None else end_index
|
||||
if start_index <= self.start_index:
|
||||
self.write([np.nan] * (self.start_index - start_index), start_index)
|
||||
else:
|
||||
self.rewrite(self[start_index:].values, start_index)
|
||||
|
||||
if end_index >= self.end_index:
|
||||
self.write([np.nan] * (end_index - self.end_index))
|
||||
else:
|
||||
self.rewrite(self[: end_index + 1].values, self.start_index)
|
||||
|
||||
def rewrite(self, data: Union[List, np.ndarray, Tuple], index: int):
|
||||
"""overwrite all data in FeatureStorage with data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: Union[List, np.ndarray, Tuple]
|
||||
data
|
||||
index: int
|
||||
data start index
|
||||
"""
|
||||
self.clear()
|
||||
self.write(data, index)
|
||||
|
||||
@overload
|
||||
def __getitem__(self, s: slice) -> pd.Series:
|
||||
@@ -224,11 +377,3 @@ class FeatureStorage:
|
||||
raise NotImplementedError(
|
||||
"Subclass of FeatureStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method"
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""len(feature) <==> feature.__len__() """
|
||||
raise NotImplementedError("Subclass of FeatureStorage must implement `__len__` method")
|
||||
|
||||
def __iter__(self) -> Iterable[Tuple[int, float]]:
|
||||
"""iter(feature)"""
|
||||
raise NotImplementedError("Subclass of FeatureStorage must implement `__iter__` method")
|
||||
|
||||
@@ -9,19 +9,19 @@ from ..config import REG_CN
|
||||
class TestAutoData(unittest.TestCase):
|
||||
|
||||
_setup_kwargs = {}
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
if not exists_qlib_data(cls.provider_uri):
|
||||
print(f"Qlib data is not found in {cls.provider_uri}")
|
||||
|
||||
GetData().qlib_data(
|
||||
name="qlib_data_simple",
|
||||
region="cn",
|
||||
interval="1d",
|
||||
target_dir=provider_uri,
|
||||
target_dir=cls.provider_uri,
|
||||
delete_old=False,
|
||||
)
|
||||
init(provider_uri=provider_uri, region=REG_CN, **cls._setup_kwargs)
|
||||
init(provider_uri=cls.provider_uri, region=REG_CN, **cls._setup_kwargs)
|
||||
|
||||
@@ -2,13 +2,12 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from collections.abc import Iterable
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
from qlib.data.storage.file_storage import (
|
||||
FileCalendarStorage as CalendarStorage,
|
||||
@@ -22,25 +21,10 @@ QLIB_DIR = DATA_DIR.joinpath("qlib")
|
||||
QLIB_DIR.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
# TODO: set value
|
||||
CALENDAR_URI = QLIB_DIR.joinpath("calendars")
|
||||
INSTRUMENT_URI = QLIB_DIR.joinpath("instruments")
|
||||
FEATURE_URI = QLIB_DIR.joinpath("features")
|
||||
|
||||
|
||||
class TestStorage:
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
GetData().qlib_data(name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", delete_old=False)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
shutil.rmtree(str(DATA_DIR.resolve()))
|
||||
|
||||
class TestStorage(TestAutoData):
|
||||
def test_calendar_storage(self):
|
||||
|
||||
calendar = CalendarStorage(freq="day", future=False, uri=CALENDAR_URI)
|
||||
assert isinstance(calendar, Iterable), f"{calendar.__class__.__name__} is not Iterable"
|
||||
calendar = CalendarStorage(freq="day", future=False, uri=self.provider_uri)
|
||||
assert isinstance(calendar[:], Iterable), f"{calendar.__class__.__name__}.__getitem__(s: slice) is not Iterable"
|
||||
assert isinstance(calendar.data, Iterable), f"{calendar.__class__.__name__}.data is not Iterable"
|
||||
|
||||
@@ -82,9 +66,7 @@ class TestStorage:
|
||||
|
||||
"""
|
||||
|
||||
instrument = InstrumentStorage(market="csi300", uri=INSTRUMENT_URI)
|
||||
|
||||
assert isinstance(instrument, Iterable), f"{instrument.__class__.__name__} is not Iterable"
|
||||
instrument = InstrumentStorage(market="csi300", uri=self.provider_uri)
|
||||
|
||||
for inst, spans in instrument.data.items():
|
||||
assert isinstance(inst, str) and isinstance(
|
||||
@@ -151,13 +133,12 @@ class TestStorage:
|
||||
|
||||
"""
|
||||
|
||||
feature = FeatureStorage(instrument="SH600004", field="close", freq="day", uri=FEATURE_URI)
|
||||
feature = FeatureStorage(instrument="SH600004", field="close", freq="day", uri=self.provider_uri)
|
||||
|
||||
assert isinstance(feature, Iterable), f"{feature.__class__.__name__} is not Iterable"
|
||||
with pytest.raises(IndexError):
|
||||
print(feature[0])
|
||||
assert isinstance(
|
||||
feature[815][1], (np.float, np.float32)
|
||||
feature[815][1], (float, np.float32)
|
||||
), f"{feature.__class__.__name__}.__getitem__(i: int) error"
|
||||
assert len(feature[815:818]) == 3, f"{feature.__class__.__name__}.__getitem__(s: slice) error"
|
||||
print(f"feature[815: 818]: {feature[815: 818]}")
|
||||
@@ -167,5 +148,5 @@ class TestStorage:
|
||||
isinstance(_item, tuple) and len(_item) == 2
|
||||
), f"{feature.__class__.__name__}.__iter__ item type error"
|
||||
assert isinstance(_item[0], int) and isinstance(
|
||||
_item[1], (np.float, np.float32)
|
||||
_item[1], (float, np.float32)
|
||||
), f"{feature.__class__.__name__}.__iter__ value type error"
|
||||
|
||||
Reference in New Issue
Block a user