Fixed infinite values with list centers

This commit is contained in:
Andrew Kane
2023-06-04 10:42:55 -07:00
parent 6330abb7df
commit e971fdd4fd

View File

@@ -170,7 +170,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
FmgrInfo *normprocinfo;
Oid collation;
Vector *vec;
Vector *newCenter;
int iteration;
int64 j;
int64 k;
@@ -178,6 +177,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
int numCenters = centers->maxlen;
int numSamples = samples->length;
VectorArray newCenters;
double *centerSums;
int *centerCounts;
int *closestCenters;
float *lowerBound;
@@ -198,6 +198,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->dim);
Size centersSize = VECTOR_ARRAY_SIZE(centers->maxlen, centers->dim);
Size newCentersSize = VECTOR_ARRAY_SIZE(numCenters, dimensions);
Size centerSumsSize = sizeof(double) * numCenters * dimensions;
Size centerCountsSize = sizeof(int) * numCenters;
Size closestCentersSize = sizeof(int) * numSamples;
Size lowerBoundSize = sizeof(float) * numSamples * numCenters;
@@ -207,6 +208,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
Size newcdistSize = sizeof(float) * numCenters;
/* Calculate total size */
/* TODO Add centerSumsSize in 0.5.0 */
Size totalSize = samplesSize + centersSize + newCentersSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize;
/* Check memory requirements */
@@ -227,7 +229,8 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
collation = index->rd_indcollation[0];
/* Allocate space */
/* Use float instead of double to save memory */
/* Use float instead of double when possible to save memory */
centerSums = palloc(centerSumsSize);
centerCounts = palloc(centerCountsSize);
closestCenters = palloc(closestCentersSize);
lowerBound = palloc_extended(lowerBoundSize, MCXT_ALLOC_HUGE);
@@ -370,14 +373,11 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
}
/* Step 4: For each center c, let m(c) be mean of all points assigned */
for (j = 0; j < numCenters; j++)
{
vec = VectorArrayGet(newCenters, j);
for (k = 0; k < dimensions; k++)
vec->x[k] = 0.0;
for (j = 0; j < numCenters * dimensions; j++)
centerSums[j] = 0;
for (j = 0; j < numCenters; j++)
centerCounts[j] = 0;
}
for (j = 0; j < numSamples; j++)
{
@@ -385,9 +385,8 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
closestCenter = closestCenters[j];
/* Increment sum and count of closest center */
newCenter = VectorArrayGet(newCenters, closestCenter);
for (k = 0; k < dimensions; k++)
newCenter->x[k] += vec->x[k];
centerSums[closestCenter * dimensions + k] += vec->x[k];
centerCounts[closestCenter] += 1;
}
@@ -399,7 +398,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
if (centerCounts[j] > 0)
{
for (k = 0; k < dimensions; k++)
vec->x[k] /= centerCounts[j];
vec->x[k] = centerSums[j * dimensions + k] / centerCounts[j];
}
else
{
@@ -444,6 +443,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
}
VectorArrayFree(newCenters);
pfree(centerSums);
pfree(centerCounts);
pfree(closestCenters);
pfree(lowerBound);
@@ -461,12 +461,26 @@ CheckCenters(Relation index, VectorArray centers)
{
FmgrInfo *normprocinfo;
Oid collation;
Vector *vec;
int i;
int j;
double norm;
if (centers->length != centers->maxlen)
elog(ERROR, "Not enough centers. Please report a bug.");
/* Ensure no infinite values */
for (i = 0; i < centers->length; i++)
{
vec = VectorArrayGet(centers, i);
for (j = 0; j < vec->dim; j++)
{
if (isinf(vec->x[j]))
elog(ERROR, "Infinite value detected. Please report a bug.");
}
}
/* Ensure no duplicate centers */
/* Fine to sort in-place */
qsort(centers->items, centers->length, VECTOR_SIZE(centers->dim), CompareVectors);