From 029c336c62d832b2ca52825b37c457204c1f50d7 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Fri, 4 Aug 2023 13:51:16 -0700 Subject: [PATCH] Fixed results for NULL and NaN distances [skip ci] --- src/hnswscan.c | 46 ++++++++++++++++++++++++++--------- test/expected/hnsw_cosine.out | 17 ++++++++++--- test/expected/hnsw_ip.out | 10 +++++--- test/expected/hnsw_l2.out | 10 +++++--- test/sql/hnsw_cosine.sql | 1 + 5 files changed, 63 insertions(+), 21 deletions(-) diff --git a/src/hnswscan.c b/src/hnswscan.c index 3dca1ce..a178b60 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -35,6 +35,29 @@ GetScanItems(IndexScanDesc scan, Datum q) so->w = SearchLayer(q, ep, hnsw_ef_search, 0, index, procinfo, collation, false, NULL, NULL); } +/* + * Get dimensions from metapage + */ +static int +GetDimensions(Relation index) +{ + Buffer buf; + Page page; + HnswMetaPage metap; + int dimensions; + + buf = ReadBuffer(index, HNSW_METAPAGE_BLKNO); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + metap = HnswPageGetMeta(page); + + dimensions = metap->dimensions; + + UnlockReleaseBuffer(buf); + + return dimensions; +} + /* * Prepare for an index scan */ @@ -107,19 +130,18 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir) /* No items will match if null */ if (scan->orderByData->sk_flags & SK_ISNULL) - return false; - - value = scan->orderByData->sk_argument; - - /* Value should not be compressed or toasted */ - Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value))); - Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value))); - - if (so->normprocinfo != NULL) + value = PointerGetDatum(InitVector(GetDimensions(scan->indexRelation))); + else { - /* No items will match if normalization fails */ - if (!HnswNormValue(so->normprocinfo, so->collation, &value, NULL)) - return false; + value = scan->orderByData->sk_argument; + + /* Value should not be compressed or toasted */ + Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value))); + Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value))); + + /* Fine if normalization fails */ + if (so->normprocinfo != NULL) + HnswNormValue(so->normprocinfo, so->collation, &value, NULL); } GetScanItems(scan, value); diff --git a/test/expected/hnsw_cosine.out b/test/expected/hnsw_cosine.out index eec40d9..9bf8f59 100644 --- a/test/expected/hnsw_cosine.out +++ b/test/expected/hnsw_cosine.out @@ -11,9 +11,20 @@ SELECT * FROM t ORDER BY val <=> '[3,3,3]'; [1,2,4] (3 rows) +SELECT * FROM t ORDER BY val <=> '[0,0,0]'; + val +--------- + [1,1,1] + [1,2,4] + [1,2,3] +(3 rows) + SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector); - val ------ -(0 rows) + val +--------- + [1,1,1] + [1,2,4] + [1,2,3] +(3 rows) DROP TABLE t; diff --git a/test/expected/hnsw_ip.out b/test/expected/hnsw_ip.out index 85a4648..d6ae2ea 100644 --- a/test/expected/hnsw_ip.out +++ b/test/expected/hnsw_ip.out @@ -13,8 +13,12 @@ SELECT * FROM t ORDER BY val <#> '[3,3,3]'; (4 rows) SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector); - val ------ -(0 rows) + val +--------- + [1,1,1] + [1,2,4] + [1,2,3] + [0,0,0] +(4 rows) DROP TABLE t; diff --git a/test/expected/hnsw_l2.out b/test/expected/hnsw_l2.out index 4136b82..e8a16c8 100644 --- a/test/expected/hnsw_l2.out +++ b/test/expected/hnsw_l2.out @@ -13,9 +13,13 @@ SELECT * FROM t ORDER BY val <-> '[3,3,3]'; (4 rows) SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector); - val ------ -(0 rows) + val +--------- + [0,0,0] + [1,1,1] + [1,2,3] + [1,2,4] +(4 rows) SELECT COUNT(*) FROM t; count diff --git a/test/sql/hnsw_cosine.sql b/test/sql/hnsw_cosine.sql index 4398150..9b84d09 100644 --- a/test/sql/hnsw_cosine.sql +++ b/test/sql/hnsw_cosine.sql @@ -7,6 +7,7 @@ CREATE INDEX ON t USING hnsw (val vector_cosine_ops); INSERT INTO t (val) VALUES ('[1,2,4]'); SELECT * FROM t ORDER BY val <=> '[3,3,3]'; +SELECT * FROM t ORDER BY val <=> '[0,0,0]'; SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector); DROP TABLE t;