1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Merge pull request #372 from zhupr/data_storage

add data storage
This commit is contained in:
you-n-g
2021-05-26 14:30:46 +08:00
committed by GitHub
11 changed files with 1070 additions and 55 deletions

View File

@@ -53,6 +53,34 @@ Cache
.. autoclass:: qlib.data.cache.DiskDatasetCache
:members:
Storage
-------------
.. autoclass:: qlib.data.storage.storage.BaseStorage
:members:
.. autoclass:: qlib.data.storage.storage.CalendarStorage
:members:
.. autoclass:: qlib.data.storage.storage.InstrumentStorage
:members:
.. autoclass:: qlib.data.storage.storage.FeatureStorage
:members:
.. autoclass:: qlib.data.storage.file_storage.FileStorageMixin
:members:
.. autoclass:: qlib.data.storage.file_storage.FileCalendarStorage
:members:
.. autoclass:: qlib.data.storage.file_storage.FileInstrumentStorage
:members:
.. autoclass:: qlib.data.storage.file_storage.FileFeatureStorage
:members:
Dataset
---------------

View File

@@ -166,7 +166,7 @@ class Position:
def save_position(self, path, last_trade_date):
path = pathlib.Path(path)
p = copy.deepcopy(self.position)
cash = pd.Series(dtype=np.float)
cash = pd.Series(dtype=float)
cash["init_cash"] = self.init_cash
cash["cash"] = p["cash"]
cash["today_account_value"] = p["today_account_value"]

View File

