diff --git a/CHANGELOG.md b/CHANGELOG.md index 82e589b..532f40b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - Added `sum` aggregate - Improved performance of distance functions - Fixed out of range results for cosine distance +- Fixed results for NULL and NaN distances ## 0.4.4 (2023-06-12) diff --git a/src/ivfscan.c b/src/ivfscan.c index a2108ef..6703120 100644 --- a/src/ivfscan.c +++ b/src/ivfscan.c @@ -180,6 +180,29 @@ GetScanItems(IndexScanDesc scan, Datum value) tuplesort_performsort(so->sortstate); } +/* + * Get dimensions from metapage + */ +static int +GetDimensions(Relation index) +{ + Buffer buf; + Page page; + IvfflatMetaPage metap; + int dimensions; + + buf = ReadBuffer(index, IVFFLAT_METAPAGE_BLKNO); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + metap = IvfflatPageGetMeta(page); + + dimensions = metap->dimensions; + + UnlockReleaseBuffer(buf); + + return dimensions; +} + /* * Prepare for an index scan */ @@ -285,21 +308,19 @@ ivfflatgettuple(IndexScanDesc scan, ScanDirection dir) if (scan->orderByData == NULL) elog(ERROR, "cannot scan ivfflat index without order"); - /* No items will match if null */ if (scan->orderByData->sk_flags & SK_ISNULL) - return false; - - value = scan->orderByData->sk_argument; - - /* Value should not be compressed or toasted */ - Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value))); - Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value))); - - if (so->normprocinfo != NULL) + value = PointerGetDatum(InitVector(GetDimensions(scan->indexRelation))); + else { - /* No items will match if normalization fails */ - if (!IvfflatNormValue(so->normprocinfo, so->collation, &value, NULL)) - return false; + value = scan->orderByData->sk_argument; + + /* Value should not be compressed or toasted */ + Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value))); + Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value))); + + /* Fine if normalization fails */ + if (so->normprocinfo != NULL) + IvfflatNormValue(so->normprocinfo, so->collation, &value, NULL); } IvfflatBench("GetScanLists", GetScanLists(scan, value)); diff --git a/test/expected/ivfflat_cosine.out b/test/expected/ivfflat_cosine.out index 96db5e0..8584d95 100644 --- a/test/expected/ivfflat_cosine.out +++ b/test/expected/ivfflat_cosine.out @@ -11,9 +11,16 @@ SELECT * FROM t ORDER BY val <=> '[3,3,3]'; [1,2,4] (3 rows) -SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector); - val ------ -(0 rows) +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; + count +------- + 3 +(1 row) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2; + count +------- + 3 +(1 row) DROP TABLE t; diff --git a/test/expected/ivfflat_ip.out b/test/expected/ivfflat_ip.out index d4fc538..d2bc386 100644 --- a/test/expected/ivfflat_ip.out +++ b/test/expected/ivfflat_ip.out @@ -12,9 +12,10 @@ SELECT * FROM t ORDER BY val <#> '[3,3,3]'; [0,0,0] (4 rows) -SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector); - val ------ -(0 rows) +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector)) t2; + count +------- + 4 +(1 row) DROP TABLE t; diff --git a/test/expected/ivfflat_l2.out b/test/expected/ivfflat_l2.out index 2e8c6c2..1f510f2 100644 --- a/test/expected/ivfflat_l2.out +++ b/test/expected/ivfflat_l2.out @@ -13,9 +13,13 @@ SELECT * FROM t ORDER BY val <-> '[3,3,3]'; (4 rows) SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector); - val ------ -(0 rows) + val +--------- + [0,0,0] + [1,1,1] + [1,2,3] + [1,2,4] +(4 rows) SELECT COUNT(*) FROM t; count diff --git a/test/sql/ivfflat_cosine.sql b/test/sql/ivfflat_cosine.sql index 1fec6cf..a891a04 100644 --- a/test/sql/ivfflat_cosine.sql +++ b/test/sql/ivfflat_cosine.sql @@ -7,6 +7,7 @@ CREATE INDEX ON t USING ivfflat (val vector_cosine_ops) WITH (lists = 1); INSERT INTO t (val) VALUES ('[1,2,4]'); SELECT * FROM t ORDER BY val <=> '[3,3,3]'; -SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector); +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2; DROP TABLE t; diff --git a/test/sql/ivfflat_ip.sql b/test/sql/ivfflat_ip.sql index 46daa4e..1560c55 100644 --- a/test/sql/ivfflat_ip.sql +++ b/test/sql/ivfflat_ip.sql @@ -7,6 +7,6 @@ CREATE INDEX ON t USING ivfflat (val vector_ip_ops) WITH (lists = 1); INSERT INTO t (val) VALUES ('[1,2,4]'); SELECT * FROM t ORDER BY val <#> '[3,3,3]'; -SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector); +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector)) t2; DROP TABLE t;