Improved initialization of new centers [skip ci]

This commit is contained in:
Andrew Kane
2024-04-24 16:45:16 -07:00
parent 25b98540c9
commit 15ee38456f

View File

@@ -474,6 +474,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
float *newcdist;
MemoryContext kmeansCtx;
MemoryContext oldCtx;
void (*initNewCenter) (Pointer v, int dimensions);
/* Calculate allocation sizes */
Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->itemsize);
@@ -525,50 +526,36 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
halfcdist = palloc_extended(halfcdistSize, MCXT_ALLOC_HUGE);
newcdist = palloc(newcdistSize);
aggCenters = VectorArrayInit(numCenters, dimensions, VECTOR_SIZE(dimensions));
aggCenters->length = numCenters;
/* Initialize new centers */
if (type == IVFFLAT_TYPE_VECTOR)
initNewCenter = VectorInitNewCenter;
else if (type == IVFFLAT_TYPE_HALFVEC)
initNewCenter = HalfvecInitNewCenter;
else if (type == IVFFLAT_TYPE_BIT)
initNewCenter = BitInitNewCenter;
else
elog(ERROR, "Unsupported type");
newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize);
newCenters->length = numCenters;
for (int j = 0; j < numCenters; j++)
{
Vector *vec = (Vector *) VectorArrayGet(aggCenters, j);
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
vec->dim = dimensions;
}
initNewCenter(VectorArrayGet(newCenters, j), dimensions);
/* Initialize agg centers */
if (type == IVFFLAT_TYPE_VECTOR)
{
/* Use same centers to save memory */
newCenters = aggCenters;
}
else if (type == IVFFLAT_TYPE_HALFVEC)
{
newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize);
newCenters->length = numCenters;
for (int j = 0; j < numCenters; j++)
{
HalfVector *vec = (HalfVector *) VectorArrayGet(newCenters, j);
SET_VARSIZE(vec, HALFVEC_SIZE(dimensions));
vec->dim = dimensions;
}
}
else if (type == IVFFLAT_TYPE_BIT)
{
newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize);
newCenters->length = numCenters;
for (int j = 0; j < numCenters; j++)
{
VarBit *vec = (VarBit *) VectorArrayGet(newCenters, j);
SET_VARSIZE(vec, VARBITTOTALLEN(dimensions));
VARBITLEN(vec) = dimensions;
}
aggCenters = newCenters;
}
else
elog(ERROR, "Unsupported type");
{
aggCenters = VectorArrayInit(numCenters, dimensions, VECTOR_SIZE(dimensions));
aggCenters->length = numCenters;
for (int j = 0; j < numCenters; j++)
VectorInitNewCenter(VectorArrayGet(aggCenters, j), dimensions);
}
#ifdef IVFFLAT_MEMORY
ShowMemoryUsage(oldCtx, totalSize);