From c581db9f9892f5b172c1749fb50a4830debf031c Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Thu, 11 Apr 2024 17:15:20 -0700 Subject: [PATCH] Improved k-means code [skip ci] --- src/ivfkmeans.c | 41 +++++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index c2c18f1..0fcf933 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -206,6 +206,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp int numCenters = centers->maxlen; int numSamples = samples->length; VectorArray newCenters; + VectorArray aggCenters; int *centerCounts; int *closestCenters; float *lowerBound; @@ -264,15 +265,22 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp halfcdist = palloc_extended(halfcdistSize, MCXT_ALLOC_HUGE); newcdist = palloc(newcdistSize); - newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize); + aggCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize); for (int64 j = 0; j < numCenters; j++) { - Vector *vec = (Vector *) VectorArrayGet(newCenters, j); + if (type == IVFFLAT_TYPE_VECTOR) + { + Vector *vec = (Vector *) VectorArrayGet(aggCenters, j); - SET_VARSIZE(vec, VECTOR_SIZE(dimensions)); - vec->dim = dimensions; + SET_VARSIZE(vec, VECTOR_SIZE(dimensions)); + vec->dim = dimensions; + } + else + elog(ERROR, "Unsupported type"); } + newCenters = aggCenters; + #ifdef IVFFLAT_MEMORY ShowMemoryUsage(totalSize); #endif @@ -413,7 +421,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp /* Step 4: For each center c, let m(c) be mean of all points assigned */ for (int64 j = 0; j < numCenters; j++) { - Vector *vec = (Vector *) VectorArrayGet(newCenters, j); + Vector *vec = (Vector *) VectorArrayGet(aggCenters, j); for (int64 k = 0; k < dimensions; k++) vec->x[k] = 0.0; @@ -425,7 +433,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp { int closestCenter = closestCenters[j]; Vector *vec = (Vector *) VectorArrayGet(samples, j); - Vector *newCenter = (Vector *) VectorArrayGet(newCenters, closestCenter); + Vector *newCenter = (Vector *) VectorArrayGet(aggCenters, closestCenter); /* Increment sum and count of closest center */ for (int64 k = 0; k < dimensions; k++) @@ -436,7 +444,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp for (int64 j = 0; j < numCenters; j++) { - Datum center = PointerGetDatum(VectorArrayGet(newCenters, j)); + Datum center = PointerGetDatum(VectorArrayGet(aggCenters, j)); Vector *vec = DatumGetVector(center); if (centerCounts[j] > 0) @@ -512,16 +520,21 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type) /* Ensure no NaN or infinite values */ for (int i = 0; i < centers->length; i++) { - Vector *vec = (Vector *) VectorArrayGet(centers, i); - - for (int j = 0; j < vec->dim; j++) + if (type == IVFFLAT_TYPE_VECTOR) { - if (isnan(vec->x[j])) - elog(ERROR, "NaN detected. Please report a bug."); + Vector *vec = (Vector *) VectorArrayGet(centers, i); - if (isinf(vec->x[j])) - elog(ERROR, "Infinite value detected. Please report a bug."); + for (int j = 0; j < vec->dim; j++) + { + if (isnan(vec->x[j])) + elog(ERROR, "NaN detected. Please report a bug."); + + if (isinf(vec->x[j])) + elog(ERROR, "Infinite value detected. Please report a bug."); + } } + else + elog(ERROR, "Unsupported type"); } /* Ensure no duplicate centers */