diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index d4130f7..548f7ff 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -254,50 +254,58 @@ ShowMemoryUsage(MemoryContext context, Size estimatedSize) } #endif +static void +VectorSumCenter(const void *v, Vector * aggCenter) +{ + Vector *vec = (Vector *) v; + + for (int k = 0; k < aggCenter->dim; k++) + aggCenter->x[k] += vec->x[k]; +} + +static void +HalfvecSumCenter(const void *v, Vector * aggCenter) +{ + HalfVector *vec = (HalfVector *) v; + + for (int k = 0; k < aggCenter->dim; k++) + aggCenter->x[k] += HalfToFloat4(vec->x[k]); +} + +static void +BitSumCenter(const void *v, Vector * aggCenter) +{ + VarBit *vec = (VarBit *) v; + + for (int k = 0; k < aggCenter->dim; k++) + aggCenter->x[k] += (float) (((VARBITS(vec)[k / 8]) >> (7 - (k % 8))) & 0x01); +} + /* * Sum centers */ static void SumCenters(VectorArray samples, VectorArray aggCenters, int *closestCenters, IvfflatType type) { - int dimensions = aggCenters->dim; int numSamples = samples->length; + void (*sumCenter) (const void *v, Vector * aggCenter); if (type == IVFFLAT_TYPE_VECTOR) - { - for (int j = 0; j < numSamples; j++) - { - Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenters[j]); - Vector *vec = (Vector *) VectorArrayGet(samples, j); - - for (int k = 0; k < dimensions; k++) - aggCenter->x[k] += vec->x[k]; - } - } + sumCenter = VectorSumCenter; else if (type == IVFFLAT_TYPE_HALFVEC) - { - for (int j = 0; j < numSamples; j++) - { - Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenters[j]); - HalfVector *vec = (HalfVector *) VectorArrayGet(samples, j); - - for (int k = 0; k < dimensions; k++) - aggCenter->x[k] += HalfToFloat4(vec->x[k]); - } - } + sumCenter = HalfvecSumCenter; else if (type == IVFFLAT_TYPE_BIT) - { - for (int j = 0; j < numSamples; j++) - { - Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenters[j]); - VarBit *vec = (VarBit *) VectorArrayGet(samples, j); - - for (int k = 0; k < dimensions; k++) - aggCenter->x[k] += (float) (((VARBITS(vec)[k / 8]) >> (7 - (k % 8))) & 0x01); - } - } + sumCenter = BitSumCenter; else elog(ERROR, "Unsupported type"); + + + for (int j = 0; j < numSamples; j++) + { + Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenters[j]); + + sumCenter(VectorArrayGet(samples, j), aggCenter); + } } /*