Improved k-means code [skip ci]

This commit is contained in:
Andrew Kane
2024-04-11 17:15:20 -07:00
parent 626bc053e5
commit c581db9f98

View File

@@ -206,6 +206,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
int numCenters = centers->maxlen;
int numSamples = samples->length;
VectorArray newCenters;
VectorArray aggCenters;
int *centerCounts;
int *closestCenters;
float *lowerBound;
@@ -264,15 +265,22 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
halfcdist = palloc_extended(halfcdistSize, MCXT_ALLOC_HUGE);
newcdist = palloc(newcdistSize);
newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize);
aggCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize);
for (int64 j = 0; j < numCenters; j++)
{
Vector *vec = (Vector *) VectorArrayGet(newCenters, j);
if (type == IVFFLAT_TYPE_VECTOR)
{
Vector *vec = (Vector *) VectorArrayGet(aggCenters, j);
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
vec->dim = dimensions;
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
vec->dim = dimensions;
}
else
elog(ERROR, "Unsupported type");
}
newCenters = aggCenters;
#ifdef IVFFLAT_MEMORY
ShowMemoryUsage(totalSize);
#endif
@@ -413,7 +421,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
/* Step 4: For each center c, let m(c) be mean of all points assigned */
for (int64 j = 0; j < numCenters; j++)
{
Vector *vec = (Vector *) VectorArrayGet(newCenters, j);
Vector *vec = (Vector *) VectorArrayGet(aggCenters, j);
for (int64 k = 0; k < dimensions; k++)
vec->x[k] = 0.0;
@@ -425,7 +433,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
{
int closestCenter = closestCenters[j];
Vector *vec = (Vector *) VectorArrayGet(samples, j);
Vector *newCenter = (Vector *) VectorArrayGet(newCenters, closestCenter);
Vector *newCenter = (Vector *) VectorArrayGet(aggCenters, closestCenter);
/* Increment sum and count of closest center */
for (int64 k = 0; k < dimensions; k++)
@@ -436,7 +444,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
for (int64 j = 0; j < numCenters; j++)
{
Datum center = PointerGetDatum(VectorArrayGet(newCenters, j));
Datum center = PointerGetDatum(VectorArrayGet(aggCenters, j));
Vector *vec = DatumGetVector(center);
if (centerCounts[j] > 0)
@@ -512,16 +520,21 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type)
/* Ensure no NaN or infinite values */
for (int i = 0; i < centers->length; i++)
{
Vector *vec = (Vector *) VectorArrayGet(centers, i);
for (int j = 0; j < vec->dim; j++)
if (type == IVFFLAT_TYPE_VECTOR)
{
if (isnan(vec->x[j]))
elog(ERROR, "NaN detected. Please report a bug.");
Vector *vec = (Vector *) VectorArrayGet(centers, i);
if (isinf(vec->x[j]))
elog(ERROR, "Infinite value detected. Please report a bug.");
for (int j = 0; j < vec->dim; j++)
{
if (isnan(vec->x[j]))
elog(ERROR, "NaN detected. Please report a bug.");
if (isinf(vec->x[j]))
elog(ERROR, "Infinite value detected. Please report a bug.");
}
}
else
elog(ERROR, "Unsupported type");
}
/* Ensure no duplicate centers */