diff --git a/src/hnswutils.c b/src/hnswutils.c index 4881b60..764e9d5 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -112,6 +112,12 @@ typedef union tidhash_hash *tids; } visited_hash; +typedef union +{ + HnswElement element; + ItemPointerData indextid; +} HnswUnvisited; + /* * Get the max number of connections in an upper layer for each element in the index */ @@ -547,19 +553,19 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe /* * Load an element and optionally get its distance from q */ -void -HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, float *maxDistance) +static void +HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, float *maxDistance, HnswElement * element) { Buffer buf; Page page; HnswElementTuple etup; /* Read vector */ - buf = ReadBuffer(index, element->blkno); + buf = ReadBuffer(index, blkno); LockBuffer(buf, BUFFER_LOCK_SHARE); page = BufferGetPage(buf); - etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, element->offno)); + etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno)); Assert(HnswIsElementTuple(etup)); @@ -574,11 +580,25 @@ HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, /* Load element */ if (distance == NULL || maxDistance == NULL || *distance < *maxDistance) - HnswLoadElementFromTuple(element, etup, true, loadVec); + { + if (*element == NULL) + *element = HnswInitElementFromBlock(blkno, offno); + + HnswLoadElementFromTuple(*element, etup, true, loadVec); + } UnlockReleaseBuffer(buf); } +/* + * Load an element and optionally get its distance from q + */ +void +HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, float *maxDistance) +{ + HnswLoadElementImpl(element->blkno, element->offno, distance, q, index, procinfo, collation, loadVec, maxDistance, &element); +} + /* * Get the distance for an element */ @@ -720,7 +740,7 @@ CountElement(char *base, HnswElement skipElement, HnswElement e) * Load unvisited neighbors from memory */ static void -HnswLoadUnvisitedFromMemory(char *base, HnswElement element, HnswElement * unvisited, int *unvisitedLength, visited_hash * v, int lc, HnswNeighborArray * neighborhoodData, Size neighborhoodSize) +HnswLoadUnvisitedFromMemory(char *base, HnswElement element, HnswUnvisited * unvisited, int *unvisitedLength, visited_hash * v, int lc, HnswNeighborArray * neighborhoodData, Size neighborhoodSize) { /* Get the neighborhood at layer lc */ HnswNeighborArray *neighborhood = HnswGetNeighbors(base, element, lc); @@ -741,7 +761,7 @@ HnswLoadUnvisitedFromMemory(char *base, HnswElement element, HnswElement * unvis AddToVisited(base, v, hc->element, NULL, &found); if (!found) - unvisited[(*unvisitedLength)++] = HnswPtrAccess(base, hc->element); + unvisited[(*unvisitedLength)++].element = HnswPtrAccess(base, hc->element); } } @@ -749,7 +769,7 @@ HnswLoadUnvisitedFromMemory(char *base, HnswElement element, HnswElement * unvis * Load unvisited neighbors from disk */ static void -HnswLoadUnvisitedFromDisk(HnswElement element, HnswElement * unvisited, int *unvisitedLength, visited_hash * v, Relation index, int m, int lm, int lc) +HnswLoadUnvisitedFromDisk(HnswElement element, HnswUnvisited * unvisited, int *unvisitedLength, visited_hash * v, Relation index, int m, int lm, int lc) { Buffer buf; Page page; @@ -782,7 +802,7 @@ HnswLoadUnvisitedFromDisk(HnswElement element, HnswElement * unvisited, int *unv tidhash_insert(v->tids, *indextid, &found); if (!found) - unvisited[(*unvisitedLength)++] = HnswInitElementFromBlock(ItemPointerGetBlockNumber(indextid), ItemPointerGetOffsetNumber(indextid)); + unvisited[(*unvisitedLength)++].indextid = *indextid; } } @@ -801,7 +821,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F HnswNeighborArray *neighborhoodData = NULL; Size neighborhoodSize = 0; int lm = HnswGetLayerM(m, lc); - HnswElement *unvisited = palloc(lm * sizeof(HnswElement)); + HnswUnvisited *unvisited = palloc(lm * sizeof(HnswUnvisited)); int unvisitedLength; InitVisited(base, &v, index, ef, m); @@ -853,16 +873,27 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F for (int i = 0; i < unvisitedLength; i++) { - HnswElement eElement = unvisited[i]; + HnswElement eElement; float eDistance; bool alwaysAdd = wlen < ef; f = HnswGetPairingHeapCandidate(w_node, pairingheap_first(W)); if (index == NULL) + { + eElement = unvisited[i].element; eDistance = GetElementDistance(base, eElement, q, procinfo, collation); + } else - HnswLoadElement(eElement, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd ? NULL : &f->distance); + { + ItemPointer indextid = &unvisited[i].indextid; + BlockNumber blkno = ItemPointerGetBlockNumber(indextid); + OffsetNumber offno = ItemPointerGetOffsetNumber(indextid); + + /* Avoid any allocations if not adding */ + eElement = NULL; + HnswLoadElementImpl(blkno, offno, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd ? NULL : &f->distance, &eElement); + } if (eDistance < f->distance || alwaysAdd) {