Fixed results for NULL and NaN distances [skip ci]

This commit is contained in:
Andrew Kane
2023-08-04 13:51:16 -07:00
parent 7f4acf9d43
commit 029c336c62
5 changed files with 63 additions and 21 deletions

View File

@@ -35,6 +35,29 @@ GetScanItems(IndexScanDesc scan, Datum q)
so->w = SearchLayer(q, ep, hnsw_ef_search, 0, index, procinfo, collation, false, NULL, NULL);
}
/*
* Get dimensions from metapage
*/
static int
GetDimensions(Relation index)
{
Buffer buf;
Page page;
HnswMetaPage metap;
int dimensions;
buf = ReadBuffer(index, HNSW_METAPAGE_BLKNO);
LockBuffer(buf, BUFFER_LOCK_SHARE);
page = BufferGetPage(buf);
metap = HnswPageGetMeta(page);
dimensions = metap->dimensions;
UnlockReleaseBuffer(buf);
return dimensions;
}
/*
* Prepare for an index scan
*/
@@ -107,19 +130,18 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir)
/* No items will match if null */
if (scan->orderByData->sk_flags & SK_ISNULL)
return false;
value = scan->orderByData->sk_argument;
/* Value should not be compressed or toasted */
Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value)));
Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value)));
if (so->normprocinfo != NULL)
value = PointerGetDatum(InitVector(GetDimensions(scan->indexRelation)));
else
{
/* No items will match if normalization fails */
if (!HnswNormValue(so->normprocinfo, so->collation, &value, NULL))
return false;
value = scan->orderByData->sk_argument;
/* Value should not be compressed or toasted */
Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value)));
Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value)));
/* Fine if normalization fails */
if (so->normprocinfo != NULL)
HnswNormValue(so->normprocinfo, so->collation, &value, NULL);
}
GetScanItems(scan, value);

View File

@@ -11,9 +11,20 @@ SELECT * FROM t ORDER BY val <=> '[3,3,3]';
[1,2,4]
(3 rows)
SELECT * FROM t ORDER BY val <=> '[0,0,0]';
val
---------
[1,1,1]
[1,2,4]
[1,2,3]
(3 rows)
SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector);
val
-----
(0 rows)
val
---------
[1,1,1]
[1,2,4]
[1,2,3]
(3 rows)
DROP TABLE t;

View File

@@ -13,8 +13,12 @@ SELECT * FROM t ORDER BY val <#> '[3,3,3]';
(4 rows)
SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector);
val
-----
(0 rows)
val
---------
[1,1,1]
[1,2,4]
[1,2,3]
[0,0,0]
(4 rows)
DROP TABLE t;

View File

@@ -13,9 +13,13 @@ SELECT * FROM t ORDER BY val <-> '[3,3,3]';
(4 rows)
SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector);
val
-----
(0 rows)
val
---------
[0,0,0]
[1,1,1]
[1,2,3]
[1,2,4]
(4 rows)
SELECT COUNT(*) FROM t;
count

View File

@@ -7,6 +7,7 @@ CREATE INDEX ON t USING hnsw (val vector_cosine_ops);
INSERT INTO t (val) VALUES ('[1,2,4]');
SELECT * FROM t ORDER BY val <=> '[3,3,3]';
SELECT * FROM t ORDER BY val <=> '[0,0,0]';
SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector);
DROP TABLE t;