DRY HNSW distance calculations

This commit is contained in:
Andrew Kane
2024-10-09 17:01:49 -07:00
parent 77688b4309
commit f4b67b078f

View File

@@ -482,6 +482,15 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe
}
}
/*
* Calculate the distance between values
*/
static inline float
HnswGetDistance(Datum a, Datum b, FmgrInfo *procinfo, Oid collation)
{
return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, a, b));
}
/*
* Load an element and optionally get its distance from q
*/
@@ -507,7 +516,7 @@ HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, Dat
if (DatumGetPointer(*q) == NULL)
*distance = 0;
else
*distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&etup->data)));
*distance = HnswGetDistance(*q, PointerGetDatum(&etup->data), procinfo, collation);
}
/* Load element */
@@ -539,7 +548,7 @@ GetElementDistance(char *base, HnswElement element, Datum q, FmgrInfo *procinfo,
{
Datum value = HnswGetValue(base, element);
return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, value));
return HnswGetDistance(q, value, procinfo, collation);
}
/*
@@ -921,18 +930,6 @@ CompareCandidateDistancesOffset(const ListCell *a, const ListCell *b)
return 0;
}
/*
* Calculate the distance between elements
*/
static float
HnswGetDistance(char *base, HnswElement a, HnswElement b, FmgrInfo *procinfo, Oid collation)
{
Datum aValue = HnswGetValue(base, a);
Datum bValue = HnswGetValue(base, b);
return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, aValue, bValue));
}
/*
* Check if an element is closer to q than any element from R
*/
@@ -940,13 +937,15 @@ static bool
CheckElementCloser(char *base, HnswCandidate * e, List *r, FmgrInfo *procinfo, Oid collation)
{
HnswElement eElement = HnswPtrAccess(base, e->element);
Datum eValue = HnswGetValue(base, eElement);
ListCell *lc2;
foreach(lc2, r)
{
HnswCandidate *ri = lfirst(lc2);
HnswElement riElement = HnswPtrAccess(base, ri->element);
float distance = HnswGetDistance(base, eElement, riElement, procinfo, collation);
Datum riValue = HnswGetValue(base, riElement);
float distance = HnswGetDistance(eValue, riValue, procinfo, collation);
if (distance <= e->distance)
return false;