1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 12:00:58 +08:00

Fix FileStorage

This commit is contained in:
zhupr
2021-03-27 01:15:33 +08:00
parent d395c904f2
commit 9b8acd9a82
4 changed files with 73 additions and 71 deletions

View File

@@ -1,4 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .storage import CalendarStorage, InstrumentStorage, FeatureStorage
from .storage import CalendarStorage, InstrumentStorage, FeatureStorage

View File

@@ -1,14 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import struct
from pathlib import Path
from typing import Iterator, Iterable, Type, List, Tuple, Text, Union
from data.storage.storage import FeatureVT
from .storage import FeatureVT
import numpy as np
import pandas as pd
from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage
from . import CalendarStorage, InstrumentStorage, FeatureStorage
CalVT = Type[pd.Timestamp]
@@ -70,9 +71,9 @@ class FileFeatureStorage(FeatureStorage):
if isinstance(i, int):
if ref_start_index > i:
raise IndexError(f"{i}")
raise IndexError(f"{i}: start index is {ref_start_index}")
fp.seek(4 * (i - ref_start_index) + 4)
return i, float(fp.read(4))
return i, struct.unpack("f", fp.read(4))
elif isinstance(i, slice):
start_index = i.start
end_index = i.stop - 1
@@ -83,9 +84,18 @@ class FileFeatureStorage(FeatureStorage):
# read n bytes
count = end_index - si + 1
data = np.frombuffer(fp.read(4 * count), dtype="<f")
return zip(range(si, si + len(data)), data)
return list(zip(range(si, si + len(data)), data))
else:
raise TypeError(f"type(i) = {type(i)}")
def __len__(self) -> int:
return Path(self._uri).stat().st_size // 4 - 1
def __iter__(self):
with open(self._uri, "rb") as fp:
ref_start_index = int(np.frombuffer(fp.read(4), dtype="<f")[0])
fp.seek(4)
# read n bytes
data = np.frombuffer(fp.read(), dtype="<f")
for v in zip(range(ref_start_index, ref_start_index + len(data)), data):
yield v

View File

@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from abc import abstractmethod
from collections.abc import MutableSequence, MutableMapping, Sequence
from typing import Iterable, overload, TypeVar, Tuple, List, Text, Iterator

View File

