diff --git a/CHANGELOG.md b/CHANGELOG.md index af972fe..246bba0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.7.5 (unreleased) + +- Reduced memory usage for HNSW index scans + ## 0.7.4 (2024-08-05) - Fixed locking for parallel HNSW index builds diff --git a/src/hnsw.h b/src/hnsw.h index 3bc454e..feca91c 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -164,8 +164,9 @@ struct HnswNeighborArray typedef struct HnswPairingHeapNode { - pairingheap_node ph_node; HnswCandidate *inner; + pairingheap_node c_node; + pairingheap_node w_node; } HnswPairingHeapNode; /* HNSW index options */ diff --git a/src/hnswutils.c b/src/hnswutils.c index 371e42f..bc72813 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -105,6 +105,12 @@ hash_offset(Size offset) #define SH_DEFINE #include "lib/simplehash.h" +typedef union +{ + HnswElement element; + ItemPointerData indextid; +} HnswUnvisited; + /* * Get the max number of connections in an upper layer for each element in the index */ @@ -540,19 +546,19 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe /* * Load an element and optionally get its distance from q */ -void -HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, float *maxDistance) +static void +HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, float *maxDistance, HnswElement * element) { Buffer buf; Page page; HnswElementTuple etup; /* Read vector */ - buf = ReadBuffer(index, element->blkno); + buf = ReadBuffer(index, blkno); LockBuffer(buf, BUFFER_LOCK_SHARE); page = BufferGetPage(buf); - etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, element->offno)); + etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno)); Assert(HnswIsElementTuple(etup)); @@ -567,19 +573,32 @@ HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, /* Load element */ if (distance == NULL || maxDistance == NULL || *distance < *maxDistance) - HnswLoadElementFromTuple(element, etup, true, loadVec); + { + if (*element == NULL) + *element = HnswInitElementFromBlock(blkno, offno); + + HnswLoadElementFromTuple(*element, etup, true, loadVec); + } UnlockReleaseBuffer(buf); } /* - * Get the distance for a candidate + * Load an element and optionally get its distance from q + */ +void +HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, float *maxDistance) +{ + HnswLoadElementImpl(element->blkno, element->offno, distance, q, index, procinfo, collation, loadVec, maxDistance, &element); +} + +/* + * Get the distance for an element */ static float -GetCandidateDistance(char *base, HnswCandidate * hc, Datum q, FmgrInfo *procinfo, Oid collation) +GetElementDistance(char *base, HnswElement element, Datum q, FmgrInfo *procinfo, Oid collation) { - HnswElement hce = HnswPtrAccess(base, hc->element); - Datum value = HnswGetValue(base, hce); + Datum value = HnswGetValue(base, element); return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, value)); } @@ -594,22 +613,25 @@ HnswEntryCandidate(char *base, HnswElement entryPoint, Datum q, Relation index, HnswPtrStore(base, hc->element, entryPoint); if (index == NULL) - hc->distance = GetCandidateDistance(base, hc, q, procinfo, collation); + hc->distance = GetElementDistance(base, entryPoint, q, procinfo, collation); else HnswLoadElement(entryPoint, &hc->distance, &q, index, procinfo, collation, loadVec, NULL); return hc; } +#define HnswGetPairingHeapCandidate(membername, ptr) (pairingheap_container(HnswPairingHeapNode, membername, ptr)->inner) +#define HnswGetPairingHeapCandidateConst(membername, ptr) (pairingheap_const_container(HnswPairingHeapNode, membername, ptr)->inner) + /* * Compare candidate distances */ static int CompareNearestCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg) { - if (((const HnswPairingHeapNode *) a)->inner->distance < ((const HnswPairingHeapNode *) b)->inner->distance) + if (HnswGetPairingHeapCandidateConst(c_node, a)->distance < HnswGetPairingHeapCandidateConst(c_node, b)->distance) return 1; - if (((const HnswPairingHeapNode *) a)->inner->distance > ((const HnswPairingHeapNode *) b)->inner->distance) + if (HnswGetPairingHeapCandidateConst(c_node, a)->distance > HnswGetPairingHeapCandidateConst(c_node, b)->distance) return -1; return 0; @@ -621,10 +643,10 @@ CompareNearestCandidates(const pairingheap_node *a, const pairingheap_node *b, v static int CompareFurthestCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg) { - if (((const HnswPairingHeapNode *) a)->inner->distance < ((const HnswPairingHeapNode *) b)->inner->distance) + if (HnswGetPairingHeapCandidateConst(w_node, a)->distance < HnswGetPairingHeapCandidateConst(w_node, b)->distance) return -1; - if (((const HnswPairingHeapNode *) a)->inner->distance > ((const HnswPairingHeapNode *) b)->inner->distance) + if (HnswGetPairingHeapCandidateConst(w_node, a)->distance > HnswGetPairingHeapCandidateConst(w_node, b)->distance) return 1; return 0; @@ -660,11 +682,11 @@ InitVisited(char *base, visited_hash * v, Relation index, int ef, int m) * Add to visited */ static inline void -AddToVisited(char *base, visited_hash * v, HnswCandidate * hc, Relation index, bool *found) +AddToVisited(char *base, visited_hash * v, HnswElementPtr elementPtr, Relation index, bool *found) { if (index != NULL) { - HnswElement element = HnswPtrAccess(base, hc->element); + HnswElement element = HnswPtrAccess(base, elementPtr); ItemPointerData indextid; ItemPointerSet(&indextid, element->blkno, element->offno); @@ -673,21 +695,21 @@ AddToVisited(char *base, visited_hash * v, HnswCandidate * hc, Relation index, b else if (base != NULL) { #if PG_VERSION_NUM >= 130000 - HnswElement element = HnswPtrAccess(base, hc->element); + HnswElement element = HnswPtrAccess(base, elementPtr); - offsethash_insert_hash(v->offsets, HnswPtrOffset(hc->element), element->hash, found); + offsethash_insert_hash(v->offsets, HnswPtrOffset(elementPtr), element->hash, found); #else - offsethash_insert(v->offsets, HnswPtrOffset(hc->element), found); + offsethash_insert(v->offsets, HnswPtrOffset(elementPtr), found); #endif } else { #if PG_VERSION_NUM >= 130000 - HnswElement element = HnswPtrAccess(base, hc->element); + HnswElement element = HnswPtrAccess(base, elementPtr); - pointerhash_insert_hash(v->pointers, (uintptr_t) HnswPtrPointer(hc->element), element->hash, found); + pointerhash_insert_hash(v->pointers, (uintptr_t) HnswPtrPointer(elementPtr), element->hash, found); #else - pointerhash_insert(v->pointers, (uintptr_t) HnswPtrPointer(hc->element), found); + pointerhash_insert(v->pointers, (uintptr_t) HnswPtrPointer(elementPtr), found); #endif } } @@ -696,20 +718,86 @@ AddToVisited(char *base, visited_hash * v, HnswCandidate * hc, Relation index, b * Count element towards ef */ static inline bool -CountElement(char *base, HnswElement skipElement, HnswCandidate * hc) +CountElement(HnswElement skipElement, HnswElement e) { - HnswElement e; - if (skipElement == NULL) return true; /* Ensure does not access heaptidsLength during in-memory build */ pg_memory_barrier(); - e = HnswPtrAccess(base, hc->element); return e->heaptidsLength != 0; } +/* + * Load unvisited neighbors from memory + */ +static void +HnswLoadUnvisitedFromMemory(char *base, HnswElement element, HnswUnvisited * unvisited, int *unvisitedLength, visited_hash * v, int lc, HnswNeighborArray * neighborhoodData, Size neighborhoodSize) +{ + /* Get the neighborhood at layer lc */ + HnswNeighborArray *neighborhood = HnswGetNeighbors(base, element, lc); + + /* Copy neighborhood to local memory */ + LWLockAcquire(&element->lock, LW_SHARED); + memcpy(neighborhoodData, neighborhood, neighborhoodSize); + LWLockRelease(&element->lock); + + *unvisitedLength = 0; + + for (int i = 0; i < neighborhoodData->length; i++) + { + HnswCandidate *hc = &neighborhoodData->items[i]; + bool found; + + AddToVisited(base, v, hc->element, NULL, &found); + + if (!found) + unvisited[(*unvisitedLength)++].element = HnswPtrAccess(base, hc->element); + } +} + +/* + * Load unvisited neighbors from disk + */ +static void +HnswLoadUnvisitedFromDisk(HnswElement element, HnswUnvisited * unvisited, int *unvisitedLength, visited_hash * v, Relation index, int m, int lm, int lc) +{ + Buffer buf; + Page page; + HnswNeighborTuple ntup; + int start; + ItemPointerData indextids[HNSW_MAX_M * 2]; + + buf = ReadBuffer(index, element->neighborPage); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + + ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno)); + start = (element->level - lc) * m; + + /* Copy to minimize lock time */ + memcpy(&indextids, ntup->indextids + start, lm * sizeof(ItemPointerData)); + + UnlockReleaseBuffer(buf); + + *unvisitedLength = 0; + + for (int i = 0; i < lm; i++) + { + ItemPointer indextid = &indextids[i]; + bool found; + + if (!ItemPointerIsValid(indextid)) + break; + + tidhash_insert(v->tids, *indextid, &found); + + if (!found) + unvisited[(*unvisitedLength)++].indextid = *indextid; + } +} + /* * Algorithm 2 from paper */ @@ -724,6 +812,9 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F ListCell *lc2; HnswNeighborArray *neighborhoodData = NULL; Size neighborhoodSize = 0; + int lm = HnswGetLayerM(m, lc); + HnswUnvisited *unvisited = palloc(lm * sizeof(HnswUnvisited)); + int unvisitedLength; if (v == NULL) v = &v2; @@ -734,7 +825,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F /* Create local memory for neighborhood if needed */ if (index == NULL) { - neighborhoodSize = HNSW_NEIGHBOR_ARRAY_SIZE(HnswGetLayerM(m, lc)); + neighborhoodSize = HNSW_NEIGHBOR_ARRAY_SIZE(lm); neighborhoodData = palloc(neighborhoodSize); } @@ -743,26 +834,27 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F { HnswCandidate *hc = (HnswCandidate *) lfirst(lc2); bool found; + HnswPairingHeapNode *node; - AddToVisited(base, v, hc, index, &found); + AddToVisited(base, v, hc->element, index, &found); - pairingheap_add(C, &(CreatePairingHeapNode(hc)->ph_node)); - pairingheap_add(W, &(CreatePairingHeapNode(hc)->ph_node)); + node = CreatePairingHeapNode(hc); + pairingheap_add(C, &node->c_node); + pairingheap_add(W, &node->w_node); /* * Do not count elements being deleted towards ef when vacuuming. It * would be ideal to do this for inserts as well, but this could * affect insert performance. */ - if (CountElement(base, skipElement, hc)) + if (CountElement(skipElement, HnswPtrAccess(base, hc->element))) wlen++; } while (!pairingheap_is_empty(C)) { - HnswNeighborArray *neighborhood; - HnswCandidate *c = ((HnswPairingHeapNode *) pairingheap_remove_first(C))->inner; - HnswCandidate *f = ((HnswPairingHeapNode *) pairingheap_first(W))->inner; + HnswCandidate *c = HnswGetPairingHeapCandidate(c_node, pairingheap_remove_first(C)); + HnswCandidate *f = HnswGetPairingHeapCandidate(w_node, pairingheap_first(W)); HnswElement cElement; if (c->distance > f->distance) @@ -770,89 +862,86 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F cElement = HnswPtrAccess(base, c->element); - if (HnswPtrIsNull(base, cElement->neighbors)) - HnswLoadNeighbors(cElement, index, m); - - /* Get the neighborhood at layer lc */ - neighborhood = HnswGetNeighbors(base, cElement, lc); - - /* Copy neighborhood to local memory if needed */ if (index == NULL) + HnswLoadUnvisitedFromMemory(base, cElement, unvisited, &unvisitedLength, v, lc, neighborhoodData, neighborhoodSize); + else + HnswLoadUnvisitedFromDisk(cElement, unvisited, &unvisitedLength, v, index, m, lm, lc); + + for (int i = 0; i < unvisitedLength; i++) { - LWLockAcquire(&cElement->lock, LW_SHARED); - memcpy(neighborhoodData, neighborhood, neighborhoodSize); - LWLockRelease(&cElement->lock); - neighborhood = neighborhoodData; - } + HnswElement eElement; + HnswCandidate *e; + HnswPairingHeapNode *node; + float eDistance; + bool alwaysAdd = wlen < ef; + bool discard; - for (int i = 0; i < neighborhood->length; i++) - { - HnswCandidate *e = &neighborhood->items[i]; - bool visited; + f = HnswGetPairingHeapCandidate(w_node, pairingheap_first(W)); - AddToVisited(base, v, e, index, &visited); - - if (!visited) + if (index == NULL) { - float eDistance; - HnswElement eElement = HnswPtrAccess(base, e->element); - bool alwaysAdd = wlen < ef; + eElement = unvisited[i].element; + eDistance = GetElementDistance(base, eElement, q, procinfo, collation); - f = ((HnswPairingHeapNode *) pairingheap_first(W))->inner; + discard = !(eDistance < f->distance || alwaysAdd); + } + else + { + ItemPointer indextid = &unvisited[i].indextid; + BlockNumber blkno = ItemPointerGetBlockNumber(indextid); + OffsetNumber offno = ItemPointerGetOffsetNumber(indextid); - if (index == NULL) - eDistance = GetCandidateDistance(base, e, q, procinfo, collation); - else - HnswLoadElement(eElement, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd || discarded != NULL ? NULL : &f->distance); + /* Avoid any allocations if not adding */ + eElement = NULL; + HnswLoadElementImpl(blkno, offno, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd || discarded != NULL ? NULL : &f->distance, &eElement); - if (eDistance < f->distance || alwaysAdd) + discard = eElement == NULL || !(eDistance < f->distance || alwaysAdd); + } + + if (discard) + { + if (discarded != NULL) { - HnswCandidate *ec; + /* Create a new candidate */ + e = palloc(sizeof(HnswCandidate)); + HnswPtrStore(base, e->element, eElement); + e->distance = eDistance; - Assert(!eElement->deleted); - Assert(eElement->level >= lc); - - /* Make robust to issues */ - if (eElement->level < lc) - continue; - - /* Copy e */ - ec = palloc(sizeof(HnswCandidate)); - HnswPtrStore(base, ec->element, eElement); - ec->distance = eDistance; - - pairingheap_add(C, &(CreatePairingHeapNode(ec)->ph_node)); - pairingheap_add(W, &(CreatePairingHeapNode(ec)->ph_node)); - - /* - * Do not count elements being deleted towards ef when - * vacuuming. It would be ideal to do this for inserts as - * well, but this could affect insert performance. - */ - if (CountElement(base, skipElement, e)) - { - wlen++; - - /* No need to decrement wlen */ - if (wlen > ef) - { - HnswCandidate *hc = ((HnswPairingHeapNode *) pairingheap_remove_first(W))->inner; - - if (discarded != NULL) - *discarded = lappend(*discarded, hc); - } - } + *discarded = lappend(*discarded, e); } - else if (discarded != NULL) + + continue; + } + + /* Make robust to issues */ + if (eElement->level < lc) + continue; + + /* Create a new candidate */ + e = palloc(sizeof(HnswCandidate)); + HnswPtrStore(base, e->element, eElement); + e->distance = eDistance; + + node = CreatePairingHeapNode(e); + pairingheap_add(C, &node->c_node); + pairingheap_add(W, &node->w_node); + + /* + * Do not count elements being deleted towards ef when vacuuming. + * It would be ideal to do this for inserts as well, but this + * could affect insert performance. + */ + if (CountElement(skipElement, eElement)) + { + wlen++; + + /* No need to decrement wlen */ + if (wlen > ef) { - HnswCandidate *ec; + HnswCandidate *hc = HnswGetPairingHeapCandidate(w_node, pairingheap_remove_first(W)); - /* Copy e */ - ec = palloc(sizeof(HnswCandidate)); - HnswPtrStore(base, ec->element, eElement); - ec->distance = eDistance; - - *discarded = lappend(*discarded, ec); + if (discarded != NULL) + *discarded = lappend(*discarded, hc); } } } @@ -861,7 +950,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F /* Add each element of W to w */ while (!pairingheap_is_empty(W)) { - HnswCandidate *hc = ((HnswPairingHeapNode *) pairingheap_remove_first(W))->inner; + HnswCandidate *hc = HnswGetPairingHeapCandidate(w_node, pairingheap_remove_first(W)); w = lappend(w, hc); } @@ -1124,7 +1213,7 @@ HnswUpdateConnection(char *base, HnswElement element, HnswCandidate * hc, int lm if (HnswPtrIsNull(base, hc3Element->value)) HnswLoadElement(hc3Element, &hc3->distance, &q, index, procinfo, collation, true, NULL); else - hc3->distance = GetCandidateDistance(base, hc3, q, procinfo, collation); + hc3->distance = GetElementDistance(base, hc3Element, q, procinfo, collation); /* Prune element if being deleted */ if (hc3Element->heaptidsLength == 0)