diff --git a/src/hnsw.h b/src/hnsw.h index 878e02d..442a9ea 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -363,6 +363,7 @@ typedef struct HnswScanOpaqueData List *discarded; Datum q; int m; + int64 tuples; MemoryContext tmpCtx; /* Support functions */ diff --git a/src/hnswscan.c b/src/hnswscan.c index 000c335..287341c 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -131,6 +131,7 @@ hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int no tidhash_reset(so->v.tids); so->first = true; so->discarded = NIL; + so->tuples = 0; MemoryContextReset(so->tmpCtx); if (keys && scan->numberOfKeys > 0) @@ -204,15 +205,34 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir) if (!hnsw_streaming) break; - /* - * Ensure vacuum does not mark tuples as deleted during an - * iteration - */ - LockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock); + if (MemoryContextMemAllocated(so->tmpCtx, false) > (work_mem * 1024L)) + { + if (list_length(so->discarded) == 0) + { + ereport(NOTICE, + (errmsg("iterative search exceeded work_mem after " INT64_FORMAT " tuples", so->tuples), + errhint("Increase work_mem to scan more tuples."))); - HnswBench("scan iteration", so->w = ResumeScanItems(scan)); + break; + } - UnlockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock); + /* Return remaining tuples and exit */ + /* TODO sort */ + so->w = so->discarded; + so->discarded = NIL; + } + else + { + /* + * Ensure vacuum does not mark tuples as deleted during an + * iteration + */ + LockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock); + + HnswBench("scan iteration", so->w = ResumeScanItems(scan)); + + UnlockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock); + } #if defined(HNSW_MEMORY) elog(INFO, "memory: %zu KB", MemoryContextMemAllocated(so->tmpCtx, false) / 1024); @@ -240,6 +260,8 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir) continue; } + so->tuples++; + heaptid = &element->heaptids[--element->heaptidsLength]; MemoryContextSwitchTo(oldCtx); diff --git a/test/t/039_hnsw_streaming.pl b/test/t/039_hnsw_streaming.pl index 5379b56..ecb5b5e 100644 --- a/test/t/039_hnsw_streaming.pl +++ b/test/t/039_hnsw_streaming.pl @@ -16,15 +16,28 @@ $node->start; $node->safe_psql("postgres", "CREATE EXTENSION vector;"); $node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));"); $node->safe_psql("postgres", - "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 10000) i;" + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 100000) i;" ); -$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);"); +$node->safe_psql("postgres", qq( + SET maintenance_work_mem = '128MB'; + SET max_parallel_maintenance_workers = 2; + CREATE INDEX ON tst USING hnsw (v vector_l2_ops) +)); my $count = $node->safe_psql("postgres", qq( SET enable_seqscan = off; SET hnsw.streaming = on; - SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 1000 = 0 ORDER BY v <-> (SELECT v FROM tst LIMIT 1) LIMIT 11) t; + SET work_mem = '8MB'; + SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst LIMIT 1) LIMIT 11) t; )); is($count, 10); +my ($ret, $stdout, $stderr) = $node->psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.streaming = on; + SET work_mem = '2MB'; + SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst LIMIT 1) LIMIT 11) t; +)); +like($stderr, qr/iterative search exceeded work_mem after \d+ tuples/); + done_testing();