@@ -1,54 +1,50 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import shutil
from pathlib import Path
from importlib.util import spec_from_file_location, module_from_spec
from collections.abc import Iterable
import pandas as pd
import pytest
import numpy as np
from qlib.tests.data import GetData
from qlib.data.storage.file_storage import (
FileCalendarStorage as CalendarStorage,
FileInstrumentStorage as InstrumentStorage,
FileFeatureStorage as FeatureStorage,
)
DATA_DIR = Path(__file__).parent.joinpath("test_get_data")
QLIB_DIR = DATA_DIR.joinpath("qlib")
QLIB_DIR.mkdir(exist_ok=True, parents=True)
# TODO: set STORAGE_NAME
STORAGE_NAME = ""
STORAGE_FILE_PATH = Path("")
# TODO: set value
CALENDAR_URI = ""
INSTRUMENT_URI = ""
FEATURE_URI = ""
CALENDAR_URI = QLIB_DIR.joinpath("calendars").joinpath("day.txt")
INSTRUMENT_URI = QLIB_DIR.joinpath("instruments").joinpath("csi300.txt")
FEATURE_URI = QLIB_DIR.joinpath("features").joinpath("SH600004").joinpath("close.day.bin")
def get_module(module_path: Path):
module_spec = spec_from_file_location("", module_path)
module = module_from_spec(module_spec)
module_spec.loader.exec_module(module)
return module
class TestStorage:
@classmethod
def setup_class(cls):
GetData().qlib_data(name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", delete_old=False)
@classmethod
def teardown_class(cls):
shutil.rmtree(str(DATA_DIR.resolve()))
STORAGE_MODULE = get_module(STORAGE_FILE_PATH)
CalendarStorage = getattr(STORAGE_MODULE, f"{STORAGE_NAME.title()}CalendarStorage")
InstrumentStorage = getattr(STORAGE_MODULE, f"{STORAGE_NAME.title()}InstrumentStorage")
FeatureStorage = getattr(STORAGE_MODULE, f"{STORAGE_NAME.title()}FeatureStorage")
class TestCalendarStorage:
def test_calendar_storage(self):
# calendar value: pd.date_range(start="2005-01-01", stop="2005-03-01", freq="1D")
start_date = "2005-01-01"
end_date = "2005-03-01"
values = pd.date_range(start_date, end_date, freq="1D")
calendar = CalendarStorage(uri=CALENDAR_URI)
# test `__iter__`
for _s, _t in zip(calendar, values):
assert pd.Timestamp(_s) == pd.Timestamp(_t), f"{calendar.__name__}.__iter__ error"
assert isinstance(calendar, Iterable), f"{calendar.__class__.__name__} is not Iterable"
assert isinstance(calendar[:], Iterable), f"{calendar.__class__.__name__}.__getitem__(s: slice) is not Iterable"
# test `__getitem__(self, s: slice)`
for _s, _t in zip(calendar[1:3], values[1:3]):
assert pd.Timestamp(_s) == pd.Timestamp(_t), f"{calendar.__name__}.__getitem__(s: slice) error"
# test `__getitem__(self, i)`
assert pd.Timestamp(calendar[0]) == pd.Timestamp(values[0]), f"{calendar.__name__}.__getitem__(i: int) error"
print(f"calendar[1: 5]: {calendar[1:5]}")
print(f"calendar[0]: {calendar[0]}")
print(f"calendar[-1]: {calendar[-1]}")
def test_instrument_storage(self):
"""
@@ -83,24 +79,21 @@ class TestCalendarStorage:
}
"""
base_instrument = {
"SH600000": [("2005-01-01", "2005-01-31"), ("2005-02-15", "2005-03-01")],
"SH600001": [("2005-01-01", "2005-03-01")],
"SH600002": [("2005-01-01", "2005-02-14")],
"SH600003": [("2005-02-01", "2005-03-01")],
}
instrument = InstrumentStorage(uri=INSTRUMENT_URI)
# test `keys`
assert sorted(instrument.keys()) == sorted(base_instrument.keys()), f"{instrument.__name__}.keys error"
# test `__getitem__`
assert instrument["SH600000"] == base_instrument["SH600000"], f"{instrument.__name__}.__getitem__ error"
# test `get`
assert instrument.get("SH600001") == base_instrument.get("SH600001"), f"{instrument.__name__}.get error"
# test `items`
for _item in instrument.items():
assert base_instrument[_item[0]] == _item[1]
assert len(instrument.items()) == len(instrument) == len(base_instrument), f"{instrument.__name__}.items error"
assert isinstance(instrument, Iterable), f"{instrument.__class__.__name__} is not Iterable"
for inst, spans in instrument.items():
assert isinstance(inst, str) and isinstance(
spans, Iterable
), f"{instrument.__class__.__name__} value is not Iterable"
for s_e in spans:
assert (
isinstance(s_e, tuple) and len(s_e) == 2
), f"{instrument.__class__.__name__}.__getitem__(k) TypeError"
print(f"instrument['SH600000']: {instrument['SH600000']}")
def test_feature_storage(self):
"""
@@ -156,19 +149,18 @@ class TestCalendarStorage:
"""
# FeatureStorage(SH600003, close)
feature = FeatureStorage(uri=FEATURE_URI)
# 2005-02-01 and 2005-03-01
assert feature[31] == 1 and feature[59] == 4, f"{feature.__name__}.__getitem__(i: int) error"
# 2005-02-01, 2005-02-02, 2005-02-03
# close_items: [(31, 1), ..., (33, <value>)]
close_items = feature[31:34]
assert isinstance(feature, Iterable), f"{feature.__class__.__name__} is not Iterable"
with pytest.raises(IndexError):
print(feature[0])
assert len(feature[815:818]) == 3, f"{feature.__class__.__name__}.__getitem__(s: slice) error"
print(f"feature[815: 818]: {feature[815: 818]}")
# 2005-02-01, ..., 2005-03-01
# feature: [(31, 1), ..., (59, 4)]
print(feature)
assert (
len(feature) == len(feature[:]) == len(feature[31:60]) == 29
), f"{feature.__name__}.items/__getitem__(s: slice) error"
for _item in feature:
assert (
isinstance(_item, tuple) and len(_item) == 2
), f"{feature.__class__.__name__}.__iter__ item type error"
assert isinstance(_item[0], int) and isinstance(
_item[1], (float, np.float, np.float32)
), f"{feature.__class__.__name__}.__iter__ value type error"