From d74139c44714a232525112197de202aa3179dbb4 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 16 Oct 2023 16:11:21 -0700 Subject: [PATCH] Use datums for HNSW [skip ci] --- src/hnsw.c | 7 +++ src/hnsw.h | 12 +++-- src/hnswbuild.c | 32 ++++++----- src/hnswinsert.c | 7 ++- src/hnswutils.c | 42 ++++++++++----- src/hnswvacuum.c | 10 ++-- test/t/019_hnsw_array.pl | 113 +++++++++++++++++++++++++++++++++++++++ 7 files changed, 188 insertions(+), 35 deletions(-) create mode 100644 test/t/019_hnsw_array.pl diff --git a/src/hnsw.c b/src/hnsw.c index 758e418..042d045 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -33,6 +33,12 @@ HnswInit(void) HNSW_DEFAULT_EF_CONSTRUCTION, HNSW_MIN_EF_CONSTRUCTION, HNSW_MAX_EF_CONSTRUCTION #if PG_VERSION_NUM >= 130000 ,AccessExclusiveLock +#endif + ); + add_int_reloption(hnsw_relopt_kind, "dimensions", "Number of dimensions", + HNSW_DEFAULT_DIMENSIONS, HNSW_MIN_DIMENSIONS, HNSW_MAX_DIMENSIONS +#if PG_VERSION_NUM >= 130000 + ,AccessExclusiveLock #endif ); @@ -125,6 +131,7 @@ hnswoptions(Datum reloptions, bool validate) static const relopt_parse_elt tab[] = { {"m", RELOPT_TYPE_INT, offsetof(HnswOptions, m)}, {"ef_construction", RELOPT_TYPE_INT, offsetof(HnswOptions, efConstruction)}, + {"dimensions", RELOPT_TYPE_INT, offsetof(HnswOptions, dimensions)}, }; #if PG_VERSION_NUM >= 130000 diff --git a/src/hnsw.h b/src/hnsw.h index eb2aa9f..9f9dfae 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -42,6 +42,9 @@ #define HNSW_DEFAULT_EF_SEARCH 40 #define HNSW_MIN_EF_SEARCH 1 #define HNSW_MAX_EF_SEARCH 1000 +#define HNSW_DEFAULT_DIMENSIONS -1 +#define HNSW_MIN_DIMENSIONS 1 +#define HNSW_MAX_DIMENSIONS HNSW_MAX_DIM /* Tuple types */ #define HNSW_ELEMENT_TUPLE_TYPE 1 @@ -59,7 +62,7 @@ #define HNSW_MAX_SIZE (BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(HnswPageOpaqueData)) - sizeof(ItemIdData)) -#define HNSW_ELEMENT_TUPLE_SIZE(_dim) MAXALIGN(offsetof(HnswElementTupleData, vec) + VECTOR_SIZE(_dim)) +#define HNSW_ELEMENT_TUPLE_SIZE(_datum) MAXALIGN(offsetof(HnswElementTupleData, value) + VARSIZE_ANY(_datum)) #define HNSW_NEIGHBOR_TUPLE_SIZE(level, m) MAXALIGN(offsetof(HnswNeighborTupleData, indextids) + ((level) + 2) * (m) * sizeof(ItemPointerData)) #define HnswPageGetOpaque(page) ((HnswPageOpaque) PageGetSpecialPointer(page)) @@ -98,12 +101,13 @@ typedef struct HnswElementData List *heaptids; uint8 level; uint8 deleted; + bool loaded; HnswNeighborArray *neighbors; BlockNumber blkno; OffsetNumber offno; OffsetNumber neighborOffno; BlockNumber neighborPage; - Vector *vec; + Datum value; } HnswElementData; typedef HnswElementData * HnswElement; @@ -134,6 +138,7 @@ typedef struct HnswOptions int32 vl_len_; /* varlena header (do not touch directly!) */ int m; /* number of connections */ int efConstruction; /* size of dynamic candidate list */ + int dimensions; } HnswOptions; typedef struct HnswBuildState @@ -204,7 +209,7 @@ typedef struct HnswElementTupleData ItemPointerData heaptids[HNSW_HEAPTIDS]; ItemPointerData neighbortid; uint16 unused2; - Vector vec; + char value[FLEXIBLE_ARRAY_MEMBER]; } HnswElementTupleData; typedef HnswElementTupleData * HnswElementTuple; @@ -262,6 +267,7 @@ typedef struct HnswVacuumState /* Methods */ int HnswGetM(Relation index); int HnswGetEfConstruction(Relation index); +int HnswGetDimensions(Relation index); FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum); bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result); void HnswCommitBuffer(Buffer buf, GenericXLogState *state); diff --git a/src/hnswbuild.c b/src/hnswbuild.c index 18959d5..556f14a 100644 --- a/src/hnswbuild.c +++ b/src/hnswbuild.c @@ -8,6 +8,7 @@ #include "lib/pairingheap.h" #include "nodes/pg_list.h" #include "storage/bufmgr.h" +#include "utils/datum.h" #include "utils/memutils.h" #if PG_VERSION_NUM >= 140000 @@ -105,8 +106,6 @@ CreateElementPages(HnswBuildState * buildstate) { Relation index = buildstate->index; ForkNumber forkNum = buildstate->forkNum; - int dimensions = buildstate->dimensions; - Size etupSize; Size maxSize; HnswElementTuple etup; HnswNeighborTuple ntup; @@ -118,10 +117,9 @@ CreateElementPages(HnswBuildState * buildstate) /* Calculate sizes */ maxSize = HNSW_MAX_SIZE; - etupSize = HNSW_ELEMENT_TUPLE_SIZE(dimensions); /* Allocate once */ - etup = palloc0(etupSize); + etup = palloc0(BLCKSZ); ntup = palloc0(BLCKSZ); /* Prepare first page */ @@ -133,12 +131,14 @@ CreateElementPages(HnswBuildState * buildstate) foreach(lc, buildstate->elements) { HnswElement element = lfirst(lc); + Size etupSize; Size ntupSize; Size combinedSize; HnswSetElementTuple(etup, element); /* Calculate sizes */ + etupSize = HNSW_ELEMENT_TUPLE_SIZE(element->value); ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(element->level, buildstate->m); combinedSize = etupSize + ntupSize + sizeof(ItemIdData); @@ -273,18 +273,15 @@ InsertTuple(Relation index, Datum *values, HnswElement element, HnswBuildState * int m = buildstate->m; /* Detoast once for all calls */ - Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); + element->value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); /* Normalize if needed */ if (buildstate->normprocinfo != NULL) { - if (!HnswNormValue(buildstate->normprocinfo, collation, &value, buildstate->normvec)) + if (!HnswNormValue(buildstate->normprocinfo, collation, &element->value, buildstate->normvec)) return false; } - /* Copy value to element so accessible outside of memory context */ - memcpy(element->vec, DatumGetVector(value), VECTOR_SIZE(buildstate->dimensions)); - /* Insert element in graph */ HnswInsertElement(element, entryPoint, NULL, procinfo, collation, m, efConstruction, false); @@ -360,7 +357,6 @@ BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values, /* Allocate necessary memory outside of memory context */ element = HnswInitElement(tid, buildstate->m, buildstate->ml, buildstate->maxLevel); - element->vec = palloc(VECTOR_SIZE(buildstate->dimensions)); /* Use memory context since detoast can allocate */ oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx); @@ -368,9 +364,8 @@ BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values, /* Insert tuple */ inserted = InsertTuple(index, values, element, buildstate, &dup); - /* Reset memory context */ + /* Switch memory context */ MemoryContextSwitchTo(oldCtx); - MemoryContextReset(buildstate->tmpCtx); /* Add outside memory context */ if (dup != NULL) @@ -378,9 +373,16 @@ BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values, /* Add to buildstate or free */ if (inserted) + { + element->value = datumCopy(element->value, false, -1); + element->loaded = true; buildstate->elements = lappend(buildstate->elements, element); + } else HnswFreeElement(element); + + /* Reset memory context */ + MemoryContextReset(buildstate->tmpCtx); } /* @@ -395,6 +397,7 @@ HnswGetMaxInMemoryElements(int m, double ml, int dimensions) elementSize += sizeof(HnswNeighborArray) * (avgLevel + 1); elementSize += sizeof(HnswCandidate) * (m * (avgLevel + 2)); elementSize += sizeof(ItemPointerData); + /* TODO Handle non-vector types */ elementSize += VECTOR_SIZE(dimensions); return (maintenance_work_mem * 1024L) / elementSize; } @@ -412,7 +415,10 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index buildstate->m = HnswGetM(index); buildstate->efConstruction = HnswGetEfConstruction(index); - buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; + buildstate->dimensions = HnswGetDimensions(index); + + if (buildstate->dimensions < 0) + buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; /* Require column to have dimensions to be indexed */ if (buildstate->dimensions < 0) diff --git a/src/hnswinsert.c b/src/hnswinsert.c index f7cd51f..311dc23 100644 --- a/src/hnswinsert.c +++ b/src/hnswinsert.c @@ -123,7 +123,6 @@ WriteNewElementPages(Relation index, HnswElement e, int m, BlockNumber insertPag Size minCombinedSize; HnswElementTuple etup; BlockNumber currentPage = insertPage; - int dimensions = e->vec->dim; HnswNeighborTuple ntup; Buffer nbuf; Page npage; @@ -132,7 +131,7 @@ WriteNewElementPages(Relation index, HnswElement e, int m, BlockNumber insertPag BlockNumber newInsertPage = InvalidBlockNumber; /* Calculate sizes */ - etupSize = HNSW_ELEMENT_TUPLE_SIZE(dimensions); + etupSize = HNSW_ELEMENT_TUPLE_SIZE(e->value); ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(e->level, m); combinedSize = etupSize + ntupSize + sizeof(ItemIdData); maxSize = HNSW_MAX_SIZE; @@ -405,7 +404,7 @@ HnswAddDuplicate(Relation index, HnswElement element, HnswElement dup) Buffer buf; Page page; GenericXLogState *state; - Size etupSize = HNSW_ELEMENT_TUPLE_SIZE(dup->vec->dim); + Size etupSize = HNSW_ELEMENT_TUPLE_SIZE(dup->value); HnswElementTuple etup; int i; @@ -515,7 +514,7 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti /* Create an element */ element = HnswInitElement(heap_tid, m, HnswGetMl(m), HnswGetMaxLevel(m)); - element->vec = DatumGetVector(value); + element->value = value; /* Prevent concurrent inserts when likely updating entry point */ if (entryPoint == NULL || element->level > entryPoint->level) diff --git a/src/hnswutils.c b/src/hnswutils.c index e7d1705..24beae5 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -4,6 +4,7 @@ #include "hnsw.h" #include "storage/bufmgr.h" +#include "utils/datum.h" #include "vector.h" /* @@ -34,6 +35,20 @@ HnswGetEfConstruction(Relation index) return HNSW_DEFAULT_EF_CONSTRUCTION; } +/* + * Get the number of dimensions in the index + */ +int +HnswGetDimensions(Relation index) +{ + HnswOptions *opts = (HnswOptions *) index->rd_options; + + if (opts) + return opts->dimensions; + + return HNSW_DEFAULT_DIMENSIONS; +} + /* * Get proc */ @@ -187,7 +202,8 @@ HnswFreeElement(HnswElement element) { HnswFreeNeighbors(element); list_free_deep(element->heaptids); - pfree(element->vec); + if (element->loaded) + pfree(DatumGetPointer(element->value)); pfree(element); } @@ -214,7 +230,7 @@ HnswInitElementFromBlock(BlockNumber blkno, OffsetNumber offno) element->blkno = blkno; element->offno = offno; element->neighbors = NULL; - element->vec = NULL; + element->loaded = false; return element; } @@ -324,7 +340,7 @@ HnswSetElementTuple(HnswElementTuple etup, HnswElement element) else ItemPointerSetInvalid(&etup->heaptids[i]); } - memcpy(&etup->vec, element->vec, VECTOR_SIZE(element->vec->dim)); + memcpy(&etup->value, DatumGetPointer(element->value), VARSIZE_ANY(element->value)); } /* @@ -447,8 +463,10 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe if (loadVec) { - element->vec = palloc(VECTOR_SIZE(etup->vec.dim)); - memcpy(element->vec, &etup->vec, VECTOR_SIZE(etup->vec.dim)); + Datum value = PointerGetDatum(&etup->value); + + element->value = datumCopy(value, false, -1); + element->loaded = true; } } @@ -476,7 +494,7 @@ HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, /* Calculate distance */ if (distance != NULL) - *distance = (float) DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&etup->vec))); + *distance = (float) DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&etup->value))); UnlockReleaseBuffer(buf); } @@ -487,7 +505,7 @@ HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, static float GetCandidateDistance(HnswCandidate * hc, Datum q, FmgrInfo *procinfo, Oid collation) { - return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, PointerGetDatum(hc->element->vec))); + return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, hc->element->value)); } /* @@ -750,7 +768,7 @@ HnswGetDistance(HnswElement a, HnswElement b, int lc, FmgrInfo *procinfo, Oid co } } - return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(a->vec), PointerGetDatum(b->vec))); + return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, a->value, b->value)); } /* @@ -877,7 +895,7 @@ HnswFindDuplicate(HnswElement e) HnswCandidate *neighbor = &neighbors->items[i]; /* Exit early since ordered by distance */ - if (vector_cmp_internal(e->vec, neighbor->element->vec) != 0) + if (!datumIsEqual(e->value, neighbor->element->value, false, -1)) break; /* Check for space */ @@ -930,13 +948,13 @@ HnswUpdateConnection(HnswElement element, HnswCandidate * hc, int m, int lc, int /* Load elements on insert */ if (index != NULL) { - Datum q = PointerGetDatum(hc->element->vec); + Datum q = hc->element->value; for (int i = 0; i < currentNeighbors->length; i++) { HnswCandidate *hc3 = ¤tNeighbors->items[i]; - if (hc3->element->vec == NULL) + if (!hc3->element->loaded) HnswLoadElement(hc3->element, &hc3->distance, &q, index, procinfo, collation, true); else hc3->distance = GetCandidateDistance(hc3, q, procinfo, collation); @@ -1017,7 +1035,7 @@ HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, F List *w; int level = element->level; int entryLevel; - Datum q = PointerGetDatum(element->vec); + Datum q = element->value; HnswElement skipElement = existing ? element : NULL; /* No neighbors if no entry point */ diff --git a/src/hnswvacuum.c b/src/hnswvacuum.c index 29b675f..1188d21 100644 --- a/src/hnswvacuum.c +++ b/src/hnswvacuum.c @@ -93,7 +93,7 @@ RemoveHeapTids(HnswVacuumState * vacuumstate) if (itemUpdated) { - Size etupSize = HNSW_ELEMENT_TUPLE_SIZE(etup->vec.dim); + Size etupSize = HNSW_ELEMENT_TUPLE_SIZE(PointerGetDatum(&etup->value)); /* Mark rest as invalid */ for (int i = idx; i < HNSW_HEAPTIDS; i++) @@ -481,6 +481,7 @@ MarkDeleted(HnswVacuumState * vacuumstate) HnswNeighborTuple ntup; Size etupSize; Size ntupSize; + Datum value; Buffer nbuf; Page npage; BlockNumber neighborPage; @@ -504,8 +505,11 @@ MarkDeleted(HnswVacuumState * vacuumstate) if (ItemPointerIsValid(&etup->heaptids[0])) continue; + /* Get datum */ + value = PointerGetDatum(&etup->value); + /* Calculate sizes */ - etupSize = HNSW_ELEMENT_TUPLE_SIZE(etup->vec.dim); + etupSize = HNSW_ELEMENT_TUPLE_SIZE(value); ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(etup->level, vacuumstate->m); /* Get neighbor page */ @@ -528,7 +532,7 @@ MarkDeleted(HnswVacuumState * vacuumstate) /* Overwrite element */ etup->deleted = 1; - MemSet(&etup->vec.x, 0, etup->vec.dim * sizeof(float)); + MemSet(&etup->value, 0, VARSIZE_ANY(value)); /* Overwrite neighbors */ for (int i = 0; i < ntup->count; i++) diff --git a/test/t/019_hnsw_array.pl b/test/t/019_hnsw_array.pl new file mode 100644 index 0000000..5478599 --- /dev/null +++ b/test/t/019_hnsw_array.pl @@ -0,0 +1,113 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; + +sub test_recall +{ + my ($min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Scan/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + my %actual_set = map { $_ => 1 } @actual_ids; + + my @expected_ids = split("\n", $expected[$i]); + + foreach (@expected_ids) + { + if (exists($actual_set{$_})) + { + $correct++; + } + $total++; + } + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v float4[3]);"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 10000) i;" +); + +$node->safe_psql("postgres", qq( + CREATE FUNCTION float4_l2_distance(float4[], float4[]) RETURNS float8 + AS 'BEGIN RETURN l2_distance(\$1::vector, \$2::vector); END;' + LANGUAGE plpgsql IMMUTABLE STRICT PARALLEL SAFE; + + CREATE FUNCTION float4_l2_squared_distance(float4[], float4[]) RETURNS float8 + AS 'BEGIN RETURN vector_l2_squared_distance(\$1::vector, \$2::vector); END;' + LANGUAGE plpgsql IMMUTABLE STRICT PARALLEL SAFE; + + CREATE OPERATOR <-> ( + LEFTARG = float4[], RIGHTARG = float4[], PROCEDURE = float4_l2_distance, + COMMUTATOR = '<->' + ); + + CREATE OPERATOR CLASS float4_l2_ops + FOR TYPE float4[] USING hnsw AS + OPERATOR 1 <-> (float4[], float4[]) FOR ORDER BY float_ops, + FUNCTION 1 float4_l2_squared_distance(float4[], float4[]); +)); + +# Generate queries +for (1 .. 20) +{ + my $r1 = rand(); + my $r2 = rand(); + my $r3 = rand(); + push(@queries, "{$r1,$r2,$r3}"); +} + +# Check each index type +my @operators = ("<->"); +my @opclasses = ("float4_l2_ops"); + +for my $i (0 .. $#operators) +{ + my $operator = $operators[$i]; + my $opclass = $opclasses[$i]; + + # Get exact results + @expected = (); + foreach (@queries) + { + my $res = $node->safe_psql("postgres", "SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit;"); + push(@expected, $res); + } + + # Add index + $node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v $opclass) WITH (dimensions = 3);"); + + my $min = $operator eq "<#>" ? 0.80 : 0.99; + test_recall($min, $operator); +} + +done_testing();