From 6bb5de3d1b047c193c1ee5425e33220107d5ceee Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Wed, 24 Apr 2024 17:40:21 -0700 Subject: [PATCH] Added KmeansState [skip ci] --- src/ivfkmeans.c | 148 +++++++++++++++++++++--------------------------- 1 file changed, 64 insertions(+), 84 deletions(-) diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index e14926b..b5b29ef 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -13,6 +13,14 @@ #include "utils/memutils.h" #include "vector.h" +typedef struct KmeansState +{ + void (*initNewCenter) (Pointer v, int dimensions); + void (*setNewCenter) (Pointer v, float *x); + void (*sumCenter) (Pointer v, float *x); + int (*comp) (const void *a, const void *b); +} KmeansState; + /* * Initialize with kmeans++ * @@ -147,20 +155,9 @@ CompareBitVectors(const void *a, const void *b) * Sort vector array */ static void -SortVectorArray(VectorArray arr, IvfflatType type) +SortVectorArray(VectorArray arr, KmeansState * kmeansstate) { - int (*comp) (const void *a, const void *b); - - if (type == IVFFLAT_TYPE_VECTOR) - comp = CompareVectors; - else if (type == IVFFLAT_TYPE_HALFVEC) - comp = CompareHalfVectors; - else if (type == IVFFLAT_TYPE_BIT) - comp = CompareBitVectors; - else - elog(ERROR, "Unsupported type"); - - qsort(arr->items, arr->length, arr->itemsize, comp); + qsort(arr->items, arr->length, arr->itemsize, kmeansstate->comp); } static void @@ -225,38 +222,18 @@ BitSetNewCenter(Pointer v, float *x) * Quick approach if we have little data */ static void -QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatType type) +QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatType type, KmeansState * kmeansstate) { int dimensions = centers->dim; Oid collation = index->rd_indcollation[0]; FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC); FmgrInfo *normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC); - void (*initNewCenter) (Pointer v, int dimensions); - void (*setNewCenter) (Pointer v, float *x); float *x = (float *) palloc(sizeof(float) * dimensions); - if (type == IVFFLAT_TYPE_VECTOR) - { - initNewCenter = VectorInitNewCenter; - setNewCenter = VectorSetNewCenter; - } - else if (type == IVFFLAT_TYPE_HALFVEC) - { - initNewCenter = HalfvecInitNewCenter; - setNewCenter = HalfvecSetNewCenter; - } - else if (type == IVFFLAT_TYPE_BIT) - { - initNewCenter = BitInitNewCenter; - setNewCenter = BitSetNewCenter; - } - else - elog(ERROR, "Unsupported type"); - /* Copy existing vectors while avoiding duplicates */ if (samples->length > 0) { - SortVectorArray(samples, type); + SortVectorArray(samples, kmeansstate); for (int i = 0; i < samples->length; i++) { @@ -278,8 +255,8 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy for (int i = 0; i < dimensions; i++) x[i] = (float) RandomDouble(); - initNewCenter(center, dimensions); - setNewCenter(center, x); + kmeansstate->initNewCenter(center, dimensions); + kmeansstate->setNewCenter(center, x); centers->length++; } @@ -339,24 +316,13 @@ BitSumCenter(Pointer v, float *x) * Sum centers */ static void -SumCenters(VectorArray samples, VectorArray aggCenters, int *closestCenters, IvfflatType type) +SumCenters(VectorArray samples, VectorArray aggCenters, int *closestCenters, KmeansState * kmeansstate) { - void (*sumCenter) (Pointer v, float *x); - - if (type == IVFFLAT_TYPE_VECTOR) - sumCenter = VectorSumCenter; - else if (type == IVFFLAT_TYPE_HALFVEC) - sumCenter = HalfvecSumCenter; - else if (type == IVFFLAT_TYPE_BIT) - sumCenter = BitSumCenter; - else - elog(ERROR, "Unsupported type"); - for (int j = 0; j < samples->length; j++) { Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenters[j]); - sumCenter(VectorArrayGet(samples, j), aggCenter->x); + kmeansstate->sumCenter(VectorArrayGet(samples, j), aggCenter->x); } } @@ -364,22 +330,13 @@ SumCenters(VectorArray samples, VectorArray aggCenters, int *closestCenters, Ivf * Set new centers */ static void -SetNewCenters(VectorArray aggCenters, VectorArray newCenters, IvfflatType type) +SetNewCenters(VectorArray aggCenters, VectorArray newCenters, IvfflatType type, KmeansState * kmeansstate) { - void (*setNewCenter) (Pointer v, float *x); - - if (type == IVFFLAT_TYPE_HALFVEC) - setNewCenter = HalfvecSetNewCenter; - else if (type == IVFFLAT_TYPE_BIT) - setNewCenter = BitSetNewCenter; - else - return; - for (int j = 0; j < aggCenters->length; j++) { Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, j); - setNewCenter(VectorArrayGet(newCenters, j), aggCenter->x); + kmeansstate->setNewCenter(VectorArrayGet(newCenters, j), aggCenter->x); } } @@ -387,7 +344,7 @@ SetNewCenters(VectorArray aggCenters, VectorArray newCenters, IvfflatType type) * Compute new centers */ static void -ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, IvfflatType type) +ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, IvfflatType type, KmeansState * kmeansstate) { int dimensions = aggCenters->dim; int numCenters = aggCenters->maxlen; @@ -405,7 +362,7 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe } /* Increment sum of closest center */ - SumCenters(samples, aggCenters, closestCenters, type); + SumCenters(samples, aggCenters, closestCenters, kmeansstate); /* Increment count of closest center */ for (int j = 0; j < numSamples; j++) @@ -438,7 +395,8 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe } /* Set new centers if different from agg centers */ - SetNewCenters(aggCenters, newCenters, type); + if (type != IVFFLAT_TYPE_VECTOR) + SetNewCenters(aggCenters, newCenters, type, kmeansstate); /* Normalize if needed */ if (normprocinfo != NULL) @@ -454,7 +412,7 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe * https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf */ static void -ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type) +ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type, KmeansState * kmeansstate) { FmgrInfo *procinfo; FmgrInfo *normprocinfo; @@ -474,7 +432,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp float *newcdist; MemoryContext kmeansCtx; MemoryContext oldCtx; - void (*initNewCenter) (Pointer v, int dimensions); /* Calculate allocation sizes */ Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->itemsize); @@ -527,20 +484,11 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp newcdist = palloc(newcdistSize); /* Initialize new centers */ - if (type == IVFFLAT_TYPE_VECTOR) - initNewCenter = VectorInitNewCenter; - else if (type == IVFFLAT_TYPE_HALFVEC) - initNewCenter = HalfvecInitNewCenter; - else if (type == IVFFLAT_TYPE_BIT) - initNewCenter = BitInitNewCenter; - else - elog(ERROR, "Unsupported type"); - newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize); newCenters->length = numCenters; for (int j = 0; j < numCenters; j++) - initNewCenter(VectorArrayGet(newCenters, j), dimensions); + kmeansstate->initNewCenter(VectorArrayGet(newCenters, j), dimensions); /* Initialize agg centers */ if (type == IVFFLAT_TYPE_VECTOR) @@ -695,7 +643,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp } /* Step 4: For each center c, let m(c) be mean of all points assigned */ - ComputeNewCenters(samples, aggCenters, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, type); + ComputeNewCenters(samples, aggCenters, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, type, kmeansstate); /* Step 5 */ for (int j = 0; j < numCenters; j++) @@ -735,7 +683,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp * Detect issues with centers */ static void -CheckCenters(Relation index, VectorArray centers, IvfflatType type) +CheckCenters(Relation index, VectorArray centers, IvfflatType type, KmeansState * kmeansstate) { FmgrInfo *normprocinfo; @@ -778,7 +726,7 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type) if (type != IVFFLAT_TYPE_BIT) { /* Ensure no duplicate centers */ - SortVectorArray(centers, type); + SortVectorArray(centers, kmeansstate); for (int i = 1; i < centers->length; i++) { @@ -804,6 +752,34 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type) } } +static void +InitKmeansState(KmeansState * kmeansstate, IvfflatType type) +{ + if (type == IVFFLAT_TYPE_VECTOR) + { + kmeansstate->initNewCenter = VectorInitNewCenter; + kmeansstate->setNewCenter = VectorSetNewCenter; + kmeansstate->sumCenter = VectorSumCenter; + kmeansstate->comp = CompareVectors; + } + else if (type == IVFFLAT_TYPE_HALFVEC) + { + kmeansstate->initNewCenter = HalfvecInitNewCenter; + kmeansstate->setNewCenter = HalfvecSetNewCenter; + kmeansstate->sumCenter = HalfvecSumCenter; + kmeansstate->comp = CompareHalfVectors; + } + else if (type == IVFFLAT_TYPE_BIT) + { + kmeansstate->initNewCenter = BitInitNewCenter; + kmeansstate->setNewCenter = BitSetNewCenter; + kmeansstate->sumCenter = BitSumCenter; + kmeansstate->comp = CompareBitVectors; + } + else + elog(ERROR, "Unsupported type"); +} + /* * Perform naive k-means centering * We use spherical k-means for inner product and cosine @@ -811,10 +787,14 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type) void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type) { - if (samples->length <= centers->maxlen) - QuickCenters(index, samples, centers, type); - else - ElkanKmeans(index, samples, centers, type); + KmeansState kmeansstate; - CheckCenters(index, centers, type); + InitKmeansState(&kmeansstate, type); + + if (samples->length <= centers->maxlen) + QuickCenters(index, samples, centers, type, &kmeansstate); + else + ElkanKmeans(index, samples, centers, type, &kmeansstate); + + CheckCenters(index, centers, type, &kmeansstate); }