diff --git a/src/hnsw.c b/src/hnsw.c index a7b1e5f..3309966 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -18,6 +18,7 @@ #endif int hnsw_ef_search; +bool hnsw_streaming; int hnsw_lock_tranche_id; static relopt_kind hnsw_relopt_kind; @@ -68,6 +69,13 @@ HnswInit(void) "Valid range is 1..1000.", &hnsw_ef_search, HNSW_DEFAULT_EF_SEARCH, HNSW_MIN_EF_SEARCH, HNSW_MAX_EF_SEARCH, PGC_USERSET, 0, NULL, NULL, NULL); + /* TODO Figure out name */ + DefineCustomBoolVariable("hnsw.streaming", "Use streaming mode", + NULL, &hnsw_streaming, + HNSW_DEFAULT_STREAMING, PGC_USERSET, 0, NULL, NULL, NULL); + + /* TODO Add option for limiting iterative search */ + MarkGUCPrefixReserved("hnsw"); } @@ -126,6 +134,8 @@ hnswcostestimate(PlannerInfo *root, IndexPath *path, double loop_count, /* Account for number of tuples (or entry level), m, and ef_search */ costs.numIndexTuples = (entryLevel + 2) * m; + /* TODO Adjust for selectivity for iterative scans */ + genericcostestimate(root, path, loop_count, &costs); /* Use total cost since most work happens before first tuple is returned */ diff --git a/src/hnsw.h b/src/hnsw.h index 9fb650a..e6bfde5 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -12,6 +12,10 @@ #include "utils/sampling.h" #include "vector.h" +#ifdef HNSW_BENCH +#include "portability/instr_time.h" +#endif + #define HNSW_MAX_DIM 2000 #define HNSW_MAX_NNZ 1000 @@ -42,6 +46,7 @@ #define HNSW_DEFAULT_EF_SEARCH 40 #define HNSW_MIN_EF_SEARCH 1 #define HNSW_MAX_EF_SEARCH 1000 +#define HNSW_DEFAULT_STREAMING false /* Tuple types */ #define HNSW_ELEMENT_TUPLE_TYPE 1 @@ -68,6 +73,21 @@ #define HnswPageGetOpaque(page) ((HnswPageOpaque) PageGetSpecialPointer(page)) #define HnswPageGetMeta(page) ((HnswMetaPageData *) PageGetContents(page)) +#ifdef HNSW_BENCH +#define HnswBench(name, code) \ + do { \ + instr_time start; \ + instr_time duration; \ + INSTR_TIME_SET_CURRENT(start); \ + (code); \ + INSTR_TIME_SET_CURRENT(duration); \ + INSTR_TIME_SUBTRACT(duration, start); \ + elog(INFO, "%s: %.3f ms", name, INSTR_TIME_GET_MILLISEC(duration)); \ + } while (0) +#else +#define HnswBench(name, code) (code) +#endif + #if PG_VERSION_NUM >= 150000 #define RandomDouble() pg_prng_double(&pg_global_prng_state) #define SeedRandom(seed) pg_prng_seed(&pg_global_prng_state, seed) @@ -106,6 +126,7 @@ /* Variables */ extern int hnsw_ef_search; +extern bool hnsw_streaming; extern int hnsw_lock_tranche_id; typedef struct HnswElementData HnswElementData; @@ -129,6 +150,7 @@ struct HnswElementData uint8 heaptidsLength; uint8 level; uint8 deleted; + uint8 version; uint32 hash; HnswNeighborsPtr neighbors; BlockNumber blkno; @@ -163,6 +185,9 @@ typedef struct HnswSearchCandidate float distance; } HnswSearchCandidate; +#define HnswGetSearchCandidate(membername, ptr) pairingheap_container(HnswSearchCandidate, membername, ptr) +#define HnswGetSearchCandidateConst(membername, ptr) pairingheap_const_container(HnswSearchCandidate, membername, ptr) + /* HNSW index options */ typedef struct HnswOptions { @@ -306,10 +331,10 @@ typedef struct HnswElementTupleData uint8 type; uint8 level; uint8 deleted; - uint8 unused; + uint8 version; ItemPointerData heaptids[HNSW_HEAPTIDS]; ItemPointerData neighbortid; - uint16 unused2; + uint16 unused; Vector data; } HnswElementTupleData; @@ -318,18 +343,30 @@ typedef HnswElementTupleData * HnswElementTuple; typedef struct HnswNeighborTupleData { uint8 type; - uint8 unused; + uint8 version; uint16 count; ItemPointerData indextids[FLEXIBLE_ARRAY_MEMBER]; } HnswNeighborTupleData; typedef HnswNeighborTupleData * HnswNeighborTuple; +typedef union +{ + struct pointerhash_hash *pointers; + struct offsethash_hash *offsets; + struct tidhash_hash *tids; +} visited_hash; + typedef struct HnswScanOpaqueData { const HnswTypeInfo *typeInfo; bool first; List *w; + visited_hash v; + pairingheap *discarded; + Datum q; + int m; + int64 tuples; MemoryContext tmpCtx; /* Support functions */ @@ -375,7 +412,7 @@ bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value); Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); void HnswInit(void); -List *HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement); +List *HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement, visited_hash * v, pairingheap **discarded, bool initVisited); HnswElement HnswGetEntryPoint(Relation index); void HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint); void *HnswAlloc(HnswAllocator * allocator, Size size); diff --git a/src/hnswinsert.c b/src/hnswinsert.c index 2dce16f..fdc18c0 100644 --- a/src/hnswinsert.c +++ b/src/hnswinsert.c @@ -36,7 +36,7 @@ GetInsertPage(Relation index) * Check for a free offset */ static bool -HnswFreeOffset(Relation index, Buffer buf, Page page, HnswElement element, Size etupSize, Size ntupSize, Buffer *nbuf, Page *npage, OffsetNumber *freeOffno, OffsetNumber *freeNeighborOffno, BlockNumber *newInsertPage) +HnswFreeOffset(Relation index, Buffer buf, Page page, HnswElement element, Size etupSize, Size ntupSize, Buffer *nbuf, Page *npage, OffsetNumber *freeOffno, OffsetNumber *freeNeighborOffno, BlockNumber *newInsertPage, uint8 *tupleVersion) { OffsetNumber offno; OffsetNumber maxoffno = PageGetMaxOffsetNumber(page); @@ -98,6 +98,7 @@ HnswFreeOffset(Relation index, Buffer buf, Page page, HnswElement element, Size { *freeOffno = offno; *freeNeighborOffno = neighborOffno; + *tupleVersion = etup->version; return true; } else if (*nbuf != buf) @@ -153,6 +154,7 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B OffsetNumber freeOffno = InvalidOffsetNumber; OffsetNumber freeNeighborOffno = InvalidOffsetNumber; BlockNumber newInsertPage = InvalidBlockNumber; + uint8 tupleVersion; char *base = NULL; /* Calculate sizes */ @@ -202,7 +204,7 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B } /* Next, try space from a deleted element */ - if (HnswFreeOffset(index, buf, page, e, etupSize, ntupSize, &nbuf, &npage, &freeOffno, &freeNeighborOffno, &newInsertPage)) + if (HnswFreeOffset(index, buf, page, e, etupSize, ntupSize, &nbuf, &npage, &freeOffno, &freeNeighborOffno, &newInsertPage, &tupleVersion)) { if (nbuf != buf) { @@ -212,6 +214,10 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B npage = GenericXLogRegisterBuffer(state, nbuf, 0); } + /* Set tuple version */ + etup->version = tupleVersion; + ntup->version = tupleVersion; + break; } diff --git a/src/hnswscan.c b/src/hnswscan.c index 30815af..2e594b4 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -26,6 +26,9 @@ GetScanItems(IndexScanDesc scan, Datum q) /* Get m and entry point */ HnswGetMetaPageInfo(index, &m, &entryPoint); + so->q = q; + so->m = m; + if (entryPoint == NULL) return NIL; @@ -33,11 +36,44 @@ GetScanItems(IndexScanDesc scan, Datum q) for (int lc = entryPoint->level; lc >= 1; lc--) { - w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, false, NULL); + w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, false, NULL, NULL, NULL, true); ep = w; } - return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL); + return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL, &so->v, &so->discarded, true); +} + +/* + * Resume scan at ground level with discarded candidates + */ +static List * +ResumeScanItems(IndexScanDesc scan) +{ + HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + Relation index = scan->indexRelation; + FmgrInfo *procinfo = so->procinfo; + Oid collation = so->collation; + List *ep = NIL; + char *base = NULL; + int batch_size = hnsw_ef_search; + + if (pairingheap_is_empty(so->discarded)) + return NIL; + + /* Get next batch of candidates */ + for (int i = 0; i < batch_size; i++) + { + HnswSearchCandidate *hc; + + if (pairingheap_is_empty(so->discarded)) + break; + + hc = HnswGetSearchCandidate(w_node, pairingheap_remove_first(so->discarded)); + + ep = lappend(ep, hc); + } + + return HnswSearchLayer(base, so->q, ep, batch_size, 0, index, procinfo, collation, so->m, false, NULL, &so->v, &so->discarded, false); } /* @@ -103,7 +139,13 @@ hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int no { HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + if (!so->first) + { + pairingheap_reset(so->discarded); + tidhash_reset(so->v.tids); + } so->first = true; + so->tuples = 0; MemoryContextReset(so->tmpCtx); if (keys && scan->numberOfKeys > 0) @@ -153,7 +195,7 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir) */ LockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock); - so->w = GetScanItems(scan, value); + HnswBench("scan iteration", so->w = GetScanItems(scan, value)); /* Release shared lock */ UnlockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock); @@ -165,20 +207,79 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir) #endif } - while (list_length(so->w) > 0) + for (;;) { char *base = NULL; - HnswSearchCandidate *hc = llast(so->w); - HnswElement element = HnswPtrAccess(base, hc->element); + HnswSearchCandidate *hc; + HnswElement element; ItemPointer heaptid; + if (list_length(so->w) == 0) + { + if (!hnsw_streaming) + break; + + /* Prevent scans from consuming too much memory */ + if (MemoryContextMemAllocated(so->tmpCtx, false) > (Size) work_mem * 1024L) + { + if (pairingheap_is_empty(so->discarded)) + { + ereport(NOTICE, + (errmsg("hnsw iterative search exceeded work_mem after " INT64_FORMAT " tuples", so->tuples), + errhint("Increase work_mem to scan more tuples."))); + + break; + } + + /* Return remaining tuples */ + so->w = lappend(so->w, HnswGetSearchCandidate(w_node, pairingheap_remove_first(so->discarded))); + } + else + { + /* + * Locking ensures when neighbors are read, the elements they + * reference will not be deleted (and replaced) during the + * iteration. + * + * Elements loaded into memory on previous iterations may have + * been deleted (and replaced), so when reading neighbors, the + * element version must be checked. + */ + LockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock); + + HnswBench("scan iteration", so->w = ResumeScanItems(scan)); + + UnlockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock); + +#if defined(HNSW_MEMORY) + elog(INFO, "memory: %zu KB", MemoryContextMemAllocated(so->tmpCtx, false) / 1024); +#endif + } + + if (list_length(so->w) == 0) + break; + } + + hc = llast(so->w); + element = HnswPtrAccess(base, hc->element); + /* Move to next element if no valid heap TIDs */ if (element->heaptidsLength == 0) { so->w = list_delete_last(so->w); + + /* Mark memory as free for next iteration */ + if (hnsw_streaming) + { + pfree(element); + pfree(hc); + } + continue; } + so->tuples++; + heaptid = &element->heaptids[--element->heaptidsLength]; MemoryContextSwitchTo(oldCtx); diff --git a/src/hnswutils.c b/src/hnswutils.c index ac1e7de..1246872 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -100,13 +100,6 @@ hash_offset(Size offset) #define SH_DEFINE #include "lib/simplehash.h" -typedef union -{ - pointerhash_hash *pointers; - offsethash_hash *offsets; - tidhash_hash *tids; -} visited_hash; - typedef union { HnswElement element; @@ -253,6 +246,8 @@ HnswInitElement(char *base, ItemPointer heaptid, int m, double ml, int maxLevel, element->level = level; element->deleted = 0; + /* Start at one to make it easier to find issues */ + element->version = 1; HnswInitNeighbors(base, element, m, allocator); @@ -405,6 +400,7 @@ HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element) etup->type = HNSW_ELEMENT_TUPLE_TYPE; etup->level = element->level; etup->deleted = 0; + etup->version = element->version; for (int i = 0; i < HNSW_HEAPTIDS; i++) { if (i < element->heaptidsLength) @@ -447,6 +443,7 @@ HnswSetNeighborTuple(char *base, HnswNeighborTuple ntup, HnswElement e, int m) } ntup->count = idx; + ntup->version = e->version; } /* @@ -520,6 +517,7 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe { element->level = etup->level; element->deleted = etup->deleted; + element->version = etup->version; element->neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid); element->neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid); element->heaptidsLength = 0; @@ -621,9 +619,6 @@ HnswEntryCandidate(char *base, HnswElement entryPoint, Datum q, Relation index, return hc; } -#define HnswGetSearchCandidate(membername, ptr) pairingheap_container(HnswSearchCandidate, membername, ptr) -#define HnswGetSearchCandidateConst(membername, ptr) pairingheap_const_container(HnswSearchCandidate, membername, ptr) - /* * Compare candidate distances */ @@ -639,6 +634,21 @@ CompareNearestCandidates(const pairingheap_node *a, const pairingheap_node *b, v return 0; } +/* + * Compare discarded candidate distances + */ +static int +CompareNearestDiscardedCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg) +{ + if (HnswGetSearchCandidateConst(w_node, a)->distance < HnswGetSearchCandidateConst(w_node, b)->distance) + return 1; + + if (HnswGetSearchCandidateConst(w_node, a)->distance > HnswGetSearchCandidateConst(w_node, b)->distance) + return -1; + + return 0; +} + /* * Compare candidate distances */ @@ -754,20 +764,30 @@ HnswLoadUnvisitedFromDisk(HnswElement element, HnswUnvisited * unvisited, int *u int start; ItemPointerData indextids[HNSW_MAX_M * 2]; + *unvisitedLength = 0; + buf = ReadBuffer(index, element->neighborPage); LockBuffer(buf, BUFFER_LOCK_SHARE); page = BufferGetPage(buf); ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno)); - start = (element->level - lc) * m; + + /* + * Ensure the neighbor tuple has not been deleted or replaced between + * index scan iterations + */ + if (ntup->version != element->version) + { + UnlockReleaseBuffer(buf); + return; + } /* Copy to minimize lock time */ + start = (element->level - lc) * m; memcpy(&indextids, ntup->indextids + start, lm * sizeof(ItemPointerData)); UnlockReleaseBuffer(buf); - *unvisitedLength = 0; - for (int i = 0; i < lm; i++) { ItemPointer indextid = &indextids[i]; @@ -787,13 +807,13 @@ HnswLoadUnvisitedFromDisk(HnswElement element, HnswUnvisited * unvisited, int *u * Algorithm 2 from paper */ List * -HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement) +HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement, visited_hash * v, pairingheap **discarded, bool initVisited) { List *w = NIL; pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL); pairingheap *W = pairingheap_allocate(CompareFurthestCandidates, NULL); int wlen = 0; - visited_hash v; + visited_hash vh; ListCell *lc2; HnswNeighborArray *localNeighborhood = NULL; Size neighborhoodSize = 0; @@ -801,7 +821,19 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F HnswUnvisited *unvisited = palloc(lm * sizeof(HnswUnvisited)); int unvisitedLength; - InitVisited(base, &v, index, ef, m); + if (v == NULL) + { + v = &vh; + initVisited = true; + } + + if (initVisited) + { + InitVisited(base, v, index, ef, m); + + if (discarded != NULL) + *discarded = pairingheap_allocate(CompareNearestDiscardedCandidates, NULL); + } /* Create local memory for neighborhood if needed */ if (index == NULL) @@ -816,7 +848,8 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F HnswSearchCandidate *hc = (HnswSearchCandidate *) lfirst(lc2); bool found; - AddToVisited(base, &v, hc->element, index, &found); + if (initVisited) + AddToVisited(base, v, hc->element, index, &found); pairingheap_add(C, &hc->c_node); pairingheap_add(W, &hc->w_node); @@ -842,9 +875,9 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F cElement = HnswPtrAccess(base, c->element); if (index == NULL) - HnswLoadUnvisitedFromMemory(base, cElement, unvisited, &unvisitedLength, &v, lc, localNeighborhood, neighborhoodSize); + HnswLoadUnvisitedFromMemory(base, cElement, unvisited, &unvisitedLength, v, lc, localNeighborhood, neighborhoodSize); else - HnswLoadUnvisitedFromDisk(cElement, unvisited, &unvisitedLength, &v, index, m, lm, lc); + HnswLoadUnvisitedFromDisk(cElement, unvisited, &unvisitedLength, v, index, m, lm, lc); for (int i = 0; i < unvisitedLength; i++) { @@ -868,16 +901,22 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F /* Avoid any allocations if not adding */ eElement = NULL; - HnswLoadElementImpl(blkno, offno, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd ? NULL : &f->distance, &eElement); - - if (eElement == NULL) - continue; + HnswLoadElementImpl(blkno, offno, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd || discarded != NULL ? NULL : &f->distance, &eElement); } - if (!(eDistance < f->distance || alwaysAdd)) - continue; + if (eElement == NULL || !(eDistance < f->distance || alwaysAdd)) + { + if (discarded != NULL) + { + /* Create a new candidate */ + e = palloc(sizeof(HnswSearchCandidate)); + HnswPtrStore(base, e->element, eElement); + e->distance = eDistance; + pairingheap_add(*discarded, &e->w_node); + } - Assert(!eElement->deleted); + continue; + } /* Make robust to issues */ if (eElement->level < lc) @@ -901,7 +940,12 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F /* No need to decrement wlen */ if (wlen > ef) - pairingheap_remove_first(W); + { + HnswSearchCandidate *d = HnswGetSearchCandidate(w_node, pairingheap_remove_first(W)); + + if (discarded != NULL) + pairingheap_add(*discarded, &d->w_node); + } } } } @@ -1274,7 +1318,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint /* 1st phase: greedy search to insert level */ for (int lc = entryLevel; lc >= level + 1; lc--) { - w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, true, skipElement); + w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, true, skipElement, NULL, NULL, true); ep = w; } @@ -1293,7 +1337,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint List *lw = NIL; ListCell *lc2; - w = HnswSearchLayer(base, q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement); + w = HnswSearchLayer(base, q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement, NULL, NULL, true); /* Convert search candidates to candidates */ foreach(lc2, w) diff --git a/src/hnswvacuum.c b/src/hnswvacuum.c index 67cc645..c4a777c 100644 --- a/src/hnswvacuum.c +++ b/src/hnswvacuum.c @@ -527,6 +527,14 @@ MarkDeleted(HnswVacuumState * vacuumstate) for (int i = 0; i < ntup->count; i++) ItemPointerSetInvalid(&ntup->indextids[i]); + /* Increment version */ + /* This is used to avoid incorrect reads for iterative scans */ + /* Reserve some bits for future use */ + etup->version++; + if (etup->version > 15) + etup->version = 1; + ntup->version = etup->version; + /* * We modified the tuples in place, no need to call * PageIndexTupleOverwrite diff --git a/test/t/039_hnsw_streaming.pl b/test/t/039_hnsw_streaming.pl new file mode 100644 index 0000000..ecb5b5e --- /dev/null +++ b/test/t/039_hnsw_streaming.pl @@ -0,0 +1,43 @@ +use strict; +use warnings FATAL => 'all'; +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use Test::More; + +my $dim = 3; +my $array_sql = join(",", ('random()') x $dim); + +# Initialize node +my $node = PostgreSQL::Test::Cluster->new('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 100000) i;" +); +$node->safe_psql("postgres", qq( + SET maintenance_work_mem = '128MB'; + SET max_parallel_maintenance_workers = 2; + CREATE INDEX ON tst USING hnsw (v vector_l2_ops) +)); + +my $count = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.streaming = on; + SET work_mem = '8MB'; + SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst LIMIT 1) LIMIT 11) t; +)); +is($count, 10); + +my ($ret, $stdout, $stderr) = $node->psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.streaming = on; + SET work_mem = '2MB'; + SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst LIMIT 1) LIMIT 11) t; +)); +like($stderr, qr/iterative search exceeded work_mem after \d+ tuples/); + +done_testing(); diff --git a/test/t/040_hnsw_streaming_recall.pl b/test/t/040_hnsw_streaming_recall.pl new file mode 100644 index 0000000..2c2eca4 --- /dev/null +++ b/test/t/040_hnsw_streaming_recall.pl @@ -0,0 +1,131 @@ +use strict; +use warnings FATAL => 'all'; +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; +my $dim = 3; +my $array_sql = join(",", ('random()') x $dim); +my @cs = (100, 1000); + +sub test_recall +{ + my ($c, $ef_search, $min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.ef_search = $ef_search; + SET hnsw.streaming = on; + EXPLAIN ANALYZE SELECT i FROM tst WHERE i % $c = 0 ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Scan using idx on tst/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.ef_search = $ef_search; + SET hnsw.streaming = on; + SELECT i FROM tst WHERE i % $c = 0 ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + + my @expected_ids = split("\n", $expected[$i]); + my %expected_set = map { $_ => 1 } @expected_ids; + + foreach (@actual_ids) + { + if (exists($expected_set{$_})) + { + $correct++; + } + } + + $total += $limit; + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = PostgreSQL::Test::Cluster->new('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 100000) i;" +); + +# Generate queries +for (1 .. 20) +{ + my @r = (); + for (1 .. $dim) + { + push(@r, rand()); + } + push(@queries, "[" . join(",", @r) . "]"); +} + +# Check each index type +my @operators = ("<->", "<=>"); +my @opclasses = ("vector_l2_ops", "vector_cosine_ops"); + +for my $i (0 .. $#operators) +{ + my $operator = $operators[$i]; + my $opclass = $opclasses[$i]; + + $node->safe_psql("postgres", qq( + SET maintenance_work_mem = '128MB'; + CREATE INDEX idx ON tst USING hnsw (v $opclass); + )); + + foreach (@cs) + { + my $c = $_; + + # Get exact results + @expected = (); + foreach (@queries) + { + my $res = $node->safe_psql("postgres", qq( + SET enable_indexscan = off; + WITH top AS ( + SELECT v $operator '$_' AS distance FROM tst WHERE i % $c = 0 ORDER BY distance LIMIT $limit + ) + SELECT i FROM tst WHERE (v $operator '$_') <= (SELECT MAX(distance) FROM top) + )); + push(@expected, $res); + } + + if ($c == 100) + { + test_recall($c, 40, 0.99, $operator); + } + else + { + if ($operator eq "<->") + { + test_recall($c, 40, 0.99, $operator); + } + else + { + test_recall($c, 40, 0.99, $operator); + } + } + } + + $node->safe_psql("postgres", "DROP INDEX idx;"); +} + +done_testing();