mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
Improved CheckCenters code [skip ci]
This commit is contained in:
@@ -498,18 +498,13 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const Ivff
|
||||
}
|
||||
|
||||
/*
|
||||
* Detect issues with centers
|
||||
* Ensure no NaN or infinite values
|
||||
*/
|
||||
static void
|
||||
CheckCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeInfo)
|
||||
CheckElements(VectorArray centers, const IvfflatTypeInfo * typeInfo)
|
||||
{
|
||||
FmgrInfo *normprocinfo;
|
||||
float *scratch = palloc(sizeof(float) * centers->dim);
|
||||
|
||||
if (centers->length != centers->maxlen)
|
||||
elog(ERROR, "Not enough centers. Please report a bug.");
|
||||
|
||||
/* Ensure no NaN or infinite values */
|
||||
for (int i = 0; i < centers->length; i++)
|
||||
{
|
||||
for (int j = 0; j < centers->dim; j++)
|
||||
@@ -527,24 +522,43 @@ CheckCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeIn
|
||||
elog(ERROR, "Infinite value detected. Please report a bug.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Ensure no zero vectors for cosine distance */
|
||||
/*
|
||||
* Ensure no zero vectors for cosine distance
|
||||
*/
|
||||
static void
|
||||
CheckNorms(VectorArray centers, Relation index)
|
||||
{
|
||||
/* Check NORM_PROC instead of KMEANS_NORM_PROC */
|
||||
normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
|
||||
if (normprocinfo != NULL)
|
||||
FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
|
||||
Oid collation = index->rd_indcollation[0];
|
||||
|
||||
if (normprocinfo == NULL)
|
||||
return;
|
||||
|
||||
for (int i = 0; i < centers->length; i++)
|
||||
{
|
||||
Oid collation = index->rd_indcollation[0];
|
||||
double norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(VectorArrayGet(centers, i))));
|
||||
|
||||
for (int i = 0; i < centers->length; i++)
|
||||
{
|
||||
double norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(VectorArrayGet(centers, i))));
|
||||
|
||||
if (norm == 0)
|
||||
elog(ERROR, "Zero norm detected. Please report a bug.");
|
||||
}
|
||||
if (norm == 0)
|
||||
elog(ERROR, "Zero norm detected. Please report a bug.");
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Detect issues with centers
|
||||
*/
|
||||
static void
|
||||
CheckCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeInfo)
|
||||
{
|
||||
if (centers->length != centers->maxlen)
|
||||
elog(ERROR, "Not enough centers. Please report a bug.");
|
||||
|
||||
CheckElements(centers, typeInfo);
|
||||
CheckNorms(centers, index);
|
||||
}
|
||||
|
||||
/*
|
||||
* Perform naive k-means centering
|
||||
* We use spherical k-means for inner product and cosine
|
||||
|
||||
Reference in New Issue
Block a user