diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index 8fb8be6..ab8f7d4 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -498,18 +498,13 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const Ivff } /* - * Detect issues with centers + * Ensure no NaN or infinite values */ static void -CheckCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeInfo) +CheckElements(VectorArray centers, const IvfflatTypeInfo * typeInfo) { - FmgrInfo *normprocinfo; float *scratch = palloc(sizeof(float) * centers->dim); - if (centers->length != centers->maxlen) - elog(ERROR, "Not enough centers. Please report a bug."); - - /* Ensure no NaN or infinite values */ for (int i = 0; i < centers->length; i++) { for (int j = 0; j < centers->dim; j++) @@ -527,24 +522,43 @@ CheckCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeIn elog(ERROR, "Infinite value detected. Please report a bug."); } } +} - /* Ensure no zero vectors for cosine distance */ +/* + * Ensure no zero vectors for cosine distance + */ +static void +CheckNorms(VectorArray centers, Relation index) +{ /* Check NORM_PROC instead of KMEANS_NORM_PROC */ - normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); - if (normprocinfo != NULL) + FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); + Oid collation = index->rd_indcollation[0]; + + if (normprocinfo == NULL) + return; + + for (int i = 0; i < centers->length; i++) { - Oid collation = index->rd_indcollation[0]; + double norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(VectorArrayGet(centers, i)))); - for (int i = 0; i < centers->length; i++) - { - double norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(VectorArrayGet(centers, i)))); - - if (norm == 0) - elog(ERROR, "Zero norm detected. Please report a bug."); - } + if (norm == 0) + elog(ERROR, "Zero norm detected. Please report a bug."); } } +/* + * Detect issues with centers + */ +static void +CheckCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeInfo) +{ + if (centers->length != centers->maxlen) + elog(ERROR, "Not enough centers. Please report a bug."); + + CheckElements(centers, typeInfo); + CheckNorms(centers, index); +} + /* * Perform naive k-means centering * We use spherical k-means for inner product and cosine