diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index ba5d2d4..723ff39 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -218,6 +218,83 @@ ShowMemoryUsage(MemoryContext context, Size estimatedSize) } #endif +/* + * Compute new centers + */ +static void +ComputeNewCenters(VectorArray samples, VectorArray aggCenters, int *centerCounts, int *closestCenters, IvfflatType type) +{ + int dimensions = aggCenters->dim; + int numCenters = aggCenters->maxlen; + int numSamples = samples->length; + + for (int j = 0; j < numCenters; j++) + { + Vector *vec = (Vector *) VectorArrayGet(aggCenters, j); + + for (int k = 0; k < dimensions; k++) + vec->x[k] = 0.0; + + centerCounts[j] = 0; + } + + /* Increment sum of closest center */ + if (type == IVFFLAT_TYPE_VECTOR) + { + for (int j = 0; j < numSamples; j++) + { + Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenters[j]); + Vector *vec = (Vector *) VectorArrayGet(samples, j); + + for (int k = 0; k < dimensions; k++) + aggCenter->x[k] += vec->x[k]; + } + } + else if (type == IVFFLAT_TYPE_HALFVEC) + { + for (int j = 0; j < numSamples; j++) + { + Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenters[j]); + HalfVector *vec = (HalfVector *) VectorArrayGet(samples, j); + + for (int k = 0; k < dimensions; k++) + aggCenter->x[k] += HalfToFloat4(vec->x[k]); + } + } + else + elog(ERROR, "Unsupported type"); + + /* Increment count of closest center */ + for (int j = 0; j < numSamples; j++) + centerCounts[closestCenters[j]] += 1; + + /* Average centers */ + for (int j = 0; j < numCenters; j++) + { + Vector *vec = (Vector *) VectorArrayGet(aggCenters, j); + + if (centerCounts[j] > 0) + { + /* Double avoids overflow, but requires more memory */ + /* TODO Update bounds */ + for (int k = 0; k < dimensions; k++) + { + if (isinf(vec->x[k])) + vec->x[k] = vec->x[k] > 0 ? FLT_MAX : -FLT_MAX; + } + + for (int k = 0; k < dimensions; k++) + vec->x[k] /= centerCounts[j]; + } + else + { + /* TODO Handle empty centers properly */ + for (int k = 0; k < dimensions; k++) + vec->x[k] = RandomDouble(); + } + } +} + /* * Use Elkan for performance. This requires distance function to satisfy triangle inequality. * @@ -463,71 +540,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp } /* Step 4: For each center c, let m(c) be mean of all points assigned */ - for (int j = 0; j < numCenters; j++) - { - Vector *vec = (Vector *) VectorArrayGet(aggCenters, j); - - for (int k = 0; k < dimensions; k++) - vec->x[k] = 0.0; - - centerCounts[j] = 0; - } - - /* Increment sum of closest center */ - if (type == IVFFLAT_TYPE_VECTOR) - { - for (int j = 0; j < numSamples; j++) - { - Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenters[j]); - Vector *vec = (Vector *) VectorArrayGet(samples, j); - - for (int k = 0; k < dimensions; k++) - aggCenter->x[k] += vec->x[k]; - } - } - else if (type == IVFFLAT_TYPE_HALFVEC) - { - for (int j = 0; j < numSamples; j++) - { - Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenters[j]); - HalfVector *vec = (HalfVector *) VectorArrayGet(samples, j); - - for (int k = 0; k < dimensions; k++) - aggCenter->x[k] += HalfToFloat4(vec->x[k]); - } - } - else - elog(ERROR, "Unsupported type"); - - /* Increment count of closest center */ - for (int j = 0; j < numSamples; j++) - centerCounts[closestCenters[j]] += 1; - - /* Average centers */ - for (int j = 0; j < numCenters; j++) - { - Vector *vec = (Vector *) VectorArrayGet(aggCenters, j); - - if (centerCounts[j] > 0) - { - /* Double avoids overflow, but requires more memory */ - /* TODO Update bounds */ - for (int k = 0; k < dimensions; k++) - { - if (isinf(vec->x[k])) - vec->x[k] = vec->x[k] > 0 ? FLT_MAX : -FLT_MAX; - } - - for (int k = 0; k < dimensions; k++) - vec->x[k] /= centerCounts[j]; - } - else - { - /* TODO Handle empty centers properly */ - for (int k = 0; k < dimensions; k++) - vec->x[k] = RandomDouble(); - } - } + ComputeNewCenters(samples, aggCenters, centerCounts, closestCenters, type); /* Set new centers if different from agg centers */ if (type == IVFFLAT_TYPE_HALFVEC)