DRY code for sorting vector arrays [skip ci]

This commit is contained in:
Andrew Kane
2024-04-23 15:59:42 -07:00
parent 99d367edc0
commit bbfb3f200a

View File

@@ -143,6 +143,22 @@ CompareBitVectors(const void *a, const void *b)
return DirectFunctionCall2(bitcmp, VarBitPGetDatum((VarBit *) a), VarBitPGetDatum((VarBit *) b));
}
/*
* Sort vector array
*/
static void
SortVectorArray(VectorArray arr, IvfflatType type)
{
if (type == IVFFLAT_TYPE_VECTOR)
qsort(arr->items, arr->length, arr->itemsize, CompareVectors);
else if (type == IVFFLAT_TYPE_HALFVEC)
qsort(arr->items, arr->length, arr->itemsize, CompareHalfVectors);
else if (type == IVFFLAT_TYPE_BIT)
qsort(arr->items, arr->length, arr->itemsize, CompareBitVectors);
else
elog(ERROR, "Unsupported type");
}
/*
* Quick approach if we have little data
*/
@@ -157,14 +173,7 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy
/* Copy existing vectors while avoiding duplicates */
if (samples->length > 0)
{
if (type == IVFFLAT_TYPE_VECTOR)
qsort(samples->items, samples->length, samples->itemsize, CompareVectors);
else if (type == IVFFLAT_TYPE_HALFVEC)
qsort(samples->items, samples->length, samples->itemsize, CompareHalfVectors);
else if (type == IVFFLAT_TYPE_BIT)
qsort(samples->items, samples->length, samples->itemsize, CompareBitVectors);
else
elog(ERROR, "Unsupported type");
SortVectorArray(samples, type);
for (int i = 0; i < samples->length; i++)
{
@@ -708,13 +717,7 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type)
if (type != IVFFLAT_TYPE_BIT)
{
/* Ensure no duplicate centers */
/* Fine to sort in-place */
if (type == IVFFLAT_TYPE_VECTOR)
qsort(centers->items, centers->length, centers->itemsize, CompareVectors);
else if (type == IVFFLAT_TYPE_HALFVEC)
qsort(centers->items, centers->length, centers->itemsize, CompareHalfVectors);
else
elog(ERROR, "Unsupported type");
SortVectorArray(centers, type);
for (int i = 1; i < centers->length; i++)
{