diff --git a/src/hnsw.h b/src/hnsw.h index 09e90f3..7ee2f2b 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -362,6 +362,11 @@ typedef struct HnswVacuumState MemoryContext tmpCtx; } HnswVacuumState; +typedef struct HnswQuery +{ + Datum value; +} HnswQuery; + /* Methods */ int HnswGetM(Relation index); int HnswGetEfConstruction(Relation index); @@ -370,14 +375,14 @@ bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * re Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); void HnswInit(void); -List *HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement); +List *HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement); HnswElement HnswGetEntryPoint(Relation index); void HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint); void *HnswAlloc(HnswAllocator * allocator, Size size); HnswElement HnswInitElement(char *base, ItemPointer tid, int m, double ml, int maxLevel, HnswAllocator * alloc); HnswElement HnswInitElementFromBlock(BlockNumber blkno, OffsetNumber offno); void HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, bool existing); -HnswCandidate *HnswEntryCandidate(char *base, HnswElement em, Datum q, Relation rel, FmgrInfo *procinfo, Oid collation, bool loadVec); +HnswCandidate *HnswEntryCandidate(char *base, HnswElement em, HnswQuery * q, Relation rel, FmgrInfo *procinfo, Oid collation, bool loadVec); void HnswUpdateMetaPage(Relation index, int updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum, bool building); void HnswSetNeighborTuple(char *base, HnswNeighborTuple ntup, HnswElement e, int m); void HnswAddHeapTid(HnswElement element, ItemPointer heaptid); @@ -385,7 +390,7 @@ void HnswInitNeighbors(char *base, HnswElement element, int m, HnswAllocator * bool HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull, ItemPointer heap_tid, bool building); void HnswUpdateNeighborsOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement e, int m, bool checkExisting, bool building); void HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec); -void HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec); +void HnswLoadElement(HnswElement element, float *distance, HnswQuery * q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec); void HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element); void HnswUpdateConnection(char *base, HnswElement element, HnswCandidate * hc, int lm, int lc, int *updateIdx, Relation index, FmgrInfo *procinfo, Oid collation); void HnswLoadNeighbors(HnswElement element, Relation index, int m); diff --git a/src/hnswscan.c b/src/hnswscan.c index eaf0519..2267077 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -11,7 +11,7 @@ * Algorithm 5 from paper */ static List * -GetScanItems(IndexScanDesc scan, Datum q) +GetScanItems(IndexScanDesc scan, Datum value) { HnswScanOpaque so = (HnswScanOpaque) scan->opaque; Relation index = scan->indexRelation; @@ -22,6 +22,9 @@ GetScanItems(IndexScanDesc scan, Datum q) int m; HnswElement entryPoint; char *base = NULL; + HnswQuery q; + + q.value = value; /* Get m and entry point */ HnswGetMetaPageInfo(index, &m, &entryPoint); @@ -29,15 +32,15 @@ GetScanItems(IndexScanDesc scan, Datum q) if (entryPoint == NULL) return NIL; - ep = list_make1(HnswEntryCandidate(base, entryPoint, q, index, procinfo, collation, false)); + ep = list_make1(HnswEntryCandidate(base, entryPoint, &q, index, procinfo, collation, false)); for (int lc = entryPoint->level; lc >= 1; lc--) { - w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, false, NULL); + w = HnswSearchLayer(base, &q, ep, 1, lc, index, procinfo, collation, m, false, NULL); ep = w; } - return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL); + return HnswSearchLayer(base, &q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL); } /* diff --git a/src/hnswutils.c b/src/hnswutils.c index 212214e..799e98f 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -555,7 +555,7 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe * Load an element and optionally get its distance from q */ void -HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec) +HnswLoadElement(HnswElement element, float *distance, HnswQuery * q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec) { Buffer buf; Page page; @@ -575,7 +575,7 @@ HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, /* Calculate distance */ if (distance != NULL) - *distance = (float) DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&etup->data))); + *distance = (float) DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q->value, PointerGetDatum(&etup->data))); UnlockReleaseBuffer(buf); } @@ -584,19 +584,19 @@ HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, * Get the distance for a candidate */ static float -GetCandidateDistance(char *base, HnswCandidate * hc, Datum q, FmgrInfo *procinfo, Oid collation) +GetCandidateDistance(char *base, HnswCandidate * hc, HnswQuery * q, FmgrInfo *procinfo, Oid collation) { HnswElement hce = HnswPtrAccess(base, hc->element); Datum value = HnswGetValue(base, hce); - return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, value)); + return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q->value, value)); } /* * Create a candidate for the entry point */ HnswCandidate * -HnswEntryCandidate(char *base, HnswElement entryPoint, Datum q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec) +HnswEntryCandidate(char *base, HnswElement entryPoint, HnswQuery * q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec) { HnswCandidate *hc = palloc(sizeof(HnswCandidate)); @@ -604,7 +604,7 @@ HnswEntryCandidate(char *base, HnswElement entryPoint, Datum q, Relation index, if (index == NULL) hc->distance = GetCandidateDistance(base, hc, q, procinfo, collation); else - HnswLoadElement(entryPoint, &hc->distance, &q, index, procinfo, collation, loadVec); + HnswLoadElement(entryPoint, &hc->distance, q, index, procinfo, collation, loadVec); return hc; } @@ -722,7 +722,7 @@ CountElement(char *base, HnswElement skipElement, HnswCandidate * hc) * Algorithm 2 from paper */ List * -HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement) +HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement) { List *w = NIL; pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL); @@ -806,7 +806,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F if (index == NULL) eDistance = GetCandidateDistance(base, e, q, procinfo, collation); else - HnswLoadElement(eElement, &eDistance, &q, index, procinfo, collation, inserting); + HnswLoadElement(eElement, &eDistance, q, index, procinfo, collation, inserting); Assert(!eElement->deleted); @@ -1091,7 +1091,9 @@ HnswUpdateConnection(char *base, HnswElement element, HnswCandidate * hc, int lm /* Load elements on insert */ if (index != NULL) { - Datum q = HnswGetValue(base, hce); + HnswQuery q; + + q.value = HnswGetValue(base, hce); for (int i = 0; i < currentNeighbors->length; i++) { @@ -1101,7 +1103,7 @@ HnswUpdateConnection(char *base, HnswElement element, HnswCandidate * hc, int lm if (HnswPtrIsNull(base, hc3Element->value)) HnswLoadElement(hc3Element, &hc3->distance, &q, index, procinfo, collation, true); else - hc3->distance = GetCandidateDistance(base, hc3, q, procinfo, collation); + hc3->distance = GetCandidateDistance(base, hc3, &q, procinfo, collation); /* Prune element if being deleted */ if (hc3Element->heaptidsLength == 0) @@ -1201,9 +1203,11 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint List *w; int level = element->level; int entryLevel; - Datum q = HnswGetValue(base, element); + HnswQuery q; HnswElement skipElement = existing ? element : NULL; + q.value = HnswGetValue(base, element); + #if PG_VERSION_NUM >= 130000 /* Precompute hash */ if (index == NULL) @@ -1215,13 +1219,13 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint return; /* Get entry point and level */ - ep = list_make1(HnswEntryCandidate(base, entryPoint, q, index, procinfo, collation, true)); + ep = list_make1(HnswEntryCandidate(base, entryPoint, &q, index, procinfo, collation, true)); entryLevel = entryPoint->level; /* 1st phase: greedy search to insert level */ for (int lc = entryLevel; lc >= level + 1; lc--) { - w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, true, skipElement); + w = HnswSearchLayer(base, &q, ep, 1, lc, index, procinfo, collation, m, true, skipElement); ep = w; } @@ -1239,7 +1243,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint List *neighbors; List *lw; - w = HnswSearchLayer(base, q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement); + w = HnswSearchLayer(base, &q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement); /* Elements being deleted or skipped can help with search */ /* but should be removed before selecting neighbors */