Improved CheckCenters code [skip ci]

This commit is contained in:
Andrew Kane
2024-04-25 17:41:53 -07:00
parent dc88135515
commit cd95d6dfa4

View File

@@ -498,18 +498,13 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const Ivff
}
/*
* Detect issues with centers
* Ensure no NaN or infinite values
*/
static void
CheckCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeInfo)
CheckElements(VectorArray centers, const IvfflatTypeInfo * typeInfo)
{
FmgrInfo *normprocinfo;
float *scratch = palloc(sizeof(float) * centers->dim);
if (centers->length != centers->maxlen)
elog(ERROR, "Not enough centers. Please report a bug.");
/* Ensure no NaN or infinite values */
for (int i = 0; i < centers->length; i++)
{
for (int j = 0; j < centers->dim; j++)
@@ -527,24 +522,43 @@ CheckCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeIn
elog(ERROR, "Infinite value detected. Please report a bug.");
}
}
}
/* Ensure no zero vectors for cosine distance */
/*
* Ensure no zero vectors for cosine distance
*/
static void
CheckNorms(VectorArray centers, Relation index)
{
/* Check NORM_PROC instead of KMEANS_NORM_PROC */
normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
if (normprocinfo != NULL)
FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
Oid collation = index->rd_indcollation[0];
if (normprocinfo == NULL)
return;
for (int i = 0; i < centers->length; i++)
{
Oid collation = index->rd_indcollation[0];
double norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(VectorArrayGet(centers, i))));
for (int i = 0; i < centers->length; i++)
{
double norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(VectorArrayGet(centers, i))));
if (norm == 0)
elog(ERROR, "Zero norm detected. Please report a bug.");
}
if (norm == 0)
elog(ERROR, "Zero norm detected. Please report a bug.");
}
}
/*
* Detect issues with centers
*/
static void
CheckCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeInfo)
{
if (centers->length != centers->maxlen)
elog(ERROR, "Not enough centers. Please report a bug.");
CheckElements(centers, typeInfo);
CheckNorms(centers, index);
}
/*
* Perform naive k-means centering
* We use spherical k-means for inner product and cosine