Reduced memory usage for HNSW index scans

Co-authored-by: Heikki Linnakangas <heikki.linnakangas@iki.fi>
This commit is contained in:
Andrew Kane
2024-09-19 02:17:51 -07:00
parent d74d3065bc
commit 8dde14a736

View File

@@ -580,13 +580,12 @@ HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index,
}
/*
* Get the distance for a candidate
* Get the distance for an element
*/
static float
GetCandidateDistance(char *base, HnswCandidate * hc, Datum q, FmgrInfo *procinfo, Oid collation)
GetElementDistance(char *base, HnswElement element, Datum q, FmgrInfo *procinfo, Oid collation)
{
HnswElement hce = HnswPtrAccess(base, hc->element);
Datum value = HnswGetValue(base, hce);
Datum value = HnswGetValue(base, element);
return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, value));
}
@@ -601,7 +600,7 @@ HnswEntryCandidate(char *base, HnswElement entryPoint, Datum q, Relation index,
HnswPtrStore(base, hc->element, entryPoint);
if (index == NULL)
hc->distance = GetCandidateDistance(base, hc, q, procinfo, collation);
hc->distance = GetElementDistance(base, entryPoint, q, procinfo, collation);
else
HnswLoadElement(entryPoint, &hc->distance, &q, index, procinfo, collation, loadVec, NULL);
return hc;
@@ -706,20 +705,87 @@ AddToVisited(char *base, visited_hash * v, HnswCandidate * hc, Relation index, b
* Count element towards ef
*/
static inline bool
CountElement(char *base, HnswElement skipElement, HnswCandidate * hc)
CountElement(char *base, HnswElement skipElement, HnswElement e)
{
HnswElement e;
if (skipElement == NULL)
return true;
/* Ensure does not access heaptidsLength during in-memory build */
pg_memory_barrier();
e = HnswPtrAccess(base, hc->element);
return e->heaptidsLength != 0;
}
/*
* Load unvisited neighbors from memory
*/
static void
HnswLoadUnvisitedFromMemory(char *base, HnswElement element, HnswElement * unvisited, int *unvisitedLength, visited_hash * v, int lc, HnswNeighborArray * neighborhoodData, Size neighborhoodSize)
{
/* Get the neighborhood at layer lc */
HnswNeighborArray *neighborhood = HnswGetNeighbors(base, element, lc);
/* Copy neighborhood to local memory */
LWLockAcquire(&element->lock, LW_SHARED);
memcpy(neighborhoodData, neighborhood, neighborhoodSize);
LWLockRelease(&element->lock);
neighborhood = neighborhoodData;
*unvisitedLength = 0;
for (int i = 0; i < neighborhood->length; i++)
{
HnswCandidate *hc = &neighborhood->items[i];
bool found;
AddToVisited(base, v, hc, NULL, &found);
if (!found)
unvisited[(*unvisitedLength)++] = HnswPtrAccess(base, hc->element);
}
}
/*
* Load unvisited neighbors from disk
*/
static void
HnswLoadUnvisitedFromDisk(HnswElement element, HnswElement * unvisited, int *unvisitedLength, visited_hash * v, Relation index, int m, int lm, int lc)
{
Buffer buf;
Page page;
HnswNeighborTuple ntup;
int start;
ItemPointerData indextids[HNSW_MAX_M * 2];
buf = ReadBuffer(index, element->neighborPage);
LockBuffer(buf, BUFFER_LOCK_SHARE);
page = BufferGetPage(buf);
ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno));
start = (element->level - lc) * m;
/* Copy to minimize lock time */
memcpy(&indextids, ntup->indextids + start, lm * sizeof(ItemPointerData));
UnlockReleaseBuffer(buf);
*unvisitedLength = 0;
for (int i = 0; i < lm; i++)
{
ItemPointer indextid = &indextids[i];
bool found;
if (!ItemPointerIsValid(indextid))
break;
tidhash_insert(v->tids, *indextid, &found);
if (!found)
unvisited[(*unvisitedLength)++] = HnswInitElementFromBlock(ItemPointerGetBlockNumber(indextid), ItemPointerGetOffsetNumber(indextid));
}
}
/*
* Algorithm 2 from paper
*/
@@ -734,13 +800,16 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
ListCell *lc2;
HnswNeighborArray *neighborhoodData = NULL;
Size neighborhoodSize = 0;
int lm = HnswGetLayerM(m, lc);
HnswElement *unvisited = palloc(lm * sizeof(HnswElement));
int unvisitedLength;
InitVisited(base, &v, index, ef, m);
/* Create local memory for neighborhood if needed */
if (index == NULL)
{
neighborhoodSize = HNSW_NEIGHBOR_ARRAY_SIZE(HnswGetLayerM(m, lc));
neighborhoodSize = HNSW_NEIGHBOR_ARRAY_SIZE(lm);
neighborhoodData = palloc(neighborhoodSize);
}
@@ -762,13 +831,12 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
* would be ideal to do this for inserts as well, but this could
* affect insert performance.
*/
if (CountElement(base, skipElement, hc))
if (CountElement(base, skipElement, HnswPtrAccess(base, hc->element)))
wlen++;
}
while (!pairingheap_is_empty(C))
{
HnswNeighborArray *neighborhood;
HnswCandidate *c = HnswGetPairingHeapCandidate(c_node, pairingheap_remove_first(C));
HnswCandidate *f = HnswGetPairingHeapCandidate(w_node, pairingheap_first(W));
HnswElement cElement;
@@ -778,74 +846,56 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
cElement = HnswPtrAccess(base, c->element);
if (HnswPtrIsNull(base, cElement->neighbors))
HnswLoadNeighbors(cElement, index, m);
/* Get the neighborhood at layer lc */
neighborhood = HnswGetNeighbors(base, cElement, lc);
/* Copy neighborhood to local memory if needed */
if (index == NULL)
HnswLoadUnvisitedFromMemory(base, cElement, unvisited, &unvisitedLength, &v, lc, neighborhoodData, neighborhoodSize);
else
HnswLoadUnvisitedFromDisk(cElement, unvisited, &unvisitedLength, &v, index, m, lm, lc);
for (int i = 0; i < unvisitedLength; i++)
{
LWLockAcquire(&cElement->lock, LW_SHARED);
memcpy(neighborhoodData, neighborhood, neighborhoodSize);
LWLockRelease(&cElement->lock);
neighborhood = neighborhoodData;
}
HnswElement eElement = unvisited[i];
float eDistance;
bool alwaysAdd = wlen < ef;
for (int i = 0; i < neighborhood->length; i++)
{
HnswCandidate *e = &neighborhood->items[i];
bool visited;
f = HnswGetPairingHeapCandidate(w_node, pairingheap_first(W));
AddToVisited(base, &v, e, index, &visited);
if (index == NULL)
eDistance = GetElementDistance(base, eElement, q, procinfo, collation);
else
HnswLoadElement(eElement, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd ? NULL : &f->distance);
if (!visited)
if (eDistance < f->distance || alwaysAdd)
{
float eDistance;
HnswElement eElement = HnswPtrAccess(base, e->element);
bool alwaysAdd = wlen < ef;
HnswCandidate *e;
HnswPairingHeapNode *node;
f = HnswGetPairingHeapCandidate(w_node, pairingheap_first(W));
Assert(!eElement->deleted);
if (index == NULL)
eDistance = GetCandidateDistance(base, e, q, procinfo, collation);
else
HnswLoadElement(eElement, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd ? NULL : &f->distance);
/* Make robust to issues */
if (eElement->level < lc)
continue;
if (eDistance < f->distance || alwaysAdd)
/* Create a new candidate */
e = palloc(sizeof(HnswCandidate));
HnswPtrStore(base, e->element, eElement);
e->distance = eDistance;
node = CreatePairingHeapNode(e);
pairingheap_add(C, &node->c_node);
pairingheap_add(W, &node->w_node);
/*
* Do not count elements being deleted towards ef when
* vacuuming. It would be ideal to do this for inserts as
* well, but this could affect insert performance.
*/
if (CountElement(base, skipElement, eElement))
{
HnswCandidate *ec;
HnswPairingHeapNode *node;
wlen++;
Assert(!eElement->deleted);
/* Make robust to issues */
if (eElement->level < lc)
continue;
/* Copy e */
ec = palloc(sizeof(HnswCandidate));
HnswPtrStore(base, ec->element, eElement);
ec->distance = eDistance;
node = CreatePairingHeapNode(ec);
pairingheap_add(C, &node->c_node);
pairingheap_add(W, &node->w_node);
/*
* Do not count elements being deleted towards ef when
* vacuuming. It would be ideal to do this for inserts as
* well, but this could affect insert performance.
*/
if (CountElement(base, skipElement, e))
{
wlen++;
/* No need to decrement wlen */
if (wlen > ef)
pairingheap_remove_first(W);
}
/* No need to decrement wlen */
if (wlen > ef)
pairingheap_remove_first(W);
}
}
}
@@ -1117,7 +1167,7 @@ HnswUpdateConnection(char *base, HnswElement element, HnswCandidate * hc, int lm
if (HnswPtrIsNull(base, hc3Element->value))
HnswLoadElement(hc3Element, &hc3->distance, &q, index, procinfo, collation, true, NULL);
else
hc3->distance = GetCandidateDistance(base, hc3, q, procinfo, collation);
hc3->distance = GetElementDistance(base, hc3Element, q, procinfo, collation);
/* Prune element if being deleted */
if (hc3Element->heaptidsLength == 0)