diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index 1f81501..e14926b 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -474,6 +474,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp float *newcdist; MemoryContext kmeansCtx; MemoryContext oldCtx; + void (*initNewCenter) (Pointer v, int dimensions); /* Calculate allocation sizes */ Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->itemsize); @@ -525,50 +526,36 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp halfcdist = palloc_extended(halfcdistSize, MCXT_ALLOC_HUGE); newcdist = palloc(newcdistSize); - aggCenters = VectorArrayInit(numCenters, dimensions, VECTOR_SIZE(dimensions)); - aggCenters->length = numCenters; + /* Initialize new centers */ + if (type == IVFFLAT_TYPE_VECTOR) + initNewCenter = VectorInitNewCenter; + else if (type == IVFFLAT_TYPE_HALFVEC) + initNewCenter = HalfvecInitNewCenter; + else if (type == IVFFLAT_TYPE_BIT) + initNewCenter = BitInitNewCenter; + else + elog(ERROR, "Unsupported type"); + + newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize); + newCenters->length = numCenters; for (int j = 0; j < numCenters; j++) - { - Vector *vec = (Vector *) VectorArrayGet(aggCenters, j); - - SET_VARSIZE(vec, VECTOR_SIZE(dimensions)); - vec->dim = dimensions; - } + initNewCenter(VectorArrayGet(newCenters, j), dimensions); + /* Initialize agg centers */ if (type == IVFFLAT_TYPE_VECTOR) { /* Use same centers to save memory */ - newCenters = aggCenters; - } - else if (type == IVFFLAT_TYPE_HALFVEC) - { - newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize); - newCenters->length = numCenters; - - for (int j = 0; j < numCenters; j++) - { - HalfVector *vec = (HalfVector *) VectorArrayGet(newCenters, j); - - SET_VARSIZE(vec, HALFVEC_SIZE(dimensions)); - vec->dim = dimensions; - } - } - else if (type == IVFFLAT_TYPE_BIT) - { - newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize); - newCenters->length = numCenters; - - for (int j = 0; j < numCenters; j++) - { - VarBit *vec = (VarBit *) VectorArrayGet(newCenters, j); - - SET_VARSIZE(vec, VARBITTOTALLEN(dimensions)); - VARBITLEN(vec) = dimensions; - } + aggCenters = newCenters; } else - elog(ERROR, "Unsupported type"); + { + aggCenters = VectorArrayInit(numCenters, dimensions, VECTOR_SIZE(dimensions)); + aggCenters->length = numCenters; + + for (int j = 0; j < numCenters; j++) + VectorInitNewCenter(VectorArrayGet(aggCenters, j), dimensions); + } #ifdef IVFFLAT_MEMORY ShowMemoryUsage(oldCtx, totalSize);