Switched to mini-batch k-means

This commit is contained in:
Andrew Kane
2022-02-12 22:56:00 -08:00
parent 41d11c62d6
commit 8ee6d0e596
6 changed files with 220 additions and 393 deletions

View File

@@ -1,3 +1,7 @@
## 0.2.6 (unreleased)
- Switched to mini-batch k-means
## 0.2.5 (2022-02-11)
- Reduced memory usage during index creation

View File

@@ -119,10 +119,9 @@ SELECT phase, tuples_done, tuples_total FROM pg_stat_progress_create_index;
The phases are:
1. `initializing`
2. `sampling table`
3. `performing k-means`
4. `sorting tuples`
5. `loading tuples`
2. `performing k-means`
3. `sorting tuples`
4. `loading tuples`
Note: `tuples_done` and `tuples_total` are only populated during the `loading tuples` phase
@@ -264,7 +263,7 @@ Thanks to:
- [PASE: PostgreSQL Ultra-High-Dimensional Approximate Nearest Neighbor Search Extension](https://dl.acm.org/doi/pdf/10.1145/3318464.3386131)
- [Faiss: A Library for Efficient Similarity Search and Clustering of Dense Vectors](https://github.com/facebookresearch/faiss)
- [Using the Triangle Inequality to Accelerate k-means](https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf)
- [Web-Scale k-means Clustering](https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf)
- [k-means++: The Advantage of Careful Seeding](https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf)
- [Concept Decompositions for Large Sparse Text Data using Clustering](https://www.cs.utexas.edu/users/inderjit/public_papers/concept_mlj.pdf)

View File

@@ -47,87 +47,6 @@ UpdateProgress(int index, int64 val)
#endif
}
/*
* Callback for sampling
*/
static void
SampleCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values,
bool *isnull, bool tupleIsAlive, void *state)
{
IvfflatBuildState *buildstate = (IvfflatBuildState *) state;
VectorArray samples = buildstate->samples;
int targsamples = samples->maxlen;
Datum value = values[0];
/* Skip nulls */
if (isnull[0])
return;
/*
* Normalize with KMEANS_NORM_PROC since spherical distance function
* expects unit vectors
*/
if (buildstate->kmeansnormprocinfo != NULL)
{
if (!IvfflatNormValue(buildstate->kmeansnormprocinfo, buildstate->collation, &value, buildstate->normvec))
return;
}
if (samples->length < targsamples)
{
VectorArraySet(samples, samples->length, DatumGetVector(value));
samples->length++;
}
else
{
if (buildstate->rowstoskip < 0)
buildstate->rowstoskip = reservoir_get_next_S(&buildstate->rstate, samples->length, targsamples);
if (buildstate->rowstoskip <= 0)
{
int k = (int) (targsamples * sampler_random_fract(buildstate->rstate.randstate));
Assert(k >= 0 && k < targsamples);
VectorArraySet(samples, k, DatumGetVector(value));
}
buildstate->rowstoskip -= 1;
}
}
/*
* Sample rows with same logic as ANALYZE
*/
static void
SampleRows(IvfflatBuildState * buildstate)
{
int targsamples = buildstate->samples->maxlen;
BlockNumber totalblocks = RelationGetNumberOfBlocks(buildstate->heap);
UpdateProgress(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_IVFFLAT_PHASE_SAMPLE);
buildstate->rowstoskip = -1;
BlockSampler_Init(&buildstate->bs, totalblocks, targsamples, random());
reservoir_init_selection_state(&buildstate->rstate, targsamples);
while (BlockSampler_HasMore(&buildstate->bs))
{
BlockNumber targblock = BlockSampler_Next(&buildstate->bs);
#if PG_VERSION_NUM >= 120000
table_index_build_range_scan(buildstate->heap, buildstate->index, buildstate->indexInfo,
false, true, true, targblock, 1, SampleCallback, (void *) buildstate, NULL);
#elif PG_VERSION_NUM >= 110000
IndexBuildHeapRangeScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, true, targblock, 1, SampleCallback, (void *) buildstate, NULL);
#else
IndexBuildHeapRangeScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, true, targblock, 1, SampleCallback, (void *) buildstate);
#endif
}
}
/*
* Callback for table_index_build_scan
*/
@@ -371,38 +290,6 @@ FreeBuildState(IvfflatBuildState * buildstate)
pfree(buildstate->normvec);
}
/*
* Compute centers
*/
static void
ComputeCenters(IvfflatBuildState * buildstate)
{
int numSamples;
/* Target 50 samples per list, with at least 10000 samples */
/* The number of samples has a large effect on index build time */
numSamples = buildstate->lists * 50;
if (numSamples < 10000)
numSamples = 10000;
/* Skip samples for unlogged table */
if (buildstate->heap == NULL)
numSamples = 1;
/* Sample rows */
/* TODO Ensure within maintenance_work_mem */
buildstate->samples = VectorArrayInit(numSamples, buildstate->dimensions);
if (buildstate->heap != NULL)
SampleRows(buildstate);
/* Calculate centers */
UpdateProgress(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_IVFFLAT_PHASE_KMEANS);
IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers));
/* Free samples before we allocate more memory */
pfree(buildstate->samples);
}
/*
* Create the metapage
*/
@@ -531,7 +418,9 @@ BuildIndex(Relation heap, Relation index, IndexInfo *indexInfo,
{
InitBuildState(buildstate, heap, index, indexInfo);
ComputeCenters(buildstate);
/* Perform k-means clustering */
UpdateProgress(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_IVFFLAT_PHASE_KMEANS);
IvfflatBench("k-means", IvfflatKmeans(buildstate));
/* Create pages */
CreateMetaPage(index, buildstate->dimensions, buildstate->lists, forkNum);

View File

@@ -45,8 +45,6 @@ ivfflatbuildphasename(int64 phasenum)
{
case PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE:
return "initializing";
case PROGRESS_IVFFLAT_PHASE_SAMPLE:
return "sampling table";
case PROGRESS_IVFFLAT_PHASE_KMEANS:
return "performing k-means";
case PROGRESS_IVFFLAT_PHASE_SORT:

View File

@@ -37,10 +37,9 @@
/* Build phases */
/* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */
#define PROGRESS_IVFFLAT_PHASE_SAMPLE 2
#define PROGRESS_IVFFLAT_PHASE_KMEANS 3
#define PROGRESS_IVFFLAT_PHASE_SORT 4
#define PROGRESS_IVFFLAT_PHASE_LOAD 5
#define PROGRESS_IVFFLAT_PHASE_KMEANS 2
#define PROGRESS_IVFFLAT_PHASE_SORT 3
#define PROGRESS_IVFFLAT_PHASE_LOAD 4
#define IVFFLAT_LIST_SIZE(_dim) (offsetof(IvfflatListData, center) + VECTOR_SIZE(_dim))
@@ -200,7 +199,7 @@ typedef IvfflatScanOpaqueData * IvfflatScanOpaque;
void _PG_init(void);
VectorArray VectorArrayInit(int maxlen, int dimensions);
void PrintVectorArray(char *msg, VectorArray arr);
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers);
void IvfflatKmeans(IvfflatBuildState * buildstate);
FmgrInfo *IvfflatOptionalProcInfo(Relation rel, uint16 procnum);
bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result);
int IvfflatGetLists(Relation index);

View File

@@ -4,6 +4,17 @@
#include "ivfflat.h"
#include "miscadmin.h"
#include "storage/bufmgr.h"
#if PG_VERSION_NUM >= 120000
#include "access/tableam.h"
#endif
#if PG_VERSION_NUM >= 130000
#define CALLBACK_ITEM_POINTER ItemPointer tid
#else
#define CALLBACK_ITEM_POINTER HeapTuple hup
#endif
/*
* Initialize with kmeans++
@@ -11,7 +22,7 @@
* https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf
*/
static void
InitCenters(Relation index, VectorArray samples, VectorArray centers, float *lowerBound)
InitCenters(Relation index, VectorArray samples, VectorArray centers)
{
FmgrInfo *procinfo;
Oid collation;
@@ -35,7 +46,7 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low
for (j = 0; j < numSamples; j++)
weight[j] = DBL_MAX;
for (i = 0; i < numCenters; i++)
for (i = 0; i < numCenters - 1; i++)
{
CHECK_FOR_INTERRUPTS();
@@ -49,9 +60,6 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low
/* TODO Use triangle inequality to reduce distance calculations */
distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, i))));
/* Set lower bound */
lowerBound[j * numCenters + i] = distance;
/* Use distance squared for weighted probability distribution */
distance *= distance;
@@ -61,10 +69,6 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low
sum += weight[j];
}
/* Only compute lower bound on last iteration */
if (i + 1 == numCenters)
break;
/* Choose new center using weighted probability distribution. */
choice = sum * (((double) random()) / MAX_RANDOM_VALUE);
for (j = 0; j < numSamples - 1; j++)
@@ -156,299 +160,202 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers)
}
/*
* Use Elkan for performance. This requires distance function to satisfy triangle inequality.
* Callback for sampling
*/
static void
SampleCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values,
bool *isnull, bool tupleIsAlive, void *state)
{
IvfflatBuildState *buildstate = (IvfflatBuildState *) state;
VectorArray samples = buildstate->samples;
int targsamples = samples->maxlen;
Datum value = values[0];
/* Skip nulls */
if (isnull[0])
return;
/*
* Normalize with KMEANS_NORM_PROC since spherical distance function
* expects unit vectors
*/
if (buildstate->kmeansnormprocinfo != NULL)
{
if (!IvfflatNormValue(buildstate->kmeansnormprocinfo, buildstate->collation, &value, buildstate->normvec))
return;
}
if (samples->length < targsamples)
{
VectorArraySet(samples, samples->length, DatumGetVector(value));
samples->length++;
}
else
{
if (buildstate->rowstoskip < 0)
buildstate->rowstoskip = reservoir_get_next_S(&buildstate->rstate, samples->length, targsamples);
if (buildstate->rowstoskip <= 0)
{
int k = (int) (targsamples * sampler_random_fract(buildstate->rstate.randstate));
Assert(k >= 0 && k < targsamples);
VectorArraySet(samples, k, DatumGetVector(value));
}
buildstate->rowstoskip -= 1;
}
}
/*
* Sample rows with same logic as ANALYZE
*/
static void
SampleRows(IvfflatBuildState * buildstate)
{
int targsamples = buildstate->samples->maxlen;
BlockNumber totalblocks = RelationGetNumberOfBlocks(buildstate->heap);
buildstate->rowstoskip = -1;
BlockSampler_Init(&buildstate->bs, totalblocks, targsamples, random());
reservoir_init_selection_state(&buildstate->rstate, targsamples);
while (BlockSampler_HasMore(&buildstate->bs))
{
BlockNumber targblock = BlockSampler_Next(&buildstate->bs);
#if PG_VERSION_NUM >= 120000
table_index_build_range_scan(buildstate->heap, buildstate->index, buildstate->indexInfo,
false, true, true, targblock, 1, SampleCallback, (void *) buildstate, NULL);
#elif PG_VERSION_NUM >= 110000
IndexBuildHeapRangeScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, true, targblock, 1, SampleCallback, (void *) buildstate, NULL);
#else
IndexBuildHeapRangeScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, true, targblock, 1, SampleCallback, (void *) buildstate);
#endif
}
}
/*
* Use mini-batch k-means
*
* We use L2 distance for L2 (not L2 squared like index scan)
* and angular distance for inner product and cosine distance
*
* https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf
* https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf
*/
static void
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
MiniBatchKmeans(IvfflatBuildState * buildstate)
{
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
Oid collation;
Vector *vec;
Vector *newCenter;
int iteration;
int j;
int k;
int dimensions = centers->dim;
int numCenters = centers->maxlen;
int numSamples = samples->length;
VectorArray newCenters;
int *centerCounts;
int *closestCenters;
float *lowerBound;
float *upperBound;
float *s;
float *halfcdist;
float *newcdist;
int changes;
VectorArray centers = buildstate->centers;
int b = buildstate->samples->maxlen;
int t = 20;
double distance;
double minDistance;
int closestCenter;
double distance;
bool rj;
bool rjreset;
double dxcx;
double dxc;
/* Calculate allocation sizes */
Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->dim);
Size centersSize = VECTOR_ARRAY_SIZE(centers->maxlen, centers->dim);
Size newCentersSize = VECTOR_ARRAY_SIZE(numCenters, dimensions);
Size centerCountsSize = sizeof(int) * numCenters;
Size closestCentersSize = sizeof(int) * numSamples;
Size lowerBoundSize = sizeof(float) * numSamples * numCenters;
Size upperBoundSize = sizeof(float) * numSamples;
Size sSize = sizeof(float) * numCenters;
Size halfcdistSize = sizeof(float) * numCenters * numCenters;
Size newcdistSize = sizeof(float) * numCenters;
/* Calculate total size */
Size totalSize = samplesSize + centersSize + newCentersSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize;
/* Check memory requirements */
/* Add one to error message to ceil */
if (totalSize / 1024 > maintenance_work_mem)
ereport(ERROR,
(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
errmsg("memory required is %zu MB, maintenance_work_mem is %d MB",
totalSize / (1024 * 1024) + 1, maintenance_work_mem / 1024)));
int i;
int j;
int k;
VectorArray m;
Vector *c;
Vector *x;
int *v;
int *d;
float eta;
/* Set support functions */
procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC);
normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
collation = index->rd_indcollation[0];
/* Allocate space */
/* Use float instead of double to save memory */
centerCounts = palloc(centerCountsSize);
closestCenters = palloc(closestCentersSize);
lowerBound = palloc_extended(lowerBoundSize, MCXT_ALLOC_HUGE);
upperBound = palloc(upperBoundSize);
s = palloc(sSize);
halfcdist = palloc(halfcdistSize);
newcdist = palloc(newcdistSize);
newCenters = VectorArrayInit(numCenters, dimensions);
for (j = 0; j < numCenters; j++)
{
vec = VectorArrayGet(newCenters, j);
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
vec->dim = dimensions;
}
FmgrInfo *procinfo = index_getprocinfo(buildstate->index, 1, IVFFLAT_KMEANS_DISTANCE_PROC);
FmgrInfo *normprocinfo = buildstate->kmeansnormprocinfo;
Oid collation = buildstate->index->rd_indcollation[0];
/* Pick initial centers */
InitCenters(index, samples, centers, lowerBound);
InitCenters(buildstate->index, buildstate->samples, buildstate->centers);
/* Assign each x to its closest initial center c(x) = argmin d(x,c) */
for (j = 0; j < numSamples; j++)
v = palloc(sizeof(int) * centers->maxlen);
d = palloc(sizeof(int) * b);
for (int i = 0; i < centers->length; i++)
v[i] = 0;
for (i = 0; i < t; i++)
{
minDistance = DBL_MAX;
closestCenter = -1;
/* Reset samples */
buildstate->samples->length = 0;
vec = VectorArrayGet(samples, j);
/* Get b examples picked randomly from X */
SampleRows(buildstate);
m = buildstate->samples;
/* Find closest center */
for (k = 0; k < numCenters; k++)
{
/* TODO Use Lemma 1 in k-means++ initialization */
distance = lowerBound[j * numCenters + k];
if (distance < minDistance)
{
minDistance = distance;
closestCenter = k;
}
}
upperBound[j] = minDistance;
closestCenters[j] = closestCenter;
}
/* Give 500 iterations to converge */
for (iteration = 0; iteration < 500; iteration++)
{
/* Can take a while, so ensure we can interrupt */
CHECK_FOR_INTERRUPTS();
changes = 0;
/* Step 1: For all centers, compute distance */
for (j = 0; j < numCenters; j++)
{
vec = VectorArrayGet(centers, j);
for (k = j + 1; k < numCenters; k++)
{
distance = 0.5 * DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k))));
halfcdist[j * numCenters + k] = distance;
halfcdist[k * numCenters + j] = distance;
}
}
/* For all centers c, compute s(c) */
for (j = 0; j < numCenters; j++)
/* Cache nearest center to x */
for (j = 0; j < m->length; j++)
{
/* compute closest */
minDistance = DBL_MAX;
closestCenter = -1;
for (k = 0; k < numCenters; k++)
x = VectorArrayGet(m, j);
/* Find closest center */
for (k = 0; k < centers->length; k++)
{
if (j == k)
continue;
distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(x), PointerGetDatum(VectorArrayGet(centers, k))));
distance = halfcdist[j * numCenters + k];
if (distance < minDistance)
{
minDistance = distance;
closestCenter = k;
}
}
s[j] = minDistance;
d[j] = closestCenter;
}
rjreset = iteration != 0;
for (j = 0; j < numSamples; j++)
for (j = 0; j < m->length; j++)
{
/* Step 2: Identify all points x such that u(x) <= s(c(x)) */
if (upperBound[j] <= s[closestCenters[j]])
continue;
x = VectorArrayGet(m, j);
rj = rjreset;
/* Get cached center for this x */
c = VectorArrayGet(centers, d[j]);
for (k = 0; k < numCenters; k++)
/* Update per-center counts */
v[d[j]]++;
/* Get per-center learning rate */
eta = 1.0 / v[d[j]];
/* Take gradient step */
for (k = 0; k < buildstate->dimensions; k++)
c->x[k] = (1 - eta) * c->x[k] + eta * x->x[k];
}
/* Check for empty centers (likely duplicates) */
if (i == 0)
{
for (j = 0; j < centers->length; j++)
{
/* Step 3: For all remaining points x and centers c */
if (k == closestCenters[j])
continue;
if (upperBound[j] <= lowerBound[j * numCenters + k])
continue;
if (upperBound[j] <= halfcdist[closestCenters[j] * numCenters + k])
continue;
vec = VectorArrayGet(samples, j);
/* Step 3a */
if (rj)
if (v[j] == 0)
{
dxcx = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, closestCenters[j]))));
/* d(x,c(x)) computed, which is a form of d(x,c) */
lowerBound[j * numCenters + closestCenters[j]] = dxcx;
upperBound[j] = dxcx;
rj = false;
}
else
dxcx = upperBound[j];
/* Step 3b */
if (dxcx > lowerBound[j * numCenters + k] || dxcx > halfcdist[closestCenters[j] * numCenters + k])
{
dxc = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k))));
/* d(x,c) calculated */
lowerBound[j * numCenters + k] = dxc;
if (dxc < dxcx)
{
closestCenters[j] = k;
/* c(x) changed */
upperBound[j] = dxc;
changes++;
}
c = VectorArrayGet(centers, j);
/* TODO Handle empty centers properly */
for (k = 0; k < c->dim; k++)
c->x[k] = ((double) random()) / MAX_RANDOM_VALUE;
}
}
}
/* Step 4: For each center c, let m(c) be mean of all points assigned */
for (j = 0; j < numCenters; j++)
/* Normalize if needed */
if (normprocinfo != NULL)
{
vec = VectorArrayGet(newCenters, j);
for (k = 0; k < dimensions; k++)
vec->x[k] = 0.0;
centerCounts[j] = 0;
for (j = 0; j < centers->length; j++)
ApplyNorm(normprocinfo, collation, VectorArrayGet(centers, j));
}
for (j = 0; j < numSamples; j++)
{
vec = VectorArrayGet(samples, j);
closestCenter = closestCenters[j];
/* Increment sum and count of closest center */
newCenter = VectorArrayGet(newCenters, closestCenter);
for (k = 0; k < dimensions; k++)
newCenter->x[k] += vec->x[k];
centerCounts[closestCenter] += 1;
}
for (j = 0; j < numCenters; j++)
{
vec = VectorArrayGet(newCenters, j);
if (centerCounts[j] > 0)
{
for (k = 0; k < dimensions; k++)
vec->x[k] /= centerCounts[j];
}
else
{
/* TODO Handle empty centers properly */
for (k = 0; k < dimensions; k++)
vec->x[k] = ((double) random()) / MAX_RANDOM_VALUE;
}
/* Normalize if needed */
if (normprocinfo != NULL)
ApplyNorm(normprocinfo, collation, vec);
}
/* Step 5 */
for (j = 0; j < numCenters; j++)
newcdist[j] = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(VectorArrayGet(centers, j)), PointerGetDatum(VectorArrayGet(newCenters, j))));
for (j = 0; j < numSamples; j++)
{
for (k = 0; k < numCenters; k++)
{
distance = lowerBound[j * numCenters + k] - newcdist[k];
if (distance < 0)
distance = 0;
lowerBound[j * numCenters + k] = distance;
}
}
/* Step 6 */
/* We reset r(x) before Step 3 in the next iteration */
for (j = 0; j < numSamples; j++)
upperBound[j] += newcdist[closestCenters[j]];
/* Step 7 */
for (j = 0; j < numCenters; j++)
memcpy(VectorArrayGet(centers, j), VectorArrayGet(newCenters, j), VECTOR_SIZE(dimensions));
if (changes == 0 && iteration != 0)
break;
}
pfree(newCenters);
pfree(centerCounts);
pfree(closestCenters);
pfree(lowerBound);
pfree(upperBound);
pfree(s);
pfree(halfcdist);
pfree(newcdist);
pfree(v);
pfree(d);
}
/*
@@ -491,16 +398,47 @@ CheckCenters(Relation index, VectorArray centers)
}
/*
* Perform naive k-means centering
* Perform k-means clustering
* We use spherical k-means for inner product and cosine
*/
void
IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers)
IvfflatKmeans(IvfflatBuildState * buildstate)
{
if (samples->length <= centers->maxlen)
QuickCenters(index, samples, centers);
else
ElkanKmeans(index, samples, centers);
int numSamples;
CheckCenters(index, centers);
/* Target 10 samples per list, with at least 10000 samples */
/* The number of samples has a large effect on index build time */
numSamples = buildstate->lists * 10;
if (numSamples < 10000)
numSamples = 10000;
/* Skip samples for unlogged table */
if (buildstate->heap == NULL)
numSamples = 1;
/* Calculate total size */
Size totalSize = VECTOR_ARRAY_SIZE(numSamples, buildstate->dimensions);
/* Check memory requirements */
/* Add one to error message to ceil */
if (totalSize / 1024 > maintenance_work_mem)
ereport(ERROR,
(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
errmsg("memory required is %zu MB, maintenance_work_mem is %d MB",
totalSize / (1024 * 1024) + 1, maintenance_work_mem / 1024)));
/* Sample rows */
buildstate->samples = VectorArrayInit(numSamples, buildstate->dimensions);
if (buildstate->heap != NULL)
SampleRows(buildstate);
if (buildstate->samples->length <= buildstate->centers->maxlen)
QuickCenters(buildstate->index, buildstate->samples, buildstate->centers);
else
MiniBatchKmeans(buildstate);
CheckCenters(buildstate->index, buildstate->centers);
/* Free samples before we allocate more memory */
pfree(buildstate->samples);
}