mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
@@ -53,6 +53,34 @@ Cache
|
||||
.. autoclass:: qlib.data.cache.DiskDatasetCache
|
||||
:members:
|
||||
|
||||
|
||||
Storage
|
||||
-------------
|
||||
.. autoclass:: qlib.data.storage.storage.BaseStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.storage.CalendarStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.storage.InstrumentStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.storage.FeatureStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.file_storage.FileStorageMixin
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.file_storage.FileCalendarStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.file_storage.FileInstrumentStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.file_storage.FileFeatureStorage
|
||||
:members:
|
||||
|
||||
|
||||
Dataset
|
||||
---------------
|
||||
|
||||
|
||||
@@ -166,7 +166,7 @@ class Position:
|
||||
def save_position(self, path, last_trade_date):
|
||||
path = pathlib.Path(path)
|
||||
p = copy.deepcopy(self.position)
|
||||
cash = pd.Series(dtype=np.float)
|
||||
cash = pd.Series(dtype=float)
|
||||
cash["init_cash"] = self.init_cash
|
||||
cash["cash"] = p["cash"]
|
||||
cash["today_account_value"] = p["today_account_value"]
|
||||
|
||||
@@ -6,7 +6,9 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import re
|
||||
import abc
|
||||
import copy
|
||||
import time
|
||||
import queue
|
||||
import bisect
|
||||
@@ -27,12 +29,41 @@ from .cache import DiskDatasetCache, DiskExpressionCache
|
||||
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
|
||||
|
||||
|
||||
class CalendarProvider(abc.ABC):
|
||||
class ProviderBackendMixin:
|
||||
def get_default_backend(self):
|
||||
backend = {}
|
||||
provider_name: str = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2]
|
||||
# set default storage class
|
||||
backend.setdefault("class", f"File{provider_name}Storage")
|
||||
# set default storage module
|
||||
backend.setdefault("module_path", "qlib.data.storage.file_storage")
|
||||
return backend
|
||||
|
||||
def backend_obj(self, **kwargs):
|
||||
backend = self.backend if self.backend else self.get_default_backend()
|
||||
backend = copy.deepcopy(backend)
|
||||
|
||||
# set default storage kwargs
|
||||
backend_kwargs = backend.setdefault("kwargs", {})
|
||||
# default provider_uri map
|
||||
if "provider_uri" not in backend_kwargs:
|
||||
# if the user has no uri configured, use: uri = uri_map[freq]
|
||||
freq = kwargs.get("freq", "day")
|
||||
provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {freq: C.get_data_path()})
|
||||
backend_kwargs["provider_uri"] = provider_uri_map[freq]
|
||||
backend.setdefault("kwargs", {}).update(**kwargs)
|
||||
return init_instance_by_config(backend)
|
||||
|
||||
|
||||
class CalendarProvider(abc.ABC, ProviderBackendMixin):
|
||||
"""Calendar provider base class
|
||||
|
||||
Provide calendar data.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.backend = kwargs.get("backend", {})
|
||||
|
||||
@abc.abstractmethod
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
|
||||
"""Get calendar of certain market in given time range.
|
||||
@@ -127,12 +158,15 @@ class CalendarProvider(abc.ABC):
|
||||
return hash_args(start_time, end_time, freq, future)
|
||||
|
||||
|
||||
class InstrumentProvider(abc.ABC):
|
||||
class InstrumentProvider(abc.ABC, ProviderBackendMixin):
|
||||
"""Instrument provider base class
|
||||
|
||||
Provide instrument data.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.backend = kwargs.get("backend", {})
|
||||
|
||||
@staticmethod
|
||||
def instruments(market="all", filter_pipe=None):
|
||||
"""Get the general config dictionary for a base market adding several dynamic filters.
|
||||
@@ -215,12 +249,15 @@ class InstrumentProvider(abc.ABC):
|
||||
raise ValueError(f"Unknown instrument type {inst}")
|
||||
|
||||
|
||||
class FeatureProvider(abc.ABC):
|
||||
class FeatureProvider(abc.ABC, ProviderBackendMixin):
|
||||
"""Feature provider class
|
||||
|
||||
Provide feature data.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.backend = kwargs.get("backend", {})
|
||||
|
||||
@abc.abstractmethod
|
||||
def feature(self, instrument, field, start_time, end_time, freq):
|
||||
"""Get feature data.
|
||||
@@ -497,6 +534,7 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(LocalCalendarProvider, self).__init__(**kwargs)
|
||||
self.remote = kwargs.get("remote", False)
|
||||
|
||||
@property
|
||||
@@ -517,21 +555,22 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
list
|
||||
list of timestamps
|
||||
"""
|
||||
if future:
|
||||
fname = self._uri_cal.format(freq + "_future")
|
||||
# if future calendar not exists, return current calendar
|
||||
if not os.path.exists(fname):
|
||||
get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!")
|
||||
|
||||
try:
|
||||
backend_obj = self.backend_obj(freq=freq, future=future).data
|
||||
except ValueError:
|
||||
if future:
|
||||
get_module_logger("data").warning(
|
||||
f"load calendar error: freq={freq}, future={future}; return current calendar!"
|
||||
)
|
||||
get_module_logger("data").warning(
|
||||
"You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md"
|
||||
)
|
||||
fname = self._uri_cal.format(freq)
|
||||
else:
|
||||
fname = self._uri_cal.format(freq)
|
||||
if not os.path.exists(fname):
|
||||
raise ValueError("calendar not exists for freq " + freq)
|
||||
with open(fname) as f:
|
||||
return [pd.Timestamp(x.strip()) for x in f]
|
||||
backend_obj = self.backend_obj(freq=freq, future=False).data
|
||||
else:
|
||||
raise
|
||||
|
||||
return [pd.Timestamp(x) for x in backend_obj]
|
||||
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
|
||||
_calendar, _calendar_index = self._get_calendar(freq, future)
|
||||
@@ -562,38 +601,20 @@ class LocalInstrumentProvider(InstrumentProvider):
|
||||
Provide instrument data from local data source.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def _uri_inst(self):
|
||||
"""Instrument file uri."""
|
||||
return os.path.join(C.get_data_path(), "instruments", "{}.txt")
|
||||
|
||||
def _load_instruments(self, market):
|
||||
fname = self._uri_inst.format(market)
|
||||
if not os.path.exists(fname):
|
||||
raise ValueError("instruments not exists for market " + market)
|
||||
|
||||
_instruments = dict()
|
||||
df = pd.read_csv(
|
||||
fname,
|
||||
sep="\t",
|
||||
usecols=[0, 1, 2],
|
||||
names=["inst", "start_datetime", "end_datetime"],
|
||||
dtype={"inst": str},
|
||||
parse_dates=["start_datetime", "end_datetime"],
|
||||
)
|
||||
for row in df.itertuples(index=False):
|
||||
_instruments.setdefault(row[0], []).append((row[1], row[2]))
|
||||
return _instruments
|
||||
def _load_instruments(self, market, freq):
|
||||
return self.backend_obj(market=market, freq=freq).data
|
||||
|
||||
def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
|
||||
market = instruments["market"]
|
||||
if market in H["i"]:
|
||||
_instruments = H["i"][market]
|
||||
else:
|
||||
_instruments = self._load_instruments(market)
|
||||
_instruments = self._load_instruments(market, freq=freq)
|
||||
H["i"][market] = _instruments
|
||||
# strip
|
||||
# use calendar boundary
|
||||
@@ -604,7 +625,7 @@ class LocalInstrumentProvider(InstrumentProvider):
|
||||
inst: list(
|
||||
filter(
|
||||
lambda x: x[0] <= x[1],
|
||||
[(max(start_time, x[0]), min(end_time, x[1])) for x in spans],
|
||||
[(max(start_time, pd.Timestamp(x[0])), min(end_time, pd.Timestamp(x[1]))) for x in spans],
|
||||
)
|
||||
)
|
||||
for inst, spans in _instruments.items()
|
||||
@@ -630,6 +651,7 @@ class LocalFeatureProvider(FeatureProvider):
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(LocalFeatureProvider, self).__init__(**kwargs)
|
||||
self.remote = kwargs.get("remote", False)
|
||||
|
||||
@property
|
||||
@@ -641,14 +663,7 @@ class LocalFeatureProvider(FeatureProvider):
|
||||
# validate
|
||||
field = str(field).lower()[1:]
|
||||
instrument = code_to_fname(instrument)
|
||||
uri_data = self._uri_data.format(instrument.lower(), field, freq)
|
||||
if not os.path.exists(uri_data):
|
||||
get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field))
|
||||
return pd.Series(dtype=np.float32)
|
||||
# raise ValueError('uri_data not found: ' + uri_data)
|
||||
# load
|
||||
series = read_bin(uri_data, start_index, end_index)
|
||||
return series
|
||||
return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]
|
||||
|
||||
|
||||
class LocalExpressionProvider(ExpressionProvider):
|
||||
@@ -1065,7 +1080,8 @@ def register_all_wrappers(C):
|
||||
register_wrapper(Cal, _calendar_provider, "qlib.data")
|
||||
logger.debug(f"registering Cal {C.calendar_provider}-{C.calendar_cache}")
|
||||
|
||||
register_wrapper(Inst, C.instrument_provider, "qlib.data")
|
||||
_instrument_provider = init_instance_by_config(C.instrument_provider, module)
|
||||
register_wrapper(Inst, _instrument_provider, "qlib.data")
|
||||
logger.debug(f"registering Inst {C.instrument_provider}")
|
||||
|
||||
if getattr(C, "feature_provider", None) is not None:
|
||||
|
||||
@@ -357,7 +357,7 @@ class TSDataSampler:
|
||||
# get the previous index of a line given index
|
||||
"""
|
||||
# object incase of pandas converting int to flaot
|
||||
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=np.object)
|
||||
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)
|
||||
idx_df = lazy_sort_index(idx_df.unstack())
|
||||
# NOTE: the correctness of `__getitem__` depends on columns sorted here
|
||||
idx_df = lazy_sort_index(idx_df, axis=1)
|
||||
|
||||
4
qlib/data/storage/__init__.py
Normal file
4
qlib/data/storage/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstVT, InstKT
|
||||
292
qlib/data/storage/file_storage.py
Normal file
292
qlib/data/storage/file_storage.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import struct
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Union, Dict, Mapping, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT
|
||||
|
||||
logger = get_module_logger("file_storage")
|
||||
|
||||
|
||||
class FileStorageMixin:
|
||||
@property
|
||||
def uri(self) -> Path:
|
||||
_provider_uri = self.kwargs.get("provider_uri", None)
|
||||
if _provider_uri is None:
|
||||
raise ValueError(
|
||||
f"The `provider_uri` parameter is not found in {self.__class__.__name__}, "
|
||||
f'please specify `provider_uri` in the "provider\'s backend"'
|
||||
)
|
||||
return Path(_provider_uri).expanduser().joinpath(f"{self.storage_name}s", self.file_name)
|
||||
|
||||
def check(self):
|
||||
"""check self.uri
|
||||
|
||||
Raises
|
||||
-------
|
||||
ValueError
|
||||
"""
|
||||
if not self.uri.exists():
|
||||
raise ValueError(f"{self.storage_name} not exists: {self.uri}")
|
||||
|
||||
|
||||
class FileCalendarStorage(FileStorageMixin, CalendarStorage):
|
||||
def __init__(self, freq: str, future: bool, **kwargs):
|
||||
super(FileCalendarStorage, self).__init__(freq, future, **kwargs)
|
||||
self.file_name = f"{freq}_future.txt" if future else f"{freq}.txt".lower()
|
||||
|
||||
def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> List[CalVT]:
|
||||
if not self.uri.exists():
|
||||
self._write_calendar(values=[])
|
||||
with self.uri.open("rb") as fp:
|
||||
return [
|
||||
str(x)
|
||||
for x in np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, delimiter="\n", encoding="utf-8")
|
||||
]
|
||||
|
||||
def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"):
|
||||
with self.uri.open(mode=mode) as fp:
|
||||
np.savetxt(fp, values, fmt="%s", encoding="utf-8")
|
||||
|
||||
@property
|
||||
def data(self) -> List[CalVT]:
|
||||
self.check()
|
||||
return self._read_calendar()
|
||||
|
||||
def extend(self, values: Iterable[CalVT]) -> None:
|
||||
self._write_calendar(values, mode="ab")
|
||||
|
||||
def clear(self) -> None:
|
||||
self._write_calendar(values=[])
|
||||
|
||||
def index(self, value: CalVT) -> int:
|
||||
self.check()
|
||||
calendar = self._read_calendar()
|
||||
return int(np.argwhere(calendar == value)[0])
|
||||
|
||||
def insert(self, index: int, value: CalVT):
|
||||
calendar = self._read_calendar()
|
||||
calendar = np.insert(calendar, index, value)
|
||||
self._write_calendar(values=calendar)
|
||||
|
||||
def remove(self, value: CalVT) -> None:
|
||||
self.check()
|
||||
index = self.index(value)
|
||||
calendar = self._read_calendar()
|
||||
calendar = np.delete(calendar, index)
|
||||
self._write_calendar(values=calendar)
|
||||
|
||||
def __setitem__(self, i: Union[int, slice], values: Union[CalVT, Iterable[CalVT]]) -> None:
|
||||
calendar = self._read_calendar()
|
||||
calendar[i] = values
|
||||
self._write_calendar(values=calendar)
|
||||
|
||||
def __delitem__(self, i: Union[int, slice]) -> None:
|
||||
self.check()
|
||||
calendar = self._read_calendar()
|
||||
calendar = np.delete(calendar, i)
|
||||
self._write_calendar(values=calendar)
|
||||
|
||||
def __getitem__(self, i: Union[int, slice]) -> Union[CalVT, List[CalVT]]:
|
||||
self.check()
|
||||
return self._read_calendar()[i]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class FileInstrumentStorage(FileStorageMixin, InstrumentStorage):
|
||||
|
||||
INSTRUMENT_SEP = "\t"
|
||||
INSTRUMENT_START_FIELD = "start_datetime"
|
||||
INSTRUMENT_END_FIELD = "end_datetime"
|
||||
SYMBOL_FIELD_NAME = "instrument"
|
||||
|
||||
def __init__(self, market: str, **kwargs):
|
||||
super(FileInstrumentStorage, self).__init__(market, **kwargs)
|
||||
self.file_name = f"{market.lower()}.txt"
|
||||
|
||||
def _read_instrument(self) -> Dict[InstKT, InstVT]:
|
||||
if not self.uri.exists():
|
||||
self._write_instrument()
|
||||
|
||||
_instruments = dict()
|
||||
df = pd.read_csv(
|
||||
self.uri,
|
||||
sep="\t",
|
||||
usecols=[0, 1, 2],
|
||||
names=[self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD],
|
||||
dtype={self.SYMBOL_FIELD_NAME: str},
|
||||
parse_dates=[self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD],
|
||||
)
|
||||
for row in df.itertuples(index=False):
|
||||
_instruments.setdefault(row[0], []).append((row[1], row[2]))
|
||||
return _instruments
|
||||
|
||||
def _write_instrument(self, data: Dict[InstKT, InstVT] = None) -> None:
|
||||
if not data:
|
||||
with self.uri.open("w") as _:
|
||||
pass
|
||||
return
|
||||
|
||||
res = []
|
||||
for inst, v_list in data.items():
|
||||
_df = pd.DataFrame(v_list, columns=[self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD])
|
||||
_df[self.SYMBOL_FIELD_NAME] = inst
|
||||
res.append(_df)
|
||||
|
||||
df = pd.concat(res, sort=False)
|
||||
df.loc[:, [self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD]].to_csv(
|
||||
self.uri, header=False, sep=self.INSTRUMENT_SEP, index=False
|
||||
)
|
||||
df.to_csv(self.uri, sep="\t", encoding="utf-8", header=False, index=False)
|
||||
|
||||
def clear(self) -> None:
|
||||
self._write_instrument(data={})
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[InstKT, InstVT]:
|
||||
self.check()
|
||||
return self._read_instrument()
|
||||
|
||||
def __setitem__(self, k: InstKT, v: InstVT) -> None:
|
||||
inst = self._read_instrument()
|
||||
inst[k] = v
|
||||
self._write_instrument(inst)
|
||||
|
||||
def __delitem__(self, k: InstKT) -> None:
|
||||
self.check()
|
||||
inst = self._read_instrument()
|
||||
del inst[k]
|
||||
self._write_instrument(inst)
|
||||
|
||||
def __getitem__(self, k: InstKT) -> InstVT:
|
||||
self.check()
|
||||
return self._read_instrument()[k]
|
||||
|
||||
def update(self, *args, **kwargs) -> None:
|
||||
|
||||
if len(args) > 1:
|
||||
raise TypeError(f"update expected at most 1 arguments, got {len(args)}")
|
||||
inst = self._read_instrument()
|
||||
if args:
|
||||
other = args[0] # type: dict
|
||||
if isinstance(other, Mapping):
|
||||
for key in other:
|
||||
inst[key] = other[key]
|
||||
elif hasattr(other, "keys"):
|
||||
for key in other.keys():
|
||||
inst[key] = other[key]
|
||||
else:
|
||||
for key, value in other:
|
||||
inst[key] = value
|
||||
for key, value in kwargs.items():
|
||||
inst[key] = value
|
||||
|
||||
self._write_instrument(inst)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class FileFeatureStorage(FileStorageMixin, FeatureStorage):
|
||||
def __init__(self, instrument: str, field: str, freq: str, **kwargs):
|
||||
super(FileFeatureStorage, self).__init__(instrument, field, freq, **kwargs)
|
||||
self.file_name = f"{instrument.lower()}/{field.lower()}.{freq.lower()}.bin"
|
||||
|
||||
def clear(self):
|
||||
with self.uri.open("wb") as _:
|
||||
pass
|
||||
|
||||
@property
|
||||
def data(self) -> pd.Series:
|
||||
return self[:]
|
||||
|
||||
def write(self, data_array: Union[List, np.ndarray], index: int = None) -> None:
|
||||
if len(data_array) == 0:
|
||||
logger.info(
|
||||
"len(data_array) == 0, write"
|
||||
"if you need to clear the FeatureStorage, please execute: FeatureStorage.clear"
|
||||
)
|
||||
return
|
||||
if not self.uri.exists():
|
||||
# write
|
||||
index = 0 if index is None else index
|
||||
with self.uri.open("wb") as fp:
|
||||
np.hstack([index, data_array]).astype("<f").tofile(fp)
|
||||
else:
|
||||
if index is None or index > self.end_index:
|
||||
# append
|
||||
index = 0 if index is None else index
|
||||
with self.uri.open("ab+") as fp:
|
||||
np.hstack([[np.nan] * (index - self.end_index - 1), data_array]).astype("<f").tofile(fp)
|
||||
else:
|
||||
# rewrite
|
||||
with self.uri.open("rb+") as fp:
|
||||
_old_data = np.fromfile(fp, dtype="<f")
|
||||
_old_index = _old_data[0]
|
||||
_old_df = pd.DataFrame(
|
||||
_old_data[1:], index=range(_old_index, _old_index + len(_old_data) - 1), columns=["old"]
|
||||
)
|
||||
fp.seek(0)
|
||||
_new_df = pd.DataFrame(data_array, index=range(index, index + len(data_array)), columns=["new"])
|
||||
_df = pd.concat([_old_df, _new_df], sort=False, axis=1)
|
||||
_df = _df.reindex(range(_df.index.min(), _df.index.max() + 1))
|
||||
_df["new"].fillna(_df["old"]).values.astype("<f").tofile(fp)
|
||||
|
||||
@property
|
||||
def start_index(self) -> Union[int, None]:
|
||||
if not self.uri.exists():
|
||||
return None
|
||||
with self.uri.open("rb") as fp:
|
||||
index = int(np.frombuffer(fp.read(4), dtype="<f")[0])
|
||||
return index
|
||||
|
||||
@property
|
||||
def end_index(self) -> Union[int, None]:
|
||||
if not self.uri.exists():
|
||||
return None
|
||||
# The next data appending index point will be `end_index + 1`
|
||||
return self.start_index + len(self) - 1
|
||||
|
||||
def __getitem__(self, i: Union[int, slice]) -> Union[Tuple[int, float], pd.Series]:
|
||||
if not self.uri.exists():
|
||||
if isinstance(i, int):
|
||||
return None, None
|
||||
elif isinstance(i, slice):
|
||||
return pd.Series(dtype=np.float32)
|
||||
else:
|
||||
raise TypeError(f"type(i) = {type(i)}")
|
||||
|
||||
storage_start_index = self.start_index
|
||||
storage_end_index = self.end_index
|
||||
with self.uri.open("rb") as fp:
|
||||
if isinstance(i, int):
|
||||
|
||||
if storage_start_index > i:
|
||||
raise IndexError(f"{i}: start index is {storage_start_index}")
|
||||
fp.seek(4 * (i - storage_start_index) + 4)
|
||||
return i, struct.unpack("f", fp.read(4))[0]
|
||||
elif isinstance(i, slice):
|
||||
start_index = storage_start_index if i.start is None else i.start
|
||||
end_index = storage_end_index if i.stop is None else i.stop - 1
|
||||
si = max(start_index, storage_start_index)
|
||||
if si > end_index:
|
||||
return pd.Series(dtype=np.float32)
|
||||
fp.seek(4 * (si - storage_start_index) + 4)
|
||||
# read n bytes
|
||||
count = end_index - si + 1
|
||||
data = np.frombuffer(fp.read(4 * count), dtype="<f")
|
||||
return pd.Series(data, index=pd.RangeIndex(si, si + len(data)))
|
||||
else:
|
||||
raise TypeError(f"type(i) = {type(i)}")
|
||||
|
||||
def __len__(self) -> int:
|
||||
self.check()
|
||||
return self.uri.stat().st_size // 4 - 1
|
||||
501
qlib/data/storage/storage.py
Normal file
501
qlib/data/storage/storage.py
Normal file
@@ -0,0 +1,501 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
from typing import Iterable, overload, Tuple, List, Text, Union, Dict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
# calendar value type
|
||||
CalVT = str
|
||||
|
||||
# instrument value
|
||||
InstVT = List[Tuple[CalVT, CalVT]]
|
||||
# instrument key
|
||||
InstKT = Text
|
||||
|
||||
logger = get_module_logger("storage")
|
||||
|
||||
"""
|
||||
If the user is only using it in `qlib`, you can customize Storage to implement only the following methods:
|
||||
|
||||
class UserCalendarStorage(CalendarStorage):
|
||||
|
||||
@property
|
||||
def data(self) -> Iterable[CalVT]:
|
||||
'''get all data
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
'''
|
||||
raise NotImplementedError("Subclass of CalendarStorage must implement `data` method")
|
||||
|
||||
|
||||
class UserInstrumentStorage(InstrumentStorage):
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[InstKT, InstVT]:
|
||||
'''get all data
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
'''
|
||||
raise NotImplementedError("Subclass of InstrumentStorage must implement `data` method")
|
||||
|
||||
|
||||
class UserFeatureStorage(FeatureStorage):
|
||||
|
||||
def __getitem__(self, s: slice) -> pd.Series:
|
||||
'''x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.Series(values, index=pd.RangeIndex(start, len(values))
|
||||
|
||||
Notes
|
||||
-------
|
||||
if data(storage) does not exist:
|
||||
if isinstance(i, int):
|
||||
return (None, None)
|
||||
if isinstance(i, slice):
|
||||
# return empty pd.Series
|
||||
return pd.Series(dtype=np.float32)
|
||||
'''
|
||||
raise NotImplementedError(
|
||||
"Subclass of FeatureStorage must implement `__getitem__(s: slice)` method"
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class BaseStorage:
|
||||
@property
|
||||
def storage_name(self) -> str:
|
||||
return re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2].lower()
|
||||
|
||||
|
||||
class CalendarStorage(BaseStorage):
|
||||
"""
|
||||
The behavior of CalendarStorage's methods and List's methods of the same name remain consistent
|
||||
"""
|
||||
|
||||
def __init__(self, freq: str, future: bool, **kwargs):
|
||||
self.freq = freq
|
||||
self.future = future
|
||||
self.kwargs = kwargs
|
||||
|
||||
@property
|
||||
def data(self) -> Iterable[CalVT]:
|
||||
"""get all data
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
"""
|
||||
raise NotImplementedError("Subclass of CalendarStorage must implement `data` method")
|
||||
|
||||
def clear(self) -> None:
|
||||
raise NotImplementedError("Subclass of CalendarStorage must implement `clear` method")
|
||||
|
||||
def extend(self, iterable: Iterable[CalVT]) -> None:
|
||||
raise NotImplementedError("Subclass of CalendarStorage must implement `extend` method")
|
||||
|
||||
def index(self, value: CalVT) -> int:
|
||||
"""
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
"""
|
||||
raise NotImplementedError("Subclass of CalendarStorage must implement `index` method")
|
||||
|
||||
def insert(self, index: int, value: CalVT) -> None:
|
||||
raise NotImplementedError("Subclass of CalendarStorage must implement `insert` method")
|
||||
|
||||
def remove(self, value: CalVT) -> None:
|
||||
raise NotImplementedError("Subclass of CalendarStorage must implement `remove` method")
|
||||
|
||||
@overload
|
||||
def __setitem__(self, i: int, value: CalVT) -> None:
|
||||
"""x.__setitem__(i, o) <==> (x[i] = o)"""
|
||||
...
|
||||
|
||||
@overload
|
||||
def __setitem__(self, s: slice, value: Iterable[CalVT]) -> None:
|
||||
"""x.__setitem__(s, o) <==> (x[s] = o)"""
|
||||
...
|
||||
|
||||
def __setitem__(self, i, value) -> None:
|
||||
raise NotImplementedError(
|
||||
"Subclass of CalendarStorage must implement `__setitem__(i: int, o: CalVT)`/`__setitem__(s: slice, o: Iterable[CalVT])` method"
|
||||
)
|
||||
|
||||
@overload
|
||||
def __delitem__(self, i: int) -> None:
|
||||
"""x.__delitem__(i) <==> del x[i]"""
|
||||
...
|
||||
|
||||
@overload
|
||||
def __delitem__(self, i: slice) -> None:
|
||||
"""x.__delitem__(slice(start: int, stop: int, step: int)) <==> del x[start:stop:step]"""
|
||||
...
|
||||
|
||||
def __delitem__(self, i) -> None:
|
||||
"""
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Subclass of CalendarStorage must implement `__delitem__(i: int)`/`__delitem__(s: slice)` method"
|
||||
)
|
||||
|
||||
@overload
|
||||
def __getitem__(self, s: slice) -> Iterable[CalVT]:
|
||||
"""x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]"""
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, i: int) -> CalVT:
|
||||
"""x.__getitem__(i) <==> x[i]"""
|
||||
...
|
||||
|
||||
def __getitem__(self, i) -> CalVT:
|
||||
"""
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Subclass of CalendarStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method"
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Subclass of CalendarStorage must implement `__len__` method")
|
||||
|
||||
|
||||
class InstrumentStorage(BaseStorage):
|
||||
def __init__(self, market: str, **kwargs):
|
||||
self.market = market
|
||||
self.kwargs = kwargs
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[InstKT, InstVT]:
|
||||
"""get all data
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
"""
|
||||
raise NotImplementedError("Subclass of InstrumentStorage must implement `data` method")
|
||||
|
||||
def clear(self) -> None:
|
||||
raise NotImplementedError("Subclass of InstrumentStorage must implement `clear` method")
|
||||
|
||||
def update(self, *args, **kwargs) -> None:
|
||||
"""D.update([E, ]**F) -> None. Update D from mapping/iterable E and F.
|
||||
|
||||
Notes
|
||||
------
|
||||
If E present and has a .keys() method, does: for k in E: D[k] = E[k]
|
||||
|
||||
If E present and lacks .keys() method, does: for (k, v) in E: D[k] = v
|
||||
|
||||
In either case, this is followed by: for k, v in F.items(): D[k] = v
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Subclass of InstrumentStorage must implement `update` method")
|
||||
|
||||
def __setitem__(self, k: InstKT, v: InstVT) -> None:
|
||||
"""Set self[key] to value."""
|
||||
raise NotImplementedError("Subclass of InstrumentStorage must implement `__setitem__` method")
|
||||
|
||||
def __delitem__(self, k: InstKT) -> None:
|
||||
"""Delete self[key].
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
"""
|
||||
raise NotImplementedError("Subclass of InstrumentStorage must implement `__delitem__` method")
|
||||
|
||||
def __getitem__(self, k: InstKT) -> InstVT:
|
||||
"""x.__getitem__(k) <==> x[k]"""
|
||||
raise NotImplementedError("Subclass of InstrumentStorage must implement `__getitem__` method")
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Subclass of InstrumentStorage must implement `__len__` method")
|
||||
|
||||
|
||||
class FeatureStorage(BaseStorage):
|
||||
def __init__(self, instrument: str, field: str, freq: str, **kwargs):
|
||||
self.instrument = instrument
|
||||
self.field = field
|
||||
self.freq = freq
|
||||
self.kwargs = kwargs
|
||||
|
||||
@property
|
||||
def data(self) -> pd.Series:
|
||||
"""get all data
|
||||
|
||||
Notes
|
||||
------
|
||||
if data(storage) does not exist, return empty pd.Series: `return pd.Series(dtype=np.float32)`
|
||||
"""
|
||||
raise NotImplementedError("Subclass of FeatureStorage must implement `data` method")
|
||||
|
||||
@property
|
||||
def start_index(self) -> Union[int, None]:
|
||||
"""get FeatureStorage start index
|
||||
|
||||
Notes
|
||||
-----
|
||||
If the data(storage) does not exist, return None
|
||||
"""
|
||||
raise NotImplementedError("Subclass of FeatureStorage must implement `start_index` method")
|
||||
|
||||
@property
|
||||
def end_index(self) -> Union[int, None]:
|
||||
"""get FeatureStorage end index
|
||||
|
||||
Notes
|
||||
-----
|
||||
The right index of the data range (both sides are closed)
|
||||
|
||||
The next data appending point will be `end_index + 1`
|
||||
|
||||
If the data(storage) does not exist, return None
|
||||
"""
|
||||
raise NotImplementedError("Subclass of FeatureStorage must implement `end_index` method")
|
||||
|
||||
def clear(self) -> None:
|
||||
raise NotImplementedError("Subclass of FeatureStorage must implement `clear` method")
|
||||
|
||||
def write(self, data_array: Union[List, np.ndarray, Tuple], index: int = None):
|
||||
"""Write data_array to FeatureStorage starting from index.
|
||||
|
||||
Notes
|
||||
------
|
||||
If index is None, append data_array to feature.
|
||||
|
||||
If len(data_array) == 0; return
|
||||
|
||||
If (index - self.end_index) >= 1, self[end_index+1: index] will be filled with np.nan
|
||||
|
||||
Examples
|
||||
---------
|
||||
.. code-block::
|
||||
|
||||
feature:
|
||||
3 4
|
||||
4 5
|
||||
5 6
|
||||
|
||||
|
||||
>>> self.write([6, 7], index=6)
|
||||
|
||||
feature:
|
||||
3 4
|
||||
4 5
|
||||
5 6
|
||||
6 6
|
||||
7 7
|
||||
|
||||
>>> self.write([8], index=9)
|
||||
|
||||
feature:
|
||||
3 4
|
||||
4 5
|
||||
5 6
|
||||
6 6
|
||||
7 7
|
||||
8 np.nan
|
||||
9 8
|
||||
|
||||
>>> self.write([1, np.nan], index=3)
|
||||
|
||||
feature:
|
||||
3 1
|
||||
4 np.nan
|
||||
5 6
|
||||
6 6
|
||||
7 7
|
||||
8 np.nan
|
||||
9 8
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Subclass of FeatureStorage must implement `write` method")
|
||||
|
||||
def rebase(self, start_index: int = None, end_index: int = None):
|
||||
"""Rebase the start_index and end_index of the FeatureStorage.
|
||||
|
||||
start_index and end_index are closed intervals: [start_index, end_index]
|
||||
|
||||
Examples
|
||||
---------
|
||||
|
||||
.. code-block::
|
||||
|
||||
feature:
|
||||
3 4
|
||||
4 5
|
||||
5 6
|
||||
|
||||
|
||||
>>> self.rebase(start_index=4)
|
||||
|
||||
feature:
|
||||
4 5
|
||||
5 6
|
||||
|
||||
>>> self.rebase(start_index=3)
|
||||
|
||||
feature:
|
||||
3 np.nan
|
||||
4 5
|
||||
5 6
|
||||
|
||||
>>> self.write([3], index=3)
|
||||
|
||||
feature:
|
||||
3 3
|
||||
4 5
|
||||
5 6
|
||||
|
||||
>>> self.rebase(end_index=4)
|
||||
|
||||
feature:
|
||||
3 3
|
||||
4 5
|
||||
|
||||
>>> self.write([6, 7, 8], index=4)
|
||||
|
||||
feature:
|
||||
3 3
|
||||
4 6
|
||||
5 7
|
||||
6 8
|
||||
|
||||
>>> self.rebase(start_index=4, end_index=5)
|
||||
|
||||
feature:
|
||||
4 6
|
||||
5 7
|
||||
|
||||
"""
|
||||
storage_si = self.start_index
|
||||
storage_ei = self.end_index
|
||||
if storage_si is None or storage_ei is None:
|
||||
raise ValueError("storage.start_index or storage.end_index is None, storage may not exist")
|
||||
|
||||
start_index = storage_si if start_index is None else start_index
|
||||
end_index = storage_ei if end_index is None else end_index
|
||||
|
||||
if start_index is None or end_index is None:
|
||||
logger.warning("both start_index and end_index are None, or storage does not exist; rebase is ignored")
|
||||
return
|
||||
|
||||
if start_index < 0 or end_index < 0:
|
||||
logger.warning("start_index or end_index cannot be less than 0")
|
||||
return
|
||||
if start_index > end_index:
|
||||
logger.warning(
|
||||
f"start_index({start_index}) > end_index({end_index}), rebase is ignored; "
|
||||
f"if you need to clear the FeatureStorage, please execute: FeatureStorage.clear"
|
||||
)
|
||||
return
|
||||
|
||||
if start_index <= storage_si:
|
||||
self.write([np.nan] * (storage_si - start_index), start_index)
|
||||
else:
|
||||
self.rewrite(self[start_index:].values, start_index)
|
||||
|
||||
if end_index >= self.end_index:
|
||||
self.write([np.nan] * (end_index - self.end_index))
|
||||
else:
|
||||
self.rewrite(self[: end_index + 1].values, start_index)
|
||||
|
||||
def rewrite(self, data: Union[List, np.ndarray, Tuple], index: int):
|
||||
"""overwrite all data in FeatureStorage with data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: Union[List, np.ndarray, Tuple]
|
||||
data
|
||||
index: int
|
||||
data start index
|
||||
"""
|
||||
self.clear()
|
||||
self.write(data, index)
|
||||
|
||||
@overload
|
||||
def __getitem__(self, s: slice) -> pd.Series:
|
||||
"""x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.Series(values, index=pd.RangeIndex(start, len(values))
|
||||
"""
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, i: int) -> Tuple[int, float]:
|
||||
"""x.__getitem__(y) <==> x[y]"""
|
||||
...
|
||||
|
||||
def __getitem__(self, i) -> Union[Tuple[int, float], pd.Series]:
|
||||
"""x.__getitem__(y) <==> x[y]
|
||||
|
||||
Notes
|
||||
-------
|
||||
if data(storage) does not exist:
|
||||
if isinstance(i, int):
|
||||
return (None, None)
|
||||
if isinstance(i, slice):
|
||||
# return empty pd.Series
|
||||
return pd.Series(dtype=np.float32)
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Subclass of FeatureStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method"
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Subclass of FeatureStorage must implement `__len__` method")
|
||||
@@ -9,19 +9,19 @@ from ..config import REG_CN
|
||||
class TestAutoData(unittest.TestCase):
|
||||
|
||||
_setup_kwargs = {}
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
if not exists_qlib_data(cls.provider_uri):
|
||||
print(f"Qlib data is not found in {cls.provider_uri}")
|
||||
|
||||
GetData().qlib_data(
|
||||
name="qlib_data_simple",
|
||||
region="cn",
|
||||
interval="1d",
|
||||
target_dir=provider_uri,
|
||||
target_dir=cls.provider_uri,
|
||||
delete_old=False,
|
||||
)
|
||||
init(provider_uri=provider_uri, region=REG_CN, **cls._setup_kwargs)
|
||||
init(provider_uri=cls.provider_uri, region=REG_CN, **cls._setup_kwargs)
|
||||
|
||||
@@ -668,7 +668,10 @@ def exists_qlib_data(qlib_dir):
|
||||
return False
|
||||
# check calendar bin
|
||||
for _calendar in calendars_dir.iterdir():
|
||||
if not list(features_dir.rglob(f"*.{_calendar.name.split('.')[0]}.bin")):
|
||||
|
||||
if ("_future" not in _calendar.name) and (
|
||||
not list(features_dir.rglob(f"*.{_calendar.name.split('.')[0]}.bin"))
|
||||
):
|
||||
return False
|
||||
|
||||
# check instruments
|
||||
|
||||
@@ -120,7 +120,7 @@ class DumpDataBase:
|
||||
else:
|
||||
df = file_or_df
|
||||
if df.empty or self.date_field_name not in df.columns.tolist():
|
||||
_calendars = pd.Series()
|
||||
_calendars = pd.Series(dtype=np.float32)
|
||||
else:
|
||||
_calendars = df[self.date_field_name]
|
||||
|
||||
|
||||
171
tests/storage_tests/test_storage.py
Normal file
171
tests/storage_tests/test_storage.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
from collections.abc import Iterable
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
from qlib.data.storage.file_storage import (
|
||||
FileCalendarStorage as CalendarStorage,
|
||||
FileInstrumentStorage as InstrumentStorage,
|
||||
FileFeatureStorage as FeatureStorage,
|
||||
)
|
||||
|
||||
_file_name = Path(__file__).name.split(".")[0]
|
||||
DATA_DIR = Path(__file__).parent.joinpath(f"{_file_name}_data")
|
||||
QLIB_DIR = DATA_DIR.joinpath("qlib")
|
||||
QLIB_DIR.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
class TestStorage(TestAutoData):
|
||||
def test_calendar_storage(self):
|
||||
|
||||
calendar = CalendarStorage(freq="day", future=False, provider_uri=self.provider_uri)
|
||||
assert isinstance(calendar[:], Iterable), f"{calendar.__class__.__name__}.__getitem__(s: slice) is not Iterable"
|
||||
assert isinstance(calendar.data, Iterable), f"{calendar.__class__.__name__}.data is not Iterable"
|
||||
|
||||
print(f"calendar[1: 5]: {calendar[1:5]}")
|
||||
print(f"calendar[0]: {calendar[0]}")
|
||||
print(f"calendar[-1]: {calendar[-1]}")
|
||||
|
||||
calendar = CalendarStorage(freq="1min", future=False, provider_uri="not_found")
|
||||
with pytest.raises(ValueError):
|
||||
print(calendar.data)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
print(calendar[:])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
print(calendar[0])
|
||||
|
||||
def test_instrument_storage(self):
|
||||
"""
|
||||
The meaning of instrument, such as CSI500:
|
||||
|
||||
CSI500 composition changes:
|
||||
|
||||
date add remove
|
||||
2005-01-01 SH600000
|
||||
2005-01-01 SH600001
|
||||
2005-01-01 SH600002
|
||||
2005-02-01 SH600003 SH600000
|
||||
2005-02-15 SH600000 SH600002
|
||||
|
||||
Calendar:
|
||||
pd.date_range(start="2020-01-01", stop="2020-03-01", freq="1D")
|
||||
|
||||
Instrument:
|
||||
symbol start_time end_time
|
||||
SH600000 2005-01-01 2005-01-31 (2005-02-01 Last trading day)
|
||||
SH600000 2005-02-15 2005-03-01
|
||||
SH600001 2005-01-01 2005-03-01
|
||||
SH600002 2005-01-01 2005-02-14 (2005-02-15 Last trading day)
|
||||
SH600003 2005-02-01 2005-03-01
|
||||
|
||||
InstrumentStorage:
|
||||
{
|
||||
"SH600000": [(2005-01-01, 2005-01-31), (2005-02-15, 2005-03-01)],
|
||||
"SH600001": [(2005-01-01, 2005-03-01)],
|
||||
"SH600002": [(2005-01-01, 2005-02-14)],
|
||||
"SH600003": [(2005-02-01, 2005-03-01)],
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
instrument = InstrumentStorage(market="csi300", provider_uri=self.provider_uri)
|
||||
|
||||
for inst, spans in instrument.data.items():
|
||||
assert isinstance(inst, str) and isinstance(
|
||||
spans, Iterable
|
||||
), f"{instrument.__class__.__name__} value is not Iterable"
|
||||
for s_e in spans:
|
||||
assert (
|
||||
isinstance(s_e, tuple) and len(s_e) == 2
|
||||
), f"{instrument.__class__.__name__}.__getitem__(k) TypeError"
|
||||
|
||||
print(f"instrument['SH600000']: {instrument['SH600000']}")
|
||||
|
||||
instrument = InstrumentStorage(market="csi300", provider_uri="not_found")
|
||||
with pytest.raises(ValueError):
|
||||
print(instrument.data)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
print(instrument["sSH600000"])
|
||||
|
||||
def test_feature_storage(self):
|
||||
"""
|
||||
Calendar:
|
||||
pd.date_range(start="2005-01-01", stop="2005-03-01", freq="1D")
|
||||
|
||||
Instrument:
|
||||
{
|
||||
"SH600000": [(2005-01-01, 2005-01-31), (2005-02-15, 2005-03-01)],
|
||||
"SH600001": [(2005-01-01, 2005-03-01)],
|
||||
"SH600002": [(2005-01-01, 2005-02-14)],
|
||||
"SH600003": [(2005-02-01, 2005-03-01)],
|
||||
}
|
||||
|
||||
Feature:
|
||||
Stock data(close):
|
||||
2005-01-01 ... 2005-02-01 ... 2005-02-14 2005-02-15 ... 2005-03-01
|
||||
SH600000 1 ... 3 ... 4 5 6
|
||||
SH600001 1 ... 4 ... 5 6 7
|
||||
SH600002 1 ... 5 ... 6 nan nan
|
||||
SH600003 nan ... 1 ... 2 3 4
|
||||
|
||||
FeatureStorage(SH600000, close):
|
||||
|
||||
[
|
||||
(calendar.index("2005-01-01"), 1),
|
||||
...,
|
||||
(calendar.index("2005-03-01"), 6)
|
||||
]
|
||||
|
||||
====> [(0, 1), ..., (59, 6)]
|
||||
|
||||
|
||||
FeatureStorage(SH600002, close):
|
||||
|
||||
[
|
||||
(calendar.index("2005-01-01"), 1),
|
||||
...,
|
||||
(calendar.index("2005-02-14"), 6)
|
||||
]
|
||||
|
||||
===> [(0, 1), ..., (44, 6)]
|
||||
|
||||
FeatureStorage(SH600003, close):
|
||||
|
||||
[
|
||||
(calendar.index("2005-02-01"), 1),
|
||||
...,
|
||||
(calendar.index("2005-03-01"), 4)
|
||||
]
|
||||
|
||||
===> [(31, 1), ..., (59, 4)]
|
||||
|
||||
"""
|
||||
|
||||
feature = FeatureStorage(instrument="SH600004", field="close", freq="day", provider_uri=self.provider_uri)
|
||||
|
||||
with pytest.raises(IndexError):
|
||||
print(feature[0])
|
||||
assert isinstance(
|
||||
feature[815][1], (float, np.float32)
|
||||
), f"{feature.__class__.__name__}.__getitem__(i: int) error"
|
||||
assert len(feature[815:818]) == 3, f"{feature.__class__.__name__}.__getitem__(s: slice) error"
|
||||
print(f"feature[815: 818]: \n{feature[815: 818]}")
|
||||
|
||||
print(f"feature[:].tail(): \n{feature[:].tail()}")
|
||||
|
||||
feature = FeatureStorage(instrument="SH600004", field="close", freq="day", provider_uri="not_fount")
|
||||
|
||||
assert feature[0] == (None, None), "FeatureStorage does not exist, feature[i] should return `(None, None)`"
|
||||
assert feature[:].empty, "FeatureStorage does not exist, feature[:] should return `pd.Series(dtype=np.float32)`"
|
||||
assert (
|
||||
feature.data.empty
|
||||
), "FeatureStorage does not exist, feature.data should return `pd.Series(dtype=np.float32)`"
|
||||
Reference in New Issue
Block a user