From 3ccfab8f921cf92c8372438e36a76a48b5604916 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Wed, 9 Oct 2024 19:02:40 -0700 Subject: [PATCH] Added support for inline filtering with HNSW --- CHANGELOG.md | 1 + README.md | 7 + sql/vector--0.7.4--0.8.0.sql | 8 + sql/vector.sql | 10 ++ src/hnsw.c | 18 ++- src/hnsw.h | 38 +++-- src/hnswbuild.c | 97 +++++++++--- src/hnswinsert.c | 69 ++++++--- src/hnswscan.c | 20 +-- src/hnswutils.c | 292 ++++++++++++++++++++++++++++------- src/hnswvacuum.c | 26 ++-- test/t/041_hnsw_filtering.pl | 109 +++++++++++++ 12 files changed, 567 insertions(+), 128 deletions(-) create mode 100644 test/t/041_hnsw_filtering.pl diff --git a/CHANGELOG.md b/CHANGELOG.md index a7d9924..bcce7ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ ## 0.8.0 (unreleased) +- Added support for inline filtering with HNSW - Added casts for arrays to `sparsevec` - Improved cost estimation - Improved performance of HNSW inserts and on-disk index builds diff --git a/README.md b/README.md index 6cdca4f..90284ae 100644 --- a/README.md +++ b/README.md @@ -439,6 +439,12 @@ Create an index on one [or more](https://www.postgresql.org/docs/current/indexes CREATE INDEX ON items (category_id); ``` +Or a composite HNSW index for approximate search (added in 0.8.0) + +```sql +CREATE INDEX ON items USING hnsw (embedding vector_l2_ops, category_id); +``` + Or a [partial index](https://www.postgresql.org/docs/current/indexes-partial.html) on the vector column for approximate search ```sql @@ -1189,6 +1195,7 @@ Thanks to: - [k-means++: The Advantage of Careful Seeding](https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf) - [Concept Decompositions for Large Sparse Text Data using Clustering](https://www.cs.utexas.edu/users/inderjit/public_papers/concept_mlj.pdf) - [Efficient and Robust Approximate Nearest Neighbor Search using Hierarchical Navigable Small World Graphs](https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf) +- [HQANN: Efficient and Robust Similarity Search for Hybrid Queries with Structured and Unstructured Constraints](https://arxiv.org/pdf/2207.07940.pdf) ## History diff --git a/sql/vector--0.7.4--0.8.0.sql b/sql/vector--0.7.4--0.8.0.sql index e00348d..cd3e3c5 100644 --- a/sql/vector--0.7.4--0.8.0.sql +++ b/sql/vector--0.7.4--0.8.0.sql @@ -24,3 +24,11 @@ CREATE CAST (double precision[] AS sparsevec) CREATE CAST (numeric[] AS sparsevec) WITH FUNCTION array_to_sparsevec(numeric[], integer, boolean) AS ASSIGNMENT; + +CREATE FUNCTION hnsw_attribute_distance(integer, integer) RETURNS float8 + AS 'MODULE_PATHNAME', 'hnsw_int4_attribute_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE OPERATOR CLASS vector_integer_ops + DEFAULT FOR TYPE integer USING hnsw AS + OPERATOR 2 = (integer, integer), + FUNCTION 4 hnsw_attribute_distance(integer, integer); diff --git a/sql/vector.sql b/sql/vector.sql index 7fc3671..adf70d2 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -916,3 +916,13 @@ CREATE OPERATOR CLASS sparsevec_l1_ops OPERATOR 1 <+> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 l1_distance(sparsevec, sparsevec), FUNCTION 3 hnsw_sparsevec_support(internal); + +-- hnsw attributes + +CREATE FUNCTION hnsw_attribute_distance(integer, integer) RETURNS float8 + AS 'MODULE_PATHNAME', 'hnsw_int4_attribute_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE OPERATOR CLASS vector_integer_ops + DEFAULT FOR TYPE integer USING hnsw AS + OPERATOR 2 = (integer, integer), + FUNCTION 4 hnsw_attribute_distance(integer, integer); diff --git a/src/hnsw.c b/src/hnsw.c index c2579c1..bbfa6ce 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -227,13 +227,13 @@ hnswhandler(PG_FUNCTION_ARGS) IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); amroutine->amstrategies = 0; - amroutine->amsupport = 3; + amroutine->amsupport = 4; amroutine->amoptsprocnum = 0; amroutine->amcanorder = false; amroutine->amcanorderbyop = true; amroutine->amcanbackward = false; /* can change direction mid-scan */ amroutine->amcanunique = false; - amroutine->amcanmulticol = false; + amroutine->amcanmulticol = true; amroutine->amoptionalkey = true; amroutine->amsearcharray = false; amroutine->amsearchnulls = false; @@ -285,3 +285,17 @@ hnswhandler(PG_FUNCTION_ARGS) PG_RETURN_POINTER(amroutine); } + +/* + * Get the distance between two int4 attributes + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(hnsw_int4_attribute_distance); +Datum +hnsw_int4_attribute_distance(PG_FUNCTION_ARGS) +{ + int32 a = PG_GETARG_INT32(0); + int32 b = PG_GETARG_INT32(1); + double distance = ((double) a) - ((double) b); + + PG_RETURN_FLOAT8(distance); +} diff --git a/src/hnsw.h b/src/hnsw.h index 116d9bc..9efb53f 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -19,6 +19,7 @@ #define HNSW_DISTANCE_PROC 1 #define HNSW_NORM_PROC 2 #define HNSW_TYPE_INFO_PROC 3 +#define HNSW_ATTRIBUTE_DISTANCE_PROC 4 #define HNSW_VERSION 1 #define HNSW_MAGIC_NUMBER 0xA953A953 @@ -104,6 +105,8 @@ #define HnswPtrPointer(hp) (hp).ptr #define HnswPtrOffset(hp) relptr_offset((hp).relptr) +#define HnswUseIndexTuple(index) (IndexRelationGetNumberOfAttributes(index) > 1) + /* Variables */ extern int hnsw_ef_search; extern int hnsw_lock_tranche_id; @@ -121,6 +124,7 @@ HnswPtrDeclare(HnswElementData, HnswElementRelptr, HnswElementPtr); HnswPtrDeclare(HnswNeighborArray, HnswNeighborArrayRelptr, HnswNeighborArrayPtr); HnswPtrDeclare(HnswNeighborArrayPtr, HnswNeighborsRelptr, HnswNeighborsPtr); HnswPtrDeclare(char, DatumRelptr, DatumPtr); +HnswPtrDeclare(IndexTupleData, IndexTupleRelptr, IndexTuplePtr); struct HnswElementData { @@ -136,6 +140,7 @@ struct HnswElementData OffsetNumber neighborOffno; BlockNumber neighborPage; DatumPtr value; + IndexTuplePtr itup; LWLock lock; }; @@ -161,6 +166,7 @@ typedef struct HnswSearchCandidate pairingheap_node w_node; HnswElementPtr element; double distance; + bool matches; } HnswSearchCandidate; /* HNSW index options */ @@ -256,15 +262,17 @@ typedef struct HnswBuildState double reltuples; /* Support functions */ - FmgrInfo *procinfo; + FmgrInfo *procinfo[2]; FmgrInfo *normprocinfo; - Oid collation; + Oid *collation; /* Variables */ HnswGraph graphData; HnswGraph *graph; double ml; int maxLevel; + bool useIndexTuple; + TupleDesc tupdesc; /* Memory */ MemoryContext graphCtx; @@ -333,9 +341,9 @@ typedef struct HnswScanOpaqueData MemoryContext tmpCtx; /* Support functions */ - FmgrInfo *procinfo; + FmgrInfo *procinfo[2]; FmgrInfo *normprocinfo; - Oid collation; + Oid *collation; } HnswScanOpaqueData; typedef HnswScanOpaqueData * HnswScanOpaque; @@ -353,8 +361,8 @@ typedef struct HnswVacuumState int efConstruction; /* Support functions */ - FmgrInfo *procinfo; - Oid collation; + FmgrInfo *procinfo[2]; + Oid *collation; /* Variables */ struct tidhash_hash *deleted; @@ -375,29 +383,31 @@ bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value); Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); void HnswInit(void); -List *HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement); +List *HnswSearchLayer(char *base, Datum q, IndexTuple qtup, ScanKeyData *keyData, List *ep, int ef, int lc, Relation index, FmgrInfo **procinfo, Oid *collation, int m, bool inserting, HnswElement skipElement, bool inMemory); HnswElement HnswGetEntryPoint(Relation index); void HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint); void *HnswAlloc(HnswAllocator * allocator, Size size); HnswElement HnswInitElement(char *base, ItemPointer tid, int m, double ml, int maxLevel, HnswAllocator * alloc); HnswElement HnswInitElementFromBlock(BlockNumber blkno, OffsetNumber offno); -void HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, bool existing); -HnswSearchCandidate *HnswEntryCandidate(char *base, HnswElement em, Datum q, Relation rel, FmgrInfo *procinfo, Oid collation, bool loadVec); +void HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo **procinfo, Oid *collation, int m, int efConstruction, bool existing, bool inMemory); +HnswSearchCandidate *HnswEntryCandidate(char *base, HnswElement em, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation rel, FmgrInfo **procinfo, Oid *collation, bool loadVec, bool inMemory); void HnswUpdateMetaPage(Relation index, int updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum, bool building); void HnswSetNeighborTuple(char *base, HnswNeighborTuple ntup, HnswElement e, int m); void HnswAddHeapTid(HnswElement element, ItemPointer heaptid); HnswNeighborArray *HnswInitNeighborArray(int lm, HnswAllocator * allocator); void HnswInitNeighbors(char *base, HnswElement element, int m, HnswAllocator * alloc); bool HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull, ItemPointer heap_tid, bool building); -void HnswUpdateNeighborsOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement e, int m, bool checkExisting, bool building); -void HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec); -void HnswLoadElement(HnswElement element, double *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, double *maxDistance); -void HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element); -void HnswUpdateConnection(char *base, HnswNeighborArray * neighbors, HnswElement newElement, float distance, int lm, int *updateIdx, Relation index, FmgrInfo *procinfo, Oid collation); +void HnswUpdateNeighborsOnDisk(Relation index, FmgrInfo **procinfo, Oid *collation, HnswElement e, int m, bool checkExisting, bool building); +void HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec, Relation index); +void HnswLoadElement(HnswElement element, double *distance, bool *matches, Datum *q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfo, Oid *collation, bool loadVec, double *maxDistance); +void HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element, bool useIndexTuple); +void HnswUpdateConnection(char *base, HnswNeighborArray * neighbors, HnswElement newElement, float distance, int lm, int *updateIdx, Relation index, FmgrInfo **procinfo, Oid *collation); bool HnswLoadNeighborTids(HnswElement element, ItemPointerData *indextids, Relation index, int m, int lm, int lc); void HnswInitLockTranche(void); const HnswTypeInfo *HnswGetTypeInfo(Relation index); PGDLLEXPORT void HnswParallelBuildMain(dsm_segment *seg, shm_toc *toc); +void HnswInitProcinfo(FmgrInfo **procinfo, Relation index); +bool HnswIndexTupleIsEqual(IndexTuple a, IndexTuple b, TupleDesc tupdesc); /* Index access methods */ IndexBuildResult *hnswbuild(Relation heap, Relation index, IndexInfo *indexInfo); diff --git a/src/hnswbuild.c b/src/hnswbuild.c index 87d4823..bda90ca 100644 --- a/src/hnswbuild.c +++ b/src/hnswbuild.c @@ -148,6 +148,7 @@ CreateGraphPages(HnswBuildState * buildstate) Page page; HnswElementPtr iter = buildstate->graph->head; char *base = buildstate->hnswarea; + bool useIndexTuple = buildstate->useIndexTuple; /* Calculate sizes */ maxSize = HNSW_MAX_SIZE; @@ -167,7 +168,6 @@ CreateGraphPages(HnswBuildState * buildstate) Size etupSize; Size ntupSize; Size combinedSize; - Pointer valuePtr = HnswPtrAccess(base, element->value); /* Update iterator */ iter = element->next; @@ -176,7 +176,7 @@ CreateGraphPages(HnswBuildState * buildstate) MemSet(etup, 0, HNSW_TUPLE_ALLOC_SIZE); /* Calculate sizes */ - etupSize = HNSW_ELEMENT_TUPLE_SIZE(VARSIZE_ANY(valuePtr)); + etupSize = HNSW_ELEMENT_TUPLE_SIZE(useIndexTuple ? IndexTupleSize(HnswPtrAccess(base, element->itup)) : VARSIZE_ANY(HnswPtrAccess(base, element->value))); ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(element->level, buildstate->m); combinedSize = etupSize + ntupSize + sizeof(ItemIdData); @@ -186,7 +186,7 @@ CreateGraphPages(HnswBuildState * buildstate) (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED), errmsg("index tuple too large"))); - HnswSetElementTuple(base, etup, element); + HnswSetElementTuple(base, etup, element, useIndexTuple); /* Keep element and neighbors on the same page if possible */ if (PageGetFreeSpace(page) < etupSize || (combinedSize <= maxSize && PageGetFreeSpace(page) < combinedSize)) @@ -327,20 +327,29 @@ AddDuplicateInMemory(HnswElement element, HnswElement dup) * Find duplicate element */ static bool -FindDuplicateInMemory(char *base, HnswElement element) +FindDuplicateInMemory(char *base, HnswElement element, bool useIndexTuple, TupleDesc tupdesc) { HnswNeighborArray *neighbors = HnswGetNeighbors(base, element, 0); Datum value = HnswGetValue(base, element); + IndexTuple itup = HnswPtrAccess(base, element->itup); for (int i = 0; i < neighbors->length; i++) { HnswCandidate *neighbor = &neighbors->items[i]; HnswElement neighborElement = HnswPtrAccess(base, neighbor->element); - Datum neighborValue = HnswGetValue(base, neighborElement); - /* Exit early since ordered by distance */ - if (!datumIsEqual(value, neighborValue, false, -1)) - return false; + if (useIndexTuple) + { + /* Exit early since ordered by distance */ + if (!HnswIndexTupleIsEqual(itup, HnswPtrAccess(base, neighborElement->itup), tupdesc)) + return false; + } + else + { + /* Exit early since ordered by distance */ + if (!datumIsEqual(value, HnswGetValue(base, neighborElement), false, -1)) + return false; + } /* Check for space */ if (AddDuplicateInMemory(element, neighborElement)) @@ -366,7 +375,7 @@ AddElementInMemory(char *base, HnswGraph * graph, HnswElement element) * Update neighbors */ static void -UpdateNeighborsInMemory(char *base, FmgrInfo *procinfo, Oid collation, HnswElement e, int m) +UpdateNeighborsInMemory(char *base, Relation index, FmgrInfo **procinfo, Oid *collation, HnswElement e, int m) { for (int lc = e->level; lc >= 0; lc--) { @@ -388,7 +397,7 @@ UpdateNeighborsInMemory(char *base, FmgrInfo *procinfo, Oid collation, HnswEleme Assert(neighborElement); LWLockAcquire(&neighborElement->lock, LW_EXCLUSIVE); - HnswUpdateConnection(base, HnswGetNeighbors(base, neighborElement, lc), e, hc->distance, lm, NULL, NULL, procinfo, collation); + HnswUpdateConnection(base, HnswGetNeighbors(base, neighborElement, lc), e, hc->distance, lm, NULL, index, procinfo, collation); LWLockRelease(&neighborElement->lock); } } @@ -398,20 +407,20 @@ UpdateNeighborsInMemory(char *base, FmgrInfo *procinfo, Oid collation, HnswEleme * Update graph in memory */ static void -UpdateGraphInMemory(FmgrInfo *procinfo, Oid collation, HnswElement element, int m, int efConstruction, HnswElement entryPoint, HnswBuildState * buildstate) +UpdateGraphInMemory(FmgrInfo **procinfo, Oid *collation, HnswElement element, int m, int efConstruction, HnswElement entryPoint, HnswBuildState * buildstate) { HnswGraph *graph = buildstate->graph; char *base = buildstate->hnswarea; /* Look for duplicate */ - if (FindDuplicateInMemory(base, element)) + if (FindDuplicateInMemory(base, element, buildstate->useIndexTuple, buildstate->tupdesc)) return; /* Add element */ AddElementInMemory(base, graph, element); /* Update neighbors */ - UpdateNeighborsInMemory(base, procinfo, collation, element, m); + UpdateNeighborsInMemory(base, buildstate->index, procinfo, collation, element, m); /* Update entry point if needed (already have lock) */ if (entryPoint == NULL || element->level > entryPoint->level) @@ -424,8 +433,9 @@ UpdateGraphInMemory(FmgrInfo *procinfo, Oid collation, HnswElement element, int static void InsertTupleInMemory(HnswBuildState * buildstate, HnswElement element) { - FmgrInfo *procinfo = buildstate->procinfo; - Oid collation = buildstate->collation; + Relation index = buildstate->index; + FmgrInfo **procinfo = buildstate->procinfo; + Oid *collation = buildstate->collation; HnswGraph *graph = buildstate->graph; HnswElement entryPoint; LWLock *entryLock = &graph->entryLock; @@ -458,7 +468,7 @@ InsertTupleInMemory(HnswBuildState * buildstate, HnswElement element) } /* Find neighbors for element */ - HnswFindElementNeighbors(base, element, entryPoint, NULL, procinfo, collation, m, efConstruction, false); + HnswFindElementNeighbors(base, element, entryPoint, index, procinfo, collation, m, efConstruction, false, true); /* Update graph in memory */ UpdateGraphInMemory(procinfo, collation, element, m, efConstruction, entryPoint, buildstate); @@ -481,6 +491,11 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn Pointer valuePtr; LWLock *flushLock = &graph->flushLock; char *base = buildstate->hnswarea; + bool useIndexTuple = buildstate->useIndexTuple; + TupleDesc tupdesc = buildstate->tupdesc; + IndexTuple itup; + Size itupSize; + IndexTuple itupPtr; /* Detoast once for all calls */ Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); @@ -492,10 +507,10 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn /* Normalize if needed */ if (buildstate->normprocinfo != NULL) { - if (!HnswCheckNorm(buildstate->normprocinfo, buildstate->collation, value)) + if (!HnswCheckNorm(buildstate->normprocinfo, buildstate->collation[0], value)) return false; - value = HnswNormValue(typeInfo, buildstate->collation, value); + value = HnswNormValue(typeInfo, buildstate->collation[0], value); } /* Get datum size */ @@ -546,7 +561,17 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn /* Ok, we can proceed to allocate the element */ element = HnswInitElement(base, heaptid, buildstate->m, buildstate->ml, buildstate->maxLevel, allocator); - valuePtr = HnswAlloc(allocator, valueSize); + + if (useIndexTuple) + { + /* TODO fix */ + values[0] = value; + itup = index_form_tuple(tupdesc, values, isnull); + itupSize = IndexTupleSize(itup); + itupPtr = HnswAlloc(allocator, itupSize); + } + else + valuePtr = HnswAlloc(allocator, valueSize); /* * We have now allocated the space needed for the element, so we don't @@ -556,8 +581,19 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn LWLockRelease(&graph->allocatorLock); /* Copy the datum */ - memcpy(valuePtr, DatumGetPointer(value), valueSize); - HnswPtrStore(base, element->value, valuePtr); + if (useIndexTuple) + { + bool unused; + + memcpy(itupPtr, itup, itupSize); + HnswPtrStore(base, element->itup, itupPtr); + HnswPtrStore(base, element->value, DatumGetPointer(index_getattr(itupPtr, 1, tupdesc, &unused))); + } + else + { + memcpy(valuePtr, DatumGetPointer(value), valueSize); + HnswPtrStore(base, element->value, valuePtr); + } /* Create a lock for the element */ LWLockInitialize(&element->lock, hnsw_lock_tranche_id); @@ -684,6 +720,19 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), errmsg("type not supported for hnsw index"))); + /* TODO See if needed */ + if (IndexRelationGetNumberOfKeyAttributes(index) > 2) + elog(ERROR, "index cannot have more than two columns"); + + if (!OidIsValid(index_getprocid(index, 1, HNSW_DISTANCE_PROC))) + elog(ERROR, "first column must be a vector"); + + for (int i = 1; i < IndexRelationGetNumberOfKeyAttributes(index); i++) + { + if (!OidIsValid(index_getprocid(index, i + 1, HNSW_ATTRIBUTE_DISTANCE_PROC))) + elog(ERROR, "column %d cannot be a vector", i + 1); + } + /* Require column to have dimensions to be indexed */ if (buildstate->dimensions < 0) ereport(ERROR, @@ -704,14 +753,16 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index buildstate->indtuples = 0; /* Get support functions */ - buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); + HnswInitProcinfo(buildstate->procinfo, index); buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); - buildstate->collation = index->rd_indcollation[0]; + buildstate->collation = index->rd_indcollation; InitGraph(&buildstate->graphData, NULL, (Size) maintenance_work_mem * 1024L); buildstate->graph = &buildstate->graphData; buildstate->ml = HnswGetMl(buildstate->m); buildstate->maxLevel = HnswGetMaxLevel(buildstate->m); + buildstate->useIndexTuple = HnswUseIndexTuple(index); + buildstate->tupdesc = RelationGetDescr(index); buildstate->graphCtx = GenerationContextCreate(CurrentMemoryContext, "Hnsw build graph context", diff --git a/src/hnswinsert.c b/src/hnswinsert.c index 2dfd8d3..62ff0e1 100644 --- a/src/hnswinsert.c +++ b/src/hnswinsert.c @@ -154,9 +154,10 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B OffsetNumber freeNeighborOffno = InvalidOffsetNumber; BlockNumber newInsertPage = InvalidBlockNumber; char *base = NULL; + bool useIndexTuple = HnswUseIndexTuple(index); /* Calculate sizes */ - etupSize = HNSW_ELEMENT_TUPLE_SIZE(VARSIZE_ANY(HnswPtrAccess(base, e->value))); + etupSize = HNSW_ELEMENT_TUPLE_SIZE(useIndexTuple ? IndexTupleSize(HnswPtrAccess(base, e->itup)) : VARSIZE_ANY(HnswPtrAccess(base, e->value))); ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(e->level, m); combinedSize = etupSize + ntupSize + sizeof(ItemIdData); maxSize = HNSW_MAX_SIZE; @@ -164,7 +165,7 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B /* Prepare element tuple */ etup = palloc0(etupSize); - HnswSetElementTuple(base, etup, e); + HnswSetElementTuple(base, etup, e, useIndexTuple); /* Prepare neighbor tuple */ ntup = palloc0(ntupSize); @@ -368,7 +369,7 @@ HnswLoadNeighbors(HnswElement element, Relation index, int m, int lm, int lc) * Load elements for insert */ static void -LoadElementsForInsert(HnswNeighborArray * neighbors, Datum q, int *idx, Relation index, FmgrInfo *procinfo, Oid collation) +LoadElementsForInsert(HnswNeighborArray * neighbors, Datum q, IndexTuple qtup, int *idx, Relation index, FmgrInfo **procinfo, Oid *collation) { char *base = NULL; @@ -377,8 +378,9 @@ LoadElementsForInsert(HnswNeighborArray * neighbors, Datum q, int *idx, Relation HnswCandidate *hc = &neighbors->items[i]; HnswElement element = HnswPtrAccess(base, hc->element); double distance; + bool matches; - HnswLoadElement(element, &distance, &q, index, procinfo, collation, true, NULL); + HnswLoadElement(element, &distance, &matches, &q, qtup, NULL, index, procinfo, collation, true, NULL); hc->distance = distance; /* Prune element if being deleted */ @@ -394,7 +396,7 @@ LoadElementsForInsert(HnswNeighborArray * neighbors, Datum q, int *idx, Relation * Get update index */ static int -GetUpdateIndex(HnswElement element, HnswElement newElement, float distance, int m, int lm, int lc, Relation index, FmgrInfo *procinfo, Oid collation, MemoryContext updateCtx) +GetUpdateIndex(HnswElement element, HnswElement newElement, float distance, int m, int lm, int lc, Relation index, FmgrInfo **procinfo, Oid *collation, MemoryContext updateCtx) { char *base = NULL; int idx = -1; @@ -420,8 +422,9 @@ GetUpdateIndex(HnswElement element, HnswElement newElement, float distance, int else { Datum q = HnswGetValue(base, element); + IndexTuple qtup = HnswPtrAccess(base, element->itup);; - LoadElementsForInsert(neighbors, q, &idx, index, procinfo, collation); + LoadElementsForInsert(neighbors, q, qtup, &idx, index, procinfo, collation); if (idx == -1) HnswUpdateConnection(base, neighbors, newElement, distance, lm, &idx, index, procinfo, collation); @@ -529,7 +532,7 @@ UpdateNeighborOnDisk(HnswElement element, HnswElement newElement, int idx, int m * Update neighbors */ void -HnswUpdateNeighborsOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement e, int m, bool checkExisting, bool building) +HnswUpdateNeighborsOnDisk(Relation index, FmgrInfo **procinfo, Oid *collation, HnswElement e, int m, bool checkExisting, bool building) { char *base = NULL; @@ -630,16 +633,26 @@ FindDuplicateOnDisk(Relation index, HnswElement element, bool building) char *base = NULL; HnswNeighborArray *neighbors = HnswGetNeighbors(base, element, 0); Datum value = HnswGetValue(base, element); + IndexTuple itup = HnswPtrAccess(base, element->itup); + TupleDesc tupdesc = RelationGetDescr(index); for (int i = 0; i < neighbors->length; i++) { HnswCandidate *neighbor = &neighbors->items[i]; HnswElement neighborElement = HnswPtrAccess(base, neighbor->element); - Datum neighborValue = HnswGetValue(base, neighborElement); - /* Exit early since ordered by distance */ - if (!datumIsEqual(value, neighborValue, false, -1)) - return false; + if (HnswUseIndexTuple(index)) + { + /* Exit early since ordered by distance */ + if (!HnswIndexTupleIsEqual(itup, HnswPtrAccess(base, neighborElement->itup), tupdesc)) + return false; + } + else + { + /* Exit early since ordered by distance */ + if (!datumIsEqual(value, HnswGetValue(base, neighborElement), false, -1)) + return false; + } if (AddDuplicateOnDisk(index, element, neighborElement, building)) return true; @@ -652,7 +665,7 @@ FindDuplicateOnDisk(Relation index, HnswElement element, bool building) * Update graph on disk */ static void -UpdateGraphOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement element, int m, int efConstruction, HnswElement entryPoint, bool building) +UpdateGraphOnDisk(Relation index, FmgrInfo **procinfo, Oid *collation, HnswElement element, int m, int efConstruction, HnswElement entryPoint, bool building) { BlockNumber newInsertPage = InvalidBlockNumber; @@ -685,11 +698,13 @@ HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull, HnswElement element; int m; int efConstruction = HnswGetEfConstruction(index); - FmgrInfo *procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); - Oid collation = index->rd_indcollation[0]; + FmgrInfo *procinfo[2]; + Oid *collation = index->rd_indcollation; LOCKMODE lockmode = ShareLock; char *base = NULL; + HnswInitProcinfo(procinfo, index); + /* * Get a shared lock. This allows vacuum to ensure no in-flight inserts * before repairing graph. Use a page lock so it does not interfere with @@ -702,7 +717,23 @@ HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull, /* Create an element */ element = HnswInitElement(base, heap_tid, m, HnswGetMl(m), HnswGetMaxLevel(m), NULL); - HnswPtrStore(base, element->value, DatumGetPointer(value)); + if (HnswUseIndexTuple(index)) + { + /* TODO no toast */ + TupleDesc tupdesc = RelationGetDescr(index); + IndexTuple itup; + bool unused; + + /* TODO fix */ + values[0] = value; + itup = index_form_tuple(tupdesc, values, isnull); + + HnswPtrStore(base, element->itup, itup); + HnswPtrStore(base, element->value, DatumGetPointer(index_getattr(itup, 1, tupdesc, &unused))); + + } + else + HnswPtrStore(base, element->value, DatumGetPointer(value)); /* Prevent concurrent inserts when likely updating entry point */ if (entryPoint == NULL || element->level > entryPoint->level) @@ -719,7 +750,7 @@ HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull, } /* Find neighbors for element */ - HnswFindElementNeighbors(base, element, entryPoint, index, procinfo, collation, m, efConstruction, false); + HnswFindElementNeighbors(base, element, entryPoint, index, procinfo, collation, m, efConstruction, false, false); /* Update graph on disk */ UpdateGraphOnDisk(index, procinfo, collation, element, m, efConstruction, entryPoint, building); @@ -739,7 +770,7 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti Datum value; const HnswTypeInfo *typeInfo = HnswGetTypeInfo(index); FmgrInfo *normprocinfo; - Oid collation = index->rd_indcollation[0]; + Oid *collation = index->rd_indcollation; /* Detoast once for all calls */ value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); @@ -752,10 +783,10 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); if (normprocinfo != NULL) { - if (!HnswCheckNorm(normprocinfo, collation, value)) + if (!HnswCheckNorm(normprocinfo, collation[0], value)) return; - value = HnswNormValue(typeInfo, collation, value); + value = HnswNormValue(typeInfo, collation[0], value); } HnswInsertTupleOnDisk(index, value, values, isnull, heap_tid, false); diff --git a/src/hnswscan.c b/src/hnswscan.c index 30815af..efb693a 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -15,13 +15,15 @@ GetScanItems(IndexScanDesc scan, Datum q) { HnswScanOpaque so = (HnswScanOpaque) scan->opaque; Relation index = scan->indexRelation; - FmgrInfo *procinfo = so->procinfo; - Oid collation = so->collation; + FmgrInfo **procinfo = so->procinfo; + Oid *collation = so->collation; List *ep; List *w; int m; HnswElement entryPoint; char *base = NULL; + bool inMemory = false; + ScanKeyData *keyData = scan->keyData; /* Get m and entry point */ HnswGetMetaPageInfo(index, &m, &entryPoint); @@ -29,15 +31,15 @@ GetScanItems(IndexScanDesc scan, Datum q) if (entryPoint == NULL) return NIL; - ep = list_make1(HnswEntryCandidate(base, entryPoint, q, index, procinfo, collation, false)); + ep = list_make1(HnswEntryCandidate(base, entryPoint, q, NULL, keyData, index, procinfo, collation, false, inMemory)); for (int lc = entryPoint->level; lc >= 1; lc--) { - w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, false, NULL); + w = HnswSearchLayer(base, q, NULL, keyData, ep, 1, lc, index, procinfo, collation, m, false, NULL, inMemory); ep = w; } - return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL); + return HnswSearchLayer(base, q, NULL, keyData, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL, inMemory); } /* @@ -61,7 +63,7 @@ GetScanValue(IndexScanDesc scan) /* Normalize if needed */ if (so->normprocinfo != NULL) - value = HnswNormValue(so->typeInfo, so->collation, value); + value = HnswNormValue(so->typeInfo, so->collation[0], value); } return value; @@ -86,9 +88,9 @@ hnswbeginscan(Relation index, int nkeys, int norderbys) ALLOCSET_DEFAULT_SIZES); /* Set support functions */ - so->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); + HnswInitProcinfo(so->procinfo, index); so->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); - so->collation = index->rd_indcollation[0]; + so->collation = index->rd_indcollation; scan->opaque = so; @@ -173,7 +175,7 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir) ItemPointer heaptid; /* Move to next element if no valid heap TIDs */ - if (element->heaptidsLength == 0) + if (!hc->matches || element->heaptidsLength == 0) { so->w = list_delete_last(so->w); continue; diff --git a/src/hnswutils.c b/src/hnswutils.c index 856c309..09b73ed 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -153,6 +153,18 @@ HnswOptionalProcInfo(Relation index, uint16 procnum) return index_getprocinfo(index, 1, procnum); } +/* + * Init procinfo + */ +void +HnswInitProcinfo(FmgrInfo **procinfo, Relation index) +{ + procinfo[0] = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); + + if (IndexRelationGetNumberOfKeyAttributes(index) > 1) + procinfo[1] = index_getprocinfo(index, 2, HNSW_ATTRIBUTE_DISTANCE_PROC); +} + /* * Normalize value */ @@ -171,6 +183,37 @@ HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value) return DatumGetFloat8(FunctionCall1Coll(procinfo, collation, value)) > 0; } +/* + * Check if index tuples are equal + */ +bool +HnswIndexTupleIsEqual(IndexTuple a, IndexTuple b, TupleDesc tupdesc) +{ + for (int i = 0; i < tupdesc->natts; i++) + { + bool nullA; + bool nullB; + + Datum datumA = index_getattr(a, i + 1, tupdesc, &nullA); + Datum datumB = index_getattr(b, i + 1, tupdesc, &nullB); + + if (nullA || nullB) + { + if (nullA != nullB) + return false; + } + else + { + Form_pg_attribute att = TupleDescAttr(tupdesc, i); + + if (!datumIsEqual(datumA, datumB, att->attbyval, att->attlen)) + return false; + } + } + + return true; +} + /* * New buffer */ @@ -257,6 +300,7 @@ HnswInitElement(char *base, ItemPointer heaptid, int m, double ml, int maxLevel, HnswInitNeighbors(base, element, m, allocator); HnswPtrStore(base, element->value, (Pointer) NULL); + HnswPtrStore(base, element->itup, (IndexTuple) NULL); return element; } @@ -283,6 +327,7 @@ HnswInitElementFromBlock(BlockNumber blkno, OffsetNumber offno) element->offno = offno; HnswPtrStore(base, element->neighbors, (HnswNeighborArrayPtr *) NULL); HnswPtrStore(base, element->value, (Pointer) NULL); + HnswPtrStore(base, element->itup, (IndexTuple) NULL); return element; } @@ -398,10 +443,8 @@ HnswUpdateMetaPage(Relation index, int updateEntry, HnswElement entryPoint, Bloc * Set element tuple, except for neighbor info */ void -HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element) +HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element, bool useIndexTuple) { - Pointer valuePtr = HnswPtrAccess(base, element->value); - etup->type = HNSW_ELEMENT_TUPLE_TYPE; etup->level = element->level; etup->deleted = 0; @@ -412,7 +455,19 @@ HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element) else ItemPointerSetInvalid(&etup->heaptids[i]); } - memcpy(&etup->data, valuePtr, VARSIZE_ANY(valuePtr)); + + if (useIndexTuple) + { + IndexTuple itup = HnswPtrAccess(base, element->itup); + + memcpy(&etup->data, itup, IndexTupleSize(itup)); + } + else + { + Pointer valuePtr = HnswPtrAccess(base, element->value); + + memcpy(&etup->data, valuePtr, VARSIZE_ANY(valuePtr)); + } } /* @@ -453,7 +508,7 @@ HnswSetNeighborTuple(char *base, HnswNeighborTuple ntup, HnswElement e, int m) * Load an element from a tuple */ void -HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec) +HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec, Relation index) { element->level = etup->level; element->deleted = etup->deleted; @@ -476,26 +531,128 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe if (loadVec) { char *base = NULL; - Datum value = datumCopy(PointerGetDatum(&etup->data), false, -1); - HnswPtrStore(base, element->value, DatumGetPointer(value)); + if (HnswUseIndexTuple(index)) + { + IndexTuple itup = CopyIndexTuple((IndexTuple) &etup->data); + TupleDesc tupdesc = RelationGetDescr(index); + bool unused; + + HnswPtrStore(base, element->itup, itup); + HnswPtrStore(base, element->value, DatumGetPointer(index_getattr(itup, 1, tupdesc, &unused))); + } + else + { + Datum value = datumCopy(PointerGetDatum(&etup->data), false, -1); + + HnswPtrStore(base, element->value, DatumGetPointer(value)); + } } } +/* + * Get the attribute distance + */ +static inline double +AttributeDistance(double e) +{ + /* TODO Better bias */ + /* must be >> max(w * g) + 1 / log10(2) */ + double bias = 4.32; + + return e > 0 ? bias - 1.0 / log10(e + 1) : 0; +} + /* * Calculate the distance between values */ -static inline double -HnswGetDistance(Datum a, Datum b, FmgrInfo *procinfo, Oid collation) +static double +HnswGetDistance(IndexTuple itup, Datum vec, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfo, Oid *collation, bool *matches) { - return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, a, b)); + double g; + + if (DatumGetPointer(q) == NULL) + g = 0; + else + g = DatumGetFloat8(FunctionCall2Coll(procinfo[0], collation[0], q, vec)); + + Assert(PointerIsValid(matches)); + *matches = true; + + if (IndexRelationGetNumberOfKeyAttributes(index) > 1) + { + double w = 0.25; + double e = 0.0; + TupleDesc tupdesc = RelationGetDescr(index); + + if (keyData) + { + /* TODO need to pass length of key data */ + int keyCount = 1; + + for (int i = 0; i < keyCount; i++) + { + ScanKey key = &keyData[i]; + bool isnull; + Datum value = index_getattr(itup, key->sk_attno, tupdesc, &isnull); + bool attnull = key->sk_flags & SK_ISNULL; + + if (isnull || attnull) + { + if (isnull != attnull) + { + e += 1000; + *matches = false; + } + } + else if (!DatumGetBool(FunctionCall2Coll(&key->sk_func, key->sk_collation, value, key->sk_argument))) + { + double ei = fabs(DatumGetFloat8(FunctionCall2Coll(procinfo[key->sk_attno - 1], collation[key->sk_attno - 1], value, key->sk_argument))); + + if (ei > 0) + e += ei; + else + /* Distance is zero for inequality */ + e += 1000; + + *matches = false; + } + } + + return w * g + AttributeDistance(e); + } + else if (qtup) + { + int keyCount = IndexRelationGetNumberOfKeyAttributes(index) - 1; + + for (int i = 0; i < keyCount; i++) + { + bool isnull; + bool attnull; + Datum value = index_getattr(itup, i + 2, tupdesc, &isnull); + Datum value2 = index_getattr(qtup, i + 2, tupdesc, &attnull); + + if (isnull || attnull) + { + if (isnull != attnull) + e += 1000; + } + else + e += fabs(DatumGetFloat8(FunctionCall2Coll(procinfo[i + 1], collation[i + 1], value, value2))); + } + + return w * g + AttributeDistance(e); + } + } + + return g; } /* * Load an element and optionally get its distance from q */ static void -HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, double *maxDistance, HnswElement * element) +HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, bool *matches, Datum *q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfo, Oid *collation, bool loadVec, double *maxDistance, HnswElement * element) { Buffer buf; Page page; @@ -513,10 +670,23 @@ HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, Dat /* Calculate distance */ if (distance != NULL) { - if (DatumGetPointer(*q) == NULL) - *distance = 0; + IndexTuple itup = NULL; + Datum value; + + if (HnswUseIndexTuple(index)) + { + TupleDesc tupdesc = RelationGetDescr(index); + bool unused; + + itup = (IndexTuple) &etup->data; + value = index_getattr(itup, 1, tupdesc, &unused); + } else - *distance = HnswGetDistance(*q, PointerGetDatum(&etup->data), procinfo, collation); + { + value = PointerGetDatum(&etup->data); + } + + *distance = HnswGetDistance(itup, value, *q, qtup, keyData, index, procinfo, collation, matches); } /* Load element */ @@ -525,7 +695,7 @@ HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, Dat if (*element == NULL) *element = HnswInitElementFromBlock(blkno, offno); - HnswLoadElementFromTuple(*element, etup, true, loadVec); + HnswLoadElementFromTuple(*element, etup, true, loadVec, index); } UnlockReleaseBuffer(buf); @@ -535,35 +705,36 @@ HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, Dat * Load an element and optionally get its distance from q */ void -HnswLoadElement(HnswElement element, double *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, double *maxDistance) +HnswLoadElement(HnswElement element, double *distance, bool *matches, Datum *q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfo, Oid *collation, bool loadVec, double *maxDistance) { - HnswLoadElementImpl(element->blkno, element->offno, distance, q, index, procinfo, collation, loadVec, maxDistance, &element); + HnswLoadElementImpl(element->blkno, element->offno, distance, matches, q, qtup, keyData, index, procinfo, collation, loadVec, maxDistance, &element); } /* * Get the distance for an element */ static double -GetElementDistance(char *base, HnswElement element, Datum q, FmgrInfo *procinfo, Oid collation) +GetElementDistance(char *base, HnswElement element, bool *matches, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfo, Oid *collation) { Datum value = HnswGetValue(base, element); + IndexTuple itup = HnswPtrAccess(base, element->itup); - return HnswGetDistance(q, value, procinfo, collation); + return HnswGetDistance(itup, value, q, qtup, keyData, index, procinfo, collation, matches); } /* * Create a candidate for the entry point */ HnswSearchCandidate * -HnswEntryCandidate(char *base, HnswElement entryPoint, Datum q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec) +HnswEntryCandidate(char *base, HnswElement entryPoint, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfo, Oid *collation, bool loadVec, bool inMemory) { HnswSearchCandidate *sc = palloc(sizeof(HnswSearchCandidate)); HnswPtrStore(base, sc->element, entryPoint); - if (index == NULL) - sc->distance = GetElementDistance(base, entryPoint, q, procinfo, collation); + if (inMemory) + sc->distance = GetElementDistance(base, entryPoint, &sc->matches, q, qtup, keyData, index, procinfo, collation); else - HnswLoadElement(entryPoint, &sc->distance, &q, index, procinfo, collation, loadVec, NULL); + HnswLoadElement(entryPoint, &sc->distance, &sc->matches, &q, qtup, keyData, index, procinfo, collation, loadVec, NULL); return sc; } @@ -604,9 +775,9 @@ CompareFurthestCandidates(const pairingheap_node *a, const pairingheap_node *b, * Init visited */ static inline void -InitVisited(char *base, visited_hash * v, Relation index, int ef, int m) +InitVisited(char *base, visited_hash * v, bool inMemory, int ef, int m) { - if (index != NULL) + if (!inMemory) v->tids = tidhash_create(CurrentMemoryContext, ef * m * 2, NULL); else if (base != NULL) v->offsets = offsethash_create(CurrentMemoryContext, ef * m * 2, NULL); @@ -618,9 +789,9 @@ InitVisited(char *base, visited_hash * v, Relation index, int ef, int m) * Add to visited */ static inline void -AddToVisited(char *base, visited_hash * v, HnswElementPtr elementPtr, Relation index, bool *found) +AddToVisited(char *base, visited_hash * v, HnswElementPtr elementPtr, bool inMemory, bool *found) { - if (index != NULL) + if (!inMemory) { HnswElement element = HnswPtrAccess(base, elementPtr); ItemPointerData indextid; @@ -681,7 +852,7 @@ HnswLoadUnvisitedFromMemory(char *base, HnswElement element, HnswUnvisited * unv HnswCandidate *hc = &localNeighborhood->items[i]; bool found; - AddToVisited(base, v, hc->element, NULL, &found); + AddToVisited(base, v, hc->element, true, &found); if (!found) unvisited[(*unvisitedLength)++].element = HnswPtrAccess(base, hc->element); @@ -752,7 +923,7 @@ HnswLoadUnvisitedFromDisk(HnswElement element, HnswUnvisited * unvisited, int *u * Algorithm 2 from paper */ List * -HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement) +HnswSearchLayer(char *base, Datum q, IndexTuple qtup, ScanKeyData *keyData, List *ep, int ef, int lc, Relation index, FmgrInfo **procinfo, Oid *collation, int m, bool inserting, HnswElement skipElement, bool inMemory) { List *w = NIL; pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL); @@ -765,11 +936,13 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F int lm = HnswGetLayerM(m, lc); HnswUnvisited *unvisited = palloc(lm * sizeof(HnswUnvisited)); int unvisitedLength; + uint64 additional = 0; + uint64 maxAdditional = keyData && lc == 0 ? 10000 : 0; - InitVisited(base, &v, index, ef, m); + InitVisited(base, &v, inMemory, ef, m); /* Create local memory for neighborhood if needed */ - if (index == NULL) + if (inMemory) { neighborhoodSize = HNSW_NEIGHBOR_ARRAY_SIZE(lm); localNeighborhood = palloc(neighborhoodSize); @@ -781,11 +954,15 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F HnswSearchCandidate *sc = (HnswSearchCandidate *) lfirst(lc2); bool found; - AddToVisited(base, &v, sc->element, index, &found); + AddToVisited(base, &v, sc->element, inMemory, &found); pairingheap_add(C, &sc->c_node); pairingheap_add(W, &sc->w_node); + /* Do not count elements that do not match filter towards ef */ + if (!sc->matches && ++additional <= maxAdditional) + continue; + /* * Do not count elements being deleted towards ef when vacuuming. It * would be ideal to do this for inserts as well, but this could @@ -806,7 +983,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F cElement = HnswPtrAccess(base, c->element); - if (index == NULL) + if (inMemory) HnswLoadUnvisitedFromMemory(base, cElement, unvisited, &unvisitedLength, &v, lc, localNeighborhood, neighborhoodSize); else HnswLoadUnvisitedFromDisk(cElement, unvisited, &unvisitedLength, &v, index, m, lm, lc); @@ -816,14 +993,15 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F HnswElement eElement; HnswSearchCandidate *e; double eDistance; + bool eMatches; bool alwaysAdd = wlen < ef; f = HnswGetSearchCandidate(w_node, pairingheap_first(W)); - if (index == NULL) + if (inMemory) { eElement = unvisited[i].element; - eDistance = GetElementDistance(base, eElement, q, procinfo, collation); + eDistance = GetElementDistance(base, eElement, &eMatches, q, qtup, keyData, index, procinfo, collation); } else { @@ -833,7 +1011,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F /* Avoid any allocations if not adding */ eElement = NULL; - HnswLoadElementImpl(blkno, offno, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd ? NULL : &f->distance, &eElement); + HnswLoadElementImpl(blkno, offno, &eDistance, &eMatches, &q, qtup, keyData, index, procinfo, collation, inserting, alwaysAdd ? NULL : &f->distance, &eElement); if (eElement == NULL) continue; @@ -852,6 +1030,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F e = palloc(sizeof(HnswSearchCandidate)); HnswPtrStore(base, e->element, eElement); e->distance = eDistance; + e->matches = eMatches; pairingheap_add(C, &e->c_node); pairingheap_add(W, &e->w_node); @@ -862,6 +1041,10 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F */ if (CountElement(skipElement, eElement)) { + /* Do not count elements that do not match filter towards ef */ + if (!e->matches && ++additional <= maxAdditional) + continue; + wlen++; /* No need to decrement wlen */ @@ -934,10 +1117,11 @@ CompareCandidateDistancesOffset(const ListCell *a, const ListCell *b) * Check if an element is closer to q than any element from R */ static bool -CheckElementCloser(char *base, HnswCandidate * e, List *r, FmgrInfo *procinfo, Oid collation) +CheckElementCloser(char *base, HnswCandidate * e, List *r, Relation index, FmgrInfo **procinfo, Oid *collation) { HnswElement eElement = HnswPtrAccess(base, e->element); Datum eValue = HnswGetValue(base, eElement); + IndexTuple etup = HnswPtrAccess(base, eElement->itup); ListCell *lc2; foreach(lc2, r) @@ -945,7 +1129,9 @@ CheckElementCloser(char *base, HnswCandidate * e, List *r, FmgrInfo *procinfo, O HnswCandidate *ri = lfirst(lc2); HnswElement riElement = HnswPtrAccess(base, ri->element); Datum riValue = HnswGetValue(base, riElement); - float distance = HnswGetDistance(eValue, riValue, procinfo, collation); + IndexTuple ritup = HnswPtrAccess(base, riElement->itup); + bool matches; + float distance = HnswGetDistance(etup, eValue, riValue, ritup, NULL, index, procinfo, collation, &matches); if (distance <= e->distance) return false; @@ -958,7 +1144,7 @@ CheckElementCloser(char *base, HnswCandidate * e, List *r, FmgrInfo *procinfo, O * Algorithm 4 from paper */ static List * -SelectNeighbors(char *base, List *c, int lm, FmgrInfo *procinfo, Oid collation, bool *closerSet, HnswCandidate * newCandidate, HnswCandidate * *pruned, bool sortCandidates) +SelectNeighbors(char *base, List *c, int lm, Relation index, FmgrInfo **procinfo, Oid *collation, bool *closerSet, HnswCandidate * newCandidate, HnswCandidate * *pruned, bool sortCandidates) { List *r = NIL; List *w = list_copy(c); @@ -992,7 +1178,7 @@ SelectNeighbors(char *base, List *c, int lm, FmgrInfo *procinfo, Oid collation, /* Use previous state of r and wd to skip work when possible */ if (mustCalculate) - e->closer = CheckElementCloser(base, e, r, procinfo, collation); + e->closer = CheckElementCloser(base, e, r, index, procinfo, collation); else if (list_length(added) > 0) { /* Keep Valgrind happy for in-memory, parallel builds */ @@ -1005,7 +1191,7 @@ SelectNeighbors(char *base, List *c, int lm, FmgrInfo *procinfo, Oid collation, */ if (e->closer) { - e->closer = CheckElementCloser(base, e, added, procinfo, collation); + e->closer = CheckElementCloser(base, e, added, index, procinfo, collation); if (!e->closer) removedAny = true; @@ -1018,7 +1204,7 @@ SelectNeighbors(char *base, List *c, int lm, FmgrInfo *procinfo, Oid collation, */ if (removedAny) { - e->closer = CheckElementCloser(base, e, r, procinfo, collation); + e->closer = CheckElementCloser(base, e, r, index, procinfo, collation); if (e->closer) added = lappend(added, e); } @@ -1026,7 +1212,7 @@ SelectNeighbors(char *base, List *c, int lm, FmgrInfo *procinfo, Oid collation, } else if (e == newCandidate) { - e->closer = CheckElementCloser(base, e, r, procinfo, collation); + e->closer = CheckElementCloser(base, e, r, index, procinfo, collation); if (e->closer) added = lappend(added, e); } @@ -1077,7 +1263,7 @@ AddConnections(char *base, HnswElement element, List *neighbors, int lc) * Update connections */ void -HnswUpdateConnection(char *base, HnswNeighborArray * neighbors, HnswElement newElement, float distance, int lm, int *updateIdx, Relation index, FmgrInfo *procinfo, Oid collation) +HnswUpdateConnection(char *base, HnswNeighborArray * neighbors, HnswElement newElement, float distance, int lm, int *updateIdx, Relation index, FmgrInfo **procinfo, Oid *collation) { HnswCandidate newHc; @@ -1103,7 +1289,7 @@ HnswUpdateConnection(char *base, HnswNeighborArray * neighbors, HnswElement newE c = lappend(c, &neighbors->items[i]); c = lappend(c, &newHc); - SelectNeighbors(base, c, lm, procinfo, collation, &neighbors->closerSet, &newHc, &pruned, true); + SelectNeighbors(base, c, lm, index, procinfo, collation, &neighbors->closerSet, &newHc, &pruned, true); /* Should not happen */ if (pruned == NULL) @@ -1174,17 +1360,19 @@ PrecomputeHash(char *base, HnswElement element) * Algorithm 1 from paper */ void -HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, bool existing) +HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo **procinfo, Oid *collation, int m, int efConstruction, bool existing, bool inMemory) { List *ep; List *w; int level = element->level; int entryLevel; Datum q = HnswGetValue(base, element); + IndexTuple qtup = HnswPtrAccess(base, element->itup); + ScanKeyData *keyData = NULL; HnswElement skipElement = existing ? element : NULL; /* Precompute hash */ - if (index == NULL) + if (inMemory) PrecomputeHash(base, element); /* No neighbors if no entry point */ @@ -1192,13 +1380,13 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint return; /* Get entry point and level */ - ep = list_make1(HnswEntryCandidate(base, entryPoint, q, index, procinfo, collation, true)); + ep = list_make1(HnswEntryCandidate(base, entryPoint, q, qtup, keyData, index, procinfo, collation, true, inMemory)); entryLevel = entryPoint->level; /* 1st phase: greedy search to insert level */ for (int lc = entryLevel; lc >= level + 1; lc--) { - w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, true, skipElement); + w = HnswSearchLayer(base, q, qtup, keyData, ep, 1, lc, index, procinfo, collation, m, true, skipElement, inMemory); ep = w; } @@ -1217,7 +1405,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint List *lw = NIL; ListCell *lc2; - w = HnswSearchLayer(base, q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement); + w = HnswSearchLayer(base, q, qtup, keyData, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement, inMemory); /* Convert search candidates to candidates */ foreach(lc2, w) @@ -1233,7 +1421,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint /* Elements being deleted or skipped can help with search */ /* but should be removed before selecting neighbors */ - if (index != NULL) + if (!inMemory) lw = RemoveElements(base, lw, skipElement); /* @@ -1241,7 +1429,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint * sortCandidates to true for in-memory builds to enable closer * caching, but there does not seem to be a difference in performance. */ - neighbors = SelectNeighbors(base, lw, lm, procinfo, collation, &HnswGetNeighbors(base, element, lc)->closerSet, NULL, NULL, false); + neighbors = SelectNeighbors(base, lw, lm, index, procinfo, collation, &HnswGetNeighbors(base, element, lc)->closerSet, NULL, NULL, false); AddConnections(base, element, neighbors, lc); diff --git a/src/hnswvacuum.c b/src/hnswvacuum.c index 67cc645..09cf8e8 100644 --- a/src/hnswvacuum.c +++ b/src/hnswvacuum.c @@ -189,8 +189,8 @@ RepairGraphElement(HnswVacuumState * vacuumstate, HnswElement element, HnswEleme GenericXLogState *state; int m = vacuumstate->m; int efConstruction = vacuumstate->efConstruction; - FmgrInfo *procinfo = vacuumstate->procinfo; - Oid collation = vacuumstate->collation; + FmgrInfo **procinfo = vacuumstate->procinfo; + Oid *collation = vacuumstate->collation; BufferAccessStrategy bas = vacuumstate->bas; HnswNeighborTuple ntup = vacuumstate->ntup; Size ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(element->level, m); @@ -205,7 +205,7 @@ RepairGraphElement(HnswVacuumState * vacuumstate, HnswElement element, HnswEleme element->heaptidsLength = 0; /* Find neighbors for element, skipping itself */ - HnswFindElementNeighbors(base, element, entryPoint, index, procinfo, collation, m, efConstruction, true); + HnswFindElementNeighbors(base, element, entryPoint, index, procinfo, collation, m, efConstruction, true, false); /* Zero memory for each element */ MemSet(ntup, 0, HNSW_TUPLE_ALLOC_SIZE); @@ -256,7 +256,7 @@ RepairGraphEntryPoint(HnswVacuumState * vacuumstate) LockPage(index, HNSW_UPDATE_LOCK, ShareLock); /* Load element */ - HnswLoadElement(highestPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true, NULL); + HnswLoadElement(highestPoint, NULL, NULL, NULL, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true, NULL); /* Repair if needed */ if (NeedsUpdated(vacuumstate, highestPoint)) @@ -294,7 +294,7 @@ RepairGraphEntryPoint(HnswVacuumState * vacuumstate) * is outdated, this can remove connections at higher levels in * the graph until they are repaired, but this should be fine. */ - HnswLoadElement(entryPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true, NULL); + HnswLoadElement(entryPoint, NULL, NULL, NULL, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true, NULL); if (NeedsUpdated(vacuumstate, entryPoint)) { @@ -370,7 +370,7 @@ RepairGraph(HnswVacuumState * vacuumstate) /* Create an element */ element = HnswInitElementFromBlock(blkno, offno); - HnswLoadElementFromTuple(element, etup, false, true); + HnswLoadElementFromTuple(element, etup, false, true, index); elements = lappend(elements, element); } @@ -440,6 +440,7 @@ MarkDeleted(HnswVacuumState * vacuumstate) BlockNumber insertPage = InvalidBlockNumber; Relation index = vacuumstate->index; BufferAccessStrategy bas = vacuumstate->bas; + bool useIndexTuple = HnswUseIndexTuple(index); /* * Wait for index scans to complete. Scans before this point may contain @@ -521,7 +522,14 @@ MarkDeleted(HnswVacuumState * vacuumstate) /* Overwrite element */ etup->deleted = 1; - MemSet(&etup->data, 0, VARSIZE_ANY(&etup->data)); + if (useIndexTuple) + { + IndexTuple itup = (IndexTuple) &etup->data; + + MemSet(itup, 0, IndexTupleSize(itup)); + } + else + MemSet(&etup->data, 0, VARSIZE_ANY(&etup->data)); /* Overwrite neighbors */ for (int i = 0; i < ntup->count; i++) @@ -573,8 +581,8 @@ InitVacuumState(HnswVacuumState * vacuumstate, IndexVacuumInfo *info, IndexBulkD vacuumstate->callback_state = callback_state; vacuumstate->efConstruction = HnswGetEfConstruction(index); vacuumstate->bas = GetAccessStrategy(BAS_BULKREAD); - vacuumstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); - vacuumstate->collation = index->rd_indcollation[0]; + HnswInitProcinfo(vacuumstate->procinfo, index); + vacuumstate->collation = index->rd_indcollation; vacuumstate->ntup = palloc0(HNSW_TUPLE_ALLOC_SIZE); vacuumstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, "Hnsw vacuum temporary context", diff --git a/test/t/041_hnsw_filtering.pl b/test/t/041_hnsw_filtering.pl new file mode 100644 index 0000000..a1be049 --- /dev/null +++ b/test/t/041_hnsw_filtering.pl @@ -0,0 +1,109 @@ +use strict; +use warnings FATAL => 'all'; +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use Test::More; + +my $node; +my @queries = (); +my @cs = (); +my @expected; +my $limit = 20; +my $dim = 3; +my $array_sql = join(",", ('random()') x $dim); +my $nc = 50; + +sub test_recall +{ + my ($min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst WHERE c = $cs[0] ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Cond/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SELECT i FROM tst WHERE c = $cs[$i] ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + my %actual_set = map { $_ => 1 } @actual_ids; + + is(scalar(@actual_ids), $limit); + + my @expected_ids = split("\n", $expected[$i]); + + foreach (@expected_ids) + { + if (exists($actual_set{$_})) + { + $correct++; + } + $total++; + } + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = PostgreSQL::Test::Cluster->new('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim), c int4);"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql], i % $nc FROM generate_series(1, 20000) i;" +); + +# Generate queries +for (1 .. 20) +{ + my @r = (); + for (1 .. $dim) + { + push(@r, rand()); + } + push(@queries, "[" . join(",", @r) . "]"); + push(@cs, int(rand() * $nc)); +} + +# Get exact results +@expected = (); +for my $i (0 .. $#queries) +{ + my $res = $node->safe_psql("postgres", "SELECT i FROM tst WHERE c = $cs[$i] ORDER BY v <-> '$queries[$i]' LIMIT $limit;"); + push(@expected, $res); +} + +# Add index +$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops, c);"); + +# Test recall +test_recall(0.99, '<->'); + +# Test vacuum +$node->safe_psql("postgres", "DELETE FROM tst WHERE c > 5;"); +$node->safe_psql("postgres", "VACUUM tst;"); + +# Test columns +my ($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING hnsw (c);"); +like($stderr, qr/first column must be a vector/); + +($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING hnsw (c, v vector_l2_ops);"); +like($stderr, qr/first column must be a vector/); + +($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops, c, c);"); +like($stderr, qr/index cannot have more than two columns/); + +($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops, v vector_l2_ops);"); +like($stderr, qr/column 2 cannot be a vector/); + +done_testing();