diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index 2584dd1..e6f3788 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -28,16 +28,13 @@ CREATE FUNCTION ivfflat_bit_support(internal) RETURNS internal CREATE FUNCTION ivfflat_halfvec_support(internal) RETURNS internal AS 'MODULE_PATHNAME' LANGUAGE C; -CREATE FUNCTION hnsw_bit_max_dims(internal) RETURNS internal +CREATE FUNCTION hnsw_bit_support(internal) RETURNS internal AS 'MODULE_PATHNAME' LANGUAGE C; -CREATE FUNCTION hnsw_halfvec_max_dims(internal) RETURNS internal +CREATE FUNCTION hnsw_halfvec_support(internal) RETURNS internal AS 'MODULE_PATHNAME' LANGUAGE C; -CREATE FUNCTION hnsw_sparsevec_max_dims(internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C; - -CREATE FUNCTION hnsw_sparsevec_check_value(internal) RETURNS internal +CREATE FUNCTION hnsw_sparsevec_support(internal) RETURNS internal AS 'MODULE_PATHNAME' LANGUAGE C; CREATE OPERATOR CLASS vector_l1_ops @@ -72,13 +69,13 @@ CREATE OPERATOR CLASS bit_hamming_ops FOR TYPE bit USING hnsw AS OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, FUNCTION 1 hamming_distance(bit, bit), - FUNCTION 4 hnsw_bit_max_dims(internal); + FUNCTION 4 hnsw_bit_support(internal); CREATE OPERATOR CLASS bit_jaccard_ops FOR TYPE bit USING hnsw AS OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops, FUNCTION 1 jaccard_distance(bit, bit), - FUNCTION 4 hnsw_bit_max_dims(internal); + FUNCTION 4 hnsw_bit_support(internal); CREATE TYPE halfvec; @@ -358,13 +355,13 @@ CREATE OPERATOR CLASS halfvec_l2_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec), - FUNCTION 4 hnsw_halfvec_max_dims(internal); + FUNCTION 4 hnsw_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_ip_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), - FUNCTION 4 hnsw_halfvec_max_dims(internal); + FUNCTION 4 hnsw_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_cosine_ops FOR TYPE halfvec USING hnsw AS @@ -372,13 +369,13 @@ CREATE OPERATOR CLASS halfvec_cosine_ops FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), FUNCTION 2 l2_norm(halfvec), FUNCTION 3 l2_normalize(halfvec), - FUNCTION 4 hnsw_halfvec_max_dims(internal); + FUNCTION 4 hnsw_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_l1_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <+> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 l1_distance(halfvec, halfvec), - FUNCTION 4 hnsw_halfvec_max_dims(internal); + FUNCTION 4 hnsw_halfvec_support(internal); CREATE TYPE sparsevec; @@ -550,15 +547,13 @@ CREATE OPERATOR CLASS sparsevec_l2_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <-> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 sparsevec_l2_squared_distance(sparsevec, sparsevec), - FUNCTION 4 hnsw_sparsevec_max_dims(internal), - FUNCTION 5 hnsw_sparsevec_check_value(internal); + FUNCTION 4 hnsw_sparsevec_support(internal); CREATE OPERATOR CLASS sparsevec_ip_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <#> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec), - FUNCTION 4 hnsw_sparsevec_max_dims(internal), - FUNCTION 5 hnsw_sparsevec_check_value(internal); + FUNCTION 4 hnsw_sparsevec_support(internal); CREATE OPERATOR CLASS sparsevec_cosine_ops FOR TYPE sparsevec USING hnsw AS @@ -566,12 +561,10 @@ CREATE OPERATOR CLASS sparsevec_cosine_ops FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec), FUNCTION 2 l2_norm(sparsevec), FUNCTION 3 l2_normalize(sparsevec), - FUNCTION 4 hnsw_sparsevec_max_dims(internal), - FUNCTION 5 hnsw_sparsevec_check_value(internal); + FUNCTION 4 hnsw_sparsevec_support(internal); CREATE OPERATOR CLASS sparsevec_l1_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <+> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 l1_distance(sparsevec, sparsevec), - FUNCTION 4 hnsw_sparsevec_max_dims(internal), - FUNCTION 5 hnsw_sparsevec_check_value(internal); + FUNCTION 4 hnsw_sparsevec_support(internal); diff --git a/sql/vector.sql b/sql/vector.sql index 7d1767a..527ad67 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -269,16 +269,13 @@ CREATE FUNCTION ivfflat_bit_support(internal) RETURNS internal CREATE FUNCTION ivfflat_halfvec_support(internal) RETURNS internal AS 'MODULE_PATHNAME' LANGUAGE C; -CREATE FUNCTION hnsw_bit_max_dims(internal) RETURNS internal +CREATE FUNCTION hnsw_bit_support(internal) RETURNS internal AS 'MODULE_PATHNAME' LANGUAGE C; -CREATE FUNCTION hnsw_halfvec_max_dims(internal) RETURNS internal +CREATE FUNCTION hnsw_halfvec_support(internal) RETURNS internal AS 'MODULE_PATHNAME' LANGUAGE C; -CREATE FUNCTION hnsw_sparsevec_max_dims(internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C; - -CREATE FUNCTION hnsw_sparsevec_check_value(internal) RETURNS internal +CREATE FUNCTION hnsw_sparsevec_support(internal) RETURNS internal AS 'MODULE_PATHNAME' LANGUAGE C; -- vector opclasses @@ -367,13 +364,13 @@ CREATE OPERATOR CLASS bit_hamming_ops FOR TYPE bit USING hnsw AS OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, FUNCTION 1 hamming_distance(bit, bit), - FUNCTION 4 hnsw_bit_max_dims(internal); + FUNCTION 4 hnsw_bit_support(internal); CREATE OPERATOR CLASS bit_jaccard_ops FOR TYPE bit USING hnsw AS OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops, FUNCTION 1 jaccard_distance(bit, bit), - FUNCTION 4 hnsw_bit_max_dims(internal); + FUNCTION 4 hnsw_bit_support(internal); -- halfvec type @@ -669,13 +666,13 @@ CREATE OPERATOR CLASS halfvec_l2_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec), - FUNCTION 4 hnsw_halfvec_max_dims(internal); + FUNCTION 4 hnsw_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_ip_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), - FUNCTION 4 hnsw_halfvec_max_dims(internal); + FUNCTION 4 hnsw_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_cosine_ops FOR TYPE halfvec USING hnsw AS @@ -683,13 +680,13 @@ CREATE OPERATOR CLASS halfvec_cosine_ops FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), FUNCTION 2 l2_norm(halfvec), FUNCTION 3 l2_normalize(halfvec), - FUNCTION 4 hnsw_halfvec_max_dims(internal); + FUNCTION 4 hnsw_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_l1_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <+> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 l1_distance(halfvec, halfvec), - FUNCTION 4 hnsw_halfvec_max_dims(internal); + FUNCTION 4 hnsw_halfvec_support(internal); --- sparsevec type @@ -875,15 +872,13 @@ CREATE OPERATOR CLASS sparsevec_l2_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <-> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 sparsevec_l2_squared_distance(sparsevec, sparsevec), - FUNCTION 4 hnsw_sparsevec_max_dims(internal), - FUNCTION 5 hnsw_sparsevec_check_value(internal); + FUNCTION 4 hnsw_sparsevec_support(internal); CREATE OPERATOR CLASS sparsevec_ip_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <#> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec), - FUNCTION 4 hnsw_sparsevec_max_dims(internal), - FUNCTION 5 hnsw_sparsevec_check_value(internal); + FUNCTION 4 hnsw_sparsevec_support(internal); CREATE OPERATOR CLASS sparsevec_cosine_ops FOR TYPE sparsevec USING hnsw AS @@ -891,12 +886,10 @@ CREATE OPERATOR CLASS sparsevec_cosine_ops FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec), FUNCTION 2 l2_norm(sparsevec), FUNCTION 3 l2_normalize(sparsevec), - FUNCTION 4 hnsw_sparsevec_max_dims(internal), - FUNCTION 5 hnsw_sparsevec_check_value(internal); + FUNCTION 4 hnsw_sparsevec_support(internal); CREATE OPERATOR CLASS sparsevec_l1_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <+> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 l1_distance(sparsevec, sparsevec), - FUNCTION 4 hnsw_sparsevec_max_dims(internal), - FUNCTION 5 hnsw_sparsevec_check_value(internal); + FUNCTION 4 hnsw_sparsevec_support(internal); diff --git a/src/hnsw.c b/src/hnsw.c index 9741a5d..9f32260 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -194,7 +194,7 @@ hnswhandler(PG_FUNCTION_ARGS) IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); amroutine->amstrategies = 0; - amroutine->amsupport = 5; + amroutine->amsupport = 4; #if PG_VERSION_NUM >= 130000 amroutine->amoptsprocnum = 0; #endif diff --git a/src/hnsw.h b/src/hnsw.h index 8cd7ab1..4400555 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -23,8 +23,7 @@ #define HNSW_DISTANCE_PROC 1 #define HNSW_NORM_PROC 2 #define HNSW_NORMALIZE_PROC 3 -#define HNSW_MAX_DIMS_PROC 4 -#define HNSW_CHECK_VALUE_PROC 5 +#define HNSW_TYPE_INFO_PROC 4 #define HNSW_VERSION 1 #define HNSW_MAGIC_NUMBER 0xA953A953 @@ -239,6 +238,12 @@ typedef struct HnswAllocator void *state; } HnswAllocator; +typedef struct HnswTypeInfo +{ + int maxDimensions; + void (*checkValue) (Pointer v); +} HnswTypeInfo; + typedef struct HnswBuildState { /* Info */ @@ -246,6 +251,7 @@ typedef struct HnswBuildState Relation index; IndexInfo *indexInfo; ForkNumber forkNum; + const HnswTypeInfo *typeInfo; /* Settings */ int dimensions; @@ -260,7 +266,6 @@ typedef struct HnswBuildState FmgrInfo *procinfo; FmgrInfo *normprocinfo; FmgrInfo *normalizeprocinfo; - FmgrInfo *checkvalueprocinfo; Oid collation; /* Variables */ @@ -375,7 +380,6 @@ int HnswGetEfConstruction(Relation index); FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum); Datum HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum value); bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value); -void HnswCheckValue(FmgrInfo *procinfo, Oid collation, Datum value); Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); void HnswInit(void); @@ -399,6 +403,7 @@ void HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element void HnswUpdateConnection(char *base, HnswElement element, HnswCandidate * hc, int lm, int lc, int *updateIdx, Relation index, FmgrInfo *procinfo, Oid collation); void HnswLoadNeighbors(HnswElement element, Relation index, int m); void HnswInitLockTranche(void); +const HnswTypeInfo *HnswGetTypeInfo(Relation index); PGDLLEXPORT void HnswParallelBuildMain(dsm_segment *seg, shm_toc *toc); /* Index access methods */ diff --git a/src/hnswbuild.c b/src/hnswbuild.c index 34f390b..dfb5ab9 100644 --- a/src/hnswbuild.c +++ b/src/hnswbuild.c @@ -488,8 +488,8 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); /* Check value */ - if (buildstate->checkvalueprocinfo != NULL) - HnswCheckValue(buildstate->checkvalueprocinfo, buildstate->collation, value); + if (buildstate->typeInfo->checkValue != NULL) + buildstate->typeInfo->checkValue(DatumGetPointer(value)); /* Normalize if needed */ if (buildstate->normprocinfo != NULL) @@ -672,32 +672,17 @@ HnswSharedMemoryAlloc(Size size, void *state) return chunk; } -/* - * Get max dimensions - */ -static int -GetMaxDimensions(Relation index) -{ - FmgrInfo *procinfo = HnswOptionalProcInfo(index, HNSW_MAX_DIMS_PROC); - - if (procinfo == NULL) - return HNSW_MAX_DIM; - - return DatumGetInt32(FunctionCall1(procinfo, PointerGetDatum(NULL))); -} - /* * Initialize the build state */ static void InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo, ForkNumber forkNum) { - int maxDimensions; - buildstate->heap = heap; buildstate->index = index; buildstate->indexInfo = indexInfo; buildstate->forkNum = forkNum; + buildstate->typeInfo = HnswGetTypeInfo(index); buildstate->m = HnswGetM(index); buildstate->efConstruction = HnswGetEfConstruction(index); @@ -707,14 +692,12 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index if (TupleDescAttr(index->rd_att, 0)->atttypid == VARBITOID) elog(ERROR, "type not supported for hnsw index"); - maxDimensions = GetMaxDimensions(index); - /* Require column to have dimensions to be indexed */ if (buildstate->dimensions < 0) elog(ERROR, "column does not have dimensions"); - if (buildstate->dimensions > maxDimensions) - elog(ERROR, "column cannot have more than %d dimensions for hnsw index", maxDimensions); + if (buildstate->dimensions > buildstate->typeInfo->maxDimensions) + elog(ERROR, "column cannot have more than %d dimensions for hnsw index", buildstate->typeInfo->maxDimensions); if (buildstate->efConstruction < 2 * buildstate->m) elog(ERROR, "ef_construction must be greater than or equal to 2 * m"); @@ -726,7 +709,6 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); buildstate->normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC); - buildstate->checkvalueprocinfo = HnswOptionalProcInfo(index, HNSW_CHECK_VALUE_PROC); buildstate->collation = index->rd_indcollation[0]; InitGraph(&buildstate->graphData, NULL, maintenance_work_mem * 1024L); diff --git a/src/hnswinsert.c b/src/hnswinsert.c index 1ea3aef..c0bc436 100644 --- a/src/hnswinsert.c +++ b/src/hnswinsert.c @@ -612,7 +612,7 @@ static void HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid) { Datum value; - FmgrInfo *checkvalueprocinfo = HnswOptionalProcInfo(index, HNSW_CHECK_VALUE_PROC); + const HnswTypeInfo *typeInfo = HnswGetTypeInfo(index); FmgrInfo *normprocinfo; Oid collation = index->rd_indcollation[0]; @@ -620,8 +620,8 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); /* Check value */ - if (checkvalueprocinfo != NULL) - HnswCheckValue(checkvalueprocinfo, collation, value); + if (typeInfo->checkValue != NULL) + typeInfo->checkValue(DatumGetPointer(value)); /* Normalize if needed */ normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); diff --git a/src/hnswutils.c b/src/hnswutils.c index bc3c56e..c8c05dc 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -173,15 +173,6 @@ HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value) return DatumGetFloat8(FunctionCall1Coll(procinfo, collation, value)) > 0; } -/* - * Check if a value can be indexed - */ -void -HnswCheckValue(FmgrInfo *procinfo, Oid collation, Datum value) -{ - FunctionCall1Coll(procinfo, collation, value); -} - /* * New buffer */ @@ -1276,35 +1267,68 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint } } -PGDLLEXPORT PG_FUNCTION_INFO_V1(hnsw_halfvec_max_dims); -Datum -hnsw_halfvec_max_dims(PG_FUNCTION_ARGS) +static void +SparsevecCheckValue(Pointer v) { - PG_RETURN_INT32(HNSW_MAX_DIM * 2); -}; - -PGDLLEXPORT PG_FUNCTION_INFO_V1(hnsw_bit_max_dims); -Datum -hnsw_bit_max_dims(PG_FUNCTION_ARGS) -{ - PG_RETURN_INT32(HNSW_MAX_DIM * 32); -}; - -PGDLLEXPORT PG_FUNCTION_INFO_V1(hnsw_sparsevec_max_dims); -Datum -hnsw_sparsevec_max_dims(PG_FUNCTION_ARGS) -{ - PG_RETURN_INT32(SPARSEVEC_MAX_DIM); -}; - -PGDLLEXPORT PG_FUNCTION_INFO_V1(hnsw_sparsevec_check_value); -Datum -hnsw_sparsevec_check_value(PG_FUNCTION_ARGS) -{ - SparseVector *vec = PG_GETARG_SPARSEVEC_P(0); + SparseVector *vec = (SparseVector *) v; if (vec->nnz > HNSW_MAX_NNZ) elog(ERROR, "sparsevec cannot have more than %d non-zero elements for hnsw index", HNSW_MAX_NNZ); - - PG_RETURN_VOID(); } + +/* + * Get type info + */ +const HnswTypeInfo * +HnswGetTypeInfo(Relation index) +{ + FmgrInfo *procinfo = HnswOptionalProcInfo(index, HNSW_TYPE_INFO_PROC); + + if (procinfo == NULL) + { + static const HnswTypeInfo typeInfo = { + .maxDimensions = HNSW_MAX_DIM, + .checkValue = NULL + }; + + return (&typeInfo); + } + else + return (const HnswTypeInfo *) DatumGetPointer(FunctionCall0Coll(procinfo, InvalidOid)); +} + +PGDLLEXPORT PG_FUNCTION_INFO_V1(hnsw_halfvec_support); +Datum +hnsw_halfvec_support(PG_FUNCTION_ARGS) +{ + static const HnswTypeInfo typeInfo = { + .maxDimensions = HNSW_MAX_DIM * 2, + .checkValue = NULL + }; + + PG_RETURN_POINTER(&typeInfo); +}; + +PGDLLEXPORT PG_FUNCTION_INFO_V1(hnsw_bit_support); +Datum +hnsw_bit_support(PG_FUNCTION_ARGS) +{ + static const HnswTypeInfo typeInfo = { + .maxDimensions = HNSW_MAX_DIM * 32, + .checkValue = NULL + }; + + PG_RETURN_POINTER(&typeInfo); +}; + +PGDLLEXPORT PG_FUNCTION_INFO_V1(hnsw_sparsevec_support); +Datum +hnsw_sparsevec_support(PG_FUNCTION_ARGS) +{ + static const HnswTypeInfo typeInfo = { + .maxDimensions = SPARSEVEC_MAX_DIM, + .checkValue = SparsevecCheckValue + }; + + PG_RETURN_POINTER(&typeInfo); +};