Added KmeansState [skip ci]

This commit is contained in:
Andrew Kane
2024-04-24 17:40:21 -07:00
parent 15ee38456f
commit 6bb5de3d1b

View File

@@ -13,6 +13,14 @@
#include "utils/memutils.h"
#include "vector.h"
typedef struct KmeansState
{
void (*initNewCenter) (Pointer v, int dimensions);
void (*setNewCenter) (Pointer v, float *x);
void (*sumCenter) (Pointer v, float *x);
int (*comp) (const void *a, const void *b);
} KmeansState;
/*
* Initialize with kmeans++
*
@@ -147,20 +155,9 @@ CompareBitVectors(const void *a, const void *b)
* Sort vector array
*/
static void
SortVectorArray(VectorArray arr, IvfflatType type)
SortVectorArray(VectorArray arr, KmeansState * kmeansstate)
{
int (*comp) (const void *a, const void *b);
if (type == IVFFLAT_TYPE_VECTOR)
comp = CompareVectors;
else if (type == IVFFLAT_TYPE_HALFVEC)
comp = CompareHalfVectors;
else if (type == IVFFLAT_TYPE_BIT)
comp = CompareBitVectors;
else
elog(ERROR, "Unsupported type");
qsort(arr->items, arr->length, arr->itemsize, comp);
qsort(arr->items, arr->length, arr->itemsize, kmeansstate->comp);
}
static void
@@ -225,38 +222,18 @@ BitSetNewCenter(Pointer v, float *x)
* Quick approach if we have little data
*/
static void
QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatType type)
QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatType type, KmeansState * kmeansstate)
{
int dimensions = centers->dim;
Oid collation = index->rd_indcollation[0];
FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
FmgrInfo *normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
void (*initNewCenter) (Pointer v, int dimensions);
void (*setNewCenter) (Pointer v, float *x);
float *x = (float *) palloc(sizeof(float) * dimensions);
if (type == IVFFLAT_TYPE_VECTOR)
{
initNewCenter = VectorInitNewCenter;
setNewCenter = VectorSetNewCenter;
}
else if (type == IVFFLAT_TYPE_HALFVEC)
{
initNewCenter = HalfvecInitNewCenter;
setNewCenter = HalfvecSetNewCenter;
}
else if (type == IVFFLAT_TYPE_BIT)
{
initNewCenter = BitInitNewCenter;
setNewCenter = BitSetNewCenter;
}
else
elog(ERROR, "Unsupported type");
/* Copy existing vectors while avoiding duplicates */
if (samples->length > 0)
{
SortVectorArray(samples, type);
SortVectorArray(samples, kmeansstate);
for (int i = 0; i < samples->length; i++)
{
@@ -278,8 +255,8 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy
for (int i = 0; i < dimensions; i++)
x[i] = (float) RandomDouble();
initNewCenter(center, dimensions);
setNewCenter(center, x);
kmeansstate->initNewCenter(center, dimensions);
kmeansstate->setNewCenter(center, x);
centers->length++;
}
@@ -339,24 +316,13 @@ BitSumCenter(Pointer v, float *x)
* Sum centers
*/
static void
SumCenters(VectorArray samples, VectorArray aggCenters, int *closestCenters, IvfflatType type)
SumCenters(VectorArray samples, VectorArray aggCenters, int *closestCenters, KmeansState * kmeansstate)
{
void (*sumCenter) (Pointer v, float *x);
if (type == IVFFLAT_TYPE_VECTOR)
sumCenter = VectorSumCenter;
else if (type == IVFFLAT_TYPE_HALFVEC)
sumCenter = HalfvecSumCenter;
else if (type == IVFFLAT_TYPE_BIT)
sumCenter = BitSumCenter;
else
elog(ERROR, "Unsupported type");
for (int j = 0; j < samples->length; j++)
{
Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenters[j]);
sumCenter(VectorArrayGet(samples, j), aggCenter->x);
kmeansstate->sumCenter(VectorArrayGet(samples, j), aggCenter->x);
}
}
@@ -364,22 +330,13 @@ SumCenters(VectorArray samples, VectorArray aggCenters, int *closestCenters, Ivf
* Set new centers
*/
static void
SetNewCenters(VectorArray aggCenters, VectorArray newCenters, IvfflatType type)
SetNewCenters(VectorArray aggCenters, VectorArray newCenters, IvfflatType type, KmeansState * kmeansstate)
{
void (*setNewCenter) (Pointer v, float *x);
if (type == IVFFLAT_TYPE_HALFVEC)
setNewCenter = HalfvecSetNewCenter;
else if (type == IVFFLAT_TYPE_BIT)
setNewCenter = BitSetNewCenter;
else
return;
for (int j = 0; j < aggCenters->length; j++)
{
Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, j);
setNewCenter(VectorArrayGet(newCenters, j), aggCenter->x);
kmeansstate->setNewCenter(VectorArrayGet(newCenters, j), aggCenter->x);
}
}
@@ -387,7 +344,7 @@ SetNewCenters(VectorArray aggCenters, VectorArray newCenters, IvfflatType type)
* Compute new centers
*/
static void
ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, IvfflatType type)
ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, IvfflatType type, KmeansState * kmeansstate)
{
int dimensions = aggCenters->dim;
int numCenters = aggCenters->maxlen;
@@ -405,7 +362,7 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe
}
/* Increment sum of closest center */
SumCenters(samples, aggCenters, closestCenters, type);
SumCenters(samples, aggCenters, closestCenters, kmeansstate);
/* Increment count of closest center */
for (int j = 0; j < numSamples; j++)
@@ -438,7 +395,8 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe
}
/* Set new centers if different from agg centers */
SetNewCenters(aggCenters, newCenters, type);
if (type != IVFFLAT_TYPE_VECTOR)
SetNewCenters(aggCenters, newCenters, type, kmeansstate);
/* Normalize if needed */
if (normprocinfo != NULL)
@@ -454,7 +412,7 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe
* https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf
*/
static void
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type)
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type, KmeansState * kmeansstate)
{
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
@@ -474,7 +432,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
float *newcdist;
MemoryContext kmeansCtx;
MemoryContext oldCtx;
void (*initNewCenter) (Pointer v, int dimensions);
/* Calculate allocation sizes */
Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->itemsize);
@@ -527,20 +484,11 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
newcdist = palloc(newcdistSize);
/* Initialize new centers */
if (type == IVFFLAT_TYPE_VECTOR)
initNewCenter = VectorInitNewCenter;
else if (type == IVFFLAT_TYPE_HALFVEC)
initNewCenter = HalfvecInitNewCenter;
else if (type == IVFFLAT_TYPE_BIT)
initNewCenter = BitInitNewCenter;
else
elog(ERROR, "Unsupported type");
newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize);
newCenters->length = numCenters;
for (int j = 0; j < numCenters; j++)
initNewCenter(VectorArrayGet(newCenters, j), dimensions);
kmeansstate->initNewCenter(VectorArrayGet(newCenters, j), dimensions);
/* Initialize agg centers */
if (type == IVFFLAT_TYPE_VECTOR)
@@ -695,7 +643,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
}
/* Step 4: For each center c, let m(c) be mean of all points assigned */
ComputeNewCenters(samples, aggCenters, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, type);
ComputeNewCenters(samples, aggCenters, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, type, kmeansstate);
/* Step 5 */
for (int j = 0; j < numCenters; j++)
@@ -735,7 +683,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
* Detect issues with centers
*/
static void
CheckCenters(Relation index, VectorArray centers, IvfflatType type)
CheckCenters(Relation index, VectorArray centers, IvfflatType type, KmeansState * kmeansstate)
{
FmgrInfo *normprocinfo;
@@ -778,7 +726,7 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type)
if (type != IVFFLAT_TYPE_BIT)
{
/* Ensure no duplicate centers */
SortVectorArray(centers, type);
SortVectorArray(centers, kmeansstate);
for (int i = 1; i < centers->length; i++)
{
@@ -804,6 +752,34 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type)
}
}
static void
InitKmeansState(KmeansState * kmeansstate, IvfflatType type)
{
if (type == IVFFLAT_TYPE_VECTOR)
{
kmeansstate->initNewCenter = VectorInitNewCenter;
kmeansstate->setNewCenter = VectorSetNewCenter;
kmeansstate->sumCenter = VectorSumCenter;
kmeansstate->comp = CompareVectors;
}
else if (type == IVFFLAT_TYPE_HALFVEC)
{
kmeansstate->initNewCenter = HalfvecInitNewCenter;
kmeansstate->setNewCenter = HalfvecSetNewCenter;
kmeansstate->sumCenter = HalfvecSumCenter;
kmeansstate->comp = CompareHalfVectors;
}
else if (type == IVFFLAT_TYPE_BIT)
{
kmeansstate->initNewCenter = BitInitNewCenter;
kmeansstate->setNewCenter = BitSetNewCenter;
kmeansstate->sumCenter = BitSumCenter;
kmeansstate->comp = CompareBitVectors;
}
else
elog(ERROR, "Unsupported type");
}
/*
* Perform naive k-means centering
* We use spherical k-means for inner product and cosine
@@ -811,10 +787,14 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type)
void
IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type)
{
if (samples->length <= centers->maxlen)
QuickCenters(index, samples, centers, type);
else
ElkanKmeans(index, samples, centers, type);
KmeansState kmeansstate;
CheckCenters(index, centers, type);
InitKmeansState(&kmeansstate, type);
if (samples->length <= centers->maxlen)
QuickCenters(index, samples, centers, type, &kmeansstate);
else
ElkanKmeans(index, samples, centers, type, &kmeansstate);
CheckCenters(index, centers, type, &kmeansstate);
}