From 0f4c2407dd93baf8ebdeb1235ba0e6e51c3466cb Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Wed, 24 Apr 2024 18:13:01 -0700 Subject: [PATCH] Removed IvfflatType from CheckCenters [skip ci] --- src/ivfkmeans.c | 42 +++++++++++++++--------------------------- 1 file changed, 15 insertions(+), 27 deletions(-) diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index 774f530..14a03ed 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -685,9 +685,10 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, KmeansStat * Detect issues with centers */ static void -CheckCenters(Relation index, VectorArray centers, IvfflatType type, KmeansState * kmeansstate) +CheckCenters(Relation index, VectorArray centers, KmeansState * kmeansstate) { FmgrInfo *normprocinfo; + float *scratch = palloc(sizeof(float) * centers->dim); if (centers->length != centers->maxlen) elog(ERROR, "Not enough centers. Please report a bug."); @@ -695,34 +696,19 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type, KmeansState /* Ensure no NaN or infinite values */ for (int i = 0; i < centers->length; i++) { - if (type == IVFFLAT_TYPE_VECTOR) + for (int j = 0; j < centers->dim; j++) + scratch[j] = 0; + + kmeansstate->sumCenter(VectorArrayGet(centers, i), scratch); + + for (int j = 0; j < centers->dim; j++) { - Vector *vec = (Vector *) VectorArrayGet(centers, i); + if (isnan(scratch[j])) + elog(ERROR, "NaN detected. Please report a bug."); - for (int j = 0; j < vec->dim; j++) - { - if (isnan(vec->x[j])) - elog(ERROR, "NaN detected. Please report a bug."); - - if (isinf(vec->x[j])) - elog(ERROR, "Infinite value detected. Please report a bug."); - } + if (isinf(scratch[j])) + elog(ERROR, "Infinite value detected. Please report a bug."); } - else if (type == IVFFLAT_TYPE_HALFVEC) - { - HalfVector *vec = (HalfVector *) VectorArrayGet(centers, i); - - for (int j = 0; j < vec->dim; j++) - { - if (HalfIsNan(vec->x[j])) - elog(ERROR, "NaN detected. Please report a bug."); - - if (HalfIsInf(vec->x[j])) - elog(ERROR, "Infinite value detected. Please report a bug."); - } - } - else if (type != IVFFLAT_TYPE_BIT) - elog(ERROR, "Unsupported type"); } if (kmeansstate->checkDuplicates) @@ -752,6 +738,8 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type, KmeansState elog(ERROR, "Zero norm detected. Please report a bug."); } } + + pfree(scratch); } static void @@ -804,5 +792,5 @@ IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatT else ElkanKmeans(index, samples, centers, &kmeansstate); - CheckCenters(index, centers, type, &kmeansstate); + CheckCenters(index, centers, &kmeansstate); }