@@ -6,7 +6,9 @@ from __future__ import division
from __future__ import print_function
import os
import re
import abc
import copy
import time
import queue
import bisect
@@ -27,12 +29,41 @@ from .cache import DiskDatasetCache, DiskExpressionCache
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
class CalendarProvider(abc.ABC):
class ProviderBackendMixin:
def get_default_backend(self):
backend = {}
provider_name: str = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2]
# set default storage class
backend.setdefault("class", f"File{provider_name}Storage")
# set default storage module
backend.setdefault("module_path", "qlib.data.storage.file_storage")
return backend
def backend_obj(self, **kwargs):
backend = self.backend if self.backend else self.get_default_backend()
backend = copy.deepcopy(backend)
# set default storage kwargs
backend_kwargs = backend.setdefault("kwargs", {})
# default provider_uri map
if "provider_uri" not in backend_kwargs:
# if the user has no uri configured, use: uri = uri_map[freq]
freq = kwargs.get("freq", "day")
provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {freq: C.get_data_path()})
backend_kwargs["provider_uri"] = provider_uri_map[freq]
backend.setdefault("kwargs", {}).update(**kwargs)
return init_instance_by_config(backend)
class CalendarProvider(abc.ABC, ProviderBackendMixin):
"""Calendar provider base class
Provide calendar data.
"""
def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
@abc.abstractmethod
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
"""Get calendar of certain market in given time range.
@@ -127,12 +158,15 @@ class CalendarProvider(abc.ABC):
return hash_args(start_time, end_time, freq, future)
class InstrumentProvider(abc.ABC):
class InstrumentProvider(abc.ABC, ProviderBackendMixin):
"""Instrument provider base class
Provide instrument data.
"""
def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
@staticmethod
def instruments(market="all", filter_pipe=None):
"""Get the general config dictionary for a base market adding several dynamic filters.
@@ -215,12 +249,15 @@ class InstrumentProvider(abc.ABC):
raise ValueError(f"Unknown instrument type {inst}")
class FeatureProvider(abc.ABC):
class FeatureProvider(abc.ABC, ProviderBackendMixin):
"""Feature provider class
Provide feature data.
"""
def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
@abc.abstractmethod
def feature(self, instrument, field, start_time, end_time, freq):
"""Get feature data.
@@ -497,6 +534,7 @@ class LocalCalendarProvider(CalendarProvider):
"""
def __init__(self, **kwargs):
super(LocalCalendarProvider, self).__init__(**kwargs)
self.remote = kwargs.get("remote", False)
@property
@@ -517,21 +555,22 @@ class LocalCalendarProvider(CalendarProvider):
list
list of timestamps
"""
if future:
fname = self._uri_cal.format(freq + "_future")
# if future calendar not exists, return current calendar
if not os.path.exists(fname):
get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!")
try:
backend_obj = self.backend_obj(freq=freq, future=future).data
except ValueError:
if future:
get_module_logger("data").warning(
f"load calendar error: freq={freq}, future={future}; return current calendar!"
)
get_module_logger("data").warning(
"You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md"
)
fname = self._uri_cal.format(freq)
else:
fname = self._uri_cal.format(freq)
if not os.path.exists(fname):
raise ValueError("calendar not exists for freq " + freq)
with open(fname) as f:
return [pd.Timestamp(x.strip()) for x in f]
backend_obj = self.backend_obj(freq=freq, future=False).data
else:
raise
return [pd.Timestamp(x) for x in backend_obj]
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
_calendar, _calendar_index = self._get_calendar(freq, future)
@@ -562,38 +601,20 @@ class LocalInstrumentProvider(InstrumentProvider):
Provide instrument data from local data source.
"""
def __init__(self):
pass
@property
def _uri_inst(self):
"""Instrument file uri."""
return os.path.join(C.get_data_path(), "instruments", "{}.txt")
def _load_instruments(self, market):
fname = self._uri_inst.format(market)
if not os.path.exists(fname):
raise ValueError("instruments not exists for market " + market)
_instruments = dict()
df = pd.read_csv(
fname,
sep="\t",
usecols=[0, 1, 2],
names=["inst", "start_datetime", "end_datetime"],
dtype={"inst": str},
parse_dates=["start_datetime", "end_datetime"],
)
for row in df.itertuples(index=False):
_instruments.setdefault(row[0], []).append((row[1], row[2]))
return _instruments
def _load_instruments(self, market, freq):
return self.backend_obj(market=market, freq=freq).data
def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
market = instruments["market"]
if market in H["i"]:
_instruments = H["i"][market]
else:
_instruments = self._load_instruments(market)
_instruments = self._load_instruments(market, freq=freq)
H["i"][market] = _instruments
# strip
# use calendar boundary
@@ -604,7 +625,7 @@ class LocalInstrumentProvider(InstrumentProvider):
inst: list(
filter(
lambda x: x[0] <= x[1],
[(max(start_time, x[0]), min(end_time, x[1])) for x in spans],
[(max(start_time, pd.Timestamp(x[0])), min(end_time, pd.Timestamp(x[1]))) for x in spans],
)
)
for inst, spans in _instruments.items()
@@ -630,6 +651,7 @@ class LocalFeatureProvider(FeatureProvider):
"""
def __init__(self, **kwargs):
super(LocalFeatureProvider, self).__init__(**kwargs)
self.remote = kwargs.get("remote", False)
@property
@@ -641,14 +663,7 @@ class LocalFeatureProvider(FeatureProvider):
# validate
field = str(field).lower()[1:]
instrument = code_to_fname(instrument)
uri_data = self._uri_data.format(instrument.lower(), field, freq)
if not os.path.exists(uri_data):
get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field))
return pd.Series(dtype=np.float32)
# raise ValueError('uri_data not found: ' + uri_data)
# load
series = read_bin(uri_data, start_index, end_index)
return series
return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]
class LocalExpressionProvider(ExpressionProvider):
@@ -1065,7 +1080,8 @@ def register_all_wrappers(C):
register_wrapper(Cal, _calendar_provider, "qlib.data")
logger.debug(f"registering Cal {C.calendar_provider}-{C.calendar_cache}")
register_wrapper(Inst, C.instrument_provider, "qlib.data")
_instrument_provider = init_instance_by_config(C.instrument_provider, module)
register_wrapper(Inst, _instrument_provider, "qlib.data")
logger.debug(f"registering Inst {C.instrument_provider}")
if getattr(C, "feature_provider", None) is not None:

View File

@@ -357,7 +357,7 @@ class TSDataSampler:
# get the previous index of a line given index
"""
# object incase of pandas converting int to flaot
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=np.object)
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)
idx_df = lazy_sort_index(idx_df.unstack())
# NOTE: the correctness of `__getitem__` depends on columns sorted here
idx_df = lazy_sort_index(idx_df, axis=1)

View File

@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstVT, InstKT

View File

