diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index 723ff39..ee13b96 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -222,7 +222,7 @@ ShowMemoryUsage(MemoryContext context, Size estimatedSize) * Compute new centers */ static void -ComputeNewCenters(VectorArray samples, VectorArray aggCenters, int *centerCounts, int *closestCenters, IvfflatType type) +ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, Oid collation, IvfflatType type) { int dimensions = aggCenters->dim; int numCenters = aggCenters->maxlen; @@ -293,6 +293,30 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, int *centerCounts vec->x[k] = RandomDouble(); } } + + /* Set new centers if different from agg centers */ + if (type == IVFFLAT_TYPE_HALFVEC) + { + for (int j = 0; j < numCenters; j++) + { + Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, j); + HalfVector *newCenter = (HalfVector *) VectorArrayGet(newCenters, j); + + for (int k = 0; k < dimensions; k++) + newCenter->x[k] = Float4ToHalfUnchecked(aggCenter->x[k]); + } + } + + /* Normalize if needed */ + if (normprocinfo != NULL) + { + for (int j = 0; j < numCenters; j++) + { + Datum newCenter = PointerGetDatum(VectorArrayGet(newCenters, j)); + + ApplyNorm(normprocinfo, collation, newCenter, type); + } + } } /* @@ -540,31 +564,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp } /* Step 4: For each center c, let m(c) be mean of all points assigned */ - ComputeNewCenters(samples, aggCenters, centerCounts, closestCenters, type); - - /* Set new centers if different from agg centers */ - if (type == IVFFLAT_TYPE_HALFVEC) - { - for (int j = 0; j < numCenters; j++) - { - Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, j); - HalfVector *newCenter = (HalfVector *) VectorArrayGet(newCenters, j); - - for (int k = 0; k < dimensions; k++) - newCenter->x[k] = Float4ToHalfUnchecked(aggCenter->x[k]); - } - } - - /* Normalize if needed */ - if (normprocinfo != NULL) - { - for (int j = 0; j < numCenters; j++) - { - Datum newCenter = PointerGetDatum(VectorArrayGet(newCenters, j)); - - ApplyNorm(normprocinfo, collation, newCenter, type); - } - } + ComputeNewCenters(samples, aggCenters, newCenters, centerCounts, closestCenters, normprocinfo, collation, type); /* Step 5 */ for (int j = 0; j < numCenters; j++)