From cd8a25bc9a4ee5d9f473099cf3211c2862609f17 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Wed, 24 Apr 2024 17:45:48 -0700 Subject: [PATCH] Removed IvfflatType from more functions [skip ci] --- src/ivfkmeans.c | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index b5b29ef..4713929 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -19,6 +19,7 @@ typedef struct KmeansState void (*setNewCenter) (Pointer v, float *x); void (*sumCenter) (Pointer v, float *x); int (*comp) (const void *a, const void *b); + bool separateAgg; } KmeansState; /* @@ -222,7 +223,7 @@ BitSetNewCenter(Pointer v, float *x) * Quick approach if we have little data */ static void -QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatType type, KmeansState * kmeansstate) +QuickCenters(Relation index, VectorArray samples, VectorArray centers, KmeansState * kmeansstate) { int dimensions = centers->dim; Oid collation = index->rd_indcollation[0]; @@ -330,7 +331,7 @@ SumCenters(VectorArray samples, VectorArray aggCenters, int *closestCenters, Kme * Set new centers */ static void -SetNewCenters(VectorArray aggCenters, VectorArray newCenters, IvfflatType type, KmeansState * kmeansstate) +SetNewCenters(VectorArray aggCenters, VectorArray newCenters, KmeansState * kmeansstate) { for (int j = 0; j < aggCenters->length; j++) { @@ -344,7 +345,7 @@ SetNewCenters(VectorArray aggCenters, VectorArray newCenters, IvfflatType type, * Compute new centers */ static void -ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, IvfflatType type, KmeansState * kmeansstate) +ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, KmeansState * kmeansstate) { int dimensions = aggCenters->dim; int numCenters = aggCenters->maxlen; @@ -395,8 +396,8 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe } /* Set new centers if different from agg centers */ - if (type != IVFFLAT_TYPE_VECTOR) - SetNewCenters(aggCenters, newCenters, type, kmeansstate); + if (kmeansstate->separateAgg) + SetNewCenters(aggCenters, newCenters, kmeansstate); /* Normalize if needed */ if (normprocinfo != NULL) @@ -412,7 +413,7 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe * https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf */ static void -ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type, KmeansState * kmeansstate) +ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, KmeansState * kmeansstate) { FmgrInfo *procinfo; FmgrInfo *normprocinfo; @@ -437,7 +438,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->itemsize); Size centersSize = VECTOR_ARRAY_SIZE(centers->maxlen, centers->itemsize); Size newCentersSize = VECTOR_ARRAY_SIZE(numCenters, centers->itemsize); - Size aggCentersSize = type == IVFFLAT_TYPE_VECTOR ? 0 : VECTOR_ARRAY_SIZE(numCenters, VECTOR_SIZE(dimensions)); + Size aggCentersSize = !kmeansstate->separateAgg ? 0 : VECTOR_ARRAY_SIZE(numCenters, VECTOR_SIZE(dimensions)); Size centerCountsSize = sizeof(int) * numCenters; Size closestCentersSize = sizeof(int) * numSamples; Size lowerBoundSize = sizeof(float) * numSamples * numCenters; @@ -491,7 +492,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp kmeansstate->initNewCenter(VectorArrayGet(newCenters, j), dimensions); /* Initialize agg centers */ - if (type == IVFFLAT_TYPE_VECTOR) + if (!kmeansstate->separateAgg) { /* Use same centers to save memory */ aggCenters = newCenters; @@ -643,7 +644,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp } /* Step 4: For each center c, let m(c) be mean of all points assigned */ - ComputeNewCenters(samples, aggCenters, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, type, kmeansstate); + ComputeNewCenters(samples, aggCenters, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, kmeansstate); /* Step 5 */ for (int j = 0; j < numCenters; j++) @@ -761,6 +762,7 @@ InitKmeansState(KmeansState * kmeansstate, IvfflatType type) kmeansstate->setNewCenter = VectorSetNewCenter; kmeansstate->sumCenter = VectorSumCenter; kmeansstate->comp = CompareVectors; + kmeansstate->separateAgg = false; } else if (type == IVFFLAT_TYPE_HALFVEC) { @@ -768,6 +770,7 @@ InitKmeansState(KmeansState * kmeansstate, IvfflatType type) kmeansstate->setNewCenter = HalfvecSetNewCenter; kmeansstate->sumCenter = HalfvecSumCenter; kmeansstate->comp = CompareHalfVectors; + kmeansstate->separateAgg = true; } else if (type == IVFFLAT_TYPE_BIT) { @@ -775,6 +778,7 @@ InitKmeansState(KmeansState * kmeansstate, IvfflatType type) kmeansstate->setNewCenter = BitSetNewCenter; kmeansstate->sumCenter = BitSumCenter; kmeansstate->comp = CompareBitVectors; + kmeansstate->separateAgg = true; } else elog(ERROR, "Unsupported type"); @@ -792,9 +796,9 @@ IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatT InitKmeansState(&kmeansstate, type); if (samples->length <= centers->maxlen) - QuickCenters(index, samples, centers, type, &kmeansstate); + QuickCenters(index, samples, centers, &kmeansstate); else - ElkanKmeans(index, samples, centers, type, &kmeansstate); + ElkanKmeans(index, samples, centers, &kmeansstate); CheckCenters(index, centers, type, &kmeansstate); }