diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index 96b1694..c5409cb 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -473,31 +473,35 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp centerCounts[j] = 0; } - for (int64 j = 0; j < numSamples; j++) + /* Increment sum of closest center */ + if (type == IVFFLAT_TYPE_VECTOR) { - int closestCenter = closestCenters[j]; - Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenter); - - /* Increment sum and count of closest center */ - if (type == IVFFLAT_TYPE_VECTOR) + for (int j = 0; j < numSamples; j++) { + Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenters[j]); Vector *vec = (Vector *) VectorArrayGet(samples, j); for (int k = 0; k < dimensions; k++) aggCenter->x[k] += vec->x[k]; } - else if (type == IVFFLAT_TYPE_HALFVEC) + } + else if (type == IVFFLAT_TYPE_HALFVEC) + { + for (int j = 0; j < numSamples; j++) { + Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenters[j]); HalfVector *vec = (HalfVector *) VectorArrayGet(samples, j); for (int k = 0; k < dimensions; k++) aggCenter->x[k] += HalfToFloat4(vec->x[k]); } - else - elog(ERROR, "Unsupported type"); - - centerCounts[closestCenter] += 1; } + else + elog(ERROR, "Unsupported type"); + + /* Increment count of closest center */ + for (int j = 0; j < numSamples; j++) + centerCounts[closestCenters[j]] += 1; for (int j = 0; j < numCenters; j++) {