diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index 4713929..774f530 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -20,6 +20,7 @@ typedef struct KmeansState void (*sumCenter) (Pointer v, float *x); int (*comp) (const void *a, const void *b); bool separateAgg; + bool checkDuplicates; } KmeansState; /* @@ -724,7 +725,7 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type, KmeansState elog(ERROR, "Unsupported type"); } - if (type != IVFFLAT_TYPE_BIT) + if (kmeansstate->checkDuplicates) { /* Ensure no duplicate centers */ SortVectorArray(centers, kmeansstate); @@ -763,6 +764,7 @@ InitKmeansState(KmeansState * kmeansstate, IvfflatType type) kmeansstate->sumCenter = VectorSumCenter; kmeansstate->comp = CompareVectors; kmeansstate->separateAgg = false; + kmeansstate->checkDuplicates = true; } else if (type == IVFFLAT_TYPE_HALFVEC) { @@ -771,6 +773,7 @@ InitKmeansState(KmeansState * kmeansstate, IvfflatType type) kmeansstate->sumCenter = HalfvecSumCenter; kmeansstate->comp = CompareHalfVectors; kmeansstate->separateAgg = true; + kmeansstate->checkDuplicates = true; } else if (type == IVFFLAT_TYPE_BIT) { @@ -779,6 +782,7 @@ InitKmeansState(KmeansState * kmeansstate, IvfflatType type) kmeansstate->sumCenter = BitSumCenter; kmeansstate->comp = CompareBitVectors; kmeansstate->separateAgg = true; + kmeansstate->checkDuplicates = false; } else elog(ERROR, "Unsupported type");