diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index 65d312b..7c55bee 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -89,33 +89,27 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low } /* - * Apply norm to vector + * Norm centers */ -static inline void -ApplyNorm(FmgrInfo *normprocinfo, Oid collation, Datum value, IvfflatType type) +static void +NormCenters(FmgrInfo *normalizeprocinfo, Oid collation, VectorArray centers) { - double norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, value)); + MemoryContext normCtx = AllocSetContextCreate(CurrentMemoryContext, + "Ivfflat norm temporary context", + ALLOCSET_DEFAULT_SIZES); + MemoryContext oldCtx = MemoryContextSwitchTo(normCtx); - /* TODO Handle zero norm */ - if (norm > 0) + for (int j = 0; j < centers->maxlen; j++) { - if (type == IVFFLAT_TYPE_VECTOR) - { - Vector *vec = DatumGetVector(value); + Datum center = PointerGetDatum(VectorArrayGet(centers, j)); + Datum newCenter = IvfflatNormValue(normalizeprocinfo, collation, center); - for (int i = 0; i < vec->dim; i++) - vec->x[i] /= norm; - } - else if (type == IVFFLAT_TYPE_HALFVEC) - { - HalfVector *vec = DatumGetHalfVector(value); - - for (int i = 0; i < vec->dim; i++) - vec->x[i] = Float4ToHalfUnchecked(HalfToFloat4(vec->x[i]) / norm); - } - else - elog(ERROR, "Unsupported type"); + memcpy(DatumGetPointer(center), DatumGetPointer(newCenter), VARSIZE_ANY(DatumGetPointer(newCenter))); + MemoryContextReset(normCtx); } + + MemoryContextSwitchTo(oldCtx); + MemoryContextDelete(normCtx); } /* @@ -154,6 +148,7 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy int dimensions = centers->dim; Oid collation = index->rd_indcollation[0]; FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC); + FmgrInfo *normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC); /* Copy existing vectors while avoiding duplicates */ if (samples->length > 0) @@ -217,12 +212,12 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy else elog(ERROR, "Unsupported type"); - /* Normalize if needed (only needed for random centers) */ - if (normprocinfo != NULL) - ApplyNorm(normprocinfo, collation, center, type); - centers->length++; } + + /* Fine if existing vectors are normalized twice */ + if (normprocinfo != NULL) + NormCenters(normalizeprocinfo, collation, centers); } #ifdef IVFFLAT_MEMORY @@ -246,7 +241,7 @@ ShowMemoryUsage(MemoryContext context, Size estimatedSize) * Compute new centers */ static void -ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, Oid collation, IvfflatType type) +ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, IvfflatType type) { int dimensions = aggCenters->dim; int numCenters = aggCenters->maxlen; @@ -360,14 +355,7 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe /* Normalize if needed */ if (normprocinfo != NULL) - { - for (int j = 0; j < numCenters; j++) - { - Datum newCenter = PointerGetDatum(VectorArrayGet(newCenters, j)); - - ApplyNorm(normprocinfo, collation, newCenter, type); - } - } + NormCenters(normalizeprocinfo, collation, newCenters); } /* @@ -383,6 +371,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp { FmgrInfo *procinfo; FmgrInfo *normprocinfo; + FmgrInfo *normalizeprocinfo; Oid collation; int dimensions = centers->dim; int numCenters = centers->maxlen; @@ -430,6 +419,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp /* Set support functions */ procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC); normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC); + normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC); collation = index->rd_indcollation[0]; /* Use memory context */ @@ -627,7 +617,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp } /* Step 4: For each center c, let m(c) be mean of all points assigned */ - ComputeNewCenters(samples, aggCenters, newCenters, centerCounts, closestCenters, normprocinfo, collation, type); + ComputeNewCenters(samples, aggCenters, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, type); /* Step 5 */ for (int j = 0; j < numCenters; j++)