diff --git a/CHANGELOG.md b/CHANGELOG.md index 353f4cc..6753a7d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ - Added support for iterative index scans - Added casts for arrays to `sparsevec` - Improved cost estimation +- Improved performance of HNSW inserts and on-disk index builds - Reduced memory usage for HNSW index scans - Dropped support for Postgres 12 diff --git a/src/hnsw.c b/src/hnsw.c index 22e6985..6cb2deb 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -188,13 +188,13 @@ hnswcostestimate(PlannerInfo *root, IndexPath *path, double loop_count, else ratio = 1; - /* Set startup cost since this work happens before first tuple is returned */ - costs.indexStartupCost = costs.indexTotalCost * ratio; - startupPages = costs.numIndexPages * ratio; - get_tablespace_page_costs(path->indexinfo->reltablespace, NULL, &spc_seq_page_cost); + /* Startup cost is cost before returning the first row */ + costs.indexStartupCost = costs.indexTotalCost * ratio; + /* Adjust cost if needed since TOAST not included in seq scan cost */ + startupPages = costs.numIndexPages * ratio; if (startupPages > path->indexinfo->rel->pages && ratio < 0.5) { /* Change all page cost from random to sequential */ diff --git a/src/hnsw.h b/src/hnsw.h index e290e1b..37293e6 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -220,8 +220,8 @@ typedef struct HnswGraph /* Allocations state */ LWLock allocatorLock; - long memoryUsed; - long memoryTotal; + Size memoryUsed; + Size memoryTotal; /* Flushed state */ LWLock flushLock; @@ -272,6 +272,18 @@ typedef struct HnswTypeInfo void (*checkValue) (Pointer v); } HnswTypeInfo; +typedef struct HnswSupport +{ + FmgrInfo *procinfo; + FmgrInfo *normprocinfo; + Oid collation; +} HnswSupport; + +typedef struct HnswQuery +{ + Datum value; +} HnswQuery; + typedef struct HnswBuildState { /* Info */ @@ -291,9 +303,7 @@ typedef struct HnswBuildState double reltuples; /* Support functions */ - FmgrInfo *procinfo; - FmgrInfo *normprocinfo; - Oid collation; + HnswSupport support; /* Variables */ HnswGraph graphData; @@ -374,16 +384,14 @@ typedef struct HnswScanOpaqueData List *w; visited_hash v; pairingheap *discarded; - Datum q; + HnswQuery q; int m; int64 tuples; double previousDistance; MemoryContext tmpCtx; /* Support functions */ - FmgrInfo *procinfo; - FmgrInfo *normprocinfo; - Oid collation; + HnswSupport support; } HnswScanOpaqueData; typedef HnswScanOpaqueData * HnswScanOpaque; @@ -401,8 +409,7 @@ typedef struct HnswVacuumState int efConstruction; /* Support functions */ - FmgrInfo *procinfo; - Oid collation; + HnswSupport support; /* Variables */ struct tidhash_hash *deleted; @@ -418,30 +425,33 @@ typedef struct HnswVacuumState int HnswGetM(Relation index); int HnswGetEfConstruction(Relation index); FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum); +void HnswInitSupport(HnswSupport * support, Relation index); Datum HnswNormValue(const HnswTypeInfo * typeInfo, Oid collation, Datum value); -bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value); +bool HnswCheckNorm(HnswSupport * support, Datum value); Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); void HnswInit(void); -List *HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement, visited_hash * v, pairingheap **discarded, bool initVisited, int64 *tuples); +List *HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement, visited_hash * v, pairingheap **discarded, bool initVisited, int64 *tuples); HnswElement HnswGetEntryPoint(Relation index); void HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint); void *HnswAlloc(HnswAllocator * allocator, Size size); HnswElement HnswInitElement(char *base, ItemPointer tid, int m, double ml, int maxLevel, HnswAllocator * alloc); HnswElement HnswInitElementFromBlock(BlockNumber blkno, OffsetNumber offno); -void HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, bool existing); -HnswSearchCandidate *HnswEntryCandidate(char *base, HnswElement em, Datum q, Relation rel, FmgrInfo *procinfo, Oid collation, bool loadVec); +void HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, HnswSupport * support, int m, int efConstruction, bool existing); +HnswSearchCandidate *HnswEntryCandidate(char *base, HnswElement em, HnswQuery * q, Relation rel, HnswSupport * support, bool loadVec); void HnswUpdateMetaPage(Relation index, int updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum, bool building); void HnswSetNeighborTuple(char *base, HnswNeighborTuple ntup, HnswElement e, int m); void HnswAddHeapTid(HnswElement element, ItemPointer heaptid); +HnswNeighborArray *HnswInitNeighborArray(int lm, HnswAllocator * allocator); void HnswInitNeighbors(char *base, HnswElement element, int m, HnswAllocator * alloc); -bool HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull, ItemPointer heap_tid, bool building); -void HnswUpdateNeighborsOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement e, int m, bool checkExisting, bool building); +bool HnswInsertTupleOnDisk(Relation index, HnswSupport * support, Datum value, ItemPointer heaptid, bool building); +void HnswUpdateNeighborsOnDisk(Relation index, HnswSupport * support, HnswElement e, int m, bool checkExisting, bool building); void HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec); -void HnswLoadElement(HnswElement element, double *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, double *maxDistance); +void HnswLoadElement(HnswElement element, double *distance, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec, double *maxDistance); +bool HnswFormIndexValue(Datum *out, Datum *values, bool *isnull, const HnswTypeInfo * typeInfo, HnswSupport * support); void HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element); -void HnswUpdateConnection(char *base, HnswElement element, HnswCandidate * hc, int lm, int lc, int *updateIdx, Relation index, FmgrInfo *procinfo, Oid collation); -void HnswLoadNeighbors(HnswElement element, Relation index, int m); +void HnswUpdateConnection(char *base, HnswNeighborArray * neighbors, HnswElement newElement, float distance, int lm, int *updateIdx, Relation index, HnswSupport * support); +bool HnswLoadNeighborTids(HnswElement element, ItemPointerData *indextids, Relation index, int m, int lm, int lc); void HnswInitLockTranche(void); const HnswTypeInfo *HnswGetTypeInfo(Relation index); PGDLLEXPORT void HnswParallelBuildMain(dsm_segment *seg, shm_toc *toc); diff --git a/src/hnswbuild.c b/src/hnswbuild.c index 498b5d9..b667478 100644 --- a/src/hnswbuild.c +++ b/src/hnswbuild.c @@ -366,7 +366,7 @@ AddElementInMemory(char *base, HnswGraph * graph, HnswElement element) * Update neighbors */ static void -UpdateNeighborsInMemory(char *base, FmgrInfo *procinfo, Oid collation, HnswElement e, int m) +UpdateNeighborsInMemory(char *base, HnswSupport * support, HnswElement e, int m) { for (int lc = e->level; lc >= 0; lc--) { @@ -388,7 +388,7 @@ UpdateNeighborsInMemory(char *base, FmgrInfo *procinfo, Oid collation, HnswEleme Assert(neighborElement); LWLockAcquire(&neighborElement->lock, LW_EXCLUSIVE); - HnswUpdateConnection(base, e, hc, lm, lc, NULL, NULL, procinfo, collation); + HnswUpdateConnection(base, HnswGetNeighbors(base, neighborElement, lc), e, hc->distance, lm, NULL, NULL, support); LWLockRelease(&neighborElement->lock); } } @@ -398,7 +398,7 @@ UpdateNeighborsInMemory(char *base, FmgrInfo *procinfo, Oid collation, HnswEleme * Update graph in memory */ static void -UpdateGraphInMemory(FmgrInfo *procinfo, Oid collation, HnswElement element, int m, int efConstruction, HnswElement entryPoint, HnswBuildState * buildstate) +UpdateGraphInMemory(HnswSupport * support, HnswElement element, int m, int efConstruction, HnswElement entryPoint, HnswBuildState * buildstate) { HnswGraph *graph = buildstate->graph; char *base = buildstate->hnswarea; @@ -411,7 +411,7 @@ UpdateGraphInMemory(FmgrInfo *procinfo, Oid collation, HnswElement element, int AddElementInMemory(base, graph, element); /* Update neighbors */ - UpdateNeighborsInMemory(base, procinfo, collation, element, m); + UpdateNeighborsInMemory(base, support, element, m); /* Update entry point if needed (already have lock) */ if (entryPoint == NULL || element->level > entryPoint->level) @@ -424,9 +424,8 @@ UpdateGraphInMemory(FmgrInfo *procinfo, Oid collation, HnswElement element, int static void InsertTupleInMemory(HnswBuildState * buildstate, HnswElement element) { - FmgrInfo *procinfo = buildstate->procinfo; - Oid collation = buildstate->collation; HnswGraph *graph = buildstate->graph; + HnswSupport *support = &buildstate->support; HnswElement entryPoint; LWLock *entryLock = &graph->entryLock; LWLock *entryWaitLock = &graph->entryWaitLock; @@ -458,10 +457,10 @@ InsertTupleInMemory(HnswBuildState * buildstate, HnswElement element) } /* Find neighbors for element */ - HnswFindElementNeighbors(base, element, entryPoint, NULL, procinfo, collation, m, efConstruction, false); + HnswFindElementNeighbors(base, element, entryPoint, NULL, support, m, efConstruction, false); /* Update graph in memory */ - UpdateGraphInMemory(procinfo, collation, element, m, efConstruction, entryPoint, buildstate); + UpdateGraphInMemory(support, element, m, efConstruction, entryPoint, buildstate); /* Release entry lock */ LWLockRelease(entryLock); @@ -473,30 +472,19 @@ InsertTupleInMemory(HnswBuildState * buildstate, HnswElement element) static bool InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, HnswBuildState * buildstate) { - const HnswTypeInfo *typeInfo = buildstate->typeInfo; HnswGraph *graph = buildstate->graph; HnswElement element; HnswAllocator *allocator = &buildstate->allocator; + HnswSupport *support = &buildstate->support; Size valueSize; Pointer valuePtr; LWLock *flushLock = &graph->flushLock; char *base = buildstate->hnswarea; + Datum value; - /* Detoast once for all calls */ - Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); - - /* Check value */ - if (typeInfo->checkValue != NULL) - typeInfo->checkValue(DatumGetPointer(value)); - - /* Normalize if needed */ - if (buildstate->normprocinfo != NULL) - { - if (!HnswCheckNorm(buildstate->normprocinfo, buildstate->collation, value)) - return false; - - value = HnswNormValue(typeInfo, buildstate->collation, value); - } + /* Form index value */ + if (!HnswFormIndexValue(&value, values, isnull, buildstate->typeInfo, support)) + return false; /* Get datum size */ valueSize = VARSIZE_ANY(DatumGetPointer(value)); @@ -509,7 +497,7 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn { LWLockRelease(flushLock); - return HnswInsertTupleOnDisk(index, value, values, isnull, heaptid, true); + return HnswInsertTupleOnDisk(index, support, value, heaptid, true); } /* @@ -541,7 +529,7 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn LWLockRelease(flushLock); - return HnswInsertTupleOnDisk(index, value, values, isnull, heaptid, true); + return HnswInsertTupleOnDisk(index, support, value, heaptid, true); } /* Ok, we can proceed to allocate the element */ @@ -607,7 +595,7 @@ BuildCallback(Relation index, ItemPointer tid, Datum *values, * Initialize the graph */ static void -InitGraph(HnswGraph * graph, char *base, long memoryTotal) +InitGraph(HnswGraph * graph, char *base, Size memoryTotal) { /* Initialize the lock tranche if needed */ HnswInitLockTranche(); @@ -704,11 +692,9 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index buildstate->indtuples = 0; /* Get support functions */ - buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); - buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); - buildstate->collation = index->rd_indcollation[0]; + HnswInitSupport(&buildstate->support, index); - InitGraph(&buildstate->graphData, NULL, maintenance_work_mem * 1024L); + InitGraph(&buildstate->graphData, NULL, (Size) maintenance_work_mem * 1024L); buildstate->graph = &buildstate->graphData; buildstate->ml = HnswGetMl(buildstate->m); buildstate->maxLevel = HnswGetMaxLevel(buildstate->m); diff --git a/src/hnswinsert.c b/src/hnswinsert.c index fdc18c0..a5fac4e 100644 --- a/src/hnswinsert.c +++ b/src/hnswinsert.c @@ -340,6 +340,107 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B *updatedInsertPage = newInsertPage; } +/* + * Load neighbors + */ +static HnswNeighborArray * +HnswLoadNeighbors(HnswElement element, Relation index, int m, int lm, int lc) +{ + char *base = NULL; + HnswNeighborArray *neighbors = HnswInitNeighborArray(lm, NULL); + ItemPointerData indextids[HNSW_MAX_M * 2]; + + if (!HnswLoadNeighborTids(element, indextids, index, m, lm, lc)) + return neighbors; + + for (int i = 0; i < lm; i++) + { + ItemPointer indextid = &indextids[i]; + HnswElement e; + HnswCandidate *hc; + + if (!ItemPointerIsValid(indextid)) + break; + + e = HnswInitElementFromBlock(ItemPointerGetBlockNumber(indextid), ItemPointerGetOffsetNumber(indextid)); + hc = &neighbors->items[neighbors->length++]; + HnswPtrStore(base, hc->element, e); + } + + return neighbors; +} + +/* + * Load elements for insert + */ +static void +LoadElementsForInsert(HnswNeighborArray * neighbors, HnswQuery * q, int *idx, Relation index, HnswSupport * support) +{ + char *base = NULL; + + for (int i = 0; i < neighbors->length; i++) + { + HnswCandidate *hc = &neighbors->items[i]; + HnswElement element = HnswPtrAccess(base, hc->element); + double distance; + + HnswLoadElement(element, &distance, q, index, support, true, NULL); + hc->distance = distance; + + /* Prune element if being deleted */ + if (element->heaptidsLength == 0) + { + *idx = i; + break; + } + } +} + +/* + * Get update index + */ +static int +GetUpdateIndex(HnswElement element, HnswElement newElement, float distance, int m, int lm, int lc, Relation index, HnswSupport * support, MemoryContext updateCtx) +{ + char *base = NULL; + int idx = -1; + HnswNeighborArray *neighbors; + MemoryContext oldCtx = MemoryContextSwitchTo(updateCtx); + + /* + * Get latest neighbors since they may have changed. Do not lock yet since + * selecting neighbors can take time. Could use optimistic locking to + * retry if another update occurs before getting exclusive lock. + */ + neighbors = HnswLoadNeighbors(element, index, m, lm, lc); + + /* + * Could improve performance for vacuuming by checking neighbors against + * list of elements being deleted to find index. It's important to exclude + * already deleted elements for this since they can be replaced at any + * time. + */ + + if (neighbors->length < lm) + idx = -2; + else + { + HnswQuery q; + + q.value = HnswGetValue(base, element); + + LoadElementsForInsert(neighbors, &q, &idx, index, support); + + if (idx == -1) + HnswUpdateConnection(base, neighbors, newElement, distance, lm, &idx, index, support); + } + + MemoryContextSwitchTo(oldCtx); + MemoryContextReset(updateCtx); + + return idx; +} + /* * Check if connection already exists */ @@ -360,14 +461,94 @@ ConnectionExists(HnswElement e, HnswNeighborTuple ntup, int startIdx, int lm) return false; } +/* + * Update neighbor + */ +static void +UpdateNeighborOnDisk(HnswElement element, HnswElement newElement, int idx, int m, int lm, int lc, Relation index, bool checkExisting, bool building) +{ + Buffer buf; + Page page; + GenericXLogState *state; + HnswNeighborTuple ntup; + int startIdx; + OffsetNumber offno = element->neighborOffno; + + /* Register page */ + buf = ReadBuffer(index, element->neighborPage); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + if (building) + { + state = NULL; + page = BufferGetPage(buf); + } + else + { + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + } + + /* Get tuple */ + ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, offno)); + + /* Calculate index for update */ + startIdx = (element->level - lc) * m; + + /* Check for existing connection */ + if (checkExisting && ConnectionExists(newElement, ntup, startIdx, lm)) + idx = -1; + else if (idx == -2) + { + /* Find free offset if still exists */ + /* TODO Retry updating connections if not */ + for (int j = 0; j < lm; j++) + { + if (!ItemPointerIsValid(&ntup->indextids[startIdx + j])) + { + idx = startIdx + j; + break; + } + } + } + else + idx += startIdx; + + /* Make robust to issues */ + if (idx >= 0 && idx < ntup->count) + { + ItemPointer indextid = &ntup->indextids[idx]; + + /* Update neighbor on the buffer */ + ItemPointerSet(indextid, newElement->blkno, newElement->offno); + + /* Commit */ + if (building) + MarkBufferDirty(buf); + else + GenericXLogFinish(state); + } + else if (!building) + GenericXLogAbort(state); + + UnlockReleaseBuffer(buf); +} + /* * Update neighbors */ void -HnswUpdateNeighborsOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement e, int m, bool checkExisting, bool building) +HnswUpdateNeighborsOnDisk(Relation index, HnswSupport * support, HnswElement e, int m, bool checkExisting, bool building) { char *base = NULL; + /* Use separate memory context to improve performance for larger vectors */ + MemoryContext updateCtx = GenerationContextCreate(CurrentMemoryContext, + "Hnsw insert update context", +#if PG_VERSION_NUM >= 150000 + 128 * 1024, 128 * 1024, +#endif + 128 * 1024); + for (int lc = e->level; lc >= 0; lc--) { int lm = HnswGetLayerM(m, lc); @@ -376,96 +557,20 @@ HnswUpdateNeighborsOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, Hns for (int i = 0; i < neighbors->length; i++) { HnswCandidate *hc = &neighbors->items[i]; - Buffer buf; - Page page; - GenericXLogState *state; - HnswNeighborTuple ntup; - int idx = -1; - int startIdx; HnswElement neighborElement = HnswPtrAccess(base, hc->element); - OffsetNumber offno = neighborElement->neighborOffno; + int idx; - /* - * Get latest neighbors since they may have changed. Do not lock - * yet since selecting neighbors can take time. Could use - * optimistic locking to retry if another update occurs before - * getting exclusive lock. - */ - HnswLoadNeighbors(neighborElement, index, m); - - /* - * Could improve performance for vacuuming by checking neighbors - * against list of elements being deleted to find index. It's - * important to exclude already deleted elements for this since - * they can be replaced at any time. - */ - - /* Select neighbors */ - HnswUpdateConnection(NULL, e, hc, lm, lc, &idx, index, procinfo, collation); + idx = GetUpdateIndex(neighborElement, e, hc->distance, m, lm, lc, index, support, updateCtx); /* New element was not selected as a neighbor */ if (idx == -1) continue; - /* Register page */ - buf = ReadBuffer(index, neighborElement->neighborPage); - LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); - if (building) - { - state = NULL; - page = BufferGetPage(buf); - } - else - { - state = GenericXLogStart(index); - page = GenericXLogRegisterBuffer(state, buf, 0); - } - - /* Get tuple */ - ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, offno)); - - /* Calculate index for update */ - startIdx = (neighborElement->level - lc) * m; - - /* Check for existing connection */ - if (checkExisting && ConnectionExists(e, ntup, startIdx, lm)) - idx = -1; - else if (idx == -2) - { - /* Find free offset if still exists */ - /* TODO Retry updating connections if not */ - for (int j = 0; j < lm; j++) - { - if (!ItemPointerIsValid(&ntup->indextids[startIdx + j])) - { - idx = startIdx + j; - break; - } - } - } - else - idx += startIdx; - - /* Make robust to issues */ - if (idx >= 0 && idx < ntup->count) - { - ItemPointer indextid = &ntup->indextids[idx]; - - /* Update neighbor on the buffer */ - ItemPointerSet(indextid, e->blkno, e->offno); - - /* Commit */ - if (building) - MarkBufferDirty(buf); - else - GenericXLogFinish(state); - } - else if (!building) - GenericXLogAbort(state); - - UnlockReleaseBuffer(buf); + UpdateNeighborOnDisk(neighborElement, e, idx, m, lm, lc, index, checkExisting, building); } } + + MemoryContextDelete(updateCtx); } /* @@ -555,7 +660,7 @@ FindDuplicateOnDisk(Relation index, HnswElement element, bool building) * Update graph on disk */ static void -UpdateGraphOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement element, int m, int efConstruction, HnswElement entryPoint, bool building) +UpdateGraphOnDisk(Relation index, HnswSupport * support, HnswElement element, int m, int efConstruction, HnswElement entryPoint, bool building) { BlockNumber newInsertPage = InvalidBlockNumber; @@ -571,7 +676,7 @@ UpdateGraphOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement HnswUpdateMetaPage(index, 0, NULL, newInsertPage, MAIN_FORKNUM, building); /* Update neighbors */ - HnswUpdateNeighborsOnDisk(index, procinfo, collation, element, m, false, building); + HnswUpdateNeighborsOnDisk(index, support, element, m, false, building); /* Update entry point if needed */ if (entryPoint == NULL || element->level > entryPoint->level) @@ -582,14 +687,12 @@ UpdateGraphOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement * Insert a tuple into the index */ bool -HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull, ItemPointer heap_tid, bool building) +HnswInsertTupleOnDisk(Relation index, HnswSupport * support, Datum value, ItemPointer heaptid, bool building) { HnswElement entryPoint; HnswElement element; int m; int efConstruction = HnswGetEfConstruction(index); - FmgrInfo *procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); - Oid collation = index->rd_indcollation[0]; LOCKMODE lockmode = ShareLock; char *base = NULL; @@ -604,7 +707,7 @@ HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull, HnswGetMetaPageInfo(index, &m, &entryPoint); /* Create an element */ - element = HnswInitElement(base, heap_tid, m, HnswGetMl(m), HnswGetMaxLevel(m), NULL); + element = HnswInitElement(base, heaptid, m, HnswGetMl(m), HnswGetMaxLevel(m), NULL); HnswPtrStore(base, element->value, DatumGetPointer(value)); /* Prevent concurrent inserts when likely updating entry point */ @@ -622,10 +725,10 @@ HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull, } /* Find neighbors for element */ - HnswFindElementNeighbors(base, element, entryPoint, index, procinfo, collation, m, efConstruction, false); + HnswFindElementNeighbors(base, element, entryPoint, index, support, m, efConstruction, false); /* Update graph on disk */ - UpdateGraphOnDisk(index, procinfo, collation, element, m, efConstruction, entryPoint, building); + UpdateGraphOnDisk(index, support, element, m, efConstruction, entryPoint, building); /* Release lock */ UnlockPage(index, HNSW_UPDATE_LOCK, lockmode); @@ -637,31 +740,19 @@ HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull, * Insert a tuple into the index */ static void -HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid) +HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid) { Datum value; const HnswTypeInfo *typeInfo = HnswGetTypeInfo(index); - FmgrInfo *normprocinfo; - Oid collation = index->rd_indcollation[0]; + HnswSupport support; - /* Detoast once for all calls */ - value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); + HnswInitSupport(&support, index); - /* Check value */ - if (typeInfo->checkValue != NULL) - typeInfo->checkValue(DatumGetPointer(value)); + /* Form index value */ + if (!HnswFormIndexValue(&value, values, isnull, typeInfo, &support)) + return; - /* Normalize if needed */ - normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); - if (normprocinfo != NULL) - { - if (!HnswCheckNorm(normprocinfo, collation, value)) - return; - - value = HnswNormValue(typeInfo, collation, value); - } - - HnswInsertTupleOnDisk(index, value, values, isnull, heap_tid, false); + HnswInsertTupleOnDisk(index, &support, value, heaptid, false); } /* diff --git a/src/hnswscan.c b/src/hnswscan.c index de8fad6..1b61784 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -12,36 +12,36 @@ * Algorithm 5 from paper */ static List * -GetScanItems(IndexScanDesc scan, Datum q) +GetScanItems(IndexScanDesc scan, Datum value) { HnswScanOpaque so = (HnswScanOpaque) scan->opaque; Relation index = scan->indexRelation; - FmgrInfo *procinfo = so->procinfo; - Oid collation = so->collation; + HnswSupport *support = &so->support; List *ep; List *w; int m; HnswElement entryPoint; char *base = NULL; + HnswQuery *q = &so->q; /* Get m and entry point */ HnswGetMetaPageInfo(index, &m, &entryPoint); - so->q = q; + q->value = value; so->m = m; if (entryPoint == NULL) return NIL; - ep = list_make1(HnswEntryCandidate(base, entryPoint, q, index, procinfo, collation, false)); + ep = list_make1(HnswEntryCandidate(base, entryPoint, q, index, support, false)); for (int lc = entryPoint->level; lc >= 1; lc--) { - w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, false, NULL, NULL, NULL, true, NULL); + w = HnswSearchLayer(base, q, ep, 1, lc, index, support, m, false, NULL, NULL, NULL, true, NULL); ep = w; } - return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL, &so->v, hnsw_iterative_search != HNSW_ITERATIVE_SEARCH_OFF ? &so->discarded : NULL, true, &so->tuples); + return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, support, m, false, NULL, &so->v, hnsw_iterative_search != HNSW_ITERATIVE_SEARCH_OFF ? &so->discarded : NULL, true, &so->tuples); } /* @@ -52,8 +52,6 @@ ResumeScanItems(IndexScanDesc scan) { HnswScanOpaque so = (HnswScanOpaque) scan->opaque; Relation index = scan->indexRelation; - FmgrInfo *procinfo = so->procinfo; - Oid collation = so->collation; List *ep = NIL; char *base = NULL; int batch_size = hnsw_ef_search; @@ -74,7 +72,7 @@ ResumeScanItems(IndexScanDesc scan) ep = lappend(ep, hc); } - return HnswSearchLayer(base, so->q, ep, batch_size, 0, index, procinfo, collation, so->m, false, NULL, &so->v, &so->discarded, false, &so->tuples); + return HnswSearchLayer(base, &so->q, ep, batch_size, 0, index, &so->support, so->m, false, NULL, &so->v, &so->discarded, false, &so->tuples); } /* @@ -97,8 +95,8 @@ GetScanValue(IndexScanDesc scan) Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value))); /* Normalize if needed */ - if (so->normprocinfo != NULL) - value = HnswNormValue(so->typeInfo, so->collation, value); + if (so->support.normprocinfo != NULL) + value = HnswNormValue(so->typeInfo, so->support.collation, value); } return value; @@ -125,9 +123,7 @@ hnswbeginscan(Relation index, int nkeys, int norderbys) ALLOCSET_DEFAULT_SIZES); /* Set support functions */ - so->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); - so->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); - so->collation = index->rd_indcollation[0]; + HnswInitSupport(&so->support, index); scan->opaque = so; @@ -215,7 +211,7 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir) for (;;) { char *base = NULL; - HnswSearchCandidate *hc; + HnswSearchCandidate *sc; HnswElement element; ItemPointer heaptid; @@ -278,8 +274,8 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir) break; } - hc = llast(so->w); - element = HnswPtrAccess(base, hc->element); + sc = llast(so->w); + element = HnswPtrAccess(base, sc->element); /* Move to next element if no valid heap TIDs */ if (element->heaptidsLength == 0) @@ -290,7 +286,7 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir) if (hnsw_iterative_search != HNSW_ITERATIVE_SEARCH_OFF) { pfree(element); - pfree(hc); + pfree(sc); } continue; @@ -300,10 +296,10 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir) if (hnsw_iterative_search == HNSW_ITERATIVE_SEARCH_STRICT) { - if (hc->distance < so->previousDistance) + if (sc->distance < so->previousDistance) continue; - so->previousDistance = hc->distance; + so->previousDistance = sc->distance; } MemoryContextSwitchTo(oldCtx); diff --git a/src/hnswutils.c b/src/hnswutils.c index 2b731b5..fb35563 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -146,6 +146,17 @@ HnswOptionalProcInfo(Relation index, uint16 procnum) return index_getprocinfo(index, 1, procnum); } +/* + * Init support functions + */ +void +HnswInitSupport(HnswSupport * support, Relation index) +{ + support->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); + support->collation = index->rd_indcollation[0]; + support->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); +} + /* * Normalize value */ @@ -159,9 +170,9 @@ HnswNormValue(const HnswTypeInfo * typeInfo, Oid collation, Datum value) * Check if non-zero norm */ bool -HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value) +HnswCheckNorm(HnswSupport * support, Datum value) { - return DatumGetFloat8(FunctionCall1Coll(procinfo, collation, value)) > 0; + return DatumGetFloat8(FunctionCall1Coll(support->normprocinfo, support->collation, value)) > 0; } /* @@ -190,7 +201,7 @@ HnswInitPage(Buffer buf, Page page) /* * Allocate a neighbor array */ -static HnswNeighborArray * +HnswNeighborArray * HnswInitNeighborArray(int lm, HnswAllocator * allocator) { HnswNeighborArray *a = HnswAlloc(allocator, HNSW_NEIGHBOR_ARRAY_SIZE(lm)); @@ -389,6 +400,33 @@ HnswUpdateMetaPage(Relation index, int updateEntry, HnswElement entryPoint, Bloc UnlockReleaseBuffer(buf); } +/* + * Form index value + */ +bool +HnswFormIndexValue(Datum *out, Datum *values, bool *isnull, const HnswTypeInfo * typeInfo, HnswSupport * support) +{ + /* Detoast once for all calls */ + Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); + + /* Check value */ + if (typeInfo->checkValue != NULL) + typeInfo->checkValue(DatumGetPointer(value)); + + /* Normalize if needed */ + if (support->normprocinfo != NULL) + { + if (!HnswCheckNorm(support, value)) + return false; + + value = HnswNormValue(typeInfo, support->collation, value); + } + + *out = value; + + return true; +} + /* * Set element tuple, except for neighbor info */ @@ -446,69 +484,6 @@ HnswSetNeighborTuple(char *base, HnswNeighborTuple ntup, HnswElement e, int m) ntup->version = e->version; } -/* - * Load neighbors from page - */ -static void -LoadNeighborsFromPage(HnswElement element, Relation index, Page page, int m) -{ - char *base = NULL; - - HnswNeighborTuple ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno)); - int neighborCount = (element->level + 2) * m; - - Assert(HnswIsNeighborTuple(ntup)); - - HnswInitNeighbors(base, element, m, NULL); - - /* Ensure expected neighbors */ - if (ntup->count != neighborCount) - return; - - for (int i = 0; i < neighborCount; i++) - { - HnswElement e; - int level; - HnswCandidate *hc; - ItemPointer indextid; - HnswNeighborArray *neighbors; - - indextid = &ntup->indextids[i]; - - if (!ItemPointerIsValid(indextid)) - continue; - - e = HnswInitElementFromBlock(ItemPointerGetBlockNumber(indextid), ItemPointerGetOffsetNumber(indextid)); - - /* Calculate level based on offset */ - level = element->level - i / m; - if (level < 0) - level = 0; - - neighbors = HnswGetNeighbors(base, element, level); - hc = &neighbors->items[neighbors->length++]; - HnswPtrStore(base, hc->element, e); - } -} - -/* - * Load neighbors - */ -void -HnswLoadNeighbors(HnswElement element, Relation index, int m) -{ - Buffer buf; - Page page; - - buf = ReadBuffer(index, element->neighborPage); - LockBuffer(buf, BUFFER_LOCK_SHARE); - page = BufferGetPage(buf); - - LoadNeighborsFromPage(element, index, page, m); - - UnlockReleaseBuffer(buf); -} - /* * Load an element from a tuple */ @@ -543,11 +518,20 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe } } +/* + * Calculate the distance between values + */ +static inline double +HnswGetDistance(Datum a, Datum b, HnswSupport * support) +{ + return DatumGetFloat8(FunctionCall2Coll(support->procinfo, support->collation, a, b)); +} + /* * Load an element and optionally get its distance from q */ static void -HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, double *maxDistance, HnswElement * element) +HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec, double *maxDistance, HnswElement * element) { Buffer buf; Page page; @@ -565,10 +549,10 @@ HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, Dat /* Calculate distance */ if (distance != NULL) { - if (DatumGetPointer(*q) == NULL) + if (DatumGetPointer(q->value) == NULL) *distance = 0; else - *distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&etup->data))); + *distance = HnswGetDistance(q->value, PointerGetDatum(&etup->data), support); } /* Load element */ @@ -587,35 +571,36 @@ HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, Dat * Load an element and optionally get its distance from q */ void -HnswLoadElement(HnswElement element, double *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, double *maxDistance) +HnswLoadElement(HnswElement element, double *distance, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec, double *maxDistance) { - HnswLoadElementImpl(element->blkno, element->offno, distance, q, index, procinfo, collation, loadVec, maxDistance, &element); + HnswLoadElementImpl(element->blkno, element->offno, distance, q, index, support, loadVec, maxDistance, &element); } /* * Get the distance for an element */ static double -GetElementDistance(char *base, HnswElement element, Datum q, FmgrInfo *procinfo, Oid collation) +GetElementDistance(char *base, HnswElement element, HnswQuery * q, HnswSupport * support) { Datum value = HnswGetValue(base, element); - return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, value)); + return HnswGetDistance(q->value, value, support); } /* * Create a candidate for the entry point */ HnswSearchCandidate * -HnswEntryCandidate(char *base, HnswElement entryPoint, Datum q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec) +HnswEntryCandidate(char *base, HnswElement entryPoint, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec) { HnswSearchCandidate *sc = palloc(sizeof(HnswSearchCandidate)); + bool inMemory = index == NULL; HnswPtrStore(base, sc->element, entryPoint); - if (index == NULL) - sc->distance = GetElementDistance(base, entryPoint, q, procinfo, collation); + if (inMemory) + sc->distance = GetElementDistance(base, entryPoint, q, support); else - HnswLoadElement(entryPoint, &sc->distance, &q, index, procinfo, collation, loadVec, NULL); + HnswLoadElement(entryPoint, &sc->distance, q, index, support, loadVec, NULL); return sc; } @@ -668,9 +653,9 @@ CompareFurthestCandidates(const pairingheap_node *a, const pairingheap_node *b, * Init visited */ static inline void -InitVisited(char *base, visited_hash * v, Relation index, int ef, int m) +InitVisited(char *base, visited_hash * v, bool inMemory, int ef, int m) { - if (index != NULL) + if (!inMemory) v->tids = tidhash_create(CurrentMemoryContext, ef * m * 2, NULL); else if (base != NULL) v->offsets = offsethash_create(CurrentMemoryContext, ef * m * 2, NULL); @@ -682,9 +667,9 @@ InitVisited(char *base, visited_hash * v, Relation index, int ef, int m) * Add to visited */ static inline void -AddToVisited(char *base, visited_hash * v, HnswElementPtr elementPtr, Relation index, bool *found) +AddToVisited(char *base, visited_hash * v, HnswElementPtr elementPtr, bool inMemory, bool *found) { - if (index != NULL) + if (!inMemory) { HnswElement element = HnswPtrAccess(base, elementPtr); ItemPointerData indextid; @@ -745,7 +730,7 @@ HnswLoadUnvisitedFromMemory(char *base, HnswElement element, HnswUnvisited * unv HnswCandidate *hc = &localNeighborhood->items[i]; bool found; - AddToVisited(base, v, hc->element, NULL, &found); + AddToVisited(base, v, hc->element, true, &found); if (!found) unvisited[(*unvisitedLength)++].element = HnswPtrAccess(base, hc->element); @@ -753,18 +738,15 @@ HnswLoadUnvisitedFromMemory(char *base, HnswElement element, HnswUnvisited * unv } /* - * Load unvisited neighbors from disk + * Load neighbor index TIDs */ -static void -HnswLoadUnvisitedFromDisk(HnswElement element, HnswUnvisited * unvisited, int *unvisitedLength, visited_hash * v, Relation index, int m, int lm, int lc) +bool +HnswLoadNeighborTids(HnswElement element, ItemPointerData *indextids, Relation index, int m, int lm, int lc) { Buffer buf; Page page; HnswNeighborTuple ntup; int start; - ItemPointerData indextids[HNSW_MAX_M * 2]; - - *unvisitedLength = 0; buf = ReadBuffer(index, element->neighborPage); LockBuffer(buf, BUFFER_LOCK_SHARE); @@ -779,14 +761,29 @@ HnswLoadUnvisitedFromDisk(HnswElement element, HnswUnvisited * unvisited, int *u if (ntup->version != element->version || ntup->count != (element->level + 2) * m) { UnlockReleaseBuffer(buf); - return; + return false; } /* Copy to minimize lock time */ start = (element->level - lc) * m; - memcpy(&indextids, ntup->indextids + start, lm * sizeof(ItemPointerData)); + memcpy(indextids, ntup->indextids + start, lm * sizeof(ItemPointerData)); UnlockReleaseBuffer(buf); + return true; +} + +/* + * Load unvisited neighbors from disk + */ +static void +HnswLoadUnvisitedFromDisk(HnswElement element, HnswUnvisited * unvisited, int *unvisitedLength, visited_hash * v, Relation index, int m, int lm, int lc) +{ + ItemPointerData indextids[HNSW_MAX_M * 2]; + + *unvisitedLength = 0; + + if (!HnswLoadNeighborTids(element, indextids, index, m, lm, lc)) + return; for (int i = 0; i < lm; i++) { @@ -807,7 +804,7 @@ HnswLoadUnvisitedFromDisk(HnswElement element, HnswUnvisited * unvisited, int *u * Algorithm 2 from paper */ List * -HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement, visited_hash * v, pairingheap **discarded, bool initVisited, int64 *tuples) +HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement, visited_hash * v, pairingheap **discarded, bool initVisited, int64 *tuples) { List *w = NIL; pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL); @@ -820,6 +817,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F int lm = HnswGetLayerM(m, lc); HnswUnvisited *unvisited = palloc(lm * sizeof(HnswUnvisited)); int unvisitedLength; + bool inMemory = index == NULL; if (v == NULL) { @@ -829,14 +827,14 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F if (initVisited) { - InitVisited(base, v, index, ef, m); + InitVisited(base, v, inMemory, ef, m); if (discarded != NULL) *discarded = pairingheap_allocate(CompareNearestDiscardedCandidates, NULL); } /* Create local memory for neighborhood if needed */ - if (index == NULL) + if (inMemory) { neighborhoodSize = HNSW_NEIGHBOR_ARRAY_SIZE(lm); localNeighborhood = palloc(neighborhoodSize); @@ -850,7 +848,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F if (initVisited) { - AddToVisited(base, v, sc->element, index, &found); + AddToVisited(base, v, sc->element, inMemory, &found); if (tuples != NULL) (*tuples)++; @@ -879,7 +877,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F cElement = HnswPtrAccess(base, c->element); - if (index == NULL) + if (inMemory) HnswLoadUnvisitedFromMemory(base, cElement, unvisited, &unvisitedLength, v, lc, localNeighborhood, neighborhoodSize); else HnswLoadUnvisitedFromDisk(cElement, unvisited, &unvisitedLength, v, index, m, lm, lc); @@ -896,10 +894,10 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F f = HnswGetSearchCandidate(w_node, pairingheap_first(W)); - if (index == NULL) + if (inMemory) { eElement = unvisited[i].element; - eDistance = GetElementDistance(base, eElement, q, procinfo, collation); + eDistance = GetElementDistance(base, eElement, q, support); } else { @@ -909,7 +907,10 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F /* Avoid any allocations if not adding */ eElement = NULL; - HnswLoadElementImpl(blkno, offno, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd || discarded != NULL ? NULL : &f->distance, &eElement); + HnswLoadElementImpl(blkno, offno, &eDistance, q, index, support, inserting, alwaysAdd || discarded != NULL ? NULL : &f->distance, &eElement); + + if (eElement == NULL) + continue; } if (eElement == NULL || !(eDistance < f->distance || alwaysAdd)) @@ -1017,32 +1018,22 @@ CompareCandidateDistancesOffset(const ListCell *a, const ListCell *b) return 0; } -/* - * Calculate the distance between elements - */ -static float -HnswGetDistance(char *base, HnswElement a, HnswElement b, FmgrInfo *procinfo, Oid collation) -{ - Datum aValue = HnswGetValue(base, a); - Datum bValue = HnswGetValue(base, b); - - return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, aValue, bValue)); -} - /* * Check if an element is closer to q than any element from R */ static bool -CheckElementCloser(char *base, HnswCandidate * e, List *r, FmgrInfo *procinfo, Oid collation) +CheckElementCloser(char *base, HnswCandidate * e, List *r, HnswSupport * support) { HnswElement eElement = HnswPtrAccess(base, e->element); + Datum eValue = HnswGetValue(base, eElement); ListCell *lc2; foreach(lc2, r) { HnswCandidate *ri = lfirst(lc2); HnswElement riElement = HnswPtrAccess(base, ri->element); - float distance = HnswGetDistance(base, eElement, riElement, procinfo, collation); + Datum riValue = HnswGetValue(base, riElement); + float distance = HnswGetDistance(eValue, riValue, support); if (distance <= e->distance) return false; @@ -1055,15 +1046,14 @@ CheckElementCloser(char *base, HnswCandidate * e, List *r, FmgrInfo *procinfo, O * Algorithm 4 from paper */ static List * -SelectNeighbors(char *base, List *c, int lm, int lc, FmgrInfo *procinfo, Oid collation, HnswElement e2, HnswCandidate * newCandidate, HnswCandidate * *pruned, bool sortCandidates) +SelectNeighbors(char *base, List *c, int lm, HnswSupport * support, bool *closerSet, HnswCandidate * newCandidate, HnswCandidate * *pruned, bool sortCandidates) { List *r = NIL; List *w = list_copy(c); HnswCandidate **wd; int wdlen = 0; int wdoff = 0; - HnswNeighborArray *neighbors = HnswGetNeighbors(base, e2, lc); - bool mustCalculate = !neighbors->closerSet; + bool mustCalculate = !(*closerSet); List *added = NIL; bool removedAny = false; @@ -1090,7 +1080,7 @@ SelectNeighbors(char *base, List *c, int lm, int lc, FmgrInfo *procinfo, Oid col /* Use previous state of r and wd to skip work when possible */ if (mustCalculate) - e->closer = CheckElementCloser(base, e, r, procinfo, collation); + e->closer = CheckElementCloser(base, e, r, support); else if (list_length(added) > 0) { /* Keep Valgrind happy for in-memory, parallel builds */ @@ -1103,7 +1093,7 @@ SelectNeighbors(char *base, List *c, int lm, int lc, FmgrInfo *procinfo, Oid col */ if (e->closer) { - e->closer = CheckElementCloser(base, e, added, procinfo, collation); + e->closer = CheckElementCloser(base, e, added, support); if (!e->closer) removedAny = true; @@ -1116,7 +1106,7 @@ SelectNeighbors(char *base, List *c, int lm, int lc, FmgrInfo *procinfo, Oid col */ if (removedAny) { - e->closer = CheckElementCloser(base, e, r, procinfo, collation); + e->closer = CheckElementCloser(base, e, r, support); if (e->closer) added = lappend(added, e); } @@ -1124,7 +1114,7 @@ SelectNeighbors(char *base, List *c, int lm, int lc, FmgrInfo *procinfo, Oid col } else if (e == newCandidate) { - e->closer = CheckElementCloser(base, e, r, procinfo, collation); + e->closer = CheckElementCloser(base, e, r, support); if (e->closer) added = lappend(added, e); } @@ -1140,7 +1130,7 @@ SelectNeighbors(char *base, List *c, int lm, int lc, FmgrInfo *procinfo, Oid col } /* Cached value can only be used in future if sorted deterministically */ - neighbors->closerSet = sortCandidates; + *closerSet = sortCandidates; /* Keep pruned connections */ while (wdoff < wdlen && list_length(r) < lm) @@ -1175,18 +1165,16 @@ AddConnections(char *base, HnswElement element, List *neighbors, int lc) * Update connections */ void -HnswUpdateConnection(char *base, HnswElement element, HnswCandidate * hc, int lm, int lc, int *updateIdx, Relation index, FmgrInfo *procinfo, Oid collation) +HnswUpdateConnection(char *base, HnswNeighborArray * neighbors, HnswElement newElement, float distance, int lm, int *updateIdx, Relation index, HnswSupport * support) { - HnswElement hce = HnswPtrAccess(base, hc->element); - HnswNeighborArray *currentNeighbors = HnswGetNeighbors(base, hce, lc); - HnswCandidate hc2; + HnswCandidate newHc; - HnswPtrStore(base, hc2.element, element); - hc2.distance = hc->distance; + HnswPtrStore(base, newHc.element, newElement); + newHc.distance = distance; - if (currentNeighbors->length < lm) + if (neighbors->length < lm) { - currentNeighbors->items[currentNeighbors->length++] = hc2; + neighbors->items[neighbors->length++] = newHc; /* Track update */ if (updateIdx != NULL) @@ -1195,59 +1183,26 @@ HnswUpdateConnection(char *base, HnswElement element, HnswCandidate * hc, int lm else { /* Shrink connections */ + List *c = NIL; HnswCandidate *pruned = NULL; - /* Load elements on insert */ - if (index != NULL) - { - Datum q = HnswGetValue(base, hce); + /* Add candidates */ + for (int i = 0; i < neighbors->length; i++) + c = lappend(c, &neighbors->items[i]); + c = lappend(c, &newHc); - for (int i = 0; i < currentNeighbors->length; i++) - { - HnswCandidate *hc3 = ¤tNeighbors->items[i]; - HnswElement hc3Element = HnswPtrAccess(base, hc3->element); - - if (HnswPtrIsNull(base, hc3Element->value)) - { - double distance; - - HnswLoadElement(hc3Element, &distance, &q, index, procinfo, collation, true, NULL); - hc3->distance = distance; - } - else - hc3->distance = GetElementDistance(base, hc3Element, q, procinfo, collation); - - /* Prune element if being deleted */ - if (hc3Element->heaptidsLength == 0) - { - pruned = ¤tNeighbors->items[i]; - break; - } - } - } + SelectNeighbors(base, c, lm, support, &neighbors->closerSet, &newHc, &pruned, true); + /* Should not happen */ if (pruned == NULL) - { - List *c = NIL; - - /* Add candidates */ - for (int i = 0; i < currentNeighbors->length; i++) - c = lappend(c, ¤tNeighbors->items[i]); - c = lappend(c, &hc2); - - SelectNeighbors(base, c, lm, lc, procinfo, collation, hce, &hc2, &pruned, true); - - /* Should not happen */ - if (pruned == NULL) - return; - } + return; /* Find and replace the pruned element */ - for (int i = 0; i < currentNeighbors->length; i++) + for (int i = 0; i < neighbors->length; i++) { - if (HnswPtrEqual(base, currentNeighbors->items[i].element, pruned->element)) + if (HnswPtrEqual(base, neighbors->items[i].element, pruned->element)) { - currentNeighbors->items[i] = hc2; + neighbors->items[i] = newHc; /* Track update */ if (updateIdx != NULL) @@ -1307,17 +1262,20 @@ PrecomputeHash(char *base, HnswElement element) * Algorithm 1 from paper */ void -HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, bool existing) +HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, HnswSupport * support, int m, int efConstruction, bool existing) { List *ep; List *w; int level = element->level; int entryLevel; - Datum q = HnswGetValue(base, element); + HnswQuery q; HnswElement skipElement = existing ? element : NULL; + bool inMemory = index == NULL; + + q.value = HnswGetValue(base, element); /* Precompute hash */ - if (index == NULL) + if (inMemory) PrecomputeHash(base, element); /* No neighbors if no entry point */ @@ -1325,13 +1283,13 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint return; /* Get entry point and level */ - ep = list_make1(HnswEntryCandidate(base, entryPoint, q, index, procinfo, collation, true)); + ep = list_make1(HnswEntryCandidate(base, entryPoint, &q, index, support, true)); entryLevel = entryPoint->level; /* 1st phase: greedy search to insert level */ for (int lc = entryLevel; lc >= level + 1; lc--) { - w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, true, skipElement, NULL, NULL, true, NULL); + w = HnswSearchLayer(base, &q, ep, 1, lc, index, support, m, true, skipElement, NULL, NULL, true, NULL); ep = w; } @@ -1350,7 +1308,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint List *lw = NIL; ListCell *lc2; - w = HnswSearchLayer(base, q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement, NULL, NULL, true, NULL); + w = HnswSearchLayer(base, &q, ep, efConstruction, lc, index, support, m, true, skipElement, NULL, NULL, true, NULL); /* Convert search candidates to candidates */ foreach(lc2, w) @@ -1366,7 +1324,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint /* Elements being deleted or skipped can help with search */ /* but should be removed before selecting neighbors */ - if (index != NULL) + if (!inMemory) lw = RemoveElements(base, lw, skipElement); /* @@ -1374,7 +1332,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint * sortCandidates to true for in-memory builds to enable closer * caching, but there does not seem to be a difference in performance. */ - neighbors = SelectNeighbors(base, lw, lm, lc, procinfo, collation, element, NULL, NULL, false); + neighbors = SelectNeighbors(base, lw, lm, support, &HnswGetNeighbors(base, element, lc)->closerSet, NULL, NULL, false); AddConnections(base, element, neighbors, lc); diff --git a/src/hnswvacuum.c b/src/hnswvacuum.c index c4a777c..251d9d9 100644 --- a/src/hnswvacuum.c +++ b/src/hnswvacuum.c @@ -184,13 +184,12 @@ static void RepairGraphElement(HnswVacuumState * vacuumstate, HnswElement element, HnswElement entryPoint) { Relation index = vacuumstate->index; + HnswSupport *support = &vacuumstate->support; Buffer buf; Page page; GenericXLogState *state; int m = vacuumstate->m; int efConstruction = vacuumstate->efConstruction; - FmgrInfo *procinfo = vacuumstate->procinfo; - Oid collation = vacuumstate->collation; BufferAccessStrategy bas = vacuumstate->bas; HnswNeighborTuple ntup = vacuumstate->ntup; Size ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(element->level, m); @@ -205,7 +204,7 @@ RepairGraphElement(HnswVacuumState * vacuumstate, HnswElement element, HnswEleme element->heaptidsLength = 0; /* Find neighbors for element, skipping itself */ - HnswFindElementNeighbors(base, element, entryPoint, index, procinfo, collation, m, efConstruction, true); + HnswFindElementNeighbors(base, element, entryPoint, index, support, m, efConstruction, true); /* Zero memory for each element */ MemSet(ntup, 0, HNSW_TUPLE_ALLOC_SIZE); @@ -229,7 +228,7 @@ RepairGraphElement(HnswVacuumState * vacuumstate, HnswElement element, HnswEleme UnlockReleaseBuffer(buf); /* Update neighbors */ - HnswUpdateNeighborsOnDisk(index, procinfo, collation, element, m, true, false); + HnswUpdateNeighborsOnDisk(index, support, element, m, true, false); } /* @@ -239,6 +238,7 @@ static void RepairGraphEntryPoint(HnswVacuumState * vacuumstate) { Relation index = vacuumstate->index; + HnswSupport *support = &vacuumstate->support; HnswElement highestPoint = &vacuumstate->highestPoint; HnswElement entryPoint; MemoryContext oldCtx = MemoryContextSwitchTo(vacuumstate->tmpCtx); @@ -256,7 +256,7 @@ RepairGraphEntryPoint(HnswVacuumState * vacuumstate) LockPage(index, HNSW_UPDATE_LOCK, ShareLock); /* Load element */ - HnswLoadElement(highestPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true, NULL); + HnswLoadElement(highestPoint, NULL, NULL, index, support, true, NULL); /* Repair if needed */ if (NeedsUpdated(vacuumstate, highestPoint)) @@ -294,7 +294,7 @@ RepairGraphEntryPoint(HnswVacuumState * vacuumstate) * is outdated, this can remove connections at higher levels in * the graph until they are repaired, but this should be fine. */ - HnswLoadElement(entryPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true, NULL); + HnswLoadElement(entryPoint, NULL, NULL, index, support, true, NULL); if (NeedsUpdated(vacuumstate, entryPoint)) { @@ -581,13 +581,13 @@ InitVacuumState(HnswVacuumState * vacuumstate, IndexVacuumInfo *info, IndexBulkD vacuumstate->callback_state = callback_state; vacuumstate->efConstruction = HnswGetEfConstruction(index); vacuumstate->bas = GetAccessStrategy(BAS_BULKREAD); - vacuumstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); - vacuumstate->collation = index->rd_indcollation[0]; vacuumstate->ntup = palloc0(HNSW_TUPLE_ALLOC_SIZE); vacuumstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, "Hnsw vacuum temporary context", ALLOCSET_DEFAULT_SIZES); + HnswInitSupport(&vacuumstate->support, index); + /* Get m from metapage */ HnswGetMetaPageInfo(index, &vacuumstate->m, NULL); diff --git a/src/ivfflat.c b/src/ivfflat.c index 986e19d..395040d 100644 --- a/src/ivfflat.c +++ b/src/ivfflat.c @@ -69,6 +69,8 @@ ivfflatcostestimate(PlannerInfo *root, IndexPath *path, double loop_count, GenericCosts costs; int lists; double ratio; + double sequentialRatio = 0.5; + double startupPages; double spc_seq_page_cost; Relation index; @@ -85,6 +87,8 @@ ivfflatcostestimate(PlannerInfo *root, IndexPath *path, double loop_count, MemSet(&costs, 0, sizeof(costs)); + genericcostestimate(root, path, loop_count, &costs); + index = index_open(path->indexinfo->indexoid, NoLock); IvfflatGetMetaPageInfo(index, &lists, NULL); index_close(index, NoLock); @@ -94,34 +98,26 @@ ivfflatcostestimate(PlannerInfo *root, IndexPath *path, double loop_count, if (ratio > 1.0) ratio = 1.0; - /* - * This gives us the subset of tuples to visit. This value is passed into - * the generic cost estimator to determine the number of pages to visit - * during the index scan. - */ - costs.numIndexTuples = path->indexinfo->tuples * ratio; - - genericcostestimate(root, path, loop_count, &costs); - get_tablespace_page_costs(path->indexinfo->reltablespace, NULL, &spc_seq_page_cost); + /* Change some page cost from random to sequential */ + costs.indexTotalCost -= sequentialRatio * costs.numIndexPages * (costs.spc_random_page_cost - spc_seq_page_cost); + + /* Startup cost is cost before returning the first row */ + costs.indexStartupCost = costs.indexTotalCost * ratio; + /* Adjust cost if needed since TOAST not included in seq scan cost */ - if (costs.numIndexPages > path->indexinfo->rel->pages && ratio < 0.5) + startupPages = costs.numIndexPages * ratio; + if (startupPages > path->indexinfo->rel->pages && ratio < 0.5) { - /* Change all page cost from random to sequential */ - costs.indexTotalCost -= costs.numIndexPages * (costs.spc_random_page_cost - spc_seq_page_cost); + /* Change rest of page cost from random to sequential */ + costs.indexStartupCost -= (1 - sequentialRatio) * startupPages * (costs.spc_random_page_cost - spc_seq_page_cost); /* Remove cost of extra pages */ - costs.indexTotalCost -= (costs.numIndexPages - path->indexinfo->rel->pages) * spc_seq_page_cost; - } - else - { - /* Change some page cost from random to sequential */ - costs.indexTotalCost -= 0.5 * costs.numIndexPages * (costs.spc_random_page_cost - spc_seq_page_cost); + costs.indexStartupCost -= (startupPages - path->indexinfo->rel->pages) * spc_seq_page_cost; } - /* Use total cost since most work happens before first tuple is returned */ - *indexStartupCost = costs.indexTotalCost; + *indexStartupCost = costs.indexStartupCost; *indexTotalCost = costs.indexTotalCost; *indexSelectivity = costs.indexSelectivity; *indexCorrelation = costs.indexCorrelation; diff --git a/test/t/039_hnsw_cost.pl b/test/t/039_hnsw_cost.pl index 763e374..97ea5e7 100644 --- a/test/t/039_hnsw_cost.pl +++ b/test/t/039_hnsw_cost.pl @@ -17,12 +17,11 @@ $node->safe_psql("postgres", "CREATE EXTENSION vector;"); for my $dim (@dims) { my $array_sql = join(",", ('random()') x $dim); - my $n = 6000; # Create table and index $node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));"); $node->safe_psql("postgres", - "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, $n) i;" + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 2000) i;" ); $node->safe_psql("postgres", "CREATE INDEX idx ON tst USING hnsw (v vector_l2_ops);"); $node->safe_psql("postgres", "ANALYZE tst;"); @@ -40,6 +39,16 @@ for my $dim (@dims) )); like($explain, qr/Index Scan using idx/); + # 3x the rows are needed for distance filters + # since the planner uses DEFAULT_INEQ_SEL for the selectivity (should be 1) + # Recreate index for performance + $node->safe_psql("postgres", "DROP INDEX idx;"); + $node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(2001, 6000) i;" + ); + $node->safe_psql("postgres", "CREATE INDEX idx ON tst USING hnsw (v vector_l2_ops);"); + $node->safe_psql("postgres", "ANALYZE tst;"); + $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE v <-> '$query' < 1 ORDER BY v <-> '$query' LIMIT $limit; ));