mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Update index_data.py for datatype conversion and alignment (#1813)
* Update index_data.py for data convertion and alignment * Update qlib/utils/index_data.py * Update qlib/utils/index_data.py * fix linting --------- Co-authored-by: taozhiwang <taozhiwa@gmail.com> Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
This commit is contained in:
@@ -108,6 +108,12 @@ class Index:
|
||||
self.index_map = self.idx_list = np.arange(idx_list)
|
||||
self._is_sorted = True
|
||||
else:
|
||||
# Check if all elements in idx_list are of the same type
|
||||
if not all(isinstance(x, type(idx_list[0])) for x in idx_list):
|
||||
raise TypeError("All elements in idx_list must be of the same type")
|
||||
# Check if all elements in idx_list are of the same datetime64 precision
|
||||
if isinstance(idx_list[0], np.datetime64) and not all(x.dtype == idx_list[0].dtype for x in idx_list):
|
||||
raise TypeError("All elements in idx_list must be of the same datetime64 precision")
|
||||
self.idx_list = np.array(idx_list)
|
||||
# NOTE: only the first appearance is indexed
|
||||
self.index_map = dict(zip(self.idx_list, range(len(self))))
|
||||
@@ -131,7 +137,12 @@ class Index:
|
||||
if self.idx_list.dtype.type is np.datetime64:
|
||||
if isinstance(item, pd.Timestamp):
|
||||
# This happens often when creating index based on pandas.DatetimeIndex and query with pd.Timestamp
|
||||
return item.to_numpy()
|
||||
return item.to_numpy().astype(self.idx_list.dtype)
|
||||
elif isinstance(item, np.datetime64):
|
||||
# This happens often when creating index based on np.datetime64 and query with another precision
|
||||
return item.astype(self.idx_list.dtype)
|
||||
# NOTE: It is hard to consider every case at first.
|
||||
# We just try to cover part of cases to make it more user-friendly
|
||||
return item
|
||||
|
||||
def index(self, item) -> int:
|
||||
|
||||
@@ -94,6 +94,24 @@ class IndexDataTest(unittest.TestCase):
|
||||
print(sd)
|
||||
self.assertTrue(sd.iloc[0] == 2)
|
||||
|
||||
# test different precisions of time data
|
||||
timeindex = [
|
||||
np.datetime64("2024-06-22T00:00:00.000000000"),
|
||||
np.datetime64("2024-06-21T00:00:00.000000000"),
|
||||
np.datetime64("2024-06-20T00:00:00.000000000"),
|
||||
]
|
||||
sd = idd.SingleData([1, 2, 3], index=timeindex)
|
||||
self.assertTrue(
|
||||
sd.index.index(np.datetime64("2024-06-21T00:00:00.000000000"))
|
||||
== sd.index.index(np.datetime64("2024-06-21T00:00:00"))
|
||||
)
|
||||
self.assertTrue(sd.index.index(pd.Timestamp("2024-06-21 00:00")) == 1)
|
||||
|
||||
# Bad case: the input is not aligned
|
||||
timeindex[1] = (np.datetime64("2024-06-21T00:00:00.00"),)
|
||||
with self.assertRaises(TypeError):
|
||||
sd = idd.SingleData([1, 2, 3], index=timeindex)
|
||||
|
||||
def test_ops(self):
|
||||
sd1 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
|
||||
sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
|
||||
|
||||
Reference in New Issue
Block a user