diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index f61ba1d..20d23c0 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -251,27 +251,14 @@ ShowMemoryUsage(MemoryContext context, Size estimatedSize) #endif /* - * Compute new centers + * Sum centers */ static void -ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, IvfflatType type) +SumCenters(VectorArray samples, VectorArray aggCenters, int *closestCenters, IvfflatType type) { int dimensions = aggCenters->dim; - int numCenters = aggCenters->maxlen; int numSamples = samples->length; - /* Reset sum and count */ - for (int j = 0; j < numCenters; j++) - { - Vector *vec = (Vector *) VectorArrayGet(aggCenters, j); - - for (int k = 0; k < dimensions; k++) - vec->x[k] = 0.0; - - centerCounts[j] = 0; - } - - /* Increment sum of closest center */ if (type == IVFFLAT_TYPE_VECTOR) { for (int j = 0; j < numSamples; j++) @@ -307,6 +294,68 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe } else elog(ERROR, "Unsupported type"); +} + +/* + * Set new centers + */ +static void +SetNewCenters(VectorArray aggCenters, VectorArray newCenters, IvfflatType type) +{ + int dimensions = aggCenters->dim; + int numCenters = aggCenters->maxlen; + + if (type == IVFFLAT_TYPE_HALFVEC) + { + for (int j = 0; j < numCenters; j++) + { + Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, j); + HalfVector *newCenter = (HalfVector *) VectorArrayGet(newCenters, j); + + for (int k = 0; k < dimensions; k++) + newCenter->x[k] = Float4ToHalfUnchecked(aggCenter->x[k]); + } + } + else if (type == IVFFLAT_TYPE_BIT) + { + for (int j = 0; j < numCenters; j++) + { + Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, j); + VarBit *newCenter = (VarBit *) VectorArrayGet(newCenters, j); + unsigned char *nx = VARBITS(newCenter); + + for (uint32 k = 0; k < VARBITBYTES(newCenter); k++) + nx[k] = 0; + + for (int k = 0; k < dimensions; k++) + nx[k / 8] |= (aggCenter->x[k] > 0.5) << (7 - (k % 8)); + } + } +} + +/* + * Compute new centers + */ +static void +ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, IvfflatType type) +{ + int dimensions = aggCenters->dim; + int numCenters = aggCenters->maxlen; + int numSamples = samples->length; + + /* Reset sum and count */ + for (int j = 0; j < numCenters; j++) + { + Vector *vec = (Vector *) VectorArrayGet(aggCenters, j); + + for (int k = 0; k < dimensions; k++) + vec->x[k] = 0.0; + + centerCounts[j] = 0; + } + + /* Increment sum of closest center */ + SumCenters(samples, aggCenters, closestCenters, type); /* Increment count of closest center */ for (int j = 0; j < numSamples; j++) @@ -339,32 +388,7 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe } /* Set new centers if different from agg centers */ - if (type == IVFFLAT_TYPE_HALFVEC) - { - for (int j = 0; j < numCenters; j++) - { - Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, j); - HalfVector *newCenter = (HalfVector *) VectorArrayGet(newCenters, j); - - for (int k = 0; k < dimensions; k++) - newCenter->x[k] = Float4ToHalfUnchecked(aggCenter->x[k]); - } - } - else if (type == IVFFLAT_TYPE_BIT) - { - for (int j = 0; j < numCenters; j++) - { - Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, j); - VarBit *newCenter = (VarBit *) VectorArrayGet(newCenters, j); - unsigned char *nx = VARBITS(newCenter); - - for (uint32 k = 0; k < VARBITBYTES(newCenter); k++) - nx[k] = 0; - - for (int k = 0; k < dimensions; k++) - nx[k / 8] |= (aggCenter->x[k] > 0.5) << (7 - (k % 8)); - } - } + SetNewCenters(aggCenters, newCenters, type); /* Normalize if needed */ if (normprocinfo != NULL)