diff --git a/src/hnsw.h b/src/hnsw.h index bfc12f6..e034068 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -244,6 +244,11 @@ typedef struct HnswSupport Oid collation; } HnswSupport; +typedef struct HnswQuery +{ + Datum value; +} HnswQuery; + typedef struct HnswBuildState { /* Info */ @@ -378,14 +383,14 @@ bool HnswCheckNorm(HnswSupport * support, Datum value); Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); void HnswInit(void); -List *HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement); +List *HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement); HnswElement HnswGetEntryPoint(Relation index); void HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint); void *HnswAlloc(HnswAllocator * allocator, Size size); HnswElement HnswInitElement(char *base, ItemPointer tid, int m, double ml, int maxLevel, HnswAllocator * alloc); HnswElement HnswInitElementFromBlock(BlockNumber blkno, OffsetNumber offno); void HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, HnswSupport * support, int m, int efConstruction, bool existing); -HnswSearchCandidate *HnswEntryCandidate(char *base, HnswElement em, Datum q, Relation rel, HnswSupport * support, bool loadVec); +HnswSearchCandidate *HnswEntryCandidate(char *base, HnswElement em, HnswQuery * q, Relation rel, HnswSupport * support, bool loadVec); void HnswUpdateMetaPage(Relation index, int updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum, bool building); void HnswSetNeighborTuple(char *base, HnswNeighborTuple ntup, HnswElement e, int m); void HnswAddHeapTid(HnswElement element, ItemPointer heaptid); @@ -394,7 +399,7 @@ void HnswInitNeighbors(char *base, HnswElement element, int m, HnswAllocator * bool HnswInsertTupleOnDisk(Relation index, HnswSupport * support, Datum value, ItemPointer heaptid, bool building); void HnswUpdateNeighborsOnDisk(Relation index, HnswSupport * support, HnswElement e, int m, bool checkExisting, bool building); void HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec); -void HnswLoadElement(HnswElement element, double *distance, Datum *q, Relation index, HnswSupport * support, bool loadVec, double *maxDistance); +void HnswLoadElement(HnswElement element, double *distance, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec, double *maxDistance); bool HnswFormIndexValue(Datum *out, Datum *values, bool *isnull, const HnswTypeInfo * typeInfo, HnswSupport * support); void HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element); void HnswUpdateConnection(char *base, HnswNeighborArray * neighbors, HnswElement newElement, float distance, int lm, int *updateIdx, Relation index, HnswSupport * support); diff --git a/src/hnswinsert.c b/src/hnswinsert.c index 87204ca..84eb1d4 100644 --- a/src/hnswinsert.c +++ b/src/hnswinsert.c @@ -368,7 +368,7 @@ HnswLoadNeighbors(HnswElement element, Relation index, int m, int lm, int lc) * Load elements for insert */ static void -LoadElementsForInsert(HnswNeighborArray * neighbors, Datum q, int *idx, Relation index, HnswSupport * support) +LoadElementsForInsert(HnswNeighborArray * neighbors, HnswQuery * q, int *idx, Relation index, HnswSupport * support) { char *base = NULL; @@ -378,7 +378,7 @@ LoadElementsForInsert(HnswNeighborArray * neighbors, Datum q, int *idx, Relation HnswElement element = HnswPtrAccess(base, hc->element); double distance; - HnswLoadElement(element, &distance, &q, index, support, true, NULL); + HnswLoadElement(element, &distance, q, index, support, true, NULL); hc->distance = distance; /* Prune element if being deleted */ @@ -419,9 +419,11 @@ GetUpdateIndex(HnswElement element, HnswElement newElement, float distance, int idx = -2; else { - Datum q = HnswGetValue(base, element); + HnswQuery q; - LoadElementsForInsert(neighbors, q, &idx, index, support); + q.value = HnswGetValue(base, element); + + LoadElementsForInsert(neighbors, &q, &idx, index, support); if (idx == -1) HnswUpdateConnection(base, neighbors, newElement, distance, lm, &idx, index, support); diff --git a/src/hnswscan.c b/src/hnswscan.c index e3aaced..2c6a454 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -11,7 +11,7 @@ * Algorithm 5 from paper */ static List * -GetScanItems(IndexScanDesc scan, Datum q) +GetScanItems(IndexScanDesc scan, Datum value) { HnswScanOpaque so = (HnswScanOpaque) scan->opaque; Relation index = scan->indexRelation; @@ -21,6 +21,9 @@ GetScanItems(IndexScanDesc scan, Datum q) int m; HnswElement entryPoint; char *base = NULL; + HnswQuery q; + + q.value = value; /* Get m and entry point */ HnswGetMetaPageInfo(index, &m, &entryPoint); @@ -28,15 +31,15 @@ GetScanItems(IndexScanDesc scan, Datum q) if (entryPoint == NULL) return NIL; - ep = list_make1(HnswEntryCandidate(base, entryPoint, q, index, support, false)); + ep = list_make1(HnswEntryCandidate(base, entryPoint, &q, index, support, false)); for (int lc = entryPoint->level; lc >= 1; lc--) { - w = HnswSearchLayer(base, q, ep, 1, lc, index, support, m, false, NULL); + w = HnswSearchLayer(base, &q, ep, 1, lc, index, support, m, false, NULL); ep = w; } - return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, support, m, false, NULL); + return HnswSearchLayer(base, &q, ep, hnsw_ef_search, 0, index, support, m, false, NULL); } /* diff --git a/src/hnswutils.c b/src/hnswutils.c index 7fa0720..fe2b16e 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -533,7 +533,7 @@ HnswGetDistance(Datum a, Datum b, HnswSupport * support) * Load an element and optionally get its distance from q */ static void -HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, Datum *q, Relation index, HnswSupport * support, bool loadVec, double *maxDistance, HnswElement * element) +HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec, double *maxDistance, HnswElement * element) { Buffer buf; Page page; @@ -551,10 +551,10 @@ HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, Dat /* Calculate distance */ if (distance != NULL) { - if (DatumGetPointer(*q) == NULL) + if (DatumGetPointer(q->value) == NULL) *distance = 0; else - *distance = HnswGetDistance(*q, PointerGetDatum(&etup->data), support); + *distance = HnswGetDistance(q->value, PointerGetDatum(&etup->data), support); } /* Load element */ @@ -573,7 +573,7 @@ HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, Dat * Load an element and optionally get its distance from q */ void -HnswLoadElement(HnswElement element, double *distance, Datum *q, Relation index, HnswSupport * support, bool loadVec, double *maxDistance) +HnswLoadElement(HnswElement element, double *distance, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec, double *maxDistance) { HnswLoadElementImpl(element->blkno, element->offno, distance, q, index, support, loadVec, maxDistance, &element); } @@ -582,18 +582,18 @@ HnswLoadElement(HnswElement element, double *distance, Datum *q, Relation index, * Get the distance for an element */ static double -GetElementDistance(char *base, HnswElement element, Datum q, HnswSupport * support) +GetElementDistance(char *base, HnswElement element, HnswQuery * q, HnswSupport * support) { Datum value = HnswGetValue(base, element); - return HnswGetDistance(q, value, support); + return HnswGetDistance(q->value, value, support); } /* * Create a candidate for the entry point */ HnswSearchCandidate * -HnswEntryCandidate(char *base, HnswElement entryPoint, Datum q, Relation index, HnswSupport * support, bool loadVec) +HnswEntryCandidate(char *base, HnswElement entryPoint, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec) { HnswSearchCandidate *sc = palloc(sizeof(HnswSearchCandidate)); bool inMemory = index == NULL; @@ -602,7 +602,7 @@ HnswEntryCandidate(char *base, HnswElement entryPoint, Datum q, Relation index, if (inMemory) sc->distance = GetElementDistance(base, entryPoint, q, support); else - HnswLoadElement(entryPoint, &sc->distance, &q, index, support, loadVec, NULL); + HnswLoadElement(entryPoint, &sc->distance, q, index, support, loadVec, NULL); return sc; } @@ -791,7 +791,7 @@ HnswLoadUnvisitedFromDisk(HnswElement element, HnswUnvisited * unvisited, int *u * Algorithm 2 from paper */ List * -HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement) +HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement) { List *w = NIL; pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL); @@ -873,7 +873,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, H /* Avoid any allocations if not adding */ eElement = NULL; - HnswLoadElementImpl(blkno, offno, &eDistance, &q, index, support, inserting, alwaysAdd ? NULL : &f->distance, &eElement); + HnswLoadElementImpl(blkno, offno, &eDistance, q, index, support, inserting, alwaysAdd ? NULL : &f->distance, &eElement); if (eElement == NULL) continue; @@ -1220,10 +1220,12 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint List *w; int level = element->level; int entryLevel; - Datum q = HnswGetValue(base, element); + HnswQuery q; HnswElement skipElement = existing ? element : NULL; bool inMemory = index == NULL; + q.value = HnswGetValue(base, element); + /* Precompute hash */ if (inMemory) PrecomputeHash(base, element); @@ -1233,13 +1235,13 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint return; /* Get entry point and level */ - ep = list_make1(HnswEntryCandidate(base, entryPoint, q, index, support, true)); + ep = list_make1(HnswEntryCandidate(base, entryPoint, &q, index, support, true)); entryLevel = entryPoint->level; /* 1st phase: greedy search to insert level */ for (int lc = entryLevel; lc >= level + 1; lc--) { - w = HnswSearchLayer(base, q, ep, 1, lc, index, support, m, true, skipElement); + w = HnswSearchLayer(base, &q, ep, 1, lc, index, support, m, true, skipElement); ep = w; } @@ -1258,7 +1260,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint List *lw = NIL; ListCell *lc2; - w = HnswSearchLayer(base, q, ep, efConstruction, lc, index, support, m, true, skipElement); + w = HnswSearchLayer(base, &q, ep, efConstruction, lc, index, support, m, true, skipElement); /* Convert search candidates to candidates */ foreach(lc2, w)