diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index 61053af..7bba1e1 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -170,7 +170,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) FmgrInfo *normprocinfo; Oid collation; Vector *vec; - Vector *newCenter; int iteration; int64 j; int64 k; @@ -178,6 +177,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) int numCenters = centers->maxlen; int numSamples = samples->length; VectorArray newCenters; + double *centerSums; int *centerCounts; int *closestCenters; float *lowerBound; @@ -198,6 +198,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->dim); Size centersSize = VECTOR_ARRAY_SIZE(centers->maxlen, centers->dim); Size newCentersSize = VECTOR_ARRAY_SIZE(numCenters, dimensions); + Size centerSumsSize = sizeof(double) * numCenters * dimensions; Size centerCountsSize = sizeof(int) * numCenters; Size closestCentersSize = sizeof(int) * numSamples; Size lowerBoundSize = sizeof(float) * numSamples * numCenters; @@ -207,6 +208,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) Size newcdistSize = sizeof(float) * numCenters; /* Calculate total size */ + /* TODO Add centerSumsSize in 0.5.0 */ Size totalSize = samplesSize + centersSize + newCentersSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize; /* Check memory requirements */ @@ -227,7 +229,8 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) collation = index->rd_indcollation[0]; /* Allocate space */ - /* Use float instead of double to save memory */ + /* Use float instead of double when possible to save memory */ + centerSums = palloc(centerSumsSize); centerCounts = palloc(centerCountsSize); closestCenters = palloc(closestCentersSize); lowerBound = palloc_extended(lowerBoundSize, MCXT_ALLOC_HUGE); @@ -370,14 +373,11 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) } /* Step 4: For each center c, let m(c) be mean of all points assigned */ - for (j = 0; j < numCenters; j++) - { - vec = VectorArrayGet(newCenters, j); - for (k = 0; k < dimensions; k++) - vec->x[k] = 0.0; + for (j = 0; j < numCenters * dimensions; j++) + centerSums[j] = 0; + for (j = 0; j < numCenters; j++) centerCounts[j] = 0; - } for (j = 0; j < numSamples; j++) { @@ -385,9 +385,8 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) closestCenter = closestCenters[j]; /* Increment sum and count of closest center */ - newCenter = VectorArrayGet(newCenters, closestCenter); for (k = 0; k < dimensions; k++) - newCenter->x[k] += vec->x[k]; + centerSums[closestCenter * dimensions + k] += vec->x[k]; centerCounts[closestCenter] += 1; } @@ -399,7 +398,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) if (centerCounts[j] > 0) { for (k = 0; k < dimensions; k++) - vec->x[k] /= centerCounts[j]; + vec->x[k] = centerSums[j * dimensions + k] / centerCounts[j]; } else { @@ -444,6 +443,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) } VectorArrayFree(newCenters); + pfree(centerSums); pfree(centerCounts); pfree(closestCenters); pfree(lowerBound); @@ -461,12 +461,26 @@ CheckCenters(Relation index, VectorArray centers) { FmgrInfo *normprocinfo; Oid collation; + Vector *vec; int i; + int j; double norm; if (centers->length != centers->maxlen) elog(ERROR, "Not enough centers. Please report a bug."); + /* Ensure no infinite values */ + for (i = 0; i < centers->length; i++) + { + vec = VectorArrayGet(centers, i); + + for (j = 0; j < vec->dim; j++) + { + if (isinf(vec->x[j])) + elog(ERROR, "Infinite value detected. Please report a bug."); + } + } + /* Ensure no duplicate centers */ /* Fine to sort in-place */ qsort(centers->items, centers->length, VECTOR_SIZE(centers->dim), CompareVectors);