diff --git a/src/hnswutils.c b/src/hnswutils.c index 6536646..dc75130 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -559,13 +559,12 @@ HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, } /* - * Get the distance for a candidate + * Get the distance between 'hc' and 'q' */ static float -GetCandidateDistance(char *base, HnswCandidate * hc, Datum q, FmgrInfo *procinfo, Oid collation) +GetElementDistance(char *base, HnswElement hc, Datum q, FmgrInfo *procinfo, Oid collation) { - HnswElement hce = HnswPtrAccess(base, hc->element); - Datum value = HnswGetValue(base, hce); + Datum value = HnswGetValue(base, hc); return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, value)); } @@ -580,7 +579,7 @@ HnswEntryCandidate(char *base, HnswElement entryPoint, Datum q, Relation index, HnswPtrStore(base, hc->element, entryPoint); if (index == NULL) - hc->distance = GetCandidateDistance(base, hc, q, procinfo, collation); + hc->distance = GetElementDistance(base, entryPoint, q, procinfo, collation); else HnswLoadElement(entryPoint, &hc->distance, &q, index, procinfo, collation, loadVec); return hc; @@ -667,17 +666,14 @@ AddToVisited(char *base, visited_hash * v, HnswElement element, Relation index, * Count element towards ef */ static inline bool -CountElement(char *base, HnswElement skipElement, HnswCandidate * hc) +CountElement(char *base, HnswElement skipElement, HnswElement e) { - HnswElement e; - if (skipElement == NULL) return true; /* Ensure does not access heaptidsLength during in-memory build */ pg_memory_barrier(); - e = HnswPtrAccess(base, hc->element); return e->heaptidsLength != 0; } @@ -709,9 +705,10 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F foreach(lc2, ep) { HnswCandidate *hc = (HnswCandidate *) lfirst(lc2); + HnswElement e = HnswPtrAccess(base, hc->element); bool found; - AddToVisited(base, &v, HnswPtrAccess(base, hc->element), index, &found); + AddToVisited(base, &v, e, index, &found); pairingheap_add(C, &(CreatePairingHeapNode(hc)->ph_node)); pairingheap_add(W, &(CreatePairingHeapNode(hc)->ph_node)); @@ -721,7 +718,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F * would be ideal to do this for inserts as well, but this could * affect insert performance. */ - if (CountElement(base, skipElement, hc)) + if (CountElement(base, skipElement, e)) wlen++; } @@ -754,8 +751,8 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F for (int i = 0; i < neighborhood->length; i++) { - HnswCandidate *e = &neighborhood->items[i]; - HnswElement eElement = HnswPtrAccess(base, e->element); + HnswCandidate *hc = &neighborhood->items[i]; + HnswElement eElement = HnswPtrAccess(base, hc->element); bool visited; AddToVisited(base, &v, eElement, index, &visited); @@ -767,7 +764,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F f = ((HnswPairingHeapNode *) pairingheap_first(W))->inner; if (index == NULL) - eDistance = GetCandidateDistance(base, e, q, procinfo, collation); + eDistance = GetElementDistance(base, eElement, q, procinfo, collation); else HnswLoadElement(eElement, &eDistance, &q, index, procinfo, collation, inserting); @@ -782,7 +779,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F /* Copy e */ HnswCandidate *ec = palloc(sizeof(HnswCandidate)); - HnswPtrStore(base, ec->element, eElement); + ec->element = hc->element; ec->distance = eDistance; pairingheap_add(C, &(CreatePairingHeapNode(ec)->ph_node)); @@ -793,7 +790,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F * vacuuming. It would be ideal to do this for inserts as * well, but this could affect insert performance. */ - if (CountElement(base, skipElement, e)) + if (CountElement(base, skipElement, eElement)) { wlen++; @@ -874,15 +871,12 @@ CompareCandidateDistancesOffset(const void *a, const void *b) } /* - * Calculate the distance between elements + * Calculate the distance between two vectors */ static float -HnswGetDistance(char *base, HnswElement a, HnswElement b, FmgrInfo *procinfo, Oid collation) +HnswGetDistance(Datum a, Datum b, FmgrInfo *procinfo, Oid collation) { - Datum aValue = HnswGetValue(base, a); - Datum bValue = HnswGetValue(base, b); - - return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, aValue, bValue)); + return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, a, b)); } /* @@ -892,13 +886,15 @@ static bool CheckElementCloser(char *base, HnswCandidate * e, List *r, FmgrInfo *procinfo, Oid collation) { HnswElement eElement = HnswPtrAccess(base, e->element); + Datum eValue = HnswGetValue(base, eElement); ListCell *lc2; foreach(lc2, r) { HnswCandidate *ri = lfirst(lc2); HnswElement riElement = HnswPtrAccess(base, ri->element); - float distance = HnswGetDistance(base, eElement, riElement, procinfo, collation); + Datum riValue = HnswGetValue(base, riElement); + float distance = HnswGetDistance(eValue, riValue, procinfo, collation); if (distance <= e->distance) return false; @@ -1064,7 +1060,7 @@ HnswUpdateConnection(char *base, HnswElement element, HnswCandidate * hc, int lm if (HnswPtrIsNull(base, hc3Element->value)) HnswLoadElement(hc3Element, &hc3->distance, &q, index, procinfo, collation, true); else - hc3->distance = GetCandidateDistance(base, hc3, q, procinfo, collation); + hc3->distance = GetElementDistance(base, hc3Element, q, procinfo, collation); /* Prune element if being deleted */ if (hc3Element->heaptidsLength == 0)