Removed IvfflatType from more functions [skip ci]

This commit is contained in:
Andrew Kane
2024-04-24 17:45:48 -07:00
parent 6bb5de3d1b
commit cd8a25bc9a

View File

@@ -19,6 +19,7 @@ typedef struct KmeansState
void (*setNewCenter) (Pointer v, float *x);
void (*sumCenter) (Pointer v, float *x);
int (*comp) (const void *a, const void *b);
bool separateAgg;
} KmeansState;
/*
@@ -222,7 +223,7 @@ BitSetNewCenter(Pointer v, float *x)
* Quick approach if we have little data
*/
static void
QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatType type, KmeansState * kmeansstate)
QuickCenters(Relation index, VectorArray samples, VectorArray centers, KmeansState * kmeansstate)
{
int dimensions = centers->dim;
Oid collation = index->rd_indcollation[0];
@@ -330,7 +331,7 @@ SumCenters(VectorArray samples, VectorArray aggCenters, int *closestCenters, Kme
* Set new centers
*/
static void
SetNewCenters(VectorArray aggCenters, VectorArray newCenters, IvfflatType type, KmeansState * kmeansstate)
SetNewCenters(VectorArray aggCenters, VectorArray newCenters, KmeansState * kmeansstate)
{
for (int j = 0; j < aggCenters->length; j++)
{
@@ -344,7 +345,7 @@ SetNewCenters(VectorArray aggCenters, VectorArray newCenters, IvfflatType type,
* Compute new centers
*/
static void
ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, IvfflatType type, KmeansState * kmeansstate)
ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, KmeansState * kmeansstate)
{
int dimensions = aggCenters->dim;
int numCenters = aggCenters->maxlen;
@@ -395,8 +396,8 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe
}
/* Set new centers if different from agg centers */
if (type != IVFFLAT_TYPE_VECTOR)
SetNewCenters(aggCenters, newCenters, type, kmeansstate);
if (kmeansstate->separateAgg)
SetNewCenters(aggCenters, newCenters, kmeansstate);
/* Normalize if needed */
if (normprocinfo != NULL)
@@ -412,7 +413,7 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe
* https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf
*/
static void
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type, KmeansState * kmeansstate)
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, KmeansState * kmeansstate)
{
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
@@ -437,7 +438,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->itemsize);
Size centersSize = VECTOR_ARRAY_SIZE(centers->maxlen, centers->itemsize);
Size newCentersSize = VECTOR_ARRAY_SIZE(numCenters, centers->itemsize);
Size aggCentersSize = type == IVFFLAT_TYPE_VECTOR ? 0 : VECTOR_ARRAY_SIZE(numCenters, VECTOR_SIZE(dimensions));
Size aggCentersSize = !kmeansstate->separateAgg ? 0 : VECTOR_ARRAY_SIZE(numCenters, VECTOR_SIZE(dimensions));
Size centerCountsSize = sizeof(int) * numCenters;
Size closestCentersSize = sizeof(int) * numSamples;
Size lowerBoundSize = sizeof(float) * numSamples * numCenters;
@@ -491,7 +492,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
kmeansstate->initNewCenter(VectorArrayGet(newCenters, j), dimensions);
/* Initialize agg centers */
if (type == IVFFLAT_TYPE_VECTOR)
if (!kmeansstate->separateAgg)
{
/* Use same centers to save memory */
aggCenters = newCenters;
@@ -643,7 +644,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
}
/* Step 4: For each center c, let m(c) be mean of all points assigned */
ComputeNewCenters(samples, aggCenters, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, type, kmeansstate);
ComputeNewCenters(samples, aggCenters, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, kmeansstate);
/* Step 5 */
for (int j = 0; j < numCenters; j++)
@@ -761,6 +762,7 @@ InitKmeansState(KmeansState * kmeansstate, IvfflatType type)
kmeansstate->setNewCenter = VectorSetNewCenter;
kmeansstate->sumCenter = VectorSumCenter;
kmeansstate->comp = CompareVectors;
kmeansstate->separateAgg = false;
}
else if (type == IVFFLAT_TYPE_HALFVEC)
{
@@ -768,6 +770,7 @@ InitKmeansState(KmeansState * kmeansstate, IvfflatType type)
kmeansstate->setNewCenter = HalfvecSetNewCenter;
kmeansstate->sumCenter = HalfvecSumCenter;
kmeansstate->comp = CompareHalfVectors;
kmeansstate->separateAgg = true;
}
else if (type == IVFFLAT_TYPE_BIT)
{
@@ -775,6 +778,7 @@ InitKmeansState(KmeansState * kmeansstate, IvfflatType type)
kmeansstate->setNewCenter = BitSetNewCenter;
kmeansstate->sumCenter = BitSumCenter;
kmeansstate->comp = CompareBitVectors;
kmeansstate->separateAgg = true;
}
else
elog(ERROR, "Unsupported type");
@@ -792,9 +796,9 @@ IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatT
InitKmeansState(&kmeansstate, type);
if (samples->length <= centers->maxlen)
QuickCenters(index, samples, centers, type, &kmeansstate);
QuickCenters(index, samples, centers, &kmeansstate);
else
ElkanKmeans(index, samples, centers, type, &kmeansstate);
ElkanKmeans(index, samples, centers, &kmeansstate);
CheckCenters(index, centers, type, &kmeansstate);
}