diff --git a/src/hnswscan.c b/src/hnswscan.c index 63e960b..365ac71 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -40,29 +40,6 @@ GetScanItems(IndexScanDesc scan, Datum q) return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL); } -/* - * Get dimensions from metapage - */ -static int -GetDimensions(Relation index) -{ - Buffer buf; - Page page; - HnswMetaPage metap; - int dimensions; - - buf = ReadBuffer(index, HNSW_METAPAGE_BLKNO); - LockBuffer(buf, BUFFER_LOCK_SHARE); - page = BufferGetPage(buf); - metap = HnswPageGetMeta(page); - - dimensions = metap->dimensions; - - UnlockReleaseBuffer(buf); - - return dimensions; -} - /* * Get scan value */ @@ -73,7 +50,7 @@ GetScanValue(IndexScanDesc scan) Datum value; if (scan->orderByData->sk_flags & SK_ISNULL) - value = PointerGetDatum(InitVector(GetDimensions(scan->indexRelation))); + value = PointerGetDatum(NULL); else { value = scan->orderByData->sk_argument; diff --git a/src/hnswutils.c b/src/hnswutils.c index 5ecfda0..4104f32 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -573,7 +573,12 @@ HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, /* Calculate distance */ if (distance != NULL) - *distance = (float) DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&etup->data))); + { + if (DatumGetPointer(*q) == NULL) + *distance = 0; + else + *distance = (float) DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&etup->data))); + } UnlockReleaseBuffer(buf); } diff --git a/test/expected/hnsw_l2.out b/test/expected/hnsw_l2.out index 4209661..aa56c2b 100644 --- a/test/expected/hnsw_l2.out +++ b/test/expected/hnsw_l2.out @@ -12,14 +12,11 @@ SELECT * FROM t ORDER BY val <-> '[3,3,3]'; [0,0,0] (4 rows) -SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector); - val ---------- - [0,0,0] - [1,1,1] - [1,2,3] - [1,2,4] -(4 rows) +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector)) t2; + count +------- + 4 +(1 row) SELECT COUNT(*) FROM t; count diff --git a/test/sql/hnsw_l2.sql b/test/sql/hnsw_l2.sql index 70bb50a..32d9cac 100644 --- a/test/sql/hnsw_l2.sql +++ b/test/sql/hnsw_l2.sql @@ -7,7 +7,7 @@ CREATE INDEX ON t USING hnsw (val vector_l2_ops); INSERT INTO t (val) VALUES ('[1,2,4]'); SELECT * FROM t ORDER BY val <-> '[3,3,3]'; -SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector); +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector)) t2; SELECT COUNT(*) FROM t; TRUNCATE t;