From f858705293efc5f2eb9f4886a209a7e7cd098ab4 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 16 Oct 2023 16:20:10 -0700 Subject: [PATCH] No dimensions --- src/hnsw.h | 7 +-- src/hnswbuild.c | 32 +++++------ src/hnswinsert.c | 7 ++- src/hnswutils.c | 28 +++++----- src/hnswvacuum.c | 10 ++-- test/t/019_hnsw_array.pl | 113 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 159 insertions(+), 38 deletions(-) create mode 100644 test/t/019_hnsw_array.pl diff --git a/src/hnsw.h b/src/hnsw.h index eb2aa9f..ad8acbd 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -59,7 +59,7 @@ #define HNSW_MAX_SIZE (BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(HnswPageOpaqueData)) - sizeof(ItemIdData)) -#define HNSW_ELEMENT_TUPLE_SIZE(_dim) MAXALIGN(offsetof(HnswElementTupleData, vec) + VECTOR_SIZE(_dim)) +#define HNSW_ELEMENT_TUPLE_SIZE(_datum) MAXALIGN(offsetof(HnswElementTupleData, value) + VARSIZE_ANY(_datum)) #define HNSW_NEIGHBOR_TUPLE_SIZE(level, m) MAXALIGN(offsetof(HnswNeighborTupleData, indextids) + ((level) + 2) * (m) * sizeof(ItemPointerData)) #define HnswPageGetOpaque(page) ((HnswPageOpaque) PageGetSpecialPointer(page)) @@ -98,12 +98,13 @@ typedef struct HnswElementData List *heaptids; uint8 level; uint8 deleted; + bool loaded; HnswNeighborArray *neighbors; BlockNumber blkno; OffsetNumber offno; OffsetNumber neighborOffno; BlockNumber neighborPage; - Vector *vec; + Datum value; } HnswElementData; typedef HnswElementData * HnswElement; @@ -204,7 +205,7 @@ typedef struct HnswElementTupleData ItemPointerData heaptids[HNSW_HEAPTIDS]; ItemPointerData neighbortid; uint16 unused2; - Vector vec; + char value[FLEXIBLE_ARRAY_MEMBER]; } HnswElementTupleData; typedef HnswElementTupleData * HnswElementTuple; diff --git a/src/hnswbuild.c b/src/hnswbuild.c index 18959d5..fb20804 100644 --- a/src/hnswbuild.c +++ b/src/hnswbuild.c @@ -8,6 +8,7 @@ #include "lib/pairingheap.h" #include "nodes/pg_list.h" #include "storage/bufmgr.h" +#include "utils/datum.h" #include "utils/memutils.h" #if PG_VERSION_NUM >= 140000 @@ -105,8 +106,6 @@ CreateElementPages(HnswBuildState * buildstate) { Relation index = buildstate->index; ForkNumber forkNum = buildstate->forkNum; - int dimensions = buildstate->dimensions; - Size etupSize; Size maxSize; HnswElementTuple etup; HnswNeighborTuple ntup; @@ -118,10 +117,9 @@ CreateElementPages(HnswBuildState * buildstate) /* Calculate sizes */ maxSize = HNSW_MAX_SIZE; - etupSize = HNSW_ELEMENT_TUPLE_SIZE(dimensions); /* Allocate once */ - etup = palloc0(etupSize); + etup = palloc0(BLCKSZ); ntup = palloc0(BLCKSZ); /* Prepare first page */ @@ -133,12 +131,14 @@ CreateElementPages(HnswBuildState * buildstate) foreach(lc, buildstate->elements) { HnswElement element = lfirst(lc); + Size etupSize; Size ntupSize; Size combinedSize; HnswSetElementTuple(etup, element); /* Calculate sizes */ + etupSize = HNSW_ELEMENT_TUPLE_SIZE(element->value); ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(element->level, buildstate->m); combinedSize = etupSize + ntupSize + sizeof(ItemIdData); @@ -273,18 +273,15 @@ InsertTuple(Relation index, Datum *values, HnswElement element, HnswBuildState * int m = buildstate->m; /* Detoast once for all calls */ - Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); + element->value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); /* Normalize if needed */ if (buildstate->normprocinfo != NULL) { - if (!HnswNormValue(buildstate->normprocinfo, collation, &value, buildstate->normvec)) + if (!HnswNormValue(buildstate->normprocinfo, collation, &element->value, buildstate->normvec)) return false; } - /* Copy value to element so accessible outside of memory context */ - memcpy(element->vec, DatumGetVector(value), VECTOR_SIZE(buildstate->dimensions)); - /* Insert element in graph */ HnswInsertElement(element, entryPoint, NULL, procinfo, collation, m, efConstruction, false); @@ -360,7 +357,6 @@ BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values, /* Allocate necessary memory outside of memory context */ element = HnswInitElement(tid, buildstate->m, buildstate->ml, buildstate->maxLevel); - element->vec = palloc(VECTOR_SIZE(buildstate->dimensions)); /* Use memory context since detoast can allocate */ oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx); @@ -368,9 +364,8 @@ BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values, /* Insert tuple */ inserted = InsertTuple(index, values, element, buildstate, &dup); - /* Reset memory context */ + /* Switch memory context */ MemoryContextSwitchTo(oldCtx); - MemoryContextReset(buildstate->tmpCtx); /* Add outside memory context */ if (dup != NULL) @@ -378,9 +373,16 @@ BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values, /* Add to buildstate or free */ if (inserted) + { + element->value = datumCopy(element->value, false, -1); + element->loaded = true; buildstate->elements = lappend(buildstate->elements, element); + } else HnswFreeElement(element); + + /* Reset memory context */ + MemoryContextReset(buildstate->tmpCtx); } /* @@ -395,6 +397,7 @@ HnswGetMaxInMemoryElements(int m, double ml, int dimensions) elementSize += sizeof(HnswNeighborArray) * (avgLevel + 1); elementSize += sizeof(HnswCandidate) * (m * (avgLevel + 2)); elementSize += sizeof(ItemPointerData); + /* TODO Handle non-vector types */ elementSize += VECTOR_SIZE(dimensions); return (maintenance_work_mem * 1024L) / elementSize; } @@ -414,10 +417,6 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index buildstate->efConstruction = HnswGetEfConstruction(index); buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; - /* Require column to have dimensions to be indexed */ - if (buildstate->dimensions < 0) - elog(ERROR, "column does not have dimensions"); - if (buildstate->dimensions > HNSW_MAX_DIM) elog(ERROR, "column cannot have more than %d dimensions for hnsw index", HNSW_MAX_DIM); @@ -440,6 +439,7 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index buildstate->flushed = false; /* Reuse for each tuple */ + /* TODO Fix / replace with support function */ buildstate->normvec = InitVector(buildstate->dimensions); buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, diff --git a/src/hnswinsert.c b/src/hnswinsert.c index f7cd51f..311dc23 100644 --- a/src/hnswinsert.c +++ b/src/hnswinsert.c @@ -123,7 +123,6 @@ WriteNewElementPages(Relation index, HnswElement e, int m, BlockNumber insertPag Size minCombinedSize; HnswElementTuple etup; BlockNumber currentPage = insertPage; - int dimensions = e->vec->dim; HnswNeighborTuple ntup; Buffer nbuf; Page npage; @@ -132,7 +131,7 @@ WriteNewElementPages(Relation index, HnswElement e, int m, BlockNumber insertPag BlockNumber newInsertPage = InvalidBlockNumber; /* Calculate sizes */ - etupSize = HNSW_ELEMENT_TUPLE_SIZE(dimensions); + etupSize = HNSW_ELEMENT_TUPLE_SIZE(e->value); ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(e->level, m); combinedSize = etupSize + ntupSize + sizeof(ItemIdData); maxSize = HNSW_MAX_SIZE; @@ -405,7 +404,7 @@ HnswAddDuplicate(Relation index, HnswElement element, HnswElement dup) Buffer buf; Page page; GenericXLogState *state; - Size etupSize = HNSW_ELEMENT_TUPLE_SIZE(dup->vec->dim); + Size etupSize = HNSW_ELEMENT_TUPLE_SIZE(dup->value); HnswElementTuple etup; int i; @@ -515,7 +514,7 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti /* Create an element */ element = HnswInitElement(heap_tid, m, HnswGetMl(m), HnswGetMaxLevel(m)); - element->vec = DatumGetVector(value); + element->value = value; /* Prevent concurrent inserts when likely updating entry point */ if (entryPoint == NULL || element->level > entryPoint->level) diff --git a/src/hnswutils.c b/src/hnswutils.c index e7d1705..1454709 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -4,6 +4,7 @@ #include "hnsw.h" #include "storage/bufmgr.h" +#include "utils/datum.h" #include "vector.h" /* @@ -187,7 +188,8 @@ HnswFreeElement(HnswElement element) { HnswFreeNeighbors(element); list_free_deep(element->heaptids); - pfree(element->vec); + if (element->loaded) + pfree(DatumGetPointer(element->value)); pfree(element); } @@ -214,7 +216,7 @@ HnswInitElementFromBlock(BlockNumber blkno, OffsetNumber offno) element->blkno = blkno; element->offno = offno; element->neighbors = NULL; - element->vec = NULL; + element->loaded = false; return element; } @@ -324,7 +326,7 @@ HnswSetElementTuple(HnswElementTuple etup, HnswElement element) else ItemPointerSetInvalid(&etup->heaptids[i]); } - memcpy(&etup->vec, element->vec, VECTOR_SIZE(element->vec->dim)); + memcpy(&etup->value, DatumGetPointer(element->value), VARSIZE_ANY(element->value)); } /* @@ -447,8 +449,10 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe if (loadVec) { - element->vec = palloc(VECTOR_SIZE(etup->vec.dim)); - memcpy(element->vec, &etup->vec, VECTOR_SIZE(etup->vec.dim)); + Datum value = PointerGetDatum(&etup->value); + + element->value = datumCopy(value, false, -1); + element->loaded = true; } } @@ -476,7 +480,7 @@ HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, /* Calculate distance */ if (distance != NULL) - *distance = (float) DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&etup->vec))); + *distance = (float) DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&etup->value))); UnlockReleaseBuffer(buf); } @@ -487,7 +491,7 @@ HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, static float GetCandidateDistance(HnswCandidate * hc, Datum q, FmgrInfo *procinfo, Oid collation) { - return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, PointerGetDatum(hc->element->vec))); + return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, hc->element->value)); } /* @@ -750,7 +754,7 @@ HnswGetDistance(HnswElement a, HnswElement b, int lc, FmgrInfo *procinfo, Oid co } } - return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(a->vec), PointerGetDatum(b->vec))); + return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, a->value, b->value)); } /* @@ -877,7 +881,7 @@ HnswFindDuplicate(HnswElement e) HnswCandidate *neighbor = &neighbors->items[i]; /* Exit early since ordered by distance */ - if (vector_cmp_internal(e->vec, neighbor->element->vec) != 0) + if (!datumIsEqual(e->value, neighbor->element->value, false, -1)) break; /* Check for space */ @@ -930,13 +934,13 @@ HnswUpdateConnection(HnswElement element, HnswCandidate * hc, int m, int lc, int /* Load elements on insert */ if (index != NULL) { - Datum q = PointerGetDatum(hc->element->vec); + Datum q = hc->element->value; for (int i = 0; i < currentNeighbors->length; i++) { HnswCandidate *hc3 = ¤tNeighbors->items[i]; - if (hc3->element->vec == NULL) + if (!hc3->element->loaded) HnswLoadElement(hc3->element, &hc3->distance, &q, index, procinfo, collation, true); else hc3->distance = GetCandidateDistance(hc3, q, procinfo, collation); @@ -1017,7 +1021,7 @@ HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, F List *w; int level = element->level; int entryLevel; - Datum q = PointerGetDatum(element->vec); + Datum q = element->value; HnswElement skipElement = existing ? element : NULL; /* No neighbors if no entry point */ diff --git a/src/hnswvacuum.c b/src/hnswvacuum.c index 29b675f..1188d21 100644 --- a/src/hnswvacuum.c +++ b/src/hnswvacuum.c @@ -93,7 +93,7 @@ RemoveHeapTids(HnswVacuumState * vacuumstate) if (itemUpdated) { - Size etupSize = HNSW_ELEMENT_TUPLE_SIZE(etup->vec.dim); + Size etupSize = HNSW_ELEMENT_TUPLE_SIZE(PointerGetDatum(&etup->value)); /* Mark rest as invalid */ for (int i = idx; i < HNSW_HEAPTIDS; i++) @@ -481,6 +481,7 @@ MarkDeleted(HnswVacuumState * vacuumstate) HnswNeighborTuple ntup; Size etupSize; Size ntupSize; + Datum value; Buffer nbuf; Page npage; BlockNumber neighborPage; @@ -504,8 +505,11 @@ MarkDeleted(HnswVacuumState * vacuumstate) if (ItemPointerIsValid(&etup->heaptids[0])) continue; + /* Get datum */ + value = PointerGetDatum(&etup->value); + /* Calculate sizes */ - etupSize = HNSW_ELEMENT_TUPLE_SIZE(etup->vec.dim); + etupSize = HNSW_ELEMENT_TUPLE_SIZE(value); ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(etup->level, vacuumstate->m); /* Get neighbor page */ @@ -528,7 +532,7 @@ MarkDeleted(HnswVacuumState * vacuumstate) /* Overwrite element */ etup->deleted = 1; - MemSet(&etup->vec.x, 0, etup->vec.dim * sizeof(float)); + MemSet(&etup->value, 0, VARSIZE_ANY(value)); /* Overwrite neighbors */ for (int i = 0; i < ntup->count; i++) diff --git a/test/t/019_hnsw_array.pl b/test/t/019_hnsw_array.pl new file mode 100644 index 0000000..6f11ed8 --- /dev/null +++ b/test/t/019_hnsw_array.pl @@ -0,0 +1,113 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; + +sub test_recall +{ + my ($min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Scan/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + my %actual_set = map { $_ => 1 } @actual_ids; + + my @expected_ids = split("\n", $expected[$i]); + + foreach (@expected_ids) + { + if (exists($actual_set{$_})) + { + $correct++; + } + $total++; + } + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v float4[3]);"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 10000) i;" +); + +$node->safe_psql("postgres", qq( + CREATE FUNCTION float4_l2_distance(float4[], float4[]) RETURNS float8 + AS 'BEGIN RETURN l2_distance(\$1::vector, \$2::vector); END;' + LANGUAGE plpgsql IMMUTABLE STRICT PARALLEL SAFE; + + CREATE FUNCTION float4_l2_squared_distance(float4[], float4[]) RETURNS float8 + AS 'BEGIN RETURN vector_l2_squared_distance(\$1::vector, \$2::vector); END;' + LANGUAGE plpgsql IMMUTABLE STRICT PARALLEL SAFE; + + CREATE OPERATOR <-> ( + LEFTARG = float4[], RIGHTARG = float4[], PROCEDURE = float4_l2_distance, + COMMUTATOR = '<->' + ); + + CREATE OPERATOR CLASS float4_l2_ops + FOR TYPE float4[] USING hnsw AS + OPERATOR 1 <-> (float4[], float4[]) FOR ORDER BY float_ops, + FUNCTION 1 float4_l2_squared_distance(float4[], float4[]); +)); + +# Generate queries +for (1 .. 20) +{ + my $r1 = rand(); + my $r2 = rand(); + my $r3 = rand(); + push(@queries, "{$r1,$r2,$r3}"); +} + +# Check each index type +my @operators = ("<->"); +my @opclasses = ("float4_l2_ops"); + +for my $i (0 .. $#operators) +{ + my $operator = $operators[$i]; + my $opclass = $opclasses[$i]; + + # Get exact results + @expected = (); + foreach (@queries) + { + my $res = $node->safe_psql("postgres", "SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit;"); + push(@expected, $res); + } + + # Add index + $node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v $opclass);"); + + my $min = $operator eq "<#>" ? 0.80 : 0.99; + test_recall($min, $operator); +} + +done_testing();