From f14c21748b9003d34d0e4f7d08e4d59c059fb328 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 22 Apr 2024 18:36:47 -0700 Subject: [PATCH] Added support function for l2_normalize [skip ci] --- sql/vector--0.6.2--0.7.0.sql | 6 ++++-- sql/vector.sql | 6 ++++-- src/hnsw.c | 2 +- src/hnsw.h | 5 ++++- src/hnswbuild.c | 3 ++- src/hnswinsert.c | 2 +- src/hnswscan.c | 3 ++- src/hnswutils.c | 13 ++++--------- 8 files changed, 22 insertions(+), 18 deletions(-) diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index 4b62d6f..d2ddbce 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -342,7 +342,8 @@ CREATE OPERATOR CLASS halfvec_cosine_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), - FUNCTION 2 l2_norm(halfvec); + FUNCTION 2 l2_norm(halfvec), + FUNCTION 3 l2_normalize(halfvec); CREATE OPERATOR CLASS halfvec_l1_ops FOR TYPE halfvec USING hnsw AS @@ -529,7 +530,8 @@ CREATE OPERATOR CLASS sparsevec_cosine_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <=> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec), - FUNCTION 2 l2_norm(sparsevec); + FUNCTION 2 l2_norm(sparsevec), + FUNCTION 3 l2_normalize(sparsevec); CREATE OPERATOR CLASS sparsevec_l1_ops FOR TYPE sparsevec USING hnsw AS diff --git a/sql/vector.sql b/sql/vector.sql index 95d136d..2aed8d2 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -651,7 +651,8 @@ CREATE OPERATOR CLASS halfvec_cosine_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), - FUNCTION 2 l2_norm(halfvec); + FUNCTION 2 l2_norm(halfvec), + FUNCTION 3 l2_normalize(halfvec); CREATE OPERATOR CLASS halfvec_l1_ops FOR TYPE halfvec USING hnsw AS @@ -852,7 +853,8 @@ CREATE OPERATOR CLASS sparsevec_cosine_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <=> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec), - FUNCTION 2 l2_norm(sparsevec); + FUNCTION 2 l2_norm(sparsevec), + FUNCTION 3 l2_normalize(sparsevec); CREATE OPERATOR CLASS sparsevec_l1_ops FOR TYPE sparsevec USING hnsw AS diff --git a/src/hnsw.c b/src/hnsw.c index fa08cf5..b56ab71 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -194,7 +194,7 @@ hnswhandler(PG_FUNCTION_ARGS) IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); amroutine->amstrategies = 0; - amroutine->amsupport = 2; + amroutine->amsupport = 3; #if PG_VERSION_NUM >= 130000 amroutine->amoptsprocnum = 0; #endif diff --git a/src/hnsw.h b/src/hnsw.h index c1a7f12..d02522b 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -22,6 +22,7 @@ /* Support functions */ #define HNSW_DISTANCE_PROC 1 #define HNSW_NORM_PROC 2 +#define HNSW_NORMALIZE_PROC 3 #define HNSW_VERSION 1 #define HNSW_MAGIC_NUMBER 0xA953A953 @@ -265,6 +266,7 @@ typedef struct HnswBuildState /* Support functions */ FmgrInfo *procinfo; FmgrInfo *normprocinfo; + FmgrInfo *normalizeprocinfo; Oid collation; /* Variables */ @@ -341,6 +343,7 @@ typedef struct HnswScanOpaqueData /* Support functions */ FmgrInfo *procinfo; FmgrInfo *normprocinfo; + FmgrInfo *normalizeprocinfo; Oid collation; } HnswScanOpaqueData; @@ -377,7 +380,7 @@ int HnswGetM(Relation index); int HnswGetEfConstruction(Relation index); FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum); HnswType HnswGetType(Relation index); -Datum HnswNormValue(Datum value, HnswType type); +Datum HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum value); bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value); void HnswCheckValue(Datum value, HnswType type); Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); diff --git a/src/hnswbuild.c b/src/hnswbuild.c index bd32e06..222ceb5 100644 --- a/src/hnswbuild.c +++ b/src/hnswbuild.c @@ -496,7 +496,7 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn if (!HnswCheckNorm(buildstate->normprocinfo, buildstate->collation, value)) return false; - value = HnswNormValue(value, buildstate->type); + value = HnswNormValue(buildstate->normalizeprocinfo, buildstate->collation, value); } /* Get datum size */ @@ -725,6 +725,7 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index /* Get support functions */ buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); + buildstate->normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC); buildstate->collation = index->rd_indcollation[0]; InitGraph(&buildstate->graphData, NULL, maintenance_work_mem * 1024L); diff --git a/src/hnswinsert.c b/src/hnswinsert.c index 88fc6c3..5d6e061 100644 --- a/src/hnswinsert.c +++ b/src/hnswinsert.c @@ -629,7 +629,7 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti if (!HnswCheckNorm(normprocinfo, collation, value)) return; - value = HnswNormValue(value, type); + value = HnswNormValue(HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC), collation, value); } HnswInsertTupleOnDisk(index, value, values, isnull, heap_tid, false); diff --git a/src/hnswscan.c b/src/hnswscan.c index ad00852..0aa6df4 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -61,7 +61,7 @@ GetScanValue(IndexScanDesc scan) /* Fine if normalization fails */ if (so->normprocinfo != NULL) - value = HnswNormValue(value, HnswGetType(scan->indexRelation)); + value = HnswNormValue(so->normalizeprocinfo, so->collation, value); } return value; @@ -87,6 +87,7 @@ hnswbeginscan(Relation index, int nkeys, int norderbys) /* Set support functions */ so->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); so->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); + so->normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC); so->collation = index->rd_indcollation[0]; scan->opaque = so; diff --git a/src/hnswutils.c b/src/hnswutils.c index 23f8d0b..b8dfd04 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -195,17 +195,12 @@ HnswGetType(Relation index) * Normalize value */ Datum -HnswNormValue(Datum value, HnswType type) +HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum value) { - /* TODO Remove type-specific code */ - if (type == HNSW_TYPE_VECTOR) + if (procinfo == NULL) return DirectFunctionCall1(l2_normalize, value); - else if (type == HNSW_TYPE_HALFVEC) - return DirectFunctionCall1(halfvec_l2_normalize, value); - else if (type == HNSW_TYPE_SPARSEVEC) - return DirectFunctionCall1(sparsevec_l2_normalize, value); - else - elog(ERROR, "Unsupported type"); + + return FunctionCall1Coll(procinfo, collation, value); } /*