From 5dec500879011132706e1b365aa23917bd75b816 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Thu, 25 Apr 2024 13:56:20 -0700 Subject: [PATCH] Reduced support functions for IVFFlat - #527 --- sql/vector--0.6.2--0.7.0.sql | 10 ++++------ sql/vector.sql | 10 ++++------ src/ivfbuild.c | 5 ++--- src/ivfflat.c | 2 +- src/ivfflat.h | 9 ++++----- src/ivfinsert.c | 3 ++- src/ivfkmeans.c | 17 +++++++---------- src/ivfscan.c | 6 +++--- src/ivfutils.c | 14 +++++++++----- src/vector.h | 1 - 10 files changed, 36 insertions(+), 41 deletions(-) diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index 219463c..0d0a07c 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -63,7 +63,7 @@ CREATE OPERATOR CLASS bit_hamming_ops OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, FUNCTION 1 hamming_distance(bit, bit), FUNCTION 3 hamming_distance(bit, bit), - FUNCTION 6 ivfflat_bit_support(internal); + FUNCTION 5 ivfflat_bit_support(internal); CREATE OPERATOR CLASS bit_hamming_ops FOR TYPE bit USING hnsw AS @@ -330,7 +330,7 @@ CREATE OPERATOR CLASS halfvec_l2_ops OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec), FUNCTION 3 l2_distance(halfvec, halfvec), - FUNCTION 6 ivfflat_halfvec_support(internal); + FUNCTION 5 ivfflat_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_ip_ops FOR TYPE halfvec USING ivfflat AS @@ -338,8 +338,7 @@ CREATE OPERATOR CLASS halfvec_ip_ops FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), FUNCTION 4 l2_norm(halfvec), - FUNCTION 5 l2_normalize(halfvec), - FUNCTION 6 ivfflat_halfvec_support(internal); + FUNCTION 5 ivfflat_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_cosine_ops FOR TYPE halfvec USING ivfflat AS @@ -348,8 +347,7 @@ CREATE OPERATOR CLASS halfvec_cosine_ops FUNCTION 2 l2_norm(halfvec), FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), FUNCTION 4 l2_norm(halfvec), - FUNCTION 5 l2_normalize(halfvec), - FUNCTION 6 ivfflat_halfvec_support(internal); + FUNCTION 5 ivfflat_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_l2_ops FOR TYPE halfvec USING hnsw AS diff --git a/sql/vector.sql b/sql/vector.sql index 08f4ffa..75d27b6 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -358,7 +358,7 @@ CREATE OPERATOR CLASS bit_hamming_ops OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, FUNCTION 1 hamming_distance(bit, bit), FUNCTION 3 hamming_distance(bit, bit), - FUNCTION 6 ivfflat_bit_support(internal); + FUNCTION 5 ivfflat_bit_support(internal); CREATE OPERATOR CLASS bit_hamming_ops FOR TYPE bit USING hnsw AS @@ -641,7 +641,7 @@ CREATE OPERATOR CLASS halfvec_l2_ops OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec), FUNCTION 3 l2_distance(halfvec, halfvec), - FUNCTION 6 ivfflat_halfvec_support(internal); + FUNCTION 5 ivfflat_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_ip_ops FOR TYPE halfvec USING ivfflat AS @@ -649,8 +649,7 @@ CREATE OPERATOR CLASS halfvec_ip_ops FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), FUNCTION 4 l2_norm(halfvec), - FUNCTION 5 l2_normalize(halfvec), - FUNCTION 6 ivfflat_halfvec_support(internal); + FUNCTION 5 ivfflat_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_cosine_ops FOR TYPE halfvec USING ivfflat AS @@ -659,8 +658,7 @@ CREATE OPERATOR CLASS halfvec_cosine_ops FUNCTION 2 l2_norm(halfvec), FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), FUNCTION 4 l2_norm(halfvec), - FUNCTION 5 l2_normalize(halfvec), - FUNCTION 6 ivfflat_halfvec_support(internal); + FUNCTION 5 ivfflat_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_l2_ops FOR TYPE halfvec USING hnsw AS diff --git a/src/ivfbuild.c b/src/ivfbuild.c index 7b87962..8a7b52d 100644 --- a/src/ivfbuild.c +++ b/src/ivfbuild.c @@ -63,7 +63,7 @@ AddSample(Datum *values, IvfflatBuildState * buildstate) if (!IvfflatCheckNorm(buildstate->kmeansnormprocinfo, buildstate->collation, value)) return; - value = IvfflatNormValue(buildstate->normalizeprocinfo, buildstate->collation, value); + value = IvfflatNormValue(buildstate->typeInfo, buildstate->collation, value); } if (samples->length < targsamples) @@ -161,7 +161,7 @@ AddTupleToSort(Relation index, ItemPointer tid, Datum *values, IvfflatBuildState if (!IvfflatCheckNorm(buildstate->normprocinfo, buildstate->collation, value)) return; - value = IvfflatNormValue(buildstate->normalizeprocinfo, buildstate->collation, value); + value = IvfflatNormValue(buildstate->typeInfo, buildstate->collation, value); } /* Find the list that minimizes the distance */ @@ -351,7 +351,6 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In buildstate->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC); buildstate->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); buildstate->kmeansnormprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC); - buildstate->normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC); buildstate->collation = index->rd_indcollation[0]; /* Require more than one dimension for spherical k-means */ diff --git a/src/ivfflat.c b/src/ivfflat.c index 6bb2422..53dc766 100644 --- a/src/ivfflat.c +++ b/src/ivfflat.c @@ -188,7 +188,7 @@ ivfflathandler(PG_FUNCTION_ARGS) IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); amroutine->amstrategies = 0; - amroutine->amsupport = 6; + amroutine->amsupport = 5; #if PG_VERSION_NUM >= 130000 amroutine->amoptsprocnum = 0; #endif diff --git a/src/ivfflat.h b/src/ivfflat.h index 41d5239..dd96f57 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -28,8 +28,7 @@ #define IVFFLAT_NORM_PROC 2 #define IVFFLAT_KMEANS_DISTANCE_PROC 3 #define IVFFLAT_KMEANS_NORM_PROC 4 -#define IVFFLAT_NORMALIZE_PROC 5 -#define IVFFLAT_TYPE_INFO_PROC 6 +#define IVFFLAT_TYPE_INFO_PROC 5 #define IVFFLAT_VERSION 1 #define IVFFLAT_MAGIC_NUMBER 0x14FF1A7 @@ -153,6 +152,7 @@ typedef struct IvfflatLeader typedef struct IvfflatTypeInfo { int maxDimensions; + Datum (*normalize) (PG_FUNCTION_ARGS); void (*updateCenter) (Pointer v, int dimensions, float *x); void (*sumCenter) (Pointer v, float *x); } IvfflatTypeInfo; @@ -177,7 +177,6 @@ typedef struct IvfflatBuildState FmgrInfo *procinfo; FmgrInfo *normprocinfo; FmgrInfo *kmeansnormprocinfo; - FmgrInfo *normalizeprocinfo; Oid collation; /* Variables */ @@ -245,6 +244,7 @@ typedef struct IvfflatScanList typedef struct IvfflatScanOpaqueData { + const IvfflatTypeInfo *typeInfo; int probes; int dimensions; bool first; @@ -258,7 +258,6 @@ typedef struct IvfflatScanOpaqueData /* Support functions */ FmgrInfo *procinfo; FmgrInfo *normprocinfo; - FmgrInfo *normalizeprocinfo; Oid collation; Datum (*distfunc) (FmgrInfo *flinfo, Oid collation, Datum arg1, Datum arg2); @@ -279,7 +278,7 @@ VectorArray VectorArrayInit(int maxlen, int dimensions, Size itemsize); void VectorArrayFree(VectorArray arr); void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo); FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum); -Datum IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum value); +Datum IvfflatNormValue(const IvfflatTypeInfo * typeInfo, Oid collation, Datum value); bool IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value); int IvfflatGetLists(Relation index); void IvfflatGetMetaPageInfo(Relation index, int *lists, int *dimensions); diff --git a/src/ivfinsert.c b/src/ivfinsert.c index ce23f5c..fd1ea17 100644 --- a/src/ivfinsert.c +++ b/src/ivfinsert.c @@ -67,6 +67,7 @@ FindInsertPage(Relation index, Datum *values, BlockNumber *insertPage, ListInfo static void InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel) { + const IvfflatTypeInfo *typeInfo = IvfflatGetTypeInfo(index); IndexTuple itup; Datum value; FmgrInfo *normprocinfo; @@ -90,7 +91,7 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, R if (!IvfflatCheckNorm(normprocinfo, collation, value)) return; - value = IvfflatNormValue(IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC), collation, value); + value = IvfflatNormValue(typeInfo, collation, value); } /* Find the insert page - sets the page and list info */ diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index 89b5cf7..1ea7a45 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -92,7 +92,7 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low * Norm centers */ static void -NormCenters(FmgrInfo *normalizeprocinfo, Oid collation, VectorArray centers) +NormCenters(const IvfflatTypeInfo * typeInfo, Oid collation, VectorArray centers) { MemoryContext normCtx = AllocSetContextCreate(CurrentMemoryContext, "Ivfflat norm temporary context", @@ -102,7 +102,7 @@ NormCenters(FmgrInfo *normalizeprocinfo, Oid collation, VectorArray centers) for (int j = 0; j < centers->length; j++) { Datum center = PointerGetDatum(VectorArrayGet(centers, j)); - Datum newCenter = IvfflatNormValue(normalizeprocinfo, collation, center); + Datum newCenter = IvfflatNormValue(typeInfo, collation, center); Size size = VARSIZE_ANY(DatumGetPointer(newCenter)); if (size > centers->itemsize) @@ -123,9 +123,8 @@ static void RandomCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeInfo) { int dimensions = centers->dim; - Oid collation = index->rd_indcollation[0]; FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC); - FmgrInfo *normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC); + Oid collation = index->rd_indcollation[0]; float *x = (float *) palloc(sizeof(float) * dimensions); /* Fill with random data */ @@ -142,7 +141,7 @@ RandomCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeI } if (normprocinfo != NULL) - NormCenters(normalizeprocinfo, collation, centers); + NormCenters(typeInfo, collation, centers); pfree(x); } @@ -196,7 +195,7 @@ UpdateCenters(float *agg, VectorArray centers, const IvfflatTypeInfo * typeInfo) * Compute new centers */ static void -ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, const IvfflatTypeInfo * typeInfo) +ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, Oid collation, const IvfflatTypeInfo * typeInfo) { int dimensions = newCenters->dim; int numCenters = newCenters->length; @@ -251,7 +250,7 @@ ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int * /* Normalize if needed */ if (normprocinfo != NULL) - NormCenters(normalizeprocinfo, collation, newCenters); + NormCenters(typeInfo, collation, newCenters); } /* @@ -267,7 +266,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const Ivff { FmgrInfo *procinfo; FmgrInfo *normprocinfo; - FmgrInfo *normalizeprocinfo; Oid collation; int dimensions = centers->dim; int numCenters = centers->maxlen; @@ -315,7 +313,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const Ivff /* Set support functions */ procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC); normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC); - normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC); collation = index->rd_indcollation[0]; /* Use memory context */ @@ -477,7 +474,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const Ivff } /* Step 4: For each center c, let m(c) be mean of all points assigned */ - ComputeNewCenters(samples, agg, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, typeInfo); + ComputeNewCenters(samples, agg, newCenters, centerCounts, closestCenters, normprocinfo, collation, typeInfo); /* Step 5 */ for (int j = 0; j < numCenters; j++) diff --git a/src/ivfscan.c b/src/ivfscan.c index f17faad..03f743d 100644 --- a/src/ivfscan.c +++ b/src/ivfscan.c @@ -209,9 +209,9 @@ GetScanValue(IndexScanDesc scan) Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value))); Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value))); - /* Check normprocinfo since normalizeprocinfo not set for vector */ + /* Normalize if needed */ if (so->normprocinfo != NULL) - value = IvfflatNormValue(so->normalizeprocinfo, so->collation, value); + value = IvfflatNormValue(so->typeInfo, so->collation, value); } return value; @@ -242,6 +242,7 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys) probes = lists; so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + probes * sizeof(IvfflatScanList)); + so->typeInfo = IvfflatGetTypeInfo(index); so->first = true; so->probes = probes; so->dimensions = dimensions; @@ -249,7 +250,6 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys) /* Set support functions */ so->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC); so->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); - so->normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC); so->collation = index->rd_indcollation[0]; /* Create tuple description for sorting */ diff --git a/src/ivfutils.c b/src/ivfutils.c index 7b5f152..6ecf85a 100644 --- a/src/ivfutils.c +++ b/src/ivfutils.c @@ -68,12 +68,9 @@ IvfflatOptionalProcInfo(Relation index, uint16 procnum) * Normalize value */ Datum -IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum value) +IvfflatNormValue(const IvfflatTypeInfo * typeInfo, Oid collation, Datum value) { - if (procinfo == NULL) - return DirectFunctionCall1(l2_normalize, value); - - return FunctionCall1Coll(procinfo, collation, value); + return DirectFunctionCall1Coll(typeInfo->normalize, collation, value); } /* @@ -228,6 +225,10 @@ IvfflatUpdateList(Relation index, ListInfo listInfo, } } +PGDLLEXPORT Datum l2_normalize(PG_FUNCTION_ARGS); +PGDLLEXPORT Datum halfvec_l2_normalize(PG_FUNCTION_ARGS); +PGDLLEXPORT Datum sparsevec_l2_normalize(PG_FUNCTION_ARGS); + static void VectorUpdateCenter(Pointer v, int dimensions, float *x) { @@ -307,6 +308,7 @@ IvfflatGetTypeInfo(Relation index) { static const IvfflatTypeInfo typeInfo = { .maxDimensions = IVFFLAT_MAX_DIM, + .normalize = l2_normalize, .updateCenter = VectorUpdateCenter, .sumCenter = VectorSumCenter }; @@ -323,6 +325,7 @@ ivfflat_halfvec_support(PG_FUNCTION_ARGS) { static const IvfflatTypeInfo typeInfo = { .maxDimensions = IVFFLAT_MAX_DIM * 2, + .normalize = halfvec_l2_normalize, .updateCenter = HalfvecUpdateCenter, .sumCenter = HalfvecSumCenter }; @@ -336,6 +339,7 @@ ivfflat_bit_support(PG_FUNCTION_ARGS) { static const IvfflatTypeInfo typeInfo = { .maxDimensions = IVFFLAT_MAX_DIM * 32, + .normalize = NULL, .updateCenter = BitUpdateCenter, .sumCenter = BitSumCenter }; diff --git a/src/vector.h b/src/vector.h index 6662a59..570a758 100644 --- a/src/vector.h +++ b/src/vector.h @@ -21,6 +21,5 @@ typedef struct Vector Vector *InitVector(int dim); void PrintVector(char *msg, Vector * vector); int vector_cmp_internal(Vector * a, Vector * b); -PGDLLEXPORT Datum l2_normalize(PG_FUNCTION_ARGS); #endif