Moved type-specific code to separate functions

This commit is contained in:
Andrew Kane
2024-04-23 16:32:10 -07:00
parent bbfb3f200a
commit b609c343b4

View File

@@ -251,27 +251,14 @@ ShowMemoryUsage(MemoryContext context, Size estimatedSize)
#endif
/*
* Compute new centers
* Sum centers
*/
static void
ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, IvfflatType type)
SumCenters(VectorArray samples, VectorArray aggCenters, int *closestCenters, IvfflatType type)
{
int dimensions = aggCenters->dim;
int numCenters = aggCenters->maxlen;
int numSamples = samples->length;
/* Reset sum and count */
for (int j = 0; j < numCenters; j++)
{
Vector *vec = (Vector *) VectorArrayGet(aggCenters, j);
for (int k = 0; k < dimensions; k++)
vec->x[k] = 0.0;
centerCounts[j] = 0;
}
/* Increment sum of closest center */
if (type == IVFFLAT_TYPE_VECTOR)
{
for (int j = 0; j < numSamples; j++)
@@ -307,6 +294,68 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe
}
else
elog(ERROR, "Unsupported type");
}
/*
* Set new centers
*/
static void
SetNewCenters(VectorArray aggCenters, VectorArray newCenters, IvfflatType type)
{
int dimensions = aggCenters->dim;
int numCenters = aggCenters->maxlen;
if (type == IVFFLAT_TYPE_HALFVEC)
{
for (int j = 0; j < numCenters; j++)
{
Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, j);
HalfVector *newCenter = (HalfVector *) VectorArrayGet(newCenters, j);
for (int k = 0; k < dimensions; k++)
newCenter->x[k] = Float4ToHalfUnchecked(aggCenter->x[k]);
}
}
else if (type == IVFFLAT_TYPE_BIT)
{
for (int j = 0; j < numCenters; j++)
{
Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, j);
VarBit *newCenter = (VarBit *) VectorArrayGet(newCenters, j);
unsigned char *nx = VARBITS(newCenter);
for (uint32 k = 0; k < VARBITBYTES(newCenter); k++)
nx[k] = 0;
for (int k = 0; k < dimensions; k++)
nx[k / 8] |= (aggCenter->x[k] > 0.5) << (7 - (k % 8));
}
}
}
/*
* Compute new centers
*/
static void
ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, IvfflatType type)
{
int dimensions = aggCenters->dim;
int numCenters = aggCenters->maxlen;
int numSamples = samples->length;
/* Reset sum and count */
for (int j = 0; j < numCenters; j++)
{
Vector *vec = (Vector *) VectorArrayGet(aggCenters, j);
for (int k = 0; k < dimensions; k++)
vec->x[k] = 0.0;
centerCounts[j] = 0;
}
/* Increment sum of closest center */
SumCenters(samples, aggCenters, closestCenters, type);
/* Increment count of closest center */
for (int j = 0; j < numSamples; j++)
@@ -339,32 +388,7 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe
}
/* Set new centers if different from agg centers */
if (type == IVFFLAT_TYPE_HALFVEC)
{
for (int j = 0; j < numCenters; j++)
{
Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, j);
HalfVector *newCenter = (HalfVector *) VectorArrayGet(newCenters, j);
for (int k = 0; k < dimensions; k++)
newCenter->x[k] = Float4ToHalfUnchecked(aggCenter->x[k]);
}
}
else if (type == IVFFLAT_TYPE_BIT)
{
for (int j = 0; j < numCenters; j++)
{
Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, j);
VarBit *newCenter = (VarBit *) VectorArrayGet(newCenters, j);
unsigned char *nx = VARBITS(newCenter);
for (uint32 k = 0; k < VARBITBYTES(newCenter); k++)
nx[k] = 0;
for (int k = 0; k < dimensions; k++)
nx[k / 8] |= (aggCenter->x[k] > 0.5) << (7 - (k % 8));
}
}
SetNewCenters(aggCenters, newCenters, type);
/* Normalize if needed */
if (normprocinfo != NULL)