Merge branch 'master' into hnsw-streaming

This commit is contained in:
Andrew Kane
2024-09-19 14:52:34 -07:00
3 changed files with 201 additions and 107 deletions

View File

@@ -1,3 +1,7 @@
## 0.7.5 (unreleased)
- Reduced memory usage for HNSW index scans
## 0.7.4 (2024-08-05)
- Fixed locking for parallel HNSW index builds

View File

@@ -164,8 +164,9 @@ struct HnswNeighborArray
typedef struct HnswPairingHeapNode
{
pairingheap_node ph_node;
HnswCandidate *inner;
pairingheap_node c_node;
pairingheap_node w_node;
} HnswPairingHeapNode;
/* HNSW index options */

View File

@@ -105,6 +105,12 @@ hash_offset(Size offset)
#define SH_DEFINE
#include "lib/simplehash.h"
typedef union
{
HnswElement element;
ItemPointerData indextid;
} HnswUnvisited;
/*
* Get the max number of connections in an upper layer for each element in the index
*/
@@ -540,19 +546,19 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe
/*
* Load an element and optionally get its distance from q
*/
void
HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, float *maxDistance)
static void
HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, float *maxDistance, HnswElement * element)
{
Buffer buf;
Page page;
HnswElementTuple etup;
/* Read vector */
buf = ReadBuffer(index, element->blkno);
buf = ReadBuffer(index, blkno);
LockBuffer(buf, BUFFER_LOCK_SHARE);
page = BufferGetPage(buf);
etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, element->offno));
etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno));
Assert(HnswIsElementTuple(etup));
@@ -567,19 +573,32 @@ HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index,
/* Load element */
if (distance == NULL || maxDistance == NULL || *distance < *maxDistance)
HnswLoadElementFromTuple(element, etup, true, loadVec);
{
if (*element == NULL)
*element = HnswInitElementFromBlock(blkno, offno);
HnswLoadElementFromTuple(*element, etup, true, loadVec);
}
UnlockReleaseBuffer(buf);
}
/*
* Get the distance for a candidate
* Load an element and optionally get its distance from q
*/
void
HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, float *maxDistance)
{
HnswLoadElementImpl(element->blkno, element->offno, distance, q, index, procinfo, collation, loadVec, maxDistance, &element);
}
/*
* Get the distance for an element
*/
static float
GetCandidateDistance(char *base, HnswCandidate * hc, Datum q, FmgrInfo *procinfo, Oid collation)
GetElementDistance(char *base, HnswElement element, Datum q, FmgrInfo *procinfo, Oid collation)
{
HnswElement hce = HnswPtrAccess(base, hc->element);
Datum value = HnswGetValue(base, hce);
Datum value = HnswGetValue(base, element);
return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, value));
}
@@ -594,22 +613,25 @@ HnswEntryCandidate(char *base, HnswElement entryPoint, Datum q, Relation index,
HnswPtrStore(base, hc->element, entryPoint);
if (index == NULL)
hc->distance = GetCandidateDistance(base, hc, q, procinfo, collation);
hc->distance = GetElementDistance(base, entryPoint, q, procinfo, collation);
else
HnswLoadElement(entryPoint, &hc->distance, &q, index, procinfo, collation, loadVec, NULL);
return hc;
}
#define HnswGetPairingHeapCandidate(membername, ptr) (pairingheap_container(HnswPairingHeapNode, membername, ptr)->inner)
#define HnswGetPairingHeapCandidateConst(membername, ptr) (pairingheap_const_container(HnswPairingHeapNode, membername, ptr)->inner)
/*
* Compare candidate distances
*/
static int
CompareNearestCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg)
{
if (((const HnswPairingHeapNode *) a)->inner->distance < ((const HnswPairingHeapNode *) b)->inner->distance)
if (HnswGetPairingHeapCandidateConst(c_node, a)->distance < HnswGetPairingHeapCandidateConst(c_node, b)->distance)
return 1;
if (((const HnswPairingHeapNode *) a)->inner->distance > ((const HnswPairingHeapNode *) b)->inner->distance)
if (HnswGetPairingHeapCandidateConst(c_node, a)->distance > HnswGetPairingHeapCandidateConst(c_node, b)->distance)
return -1;
return 0;
@@ -621,10 +643,10 @@ CompareNearestCandidates(const pairingheap_node *a, const pairingheap_node *b, v
static int
CompareFurthestCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg)
{
if (((const HnswPairingHeapNode *) a)->inner->distance < ((const HnswPairingHeapNode *) b)->inner->distance)
if (HnswGetPairingHeapCandidateConst(w_node, a)->distance < HnswGetPairingHeapCandidateConst(w_node, b)->distance)
return -1;
if (((const HnswPairingHeapNode *) a)->inner->distance > ((const HnswPairingHeapNode *) b)->inner->distance)
if (HnswGetPairingHeapCandidateConst(w_node, a)->distance > HnswGetPairingHeapCandidateConst(w_node, b)->distance)
return 1;
return 0;
@@ -660,11 +682,11 @@ InitVisited(char *base, visited_hash * v, Relation index, int ef, int m)
* Add to visited
*/
static inline void
AddToVisited(char *base, visited_hash * v, HnswCandidate * hc, Relation index, bool *found)
AddToVisited(char *base, visited_hash * v, HnswElementPtr elementPtr, Relation index, bool *found)
{
if (index != NULL)
{
HnswElement element = HnswPtrAccess(base, hc->element);
HnswElement element = HnswPtrAccess(base, elementPtr);
ItemPointerData indextid;
ItemPointerSet(&indextid, element->blkno, element->offno);
@@ -673,21 +695,21 @@ AddToVisited(char *base, visited_hash * v, HnswCandidate * hc, Relation index, b
else if (base != NULL)
{
#if PG_VERSION_NUM >= 130000
HnswElement element = HnswPtrAccess(base, hc->element);
HnswElement element = HnswPtrAccess(base, elementPtr);
offsethash_insert_hash(v->offsets, HnswPtrOffset(hc->element), element->hash, found);
offsethash_insert_hash(v->offsets, HnswPtrOffset(elementPtr), element->hash, found);
#else
offsethash_insert(v->offsets, HnswPtrOffset(hc->element), found);
offsethash_insert(v->offsets, HnswPtrOffset(elementPtr), found);
#endif
}
else
{
#if PG_VERSION_NUM >= 130000
HnswElement element = HnswPtrAccess(base, hc->element);
HnswElement element = HnswPtrAccess(base, elementPtr);
pointerhash_insert_hash(v->pointers, (uintptr_t) HnswPtrPointer(hc->element), element->hash, found);
pointerhash_insert_hash(v->pointers, (uintptr_t) HnswPtrPointer(elementPtr), element->hash, found);
#else
pointerhash_insert(v->pointers, (uintptr_t) HnswPtrPointer(hc->element), found);
pointerhash_insert(v->pointers, (uintptr_t) HnswPtrPointer(elementPtr), found);
#endif
}
}
@@ -696,20 +718,86 @@ AddToVisited(char *base, visited_hash * v, HnswCandidate * hc, Relation index, b
* Count element towards ef
*/
static inline bool
CountElement(char *base, HnswElement skipElement, HnswCandidate * hc)
CountElement(HnswElement skipElement, HnswElement e)
{
HnswElement e;
if (skipElement == NULL)
return true;
/* Ensure does not access heaptidsLength during in-memory build */
pg_memory_barrier();
e = HnswPtrAccess(base, hc->element);
return e->heaptidsLength != 0;
}
/*
* Load unvisited neighbors from memory
*/
static void
HnswLoadUnvisitedFromMemory(char *base, HnswElement element, HnswUnvisited * unvisited, int *unvisitedLength, visited_hash * v, int lc, HnswNeighborArray * neighborhoodData, Size neighborhoodSize)
{
/* Get the neighborhood at layer lc */
HnswNeighborArray *neighborhood = HnswGetNeighbors(base, element, lc);
/* Copy neighborhood to local memory */
LWLockAcquire(&element->lock, LW_SHARED);
memcpy(neighborhoodData, neighborhood, neighborhoodSize);
LWLockRelease(&element->lock);
*unvisitedLength = 0;
for (int i = 0; i < neighborhoodData->length; i++)
{
HnswCandidate *hc = &neighborhoodData->items[i];
bool found;
AddToVisited(base, v, hc->element, NULL, &found);
if (!found)
unvisited[(*unvisitedLength)++].element = HnswPtrAccess(base, hc->element);
}
}
/*
* Load unvisited neighbors from disk
*/
static void
HnswLoadUnvisitedFromDisk(HnswElement element, HnswUnvisited * unvisited, int *unvisitedLength, visited_hash * v, Relation index, int m, int lm, int lc)
{
Buffer buf;
Page page;
HnswNeighborTuple ntup;
int start;
ItemPointerData indextids[HNSW_MAX_M * 2];
buf = ReadBuffer(index, element->neighborPage);
LockBuffer(buf, BUFFER_LOCK_SHARE);
page = BufferGetPage(buf);
ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno));
start = (element->level - lc) * m;
/* Copy to minimize lock time */
memcpy(&indextids, ntup->indextids + start, lm * sizeof(ItemPointerData));
UnlockReleaseBuffer(buf);
*unvisitedLength = 0;
for (int i = 0; i < lm; i++)
{
ItemPointer indextid = &indextids[i];
bool found;
if (!ItemPointerIsValid(indextid))
break;
tidhash_insert(v->tids, *indextid, &found);
if (!found)
unvisited[(*unvisitedLength)++].indextid = *indextid;
}
}
/*
* Algorithm 2 from paper
*/
@@ -724,6 +812,9 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
ListCell *lc2;
HnswNeighborArray *neighborhoodData = NULL;
Size neighborhoodSize = 0;
int lm = HnswGetLayerM(m, lc);
HnswUnvisited *unvisited = palloc(lm * sizeof(HnswUnvisited));
int unvisitedLength;
if (v == NULL)
v = &v2;
@@ -734,7 +825,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
/* Create local memory for neighborhood if needed */
if (index == NULL)
{
neighborhoodSize = HNSW_NEIGHBOR_ARRAY_SIZE(HnswGetLayerM(m, lc));
neighborhoodSize = HNSW_NEIGHBOR_ARRAY_SIZE(lm);
neighborhoodData = palloc(neighborhoodSize);
}
@@ -743,26 +834,27 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
{
HnswCandidate *hc = (HnswCandidate *) lfirst(lc2);
bool found;
HnswPairingHeapNode *node;
AddToVisited(base, v, hc, index, &found);
AddToVisited(base, v, hc->element, index, &found);
pairingheap_add(C, &(CreatePairingHeapNode(hc)->ph_node));
pairingheap_add(W, &(CreatePairingHeapNode(hc)->ph_node));
node = CreatePairingHeapNode(hc);
pairingheap_add(C, &node->c_node);
pairingheap_add(W, &node->w_node);
/*
* Do not count elements being deleted towards ef when vacuuming. It
* would be ideal to do this for inserts as well, but this could
* affect insert performance.
*/
if (CountElement(base, skipElement, hc))
if (CountElement(skipElement, HnswPtrAccess(base, hc->element)))
wlen++;
}
while (!pairingheap_is_empty(C))
{
HnswNeighborArray *neighborhood;
HnswCandidate *c = ((HnswPairingHeapNode *) pairingheap_remove_first(C))->inner;
HnswCandidate *f = ((HnswPairingHeapNode *) pairingheap_first(W))->inner;
HnswCandidate *c = HnswGetPairingHeapCandidate(c_node, pairingheap_remove_first(C));
HnswCandidate *f = HnswGetPairingHeapCandidate(w_node, pairingheap_first(W));
HnswElement cElement;
if (c->distance > f->distance)
@@ -770,89 +862,86 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
cElement = HnswPtrAccess(base, c->element);
if (HnswPtrIsNull(base, cElement->neighbors))
HnswLoadNeighbors(cElement, index, m);
/* Get the neighborhood at layer lc */
neighborhood = HnswGetNeighbors(base, cElement, lc);
/* Copy neighborhood to local memory if needed */
if (index == NULL)
HnswLoadUnvisitedFromMemory(base, cElement, unvisited, &unvisitedLength, v, lc, neighborhoodData, neighborhoodSize);
else
HnswLoadUnvisitedFromDisk(cElement, unvisited, &unvisitedLength, v, index, m, lm, lc);
for (int i = 0; i < unvisitedLength; i++)
{
LWLockAcquire(&cElement->lock, LW_SHARED);
memcpy(neighborhoodData, neighborhood, neighborhoodSize);
LWLockRelease(&cElement->lock);
neighborhood = neighborhoodData;
}
HnswElement eElement;
HnswCandidate *e;
HnswPairingHeapNode *node;
float eDistance;
bool alwaysAdd = wlen < ef;
bool discard;
for (int i = 0; i < neighborhood->length; i++)
{
HnswCandidate *e = &neighborhood->items[i];
bool visited;
f = HnswGetPairingHeapCandidate(w_node, pairingheap_first(W));
AddToVisited(base, v, e, index, &visited);
if (!visited)
if (index == NULL)
{
float eDistance;
HnswElement eElement = HnswPtrAccess(base, e->element);
bool alwaysAdd = wlen < ef;
eElement = unvisited[i].element;
eDistance = GetElementDistance(base, eElement, q, procinfo, collation);
f = ((HnswPairingHeapNode *) pairingheap_first(W))->inner;
discard = !(eDistance < f->distance || alwaysAdd);
}
else
{
ItemPointer indextid = &unvisited[i].indextid;
BlockNumber blkno = ItemPointerGetBlockNumber(indextid);
OffsetNumber offno = ItemPointerGetOffsetNumber(indextid);
if (index == NULL)
eDistance = GetCandidateDistance(base, e, q, procinfo, collation);
else
HnswLoadElement(eElement, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd || discarded != NULL ? NULL : &f->distance);
/* Avoid any allocations if not adding */
eElement = NULL;
HnswLoadElementImpl(blkno, offno, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd || discarded != NULL ? NULL : &f->distance, &eElement);
if (eDistance < f->distance || alwaysAdd)
discard = eElement == NULL || !(eDistance < f->distance || alwaysAdd);
}
if (discard)
{
if (discarded != NULL)
{
HnswCandidate *ec;
/* Create a new candidate */
e = palloc(sizeof(HnswCandidate));
HnswPtrStore(base, e->element, eElement);
e->distance = eDistance;
Assert(!eElement->deleted);
Assert(eElement->level >= lc);
/* Make robust to issues */
if (eElement->level < lc)
continue;
/* Copy e */
ec = palloc(sizeof(HnswCandidate));
HnswPtrStore(base, ec->element, eElement);
ec->distance = eDistance;
pairingheap_add(C, &(CreatePairingHeapNode(ec)->ph_node));
pairingheap_add(W, &(CreatePairingHeapNode(ec)->ph_node));
/*
* Do not count elements being deleted towards ef when
* vacuuming. It would be ideal to do this for inserts as
* well, but this could affect insert performance.
*/
if (CountElement(base, skipElement, e))
{
wlen++;
/* No need to decrement wlen */
if (wlen > ef)
{
HnswCandidate *hc = ((HnswPairingHeapNode *) pairingheap_remove_first(W))->inner;
if (discarded != NULL)
*discarded = lappend(*discarded, hc);
}
}
*discarded = lappend(*discarded, e);
}
else if (discarded != NULL)
continue;
}
/* Make robust to issues */
if (eElement->level < lc)
continue;
/* Create a new candidate */
e = palloc(sizeof(HnswCandidate));
HnswPtrStore(base, e->element, eElement);
e->distance = eDistance;
node = CreatePairingHeapNode(e);
pairingheap_add(C, &node->c_node);
pairingheap_add(W, &node->w_node);
/*
* Do not count elements being deleted towards ef when vacuuming.
* It would be ideal to do this for inserts as well, but this
* could affect insert performance.
*/
if (CountElement(skipElement, eElement))
{
wlen++;
/* No need to decrement wlen */
if (wlen > ef)
{
HnswCandidate *ec;
HnswCandidate *hc = HnswGetPairingHeapCandidate(w_node, pairingheap_remove_first(W));
/* Copy e */
ec = palloc(sizeof(HnswCandidate));
HnswPtrStore(base, ec->element, eElement);
ec->distance = eDistance;
*discarded = lappend(*discarded, ec);
if (discarded != NULL)
*discarded = lappend(*discarded, hc);
}
}
}
@@ -861,7 +950,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
/* Add each element of W to w */
while (!pairingheap_is_empty(W))
{
HnswCandidate *hc = ((HnswPairingHeapNode *) pairingheap_remove_first(W))->inner;
HnswCandidate *hc = HnswGetPairingHeapCandidate(w_node, pairingheap_remove_first(W));
w = lappend(w, hc);
}
@@ -1124,7 +1213,7 @@ HnswUpdateConnection(char *base, HnswElement element, HnswCandidate * hc, int lm
if (HnswPtrIsNull(base, hc3Element->value))
HnswLoadElement(hc3Element, &hc3->distance, &q, index, procinfo, collation, true, NULL);
else
hc3->distance = GetCandidateDistance(base, hc3, q, procinfo, collation);
hc3->distance = GetElementDistance(base, hc3Element, q, procinfo, collation);
/* Prune element if being deleted */
if (hc3Element->heaptidsLength == 0)