@@ -0,0 +1,292 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import struct
from pathlib import Path
from typing import Iterable, Union, Dict, Mapping, Tuple, List
import numpy as np
import pandas as pd
from qlib.log import get_module_logger
from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT
logger = get_module_logger("file_storage")
class FileStorageMixin:
@property
def uri(self) -> Path:
_provider_uri = self.kwargs.get("provider_uri", None)
if _provider_uri is None:
raise ValueError(
f"The `provider_uri` parameter is not found in {self.__class__.__name__}, "
f'please specify `provider_uri` in the "provider\'s backend"'
)
return Path(_provider_uri).expanduser().joinpath(f"{self.storage_name}s", self.file_name)
def check(self):
"""check self.uri
Raises
-------
ValueError
"""
if not self.uri.exists():
raise ValueError(f"{self.storage_name} not exists: {self.uri}")
class FileCalendarStorage(FileStorageMixin, CalendarStorage):
def __init__(self, freq: str, future: bool, **kwargs):
super(FileCalendarStorage, self).__init__(freq, future, **kwargs)
self.file_name = f"{freq}_future.txt" if future else f"{freq}.txt".lower()
def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> List[CalVT]:
if not self.uri.exists():
self._write_calendar(values=[])
with self.uri.open("rb") as fp:
return [
str(x)
for x in np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, delimiter="\n", encoding="utf-8")
]
def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"):
with self.uri.open(mode=mode) as fp:
np.savetxt(fp, values, fmt="%s", encoding="utf-8")
@property
def data(self) -> List[CalVT]:
self.check()
return self._read_calendar()
def extend(self, values: Iterable[CalVT]) -> None:
self._write_calendar(values, mode="ab")
def clear(self) -> None:
self._write_calendar(values=[])
def index(self, value: CalVT) -> int:
self.check()
calendar = self._read_calendar()
return int(np.argwhere(calendar == value)[0])
def insert(self, index: int, value: CalVT):
calendar = self._read_calendar()
calendar = np.insert(calendar, index, value)
self._write_calendar(values=calendar)
def remove(self, value: CalVT) -> None:
self.check()
index = self.index(value)
calendar = self._read_calendar()
calendar = np.delete(calendar, index)
self._write_calendar(values=calendar)
def __setitem__(self, i: Union[int, slice], values: Union[CalVT, Iterable[CalVT]]) -> None:
calendar = self._read_calendar()
calendar[i] = values
self._write_calendar(values=calendar)
def __delitem__(self, i: Union[int, slice]) -> None:
self.check()
calendar = self._read_calendar()
calendar = np.delete(calendar, i)
self._write_calendar(values=calendar)
def __getitem__(self, i: Union[int, slice]) -> Union[CalVT, List[CalVT]]:
self.check()
return self._read_calendar()[i]
def __len__(self) -> int:
return len(self.data)
class FileInstrumentStorage(FileStorageMixin, InstrumentStorage):
INSTRUMENT_SEP = "\t"
INSTRUMENT_START_FIELD = "start_datetime"
INSTRUMENT_END_FIELD = "end_datetime"
SYMBOL_FIELD_NAME = "instrument"
def __init__(self, market: str, **kwargs):
super(FileInstrumentStorage, self).__init__(market, **kwargs)
self.file_name = f"{market.lower()}.txt"
def _read_instrument(self) -> Dict[InstKT, InstVT]:
if not self.uri.exists():
self._write_instrument()
_instruments = dict()
df = pd.read_csv(
self.uri,
sep="\t",
usecols=[0, 1, 2],
names=[self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD],
dtype={self.SYMBOL_FIELD_NAME: str},
parse_dates=[self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD],
)
for row in df.itertuples(index=False):
_instruments.setdefault(row[0], []).append((row[1], row[2]))
return _instruments
def _write_instrument(self, data: Dict[InstKT, InstVT] = None) -> None:
if not data:
with self.uri.open("w") as _:
pass
return
res = []
for inst, v_list in data.items():
_df = pd.DataFrame(v_list, columns=[self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD])
_df[self.SYMBOL_FIELD_NAME] = inst
res.append(_df)
df = pd.concat(res, sort=False)
df.loc[:, [self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD]].to_csv(
self.uri, header=False, sep=self.INSTRUMENT_SEP, index=False
)
df.to_csv(self.uri, sep="\t", encoding="utf-8", header=False, index=False)
def clear(self) -> None:
self._write_instrument(data={})
@property
def data(self) -> Dict[InstKT, InstVT]:
self.check()
return self._read_instrument()
def __setitem__(self, k: InstKT, v: InstVT) -> None:
inst = self._read_instrument()
inst[k] = v
self._write_instrument(inst)
def __delitem__(self, k: InstKT) -> None:
self.check()
inst = self._read_instrument()
del inst[k]
self._write_instrument(inst)
def __getitem__(self, k: InstKT) -> InstVT:
self.check()
return self._read_instrument()[k]
def update(self, *args, **kwargs) -> None:
if len(args) > 1:
raise TypeError(f"update expected at most 1 arguments, got {len(args)}")
inst = self._read_instrument()
if args:
other = args[0] # type: dict
if isinstance(other, Mapping):
for key in other:
inst[key] = other[key]
elif hasattr(other, "keys"):
for key in other.keys():
inst[key] = other[key]
else:
for key, value in other:
inst[key] = value
for key, value in kwargs.items():
inst[key] = value
self._write_instrument(inst)
def __len__(self) -> int:
return len(self.data)
class FileFeatureStorage(FileStorageMixin, FeatureStorage):
def __init__(self, instrument: str, field: str, freq: str, **kwargs):
super(FileFeatureStorage, self).__init__(instrument, field, freq, **kwargs)
self.file_name = f"{instrument.lower()}/{field.lower()}.{freq.lower()}.bin"
def clear(self):
with self.uri.open("wb") as _:
pass
@property
def data(self) -> pd.Series:
return self[:]
def write(self, data_array: Union[List, np.ndarray], index: int = None) -> None:
if len(data_array) == 0:
logger.info(
"len(data_array) == 0, write"
"if you need to clear the FeatureStorage, please execute: FeatureStorage.clear"
)
return
if not self.uri.exists():
# write
index = 0 if index is None else index
with self.uri.open("wb") as fp:
np.hstack([index, data_array]).astype("<f").tofile(fp)
else:
if index is None or index > self.end_index:
# append
index = 0 if index is None else index
with self.uri.open("ab+") as fp:
np.hstack([[np.nan] * (index - self.end_index - 1), data_array]).astype("<f").tofile(fp)
else:
# rewrite
with self.uri.open("rb+") as fp:
_old_data = np.fromfile(fp, dtype="<f")
_old_index = _old_data[0]
_old_df = pd.DataFrame(
_old_data[1:], index=range(_old_index, _old_index + len(_old_data) - 1), columns=["old"]
)
fp.seek(0)
_new_df = pd.DataFrame(data_array, index=range(index, index + len(data_array)), columns=["new"])
_df = pd.concat([_old_df, _new_df], sort=False, axis=1)
_df = _df.reindex(range(_df.index.min(), _df.index.max() + 1))
_df["new"].fillna(_df["old"]).values.astype("<f").tofile(fp)
@property
def start_index(self) -> Union[int, None]:
if not self.uri.exists():
return None
with self.uri.open("rb") as fp:
index = int(np.frombuffer(fp.read(4), dtype="<f")[0])
return index
@property
def end_index(self) -> Union[int, None]:
if not self.uri.exists():
return None
# The next data appending index point will be `end_index + 1`
return self.start_index + len(self) - 1
def __getitem__(self, i: Union[int, slice]) -> Union[Tuple[int, float], pd.Series]:
if not self.uri.exists():
if isinstance(i, int):
return None, None
elif isinstance(i, slice):
return pd.Series(dtype=np.float32)
else:
raise TypeError(f"type(i) = {type(i)}")
storage_start_index = self.start_index
storage_end_index = self.end_index
with self.uri.open("rb") as fp:
if isinstance(i, int):
if storage_start_index > i:
raise IndexError(f"{i}: start index is {storage_start_index}")
fp.seek(4 * (i - storage_start_index) + 4)
return i, struct.unpack("f", fp.read(4))[0]
elif isinstance(i, slice):
start_index = storage_start_index if i.start is None else i.start
end_index = storage_end_index if i.stop is None else i.stop - 1
si = max(start_index, storage_start_index)
if si > end_index:
return pd.Series(dtype=np.float32)
fp.seek(4 * (si - storage_start_index) + 4)
# read n bytes
count = end_index - si + 1
data = np.frombuffer(fp.read(4 * count), dtype="<f")
return pd.Series(data, index=pd.RangeIndex(si, si + len(data)))
else:
raise TypeError(f"type(i) = {type(i)}")
def __len__(self) -> int:
self.check()
return self.uri.stat().st_size // 4 - 1

