diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f74ea8..e25ec6a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.9.0 (unreleased) + +- Added support for inline filtering with HNSW + ## 0.8.2 (unreleased) - Improved `install` target on Windows diff --git a/README.md b/README.md index a108520..e241933 100644 --- a/README.md +++ b/README.md @@ -467,6 +467,12 @@ If filtering by many different values, consider [partitioning](https://www.postg CREATE TABLE items (embedding vector(3), category_id int) PARTITION BY LIST(category_id); ``` +Or a composite HNSW index (added in 0.9.0) + +```sql +CREATE INDEX ON items USING hnsw (embedding vector_l2_ops, category_id); +``` + ## Iterative Index Scans With approximate indexes, queries with filtering can return less results since filtering is applied *after* the index is scanned. Starting with 0.8.0, you can enable iterative index scans, which will automatically scan more of the index until enough results are found (or it reaches `hnsw.max_scan_tuples` or `ivfflat.max_probes`). @@ -1282,6 +1288,7 @@ Thanks to: - [k-means++: The Advantage of Careful Seeding](https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf) - [Concept Decompositions for Large Sparse Text Data using Clustering](https://www.cs.utexas.edu/users/inderjit/public_papers/concept_mlj.pdf) - [Efficient and Robust Approximate Nearest Neighbor Search using Hierarchical Navigable Small World Graphs](https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf) +- [HQANN: Efficient and Robust Similarity Search for Hybrid Queries with Structured and Unstructured Constraints](https://arxiv.org/pdf/2207.07940.pdf) ## History diff --git a/sql/vector--0.8.1--0.9.0.sql b/sql/vector--0.8.1--0.9.0.sql new file mode 100644 index 0000000..76ad21c --- /dev/null +++ b/sql/vector--0.8.1--0.9.0.sql @@ -0,0 +1,10 @@ +-- complain if script is sourced in psql, rather than via CREATE EXTENSION +\echo Use "ALTER EXTENSION vector UPDATE TO '0.9.0'" to load this file. \quit + +CREATE FUNCTION hnsw_attribute_distance(integer, integer) RETURNS float8 + AS 'MODULE_PATHNAME', 'hnsw_int4_attribute_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE OPERATOR CLASS vector_integer_ops + DEFAULT FOR TYPE integer USING hnsw AS + OPERATOR 2 = (integer, integer), + FUNCTION 4 hnsw_attribute_distance(integer, integer); diff --git a/sql/vector.sql b/sql/vector.sql index 7fc3671..adf70d2 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -916,3 +916,13 @@ CREATE OPERATOR CLASS sparsevec_l1_ops OPERATOR 1 <+> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 l1_distance(sparsevec, sparsevec), FUNCTION 3 hnsw_sparsevec_support(internal); + +-- hnsw attributes + +CREATE FUNCTION hnsw_attribute_distance(integer, integer) RETURNS float8 + AS 'MODULE_PATHNAME', 'hnsw_int4_attribute_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE OPERATOR CLASS vector_integer_ops + DEFAULT FOR TYPE integer USING hnsw AS + OPERATOR 2 = (integer, integer), + FUNCTION 4 hnsw_attribute_distance(integer, integer); diff --git a/src/hnsw.c b/src/hnsw.c index 1d56ef6..ac35ae2 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -263,7 +263,7 @@ hnswhandler(PG_FUNCTION_ARGS) IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); amroutine->amstrategies = 0; - amroutine->amsupport = 3; + amroutine->amsupport = 4; amroutine->amoptsprocnum = 0; amroutine->amcanorder = false; amroutine->amcanorderbyop = true; @@ -274,7 +274,7 @@ hnswhandler(PG_FUNCTION_ARGS) #endif amroutine->amcanbackward = false; /* can change direction mid-scan */ amroutine->amcanunique = false; - amroutine->amcanmulticol = false; + amroutine->amcanmulticol = true; amroutine->amoptionalkey = true; amroutine->amsearcharray = false; amroutine->amsearchnulls = false; @@ -334,3 +334,17 @@ hnswhandler(PG_FUNCTION_ARGS) PG_RETURN_POINTER(amroutine); } + +/* + * Get the distance between two int4 attributes + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(hnsw_int4_attribute_distance); +Datum +hnsw_int4_attribute_distance(PG_FUNCTION_ARGS) +{ + int32 a = PG_GETARG_INT32(0); + int32 b = PG_GETARG_INT32(1); + double distance = ((double) a) - ((double) b); + + PG_RETURN_FLOAT8(distance); +} diff --git a/src/hnsw.h b/src/hnsw.h index 5102bfb..82fcd7c 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -19,6 +19,7 @@ #define HNSW_DISTANCE_PROC 1 #define HNSW_NORM_PROC 2 #define HNSW_TYPE_INFO_PROC 3 +#define HNSW_ATTRIBUTE_DISTANCE_PROC 4 #define HNSW_VERSION 1 #define HNSW_MAGIC_NUMBER 0xA953A953 @@ -107,6 +108,8 @@ #define HnswPtrPointer(hp) (hp).ptr #define HnswPtrOffset(hp) relptr_offset((hp).relptr) +#define HnswUseIndexTuple(index) (IndexRelationGetNumberOfAttributes(index) > 1) + /* Variables */ extern int hnsw_ef_search; extern int hnsw_iterative_scan; @@ -134,6 +137,7 @@ HnswPtrDeclare(HnswElementData, HnswElementRelptr, HnswElementPtr); HnswPtrDeclare(HnswNeighborArray, HnswNeighborArrayRelptr, HnswNeighborArrayPtr); HnswPtrDeclare(HnswNeighborArrayPtr, HnswNeighborsRelptr, HnswNeighborsPtr); HnswPtrDeclare(char, DatumRelptr, DatumPtr); +HnswPtrDeclare(IndexTupleData, IndexTupleRelptr, IndexTuplePtr); struct HnswElementData { @@ -150,6 +154,7 @@ struct HnswElementData OffsetNumber neighborOffno; BlockNumber neighborPage; DatumPtr value; + IndexTuplePtr itup; LWLock lock; }; @@ -175,6 +180,7 @@ typedef struct HnswSearchCandidate pairingheap_node w_node; HnswElementPtr element; double distance; + bool matches; } HnswSearchCandidate; /* HNSW index options */ @@ -253,14 +259,16 @@ typedef struct HnswTypeInfo typedef struct HnswSupport { - FmgrInfo *procinfo; + FmgrInfo *procinfo[2]; FmgrInfo *normprocinfo; - Oid collation; + Oid *collation; } HnswSupport; typedef struct HnswQuery { Datum value; + IndexTuple itup; + ScanKeyData *keyData; } HnswQuery; typedef struct HnswBuildState @@ -289,6 +297,8 @@ typedef struct HnswBuildState HnswGraph *graph; double ml; int maxLevel; + bool useIndexTuple; + TupleDesc tupdesc; /* Memory */ MemoryContext graphCtx; @@ -417,30 +427,32 @@ bool HnswCheckNorm(HnswSupport * support, Datum value); Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); void HnswInit(void); -List *HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement, visited_hash * v, pairingheap **discarded, bool initVisited, int64 *tuples); +List *HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement, bool inMemory, visited_hash * v, pairingheap **discarded, bool initVisited, int64 *tuples); HnswElement HnswGetEntryPoint(Relation index); void HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint); void *HnswAlloc(HnswAllocator * allocator, Size size); HnswElement HnswInitElement(char *base, ItemPointer tid, int m, double ml, int maxLevel, HnswAllocator * alloc); HnswElement HnswInitElementFromBlock(BlockNumber blkno, OffsetNumber offno); -void HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, HnswSupport * support, int m, int efConstruction, bool existing); -HnswSearchCandidate *HnswEntryCandidate(char *base, HnswElement em, HnswQuery * q, Relation rel, HnswSupport * support, bool loadVec); +void HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, HnswSupport * support, int m, int efConstruction, bool existing, bool inMemory); +HnswSearchCandidate *HnswEntryCandidate(char *base, HnswElement em, HnswQuery * q, Relation rel, HnswSupport * support, bool loadVec, bool inMemory); void HnswUpdateMetaPage(Relation index, int updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum, bool building); void HnswSetNeighborTuple(char *base, HnswNeighborTuple ntup, HnswElement e, int m); void HnswAddHeapTid(HnswElement element, ItemPointer heaptid); HnswNeighborArray *HnswInitNeighborArray(int lm, HnswAllocator * allocator); void HnswInitNeighbors(char *base, HnswElement element, int m, HnswAllocator * alloc); -bool HnswInsertTupleOnDisk(Relation index, HnswSupport * support, Datum value, ItemPointer heaptid, bool building); +bool HnswInsertTupleOnDisk(Relation index, HnswSupport * support, IndexTuple itup, ItemPointer heaptid, bool building, TupleDesc tupdesc); void HnswUpdateNeighborsOnDisk(Relation index, HnswSupport * support, HnswElement e, int m, bool checkExisting, bool building); -void HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec); -void HnswLoadElement(HnswElement element, double *distance, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec, double *maxDistance); -bool HnswFormIndexValue(Datum *out, Datum *values, bool *isnull, const HnswTypeInfo * typeInfo, HnswSupport * support); -void HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element); +void HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec, Relation index); +void HnswLoadElement(HnswElement element, double *distance, bool *matches, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec, double *maxDistance); +void HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element, bool useIndexTuple); void HnswUpdateConnection(char *base, HnswNeighborArray * neighbors, HnswElement newElement, float distance, int lm, int *updateIdx, Relation index, HnswSupport * support); +bool HnswFormIndexTuple(IndexTuple *out, Datum *values, bool *isnull, const HnswTypeInfo * typeInfo, HnswSupport * support, TupleDesc tupdesc); bool HnswLoadNeighborTids(HnswElement element, ItemPointerData *indextids, Relation index, int m, int lm, int lc); void HnswInitLockTranche(void); const HnswTypeInfo *HnswGetTypeInfo(Relation index); PGDLLEXPORT void HnswParallelBuildMain(dsm_segment *seg, shm_toc *toc); +Size HnswGetElementTupleSize(char *base, HnswElement element, bool useIndexTuple); +bool HnswIndexTupleIsEqual(IndexTuple a, IndexTuple b, TupleDesc tupdesc); /* Index access methods */ IndexBuildResult *hnswbuild(Relation heap, Relation index, IndexInfo *indexInfo); diff --git a/src/hnswbuild.c b/src/hnswbuild.c index 03f0ef4..5a4b6d0 100644 --- a/src/hnswbuild.c +++ b/src/hnswbuild.c @@ -152,6 +152,7 @@ CreateGraphPages(HnswBuildState * buildstate) Page page; HnswElementPtr iter = buildstate->graph->head; char *base = buildstate->hnswarea; + bool useIndexTuple = buildstate->useIndexTuple; /* Calculate sizes */ maxSize = HNSW_MAX_SIZE; @@ -171,7 +172,6 @@ CreateGraphPages(HnswBuildState * buildstate) Size etupSize; Size ntupSize; Size combinedSize; - Pointer valuePtr = HnswPtrAccess(base, element->value); /* Update iterator */ iter = element->next; @@ -180,7 +180,7 @@ CreateGraphPages(HnswBuildState * buildstate) MemSet(etup, 0, HNSW_TUPLE_ALLOC_SIZE); /* Calculate sizes */ - etupSize = HNSW_ELEMENT_TUPLE_SIZE(VARSIZE_ANY(valuePtr)); + etupSize = HnswGetElementTupleSize(base, element, useIndexTuple); ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(element->level, buildstate->m); combinedSize = etupSize + ntupSize + sizeof(ItemIdData); @@ -190,7 +190,7 @@ CreateGraphPages(HnswBuildState * buildstate) (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED), errmsg("index tuple too large"))); - HnswSetElementTuple(base, etup, element); + HnswSetElementTuple(base, etup, element, useIndexTuple); /* Keep element and neighbors on the same page if possible */ if (PageGetFreeSpace(page) < etupSize || (combinedSize <= maxSize && PageGetFreeSpace(page) < combinedSize)) @@ -331,19 +331,18 @@ AddDuplicateInMemory(HnswElement element, HnswElement dup) * Find duplicate element */ static bool -FindDuplicateInMemory(char *base, HnswElement element) +FindDuplicateInMemory(char *base, HnswElement element, bool useIndexTuple, TupleDesc tupdesc) { HnswNeighborArray *neighbors = HnswGetNeighbors(base, element, 0); - Datum value = HnswGetValue(base, element); + IndexTuple itup = HnswPtrAccess(base, element->itup); for (int i = 0; i < neighbors->length; i++) { HnswCandidate *neighbor = &neighbors->items[i]; HnswElement neighborElement = HnswPtrAccess(base, neighbor->element); - Datum neighborValue = HnswGetValue(base, neighborElement); /* Exit early since ordered by distance */ - if (!datumIsEqual(value, neighborValue, false, -1)) + if (!HnswIndexTupleIsEqual(itup, HnswPtrAccess(base, neighborElement->itup), tupdesc)) return false; /* Check for space */ @@ -370,7 +369,7 @@ AddElementInMemory(char *base, HnswGraph * graph, HnswElement element) * Update neighbors */ static void -UpdateNeighborsInMemory(char *base, HnswSupport * support, HnswElement e, int m) +UpdateNeighborsInMemory(char *base, Relation index, HnswSupport * support, HnswElement e, int m) { for (int lc = e->level; lc >= 0; lc--) { @@ -392,7 +391,7 @@ UpdateNeighborsInMemory(char *base, HnswSupport * support, HnswElement e, int m) Assert(neighborElement); LWLockAcquire(&neighborElement->lock, LW_EXCLUSIVE); - HnswUpdateConnection(base, HnswGetNeighbors(base, neighborElement, lc), e, hc->distance, lm, NULL, NULL, support); + HnswUpdateConnection(base, HnswGetNeighbors(base, neighborElement, lc), e, hc->distance, lm, NULL, index, support); LWLockRelease(&neighborElement->lock); } } @@ -408,14 +407,14 @@ UpdateGraphInMemory(HnswSupport * support, HnswElement element, int m, HnswEleme char *base = buildstate->hnswarea; /* Look for duplicate */ - if (FindDuplicateInMemory(base, element)) + if (FindDuplicateInMemory(base, element, buildstate->useIndexTuple, buildstate->tupdesc)) return; /* Add element */ AddElementInMemory(base, graph, element); /* Update neighbors */ - UpdateNeighborsInMemory(base, support, element, m); + UpdateNeighborsInMemory(base, buildstate->index, support, element, m); /* Update entry point if needed (already have lock) */ if (entryPoint == NULL || element->level > entryPoint->level) @@ -428,6 +427,7 @@ UpdateGraphInMemory(HnswSupport * support, HnswElement element, int m, HnswEleme static void InsertTupleInMemory(HnswBuildState * buildstate, HnswElement element) { + Relation index = buildstate->index; HnswGraph *graph = buildstate->graph; HnswSupport *support = &buildstate->support; HnswElement entryPoint; @@ -461,7 +461,7 @@ InsertTupleInMemory(HnswBuildState * buildstate, HnswElement element) } /* Find neighbors for element */ - HnswFindElementNeighbors(base, element, entryPoint, NULL, support, m, efConstruction, false); + HnswFindElementNeighbors(base, element, entryPoint, index, support, m, efConstruction, false, true); /* Update graph in memory */ UpdateGraphInMemory(support, element, m, entryPoint, buildstate); @@ -480,18 +480,20 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn HnswElement element; HnswAllocator *allocator = &buildstate->allocator; HnswSupport *support = &buildstate->support; - Size valueSize; - Pointer valuePtr; LWLock *flushLock = &graph->flushLock; char *base = buildstate->hnswarea; - Datum value; + TupleDesc tupdesc = buildstate->tupdesc; + IndexTuple itup; + Size itupSize; + IndexTuple itupShared; + bool unused; - /* Form index value */ - if (!HnswFormIndexValue(&value, values, isnull, buildstate->typeInfo, support)) + /* Form index tuple */ + if (!HnswFormIndexTuple(&itup, values, isnull, buildstate->typeInfo, support, tupdesc)) return false; - /* Get datum size */ - valueSize = VARSIZE_ANY(DatumGetPointer(value)); + /* Get tuple size */ + itupSize = IndexTupleSize(itup); /* Ensure graph not flushed when inserting */ LWLockAcquire(flushLock, LW_SHARED); @@ -501,7 +503,7 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn { LWLockRelease(flushLock); - return HnswInsertTupleOnDisk(index, support, value, heaptid, true); + return HnswInsertTupleOnDisk(index, support, itup, heaptid, true, tupdesc); } /* @@ -533,12 +535,12 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn LWLockRelease(flushLock); - return HnswInsertTupleOnDisk(index, support, value, heaptid, true); + return HnswInsertTupleOnDisk(index, support, itup, heaptid, true, tupdesc); } /* Ok, we can proceed to allocate the element */ element = HnswInitElement(base, heaptid, buildstate->m, buildstate->ml, buildstate->maxLevel, allocator); - valuePtr = HnswAlloc(allocator, valueSize); + itupShared = HnswAlloc(allocator, itupSize); /* * We have now allocated the space needed for the element, so we don't @@ -547,9 +549,10 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn */ LWLockRelease(&graph->allocatorLock); - /* Copy the datum */ - memcpy(valuePtr, DatumGetPointer(value), valueSize); - HnswPtrStore(base, element->value, valuePtr); + /* Copy the tuple */ + memcpy(itupShared, itup, itupSize); + HnswPtrStore(base, element->itup, itupShared); + HnswPtrStore(base, element->value, DatumGetPointer(index_getattr(itupShared, 1, tupdesc, &unused))); /* Create a lock for the element */ LWLockInitialize(&element->lock, hnsw_lock_tranche_id); @@ -676,6 +679,19 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), errmsg("type not supported for hnsw index"))); + /* TODO See if needed */ + if (IndexRelationGetNumberOfKeyAttributes(index) > 2) + elog(ERROR, "index cannot have more than two columns"); + + if (!OidIsValid(index_getprocid(index, 1, HNSW_DISTANCE_PROC))) + elog(ERROR, "first column must be a vector"); + + for (int i = 1; i < IndexRelationGetNumberOfKeyAttributes(index); i++) + { + if (!OidIsValid(index_getprocid(index, i + 1, HNSW_ATTRIBUTE_DISTANCE_PROC))) + elog(ERROR, "column %d cannot be a vector", i + 1); + } + /* Require column to have dimensions to be indexed */ if (buildstate->dimensions < 0) ereport(ERROR, @@ -702,6 +718,8 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index buildstate->graph = &buildstate->graphData; buildstate->ml = HnswGetMl(buildstate->m); buildstate->maxLevel = HnswGetMaxLevel(buildstate->m); + buildstate->useIndexTuple = HnswUseIndexTuple(index); + buildstate->tupdesc = RelationGetDescr(index); buildstate->graphCtx = GenerationContextCreate(CurrentMemoryContext, "Hnsw build graph context", diff --git a/src/hnswinsert.c b/src/hnswinsert.c index a4d2885..fc9d975 100644 --- a/src/hnswinsert.c +++ b/src/hnswinsert.c @@ -160,9 +160,10 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B BlockNumber newInsertPage = InvalidBlockNumber; uint8 tupleVersion; char *base = NULL; + bool useIndexTuple = HnswUseIndexTuple(index); /* Calculate sizes */ - etupSize = HNSW_ELEMENT_TUPLE_SIZE(VARSIZE_ANY(HnswPtrAccess(base, e->value))); + etupSize = HnswGetElementTupleSize(base, e, useIndexTuple); ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(e->level, m); combinedSize = etupSize + ntupSize + sizeof(ItemIdData); maxSize = HNSW_MAX_SIZE; @@ -170,7 +171,7 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B /* Prepare element tuple */ etup = palloc0(etupSize); - HnswSetElementTuple(base, etup, e); + HnswSetElementTuple(base, etup, e, useIndexTuple); /* Prepare neighbor tuple */ ntup = palloc0(ntupSize); @@ -387,8 +388,9 @@ LoadElementsForInsert(HnswNeighborArray * neighbors, HnswQuery * q, int *idx, Re HnswCandidate *hc = &neighbors->items[i]; HnswElement element = HnswPtrAccess(base, hc->element); double distance; + bool matches; - HnswLoadElement(element, &distance, q, index, support, true, NULL); + HnswLoadElement(element, &distance, &matches, q, index, support, true, NULL); hc->distance = distance; /* Prune element if being deleted */ @@ -432,6 +434,8 @@ GetUpdateIndex(HnswElement element, HnswElement newElement, float distance, int HnswQuery q; q.value = HnswGetValue(base, element); + q.itup = HnswPtrAccess(base, element->itup); + q.keyData = NULL; LoadElementsForInsert(neighbors, &q, &idx, index, support); @@ -637,21 +641,30 @@ AddDuplicateOnDisk(Relation index, HnswElement element, HnswElement dup, bool bu * Find duplicate element */ static bool -FindDuplicateOnDisk(Relation index, HnswElement element, bool building) +FindDuplicateOnDisk(Relation index, HnswElement element, bool building, TupleDesc tupdesc) { char *base = NULL; HnswNeighborArray *neighbors = HnswGetNeighbors(base, element, 0); Datum value = HnswGetValue(base, element); + IndexTuple itup = HnswPtrAccess(base, element->itup); for (int i = 0; i < neighbors->length; i++) { HnswCandidate *neighbor = &neighbors->items[i]; HnswElement neighborElement = HnswPtrAccess(base, neighbor->element); - Datum neighborValue = HnswGetValue(base, neighborElement); - /* Exit early since ordered by distance */ - if (!datumIsEqual(value, neighborValue, false, -1)) - return false; + if (HnswUseIndexTuple(index)) + { + /* Exit early since ordered by distance */ + if (!HnswIndexTupleIsEqual(itup, HnswPtrAccess(base, neighborElement->itup), tupdesc)) + return false; + } + else + { + /* Exit early since ordered by distance */ + if (!datumIsEqual(value, HnswGetValue(base, neighborElement), false, -1)) + return false; + } if (AddDuplicateOnDisk(index, element, neighborElement, building)) return true; @@ -664,12 +677,12 @@ FindDuplicateOnDisk(Relation index, HnswElement element, bool building) * Update graph on disk */ static void -UpdateGraphOnDisk(Relation index, HnswSupport * support, HnswElement element, int m, HnswElement entryPoint, bool building) +UpdateGraphOnDisk(Relation index, HnswSupport * support, HnswElement element, int m, HnswElement entryPoint, bool building, TupleDesc tupdesc) { BlockNumber newInsertPage = InvalidBlockNumber; /* Look for duplicate */ - if (FindDuplicateOnDisk(index, element, building)) + if (FindDuplicateOnDisk(index, element, building, tupdesc)) return; /* Add element */ @@ -691,7 +704,7 @@ UpdateGraphOnDisk(Relation index, HnswSupport * support, HnswElement element, in * Insert a tuple into the index */ bool -HnswInsertTupleOnDisk(Relation index, HnswSupport * support, Datum value, ItemPointer heaptid, bool building) +HnswInsertTupleOnDisk(Relation index, HnswSupport * support, IndexTuple itup, ItemPointer heaptid, bool building, TupleDesc tupdesc) { HnswElement entryPoint; HnswElement element; @@ -699,6 +712,7 @@ HnswInsertTupleOnDisk(Relation index, HnswSupport * support, Datum value, ItemPo int efConstruction = HnswGetEfConstruction(index); LOCKMODE lockmode = ShareLock; char *base = NULL; + bool unused; /* * Get a shared lock. This allows vacuum to ensure no in-flight inserts @@ -712,7 +726,8 @@ HnswInsertTupleOnDisk(Relation index, HnswSupport * support, Datum value, ItemPo /* Create an element */ element = HnswInitElement(base, heaptid, m, HnswGetMl(m), HnswGetMaxLevel(m), NULL); - HnswPtrStore(base, element->value, DatumGetPointer(value)); + HnswPtrStore(base, element->itup, itup); + HnswPtrStore(base, element->value, DatumGetPointer(index_getattr(itup, 1, tupdesc, &unused))); /* Prevent concurrent inserts when likely updating entry point */ if (entryPoint == NULL || element->level > entryPoint->level) @@ -729,10 +744,10 @@ HnswInsertTupleOnDisk(Relation index, HnswSupport * support, Datum value, ItemPo } /* Find neighbors for element */ - HnswFindElementNeighbors(base, element, entryPoint, index, support, m, efConstruction, false); + HnswFindElementNeighbors(base, element, entryPoint, index, support, m, efConstruction, false, false); /* Update graph on disk */ - UpdateGraphOnDisk(index, support, element, m, entryPoint, building); + UpdateGraphOnDisk(index, support, element, m, entryPoint, building, tupdesc); /* Release lock */ UnlockPage(index, HNSW_UPDATE_LOCK, lockmode); @@ -746,17 +761,18 @@ HnswInsertTupleOnDisk(Relation index, HnswSupport * support, Datum value, ItemPo static void HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid) { - Datum value; + IndexTuple itup; const HnswTypeInfo *typeInfo = HnswGetTypeInfo(index); + TupleDesc tupdesc = RelationGetDescr(index); HnswSupport support; HnswInitSupport(&support, index); - /* Form index value */ - if (!HnswFormIndexValue(&value, values, isnull, typeInfo, &support)) + /* Form index tuple */ + if (!HnswFormIndexTuple(&itup, values, isnull, typeInfo, &support, tupdesc)) return; - HnswInsertTupleOnDisk(index, &support, value, heaptid, false); + HnswInsertTupleOnDisk(index, &support, itup, heaptid, false, tupdesc); } /* diff --git a/src/hnswscan.c b/src/hnswscan.c index 5c526f4..5f88248 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -22,26 +22,30 @@ GetScanItems(IndexScanDesc scan, Datum value) int m; HnswElement entryPoint; char *base = NULL; + bool inMemory = false; HnswQuery *q = &so->q; + q->value = value; + q->itup = NULL; + q->keyData = scan->keyData; + /* Get m and entry point */ HnswGetMetaPageInfo(index, &m, &entryPoint); - q->value = value; so->m = m; if (entryPoint == NULL) return NIL; - ep = list_make1(HnswEntryCandidate(base, entryPoint, q, index, support, false)); + ep = list_make1(HnswEntryCandidate(base, entryPoint, q, index, support, false, inMemory)); for (int lc = entryPoint->level; lc >= 1; lc--) { - w = HnswSearchLayer(base, q, ep, 1, lc, index, support, m, false, NULL, NULL, NULL, true, NULL); + w = HnswSearchLayer(base, q, ep, 1, lc, index, support, m, false, NULL, inMemory, NULL, NULL, true, NULL); ep = w; } - return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, support, m, false, NULL, &so->v, hnsw_iterative_scan != HNSW_ITERATIVE_SCAN_OFF ? &so->discarded : NULL, true, &so->tuples); + return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, support, m, false, NULL, inMemory, &so->v, hnsw_iterative_scan != HNSW_ITERATIVE_SCAN_OFF ? &so->discarded : NULL, true, &so->tuples); } /* @@ -72,7 +76,7 @@ ResumeScanItems(IndexScanDesc scan) ep = lappend(ep, sc); } - return HnswSearchLayer(base, &so->q, ep, batch_size, 0, index, &so->support, so->m, false, NULL, &so->v, &so->discarded, false, &so->tuples); + return HnswSearchLayer(base, &so->q, ep, batch_size, 0, index, &so->support, so->m, false, NULL, false, &so->v, &so->discarded, false, &so->tuples); } /* @@ -96,7 +100,7 @@ GetScanValue(IndexScanDesc scan) /* Normalize if needed */ if (so->support.normprocinfo != NULL) - value = HnswNormValue(so->typeInfo, so->support.collation, value); + value = HnswNormValue(so->typeInfo, so->support.collation[0], value); } return value; @@ -283,7 +287,7 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir) element = HnswPtrAccess(base, sc->element); /* Move to next element if no valid heap TIDs */ - if (element->heaptidsLength == 0) + if (!sc->matches || element->heaptidsLength == 0) { so->w = list_delete_last(so->w); diff --git a/src/hnswutils.c b/src/hnswutils.c index 8e2a42c..a7e3cb8 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -150,11 +150,39 @@ HnswOptionalProcInfo(Relation index, uint16 procnum) void HnswInitSupport(HnswSupport * support, Relation index) { - support->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); - support->collation = index->rd_indcollation[0]; + support->procinfo[0] = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); + + if (IndexRelationGetNumberOfKeyAttributes(index) > 1) + support->procinfo[1] = index_getprocinfo(index, 2, HNSW_ATTRIBUTE_DISTANCE_PROC); + + support->collation = index->rd_indcollation; support->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); } +/* + * Get element tuple size + */ +Size +HnswGetElementTupleSize(char *base, HnswElement element, bool useIndexTuple) +{ + Size size; + + if (useIndexTuple) + { + IndexTuple itup = HnswPtrAccess(base, element->itup); + + size = IndexTupleSize(itup); + } + else + { + Pointer valuePtr = HnswPtrAccess(base, element->value); + + size = VARSIZE_ANY(valuePtr); + } + + return HNSW_ELEMENT_TUPLE_SIZE(size); +} + /* * Normalize value */ @@ -170,7 +198,38 @@ HnswNormValue(const HnswTypeInfo * typeInfo, Oid collation, Datum value) bool HnswCheckNorm(HnswSupport * support, Datum value) { - return DatumGetFloat8(FunctionCall1Coll(support->normprocinfo, support->collation, value)) > 0; + return DatumGetFloat8(FunctionCall1Coll(support->normprocinfo, support->collation[0], value)) > 0; +} + +/* + * Check if index tuples are equal + */ +bool +HnswIndexTupleIsEqual(IndexTuple a, IndexTuple b, TupleDesc tupdesc) +{ + for (int i = 0; i < tupdesc->natts; i++) + { + bool nullA; + bool nullB; + + Datum datumA = index_getattr(a, i + 1, tupdesc, &nullA); + Datum datumB = index_getattr(b, i + 1, tupdesc, &nullB); + + if (nullA || nullB) + { + if (nullA != nullB) + return false; + } + else + { + Form_pg_attribute att = TupleDescAttr(tupdesc, i); + + if (!datumIsEqual(datumA, datumB, att->attbyval, att->attlen)) + return false; + } + } + + return true; } /* @@ -261,6 +320,7 @@ HnswInitElement(char *base, ItemPointer heaptid, int m, double ml, int maxLevel, HnswInitNeighbors(base, element, m, allocator); HnswPtrStore(base, element->value, (Pointer) NULL); + HnswPtrStore(base, element->itup, (IndexTuple) NULL); return element; } @@ -287,6 +347,7 @@ HnswInitElementFromBlock(BlockNumber blkno, OffsetNumber offno) element->offno = offno; HnswPtrStore(base, element->neighbors, (HnswNeighborArrayPtr *) NULL); HnswPtrStore(base, element->value, (Pointer) NULL); + HnswPtrStore(base, element->itup, (IndexTuple) NULL); return element; } @@ -399,11 +460,13 @@ HnswUpdateMetaPage(Relation index, int updateEntry, HnswElement entryPoint, Bloc } /* - * Form index value + * Form index tuple */ bool -HnswFormIndexValue(Datum *out, Datum *values, bool *isnull, const HnswTypeInfo * typeInfo, HnswSupport * support) +HnswFormIndexTuple(IndexTuple *out, Datum *values, bool *isnull, const HnswTypeInfo * typeInfo, HnswSupport * support, TupleDesc tupdesc) { + Datum newValues[2]; + /* Detoast once for all calls */ Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); @@ -417,10 +480,14 @@ HnswFormIndexValue(Datum *out, Datum *values, bool *isnull, const HnswTypeInfo * if (!HnswCheckNorm(support, value)) return false; - value = HnswNormValue(typeInfo, support->collation, value); + value = HnswNormValue(typeInfo, support->collation[0], value); } - *out = value; + newValues[0] = value; + for (int i = 1; i < tupdesc->natts; i++) + newValues[i] = values[i]; + + *out = index_form_tuple(tupdesc, newValues, isnull); return true; } @@ -429,10 +496,8 @@ HnswFormIndexValue(Datum *out, Datum *values, bool *isnull, const HnswTypeInfo * * Set element tuple, except for neighbor info */ void -HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element) +HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element, bool useIndexTuple) { - Pointer valuePtr = HnswPtrAccess(base, element->value); - etup->type = HNSW_ELEMENT_TUPLE_TYPE; etup->level = element->level; etup->deleted = 0; @@ -444,7 +509,19 @@ HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element) else ItemPointerSetInvalid(&etup->heaptids[i]); } - memcpy(&etup->data, valuePtr, VARSIZE_ANY(valuePtr)); + + if (useIndexTuple) + { + IndexTuple itup = HnswPtrAccess(base, element->itup); + + memcpy(&etup->data, itup, IndexTupleSize(itup)); + } + else + { + Pointer valuePtr = HnswPtrAccess(base, element->value); + + memcpy(&etup->data, valuePtr, VARSIZE_ANY(valuePtr)); + } } /* @@ -486,7 +563,7 @@ HnswSetNeighborTuple(char *base, HnswNeighborTuple ntup, HnswElement e, int m) * Load an element from a tuple */ void -HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec) +HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec, Relation index) { element->level = etup->level; element->deleted = etup->deleted; @@ -510,26 +587,128 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe if (loadVec) { char *base = NULL; - Datum value = datumCopy(PointerGetDatum(&etup->data), false, -1); - HnswPtrStore(base, element->value, DatumGetPointer(value)); + if (HnswUseIndexTuple(index)) + { + IndexTuple itup = CopyIndexTuple((IndexTuple) &etup->data); + TupleDesc tupdesc = RelationGetDescr(index); + bool unused; + + HnswPtrStore(base, element->itup, itup); + HnswPtrStore(base, element->value, DatumGetPointer(index_getattr(itup, 1, tupdesc, &unused))); + } + else + { + Datum value = datumCopy(PointerGetDatum(&etup->data), false, -1); + + HnswPtrStore(base, element->value, DatumGetPointer(value)); + } } } +/* + * Get the attribute distance + */ +static inline double +AttributeDistance(double e) +{ + /* TODO Better bias */ + /* must be >> max(w * g) + 1 / log10(2) */ + double bias = 4.32; + + return e > 0 ? bias - 1.0 / log10(e + 1) : 0; +} + /* * Calculate the distance between values */ -static inline double -HnswGetDistance(Datum a, Datum b, HnswSupport * support) +static double +HnswGetDistance(IndexTuple itup, Datum vec, HnswQuery * q, Relation index, HnswSupport * support, bool *matches) { - return DatumGetFloat8(FunctionCall2Coll(support->procinfo, support->collation, a, b)); + double g; + + if (DatumGetPointer(q->value) == NULL) + g = 0; + else + g = DatumGetFloat8(FunctionCall2Coll(support->procinfo[0], support->collation[0], q->value, vec)); + + Assert(PointerIsValid(matches)); + *matches = true; + + if (IndexRelationGetNumberOfKeyAttributes(index) > 1) + { + double w = 0.25; + double e = 0.0; + TupleDesc tupdesc = RelationGetDescr(index); + + if (q->keyData) + { + /* TODO need to pass length of key data */ + int keyCount = 1; + + for (int i = 0; i < keyCount; i++) + { + ScanKey key = &q->keyData[i]; + bool isnull; + Datum value = index_getattr(itup, key->sk_attno, tupdesc, &isnull); + bool attnull = key->sk_flags & SK_ISNULL; + + if (isnull || attnull) + { + if (isnull != attnull) + { + e += 1000; + *matches = false; + } + } + else if (!DatumGetBool(FunctionCall2Coll(&key->sk_func, key->sk_collation, value, key->sk_argument))) + { + double ei = fabs(DatumGetFloat8(FunctionCall2Coll(support->procinfo[key->sk_attno - 1], support->collation[key->sk_attno - 1], value, key->sk_argument))); + + if (ei > 0) + e += ei; + else + /* Distance is zero for inequality */ + e += 1000; + + *matches = false; + } + } + + return w * g + AttributeDistance(e); + } + else if (q->itup) + { + int keyCount = IndexRelationGetNumberOfKeyAttributes(index) - 1; + + for (int i = 0; i < keyCount; i++) + { + bool isnull; + bool attnull; + Datum value = index_getattr(itup, i + 2, tupdesc, &isnull); + Datum value2 = index_getattr(q->itup, i + 2, tupdesc, &attnull); + + if (isnull || attnull) + { + if (isnull != attnull) + e += 1000; + } + else + e += fabs(DatumGetFloat8(FunctionCall2Coll(support->procinfo[i + 1], support->collation[i + 1], value, value2))); + } + + return w * g + AttributeDistance(e); + } + } + + return g; } /* * Load an element and optionally get its distance from q */ static void -HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec, double *maxDistance, HnswElement * element) +HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, bool *matches, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec, double *maxDistance, HnswElement * element) { Buffer buf; Page page; @@ -547,10 +726,23 @@ HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, Hns /* Calculate distance */ if (distance != NULL) { - if (DatumGetPointer(q->value) == NULL) - *distance = 0; + IndexTuple itup = NULL; + Datum value; + + if (HnswUseIndexTuple(index)) + { + TupleDesc tupdesc = RelationGetDescr(index); + bool unused; + + itup = (IndexTuple) &etup->data; + value = index_getattr(itup, 1, tupdesc, &unused); + } else - *distance = HnswGetDistance(q->value, PointerGetDatum(&etup->data), support); + { + value = PointerGetDatum(&etup->data); + } + + *distance = HnswGetDistance(itup, value, q, index, support, matches); } /* Load element */ @@ -559,7 +751,7 @@ HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, Hns if (*element == NULL) *element = HnswInitElementFromBlock(blkno, offno); - HnswLoadElementFromTuple(*element, etup, true, loadVec); + HnswLoadElementFromTuple(*element, etup, true, loadVec, index); } UnlockReleaseBuffer(buf); @@ -569,32 +761,34 @@ HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, Hns * Load an element and optionally get its distance from q */ void -HnswLoadElement(HnswElement element, double *distance, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec, double *maxDistance) +HnswLoadElement(HnswElement element, double *distance, bool *matches, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec, double *maxDistance) { - HnswLoadElementImpl(element->blkno, element->offno, distance, q, index, support, loadVec, maxDistance, &element); + HnswLoadElementImpl(element->blkno, element->offno, distance, matches, q, index, support, loadVec, maxDistance, &element); } /* * Get the distance for an element */ static double -GetElementDistance(char *base, HnswElement element, HnswQuery * q, HnswSupport * support) +GetElementDistance(char *base, HnswElement element, bool *matches, HnswQuery * q, Relation index, HnswSupport * support) { Datum value = HnswGetValue(base, element); + IndexTuple itup = HnswPtrAccess(base, element->itup); - return HnswGetDistance(q->value, value, support); + return HnswGetDistance(itup, value, q, index, support, matches); } /* * Allocate a search candidate */ static HnswSearchCandidate * -HnswInitSearchCandidate(char *base, HnswElement element, double distance) +HnswInitSearchCandidate(char *base, HnswElement element, double distance, bool matches) { HnswSearchCandidate *sc = palloc(sizeof(HnswSearchCandidate)); HnswPtrStore(base, sc->element, element); sc->distance = distance; + sc->matches = matches; return sc; } @@ -602,17 +796,17 @@ HnswInitSearchCandidate(char *base, HnswElement element, double distance) * Create a candidate for the entry point */ HnswSearchCandidate * -HnswEntryCandidate(char *base, HnswElement entryPoint, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec) +HnswEntryCandidate(char *base, HnswElement entryPoint, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec, bool inMemory) { - bool inMemory = index == NULL; double distance; + bool matches; if (inMemory) - distance = GetElementDistance(base, entryPoint, q, support); + distance = GetElementDistance(base, entryPoint, &matches, q, index, support); else - HnswLoadElement(entryPoint, &distance, q, index, support, loadVec, NULL); + HnswLoadElement(entryPoint, &distance, &matches, q, index, support, loadVec, NULL); - return HnswInitSearchCandidate(base, entryPoint, distance); + return HnswInitSearchCandidate(base, entryPoint, distance, matches); } /* @@ -815,7 +1009,7 @@ HnswLoadUnvisitedFromDisk(HnswElement element, HnswUnvisited * unvisited, int *u * Algorithm 2 from paper */ List * -HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement, visited_hash * v, pairingheap **discarded, bool initVisited, int64 *tuples) +HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement, bool inMemory, visited_hash * v, pairingheap **discarded, bool initVisited, int64 *tuples) { List *w = NIL; pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL); @@ -828,7 +1022,8 @@ HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation in int lm = HnswGetLayerM(m, lc); HnswUnvisited *unvisited = palloc(lm * sizeof(HnswUnvisited)); int unvisitedLength; - bool inMemory = index == NULL; + uint64 additional = 0; + uint64 maxAdditional = q->keyData && lc == 0 ? 10000 : 0; if (v == NULL) { @@ -869,6 +1064,10 @@ HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation in pairingheap_add(C, &sc->c_node); pairingheap_add(W, &sc->w_node); + /* Do not count elements that do not match filter towards ef */ + if (!sc->matches && ++additional <= maxAdditional) + continue; + /* * Do not count elements being deleted towards ef when vacuuming. It * would be ideal to do this for inserts as well, but this could @@ -903,6 +1102,7 @@ HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation in HnswElement eElement; HnswSearchCandidate *e; double eDistance; + bool eMatches; bool alwaysAdd = wlen < ef; f = HnswGetSearchCandidate(w_node, pairingheap_first(W)); @@ -910,7 +1110,7 @@ HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation in if (inMemory) { eElement = unvisited[i].element; - eDistance = GetElementDistance(base, eElement, q, support); + eDistance = GetElementDistance(base, eElement, &eMatches, q, index, support); } else { @@ -920,7 +1120,7 @@ HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation in /* Avoid any allocations if not adding */ eElement = NULL; - HnswLoadElementImpl(blkno, offno, &eDistance, q, index, support, inserting, alwaysAdd || discarded != NULL ? NULL : &f->distance, &eElement); + HnswLoadElementImpl(blkno, offno, &eDistance, &eMatches, q, index, support, inserting, alwaysAdd || discarded != NULL ? NULL : &f->distance, &eElement); if (eElement == NULL) continue; @@ -931,7 +1131,7 @@ HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation in if (discarded != NULL) { /* Create a new candidate */ - e = HnswInitSearchCandidate(base, eElement, eDistance); + e = HnswInitSearchCandidate(base, eElement, eDistance, eMatches); pairingheap_add(*discarded, &e->w_node); } @@ -943,7 +1143,7 @@ HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation in continue; /* Create a new candidate */ - e = HnswInitSearchCandidate(base, eElement, eDistance); + e = HnswInitSearchCandidate(base, eElement, eDistance, eMatches); pairingheap_add(C, &e->c_node); pairingheap_add(W, &e->w_node); @@ -954,6 +1154,10 @@ HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation in */ if (CountElement(skipElement, eElement)) { + /* Do not count elements that do not match filter towards ef */ + if (!e->matches && ++additional <= maxAdditional) + continue; + wlen++; /* No need to decrement wlen */ @@ -1031,18 +1235,24 @@ CompareCandidateDistancesOffset(const ListCell *a, const ListCell *b) * Check if an element is closer to q than any element from R */ static bool -CheckElementCloser(char *base, HnswCandidate * e, List *r, HnswSupport * support) +CheckElementCloser(char *base, HnswCandidate * e, List *r, Relation index, HnswSupport * support) { HnswElement eElement = HnswPtrAccess(base, e->element); - Datum eValue = HnswGetValue(base, eElement); + HnswQuery q; ListCell *lc2; + q.value = HnswGetValue(base, eElement); + q.itup = HnswPtrAccess(base, eElement->itup); + q.keyData = NULL; + foreach(lc2, r) { HnswCandidate *ri = lfirst(lc2); HnswElement riElement = HnswPtrAccess(base, ri->element); Datum riValue = HnswGetValue(base, riElement); - float distance = HnswGetDistance(eValue, riValue, support); + IndexTuple ritup = HnswPtrAccess(base, riElement->itup); + bool matches; + float distance = HnswGetDistance(ritup, riValue, &q, index, support, &matches); if (distance <= e->distance) return false; @@ -1055,7 +1265,7 @@ CheckElementCloser(char *base, HnswCandidate * e, List *r, HnswSupport * support * Algorithm 4 from paper */ static List * -SelectNeighbors(char *base, List *c, int lm, HnswSupport * support, bool *closerSet, HnswCandidate * newCandidate, HnswCandidate * *pruned, bool sortCandidates) +SelectNeighbors(char *base, List *c, int lm, Relation index, HnswSupport * support, bool *closerSet, HnswCandidate * newCandidate, HnswCandidate * *pruned, bool sortCandidates) { List *r = NIL; List *w = list_copy(c); @@ -1089,7 +1299,7 @@ SelectNeighbors(char *base, List *c, int lm, HnswSupport * support, bool *closer /* Use previous state of r and wd to skip work when possible */ if (mustCalculate) - e->closer = CheckElementCloser(base, e, r, support); + e->closer = CheckElementCloser(base, e, r, index, support); else if (list_length(added) > 0) { /* Keep Valgrind happy for in-memory, parallel builds */ @@ -1102,8 +1312,7 @@ SelectNeighbors(char *base, List *c, int lm, HnswSupport * support, bool *closer */ if (e->closer) { - e->closer = CheckElementCloser(base, e, added, support); - + e->closer = CheckElementCloser(base, e, added, index, support); if (!e->closer) removedAny = true; } @@ -1115,7 +1324,7 @@ SelectNeighbors(char *base, List *c, int lm, HnswSupport * support, bool *closer */ if (removedAny) { - e->closer = CheckElementCloser(base, e, r, support); + e->closer = CheckElementCloser(base, e, r, index, support); if (e->closer) added = lappend(added, e); } @@ -1123,7 +1332,7 @@ SelectNeighbors(char *base, List *c, int lm, HnswSupport * support, bool *closer } else if (e == newCandidate) { - e->closer = CheckElementCloser(base, e, r, support); + e->closer = CheckElementCloser(base, e, r, index, support); if (e->closer) added = lappend(added, e); } @@ -1200,7 +1409,7 @@ HnswUpdateConnection(char *base, HnswNeighborArray * neighbors, HnswElement newE c = lappend(c, &neighbors->items[i]); c = lappend(c, &newHc); - SelectNeighbors(base, c, lm, support, &neighbors->closerSet, &newHc, &pruned, true); + SelectNeighbors(base, c, lm, index, support, &neighbors->closerSet, &newHc, &pruned, true); /* Should not happen */ if (pruned == NULL) @@ -1271,17 +1480,19 @@ PrecomputeHash(char *base, HnswElement element) * Algorithm 1 from paper */ void -HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, HnswSupport * support, int m, int efConstruction, bool existing) +HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, HnswSupport * support, int m, int efConstruction, bool existing, bool inMemory) { List *ep; List *w; int level = element->level; int entryLevel; HnswQuery q; + HnswElement skipElement = existing ? element : NULL; - bool inMemory = index == NULL; q.value = HnswGetValue(base, element); + q.itup = HnswPtrAccess(base, element->itup); + q.keyData = NULL; /* Precompute hash */ if (inMemory) @@ -1292,13 +1503,13 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint return; /* Get entry point and level */ - ep = list_make1(HnswEntryCandidate(base, entryPoint, &q, index, support, true)); + ep = list_make1(HnswEntryCandidate(base, entryPoint, &q, index, support, true, inMemory)); entryLevel = entryPoint->level; /* 1st phase: greedy search to insert level */ for (int lc = entryLevel; lc >= level + 1; lc--) { - w = HnswSearchLayer(base, &q, ep, 1, lc, index, support, m, true, skipElement, NULL, NULL, true, NULL); + w = HnswSearchLayer(base, &q, ep, 1, lc, index, support, m, true, skipElement, inMemory, NULL, NULL, true, NULL); ep = w; } @@ -1317,7 +1528,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint List *lw = NIL; ListCell *lc2; - w = HnswSearchLayer(base, &q, ep, efConstruction, lc, index, support, m, true, skipElement, NULL, NULL, true, NULL); + w = HnswSearchLayer(base, &q, ep, efConstruction, lc, index, support, m, true, skipElement, inMemory, NULL, NULL, true, NULL); /* Convert search candidates to candidates */ foreach(lc2, w) @@ -1341,7 +1552,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint * sortCandidates to true for in-memory builds to enable closer * caching, but there does not seem to be a difference in performance. */ - neighbors = SelectNeighbors(base, lw, lm, support, &HnswGetNeighbors(base, element, lc)->closerSet, NULL, NULL, false); + neighbors = SelectNeighbors(base, lw, lm, index, support, &HnswGetNeighbors(base, element, lc)->closerSet, NULL, NULL, false); AddConnections(base, element, neighbors, lc); diff --git a/src/hnswvacuum.c b/src/hnswvacuum.c index 3a8ee26..f28271c 100644 --- a/src/hnswvacuum.c +++ b/src/hnswvacuum.c @@ -212,7 +212,7 @@ RepairGraphElement(HnswVacuumState * vacuumstate, HnswElement element, HnswEleme element->heaptidsLength = 0; /* Find neighbors for element, skipping itself */ - HnswFindElementNeighbors(base, element, entryPoint, index, support, m, efConstruction, true); + HnswFindElementNeighbors(base, element, entryPoint, index, support, m, efConstruction, true, false); /* Zero memory for each element */ MemSet(ntup, 0, HNSW_TUPLE_ALLOC_SIZE); @@ -264,7 +264,7 @@ RepairGraphEntryPoint(HnswVacuumState * vacuumstate) LockPage(index, HNSW_UPDATE_LOCK, ShareLock); /* Load element */ - HnswLoadElement(highestPoint, NULL, NULL, index, support, true, NULL); + HnswLoadElement(highestPoint, NULL, NULL, NULL, index, support, true, NULL); /* Repair if needed */ if (NeedsUpdated(vacuumstate, highestPoint)) @@ -302,7 +302,7 @@ RepairGraphEntryPoint(HnswVacuumState * vacuumstate) * is outdated, this can remove connections at higher levels in * the graph until they are repaired, but this should be fine. */ - HnswLoadElement(entryPoint, NULL, NULL, index, support, true, NULL); + HnswLoadElement(entryPoint, NULL, NULL, NULL, index, support, true, NULL); if (NeedsUpdated(vacuumstate, entryPoint)) { @@ -378,7 +378,7 @@ RepairGraph(HnswVacuumState * vacuumstate) /* Create an element */ element = HnswInitElementFromBlock(blkno, offno); - HnswLoadElementFromTuple(element, etup, false, true); + HnswLoadElementFromTuple(element, etup, false, true, index); elements = lappend(elements, element); } @@ -448,6 +448,7 @@ MarkDeleted(HnswVacuumState * vacuumstate) BlockNumber insertPage = InvalidBlockNumber; Relation index = vacuumstate->index; BufferAccessStrategy bas = vacuumstate->bas; + bool useIndexTuple = HnswUseIndexTuple(index); /* * Wait for index scans to complete. Scans before this point may contain @@ -529,7 +530,14 @@ MarkDeleted(HnswVacuumState * vacuumstate) /* Overwrite element */ etup->deleted = 1; - MemSet(&etup->data, 0, VARSIZE_ANY(&etup->data)); + if (useIndexTuple) + { + IndexTuple itup = (IndexTuple) &etup->data; + + MemSet(itup, 0, IndexTupleSize(itup)); + } + else + MemSet(&etup->data, 0, VARSIZE_ANY(&etup->data)); /* Overwrite neighbors */ for (int i = 0; i < ntup->count; i++) diff --git a/test/t/045_hnsw_hqann.pl b/test/t/045_hnsw_hqann.pl new file mode 100644 index 0000000..fc6d892 --- /dev/null +++ b/test/t/045_hnsw_hqann.pl @@ -0,0 +1,113 @@ +use strict; +use warnings FATAL => 'all'; +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use Test::More; + +my $node; +my @queries = (); +my @cs = (); +my @expected; +my $limit = 20; +my $dim = 3; +my $array_sql = join(",", ('random()') x $dim); +my $nc = 1000; + +sub test_recall +{ + my ($min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst WHERE c = $cs[0] ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Cond/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SELECT i FROM tst WHERE c = $cs[$i] ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + my %actual_set = map { $_ => 1 } @actual_ids; + + is(scalar(@actual_ids), $limit); + + my @expected_ids = split("\n", $expected[$i]); + + foreach (@expected_ids) + { + if (exists($actual_set{$_})) + { + $correct++; + } + $total++; + } + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = PostgreSQL::Test::Cluster->new('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim), c int4);"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql], i % $nc FROM generate_series(1, 20000) i;" +); + +# Generate queries +for (1 .. 20) +{ + my @r = (); + for (1 .. $dim) + { + push(@r, rand()); + } + push(@queries, "[" . join(",", @r) . "]"); + push(@cs, int(rand() * $nc)); +} + +# Get exact results +@expected = (); +for my $i (0 .. $#queries) +{ + my $res = $node->safe_psql("postgres", "SELECT i FROM tst WHERE c = $cs[$i] ORDER BY v <=> '$queries[$i]' LIMIT $limit;"); + push(@expected, $res); +} + +# Add index +$node->safe_psql("postgres", qq( + SET maintenance_work_mem = '256MB'; + SET max_parallel_maintenance_workers = 2; + CREATE INDEX ON tst USING hnsw (v vector_cosine_ops, c); +)); + +# Test recall +test_recall(0.99, '<=>'); + +# Test vacuum +$node->safe_psql("postgres", "DELETE FROM tst WHERE c > 5;"); +$node->safe_psql("postgres", "VACUUM tst;"); + +# Test columns +my ($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING hnsw (c);"); +like($stderr, qr/first column must be a vector/); + +($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING hnsw (c, v vector_cosine_ops);"); +like($stderr, qr/first column must be a vector/); + +($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_cosine_ops, c, c);"); +like($stderr, qr/index cannot have more than two columns/); + +($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_cosine_ops, v vector_cosine_ops);"); +like($stderr, qr/column 2 cannot be a vector/); + +done_testing();