diff --git a/src/hnswutils.c b/src/hnswutils.c index 8f7a783..b12709c 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -580,13 +580,12 @@ HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, } /* - * Get the distance for a candidate + * Get the distance for an element */ static float -GetCandidateDistance(char *base, HnswCandidate * hc, Datum q, FmgrInfo *procinfo, Oid collation) +GetElementDistance(char *base, HnswElement element, Datum q, FmgrInfo *procinfo, Oid collation) { - HnswElement hce = HnswPtrAccess(base, hc->element); - Datum value = HnswGetValue(base, hce); + Datum value = HnswGetValue(base, element); return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, value)); } @@ -601,7 +600,7 @@ HnswEntryCandidate(char *base, HnswElement entryPoint, Datum q, Relation index, HnswPtrStore(base, hc->element, entryPoint); if (index == NULL) - hc->distance = GetCandidateDistance(base, hc, q, procinfo, collation); + hc->distance = GetElementDistance(base, entryPoint, q, procinfo, collation); else HnswLoadElement(entryPoint, &hc->distance, &q, index, procinfo, collation, loadVec, NULL); return hc; @@ -706,20 +705,87 @@ AddToVisited(char *base, visited_hash * v, HnswCandidate * hc, Relation index, b * Count element towards ef */ static inline bool -CountElement(char *base, HnswElement skipElement, HnswCandidate * hc) +CountElement(char *base, HnswElement skipElement, HnswElement e) { - HnswElement e; - if (skipElement == NULL) return true; /* Ensure does not access heaptidsLength during in-memory build */ pg_memory_barrier(); - e = HnswPtrAccess(base, hc->element); return e->heaptidsLength != 0; } +/* + * Load unvisited neighbors from memory + */ +static void +HnswLoadUnvisitedFromMemory(char *base, HnswElement element, HnswElement * unvisited, int *unvisitedLength, visited_hash * v, int lc, HnswNeighborArray * neighborhoodData, Size neighborhoodSize) +{ + /* Get the neighborhood at layer lc */ + HnswNeighborArray *neighborhood = HnswGetNeighbors(base, element, lc); + + /* Copy neighborhood to local memory */ + LWLockAcquire(&element->lock, LW_SHARED); + memcpy(neighborhoodData, neighborhood, neighborhoodSize); + LWLockRelease(&element->lock); + neighborhood = neighborhoodData; + + *unvisitedLength = 0; + + for (int i = 0; i < neighborhood->length; i++) + { + HnswCandidate *hc = &neighborhood->items[i]; + bool found; + + AddToVisited(base, v, hc, NULL, &found); + + if (!found) + unvisited[(*unvisitedLength)++] = HnswPtrAccess(base, hc->element); + } +} + +/* + * Load unvisited neighbors from disk + */ +static void +HnswLoadUnvisitedFromDisk(HnswElement element, HnswElement * unvisited, int *unvisitedLength, visited_hash * v, Relation index, int m, int lm, int lc) +{ + Buffer buf; + Page page; + HnswNeighborTuple ntup; + int start; + ItemPointerData indextids[HNSW_MAX_M * 2]; + + buf = ReadBuffer(index, element->neighborPage); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + + ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno)); + start = (element->level - lc) * m; + + /* Copy to minimize lock time */ + memcpy(&indextids, ntup->indextids + start, lm * sizeof(ItemPointerData)); + + UnlockReleaseBuffer(buf); + + *unvisitedLength = 0; + + for (int i = 0; i < lm; i++) + { + ItemPointer indextid = &indextids[i]; + bool found; + + if (!ItemPointerIsValid(indextid)) + break; + + tidhash_insert(v->tids, *indextid, &found); + + if (!found) + unvisited[(*unvisitedLength)++] = HnswInitElementFromBlock(ItemPointerGetBlockNumber(indextid), ItemPointerGetOffsetNumber(indextid)); + } +} + /* * Algorithm 2 from paper */ @@ -734,13 +800,16 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F ListCell *lc2; HnswNeighborArray *neighborhoodData = NULL; Size neighborhoodSize = 0; + int lm = HnswGetLayerM(m, lc); + HnswElement *unvisited = palloc(lm * sizeof(HnswElement)); + int unvisitedLength; InitVisited(base, &v, index, ef, m); /* Create local memory for neighborhood if needed */ if (index == NULL) { - neighborhoodSize = HNSW_NEIGHBOR_ARRAY_SIZE(HnswGetLayerM(m, lc)); + neighborhoodSize = HNSW_NEIGHBOR_ARRAY_SIZE(lm); neighborhoodData = palloc(neighborhoodSize); } @@ -762,13 +831,12 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F * would be ideal to do this for inserts as well, but this could * affect insert performance. */ - if (CountElement(base, skipElement, hc)) + if (CountElement(base, skipElement, HnswPtrAccess(base, hc->element))) wlen++; } while (!pairingheap_is_empty(C)) { - HnswNeighborArray *neighborhood; HnswCandidate *c = HnswGetPairingHeapCandidate(c_node, pairingheap_remove_first(C)); HnswCandidate *f = HnswGetPairingHeapCandidate(w_node, pairingheap_first(W)); HnswElement cElement; @@ -778,74 +846,56 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F cElement = HnswPtrAccess(base, c->element); - if (HnswPtrIsNull(base, cElement->neighbors)) - HnswLoadNeighbors(cElement, index, m); - - /* Get the neighborhood at layer lc */ - neighborhood = HnswGetNeighbors(base, cElement, lc); - - /* Copy neighborhood to local memory if needed */ if (index == NULL) + HnswLoadUnvisitedFromMemory(base, cElement, unvisited, &unvisitedLength, &v, lc, neighborhoodData, neighborhoodSize); + else + HnswLoadUnvisitedFromDisk(cElement, unvisited, &unvisitedLength, &v, index, m, lm, lc); + + for (int i = 0; i < unvisitedLength; i++) { - LWLockAcquire(&cElement->lock, LW_SHARED); - memcpy(neighborhoodData, neighborhood, neighborhoodSize); - LWLockRelease(&cElement->lock); - neighborhood = neighborhoodData; - } + HnswElement eElement = unvisited[i]; + float eDistance; + bool alwaysAdd = wlen < ef; - for (int i = 0; i < neighborhood->length; i++) - { - HnswCandidate *e = &neighborhood->items[i]; - bool visited; + f = HnswGetPairingHeapCandidate(w_node, pairingheap_first(W)); - AddToVisited(base, &v, e, index, &visited); + if (index == NULL) + eDistance = GetElementDistance(base, eElement, q, procinfo, collation); + else + HnswLoadElement(eElement, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd ? NULL : &f->distance); - if (!visited) + if (eDistance < f->distance || alwaysAdd) { - float eDistance; - HnswElement eElement = HnswPtrAccess(base, e->element); - bool alwaysAdd = wlen < ef; + HnswCandidate *e; + HnswPairingHeapNode *node; - f = HnswGetPairingHeapCandidate(w_node, pairingheap_first(W)); + Assert(!eElement->deleted); - if (index == NULL) - eDistance = GetCandidateDistance(base, e, q, procinfo, collation); - else - HnswLoadElement(eElement, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd ? NULL : &f->distance); + /* Make robust to issues */ + if (eElement->level < lc) + continue; - if (eDistance < f->distance || alwaysAdd) + /* Create a new candidate */ + e = palloc(sizeof(HnswCandidate)); + HnswPtrStore(base, e->element, eElement); + e->distance = eDistance; + + node = CreatePairingHeapNode(e); + pairingheap_add(C, &node->c_node); + pairingheap_add(W, &node->w_node); + + /* + * Do not count elements being deleted towards ef when + * vacuuming. It would be ideal to do this for inserts as + * well, but this could affect insert performance. + */ + if (CountElement(base, skipElement, eElement)) { - HnswCandidate *ec; - HnswPairingHeapNode *node; + wlen++; - Assert(!eElement->deleted); - - /* Make robust to issues */ - if (eElement->level < lc) - continue; - - /* Copy e */ - ec = palloc(sizeof(HnswCandidate)); - HnswPtrStore(base, ec->element, eElement); - ec->distance = eDistance; - - node = CreatePairingHeapNode(ec); - pairingheap_add(C, &node->c_node); - pairingheap_add(W, &node->w_node); - - /* - * Do not count elements being deleted towards ef when - * vacuuming. It would be ideal to do this for inserts as - * well, but this could affect insert performance. - */ - if (CountElement(base, skipElement, e)) - { - wlen++; - - /* No need to decrement wlen */ - if (wlen > ef) - pairingheap_remove_first(W); - } + /* No need to decrement wlen */ + if (wlen > ef) + pairingheap_remove_first(W); } } } @@ -1117,7 +1167,7 @@ HnswUpdateConnection(char *base, HnswElement element, HnswCandidate * hc, int lm if (HnswPtrIsNull(base, hc3Element->value)) HnswLoadElement(hc3Element, &hc3->distance, &q, index, procinfo, collation, true, NULL); else - hc3->distance = GetCandidateDistance(base, hc3, q, procinfo, collation); + hc3->distance = GetElementDistance(base, hc3Element, q, procinfo, collation); /* Prune element if being deleted */ if (hc3Element->heaptidsLength == 0)