diff --git a/src/hnsw.c b/src/hnsw.c index 3309966..f4a8320 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -18,6 +18,7 @@ #endif int hnsw_ef_search; +int hnsw_ef_stream; bool hnsw_streaming; int hnsw_lock_tranche_id; static relopt_kind hnsw_relopt_kind; @@ -74,7 +75,10 @@ HnswInit(void) NULL, &hnsw_streaming, HNSW_DEFAULT_STREAMING, PGC_USERSET, 0, NULL, NULL, NULL); - /* TODO Add option for limiting iterative search */ + /* TODO Figure out name */ + DefineCustomIntVariable("hnsw.ef_stream", "Sets the max number of additional candidates to visit for streaming search", + "-1 means all", &hnsw_ef_stream, + HNSW_DEFAULT_EF_STREAM, HNSW_MIN_EF_STREAM, HNSW_MAX_EF_STREAM, PGC_USERSET, 0, NULL, NULL, NULL); MarkGUCPrefixReserved("hnsw"); } diff --git a/src/hnsw.h b/src/hnsw.h index e6bfde5..0e1aa54 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -47,6 +47,9 @@ #define HNSW_MIN_EF_SEARCH 1 #define HNSW_MAX_EF_SEARCH 1000 #define HNSW_DEFAULT_STREAMING false +#define HNSW_DEFAULT_EF_STREAM -1 +#define HNSW_MIN_EF_STREAM -1 +#define HNSW_MAX_EF_STREAM INT_MAX /* Tuple types */ #define HNSW_ELEMENT_TUPLE_TYPE 1 @@ -126,6 +129,7 @@ /* Variables */ extern int hnsw_ef_search; +extern int hnsw_ef_stream; extern bool hnsw_streaming; extern int hnsw_lock_tranche_id; @@ -412,7 +416,7 @@ bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value); Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); void HnswInit(void); -List *HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement, visited_hash * v, pairingheap **discarded, bool initVisited); +List *HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement, visited_hash * v, pairingheap **discarded, bool initVisited, int64 *tuples); HnswElement HnswGetEntryPoint(Relation index); void HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint); void *HnswAlloc(HnswAllocator * allocator, Size size); diff --git a/src/hnswscan.c b/src/hnswscan.c index d63090f..de68e6a 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -36,11 +36,11 @@ GetScanItems(IndexScanDesc scan, Datum q) for (int lc = entryPoint->level; lc >= 1; lc--) { - w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, false, NULL, NULL, NULL, true); + w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, false, NULL, NULL, NULL, true, NULL); ep = w; } - return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL, &so->v, hnsw_streaming ? &so->discarded : NULL, true); + return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL, &so->v, hnsw_streaming ? &so->discarded : NULL, true, &so->tuples); } /* @@ -73,7 +73,7 @@ ResumeScanItems(IndexScanDesc scan) ep = lappend(ep, hc); } - return HnswSearchLayer(base, so->q, ep, batch_size, 0, index, procinfo, collation, so->m, false, NULL, &so->v, &so->discarded, false); + return HnswSearchLayer(base, so->q, ep, batch_size, 0, index, procinfo, collation, so->m, false, NULL, &so->v, &so->discarded, false, &so->tuples); } /* @@ -219,8 +219,17 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir) if (!hnsw_streaming) break; + /* Reached max number of additional tuples */ + if (hnsw_ef_stream != -1 && so->tuples >= hnsw_ef_search + hnsw_ef_stream) + { + if (pairingheap_is_empty(so->discarded)) + break; + + /* Return remaining tuples */ + so->w = lappend(so->w, HnswGetSearchCandidate(w_node, pairingheap_remove_first(so->discarded))); + } /* Prevent scans from consuming too much memory */ - if (MemoryContextMemAllocated(so->tmpCtx, false) > (Size) work_mem * 1024L) + else if (MemoryContextMemAllocated(so->tmpCtx, false) > (Size) work_mem * 1024L) { if (pairingheap_is_empty(so->discarded)) { @@ -278,8 +287,6 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir) continue; } - so->tuples++; - heaptid = &element->heaptids[--element->heaptidsLength]; MemoryContextSwitchTo(oldCtx); diff --git a/src/hnswutils.c b/src/hnswutils.c index 1246872..206cc6d 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -807,7 +807,7 @@ HnswLoadUnvisitedFromDisk(HnswElement element, HnswUnvisited * unvisited, int *u * Algorithm 2 from paper */ List * -HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement, visited_hash * v, pairingheap **discarded, bool initVisited) +HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement, visited_hash * v, pairingheap **discarded, bool initVisited, int64 *tuples) { List *w = NIL; pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL); @@ -849,8 +849,13 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F bool found; if (initVisited) + { AddToVisited(base, v, hc->element, index, &found); + if (tuples != NULL) + (*tuples)++; + } + pairingheap_add(C, &hc->c_node); pairingheap_add(W, &hc->w_node); @@ -879,6 +884,9 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F else HnswLoadUnvisitedFromDisk(cElement, unvisited, &unvisitedLength, v, index, m, lm, lc); + if (tuples != NULL) + (*tuples) += unvisitedLength; + for (int i = 0; i < unvisitedLength; i++) { HnswElement eElement; @@ -1318,7 +1326,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint /* 1st phase: greedy search to insert level */ for (int lc = entryLevel; lc >= level + 1; lc--) { - w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, true, skipElement, NULL, NULL, true); + w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, true, skipElement, NULL, NULL, true, NULL); ep = w; } @@ -1337,7 +1345,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint List *lw = NIL; ListCell *lc2; - w = HnswSearchLayer(base, q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement, NULL, NULL, true); + w = HnswSearchLayer(base, q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement, NULL, NULL, true, NULL); /* Convert search candidates to candidates */ foreach(lc2, w)