From 58ec5296b076819b7cfac1a2070bc8d69984c284 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Thu, 25 Apr 2024 13:21:24 -0700 Subject: [PATCH] Reduced support functions for HNSW - #527 --- sql/vector--0.6.2--0.7.0.sql | 22 ++++++++++------------ sql/vector.sql | 22 ++++++++++------------ src/hnsw.c | 2 +- src/hnsw.h | 8 +++----- src/hnswbuild.c | 8 ++++---- src/hnswinsert.c | 2 +- src/hnswscan.c | 4 ++-- src/hnswutils.c | 20 ++++++++------------ 8 files changed, 39 insertions(+), 49 deletions(-) diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index e6f3788..219463c 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -69,13 +69,13 @@ CREATE OPERATOR CLASS bit_hamming_ops FOR TYPE bit USING hnsw AS OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, FUNCTION 1 hamming_distance(bit, bit), - FUNCTION 4 hnsw_bit_support(internal); + FUNCTION 3 hnsw_bit_support(internal); CREATE OPERATOR CLASS bit_jaccard_ops FOR TYPE bit USING hnsw AS OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops, FUNCTION 1 jaccard_distance(bit, bit), - FUNCTION 4 hnsw_bit_support(internal); + FUNCTION 3 hnsw_bit_support(internal); CREATE TYPE halfvec; @@ -355,27 +355,26 @@ CREATE OPERATOR CLASS halfvec_l2_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec), - FUNCTION 4 hnsw_halfvec_support(internal); + FUNCTION 3 hnsw_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_ip_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), - FUNCTION 4 hnsw_halfvec_support(internal); + FUNCTION 3 hnsw_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_cosine_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), FUNCTION 2 l2_norm(halfvec), - FUNCTION 3 l2_normalize(halfvec), - FUNCTION 4 hnsw_halfvec_support(internal); + FUNCTION 3 hnsw_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_l1_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <+> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 l1_distance(halfvec, halfvec), - FUNCTION 4 hnsw_halfvec_support(internal); + FUNCTION 3 hnsw_halfvec_support(internal); CREATE TYPE sparsevec; @@ -547,24 +546,23 @@ CREATE OPERATOR CLASS sparsevec_l2_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <-> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 sparsevec_l2_squared_distance(sparsevec, sparsevec), - FUNCTION 4 hnsw_sparsevec_support(internal); + FUNCTION 3 hnsw_sparsevec_support(internal); CREATE OPERATOR CLASS sparsevec_ip_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <#> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec), - FUNCTION 4 hnsw_sparsevec_support(internal); + FUNCTION 3 hnsw_sparsevec_support(internal); CREATE OPERATOR CLASS sparsevec_cosine_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <=> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec), FUNCTION 2 l2_norm(sparsevec), - FUNCTION 3 l2_normalize(sparsevec), - FUNCTION 4 hnsw_sparsevec_support(internal); + FUNCTION 3 hnsw_sparsevec_support(internal); CREATE OPERATOR CLASS sparsevec_l1_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <+> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 l1_distance(sparsevec, sparsevec), - FUNCTION 4 hnsw_sparsevec_support(internal); + FUNCTION 3 hnsw_sparsevec_support(internal); diff --git a/sql/vector.sql b/sql/vector.sql index 527ad67..08f4ffa 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -364,13 +364,13 @@ CREATE OPERATOR CLASS bit_hamming_ops FOR TYPE bit USING hnsw AS OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, FUNCTION 1 hamming_distance(bit, bit), - FUNCTION 4 hnsw_bit_support(internal); + FUNCTION 3 hnsw_bit_support(internal); CREATE OPERATOR CLASS bit_jaccard_ops FOR TYPE bit USING hnsw AS OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops, FUNCTION 1 jaccard_distance(bit, bit), - FUNCTION 4 hnsw_bit_support(internal); + FUNCTION 3 hnsw_bit_support(internal); -- halfvec type @@ -666,27 +666,26 @@ CREATE OPERATOR CLASS halfvec_l2_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec), - FUNCTION 4 hnsw_halfvec_support(internal); + FUNCTION 3 hnsw_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_ip_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), - FUNCTION 4 hnsw_halfvec_support(internal); + FUNCTION 3 hnsw_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_cosine_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), FUNCTION 2 l2_norm(halfvec), - FUNCTION 3 l2_normalize(halfvec), - FUNCTION 4 hnsw_halfvec_support(internal); + FUNCTION 3 hnsw_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_l1_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <+> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 l1_distance(halfvec, halfvec), - FUNCTION 4 hnsw_halfvec_support(internal); + FUNCTION 3 hnsw_halfvec_support(internal); --- sparsevec type @@ -872,24 +871,23 @@ CREATE OPERATOR CLASS sparsevec_l2_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <-> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 sparsevec_l2_squared_distance(sparsevec, sparsevec), - FUNCTION 4 hnsw_sparsevec_support(internal); + FUNCTION 3 hnsw_sparsevec_support(internal); CREATE OPERATOR CLASS sparsevec_ip_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <#> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec), - FUNCTION 4 hnsw_sparsevec_support(internal); + FUNCTION 3 hnsw_sparsevec_support(internal); CREATE OPERATOR CLASS sparsevec_cosine_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <=> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec), FUNCTION 2 l2_norm(sparsevec), - FUNCTION 3 l2_normalize(sparsevec), - FUNCTION 4 hnsw_sparsevec_support(internal); + FUNCTION 3 hnsw_sparsevec_support(internal); CREATE OPERATOR CLASS sparsevec_l1_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <+> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 l1_distance(sparsevec, sparsevec), - FUNCTION 4 hnsw_sparsevec_support(internal); + FUNCTION 3 hnsw_sparsevec_support(internal); diff --git a/src/hnsw.c b/src/hnsw.c index 9f32260..b56ab71 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -194,7 +194,7 @@ hnswhandler(PG_FUNCTION_ARGS) IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); amroutine->amstrategies = 0; - amroutine->amsupport = 4; + amroutine->amsupport = 3; #if PG_VERSION_NUM >= 130000 amroutine->amoptsprocnum = 0; #endif diff --git a/src/hnsw.h b/src/hnsw.h index 4400555..2c1c495 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -22,8 +22,7 @@ /* Support functions */ #define HNSW_DISTANCE_PROC 1 #define HNSW_NORM_PROC 2 -#define HNSW_NORMALIZE_PROC 3 -#define HNSW_TYPE_INFO_PROC 4 +#define HNSW_TYPE_INFO_PROC 3 #define HNSW_VERSION 1 #define HNSW_MAGIC_NUMBER 0xA953A953 @@ -241,6 +240,7 @@ typedef struct HnswAllocator typedef struct HnswTypeInfo { int maxDimensions; + Datum (*normalize) (PG_FUNCTION_ARGS); void (*checkValue) (Pointer v); } HnswTypeInfo; @@ -265,7 +265,6 @@ typedef struct HnswBuildState /* Support functions */ FmgrInfo *procinfo; FmgrInfo *normprocinfo; - FmgrInfo *normalizeprocinfo; Oid collation; /* Variables */ @@ -335,6 +334,7 @@ typedef HnswNeighborTupleData * HnswNeighborTuple; typedef struct HnswScanOpaqueData { + const HnswTypeInfo *typeInfo; bool first; List *w; MemoryContext tmpCtx; @@ -342,7 +342,6 @@ typedef struct HnswScanOpaqueData /* Support functions */ FmgrInfo *procinfo; FmgrInfo *normprocinfo; - FmgrInfo *normalizeprocinfo; Oid collation; } HnswScanOpaqueData; @@ -378,7 +377,6 @@ typedef struct HnswVacuumState int HnswGetM(Relation index); int HnswGetEfConstruction(Relation index); FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum); -Datum HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum value); bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value); Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); diff --git a/src/hnswbuild.c b/src/hnswbuild.c index dfb5ab9..bfe1e75 100644 --- a/src/hnswbuild.c +++ b/src/hnswbuild.c @@ -476,6 +476,7 @@ InsertTupleInMemory(HnswBuildState * buildstate, HnswElement element) static bool InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, HnswBuildState * buildstate) { + const HnswTypeInfo *typeInfo = buildstate->typeInfo; HnswGraph *graph = buildstate->graph; HnswElement element; HnswAllocator *allocator = &buildstate->allocator; @@ -488,8 +489,8 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); /* Check value */ - if (buildstate->typeInfo->checkValue != NULL) - buildstate->typeInfo->checkValue(DatumGetPointer(value)); + if (typeInfo->checkValue != NULL) + typeInfo->checkValue(DatumGetPointer(value)); /* Normalize if needed */ if (buildstate->normprocinfo != NULL) @@ -497,7 +498,7 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn if (!HnswCheckNorm(buildstate->normprocinfo, buildstate->collation, value)) return false; - value = HnswNormValue(buildstate->normalizeprocinfo, buildstate->collation, value); + value = DirectFunctionCall1(typeInfo->normalize, value); } /* Get datum size */ @@ -708,7 +709,6 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index /* Get support functions */ buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); - buildstate->normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC); buildstate->collation = index->rd_indcollation[0]; InitGraph(&buildstate->graphData, NULL, maintenance_work_mem * 1024L); diff --git a/src/hnswinsert.c b/src/hnswinsert.c index c0bc436..8c3c491 100644 --- a/src/hnswinsert.c +++ b/src/hnswinsert.c @@ -630,7 +630,7 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti if (!HnswCheckNorm(normprocinfo, collation, value)) return; - value = HnswNormValue(HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC), collation, value); + value = DirectFunctionCall1(typeInfo->normalize, value); } HnswInsertTupleOnDisk(index, value, values, isnull, heap_tid, false); diff --git a/src/hnswscan.c b/src/hnswscan.c index 0aa6df4..b14eb7f 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -61,7 +61,7 @@ GetScanValue(IndexScanDesc scan) /* Fine if normalization fails */ if (so->normprocinfo != NULL) - value = HnswNormValue(so->normalizeprocinfo, so->collation, value); + value = DirectFunctionCall1(so->typeInfo->normalize, value); } return value; @@ -79,6 +79,7 @@ hnswbeginscan(Relation index, int nkeys, int norderbys) scan = RelationGetIndexScan(index, nkeys, norderbys); so = (HnswScanOpaque) palloc(sizeof(HnswScanOpaqueData)); + so->typeInfo = HnswGetTypeInfo(index); so->first = true; so->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, "Hnsw scan temporary context", @@ -87,7 +88,6 @@ hnswbeginscan(Relation index, int nkeys, int norderbys) /* Set support functions */ so->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); so->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); - so->normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC); so->collation = index->rd_indcollation[0]; scan->opaque = so; diff --git a/src/hnswutils.c b/src/hnswutils.c index c8c05dc..645faaf 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -152,18 +152,6 @@ HnswOptionalProcInfo(Relation index, uint16 procnum) return index_getprocinfo(index, 1, procnum); } -/* - * Normalize value - */ -Datum -HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum value) -{ - if (procinfo == NULL) - return DirectFunctionCall1(l2_normalize, value); - - return FunctionCall1Coll(procinfo, collation, value); -} - /* * Check if non-zero norm */ @@ -1267,6 +1255,10 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint } } +PGDLLEXPORT Datum l2_normalize(PG_FUNCTION_ARGS); +PGDLLEXPORT Datum halfvec_l2_normalize(PG_FUNCTION_ARGS); +PGDLLEXPORT Datum sparsevec_l2_normalize(PG_FUNCTION_ARGS); + static void SparsevecCheckValue(Pointer v) { @@ -1288,6 +1280,7 @@ HnswGetTypeInfo(Relation index) { static const HnswTypeInfo typeInfo = { .maxDimensions = HNSW_MAX_DIM, + .normalize = l2_normalize, .checkValue = NULL }; @@ -1303,6 +1296,7 @@ hnsw_halfvec_support(PG_FUNCTION_ARGS) { static const HnswTypeInfo typeInfo = { .maxDimensions = HNSW_MAX_DIM * 2, + .normalize = halfvec_l2_normalize, .checkValue = NULL }; @@ -1315,6 +1309,7 @@ hnsw_bit_support(PG_FUNCTION_ARGS) { static const HnswTypeInfo typeInfo = { .maxDimensions = HNSW_MAX_DIM * 32, + .normalize = NULL, .checkValue = NULL }; @@ -1327,6 +1322,7 @@ hnsw_sparsevec_support(PG_FUNCTION_ARGS) { static const HnswTypeInfo typeInfo = { .maxDimensions = SPARSEVEC_MAX_DIM, + .normalize = sparsevec_l2_normalize, .checkValue = SparsevecCheckValue };