Moved code [skip ci]

This commit is contained in:
Andrew Kane
2024-04-11 22:25:53 -07:00
parent 3621a84ef8
commit cc4b01bd49

View File

@@ -222,7 +222,7 @@ ShowMemoryUsage(MemoryContext context, Size estimatedSize)
* Compute new centers
*/
static void
ComputeNewCenters(VectorArray samples, VectorArray aggCenters, int *centerCounts, int *closestCenters, IvfflatType type)
ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, Oid collation, IvfflatType type)
{
int dimensions = aggCenters->dim;
int numCenters = aggCenters->maxlen;
@@ -293,6 +293,30 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, int *centerCounts
vec->x[k] = RandomDouble();
}
}
/* Set new centers if different from agg centers */
if (type == IVFFLAT_TYPE_HALFVEC)
{
for (int j = 0; j < numCenters; j++)
{
Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, j);
HalfVector *newCenter = (HalfVector *) VectorArrayGet(newCenters, j);
for (int k = 0; k < dimensions; k++)
newCenter->x[k] = Float4ToHalfUnchecked(aggCenter->x[k]);
}
}
/* Normalize if needed */
if (normprocinfo != NULL)
{
for (int j = 0; j < numCenters; j++)
{
Datum newCenter = PointerGetDatum(VectorArrayGet(newCenters, j));
ApplyNorm(normprocinfo, collation, newCenter, type);
}
}
}
/*
@@ -540,31 +564,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
}
/* Step 4: For each center c, let m(c) be mean of all points assigned */
ComputeNewCenters(samples, aggCenters, centerCounts, closestCenters, type);
/* Set new centers if different from agg centers */
if (type == IVFFLAT_TYPE_HALFVEC)
{
for (int j = 0; j < numCenters; j++)
{
Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, j);
HalfVector *newCenter = (HalfVector *) VectorArrayGet(newCenters, j);
for (int k = 0; k < dimensions; k++)
newCenter->x[k] = Float4ToHalfUnchecked(aggCenter->x[k]);
}
}
/* Normalize if needed */
if (normprocinfo != NULL)
{
for (int j = 0; j < numCenters; j++)
{
Datum newCenter = PointerGetDatum(VectorArrayGet(newCenters, j));
ApplyNorm(normprocinfo, collation, newCenter, type);
}
}
ComputeNewCenters(samples, aggCenters, newCenters, centerCounts, closestCenters, normprocinfo, collation, type);
/* Step 5 */
for (int j = 0; j < numCenters; j++)