diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index d2ddbce..8e135e6 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -318,7 +318,8 @@ CREATE OPERATOR CLASS halfvec_ip_ops OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), - FUNCTION 4 l2_norm(halfvec); + FUNCTION 4 l2_norm(halfvec), + FUNCTION 5 l2_normalize(halfvec); CREATE OPERATOR CLASS halfvec_cosine_ops FOR TYPE halfvec USING ivfflat AS @@ -326,7 +327,8 @@ CREATE OPERATOR CLASS halfvec_cosine_ops FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), FUNCTION 2 l2_norm(halfvec), FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), - FUNCTION 4 l2_norm(halfvec); + FUNCTION 4 l2_norm(halfvec), + FUNCTION 5 l2_normalize(halfvec); CREATE OPERATOR CLASS halfvec_l2_ops FOR TYPE halfvec USING hnsw AS diff --git a/sql/vector.sql b/sql/vector.sql index 2aed8d2..a43871d 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -627,7 +627,8 @@ CREATE OPERATOR CLASS halfvec_ip_ops OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), - FUNCTION 4 l2_norm(halfvec); + FUNCTION 4 l2_norm(halfvec), + FUNCTION 5 l2_normalize(halfvec); CREATE OPERATOR CLASS halfvec_cosine_ops FOR TYPE halfvec USING ivfflat AS @@ -635,7 +636,8 @@ CREATE OPERATOR CLASS halfvec_cosine_ops FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), FUNCTION 2 l2_norm(halfvec), FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), - FUNCTION 4 l2_norm(halfvec); + FUNCTION 4 l2_norm(halfvec), + FUNCTION 5 l2_normalize(halfvec); CREATE OPERATOR CLASS halfvec_l2_ops FOR TYPE halfvec USING hnsw AS diff --git a/src/ivfbuild.c b/src/ivfbuild.c index dd89b44..b42372f 100644 --- a/src/ivfbuild.c +++ b/src/ivfbuild.c @@ -60,8 +60,10 @@ AddSample(Datum *values, IvfflatBuildState * buildstate) */ if (buildstate->kmeansnormprocinfo != NULL) { - if (!IvfflatNormValue(buildstate->kmeansnormprocinfo, buildstate->collation, &value, buildstate->type)) + if (!IvfflatCheckNorm(buildstate->kmeansnormprocinfo, buildstate->collation, value)) return; + + value = IvfflatNormValue(buildstate->normalizeprocinfo, buildstate->collation, value); } if (samples->length < targsamples) @@ -156,8 +158,10 @@ AddTupleToSort(Relation index, ItemPointer tid, Datum *values, IvfflatBuildState /* Normalize if needed */ if (buildstate->normprocinfo != NULL) { - if (!IvfflatNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->type)) + if (!IvfflatCheckNorm(buildstate->normprocinfo, buildstate->collation, value)) return; + + value = IvfflatNormValue(buildstate->normalizeprocinfo, buildstate->collation, value); } /* Find the list that minimizes the distance */ @@ -379,6 +383,7 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In buildstate->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC); buildstate->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); buildstate->kmeansnormprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC); + buildstate->normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC); buildstate->collation = index->rd_indcollation[0]; /* Require more than one dimension for spherical k-means */ diff --git a/src/ivfflat.c b/src/ivfflat.c index a4d7fab..53dc766 100644 --- a/src/ivfflat.c +++ b/src/ivfflat.c @@ -188,7 +188,7 @@ ivfflathandler(PG_FUNCTION_ARGS) IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); amroutine->amstrategies = 0; - amroutine->amsupport = 4; + amroutine->amsupport = 5; #if PG_VERSION_NUM >= 130000 amroutine->amoptsprocnum = 0; #endif diff --git a/src/ivfflat.h b/src/ivfflat.h index 5393eb0..1fb873a 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -28,6 +28,7 @@ #define IVFFLAT_NORM_PROC 2 #define IVFFLAT_KMEANS_DISTANCE_PROC 3 #define IVFFLAT_KMEANS_NORM_PROC 4 +#define IVFFLAT_NORMALIZE_PROC 5 #define IVFFLAT_VERSION 1 #define IVFFLAT_MAGIC_NUMBER 0x14FF1A7 @@ -175,6 +176,7 @@ typedef struct IvfflatBuildState FmgrInfo *procinfo; FmgrInfo *normprocinfo; FmgrInfo *kmeansnormprocinfo; + FmgrInfo *normalizeprocinfo; Oid collation; /* Variables */ @@ -255,6 +257,7 @@ typedef struct IvfflatScanOpaqueData /* Support functions */ FmgrInfo *procinfo; FmgrInfo *normprocinfo; + FmgrInfo *normalizeprocinfo; Oid collation; Datum (*distfunc) (FmgrInfo *flinfo, Oid collation, Datum arg1, Datum arg2); @@ -276,7 +279,8 @@ void VectorArrayFree(VectorArray arr); void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type); FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum); IvfflatType IvfflatGetType(Relation index); -bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, IvfflatType type); +Datum IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum value); +bool IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value); int IvfflatGetLists(Relation index); void IvfflatGetMetaPageInfo(Relation index, int *lists, int *dimensions); void IvfflatUpdateList(Relation index, ListInfo listInfo, BlockNumber insertPage, BlockNumber originalInsertPage, BlockNumber startPage, ForkNumber forkNum); diff --git a/src/ivfinsert.c b/src/ivfinsert.c index 6d24ecf..ce23f5c 100644 --- a/src/ivfinsert.c +++ b/src/ivfinsert.c @@ -85,8 +85,12 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, R normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); if (normprocinfo != NULL) { - if (!IvfflatNormValue(normprocinfo, index->rd_indcollation[0], &value, IvfflatGetType(index))) + Oid collation = index->rd_indcollation[0]; + + if (!IvfflatCheckNorm(normprocinfo, collation, value)) return; + + value = IvfflatNormValue(IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC), collation, value); } /* Find the insert page - sets the page and list info */ diff --git a/src/ivfscan.c b/src/ivfscan.c index 018a636..f17faad 100644 --- a/src/ivfscan.c +++ b/src/ivfscan.c @@ -209,9 +209,9 @@ GetScanValue(IndexScanDesc scan) Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value))); Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value))); - /* Fine if normalization fails */ + /* Check normprocinfo since normalizeprocinfo not set for vector */ if (so->normprocinfo != NULL) - IvfflatNormValue(so->normprocinfo, so->collation, &value, IvfflatGetType(scan->indexRelation)); + value = IvfflatNormValue(so->normalizeprocinfo, so->collation, value); } return value; @@ -249,6 +249,7 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys) /* Set support functions */ so->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC); so->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); + so->normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC); so->collation = index->rd_indcollation[0]; /* Create tuple description for sorting */ diff --git a/src/ivfutils.c b/src/ivfutils.c index 728fa72..cca3bd8 100644 --- a/src/ivfutils.c +++ b/src/ivfutils.c @@ -99,31 +99,24 @@ IvfflatGetType(Relation index) } /* - * Divide by the norm - * - * Returns false if value should not be indexed - * - * The caller needs to free the pointer stored in value - * if it's different than the original value + * Normalize value + */ +Datum +IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum value) +{ + if (procinfo == NULL) + return DirectFunctionCall1(l2_normalize, value); + + return FunctionCall1Coll(procinfo, collation, value); +} + +/* + * Check if non-zero norm */ bool -IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, IvfflatType type) +IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value) { - double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value)); - - if (norm > 0) - { - if (type == IVFFLAT_TYPE_VECTOR) - *value = DirectFunctionCall1(l2_normalize, *value); - else if (type == IVFFLAT_TYPE_HALFVEC) - *value = DirectFunctionCall1(halfvec_l2_normalize, *value); - else - elog(ERROR, "Unsupported type"); - - return true; - } - - return false; + return DatumGetFloat8(FunctionCall1Coll(procinfo, collation, value)) > 0; } /*