From 3eef1ff5c227e26f24d452022545fb3d4a8b1fbc Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Wed, 24 Apr 2024 14:53:45 -0700 Subject: [PATCH] Removed type-specific code from HNSW [skip ci] --- sql/vector.sql | 33 ++++++++++++++---------- src/hnsw.c | 2 +- src/hnsw.h | 17 +++---------- src/hnswbuild.c | 25 +++++++++--------- src/hnswinsert.c | 4 ++- src/hnswutils.c | 66 +++++++++++++++++------------------------------- 6 files changed, 64 insertions(+), 83 deletions(-) diff --git a/sql/vector.sql b/sql/vector.sql index da09d1c..f57f296 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -327,7 +327,7 @@ CREATE FUNCTION jaccard_distance(bit, bit) RETURNS float8 CREATE FUNCTION bit_ivfflat_support(internal) RETURNS internal AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; -CREATE FUNCTION bit_hnsw_support(internal) RETURNS internal +CREATE FUNCTION bit_hnsw_max_dims(internal) RETURNS internal AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; -- bit operators @@ -355,13 +355,13 @@ CREATE OPERATOR CLASS bit_hamming_ops FOR TYPE bit USING hnsw AS OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, FUNCTION 1 hamming_distance(bit, bit), - FUNCTION 4 bit_hnsw_support(internal); + FUNCTION 4 bit_hnsw_max_dims(internal); CREATE OPERATOR CLASS bit_jaccard_ops FOR TYPE bit USING hnsw AS OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops, FUNCTION 1 jaccard_distance(bit, bit), - FUNCTION 4 bit_hnsw_support(internal); + FUNCTION 4 bit_hnsw_max_dims(internal); -- halfvec type @@ -473,7 +473,7 @@ CREATE FUNCTION halfvec_avg(double precision[]) RETURNS halfvec CREATE FUNCTION halfvec_ivfflat_support(internal) RETURNS internal AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; -CREATE FUNCTION halfvec_hnsw_support(internal) RETURNS internal +CREATE FUNCTION halfvec_hnsw_max_dims(internal) RETURNS internal AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; -- halfvec aggregates @@ -663,13 +663,13 @@ CREATE OPERATOR CLASS halfvec_l2_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec), - FUNCTION 4 halfvec_hnsw_support(internal); + FUNCTION 4 halfvec_hnsw_max_dims(internal); CREATE OPERATOR CLASS halfvec_ip_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), - FUNCTION 4 halfvec_hnsw_support(internal); + FUNCTION 4 halfvec_hnsw_max_dims(internal); CREATE OPERATOR CLASS halfvec_cosine_ops FOR TYPE halfvec USING hnsw AS @@ -677,13 +677,13 @@ CREATE OPERATOR CLASS halfvec_cosine_ops FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), FUNCTION 2 l2_norm(halfvec), FUNCTION 3 l2_normalize(halfvec), - FUNCTION 4 halfvec_hnsw_support(internal); + FUNCTION 4 halfvec_hnsw_max_dims(internal); CREATE OPERATOR CLASS halfvec_l1_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <+> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 l1_distance(halfvec, halfvec), - FUNCTION 4 halfvec_hnsw_support(internal); + FUNCTION 4 halfvec_hnsw_max_dims(internal); --- sparsevec type @@ -779,7 +779,10 @@ CREATE FUNCTION halfvec_to_sparsevec(halfvec, integer, boolean) RETURNS sparseve CREATE FUNCTION sparsevec_to_halfvec(sparsevec, integer, boolean) RETURNS halfvec AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; -CREATE FUNCTION sparsevec_hnsw_support(internal) RETURNS internal +CREATE FUNCTION sparsevec_hnsw_max_dims(internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION sparsevec_hnsw_check_value(internal) RETURNS internal AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; -- sparsevec casts @@ -872,13 +875,15 @@ CREATE OPERATOR CLASS sparsevec_l2_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <-> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 sparsevec_l2_squared_distance(sparsevec, sparsevec), - FUNCTION 4 sparsevec_hnsw_support(internal); + FUNCTION 4 sparsevec_hnsw_max_dims(internal), + FUNCTION 5 sparsevec_hnsw_check_value(internal); CREATE OPERATOR CLASS sparsevec_ip_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <#> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec), - FUNCTION 4 sparsevec_hnsw_support(internal); + FUNCTION 4 sparsevec_hnsw_max_dims(internal), + FUNCTION 5 sparsevec_hnsw_check_value(internal); CREATE OPERATOR CLASS sparsevec_cosine_ops FOR TYPE sparsevec USING hnsw AS @@ -886,10 +891,12 @@ CREATE OPERATOR CLASS sparsevec_cosine_ops FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec), FUNCTION 2 l2_norm(sparsevec), FUNCTION 3 l2_normalize(sparsevec), - FUNCTION 4 sparsevec_hnsw_support(internal); + FUNCTION 4 sparsevec_hnsw_max_dims(internal), + FUNCTION 5 sparsevec_hnsw_check_value(internal); CREATE OPERATOR CLASS sparsevec_l1_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <+> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 l1_distance(sparsevec, sparsevec), - FUNCTION 4 sparsevec_hnsw_support(internal); + FUNCTION 4 sparsevec_hnsw_max_dims(internal), + FUNCTION 5 sparsevec_hnsw_check_value(internal); diff --git a/src/hnsw.c b/src/hnsw.c index 9f32260..9741a5d 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -194,7 +194,7 @@ hnswhandler(PG_FUNCTION_ARGS) IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); amroutine->amstrategies = 0; - amroutine->amsupport = 4; + amroutine->amsupport = 5; #if PG_VERSION_NUM >= 130000 amroutine->amoptsprocnum = 0; #endif diff --git a/src/hnsw.h b/src/hnsw.h index 6072962..8cd7ab1 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -23,7 +23,8 @@ #define HNSW_DISTANCE_PROC 1 #define HNSW_NORM_PROC 2 #define HNSW_NORMALIZE_PROC 3 -#define HNSW_TYPE_SUPPORT_PROC 4 +#define HNSW_MAX_DIMS_PROC 4 +#define HNSW_CHECK_VALUE_PROC 5 #define HNSW_VERSION 1 #define HNSW_MAGIC_NUMBER 0xA953A953 @@ -58,15 +59,6 @@ #define HNSW_UPDATE_ENTRY_GREATER 1 #define HNSW_UPDATE_ENTRY_ALWAYS 2 -typedef enum HnswType -{ - HNSW_TYPE_VECTOR, - HNSW_TYPE_HALFVEC, - HNSW_TYPE_BIT, - HNSW_TYPE_SPARSEVEC, - HNSW_TYPE_UNSUPPORTED -} HnswType; - /* Build phases */ /* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */ #define PROGRESS_HNSW_PHASE_LOAD 2 @@ -254,7 +246,6 @@ typedef struct HnswBuildState Relation index; IndexInfo *indexInfo; ForkNumber forkNum; - HnswType type; /* Settings */ int dimensions; @@ -269,6 +260,7 @@ typedef struct HnswBuildState FmgrInfo *procinfo; FmgrInfo *normprocinfo; FmgrInfo *normalizeprocinfo; + FmgrInfo *checkvalueprocinfo; Oid collation; /* Variables */ @@ -381,10 +373,9 @@ typedef struct HnswVacuumState int HnswGetM(Relation index); int HnswGetEfConstruction(Relation index); FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum); -HnswType HnswGetType(Relation index); Datum HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum value); bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value); -void HnswCheckValue(Datum value, HnswType type); +void HnswCheckValue(FmgrInfo *procinfo, Oid collation, Datum value); Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); void HnswInit(void); diff --git a/src/hnswbuild.c b/src/hnswbuild.c index 222ceb5..34f390b 100644 --- a/src/hnswbuild.c +++ b/src/hnswbuild.c @@ -488,7 +488,8 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); /* Check value */ - HnswCheckValue(value, buildstate->type); + if (buildstate->checkvalueprocinfo != NULL) + HnswCheckValue(buildstate->checkvalueprocinfo, buildstate->collation, value); /* Normalize if needed */ if (buildstate->normprocinfo != NULL) @@ -675,18 +676,14 @@ HnswSharedMemoryAlloc(Size size, void *state) * Get max dimensions */ static int -GetMaxDimensions(HnswType type) +GetMaxDimensions(Relation index) { - int maxDimensions = HNSW_MAX_DIM; + FmgrInfo *procinfo = HnswOptionalProcInfo(index, HNSW_MAX_DIMS_PROC); - if (type == HNSW_TYPE_HALFVEC) - maxDimensions *= 2; - else if (type == HNSW_TYPE_BIT) - maxDimensions *= 32; - else if (type == HNSW_TYPE_SPARSEVEC) - maxDimensions = INT_MAX; + if (procinfo == NULL) + return HNSW_MAX_DIM; - return maxDimensions; + return DatumGetInt32(FunctionCall1(procinfo, PointerGetDatum(NULL))); } /* @@ -701,13 +698,16 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index buildstate->index = index; buildstate->indexInfo = indexInfo; buildstate->forkNum = forkNum; - buildstate->type = HnswGetType(index); buildstate->m = HnswGetM(index); buildstate->efConstruction = HnswGetEfConstruction(index); buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; - maxDimensions = GetMaxDimensions(buildstate->type); + /* Disallow varbit since require fixed dimensions */ + if (TupleDescAttr(index->rd_att, 0)->atttypid == VARBITOID) + elog(ERROR, "type not supported for hnsw index"); + + maxDimensions = GetMaxDimensions(index); /* Require column to have dimensions to be indexed */ if (buildstate->dimensions < 0) @@ -726,6 +726,7 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); buildstate->normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC); + buildstate->checkvalueprocinfo = HnswOptionalProcInfo(index, HNSW_CHECK_VALUE_PROC); buildstate->collation = index->rd_indcollation[0]; InitGraph(&buildstate->graphData, NULL, maintenance_work_mem * 1024L); diff --git a/src/hnswinsert.c b/src/hnswinsert.c index 50181ea..1ea3aef 100644 --- a/src/hnswinsert.c +++ b/src/hnswinsert.c @@ -612,6 +612,7 @@ static void HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid) { Datum value; + FmgrInfo *checkvalueprocinfo = HnswOptionalProcInfo(index, HNSW_CHECK_VALUE_PROC); FmgrInfo *normprocinfo; Oid collation = index->rd_indcollation[0]; @@ -619,7 +620,8 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); /* Check value */ - HnswCheckValue(value, HnswGetType(index)); + if (checkvalueprocinfo != NULL) + HnswCheckValue(checkvalueprocinfo, collation, value); /* Normalize if needed */ normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); diff --git a/src/hnswutils.c b/src/hnswutils.c index 6fc6dcc..71007c3 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -152,27 +152,6 @@ HnswOptionalProcInfo(Relation index, uint16 procnum) return index_getprocinfo(index, 1, procnum); } -/* - * Get type - */ -HnswType -HnswGetType(Relation index) -{ - FmgrInfo *procinfo = HnswOptionalProcInfo(index, HNSW_TYPE_SUPPORT_PROC); - Oid typid = TupleDescAttr(index->rd_att, 0)->atttypid; - HnswType result; - - if (procinfo == NULL) - return HNSW_TYPE_VECTOR; - - result = (HnswType) DatumGetInt32(FunctionCall1(procinfo, ObjectIdGetDatum(typid))); - - if (result == HNSW_TYPE_UNSUPPORTED) - elog(ERROR, "type not supported for hnsw index"); - - return result; -} - /* * Normalize value */ @@ -198,15 +177,9 @@ HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value) * Check if a value can be indexed */ void -HnswCheckValue(Datum value, HnswType type) +HnswCheckValue(FmgrInfo *procinfo, Oid collation, Datum value) { - if (type == HNSW_TYPE_SPARSEVEC) - { - SparseVector *vec = DatumGetSparseVector(value); - - if (vec->nnz > HNSW_MAX_NNZ) - elog(ERROR, "sparsevec cannot have more than %d non-zero elements for hnsw index", HNSW_MAX_NNZ); - } + FunctionCall1Coll(procinfo, collation, value); } /* @@ -1303,28 +1276,35 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint } } -PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_hnsw_support); +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_hnsw_max_dims); Datum -halfvec_hnsw_support(PG_FUNCTION_ARGS) +halfvec_hnsw_max_dims(PG_FUNCTION_ARGS) { - PG_RETURN_INT32(HNSW_TYPE_HALFVEC); + PG_RETURN_INT32(HNSW_MAX_DIM * 2); }; -PGDLLEXPORT PG_FUNCTION_INFO_V1(bit_hnsw_support); +PGDLLEXPORT PG_FUNCTION_INFO_V1(bit_hnsw_max_dims); Datum -bit_hnsw_support(PG_FUNCTION_ARGS) +bit_hnsw_max_dims(PG_FUNCTION_ARGS) { - Oid typid = PG_GETARG_OID(0); - - if (typid == BITOID) - PG_RETURN_INT32(HNSW_TYPE_BIT); - else - PG_RETURN_INT32(HNSW_TYPE_UNSUPPORTED); + PG_RETURN_INT32(HNSW_MAX_DIM * 32); }; -PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_hnsw_support); +PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_hnsw_max_dims); Datum -sparsevec_hnsw_support(PG_FUNCTION_ARGS) +sparsevec_hnsw_max_dims(PG_FUNCTION_ARGS) { - PG_RETURN_INT32(HNSW_TYPE_SPARSEVEC); + PG_RETURN_INT32(SPARSEVEC_MAX_DIM); }; + +PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_hnsw_check_value); +Datum +sparsevec_hnsw_check_value(PG_FUNCTION_ARGS) +{ + SparseVector *vec = PG_GETARG_SPARSEVEC_P(0); + + if (vec->nnz > HNSW_MAX_NNZ) + elog(ERROR, "sparsevec cannot have more than %d non-zero elements for hnsw index", HNSW_MAX_NNZ); + + PG_RETURN_VOID(); +}