From a8fdffc9a2ab4bb893ec0c8711534af9df8fd173 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 5 Feb 2024 21:40:20 -0800 Subject: [PATCH] Added support for inline filtering with HNSW --- CHANGELOG.md | 4 + README.md | 6 ++ sql/vector--0.6.1--0.7.0.sql | 6 ++ sql/vector.sql | 6 ++ src/hnsw.c | 2 +- src/hnsw.h | 12 ++- src/hnswbuild.c | 38 +++++++-- src/hnswinsert.c | 16 +++- src/hnswscan.c | 7 +- src/hnswutils.c | 150 +++++++++++++++++++++++++++++++---- src/hnswvacuum.c | 20 ++++- test/t/020_hnsw_filtering.pl | 113 ++++++++++++++++++++++++++ 12 files changed, 341 insertions(+), 39 deletions(-) create mode 100644 sql/vector--0.6.1--0.7.0.sql create mode 100644 test/t/020_hnsw_filtering.pl diff --git a/CHANGELOG.md b/CHANGELOG.md index 3604076..60d47ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.7.0 (unreleased) + +- Added support for inline filtering with HNSW + ## 0.6.1 (unreleased) - Fixed error with `ANALYZE` and vectors with different dimensions diff --git a/README.md b/README.md index 7bd93e7..c5a67ff 100644 --- a/README.md +++ b/README.md @@ -377,6 +377,12 @@ Create an index on one [or more](https://www.postgresql.org/docs/current/indexes CREATE INDEX ON items (category_id); ``` +Or a composite HNSW index for approximate search (added in 0.7.0) + +```sql +CREATE INDEX ON items USING hnsw (embedding vector_l2_ops, category_id); +``` + Or a [partial index](https://www.postgresql.org/docs/current/indexes-partial.html) on the vector column for approximate search ```sql diff --git a/sql/vector--0.6.1--0.7.0.sql b/sql/vector--0.6.1--0.7.0.sql new file mode 100644 index 0000000..2b3fa8e --- /dev/null +++ b/sql/vector--0.6.1--0.7.0.sql @@ -0,0 +1,6 @@ +-- complain if script is sourced in psql, rather than via CREATE EXTENSION +\echo Use "ALTER EXTENSION vector UPDATE TO '0.7.0'" to load this file. \quit + +CREATE OPERATOR CLASS vector_integer_ops + DEFAULT FOR TYPE integer USING hnsw AS + OPERATOR 2 = (integer, integer); diff --git a/sql/vector.sql b/sql/vector.sql index 4b17faa..bc5ede2 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -290,3 +290,9 @@ CREATE OPERATOR CLASS vector_cosine_ops OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, FUNCTION 1 vector_negative_inner_product(vector, vector), FUNCTION 2 vector_norm(vector); + +-- hnsw attributes + +CREATE OPERATOR CLASS vector_integer_ops + DEFAULT FOR TYPE integer USING hnsw AS + OPERATOR 2 = (integer, integer); diff --git a/src/hnsw.c b/src/hnsw.c index 9689b17..ed7a7a5 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -200,7 +200,7 @@ hnswhandler(PG_FUNCTION_ARGS) amroutine->amcanorderbyop = true; amroutine->amcanbackward = false; /* can change direction mid-scan */ amroutine->amcanunique = false; - amroutine->amcanmulticol = false; + amroutine->amcanmulticol = true; amroutine->amoptionalkey = true; amroutine->amsearcharray = false; amroutine->amsearchnulls = false; diff --git a/src/hnsw.h b/src/hnsw.h index 09e90f3..1f3a838 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -128,6 +128,7 @@ HnswPtrDeclare(HnswElementData, HnswElementRelptr, HnswElementPtr); HnswPtrDeclare(HnswNeighborArray, HnswNeighborArrayRelptr, HnswNeighborArrayPtr); HnswPtrDeclare(HnswNeighborArrayPtr, HnswNeighborsRelptr, HnswNeighborsPtr); HnswPtrDeclare(char, DatumRelptr, DatumPtr); +HnswPtrDeclare(IndexTupleData, IndexTupleRelptr, IndexTuplePtr); typedef struct HnswElementData { @@ -143,6 +144,7 @@ typedef struct HnswElementData OffsetNumber neighborOffno; BlockNumber neighborPage; DatumPtr value; + IndexTuplePtr itup; LWLock lock; } HnswElementData; @@ -370,7 +372,7 @@ bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * re Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); void HnswInit(void); -List *HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement); +List *HnswSearchLayer(char *base, Datum q, IndexScanDesc scan, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement); HnswElement HnswGetEntryPoint(Relation index); void HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint); void *HnswAlloc(HnswAllocator * allocator, Size size); @@ -384,11 +386,13 @@ void HnswAddHeapTid(HnswElement element, ItemPointer heaptid); void HnswInitNeighbors(char *base, HnswElement element, int m, HnswAllocator * alloc); bool HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull, ItemPointer heap_tid, bool building); void HnswUpdateNeighborsOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement e, int m, bool checkExisting, bool building); -void HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec); -void HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec); -void HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element); +void HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec, Relation index); +void HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, IndexScanDesc scan, bool *matches); +void HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element, bool useIndexTuple); void HnswUpdateConnection(char *base, HnswElement element, HnswCandidate * hc, int lm, int lc, int *updateIdx, Relation index, FmgrInfo *procinfo, Oid collation); void HnswLoadNeighbors(HnswElement element, Relation index, int m); +TupleDesc HnswTupleDesc(Relation index); +IndexTuple HnswFormIndexTuple(Relation index, TupleDesc tupdesc, Datum value, Datum *values, bool *isnull); PGDLLEXPORT void HnswParallelBuildMain(dsm_segment *seg, shm_toc *toc); /* Index access methods */ diff --git a/src/hnswbuild.c b/src/hnswbuild.c index 680789b..e961c2d 100644 --- a/src/hnswbuild.c +++ b/src/hnswbuild.c @@ -157,6 +157,7 @@ CreateGraphPages(HnswBuildState * buildstate) Page page; HnswElementPtr iter = buildstate->graph->head; char *base = buildstate->hnswarea; + bool useIndexTuple = IndexRelationGetNumberOfAttributes(index) > 1; /* Calculate sizes */ maxSize = HNSW_MAX_SIZE; @@ -176,7 +177,8 @@ CreateGraphPages(HnswBuildState * buildstate) Size etupSize; Size ntupSize; Size combinedSize; - void *valuePtr = HnswPtrAccess(base, element->value); + Pointer valuePtr = HnswPtrAccess(base, element->value); + IndexTuple itup = HnswPtrAccess(base, element->itup); /* Update iterator */ iter = element->next; @@ -185,7 +187,7 @@ CreateGraphPages(HnswBuildState * buildstate) MemSet(etup, 0, HNSW_TUPLE_ALLOC_SIZE); /* Calculate sizes */ - etupSize = HNSW_ELEMENT_TUPLE_SIZE(VARSIZE_ANY(valuePtr)); + etupSize = HNSW_ELEMENT_TUPLE_SIZE(useIndexTuple ? IndexTupleSize(itup) : VARSIZE_ANY(valuePtr)); ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(element->level, buildstate->m); combinedSize = etupSize + ntupSize + sizeof(ItemIdData); @@ -193,7 +195,7 @@ CreateGraphPages(HnswBuildState * buildstate) if (etupSize > HNSW_TUPLE_ALLOC_SIZE) elog(ERROR, "index tuple too large"); - HnswSetElementTuple(base, etup, element); + HnswSetElementTuple(base, etup, element, useIndexTuple); /* Keep element and neighbors on the same page if possible */ if (PageGetFreeSpace(page) < etupSize || (combinedSize <= maxSize && PageGetFreeSpace(page) < combinedSize)) @@ -471,10 +473,14 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn HnswGraph *graph = buildstate->graph; HnswElement element; HnswAllocator *allocator = &buildstate->allocator; - Size valueSize; + IndexTuple itup; + Size itupSize; + IndexTuple itupPtr; Pointer valuePtr; LWLock *flushLock = &graph->flushLock; char *base = buildstate->hnswarea; + TupleDesc tupdesc = HnswTupleDesc(index); + bool unused; /* Detoast once for all calls */ Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); @@ -487,7 +493,8 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn } /* Get datum size */ - valueSize = VARSIZE_ANY(DatumGetPointer(value)); + itup = HnswFormIndexTuple(index, tupdesc, value, values, isnull); + itupSize = IndexTupleSize(itup); /* Ensure graph not flushed when inserting */ LWLockAcquire(flushLock, LW_SHARED); @@ -534,7 +541,7 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn /* Ok, we can proceed to allocate the element */ element = HnswInitElement(base, heaptid, buildstate->m, buildstate->ml, buildstate->maxLevel, allocator); - valuePtr = HnswAlloc(allocator, valueSize); + itupPtr = HnswAlloc(allocator, itupSize); /* * We have now allocated the space needed for the element, so we don't @@ -543,8 +550,10 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn */ LWLockRelease(&graph->allocatorLock); - /* Copy the datum */ - memcpy(valuePtr, DatumGetPointer(value), valueSize); + /* Copy the index tuple */ + memcpy(itupPtr, itup, itupSize); + HnswPtrStore(base, element->itup, itupPtr); + valuePtr = DatumGetPointer(index_getattr(itupPtr, 1, tupdesc, &unused)); HnswPtrStore(base, element->value, valuePtr); /* Create a lock for the element */ @@ -669,6 +678,19 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index buildstate->efConstruction = HnswGetEfConstruction(index); buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; + /* For now */ + if (IndexRelationGetNumberOfKeyAttributes(index) > 2) + elog(ERROR, "index cannot have more than two columns"); + + if (!OidIsValid(index_getprocid(index, 1, HNSW_DISTANCE_PROC))) + elog(ERROR, "first column must be a vector"); + + for (int i = 1; i < IndexRelationGetNumberOfKeyAttributes(index); i++) + { + if (OidIsValid(index_getprocid(index, i + 1, HNSW_DISTANCE_PROC))) + elog(ERROR, "column %d cannot be a vector", i + 1); + } + /* Require column to have dimensions to be indexed */ if (buildstate->dimensions < 0) elog(ERROR, "column does not have dimensions"); diff --git a/src/hnswinsert.c b/src/hnswinsert.c index c3c2885..f233f14 100644 --- a/src/hnswinsert.c +++ b/src/hnswinsert.c @@ -136,9 +136,12 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B OffsetNumber freeNeighborOffno = InvalidOffsetNumber; BlockNumber newInsertPage = InvalidBlockNumber; char *base = NULL; + bool useIndexTuple = IndexRelationGetNumberOfAttributes(index) > 1; + Pointer valuePtr = HnswPtrAccess(base, e->value); + IndexTuple itup = HnswPtrAccess(base, e->itup); /* Calculate sizes */ - etupSize = HNSW_ELEMENT_TUPLE_SIZE(VARSIZE_ANY(HnswPtrAccess(base, e->value))); + etupSize = HNSW_ELEMENT_TUPLE_SIZE(useIndexTuple ? IndexTupleSize(itup) : VARSIZE_ANY(valuePtr)); ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(e->level, m); combinedSize = etupSize + ntupSize + sizeof(ItemIdData); maxSize = HNSW_MAX_SIZE; @@ -146,7 +149,7 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B /* Prepare element tuple */ etup = palloc0(etupSize); - HnswSetElementTuple(base, etup, e); + HnswSetElementTuple(base, etup, e, useIndexTuple); /* Prepare neighbor tuple */ ntup = palloc0(ntupSize); @@ -564,6 +567,10 @@ HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull, Oid collation = index->rd_indcollation[0]; LOCKMODE lockmode = ShareLock; char *base = NULL; + TupleDesc tupdesc = HnswTupleDesc(index); + bool unused; + IndexTuple itup; + Pointer valuePtr; /* * Get a shared lock. This allows vacuum to ensure no in-flight inserts @@ -577,7 +584,10 @@ HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull, /* Create an element */ element = HnswInitElement(base, heap_tid, m, HnswGetMl(m), HnswGetMaxLevel(m), NULL); - HnswPtrStore(base, element->value, DatumGetPointer(value)); + itup = HnswFormIndexTuple(index, tupdesc, value, values, isnull); + HnswPtrStore(base, element->itup, itup); + valuePtr = DatumGetPointer(index_getattr(itup, 1, tupdesc, &unused)); + HnswPtrStore(base, element->value, valuePtr); /* Prevent concurrent inserts when likely updating entry point */ if (entryPoint == NULL || element->level > entryPoint->level) diff --git a/src/hnswscan.c b/src/hnswscan.c index eaf0519..8df6dfe 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -33,11 +33,12 @@ GetScanItems(IndexScanDesc scan, Datum q) for (int lc = entryPoint->level; lc >= 1; lc--) { - w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, false, NULL); + w = HnswSearchLayer(base, q, NULL, ep, 1, lc, index, procinfo, collation, m, false, NULL); ep = w; } - return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL); + /* Only pass scan to check matches for layer 0 */ + return HnswSearchLayer(base, q, scan, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL); } /* @@ -202,7 +203,7 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir) MemoryContextSwitchTo(oldCtx); scan->xs_heaptid = *heaptid; - scan->xs_recheck = false; + scan->xs_recheck = scan->numberOfKeys > 0; scan->xs_recheckorderby = false; return true; } diff --git a/src/hnswutils.c b/src/hnswutils.c index 212214e..55668fd 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -3,6 +3,7 @@ #include #include "access/generic_xlog.h" +#include "access/relscan.h" #include "hnsw.h" #include "lib/pairingheap.h" #include "storage/bufmgr.h" @@ -295,6 +296,35 @@ HnswInitElementFromBlock(BlockNumber blkno, OffsetNumber offno) return element; } +/* + * Get the tuple descriptor + */ +TupleDesc +HnswTupleDesc(Relation index) +{ + TupleDesc tupdesc = CreateTupleDescCopyConstr(RelationGetDescr(index)); + + /* Prevent compression */ + TupleDescAttr(tupdesc, 0)->attstorage = TYPSTORAGE_PLAIN; + + return tupdesc; +} + +/* + * Form an index tuple + */ +IndexTuple +HnswFormIndexTuple(Relation index, TupleDesc tupdesc, Datum value, Datum *values, bool *isnull) +{ + Size size = sizeof(Datum) * IndexRelationGetNumberOfAttributes(index); + Datum *newValues = palloc(size); + + memcpy(newValues, values, size); + newValues[0] = value; + + return index_form_tuple(tupdesc, newValues, isnull); +} + /* * Get the metapage info */ @@ -404,10 +434,8 @@ HnswUpdateMetaPage(Relation index, int updateEntry, HnswElement entryPoint, Bloc * Set element tuple, except for neighbor info */ void -HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element) +HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element, bool useIndexTuple) { - Pointer valuePtr = HnswPtrAccess(base, element->value); - etup->type = HNSW_ELEMENT_TUPLE_TYPE; etup->level = element->level; etup->deleted = 0; @@ -418,7 +446,19 @@ HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element) else ItemPointerSetInvalid(&etup->heaptids[i]); } - memcpy(&etup->data, valuePtr, VARSIZE_ANY(valuePtr)); + + if (useIndexTuple) + { + IndexTuple itup = HnswPtrAccess(base, element->itup); + + memcpy(&etup->data, itup, IndexTupleSize(itup)); + } + else + { + Pointer valuePtr = HnswPtrAccess(base, element->value); + + memcpy(&etup->data, valuePtr, VARSIZE_ANY(valuePtr)); + } } /* @@ -522,7 +562,7 @@ HnswLoadNeighbors(HnswElement element, Relation index, int m) * Load an element from a tuple */ void -HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec) +HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec, Relation index) { element->level = etup->level; element->deleted = etup->deleted; @@ -545,17 +585,64 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe if (loadVec) { char *base = NULL; - Datum value = datumCopy(PointerGetDatum(&etup->data), false, -1); - HnswPtrStore(base, element->value, DatumGetPointer(value)); + if (IndexRelationGetNumberOfAttributes(index) > 1) + { + TupleDesc tupdesc = RelationGetDescr(index); + bool unused; + IndexTuple itup = CopyIndexTuple((IndexTuple) &etup->data); + Datum value = index_getattr(itup, 1, tupdesc, &unused); + + HnswPtrStore(base, element->itup, itup); + HnswPtrStore(base, element->value, DatumGetPointer(value)); + } + else + { + Datum value = datumCopy(PointerGetDatum(&etup->data), false, -1); + + HnswPtrStore(base, element->value, DatumGetPointer(value)); + } } } +/* + * Check if an element matches + */ +static bool +HnswCheckMatches(Relation index, HnswElementTuple etup, IndexScanDesc scan) +{ + if (scan == NULL) + return true; + + for (int i = 0; i < scan->numberOfKeys; i++) + { + IndexTuple itup = (IndexTuple) &etup->data; + TupleDesc tupdesc = RelationGetDescr(index); + ScanKey key = &scan->keyData[i]; + bool isnull; + Datum value = index_getattr(itup, key->sk_attno, tupdesc, &isnull); + bool attnull = key->sk_flags & SK_ISNULL; + + if (isnull || attnull) + { + if (isnull != attnull) + return false; + } + else + { + if (!DatumGetBool(FunctionCall2Coll(&key->sk_func, key->sk_collation, value, key->sk_argument))) + return false; + } + } + + return true; +} + /* * Load an element and optionally get its distance from q */ void -HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec) +HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, IndexScanDesc scan, bool *matches) { Buffer buf; Page page; @@ -571,11 +658,29 @@ HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, Assert(HnswIsElementTuple(etup)); /* Load element */ - HnswLoadElementFromTuple(element, etup, true, loadVec); + HnswLoadElementFromTuple(element, etup, true, loadVec, index); /* Calculate distance */ if (distance != NULL) - *distance = (float) DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&etup->data))); + { + Datum value; + + if (IndexRelationGetNumberOfAttributes(index) > 1) + { + IndexTuple itup = (IndexTuple) &etup->data; + TupleDesc tupdesc = RelationGetDescr(index); + bool unused; + + value = index_getattr(itup, 1, tupdesc, &unused); + } + else + value = PointerGetDatum(&etup->data); + + *distance = (float) DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, value)); + } + + if (matches != NULL) + *matches = HnswCheckMatches(index, etup, scan); UnlockReleaseBuffer(buf); } @@ -604,7 +709,7 @@ HnswEntryCandidate(char *base, HnswElement entryPoint, Datum q, Relation index, if (index == NULL) hc->distance = GetCandidateDistance(base, hc, q, procinfo, collation); else - HnswLoadElement(entryPoint, &hc->distance, &q, index, procinfo, collation, loadVec); + HnswLoadElement(entryPoint, &hc->distance, &q, index, procinfo, collation, loadVec, NULL, NULL); return hc; } @@ -722,12 +827,14 @@ CountElement(char *base, HnswElement skipElement, HnswCandidate * hc) * Algorithm 2 from paper */ List * -HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement) +HnswSearchLayer(char *base, Datum q, IndexScanDesc scan, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement) { List *w = NIL; pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL); pairingheap *W = pairingheap_allocate(CompareFurthestCandidates, NULL); int wlen = 0; + uint64 additional = 0; + uint64 maxAdditional = (scan != NULL && scan->numberOfKeys > 0) ? 100 * ef : 0; visited_hash v; ListCell *lc2; HnswNeighborArray *neighborhoodData = NULL; @@ -799,6 +906,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F if (!visited) { float eDistance; + bool eMatches; HnswElement eElement = HnswPtrAccess(base, e->element); f = ((HnswPairingHeapNode *) pairingheap_first(W))->inner; @@ -806,7 +914,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F if (index == NULL) eDistance = GetCandidateDistance(base, e, q, procinfo, collation); else - HnswLoadElement(eElement, &eDistance, &q, index, procinfo, collation, inserting); + HnswLoadElement(eElement, &eDistance, &q, index, procinfo, collation, inserting, scan, &eMatches); Assert(!eElement->deleted); @@ -825,6 +933,16 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F pairingheap_add(C, &(CreatePairingHeapNode(ec)->ph_node)); pairingheap_add(W, &(CreatePairingHeapNode(ec)->ph_node)); + /* + * Do not count elements that do not match filter towards + * ef + */ + if (!eMatches) + { + if ((++additional) <= maxAdditional) + continue; + } + /* * Do not count elements being deleted towards ef when * vacuuming. It would be ideal to do this for inserts as @@ -1099,7 +1217,7 @@ HnswUpdateConnection(char *base, HnswElement element, HnswCandidate * hc, int lm HnswElement hc3Element = HnswPtrAccess(base, hc3->element); if (HnswPtrIsNull(base, hc3Element->value)) - HnswLoadElement(hc3Element, &hc3->distance, &q, index, procinfo, collation, true); + HnswLoadElement(hc3Element, &hc3->distance, &q, index, procinfo, collation, true, NULL, NULL); else hc3->distance = GetCandidateDistance(base, hc3, q, procinfo, collation); @@ -1221,7 +1339,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint /* 1st phase: greedy search to insert level */ for (int lc = entryLevel; lc >= level + 1; lc--) { - w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, true, skipElement); + w = HnswSearchLayer(base, q, NULL, ep, 1, lc, index, procinfo, collation, m, true, skipElement); ep = w; } @@ -1239,7 +1357,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint List *neighbors; List *lw; - w = HnswSearchLayer(base, q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement); + w = HnswSearchLayer(base, q, NULL, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement); /* Elements being deleted or skipped can help with search */ /* but should be removed before selecting neighbors */ diff --git a/src/hnswvacuum.c b/src/hnswvacuum.c index 7c14e54..68afec8 100644 --- a/src/hnswvacuum.c +++ b/src/hnswvacuum.c @@ -256,7 +256,7 @@ RepairGraphEntryPoint(HnswVacuumState * vacuumstate) LockPage(index, HNSW_UPDATE_LOCK, ShareLock); /* Load element */ - HnswLoadElement(highestPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true); + HnswLoadElement(highestPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true, NULL, NULL); /* Repair if needed */ if (NeedsUpdated(vacuumstate, highestPoint)) @@ -294,7 +294,7 @@ RepairGraphEntryPoint(HnswVacuumState * vacuumstate) * is outdated, this can remove connections at higher levels in * the graph until they are repaired, but this should be fine. */ - HnswLoadElement(entryPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true); + HnswLoadElement(entryPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true, NULL, NULL); if (NeedsUpdated(vacuumstate, entryPoint)) { @@ -370,7 +370,7 @@ RepairGraph(HnswVacuumState * vacuumstate) /* Create an element */ element = HnswInitElementFromBlock(blkno, offno); - HnswLoadElementFromTuple(element, etup, false, true); + HnswLoadElementFromTuple(element, etup, false, true, index); elements = lappend(elements, element); } @@ -440,6 +440,7 @@ MarkDeleted(HnswVacuumState * vacuumstate) BlockNumber insertPage = InvalidBlockNumber; Relation index = vacuumstate->index; BufferAccessStrategy bas = vacuumstate->bas; + bool useIndexTuple = IndexRelationGetNumberOfAttributes(index); /* * Wait for index scans to complete. Scans before this point may contain @@ -521,7 +522,18 @@ MarkDeleted(HnswVacuumState * vacuumstate) /* Overwrite element */ etup->deleted = 1; - MemSet(&etup->data, 0, VARSIZE_ANY(&etup->data)); + if (useIndexTuple) + { + IndexTuple itup = (IndexTuple) &etup->data; + + MemSet(itup, 0, IndexTupleSize(itup)); + } + else + { + Vector *vec = (Vector *) (&etup->data); + + MemSet(vec, 0, VARSIZE_ANY(vec)); + } /* Overwrite neighbors */ for (int i = 0; i < ntup->count; i++) diff --git a/test/t/020_hnsw_filtering.pl b/test/t/020_hnsw_filtering.pl new file mode 100644 index 0000000..f85f9a3 --- /dev/null +++ b/test/t/020_hnsw_filtering.pl @@ -0,0 +1,113 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $node; +my @queries = (); +my @cs = (); +my @expected; +my $limit = 20; +my $dim = 3; +my $array_sql = join(",", ('random()') x $dim); +my $nc = 50; + +sub test_recall +{ + my ($min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst WHERE c = $cs[0] ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Cond/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SELECT i FROM tst WHERE c = $cs[$i] ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + my %actual_set = map { $_ => 1 } @actual_ids; + + is(scalar(@actual_ids), $limit); + + my @expected_ids = split("\n", $expected[$i]); + + foreach (@expected_ids) + { + if (exists($actual_set{$_})) + { + $correct++; + } + $total++; + } + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim), c int4);"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql], i % $nc FROM generate_series(1, 10000) i;" +); +$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops, c);"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql], i % $nc FROM generate_series(1, 10000) i;" +); + +# Generate queries +for (1 .. 20) +{ + my @r = (); + for (1 .. $dim) + { + push(@r, rand()); + } + push(@queries, "[" . join(",", @r) . "]"); + push(@cs, int(rand() * $nc)); +} + +# Get exact results +@expected = (); +for my $i (0 .. $#queries) +{ + my $res = $node->safe_psql("postgres", qq( + SET enable_indexscan = off; + SELECT i FROM tst WHERE c = $cs[$i] ORDER BY v <-> '$queries[$i]' LIMIT $limit; + )); + push(@expected, $res); +} + +# Test recall +test_recall(0.99, '<->'); + +# Test vacuum +$node->safe_psql("postgres", "DELETE FROM tst WHERE c > 5;"); +$node->safe_psql("postgres", "VACUUM tst;"); + +# Test columns +my ($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING hnsw (c);"); +like($stderr, qr/first column must be a vector/); + +($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING hnsw (c, v vector_l2_ops);"); +like($stderr, qr/first column must be a vector/); + +($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops, c, c);"); +like($stderr, qr/index cannot have more than two columns/); + +($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops, v vector_l2_ops);"); +like($stderr, qr/column 2 cannot be a vector/); + +done_testing();