From af1727775d6635df9879793247d81beb9d8aad59 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Wed, 18 Sep 2024 14:55:58 -0700 Subject: [PATCH] Added streaming option for HNSW [skip ci] --- src/hnsw.c | 5 +++ src/hnsw.h | 15 ++++++++- src/hnswscan.c | 60 +++++++++++++++++++++++++++++++++--- src/hnswutils.c | 46 +++++++++++++++++---------- test/t/039_hnsw_streaming.pl | 30 ++++++++++++++++++ 5 files changed, 134 insertions(+), 22 deletions(-) create mode 100644 test/t/039_hnsw_streaming.pl diff --git a/src/hnsw.c b/src/hnsw.c index 72dd6a7..3e08eb1 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -17,6 +17,7 @@ #endif int hnsw_ef_search; +bool hnsw_streaming; int hnsw_lock_tranche_id; static relopt_kind hnsw_relopt_kind; @@ -75,6 +76,10 @@ HnswInit(void) "Valid range is 1..1000.", &hnsw_ef_search, HNSW_DEFAULT_EF_SEARCH, HNSW_MIN_EF_SEARCH, HNSW_MAX_EF_SEARCH, PGC_USERSET, 0, NULL, NULL, NULL); + DefineCustomBoolVariable("hnsw.streaming", "todo", + "todo", &hnsw_streaming, + HNSW_DEFAULT_STREAMING, PGC_USERSET, 0, NULL, NULL, NULL); + MarkGUCPrefixReserved("hnsw"); } diff --git a/src/hnsw.h b/src/hnsw.h index 480ad9f..3bc454e 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -42,6 +42,7 @@ #define HNSW_DEFAULT_EF_SEARCH 40 #define HNSW_MIN_EF_SEARCH 1 #define HNSW_MAX_EF_SEARCH 1000 +#define HNSW_DEFAULT_STREAMING false /* Tuple types */ #define HNSW_ELEMENT_TUPLE_TYPE 1 @@ -111,6 +112,7 @@ /* Variables */ extern int hnsw_ef_search; +extern bool hnsw_streaming; extern int hnsw_lock_tranche_id; typedef struct HnswElementData HnswElementData; @@ -328,11 +330,22 @@ typedef struct HnswNeighborTupleData typedef HnswNeighborTupleData * HnswNeighborTuple; +typedef union +{ + struct pointerhash_hash *pointers; + struct offsethash_hash *offsets; + struct tidhash_hash *tids; +} visited_hash; + typedef struct HnswScanOpaqueData { const HnswTypeInfo *typeInfo; bool first; List *w; + visited_hash v; + List *discarded; + Datum q; + int m; MemoryContext tmpCtx; /* Support functions */ @@ -378,7 +391,7 @@ bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value); Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); void HnswInit(void); -List *HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement); +List *HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement, visited_hash * v, List **discarded, bool initVisited); HnswElement HnswGetEntryPoint(Relation index); void HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint); void *HnswAlloc(HnswAllocator * allocator, Size size); diff --git a/src/hnswscan.c b/src/hnswscan.c index 0efbaa1..bcfefb5 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -26,6 +26,9 @@ GetScanItems(IndexScanDesc scan, Datum q) /* Get m and entry point */ HnswGetMetaPageInfo(index, &m, &entryPoint); + so->q = q; + so->m = m; + if (entryPoint == NULL) return NIL; @@ -33,11 +36,32 @@ GetScanItems(IndexScanDesc scan, Datum q) for (int lc = entryPoint->level; lc >= 1; lc--) { - w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, false, NULL); + w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, false, NULL, NULL, NULL, true); ep = w; } - return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL); + return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL, &so->v, &so->discarded, true); +} + +/* + * Resume scan at ground level with discarded candidates + */ +static List * +ResumeScanItems(IndexScanDesc scan) +{ + HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + Relation index = scan->indexRelation; + FmgrInfo *procinfo = so->procinfo; + Oid collation = so->collation; + List *ep; + char *base = NULL; + + if (list_length(so->discarded) == 0) + return NIL; + + ep = so->discarded; + so->discarded = NIL; + return HnswSearchLayer(base, so->q, ep, hnsw_ef_search, 0, index, procinfo, collation, so->m, false, NULL, &so->v, &so->discarded, false); } /* @@ -103,7 +127,10 @@ hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int no { HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + if (!so->first) + tidhash_reset(so->v.tids); so->first = true; + so->discarded = NIL; MemoryContextReset(so->tmpCtx); if (keys && scan->numberOfKeys > 0) @@ -165,13 +192,36 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir) #endif } - while (list_length(so->w) > 0) + for (;;) { char *base = NULL; - HnswCandidate *hc = llast(so->w); - HnswElement element = HnswPtrAccess(base, hc->element); + HnswCandidate *hc; + HnswElement element; ItemPointer heaptid; + if (list_length(so->w) == 0) + { + if (!hnsw_streaming) + break; + + /* TODO figure out locking */ + LockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock); + + so->w = ResumeScanItems(scan); + + UnlockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock); + +#if defined(HNSW_MEMORY) && PG_VERSION_NUM >= 130000 + elog(INFO, "memory: %zu MB", MemoryContextMemAllocated(so->tmpCtx, false) / (1024 * 1024)); +#endif + + if (list_length(so->w) == 0) + break; + } + + hc = llast(so->w); + element = HnswPtrAccess(base, hc->element); + /* Move to next element if no valid heap TIDs */ if (element->heaptidsLength == 0) { diff --git a/src/hnswutils.c b/src/hnswutils.c index 96c5026..371e42f 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -105,13 +105,6 @@ hash_offset(Size offset) #define SH_DEFINE #include "lib/simplehash.h" -typedef union -{ - pointerhash_hash *pointers; - offsethash_hash *offsets; - tidhash_hash *tids; -} visited_hash; - /* * Get the max number of connections in an upper layer for each element in the index */ @@ -721,18 +714,22 @@ CountElement(char *base, HnswElement skipElement, HnswCandidate * hc) * Algorithm 2 from paper */ List * -HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement) +HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement, visited_hash * v, List **discarded, bool initVisited) { List *w = NIL; pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL); pairingheap *W = pairingheap_allocate(CompareFurthestCandidates, NULL); int wlen = 0; - visited_hash v; + visited_hash v2; ListCell *lc2; HnswNeighborArray *neighborhoodData = NULL; Size neighborhoodSize = 0; - InitVisited(base, &v, index, ef, m); + if (v == NULL) + v = &v2; + + if (initVisited) + InitVisited(base, v, index, ef, m); /* Create local memory for neighborhood if needed */ if (index == NULL) @@ -747,7 +744,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F HnswCandidate *hc = (HnswCandidate *) lfirst(lc2); bool found; - AddToVisited(base, &v, hc, index, &found); + AddToVisited(base, v, hc, index, &found); pairingheap_add(C, &(CreatePairingHeapNode(hc)->ph_node)); pairingheap_add(W, &(CreatePairingHeapNode(hc)->ph_node)); @@ -793,7 +790,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F HnswCandidate *e = &neighborhood->items[i]; bool visited; - AddToVisited(base, &v, e, index, &visited); + AddToVisited(base, v, e, index, &visited); if (!visited) { @@ -806,13 +803,14 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F if (index == NULL) eDistance = GetCandidateDistance(base, e, q, procinfo, collation); else - HnswLoadElement(eElement, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd ? NULL : &f->distance); + HnswLoadElement(eElement, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd || discarded != NULL ? NULL : &f->distance); if (eDistance < f->distance || alwaysAdd) { HnswCandidate *ec; Assert(!eElement->deleted); + Assert(eElement->level >= lc); /* Make robust to issues */ if (eElement->level < lc) @@ -837,9 +835,25 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F /* No need to decrement wlen */ if (wlen > ef) - pairingheap_remove_first(W); + { + HnswCandidate *hc = ((HnswPairingHeapNode *) pairingheap_remove_first(W))->inner; + + if (discarded != NULL) + *discarded = lappend(*discarded, hc); + } } } + else if (discarded != NULL) + { + HnswCandidate *ec; + + /* Copy e */ + ec = palloc(sizeof(HnswCandidate)); + HnswPtrStore(base, ec->element, eElement); + ec->distance = eDistance; + + *discarded = lappend(*discarded, ec); + } } } } @@ -1230,7 +1244,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint /* 1st phase: greedy search to insert level */ for (int lc = entryLevel; lc >= level + 1; lc--) { - w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, true, skipElement); + w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, true, skipElement, NULL, NULL, true); ep = w; } @@ -1248,7 +1262,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint List *neighbors; List *lw; - w = HnswSearchLayer(base, q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement); + w = HnswSearchLayer(base, q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement, NULL, NULL, true); /* Elements being deleted or skipped can help with search */ /* but should be removed before selecting neighbors */ diff --git a/test/t/039_hnsw_streaming.pl b/test/t/039_hnsw_streaming.pl new file mode 100644 index 0000000..5379b56 --- /dev/null +++ b/test/t/039_hnsw_streaming.pl @@ -0,0 +1,30 @@ +use strict; +use warnings FATAL => 'all'; +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use Test::More; + +my $dim = 3; +my $array_sql = join(",", ('random()') x $dim); + +# Initialize node +my $node = PostgreSQL::Test::Cluster->new('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 10000) i;" +); +$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);"); + +my $count = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.streaming = on; + SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 1000 = 0 ORDER BY v <-> (SELECT v FROM tst LIMIT 1) LIMIT 11) t; +)); +is($count, 10); + +done_testing();