View File

@@ -0,0 +1,501 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
from typing import Iterable, overload, Tuple, List, Text, Union, Dict
import numpy as np
import pandas as pd
from qlib.log import get_module_logger
# calendar value type
CalVT = str
# instrument value
InstVT = List[Tuple[CalVT, CalVT]]
# instrument key
InstKT = Text
logger = get_module_logger("storage")
"""
If the user is only using it in `qlib`, you can customize Storage to implement only the following methods:
class UserCalendarStorage(CalendarStorage):
@property
def data(self) -> Iterable[CalVT]:
'''get all data
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
'''
raise NotImplementedError("Subclass of CalendarStorage must implement `data` method")
class UserInstrumentStorage(InstrumentStorage):
@property
def data(self) -> Dict[InstKT, InstVT]:
'''get all data
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
'''
raise NotImplementedError("Subclass of InstrumentStorage must implement `data` method")
class UserFeatureStorage(FeatureStorage):
def __getitem__(self, s: slice) -> pd.Series:
'''x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]
Returns
-------
pd.Series(values, index=pd.RangeIndex(start, len(values))
Notes
-------
if data(storage) does not exist:
if isinstance(i, int):
return (None, None)
if isinstance(i, slice):
# return empty pd.Series
return pd.Series(dtype=np.float32)
'''
raise NotImplementedError(
"Subclass of FeatureStorage must implement `__getitem__(s: slice)` method"
)
"""
class BaseStorage:
@property
def storage_name(self) -> str:
return re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2].lower()
class CalendarStorage(BaseStorage):
"""
The behavior of CalendarStorage's methods and List's methods of the same name remain consistent
"""
def __init__(self, freq: str, future: bool, **kwargs):
self.freq = freq
self.future = future
self.kwargs = kwargs
@property
def data(self) -> Iterable[CalVT]:
"""get all data
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
"""
raise NotImplementedError("Subclass of CalendarStorage must implement `data` method")
def clear(self) -> None:
raise NotImplementedError("Subclass of CalendarStorage must implement `clear` method")
def extend(self, iterable: Iterable[CalVT]) -> None:
raise NotImplementedError("Subclass of CalendarStorage must implement `extend` method")
def index(self, value: CalVT) -> int:
"""
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
"""
raise NotImplementedError("Subclass of CalendarStorage must implement `index` method")
def insert(self, index: int, value: CalVT) -> None:
raise NotImplementedError("Subclass of CalendarStorage must implement `insert` method")
def remove(self, value: CalVT) -> None:
raise NotImplementedError("Subclass of CalendarStorage must implement `remove` method")
@overload
def __setitem__(self, i: int, value: CalVT) -> None:
"""x.__setitem__(i, o) <==> (x[i] = o)"""
...
@overload
def __setitem__(self, s: slice, value: Iterable[CalVT]) -> None:
"""x.__setitem__(s, o) <==> (x[s] = o)"""
...
def __setitem__(self, i, value) -> None:
raise NotImplementedError(
"Subclass of CalendarStorage must implement `__setitem__(i: int, o: CalVT)`/`__setitem__(s: slice, o: Iterable[CalVT])` method"
)
@overload
def __delitem__(self, i: int) -> None:
"""x.__delitem__(i) <==> del x[i]"""
...
@overload
def __delitem__(self, i: slice) -> None:
"""x.__delitem__(slice(start: int, stop: int, step: int)) <==> del x[start:stop:step]"""
...
def __delitem__(self, i) -> None:
"""
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
"""
raise NotImplementedError(
"Subclass of CalendarStorage must implement `__delitem__(i: int)`/`__delitem__(s: slice)` method"
)
@overload
def __getitem__(self, s: slice) -> Iterable[CalVT]:
"""x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]"""
...
@overload
def __getitem__(self, i: int) -> CalVT:
"""x.__getitem__(i) <==> x[i]"""
...
def __getitem__(self, i) -> CalVT:
"""
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
"""
raise NotImplementedError(
"Subclass of CalendarStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method"
)
def __len__(self) -> int:
"""
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
"""
raise NotImplementedError("Subclass of CalendarStorage must implement `__len__` method")
class InstrumentStorage(BaseStorage):
def __init__(self, market: str, **kwargs):
self.market = market
self.kwargs = kwargs
@property
def data(self) -> Dict[InstKT, InstVT]:
"""get all data
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
"""
raise NotImplementedError("Subclass of InstrumentStorage must implement `data` method")
def clear(self) -> None:
raise NotImplementedError("Subclass of InstrumentStorage must implement `clear` method")
def update(self, *args, **kwargs) -> None:
"""D.update([E, ]**F) -> None. Update D from mapping/iterable E and F.
Notes
------
If E present and has a .keys() method, does: for k in E: D[k] = E[k]
If E present and lacks .keys() method, does: for (k, v) in E: D[k] = v
In either case, this is followed by: for k, v in F.items(): D[k] = v
"""
raise NotImplementedError("Subclass of InstrumentStorage must implement `update` method")
def __setitem__(self, k: InstKT, v: InstVT) -> None:
"""Set self[key] to value."""
raise NotImplementedError("Subclass of InstrumentStorage must implement `__setitem__` method")
def __delitem__(self, k: InstKT) -> None:
"""Delete self[key].
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
"""
raise NotImplementedError("Subclass of InstrumentStorage must implement `__delitem__` method")
def __getitem__(self, k: InstKT) -> InstVT:
"""x.__getitem__(k) <==> x[k]"""
raise NotImplementedError("Subclass of InstrumentStorage must implement `__getitem__` method")
def __len__(self) -> int:
"""
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
"""
raise NotImplementedError("Subclass of InstrumentStorage must implement `__len__` method")
class FeatureStorage(BaseStorage):
def __init__(self, instrument: str, field: str, freq: str, **kwargs):
self.instrument = instrument
self.field = field
self.freq = freq
self.kwargs = kwargs
@property
def data(self) -> pd.Series:
"""get all data
Notes
------
if data(storage) does not exist, return empty pd.Series: `return pd.Series(dtype=np.float32)`
"""
raise NotImplementedError("Subclass of FeatureStorage must implement `data` method")
@property
def start_index(self) -> Union[int, None]:
"""get FeatureStorage start index
Notes
-----
If the data(storage) does not exist, return None
"""
raise NotImplementedError("Subclass of FeatureStorage must implement `start_index` method")
@property
def end_index(self) -> Union[int, None]:
"""get FeatureStorage end index
Notes
-----
The right index of the data range (both sides are closed)
The next data appending point will be `end_index + 1`
If the data(storage) does not exist, return None
"""
raise NotImplementedError("Subclass of FeatureStorage must implement `end_index` method")
def clear(self) -> None:
raise NotImplementedError("Subclass of FeatureStorage must implement `clear` method")
def write(self, data_array: Union[List, np.ndarray, Tuple], index: int = None):
"""Write data_array to FeatureStorage starting from index.
Notes
------
If index is None, append data_array to feature.
If len(data_array) == 0; return
If (index - self.end_index) >= 1, self[end_index+1: index] will be filled with np.nan
Examples
---------
.. code-block::
feature:
3 4
4 5
5 6
>>> self.write([6, 7], index=6)
feature:
3 4
4 5
5 6
6 6
7 7
>>> self.write([8], index=9)
feature:
3 4
4 5
5 6
6 6
7 7
8 np.nan
9 8
>>> self.write([1, np.nan], index=3)
feature:
3 1
4 np.nan
5 6
6 6
7 7
8 np.nan
9 8
"""
raise NotImplementedError("Subclass of FeatureStorage must implement `write` method")
def rebase(self, start_index: int = None, end_index: int = None):
"""Rebase the start_index and end_index of the FeatureStorage.
start_index and end_index are closed intervals: [start_index, end_index]
Examples
---------
.. code-block::
feature:
3 4
4 5
5 6
>>> self.rebase(start_index=4)
feature:
4 5
5 6
>>> self.rebase(start_index=3)
feature:
3 np.nan
4 5
5 6
>>> self.write([3], index=3)
feature:
3 3
4 5
5 6
>>> self.rebase(end_index=4)
feature:
3 3
4 5
>>> self.write([6, 7, 8], index=4)
feature:
3 3
4 6
5 7
6 8
>>> self.rebase(start_index=4, end_index=5)
feature:
4 6
5 7
"""
storage_si = self.start_index
storage_ei = self.end_index
if storage_si is None or storage_ei is None:
raise ValueError("storage.start_index or storage.end_index is None, storage may not exist")
start_index = storage_si if start_index is None else start_index
end_index = storage_ei if end_index is None else end_index
if start_index is None or end_index is None:
logger.warning("both start_index and end_index are None, or storage does not exist; rebase is ignored")
return
if start_index < 0 or end_index < 0:
logger.warning("start_index or end_index cannot be less than 0")
return
if start_index > end_index:
logger.warning(
f"start_index({start_index}) > end_index({end_index}), rebase is ignored; "
f"if you need to clear the FeatureStorage, please execute: FeatureStorage.clear"
)
return
if start_index <= storage_si:
self.write([np.nan] * (storage_si - start_index), start_index)
else:
self.rewrite(self[start_index:].values, start_index)
if end_index >= self.end_index:
self.write([np.nan] * (end_index - self.end_index))
else:
self.rewrite(self[: end_index + 1].values, start_index)
def rewrite(self, data: Union[List, np.ndarray, Tuple], index: int):
"""overwrite all data in FeatureStorage with data
Parameters
----------
data: Union[List, np.ndarray, Tuple]
data
index: int
data start index
"""
self.clear()
self.write(data, index)
@overload
def __getitem__(self, s: slice) -> pd.Series:
"""x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]
Returns
-------
pd.Series(values, index=pd.RangeIndex(start, len(values))
"""
...
@overload
def __getitem__(self, i: int) -> Tuple[int, float]:
"""x.__getitem__(y) <==> x[y]"""
...
def __getitem__(self, i) -> Union[Tuple[int, float], pd.Series]:
"""x.__getitem__(y) <==> x[y]
Notes
-------
if data(storage) does not exist:
if isinstance(i, int):
return (None, None)
if isinstance(i, slice):
# return empty pd.Series
return pd.Series(dtype=np.float32)
"""
raise NotImplementedError(
"Subclass of FeatureStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method"
)
def __len__(self) -> int:
"""
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
"""
raise NotImplementedError("Subclass of FeatureStorage must implement `__len__` method")

