From 8ee6d0e596c7dfc29901ce9adeedf81936323cd0 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sat, 12 Feb 2022 22:56:00 -0800 Subject: [PATCH] Switched to mini-batch k-means --- CHANGELOG.md | 4 + README.md | 9 +- src/ivfbuild.c | 117 +----------- src/ivfflat.c | 2 - src/ivfflat.h | 9 +- src/ivfkmeans.c | 472 +++++++++++++++++++++--------------------------- 6 files changed, 220 insertions(+), 393 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 240efd8..29896b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.2.6 (unreleased) + +- Switched to mini-batch k-means + ## 0.2.5 (2022-02-11) - Reduced memory usage during index creation diff --git a/README.md b/README.md index 7f25647..c3b9993 100644 --- a/README.md +++ b/README.md @@ -119,10 +119,9 @@ SELECT phase, tuples_done, tuples_total FROM pg_stat_progress_create_index; The phases are: 1. `initializing` -2. `sampling table` -3. `performing k-means` -4. `sorting tuples` -5. `loading tuples` +2. `performing k-means` +3. `sorting tuples` +4. `loading tuples` Note: `tuples_done` and `tuples_total` are only populated during the `loading tuples` phase @@ -264,7 +263,7 @@ Thanks to: - [PASE: PostgreSQL Ultra-High-Dimensional Approximate Nearest Neighbor Search Extension](https://dl.acm.org/doi/pdf/10.1145/3318464.3386131) - [Faiss: A Library for Efficient Similarity Search and Clustering of Dense Vectors](https://github.com/facebookresearch/faiss) -- [Using the Triangle Inequality to Accelerate k-means](https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf) +- [Web-Scale k-means Clustering](https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf) - [k-means++: The Advantage of Careful Seeding](https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf) - [Concept Decompositions for Large Sparse Text Data using Clustering](https://www.cs.utexas.edu/users/inderjit/public_papers/concept_mlj.pdf) diff --git a/src/ivfbuild.c b/src/ivfbuild.c index 65e4f38..15538c6 100644 --- a/src/ivfbuild.c +++ b/src/ivfbuild.c @@ -47,87 +47,6 @@ UpdateProgress(int index, int64 val) #endif } -/* - * Callback for sampling - */ -static void -SampleCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values, - bool *isnull, bool tupleIsAlive, void *state) -{ - IvfflatBuildState *buildstate = (IvfflatBuildState *) state; - VectorArray samples = buildstate->samples; - int targsamples = samples->maxlen; - Datum value = values[0]; - - /* Skip nulls */ - if (isnull[0]) - return; - - /* - * Normalize with KMEANS_NORM_PROC since spherical distance function - * expects unit vectors - */ - if (buildstate->kmeansnormprocinfo != NULL) - { - if (!IvfflatNormValue(buildstate->kmeansnormprocinfo, buildstate->collation, &value, buildstate->normvec)) - return; - } - - if (samples->length < targsamples) - { - VectorArraySet(samples, samples->length, DatumGetVector(value)); - samples->length++; - } - else - { - if (buildstate->rowstoskip < 0) - buildstate->rowstoskip = reservoir_get_next_S(&buildstate->rstate, samples->length, targsamples); - - if (buildstate->rowstoskip <= 0) - { - int k = (int) (targsamples * sampler_random_fract(buildstate->rstate.randstate)); - - Assert(k >= 0 && k < targsamples); - VectorArraySet(samples, k, DatumGetVector(value)); - } - - buildstate->rowstoskip -= 1; - } -} - -/* - * Sample rows with same logic as ANALYZE - */ -static void -SampleRows(IvfflatBuildState * buildstate) -{ - int targsamples = buildstate->samples->maxlen; - BlockNumber totalblocks = RelationGetNumberOfBlocks(buildstate->heap); - - UpdateProgress(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_IVFFLAT_PHASE_SAMPLE); - - buildstate->rowstoskip = -1; - - BlockSampler_Init(&buildstate->bs, totalblocks, targsamples, random()); - - reservoir_init_selection_state(&buildstate->rstate, targsamples); - while (BlockSampler_HasMore(&buildstate->bs)) - { - BlockNumber targblock = BlockSampler_Next(&buildstate->bs); - -#if PG_VERSION_NUM >= 120000 - table_index_build_range_scan(buildstate->heap, buildstate->index, buildstate->indexInfo, - false, true, true, targblock, 1, SampleCallback, (void *) buildstate, NULL); -#elif PG_VERSION_NUM >= 110000 - IndexBuildHeapRangeScan(buildstate->heap, buildstate->index, buildstate->indexInfo, - true, true, targblock, 1, SampleCallback, (void *) buildstate, NULL); -#else - IndexBuildHeapRangeScan(buildstate->heap, buildstate->index, buildstate->indexInfo, - true, true, targblock, 1, SampleCallback, (void *) buildstate); -#endif - } -} - /* * Callback for table_index_build_scan */ @@ -371,38 +290,6 @@ FreeBuildState(IvfflatBuildState * buildstate) pfree(buildstate->normvec); } -/* - * Compute centers - */ -static void -ComputeCenters(IvfflatBuildState * buildstate) -{ - int numSamples; - - /* Target 50 samples per list, with at least 10000 samples */ - /* The number of samples has a large effect on index build time */ - numSamples = buildstate->lists * 50; - if (numSamples < 10000) - numSamples = 10000; - - /* Skip samples for unlogged table */ - if (buildstate->heap == NULL) - numSamples = 1; - - /* Sample rows */ - /* TODO Ensure within maintenance_work_mem */ - buildstate->samples = VectorArrayInit(numSamples, buildstate->dimensions); - if (buildstate->heap != NULL) - SampleRows(buildstate); - - /* Calculate centers */ - UpdateProgress(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_IVFFLAT_PHASE_KMEANS); - IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers)); - - /* Free samples before we allocate more memory */ - pfree(buildstate->samples); -} - /* * Create the metapage */ @@ -531,7 +418,9 @@ BuildIndex(Relation heap, Relation index, IndexInfo *indexInfo, { InitBuildState(buildstate, heap, index, indexInfo); - ComputeCenters(buildstate); + /* Perform k-means clustering */ + UpdateProgress(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_IVFFLAT_PHASE_KMEANS); + IvfflatBench("k-means", IvfflatKmeans(buildstate)); /* Create pages */ CreateMetaPage(index, buildstate->dimensions, buildstate->lists, forkNum); diff --git a/src/ivfflat.c b/src/ivfflat.c index 35c0e48..95ea9c8 100644 --- a/src/ivfflat.c +++ b/src/ivfflat.c @@ -45,8 +45,6 @@ ivfflatbuildphasename(int64 phasenum) { case PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE: return "initializing"; - case PROGRESS_IVFFLAT_PHASE_SAMPLE: - return "sampling table"; case PROGRESS_IVFFLAT_PHASE_KMEANS: return "performing k-means"; case PROGRESS_IVFFLAT_PHASE_SORT: diff --git a/src/ivfflat.h b/src/ivfflat.h index 1337941..7760784 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -37,10 +37,9 @@ /* Build phases */ /* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */ -#define PROGRESS_IVFFLAT_PHASE_SAMPLE 2 -#define PROGRESS_IVFFLAT_PHASE_KMEANS 3 -#define PROGRESS_IVFFLAT_PHASE_SORT 4 -#define PROGRESS_IVFFLAT_PHASE_LOAD 5 +#define PROGRESS_IVFFLAT_PHASE_KMEANS 2 +#define PROGRESS_IVFFLAT_PHASE_SORT 3 +#define PROGRESS_IVFFLAT_PHASE_LOAD 4 #define IVFFLAT_LIST_SIZE(_dim) (offsetof(IvfflatListData, center) + VECTOR_SIZE(_dim)) @@ -200,7 +199,7 @@ typedef IvfflatScanOpaqueData * IvfflatScanOpaque; void _PG_init(void); VectorArray VectorArrayInit(int maxlen, int dimensions); void PrintVectorArray(char *msg, VectorArray arr); -void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers); +void IvfflatKmeans(IvfflatBuildState * buildstate); FmgrInfo *IvfflatOptionalProcInfo(Relation rel, uint16 procnum); bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result); int IvfflatGetLists(Relation index); diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index 60a07f7..ec754ef 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -4,6 +4,17 @@ #include "ivfflat.h" #include "miscadmin.h" +#include "storage/bufmgr.h" + +#if PG_VERSION_NUM >= 120000 +#include "access/tableam.h" +#endif + +#if PG_VERSION_NUM >= 130000 +#define CALLBACK_ITEM_POINTER ItemPointer tid +#else +#define CALLBACK_ITEM_POINTER HeapTuple hup +#endif /* * Initialize with kmeans++ @@ -11,7 +22,7 @@ * https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf */ static void -InitCenters(Relation index, VectorArray samples, VectorArray centers, float *lowerBound) +InitCenters(Relation index, VectorArray samples, VectorArray centers) { FmgrInfo *procinfo; Oid collation; @@ -35,7 +46,7 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low for (j = 0; j < numSamples; j++) weight[j] = DBL_MAX; - for (i = 0; i < numCenters; i++) + for (i = 0; i < numCenters - 1; i++) { CHECK_FOR_INTERRUPTS(); @@ -49,9 +60,6 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low /* TODO Use triangle inequality to reduce distance calculations */ distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, i)))); - /* Set lower bound */ - lowerBound[j * numCenters + i] = distance; - /* Use distance squared for weighted probability distribution */ distance *= distance; @@ -61,10 +69,6 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low sum += weight[j]; } - /* Only compute lower bound on last iteration */ - if (i + 1 == numCenters) - break; - /* Choose new center using weighted probability distribution. */ choice = sum * (((double) random()) / MAX_RANDOM_VALUE); for (j = 0; j < numSamples - 1; j++) @@ -156,299 +160,202 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers) } /* - * Use Elkan for performance. This requires distance function to satisfy triangle inequality. + * Callback for sampling + */ +static void +SampleCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values, + bool *isnull, bool tupleIsAlive, void *state) +{ + IvfflatBuildState *buildstate = (IvfflatBuildState *) state; + VectorArray samples = buildstate->samples; + int targsamples = samples->maxlen; + Datum value = values[0]; + + /* Skip nulls */ + if (isnull[0]) + return; + + /* + * Normalize with KMEANS_NORM_PROC since spherical distance function + * expects unit vectors + */ + if (buildstate->kmeansnormprocinfo != NULL) + { + if (!IvfflatNormValue(buildstate->kmeansnormprocinfo, buildstate->collation, &value, buildstate->normvec)) + return; + } + + if (samples->length < targsamples) + { + VectorArraySet(samples, samples->length, DatumGetVector(value)); + samples->length++; + } + else + { + if (buildstate->rowstoskip < 0) + buildstate->rowstoskip = reservoir_get_next_S(&buildstate->rstate, samples->length, targsamples); + + if (buildstate->rowstoskip <= 0) + { + int k = (int) (targsamples * sampler_random_fract(buildstate->rstate.randstate)); + + Assert(k >= 0 && k < targsamples); + VectorArraySet(samples, k, DatumGetVector(value)); + } + + buildstate->rowstoskip -= 1; + } +} + +/* + * Sample rows with same logic as ANALYZE + */ +static void +SampleRows(IvfflatBuildState * buildstate) +{ + int targsamples = buildstate->samples->maxlen; + BlockNumber totalblocks = RelationGetNumberOfBlocks(buildstate->heap); + + buildstate->rowstoskip = -1; + + BlockSampler_Init(&buildstate->bs, totalblocks, targsamples, random()); + + reservoir_init_selection_state(&buildstate->rstate, targsamples); + while (BlockSampler_HasMore(&buildstate->bs)) + { + BlockNumber targblock = BlockSampler_Next(&buildstate->bs); + +#if PG_VERSION_NUM >= 120000 + table_index_build_range_scan(buildstate->heap, buildstate->index, buildstate->indexInfo, + false, true, true, targblock, 1, SampleCallback, (void *) buildstate, NULL); +#elif PG_VERSION_NUM >= 110000 + IndexBuildHeapRangeScan(buildstate->heap, buildstate->index, buildstate->indexInfo, + true, true, targblock, 1, SampleCallback, (void *) buildstate, NULL); +#else + IndexBuildHeapRangeScan(buildstate->heap, buildstate->index, buildstate->indexInfo, + true, true, targblock, 1, SampleCallback, (void *) buildstate); +#endif + } +} + +/* + * Use mini-batch k-means * * We use L2 distance for L2 (not L2 squared like index scan) * and angular distance for inner product and cosine distance * - * https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf + * https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf */ static void -ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) +MiniBatchKmeans(IvfflatBuildState * buildstate) { - FmgrInfo *procinfo; - FmgrInfo *normprocinfo; - Oid collation; - Vector *vec; - Vector *newCenter; - int iteration; - int j; - int k; - int dimensions = centers->dim; - int numCenters = centers->maxlen; - int numSamples = samples->length; - VectorArray newCenters; - int *centerCounts; - int *closestCenters; - float *lowerBound; - float *upperBound; - float *s; - float *halfcdist; - float *newcdist; - int changes; + VectorArray centers = buildstate->centers; + int b = buildstate->samples->maxlen; + int t = 20; + double distance; double minDistance; int closestCenter; - double distance; - bool rj; - bool rjreset; - double dxcx; - double dxc; - - /* Calculate allocation sizes */ - Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->dim); - Size centersSize = VECTOR_ARRAY_SIZE(centers->maxlen, centers->dim); - Size newCentersSize = VECTOR_ARRAY_SIZE(numCenters, dimensions); - Size centerCountsSize = sizeof(int) * numCenters; - Size closestCentersSize = sizeof(int) * numSamples; - Size lowerBoundSize = sizeof(float) * numSamples * numCenters; - Size upperBoundSize = sizeof(float) * numSamples; - Size sSize = sizeof(float) * numCenters; - Size halfcdistSize = sizeof(float) * numCenters * numCenters; - Size newcdistSize = sizeof(float) * numCenters; - - /* Calculate total size */ - Size totalSize = samplesSize + centersSize + newCentersSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize; - - /* Check memory requirements */ - /* Add one to error message to ceil */ - if (totalSize / 1024 > maintenance_work_mem) - ereport(ERROR, - (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED), - errmsg("memory required is %zu MB, maintenance_work_mem is %d MB", - totalSize / (1024 * 1024) + 1, maintenance_work_mem / 1024))); + int i; + int j; + int k; + VectorArray m; + Vector *c; + Vector *x; + int *v; + int *d; + float eta; /* Set support functions */ - procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC); - normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC); - collation = index->rd_indcollation[0]; - - /* Allocate space */ - /* Use float instead of double to save memory */ - centerCounts = palloc(centerCountsSize); - closestCenters = palloc(closestCentersSize); - lowerBound = palloc_extended(lowerBoundSize, MCXT_ALLOC_HUGE); - upperBound = palloc(upperBoundSize); - s = palloc(sSize); - halfcdist = palloc(halfcdistSize); - newcdist = palloc(newcdistSize); - - newCenters = VectorArrayInit(numCenters, dimensions); - for (j = 0; j < numCenters; j++) - { - vec = VectorArrayGet(newCenters, j); - SET_VARSIZE(vec, VECTOR_SIZE(dimensions)); - vec->dim = dimensions; - } + FmgrInfo *procinfo = index_getprocinfo(buildstate->index, 1, IVFFLAT_KMEANS_DISTANCE_PROC); + FmgrInfo *normprocinfo = buildstate->kmeansnormprocinfo; + Oid collation = buildstate->index->rd_indcollation[0]; /* Pick initial centers */ - InitCenters(index, samples, centers, lowerBound); + InitCenters(buildstate->index, buildstate->samples, buildstate->centers); - /* Assign each x to its closest initial center c(x) = argmin d(x,c) */ - for (j = 0; j < numSamples; j++) + v = palloc(sizeof(int) * centers->maxlen); + d = palloc(sizeof(int) * b); + + for (int i = 0; i < centers->length; i++) + v[i] = 0; + + for (i = 0; i < t; i++) { - minDistance = DBL_MAX; - closestCenter = -1; + /* Reset samples */ + buildstate->samples->length = 0; - vec = VectorArrayGet(samples, j); + /* Get b examples picked randomly from X */ + SampleRows(buildstate); + m = buildstate->samples; - /* Find closest center */ - for (k = 0; k < numCenters; k++) - { - /* TODO Use Lemma 1 in k-means++ initialization */ - distance = lowerBound[j * numCenters + k]; - - if (distance < minDistance) - { - minDistance = distance; - closestCenter = k; - } - } - - upperBound[j] = minDistance; - closestCenters[j] = closestCenter; - } - - /* Give 500 iterations to converge */ - for (iteration = 0; iteration < 500; iteration++) - { - /* Can take a while, so ensure we can interrupt */ - CHECK_FOR_INTERRUPTS(); - - changes = 0; - - /* Step 1: For all centers, compute distance */ - for (j = 0; j < numCenters; j++) - { - vec = VectorArrayGet(centers, j); - - for (k = j + 1; k < numCenters; k++) - { - distance = 0.5 * DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k)))); - halfcdist[j * numCenters + k] = distance; - halfcdist[k * numCenters + j] = distance; - } - } - - /* For all centers c, compute s(c) */ - for (j = 0; j < numCenters; j++) + /* Cache nearest center to x */ + for (j = 0; j < m->length; j++) { + /* compute closest */ minDistance = DBL_MAX; + closestCenter = -1; - for (k = 0; k < numCenters; k++) + x = VectorArrayGet(m, j); + + /* Find closest center */ + for (k = 0; k < centers->length; k++) { - if (j == k) - continue; + distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(x), PointerGetDatum(VectorArrayGet(centers, k)))); - distance = halfcdist[j * numCenters + k]; if (distance < minDistance) + { minDistance = distance; + closestCenter = k; + } } - s[j] = minDistance; + d[j] = closestCenter; } - rjreset = iteration != 0; - - for (j = 0; j < numSamples; j++) + for (j = 0; j < m->length; j++) { - /* Step 2: Identify all points x such that u(x) <= s(c(x)) */ - if (upperBound[j] <= s[closestCenters[j]]) - continue; + x = VectorArrayGet(m, j); - rj = rjreset; + /* Get cached center for this x */ + c = VectorArrayGet(centers, d[j]); - for (k = 0; k < numCenters; k++) + /* Update per-center counts */ + v[d[j]]++; + + /* Get per-center learning rate */ + eta = 1.0 / v[d[j]]; + + /* Take gradient step */ + for (k = 0; k < buildstate->dimensions; k++) + c->x[k] = (1 - eta) * c->x[k] + eta * x->x[k]; + } + + /* Check for empty centers (likely duplicates) */ + if (i == 0) + { + for (j = 0; j < centers->length; j++) { - /* Step 3: For all remaining points x and centers c */ - if (k == closestCenters[j]) - continue; - - if (upperBound[j] <= lowerBound[j * numCenters + k]) - continue; - - if (upperBound[j] <= halfcdist[closestCenters[j] * numCenters + k]) - continue; - - vec = VectorArrayGet(samples, j); - - /* Step 3a */ - if (rj) + if (v[j] == 0) { - dxcx = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, closestCenters[j])))); - - /* d(x,c(x)) computed, which is a form of d(x,c) */ - lowerBound[j * numCenters + closestCenters[j]] = dxcx; - upperBound[j] = dxcx; - - rj = false; - } - else - dxcx = upperBound[j]; - - /* Step 3b */ - if (dxcx > lowerBound[j * numCenters + k] || dxcx > halfcdist[closestCenters[j] * numCenters + k]) - { - dxc = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k)))); - - /* d(x,c) calculated */ - lowerBound[j * numCenters + k] = dxc; - - if (dxc < dxcx) - { - closestCenters[j] = k; - - /* c(x) changed */ - upperBound[j] = dxc; - - changes++; - } + c = VectorArrayGet(centers, j); + /* TODO Handle empty centers properly */ + for (k = 0; k < c->dim; k++) + c->x[k] = ((double) random()) / MAX_RANDOM_VALUE; } } } - /* Step 4: For each center c, let m(c) be mean of all points assigned */ - for (j = 0; j < numCenters; j++) + /* Normalize if needed */ + if (normprocinfo != NULL) { - vec = VectorArrayGet(newCenters, j); - for (k = 0; k < dimensions; k++) - vec->x[k] = 0.0; - - centerCounts[j] = 0; + for (j = 0; j < centers->length; j++) + ApplyNorm(normprocinfo, collation, VectorArrayGet(centers, j)); } - - for (j = 0; j < numSamples; j++) - { - vec = VectorArrayGet(samples, j); - closestCenter = closestCenters[j]; - - /* Increment sum and count of closest center */ - newCenter = VectorArrayGet(newCenters, closestCenter); - for (k = 0; k < dimensions; k++) - newCenter->x[k] += vec->x[k]; - - centerCounts[closestCenter] += 1; - } - - for (j = 0; j < numCenters; j++) - { - vec = VectorArrayGet(newCenters, j); - - if (centerCounts[j] > 0) - { - for (k = 0; k < dimensions; k++) - vec->x[k] /= centerCounts[j]; - } - else - { - /* TODO Handle empty centers properly */ - for (k = 0; k < dimensions; k++) - vec->x[k] = ((double) random()) / MAX_RANDOM_VALUE; - } - - /* Normalize if needed */ - if (normprocinfo != NULL) - ApplyNorm(normprocinfo, collation, vec); - } - - /* Step 5 */ - for (j = 0; j < numCenters; j++) - newcdist[j] = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(VectorArrayGet(centers, j)), PointerGetDatum(VectorArrayGet(newCenters, j)))); - - for (j = 0; j < numSamples; j++) - { - for (k = 0; k < numCenters; k++) - { - distance = lowerBound[j * numCenters + k] - newcdist[k]; - - if (distance < 0) - distance = 0; - - lowerBound[j * numCenters + k] = distance; - } - } - - /* Step 6 */ - /* We reset r(x) before Step 3 in the next iteration */ - for (j = 0; j < numSamples; j++) - upperBound[j] += newcdist[closestCenters[j]]; - - /* Step 7 */ - for (j = 0; j < numCenters; j++) - memcpy(VectorArrayGet(centers, j), VectorArrayGet(newCenters, j), VECTOR_SIZE(dimensions)); - - if (changes == 0 && iteration != 0) - break; } - pfree(newCenters); - pfree(centerCounts); - pfree(closestCenters); - pfree(lowerBound); - pfree(upperBound); - pfree(s); - pfree(halfcdist); - pfree(newcdist); + pfree(v); + pfree(d); } /* @@ -491,16 +398,47 @@ CheckCenters(Relation index, VectorArray centers) } /* - * Perform naive k-means centering + * Perform k-means clustering * We use spherical k-means for inner product and cosine */ void -IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers) +IvfflatKmeans(IvfflatBuildState * buildstate) { - if (samples->length <= centers->maxlen) - QuickCenters(index, samples, centers); - else - ElkanKmeans(index, samples, centers); + int numSamples; - CheckCenters(index, centers); + /* Target 10 samples per list, with at least 10000 samples */ + /* The number of samples has a large effect on index build time */ + numSamples = buildstate->lists * 10; + if (numSamples < 10000) + numSamples = 10000; + + /* Skip samples for unlogged table */ + if (buildstate->heap == NULL) + numSamples = 1; + + /* Calculate total size */ + Size totalSize = VECTOR_ARRAY_SIZE(numSamples, buildstate->dimensions); + + /* Check memory requirements */ + /* Add one to error message to ceil */ + if (totalSize / 1024 > maintenance_work_mem) + ereport(ERROR, + (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED), + errmsg("memory required is %zu MB, maintenance_work_mem is %d MB", + totalSize / (1024 * 1024) + 1, maintenance_work_mem / 1024))); + + /* Sample rows */ + buildstate->samples = VectorArrayInit(numSamples, buildstate->dimensions); + if (buildstate->heap != NULL) + SampleRows(buildstate); + + if (buildstate->samples->length <= buildstate->centers->maxlen) + QuickCenters(buildstate->index, buildstate->samples, buildstate->centers); + else + MiniBatchKmeans(buildstate); + + CheckCenters(buildstate->index, buildstate->centers); + + /* Free samples before we allocate more memory */ + pfree(buildstate->samples); }