diff --git a/CHANGELOG.md b/CHANGELOG.md index a7d9924..6753a7d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ ## 0.8.0 (unreleased) +- Added support for iterative index scans - Added casts for arrays to `sparsevec` - Improved cost estimation - Improved performance of HNSW inserts and on-disk index builds diff --git a/src/hnsw.c b/src/hnsw.c index c2579c1..57fcbdb 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -18,7 +18,16 @@ #define MarkGUCPrefixReserved(x) EmitWarningsOnPlaceholders(x) #endif +static const struct config_enum_entry hnsw_iterative_search_options[] = { + {"off", HNSW_ITERATIVE_SEARCH_OFF, false}, + {"on", HNSW_ITERATIVE_SEARCH_RELAXED, false}, + {"strict", HNSW_ITERATIVE_SEARCH_STRICT, false}, + {NULL, 0, false} +}; + int hnsw_ef_search; +int hnsw_iterative_search_max_tuples; +int hnsw_iterative_search; int hnsw_lock_tranche_id; static relopt_kind hnsw_relopt_kind; @@ -69,6 +78,15 @@ HnswInit(void) "Valid range is 1..1000.", &hnsw_ef_search, HNSW_DEFAULT_EF_SEARCH, HNSW_MIN_EF_SEARCH, HNSW_MAX_EF_SEARCH, PGC_USERSET, 0, NULL, NULL, NULL); + DefineCustomEnumVariable("hnsw.iterative_search", "Sets iterative search", + NULL, &hnsw_iterative_search, + HNSW_ITERATIVE_SEARCH_OFF, hnsw_iterative_search_options, PGC_USERSET, 0, NULL, NULL, NULL); + + /* TODO Ensure ivfflat.max_probes uses same value for "all" */ + DefineCustomIntVariable("hnsw.iterative_search_max_tuples", "Sets the max number of candidates to visit for iterative search", + "-1 means all", &hnsw_iterative_search_max_tuples, + -1, -1, INT_MAX, PGC_USERSET, 0, NULL, NULL, NULL); + MarkGUCPrefixReserved("hnsw"); } diff --git a/src/hnsw.h b/src/hnsw.h index b2614d1..254a60a 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -109,8 +109,17 @@ /* Variables */ extern int hnsw_ef_search; +extern int hnsw_iterative_search; +extern int hnsw_iterative_search_max_tuples; extern int hnsw_lock_tranche_id; +typedef enum HnswIterativeSearchType +{ + HNSW_ITERATIVE_SEARCH_OFF, + HNSW_ITERATIVE_SEARCH_RELAXED, + HNSW_ITERATIVE_SEARCH_STRICT +} HnswIterativeSearchType; + typedef struct HnswElementData HnswElementData; typedef struct HnswNeighborArray HnswNeighborArray; @@ -132,6 +141,7 @@ struct HnswElementData uint8 heaptidsLength; uint8 level; uint8 deleted; + uint8 version; uint32 hash; HnswNeighborsPtr neighbors; BlockNumber blkno; @@ -319,10 +329,10 @@ typedef struct HnswElementTupleData uint8 type; uint8 level; uint8 deleted; - uint8 unused; + uint8 version; ItemPointerData heaptids[HNSW_HEAPTIDS]; ItemPointerData neighbortid; - uint16 unused2; + uint16 unused; Vector data; } HnswElementTupleData; @@ -331,7 +341,7 @@ typedef HnswElementTupleData * HnswElementTuple; typedef struct HnswNeighborTupleData { uint8 type; - uint8 unused; + uint8 version; uint16 count; ItemPointerData indextids[FLEXIBLE_ARRAY_MEMBER]; } HnswNeighborTupleData; @@ -356,6 +366,12 @@ typedef struct HnswScanOpaqueData const HnswTypeInfo *typeInfo; bool first; List *w; + visited_hash v; + pairingheap *discarded; + HnswQuery q; + int m; + int64 tuples; + double previousDistance; MemoryContext tmpCtx; /* Support functions */ @@ -399,7 +415,7 @@ bool HnswCheckNorm(HnswSupport * support, Datum value); Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); void HnswInit(void); -List *HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement); +List *HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement, visited_hash * v, pairingheap **discarded, bool initVisited, int64 *tuples); HnswElement HnswGetEntryPoint(Relation index); void HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint); void *HnswAlloc(HnswAllocator * allocator, Size size); diff --git a/src/hnswinsert.c b/src/hnswinsert.c index 84eb1d4..a5fac4e 100644 --- a/src/hnswinsert.c +++ b/src/hnswinsert.c @@ -36,7 +36,7 @@ GetInsertPage(Relation index) * Check for a free offset */ static bool -HnswFreeOffset(Relation index, Buffer buf, Page page, HnswElement element, Size etupSize, Size ntupSize, Buffer *nbuf, Page *npage, OffsetNumber *freeOffno, OffsetNumber *freeNeighborOffno, BlockNumber *newInsertPage) +HnswFreeOffset(Relation index, Buffer buf, Page page, HnswElement element, Size etupSize, Size ntupSize, Buffer *nbuf, Page *npage, OffsetNumber *freeOffno, OffsetNumber *freeNeighborOffno, BlockNumber *newInsertPage, uint8 *tupleVersion) { OffsetNumber offno; OffsetNumber maxoffno = PageGetMaxOffsetNumber(page); @@ -98,6 +98,7 @@ HnswFreeOffset(Relation index, Buffer buf, Page page, HnswElement element, Size { *freeOffno = offno; *freeNeighborOffno = neighborOffno; + *tupleVersion = etup->version; return true; } else if (*nbuf != buf) @@ -153,6 +154,7 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B OffsetNumber freeOffno = InvalidOffsetNumber; OffsetNumber freeNeighborOffno = InvalidOffsetNumber; BlockNumber newInsertPage = InvalidBlockNumber; + uint8 tupleVersion; char *base = NULL; /* Calculate sizes */ @@ -202,7 +204,7 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B } /* Next, try space from a deleted element */ - if (HnswFreeOffset(index, buf, page, e, etupSize, ntupSize, &nbuf, &npage, &freeOffno, &freeNeighborOffno, &newInsertPage)) + if (HnswFreeOffset(index, buf, page, e, etupSize, ntupSize, &nbuf, &npage, &freeOffno, &freeNeighborOffno, &newInsertPage, &tupleVersion)) { if (nbuf != buf) { @@ -212,6 +214,10 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B npage = GenericXLogRegisterBuffer(state, nbuf, 0); } + /* Set tuple version */ + etup->version = tupleVersion; + ntup->version = tupleVersion; + break; } diff --git a/src/hnswscan.c b/src/hnswscan.c index 2c6a454..3a5c5d5 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -1,5 +1,7 @@ #include "postgres.h" +#include + #include "access/relscan.h" #include "hnsw.h" #include "pgstat.h" @@ -21,25 +23,57 @@ GetScanItems(IndexScanDesc scan, Datum value) int m; HnswElement entryPoint; char *base = NULL; - HnswQuery q; - - q.value = value; + HnswQuery *q = &so->q; /* Get m and entry point */ HnswGetMetaPageInfo(index, &m, &entryPoint); + q->value = value; + so->m = m; + if (entryPoint == NULL) return NIL; - ep = list_make1(HnswEntryCandidate(base, entryPoint, &q, index, support, false)); + ep = list_make1(HnswEntryCandidate(base, entryPoint, q, index, support, false)); for (int lc = entryPoint->level; lc >= 1; lc--) { - w = HnswSearchLayer(base, &q, ep, 1, lc, index, support, m, false, NULL); + w = HnswSearchLayer(base, q, ep, 1, lc, index, support, m, false, NULL, NULL, NULL, true, NULL); ep = w; } - return HnswSearchLayer(base, &q, ep, hnsw_ef_search, 0, index, support, m, false, NULL); + return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, support, m, false, NULL, &so->v, hnsw_iterative_search != HNSW_ITERATIVE_SEARCH_OFF ? &so->discarded : NULL, true, &so->tuples); +} + +/* + * Resume scan at ground level with discarded candidates + */ +static List * +ResumeScanItems(IndexScanDesc scan) +{ + HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + Relation index = scan->indexRelation; + List *ep = NIL; + char *base = NULL; + int batch_size = hnsw_ef_search; + + if (pairingheap_is_empty(so->discarded)) + return NIL; + + /* Get next batch of candidates */ + for (int i = 0; i < batch_size; i++) + { + HnswSearchCandidate *sc; + + if (pairingheap_is_empty(so->discarded)) + break; + + sc = HnswGetSearchCandidate(w_node, pairingheap_remove_first(so->discarded)); + + ep = lappend(ep, sc); + } + + return HnswSearchLayer(base, &so->q, ep, batch_size, 0, index, &so->support, so->m, false, NULL, &so->v, &so->discarded, false, &so->tuples); } /* @@ -83,6 +117,8 @@ hnswbeginscan(Relation index, int nkeys, int norderbys) so = (HnswScanOpaque) palloc(sizeof(HnswScanOpaqueData)); so->typeInfo = HnswGetTypeInfo(index); so->first = true; + so->v.tids = NULL; + so->discarded = NULL; so->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, "Hnsw scan temporary context", ALLOCSET_DEFAULT_SIZES); @@ -103,7 +139,15 @@ hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int no { HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + if (so->v.tids != NULL) + tidhash_reset(so->v.tids); + + if (so->discarded != NULL) + pairingheap_reset(so->discarded); + so->first = true; + so->tuples = 0; + so->previousDistance = -INFINITY; MemoryContextReset(so->tmpCtx); if (keys && scan->numberOfKeys > 0) @@ -165,22 +209,100 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir) #endif } - while (list_length(so->w) > 0) + for (;;) { char *base = NULL; - HnswSearchCandidate *sc = llast(so->w); - HnswElement element = HnswPtrAccess(base, sc->element); + HnswSearchCandidate *sc; + HnswElement element; ItemPointer heaptid; + if (list_length(so->w) == 0) + { + if (hnsw_iterative_search == HNSW_ITERATIVE_SEARCH_OFF) + break; + + /* Empty index */ + if (so->discarded == NULL) + break; + + /* Reached max number of additional tuples */ + if (hnsw_iterative_search_max_tuples != -1 && so->tuples >= hnsw_iterative_search_max_tuples) + { + if (pairingheap_is_empty(so->discarded)) + break; + + /* Return remaining tuples */ + so->w = lappend(so->w, HnswGetSearchCandidate(w_node, pairingheap_remove_first(so->discarded))); + } + /* Prevent scans from consuming too much memory */ + else if (MemoryContextMemAllocated(so->tmpCtx, false) > (Size) work_mem * 1024L) + { + if (pairingheap_is_empty(so->discarded)) + { + ereport(DEBUG1, + (errmsg("hnsw index scan exceeded work_mem after " INT64_FORMAT " tuples", so->tuples), + errhint("Increase work_mem to scan more tuples."))); + + break; + } + + /* Return remaining tuples */ + so->w = lappend(so->w, HnswGetSearchCandidate(w_node, pairingheap_remove_first(so->discarded))); + } + else + { + /* + * Locking ensures when neighbors are read, the elements they + * reference will not be deleted (and replaced) during the + * iteration. + * + * Elements loaded into memory on previous iterations may have + * been deleted (and replaced), so when reading neighbors, the + * element version must be checked. + */ + LockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock); + + so->w = ResumeScanItems(scan); + + UnlockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock); + +#if defined(HNSW_MEMORY) + elog(INFO, "memory: %zu KB", MemoryContextMemAllocated(so->tmpCtx, false) / 1024); +#endif + } + + if (list_length(so->w) == 0) + break; + } + + sc = llast(so->w); + element = HnswPtrAccess(base, sc->element); + /* Move to next element if no valid heap TIDs */ if (element->heaptidsLength == 0) { so->w = list_delete_last(so->w); + + /* Mark memory as free for next iteration */ + if (hnsw_iterative_search != HNSW_ITERATIVE_SEARCH_OFF) + { + pfree(element); + pfree(sc); + } + continue; } heaptid = &element->heaptids[--element->heaptidsLength]; + if (hnsw_iterative_search == HNSW_ITERATIVE_SEARCH_STRICT) + { + if (sc->distance < so->previousDistance) + continue; + + so->previousDistance = sc->distance; + } + MemoryContextSwitchTo(oldCtx); scan->xs_heaptid = *heaptid; diff --git a/src/hnswutils.c b/src/hnswutils.c index c51fe28..732fcd7 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -251,6 +251,8 @@ HnswInitElement(char *base, ItemPointer heaptid, int m, double ml, int maxLevel, element->level = level; element->deleted = 0; + /* Start at one to make it easier to find issues */ + element->version = 1; HnswInitNeighbors(base, element, m, allocator); @@ -430,6 +432,7 @@ HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element) etup->type = HNSW_ELEMENT_TUPLE_TYPE; etup->level = element->level; etup->deleted = 0; + etup->version = element->version; for (int i = 0; i < HNSW_HEAPTIDS; i++) { if (i < element->heaptidsLength) @@ -472,6 +475,7 @@ HnswSetNeighborTuple(char *base, HnswNeighborTuple ntup, HnswElement e, int m) } ntup->count = idx; + ntup->version = e->version; } /* @@ -482,6 +486,7 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe { element->level = etup->level; element->deleted = etup->deleted; + element->version = etup->version; element->neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid); element->neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid); element->heaptidsLength = 0; @@ -608,6 +613,21 @@ CompareNearestCandidates(const pairingheap_node *a, const pairingheap_node *b, v return 0; } +/* + * Compare discarded candidate distances + */ +static int +CompareNearestDiscardedCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg) +{ + if (HnswGetSearchCandidateConst(w_node, a)->distance < HnswGetSearchCandidateConst(w_node, b)->distance) + return 1; + + if (HnswGetSearchCandidateConst(w_node, a)->distance > HnswGetSearchCandidateConst(w_node, b)->distance) + return -1; + + return 0; +} + /* * Compare candidate distances */ @@ -728,8 +748,11 @@ HnswLoadNeighborTids(HnswElement element, ItemPointerData *indextids, Relation i ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno)); - /* Ensure expected neighbors */ - if (ntup->count != (element->level + 2) * m) + /* + * Ensure the neighbor tuple has not been deleted or replaced between + * index scan iterations + */ + if (ntup->version != element->version || ntup->count != (element->level + 2) * m) { UnlockReleaseBuffer(buf); return false; @@ -775,13 +798,13 @@ HnswLoadUnvisitedFromDisk(HnswElement element, HnswUnvisited * unvisited, int *u * Algorithm 2 from paper */ List * -HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement) +HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement, visited_hash * v, pairingheap **discarded, bool initVisited, int64 *tuples) { List *w = NIL; pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL); pairingheap *W = pairingheap_allocate(CompareFurthestCandidates, NULL); int wlen = 0; - visited_hash v; + visited_hash vh; ListCell *lc2; HnswNeighborArray *localNeighborhood = NULL; Size neighborhoodSize = 0; @@ -790,7 +813,19 @@ HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation in int unvisitedLength; bool inMemory = index == NULL; - InitVisited(base, &v, inMemory, ef, m); + if (v == NULL) + { + v = &vh; + initVisited = true; + } + + if (initVisited) + { + InitVisited(base, v, inMemory, ef, m); + + if (discarded != NULL) + *discarded = pairingheap_allocate(CompareNearestDiscardedCandidates, NULL); + } /* Create local memory for neighborhood if needed */ if (inMemory) @@ -805,7 +840,13 @@ HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation in HnswSearchCandidate *sc = (HnswSearchCandidate *) lfirst(lc2); bool found; - AddToVisited(base, &v, sc->element, inMemory, &found); + if (initVisited) + { + AddToVisited(base, v, sc->element, inMemory, &found); + + if (tuples != NULL) + (*tuples)++; + } pairingheap_add(C, &sc->c_node); pairingheap_add(W, &sc->w_node); @@ -831,9 +872,12 @@ HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation in cElement = HnswPtrAccess(base, c->element); if (inMemory) - HnswLoadUnvisitedFromMemory(base, cElement, unvisited, &unvisitedLength, &v, lc, localNeighborhood, neighborhoodSize); + HnswLoadUnvisitedFromMemory(base, cElement, unvisited, &unvisitedLength, v, lc, localNeighborhood, neighborhoodSize); else - HnswLoadUnvisitedFromDisk(cElement, unvisited, &unvisitedLength, &v, index, m, lm, lc); + HnswLoadUnvisitedFromDisk(cElement, unvisited, &unvisitedLength, v, index, m, lm, lc); + + if (tuples != NULL) + (*tuples) += unvisitedLength; for (int i = 0; i < unvisitedLength; i++) { @@ -857,16 +901,25 @@ HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation in /* Avoid any allocations if not adding */ eElement = NULL; - HnswLoadElementImpl(blkno, offno, &eDistance, q, index, support, inserting, alwaysAdd ? NULL : &f->distance, &eElement); + HnswLoadElementImpl(blkno, offno, &eDistance, q, index, support, inserting, alwaysAdd || discarded != NULL ? NULL : &f->distance, &eElement); if (eElement == NULL) continue; } - if (!(eDistance < f->distance || alwaysAdd)) - continue; + if (eElement == NULL || !(eDistance < f->distance || alwaysAdd)) + { + if (discarded != NULL) + { + /* Create a new candidate */ + e = palloc(sizeof(HnswSearchCandidate)); + HnswPtrStore(base, e->element, eElement); + e->distance = eDistance; + pairingheap_add(*discarded, &e->w_node); + } - Assert(!eElement->deleted); + continue; + } /* Make robust to issues */ if (eElement->level < lc) @@ -890,7 +943,12 @@ HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation in /* No need to decrement wlen */ if (wlen > ef) - pairingheap_remove_first(W); + { + HnswSearchCandidate *d = HnswGetSearchCandidate(w_node, pairingheap_remove_first(W)); + + if (discarded != NULL) + pairingheap_add(*discarded, &d->w_node); + } } } } @@ -1225,7 +1283,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint /* 1st phase: greedy search to insert level */ for (int lc = entryLevel; lc >= level + 1; lc--) { - w = HnswSearchLayer(base, &q, ep, 1, lc, index, support, m, true, skipElement); + w = HnswSearchLayer(base, &q, ep, 1, lc, index, support, m, true, skipElement, NULL, NULL, true, NULL); ep = w; } @@ -1244,7 +1302,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint List *lw = NIL; ListCell *lc2; - w = HnswSearchLayer(base, &q, ep, efConstruction, lc, index, support, m, true, skipElement); + w = HnswSearchLayer(base, &q, ep, efConstruction, lc, index, support, m, true, skipElement, NULL, NULL, true, NULL); /* Convert search candidates to candidates */ foreach(lc2, w) diff --git a/src/hnswvacuum.c b/src/hnswvacuum.c index d3cdf68..251d9d9 100644 --- a/src/hnswvacuum.c +++ b/src/hnswvacuum.c @@ -527,6 +527,14 @@ MarkDeleted(HnswVacuumState * vacuumstate) for (int i = 0; i < ntup->count; i++) ItemPointerSetInvalid(&ntup->indextids[i]); + /* Increment version */ + /* This is used to avoid incorrect reads for iterative scans */ + /* Reserve some bits for future use */ + etup->version++; + if (etup->version > 15) + etup->version = 1; + ntup->version = etup->version; + /* * We modified the tuples in place, no need to call * PageIndexTupleOverwrite diff --git a/test/t/043_hnsw_iterative_search.pl b/test/t/043_hnsw_iterative_search.pl new file mode 100644 index 0000000..6905fc4 --- /dev/null +++ b/test/t/043_hnsw_iterative_search.pl @@ -0,0 +1,67 @@ +use strict; +use warnings FATAL => 'all'; +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use Test::More; + +my $dim = 3; +my $array_sql = join(",", ('random()') x $dim); + +# Initialize node +my $node = PostgreSQL::Test::Cluster->new('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4 PRIMARY KEY, v vector($dim));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 100000) i;" +); +$node->safe_psql("postgres", qq( + SET maintenance_work_mem = '128MB'; + SET max_parallel_maintenance_workers = 2; + CREATE INDEX ON tst USING hnsw (v vector_l2_ops) +)); + +my $count = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.iterative_search = on; + SET work_mem = '8MB'; + SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst LIMIT 1) LIMIT 11) t; +)); +is($count, 10); + +foreach ((30000, 50000, 70000)) +{ + my $max_tuples = $_; + my $expected = $max_tuples / 10000; + my $sum = 0; + + for my $i (1 .. 20) + { + $count = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.iterative_search = on; + SET hnsw.iterative_search_max_tuples = $max_tuples; + SET work_mem = '8MB'; + SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst WHERE i = $i) LIMIT 11) t; + )); + $sum += $count; + } + + my $avg = $sum / 20; + cmp_ok($avg, '>', $expected - 2); + cmp_ok($avg, '<', $expected + 2); +} + +my ($ret, $stdout, $stderr) = $node->psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.iterative_search = on; + SET client_min_messages = debug1; + SET work_mem = '2MB'; + SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst LIMIT 1) LIMIT 11) t; +)); +like($stderr, qr/hnsw index scan exceeded work_mem after \d+ tuples/); + +done_testing(); diff --git a/test/t/044_hnsw_iterative_search_recall.pl b/test/t/044_hnsw_iterative_search_recall.pl new file mode 100644 index 0000000..8bedc32 --- /dev/null +++ b/test/t/044_hnsw_iterative_search_recall.pl @@ -0,0 +1,131 @@ +use strict; +use warnings FATAL => 'all'; +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; +my $dim = 3; +my $array_sql = join(",", ('random()') x $dim); +my @cs = (100, 1000); + +sub test_recall +{ + my ($c, $ef_search, $min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.ef_search = $ef_search; + SET hnsw.iterative_search = on; + EXPLAIN ANALYZE SELECT i FROM tst WHERE i % $c = 0 ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Scan using idx on tst/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.ef_search = $ef_search; + SET hnsw.iterative_search = on; + SELECT i FROM tst WHERE i % $c = 0 ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + + my @expected_ids = split("\n", $expected[$i]); + my %expected_set = map { $_ => 1 } @expected_ids; + + foreach (@actual_ids) + { + if (exists($expected_set{$_})) + { + $correct++; + } + } + + $total += $limit; + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = PostgreSQL::Test::Cluster->new('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 100000) i;" +); + +# Generate queries +for (1 .. 20) +{ + my @r = (); + for (1 .. $dim) + { + push(@r, rand()); + } + push(@queries, "[" . join(",", @r) . "]"); +} + +# Check each index type +my @operators = ("<->", "<=>"); +my @opclasses = ("vector_l2_ops", "vector_cosine_ops"); + +for my $i (0 .. $#operators) +{ + my $operator = $operators[$i]; + my $opclass = $opclasses[$i]; + + $node->safe_psql("postgres", qq( + SET maintenance_work_mem = '128MB'; + CREATE INDEX idx ON tst USING hnsw (v $opclass); + )); + + foreach (@cs) + { + my $c = $_; + + # Get exact results + @expected = (); + foreach (@queries) + { + my $res = $node->safe_psql("postgres", qq( + SET enable_indexscan = off; + WITH top AS ( + SELECT v $operator '$_' AS distance FROM tst WHERE i % $c = 0 ORDER BY distance LIMIT $limit + ) + SELECT i FROM tst WHERE (v $operator '$_') <= (SELECT MAX(distance) FROM top) + )); + push(@expected, $res); + } + + if ($c == 100) + { + test_recall($c, 40, 0.99, $operator); + } + else + { + if ($operator eq "<->") + { + test_recall($c, 40, 0.99, $operator); + } + else + { + test_recall($c, 40, 0.99, $operator); + } + } + } + + $node->safe_psql("postgres", "DROP INDEX idx;"); +} + +done_testing();