diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index 8d2c793..f61ba1d 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -143,6 +143,22 @@ CompareBitVectors(const void *a, const void *b) return DirectFunctionCall2(bitcmp, VarBitPGetDatum((VarBit *) a), VarBitPGetDatum((VarBit *) b)); } +/* + * Sort vector array + */ +static void +SortVectorArray(VectorArray arr, IvfflatType type) +{ + if (type == IVFFLAT_TYPE_VECTOR) + qsort(arr->items, arr->length, arr->itemsize, CompareVectors); + else if (type == IVFFLAT_TYPE_HALFVEC) + qsort(arr->items, arr->length, arr->itemsize, CompareHalfVectors); + else if (type == IVFFLAT_TYPE_BIT) + qsort(arr->items, arr->length, arr->itemsize, CompareBitVectors); + else + elog(ERROR, "Unsupported type"); +} + /* * Quick approach if we have little data */ @@ -157,14 +173,7 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy /* Copy existing vectors while avoiding duplicates */ if (samples->length > 0) { - if (type == IVFFLAT_TYPE_VECTOR) - qsort(samples->items, samples->length, samples->itemsize, CompareVectors); - else if (type == IVFFLAT_TYPE_HALFVEC) - qsort(samples->items, samples->length, samples->itemsize, CompareHalfVectors); - else if (type == IVFFLAT_TYPE_BIT) - qsort(samples->items, samples->length, samples->itemsize, CompareBitVectors); - else - elog(ERROR, "Unsupported type"); + SortVectorArray(samples, type); for (int i = 0; i < samples->length; i++) { @@ -708,13 +717,7 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type) if (type != IVFFLAT_TYPE_BIT) { /* Ensure no duplicate centers */ - /* Fine to sort in-place */ - if (type == IVFFLAT_TYPE_VECTOR) - qsort(centers->items, centers->length, centers->itemsize, CompareVectors); - else if (type == IVFFLAT_TYPE_HALFVEC) - qsort(centers->items, centers->length, centers->itemsize, CompareHalfVectors); - else - elog(ERROR, "Unsupported type"); + SortVectorArray(centers, type); for (int i = 1; i < centers->length; i++) {