diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index 268e8f7..147b149 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -170,6 +170,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) FmgrInfo *normprocinfo; Oid collation; Vector *vec; + Vector *newCenter; int iteration; int64 j; int64 k; @@ -177,7 +178,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) int numCenters = centers->maxlen; int numSamples = samples->length; VectorArray newCenters; - double *centerSums; int *centerCounts; int *closestCenters; float *lowerBound; @@ -198,7 +198,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->dim); Size centersSize = VECTOR_ARRAY_SIZE(centers->maxlen, centers->dim); Size newCentersSize = VECTOR_ARRAY_SIZE(numCenters, dimensions); - Size centerSumsSize = sizeof(double) * numCenters * dimensions; Size centerCountsSize = sizeof(int) * numCenters; Size closestCentersSize = sizeof(int) * numSamples; Size lowerBoundSize = sizeof(float) * numSamples * numCenters; @@ -208,7 +207,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) Size newcdistSize = sizeof(float) * numCenters; /* Calculate total size */ - /* TODO Add centerSumsSize in 0.5.0 */ Size totalSize = samplesSize + centersSize + newCentersSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize; /* Check memory requirements */ @@ -229,8 +227,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) collation = index->rd_indcollation[0]; /* Allocate space */ - /* Use float instead of double when possible to save memory */ - centerSums = palloc(centerSumsSize); + /* Use float instead of double to save memory */ centerCounts = palloc(centerCountsSize); closestCenters = palloc(closestCentersSize); lowerBound = palloc_extended(lowerBoundSize, MCXT_ALLOC_HUGE); @@ -373,11 +370,14 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) } /* Step 4: For each center c, let m(c) be mean of all points assigned */ - for (j = 0; j < numCenters * dimensions; j++) - centerSums[j] = 0; - for (j = 0; j < numCenters; j++) + { + vec = VectorArrayGet(newCenters, j); + for (k = 0; k < dimensions; k++) + vec->x[k] = 0.0; + centerCounts[j] = 0; + } for (j = 0; j < numSamples; j++) { @@ -385,8 +385,9 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) closestCenter = closestCenters[j]; /* Increment sum and count of closest center */ + newCenter = VectorArrayGet(newCenters, closestCenter); for (k = 0; k < dimensions; k++) - centerSums[closestCenter * dimensions + k] += vec->x[k]; + newCenter->x[k] += vec->x[k]; centerCounts[closestCenter] += 1; } @@ -398,7 +399,13 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) if (centerCounts[j] > 0) { for (k = 0; k < dimensions; k++) - vec->x[k] = centerSums[j * dimensions + k] / centerCounts[j]; + vec->x[k] /= centerCounts[j]; + + /* Double avoids overflow, but requires more memory */ + /* TODO Update bounds */ + for (k = 0; k < dimensions; k++) + if (isinf(vec->x[k])) + vec->x[k] = vec->x[k] > 0 ? FLT_MAX : -FLT_MAX; } else { @@ -443,7 +450,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) } VectorArrayFree(newCenters); - pfree(centerSums); pfree(centerCounts); pfree(closestCenters); pfree(lowerBound);