From f4b67b078f8dc158b4d88a125d7043cc6cc6ac65 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Wed, 9 Oct 2024 17:01:49 -0700 Subject: [PATCH] DRY HNSW distance calculations --- src/hnswutils.c | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/src/hnswutils.c b/src/hnswutils.c index 3d0b484..a7eb819 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -482,6 +482,15 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe } } +/* + * Calculate the distance between values + */ +static inline float +HnswGetDistance(Datum a, Datum b, FmgrInfo *procinfo, Oid collation) +{ + return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, a, b)); +} + /* * Load an element and optionally get its distance from q */ @@ -507,7 +516,7 @@ HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, Dat if (DatumGetPointer(*q) == NULL) *distance = 0; else - *distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&etup->data))); + *distance = HnswGetDistance(*q, PointerGetDatum(&etup->data), procinfo, collation); } /* Load element */ @@ -539,7 +548,7 @@ GetElementDistance(char *base, HnswElement element, Datum q, FmgrInfo *procinfo, { Datum value = HnswGetValue(base, element); - return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, value)); + return HnswGetDistance(q, value, procinfo, collation); } /* @@ -921,18 +930,6 @@ CompareCandidateDistancesOffset(const ListCell *a, const ListCell *b) return 0; } -/* - * Calculate the distance between elements - */ -static float -HnswGetDistance(char *base, HnswElement a, HnswElement b, FmgrInfo *procinfo, Oid collation) -{ - Datum aValue = HnswGetValue(base, a); - Datum bValue = HnswGetValue(base, b); - - return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, aValue, bValue)); -} - /* * Check if an element is closer to q than any element from R */ @@ -940,13 +937,15 @@ static bool CheckElementCloser(char *base, HnswCandidate * e, List *r, FmgrInfo *procinfo, Oid collation) { HnswElement eElement = HnswPtrAccess(base, e->element); + Datum eValue = HnswGetValue(base, eElement); ListCell *lc2; foreach(lc2, r) { HnswCandidate *ri = lfirst(lc2); HnswElement riElement = HnswPtrAccess(base, ri->element); - float distance = HnswGetDistance(base, eElement, riElement, procinfo, collation); + Datum riValue = HnswGetValue(base, riElement); + float distance = HnswGetDistance(eValue, riValue, procinfo, collation); if (distance <= e->distance) return false;