mirror of
https://github.com/pgvector/pgvector.git
synced 2026-07-01 02:02:10 +08:00
Improved k-means code [skip ci]
This commit is contained in:
@@ -206,6 +206,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
|
||||
int numCenters = centers->maxlen;
|
||||
int numSamples = samples->length;
|
||||
VectorArray newCenters;
|
||||
VectorArray aggCenters;
|
||||
int *centerCounts;
|
||||
int *closestCenters;
|
||||
float *lowerBound;
|
||||
@@ -264,15 +265,22 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
|
||||
halfcdist = palloc_extended(halfcdistSize, MCXT_ALLOC_HUGE);
|
||||
newcdist = palloc(newcdistSize);
|
||||
|
||||
newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize);
|
||||
aggCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize);
|
||||
for (int64 j = 0; j < numCenters; j++)
|
||||
{
|
||||
Vector *vec = (Vector *) VectorArrayGet(newCenters, j);
|
||||
if (type == IVFFLAT_TYPE_VECTOR)
|
||||
{
|
||||
Vector *vec = (Vector *) VectorArrayGet(aggCenters, j);
|
||||
|
||||
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
|
||||
vec->dim = dimensions;
|
||||
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
|
||||
vec->dim = dimensions;
|
||||
}
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
}
|
||||
|
||||
newCenters = aggCenters;
|
||||
|
||||
#ifdef IVFFLAT_MEMORY
|
||||
ShowMemoryUsage(totalSize);
|
||||
#endif
|
||||
@@ -413,7 +421,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
|
||||
/* Step 4: For each center c, let m(c) be mean of all points assigned */
|
||||
for (int64 j = 0; j < numCenters; j++)
|
||||
{
|
||||
Vector *vec = (Vector *) VectorArrayGet(newCenters, j);
|
||||
Vector *vec = (Vector *) VectorArrayGet(aggCenters, j);
|
||||
|
||||
for (int64 k = 0; k < dimensions; k++)
|
||||
vec->x[k] = 0.0;
|
||||
@@ -425,7 +433,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
|
||||
{
|
||||
int closestCenter = closestCenters[j];
|
||||
Vector *vec = (Vector *) VectorArrayGet(samples, j);
|
||||
Vector *newCenter = (Vector *) VectorArrayGet(newCenters, closestCenter);
|
||||
Vector *newCenter = (Vector *) VectorArrayGet(aggCenters, closestCenter);
|
||||
|
||||
/* Increment sum and count of closest center */
|
||||
for (int64 k = 0; k < dimensions; k++)
|
||||
@@ -436,7 +444,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
|
||||
|
||||
for (int64 j = 0; j < numCenters; j++)
|
||||
{
|
||||
Datum center = PointerGetDatum(VectorArrayGet(newCenters, j));
|
||||
Datum center = PointerGetDatum(VectorArrayGet(aggCenters, j));
|
||||
Vector *vec = DatumGetVector(center);
|
||||
|
||||
if (centerCounts[j] > 0)
|
||||
@@ -512,16 +520,21 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type)
|
||||
/* Ensure no NaN or infinite values */
|
||||
for (int i = 0; i < centers->length; i++)
|
||||
{
|
||||
Vector *vec = (Vector *) VectorArrayGet(centers, i);
|
||||
|
||||
for (int j = 0; j < vec->dim; j++)
|
||||
if (type == IVFFLAT_TYPE_VECTOR)
|
||||
{
|
||||
if (isnan(vec->x[j]))
|
||||
elog(ERROR, "NaN detected. Please report a bug.");
|
||||
Vector *vec = (Vector *) VectorArrayGet(centers, i);
|
||||
|
||||
if (isinf(vec->x[j]))
|
||||
elog(ERROR, "Infinite value detected. Please report a bug.");
|
||||
for (int j = 0; j < vec->dim; j++)
|
||||
{
|
||||
if (isnan(vec->x[j]))
|
||||
elog(ERROR, "NaN detected. Please report a bug.");
|
||||
|
||||
if (isinf(vec->x[j]))
|
||||
elog(ERROR, "Infinite value detected. Please report a bug.");
|
||||
}
|
||||
}
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
}
|
||||
|
||||
/* Ensure no duplicate centers */
|
||||
|
||||
Reference in New Issue
Block a user