From 8264033a72bdb8df545e9372f406d0a2397f61e8 Mon Sep 17 00:00:00 2001 From: zhupr Date: Thu, 25 Mar 2021 17:22:05 +0800 Subject: [PATCH 01/24] add data-storage --- qlib/storage/__init__.py | 0 qlib/storage/storage.py | 154 ++++++++++++++++++++++++ tests/storage_tests/test_storage.py | 174 ++++++++++++++++++++++++++++ 3 files changed, 328 insertions(+) create mode 100644 qlib/storage/__init__.py create mode 100644 qlib/storage/storage.py create mode 100644 tests/storage_tests/test_storage.py diff --git a/qlib/storage/__init__.py b/qlib/storage/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/qlib/storage/storage.py b/qlib/storage/storage.py new file mode 100644 index 000000000..dac0e167d --- /dev/null +++ b/qlib/storage/storage.py @@ -0,0 +1,154 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +import abc + +from typing import ( + Iterable, + overload, + TypeVar, + Tuple, + List, + Text, + Optional, + AbstractSet, + Mapping, + Iterator, +) + + +# calendar value type +CalVT = TypeVar("CalVT") + +# instrument value +InstVT = List[Tuple[CalVT, CalVT]] +# instrument key +InstKT = Text + + +FeatureVT = Tuple[int, float] + + +class CalendarStorage: + def __init__(self, uri: str): + self._uri = uri + + def append(self, obj: CalVT) -> None: + """ Append object to the end of the CalendarStorage. """ + raise NotImplementedError("Subclass of CalendarStorage must implement `append` method") + + def clear(self): + """ Remove all items from CalendarStorage. """ + raise NotImplementedError("Subclass of CalendarStorage must implement `clear` method") + + def extend(self, iterable: Iterable[CalVT]): + """ Extend list by appending elements from the iterable. """ + raise NotImplementedError("Subclass of CalendarStorage must implement `extend` method") + + @overload + @abc.abstractmethod + def __getitem__(self, s: slice) -> Iterable[CalVT]: + """x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]""" + raise NotImplementedError("Subclass of CalendarStorage must implement `__getitem__(s: slice)` method") + + @abc.abstractmethod + def __getitem__(self, i: int) -> CalVT: + """x.__getitem__(y) <==> x[y]""" + + raise NotImplementedError("Subclass of CalendarStorage must implement `__getitem__(i: int)` method") + + @abc.abstractmethod + def __iter__(self) -> Iterator[CalVT]: + """ Implement iter(self). """ + raise NotImplementedError("Subclass of CalendarStorage must implement `__iter__` method") + + def __len__(self) -> int: + raise NotImplementedError("Subclass of CalendarStorage must implement `__len__` method") + + +class InstrumentStorage: + def __init__(self, uri: str): + self._uri = uri + + def clear(self) -> None: + """ D.clear() -> None. Remove all items from D. """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `clear` method") + + @abc.abstractmethod + def get(self, k: InstKT) -> Optional[InstVT]: + """D.get(k) -> InstV or None""" + raise NotImplementedError("Subclass of InstrumentStorage must implement `get` method") + + @abc.abstractmethod + def items(self) -> AbstractSet[Tuple[InstKT, InstVT]]: + """ D.items() -> a set-like object providing a view on D's items """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `items` method") + + @abc.abstractmethod + def keys(self) -> AbstractSet[InstKT]: + """ D.keys() -> a set-like object providing a view on D's keys """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `keys` method") + + def update(self, e: Mapping[InstKT, InstVT] = None, **f: InstVT) -> None: + """ + D.update([e, ]**f) -> None. Update D from dict/iterable e and f. + If e is present and has a .keys() method, then does: for k in e: D[k] = e[k] + If e is present and lacks a .keys() method, then does: for k, v in e: D[k] = v + In either case, this is followed by: for k in f: D[k] = f[k] + """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `update` method") + + def __setitem__(self, k: InstKT, v: InstVT) -> None: + """ Set self[key] to value. """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `__setitem__` method") + + def __delitem__(self, k: InstKT) -> None: + """ Delete self[key]. """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `__delitem__` method") + + @abc.abstractmethod + def __getitem__(self, k: InstKT) -> InstVT: + """ x.__getitem__(y) <==> x[y] """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `__getitem__` method") + + def __len__(self) -> int: + """ Return len(self). """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `__len__` method") + + +class FeatureStorage: + def __init__(self, uri: str): + self._uri = uri + + def append(self, obj: FeatureVT) -> None: + """ Append object to the end of the FeatureStorage. """ + raise NotImplementedError("Subclass of FeatureStorage must implement `append` method") + + def clear(self): + """ Remove all items from FeatureStorage. """ + raise NotImplementedError("Subclass of FeatureStorage must implement `clear` method") + + def extend(self, iterable: Iterable[FeatureVT]): + """ Extend list by appending elements from the iterable. """ + raise NotImplementedError("Subclass of FeatureStorage must implement `extend` method") + + @overload + @abc.abstractmethod + def __getitem__(self, s: slice) -> Iterable[FeatureVT]: + """x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]""" + raise NotImplementedError("Subclass of FeatureStorage must implement `__getitem__(s: slice)` method") + + @abc.abstractmethod + def __getitem__(self, i: int) -> float: + """x.__getitem__(y) <==> x[y]""" + + raise NotImplementedError("Subclass of FeatureStorage must implement `__getitem__(i: int)` method") + + def __len__(self) -> int: + raise NotImplementedError("Subclass of FeatureStorage must implement `__len__` method") + + @abc.abstractmethod + def __iter__(self) -> Iterator[FeatureVT]: + """ Implement iter(self). """ + raise NotImplementedError("Subclass of FeatureStorage must implement `__iter__` method") diff --git a/tests/storage_tests/test_storage.py b/tests/storage_tests/test_storage.py new file mode 100644 index 000000000..d4b37be77 --- /dev/null +++ b/tests/storage_tests/test_storage.py @@ -0,0 +1,174 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from pathlib import Path +from importlib.util import spec_from_file_location, module_from_spec + +import pandas as pd + + +# TODO: set STORAGE_NAME +STORAGE_NAME = "" +STORAGE_FILE_PATH = Path("") +# TODO: set value +CALENDAR_URI = "" +INSTRUMENT_URI = "" +FEATURE_URI = "" + + +def get_module(module_path: Path): + module_spec = spec_from_file_location("", module_path) + module = module_from_spec(module_spec) + module_spec.loader.exec_module(module) + return module + + +STORAGE_MODULE = get_module(STORAGE_FILE_PATH) + + +CalendarStorage = getattr(STORAGE_MODULE, f"{STORAGE_NAME.title()}CalendarStorage") +InstrumentStorage = getattr(STORAGE_MODULE, f"{STORAGE_NAME.title()}InstrumentStorage") +FeatureStorage = getattr(STORAGE_MODULE, f"{STORAGE_NAME.title()}FeatureStorage") + + +class TestCalendarStorage: + def test_calendar_storage(self): + # calendar value: pd.date_range(start="2005-01-01", stop="2005-03-01", freq="1D") + start_date = "2005-01-01" + end_date = "2005-03-01" + values = pd.date_range(start_date, end_date, freq="1D") + + calendar = CalendarStorage(uri=CALENDAR_URI) + # test `__iter__` + for _s, _t in zip(calendar, values): + assert pd.Timestamp(_s) == pd.Timestamp(_t), f"{calendar.__name__}.__iter__ error" + + # test `__getitem__(self, s: slice)` + for _s, _t in zip(calendar[1:3], values[1:3]): + assert pd.Timestamp(_s) == pd.Timestamp(_t), f"{calendar.__name__}.__getitem__(s: slice) error" + + # test `__getitem__(self, i)` + assert pd.Timestamp(calendar[0]) == pd.Timestamp(values[0]), f"{calendar.__name__}.__getitem__(i: int) error" + + def test_instrument_storage(self): + """ + The meaning of instrument, such as CSI500: + + CSI500 composition changes: + + date add remove + 2005-01-01 SH600000 + 2005-01-01 SH600001 + 2005-01-01 SH600002 + 2005-02-01 SH600003 SH600000 + 2005-02-15 SH600000 SH600002 + + Calendar: + pd.date_range(start="2020-01-01", stop="2020-03-01", freq="1D") + + Instrument: + symbol start_time end_time + SH600000 2005-01-01 2005-01-31 (2005-02-01 Last trading day) + SH600000 2005-02-15 2005-03-01 + SH600001 2005-01-01 2005-03-01 + SH600002 2005-01-01 2005-02-14 (2005-02-15 Last trading day) + SH600003 2005-02-01 2005-03-01 + + InstrumentStorage: + { + "SH600000": [(2005-01-01, 2005-01-31), (2005-02-15, 2005-03-01)], + "SH600001": [(2005-01-01, 2005-03-01)], + "SH600002": [(2005-01-01, 2005-02-14)], + "SH600003": [(2005-02-01, 2005-03-01)], + } + + """ + base_instrument = { + "SH600000": [("2005-01-01", "2005-01-31"), ("2005-02-15", "2005-03-01")], + "SH600001": [("2005-01-01", "2005-03-01")], + "SH600002": [("2005-01-01", "2005-02-14")], + "SH600003": [("2005-02-01", "2005-03-01")], + } + instrument = InstrumentStorage(uri=INSTRUMENT_URI) + + # test `keys` + assert sorted(instrument.keys()) == sorted(base_instrument.keys()), f"{instrument.__name__}.keys error" + # test `__getitem__` + assert instrument["SH600000"] == base_instrument["SH600000"], f"{instrument.__name__}.__getitem__ error" + # test `get` + assert instrument.get("SH600001") == base_instrument.get("SH600001"), f"{instrument.__name__}.get error" + # test `items` + for _item in instrument.items(): + assert base_instrument[_item[0]] == _item[1] + assert len(instrument.items()) == len(instrument) == len(base_instrument), f"{instrument.__name__}.items error" + + def test_feature_storage(self): + """ + Calendar: + pd.date_range(start="2005-01-01", stop="2005-03-01", freq="1D") + + Instrument: + { + "SH600000": [(2005-01-01, 2005-01-31), (2005-02-15, 2005-03-01)], + "SH600001": [(2005-01-01, 2005-03-01)], + "SH600002": [(2005-01-01, 2005-02-14)], + "SH600003": [(2005-02-01, 2005-03-01)], + } + + Feature: + Stock data(close): + 2005-01-01 ... 2005-02-01 ... 2005-02-14 2005-02-15 ... 2005-03-01 + SH600000 1 ... 3 ... 4 5 6 + SH600001 1 ... 4 ... 5 6 7 + SH600002 1 ... 5 ... 6 nan nan + SH600003 nan ... 1 ... 2 3 4 + + FeatureStorage(SH600000, close): + + [ + (calendar.index("2005-01-01"), 1), + ..., + (calendar.index("2005-03-01"), 6) + ] + + ====> [(0, 1), ..., (59, 6)] + + + FeatureStorage(SH600002, close): + + [ + (calendar.index("2005-01-01"), 1), + ..., + (calendar.index("2005-02-14"), 6) + ] + + ===> [(0, 1), ..., (44, 6)] + + FeatureStorage(SH600003, close): + + [ + (calendar.index("2005-02-01"), 1), + ..., + (calendar.index("2005-03-01"), 4) + ] + + ===> [(31, 1), ..., (59, 4)] + + """ + + # FeatureStorage(SH600003, close) + feature = FeatureStorage(uri=FEATURE_URI) + # 2005-02-01 and 2005-03-01 + assert feature[31] == 1 and feature[59] == 4, f"{feature.__name__}.__getitem__(i: int) error" + + # 2005-02-01, 2005-02-02, 2005-02-03 + # close_items: [(31, 1), ..., (33, )] + close_items = feature[31:34] + + # 2005-02-01, ..., 2005-03-01 + # feature: [(31, 1), ..., (59, 4)] + print(feature) + + assert ( + len(feature) == len(feature[:]) == len(feature[31:60]) == 29 + ), f"{feature.__name__}.items/__getitem__(s: slice) error" From d395c904f2821f7b3b59359a12a2e4195f7d8fb1 Mon Sep 17 00:00:00 2001 From: zhupr Date: Fri, 26 Mar 2021 16:14:45 +0800 Subject: [PATCH 02/24] Add FileStorage --- qlib/data/storage/__init__.py | 4 + qlib/data/storage/file_storage.py | 91 ++++++++++++++++++ qlib/data/storage/storage.py | 135 ++++++++++++++++++++++++++ qlib/storage/__init__.py | 0 qlib/storage/storage.py | 154 ------------------------------ 5 files changed, 230 insertions(+), 154 deletions(-) create mode 100644 qlib/data/storage/__init__.py create mode 100644 qlib/data/storage/file_storage.py create mode 100644 qlib/data/storage/storage.py delete mode 100644 qlib/storage/__init__.py delete mode 100644 qlib/storage/storage.py diff --git a/qlib/data/storage/__init__.py b/qlib/data/storage/__init__.py new file mode 100644 index 000000000..eb513714b --- /dev/null +++ b/qlib/data/storage/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .storage import CalendarStorage, InstrumentStorage, FeatureStorage \ No newline at end of file diff --git a/qlib/data/storage/file_storage.py b/qlib/data/storage/file_storage.py new file mode 100644 index 000000000..9d98545ce --- /dev/null +++ b/qlib/data/storage/file_storage.py @@ -0,0 +1,91 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from pathlib import Path +from typing import Iterator, Iterable, Type, List, Tuple, Text, Union + +from data.storage.storage import FeatureVT + +import numpy as np +import pandas as pd +from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage + + +CalVT = Type[pd.Timestamp] +# instrument value +InstVT = List[Tuple[CalVT, CalVT]] +# instrument key +InstKT = Text + + +class FileCalendarStorage(CalendarStorage): + def __init__(self, uri: str): + super(FileCalendarStorage, self).__init__(uri=uri) + with open(uri) as f: + self._data = [pd.Timestamp(x.strip()) for x in f] + + def __getitem__(self, i: Union[int, slice]) -> Union[CalVT, Iterable[CalVT]]: + if isinstance(i, (int, slice)): + return self._data[i] + else: + raise TypeError(f"type(i) = {type(i)}") + + def __len__(self) -> int: + return len(self._data) + + +class FileInstrumentStorage(InstrumentStorage): + def __init__(self, uri: str): + super(FileInstrumentStorage, self).__init__(uri=uri) + self._data = self._load_data() + + def _load_data(self): + _instruments = dict() + df = pd.read_csv( + self._uri, + sep="\t", + usecols=[0, 1, 2], + names=["inst", "start_datetime", "end_datetime"], + dtype={"inst": str}, + parse_dates=["start_datetime", "end_datetime"], + ) + for row in df.itertuples(index=False): + _instruments.setdefault(row[0], []).append((row[1], row[2])) + return _instruments + + def __getitem__(self, k: InstKT) -> InstVT: + return self._data[k] + + def __len__(self) -> int: + return len(self._data) + + def __iter__(self) -> Iterator[InstKT]: + return self._data.__iter__() + + +class FileFeatureStorage(FeatureStorage): + def __getitem__(self, i: Union[int, slice]) -> Union[FeatureVT, Iterable[FeatureVT]]: + with open(self._uri, "rb") as fp: + ref_start_index = int(np.frombuffer(fp.read(4), dtype=" i: + raise IndexError(f"{i}") + fp.seek(4 * (i - ref_start_index) + 4) + return i, float(fp.read(4)) + elif isinstance(i, slice): + start_index = i.start + end_index = i.stop - 1 + si = max(ref_start_index, start_index) + if si > end_index: + return [] + fp.seek(4 * (si - ref_start_index) + 4) + # read n bytes + count = end_index - si + 1 + data = np.frombuffer(fp.read(4 * count), dtype=" int: + return Path(self._uri).stat().st_size // 4 - 1 diff --git a/qlib/data/storage/storage.py b/qlib/data/storage/storage.py new file mode 100644 index 000000000..7848c243f --- /dev/null +++ b/qlib/data/storage/storage.py @@ -0,0 +1,135 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from abc import abstractmethod +from collections.abc import MutableSequence, MutableMapping, Sequence +from typing import Iterable, overload, TypeVar, Tuple, List, Text, Iterator + + +# calendar value type +CalVT = TypeVar("CalVT") + +# instrument value +InstVT = List[Tuple[CalVT, CalVT]] +# instrument key +InstKT = Text + + +FeatureVT = Tuple[int, float] + + +class CalendarStorage(MutableSequence): + def __init__(self, uri: str): + self._uri = uri + + def insert(self, index: int, o: CalVT) -> None: + raise NotImplementedError("Subclass of CalendarStorage must implement `insert` method") + + @overload + def __setitem__(self, i: int, o: CalVT) -> None: + """x.__setitem__(i, o) <==> x[i] = o""" + ... + + @overload + def __setitem__(self, s: slice, o: Iterable[CalVT]) -> None: + """x.__setitem__(s, o) <==> x[s] = o""" + ... + + def __setitem__(self, i, o) -> None: + raise NotImplementedError( + "Subclass of CalendarStorage must implement `__setitem__(i: int, o: CalVT)`/`__setitem__(s: slice, o: Iterable[CalVT])` method" + ) + + @overload + def __delitem__(self, i: int) -> None: + """x.__delitem__(i) <==> del x[i]""" + ... + + @overload + def __delitem__(self, i: slice) -> None: + """x.__delitem__(slice(start: int, stop: int, step: int)) <==> del x[start:stop:step]""" + ... + + def __delitem__(self, i) -> None: + raise NotImplementedError( + "Subclass of CalendarStorage must implement `__delitem__(i: int)`/`__delitem__(s: slice)` method" + ) + + @overload + def __getitem__(self, s: slice) -> Iterable[CalVT]: + """x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]""" + ... + + @overload + def __getitem__(self, i: int) -> CalVT: + """x.__getitem__(i) <==> x[i]""" + ... + + def __getitem__(self, i) -> CalVT: + raise NotImplementedError( + "Subclass of CalendarStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method" + ) + + def __len__(self) -> int: + """x.__len__() <==> len(x)""" + raise NotImplementedError("Subclass of CalendarStorage must implement `__len__` method") + + +class InstrumentStorage(MutableMapping): + def __init__(self, uri: str): + self._uri = uri + + def __setitem__(self, k: InstKT, v: InstVT) -> None: + """ Set self[key] to value. """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `__setitem__` method") + + def __delitem__(self, k: InstKT) -> None: + """ Delete self[key]. """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `__delitem__` method") + + def __getitem__(self, k: InstKT) -> InstVT: + """ x.__getitem__(k) <==> x[k] """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `__getitem__` method") + + def __len__(self) -> int: + """ Return len(self). """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `__len__` method") + + def __iter__(self) -> Iterator[InstKT]: + """ Return iter(self). """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `__iter__` method") + + +class FeatureStorage(Sequence): + def __init__(self, uri: str): + self._uri = uri + + def append(self, obj: FeatureVT) -> None: + """ Append object to the end of the FeatureStorage. """ + raise NotImplementedError("Subclass of FeatureStorage must implement `append` method") + + def clear(self): + """ Remove all items from FeatureStorage. """ + raise NotImplementedError("Subclass of FeatureStorage must implement `clear` method") + + def extend(self, iterable: Iterable[FeatureVT]): + """ Extend list by appending elements from the iterable. """ + raise NotImplementedError("Subclass of FeatureStorage must implement `extend` method") + + @overload + def __getitem__(self, s: slice) -> Iterable[FeatureVT]: + """x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]""" + ... + + @overload + def __getitem__(self, i: int) -> float: + """x.__getitem__(y) <==> x[y]""" + ... + + def __getitem__(self, i) -> float: + """x.__getitem__(y) <==> x[y]""" + raise NotImplementedError( + "Subclass of FeatureStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method" + ) + + def __len__(self) -> int: + raise NotImplementedError("Subclass of FeatureStorage must implement `__len__` method") diff --git a/qlib/storage/__init__.py b/qlib/storage/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/qlib/storage/storage.py b/qlib/storage/storage.py deleted file mode 100644 index dac0e167d..000000000 --- a/qlib/storage/storage.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - - -import abc - -from typing import ( - Iterable, - overload, - TypeVar, - Tuple, - List, - Text, - Optional, - AbstractSet, - Mapping, - Iterator, -) - - -# calendar value type -CalVT = TypeVar("CalVT") - -# instrument value -InstVT = List[Tuple[CalVT, CalVT]] -# instrument key -InstKT = Text - - -FeatureVT = Tuple[int, float] - - -class CalendarStorage: - def __init__(self, uri: str): - self._uri = uri - - def append(self, obj: CalVT) -> None: - """ Append object to the end of the CalendarStorage. """ - raise NotImplementedError("Subclass of CalendarStorage must implement `append` method") - - def clear(self): - """ Remove all items from CalendarStorage. """ - raise NotImplementedError("Subclass of CalendarStorage must implement `clear` method") - - def extend(self, iterable: Iterable[CalVT]): - """ Extend list by appending elements from the iterable. """ - raise NotImplementedError("Subclass of CalendarStorage must implement `extend` method") - - @overload - @abc.abstractmethod - def __getitem__(self, s: slice) -> Iterable[CalVT]: - """x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]""" - raise NotImplementedError("Subclass of CalendarStorage must implement `__getitem__(s: slice)` method") - - @abc.abstractmethod - def __getitem__(self, i: int) -> CalVT: - """x.__getitem__(y) <==> x[y]""" - - raise NotImplementedError("Subclass of CalendarStorage must implement `__getitem__(i: int)` method") - - @abc.abstractmethod - def __iter__(self) -> Iterator[CalVT]: - """ Implement iter(self). """ - raise NotImplementedError("Subclass of CalendarStorage must implement `__iter__` method") - - def __len__(self) -> int: - raise NotImplementedError("Subclass of CalendarStorage must implement `__len__` method") - - -class InstrumentStorage: - def __init__(self, uri: str): - self._uri = uri - - def clear(self) -> None: - """ D.clear() -> None. Remove all items from D. """ - raise NotImplementedError("Subclass of InstrumentStorage must implement `clear` method") - - @abc.abstractmethod - def get(self, k: InstKT) -> Optional[InstVT]: - """D.get(k) -> InstV or None""" - raise NotImplementedError("Subclass of InstrumentStorage must implement `get` method") - - @abc.abstractmethod - def items(self) -> AbstractSet[Tuple[InstKT, InstVT]]: - """ D.items() -> a set-like object providing a view on D's items """ - raise NotImplementedError("Subclass of InstrumentStorage must implement `items` method") - - @abc.abstractmethod - def keys(self) -> AbstractSet[InstKT]: - """ D.keys() -> a set-like object providing a view on D's keys """ - raise NotImplementedError("Subclass of InstrumentStorage must implement `keys` method") - - def update(self, e: Mapping[InstKT, InstVT] = None, **f: InstVT) -> None: - """ - D.update([e, ]**f) -> None. Update D from dict/iterable e and f. - If e is present and has a .keys() method, then does: for k in e: D[k] = e[k] - If e is present and lacks a .keys() method, then does: for k, v in e: D[k] = v - In either case, this is followed by: for k in f: D[k] = f[k] - """ - raise NotImplementedError("Subclass of InstrumentStorage must implement `update` method") - - def __setitem__(self, k: InstKT, v: InstVT) -> None: - """ Set self[key] to value. """ - raise NotImplementedError("Subclass of InstrumentStorage must implement `__setitem__` method") - - def __delitem__(self, k: InstKT) -> None: - """ Delete self[key]. """ - raise NotImplementedError("Subclass of InstrumentStorage must implement `__delitem__` method") - - @abc.abstractmethod - def __getitem__(self, k: InstKT) -> InstVT: - """ x.__getitem__(y) <==> x[y] """ - raise NotImplementedError("Subclass of InstrumentStorage must implement `__getitem__` method") - - def __len__(self) -> int: - """ Return len(self). """ - raise NotImplementedError("Subclass of InstrumentStorage must implement `__len__` method") - - -class FeatureStorage: - def __init__(self, uri: str): - self._uri = uri - - def append(self, obj: FeatureVT) -> None: - """ Append object to the end of the FeatureStorage. """ - raise NotImplementedError("Subclass of FeatureStorage must implement `append` method") - - def clear(self): - """ Remove all items from FeatureStorage. """ - raise NotImplementedError("Subclass of FeatureStorage must implement `clear` method") - - def extend(self, iterable: Iterable[FeatureVT]): - """ Extend list by appending elements from the iterable. """ - raise NotImplementedError("Subclass of FeatureStorage must implement `extend` method") - - @overload - @abc.abstractmethod - def __getitem__(self, s: slice) -> Iterable[FeatureVT]: - """x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]""" - raise NotImplementedError("Subclass of FeatureStorage must implement `__getitem__(s: slice)` method") - - @abc.abstractmethod - def __getitem__(self, i: int) -> float: - """x.__getitem__(y) <==> x[y]""" - - raise NotImplementedError("Subclass of FeatureStorage must implement `__getitem__(i: int)` method") - - def __len__(self) -> int: - raise NotImplementedError("Subclass of FeatureStorage must implement `__len__` method") - - @abc.abstractmethod - def __iter__(self) -> Iterator[FeatureVT]: - """ Implement iter(self). """ - raise NotImplementedError("Subclass of FeatureStorage must implement `__iter__` method") From 9b8acd9a82f451c3a5f3c145020e2e11f66ed05a Mon Sep 17 00:00:00 2001 From: zhupr Date: Sat, 27 Mar 2021 01:15:33 +0800 Subject: [PATCH 03/24] Fix FileStorage --- qlib/data/storage/__init__.py | 2 +- qlib/data/storage/file_storage.py | 20 +++-- qlib/data/storage/storage.py | 2 +- tests/storage_tests/test_storage.py | 120 +++++++++++++--------------- 4 files changed, 73 insertions(+), 71 deletions(-) diff --git a/qlib/data/storage/__init__.py b/qlib/data/storage/__init__.py index eb513714b..f42504791 100644 --- a/qlib/data/storage/__init__.py +++ b/qlib/data/storage/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .storage import CalendarStorage, InstrumentStorage, FeatureStorage \ No newline at end of file +from .storage import CalendarStorage, InstrumentStorage, FeatureStorage diff --git a/qlib/data/storage/file_storage.py b/qlib/data/storage/file_storage.py index 9d98545ce..aadc918c3 100644 --- a/qlib/data/storage/file_storage.py +++ b/qlib/data/storage/file_storage.py @@ -1,14 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import struct from pathlib import Path from typing import Iterator, Iterable, Type, List, Tuple, Text, Union -from data.storage.storage import FeatureVT +from .storage import FeatureVT import numpy as np import pandas as pd -from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage +from . import CalendarStorage, InstrumentStorage, FeatureStorage CalVT = Type[pd.Timestamp] @@ -70,9 +71,9 @@ class FileFeatureStorage(FeatureStorage): if isinstance(i, int): if ref_start_index > i: - raise IndexError(f"{i}") + raise IndexError(f"{i}: start index is {ref_start_index}") fp.seek(4 * (i - ref_start_index) + 4) - return i, float(fp.read(4)) + return i, struct.unpack("f", fp.read(4)) elif isinstance(i, slice): start_index = i.start end_index = i.stop - 1 @@ -83,9 +84,18 @@ class FileFeatureStorage(FeatureStorage): # read n bytes count = end_index - si + 1 data = np.frombuffer(fp.read(4 * count), dtype=" int: return Path(self._uri).stat().st_size // 4 - 1 + + def __iter__(self): + with open(self._uri, "rb") as fp: + ref_start_index = int(np.frombuffer(fp.read(4), dtype=")] - close_items = feature[31:34] + assert isinstance(feature, Iterable), f"{feature.__class__.__name__} is not Iterable" + with pytest.raises(IndexError): + print(feature[0]) + assert len(feature[815:818]) == 3, f"{feature.__class__.__name__}.__getitem__(s: slice) error" + print(f"feature[815: 818]: {feature[815: 818]}") - # 2005-02-01, ..., 2005-03-01 - # feature: [(31, 1), ..., (59, 4)] - print(feature) - - assert ( - len(feature) == len(feature[:]) == len(feature[31:60]) == 29 - ), f"{feature.__name__}.items/__getitem__(s: slice) error" + for _item in feature: + assert ( + isinstance(_item, tuple) and len(_item) == 2 + ), f"{feature.__class__.__name__}.__iter__ item type error" + assert isinstance(_item[0], int) and isinstance( + _item[1], (float, np.float, np.float32) + ), f"{feature.__class__.__name__}.__iter__ value type error" From 70fc58104bfb3a4666382b3441887591c40d3c9c Mon Sep 17 00:00:00 2001 From: zhupr Date: Thu, 1 Apr 2021 12:58:34 +0800 Subject: [PATCH 04/24] Modify FileStorage --- qlib/data/storage/__init__.py | 2 +- qlib/data/storage/file_storage.py | 197 ++++++++++++++++++++++++------ qlib/data/storage/storage.py | 121 ++++++++++++++---- 3 files changed, 262 insertions(+), 58 deletions(-) diff --git a/qlib/data/storage/__init__.py b/qlib/data/storage/__init__.py index f42504791..552e1e3e8 100644 --- a/qlib/data/storage/__init__.py +++ b/qlib/data/storage/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .storage import CalendarStorage, InstrumentStorage, FeatureStorage +from .storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstVT, InstKT diff --git a/qlib/data/storage/file_storage.py b/qlib/data/storage/file_storage.py index aadc918c3..e2e5bd3e7 100644 --- a/qlib/data/storage/file_storage.py +++ b/qlib/data/storage/file_storage.py @@ -3,69 +3,193 @@ import struct from pathlib import Path -from typing import Iterator, Iterable, Type, List, Tuple, Text, Union - -from .storage import FeatureVT +from typing import Iterator, Iterable, Union, Dict, Mapping, Tuple import numpy as np import pandas as pd -from . import CalendarStorage, InstrumentStorage, FeatureStorage - -CalVT = Type[pd.Timestamp] -# instrument value -InstVT = List[Tuple[CalVT, CalVT]] -# instrument key -InstKT = Text +from . import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT class FileCalendarStorage(CalendarStorage): def __init__(self, uri: str): - super(FileCalendarStorage, self).__init__(uri=uri) - with open(uri) as f: - self._data = [pd.Timestamp(x.strip()) for x in f] + super(FileCalendarStorage, self).__init__(uri) + self._uri = Path(self._uri).expanduser().resolve() + + def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> np.ndarray: + if not self._uri.exists(): + self._write_calendar(values=[]) + with self._uri.open("rb") as fp: + return np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, encoding="utf-8") + + def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"): + with self._uri.open(mode=mode) as fp: + np.savetxt(fp, values, fmt="%s", encoding="utf-8") + + def extend(self, values: Iterable[CalVT]) -> None: + self._write_calendar(values, mode="ab") + + def clear(self) -> None: + self._write_calendar(values=[]) + + def index(self, value: CalVT) -> int: + calendar = self._read_calendar() + return int(np.argwhere(calendar == value)[0]) + + def insert(self, index: int, value: CalVT): + calendar = self._read_calendar() + calendar = np.insert(calendar, index, value) + self._write_calendar(values=calendar) + + def remove(self, value: CalVT) -> None: + index = self.index(value) + calendar = self._read_calendar() + calendar = np.delete(calendar, index) + self._write_calendar(values=calendar) + + def __setitem__(self, i: Union[int, slice], values: Union[CalVT, Iterable[CalVT]]) -> None: + calendar = self._read_calendar() + calendar[i] = values + self._write_calendar(values=calendar) + + def __delitem__(self, i: Union[int, slice]) -> None: + calendar = self._read_calendar() + calendar = np.delete(calendar, i) + self._write_calendar(values=calendar) def __getitem__(self, i: Union[int, slice]) -> Union[CalVT, Iterable[CalVT]]: - if isinstance(i, (int, slice)): - return self._data[i] - else: - raise TypeError(f"type(i) = {type(i)}") + return self._read_calendar()[i] def __len__(self) -> int: - return len(self._data) + return len(self._read_calendar()) + + def __iter__(self): + with self._uri.open("r") as fp: + yield fp.readline() class FileInstrumentStorage(InstrumentStorage): + INSTRUMENT_SEP = "\t" + INSTRUMENT_START_FIELD = "start_datetime" + INSTRUMENT_END_FIELD = "end_datetime" + SYMBOL_FIELD_NAME = "instrument" + def __init__(self, uri: str): super(FileInstrumentStorage, self).__init__(uri=uri) - self._data = self._load_data() + self._uri = Path(self._uri).expanduser().resolve() + + def _read_instrument(self) -> Dict[InstKT, InstVT]: + if not self._uri.exists(): + self._write_instrument() - def _load_data(self): _instruments = dict() df = pd.read_csv( self._uri, sep="\t", usecols=[0, 1, 2], - names=["inst", "start_datetime", "end_datetime"], - dtype={"inst": str}, - parse_dates=["start_datetime", "end_datetime"], + names=[self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD], + dtype={self.SYMBOL_FIELD_NAME: str}, + parse_dates=[self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD], ) for row in df.itertuples(index=False): _instruments.setdefault(row[0], []).append((row[1], row[2])) return _instruments + def _write_instrument(self, data: Dict[InstKT, InstVT] = None) -> None: + if not data: + with self._uri.open("w") as _: + pass + return + + res = [] + for inst, v_list in data.items(): + _df = pd.DataFrame(v_list, columns=[self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD]) + _df[self.SYMBOL_FIELD_NAME] = inst + res.append(_df) + + df = pd.concat(res, sort=False) + df.loc[:, [self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD]].to_csv( + self._uri, header=False, sep=self.INSTRUMENT_SEP, index=False + ) + df.to_csv(self._uri, sep="\t", encoding="utf-8", header=False, index=False) + + def clear(self) -> None: + self._write_instrument(data={}) + + def __setitem__(self, k: InstKT, v: InstVT) -> None: + inst = self._read_instrument() + inst[k] = v + self._write_instrument(inst) + + def __delitem__(self, k: InstKT) -> None: + inst = self._read_instrument() + del inst[k] + self._write_instrument(inst) + def __getitem__(self, k: InstKT) -> InstVT: - return self._data[k] + return self._read_instrument()[k] def __len__(self) -> int: - return len(self._data) + inst = self._read_instrument() + return len(inst) def __iter__(self) -> Iterator[InstKT]: - return self._data.__iter__() + for _inst in self._read_instrument().keys(): + yield _inst + + def update(self, *args, **kwargs) -> None: + + if len(args) > 1: + raise TypeError(f"update expected at most 1 arguments, got {len(args)}") + inst = self._read_instrument() + if args: + other = args[0] + if isinstance(other, Mapping): + for key in other: + inst[key] = other[key] + elif hasattr(other, "keys"): + for key in other.keys(): + inst[key] = other[key] + else: + for key, value in other: + inst[key] = value + for key, value in kwargs.items(): + inst[key] = value + + self._write_instrument(inst) class FileFeatureStorage(FeatureStorage): - def __getitem__(self, i: Union[int, slice]) -> Union[FeatureVT, Iterable[FeatureVT]]: + def __init__(self, uri: str): + super(FileFeatureStorage, self).__init__(uri=uri) + self._uri = Path(self._uri) + + def clear(self): + with self._uri.open("wb") as _: + pass + + def extend(self, series: pd.Series) -> None: + extend_start_index = self[0][0] + len(self) if self._uri.exists() else series.index[0] + series = series.reindex(pd.RangeIndex(extend_start_index, series.index[-1] + 1)) + with self._uri.open("ab") as fp: + np.array(series.values).astype(" None: + origin_series = self[:] + series = series.append(origin_series.loc[origin_series.index > series.index[-1]]) + series = series.reindex(pd.RangeIndex(series.index[0], series.index[-1])) + with self._uri.open("wb") as fp: + np.array(series.values).astype(" Union[Tuple[int, float], pd.Series]: + if not self._uri.exists(): + if isinstance(i, int): + return None, None + elif isinstance(i, slice): + return pd.Series() + else: + raise TypeError(f"type(i) = {type(i)}") + with open(self._uri, "rb") as fp: ref_start_index = int(np.frombuffer(fp.read(4), dtype=" end_index: - return [] + return pd.Series() fp.seek(4 * (si - ref_start_index) + 4) # read n bytes count = end_index - si + 1 data = np.frombuffer(fp.read(4 * count), dtype=" int: - return Path(self._uri).stat().st_size // 4 - 1 + return self._uri.stat().st_size // 4 - 1 if self._uri.exists() else 0 def __iter__(self): + if not self._uri.exists(): + return with open(self._uri, "rb") as fp: ref_start_index = int(np.frombuffer(fp.read(4), dtype=" None: + def extend(self, iterable: Iterable[CalVT]) -> None: + raise NotImplementedError("Subclass of CalendarStorage must implement `extend` method") + + def clear(self) -> None: + raise NotImplementedError("Subclass of CalendarStorage must implement `clear` method") + + def index(self, value: CalVT) -> int: + raise NotImplementedError("Subclass of CalendarStorage must implement `index` method") + + def insert(self, index: int, value: CalVT) -> None: raise NotImplementedError("Subclass of CalendarStorage must implement `insert` method") + def remove(self, value: CalVT) -> None: + raise NotImplementedError("Subclass of CalendarStorage must implement `remove` method") + @overload - def __setitem__(self, i: int, o: CalVT) -> None: + def __setitem__(self, i: int, value: CalVT) -> None: """x.__setitem__(i, o) <==> x[i] = o""" ... @overload - def __setitem__(self, s: slice, o: Iterable[CalVT]) -> None: + def __setitem__(self, s: slice, value: Iterable[CalVT]) -> None: """x.__setitem__(s, o) <==> x[s] = o""" ... - def __setitem__(self, i, o) -> None: + def __setitem__(self, i, value) -> None: raise NotImplementedError( "Subclass of CalendarStorage must implement `__setitem__(i: int, o: CalVT)`/`__setitem__(s: slice, o: Iterable[CalVT])` method" ) @@ -74,10 +83,21 @@ class CalendarStorage(MutableSequence): raise NotImplementedError("Subclass of CalendarStorage must implement `__len__` method") -class InstrumentStorage(MutableMapping): +class InstrumentStorage: def __init__(self, uri: str): self._uri = uri + def clear(self) -> None: + raise NotImplementedError("Subclass of InstrumentStorage must implement `clear` method") + + def update(self, *args, **kwargs) -> None: + """D.update([E, ]**F) -> None. Update D from mapping/iterable E and F. + If E present and has a .keys() method, does: for k in E: D[k] = E[k] + If E present and lacks .keys() method, does: for (k, v) in E: D[k] = v + In either case, this is followed by: for k, v in F.items(): D[k] = v + """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `update` method") + def __setitem__(self, k: InstKT, v: InstVT) -> None: """ Set self[key] to value. """ raise NotImplementedError("Subclass of InstrumentStorage must implement `__setitem__` method") @@ -99,37 +119,92 @@ class InstrumentStorage(MutableMapping): raise NotImplementedError("Subclass of InstrumentStorage must implement `__iter__` method") -class FeatureStorage(Sequence): +class FeatureStorage: def __init__(self, uri: str): self._uri = uri - def append(self, obj: FeatureVT) -> None: - """ Append object to the end of the FeatureStorage. """ - raise NotImplementedError("Subclass of FeatureStorage must implement `append` method") - def clear(self): """ Remove all items from FeatureStorage. """ raise NotImplementedError("Subclass of FeatureStorage must implement `clear` method") - def extend(self, iterable: Iterable[FeatureVT]): - """ Extend list by appending elements from the iterable. """ + def extend(self, series: pd.Series): + """Extend feature by appending elements from the series. + + Examples: + + feature: + 3 4 + 4 5 + 5 6 + + >>> self.extend(pd.Series({7: 8, 9:10})) + + feature: + 3 4 + 4 5 + 5 6 + 6 np.nan + 7 8 + 9 10 + + """ raise NotImplementedError("Subclass of FeatureStorage must implement `extend` method") + def rebase(self, series: pd.Series): + """Rebase feature header from the series. + + Examples: + + feature: + 3 4 + 4 5 + 5 6 + + >>> self.rebase(pd.Series({1: 2})) + + feature: + 1 2 + 2 np.nan + 3 4 + 4 5 + 5 6 + + >>> self.rebase(pd.Series({5: 6, 7: 8, 9: 10})) + + feature: + 5 6 + 7 8 + 9 10 + + >>> self.rebase(pd.Series({11: 12, 12: 13,})) + + feature: + 11 12 + 12 13 + + """ + raise NotImplementedError("Subclass of FeatureStorage must implement `rebase` method") + @overload - def __getitem__(self, s: slice) -> Iterable[FeatureVT]: - """x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]""" + def __getitem__(self, s: slice) -> pd.Series: + """x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step] == pd.Series(values, index=pd.RangeIndex(start, len(values))""" ... @overload - def __getitem__(self, i: int) -> float: + def __getitem__(self, i: int) -> Tuple[int, float]: """x.__getitem__(y) <==> x[y]""" ... - def __getitem__(self, i) -> float: + def __getitem__(self, i) -> Union[Tuple[int, float], pd.Series]: """x.__getitem__(y) <==> x[y]""" raise NotImplementedError( "Subclass of FeatureStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method" ) def __len__(self) -> int: + """len(feature) <==> feature.__len__() """ raise NotImplementedError("Subclass of FeatureStorage must implement `__len__` method") + + def __iter__(self) -> Iterable[Tuple[int, float]]: + """iter(feature)""" + raise NotImplementedError("Subclass of FeatureStorage must implement `__iter__` method") From 317357b50d2e54a2663e9fe5f14def3956423248 Mon Sep 17 00:00:00 2001 From: zhupr Date: Tue, 13 Apr 2021 10:47:01 +0800 Subject: [PATCH 05/24] Modify data.storage --- qlib/data/data.py | 92 +++++++++++++++-------------- qlib/data/storage/file_storage.py | 79 +++++++++++++++---------- qlib/data/storage/storage.py | 38 +++++++++--- tests/storage_tests/test_storage.py | 23 +++++--- 4 files changed, 141 insertions(+), 91 deletions(-) diff --git a/qlib/data/data.py b/qlib/data/data.py index 000bd1196..1a0ca616e 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -6,6 +6,7 @@ from __future__ import division from __future__ import print_function import os +import re import abc import time import queue @@ -27,12 +28,35 @@ from .cache import DiskDatasetCache, DiskExpressionCache from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path -class CalendarProvider(abc.ABC): +class ProviderBackendMixin: + def get_default_backend(self): + backend = {} + provider_name = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2] # type: str + # set default storage class + backend.setdefault("class", f"File{provider_name}Storage") + # set default storage module + backend.setdefault("module_path", "qlib.data.storage.file_storage") + # set default storage kwargs + backend_kwargs = backend.setdefault("kwargs", {}) # type: dict + backend_kwargs.setdefault("uri", os.path.join(C.get_data_path(), f"{provider_name.lower()}s")) + return backend + + @property + def backend_obj(self): + return init_instance_by_config(self.backend) + + +class CalendarProvider(abc.ABC, ProviderBackendMixin): """Calendar provider base class Provide calendar data. """ + def __init__(self, *args, **kwargs): + self.backend = kwargs.get("backend", {}) + if not self.backend: + self.backend = self.get_default_backend() + @abc.abstractmethod def calendar(self, start_time=None, end_time=None, freq="day", future=False): """Get calendar of certain market in given time range. @@ -127,12 +151,17 @@ class CalendarProvider(abc.ABC): return hash_args(start_time, end_time, freq, future) -class InstrumentProvider(abc.ABC): +class InstrumentProvider(abc.ABC, ProviderBackendMixin): """Instrument provider base class Provide instrument data. """ + def __init__(self, *args, **kwargs): + self.backend = kwargs.get("backend", {}) + if not self.backend: + self.backend = self.get_default_backend() + @staticmethod def instruments(market="all", filter_pipe=None): """Get the general config dictionary for a base market adding several dynamic filters. @@ -215,12 +244,17 @@ class InstrumentProvider(abc.ABC): raise ValueError(f"Unknown instrument type {inst}") -class FeatureProvider(abc.ABC): +class FeatureProvider(abc.ABC, ProviderBackendMixin): """Feature provider class Provide feature data. """ + def __init__(self, *args, **kwargs): + self.backend = kwargs.get("backend", {}) + if not self.backend: + self.backend = self.get_default_backend() + @abc.abstractmethod def feature(self, instrument, field, start_time, end_time, freq): """Get feature data. @@ -497,6 +531,7 @@ class LocalCalendarProvider(CalendarProvider): """ def __init__(self, **kwargs): + super(LocalCalendarProvider, self).__init__(**kwargs) self.remote = kwargs.get("remote", False) @property @@ -517,18 +552,8 @@ class LocalCalendarProvider(CalendarProvider): list list of timestamps """ - if future: - fname = self._uri_cal.format(freq + "_future") - # if future calendar not exists, return current calendar - if not os.path.exists(fname): - get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!") - fname = self._uri_cal.format(freq) - else: - fname = self._uri_cal.format(freq) - if not os.path.exists(fname): - raise ValueError("calendar not exists for freq " + freq) - with open(fname) as f: - return [pd.Timestamp(x.strip()) for x in f] + self.backend.setdefault("kwargs", {}).update(freq=freq, future=future) + return [pd.Timestamp(x) for x in self.backend_obj.data] def calendar(self, start_time=None, end_time=None, freq="day", future=False): _calendar, _calendar_index = self._get_calendar(freq, future) @@ -559,31 +584,15 @@ class LocalInstrumentProvider(InstrumentProvider): Provide instrument data from local data source. """ - def __init__(self): - pass - @property def _uri_inst(self): """Instrument file uri.""" return os.path.join(C.get_data_path(), "instruments", "{}.txt") def _load_instruments(self, market): - fname = self._uri_inst.format(market) - if not os.path.exists(fname): - raise ValueError("instruments not exists for market " + market) - _instruments = dict() - df = pd.read_csv( - fname, - sep="\t", - usecols=[0, 1, 2], - names=["inst", "start_datetime", "end_datetime"], - dtype={"inst": str}, - parse_dates=["start_datetime", "end_datetime"], - ) - for row in df.itertuples(index=False): - _instruments.setdefault(row[0], []).append((row[1], row[2])) - return _instruments + self.backend.setdefault("kwargs", {}).update(market=market) + return self.backend_obj.data def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): market = instruments["market"] @@ -601,7 +610,7 @@ class LocalInstrumentProvider(InstrumentProvider): inst: list( filter( lambda x: x[0] <= x[1], - [(max(start_time, x[0]), min(end_time, x[1])) for x in spans], + [(max(start_time, pd.Timestamp(x[0])), min(end_time, pd.Timestamp(x[1]))) for x in spans], ) ) for inst, spans in _instruments.items() @@ -627,6 +636,7 @@ class LocalFeatureProvider(FeatureProvider): """ def __init__(self, **kwargs): + super(LocalFeatureProvider, self).__init__(**kwargs) self.remote = kwargs.get("remote", False) @property @@ -638,14 +648,9 @@ class LocalFeatureProvider(FeatureProvider): # validate field = str(field).lower()[1:] instrument = code_to_fname(instrument) - uri_data = self._uri_data.format(instrument.lower(), field, freq) - if not os.path.exists(uri_data): - get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field)) - return pd.Series(dtype=np.float32) - # raise ValueError('uri_data not found: ' + uri_data) - # load - series = read_bin(uri_data, start_index, end_index) - return series + + self.backend.setdefault("kwargs", {}).update(instrument=instrument, field=field, freq=freq) + return self.backend_obj[start_index : end_index + 1] class LocalExpressionProvider(ExpressionProvider): @@ -1061,7 +1066,8 @@ def register_all_wrappers(C): register_wrapper(Cal, _calendar_provider, "qlib.data") logger.debug(f"registering Cal {C.calendar_provider}-{C.calendar_cache}") - register_wrapper(Inst, C.instrument_provider, "qlib.data") + _instrument_provider = init_instance_by_config(C.instrument_provider, module) + register_wrapper(Inst, _instrument_provider, "qlib.data") logger.debug(f"registering Inst {C.instrument_provider}") if getattr(C, "feature_provider", None) is not None: diff --git a/qlib/data/storage/file_storage.py b/qlib/data/storage/file_storage.py index e2e5bd3e7..4090e3230 100644 --- a/qlib/data/storage/file_storage.py +++ b/qlib/data/storage/file_storage.py @@ -8,24 +8,29 @@ from typing import Iterator, Iterable, Union, Dict, Mapping, Tuple import numpy as np import pandas as pd -from . import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT +from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT class FileCalendarStorage(CalendarStorage): - def __init__(self, uri: str): - super(FileCalendarStorage, self).__init__(uri) - self._uri = Path(self._uri).expanduser().resolve() + def __init__(self, freq: str, future: bool, uri: str): + super(FileCalendarStorage, self).__init__(freq, future, uri) + _file_name = f"{freq}_future.txt" if future else f"{freq}.txt" + self.uri = Path(self.uri).expanduser().joinpath(_file_name.lower()) def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> np.ndarray: - if not self._uri.exists(): + if not self.uri.exists(): self._write_calendar(values=[]) - with self._uri.open("rb") as fp: + with self.uri.open("rb") as fp: return np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, encoding="utf-8") def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"): - with self._uri.open(mode=mode) as fp: + with self.uri.open(mode=mode) as fp: np.savetxt(fp, values, fmt="%s", encoding="utf-8") + @property + def data(self) -> Iterable[CalVT]: + return self._read_calendar() + def extend(self, values: Iterable[CalVT]) -> None: self._write_calendar(values, mode="ab") @@ -64,27 +69,27 @@ class FileCalendarStorage(CalendarStorage): return len(self._read_calendar()) def __iter__(self): - with self._uri.open("r") as fp: - yield fp.readline() + return iter(self._read_calendar()) class FileInstrumentStorage(InstrumentStorage): + INSTRUMENT_SEP = "\t" INSTRUMENT_START_FIELD = "start_datetime" INSTRUMENT_END_FIELD = "end_datetime" SYMBOL_FIELD_NAME = "instrument" - def __init__(self, uri: str): - super(FileInstrumentStorage, self).__init__(uri=uri) - self._uri = Path(self._uri).expanduser().resolve() + def __init__(self, market: str, uri: str): + super(FileInstrumentStorage, self).__init__(market, uri) + self.uri = Path(self.uri).expanduser().joinpath(f"{market.lower()}.txt") def _read_instrument(self) -> Dict[InstKT, InstVT]: - if not self._uri.exists(): + if not self.uri.exists(): self._write_instrument() _instruments = dict() df = pd.read_csv( - self._uri, + self.uri, sep="\t", usecols=[0, 1, 2], names=[self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD], @@ -97,7 +102,7 @@ class FileInstrumentStorage(InstrumentStorage): def _write_instrument(self, data: Dict[InstKT, InstVT] = None) -> None: if not data: - with self._uri.open("w") as _: + with self.uri.open("w") as _: pass return @@ -109,13 +114,17 @@ class FileInstrumentStorage(InstrumentStorage): df = pd.concat(res, sort=False) df.loc[:, [self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD]].to_csv( - self._uri, header=False, sep=self.INSTRUMENT_SEP, index=False + self.uri, header=False, sep=self.INSTRUMENT_SEP, index=False ) - df.to_csv(self._uri, sep="\t", encoding="utf-8", header=False, index=False) + df.to_csv(self.uri, sep="\t", encoding="utf-8", header=False, index=False) def clear(self) -> None: self._write_instrument(data={}) + @property + def data(self) -> Dict[InstKT, InstVT]: + return self._read_instrument() + def __setitem__(self, k: InstKT, v: InstVT) -> None: inst = self._read_instrument() inst[k] = v @@ -143,7 +152,7 @@ class FileInstrumentStorage(InstrumentStorage): raise TypeError(f"update expected at most 1 arguments, got {len(args)}") inst = self._read_instrument() if args: - other = args[0] + other = args[0] # type: dict if isinstance(other, Mapping): for key in other: inst[key] = other[key] @@ -160,29 +169,35 @@ class FileInstrumentStorage(InstrumentStorage): class FileFeatureStorage(FeatureStorage): - def __init__(self, uri: str): - super(FileFeatureStorage, self).__init__(uri=uri) - self._uri = Path(self._uri) + def __init__(self, instrument: str, field: str, freq: str, uri: str): + super(FileFeatureStorage, self).__init__(instrument, field, freq, uri) + self.uri = ( + Path(self.uri).expanduser().joinpath(instrument.lower()).joinpath(f"{field.lower()}.{freq.lower()}.bin") + ) def clear(self): - with self._uri.open("wb") as _: + with self.uri.open("wb") as _: pass + @property + def data(self) -> pd.Series: + return self[:] + def extend(self, series: pd.Series) -> None: - extend_start_index = self[0][0] + len(self) if self._uri.exists() else series.index[0] + extend_start_index = self[0][0] + len(self) if self.uri.exists() else series.index[0] series = series.reindex(pd.RangeIndex(extend_start_index, series.index[-1] + 1)) - with self._uri.open("ab") as fp: + with self.uri.open("ab") as fp: np.array(series.values).astype(" None: origin_series = self[:] series = series.append(origin_series.loc[origin_series.index > series.index[-1]]) series = series.reindex(pd.RangeIndex(series.index[0], series.index[-1])) - with self._uri.open("wb") as fp: + with self.uri.open("wb") as fp: np.array(series.values).astype(" Union[Tuple[int, float], pd.Series]: - if not self._uri.exists(): + if not self.uri.exists(): if isinstance(i, int): return None, None elif isinstance(i, slice): @@ -190,14 +205,14 @@ class FileFeatureStorage(FeatureStorage): else: raise TypeError(f"type(i) = {type(i)}") - with open(self._uri, "rb") as fp: + with open(self.uri, "rb") as fp: ref_start_index = int(np.frombuffer(fp.read(4), dtype=" i: raise IndexError(f"{i}: start index is {ref_start_index}") fp.seek(4 * (i - ref_start_index) + 4) - return i, struct.unpack("f", fp.read(4)) + return i, struct.unpack("f", fp.read(4))[0] elif isinstance(i, slice): start_index = i.start end_index = i.stop - 1 @@ -213,18 +228,18 @@ class FileFeatureStorage(FeatureStorage): raise TypeError(f"type(i) = {type(i)}") def __len__(self) -> int: - return self._uri.stat().st_size // 4 - 1 if self._uri.exists() else 0 + return self.uri.stat().st_size // 4 - 1 if self.uri.exists() else 0 def __iter__(self): - if not self._uri.exists(): + if not self.uri.exists(): return - with open(self._uri, "rb") as fp: + with open(self.uri, "rb") as fp: ref_start_index = int(np.frombuffer(fp.read(4), dtype=" Iterable[CalVT]: + """get all data""" + raise NotImplementedError("Subclass of CalendarStorage must implement `data` method") def extend(self, iterable: Iterable[CalVT]) -> None: raise NotImplementedError("Subclass of CalendarStorage must implement `extend` method") @@ -82,10 +89,19 @@ class CalendarStorage: """x.__len__() <==> len(x)""" raise NotImplementedError("Subclass of CalendarStorage must implement `__len__` method") + def __iter__(self): + raise NotImplementedError("Subclass of CalendarStorage must implement `__iter__` method") + class InstrumentStorage: - def __init__(self, uri: str): - self._uri = uri + def __init__(self, market: str, uri: str): + self.market = market + self.uri = uri + + @property + def data(self) -> Dict[InstKT, InstVT]: + """get all data""" + raise NotImplementedError("Subclass of InstrumentStorage must implement `data` method") def clear(self) -> None: raise NotImplementedError("Subclass of InstrumentStorage must implement `clear` method") @@ -120,8 +136,16 @@ class InstrumentStorage: class FeatureStorage: - def __init__(self, uri: str): - self._uri = uri + def __init__(self, instrument: str, field: str, freq: str, uri: str): + self.instrument = instrument + self.field = field + self.freq = freq + self.uri = uri + + @property + def data(self) -> pd.Series: + """get all data""" + raise NotImplementedError("Subclass of FeatureStorage must implement `data` method") def clear(self): """ Remove all items from FeatureStorage. """ diff --git a/tests/storage_tests/test_storage.py b/tests/storage_tests/test_storage.py index a70ce82ea..8ce3f5081 100644 --- a/tests/storage_tests/test_storage.py +++ b/tests/storage_tests/test_storage.py @@ -16,15 +16,16 @@ from qlib.data.storage.file_storage import ( FileFeatureStorage as FeatureStorage, ) -DATA_DIR = Path(__file__).parent.joinpath("test_get_data") +_file_name = Path(__file__).name.split(".")[0] +DATA_DIR = Path(__file__).parent.joinpath(f"{_file_name}_data") QLIB_DIR = DATA_DIR.joinpath("qlib") QLIB_DIR.mkdir(exist_ok=True, parents=True) # TODO: set value -CALENDAR_URI = QLIB_DIR.joinpath("calendars").joinpath("day.txt") -INSTRUMENT_URI = QLIB_DIR.joinpath("instruments").joinpath("csi300.txt") -FEATURE_URI = QLIB_DIR.joinpath("features").joinpath("SH600004").joinpath("close.day.bin") +CALENDAR_URI = QLIB_DIR.joinpath("calendars") +INSTRUMENT_URI = QLIB_DIR.joinpath("instruments") +FEATURE_URI = QLIB_DIR.joinpath("features") class TestStorage: @@ -38,9 +39,10 @@ class TestStorage: def test_calendar_storage(self): - calendar = CalendarStorage(uri=CALENDAR_URI) + calendar = CalendarStorage(freq="day", future=False, uri=CALENDAR_URI) assert isinstance(calendar, Iterable), f"{calendar.__class__.__name__} is not Iterable" assert isinstance(calendar[:], Iterable), f"{calendar.__class__.__name__}.__getitem__(s: slice) is not Iterable" + assert isinstance(calendar.data, Iterable), f"{calendar.__class__.__name__}.data is not Iterable" print(f"calendar[1: 5]: {calendar[1:5]}") print(f"calendar[0]: {calendar[0]}") @@ -80,11 +82,11 @@ class TestStorage: """ - instrument = InstrumentStorage(uri=INSTRUMENT_URI) + instrument = InstrumentStorage(market="csi300", uri=INSTRUMENT_URI) assert isinstance(instrument, Iterable), f"{instrument.__class__.__name__} is not Iterable" - for inst, spans in instrument.items(): + for inst, spans in instrument.data.items(): assert isinstance(inst, str) and isinstance( spans, Iterable ), f"{instrument.__class__.__name__} value is not Iterable" @@ -149,11 +151,14 @@ class TestStorage: """ - feature = FeatureStorage(uri=FEATURE_URI) + feature = FeatureStorage(instrument="SH600004", field="close", freq="day", uri=FEATURE_URI) assert isinstance(feature, Iterable), f"{feature.__class__.__name__} is not Iterable" with pytest.raises(IndexError): print(feature[0]) + assert isinstance( + feature[815][1], (np.float, np.float32) + ), f"{feature.__class__.__name__}.__getitem__(i: int) error" assert len(feature[815:818]) == 3, f"{feature.__class__.__name__}.__getitem__(s: slice) error" print(f"feature[815: 818]: {feature[815: 818]}") @@ -162,5 +167,5 @@ class TestStorage: isinstance(_item, tuple) and len(_item) == 2 ), f"{feature.__class__.__name__}.__iter__ item type error" assert isinstance(_item[0], int) and isinstance( - _item[1], (float, np.float, np.float32) + _item[1], (np.float, np.float32) ), f"{feature.__class__.__name__}.__iter__ value type error" From 4ba4512619560229456c3df664a4374cdc72be18 Mon Sep 17 00:00:00 2001 From: zhupr Date: Fri, 21 May 2021 08:43:36 +0800 Subject: [PATCH 06/24] add write method to FeatureStorage && remove extend --- qlib/data/data.py | 62 ++++--- qlib/data/storage/file_storage.py | 122 +++++++------ qlib/data/storage/storage.py | 263 +++++++++++++++++++++------- qlib/tests/__init__.py | 10 +- tests/storage_tests/test_storage.py | 33 +--- 5 files changed, 321 insertions(+), 169 deletions(-) diff --git a/qlib/data/data.py b/qlib/data/data.py index 1a0ca616e..3848a6823 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -8,6 +8,7 @@ from __future__ import print_function import os import re import abc +import copy import time import queue import bisect @@ -31,19 +32,27 @@ from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_modu class ProviderBackendMixin: def get_default_backend(self): backend = {} - provider_name = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2] # type: str + provider_name: str = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2] # set default storage class backend.setdefault("class", f"File{provider_name}Storage") # set default storage module backend.setdefault("module_path", "qlib.data.storage.file_storage") - # set default storage kwargs - backend_kwargs = backend.setdefault("kwargs", {}) # type: dict - backend_kwargs.setdefault("uri", os.path.join(C.get_data_path(), f"{provider_name.lower()}s")) return backend - @property - def backend_obj(self): - return init_instance_by_config(self.backend) + def backend_obj(self, **kwargs): + backend = self.backend if self.backend else self.get_default_backend() + backend = copy.deepcopy(backend) + + # set default storage kwargs + backend_kwargs = backend.setdefault("kwargs", {}) + # default uri map + if "uri" not in backend_kwargs: + # if the user has no uri configured, use: uri = uri_map[freq] + freq = kwargs.get("freq", "day") + uri_map = backend_kwargs.setdefault("uri_map", {freq: C.get_data_path()}) + backend_kwargs["uri"] = uri_map[freq] + backend.setdefault("kwargs", {}).update(**kwargs) + return init_instance_by_config(backend) class CalendarProvider(abc.ABC, ProviderBackendMixin): @@ -54,8 +63,6 @@ class CalendarProvider(abc.ABC, ProviderBackendMixin): def __init__(self, *args, **kwargs): self.backend = kwargs.get("backend", {}) - if not self.backend: - self.backend = self.get_default_backend() @abc.abstractmethod def calendar(self, start_time=None, end_time=None, freq="day", future=False): @@ -159,8 +166,6 @@ class InstrumentProvider(abc.ABC, ProviderBackendMixin): def __init__(self, *args, **kwargs): self.backend = kwargs.get("backend", {}) - if not self.backend: - self.backend = self.get_default_backend() @staticmethod def instruments(market="all", filter_pipe=None): @@ -252,8 +257,6 @@ class FeatureProvider(abc.ABC, ProviderBackendMixin): def __init__(self, *args, **kwargs): self.backend = kwargs.get("backend", {}) - if not self.backend: - self.backend = self.get_default_backend() @abc.abstractmethod def feature(self, instrument, field, start_time, end_time, freq): @@ -552,8 +555,18 @@ class LocalCalendarProvider(CalendarProvider): list list of timestamps """ - self.backend.setdefault("kwargs", {}).update(freq=freq, future=future) - return [pd.Timestamp(x) for x in self.backend_obj.data] + + backend_obj = self.backend_obj(freq=freq, future=future) + if future and not backend_obj.check_exists(): + get_module_logger("data").warning( + f"load calendar error: freq={freq}, future={future}; return current calendar!" + ) + get_module_logger("data").warning( + "You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md" + ) + backend_obj = self.backend_obj(freq=freq, future=False) + + return [pd.Timestamp(x) for x in backend_obj.data] def calendar(self, start_time=None, end_time=None, freq="day", future=False): _calendar, _calendar_index = self._get_calendar(freq, future) @@ -589,17 +602,15 @@ class LocalInstrumentProvider(InstrumentProvider): """Instrument file uri.""" return os.path.join(C.get_data_path(), "instruments", "{}.txt") - def _load_instruments(self, market): - - self.backend.setdefault("kwargs", {}).update(market=market) - return self.backend_obj.data + def _load_instruments(self, market, freq): + return self.backend_obj(market=market, freq=freq).data def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): market = instruments["market"] if market in H["i"]: _instruments = H["i"][market] else: - _instruments = self._load_instruments(market) + _instruments = self._load_instruments(market, freq=freq) H["i"][market] = _instruments # strip # use calendar boundary @@ -648,9 +659,14 @@ class LocalFeatureProvider(FeatureProvider): # validate field = str(field).lower()[1:] instrument = code_to_fname(instrument) - - self.backend.setdefault("kwargs", {}).update(instrument=instrument, field=field, freq=freq) - return self.backend_obj[start_index : end_index + 1] + try: + data = self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1] + except Exception as e: + get_module_logger("data").warning( + f"WARN: data not found for {instrument}.{field}\n\tException info: {str(e)}" + ) + data = pd.Series(dtype=np.float32) + return data class LocalExpressionProvider(ExpressionProvider): diff --git a/qlib/data/storage/file_storage.py b/qlib/data/storage/file_storage.py index 4090e3230..e55105f57 100644 --- a/qlib/data/storage/file_storage.py +++ b/qlib/data/storage/file_storage.py @@ -3,25 +3,36 @@ import struct from pathlib import Path -from typing import Iterator, Iterable, Union, Dict, Mapping, Tuple +from typing import Iterable, Union, Dict, Mapping, Tuple, List import numpy as np import pandas as pd +from qlib.log import get_module_logger from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT +logger = get_module_logger("file_storage") -class FileCalendarStorage(CalendarStorage): - def __init__(self, freq: str, future: bool, uri: str): - super(FileCalendarStorage, self).__init__(freq, future, uri) + +class FileStorage: + def check_exists(self): + return self.uri.exists() + + +class FileCalendarStorage(FileStorage, CalendarStorage): + def __init__(self, freq: str, future: bool, uri: str, **kwargs): + super(FileCalendarStorage, self).__init__(freq, future, uri, **kwargs) _file_name = f"{freq}_future.txt" if future else f"{freq}.txt" - self.uri = Path(self.uri).expanduser().joinpath(_file_name.lower()) + self.uri = Path(self.uri).expanduser().joinpath("calendars", _file_name.lower()) - def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> np.ndarray: - if not self.uri.exists(): + def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> Iterable[CalVT]: + if not self.check_exists(): self._write_calendar(values=[]) with self.uri.open("rb") as fp: - return np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, encoding="utf-8") + return [ + str(x) + for x in np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, delimiter="\n", encoding="utf-8") + ] def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"): with self.uri.open(mode=mode) as fp: @@ -65,23 +76,17 @@ class FileCalendarStorage(CalendarStorage): def __getitem__(self, i: Union[int, slice]) -> Union[CalVT, Iterable[CalVT]]: return self._read_calendar()[i] - def __len__(self) -> int: - return len(self._read_calendar()) - def __iter__(self): - return iter(self._read_calendar()) - - -class FileInstrumentStorage(InstrumentStorage): +class FileInstrumentStorage(FileStorage, InstrumentStorage): INSTRUMENT_SEP = "\t" INSTRUMENT_START_FIELD = "start_datetime" INSTRUMENT_END_FIELD = "end_datetime" SYMBOL_FIELD_NAME = "instrument" - def __init__(self, market: str, uri: str): - super(FileInstrumentStorage, self).__init__(market, uri) - self.uri = Path(self.uri).expanduser().joinpath(f"{market.lower()}.txt") + def __init__(self, market: str, uri: str, **kwargs): + super(FileInstrumentStorage, self).__init__(market, uri, **kwargs) + self.uri = Path(self.uri).expanduser().joinpath("instruments", f"{market.lower()}.txt") def _read_instrument(self) -> Dict[InstKT, InstVT]: if not self.uri.exists(): @@ -138,14 +143,6 @@ class FileInstrumentStorage(InstrumentStorage): def __getitem__(self, k: InstKT) -> InstVT: return self._read_instrument()[k] - def __len__(self) -> int: - inst = self._read_instrument() - return len(inst) - - def __iter__(self) -> Iterator[InstKT]: - for _inst in self._read_instrument().keys(): - yield _inst - def update(self, *args, **kwargs) -> None: if len(args) > 1: @@ -168,11 +165,11 @@ class FileInstrumentStorage(InstrumentStorage): self._write_instrument(inst) -class FileFeatureStorage(FeatureStorage): - def __init__(self, instrument: str, field: str, freq: str, uri: str): - super(FileFeatureStorage, self).__init__(instrument, field, freq, uri) +class FileFeatureStorage(FileStorage, FeatureStorage): + def __init__(self, instrument: str, field: str, freq: str, uri: str, **kwargs): + super(FileFeatureStorage, self).__init__(instrument, field, freq, uri, **kwargs) self.uri = ( - Path(self.uri).expanduser().joinpath(instrument.lower()).joinpath(f"{field.lower()}.{freq.lower()}.bin") + Path(self.uri).expanduser().joinpath("features", instrument.lower(), f"{field.lower()}.{freq.lower()}.bin") ) def clear(self): @@ -183,18 +180,45 @@ class FileFeatureStorage(FeatureStorage): def data(self) -> pd.Series: return self[:] - def extend(self, series: pd.Series) -> None: - extend_start_index = self[0][0] + len(self) if self.uri.exists() else series.index[0] - series = series.reindex(pd.RangeIndex(extend_start_index, series.index[-1] + 1)) - with self.uri.open("ab") as fp: - np.array(series.values).astype(" None: + if len(data_array) == 0: + logger.info( + "len(data_array) == 0, write" + "if you need to clear the FeatureStorage, please execute: FeatureStorage.clear" + ) + return + if not self.uri.exists(): + # write + index = 0 if index is None else index + with self.uri.open("wb") as fp: + np.hstack([index, data_array]).astype(" self.end_index: + # append + index = 0 if index is None else index + with self.uri.open("ab+") as fp: + np.hstack([[np.nan] * (index - self.end_index - 1), data_array]).astype(" None: - origin_series = self[:] - series = series.append(origin_series.loc[origin_series.index > series.index[-1]]) - series = series.reindex(pd.RangeIndex(series.index[0], series.index[-1])) - with self.uri.open("wb") as fp: - np.array(series.values).astype(" Union[int, None]: + if len(self) == 0: + return None + with open(self.uri, "rb") as fp: + index = int(np.frombuffer(fp.read(4), dtype=" Union[Tuple[int, float], pd.Series]: if not self.uri.exists(): @@ -228,18 +252,4 @@ class FileFeatureStorage(FeatureStorage): raise TypeError(f"type(i) = {type(i)}") def __len__(self) -> int: - return self.uri.stat().st_size // 4 - 1 if self.uri.exists() else 0 - - def __iter__(self): - if not self.uri.exists(): - return - with open(self.uri, "rb") as fp: - ref_start_index = int(np.frombuffer(fp.read(4), dtype=" pd.Series: + pass + +""" + + +class StorageMeta(type): + """unified management of raise when storage is not exists""" + + def __new__(cls, name, bases, dict): + class_obj = type.__new__(cls, name, bases, dict) + + # The calls to __iter__ and __getitem__ do not pass through __getattribute__. + # In order to throw an exception before calling __getitem__, use the metaclass + _getitem_func = getattr(class_obj, "__getitem__") + + def _getitem(obj, item): + _check_func = getattr(obj, "_check") + if callable(_check_func): + _check_func() + return _getitem_func(obj, item) + + setattr(class_obj, "__getitem__", _getitem) + return class_obj + + +class BaseStorage(metaclass=StorageMeta): + @property + def storage_name(self) -> str: + return re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2] + + def check_exists(self) -> bool: + """check if storage(uri) exists, if not exists: return False""" + raise NotImplementedError("Subclass of BaseStorage must implement `check_exists` method") + + def clear(self) -> None: + """clear storage""" + raise NotImplementedError("Subclass of BaseStorage must implement `clear` method") + + def __len__(self) -> 0: + return len(self.data) if self.check_exists() else 0 + + def __getitem__(self, item: Union[slice, Union[int, InstKT]]): + raise NotImplementedError( + "Subclass of BaseStorage must implement `__getitem__(i: Union[int, InstKT])`/`__getitem__(s: slice)` method" + ) + + def _check(self): + # check storage(uri) + if not self.check_exists(): + parameters_info = [f"{_k}={_v}" for _k, _v in self.__dict__.items()] + raise ValueError(f"{self.storage_name.lower()} not exists, storage parameters: {parameters_info}") + + def __getattribute__(self, item): + if item == "data": + self._check() + return super(BaseStorage, self).__getattribute__(item) + + +class CalendarStorage(BaseStorage): + """ + The behavior of CalendarStorage's methods and List's methods of the same name remain consistent + """ + + def __init__(self, freq: str, future: bool, uri: str, **kwargs): self.freq = freq self.future = future self.uri = uri @@ -28,9 +113,6 @@ class CalendarStorage: def extend(self, iterable: Iterable[CalVT]) -> None: raise NotImplementedError("Subclass of CalendarStorage must implement `extend` method") - def clear(self) -> None: - raise NotImplementedError("Subclass of CalendarStorage must implement `clear` method") - def index(self, value: CalVT) -> int: raise NotImplementedError("Subclass of CalendarStorage must implement `index` method") @@ -85,16 +167,9 @@ class CalendarStorage: "Subclass of CalendarStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method" ) - def __len__(self) -> int: - """x.__len__() <==> len(x)""" - raise NotImplementedError("Subclass of CalendarStorage must implement `__len__` method") - def __iter__(self): - raise NotImplementedError("Subclass of CalendarStorage must implement `__iter__` method") - - -class InstrumentStorage: - def __init__(self, market: str, uri: str): +class InstrumentStorage(BaseStorage): + def __init__(self, market: str, uri: str, **kwargs): self.market = market self.uri = uri @@ -103,9 +178,6 @@ class InstrumentStorage: """get all data""" raise NotImplementedError("Subclass of InstrumentStorage must implement `data` method") - def clear(self) -> None: - raise NotImplementedError("Subclass of InstrumentStorage must implement `clear` method") - def update(self, *args, **kwargs) -> None: """D.update([E, ]**F) -> None. Update D from mapping/iterable E and F. If E present and has a .keys() method, does: for k in E: D[k] = E[k] @@ -126,17 +198,9 @@ class InstrumentStorage: """ x.__getitem__(k) <==> x[k] """ raise NotImplementedError("Subclass of InstrumentStorage must implement `__getitem__` method") - def __len__(self) -> int: - """ Return len(self). """ - raise NotImplementedError("Subclass of InstrumentStorage must implement `__len__` method") - def __iter__(self) -> Iterator[InstKT]: - """ Return iter(self). """ - raise NotImplementedError("Subclass of InstrumentStorage must implement `__iter__` method") - - -class FeatureStorage: - def __init__(self, instrument: str, field: str, freq: str, uri: str): +class FeatureStorage(BaseStorage): + def __init__(self, instrument: str, field: str, freq: str, uri: str, **kwargs): self.instrument = instrument self.field = field self.freq = freq @@ -147,12 +211,25 @@ class FeatureStorage: """get all data""" raise NotImplementedError("Subclass of FeatureStorage must implement `data` method") - def clear(self): - """ Remove all items from FeatureStorage. """ - raise NotImplementedError("Subclass of FeatureStorage must implement `clear` method") + @property + def start_index(self) -> Union[int, None]: + """get FeatureStorage start index + If len(self) == 0; return None + """ + raise NotImplementedError("Subclass of FeatureStorage must implement `data` method") + + @property + def end_index(self) -> Union[int, None]: + if len(self) == 0: + return None + return None if len(self) == 0 else self.start_index + len(self) - 1 + + def write(self, data_array: Union[List, np.ndarray, Tuple], index: int = None): + """Write data_array to FeatureStorage starting from index. + If index is None, append data_array to feature. + If len(data_array) == 0; return + If (index - self.end_index) >= 1, self[end_index+1: index] will be filled with np.nan - def extend(self, series: pd.Series): - """Extend feature by appending elements from the series. Examples: @@ -161,21 +238,42 @@ class FeatureStorage: 4 5 5 6 - >>> self.extend(pd.Series({7: 8, 9:10})) + >>> self.write([6, 7], index=6) feature: 3 4 4 5 5 6 - 6 np.nan - 7 8 - 9 10 + 6 6 + 7 7 + + >>> self.write([8], index=9) + + feature: + 3 4 + 4 5 + 5 6 + 6 6 + 7 7 + 8 np.nan + 9 8 + + >>> self.write([1, np.nan], index=3) + + feature: + 3 1 + 4 np.nan + 5 6 + 6 6 + 7 7 + 8 np.nan + 9 8 """ - raise NotImplementedError("Subclass of FeatureStorage must implement `extend` method") + raise NotImplementedError("Subclass of FeatureStorage must implement `write` method") - def rebase(self, series: pd.Series): - """Rebase feature header from the series. + def rebase(self, start_index: int = None, end_index: int = None): + """Rebase the start_index and end_index of the FeatureStorage. Examples: @@ -184,30 +282,85 @@ class FeatureStorage: 4 5 5 6 - >>> self.rebase(pd.Series({1: 2})) + >>> self.rebase(start_index=4) feature: - 1 2 - 2 np.nan - 3 4 4 5 5 6 - >>> self.rebase(pd.Series({5: 6, 7: 8, 9: 10})) + >>> self.rebase(start_index=3) feature: + 3 np.nan + 4 5 5 6 - 7 8 - 9 10 - >>> self.rebase(pd.Series({11: 12, 12: 13,})) + >>> self.write([3], index=3) feature: - 11 12 - 12 13 + 3 3 + 4 5 + 5 6 + + >>> self.rebase(end_index=4) + + feature: + 3 3 + 4 5 + + >>> self.write([6, 7, 8], index=4) + + feature: + 3 3 + 4 6 + 5 7 + 6 8 + + >>> self.rebase(start_index=4, end_index=5) + + feature: + 4 6 + 5 7 """ - raise NotImplementedError("Subclass of FeatureStorage must implement `rebase` method") + if start_index is None and end_index is None: + logger.warning("both start_index and end_index are None, rebase is ignored") + return + + if start_index < 0 or end_index < 0: + logger.warning("start_index or end_index cannot be less than 0") + return + if start_index > end_index: + logger.warning( + f"start_index({start_index}) > end_index({end_index}), rebase is ignored; " + f"if you need to clear the FeatureStorage, please execute: FeatureStorage.clear" + ) + return + + start_index = self.start_index if start_index is None else end_index + end_index = self.end_index if end_index is None else end_index + if start_index <= self.start_index: + self.write([np.nan] * (self.start_index - start_index), start_index) + else: + self.rewrite(self[start_index:].values, start_index) + + if end_index >= self.end_index: + self.write([np.nan] * (end_index - self.end_index)) + else: + self.rewrite(self[: end_index + 1].values, self.start_index) + + def rewrite(self, data: Union[List, np.ndarray, Tuple], index: int): + """overwrite all data in FeatureStorage with data + + Parameters + ---------- + data: Union[List, np.ndarray, Tuple] + data + index: int + data start index + """ + self.clear() + self.write(data, index) @overload def __getitem__(self, s: slice) -> pd.Series: @@ -224,11 +377,3 @@ class FeatureStorage: raise NotImplementedError( "Subclass of FeatureStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method" ) - - def __len__(self) -> int: - """len(feature) <==> feature.__len__() """ - raise NotImplementedError("Subclass of FeatureStorage must implement `__len__` method") - - def __iter__(self) -> Iterable[Tuple[int, float]]: - """iter(feature)""" - raise NotImplementedError("Subclass of FeatureStorage must implement `__iter__` method") diff --git a/qlib/tests/__init__.py b/qlib/tests/__init__.py index f92e72787..8b53bc53a 100644 --- a/qlib/tests/__init__.py +++ b/qlib/tests/__init__.py @@ -9,19 +9,19 @@ from ..config import REG_CN class TestAutoData(unittest.TestCase): _setup_kwargs = {} + provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir @classmethod def setUpClass(cls) -> None: # use default data - provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") + if not exists_qlib_data(cls.provider_uri): + print(f"Qlib data is not found in {cls.provider_uri}") GetData().qlib_data( name="qlib_data_simple", region="cn", interval="1d", - target_dir=provider_uri, + target_dir=cls.provider_uri, delete_old=False, ) - init(provider_uri=provider_uri, region=REG_CN, **cls._setup_kwargs) + init(provider_uri=cls.provider_uri, region=REG_CN, **cls._setup_kwargs) diff --git a/tests/storage_tests/test_storage.py b/tests/storage_tests/test_storage.py index 8ce3f5081..79ad78b82 100644 --- a/tests/storage_tests/test_storage.py +++ b/tests/storage_tests/test_storage.py @@ -2,13 +2,12 @@ # Licensed under the MIT License. -import shutil from pathlib import Path from collections.abc import Iterable import pytest import numpy as np -from qlib.tests.data import GetData +from qlib.tests import TestAutoData from qlib.data.storage.file_storage import ( FileCalendarStorage as CalendarStorage, @@ -22,25 +21,10 @@ QLIB_DIR = DATA_DIR.joinpath("qlib") QLIB_DIR.mkdir(exist_ok=True, parents=True) -# TODO: set value -CALENDAR_URI = QLIB_DIR.joinpath("calendars") -INSTRUMENT_URI = QLIB_DIR.joinpath("instruments") -FEATURE_URI = QLIB_DIR.joinpath("features") - - -class TestStorage: - @classmethod - def setup_class(cls): - GetData().qlib_data(name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", delete_old=False) - - @classmethod - def teardown_class(cls): - shutil.rmtree(str(DATA_DIR.resolve())) - +class TestStorage(TestAutoData): def test_calendar_storage(self): - calendar = CalendarStorage(freq="day", future=False, uri=CALENDAR_URI) - assert isinstance(calendar, Iterable), f"{calendar.__class__.__name__} is not Iterable" + calendar = CalendarStorage(freq="day", future=False, uri=self.provider_uri) assert isinstance(calendar[:], Iterable), f"{calendar.__class__.__name__}.__getitem__(s: slice) is not Iterable" assert isinstance(calendar.data, Iterable), f"{calendar.__class__.__name__}.data is not Iterable" @@ -82,9 +66,7 @@ class TestStorage: """ - instrument = InstrumentStorage(market="csi300", uri=INSTRUMENT_URI) - - assert isinstance(instrument, Iterable), f"{instrument.__class__.__name__} is not Iterable" + instrument = InstrumentStorage(market="csi300", uri=self.provider_uri) for inst, spans in instrument.data.items(): assert isinstance(inst, str) and isinstance( @@ -151,13 +133,12 @@ class TestStorage: """ - feature = FeatureStorage(instrument="SH600004", field="close", freq="day", uri=FEATURE_URI) + feature = FeatureStorage(instrument="SH600004", field="close", freq="day", uri=self.provider_uri) - assert isinstance(feature, Iterable), f"{feature.__class__.__name__} is not Iterable" with pytest.raises(IndexError): print(feature[0]) assert isinstance( - feature[815][1], (np.float, np.float32) + feature[815][1], (float, np.float32) ), f"{feature.__class__.__name__}.__getitem__(i: int) error" assert len(feature[815:818]) == 3, f"{feature.__class__.__name__}.__getitem__(s: slice) error" print(f"feature[815: 818]: {feature[815: 818]}") @@ -167,5 +148,5 @@ class TestStorage: isinstance(_item, tuple) and len(_item) == 2 ), f"{feature.__class__.__name__}.__iter__ item type error" assert isinstance(_item[0], int) and isinstance( - _item[1], (np.float, np.float32) + _item[1], (float, np.float32) ), f"{feature.__class__.__name__}.__iter__ value type error" From 9e296a8a4ef4651ac607f87953acea2ba4f38fe7 Mon Sep 17 00:00:00 2001 From: zhupr Date: Fri, 21 May 2021 08:56:44 +0800 Subject: [PATCH 07/24] replace the type of numpy deprecated --- qlib/contrib/backtest/position.py | 2 +- qlib/data/dataset/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/qlib/contrib/backtest/position.py b/qlib/contrib/backtest/position.py index 6c269d505..5a6b102b2 100644 --- a/qlib/contrib/backtest/position.py +++ b/qlib/contrib/backtest/position.py @@ -166,7 +166,7 @@ class Position: def save_position(self, path, last_trade_date): path = pathlib.Path(path) p = copy.deepcopy(self.position) - cash = pd.Series(dtype=np.float) + cash = pd.Series(dtype=float) cash["init_cash"] = self.init_cash cash["cash"] = p["cash"] cash["today_account_value"] = p["today_account_value"] diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 0f5d2baba..ef7bfa67e 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -299,7 +299,7 @@ class TSDataSampler: # get the previous index of a line given index """ # object incase of pandas converting int to flaot - idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=np.object) + idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object) idx_df = lazy_sort_index(idx_df.unstack()) # NOTE: the correctness of `__getitem__` depends on columns sorted here idx_df = lazy_sort_index(idx_df, axis=1) From b887d2ec32ab7eafd795b90a14b7a8b046f2d906 Mon Sep 17 00:00:00 2001 From: zhupr Date: Fri, 21 May 2021 10:03:02 +0800 Subject: [PATCH 08/24] code for formatting storage.py using black(v21.5) --- qlib/data/storage/storage.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/qlib/data/storage/storage.py b/qlib/data/storage/storage.py index 8207b5c9a..25f0ce6f1 100644 --- a/qlib/data/storage/storage.py +++ b/qlib/data/storage/storage.py @@ -124,12 +124,12 @@ class CalendarStorage(BaseStorage): @overload def __setitem__(self, i: int, value: CalVT) -> None: - """x.__setitem__(i, o) <==> x[i] = o""" + """x.__setitem__(i, o) <==> (x[i] = o)""" ... @overload def __setitem__(self, s: slice, value: Iterable[CalVT]) -> None: - """x.__setitem__(s, o) <==> x[s] = o""" + """x.__setitem__(s, o) <==> (x[s] = o)""" ... def __setitem__(self, i, value) -> None: @@ -187,15 +187,15 @@ class InstrumentStorage(BaseStorage): raise NotImplementedError("Subclass of InstrumentStorage must implement `update` method") def __setitem__(self, k: InstKT, v: InstVT) -> None: - """ Set self[key] to value. """ + """Set self[key] to value.""" raise NotImplementedError("Subclass of InstrumentStorage must implement `__setitem__` method") def __delitem__(self, k: InstKT) -> None: - """ Delete self[key]. """ + """Delete self[key].""" raise NotImplementedError("Subclass of InstrumentStorage must implement `__delitem__` method") def __getitem__(self, k: InstKT) -> InstVT: - """ x.__getitem__(k) <==> x[k] """ + """x.__getitem__(k) <==> x[k]""" raise NotImplementedError("Subclass of InstrumentStorage must implement `__getitem__` method") @@ -364,7 +364,12 @@ class FeatureStorage(BaseStorage): @overload def __getitem__(self, s: slice) -> pd.Series: - """x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step] == pd.Series(values, index=pd.RangeIndex(start, len(values))""" + """x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step] + + Returns + ------- + pd.Series(values, index=pd.RangeIndex(start, len(values)) + """ ... @overload From 669f6bd6f58d159aae0a4f56142960a287bc4ac4 Mon Sep 17 00:00:00 2001 From: zhupr Date: Sat, 22 May 2021 02:03:50 +0800 Subject: [PATCH 09/24] modify exception message hint for storage.py && fix FileFeatureStorage[:] bug --- qlib/data/data.py | 2 +- qlib/data/storage/file_storage.py | 15 +++++++------- qlib/data/storage/storage.py | 31 ++++++++++++++++++++--------- tests/storage_tests/test_storage.py | 12 +++-------- 4 files changed, 33 insertions(+), 27 deletions(-) diff --git a/qlib/data/data.py b/qlib/data/data.py index e1c969247..3a74a2027 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -663,7 +663,7 @@ class LocalFeatureProvider(FeatureProvider): data = self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1] except Exception as e: get_module_logger("data").warning( - f"WARN: data not found for {instrument}.{field}\n\tException info: {str(e)}" + f"WARN: data not found for {instrument}.{field}\n\tFeature exception info: {str(e)}" ) data = pd.Series(dtype=np.float32) return data diff --git a/qlib/data/storage/file_storage.py b/qlib/data/storage/file_storage.py index e55105f57..90e4178ff 100644 --- a/qlib/data/storage/file_storage.py +++ b/qlib/data/storage/file_storage.py @@ -230,20 +230,19 @@ class FileFeatureStorage(FileStorage, FeatureStorage): raise TypeError(f"type(i) = {type(i)}") with open(self.uri, "rb") as fp: - ref_start_index = int(np.frombuffer(fp.read(4), dtype=" i: - raise IndexError(f"{i}: start index is {ref_start_index}") - fp.seek(4 * (i - ref_start_index) + 4) + if self.start_index > i: + raise IndexError(f"{i}: start index is {self.start_index}") + fp.seek(4 * (i - self.start_index) + 4) return i, struct.unpack("f", fp.read(4))[0] elif isinstance(i, slice): - start_index = i.start - end_index = i.stop - 1 - si = max(ref_start_index, start_index) + start_index = self.start_index if i.start is None else i.start + end_index = self.end_index if i.stop is None else i.stop - 1 + si = max(self.start_index, start_index) if si > end_index: return pd.Series() - fp.seek(4 * (si - ref_start_index) + 4) + fp.seek(4 * (si - self.start_index) + 4) # read n bytes count = end_index - si + 1 data = np.frombuffer(fp.read(4 * count), dtype=" str: - return re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2] + return re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2].lower() + + @property + def raise_info(self): + parameters_info = [ + f"{_k}={_v}" + for _k, _v in self.__dict__.items() + if not isinstance(_v, (dict,)) or (hasattr(_v, "__len__") and len(_v) < 3) + ] + return f"{self.storage_name.lower()} not exists, storage parameters: {parameters_info}" def check_exists(self) -> bool: """check if storage(uri) exists, if not exists: return False""" @@ -84,15 +95,17 @@ class BaseStorage(metaclass=StorageMeta): ) def _check(self): - # check storage(uri) if not self.check_exists(): - parameters_info = [f"{_k}={_v}" for _k, _v in self.__dict__.items()] - raise ValueError(f"{self.storage_name.lower()} not exists, storage parameters: {parameters_info}") + raise ValueError(self.raise_info) def __getattribute__(self, item): if item == "data": self._check() - return super(BaseStorage, self).__getattribute__(item) + try: + res = super(BaseStorage, self).__getattribute__(item) + except Exception as e: + raise ValueError(f"{self.raise_info}\n\tStorage exception info: {str(e)}") + return res class CalendarStorage(BaseStorage): diff --git a/tests/storage_tests/test_storage.py b/tests/storage_tests/test_storage.py index 79ad78b82..e7bac658c 100644 --- a/tests/storage_tests/test_storage.py +++ b/tests/storage_tests/test_storage.py @@ -135,18 +135,12 @@ class TestStorage(TestAutoData): feature = FeatureStorage(instrument="SH600004", field="close", freq="day", uri=self.provider_uri) - with pytest.raises(IndexError): + with pytest.raises(ValueError): print(feature[0]) assert isinstance( feature[815][1], (float, np.float32) ), f"{feature.__class__.__name__}.__getitem__(i: int) error" assert len(feature[815:818]) == 3, f"{feature.__class__.__name__}.__getitem__(s: slice) error" - print(f"feature[815: 818]: {feature[815: 818]}") + print(f"feature[815: 818]: \n{feature[815: 818]}") - for _item in feature: - assert ( - isinstance(_item, tuple) and len(_item) == 2 - ), f"{feature.__class__.__name__}.__iter__ item type error" - assert isinstance(_item[0], int) and isinstance( - _item[1], (float, np.float32) - ), f"{feature.__class__.__name__}.__iter__ value type error" + print(f"feature[:].tail(): \n{feature[:].tail()}") From 602f78b568acb09f13d1c8dba2050388553627df Mon Sep 17 00:00:00 2001 From: zhupr Date: Sat, 22 May 2021 08:30:12 +0800 Subject: [PATCH 10/24] add documentation on which storage methods are used in qlib --- qlib/data/storage/storage.py | 37 +++++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/qlib/data/storage/storage.py b/qlib/data/storage/storage.py index 766af30be..dcf6da9ed 100644 --- a/qlib/data/storage/storage.py +++ b/qlib/data/storage/storage.py @@ -24,20 +24,43 @@ If the user is only using it in `qlib`, you can customize Storage to implement o class UserCalendarStorage(CalendarStorage): @property - def data(self): - pass + def data(self) -> Iterable[CalVT]: + '''get all data''' + raise NotImplementedError("Subclass of CalendarStorage must implement `data` method") + + def check_exists(self) -> bool: + '''check if storage(uri) exists, if not exists: return False''' + raise NotImplementedError("Subclass of BaseStorage must implement `check_exists` method") + class UserInstrumentStorage(InstrumentStorage): @property - def data(self): - pass + def data(self) -> Dict[InstKT, InstVT]: + '''get all data''' + raise NotImplementedError("Subclass of InstrumentStorage must implement `data` method") + + def check_exists(self) -> bool: + '''check if storage(uri) exists, if not exists: return False''' + raise NotImplementedError("Subclass of BaseStorage must implement `check_exists` method") + class UserFeatureStorage(FeatureStorage): - @check_storage - def __getitem__(self, i: slice) -> pd.Series: - pass + def __getitem__(self, s: slice) -> pd.Series: + '''x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step] + + Returns + ------- + pd.Series(values, index=pd.RangeIndex(start, len(values)) + ''' + raise NotImplementedError( + "Subclass of FeatureStorage must implement `__getitem__(s: slice)` method" + ) + + def check_exists(self) -> bool: + '''check if storage(uri) exists, if not exists: return False''' + raise NotImplementedError("Subclass of BaseStorage must implement `check_exists` method") """ From 5da33562ddae7f78439d211269232b40f5231857 Mon Sep 17 00:00:00 2001 From: zhupr Date: Wed, 26 May 2021 01:01:36 +0800 Subject: [PATCH 11/24] remove uri parameter from storage && modify file_storage --- docs/reference/api.rst | 28 +++ qlib/data/data.py | 41 ++-- qlib/data/storage/file_storage.py | 110 ++++++--- qlib/data/storage/storage.py | 355 +++++++++++++++++----------- qlib/utils/__init__.py | 5 +- scripts/dump_bin.py | 2 +- tests/storage_tests/test_storage.py | 33 ++- 7 files changed, 373 insertions(+), 201 deletions(-) diff --git a/docs/reference/api.rst b/docs/reference/api.rst index 57f61f18b..5e6e50b0b 100644 --- a/docs/reference/api.rst +++ b/docs/reference/api.rst @@ -53,6 +53,34 @@ Cache .. autoclass:: qlib.data.cache.DiskDatasetCache :members: + +Storage +------------- +.. autoclass:: qlib.data.storage.storage.BaseStorage + :members: + +.. autoclass:: qlib.data.storage.storage.CalendarStorage + :members: + +.. autoclass:: qlib.data.storage.storage.InstrumentStorage + :members: + +.. autoclass:: qlib.data.storage.storage.FeatureStorage + :members: + +.. autoclass:: qlib.data.storage.file_storage.FileStorageMixin + :members: + +.. autoclass:: qlib.data.storage.file_storage.FileCalendarStorage + :members: + +.. autoclass:: qlib.data.storage.file_storage.FileInstrumentStorage + :members: + +.. autoclass:: qlib.data.storage.file_storage.FileFeatureStorage + :members: + + Dataset --------------- diff --git a/qlib/data/data.py b/qlib/data/data.py index 3a74a2027..eb7fbe0ea 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -45,12 +45,12 @@ class ProviderBackendMixin: # set default storage kwargs backend_kwargs = backend.setdefault("kwargs", {}) - # default uri map - if "uri" not in backend_kwargs: + # default provider_uri map + if "provider_uri" not in backend_kwargs: # if the user has no uri configured, use: uri = uri_map[freq] freq = kwargs.get("freq", "day") - uri_map = backend_kwargs.setdefault("uri_map", {freq: C.get_data_path()}) - backend_kwargs["uri"] = uri_map[freq] + provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {freq: C.get_data_path()}) + backend_kwargs["provider_uri"] = provider_uri_map[freq] backend.setdefault("kwargs", {}).update(**kwargs) return init_instance_by_config(backend) @@ -556,17 +556,21 @@ class LocalCalendarProvider(CalendarProvider): list of timestamps """ - backend_obj = self.backend_obj(freq=freq, future=future) - if future and not backend_obj.check_exists(): - get_module_logger("data").warning( - f"load calendar error: freq={freq}, future={future}; return current calendar!" - ) - get_module_logger("data").warning( - "You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md" - ) - backend_obj = self.backend_obj(freq=freq, future=False) + try: + backend_obj = self.backend_obj(freq=freq, future=future).data + except ValueError: + if future: + get_module_logger("data").warning( + f"load calendar error: freq={freq}, future={future}; return current calendar!" + ) + get_module_logger("data").warning( + "You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md" + ) + backend_obj = self.backend_obj(freq=freq, future=False).data + else: + raise - return [pd.Timestamp(x) for x in backend_obj.data] + return [pd.Timestamp(x) for x in backend_obj] def calendar(self, start_time=None, end_time=None, freq="day", future=False): _calendar, _calendar_index = self._get_calendar(freq, future) @@ -659,14 +663,7 @@ class LocalFeatureProvider(FeatureProvider): # validate field = str(field).lower()[1:] instrument = code_to_fname(instrument) - try: - data = self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1] - except Exception as e: - get_module_logger("data").warning( - f"WARN: data not found for {instrument}.{field}\n\tFeature exception info: {str(e)}" - ) - data = pd.Series(dtype=np.float32) - return data + return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1] class LocalExpressionProvider(ExpressionProvider): diff --git a/qlib/data/storage/file_storage.py b/qlib/data/storage/file_storage.py index 90e4178ff..a2b145c4d 100644 --- a/qlib/data/storage/file_storage.py +++ b/qlib/data/storage/file_storage.py @@ -14,19 +14,35 @@ from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage logger = get_module_logger("file_storage") -class FileStorage: - def check_exists(self): - return self.uri.exists() +class FileStorageMixin: + @property + def uri(self) -> Path: + _provider_uri = self.kwargs.get("provider_uri", None) + if _provider_uri is None: + raise ValueError( + f"The `provider_uri` parameter is not found in {self.__class__.__name__}, " + f'please specify `provider_uri` in the "provider\'s backend"' + ) + return Path(_provider_uri).expanduser().joinpath(f"{self.storage_name}s", self.file_name) + + def check(self): + """check self.uri + + Raises + ------- + ValueError + """ + if not self.uri.exists(): + raise ValueError(f"{self.storage_name} not exists: {self.uri}") -class FileCalendarStorage(FileStorage, CalendarStorage): - def __init__(self, freq: str, future: bool, uri: str, **kwargs): - super(FileCalendarStorage, self).__init__(freq, future, uri, **kwargs) - _file_name = f"{freq}_future.txt" if future else f"{freq}.txt" - self.uri = Path(self.uri).expanduser().joinpath("calendars", _file_name.lower()) +class FileCalendarStorage(FileStorageMixin, CalendarStorage): + def __init__(self, freq: str, future: bool, **kwargs): + super(FileCalendarStorage, self).__init__(freq, future, **kwargs) + self.file_name = f"{freq}_future.txt" if future else f"{freq}.txt".lower() - def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> Iterable[CalVT]: - if not self.check_exists(): + def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> List[CalVT]: + if not self.uri.exists(): self._write_calendar(values=[]) with self.uri.open("rb") as fp: return [ @@ -39,7 +55,8 @@ class FileCalendarStorage(FileStorage, CalendarStorage): np.savetxt(fp, values, fmt="%s", encoding="utf-8") @property - def data(self) -> Iterable[CalVT]: + def data(self) -> List[CalVT]: + self.check() return self._read_calendar() def extend(self, values: Iterable[CalVT]) -> None: @@ -49,6 +66,7 @@ class FileCalendarStorage(FileStorage, CalendarStorage): self._write_calendar(values=[]) def index(self, value: CalVT) -> int: + self.check() calendar = self._read_calendar() return int(np.argwhere(calendar == value)[0]) @@ -58,6 +76,7 @@ class FileCalendarStorage(FileStorage, CalendarStorage): self._write_calendar(values=calendar) def remove(self, value: CalVT) -> None: + self.check() index = self.index(value) calendar = self._read_calendar() calendar = np.delete(calendar, index) @@ -69,24 +88,29 @@ class FileCalendarStorage(FileStorage, CalendarStorage): self._write_calendar(values=calendar) def __delitem__(self, i: Union[int, slice]) -> None: + self.check() calendar = self._read_calendar() calendar = np.delete(calendar, i) self._write_calendar(values=calendar) - def __getitem__(self, i: Union[int, slice]) -> Union[CalVT, Iterable[CalVT]]: + def __getitem__(self, i: Union[int, slice]) -> Union[CalVT, List[CalVT]]: + self.check() return self._read_calendar()[i] + def __len__(self) -> int: + return len(self.data) -class FileInstrumentStorage(FileStorage, InstrumentStorage): + +class FileInstrumentStorage(FileStorageMixin, InstrumentStorage): INSTRUMENT_SEP = "\t" INSTRUMENT_START_FIELD = "start_datetime" INSTRUMENT_END_FIELD = "end_datetime" SYMBOL_FIELD_NAME = "instrument" - def __init__(self, market: str, uri: str, **kwargs): - super(FileInstrumentStorage, self).__init__(market, uri, **kwargs) - self.uri = Path(self.uri).expanduser().joinpath("instruments", f"{market.lower()}.txt") + def __init__(self, market: str, **kwargs): + super(FileInstrumentStorage, self).__init__(market, **kwargs) + self.file_name = f"{market.lower()}.txt" def _read_instrument(self) -> Dict[InstKT, InstVT]: if not self.uri.exists(): @@ -128,6 +152,7 @@ class FileInstrumentStorage(FileStorage, InstrumentStorage): @property def data(self) -> Dict[InstKT, InstVT]: + self.check() return self._read_instrument() def __setitem__(self, k: InstKT, v: InstVT) -> None: @@ -136,11 +161,13 @@ class FileInstrumentStorage(FileStorage, InstrumentStorage): self._write_instrument(inst) def __delitem__(self, k: InstKT) -> None: + self.check() inst = self._read_instrument() del inst[k] self._write_instrument(inst) def __getitem__(self, k: InstKT) -> InstVT: + self.check() return self._read_instrument()[k] def update(self, *args, **kwargs) -> None: @@ -164,13 +191,14 @@ class FileInstrumentStorage(FileStorage, InstrumentStorage): self._write_instrument(inst) + def __len__(self) -> int: + return len(self.data) -class FileFeatureStorage(FileStorage, FeatureStorage): - def __init__(self, instrument: str, field: str, freq: str, uri: str, **kwargs): - super(FileFeatureStorage, self).__init__(instrument, field, freq, uri, **kwargs) - self.uri = ( - Path(self.uri).expanduser().joinpath("features", instrument.lower(), f"{field.lower()}.{freq.lower()}.bin") - ) + +class FileFeatureStorage(FileStorageMixin, FeatureStorage): + def __init__(self, instrument: str, field: str, freq: str, **kwargs): + super(FileFeatureStorage, self).__init__(instrument, field, freq, **kwargs) + self.file_name = f"{instrument.lower()}/{field.lower()}.{freq.lower()}.bin" def clear(self): with self.uri.open("wb") as _: @@ -214,35 +242,44 @@ class FileFeatureStorage(FileStorage, FeatureStorage): @property def start_index(self) -> Union[int, None]: - if len(self) == 0: + if not self.uri.exists(): return None - with open(self.uri, "rb") as fp: + with self.uri.open("rb") as fp: index = int(np.frombuffer(fp.read(4), dtype=" Union[int, None]: + if not self.uri.exists(): + return None + # The next data appending index point will be `end_index + 1` + return self.start_index + len(self) - 1 + def __getitem__(self, i: Union[int, slice]) -> Union[Tuple[int, float], pd.Series]: if not self.uri.exists(): if isinstance(i, int): return None, None elif isinstance(i, slice): - return pd.Series() + return pd.Series(dtype=np.float32) else: raise TypeError(f"type(i) = {type(i)}") - with open(self.uri, "rb") as fp: - + storage_start_index = self.start_index + storage_end_index = self.end_index + with self.uri.open("rb") as fp: if isinstance(i, int): - if self.start_index > i: - raise IndexError(f"{i}: start index is {self.start_index}") - fp.seek(4 * (i - self.start_index) + 4) + + if storage_start_index > i: + raise IndexError(f"{i}: start index is {storage_start_index}") + fp.seek(4 * (i - storage_start_index) + 4) return i, struct.unpack("f", fp.read(4))[0] elif isinstance(i, slice): - start_index = self.start_index if i.start is None else i.start - end_index = self.end_index if i.stop is None else i.stop - 1 - si = max(self.start_index, start_index) + start_index = storage_start_index if i.start is None else i.start + end_index = storage_end_index if i.stop is None else i.stop - 1 + si = max(start_index, storage_start_index) if si > end_index: - return pd.Series() - fp.seek(4 * (si - self.start_index) + 4) + return pd.Series(dtype=np.float32) + fp.seek(4 * (si - storage_start_index) + 4) # read n bytes count = end_index - si + 1 data = np.frombuffer(fp.read(4 * count), dtype=" int: - return self.uri.stat().st_size // 4 - 1 if self.check_exists() else 0 + self.check() + return self.uri.stat().st_size // 4 - 1 diff --git a/qlib/data/storage/storage.py b/qlib/data/storage/storage.py index dcf6da9ed..8426ebe66 100644 --- a/qlib/data/storage/storage.py +++ b/qlib/data/storage/storage.py @@ -25,24 +25,28 @@ class UserCalendarStorage(CalendarStorage): @property def data(self) -> Iterable[CalVT]: - '''get all data''' - raise NotImplementedError("Subclass of CalendarStorage must implement `data` method") + '''get all data - def check_exists(self) -> bool: - '''check if storage(uri) exists, if not exists: return False''' - raise NotImplementedError("Subclass of BaseStorage must implement `check_exists` method") + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + ''' + raise NotImplementedError("Subclass of CalendarStorage must implement `data` method") class UserInstrumentStorage(InstrumentStorage): @property def data(self) -> Dict[InstKT, InstVT]: - '''get all data''' - raise NotImplementedError("Subclass of InstrumentStorage must implement `data` method") + '''get all data - def check_exists(self) -> bool: - '''check if storage(uri) exists, if not exists: return False''' - raise NotImplementedError("Subclass of BaseStorage must implement `check_exists` method") + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + ''' + raise NotImplementedError("Subclass of InstrumentStorage must implement `data` method") class UserFeatureStorage(FeatureStorage): @@ -53,103 +57,64 @@ class UserFeatureStorage(FeatureStorage): Returns ------- pd.Series(values, index=pd.RangeIndex(start, len(values)) + + Notes + ------- + if data(storage) does not exist: + if isinstance(i, int): + return (None, None) + if isinstance(i, slice): + # return empty pd.Series + return pd.Series(dtype=np.float32) ''' raise NotImplementedError( "Subclass of FeatureStorage must implement `__getitem__(s: slice)` method" ) - def check_exists(self) -> bool: - '''check if storage(uri) exists, if not exists: return False''' - raise NotImplementedError("Subclass of BaseStorage must implement `check_exists` method") """ -class StorageMeta(type): - """unified management of raise when storage is not exists""" - - def __new__(cls, name, bases, dict): - class_obj = type.__new__(cls, name, bases, dict) - - # The calls to __iter__ and __getitem__ do not pass through __getattribute__. - # In order to throw an exception before calling __getitem__, use the metaclass - _getitem_func = getattr(class_obj, "__getitem__") - - def _getitem(obj, item): - getattr(obj, "_check")() - try: - res = _getitem_func(obj, item) - except Exception as e: - raise ValueError(f"{obj.raise_info}\n\tStorage exception info: {str(e)}") - return res - - setattr(class_obj, "__getitem__", _getitem) - return class_obj - - -class BaseStorage(metaclass=StorageMeta): +class BaseStorage: @property def storage_name(self) -> str: return re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2].lower() - @property - def raise_info(self): - parameters_info = [ - f"{_k}={_v}" - for _k, _v in self.__dict__.items() - if not isinstance(_v, (dict,)) or (hasattr(_v, "__len__") and len(_v) < 3) - ] - return f"{self.storage_name.lower()} not exists, storage parameters: {parameters_info}" - - def check_exists(self) -> bool: - """check if storage(uri) exists, if not exists: return False""" - raise NotImplementedError("Subclass of BaseStorage must implement `check_exists` method") - - def clear(self) -> None: - """clear storage""" - raise NotImplementedError("Subclass of BaseStorage must implement `clear` method") - - def __len__(self) -> 0: - return len(self.data) if self.check_exists() else 0 - - def __getitem__(self, item: Union[slice, Union[int, InstKT]]): - raise NotImplementedError( - "Subclass of BaseStorage must implement `__getitem__(i: Union[int, InstKT])`/`__getitem__(s: slice)` method" - ) - - def _check(self): - if not self.check_exists(): - raise ValueError(self.raise_info) - - def __getattribute__(self, item): - if item == "data": - self._check() - try: - res = super(BaseStorage, self).__getattribute__(item) - except Exception as e: - raise ValueError(f"{self.raise_info}\n\tStorage exception info: {str(e)}") - return res - class CalendarStorage(BaseStorage): """ The behavior of CalendarStorage's methods and List's methods of the same name remain consistent """ - def __init__(self, freq: str, future: bool, uri: str, **kwargs): + def __init__(self, freq: str, future: bool, **kwargs): self.freq = freq self.future = future - self.uri = uri + self.kwargs = kwargs @property def data(self) -> Iterable[CalVT]: - """get all data""" + """get all data + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + """ raise NotImplementedError("Subclass of CalendarStorage must implement `data` method") + def clear(self) -> None: + raise NotImplementedError("Subclass of CalendarStorage must implement `clear` method") + def extend(self, iterable: Iterable[CalVT]) -> None: raise NotImplementedError("Subclass of CalendarStorage must implement `extend` method") def index(self, value: CalVT) -> int: + """ + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + """ raise NotImplementedError("Subclass of CalendarStorage must implement `index` method") def insert(self, index: int, value: CalVT) -> None: @@ -184,6 +149,12 @@ class CalendarStorage(BaseStorage): ... def __delitem__(self, i) -> None: + """ + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + """ raise NotImplementedError( "Subclass of CalendarStorage must implement `__delitem__(i: int)`/`__delitem__(s: slice)` method" ) @@ -199,26 +170,60 @@ class CalendarStorage(BaseStorage): ... def __getitem__(self, i) -> CalVT: + """ + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + + """ raise NotImplementedError( "Subclass of CalendarStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method" ) + def __len__(self) -> int: + """ + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + + """ + raise NotImplementedError("Subclass of CalendarStorage must implement `__len__` method") + class InstrumentStorage(BaseStorage): - def __init__(self, market: str, uri: str, **kwargs): + def __init__(self, market: str, **kwargs): self.market = market - self.uri = uri + self.kwargs = kwargs @property def data(self) -> Dict[InstKT, InstVT]: - """get all data""" + """get all data + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + """ raise NotImplementedError("Subclass of InstrumentStorage must implement `data` method") + def clear(self) -> None: + raise NotImplementedError("Subclass of InstrumentStorage must implement `clear` method") + def update(self, *args, **kwargs) -> None: """D.update([E, ]**F) -> None. Update D from mapping/iterable E and F. - If E present and has a .keys() method, does: for k in E: D[k] = E[k] - If E present and lacks .keys() method, does: for (k, v) in E: D[k] = v - In either case, this is followed by: for k, v in F.items(): D[k] = v + + Notes + ------ + If E present and has a .keys() method, does: for k in E: D[k] = E[k] + + If E present and lacks .keys() method, does: for (k, v) in E: D[k] = v + + In either case, this is followed by: for k, v in F.items(): D[k] = v + """ raise NotImplementedError("Subclass of InstrumentStorage must implement `update` method") @@ -227,53 +232,96 @@ class InstrumentStorage(BaseStorage): raise NotImplementedError("Subclass of InstrumentStorage must implement `__setitem__` method") def __delitem__(self, k: InstKT) -> None: - """Delete self[key].""" + """Delete self[key]. + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + """ raise NotImplementedError("Subclass of InstrumentStorage must implement `__delitem__` method") def __getitem__(self, k: InstKT) -> InstVT: """x.__getitem__(k) <==> x[k]""" raise NotImplementedError("Subclass of InstrumentStorage must implement `__getitem__` method") + def __len__(self) -> int: + """ + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + + """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `__len__` method") + class FeatureStorage(BaseStorage): - def __init__(self, instrument: str, field: str, freq: str, uri: str, **kwargs): + def __init__(self, instrument: str, field: str, freq: str, **kwargs): self.instrument = instrument self.field = field self.freq = freq - self.uri = uri + self.kwargs = kwargs @property def data(self) -> pd.Series: - """get all data""" + """get all data + + Notes + ------ + if data(storage) does not exist, return empty pd.Series: `return pd.Series(dtype=np.float32)` + """ raise NotImplementedError("Subclass of FeatureStorage must implement `data` method") @property def start_index(self) -> Union[int, None]: """get FeatureStorage start index - If len(self) == 0; return None + + Notes + ----- + If the data(storage) does not exist, return None """ - raise NotImplementedError("Subclass of FeatureStorage must implement `data` method") + raise NotImplementedError("Subclass of FeatureStorage must implement `start_index` method") @property def end_index(self) -> Union[int, None]: - if len(self) == 0: - return None - return None if len(self) == 0 else self.start_index + len(self) - 1 + """get FeatureStorage end index + + Notes + ----- + The right index of the data range (both sides are closed) + + The next data appending point will be `end_index + 1` + + If the data(storage) does not exist, return None + """ + raise NotImplementedError("Subclass of FeatureStorage must implement `end_index` method") + + def clear(self) -> None: + raise NotImplementedError("Subclass of FeatureStorage must implement `clear` method") def write(self, data_array: Union[List, np.ndarray, Tuple], index: int = None): """Write data_array to FeatureStorage starting from index. - If index is None, append data_array to feature. - If len(data_array) == 0; return - If (index - self.end_index) >= 1, self[end_index+1: index] will be filled with np.nan + Notes + ------ + If index is None, append data_array to feature. - Examples: + If len(data_array) == 0; return + + If (index - self.end_index) >= 1, self[end_index+1: index] will be filled with np.nan + + Examples + --------- + .. code-block:: feature: 3 4 4 5 5 6 + >>> self.write([6, 7], index=6) feature: @@ -311,56 +359,70 @@ class FeatureStorage(BaseStorage): def rebase(self, start_index: int = None, end_index: int = None): """Rebase the start_index and end_index of the FeatureStorage. - Examples: + start_index and end_index are closed intervals: [start_index, end_index] - feature: - 3 4 - 4 5 - 5 6 + Examples + --------- - >>> self.rebase(start_index=4) + .. code-block:: - feature: - 4 5 - 5 6 + feature: + 3 4 + 4 5 + 5 6 - >>> self.rebase(start_index=3) - feature: - 3 np.nan - 4 5 - 5 6 + >>> self.rebase(start_index=4) - >>> self.write([3], index=3) + feature: + 4 5 + 5 6 - feature: - 3 3 - 4 5 - 5 6 + >>> self.rebase(start_index=3) - >>> self.rebase(end_index=4) + feature: + 3 np.nan + 4 5 + 5 6 - feature: - 3 3 - 4 5 + >>> self.write([3], index=3) - >>> self.write([6, 7, 8], index=4) + feature: + 3 3 + 4 5 + 5 6 - feature: - 3 3 - 4 6 - 5 7 - 6 8 + >>> self.rebase(end_index=4) - >>> self.rebase(start_index=4, end_index=5) + feature: + 3 3 + 4 5 - feature: - 4 6 - 5 7 + >>> self.write([6, 7, 8], index=4) + + feature: + 3 3 + 4 6 + 5 7 + 6 8 + + >>> self.rebase(start_index=4, end_index=5) + + feature: + 4 6 + 5 7 """ - if start_index is None and end_index is None: - logger.warning("both start_index and end_index are None, rebase is ignored") + storage_si = self.start_index + storage_ei = self.end_index + if storage_si is None or storage_ei is None: + raise ValueError("storage.start_index or storage.end_index is None, storage may not exist") + + start_index = storage_si if start_index is None else start_index + end_index = storage_ei if end_index is None else end_index + + if start_index is None or end_index is None: + logger.warning("both start_index and end_index are None, or storage does not exist; rebase is ignored") return if start_index < 0 or end_index < 0: @@ -373,17 +435,15 @@ class FeatureStorage(BaseStorage): ) return - start_index = self.start_index if start_index is None else end_index - end_index = self.end_index if end_index is None else end_index - if start_index <= self.start_index: - self.write([np.nan] * (self.start_index - start_index), start_index) + if start_index <= storage_si: + self.write([np.nan] * (storage_si - start_index), start_index) else: self.rewrite(self[start_index:].values, start_index) if end_index >= self.end_index: self.write([np.nan] * (end_index - self.end_index)) else: - self.rewrite(self[: end_index + 1].values, self.start_index) + self.rewrite(self[: end_index + 1].values, start_index) def rewrite(self, data: Union[List, np.ndarray, Tuple], index: int): """overwrite all data in FeatureStorage with data @@ -414,7 +474,28 @@ class FeatureStorage(BaseStorage): ... def __getitem__(self, i) -> Union[Tuple[int, float], pd.Series]: - """x.__getitem__(y) <==> x[y]""" + """x.__getitem__(y) <==> x[y] + + Notes + ------- + if data(storage) does not exist: + if isinstance(i, int): + return (None, None) + if isinstance(i, slice): + # return empty pd.Series + return pd.Series(dtype=np.float32) + """ raise NotImplementedError( "Subclass of FeatureStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method" ) + + def __len__(self) -> int: + """ + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + + """ + raise NotImplementedError("Subclass of FeatureStorage must implement `__len__` method") diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 77857182d..686f0fc00 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -665,7 +665,10 @@ def exists_qlib_data(qlib_dir): return False # check calendar bin for _calendar in calendars_dir.iterdir(): - if not list(features_dir.rglob(f"*.{_calendar.name.split('.')[0]}.bin")): + + if ("_future" not in _calendar.name) and ( + not list(features_dir.rglob(f"*.{_calendar.name.split('.')[0]}.bin")) + ): return False # check instruments diff --git a/scripts/dump_bin.py b/scripts/dump_bin.py index 0b063fdda..b3a18cc90 100644 --- a/scripts/dump_bin.py +++ b/scripts/dump_bin.py @@ -120,7 +120,7 @@ class DumpDataBase: else: df = file_or_df if df.empty or self.date_field_name not in df.columns.tolist(): - _calendars = pd.Series() + _calendars = pd.Series(dtype=np.float32) else: _calendars = df[self.date_field_name] diff --git a/tests/storage_tests/test_storage.py b/tests/storage_tests/test_storage.py index e7bac658c..aad8d11e4 100644 --- a/tests/storage_tests/test_storage.py +++ b/tests/storage_tests/test_storage.py @@ -24,7 +24,7 @@ QLIB_DIR.mkdir(exist_ok=True, parents=True) class TestStorage(TestAutoData): def test_calendar_storage(self): - calendar = CalendarStorage(freq="day", future=False, uri=self.provider_uri) + calendar = CalendarStorage(freq="day", future=False, provider_uri=self.provider_uri) assert isinstance(calendar[:], Iterable), f"{calendar.__class__.__name__}.__getitem__(s: slice) is not Iterable" assert isinstance(calendar.data, Iterable), f"{calendar.__class__.__name__}.data is not Iterable" @@ -32,6 +32,16 @@ class TestStorage(TestAutoData): print(f"calendar[0]: {calendar[0]}") print(f"calendar[-1]: {calendar[-1]}") + calendar = CalendarStorage(freq="1min", future=False, provider_uri="not_found") + with pytest.raises(ValueError): + print(calendar.data) + + with pytest.raises(ValueError): + print(calendar[:]) + + with pytest.raises(ValueError): + print(calendar[0]) + def test_instrument_storage(self): """ The meaning of instrument, such as CSI500: @@ -66,7 +76,7 @@ class TestStorage(TestAutoData): """ - instrument = InstrumentStorage(market="csi300", uri=self.provider_uri) + instrument = InstrumentStorage(market="csi300", provider_uri=self.provider_uri) for inst, spans in instrument.data.items(): assert isinstance(inst, str) and isinstance( @@ -79,6 +89,13 @@ class TestStorage(TestAutoData): print(f"instrument['SH600000']: {instrument['SH600000']}") + instrument = InstrumentStorage(market="csi300", provider_uri="not_found") + with pytest.raises(ValueError): + print(instrument.data) + + with pytest.raises(ValueError): + print(instrument["sSH600000"]) + def test_feature_storage(self): """ Calendar: @@ -133,9 +150,9 @@ class TestStorage(TestAutoData): """ - feature = FeatureStorage(instrument="SH600004", field="close", freq="day", uri=self.provider_uri) + feature = FeatureStorage(instrument="SH600004", field="close", freq="day", provider_uri=self.provider_uri) - with pytest.raises(ValueError): + with pytest.raises(IndexError): print(feature[0]) assert isinstance( feature[815][1], (float, np.float32) @@ -144,3 +161,11 @@ class TestStorage(TestAutoData): print(f"feature[815: 818]: \n{feature[815: 818]}") print(f"feature[:].tail(): \n{feature[:].tail()}") + + feature = FeatureStorage(instrument="SH600004", field="close", freq="day", provider_uri="not_fount") + + assert feature[0] == (None, None), "FeatureStorage does not exist, feature[i] should return `(None, None)`" + assert feature[:].empty, "FeatureStorage does not exist, feature[:] should return `pd.Series(dtype=np.float32)`" + assert ( + feature.data.empty + ), "FeatureStorage does not exist, feature.data should return `pd.Series(dtype=np.float32)`" From 6222940b9c387f0cca8a61dc25f93bc8bd678f97 Mon Sep 17 00:00:00 2001 From: al Date: Wed, 26 May 2021 17:50:49 +0800 Subject: [PATCH 12/24] Update README.md fix typo --- scripts/data_collector/yahoo/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/data_collector/yahoo/README.md b/scripts/data_collector/yahoo/README.md index b9fd9123c..0413f32b6 100644 --- a/scripts/data_collector/yahoo/README.md +++ b/scripts/data_collector/yahoo/README.md @@ -121,7 +121,7 @@ df = D.features(D.instruments("all"), ["$close"], freq="day") ### Help ```bash -pythono collector.py collector_data --help +python collector.py collector_data --help ``` ## Parameters From b884c8c571602482a7355df955ff2f14ee16b5f7 Mon Sep 17 00:00:00 2001 From: al Date: Wed, 26 May 2021 18:00:23 +0800 Subject: [PATCH 13/24] Update collector.py fix typo --- scripts/data_collector/yahoo/collector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index a6e06613e..b92f7773b 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -191,7 +191,7 @@ class YahooCollector(BaseCollector): class YahooCollectorCN(YahooCollector, ABC): def get_instrument_list(self): - logger.info("get HS stock symbos......") + logger.info("get HS stock symbols......") symbols = get_hs_stock_symbols() logger.info(f"get {len(symbols)} symbols.") return symbols From 114162693fd185fa5c4bc8a2ab4eccf540975b36 Mon Sep 17 00:00:00 2001 From: zhupr Date: Wed, 26 May 2021 18:29:41 +0800 Subject: [PATCH 14/24] Fix YahooCollector can't download 1min data --- scripts/data_collector/yahoo/collector.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index a6e06613e..2cd080199 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -581,7 +581,6 @@ class Run(BaseRun): delay=0, start=None, end=None, - interval="1d", check_data_length=False, limit_nums=None, ): @@ -593,8 +592,6 @@ class Run(BaseRun): default 2 delay: float time.sleep(delay), default 0 - interval: str - freq, value from [1min, 1d], default 1d start: str start datetime, default "2000-01-01" end: str @@ -611,8 +608,9 @@ class Run(BaseRun): # get 1m data $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m """ - - super(Run, self).download_data(max_collector_count, delay, start, end, interval, check_data_length, limit_nums) + super(Run, self).download_data( + max_collector_count, delay, start, end, self.interval, check_data_length, limit_nums + ) def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): """normalize data From 9b431bc5035c64a60e02b21ede72e5d7788c76f7 Mon Sep 17 00:00:00 2001 From: al Date: Wed, 26 May 2021 22:01:15 +0800 Subject: [PATCH 15/24] Update 1min demo data in CSV format --- docs/component/data.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/component/data.rst b/docs/component/data.rst index 0a650c523..6bc55cf6c 100644 --- a/docs/component/data.rst +++ b/docs/component/data.rst @@ -72,12 +72,14 @@ Converting CSV Format into Qlib Format ``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert **any** data in CSV format into `.bin` files (``Qlib`` format) as long as they are in the correct format. -Users can download the demo china-stock data in CSV format as follows for reference to the CSV format. +Users can download the 1 day demo china-stock data in CSV format as follows for reference to the CSV format. .. code-block:: bash python scripts/get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data +For 1min demo, please refer to the script `here `_. + Users can also provide their own data in CSV format. However, the CSV data **must satisfies** following criterions: - CSV file is named after a specific stock *or* the CSV file includes a column of the stock name From 5a382d7e99a96c7af05948f6e586dc47a9662ab2 Mon Sep 17 00:00:00 2001 From: al Date: Thu, 27 May 2021 12:40:55 +0800 Subject: [PATCH 16/24] Update data.rst Update csv format according to feedback --- docs/component/data.rst | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/component/data.rst b/docs/component/data.rst index 6bc55cf6c..ff0d0c2d7 100644 --- a/docs/component/data.rst +++ b/docs/component/data.rst @@ -72,13 +72,18 @@ Converting CSV Format into Qlib Format ``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert **any** data in CSV format into `.bin` files (``Qlib`` format) as long as they are in the correct format. -Users can download the 1 day demo china-stock data in CSV format as follows for reference to the CSV format. +Besides downloading the prepared demo data, users could download demo data directly from the Collector as follows for reference to the CSV format. +Here are some example: -.. code-block:: bash +for daily data: + .. code-block:: bash python scripts/get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data -For 1min demo, please refer to the script `here `_. +for 1min data: + .. code-block:: bash + + python scripts/data_collector/yahoo/collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1min --region CN --start 2021-05-20 --end 2021-05-23 --delay 0.1 --interval 1min --limit_nums 10 Users can also provide their own data in CSV format. However, the CSV data **must satisfies** following criterions: From 0a4e2416089f4a4022e761e36b769b84b13028d1 Mon Sep 17 00:00:00 2001 From: zhupr Date: Thu, 27 May 2021 14:18:17 +0800 Subject: [PATCH 17/24] add get_feature_importance to model interpret --- examples/model_interpreter.py | 81 ++++ qlib/contrib/model/catboost_model.py | 15 +- qlib/contrib/model/double_ensemble.py | 516 +++++++++++----------- qlib/contrib/model/gbdt.py | 7 +- qlib/contrib/model/highfreq_gdbt_model.py | 15 +- qlib/contrib/model/xgboost.py | 17 +- qlib/model/interpret/__init__.py | 0 qlib/model/interpret/base.py | 33 ++ 8 files changed, 419 insertions(+), 265 deletions(-) create mode 100644 examples/model_interpreter.py create mode 100644 qlib/model/interpret/__init__.py create mode 100644 qlib/model/interpret/base.py diff --git a/examples/model_interpreter.py b/examples/model_interpreter.py new file mode 100644 index 000000000..1d9230b8c --- /dev/null +++ b/examples/model_interpreter.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +import qlib +from qlib.config import REG_CN + +from qlib.utils import exists_qlib_data, init_instance_by_config +from qlib.tests.data import GetData + +market = "csi300" +benchmark = "SH000300" + +################################### +# config +################################### +data_handler_config = { + "start_time": "2008-01-01", + "end_time": "2020-08-01", + "fit_start_time": "2008-01-01", + "fit_end_time": "2014-12-31", + "instruments": market, +} + +task = { + "model": { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + "kwargs": { + "loss": "mse", + "colsample_bytree": 0.8879, + "learning_rate": 0.0421, + "subsample": 0.8789, + "lambda_l1": 205.6999, + "lambda_l2": 580.9768, + "max_depth": 8, + "num_leaves": 210, + "num_threads": 20, + }, + }, + "dataset": { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "Alpha158", + "module_path": "qlib.contrib.data.handler", + "kwargs": data_handler_config, + }, + "segments": { + "train": ("2008-01-01", "2014-12-31"), + "valid": ("2015-01-01", "2016-12-31"), + "test": ("2017-01-01", "2020-08-01"), + }, + }, + }, +} + + +if __name__ == "__main__": + + # use default data + provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir + if not exists_qlib_data(provider_uri): + print(f"Qlib data is not found in {provider_uri}") + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + + qlib.init(provider_uri=provider_uri, region=REG_CN) + + ################################### + # train model + ################################### + # model initialization + model = init_instance_by_config(task["model"]) + dataset = init_instance_by_config(task["dataset"]) + model.fit(dataset) + + # get model feature importance + feature_importance = model.get_feature_importance() + print("feature importance:") + print(feature_importance) diff --git a/qlib/contrib/model/catboost_model.py b/qlib/contrib/model/catboost_model.py index 98b9b9c2d..5138e0e6f 100644 --- a/qlib/contrib/model/catboost_model.py +++ b/qlib/contrib/model/catboost_model.py @@ -10,9 +10,10 @@ from catboost.utils import get_gpu_device_count from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP +from ...model.interpret.base import FeatureInt -class CatBoostModel(Model): +class CatBoostModel(Model, FeatureInt): """CatBoost Model""" def __init__(self, loss="RMSE", **kwargs): @@ -69,6 +70,18 @@ class CatBoostModel(Model): x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) return pd.Series(self.model.predict(x_test.values), index=x_test.index) + def get_feature_importance(self, *args, **kwargs) -> pd.Series: + """get feature importance + + Notes + ----- + parameters references: + https://catboost.ai/docs/concepts/python-reference_catboost_get_feature_importance.html#python-reference_catboost_get_feature_importance + """ + return pd.Series( + data=self.model.get_feature_importance(*args, **kwargs), index=self.model.feature_names_ + ).sort_values(ascending=False) + if __name__ == "__main__": cat = CatBoostModel() diff --git a/qlib/contrib/model/double_ensemble.py b/qlib/contrib/model/double_ensemble.py index 4b267a2b0..d3ca898f8 100644 --- a/qlib/contrib/model/double_ensemble.py +++ b/qlib/contrib/model/double_ensemble.py @@ -1,251 +1,265 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import lightgbm as lgb -import numpy as np -import pandas as pd -from typing import Text, Union -from ...model.base import Model -from ...data.dataset import DatasetH -from ...data.dataset.handler import DataHandlerLP -from ...log import get_module_logger - - -class DEnsembleModel(Model): - """Double Ensemble Model""" - - def __init__( - self, - base_model="gbm", - loss="mse", - num_models=6, - enable_sr=True, - enable_fs=True, - alpha1=1.0, - alpha2=1.0, - bins_sr=10, - bins_fs=5, - decay=None, - sample_ratios=None, - sub_weights=None, - epochs=100, - **kwargs - ): - self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm" - self.num_models = num_models # the number of sub-models - self.enable_sr = enable_sr - self.enable_fs = enable_fs - self.alpha1 = alpha1 - self.alpha2 = alpha2 - self.bins_sr = bins_sr - self.bins_fs = bins_fs - self.decay = decay - if sample_ratios is None: # the default values for sample_ratios - sample_ratios = [0.8, 0.7, 0.6, 0.5, 0.4] - if sub_weights is None: # the default values for sub_weights - sub_weights = [1.0, 0.2, 0.2, 0.2, 0.2, 0.2] - if not len(sample_ratios) == bins_fs: - raise ValueError("The length of sample_ratios should be equal to bins_fs.") - self.sample_ratios = sample_ratios - if not len(sub_weights) == num_models: - raise ValueError("The length of sub_weights should be equal to num_models.") - self.sub_weights = sub_weights - self.epochs = epochs - self.logger = get_module_logger("DEnsembleModel") - self.logger.info("Double Ensemble Model...") - self.ensemble = [] # the current ensemble model, a list contains all the sub-models - self.sub_features = [] # the features for each sub model in the form of pandas.Index - self.params = {"objective": loss} - self.params.update(kwargs) - self.loss = loss - - def fit(self, dataset: DatasetH): - df_train, df_valid = dataset.prepare( - ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L - ) - x_train, y_train = df_train["feature"], df_train["label"] - # initialize the sample weights - N, F = x_train.shape - weights = pd.Series(np.ones(N, dtype=float)) - # initialize the features - features = x_train.columns - pred_sub = pd.DataFrame(np.zeros((N, self.num_models), dtype=float), index=x_train.index) - # train sub-models - for k in range(self.num_models): - self.sub_features.append(features) - self.logger.info("Training sub-model: ({}/{})".format(k + 1, self.num_models)) - model_k = self.train_submodel(df_train, df_valid, weights, features) - self.ensemble.append(model_k) - # no further sample re-weight and feature selection needed for the last sub-model - if k + 1 == self.num_models: - break - - self.logger.info("Retrieving loss curve and loss values...") - loss_curve = self.retrieve_loss_curve(model_k, df_train, features) - pred_k = self.predict_sub(model_k, df_train, features) - pred_sub.iloc[:, k] = pred_k - pred_ensemble = pred_sub.iloc[:, : k + 1].mean(axis=1) - loss_values = pd.Series(self.get_loss(y_train.values.squeeze(), pred_ensemble.values)) - - if self.enable_sr: - self.logger.info("Sample re-weighting...") - weights = self.sample_reweight(loss_curve, loss_values, k + 1) - - if self.enable_fs: - self.logger.info("Feature selection...") - features = self.feature_selection(df_train, loss_values) - - def train_submodel(self, df_train, df_valid, weights, features): - dtrain, dvalid = self._prepare_data_gbm(df_train, df_valid, weights, features) - evals_result = dict() - model = lgb.train( - self.params, - dtrain, - num_boost_round=self.epochs, - valid_sets=[dtrain, dvalid], - valid_names=["train", "valid"], - verbose_eval=20, - evals_result=evals_result, - ) - evals_result["train"] = list(evals_result["train"].values())[0] - evals_result["valid"] = list(evals_result["valid"].values())[0] - return model - - def _prepare_data_gbm(self, df_train, df_valid, weights, features): - x_train, y_train = df_train["feature"].loc[:, features], df_train["label"] - x_valid, y_valid = df_valid["feature"].loc[:, features], df_valid["label"] - - # Lightgbm need 1D array as its label - if y_train.values.ndim == 2 and y_train.values.shape[1] == 1: - y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values) - else: - raise ValueError("LightGBM doesn't support multi-label training") - - dtrain = lgb.Dataset(x_train.values, label=y_train, weight=weights) - dvalid = lgb.Dataset(x_valid.values, label=y_valid) - return dtrain, dvalid - - def sample_reweight(self, loss_curve, loss_values, k_th): - """ - the SR module of Double Ensemble - :param loss_curve: the shape is NxT - the loss curve for the previous sub-model, where the element (i, t) if the error on the i-th sample - after the t-th iteration in the training of the previous sub-model. - :param loss_values: the shape is N - the loss of the current ensemble on the i-th sample. - :param k_th: the index of the current sub-model, starting from 1 - :return: weights - the weights for all the samples. - """ - # normalize loss_curve and loss_values with ranking - loss_curve_norm = loss_curve.rank(axis=0, pct=True) - loss_values_norm = (-loss_values).rank(pct=True) - - # calculate l_start and l_end from loss_curve - N, T = loss_curve.shape - part = np.maximum(int(T * 0.1), 1) - l_start = loss_curve_norm.iloc[:, :part].mean(axis=1) - l_end = loss_curve_norm.iloc[:, -part:].mean(axis=1) - - # calculate h-value for each sample - h1 = loss_values_norm - h2 = (l_end / l_start).rank(pct=True) - h = pd.DataFrame({"h_value": self.alpha1 * h1 + self.alpha2 * h2}) - - # calculate weights - h["bins"] = pd.cut(h["h_value"], self.bins_sr) - h_avg = h.groupby("bins")["h_value"].mean() - weights = pd.Series(np.zeros(N, dtype=float)) - for i_b, b in enumerate(h_avg.index): - weights[h["bins"] == b] = 1.0 / (self.decay ** k_th * h_avg[i_b] + 0.1) - return weights - - def feature_selection(self, df_train, loss_values): - """ - the FS module of Double Ensemble - :param df_train: the shape is NxF - :param loss_values: the shape is N - the loss of the current ensemble on the i-th sample. - :return: res_feat: in the form of pandas.Index - - """ - x_train, y_train = df_train["feature"], df_train["label"] - features = x_train.columns - N, F = x_train.shape - g = pd.DataFrame({"g_value": np.zeros(F, dtype=float)}) - M = len(self.ensemble) - - # shuffle specific columns and calculate g-value for each feature - x_train_tmp = x_train.copy() - for i_f, feat in enumerate(features): - x_train_tmp.loc[:, feat] = np.random.permutation(x_train_tmp.loc[:, feat].values) - pred = pd.Series(np.zeros(N), index=x_train_tmp.index) - for i_s, submodel in enumerate(self.ensemble): - pred += ( - pd.Series( - submodel.predict(x_train_tmp.loc[:, self.sub_features[i_s]].values), index=x_train_tmp.index - ) - / M - ) - loss_feat = self.get_loss(y_train.values.squeeze(), pred.values) - g.loc[i_f, "g_value"] = np.mean(loss_feat - loss_values) / (np.std(loss_feat - loss_values) + 1e-7) - x_train_tmp.loc[:, feat] = x_train.loc[:, feat].copy() - - # one column in train features is all-nan # if g['g_value'].isna().any() - g["g_value"].replace(np.nan, 0, inplace=True) - - # divide features into bins_fs bins - g["bins"] = pd.cut(g["g_value"], self.bins_fs) - - # randomly sample features from bins to construct the new features - res_feat = [] - sorted_bins = sorted(g["bins"].unique(), reverse=True) - for i_b, b in enumerate(sorted_bins): - b_feat = features[g["bins"] == b] - num_feat = int(np.ceil(self.sample_ratios[i_b] * len(b_feat))) - res_feat = res_feat + np.random.choice(b_feat, size=num_feat).tolist() - return pd.Index(res_feat) - - def get_loss(self, label, pred): - if self.loss == "mse": - return (label - pred) ** 2 - else: - raise ValueError("not implemented yet") - - def retrieve_loss_curve(self, model, df_train, features): - if self.base_model == "gbm": - num_trees = model.num_trees() - x_train, y_train = df_train["feature"].loc[:, features], df_train["label"] - # Lightgbm need 1D array as its label - if y_train.values.ndim == 2 and y_train.values.shape[1] == 1: - y_train = np.squeeze(y_train.values) - else: - raise ValueError("LightGBM doesn't support multi-label training") - - N = x_train.shape[0] - loss_curve = pd.DataFrame(np.zeros((N, num_trees))) - pred_tree = np.zeros(N, dtype=float) - for i_tree in range(num_trees): - pred_tree += model.predict(x_train.values, start_iteration=i_tree, num_iteration=1) - loss_curve.iloc[:, i_tree] = self.get_loss(y_train, pred_tree) - else: - raise ValueError("not implemented yet") - return loss_curve - - def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): - if self.ensemble is None: - raise ValueError("model is not fitted yet!") - x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) - pred = pd.Series(np.zeros(x_test.shape[0]), index=x_test.index) - for i_sub, submodel in enumerate(self.ensemble): - feat_sub = self.sub_features[i_sub] - pred += ( - pd.Series(submodel.predict(x_test.loc[:, feat_sub].values), index=x_test.index) - * self.sub_weights[i_sub] - ) - return pred - - def predict_sub(self, submodel, df_data, features): - x_data, y_data = df_data["feature"].loc[:, features], df_data["label"] - pred_sub = pd.Series(submodel.predict(x_data.values), index=x_data.index) - return pred_sub +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import lightgbm as lgb +import numpy as np +import pandas as pd +from typing import Text, Union +from ...model.base import Model +from ...data.dataset import DatasetH +from ...data.dataset.handler import DataHandlerLP +from ...model.interpret.base import FeatureInt +from ...log import get_module_logger + + +class DEnsembleModel(Model, FeatureInt): + """Double Ensemble Model""" + + def __init__( + self, + base_model="gbm", + loss="mse", + num_models=6, + enable_sr=True, + enable_fs=True, + alpha1=1.0, + alpha2=1.0, + bins_sr=10, + bins_fs=5, + decay=None, + sample_ratios=None, + sub_weights=None, + epochs=100, + **kwargs + ): + self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm" + self.num_models = num_models # the number of sub-models + self.enable_sr = enable_sr + self.enable_fs = enable_fs + self.alpha1 = alpha1 + self.alpha2 = alpha2 + self.bins_sr = bins_sr + self.bins_fs = bins_fs + self.decay = decay + if sample_ratios is None: # the default values for sample_ratios + sample_ratios = [0.8, 0.7, 0.6, 0.5, 0.4] + if sub_weights is None: # the default values for sub_weights + sub_weights = [1.0, 0.2, 0.2, 0.2, 0.2, 0.2] + if not len(sample_ratios) == bins_fs: + raise ValueError("The length of sample_ratios should be equal to bins_fs.") + self.sample_ratios = sample_ratios + if not len(sub_weights) == num_models: + raise ValueError("The length of sub_weights should be equal to num_models.") + self.sub_weights = sub_weights + self.epochs = epochs + self.logger = get_module_logger("DEnsembleModel") + self.logger.info("Double Ensemble Model...") + self.ensemble = [] # the current ensemble model, a list contains all the sub-models + self.sub_features = [] # the features for each sub model in the form of pandas.Index + self.params = {"objective": loss} + self.params.update(kwargs) + self.loss = loss + + def fit(self, dataset: DatasetH): + df_train, df_valid = dataset.prepare( + ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) + x_train, y_train = df_train["feature"], df_train["label"] + # initialize the sample weights + N, F = x_train.shape + weights = pd.Series(np.ones(N, dtype=float)) + # initialize the features + features = x_train.columns + pred_sub = pd.DataFrame(np.zeros((N, self.num_models), dtype=float), index=x_train.index) + # train sub-models + for k in range(self.num_models): + self.sub_features.append(features) + self.logger.info("Training sub-model: ({}/{})".format(k + 1, self.num_models)) + model_k = self.train_submodel(df_train, df_valid, weights, features) + self.ensemble.append(model_k) + # no further sample re-weight and feature selection needed for the last sub-model + if k + 1 == self.num_models: + break + + self.logger.info("Retrieving loss curve and loss values...") + loss_curve = self.retrieve_loss_curve(model_k, df_train, features) + pred_k = self.predict_sub(model_k, df_train, features) + pred_sub.iloc[:, k] = pred_k + pred_ensemble = pred_sub.iloc[:, : k + 1].mean(axis=1) + loss_values = pd.Series(self.get_loss(y_train.values.squeeze(), pred_ensemble.values)) + + if self.enable_sr: + self.logger.info("Sample re-weighting...") + weights = self.sample_reweight(loss_curve, loss_values, k + 1) + + if self.enable_fs: + self.logger.info("Feature selection...") + features = self.feature_selection(df_train, loss_values) + + def train_submodel(self, df_train, df_valid, weights, features): + dtrain, dvalid = self._prepare_data_gbm(df_train, df_valid, weights, features) + evals_result = dict() + model = lgb.train( + self.params, + dtrain, + num_boost_round=self.epochs, + valid_sets=[dtrain, dvalid], + valid_names=["train", "valid"], + verbose_eval=20, + evals_result=evals_result, + ) + evals_result["train"] = list(evals_result["train"].values())[0] + evals_result["valid"] = list(evals_result["valid"].values())[0] + return model + + def _prepare_data_gbm(self, df_train, df_valid, weights, features): + x_train, y_train = df_train["feature"].loc[:, features], df_train["label"] + x_valid, y_valid = df_valid["feature"].loc[:, features], df_valid["label"] + + # Lightgbm need 1D array as its label + if y_train.values.ndim == 2 and y_train.values.shape[1] == 1: + y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values) + else: + raise ValueError("LightGBM doesn't support multi-label training") + + dtrain = lgb.Dataset(x_train, label=y_train, weight=weights) + dvalid = lgb.Dataset(x_valid, label=y_valid) + return dtrain, dvalid + + def sample_reweight(self, loss_curve, loss_values, k_th): + """ + the SR module of Double Ensemble + :param loss_curve: the shape is NxT + the loss curve for the previous sub-model, where the element (i, t) if the error on the i-th sample + after the t-th iteration in the training of the previous sub-model. + :param loss_values: the shape is N + the loss of the current ensemble on the i-th sample. + :param k_th: the index of the current sub-model, starting from 1 + :return: weights + the weights for all the samples. + """ + # normalize loss_curve and loss_values with ranking + loss_curve_norm = loss_curve.rank(axis=0, pct=True) + loss_values_norm = (-loss_values).rank(pct=True) + + # calculate l_start and l_end from loss_curve + N, T = loss_curve.shape + part = np.maximum(int(T * 0.1), 1) + l_start = loss_curve_norm.iloc[:, :part].mean(axis=1) + l_end = loss_curve_norm.iloc[:, -part:].mean(axis=1) + + # calculate h-value for each sample + h1 = loss_values_norm + h2 = (l_end / l_start).rank(pct=True) + h = pd.DataFrame({"h_value": self.alpha1 * h1 + self.alpha2 * h2}) + + # calculate weights + h["bins"] = pd.cut(h["h_value"], self.bins_sr) + h_avg = h.groupby("bins")["h_value"].mean() + weights = pd.Series(np.zeros(N, dtype=float)) + for i_b, b in enumerate(h_avg.index): + weights[h["bins"] == b] = 1.0 / (self.decay ** k_th * h_avg[i_b] + 0.1) + return weights + + def feature_selection(self, df_train, loss_values): + """ + the FS module of Double Ensemble + :param df_train: the shape is NxF + :param loss_values: the shape is N + the loss of the current ensemble on the i-th sample. + :return: res_feat: in the form of pandas.Index + + """ + x_train, y_train = df_train["feature"], df_train["label"] + features = x_train.columns + N, F = x_train.shape + g = pd.DataFrame({"g_value": np.zeros(F, dtype=float)}) + M = len(self.ensemble) + + # shuffle specific columns and calculate g-value for each feature + x_train_tmp = x_train.copy() + for i_f, feat in enumerate(features): + x_train_tmp.loc[:, feat] = np.random.permutation(x_train_tmp.loc[:, feat].values) + pred = pd.Series(np.zeros(N), index=x_train_tmp.index) + for i_s, submodel in enumerate(self.ensemble): + pred += ( + pd.Series( + submodel.predict(x_train_tmp.loc[:, self.sub_features[i_s]].values), index=x_train_tmp.index + ) + / M + ) + loss_feat = self.get_loss(y_train.values.squeeze(), pred.values) + g.loc[i_f, "g_value"] = np.mean(loss_feat - loss_values) / (np.std(loss_feat - loss_values) + 1e-7) + x_train_tmp.loc[:, feat] = x_train.loc[:, feat].copy() + + # one column in train features is all-nan # if g['g_value'].isna().any() + g["g_value"].replace(np.nan, 0, inplace=True) + + # divide features into bins_fs bins + g["bins"] = pd.cut(g["g_value"], self.bins_fs) + + # randomly sample features from bins to construct the new features + res_feat = [] + sorted_bins = sorted(g["bins"].unique(), reverse=True) + for i_b, b in enumerate(sorted_bins): + b_feat = features[g["bins"] == b] + num_feat = int(np.ceil(self.sample_ratios[i_b] * len(b_feat))) + res_feat = res_feat + np.random.choice(b_feat, size=num_feat, replace=False).tolist() + return pd.Index(set(res_feat)) + + def get_loss(self, label, pred): + if self.loss == "mse": + return (label - pred) ** 2 + else: + raise ValueError("not implemented yet") + + def retrieve_loss_curve(self, model, df_train, features): + if self.base_model == "gbm": + num_trees = model.num_trees() + x_train, y_train = df_train["feature"].loc[:, features], df_train["label"] + # Lightgbm need 1D array as its label + if y_train.values.ndim == 2 and y_train.values.shape[1] == 1: + y_train = np.squeeze(y_train.values) + else: + raise ValueError("LightGBM doesn't support multi-label training") + + N = x_train.shape[0] + loss_curve = pd.DataFrame(np.zeros((N, num_trees))) + pred_tree = np.zeros(N, dtype=float) + for i_tree in range(num_trees): + pred_tree += model.predict(x_train.values, start_iteration=i_tree, num_iteration=1) + loss_curve.iloc[:, i_tree] = self.get_loss(y_train, pred_tree) + else: + raise ValueError("not implemented yet") + return loss_curve + + def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): + if self.ensemble is None: + raise ValueError("model is not fitted yet!") + x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + pred = pd.Series(np.zeros(x_test.shape[0]), index=x_test.index) + for i_sub, submodel in enumerate(self.ensemble): + feat_sub = self.sub_features[i_sub] + pred += ( + pd.Series(submodel.predict(x_test.loc[:, feat_sub].values), index=x_test.index) + * self.sub_weights[i_sub] + ) + return pred + + def predict_sub(self, submodel, df_data, features): + x_data, y_data = df_data["feature"].loc[:, features], df_data["label"] + pred_sub = pd.Series(submodel.predict(x_data.values), index=x_data.index) + return pred_sub + + def get_feature_importance(self, *args, **kwargs) -> pd.Series: + """get feature importance + + Notes + ----- + parameters reference: + https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.Booster.html?highlight=feature_importance#lightgbm.Booster.feature_importance + """ + res = [] + for _model, _weight in zip(self.ensemble, self.sub_weights): + res.append(pd.Series(_model.feature_importance(*args, **kwargs), index=_model.feature_name()) * _weight) + return pd.concat(res, axis=1, sort=False).sum(axis=1).sort_values(ascending=False) diff --git a/qlib/contrib/model/gbdt.py b/qlib/contrib/model/gbdt.py index 463cf8f4f..1a7cf7fba 100644 --- a/qlib/contrib/model/gbdt.py +++ b/qlib/contrib/model/gbdt.py @@ -8,9 +8,10 @@ from typing import Text, Union from ...model.base import ModelFT from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP +from ...model.interpret.base import LightGBMFInt -class LGBModel(ModelFT): +class LGBModel(ModelFT, LightGBMFInt): """LightGBM Model""" def __init__(self, loss="mse", **kwargs): @@ -33,8 +34,8 @@ class LGBModel(ModelFT): else: raise ValueError("LightGBM doesn't support multi-label training") - dtrain = lgb.Dataset(x_train.values, label=y_train) - dvalid = lgb.Dataset(x_valid.values, label=y_valid) + dtrain = lgb.Dataset(x_train, label=y_train) + dvalid = lgb.Dataset(x_valid, label=y_valid) return dtrain, dvalid def fit( diff --git a/qlib/contrib/model/highfreq_gdbt_model.py b/qlib/contrib/model/highfreq_gdbt_model.py index 5a2eeb50a..04d6ab9d5 100644 --- a/qlib/contrib/model/highfreq_gdbt_model.py +++ b/qlib/contrib/model/highfreq_gdbt_model.py @@ -1,17 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import warnings import numpy as np import pandas as pd import lightgbm as lgb -from qlib.model.base import ModelFT -from qlib.data.dataset import DatasetH -from qlib.data.dataset.handler import DataHandlerLP -import warnings +from ...model.base import ModelFT +from ...data.dataset import DatasetH +from ...data.dataset.handler import DataHandlerLP +from ...model.interpret.base import LightGBMFInt -class HFLGBModel(ModelFT): +class HFLGBModel(ModelFT, LightGBMFInt): """LightGBM Model for high frequency prediction""" def __init__(self, loss="mse", **kwargs): @@ -97,8 +98,8 @@ class HFLGBModel(ModelFT): else: raise ValueError("LightGBM doesn't support multi-label training") - dtrain = lgb.Dataset(x_train.values, label=y_train) - dvalid = lgb.Dataset(x_valid.values, label=y_valid) + dtrain = lgb.Dataset(x_train, label=y_train) + dvalid = lgb.Dataset(x_valid, label=y_valid) return dtrain, dvalid def fit( diff --git a/qlib/contrib/model/xgboost.py b/qlib/contrib/model/xgboost.py index cbba14678..2a38f4fe1 100755 --- a/qlib/contrib/model/xgboost.py +++ b/qlib/contrib/model/xgboost.py @@ -8,9 +8,10 @@ from typing import Text, Union from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP +from ...model.interpret.base import FeatureInt -class XGBModel(Model): +class XGBModel(Model, FeatureInt): """XGBModel Model""" def __init__(self, **kwargs): @@ -42,8 +43,8 @@ class XGBModel(Model): else: raise ValueError("XGBoost doesn't support multi-label training") - dtrain = xgb.DMatrix(x_train.values, label=y_train_1d) - dvalid = xgb.DMatrix(x_valid.values, label=y_valid_1d) + dtrain = xgb.DMatrix(x_train, label=y_train_1d) + dvalid = xgb.DMatrix(x_valid, label=y_valid_1d) self.model = xgb.train( self._params, dtrain=dtrain, @@ -62,3 +63,13 @@ class XGBModel(Model): raise ValueError("model is not fitted yet!") x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) return pd.Series(self.model.predict(xgb.DMatrix(x_test.values)), index=x_test.index) + + def get_feature_importance(self, *args, **kwargs) -> pd.Series: + """get feature importance + + Notes + ------- + parameters reference: + https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.Booster.get_score + """ + return pd.Series(self.model.get_score(*args, **kwargs)).sort_values(ascending=False) diff --git a/qlib/model/interpret/__init__.py b/qlib/model/interpret/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/qlib/model/interpret/base.py b/qlib/model/interpret/base.py new file mode 100644 index 000000000..70d79faca --- /dev/null +++ b/qlib/model/interpret/base.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Interfaces to interpret models +""" + +import pandas as pd +from abc import abstractmethod + + +class FeatureInt: + """Feature (Int)erpreter""" + + @abstractmethod + def get_feature_importance(self) -> pd.Series: + ... + + +class LightGBMFInt(FeatureInt): + """LightGBM (F)eature (Int)erpreter""" + + def get_feature_importance(self, *args, **kwargs) -> pd.Series: + """get feature importance + + Notes + ----- + parameters reference: + https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.Booster.html?highlight=feature_importance#lightgbm.Booster.feature_importance + """ + return pd.Series(self.model.feature_importance(*args, **kwargs), index=self.model.feature_name()).sort_values( + ascending=False + ) From c12c861b7a99093d09bf426f2bb89b1529fa69dc Mon Sep 17 00:00:00 2001 From: al Date: Thu, 27 May 2021 19:37:57 +0800 Subject: [PATCH 18/24] Remove repeated package from requirements --- scripts/data_collector/yahoo/requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/data_collector/yahoo/requirements.txt b/scripts/data_collector/yahoo/requirements.txt index 3e3e0d1e0..5f08026e5 100644 --- a/scripts/data_collector/yahoo/requirements.txt +++ b/scripts/data_collector/yahoo/requirements.txt @@ -5,5 +5,4 @@ numpy pandas tqdm lxml -loguru yahooquery From 7ceec37848b7840bd4c4f995afc26ba07788038f Mon Sep 17 00:00:00 2001 From: al Date: Thu, 27 May 2021 22:35:43 +0800 Subject: [PATCH 19/24] Update integration.rst Fix typo --- docs/start/integration.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/start/integration.rst b/docs/start/integration.rst index 3ecae1090..3d4043826 100644 --- a/docs/start/integration.rst +++ b/docs/start/integration.rst @@ -82,7 +82,7 @@ The Custom models need to inherit `qlib.model.base.Model <../reference/api.html# return pd.Series(self.model.predict(x_test.values), index=x_test.index) - Override the `finetune` method (Optional) - - This method is optional to the users, and when users one to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`. + - This method is optional to the users. When users want to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`. - The parameters must include the parameter `dataset`. - Code Example: In the following example, users will use `LightGBM` as the model and finetune it. .. code-block:: Python From e409bee9b9e779d73296c11b533d9cc6026f1872 Mon Sep 17 00:00:00 2001 From: al Date: Fri, 28 May 2021 07:54:45 +0800 Subject: [PATCH 20/24] Update report.rst typo --- docs/component/report.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/component/report.rst b/docs/component/report.rst index 7d8053c78..6f4bff4f9 100644 --- a/docs/component/report.rst +++ b/docs/component/report.rst @@ -101,7 +101,7 @@ Graphical Result - Axis Y: - `ic` The `Pearson correlation coefficient` series between `label` and `prediction score`. - In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Featrue `_ for more details. + In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Feature `_ for more details. - `rank_ic` The `Spearman's rank correlation coefficient` series between `label` and `prediction score`. From 98eacf8f88de66aa88ee877004a76c2a60f7c5f5 Mon Sep 17 00:00:00 2001 From: zhupr Date: Fri, 28 May 2021 13:24:47 +0800 Subject: [PATCH 21/24] add test/config.py --- examples/highfreq/workflow.py | 17 +-- .../LightGBM/hyperparameter_158.py | 58 +++------- .../LightGBM/hyperparameter_360.py | 57 +++------ examples/model_interpreter.py | 81 ------------- examples/model_interpreter/feature.py | 32 ++++++ .../model_rolling/task_manager_rolling.py | 62 +--------- .../online_srv/online_management_simulate.py | 62 +--------- .../online_srv/rolling_online_management.py | 65 ++--------- examples/online_srv/update_online_pred.py | 49 +------- examples/rolling_process_data/workflow.py | 8 +- examples/run_all_model.py | 11 +- examples/workflow_by_code.py | 76 ++---------- qlib/model/interpret/base.py | 9 +- qlib/tests/__init__.py | 18 ++- qlib/tests/config.py | 108 ++++++++++++++++++ qlib/tests/data.py | 7 ++ tests/dataset_tests/test_datalayer.py | 22 +--- tests/test_all_pipeline.py | 66 ++--------- tests/test_contrib_workflow.py | 65 ++--------- tests/test_get_data.py | 5 +- tests/test_register_ops.py | 5 - 21 files changed, 246 insertions(+), 637 deletions(-) delete mode 100644 examples/model_interpreter.py create mode 100644 examples/model_interpreter/feature.py create mode 100644 qlib/tests/config.py diff --git a/examples/highfreq/workflow.py b/examples/highfreq/workflow.py index 5660ab2e9..856885b25 100644 --- a/examples/highfreq/workflow.py +++ b/examples/highfreq/workflow.py @@ -1,24 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys import fire -from pathlib import Path import qlib import pickle -import numpy as np -import pandas as pd from qlib.config import REG_CN, HIGH_FREQ_CONFIG -from qlib.contrib.model.gbdt import LGBModel -from qlib.contrib.data.handler import Alpha158 -from qlib.contrib.strategy.strategy import TopkDropoutStrategy -from qlib.contrib.evaluate import ( - backtest as normal_backtest, - risk_analysis, -) -from qlib.utils import init_instance_by_config, exists_qlib_data +from qlib.utils import init_instance_by_config from qlib.data.dataset.handler import DataHandlerLP from qlib.data.ops import Operators from qlib.data.data import Cal @@ -96,9 +85,7 @@ class HighfreqWorkflow: # use yahoo_cn_1min data QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF} provider_uri = QLIB_INIT_CONFIG.get("provider_uri") - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN) + GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN) qlib.init(**QLIB_INIT_CONFIG) def _prepare_calender_cache(self): diff --git a/examples/hyperparameter/LightGBM/hyperparameter_158.py b/examples/hyperparameter/LightGBM/hyperparameter_158.py index 5e4887a14..9e4557ed5 100644 --- a/examples/hyperparameter/LightGBM/hyperparameter_158.py +++ b/examples/hyperparameter/LightGBM/hyperparameter_158.py @@ -1,46 +1,9 @@ import qlib -from qlib.config import REG_CN -from qlib.utils import exists_qlib_data, init_instance_by_config import optuna - -provider_uri = "~/.qlib/qlib_data/cn_data" -if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - sys.path.append(str(scripts_dir)) - from get_data import GetData - - GetData().qlib_data(target_dir=provider_uri, region="cn") -qlib.init(provider_uri=provider_uri, region="cn") - -market = "csi300" -benchmark = "SH000300" - -data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, -} -dataset_task = { - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, -} -dataset = init_instance_by_config(dataset_task["dataset"]) +from qlib.config import REG_CN +from qlib.utils import init_instance_by_config +from qlib.tests.config import CSI300_DATASET_CONFIG +from qlib.tests.data import GetData def objective(trial): @@ -65,12 +28,19 @@ def objective(trial): }, }, } - evals_result = dict() model = init_instance_by_config(task["model"]) model.fit(dataset, evals_result=evals_result) return min(evals_result["valid"]) -study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3") -study.optimize(objective, n_jobs=6) +if __name__ == "__main__": + + provider_uri = "~/.qlib/qlib_data/cn_data" + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + qlib.init(provider_uri=provider_uri, region="cn") + + dataset = init_instance_by_config(CSI300_DATASET_CONFIG) + + study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3") + study.optimize(objective, n_jobs=6) diff --git a/examples/hyperparameter/LightGBM/hyperparameter_360.py b/examples/hyperparameter/LightGBM/hyperparameter_360.py index 8b498e912..a8127014b 100644 --- a/examples/hyperparameter/LightGBM/hyperparameter_360.py +++ b/examples/hyperparameter/LightGBM/hyperparameter_360.py @@ -1,46 +1,11 @@ import qlib -from qlib.config import REG_CN -from qlib.utils import exists_qlib_data, init_instance_by_config import optuna +from qlib.config import REG_CN +from qlib.utils import init_instance_by_config +from qlib.tests.data import GetData +from qlib.tests.config import get_dataset_config, CSI300_MARKET, DATASET_ALPHA360_CLASS -provider_uri = "~/.qlib/qlib_data/cn_data" -if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - sys.path.append(str(scripts_dir)) - from get_data import GetData - - GetData().qlib_data(target_dir=provider_uri, region="cn") -qlib.init(provider_uri=provider_uri, region="cn") - -market = "csi300" -benchmark = "SH000300" - -data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, -} -dataset_task = { - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha360", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, -} -dataset = init_instance_by_config(dataset_task["dataset"]) +DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA360_CLASS) def objective(trial): @@ -72,5 +37,13 @@ def objective(trial): return min(evals_result["valid"]) -study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3") -study.optimize(objective, n_jobs=6) +if __name__ == "__main__": + + provider_uri = "~/.qlib/qlib_data/cn_data" + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + qlib.init(provider_uri=provider_uri, region=REG_CN) + + dataset = init_instance_by_config(DATASET_CONFIG) + + study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3") + study.optimize(objective, n_jobs=6) diff --git a/examples/model_interpreter.py b/examples/model_interpreter.py deleted file mode 100644 index 1d9230b8c..000000000 --- a/examples/model_interpreter.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - - -import qlib -from qlib.config import REG_CN - -from qlib.utils import exists_qlib_data, init_instance_by_config -from qlib.tests.data import GetData - -market = "csi300" -benchmark = "SH000300" - -################################### -# config -################################### -data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, -} - -task = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - "kwargs": { - "loss": "mse", - "colsample_bytree": 0.8879, - "learning_rate": 0.0421, - "subsample": 0.8789, - "lambda_l1": 205.6999, - "lambda_l2": 580.9768, - "max_depth": 8, - "num_leaves": 210, - "num_threads": 20, - }, - }, - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, -} - - -if __name__ == "__main__": - - # use default data - provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) - - qlib.init(provider_uri=provider_uri, region=REG_CN) - - ################################### - # train model - ################################### - # model initialization - model = init_instance_by_config(task["model"]) - dataset = init_instance_by_config(task["dataset"]) - model.fit(dataset) - - # get model feature importance - feature_importance = model.get_feature_importance() - print("feature importance:") - print(feature_importance) diff --git a/examples/model_interpreter/feature.py b/examples/model_interpreter/feature.py new file mode 100644 index 000000000..1c29fda6e --- /dev/null +++ b/examples/model_interpreter/feature.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +import qlib +from qlib.config import REG_CN + +from qlib.utils import init_instance_by_config +from qlib.tests.data import GetData +from qlib.tests.config import CSI300_GBDT_TASK + + +if __name__ == "__main__": + + # use default data + provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + + qlib.init(provider_uri=provider_uri, region=REG_CN) + + ################################### + # train model + ################################### + # model initialization + model = init_instance_by_config(CSI300_GBDT_TASK["model"]) + dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) + model.fit(dataset) + + # get model feature importance + feature_importance = model.get_feature_importance() + print("feature importance:") + print(feature_importance) diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py index 4f3ac04b1..9ef8694bf 100644 --- a/examples/model_rolling/task_manager_rolling.py +++ b/examples/model_rolling/task_manager_rolling.py @@ -17,63 +17,7 @@ from qlib.workflow.task.manage import TaskManager from qlib.workflow.task.collect import RecorderCollector from qlib.model.ens.group import RollingGroup from qlib.model.trainer import TrainerRM - - -data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": "csi100", -} - -dataset_config = { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, -} - -record_config = [ - { - "class": "SignalRecord", - "module_path": "qlib.workflow.record_temp", - }, - { - "class": "SigAnaRecord", - "module_path": "qlib.workflow.record_temp", - }, -] - -# use lgb -task_lgb_config = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - }, - "dataset": dataset_config, - "record": record_config, -} - -# use xgboost -task_xgboost_config = { - "model": { - "class": "XGBModel", - "module_path": "qlib.contrib.model.xgboost", - }, - "dataset": dataset_config, - "record": record_config, -} +from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG class RollingTaskExample: @@ -85,11 +29,13 @@ class RollingTaskExample: task_db_name="rolling_db", experiment_name="rolling_exp", task_pool="rolling_task", - task_config=[task_xgboost_config, task_lgb_config], + task_config=None, rolling_step=550, rolling_type=RollingGen.ROLL_SD, ): # TaskManager config + if task_config is None: + task_config = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG] mongo_conf = { "task_url": task_url, "task_db_name": task_db_name, diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py index 4bb5022ee..8c9e77bf7 100644 --- a/examples/online_srv/online_management_simulate.py +++ b/examples/online_srv/online_management_simulate.py @@ -13,63 +13,7 @@ from qlib.workflow.online.manager import OnlineManager from qlib.workflow.online.strategy import RollingStrategy from qlib.workflow.task.gen import RollingGen from qlib.workflow.task.manage import TaskManager - - -data_handler_config = { - "start_time": "2018-01-01", - "end_time": "2018-10-31", - "fit_start_time": "2018-01-01", - "fit_end_time": "2018-03-31", - "instruments": "csi100", -} - -dataset_config = { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2018-01-01", "2018-03-31"), - "valid": ("2018-04-01", "2018-05-31"), - "test": ("2018-06-01", "2018-09-10"), - }, - }, -} - -record_config = [ - { - "class": "SignalRecord", - "module_path": "qlib.workflow.record_temp", - }, - { - "class": "SigAnaRecord", - "module_path": "qlib.workflow.record_temp", - }, -] - -# use lgb model -task_lgb_config = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - }, - "dataset": dataset_config, - "record": record_config, -} - -# use xgboost model -task_xgboost_config = { - "model": { - "class": "XGBModel", - "module_path": "qlib.contrib.model.xgboost", - }, - "dataset": dataset_config, - "record": record_config, -} +from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG class OnlineSimulationExample: @@ -84,7 +28,7 @@ class OnlineSimulationExample: rolling_step=80, start_time="2018-09-10", end_time="2018-10-31", - tasks=[task_xgboost_config, task_lgb_config], + tasks=None, ): """ Init OnlineManagerExample. @@ -101,6 +45,8 @@ class OnlineSimulationExample: end_time (str, optional): the end time of simulating. Defaults to "2018-10-31". tasks (dict or list[dict]): a set of the task config waiting for rolling and training """ + if tasks is None: + tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG] self.exp_name = exp_name self.task_pool = task_pool self.start_time = start_time diff --git a/examples/online_srv/rolling_online_management.py b/examples/online_srv/rolling_online_management.py index 25b8b2a0c..592f1f866 100644 --- a/examples/online_srv/rolling_online_management.py +++ b/examples/online_srv/rolling_online_management.py @@ -17,62 +17,7 @@ from qlib.workflow import R from qlib.workflow.online.strategy import RollingStrategy from qlib.workflow.task.gen import RollingGen from qlib.workflow.online.manager import OnlineManager - -data_handler_config = { - "start_time": "2013-01-01", - "end_time": "2020-09-25", - "fit_start_time": "2013-01-01", - "fit_end_time": "2014-12-31", - "instruments": "csi100", -} - -dataset_config = { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2013-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2015-12-31"), - "test": ("2016-01-01", "2020-07-10"), - }, - }, -} - -record_config = [ - { - "class": "SignalRecord", - "module_path": "qlib.workflow.record_temp", - }, - { - "class": "SigAnaRecord", - "module_path": "qlib.workflow.record_temp", - }, -] - -# use lgb model -task_lgb_config = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - }, - "dataset": dataset_config, - "record": record_config, -} - -# use xgboost model -task_xgboost_config = { - "model": { - "class": "XGBModel", - "module_path": "qlib.contrib.model.xgboost", - }, - "dataset": dataset_config, - "record": record_config, -} +from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG class RollingOnlineExample: @@ -83,9 +28,13 @@ class RollingOnlineExample: task_url="mongodb://10.0.0.4:27017/", task_db_name="rolling_db", rolling_step=550, - tasks=[task_xgboost_config], - add_tasks=[task_lgb_config], + tasks=None, + add_tasks=None, ): + if add_tasks is None: + add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG] + if tasks is None: + tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG] mongo_conf = { "task_url": task_url, # your MongoDB url "task_db_name": task_db_name, # database name diff --git a/examples/online_srv/update_online_pred.py b/examples/online_srv/update_online_pred.py index 228bc0dac..8afc66553 100644 --- a/examples/online_srv/update_online_pred.py +++ b/examples/online_srv/update_online_pred.py @@ -7,56 +7,19 @@ There are two parts including first_train and update_online_pred. Firstly, we will finish the training and set the trained models to the `online` models. Next, we will finish updating online predictions. """ +import copy import fire import qlib from qlib.config import REG_CN from qlib.model.trainer import task_train from qlib.workflow.online.utils import OnlineToolR +from qlib.tests.config import CSI300_GBDT_TASK -data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": "csi100", -} +task = copy.deepcopy(CSI300_GBDT_TASK) -task = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - "kwargs": { - "loss": "mse", - "colsample_bytree": 0.8879, - "learning_rate": 0.0421, - "subsample": 0.8789, - "lambda_l1": 205.6999, - "lambda_l2": 580.9768, - "max_depth": 8, - "num_leaves": 210, - "num_threads": 20, - }, - }, - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, - "record": { - "class": "SignalRecord", - "module_path": "qlib.workflow.record_temp", - }, +task["record"] = { + "class": "SignalRecord", + "module_path": "qlib.workflow.record_temp", } diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index 5757aaa87..bfa2d1ec4 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -4,13 +4,11 @@ import qlib import fire import pickle -import pandas as pd from datetime import datetime from qlib.config import REG_CN from qlib.data.dataset.handler import DataHandlerLP -from qlib.contrib.data.handler import Alpha158 -from qlib.utils import exists_qlib_data, init_instance_by_config +from qlib.utils import init_instance_by_config from qlib.tests.data import GetData @@ -25,9 +23,7 @@ class RollingDataWorkflow: """initialize qlib""" # use yahoo_cn_1min data provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) qlib.init(provider_uri=provider_uri, region=REG_CN) def _dump_pre_handler(self, path): diff --git a/examples/run_all_model.py b/examples/run_all_model.py index d587eff15..8875b9aa1 100644 --- a/examples/run_all_model.py +++ b/examples/run_all_model.py @@ -5,13 +5,11 @@ import os import sys import fire import time -import venv import glob import shutil import signal import inspect import tempfile -import traceback import functools import statistics import subprocess @@ -23,8 +21,7 @@ from pprint import pprint import qlib from qlib.config import REG_CN from qlib.workflow import R -from qlib.workflow.cli import workflow -from qlib.utils import exists_qlib_data +from qlib.tests.data import GetData # init qlib @@ -39,12 +36,8 @@ exp_manager = { "default_exp_name": "Experiment", }, } -if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) - from get_data import GetData - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) +GetData().qlib_data(target_dir=provider_uri, region=REG_CN) qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager) # decorator to check the arguments diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py index d5dab8917..2e84cadc2 100644 --- a/examples/workflow_by_code.py +++ b/examples/workflow_by_code.py @@ -1,82 +1,22 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys -from pathlib import Path - import qlib -import pandas as pd from qlib.config import REG_CN -from qlib.contrib.model.gbdt import LGBModel -from qlib.contrib.data.handler import Alpha158 -from qlib.contrib.strategy.strategy import TopkDropoutStrategy -from qlib.contrib.evaluate import ( - backtest as normal_backtest, - risk_analysis, -) -from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict +from qlib.utils import init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord, PortAnaRecord from qlib.tests.data import GetData +from qlib.tests.config import CSI300_BENCH, CSI300_GBDT_TASK + if __name__ == "__main__": # use default data provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) - + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) qlib.init(provider_uri=provider_uri, region=REG_CN) - market = "csi300" - benchmark = "SH000300" - - ################################### - # train model - ################################### - data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, - } - - task = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - "kwargs": { - "loss": "mse", - "colsample_bytree": 0.8879, - "learning_rate": 0.0421, - "subsample": 0.8789, - "lambda_l1": 205.6999, - "lambda_l2": 580.9768, - "max_depth": 8, - "num_leaves": 210, - "num_threads": 20, - }, - }, - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, - } - port_analysis_config = { "strategy": { "class": "TopkDropoutStrategy", @@ -90,7 +30,7 @@ if __name__ == "__main__": "verbose": False, "limit_threshold": 0.095, "account": 100000000, - "benchmark": benchmark, + "benchmark": CSI300_BENCH, "deal_price": "close", "open_cost": 0.0005, "close_cost": 0.0015, @@ -100,8 +40,8 @@ if __name__ == "__main__": } # model initialization - model = init_instance_by_config(task["model"]) - dataset = init_instance_by_config(task["dataset"]) + model = init_instance_by_config(CSI300_GBDT_TASK["model"]) + dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) # NOTE: This line is optional # It demonstrates that the dataset can be used standalone. @@ -110,7 +50,7 @@ if __name__ == "__main__": # start exp with R.start(experiment_name="workflow"): - R.log_params(**flatten_dict(task)) + R.log_params(**flatten_dict(CSI300_GBDT_TASK)) model.fit(dataset) R.save_objects(**{"params.pkl": model}) diff --git a/qlib/model/interpret/base.py b/qlib/model/interpret/base.py index 70d79faca..57cc7929a 100644 --- a/qlib/model/interpret/base.py +++ b/qlib/model/interpret/base.py @@ -14,7 +14,14 @@ class FeatureInt: @abstractmethod def get_feature_importance(self) -> pd.Series: - ... + """get feature importance + + Returns + ------- + The index is the feature name. + + The greater the value, the higher importance. + """ class LightGBMFInt(FeatureInt): diff --git a/qlib/tests/__init__.py b/qlib/tests/__init__.py index 8b53bc53a..e72f000ba 100644 --- a/qlib/tests/__init__.py +++ b/qlib/tests/__init__.py @@ -1,6 +1,4 @@ -import sys import unittest -from ..utils import exists_qlib_data from .data import GetData from .. import init from ..config import REG_CN @@ -14,14 +12,12 @@ class TestAutoData(unittest.TestCase): @classmethod def setUpClass(cls) -> None: # use default data - if not exists_qlib_data(cls.provider_uri): - print(f"Qlib data is not found in {cls.provider_uri}") - GetData().qlib_data( - name="qlib_data_simple", - region="cn", - interval="1d", - target_dir=cls.provider_uri, - delete_old=False, - ) + GetData().qlib_data( + name="qlib_data_simple", + region=REG_CN, + interval="1d", + target_dir=cls.provider_uri, + delete_old=False, + ) init(provider_uri=cls.provider_uri, region=REG_CN, **cls._setup_kwargs) diff --git a/qlib/tests/config.py b/qlib/tests/config.py new file mode 100644 index 000000000..80461f6f9 --- /dev/null +++ b/qlib/tests/config.py @@ -0,0 +1,108 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +CSI300_MARKET = "csi300" +CSI100_MARKET = "csi100" + +CSI300_BENCH = "SH000300" + +DATASET_ALPHA158_CLASS = "Alpha158" +DATASET_ALPHA360_CLASS = "Alpha360" + +################################### +# config +################################### + + +GBDT_MODEL = { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + "kwargs": { + "loss": "mse", + "colsample_bytree": 0.8879, + "learning_rate": 0.0421, + "subsample": 0.8789, + "lambda_l1": 205.6999, + "lambda_l2": 580.9768, + "max_depth": 8, + "num_leaves": 210, + "num_threads": 20, + }, +} + + +RECORD_CONFIG = [ + { + "class": "SignalRecord", + "module_path": "qlib.workflow.record_temp", + }, + { + "class": "SigAnaRecord", + "module_path": "qlib.workflow.record_temp", + }, +] + + +def get_data_handler_config(market=CSI300_MARKET): + return { + "start_time": "2008-01-01", + "end_time": "2020-08-01", + "fit_start_time": "2008-01-01", + "fit_end_time": "2014-12-31", + "instruments": market, + } + + +def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLASS): + return { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": dataset_class, + "module_path": "qlib.contrib.data.handler", + "kwargs": get_data_handler_config(market), + }, + "segments": { + "train": ("2008-01-01", "2014-12-31"), + "valid": ("2015-01-01", "2016-12-31"), + "test": ("2017-01-01", "2020-08-01"), + }, + }, + } + + +def get_gbdt_task(market=CSI300_MARKET): + return { + "model": GBDT_MODEL, + "dataset": get_dataset_config(market), + } + + +def get_record_lgb_config(market=CSI300_MARKET): + return { + "model": { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + }, + "dataset": get_dataset_config(market), + "record": RECORD_CONFIG, + } + + +def get_record_xgboost_config(market=CSI300_MARKET): + return { + "model": { + "class": "XGBModel", + "module_path": "qlib.contrib.model.xgboost", + }, + "dataset": get_dataset_config(market), + "record": RECORD_CONFIG, + } + + +CSI300_DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET) +CSI300_GBDT_TASK = get_gbdt_task(market=CSI300_MARKET) + +CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(market=CSI100_MARKET) +CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(market=CSI100_MARKET) diff --git a/qlib/tests/data.py b/qlib/tests/data.py index 3bf6a2c96..0f226c6b1 100644 --- a/qlib/tests/data.py +++ b/qlib/tests/data.py @@ -10,6 +10,7 @@ import datetime from tqdm import tqdm from pathlib import Path from loguru import logger +from qlib.utils import exists_qlib_data class GetData: @@ -112,6 +113,7 @@ class GetData: interval="1d", region="cn", delete_old=True, + exists_skip=True, ): """download cn qlib data from remote @@ -129,6 +131,8 @@ class GetData: data region, value from [cn, us], by default cn delete_old: bool delete an existing directory, by default True + exists_skip: bool + exists skip, by default True Examples --------- @@ -140,6 +144,9 @@ class GetData: ------- """ + if exists_skip and exists_qlib_data(target_dir): + return + qlib_version = ".".join(re.findall(r"(\d+)\.+", qlib.__version__)) def _get_file_name(v): diff --git a/tests/dataset_tests/test_datalayer.py b/tests/dataset_tests/test_datalayer.py index 9d282b167..bdd0d915b 100644 --- a/tests/dataset_tests/test_datalayer.py +++ b/tests/dataset_tests/test_datalayer.py @@ -1,26 +1,10 @@ -import sys -from pathlib import Path -import qlib -from qlib.data import D -from qlib.config import REG_CN import unittest import numpy as np -from qlib.utils import exists_qlib_data +from qlib.data import D +from qlib.tests import TestAutoData -class TestDataset(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - # use default data - provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath("scripts"))) - from get_data import GetData - - GetData().qlib_data(name="qlib_data_simple", target_dir=provider_uri) - qlib.init(provider_uri=provider_uri, region=REG_CN) - +class TestDataset(TestAutoData): def testCSI300(self): close_p = D.features(D.instruments("csi300"), ["$close"]) size = close_p.groupby("datetime").size() diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index d34c1773a..4c20405fa 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -12,55 +12,7 @@ from qlib.utils import init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord from qlib.tests import TestAutoData - - -market = "csi300" -benchmark = "SH000300" - -################################### -# train model -################################### -data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, -} - -task = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - "kwargs": { - "loss": "mse", - "colsample_bytree": 0.8879, - "learning_rate": 0.0421, - "subsample": 0.8789, - "lambda_l1": 205.6999, - "lambda_l2": 580.9768, - "max_depth": 8, - "num_leaves": 210, - "num_threads": 20, - }, - }, - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, -} +from qlib.tests.config import CSI300_GBDT_TASK, CSI300_BENCH port_analysis_config = { "strategy": { @@ -75,7 +27,7 @@ port_analysis_config = { "verbose": False, "limit_threshold": 0.095, "account": 100000000, - "benchmark": benchmark, + "benchmark": CSI300_BENCH, "deal_price": "close", "open_cost": 0.0005, "close_cost": 0.0015, @@ -96,15 +48,15 @@ def train(): """ # model initiaiton - model = init_instance_by_config(task["model"]) - dataset = init_instance_by_config(task["dataset"]) + model = init_instance_by_config(CSI300_GBDT_TASK["model"]) + dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) # To test __repr__ print(dataset) print(R) # start exp with R.start(experiment_name="workflow"): - R.log_params(**flatten_dict(task)) + R.log_params(**flatten_dict(CSI300_GBDT_TASK)) model.fit(dataset) # prediction @@ -137,12 +89,12 @@ def train_with_sigana(): performance: dict model performance """ - model = init_instance_by_config(task["model"]) - dataset = init_instance_by_config(task["dataset"]) + model = init_instance_by_config(CSI300_GBDT_TASK["model"]) + dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) # start exp with R.start(experiment_name="workflow_with_sigana"): - R.log_params(**flatten_dict(task)) + R.log_params(**flatten_dict(CSI300_GBDT_TASK)) model.fit(dataset) # predict and calculate ic and ric @@ -171,7 +123,7 @@ def fake_experiment(): default_uri = R.get_uri() current_uri = "file:./temp-test-exp-mag" with R.start(experiment_name="fake_workflow_for_expm", uri=current_uri): - R.log_params(**flatten_dict(task)) + R.log_params(**flatten_dict(CSI300_GBDT_TASK)) current_uri_to_check = R.get_uri() default_uri_to_check = R.get_uri() diff --git a/tests/test_contrib_workflow.py b/tests/test_contrib_workflow.py index ccd3c6a90..9b1edbd4e 100644 --- a/tests/test_contrib_workflow.py +++ b/tests/test_contrib_workflow.py @@ -1,73 +1,22 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys import shutil import unittest from pathlib import Path -import qlib -from qlib.config import C from qlib.contrib.workflow import MultiSegRecord, SignalMseRecord from qlib.utils import init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.tests import TestAutoData - - -market = "csi300" -benchmark = "SH000300" - -################################### -# train model -################################### -data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, -} - -task = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - "kwargs": { - "loss": "mse", - "colsample_bytree": 0.8879, - "learning_rate": 0.0421, - "subsample": 0.8789, - "lambda_l1": 205.6999, - "lambda_l2": 580.9768, - "max_depth": 8, - "num_leaves": 210, - "num_threads": 20, - }, - }, - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, -} +from qlib.tests.config import CSI300_GBDT_TASK def train_multiseg(): - model = init_instance_by_config(task["model"]) - dataset = init_instance_by_config(task["dataset"]) + model = init_instance_by_config(CSI300_GBDT_TASK["model"]) + dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) with R.start(experiment_name="workflow"): - R.log_params(**flatten_dict(task)) + R.log_params(**flatten_dict(CSI300_GBDT_TASK)) model.fit(dataset) recorder = R.get_recorder() sr = MultiSegRecord(model, dataset, recorder) @@ -77,10 +26,10 @@ def train_multiseg(): def train_mse(): - model = init_instance_by_config(task["model"]) - dataset = init_instance_by_config(task["dataset"]) + model = init_instance_by_config(CSI300_GBDT_TASK["model"]) + dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) with R.start(experiment_name="workflow"): - R.log_params(**flatten_dict(task)) + R.log_params(**flatten_dict(CSI300_GBDT_TASK)) model.fit(dataset) recorder = R.get_recorder() sr = SignalMseRecord(recorder, model=model, dataset=dataset) diff --git a/tests/test_get_data.py b/tests/test_get_data.py index c511d1b91..55a2c3318 100644 --- a/tests/test_get_data.py +++ b/tests/test_get_data.py @@ -1,16 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys import shutil import unittest from pathlib import Path -sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) -from get_data import GetData - import qlib from qlib.data import D +from qlib.tests.data import GetData DATA_DIR = Path(__file__).parent.joinpath("test_get_data") SOURCE_DIR = DATA_DIR.joinpath("source") diff --git a/tests/test_register_ops.py b/tests/test_register_ops.py index 7d3322ddc..ac86be59c 100644 --- a/tests/test_register_ops.py +++ b/tests/test_register_ops.py @@ -1,17 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys import unittest import numpy as np -import qlib from qlib.data import D from qlib.data.ops import ElemOperator, PairOperator -from qlib.config import REG_CN -from qlib.utils import exists_qlib_data from qlib.tests import TestAutoData -from qlib.tests.data import GetData class Diff(ElemOperator): From ef11a9d95c06965680b0e886b6c6657687895c69 Mon Sep 17 00:00:00 2001 From: zhupr Date: Fri, 28 May 2021 14:57:06 +0800 Subject: [PATCH 22/24] modify the default value of exists_skip in the GetData.qlib_data parameter to False --- examples/highfreq/workflow.py | 2 +- examples/hyperparameter/LightGBM/hyperparameter_158.py | 2 +- examples/hyperparameter/LightGBM/hyperparameter_360.py | 2 +- examples/model_interpreter/feature.py | 2 +- examples/rolling_process_data/workflow.py | 2 +- examples/run_all_model.py | 2 +- examples/workflow_by_code.py | 2 +- qlib/tests/__init__.py | 1 + qlib/tests/data.py | 8 ++++++-- tests/test_get_data.py | 4 +++- 10 files changed, 17 insertions(+), 10 deletions(-) diff --git a/examples/highfreq/workflow.py b/examples/highfreq/workflow.py index 856885b25..7bf5fd09a 100644 --- a/examples/highfreq/workflow.py +++ b/examples/highfreq/workflow.py @@ -85,7 +85,7 @@ class HighfreqWorkflow: # use yahoo_cn_1min data QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF} provider_uri = QLIB_INIT_CONFIG.get("provider_uri") - GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN) + GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN, exists_skip=True) qlib.init(**QLIB_INIT_CONFIG) def _prepare_calender_cache(self): diff --git a/examples/hyperparameter/LightGBM/hyperparameter_158.py b/examples/hyperparameter/LightGBM/hyperparameter_158.py index 9e4557ed5..89cc10cc6 100644 --- a/examples/hyperparameter/LightGBM/hyperparameter_158.py +++ b/examples/hyperparameter/LightGBM/hyperparameter_158.py @@ -37,7 +37,7 @@ def objective(trial): if __name__ == "__main__": provider_uri = "~/.qlib/qlib_data/cn_data" - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True) qlib.init(provider_uri=provider_uri, region="cn") dataset = init_instance_by_config(CSI300_DATASET_CONFIG) diff --git a/examples/hyperparameter/LightGBM/hyperparameter_360.py b/examples/hyperparameter/LightGBM/hyperparameter_360.py index a8127014b..bc0cc245d 100644 --- a/examples/hyperparameter/LightGBM/hyperparameter_360.py +++ b/examples/hyperparameter/LightGBM/hyperparameter_360.py @@ -40,7 +40,7 @@ def objective(trial): if __name__ == "__main__": provider_uri = "~/.qlib/qlib_data/cn_data" - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True) qlib.init(provider_uri=provider_uri, region=REG_CN) dataset = init_instance_by_config(DATASET_CONFIG) diff --git a/examples/model_interpreter/feature.py b/examples/model_interpreter/feature.py index 1c29fda6e..a1288e07d 100644 --- a/examples/model_interpreter/feature.py +++ b/examples/model_interpreter/feature.py @@ -14,7 +14,7 @@ if __name__ == "__main__": # use default data provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True) qlib.init(provider_uri=provider_uri, region=REG_CN) diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index bfa2d1ec4..387d5cde7 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -23,7 +23,7 @@ class RollingDataWorkflow: """initialize qlib""" # use yahoo_cn_1min data provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True) qlib.init(provider_uri=provider_uri, region=REG_CN) def _dump_pre_handler(self, path): diff --git a/examples/run_all_model.py b/examples/run_all_model.py index 8875b9aa1..c79fee004 100644 --- a/examples/run_all_model.py +++ b/examples/run_all_model.py @@ -37,7 +37,7 @@ exp_manager = { }, } -GetData().qlib_data(target_dir=provider_uri, region=REG_CN) +GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True) qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager) # decorator to check the arguments diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py index 2e84cadc2..1cdf2ac80 100644 --- a/examples/workflow_by_code.py +++ b/examples/workflow_by_code.py @@ -14,7 +14,7 @@ if __name__ == "__main__": # use default data provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True) qlib.init(provider_uri=provider_uri, region=REG_CN) port_analysis_config = { diff --git a/qlib/tests/__init__.py b/qlib/tests/__init__.py index e72f000ba..7f43cd99a 100644 --- a/qlib/tests/__init__.py +++ b/qlib/tests/__init__.py @@ -19,5 +19,6 @@ class TestAutoData(unittest.TestCase): interval="1d", target_dir=cls.provider_uri, delete_old=False, + exists_skip=True, ) init(provider_uri=cls.provider_uri, region=REG_CN, **cls._setup_kwargs) diff --git a/qlib/tests/data.py b/qlib/tests/data.py index 0f226c6b1..2bfe43590 100644 --- a/qlib/tests/data.py +++ b/qlib/tests/data.py @@ -113,7 +113,7 @@ class GetData: interval="1d", region="cn", delete_old=True, - exists_skip=True, + exists_skip=False, ): """download cn qlib data from remote @@ -132,7 +132,7 @@ class GetData: delete_old: bool delete an existing directory, by default True exists_skip: bool - exists skip, by default True + exists skip, by default False Examples --------- @@ -145,6 +145,10 @@ class GetData: """ if exists_skip and exists_qlib_data(target_dir): + logger.warning( + f"Data already exists: {target_dir}, the data download will be skipped\n" + f"\tIf downloading is required: `exists_skip=False` or `change target_dir`" + ) return qlib_version = ".".join(re.findall(r"(\d+)\.+", qlib.__version__)) diff --git a/tests/test_get_data.py b/tests/test_get_data.py index 55a2c3318..93a852f55 100644 --- a/tests/test_get_data.py +++ b/tests/test_get_data.py @@ -34,7 +34,9 @@ class TestGetData(unittest.TestCase): def test_0_qlib_data(self): - GetData().qlib_data(name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", delete_old=False) + GetData().qlib_data( + name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", delete_old=False, exists_skip=True + ) df = D.features(D.instruments("csi300"), self.FIELDS) self.assertListEqual(list(df.columns), self.FIELDS, "get qlib data failed") self.assertFalse(df.dropna().empty, "get qlib data failed") From 02e34eb9e9f7d65d9dd782e8d023e44b9014619e Mon Sep 17 00:00:00 2001 From: al Date: Sun, 30 May 2021 08:27:21 +0800 Subject: [PATCH 23/24] Add import stock pool (csi300) in documentation --- docs/component/data.rst | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/component/data.rst b/docs/component/data.rst index ff0d0c2d7..cd30ee98b 100644 --- a/docs/component/data.rst +++ b/docs/component/data.rst @@ -152,6 +152,16 @@ After conversion, users can find their Qlib format data in the directory `~/.qli In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended. +Stock Pool (Market) +-------------------------------- + +``Qlib`` defines `stock pool `_ as stock list and their date ranges. Predefined stock pools (e.g. csi300) may be imported as follows. + +.. code-block:: bash + + python collector.py --index_name CSI300 --qlib_dir --method parse_instruments + + Multiple Stock Modes -------------------------------- From 4ff0c4fb0f98abfd36faaf666cd8f6811c3a7cc7 Mon Sep 17 00:00:00 2001 From: you-n-g Date: Mon, 31 May 2021 08:52:41 +0800 Subject: [PATCH 24/24] Update strategy.rst --- docs/component/strategy.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/component/strategy.rst b/docs/component/strategy.rst index 0720dcdad..e4a5a94d1 100644 --- a/docs/component/strategy.rst +++ b/docs/component/strategy.rst @@ -111,8 +111,6 @@ Usage & Example pred_score, strategy=strategy, **BACKTEST_CONFIG ) -Also, the above example has been given in ``examples/train_backtest_analyze.ipynb``. - To know more about the `prediction score` `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction `_. To know more about ``Intraday Trading``, please refer to `Intraday Trading: Model&Strategy Testing `_.