1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 11:30:57 +08:00
Files
qlib/qlib/data/storage/file_storage.py
2022-06-28 10:17:29 +08:00

372 lines
14 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import struct
from pathlib import Path
from typing import Iterable, Union, Dict, Mapping, Tuple, List
import numpy as np
import pandas as pd
from qlib.utils.time import Freq
from qlib.utils.resam import resam_calendar
from qlib.config import C
from qlib.data.cache import H
from qlib.log import get_module_logger
from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT
logger = get_module_logger("file_storage")
class FileStorageMixin:
"""FileStorageMixin, applicable to FileXXXStorage
Subclasses need to have provider_uri, freq, storage_name, file_name attributes
"""
# NOTE: provider_uri priority:
# 1. self._provider_uri : if provider_uri is provided.
# 2. provider_uri in qlib.config.C
@property
def provider_uri(self):
return C["provider_uri"] if getattr(self, "_provider_uri", None) is None else self._provider_uri
@property
def dpm(self):
return (
C.dpm
if getattr(self, "_provider_uri", None) is None
else C.DataPathManager(self._provider_uri, C.mount_path)
)
@property
def support_freq(self) -> List[str]:
_v = "_support_freq"
if hasattr(self, _v):
return getattr(self, _v)
if len(self.provider_uri) == 1 and C.DEFAULT_FREQ in self.provider_uri:
freq_l = filter(
lambda _freq: not _freq.endswith("_future"),
map(lambda x: x.stem, self.dpm.get_data_uri(C.DEFAULT_FREQ).joinpath("calendars").glob("*.txt")),
)
else:
freq_l = self.provider_uri.keys()
freq_l = [Freq(freq) for freq in freq_l]
setattr(self, _v, freq_l)
return freq_l
@property
def uri(self) -> Path:
if self.freq not in self.support_freq:
raise ValueError(f"{self.storage_name}: {self.provider_uri} does not contain data for {self.freq}")
return self.dpm.get_data_uri(self.freq).joinpath(f"{self.storage_name}s", self.file_name)
def check(self):
"""check self.uri
Raises
-------
ValueError
"""
if not self.uri.exists():
raise ValueError(f"{self.storage_name} not exists: {self.uri}")
class FileCalendarStorage(FileStorageMixin, CalendarStorage):
def __init__(self, freq: str, future: bool, provider_uri: dict = None, **kwargs):
super(FileCalendarStorage, self).__init__(freq, future, **kwargs)
self.future = future
self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri)
self.enable_read_cache = True # TODO: make it configurable
self.region = C["region"]
@property
def file_name(self) -> str:
return f"{self._freq_file}_future.txt" if self.future else f"{self._freq_file}.txt".lower()
@property
def _freq_file(self) -> str:
"""the freq to read from file"""
if not hasattr(self, "_freq_file_cache"):
freq = Freq(self.freq)
if freq not in self.support_freq:
# NOTE: uri
# 1. If `uri` does not exist
# - Get the `min_uri` of the closest `freq` under the same "directory" as the `uri`
# - Read data from `min_uri` and resample to `freq`
freq = Freq.get_recent_freq(freq, self.support_freq)
if freq is None:
raise ValueError(f"can't find a freq from {self.support_freq} that can resample to {self.freq}!")
self._freq_file_cache = freq
return self._freq_file_cache
def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> List[CalVT]:
if not self.uri.exists():
self._write_calendar(values=[])
with self.uri.open("rb") as fp:
return [str(x) for x in np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, encoding="utf-8")]
def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"):
with self.uri.open(mode=mode) as fp:
np.savetxt(fp, values, fmt="%s", encoding="utf-8")
@property
def uri(self) -> Path:
return self.dpm.get_data_uri(self._freq_file).joinpath(f"{self.storage_name}s", self.file_name)
@property
def data(self) -> List[CalVT]:
self.check()
# If cache is enabled, then return cache directly
if self.enable_read_cache:
key = "orig_file" + str(self.uri)
if key not in H["c"]:
H["c"][key] = self._read_calendar()
_calendar = H["c"][key]
else:
_calendar = self._read_calendar()
if Freq(self._freq_file) != Freq(self.freq):
_calendar = resam_calendar(
np.array(list(map(pd.Timestamp, _calendar))), self._freq_file, self.freq, self.region
)
return _calendar
def _get_storage_freq(self) -> List[str]:
return sorted(set(map(lambda x: x.stem.split("_")[0], self.uri.parent.glob("*.txt"))))
def extend(self, values: Iterable[CalVT]) -> None:
self._write_calendar(values, mode="ab")
def clear(self) -> None:
self._write_calendar(values=[])
def index(self, value: CalVT) -> int:
self.check()
calendar = self._read_calendar()
return int(np.argwhere(calendar == value)[0])
def insert(self, index: int, value: CalVT):
calendar = self._read_calendar()
calendar = np.insert(calendar, index, value)
self._write_calendar(values=calendar)
def remove(self, value: CalVT) -> None:
self.check()
index = self.index(value)
calendar = self._read_calendar()
calendar = np.delete(calendar, index)
self._write_calendar(values=calendar)
def __setitem__(self, i: Union[int, slice], values: Union[CalVT, Iterable[CalVT]]) -> None:
calendar = self._read_calendar()
calendar[i] = values
self._write_calendar(values=calendar)
def __delitem__(self, i: Union[int, slice]) -> None:
self.check()
calendar = self._read_calendar()
calendar = np.delete(calendar, i)
self._write_calendar(values=calendar)
def __getitem__(self, i: Union[int, slice]) -> Union[CalVT, List[CalVT]]:
self.check()
return self._read_calendar()[i]
def __len__(self) -> int:
return len(self.data)
class FileInstrumentStorage(FileStorageMixin, InstrumentStorage):
INSTRUMENT_SEP = "\t"
INSTRUMENT_START_FIELD = "start_datetime"
INSTRUMENT_END_FIELD = "end_datetime"
SYMBOL_FIELD_NAME = "instrument"
def __init__(self, market: str, freq: str, provider_uri: dict = None, **kwargs):
super(FileInstrumentStorage, self).__init__(market, freq, **kwargs)
self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri)
self.file_name = f"{market.lower()}.txt"
def _read_instrument(self) -> Dict[InstKT, InstVT]:
if not self.uri.exists():
self._write_instrument()
_instruments = dict()
df = pd.read_csv(
self.uri,
sep="\t",
usecols=[0, 1, 2],
names=[self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD],
dtype={self.SYMBOL_FIELD_NAME: str},
parse_dates=[self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD],
)
for row in df.itertuples(index=False):
_instruments.setdefault(row[0], []).append((row[1], row[2]))
return _instruments
def _write_instrument(self, data: Dict[InstKT, InstVT] = None) -> None:
if not data:
with self.uri.open("w") as _:
pass
return
res = []
for inst, v_list in data.items():
_df = pd.DataFrame(v_list, columns=[self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD])
_df[self.SYMBOL_FIELD_NAME] = inst
res.append(_df)
df = pd.concat(res, sort=False)
df.loc[:, [self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD]].to_csv(
self.uri, header=False, sep=self.INSTRUMENT_SEP, index=False
)
df.to_csv(self.uri, sep="\t", encoding="utf-8", header=False, index=False)
def clear(self) -> None:
self._write_instrument(data={})
@property
def data(self) -> Dict[InstKT, InstVT]:
self.check()
return self._read_instrument()
def __setitem__(self, k: InstKT, v: InstVT) -> None:
inst = self._read_instrument()
inst[k] = v
self._write_instrument(inst)
def __delitem__(self, k: InstKT) -> None:
self.check()
inst = self._read_instrument()
del inst[k]
self._write_instrument(inst)
def __getitem__(self, k: InstKT) -> InstVT:
self.check()
return self._read_instrument()[k]
def update(self, *args, **kwargs) -> None:
if len(args) > 1:
raise TypeError(f"update expected at most 1 arguments, got {len(args)}")
inst = self._read_instrument()
if args:
other = args[0] # type: dict
if isinstance(other, Mapping):
for key in other:
inst[key] = other[key]
elif hasattr(other, "keys"):
for key in other.keys():
inst[key] = other[key]
else:
for key, value in other:
inst[key] = value
for key, value in kwargs.items():
inst[key] = value
self._write_instrument(inst)
def __len__(self) -> int:
return len(self.data)
class FileFeatureStorage(FileStorageMixin, FeatureStorage):
def __init__(self, instrument: str, field: str, freq: str, provider_uri: dict = None, **kwargs):
super(FileFeatureStorage, self).__init__(instrument, field, freq, **kwargs)
self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri)
self.file_name = f"{instrument.lower()}/{field.lower()}.{freq.lower()}.bin"
def clear(self):
with self.uri.open("wb") as _:
pass
@property
def data(self) -> pd.Series:
return self[:]
def write(self, data_array: Union[List, np.ndarray], index: int = None) -> None:
if len(data_array) == 0:
logger.info(
"len(data_array) == 0, write"
"if you need to clear the FeatureStorage, please execute: FeatureStorage.clear"
)
return
if not self.uri.exists():
# write
index = 0 if index is None else index
with self.uri.open("wb") as fp:
np.hstack([index, data_array]).astype("<f").tofile(fp)
else:
if index is None or index > self.end_index:
# append
index = 0 if index is None else index
with self.uri.open("ab+") as fp:
np.hstack([[np.nan] * (index - self.end_index - 1), data_array]).astype("<f").tofile(fp)
else:
# rewrite
with self.uri.open("rb+") as fp:
_old_data = np.fromfile(fp, dtype="<f")
_old_index = _old_data[0]
_old_df = pd.DataFrame(
_old_data[1:], index=range(_old_index, _old_index + len(_old_data) - 1), columns=["old"]
)
fp.seek(0)
_new_df = pd.DataFrame(data_array, index=range(index, index + len(data_array)), columns=["new"])
_df = pd.concat([_old_df, _new_df], sort=False, axis=1)
_df = _df.reindex(range(_df.index.min(), _df.index.max() + 1))
_df["new"].fillna(_df["old"]).values.astype("<f").tofile(fp)
@property
def start_index(self) -> Union[int, None]:
if not self.uri.exists():
return None
with self.uri.open("rb") as fp:
index = int(np.frombuffer(fp.read(4), dtype="<f")[0])
return index
@property
def end_index(self) -> Union[int, None]:
if not self.uri.exists():
return None
# The next data appending index point will be `end_index + 1`
return self.start_index + len(self) - 1
def __getitem__(self, i: Union[int, slice]) -> Union[Tuple[int, float], pd.Series]:
if not self.uri.exists():
if isinstance(i, int):
return None, None
elif isinstance(i, slice):
return pd.Series(dtype=np.float32)
else:
raise TypeError(f"type(i) = {type(i)}")
storage_start_index = self.start_index
storage_end_index = self.end_index
with self.uri.open("rb") as fp:
if isinstance(i, int):
if storage_start_index > i:
raise IndexError(f"{i}: start index is {storage_start_index}")
fp.seek(4 * (i - storage_start_index) + 4)
return i, struct.unpack("f", fp.read(4))[0]
elif isinstance(i, slice):
start_index = storage_start_index if i.start is None else i.start
end_index = storage_end_index if i.stop is None else i.stop - 1
si = max(start_index, storage_start_index)
if si > end_index:
return pd.Series(dtype=np.float32)
fp.seek(4 * (si - storage_start_index) + 4)
# read n bytes
count = end_index - si + 1
data = np.frombuffer(fp.read(4 * count), dtype="<f")
return pd.Series(data, index=pd.RangeIndex(si, si + len(data)))
else:
raise TypeError(f"type(i) = {type(i)}")
def __len__(self) -> int:
self.check()
return self.uri.stat().st_size // 4 - 1