From f57f2b68214daecd54924a31d5bbef46da50e33d Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Fri, 10 Nov 2023 13:28:48 -0800 Subject: [PATCH] Added support for inline filtering with HNSW --- CHANGELOG.md | 4 + README.md | 7 + sql/vector--0.5.1--0.6.0.sql | 10 ++ sql/vector.sql | 10 ++ src/hnsw.c | 18 ++- src/hnsw.h | 35 +++-- src/hnswbuild.c | 45 ++++-- src/hnswinsert.c | 27 ++-- src/hnswscan.c | 21 ++- src/hnswutils.c | 275 +++++++++++++++++++++++++++++------ src/hnswvacuum.c | 33 +++-- test/t/019_hnsw_filtering.pl | 107 ++++++++++++++ 12 files changed, 483 insertions(+), 109 deletions(-) create mode 100644 sql/vector--0.5.1--0.6.0.sql create mode 100644 test/t/019_hnsw_filtering.pl diff --git a/CHANGELOG.md b/CHANGELOG.md index 07040d0..a9d7244 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.6.0 (unreleased) + +- Added support for inline filtering with HNSW + ## 0.5.1 (2023-10-10) - Improved performance of HNSW index builds diff --git a/README.md b/README.md index bca3de8..3ff6407 100644 --- a/README.md +++ b/README.md @@ -315,6 +315,12 @@ Create an index on one [or more](https://www.postgresql.org/docs/current/indexes CREATE INDEX ON items (category_id); ``` +Or a composite HNSW index for approximate search (added in 0.6.0) + +```sql +CREATE INDEX ON items USING hnsw (embedding vector_l2_ops, category_id); +``` + Or a [partial index](https://www.postgresql.org/docs/current/indexes-partial.html) on the vector column for approximate search ```sql @@ -712,6 +718,7 @@ Thanks to: - [k-means++: The Advantage of Careful Seeding](https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf) - [Concept Decompositions for Large Sparse Text Data using Clustering](https://www.cs.utexas.edu/users/inderjit/public_papers/concept_mlj.pdf) - [Efficient and Robust Approximate Nearest Neighbor Search using Hierarchical Navigable Small World Graphs](https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf) +- [HQANN: Efficient and Robust Similarity Search for Hybrid Queries with Structured and Unstructured Constraints](https://arxiv.org/pdf/2207.07940.pdf) ## History diff --git a/sql/vector--0.5.1--0.6.0.sql b/sql/vector--0.5.1--0.6.0.sql new file mode 100644 index 0000000..3030c31 --- /dev/null +++ b/sql/vector--0.5.1--0.6.0.sql @@ -0,0 +1,10 @@ +-- complain if script is sourced in psql, rather than via CREATE EXTENSION +\echo Use "ALTER EXTENSION vector UPDATE TO '0.6.0'" to load this file. \quit + +CREATE FUNCTION hnsw_attribute_distance(integer, integer) RETURNS float8 + AS 'MODULE_PATHNAME', 'hnsw_int4_attribute_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE OPERATOR CLASS vector_integer_ops + DEFAULT FOR TYPE integer USING hnsw AS + OPERATOR 2 = (integer, integer), + FUNCTION 3 hnsw_attribute_distance(integer, integer); diff --git a/sql/vector.sql b/sql/vector.sql index 137931f..9149369 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -290,3 +290,13 @@ CREATE OPERATOR CLASS vector_cosine_ops OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, FUNCTION 1 vector_negative_inner_product(vector, vector), FUNCTION 2 vector_norm(vector); + +-- hnsw attributes + +CREATE FUNCTION hnsw_attribute_distance(integer, integer) RETURNS float8 + AS 'MODULE_PATHNAME', 'hnsw_int4_attribute_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE OPERATOR CLASS vector_integer_ops + DEFAULT FOR TYPE integer USING hnsw AS + OPERATOR 2 = (integer, integer), + FUNCTION 3 hnsw_attribute_distance(integer, integer); diff --git a/src/hnsw.c b/src/hnsw.c index 758e418..88c8e7d 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -167,7 +167,7 @@ hnswhandler(PG_FUNCTION_ARGS) IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); amroutine->amstrategies = 0; - amroutine->amsupport = 2; + amroutine->amsupport = 3; #if PG_VERSION_NUM >= 130000 amroutine->amoptsprocnum = 0; #endif @@ -175,7 +175,7 @@ hnswhandler(PG_FUNCTION_ARGS) amroutine->amcanorderbyop = true; amroutine->amcanbackward = false; /* can change direction mid-scan */ amroutine->amcanunique = false; - amroutine->amcanmulticol = false; + amroutine->amcanmulticol = true; amroutine->amoptionalkey = true; amroutine->amsearcharray = false; amroutine->amsearchnulls = false; @@ -222,3 +222,17 @@ hnswhandler(PG_FUNCTION_ARGS) PG_RETURN_POINTER(amroutine); } + +/* + * Get the distance between two int4 attributes + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(hnsw_int4_attribute_distance); +Datum +hnsw_int4_attribute_distance(PG_FUNCTION_ARGS) +{ + int32 a = PG_GETARG_INT32(0); + int32 b = PG_GETARG_INT32(1); + double distance = ((double) a) - ((double) b); + + PG_RETURN_FLOAT8(distance); +} diff --git a/src/hnsw.h b/src/hnsw.h index 57cdafe..232fa3c 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -19,6 +19,7 @@ /* Support functions */ #define HNSW_DISTANCE_PROC 1 #define HNSW_NORM_PROC 2 +#define HNSW_ATTRIBUTE_DISTANCE_PROC 3 #define HNSW_VERSION 1 #define HNSW_MAGIC_NUMBER 0xA953A953 @@ -104,6 +105,7 @@ typedef struct HnswElementData OffsetNumber neighborOffno; BlockNumber neighborPage; Datum value; + IndexTuple itup; } HnswElementData; typedef HnswElementData * HnswElement; @@ -154,9 +156,9 @@ typedef struct HnswBuildState double reltuples; /* Support functions */ - FmgrInfo *procinfo; + FmgrInfo **procinfos; FmgrInfo *normprocinfo; - Oid collation; + Oid *collations; /* Variables */ List *elements; @@ -165,6 +167,7 @@ typedef struct HnswBuildState int maxLevel; long memoryLeft; bool flushed; + bool useIndexTuple; Vector *normvec; /* Memory */ @@ -226,9 +229,9 @@ typedef struct HnswScanOpaqueData MemoryContext tmpCtx; /* Support functions */ - FmgrInfo *procinfo; + FmgrInfo **procinfos; FmgrInfo *normprocinfo; - Oid collation; + Oid *collations; } HnswScanOpaqueData; typedef HnswScanOpaqueData * HnswScanOpaque; @@ -246,8 +249,8 @@ typedef struct HnswVacuumState int efConstruction; /* Support functions */ - FmgrInfo *procinfo; - Oid collation; + FmgrInfo **procinfos; + Oid *collations; /* Variables */ HTAB *deleted; @@ -269,26 +272,27 @@ Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); void HnswInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state); void HnswInit(void); -List *HnswSearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement); +List *HnswSearchLayer(Datum q, IndexTuple qtup, ScanKeyData *keyData, List *ep, int ef, int lc, Relation index, FmgrInfo **procinfos, Oid *collations, int m, bool loadVec, HnswElement skipElement, bool inMemory); HnswElement HnswGetEntryPoint(Relation index); void HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint); HnswElement HnswInitElement(ItemPointer tid, int m, double ml, int maxLevel); void HnswFreeElement(HnswElement element); HnswElement HnswInitElementFromBlock(BlockNumber blkno, OffsetNumber offno); -void HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, bool existing); -HnswElement HnswFindDuplicate(HnswElement e); -HnswCandidate *HnswEntryCandidate(HnswElement em, Datum q, Relation rel, FmgrInfo *procinfo, Oid collation, bool loadVec); +void HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo **procinfos, Oid *collations, int m, int efConstruction, bool existing, bool inMemory); +HnswElement HnswFindDuplicate(HnswElement e, Relation index); +HnswCandidate *HnswEntryCandidate(HnswElement em, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation rel, FmgrInfo **procinfos, Oid *collations, bool loadVec, bool inMemory); void HnswUpdateMetaPage(Relation index, int updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum); void HnswSetNeighborTuple(HnswNeighborTuple ntup, HnswElement e, int m); void HnswAddHeapTid(HnswElement element, ItemPointer heaptid); void HnswInitNeighbors(HnswElement element, int m); bool HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel); -void HnswUpdateNeighborPages(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement e, int m, bool checkExisting); -void HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec); -void HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec); -void HnswSetElementTuple(HnswElementTuple etup, HnswElement element); -void HnswUpdateConnection(HnswElement element, HnswCandidate * hc, int m, int lc, int *updateIdx, Relation index, FmgrInfo *procinfo, Oid collation); +void HnswUpdateNeighborPages(Relation index, FmgrInfo **procinfos, Oid *collations, HnswElement e, int m, bool checkExisting); +void HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec, Relation index); +void HnswLoadElement(HnswElement element, float *distance, Datum *q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfos, Oid *collations, bool loadVec); +void HnswSetElementTuple(HnswElementTuple etup, HnswElement element, bool useIndexTuple); +void HnswUpdateConnection(HnswElement element, HnswCandidate * hc, int m, int lc, int *updateIdx, Relation index, FmgrInfo **procinfos, Oid *collations, bool inMemory); void HnswLoadNeighbors(HnswElement element, Relation index, int m); +void HnswElementSetData(HnswElement element, Relation index, Datum value, Datum *values, bool *isnull); /* Index access methods */ IndexBuildResult *hnswbuild(Relation heap, Relation index, IndexInfo *indexInfo); @@ -305,5 +309,6 @@ IndexScanDesc hnswbeginscan(Relation index, int nkeys, int norderbys); void hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys); bool hnswgettuple(IndexScanDesc scan, ScanDirection dir); void hnswendscan(IndexScanDesc scan); +FmgrInfo **HnswInitProcinfos(Relation index); #endif diff --git a/src/hnswbuild.c b/src/hnswbuild.c index 7f68c94..cbd5b57 100644 --- a/src/hnswbuild.c +++ b/src/hnswbuild.c @@ -106,6 +106,7 @@ CreateElementPages(HnswBuildState * buildstate) { Relation index = buildstate->index; ForkNumber forkNum = buildstate->forkNum; + bool useIndexTuple = buildstate->useIndexTuple; Size etupAllocSize; Size maxSize; HnswElementTuple etup; @@ -141,7 +142,7 @@ CreateElementPages(HnswBuildState * buildstate) MemSet(etup, 0, etupAllocSize); /* Calculate sizes */ - etupSize = HNSW_ELEMENT_TUPLE_SIZE(VARSIZE_ANY(DatumGetPointer(element->value))); + etupSize = HNSW_ELEMENT_TUPLE_SIZE(useIndexTuple ? IndexTupleSize(element->itup) : VARSIZE_ANY(DatumGetPointer(element->value))); ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(element->level, buildstate->m); combinedSize = etupSize + ntupSize + sizeof(ItemIdData); @@ -149,7 +150,7 @@ CreateElementPages(HnswBuildState * buildstate) if (etupSize > etupAllocSize) elog(ERROR, "index tuple too large"); - HnswSetElementTuple(etup, element); + HnswSetElementTuple(etup, element, useIndexTuple); /* Keep element and neighbors on the same page if possible */ if (PageGetFreeSpace(page) < etupSize || (combinedSize <= maxSize && PageGetFreeSpace(page) < combinedSize)) @@ -273,13 +274,14 @@ FlushPages(HnswBuildState * buildstate) * Insert tuple */ static bool -InsertTuple(Relation index, Datum *values, HnswElement element, HnswBuildState * buildstate, HnswElement * dup, MemoryContext outerCtx) +InsertTuple(Relation index, Datum *values, bool *isnull, HnswElement element, HnswBuildState * buildstate, HnswElement * dup, MemoryContext outerCtx) { - FmgrInfo *procinfo = buildstate->procinfo; - Oid collation = buildstate->collation; + FmgrInfo **procinfos = buildstate->procinfos; + Oid *collations = buildstate->collations; HnswElement entryPoint = buildstate->entryPoint; int efConstruction = buildstate->efConstruction; int m = buildstate->m; + bool inMemory = true; MemoryContext oldCtx; /* Detoast once for all calls */ @@ -288,20 +290,20 @@ InsertTuple(Relation index, Datum *values, HnswElement element, HnswBuildState * /* Normalize if needed */ if (buildstate->normprocinfo != NULL) { - if (!HnswNormValue(buildstate->normprocinfo, collation, &value, buildstate->normvec)) + if (!HnswNormValue(buildstate->normprocinfo, collations[0], &value, buildstate->normvec)) return false; } /* Copy value to element so accessible outside of memory context */ oldCtx = MemoryContextSwitchTo(outerCtx); - element->value = datumCopy(value, false, -1); + HnswElementSetData(element, index, value, values, isnull); MemoryContextSwitchTo(oldCtx); /* Insert element in graph */ - HnswInsertElement(element, entryPoint, NULL, procinfo, collation, m, efConstruction, false); + HnswInsertElement(element, entryPoint, index, procinfos, collations, m, efConstruction, false, inMemory); /* Look for duplicate */ - *dup = HnswFindDuplicate(element); + *dup = HnswFindDuplicate(element, index); /* Update neighbors if needed */ if (*dup == NULL) @@ -312,7 +314,7 @@ InsertTuple(Relation index, Datum *values, HnswElement element, HnswBuildState * HnswNeighborArray *neighbors = &element->neighbors[lc]; for (int i = 0; i < neighbors->length; i++) - HnswUpdateConnection(element, &neighbors->items[i], lm, lc, NULL, NULL, procinfo, collation); + HnswUpdateConnection(element, &neighbors->items[i], lm, lc, NULL, index, procinfos, collations, inMemory); } } @@ -336,7 +338,7 @@ HnswElementMemory(HnswElement e, int m) elementSize += sizeof(HnswNeighborArray) * (e->level + 1); elementSize += sizeof(HnswCandidate) * (m * (e->level + 2)); elementSize += sizeof(ItemPointerData); - elementSize += VARSIZE_ANY(DatumGetPointer(e->value)); + elementSize += IndexTupleSize(e->itup); return elementSize; } @@ -392,7 +394,7 @@ BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values, oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx); /* Insert tuple */ - inserted = InsertTuple(index, values, element, buildstate, &dup, oldCtx); + inserted = InsertTuple(index, values, isnull, element, buildstate, &dup, oldCtx); /* Reset memory context */ MemoryContextSwitchTo(oldCtx); @@ -430,6 +432,19 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index buildstate->efConstruction = HnswGetEfConstruction(index); buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; + /* TODO See if needed */ + if (IndexRelationGetNumberOfKeyAttributes(index) > 2) + elog(ERROR, "index cannot have more than two columns"); + + if (!OidIsValid(index_getprocid(index, 1, HNSW_DISTANCE_PROC))) + elog(ERROR, "first column must be a vector"); + + for (int i = 1; i < IndexRelationGetNumberOfKeyAttributes(index); i++) + { + if (!OidIsValid(index_getprocid(index, i + 1, HNSW_ATTRIBUTE_DISTANCE_PROC))) + elog(ERROR, "column %d cannot be a vector", i + 1); + } + /* Require column to have dimensions to be indexed */ if (buildstate->dimensions < 0) elog(ERROR, "column does not have dimensions"); @@ -444,9 +459,9 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index buildstate->indtuples = 0; /* Get support functions */ - buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); + buildstate->procinfos = HnswInitProcinfos(index); buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); - buildstate->collation = index->rd_indcollation[0]; + buildstate->collations = index->rd_indcollation; buildstate->elements = NIL; buildstate->entryPoint = NULL; @@ -454,6 +469,7 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index buildstate->maxLevel = HnswGetMaxLevel(buildstate->m); buildstate->memoryLeft = maintenance_work_mem * 1024L; buildstate->flushed = false; + buildstate->useIndexTuple = IndexRelationGetNumberOfAttributes(index) > 1; /* Reuse for each tuple */ buildstate->normvec = InitVector(buildstate->dimensions); @@ -469,6 +485,7 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index static void FreeBuildState(HnswBuildState * buildstate) { + pfree(buildstate->procinfos); pfree(buildstate->normvec); MemoryContextDelete(buildstate->tmpCtx); } diff --git a/src/hnswinsert.c b/src/hnswinsert.c index 873be3a..62c50b8 100644 --- a/src/hnswinsert.c +++ b/src/hnswinsert.c @@ -129,9 +129,10 @@ WriteNewElementPages(Relation index, HnswElement e, int m, BlockNumber insertPag OffsetNumber freeOffno = InvalidOffsetNumber; OffsetNumber freeNeighborOffno = InvalidOffsetNumber; BlockNumber newInsertPage = InvalidBlockNumber; + bool useIndexTuple = IndexRelationGetNumberOfAttributes(index) > 1; /* Calculate sizes */ - etupSize = HNSW_ELEMENT_TUPLE_SIZE(VARSIZE_ANY(DatumGetPointer(e->value))); + etupSize = HNSW_ELEMENT_TUPLE_SIZE(useIndexTuple ? IndexTupleSize(e->itup) : VARSIZE_ANY(DatumGetPointer(e->value))); ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(e->level, m); combinedSize = etupSize + ntupSize + sizeof(ItemIdData); maxSize = HNSW_MAX_SIZE; @@ -139,7 +140,7 @@ WriteNewElementPages(Relation index, HnswElement e, int m, BlockNumber insertPag /* Prepare element tuple */ etup = palloc0(etupSize); - HnswSetElementTuple(etup, e); + HnswSetElementTuple(etup, e, useIndexTuple); /* Prepare neighbor tuple */ ntup = palloc0(ntupSize); @@ -301,7 +302,7 @@ ConnectionExists(HnswElement e, HnswNeighborTuple ntup, int startIdx, int lm) * Update neighbors */ void -HnswUpdateNeighborPages(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement e, int m, bool checkExisting) +HnswUpdateNeighborPages(Relation index, FmgrInfo **procinfos, Oid *collations, HnswElement e, int m, bool checkExisting) { for (int lc = e->level; lc >= 0; lc--) { @@ -333,7 +334,7 @@ HnswUpdateNeighborPages(Relation index, FmgrInfo *procinfo, Oid collation, HnswE */ /* Select neighbors */ - HnswUpdateConnection(e, hc, lm, lc, &idx, index, procinfo, collation); + HnswUpdateConnection(e, hc, lm, lc, &idx, index, procinfos, collations, false); /* New element was not selected as a neighbor */ if (idx == -1) @@ -451,7 +452,7 @@ HnswAddDuplicate(Relation index, HnswElement element, HnswElement dup) * Write changes to disk */ static void -WriteElement(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement element, int m, int efConstruction, HnswElement dup, HnswElement entryPoint) +WriteElement(Relation index, FmgrInfo **procinfos, Oid *collations, HnswElement element, int m, int efConstruction, HnswElement dup, HnswElement entryPoint) { BlockNumber newInsertPage = InvalidBlockNumber; @@ -470,7 +471,7 @@ WriteElement(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement elem HnswUpdateMetaPage(index, 0, NULL, newInsertPage, MAIN_FORKNUM); /* Update neighbors */ - HnswUpdateNeighborPages(index, procinfo, collation, element, m, false); + HnswUpdateNeighborPages(index, procinfos, collations, element, m, false); /* Update metapage if needed */ if (entryPoint == NULL || element->level > entryPoint->level) @@ -489,8 +490,8 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti HnswElement element; int m; int efConstruction = HnswGetEfConstruction(index); - FmgrInfo *procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); - Oid collation = index->rd_indcollation[0]; + FmgrInfo **procinfos = HnswInitProcinfos(index); + Oid *collations = index->rd_indcollation; HnswElement dup; LOCKMODE lockmode = ShareLock; @@ -501,7 +502,7 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); if (normprocinfo != NULL) { - if (!HnswNormValue(normprocinfo, collation, &value, NULL)) + if (!HnswNormValue(normprocinfo, collations[0], &value, NULL)) return false; } @@ -517,7 +518,7 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti /* Create an element */ element = HnswInitElement(heap_tid, m, HnswGetMl(m), HnswGetMaxLevel(m)); - element->value = value; + HnswElementSetData(element, index, value, values, isnull); /* Prevent concurrent inserts when likely updating entry point */ if (entryPoint == NULL || element->level > entryPoint->level) @@ -534,13 +535,13 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti } /* Insert element in graph */ - HnswInsertElement(element, entryPoint, index, procinfo, collation, m, efConstruction, false); + HnswInsertElement(element, entryPoint, index, procinfos, collations, m, efConstruction, false, false); /* Look for duplicate */ - dup = HnswFindDuplicate(element); + dup = HnswFindDuplicate(element, index); /* Write to disk */ - WriteElement(index, procinfo, collation, element, m, efConstruction, dup, entryPoint); + WriteElement(index, procinfos, collations, element, m, efConstruction, dup, entryPoint); /* Release lock */ UnlockPage(index, HNSW_UPDATE_LOCK, lockmode); diff --git a/src/hnswscan.c b/src/hnswscan.c index 7cf2bf0..87a06ad 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -15,12 +15,13 @@ GetScanItems(IndexScanDesc scan, Datum q) { HnswScanOpaque so = (HnswScanOpaque) scan->opaque; Relation index = scan->indexRelation; - FmgrInfo *procinfo = so->procinfo; - Oid collation = so->collation; + FmgrInfo **procinfos = so->procinfos; + Oid *collations = so->collations; List *ep; List *w; int m; HnswElement entryPoint; + ScanKeyData *keyData = scan->keyData; /* Get m and entry point */ HnswGetMetaPageInfo(index, &m, &entryPoint); @@ -28,15 +29,15 @@ GetScanItems(IndexScanDesc scan, Datum q) if (entryPoint == NULL) return NIL; - ep = list_make1(HnswEntryCandidate(entryPoint, q, index, procinfo, collation, false)); + ep = list_make1(HnswEntryCandidate(entryPoint, q, NULL, keyData, index, procinfos, collations, false, false)); for (int lc = entryPoint->level; lc >= 1; lc--) { - w = HnswSearchLayer(q, ep, 1, lc, index, procinfo, collation, m, false, NULL); + w = HnswSearchLayer(q, NULL, keyData, ep, 1, lc, index, procinfos, collations, m, false, NULL, false); ep = w; } - return HnswSearchLayer(q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL); + return HnswSearchLayer(q, NULL, keyData, ep, hnsw_ef_search, 0, index, procinfos, collations, m, false, NULL, false); } /* @@ -83,7 +84,7 @@ GetScanValue(IndexScanDesc scan) /* Fine if normalization fails */ if (so->normprocinfo != NULL) - HnswNormValue(so->normprocinfo, so->collation, &value, NULL); + HnswNormValue(so->normprocinfo, so->collations[0], &value, NULL); } return value; @@ -107,9 +108,9 @@ hnswbeginscan(Relation index, int nkeys, int norderbys) ALLOCSET_DEFAULT_SIZES); /* Set support functions */ - so->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); + so->procinfos = HnswInitProcinfos(index); so->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); - so->collation = index->rd_indcollation[0]; + so->collations = index->rd_indcollation; scan->opaque = so; @@ -206,6 +207,9 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir) scan->xs_ctup.t_self = *heaptid; #endif + /* TODO Check during scan */ + scan->xs_recheck = scan->numberOfKeys > 0; + scan->xs_recheckorderby = false; return true; } @@ -222,6 +226,7 @@ hnswendscan(IndexScanDesc scan) { HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + pfree(so->procinfos); MemoryContextDelete(so->tmpCtx); pfree(so); diff --git a/src/hnswutils.c b/src/hnswutils.c index 08d867b..6cfdab0 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -7,6 +7,10 @@ #include "utils/datum.h" #include "vector.h" +#if PG_VERSION_NUM < 130000 +#define TYPSTORAGE_PLAIN 'p' +#endif + /* * Get the max number of connections in an upper layer for each element in the index */ @@ -47,6 +51,22 @@ HnswOptionalProcInfo(Relation index, uint16 procnum) return index_getprocinfo(index, 1, procnum); } +/* + * Init procs + */ +FmgrInfo ** +HnswInitProcinfos(Relation index) +{ + int keyAttributes = IndexRelationGetNumberOfKeyAttributes(index); + FmgrInfo **procinfos = palloc(keyAttributes * sizeof(FmgrInfo *)); + + procinfos[0] = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); + for (int i = 1; i < keyAttributes; i++) + procinfos[i] = index_getprocinfo(index, i + 1, HNSW_ATTRIBUTE_DISTANCE_PROC); + + return procinfos; +} + /* * Divide by the norm * @@ -174,6 +194,7 @@ HnswInitElement(ItemPointer heaptid, int m, double ml, int maxLevel) element->level = level; element->deleted = 0; + element->itup = NULL; HnswInitNeighbors(element, m); @@ -188,8 +209,8 @@ HnswFreeElement(HnswElement element) { HnswFreeNeighbors(element); list_free_deep(element->heaptids); - if (DatumGetPointer(element->value)) - pfree(DatumGetPointer(element->value)); + if (element->itup) + pfree(element->itup); pfree(element); } @@ -217,6 +238,7 @@ HnswInitElementFromBlock(BlockNumber blkno, OffsetNumber offno) element->offno = offno; element->neighbors = NULL; element->value = PointerGetDatum(NULL); + element->itup = NULL; return element; } @@ -314,7 +336,7 @@ HnswUpdateMetaPage(Relation index, int updateEntry, HnswElement entryPoint, Bloc * Set element tuple, except for neighbor info */ void -HnswSetElementTuple(HnswElementTuple etup, HnswElement element) +HnswSetElementTuple(HnswElementTuple etup, HnswElement element, bool useIndexTuple) { etup->type = HNSW_ELEMENT_TUPLE_TYPE; etup->level = element->level; @@ -326,7 +348,11 @@ HnswSetElementTuple(HnswElementTuple etup, HnswElement element) else ItemPointerSetInvalid(&etup->heaptids[i]); } - memcpy(&etup->data, DatumGetPointer(element->value), VARSIZE_ANY(DatumGetPointer(element->value))); + + if (useIndexTuple) + memcpy(&etup->data, element->itup, IndexTupleSize(element->itup)); + else + memcpy(&etup->data, DatumGetPointer(element->value), VARSIZE_ANY(DatumGetPointer(element->value))); } /* @@ -427,7 +453,7 @@ HnswLoadNeighbors(HnswElement element, Relation index, int m) * Load an element from a tuple */ void -HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec) +HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec, Relation index) { element->level = etup->level; element->deleted = etup->deleted; @@ -449,18 +475,149 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe if (loadVec) { - Vector *vec = palloc(VARSIZE_ANY(&etup->data)); + if (IndexRelationGetNumberOfAttributes(index) > 1) + { + TupleDesc tupdesc = RelationGetDescr(index); + bool unused; - memcpy(vec, &etup->data, VARSIZE_ANY(&etup->data)); - element->value = PointerGetDatum(vec); + element->itup = CopyIndexTuple((IndexTuple) &etup->data); + element->value = index_getattr(element->itup, 1, tupdesc, &unused); + } + else + { + Vector *vec = palloc(VARSIZE_ANY(&etup->data)); + + memcpy(vec, &etup->data, VARSIZE_ANY(&etup->data)); + element->value = PointerGetDatum(vec); + } } } +/* + * Get the tuple descriptor + */ +static TupleDesc +HnswTupleDesc(Relation index) +{ + TupleDesc tupdesc = CreateTupleDescCopyConstr(RelationGetDescr(index)); + + /* Prevent compression */ + TupleDescAttr(tupdesc, 0)->attstorage = TYPSTORAGE_PLAIN; + + return tupdesc; +} + +/* + * Set element data + */ +void +HnswElementSetData(HnswElement element, Relation index, Datum value, Datum *values, bool *isnull) +{ + /* TODO Create once per index build */ + TupleDesc tupdesc = HnswTupleDesc(index); + bool unused; + Datum tmp; + + tmp = values[0]; + values[0] = value; + element->itup = index_form_tuple(tupdesc, values, isnull); + values[0] = tmp; + + element->value = index_getattr(element->itup, 1, tupdesc, &unused); + + FreeTupleDesc(tupdesc); +} + +/* + * Get the attribute distance + */ +static inline double +AttributeDistance(double e) +{ + /* TODO Better bias */ + /* must be >> max(w * g) + 1 / log10(2) */ + double bias = 4.32; + + return e > 0 ? bias - 1.0 / log10(e + 1) : 0; +} + +/* + * Get the distance + */ +static double +GetDistance(IndexTuple itup, Datum vec, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfos, Oid *collations) +{ + double g = DatumGetFloat8(FunctionCall2Coll(procinfos[0], collations[0], q, vec)); + + if (IndexRelationGetNumberOfKeyAttributes(index) > 1) + { + double w = 0.25; + double e = 0.0; + TupleDesc tupdesc = RelationGetDescr(index); + + if (keyData) + { + /* TODO need to pass length of key data */ + int keyCount = 1; + + for (int i = 0; i < keyCount; i++) + { + ScanKey key = &keyData[i]; + bool isnull; + Datum value = index_getattr(itup, key->sk_attno, tupdesc, &isnull); + bool attnull = key->sk_flags & SK_ISNULL; + + if (isnull || attnull) + { + if (isnull != attnull) + e += 1000; + } + else if (!DatumGetBool(FunctionCall2Coll(&key->sk_func, key->sk_collation, value, key->sk_argument))) + { + double ei = fabs(DatumGetFloat8(FunctionCall2Coll(procinfos[key->sk_attno - 1], collations[key->sk_attno - 1], value, key->sk_argument))); + + if (ei > 0) + e += ei; + else + /* Distance is zero for inequality */ + e += 1000; + } + } + + return w * g + AttributeDistance(e); + } + else if (qtup) + { + int keyCount = IndexRelationGetNumberOfKeyAttributes(index) - 1; + + for (int i = 0; i < keyCount; i++) + { + bool isnull; + bool attnull; + Datum value = index_getattr(itup, i + 2, tupdesc, &isnull); + Datum value2 = index_getattr(qtup, i + 2, tupdesc, &attnull); + + if (isnull || attnull) + { + if (isnull != attnull) + e += 1000; + } + else + e += fabs(DatumGetFloat8(FunctionCall2Coll(procinfos[i + 1], collations[i + 1], value, value2))); + } + + return w * g + AttributeDistance(e); + } + } + + return g; +} + /* * Load an element and optionally get its distance from q */ void -HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec) +HnswLoadElement(HnswElement element, float *distance, Datum *q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfos, Oid *collations, bool loadVec) { Buffer buf; Page page; @@ -476,11 +633,27 @@ HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, Assert(HnswIsElementTuple(etup)); /* Load element */ - HnswLoadElementFromTuple(element, etup, true, loadVec); + HnswLoadElementFromTuple(element, etup, true, loadVec, index); /* Calculate distance */ if (distance != NULL) - *distance = (float) DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&etup->data))); + { + IndexTuple itup = NULL; + Datum value; + + if (IndexRelationGetNumberOfAttributes(index) > 1) + { + TupleDesc tupdesc = RelationGetDescr(index); + bool unused; + + itup = (IndexTuple) &etup->data; + value = index_getattr(itup, 1, tupdesc, &unused); + } + else + value = PointerGetDatum(&etup->data); + + *distance = GetDistance(itup, value, *q, qtup, keyData, index, procinfos, collations); + } UnlockReleaseBuffer(buf); } @@ -489,24 +662,24 @@ HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, * Get the distance for a candidate */ static float -GetCandidateDistance(HnswCandidate * hc, Datum q, FmgrInfo *procinfo, Oid collation) +GetCandidateDistance(HnswCandidate * hc, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfos, Oid *collations) { - return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, hc->element->value)); + return GetDistance(hc->element->itup, hc->element->value, q, qtup, keyData, index, procinfos, collations); } /* * Create a candidate for the entry point */ HnswCandidate * -HnswEntryCandidate(HnswElement entryPoint, Datum q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec) +HnswEntryCandidate(HnswElement entryPoint, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfos, Oid *collations, bool loadVec, bool inMemory) { HnswCandidate *hc = palloc(sizeof(HnswCandidate)); hc->element = entryPoint; - if (index == NULL) - hc->distance = GetCandidateDistance(hc, q, procinfo, collation); + if (inMemory) + hc->distance = GetCandidateDistance(hc, q, qtup, keyData, index, procinfos, collations); else - HnswLoadElement(hc->element, &hc->distance, &q, index, procinfo, collation, loadVec); + HnswLoadElement(hc->element, &hc->distance, &q, qtup, keyData, index, procinfos, collations, loadVec); return hc; } @@ -556,9 +729,9 @@ CreatePairingHeapNode(HnswCandidate * c) * Add to visited */ static inline void -AddToVisited(HTAB *v, HnswCandidate * hc, Relation index, bool *found) +AddToVisited(HTAB *v, HnswCandidate * hc, bool inMemory, bool *found) { - if (index == NULL) + if (inMemory) hash_search(v, &hc->element, HASH_ENTER, found); else { @@ -573,7 +746,7 @@ AddToVisited(HTAB *v, HnswCandidate * hc, Relation index, bool *found) * Algorithm 2 from paper */ List * -HnswSearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement) +HnswSearchLayer(Datum q, IndexTuple qtup, ScanKeyData *keyData, List *ep, int ef, int lc, Relation index, FmgrInfo **procinfos, Oid *collations, int m, bool loadVec, HnswElement skipElement, bool inMemory) { ListCell *lc2; @@ -585,7 +758,7 @@ HnswSearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *pro HTAB *v; /* Create hash table */ - if (index == NULL) + if (inMemory) { hash_ctl.keysize = sizeof(HnswElement *); hash_ctl.entrysize = sizeof(HnswElement *); @@ -604,7 +777,7 @@ HnswSearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *pro { HnswCandidate *hc = (HnswCandidate *) lfirst(lc2); - AddToVisited(v, hc, index, NULL); + AddToVisited(v, hc, inMemory, NULL); pairingheap_add(C, &(CreatePairingHeapNode(hc)->ph_node)); pairingheap_add(W, &(CreatePairingHeapNode(hc)->ph_node)); @@ -638,7 +811,7 @@ HnswSearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *pro HnswCandidate *e = &neighborhood->items[i]; bool visited; - AddToVisited(v, e, index, &visited); + AddToVisited(v, e, inMemory, &visited); if (!visited) { @@ -646,10 +819,10 @@ HnswSearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *pro f = ((HnswPairingHeapNode *) pairingheap_first(W))->inner; - if (index == NULL) - eDistance = GetCandidateDistance(e, q, procinfo, collation); + if (inMemory) + eDistance = GetCandidateDistance(e, q, qtup, keyData, index, procinfos, collations); else - HnswLoadElement(e->element, &eDistance, &q, index, procinfo, collation, inserting); + HnswLoadElement(e->element, &eDistance, &q, qtup, keyData, index, procinfos, collations, loadVec); Assert(!e->element->deleted); @@ -729,7 +902,7 @@ CompareCandidateDistances(const void *a, const void *b) * Calculate the distance between elements */ static float -HnswGetDistance(HnswElement a, HnswElement b, int lc, FmgrInfo *procinfo, Oid collation) +HnswGetCachedDistance(HnswElement a, HnswElement b, int lc, Relation index, FmgrInfo **procinfos, Oid *collations) { /* Look for cached distance */ if (a->neighbors != NULL) @@ -754,21 +927,21 @@ HnswGetDistance(HnswElement a, HnswElement b, int lc, FmgrInfo *procinfo, Oid co } } - return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, a->value, b->value)); + return GetDistance(a->itup, a->value, b->value, b->itup, NULL, index, procinfos, collations); } /* * Check if an element is closer to q than any element from R */ static bool -CheckElementCloser(HnswCandidate * e, List *r, int lc, FmgrInfo *procinfo, Oid collation) +CheckElementCloser(HnswCandidate * e, List *r, int lc, Relation index, FmgrInfo **procinfos, Oid *collations) { ListCell *lc2; foreach(lc2, r) { HnswCandidate *ri = lfirst(lc2); - float distance = HnswGetDistance(e->element, ri->element, lc, procinfo, collation); + float distance = HnswGetCachedDistance(e->element, ri->element, lc, index, procinfos, collations); if (distance <= e->distance) return false; @@ -781,7 +954,7 @@ CheckElementCloser(HnswCandidate * e, List *r, int lc, FmgrInfo *procinfo, Oid c * Algorithm 4 from paper */ static List * -SelectNeighbors(List *c, int m, int lc, FmgrInfo *procinfo, Oid collation, HnswElement e2, HnswCandidate * newCandidate, HnswCandidate * *pruned, bool sortCandidates) +SelectNeighbors(List *c, int m, int lc, Relation index, FmgrInfo **procinfos, Oid *collations, HnswElement e2, HnswCandidate * newCandidate, HnswCandidate * *pruned, bool sortCandidates) { List *r = NIL; List *w = list_copy(c); @@ -808,7 +981,7 @@ SelectNeighbors(List *c, int m, int lc, FmgrInfo *procinfo, Oid collation, HnswE /* Use previous state of r and wd to skip work when possible */ if (mustCalculate) - e->closer = CheckElementCloser(e, r, lc, procinfo, collation); + e->closer = CheckElementCloser(e, r, lc, index, procinfos, collations); else if (list_length(added) > 0) { /* @@ -817,7 +990,7 @@ SelectNeighbors(List *c, int m, int lc, FmgrInfo *procinfo, Oid collation, HnswE */ if (e->closer) { - e->closer = CheckElementCloser(e, added, lc, procinfo, collation); + e->closer = CheckElementCloser(e, added, lc, index, procinfos, collations); if (!e->closer) removedAny = true; @@ -830,7 +1003,7 @@ SelectNeighbors(List *c, int m, int lc, FmgrInfo *procinfo, Oid collation, HnswE */ if (removedAny) { - e->closer = CheckElementCloser(e, r, lc, procinfo, collation); + e->closer = CheckElementCloser(e, r, lc, index, procinfos, collations); if (e->closer) added = lappend(added, e); } @@ -838,7 +1011,7 @@ SelectNeighbors(List *c, int m, int lc, FmgrInfo *procinfo, Oid collation, HnswE } else if (e == newCandidate) { - e->closer = CheckElementCloser(e, r, lc, procinfo, collation); + e->closer = CheckElementCloser(e, r, lc, index, procinfos, collations); if (e->closer) added = lappend(added, e); } @@ -872,10 +1045,14 @@ SelectNeighbors(List *c, int m, int lc, FmgrInfo *procinfo, Oid collation, HnswE * Find duplicate element */ HnswElement -HnswFindDuplicate(HnswElement e) +HnswFindDuplicate(HnswElement e, Relation index) { HnswNeighborArray *neighbors = &e->neighbors[0]; + /* TODO Implement */ + if (IndexRelationGetNumberOfAttributes(index) > 1) + return NULL; + for (int i = 0; i < neighbors->length; i++) { HnswCandidate *neighbor = &neighbors->items[i]; @@ -909,7 +1086,7 @@ AddConnections(HnswElement element, List *neighbors, int m, int lc) * Update connections */ void -HnswUpdateConnection(HnswElement element, HnswCandidate * hc, int m, int lc, int *updateIdx, Relation index, FmgrInfo *procinfo, Oid collation) +HnswUpdateConnection(HnswElement element, HnswCandidate * hc, int m, int lc, int *updateIdx, Relation index, FmgrInfo **procinfos, Oid *collations, bool inMemory) { HnswNeighborArray *currentNeighbors = &hc->element->neighbors[lc]; @@ -932,18 +1109,20 @@ HnswUpdateConnection(HnswElement element, HnswCandidate * hc, int m, int lc, int HnswCandidate *pruned = NULL; /* Load elements on insert */ - if (index != NULL) + if (!inMemory) { Datum q = hc->element->value; + IndexTuple qtup = hc->element->itup; + ScanKeyData *keyData = NULL; for (int i = 0; i < currentNeighbors->length; i++) { HnswCandidate *hc3 = ¤tNeighbors->items[i]; if (DatumGetPointer(hc3->element->value) == NULL) - HnswLoadElement(hc3->element, &hc3->distance, &q, index, procinfo, collation, true); + HnswLoadElement(hc3->element, &hc3->distance, &q, qtup, keyData, index, procinfos, collations, true); else - hc3->distance = GetCandidateDistance(hc3, q, procinfo, collation); + hc3->distance = GetCandidateDistance(hc3, q, qtup, keyData, index, procinfos, collations); /* Prune element if being deleted */ if (list_length(hc3->element->heaptids) == 0) @@ -963,7 +1142,7 @@ HnswUpdateConnection(HnswElement element, HnswCandidate * hc, int m, int lc, int c = lappend(c, ¤tNeighbors->items[i]); c = lappend(c, &hc2); - SelectNeighbors(c, m, lc, procinfo, collation, hc->element, &hc2, &pruned, true); + SelectNeighbors(c, m, lc, index, procinfos, collations, hc->element, &hc2, &pruned, true); /* Should not happen */ if (pruned == NULL) @@ -1015,13 +1194,15 @@ RemoveElements(List *w, HnswElement skipElement) * Algorithm 1 from paper */ void -HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, bool existing) +HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo **procinfos, Oid *collations, int m, int efConstruction, bool existing, bool inMemory) { List *ep; List *w; int level = element->level; int entryLevel; Datum q = element->value; + IndexTuple qtup = element->itup; + ScanKeyData *keyData = NULL; HnswElement skipElement = existing ? element : NULL; /* No neighbors if no entry point */ @@ -1029,13 +1210,13 @@ HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, F return; /* Get entry point and level */ - ep = list_make1(HnswEntryCandidate(entryPoint, q, index, procinfo, collation, true)); + ep = list_make1(HnswEntryCandidate(entryPoint, q, qtup, keyData, index, procinfos, collations, true, inMemory)); entryLevel = entryPoint->level; /* 1st phase: greedy search to insert level */ for (int lc = entryLevel; lc >= level + 1; lc--) { - w = HnswSearchLayer(q, ep, 1, lc, index, procinfo, collation, m, true, skipElement); + w = HnswSearchLayer(q, qtup, keyData, ep, 1, lc, index, procinfos, collations, m, true, skipElement, inMemory); ep = w; } @@ -1053,11 +1234,11 @@ HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, F List *neighbors; List *lw; - w = HnswSearchLayer(q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement); + w = HnswSearchLayer(q, qtup, keyData, ep, efConstruction, lc, index, procinfos, collations, m, true, skipElement, inMemory); /* Elements being deleted or skipped can help with search */ /* but should be removed before selecting neighbors */ - if (index != NULL) + if (!inMemory) lw = RemoveElements(w, skipElement); else lw = w; @@ -1067,7 +1248,7 @@ HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, F * sortCandidates to true for in-memory builds to enable closer * caching, but there does not seem to be a difference in performance. */ - neighbors = SelectNeighbors(lw, lm, lc, procinfo, collation, element, NULL, NULL, false); + neighbors = SelectNeighbors(lw, lm, lc, index, procinfos, collations, element, NULL, NULL, false); AddConnections(element, neighbors, lm, lc); diff --git a/src/hnswvacuum.c b/src/hnswvacuum.c index 8fd5d4f..ce05643 100644 --- a/src/hnswvacuum.c +++ b/src/hnswvacuum.c @@ -195,8 +195,8 @@ RepairGraphElement(HnswVacuumState * vacuumstate, HnswElement element, HnswEleme GenericXLogState *state; int m = vacuumstate->m; int efConstruction = vacuumstate->efConstruction; - FmgrInfo *procinfo = vacuumstate->procinfo; - Oid collation = vacuumstate->collation; + FmgrInfo **procinfos = vacuumstate->procinfos; + Oid *collations = vacuumstate->collations; BufferAccessStrategy bas = vacuumstate->bas; HnswNeighborTuple ntup = vacuumstate->ntup; Size ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(element->level, m); @@ -210,7 +210,7 @@ RepairGraphElement(HnswVacuumState * vacuumstate, HnswElement element, HnswEleme element->heaptids = NIL; /* Add element to graph, skipping itself */ - HnswInsertElement(element, entryPoint, index, procinfo, collation, m, efConstruction, true); + HnswInsertElement(element, entryPoint, index, procinfos, collations, m, efConstruction, true, false); /* Update neighbor tuple */ /* Do this before getting page to minimize locking */ @@ -231,7 +231,7 @@ RepairGraphElement(HnswVacuumState * vacuumstate, HnswElement element, HnswEleme UnlockReleaseBuffer(buf); /* Update neighbors */ - HnswUpdateNeighborPages(index, procinfo, collation, element, m, true); + HnswUpdateNeighborPages(index, procinfos, collations, element, m, true); } /* @@ -258,7 +258,7 @@ RepairGraphEntryPoint(HnswVacuumState * vacuumstate) LockPage(index, HNSW_UPDATE_LOCK, ShareLock); /* Load element */ - HnswLoadElement(highestPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true); + HnswLoadElement(highestPoint, NULL, NULL, NULL, NULL, index, vacuumstate->procinfos, vacuumstate->collations, true); /* Repair if needed */ if (NeedsUpdated(vacuumstate, highestPoint)) @@ -296,7 +296,7 @@ RepairGraphEntryPoint(HnswVacuumState * vacuumstate) * is outdated, this can remove connections at higher levels in * the graph until they are repaired, but this should be fine. */ - HnswLoadElement(entryPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true); + HnswLoadElement(entryPoint, NULL, NULL, NULL, NULL, index, vacuumstate->procinfos, vacuumstate->collations, true); if (NeedsUpdated(vacuumstate, entryPoint)) { @@ -372,7 +372,7 @@ RepairGraph(HnswVacuumState * vacuumstate) /* Create an element */ element = HnswInitElementFromBlock(blkno, offno); - HnswLoadElementFromTuple(element, etup, false, true); + HnswLoadElementFromTuple(element, etup, false, true, index); elements = lappend(elements, element); } @@ -442,6 +442,7 @@ MarkDeleted(HnswVacuumState * vacuumstate) BlockNumber insertPage = InvalidBlockNumber; Relation index = vacuumstate->index; BufferAccessStrategy bas = vacuumstate->bas; + bool useIndexTuple = IndexRelationGetNumberOfAttributes(index); /* * Wait for index scans to complete. Scans before this point may contain @@ -530,7 +531,18 @@ MarkDeleted(HnswVacuumState * vacuumstate) /* Overwrite element */ etup->deleted = 1; - MemSet(&etup->data, 0, VARSIZE_ANY(&etup->data)); + if (useIndexTuple) + { + IndexTuple itup = (IndexTuple) &etup->data; + + MemSet(itup, 0, IndexTupleSize(itup)); + } + else + { + Vector *vec = (Vector *) (&etup->data); + + MemSet(vec, 0, VARSIZE_ANY(vec)); + } /* Overwrite neighbors */ for (int i = 0; i < ntup->count; i++) @@ -586,8 +598,8 @@ InitVacuumState(HnswVacuumState * vacuumstate, IndexVacuumInfo *info, IndexBulkD vacuumstate->callback_state = callback_state; vacuumstate->efConstruction = HnswGetEfConstruction(index); vacuumstate->bas = GetAccessStrategy(BAS_BULKREAD); - vacuumstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); - vacuumstate->collation = index->rd_indcollation[0]; + vacuumstate->procinfos = HnswInitProcinfos(index); + vacuumstate->collations = index->rd_indcollation; vacuumstate->ntup = palloc0(BLCKSZ); vacuumstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, "Hnsw vacuum temporary context", @@ -611,6 +623,7 @@ FreeVacuumState(HnswVacuumState * vacuumstate) { hash_destroy(vacuumstate->deleted); FreeAccessStrategy(vacuumstate->bas); + pfree(vacuumstate->procinfos); pfree(vacuumstate->ntup); MemoryContextDelete(vacuumstate->tmpCtx); } diff --git a/test/t/019_hnsw_filtering.pl b/test/t/019_hnsw_filtering.pl new file mode 100644 index 0000000..d7cbfda --- /dev/null +++ b/test/t/019_hnsw_filtering.pl @@ -0,0 +1,107 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $node; +my @queries = (); +my @cs = (); +my @expected; +my $limit = 20; +my $dim = 3; +my $array_sql = join(",", ('random()') x $dim); +my $nc = 50; + +sub test_recall +{ + my ($min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst WHERE c = $cs[0] ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Cond/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SELECT i FROM tst WHERE c = $cs[$i] ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + my %actual_set = map { $_ => 1 } @actual_ids; + + my @expected_ids = split("\n", $expected[$i]); + + foreach (@expected_ids) + { + if (exists($actual_set{$_})) + { + $correct++; + } + $total++; + } + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim), c int4);"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql], i % $nc FROM generate_series(1, 20000) i;" +); + +# Generate queries +for (1 .. 20) +{ + my @r = (); + for (1 .. $dim) + { + push(@r, rand()); + } + push(@queries, "[" . join(",", @r) . "]"); + push(@cs, int(rand() * $nc)); +} + +# Get exact results +@expected = (); +for my $i (0 .. $#queries) +{ + my $res = $node->safe_psql("postgres", "SELECT i FROM tst WHERE c = $cs[$i] ORDER BY v <-> '$queries[$i]' LIMIT $limit;"); + push(@expected, $res); +} + +# Add index +$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops, c);"); + +# Test recall +test_recall(0.99, '<->'); + +# Test vacuum +$node->safe_psql("postgres", "DELETE FROM tst WHERE c > 50;"); +$node->safe_psql("postgres", "VACUUM tst;"); + +# Test columns +my ($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING hnsw (c);"); +like($stderr, qr/first column must be a vector/); + +($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING hnsw (c, v vector_l2_ops);"); +like($stderr, qr/first column must be a vector/); + +($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops, c, c);"); +like($stderr, qr/index cannot have more than two columns/); + +($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops, v vector_l2_ops);"); +like($stderr, qr/column 2 cannot be a vector/); + +done_testing();