mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
Added KmeansState [skip ci]
This commit is contained in:
148
src/ivfkmeans.c
148
src/ivfkmeans.c
@@ -13,6 +13,14 @@
|
||||
#include "utils/memutils.h"
|
||||
#include "vector.h"
|
||||
|
||||
typedef struct KmeansState
|
||||
{
|
||||
void (*initNewCenter) (Pointer v, int dimensions);
|
||||
void (*setNewCenter) (Pointer v, float *x);
|
||||
void (*sumCenter) (Pointer v, float *x);
|
||||
int (*comp) (const void *a, const void *b);
|
||||
} KmeansState;
|
||||
|
||||
/*
|
||||
* Initialize with kmeans++
|
||||
*
|
||||
@@ -147,20 +155,9 @@ CompareBitVectors(const void *a, const void *b)
|
||||
* Sort vector array
|
||||
*/
|
||||
static void
|
||||
SortVectorArray(VectorArray arr, IvfflatType type)
|
||||
SortVectorArray(VectorArray arr, KmeansState * kmeansstate)
|
||||
{
|
||||
int (*comp) (const void *a, const void *b);
|
||||
|
||||
if (type == IVFFLAT_TYPE_VECTOR)
|
||||
comp = CompareVectors;
|
||||
else if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
comp = CompareHalfVectors;
|
||||
else if (type == IVFFLAT_TYPE_BIT)
|
||||
comp = CompareBitVectors;
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
|
||||
qsort(arr->items, arr->length, arr->itemsize, comp);
|
||||
qsort(arr->items, arr->length, arr->itemsize, kmeansstate->comp);
|
||||
}
|
||||
|
||||
static void
|
||||
@@ -225,38 +222,18 @@ BitSetNewCenter(Pointer v, float *x)
|
||||
* Quick approach if we have little data
|
||||
*/
|
||||
static void
|
||||
QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatType type)
|
||||
QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatType type, KmeansState * kmeansstate)
|
||||
{
|
||||
int dimensions = centers->dim;
|
||||
Oid collation = index->rd_indcollation[0];
|
||||
FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
|
||||
FmgrInfo *normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
|
||||
void (*initNewCenter) (Pointer v, int dimensions);
|
||||
void (*setNewCenter) (Pointer v, float *x);
|
||||
float *x = (float *) palloc(sizeof(float) * dimensions);
|
||||
|
||||
if (type == IVFFLAT_TYPE_VECTOR)
|
||||
{
|
||||
initNewCenter = VectorInitNewCenter;
|
||||
setNewCenter = VectorSetNewCenter;
|
||||
}
|
||||
else if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
{
|
||||
initNewCenter = HalfvecInitNewCenter;
|
||||
setNewCenter = HalfvecSetNewCenter;
|
||||
}
|
||||
else if (type == IVFFLAT_TYPE_BIT)
|
||||
{
|
||||
initNewCenter = BitInitNewCenter;
|
||||
setNewCenter = BitSetNewCenter;
|
||||
}
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
|
||||
/* Copy existing vectors while avoiding duplicates */
|
||||
if (samples->length > 0)
|
||||
{
|
||||
SortVectorArray(samples, type);
|
||||
SortVectorArray(samples, kmeansstate);
|
||||
|
||||
for (int i = 0; i < samples->length; i++)
|
||||
{
|
||||
@@ -278,8 +255,8 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy
|
||||
for (int i = 0; i < dimensions; i++)
|
||||
x[i] = (float) RandomDouble();
|
||||
|
||||
initNewCenter(center, dimensions);
|
||||
setNewCenter(center, x);
|
||||
kmeansstate->initNewCenter(center, dimensions);
|
||||
kmeansstate->setNewCenter(center, x);
|
||||
|
||||
centers->length++;
|
||||
}
|
||||
@@ -339,24 +316,13 @@ BitSumCenter(Pointer v, float *x)
|
||||
* Sum centers
|
||||
*/
|
||||
static void
|
||||
SumCenters(VectorArray samples, VectorArray aggCenters, int *closestCenters, IvfflatType type)
|
||||
SumCenters(VectorArray samples, VectorArray aggCenters, int *closestCenters, KmeansState * kmeansstate)
|
||||
{
|
||||
void (*sumCenter) (Pointer v, float *x);
|
||||
|
||||
if (type == IVFFLAT_TYPE_VECTOR)
|
||||
sumCenter = VectorSumCenter;
|
||||
else if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
sumCenter = HalfvecSumCenter;
|
||||
else if (type == IVFFLAT_TYPE_BIT)
|
||||
sumCenter = BitSumCenter;
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
|
||||
for (int j = 0; j < samples->length; j++)
|
||||
{
|
||||
Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenters[j]);
|
||||
|
||||
sumCenter(VectorArrayGet(samples, j), aggCenter->x);
|
||||
kmeansstate->sumCenter(VectorArrayGet(samples, j), aggCenter->x);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -364,22 +330,13 @@ SumCenters(VectorArray samples, VectorArray aggCenters, int *closestCenters, Ivf
|
||||
* Set new centers
|
||||
*/
|
||||
static void
|
||||
SetNewCenters(VectorArray aggCenters, VectorArray newCenters, IvfflatType type)
|
||||
SetNewCenters(VectorArray aggCenters, VectorArray newCenters, IvfflatType type, KmeansState * kmeansstate)
|
||||
{
|
||||
void (*setNewCenter) (Pointer v, float *x);
|
||||
|
||||
if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
setNewCenter = HalfvecSetNewCenter;
|
||||
else if (type == IVFFLAT_TYPE_BIT)
|
||||
setNewCenter = BitSetNewCenter;
|
||||
else
|
||||
return;
|
||||
|
||||
for (int j = 0; j < aggCenters->length; j++)
|
||||
{
|
||||
Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, j);
|
||||
|
||||
setNewCenter(VectorArrayGet(newCenters, j), aggCenter->x);
|
||||
kmeansstate->setNewCenter(VectorArrayGet(newCenters, j), aggCenter->x);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -387,7 +344,7 @@ SetNewCenters(VectorArray aggCenters, VectorArray newCenters, IvfflatType type)
|
||||
* Compute new centers
|
||||
*/
|
||||
static void
|
||||
ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, IvfflatType type)
|
||||
ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, IvfflatType type, KmeansState * kmeansstate)
|
||||
{
|
||||
int dimensions = aggCenters->dim;
|
||||
int numCenters = aggCenters->maxlen;
|
||||
@@ -405,7 +362,7 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe
|
||||
}
|
||||
|
||||
/* Increment sum of closest center */
|
||||
SumCenters(samples, aggCenters, closestCenters, type);
|
||||
SumCenters(samples, aggCenters, closestCenters, kmeansstate);
|
||||
|
||||
/* Increment count of closest center */
|
||||
for (int j = 0; j < numSamples; j++)
|
||||
@@ -438,7 +395,8 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe
|
||||
}
|
||||
|
||||
/* Set new centers if different from agg centers */
|
||||
SetNewCenters(aggCenters, newCenters, type);
|
||||
if (type != IVFFLAT_TYPE_VECTOR)
|
||||
SetNewCenters(aggCenters, newCenters, type, kmeansstate);
|
||||
|
||||
/* Normalize if needed */
|
||||
if (normprocinfo != NULL)
|
||||
@@ -454,7 +412,7 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe
|
||||
* https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf
|
||||
*/
|
||||
static void
|
||||
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type)
|
||||
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type, KmeansState * kmeansstate)
|
||||
{
|
||||
FmgrInfo *procinfo;
|
||||
FmgrInfo *normprocinfo;
|
||||
@@ -474,7 +432,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
|
||||
float *newcdist;
|
||||
MemoryContext kmeansCtx;
|
||||
MemoryContext oldCtx;
|
||||
void (*initNewCenter) (Pointer v, int dimensions);
|
||||
|
||||
/* Calculate allocation sizes */
|
||||
Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->itemsize);
|
||||
@@ -527,20 +484,11 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
|
||||
newcdist = palloc(newcdistSize);
|
||||
|
||||
/* Initialize new centers */
|
||||
if (type == IVFFLAT_TYPE_VECTOR)
|
||||
initNewCenter = VectorInitNewCenter;
|
||||
else if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
initNewCenter = HalfvecInitNewCenter;
|
||||
else if (type == IVFFLAT_TYPE_BIT)
|
||||
initNewCenter = BitInitNewCenter;
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
|
||||
newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize);
|
||||
newCenters->length = numCenters;
|
||||
|
||||
for (int j = 0; j < numCenters; j++)
|
||||
initNewCenter(VectorArrayGet(newCenters, j), dimensions);
|
||||
kmeansstate->initNewCenter(VectorArrayGet(newCenters, j), dimensions);
|
||||
|
||||
/* Initialize agg centers */
|
||||
if (type == IVFFLAT_TYPE_VECTOR)
|
||||
@@ -695,7 +643,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
|
||||
}
|
||||
|
||||
/* Step 4: For each center c, let m(c) be mean of all points assigned */
|
||||
ComputeNewCenters(samples, aggCenters, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, type);
|
||||
ComputeNewCenters(samples, aggCenters, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, type, kmeansstate);
|
||||
|
||||
/* Step 5 */
|
||||
for (int j = 0; j < numCenters; j++)
|
||||
@@ -735,7 +683,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
|
||||
* Detect issues with centers
|
||||
*/
|
||||
static void
|
||||
CheckCenters(Relation index, VectorArray centers, IvfflatType type)
|
||||
CheckCenters(Relation index, VectorArray centers, IvfflatType type, KmeansState * kmeansstate)
|
||||
{
|
||||
FmgrInfo *normprocinfo;
|
||||
|
||||
@@ -778,7 +726,7 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type)
|
||||
if (type != IVFFLAT_TYPE_BIT)
|
||||
{
|
||||
/* Ensure no duplicate centers */
|
||||
SortVectorArray(centers, type);
|
||||
SortVectorArray(centers, kmeansstate);
|
||||
|
||||
for (int i = 1; i < centers->length; i++)
|
||||
{
|
||||
@@ -804,6 +752,34 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type)
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
InitKmeansState(KmeansState * kmeansstate, IvfflatType type)
|
||||
{
|
||||
if (type == IVFFLAT_TYPE_VECTOR)
|
||||
{
|
||||
kmeansstate->initNewCenter = VectorInitNewCenter;
|
||||
kmeansstate->setNewCenter = VectorSetNewCenter;
|
||||
kmeansstate->sumCenter = VectorSumCenter;
|
||||
kmeansstate->comp = CompareVectors;
|
||||
}
|
||||
else if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
{
|
||||
kmeansstate->initNewCenter = HalfvecInitNewCenter;
|
||||
kmeansstate->setNewCenter = HalfvecSetNewCenter;
|
||||
kmeansstate->sumCenter = HalfvecSumCenter;
|
||||
kmeansstate->comp = CompareHalfVectors;
|
||||
}
|
||||
else if (type == IVFFLAT_TYPE_BIT)
|
||||
{
|
||||
kmeansstate->initNewCenter = BitInitNewCenter;
|
||||
kmeansstate->setNewCenter = BitSetNewCenter;
|
||||
kmeansstate->sumCenter = BitSumCenter;
|
||||
kmeansstate->comp = CompareBitVectors;
|
||||
}
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
}
|
||||
|
||||
/*
|
||||
* Perform naive k-means centering
|
||||
* We use spherical k-means for inner product and cosine
|
||||
@@ -811,10 +787,14 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type)
|
||||
void
|
||||
IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type)
|
||||
{
|
||||
if (samples->length <= centers->maxlen)
|
||||
QuickCenters(index, samples, centers, type);
|
||||
else
|
||||
ElkanKmeans(index, samples, centers, type);
|
||||
KmeansState kmeansstate;
|
||||
|
||||
CheckCenters(index, centers, type);
|
||||
InitKmeansState(&kmeansstate, type);
|
||||
|
||||
if (samples->length <= centers->maxlen)
|
||||
QuickCenters(index, samples, centers, type, &kmeansstate);
|
||||
else
|
||||
ElkanKmeans(index, samples, centers, type, &kmeansstate);
|
||||
|
||||
CheckCenters(index, centers, type, &kmeansstate);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user