diff --git a/src/hnsw.h b/src/hnsw.h index 7748ec0..78bd4f8 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -250,6 +250,13 @@ typedef struct HnswSupport Oid *collation; } HnswSupport; +typedef struct HnswQuery +{ + Datum value; + IndexTuple itup; + ScanKeyData *keyData; +} HnswQuery; + typedef struct HnswBuildState { /* Info */ @@ -386,14 +393,14 @@ bool HnswCheckNorm(HnswSupport * support, Datum value); Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); void HnswInit(void); -List *HnswSearchLayer(char *base, Datum q, IndexTuple qtup, ScanKeyData *keyData, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement, bool inMemory); +List *HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement, bool inMemory); HnswElement HnswGetEntryPoint(Relation index); void HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint); void *HnswAlloc(HnswAllocator * allocator, Size size); HnswElement HnswInitElement(char *base, ItemPointer tid, int m, double ml, int maxLevel, HnswAllocator * alloc); HnswElement HnswInitElementFromBlock(BlockNumber blkno, OffsetNumber offno); void HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, HnswSupport * support, int m, int efConstruction, bool existing, bool inMemory); -HnswSearchCandidate *HnswEntryCandidate(char *base, HnswElement em, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation rel, HnswSupport * support, bool loadVec, bool inMemory); +HnswSearchCandidate *HnswEntryCandidate(char *base, HnswElement em, HnswQuery * q, Relation rel, HnswSupport * support, bool loadVec, bool inMemory); void HnswUpdateMetaPage(Relation index, int updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum, bool building); void HnswSetNeighborTuple(char *base, HnswNeighborTuple ntup, HnswElement e, int m); void HnswAddHeapTid(HnswElement element, ItemPointer heaptid); @@ -402,7 +409,7 @@ void HnswInitNeighbors(char *base, HnswElement element, int m, HnswAllocator * bool HnswInsertTupleOnDisk(Relation index, HnswSupport * support, IndexTuple itup, ItemPointer heaptid, bool building); void HnswUpdateNeighborsOnDisk(Relation index, HnswSupport * support, HnswElement e, int m, bool checkExisting, bool building); void HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec, Relation index); -void HnswLoadElement(HnswElement element, double *distance, bool *matches, Datum *q, IndexTuple qtup, ScanKeyData *keyData, Relation index, HnswSupport * support, bool loadVec, double *maxDistance); +void HnswLoadElement(HnswElement element, double *distance, bool *matches, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec, double *maxDistance); void HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element, bool useIndexTuple); void HnswUpdateConnection(char *base, HnswNeighborArray * neighbors, HnswElement newElement, float distance, int lm, int *updateIdx, Relation index, HnswSupport * support); bool HnswFormIndexTuple(IndexTuple *out, Datum *values, bool *isnull, const HnswTypeInfo * typeInfo, HnswSupport * support, TupleDesc tupdesc); diff --git a/src/hnswinsert.c b/src/hnswinsert.c index b3cf9a3..58487db 100644 --- a/src/hnswinsert.c +++ b/src/hnswinsert.c @@ -369,7 +369,7 @@ HnswLoadNeighbors(HnswElement element, Relation index, int m, int lm, int lc) * Load elements for insert */ static void -LoadElementsForInsert(HnswNeighborArray * neighbors, Datum q, IndexTuple qtup, int *idx, Relation index, HnswSupport * support) +LoadElementsForInsert(HnswNeighborArray * neighbors, HnswQuery * q, int *idx, Relation index, HnswSupport * support) { char *base = NULL; @@ -380,7 +380,7 @@ LoadElementsForInsert(HnswNeighborArray * neighbors, Datum q, IndexTuple qtup, i double distance; bool matches; - HnswLoadElement(element, &distance, &matches, &q, qtup, NULL, index, support, true, NULL); + HnswLoadElement(element, &distance, &matches, q, index, support, true, NULL); hc->distance = distance; /* Prune element if being deleted */ @@ -421,10 +421,13 @@ GetUpdateIndex(HnswElement element, HnswElement newElement, float distance, int idx = -2; else { - Datum q = HnswGetValue(base, element); - IndexTuple qtup = HnswPtrAccess(base, element->itup);; + HnswQuery q; - LoadElementsForInsert(neighbors, q, qtup, &idx, index, support); + q.value = HnswGetValue(base, element); + q.itup = HnswPtrAccess(base, element->itup); + q.keyData = NULL; + + LoadElementsForInsert(neighbors, &q, &idx, index, support); if (idx == -1) HnswUpdateConnection(base, neighbors, newElement, distance, lm, &idx, index, support); diff --git a/src/hnswscan.c b/src/hnswscan.c index 3bdc494..2b9a316 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -11,7 +11,7 @@ * Algorithm 5 from paper */ static List * -GetScanItems(IndexScanDesc scan, Datum q) +GetScanItems(IndexScanDesc scan, Datum value) { HnswScanOpaque so = (HnswScanOpaque) scan->opaque; Relation index = scan->indexRelation; @@ -22,7 +22,11 @@ GetScanItems(IndexScanDesc scan, Datum q) HnswElement entryPoint; char *base = NULL; bool inMemory = false; - ScanKeyData *keyData = scan->keyData; + HnswQuery q; + + q.value = value; + q.itup = NULL; + q.keyData = scan->keyData; /* Get m and entry point */ HnswGetMetaPageInfo(index, &m, &entryPoint); @@ -30,15 +34,15 @@ GetScanItems(IndexScanDesc scan, Datum q) if (entryPoint == NULL) return NIL; - ep = list_make1(HnswEntryCandidate(base, entryPoint, q, NULL, keyData, index, support, false, inMemory)); + ep = list_make1(HnswEntryCandidate(base, entryPoint, &q, index, support, false, inMemory)); for (int lc = entryPoint->level; lc >= 1; lc--) { - w = HnswSearchLayer(base, q, NULL, keyData, ep, 1, lc, index, support, m, false, NULL, inMemory); + w = HnswSearchLayer(base, &q, ep, 1, lc, index, support, m, false, NULL, inMemory); ep = w; } - return HnswSearchLayer(base, q, NULL, keyData, ep, hnsw_ef_search, 0, index, support, m, false, NULL, inMemory); + return HnswSearchLayer(base, &q, ep, hnsw_ef_search, 0, index, support, m, false, NULL, inMemory); } /* diff --git a/src/hnswutils.c b/src/hnswutils.c index f33e064..72ff96e 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -627,14 +627,14 @@ AttributeDistance(double e) * Calculate the distance between values */ static double -HnswGetDistance(IndexTuple itup, Datum vec, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation index, HnswSupport * support, bool *matches) +HnswGetDistance(IndexTuple itup, Datum vec, HnswQuery * q, Relation index, HnswSupport * support, bool *matches) { double g; - if (DatumGetPointer(q) == NULL) + if (DatumGetPointer(q->value) == NULL) g = 0; else - g = DatumGetFloat8(FunctionCall2Coll(support->procinfo[0], support->collation[0], q, vec)); + g = DatumGetFloat8(FunctionCall2Coll(support->procinfo[0], support->collation[0], q->value, vec)); Assert(PointerIsValid(matches)); *matches = true; @@ -645,14 +645,14 @@ HnswGetDistance(IndexTuple itup, Datum vec, Datum q, IndexTuple qtup, ScanKeyDat double e = 0.0; TupleDesc tupdesc = RelationGetDescr(index); - if (keyData) + if (q->keyData) { /* TODO need to pass length of key data */ int keyCount = 1; for (int i = 0; i < keyCount; i++) { - ScanKey key = &keyData[i]; + ScanKey key = &q->keyData[i]; bool isnull; Datum value = index_getattr(itup, key->sk_attno, tupdesc, &isnull); bool attnull = key->sk_flags & SK_ISNULL; @@ -681,7 +681,7 @@ HnswGetDistance(IndexTuple itup, Datum vec, Datum q, IndexTuple qtup, ScanKeyDat return w * g + AttributeDistance(e); } - else if (qtup) + else if (q->itup) { int keyCount = IndexRelationGetNumberOfKeyAttributes(index) - 1; @@ -690,7 +690,7 @@ HnswGetDistance(IndexTuple itup, Datum vec, Datum q, IndexTuple qtup, ScanKeyDat bool isnull; bool attnull; Datum value = index_getattr(itup, i + 2, tupdesc, &isnull); - Datum value2 = index_getattr(qtup, i + 2, tupdesc, &attnull); + Datum value2 = index_getattr(q->itup, i + 2, tupdesc, &attnull); if (isnull || attnull) { @@ -712,7 +712,7 @@ HnswGetDistance(IndexTuple itup, Datum vec, Datum q, IndexTuple qtup, ScanKeyDat * Load an element and optionally get its distance from q */ static void -HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, bool *matches, Datum *q, IndexTuple qtup, ScanKeyData *keyData, Relation index, HnswSupport * support, bool loadVec, double *maxDistance, HnswElement * element) +HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, bool *matches, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec, double *maxDistance, HnswElement * element) { Buffer buf; Page page; @@ -746,7 +746,7 @@ HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, boo value = PointerGetDatum(&etup->data); } - *distance = HnswGetDistance(itup, value, *q, qtup, keyData, index, support, matches); + *distance = HnswGetDistance(itup, value, q, index, support, matches); } /* Load element */ @@ -765,36 +765,36 @@ HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, boo * Load an element and optionally get its distance from q */ void -HnswLoadElement(HnswElement element, double *distance, bool *matches, Datum *q, IndexTuple qtup, ScanKeyData *keyData, Relation index, HnswSupport * support, bool loadVec, double *maxDistance) +HnswLoadElement(HnswElement element, double *distance, bool *matches, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec, double *maxDistance) { - HnswLoadElementImpl(element->blkno, element->offno, distance, matches, q, qtup, keyData, index, support, loadVec, maxDistance, &element); + HnswLoadElementImpl(element->blkno, element->offno, distance, matches, q, index, support, loadVec, maxDistance, &element); } /* * Get the distance for an element */ static double -GetElementDistance(char *base, HnswElement element, bool *matches, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation index, HnswSupport * support) +GetElementDistance(char *base, HnswElement element, bool *matches, HnswQuery * q, Relation index, HnswSupport * support) { Datum value = HnswGetValue(base, element); IndexTuple itup = HnswPtrAccess(base, element->itup); - return HnswGetDistance(itup, value, q, qtup, keyData, index, support, matches); + return HnswGetDistance(itup, value, q, index, support, matches); } /* * Create a candidate for the entry point */ HnswSearchCandidate * -HnswEntryCandidate(char *base, HnswElement entryPoint, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation index, HnswSupport * support, bool loadVec, bool inMemory) +HnswEntryCandidate(char *base, HnswElement entryPoint, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec, bool inMemory) { HnswSearchCandidate *sc = palloc(sizeof(HnswSearchCandidate)); HnswPtrStore(base, sc->element, entryPoint); if (inMemory) - sc->distance = GetElementDistance(base, entryPoint, &sc->matches, q, qtup, keyData, index, support); + sc->distance = GetElementDistance(base, entryPoint, &sc->matches, q, index, support); else - HnswLoadElement(entryPoint, &sc->distance, &sc->matches, &q, qtup, keyData, index, support, loadVec, NULL); + HnswLoadElement(entryPoint, &sc->distance, &sc->matches, q, index, support, loadVec, NULL); return sc; } @@ -983,7 +983,7 @@ HnswLoadUnvisitedFromDisk(HnswElement element, HnswUnvisited * unvisited, int *u * Algorithm 2 from paper */ List * -HnswSearchLayer(char *base, Datum q, IndexTuple qtup, ScanKeyData *keyData, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement, bool inMemory) +HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement, bool inMemory) { List *w = NIL; pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL); @@ -997,7 +997,7 @@ HnswSearchLayer(char *base, Datum q, IndexTuple qtup, ScanKeyData *keyData, List HnswUnvisited *unvisited = palloc(lm * sizeof(HnswUnvisited)); int unvisitedLength; uint64 additional = 0; - uint64 maxAdditional = keyData && lc == 0 ? 10000 : 0; + uint64 maxAdditional = q->keyData && lc == 0 ? 10000 : 0; InitVisited(base, &v, inMemory, ef, m); @@ -1061,7 +1061,7 @@ HnswSearchLayer(char *base, Datum q, IndexTuple qtup, ScanKeyData *keyData, List if (inMemory) { eElement = unvisited[i].element; - eDistance = GetElementDistance(base, eElement, &eMatches, q, qtup, keyData, index, support); + eDistance = GetElementDistance(base, eElement, &eMatches, q, index, support); } else { @@ -1071,7 +1071,7 @@ HnswSearchLayer(char *base, Datum q, IndexTuple qtup, ScanKeyData *keyData, List /* Avoid any allocations if not adding */ eElement = NULL; - HnswLoadElementImpl(blkno, offno, &eDistance, &eMatches, &q, qtup, keyData, index, support, inserting, alwaysAdd ? NULL : &f->distance, &eElement); + HnswLoadElementImpl(blkno, offno, &eDistance, &eMatches, q, index, support, inserting, alwaysAdd ? NULL : &f->distance, &eElement); if (eElement == NULL) continue; @@ -1180,10 +1180,13 @@ static bool CheckElementCloser(char *base, HnswCandidate * e, List *r, Relation index, HnswSupport * support) { HnswElement eElement = HnswPtrAccess(base, e->element); - Datum eValue = HnswGetValue(base, eElement); - IndexTuple etup = HnswPtrAccess(base, eElement->itup); + HnswQuery q; ListCell *lc2; + q.value = HnswGetValue(base, eElement); + q.itup = HnswPtrAccess(base, eElement->itup); + q.keyData = NULL; + foreach(lc2, r) { HnswCandidate *ri = lfirst(lc2); @@ -1191,7 +1194,7 @@ CheckElementCloser(char *base, HnswCandidate * e, List *r, Relation index, HnswS Datum riValue = HnswGetValue(base, riElement); IndexTuple ritup = HnswPtrAccess(base, riElement->itup); bool matches; - float distance = HnswGetDistance(etup, eValue, riValue, ritup, NULL, index, support, &matches); + float distance = HnswGetDistance(ritup, riValue, &q, index, support, &matches); if (distance <= e->distance) return false; @@ -1425,11 +1428,14 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint List *w; int level = element->level; int entryLevel; - Datum q = HnswGetValue(base, element); - IndexTuple qtup = HnswPtrAccess(base, element->itup); - ScanKeyData *keyData = NULL; + HnswQuery q; + HnswElement skipElement = existing ? element : NULL; + q.value = HnswGetValue(base, element); + q.itup = HnswPtrAccess(base, element->itup); + q.keyData = NULL; + /* Precompute hash */ if (inMemory) PrecomputeHash(base, element); @@ -1439,13 +1445,13 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint return; /* Get entry point and level */ - ep = list_make1(HnswEntryCandidate(base, entryPoint, q, qtup, keyData, index, support, true, inMemory)); + ep = list_make1(HnswEntryCandidate(base, entryPoint, &q, index, support, true, inMemory)); entryLevel = entryPoint->level; /* 1st phase: greedy search to insert level */ for (int lc = entryLevel; lc >= level + 1; lc--) { - w = HnswSearchLayer(base, q, qtup, keyData, ep, 1, lc, index, support, m, true, skipElement, inMemory); + w = HnswSearchLayer(base, &q, ep, 1, lc, index, support, m, true, skipElement, inMemory); ep = w; } @@ -1464,7 +1470,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint List *lw = NIL; ListCell *lc2; - w = HnswSearchLayer(base, q, qtup, keyData, ep, efConstruction, lc, index, support, m, true, skipElement, inMemory); + w = HnswSearchLayer(base, &q, ep, efConstruction, lc, index, support, m, true, skipElement, inMemory); /* Convert search candidates to candidates */ foreach(lc2, w) diff --git a/src/hnswvacuum.c b/src/hnswvacuum.c index bef5fbf..d3e72ea 100644 --- a/src/hnswvacuum.c +++ b/src/hnswvacuum.c @@ -256,7 +256,7 @@ RepairGraphEntryPoint(HnswVacuumState * vacuumstate) LockPage(index, HNSW_UPDATE_LOCK, ShareLock); /* Load element */ - HnswLoadElement(highestPoint, NULL, NULL, NULL, NULL, NULL, index, support, true, NULL); + HnswLoadElement(highestPoint, NULL, NULL, NULL, index, support, true, NULL); /* Repair if needed */ if (NeedsUpdated(vacuumstate, highestPoint)) @@ -294,7 +294,7 @@ RepairGraphEntryPoint(HnswVacuumState * vacuumstate) * is outdated, this can remove connections at higher levels in * the graph until they are repaired, but this should be fine. */ - HnswLoadElement(entryPoint, NULL, NULL, NULL, NULL, NULL, index, support, true, NULL); + HnswLoadElement(entryPoint, NULL, NULL, NULL, index, support, true, NULL); if (NeedsUpdated(vacuumstate, entryPoint)) {