View File

@@ -9,19 +9,19 @@ from ..config import REG_CN
class TestAutoData(unittest.TestCase):
_setup_kwargs = {}
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
@classmethod
def setUpClass(cls) -> None:
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
if not exists_qlib_data(cls.provider_uri):
print(f"Qlib data is not found in {cls.provider_uri}")
GetData().qlib_data(
name="qlib_data_simple",
region="cn",
interval="1d",
target_dir=provider_uri,
target_dir=cls.provider_uri,
delete_old=False,
)
init(provider_uri=provider_uri, region=REG_CN, **cls._setup_kwargs)
init(provider_uri=cls.provider_uri, region=REG_CN, **cls._setup_kwargs)

View File

@@ -668,7 +668,10 @@ def exists_qlib_data(qlib_dir):
return False
# check calendar bin
for _calendar in calendars_dir.iterdir():
if not list(features_dir.rglob(f"*.{_calendar.name.split('.')[0]}.bin")):
if ("_future" not in _calendar.name) and (
not list(features_dir.rglob(f"*.{_calendar.name.split('.')[0]}.bin"))
):
return False
# check instruments

View File

@@ -120,7 +120,7 @@ class DumpDataBase:
else:
df = file_or_df
if df.empty or self.date_field_name not in df.columns.tolist():
_calendars = pd.Series()
_calendars = pd.Series(dtype=np.float32)
else:
_calendars = df[self.date_field_name]

