mirror of
https://github.com/pgvector/pgvector.git
synced 2026-07-01 02:02:10 +08:00
Moved code to separate function [skip ci]
This commit is contained in:
143
src/ivfkmeans.c
143
src/ivfkmeans.c
@@ -218,6 +218,83 @@ ShowMemoryUsage(MemoryContext context, Size estimatedSize)
|
||||
}
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Compute new centers
|
||||
*/
|
||||
static void
|
||||
ComputeNewCenters(VectorArray samples, VectorArray aggCenters, int *centerCounts, int *closestCenters, IvfflatType type)
|
||||
{
|
||||
int dimensions = aggCenters->dim;
|
||||
int numCenters = aggCenters->maxlen;
|
||||
int numSamples = samples->length;
|
||||
|
||||
for (int j = 0; j < numCenters; j++)
|
||||
{
|
||||
Vector *vec = (Vector *) VectorArrayGet(aggCenters, j);
|
||||
|
||||
for (int k = 0; k < dimensions; k++)
|
||||
vec->x[k] = 0.0;
|
||||
|
||||
centerCounts[j] = 0;
|
||||
}
|
||||
|
||||
/* Increment sum of closest center */
|
||||
if (type == IVFFLAT_TYPE_VECTOR)
|
||||
{
|
||||
for (int j = 0; j < numSamples; j++)
|
||||
{
|
||||
Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenters[j]);
|
||||
Vector *vec = (Vector *) VectorArrayGet(samples, j);
|
||||
|
||||
for (int k = 0; k < dimensions; k++)
|
||||
aggCenter->x[k] += vec->x[k];
|
||||
}
|
||||
}
|
||||
else if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
{
|
||||
for (int j = 0; j < numSamples; j++)
|
||||
{
|
||||
Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenters[j]);
|
||||
HalfVector *vec = (HalfVector *) VectorArrayGet(samples, j);
|
||||
|
||||
for (int k = 0; k < dimensions; k++)
|
||||
aggCenter->x[k] += HalfToFloat4(vec->x[k]);
|
||||
}
|
||||
}
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
|
||||
/* Increment count of closest center */
|
||||
for (int j = 0; j < numSamples; j++)
|
||||
centerCounts[closestCenters[j]] += 1;
|
||||
|
||||
/* Average centers */
|
||||
for (int j = 0; j < numCenters; j++)
|
||||
{
|
||||
Vector *vec = (Vector *) VectorArrayGet(aggCenters, j);
|
||||
|
||||
if (centerCounts[j] > 0)
|
||||
{
|
||||
/* Double avoids overflow, but requires more memory */
|
||||
/* TODO Update bounds */
|
||||
for (int k = 0; k < dimensions; k++)
|
||||
{
|
||||
if (isinf(vec->x[k]))
|
||||
vec->x[k] = vec->x[k] > 0 ? FLT_MAX : -FLT_MAX;
|
||||
}
|
||||
|
||||
for (int k = 0; k < dimensions; k++)
|
||||
vec->x[k] /= centerCounts[j];
|
||||
}
|
||||
else
|
||||
{
|
||||
/* TODO Handle empty centers properly */
|
||||
for (int k = 0; k < dimensions; k++)
|
||||
vec->x[k] = RandomDouble();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Use Elkan for performance. This requires distance function to satisfy triangle inequality.
|
||||
*
|
||||
@@ -463,71 +540,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
|
||||
}
|
||||
|
||||
/* Step 4: For each center c, let m(c) be mean of all points assigned */
|
||||
for (int j = 0; j < numCenters; j++)
|
||||
{
|
||||
Vector *vec = (Vector *) VectorArrayGet(aggCenters, j);
|
||||
|
||||
for (int k = 0; k < dimensions; k++)
|
||||
vec->x[k] = 0.0;
|
||||
|
||||
centerCounts[j] = 0;
|
||||
}
|
||||
|
||||
/* Increment sum of closest center */
|
||||
if (type == IVFFLAT_TYPE_VECTOR)
|
||||
{
|
||||
for (int j = 0; j < numSamples; j++)
|
||||
{
|
||||
Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenters[j]);
|
||||
Vector *vec = (Vector *) VectorArrayGet(samples, j);
|
||||
|
||||
for (int k = 0; k < dimensions; k++)
|
||||
aggCenter->x[k] += vec->x[k];
|
||||
}
|
||||
}
|
||||
else if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
{
|
||||
for (int j = 0; j < numSamples; j++)
|
||||
{
|
||||
Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenters[j]);
|
||||
HalfVector *vec = (HalfVector *) VectorArrayGet(samples, j);
|
||||
|
||||
for (int k = 0; k < dimensions; k++)
|
||||
aggCenter->x[k] += HalfToFloat4(vec->x[k]);
|
||||
}
|
||||
}
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
|
||||
/* Increment count of closest center */
|
||||
for (int j = 0; j < numSamples; j++)
|
||||
centerCounts[closestCenters[j]] += 1;
|
||||
|
||||
/* Average centers */
|
||||
for (int j = 0; j < numCenters; j++)
|
||||
{
|
||||
Vector *vec = (Vector *) VectorArrayGet(aggCenters, j);
|
||||
|
||||
if (centerCounts[j] > 0)
|
||||
{
|
||||
/* Double avoids overflow, but requires more memory */
|
||||
/* TODO Update bounds */
|
||||
for (int k = 0; k < dimensions; k++)
|
||||
{
|
||||
if (isinf(vec->x[k]))
|
||||
vec->x[k] = vec->x[k] > 0 ? FLT_MAX : -FLT_MAX;
|
||||
}
|
||||
|
||||
for (int k = 0; k < dimensions; k++)
|
||||
vec->x[k] /= centerCounts[j];
|
||||
}
|
||||
else
|
||||
{
|
||||
/* TODO Handle empty centers properly */
|
||||
for (int k = 0; k < dimensions; k++)
|
||||
vec->x[k] = RandomDouble();
|
||||
}
|
||||
}
|
||||
ComputeNewCenters(samples, aggCenters, centerCounts, closestCenters, type);
|
||||
|
||||
/* Set new centers if different from agg centers */
|
||||
if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
|
||||
Reference in New Issue
Block a user