From 17266ed409f6075569e886851a19f79a49606ed6 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Wed, 9 Oct 2024 21:49:32 -0700 Subject: [PATCH] Use inMemory for conditionals --- src/hnswutils.c | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/hnswutils.c b/src/hnswutils.c index 198e438..03c033e 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -598,9 +598,10 @@ HnswSearchCandidate * HnswEntryCandidate(char *base, HnswElement entryPoint, Datum q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec) { HnswSearchCandidate *sc = palloc(sizeof(HnswSearchCandidate)); + bool inMemory = index == NULL; HnswPtrStore(base, sc->element, entryPoint); - if (index == NULL) + if (inMemory) sc->distance = GetElementDistance(base, entryPoint, q, procinfo, collation); else HnswLoadElement(entryPoint, &sc->distance, &q, index, procinfo, collation, loadVec, NULL); @@ -644,9 +645,9 @@ CompareFurthestCandidates(const pairingheap_node *a, const pairingheap_node *b, * Init visited */ static inline void -InitVisited(char *base, visited_hash * v, Relation index, int ef, int m) +InitVisited(char *base, visited_hash * v, bool inMemory, int ef, int m) { - if (index != NULL) + if (!inMemory) v->tids = tidhash_create(CurrentMemoryContext, ef * m * 2, NULL); else if (base != NULL) v->offsets = offsethash_create(CurrentMemoryContext, ef * m * 2, NULL); @@ -658,9 +659,9 @@ InitVisited(char *base, visited_hash * v, Relation index, int ef, int m) * Add to visited */ static inline void -AddToVisited(char *base, visited_hash * v, HnswElementPtr elementPtr, Relation index, bool *found) +AddToVisited(char *base, visited_hash * v, HnswElementPtr elementPtr, bool inMemory, bool *found) { - if (index != NULL) + if (!inMemory) { HnswElement element = HnswPtrAccess(base, elementPtr); ItemPointerData indextid; @@ -721,7 +722,7 @@ HnswLoadUnvisitedFromMemory(char *base, HnswElement element, HnswUnvisited * unv HnswCandidate *hc = &localNeighborhood->items[i]; bool found; - AddToVisited(base, v, hc->element, NULL, &found); + AddToVisited(base, v, hc->element, true, &found); if (!found) unvisited[(*unvisitedLength)++].element = HnswPtrAccess(base, hc->element); @@ -805,11 +806,12 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F int lm = HnswGetLayerM(m, lc); HnswUnvisited *unvisited = palloc(lm * sizeof(HnswUnvisited)); int unvisitedLength; + bool inMemory = index == NULL; - InitVisited(base, &v, index, ef, m); + InitVisited(base, &v, inMemory, ef, m); /* Create local memory for neighborhood if needed */ - if (index == NULL) + if (inMemory) { neighborhoodSize = HNSW_NEIGHBOR_ARRAY_SIZE(lm); localNeighborhood = palloc(neighborhoodSize); @@ -821,7 +823,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F HnswSearchCandidate *sc = (HnswSearchCandidate *) lfirst(lc2); bool found; - AddToVisited(base, &v, sc->element, index, &found); + AddToVisited(base, &v, sc->element, inMemory, &found); pairingheap_add(C, &sc->c_node); pairingheap_add(W, &sc->w_node); @@ -846,7 +848,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F cElement = HnswPtrAccess(base, c->element); - if (index == NULL) + if (inMemory) HnswLoadUnvisitedFromMemory(base, cElement, unvisited, &unvisitedLength, &v, lc, localNeighborhood, neighborhoodSize); else HnswLoadUnvisitedFromDisk(cElement, unvisited, &unvisitedLength, &v, index, m, lm, lc); @@ -860,7 +862,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F f = HnswGetSearchCandidate(w_node, pairingheap_first(W)); - if (index == NULL) + if (inMemory) { eElement = unvisited[i].element; eDistance = GetElementDistance(base, eElement, q, procinfo, collation); @@ -1222,9 +1224,10 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint int entryLevel; Datum q = HnswGetValue(base, element); HnswElement skipElement = existing ? element : NULL; + bool inMemory = index == NULL; /* Precompute hash */ - if (index == NULL) + if (inMemory) PrecomputeHash(base, element); /* No neighbors if no entry point */ @@ -1273,7 +1276,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint /* Elements being deleted or skipped can help with search */ /* but should be removed before selecting neighbors */ - if (index != NULL) + if (!inMemory) lw = RemoveElements(base, lw, skipElement); /*