View File

@@ -0,0 +1,171 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
from collections.abc import Iterable
import pytest
import numpy as np
from qlib.tests import TestAutoData
from qlib.data.storage.file_storage import (
FileCalendarStorage as CalendarStorage,
FileInstrumentStorage as InstrumentStorage,
FileFeatureStorage as FeatureStorage,
)
_file_name = Path(__file__).name.split(".")[0]
DATA_DIR = Path(__file__).parent.joinpath(f"{_file_name}_data")
QLIB_DIR = DATA_DIR.joinpath("qlib")
QLIB_DIR.mkdir(exist_ok=True, parents=True)
class TestStorage(TestAutoData):
def test_calendar_storage(self):
calendar = CalendarStorage(freq="day", future=False, provider_uri=self.provider_uri)
assert isinstance(calendar[:], Iterable), f"{calendar.__class__.__name__}.__getitem__(s: slice) is not Iterable"
assert isinstance(calendar.data, Iterable), f"{calendar.__class__.__name__}.data is not Iterable"
print(f"calendar[1: 5]: {calendar[1:5]}")
print(f"calendar[0]: {calendar[0]}")
print(f"calendar[-1]: {calendar[-1]}")
calendar = CalendarStorage(freq="1min", future=False, provider_uri="not_found")
with pytest.raises(ValueError):
print(calendar.data)
with pytest.raises(ValueError):
print(calendar[:])
with pytest.raises(ValueError):
print(calendar[0])
def test_instrument_storage(self):
"""
The meaning of instrument, such as CSI500:
CSI500 composition changes:
date add remove
2005-01-01 SH600000
2005-01-01 SH600001
2005-01-01 SH600002
2005-02-01 SH600003 SH600000
2005-02-15 SH600000 SH600002
Calendar:
pd.date_range(start="2020-01-01", stop="2020-03-01", freq="1D")
Instrument:
symbol start_time end_time
SH600000 2005-01-01 2005-01-31 (2005-02-01 Last trading day)
SH600000 2005-02-15 2005-03-01
SH600001 2005-01-01 2005-03-01
SH600002 2005-01-01 2005-02-14 (2005-02-15 Last trading day)
SH600003 2005-02-01 2005-03-01
InstrumentStorage:
{
"SH600000": [(2005-01-01, 2005-01-31), (2005-02-15, 2005-03-01)],
"SH600001": [(2005-01-01, 2005-03-01)],
"SH600002": [(2005-01-01, 2005-02-14)],
"SH600003": [(2005-02-01, 2005-03-01)],
}
"""
instrument = InstrumentStorage(market="csi300", provider_uri=self.provider_uri)
for inst, spans in instrument.data.items():
assert isinstance(inst, str) and isinstance(
spans, Iterable
), f"{instrument.__class__.__name__} value is not Iterable"
for s_e in spans:
assert (
isinstance(s_e, tuple) and len(s_e) == 2
), f"{instrument.__class__.__name__}.__getitem__(k) TypeError"
print(f"instrument['SH600000']: {instrument['SH600000']}")
instrument = InstrumentStorage(market="csi300", provider_uri="not_found")
with pytest.raises(ValueError):
print(instrument.data)
with pytest.raises(ValueError):
print(instrument["sSH600000"])
def test_feature_storage(self):
"""
Calendar:
pd.date_range(start="2005-01-01", stop="2005-03-01", freq="1D")
Instrument:
{
"SH600000": [(2005-01-01, 2005-01-31), (2005-02-15, 2005-03-01)],
"SH600001": [(2005-01-01, 2005-03-01)],
"SH600002": [(2005-01-01, 2005-02-14)],
"SH600003": [(2005-02-01, 2005-03-01)],
}
Feature:
Stock data(close):
2005-01-01 ... 2005-02-01 ... 2005-02-14 2005-02-15 ... 2005-03-01
SH600000 1 ... 3 ... 4 5 6
SH600001 1 ... 4 ... 5 6 7
SH600002 1 ... 5 ... 6 nan nan
SH600003 nan ... 1 ... 2 3 4
FeatureStorage(SH600000, close):
[
(calendar.index("2005-01-01"), 1),
...,
(calendar.index("2005-03-01"), 6)
]
====> [(0, 1), ..., (59, 6)]
FeatureStorage(SH600002, close):
[
(calendar.index("2005-01-01"), 1),
...,
(calendar.index("2005-02-14"), 6)
]
===> [(0, 1), ..., (44, 6)]
FeatureStorage(SH600003, close):
[
(calendar.index("2005-02-01"), 1),
...,
(calendar.index("2005-03-01"), 4)
]
===> [(31, 1), ..., (59, 4)]
"""
feature = FeatureStorage(instrument="SH600004", field="close", freq="day", provider_uri=self.provider_uri)
with pytest.raises(IndexError):
print(feature[0])
assert isinstance(
feature[815][1], (float, np.float32)
), f"{feature.__class__.__name__}.__getitem__(i: int) error"
assert len(feature[815:818]) == 3, f"{feature.__class__.__name__}.__getitem__(s: slice) error"
print(f"feature[815: 818]: \n{feature[815: 818]}")
print(f"feature[:].tail(): \n{feature[:].tail()}")
feature = FeatureStorage(instrument="SH600004", field="close", freq="day", provider_uri="not_fount")
assert feature[0] == (None, None), "FeatureStorage does not exist, feature[i] should return `(None, None)`"
assert feature[:].empty, "FeatureStorage does not exist, feature[:] should return `pd.Series(dtype=np.float32)`"
assert (
feature.data.empty
), "FeatureStorage does not exist, feature.data should return `pd.Series(dtype=np.float32)`"