diff --git a/README.md b/README.md index c554513..51c085f 100644 --- a/README.md +++ b/README.md @@ -166,7 +166,7 @@ Supported index types are: ## IVFFlat -TODO Add description +An IVFFlat index clusters vectors into lists, and then searches a subset of those lists. It has faster build times and uses less memory than HNSW, but has lower query performance. Three keys to achieving good recall are: @@ -217,7 +217,12 @@ COMMIT; ## HNSW -TODO Add description and options +An HNSW index creates a multilayer graph between vectors. It has slower build times and uses more memory than IVFFlat, but has better query performance. There’s no training step like IVFFlat, so the index can be created without any data in the table. + +The options for HNSW are: + +- `m` - the max number of connections per layer (the bottom layer uses `2 * m`) +- `ef_construction` - the size of the dynamic candidate list for constructing the graph Add an index for each distance function you want to use. diff --git a/src/hnsw.h b/src/hnsw.h index 5581ca8..56f2ccf 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -41,6 +41,7 @@ #define HNSW_ELEMENT_TUPLE_TYPE 1 #define HNSW_NEIGHBOR_TUPLE_TYPE 2 +/* Make graph robust against non-HOT updates */ #define HNSW_HEAPTIDS 10 /* Build phases */ @@ -49,7 +50,6 @@ #define HNSW_ELEMENT_TUPLE_SIZE(_dim) MAXALIGN(offsetof(HnswElementTupleData, vec) + VECTOR_SIZE(_dim)) #define HNSW_NEIGHBOR_TUPLE_SIZE(level, m) MAXALIGN(offsetof(HnswNeighborTupleData, neighbors) + ((level) + 2) * (m) * sizeof(HnswNeighborTupleItem)) -#define HNSW_NEIGHBOR_COUNT(itemid) ((ItemIdGetLength(itemid) - offsetof(HnswNeighborTupleData, neighbors)) / sizeof(HnswNeighborTupleItem)) #define HnswPageGetOpaque(page) ((HnswPageOpaque) PageGetSpecialPointer(page)) #define HnswPageGetMeta(page) ((HnswMetaPageData *) PageGetContents(page)) @@ -164,8 +164,8 @@ typedef struct HnswMetaPageData uint32 magicNumber; uint32 version; uint32 dimensions; - uint32 m; - uint32 efConstruction; + uint16 m; + uint16 efConstruction; BlockNumber entryBlkno; OffsetNumber entryOffno; int16 entryLevel; @@ -201,15 +201,14 @@ typedef struct HnswNeighborTupleItem { ItemPointerData indextid; uint16 unused; - float distance; + float distance; /* improves performance of inserts */ } HnswNeighborTupleItem; typedef struct HnswNeighborTupleData { uint8 type; uint8 unused; - uint16 unused2; - uint32 unused3; + uint16 count; HnswNeighborTupleItem neighbors[FLEXIBLE_ARRAY_MEMBER]; } HnswNeighborTupleData; @@ -277,7 +276,7 @@ void HnswSetNeighborTuple(HnswNeighborTuple ntup, HnswElement e, int m); void HnswAddHeapTid(HnswElement element, ItemPointer heaptid); void HnswInitNeighbors(HnswElement element, int m); bool HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel); -void HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadvec); +void HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec); void HnswSetElementTuple(HnswElementTuple etup, HnswElement element); /* Index access methods */ diff --git a/src/hnswinsert.c b/src/hnswinsert.c index 0ba7d50..c2a17ff 100644 --- a/src/hnswinsert.c +++ b/src/hnswinsert.c @@ -295,6 +295,7 @@ UpdateNeighborPages(Relation index, HnswElement e, int m, List *updates) GenericXLogState *state; HnswUpdate *update = lfirst(lc); ItemId itemid; + HnswNeighborTuple ntup; Size ntupSize; int idx; OffsetNumber offno = update->hc.element->neighborOffno; @@ -305,23 +306,24 @@ UpdateNeighborPages(Relation index, HnswElement e, int m, List *updates) state = GenericXLogStart(index); page = GenericXLogRegisterBuffer(state, buf, 0); + /* Get tuple */ itemid = PageGetItemId(page, offno); + ntup = (HnswNeighborTuple) PageGetItem(page, itemid); ntupSize = ItemIdGetLength(itemid); + /* Calculate index */ idx = HnswGetIndex(update, m); - /* Make robust against issues */ - if (idx < (int) HNSW_NEIGHBOR_COUNT(itemid)) + /* Make robust to issues */ + if (idx < ntup->count) { - HnswNeighborTuple ntup = (HnswNeighborTuple) PageGetItem(page, itemid); - HnswNeighborTupleItem *neighbor = &ntup->neighbors[idx]; - /* Set item data */ + /* Update neighbor */ ItemPointerSet(&neighbor->indextid, e->blkno, e->offno); neighbor->distance = update->hc.distance; - /* Update connections */ + /* Overwrite tuple */ if (!PageIndexTupleOverwrite(page, offno, (Item) ntup, ntupSize)) elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); @@ -337,7 +339,7 @@ UpdateNeighborPages(Relation index, HnswElement e, int m, List *updates) } /* - * Add a heap tid to an existing element + * Add a heap TID to an existing element */ static bool HnswAddDuplicate(Relation index, HnswElement element, HnswElement dup) @@ -371,10 +373,10 @@ HnswAddDuplicate(Relation index, HnswElement element, HnswElement dup) return false; } - /* Add heap tid */ + /* Add heap TID */ etup->heaptids[i] = *((ItemPointer) linitial(element->heaptids)); - /* Update index tuple */ + /* Overwrite tuple */ if (!PageIndexTupleOverwrite(page, dup->offno, (Item) etup, etupSize)) elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); diff --git a/src/hnswutils.c b/src/hnswutils.c index 7a7dde0..1d3409e 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -123,426 +123,7 @@ HnswCommitBuffer(Buffer buf, GenericXLogState *state) } /* - * Create an element from block and offset - */ -static HnswElement -CreateElementFromBlock(BlockNumber blkno, OffsetNumber offno) -{ - HnswElement element = palloc(sizeof(HnswElementData)); - - element->blkno = blkno; - element->offno = offno; - element->neighbors = NULL; - element->vec = NULL; - return element; -} - -/* - * Get the entry point - */ -HnswElement -HnswGetEntryPoint(Relation index) -{ - Buffer buf; - Page page; - HnswMetaPage metap; - HnswElement entryPoint = NULL; - - buf = ReadBuffer(index, HNSW_METAPAGE_BLKNO); - LockBuffer(buf, BUFFER_LOCK_SHARE); - page = BufferGetPage(buf); - metap = HnswPageGetMeta(page); - - if (BlockNumberIsValid(metap->entryBlkno)) - entryPoint = CreateElementFromBlock(metap->entryBlkno, metap->entryOffno); - - UnlockReleaseBuffer(buf); - - return entryPoint; -} - -/* - * Compare candidate distances - */ -static int -CompareNearestCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg) -{ - if (((const HnswPairingHeapNode *) a)->inner->distance < ((const HnswPairingHeapNode *) b)->inner->distance) - return 1; - - if (((const HnswPairingHeapNode *) a)->inner->distance > ((const HnswPairingHeapNode *) b)->inner->distance) - return -1; - - return 0; -} - -/* - * Compare candidate distances - */ -static int -CompareFurthestCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg) -{ - if (((const HnswPairingHeapNode *) a)->inner->distance < ((const HnswPairingHeapNode *) b)->inner->distance) - return -1; - - if (((const HnswPairingHeapNode *) a)->inner->distance > ((const HnswPairingHeapNode *) b)->inner->distance) - return 1; - - return 0; -} - -/* - * Create a pairing heap node for a candidate - */ -static HnswPairingHeapNode * -CreatePairingHeapNode(HnswCandidate * c) -{ - HnswPairingHeapNode *node = palloc(sizeof(HnswPairingHeapNode)); - - node->inner = c; - return node; -} - -/* - * Calculate the distance between elements - */ -static float -HnswGetDistance(HnswElement a, HnswElement b, int lc, FmgrInfo *procinfo, Oid collation) -{ - /* Look for cached distance */ - if (a->neighbors != NULL) - { - Assert(a->level >= lc); - - for (int i = 0; i < a->neighbors[lc].length; i++) - { - if (a->neighbors[lc].items[i].element == b) - return a->neighbors[lc].items[i].distance; - } - } - - if (b->neighbors != NULL) - { - Assert(b->level >= lc); - - for (int i = 0; i < b->neighbors[lc].length; i++) - { - if (b->neighbors[lc].items[i].element == a) - return b->neighbors[lc].items[i].distance; - } - } - - return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(a->vec), PointerGetDatum(b->vec))); -} - -/* - * Check if an element is closer to q than any element from R - */ -static bool -CheckElementCloser(HnswCandidate * e, List *r, int lc, FmgrInfo *procinfo, Oid collation) -{ - ListCell *lc2; - - foreach(lc2, r) - { - HnswCandidate *ri = lfirst(lc2); - float distance = HnswGetDistance(e->element, ri->element, lc, procinfo, collation); - - if (distance <= e->distance) - return false; - } - - return true; -} - -/* - * Algorithm 4 from paper - */ -static List * -SelectNeighbors(List *c, int m, int lc, FmgrInfo *procinfo, Oid collation, HnswCandidate * *pruned) -{ - List *r = NIL; - List *w = list_copy(c); - pairingheap *wd; - - if (list_length(w) < m) - return w; - - wd = pairingheap_allocate(CompareNearestCandidates, NULL); - - while (list_length(w) > 0 && list_length(r) < m) - { - /* Assumes w is already ordered desc */ - HnswCandidate *e = llast(w); - bool closer; - - w = list_delete_last(w); - - closer = CheckElementCloser(e, r, lc, procinfo, collation); - - if (closer) - r = lappend(r, e); - else - pairingheap_add(wd, &(CreatePairingHeapNode(e)->ph_node)); - } - - /* Keep pruned connections */ - while (!pairingheap_is_empty(wd) && list_length(r) < m) - r = lappend(r, ((HnswPairingHeapNode *) pairingheap_remove_first(wd))->inner); - - /* Return pruned for update connections */ - if (pruned != NULL) - { - if (!pairingheap_is_empty(wd)) - *pruned = ((HnswPairingHeapNode *) pairingheap_first(wd))->inner; - else - *pruned = linitial(w); - } - - return r; -} - -/* - * Add connections - */ -static void -AddConnections(HnswElement element, List *neighbors, int m, int lc) -{ - ListCell *lc2; - HnswNeighborArray *a = &element->neighbors[lc]; - - foreach(lc2, neighbors) - a->items[a->length++] = *((HnswCandidate *) lfirst(lc2)); -} - -/* - * Create update - */ -static HnswUpdate * -CreateUpdate(HnswCandidate * hc, int level, int index) -{ - HnswUpdate *update = palloc(sizeof(HnswUpdate)); - - update->hc = *hc; - update->level = level; - update->index = index; - return update; -} - -/* - * Compare candidate distances - */ -static int -#if PG_VERSION_NUM >= 130000 -CompareCandidateDistances(const ListCell *a, const ListCell *b) -#else -CompareCandidateDistances(const void *a, const void *b) -#endif -{ - HnswCandidate *hca = lfirst((ListCell *) a); - HnswCandidate *hcb = lfirst((ListCell *) b); - - if (hca->distance < hcb->distance) - return 1; - - if (hca->distance > hcb->distance) - return -1; - - return 0; -} - -/* - * Add a heap TID to an element - */ -void -HnswAddHeapTid(HnswElement element, ItemPointer heaptid) -{ - ItemPointer copy = palloc(sizeof(ItemPointerData)); - - ItemPointerCopy(heaptid, copy); - element->heaptids = lappend(element->heaptids, copy); -} - -/* - * Load neighbors from page - */ -static void -LoadNeighborsFromPage(HnswElement element, Relation index, Page page) -{ - int m = HnswGetM(index); - ItemId itemid = PageGetItemId(page, element->neighborOffno); - int neighborCount = (element->level + 2) * m; - - HnswInitNeighbors(element, m); - - /* Ensure expected neighbors */ - if (HNSW_NEIGHBOR_COUNT(itemid) == neighborCount) - { - HnswNeighborTuple ntup = (HnswNeighborTuple) PageGetItem(page, itemid); - - Assert(HnswIsNeighborTuple(ntup)); - - for (int i = 0; i < neighborCount; i++) - { - HnswElement e; - int level; - HnswCandidate *hc; - HnswNeighborTupleItem *neighbor; - HnswNeighborArray *neighbors; - - neighbor = &ntup->neighbors[i]; - - if (!ItemPointerIsValid(&neighbor->indextid)) - continue; - - e = CreateElementFromBlock(ItemPointerGetBlockNumber(&neighbor->indextid), ItemPointerGetOffsetNumber(&neighbor->indextid)); - - /* Calculate level based on offset */ - level = element->level - i / m; - if (level < 0) - level = 0; - - neighbors = &element->neighbors[level]; - hc = &neighbors->items[neighbors->length++]; - hc->element = e; - hc->distance = neighbor->distance; - } - } -} - -/* - * Load an element and optionally get its distance from q - */ -void -HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadvec) -{ - Buffer buf; - Page page; - HnswElementTuple etup; - - /* Read vector */ - buf = ReadBuffer(index, element->blkno); - LockBuffer(buf, BUFFER_LOCK_SHARE); - page = BufferGetPage(buf); - - etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, element->offno)); - - Assert(HnswIsElementTuple(etup)); - - /* Load element */ - element->heaptids = NIL; - for (int i = 0; i < HNSW_HEAPTIDS; i++) - { - /* Can stop at first invalid */ - if (!ItemPointerIsValid(&etup->heaptids[i])) - break; - - HnswAddHeapTid(element, &etup->heaptids[i]); - } - element->level = etup->level; - element->neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid); - element->neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid); - element->deleted = etup->deleted; - - if (loadvec) - { - element->vec = palloc(VECTOR_SIZE(etup->vec.dim)); - memcpy(element->vec, &etup->vec, VECTOR_SIZE(etup->vec.dim)); - } - - /* Calculate distance */ - if (distance != NULL) - *distance = (float) DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&etup->vec))); - - UnlockReleaseBuffer(buf); -} - -/* - * Update connections - */ -static void -UpdateConnections(HnswElement element, List *neighbors, int m, int lc, List **updates, Relation index, FmgrInfo *procinfo, Oid collation) -{ - ListCell *lc2; - - foreach(lc2, neighbors) - { - HnswCandidate *hc = (HnswCandidate *) lfirst(lc2); - HnswNeighborArray *currentNeighbors = &hc->element->neighbors[lc]; - - HnswCandidate hc2; - - hc2.element = element; - hc2.distance = hc->distance; - - if (currentNeighbors->length < m) - { - currentNeighbors->items[currentNeighbors->length++] = hc2; - - /* Track updates */ - if (updates != NULL) - *updates = lappend(*updates, CreateUpdate(hc, lc, currentNeighbors->length - 1)); - } - else - { - /* Shrink connections */ - HnswCandidate *pruned = NULL; - List *c = NIL; - - /* Add and sort candidates */ - for (int i = 0; i < currentNeighbors->length; i++) - c = lappend(c, ¤tNeighbors->items[i]); - c = lappend(c, &hc2); - list_sort(c, CompareCandidateDistances); - - /* Load elements on insert */ - if (index != NULL) - { - for (int i = 0; i < currentNeighbors->length; i++) - { - if (currentNeighbors->items[i].element->vec == NULL) - { - HnswLoadElement(currentNeighbors->items[i].element, NULL, NULL, index, procinfo, collation, true); - - /* Prune deleted element */ - if (currentNeighbors->items[i].element->deleted) - { - pruned = ¤tNeighbors->items[i]; - break; - } - } - } - } - - if (pruned == NULL) - { - SelectNeighbors(c, m, lc, procinfo, collation, &pruned); - - /* Should not happen */ - if (pruned == NULL) - continue; - } - - /* Find and replace the pruned element */ - for (int i = 0; i < currentNeighbors->length; i++) - { - if (currentNeighbors->items[i].element == pruned->element) - { - currentNeighbors->items[i] = hc2; - - /* Track updates */ - if (updates != NULL) - *updates = lappend(*updates, CreateUpdate(hc, lc, i)); - - break; - } - } - } - } -} - -/* - * Initialize neighbors + * Allocate neighbors */ void HnswInitNeighbors(HnswElement element, int m) @@ -562,24 +143,6 @@ HnswInitNeighbors(HnswElement element, int m) } } -/* - * Load neighbors - */ -static void -LoadNeighbors(HnswElement element, Relation index) -{ - Buffer buf; - Page page; - - buf = ReadBuffer(index, element->neighborPage); - LockBuffer(buf, BUFFER_LOCK_SHARE); - page = BufferGetPage(buf); - - LoadNeighborsFromPage(element, index, page); - - UnlockReleaseBuffer(buf); -} - /* * Allocate an element */ @@ -619,6 +182,97 @@ HnswFreeElement(HnswElement element) pfree(element); } +/* + * Add a heap TID to an element + */ +void +HnswAddHeapTid(HnswElement element, ItemPointer heaptid) +{ + ItemPointer copy = palloc(sizeof(ItemPointerData)); + + ItemPointerCopy(heaptid, copy); + element->heaptids = lappend(element->heaptids, copy); +} + +/* + * Allocate an element from block and offset numbers + */ +static HnswElement +InitElementFromBlock(BlockNumber blkno, OffsetNumber offno) +{ + HnswElement element = palloc(sizeof(HnswElementData)); + + element->blkno = blkno; + element->offno = offno; + element->neighbors = NULL; + element->vec = NULL; + return element; +} + +/* + * Get the entry point + */ +HnswElement +HnswGetEntryPoint(Relation index) +{ + Buffer buf; + Page page; + HnswMetaPage metap; + HnswElement entryPoint = NULL; + + buf = ReadBuffer(index, HNSW_METAPAGE_BLKNO); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + metap = HnswPageGetMeta(page); + + if (BlockNumberIsValid(metap->entryBlkno)) + entryPoint = InitElementFromBlock(metap->entryBlkno, metap->entryOffno); + + UnlockReleaseBuffer(buf); + + return entryPoint; +} + +/* + * Update the metapage + */ +void +HnswUpdateMetaPage(Relation index, bool updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum) +{ + Buffer buf; + Page page; + GenericXLogState *state; + HnswMetaPage metap; + + buf = ReadBufferExtended(index, forkNum, HNSW_METAPAGE_BLKNO, RBM_NORMAL, NULL); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + + metap = HnswPageGetMeta(page); + + if (updateEntry) + { + if (entryPoint == NULL) + { + metap->entryBlkno = InvalidBlockNumber; + metap->entryOffno = InvalidOffsetNumber; + metap->entryLevel = -1; + } + else + { + metap->entryBlkno = entryPoint->blkno; + metap->entryOffno = entryPoint->offno; + metap->entryLevel = entryPoint->level; + } + } + + if (BlockNumberIsValid(insertPage)) + metap->insertPage = insertPage; + + HnswCommitBuffer(buf, state); +} + /* * Set element tuple, except for neighbor info */ @@ -638,6 +292,153 @@ HnswSetElementTuple(HnswElementTuple etup, HnswElement element) memcpy(&etup->vec, element->vec, VECTOR_SIZE(element->vec->dim)); } +/* + * Set neighbor tuple + */ +void +HnswSetNeighborTuple(HnswNeighborTuple ntup, HnswElement e, int m) +{ + int idx = 0; + + ntup->type = HNSW_NEIGHBOR_TUPLE_TYPE; + + for (int lc = e->level; lc >= 0; lc--) + { + HnswNeighborArray *neighbors = &e->neighbors[lc]; + int lm = HnswGetLayerM(m, lc); + + for (int i = 0; i < lm; i++) + { + HnswNeighborTupleItem *neighbor = &ntup->neighbors[idx++]; + + if (i < neighbors->length) + { + HnswCandidate *hc = &neighbors->items[i]; + + ItemPointerSet(&neighbor->indextid, hc->element->blkno, hc->element->offno); + neighbor->distance = hc->distance; + } + else + { + ItemPointerSetInvalid(&neighbor->indextid); + neighbor->distance = NAN; + } + } + } + + ntup->count = idx; +} + +/* + * Load neighbors from page + */ +static void +LoadNeighborsFromPage(HnswElement element, Relation index, Page page) +{ + HnswNeighborTuple ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno)); + int m = HnswGetM(index); + int neighborCount = (element->level + 2) * m; + + Assert(HnswIsNeighborTuple(ntup)); + + HnswInitNeighbors(element, m); + + /* Ensure expected neighbors */ + if (ntup->count != neighborCount) + return; + + for (int i = 0; i < neighborCount; i++) + { + HnswElement e; + int level; + HnswCandidate *hc; + HnswNeighborTupleItem *neighbor; + HnswNeighborArray *neighbors; + + neighbor = &ntup->neighbors[i]; + + if (!ItemPointerIsValid(&neighbor->indextid)) + continue; + + e = InitElementFromBlock(ItemPointerGetBlockNumber(&neighbor->indextid), ItemPointerGetOffsetNumber(&neighbor->indextid)); + + /* Calculate level based on offset */ + level = element->level - i / m; + if (level < 0) + level = 0; + + neighbors = &element->neighbors[level]; + hc = &neighbors->items[neighbors->length++]; + hc->element = e; + hc->distance = neighbor->distance; + } +} + +/* + * Load neighbors + */ +static void +LoadNeighbors(HnswElement element, Relation index) +{ + Buffer buf; + Page page; + + buf = ReadBuffer(index, element->neighborPage); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + + LoadNeighborsFromPage(element, index, page); + + UnlockReleaseBuffer(buf); +} + +/* + * Load an element and optionally get its distance from q + */ +void +HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec) +{ + Buffer buf; + Page page; + HnswElementTuple etup; + + /* Read vector */ + buf = ReadBuffer(index, element->blkno); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + + etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, element->offno)); + + Assert(HnswIsElementTuple(etup)); + + /* Load element */ + element->heaptids = NIL; + for (int i = 0; i < HNSW_HEAPTIDS; i++) + { + /* Can stop at first invalid */ + if (!ItemPointerIsValid(&etup->heaptids[i])) + break; + + HnswAddHeapTid(element, &etup->heaptids[i]); + } + element->level = etup->level; + element->neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid); + element->neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid); + element->deleted = etup->deleted; + + if (loadVec) + { + element->vec = palloc(VECTOR_SIZE(etup->vec.dim)); + memcpy(element->vec, &etup->vec, VECTOR_SIZE(etup->vec.dim)); + } + + /* Calculate distance */ + if (distance != NULL) + *distance = (float) DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&etup->vec))); + + UnlockReleaseBuffer(buf); +} + /* * Get the distance for a candidate */ @@ -647,6 +448,64 @@ GetCandidateDistance(HnswCandidate * hc, Datum q, FmgrInfo *procinfo, Oid collat return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, PointerGetDatum(hc->element->vec))); } +/* + * Create a candidate for the entry point + */ +HnswCandidate * +HnswEntryCandidate(HnswElement entryPoint, Datum q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadvec) +{ + HnswCandidate *hc = palloc(sizeof(HnswCandidate)); + + hc->element = entryPoint; + if (index == NULL) + hc->distance = GetCandidateDistance(hc, q, procinfo, collation); + else + HnswLoadElement(hc->element, &hc->distance, &q, index, procinfo, collation, loadvec); + return hc; +} + +/* + * Compare candidate distances + */ +static int +CompareNearestCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg) +{ + if (((const HnswPairingHeapNode *) a)->inner->distance < ((const HnswPairingHeapNode *) b)->inner->distance) + return 1; + + if (((const HnswPairingHeapNode *) a)->inner->distance > ((const HnswPairingHeapNode *) b)->inner->distance) + return -1; + + return 0; +} + +/* + * Compare candidate distances + */ +static int +CompareFurthestCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg) +{ + if (((const HnswPairingHeapNode *) a)->inner->distance < ((const HnswPairingHeapNode *) b)->inner->distance) + return -1; + + if (((const HnswPairingHeapNode *) a)->inner->distance > ((const HnswPairingHeapNode *) b)->inner->distance) + return 1; + + return 0; +} + +/* + * Create a pairing heap node for a candidate + */ +static HnswPairingHeapNode * +CreatePairingHeapNode(HnswCandidate * c) +{ + HnswPairingHeapNode *node = palloc(sizeof(HnswPairingHeapNode)); + + node->inner = c; + return node; +} + /* * Add to visited */ @@ -752,7 +611,7 @@ HnswSearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *pro if (skipPage != NULL && e->element->neighborPage == *skipPage && e->element->neighborOffno == *skipOffno) continue; - /* Stale read */ + /* Make robust to issues */ if (e->element->level < lc) continue; @@ -788,19 +647,102 @@ HnswSearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *pro } /* - * Create a candidate for the entry point + * Calculate the distance between elements */ -HnswCandidate * -HnswEntryCandidate(HnswElement entryPoint, Datum q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadvec) +static float +HnswGetDistance(HnswElement a, HnswElement b, int lc, FmgrInfo *procinfo, Oid collation) { - HnswCandidate *hc = palloc(sizeof(HnswCandidate)); + /* Look for cached distance */ + if (a->neighbors != NULL) + { + Assert(a->level >= lc); - hc->element = entryPoint; - if (index == NULL) - hc->distance = GetCandidateDistance(hc, q, procinfo, collation); - else - HnswLoadElement(hc->element, &hc->distance, &q, index, procinfo, collation, loadvec); - return hc; + for (int i = 0; i < a->neighbors[lc].length; i++) + { + if (a->neighbors[lc].items[i].element == b) + return a->neighbors[lc].items[i].distance; + } + } + + if (b->neighbors != NULL) + { + Assert(b->level >= lc); + + for (int i = 0; i < b->neighbors[lc].length; i++) + { + if (b->neighbors[lc].items[i].element == a) + return b->neighbors[lc].items[i].distance; + } + } + + return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(a->vec), PointerGetDatum(b->vec))); +} + +/* + * Check if an element is closer to q than any element from R + */ +static bool +CheckElementCloser(HnswCandidate * e, List *r, int lc, FmgrInfo *procinfo, Oid collation) +{ + ListCell *lc2; + + foreach(lc2, r) + { + HnswCandidate *ri = lfirst(lc2); + float distance = HnswGetDistance(e->element, ri->element, lc, procinfo, collation); + + if (distance <= e->distance) + return false; + } + + return true; +} + +/* + * Algorithm 4 from paper + */ +static List * +SelectNeighbors(List *c, int m, int lc, FmgrInfo *procinfo, Oid collation, HnswCandidate * *pruned) +{ + List *r = NIL; + List *w = list_copy(c); + pairingheap *wd; + + if (list_length(w) < m) + return w; + + wd = pairingheap_allocate(CompareNearestCandidates, NULL); + + while (list_length(w) > 0 && list_length(r) < m) + { + /* Assumes w is already ordered desc */ + HnswCandidate *e = llast(w); + bool closer; + + w = list_delete_last(w); + + closer = CheckElementCloser(e, r, lc, procinfo, collation); + + if (closer) + r = lappend(r, e); + else + pairingheap_add(wd, &(CreatePairingHeapNode(e)->ph_node)); + } + + /* Keep pruned connections */ + while (!pairingheap_is_empty(wd) && list_length(r) < m) + r = lappend(r, ((HnswPairingHeapNode *) pairingheap_remove_first(wd))->inner); + + /* Return pruned for update connections */ + if (pruned != NULL) + { + if (!pairingheap_is_empty(wd)) + *pruned = ((HnswPairingHeapNode *) pairingheap_first(wd))->inner; + else + *pruned = linitial(w); + } + + return r; } /* @@ -827,6 +769,139 @@ HnswFindDuplicate(HnswElement e, List *neighbors) return NULL; } +/* + * Add connections + */ +static void +AddConnections(HnswElement element, List *neighbors, int m, int lc) +{ + ListCell *lc2; + HnswNeighborArray *a = &element->neighbors[lc]; + + foreach(lc2, neighbors) + a->items[a->length++] = *((HnswCandidate *) lfirst(lc2)); +} + +/* + * Compare candidate distances + */ +static int +#if PG_VERSION_NUM >= 130000 +CompareCandidateDistances(const ListCell *a, const ListCell *b) +#else +CompareCandidateDistances(const void *a, const void *b) +#endif +{ + HnswCandidate *hca = lfirst((ListCell *) a); + HnswCandidate *hcb = lfirst((ListCell *) b); + + if (hca->distance < hcb->distance) + return 1; + + if (hca->distance > hcb->distance) + return -1; + + return 0; +} + +/* + * Create update + */ +static HnswUpdate * +CreateUpdate(HnswCandidate * hc, int level, int index) +{ + HnswUpdate *update = palloc(sizeof(HnswUpdate)); + + update->hc = *hc; + update->level = level; + update->index = index; + return update; +} + +/* + * Update connections + */ +static void +UpdateConnections(HnswElement element, List *neighbors, int m, int lc, List **updates, Relation index, FmgrInfo *procinfo, Oid collation) +{ + ListCell *lc2; + + foreach(lc2, neighbors) + { + HnswCandidate *hc = (HnswCandidate *) lfirst(lc2); + HnswNeighborArray *currentNeighbors = &hc->element->neighbors[lc]; + + HnswCandidate hc2; + + hc2.element = element; + hc2.distance = hc->distance; + + if (currentNeighbors->length < m) + { + currentNeighbors->items[currentNeighbors->length++] = hc2; + + /* Track updates */ + if (updates != NULL) + *updates = lappend(*updates, CreateUpdate(hc, lc, currentNeighbors->length - 1)); + } + else + { + /* Shrink connections */ + HnswCandidate *pruned = NULL; + List *c = NIL; + + /* Add and sort candidates */ + for (int i = 0; i < currentNeighbors->length; i++) + c = lappend(c, ¤tNeighbors->items[i]); + c = lappend(c, &hc2); + list_sort(c, CompareCandidateDistances); + + /* Load elements on insert */ + if (index != NULL) + { + for (int i = 0; i < currentNeighbors->length; i++) + { + if (currentNeighbors->items[i].element->vec == NULL) + { + HnswLoadElement(currentNeighbors->items[i].element, NULL, NULL, index, procinfo, collation, true); + + /* Prune deleted element */ + if (currentNeighbors->items[i].element->deleted) + { + pruned = ¤tNeighbors->items[i]; + break; + } + } + } + } + + if (pruned == NULL) + { + SelectNeighbors(c, m, lc, procinfo, collation, &pruned); + + /* Should not happen */ + if (pruned == NULL) + continue; + } + + /* Find and replace the pruned element */ + for (int i = 0; i < currentNeighbors->length; i++) + { + if (currentNeighbors->items[i].element == pruned->element) + { + currentNeighbors->items[i] = hc2; + + /* Track updates */ + if (updates != NULL) + *updates = lappend(*updates, CreateUpdate(hc, lc, i)); + + break; + } + } + } + } +} + /* * Algorithm 1 from paper */ @@ -859,6 +934,7 @@ HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, F removeEntryPoint = false; } + /* 1st phase: greedy search to insert level */ for (int lc = entryLevel; lc >= level + 1; lc--) { w = HnswSearchLayer(q, ep, 1, lc, index, procinfo, collation, true, skipPage, skipOffno); @@ -868,17 +944,22 @@ HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, F if (level > entryLevel) level = entryLevel; + /* 2nd phase */ for (int lc = level; lc >= 0; lc--) { int lm = HnswGetLayerM(m, lc); w = HnswSearchLayer(q, ep, efConstruction, lc, index, procinfo, collation, true, skipPage, skipOffno); + + /* Remove entry point if it's being deleted */ if (removeEntryPoint) w = list_delete_ptr(w, entryCandidate); + newNeighbors[lc] = SelectNeighbors(w, lm, lc, procinfo, collation, NULL); ep = w; } + /* Look for duplicate */ if (level >= 0 && !vacuuming) { dup = HnswFindDuplicate(element, newNeighbors[0]); @@ -899,78 +980,3 @@ HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, F return NULL; } - -/* - * Update the metapage - */ -void -HnswUpdateMetaPage(Relation index, bool updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum) -{ - Buffer buf; - Page page; - GenericXLogState *state; - HnswMetaPage metap; - - buf = ReadBufferExtended(index, forkNum, HNSW_METAPAGE_BLKNO, RBM_NORMAL, NULL); - LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); - state = GenericXLogStart(index); - page = GenericXLogRegisterBuffer(state, buf, 0); - - metap = HnswPageGetMeta(page); - - if (updateEntry) - { - if (entryPoint == NULL) - { - metap->entryBlkno = InvalidBlockNumber; - metap->entryOffno = InvalidOffsetNumber; - metap->entryLevel = -1; - } - else - { - metap->entryBlkno = entryPoint->blkno; - metap->entryOffno = entryPoint->offno; - metap->entryLevel = entryPoint->level; - } - } - - if (BlockNumberIsValid(insertPage)) - metap->insertPage = insertPage; - - HnswCommitBuffer(buf, state); -} - -/* - * Set neighbor tuple - */ -void -HnswSetNeighborTuple(HnswNeighborTuple ntup, HnswElement e, int m) -{ - int idx = 0; - - ntup->type = HNSW_NEIGHBOR_TUPLE_TYPE; - - for (int lc = e->level; lc >= 0; lc--) - { - HnswNeighborArray *neighbors = &e->neighbors[lc]; - int lm = HnswGetLayerM(m, lc); - - for (int i = 0; i < lm; i++) - { - HnswNeighborTupleItem *neighbor = &ntup->neighbors[idx++]; - - if (i < neighbors->length) - { - HnswCandidate *hc = &neighbors->items[i]; - - ItemPointerSet(&neighbor->indextid, hc->element->blkno, hc->element->offno); - neighbor->distance = hc->distance; - } - else - { - ItemPointerSetInvalid(&neighbor->indextid); - neighbor->distance = NAN; - } - } - } -} diff --git a/src/hnswvacuum.c b/src/hnswvacuum.c index 756ab1d..b37c362 100644 --- a/src/hnswvacuum.c +++ b/src/hnswvacuum.c @@ -20,7 +20,7 @@ DeletedContains(HTAB *deleted, ItemPointer indextid) } /* - * Remove deleted heap tids + * Remove deleted heap TIDs * * OK to remove for entry point, since always considered for searches and inserts */ @@ -114,6 +114,7 @@ RemoveHeapTids(HnswVacuumState * vacuumstate) /* Keep track of highest non-entry point */ highestPoint->blkno = blkno; highestPoint->offno = offno; + highestPoint->level = etup->level; highestLevel = etup->level; } } @@ -142,22 +143,18 @@ NeedsUpdated(HnswVacuumState * vacuumstate, HnswElement element) BufferAccessStrategy bas = vacuumstate->bas; Buffer buf; Page page; - ItemId itemid; - int neighborCount; HnswNeighborTuple ntup; bool needsUpdated = false; buf = ReadBufferExtended(index, MAIN_FORKNUM, element->neighborPage, RBM_NORMAL, bas); LockBuffer(buf, BUFFER_LOCK_SHARE); page = BufferGetPage(buf); - itemid = PageGetItemId(page, element->neighborOffno); - ntup = (HnswNeighborTuple) PageGetItem(page, itemid); - neighborCount = HNSW_NEIGHBOR_COUNT(itemid); + ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno)); Assert(HnswIsNeighborTuple(ntup)); /* Check neighbors */ - for (int i = 0; i < neighborCount; i++) + for (int i = 0; i < ntup->count; i++) { HnswNeighborTupleItem *neighbor = &ntup->neighbors[i]; @@ -213,26 +210,32 @@ RepairGraphElement(HnswVacuumState * vacuumstate, HnswElement element) return; entryPoint = &vacuumstate->highestPoint; + + /* Reset neighbors from previous update */ entryPoint->neighbors = NULL; } else entryPoint = NULL; } + /* Init fields */ HnswInitNeighbors(element, m); element->heaptids = NIL; + /* Add element to graph, skipping itself */ HnswInsertElement(element, entryPoint, index, procinfo, collation, m, efConstruction, NULL, true); - /* Write out new neighbors on page */ + /* Update neighbor tuple */ + /* Do this before getting page to minimize locking */ + HnswSetNeighborTuple(ntup, element, m); + + /* Get neighbor page */ buf = ReadBufferExtended(index, MAIN_FORKNUM, element->neighborPage, RBM_NORMAL, bas); LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); state = GenericXLogStart(index); page = GenericXLogRegisterBuffer(state, buf, 0); - /* Update neighbors */ - HnswSetNeighborTuple(ntup, element, m); - + /* Overwrite tuple */ if (!PageIndexTupleOverwrite(page, element->neighborOffno, (Item) ntup, ntupSize)) elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); @@ -261,6 +264,7 @@ RepairGraphEntryPoint(HnswVacuumState * vacuumstate) RepairGraphElement(vacuumstate, highestPoint); } + /* See if entry point needs updated */ entryPoint = HnswGetEntryPoint(index); if (entryPoint != NULL) { @@ -402,7 +406,6 @@ MarkDeleted(HnswVacuumState * vacuumstate) Page npage; BlockNumber neighborPage; OffsetNumber neighborOffno; - int neighborCount; /* Skip neighbor tuples */ if (!HnswIsElementTuple(etup)) @@ -412,20 +415,20 @@ MarkDeleted(HnswVacuumState * vacuumstate) if (etup->deleted) continue; + /* Skip live tuples */ if (ItemPointerIsValid(&etup->heaptids[0])) { stats->num_index_tuples++; continue; } + /* Update stats */ stats->tuples_removed++; /* Calculate sizes */ etupSize = HNSW_ELEMENT_TUPLE_SIZE(etup->vec.dim); ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(etup->level, vacuumstate->m); - neighborCount = (etup->level + 2) * vacuumstate->m; - /* Get neighbor page */ neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid); neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid); @@ -449,15 +452,17 @@ MarkDeleted(HnswVacuumState * vacuumstate) MemSet(&etup->vec.x, 0, etup->vec.dim * sizeof(float)); /* Overwrite neighbors */ - for (int i = 0; i < neighborCount; i++) + for (int i = 0; i < ntup->count; i++) { ItemPointerSetInvalid(&ntup->neighbors[i].indextid); ntup->neighbors[i].distance = NAN; } + /* Overwrite element tuple */ if (!PageIndexTupleOverwrite(page, offno, (Item) etup, etupSize)) elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + /* Overwrite neighbor tuple */ if (!PageIndexTupleOverwrite(npage, neighborOffno, (Item) ntup, ntupSize)) elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); @@ -543,7 +548,7 @@ hnswbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, InitVacuumState(&vacuumstate, info, stats, callback, callback_state); - /* Pass 1: Remove heap tids */ + /* Pass 1: Remove heap TIDs */ RemoveHeapTids(&vacuumstate); /* Pass 2: Repair graph */