mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
Switched to mini-batch k-means
This commit is contained in:
@@ -1,3 +1,7 @@
|
||||
## 0.2.6 (unreleased)
|
||||
|
||||
- Switched to mini-batch k-means
|
||||
|
||||
## 0.2.5 (2022-02-11)
|
||||
|
||||
- Reduced memory usage during index creation
|
||||
|
||||
@@ -119,10 +119,9 @@ SELECT phase, tuples_done, tuples_total FROM pg_stat_progress_create_index;
|
||||
The phases are:
|
||||
|
||||
1. `initializing`
|
||||
2. `sampling table`
|
||||
3. `performing k-means`
|
||||
4. `sorting tuples`
|
||||
5. `loading tuples`
|
||||
2. `performing k-means`
|
||||
3. `sorting tuples`
|
||||
4. `loading tuples`
|
||||
|
||||
Note: `tuples_done` and `tuples_total` are only populated during the `loading tuples` phase
|
||||
|
||||
@@ -264,7 +263,7 @@ Thanks to:
|
||||
|
||||
- [PASE: PostgreSQL Ultra-High-Dimensional Approximate Nearest Neighbor Search Extension](https://dl.acm.org/doi/pdf/10.1145/3318464.3386131)
|
||||
- [Faiss: A Library for Efficient Similarity Search and Clustering of Dense Vectors](https://github.com/facebookresearch/faiss)
|
||||
- [Using the Triangle Inequality to Accelerate k-means](https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf)
|
||||
- [Web-Scale k-means Clustering](https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf)
|
||||
- [k-means++: The Advantage of Careful Seeding](https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf)
|
||||
- [Concept Decompositions for Large Sparse Text Data using Clustering](https://www.cs.utexas.edu/users/inderjit/public_papers/concept_mlj.pdf)
|
||||
|
||||
|
||||
117
src/ivfbuild.c
117
src/ivfbuild.c
@@ -47,87 +47,6 @@ UpdateProgress(int index, int64 val)
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* Callback for sampling
|
||||
*/
|
||||
static void
|
||||
SampleCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values,
|
||||
bool *isnull, bool tupleIsAlive, void *state)
|
||||
{
|
||||
IvfflatBuildState *buildstate = (IvfflatBuildState *) state;
|
||||
VectorArray samples = buildstate->samples;
|
||||
int targsamples = samples->maxlen;
|
||||
Datum value = values[0];
|
||||
|
||||
/* Skip nulls */
|
||||
if (isnull[0])
|
||||
return;
|
||||
|
||||
/*
|
||||
* Normalize with KMEANS_NORM_PROC since spherical distance function
|
||||
* expects unit vectors
|
||||
*/
|
||||
if (buildstate->kmeansnormprocinfo != NULL)
|
||||
{
|
||||
if (!IvfflatNormValue(buildstate->kmeansnormprocinfo, buildstate->collation, &value, buildstate->normvec))
|
||||
return;
|
||||
}
|
||||
|
||||
if (samples->length < targsamples)
|
||||
{
|
||||
VectorArraySet(samples, samples->length, DatumGetVector(value));
|
||||
samples->length++;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (buildstate->rowstoskip < 0)
|
||||
buildstate->rowstoskip = reservoir_get_next_S(&buildstate->rstate, samples->length, targsamples);
|
||||
|
||||
if (buildstate->rowstoskip <= 0)
|
||||
{
|
||||
int k = (int) (targsamples * sampler_random_fract(buildstate->rstate.randstate));
|
||||
|
||||
Assert(k >= 0 && k < targsamples);
|
||||
VectorArraySet(samples, k, DatumGetVector(value));
|
||||
}
|
||||
|
||||
buildstate->rowstoskip -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Sample rows with same logic as ANALYZE
|
||||
*/
|
||||
static void
|
||||
SampleRows(IvfflatBuildState * buildstate)
|
||||
{
|
||||
int targsamples = buildstate->samples->maxlen;
|
||||
BlockNumber totalblocks = RelationGetNumberOfBlocks(buildstate->heap);
|
||||
|
||||
UpdateProgress(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_IVFFLAT_PHASE_SAMPLE);
|
||||
|
||||
buildstate->rowstoskip = -1;
|
||||
|
||||
BlockSampler_Init(&buildstate->bs, totalblocks, targsamples, random());
|
||||
|
||||
reservoir_init_selection_state(&buildstate->rstate, targsamples);
|
||||
while (BlockSampler_HasMore(&buildstate->bs))
|
||||
{
|
||||
BlockNumber targblock = BlockSampler_Next(&buildstate->bs);
|
||||
|
||||
#if PG_VERSION_NUM >= 120000
|
||||
table_index_build_range_scan(buildstate->heap, buildstate->index, buildstate->indexInfo,
|
||||
false, true, true, targblock, 1, SampleCallback, (void *) buildstate, NULL);
|
||||
#elif PG_VERSION_NUM >= 110000
|
||||
IndexBuildHeapRangeScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
|
||||
true, true, targblock, 1, SampleCallback, (void *) buildstate, NULL);
|
||||
#else
|
||||
IndexBuildHeapRangeScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
|
||||
true, true, targblock, 1, SampleCallback, (void *) buildstate);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Callback for table_index_build_scan
|
||||
*/
|
||||
@@ -371,38 +290,6 @@ FreeBuildState(IvfflatBuildState * buildstate)
|
||||
pfree(buildstate->normvec);
|
||||
}
|
||||
|
||||
/*
|
||||
* Compute centers
|
||||
*/
|
||||
static void
|
||||
ComputeCenters(IvfflatBuildState * buildstate)
|
||||
{
|
||||
int numSamples;
|
||||
|
||||
/* Target 50 samples per list, with at least 10000 samples */
|
||||
/* The number of samples has a large effect on index build time */
|
||||
numSamples = buildstate->lists * 50;
|
||||
if (numSamples < 10000)
|
||||
numSamples = 10000;
|
||||
|
||||
/* Skip samples for unlogged table */
|
||||
if (buildstate->heap == NULL)
|
||||
numSamples = 1;
|
||||
|
||||
/* Sample rows */
|
||||
/* TODO Ensure within maintenance_work_mem */
|
||||
buildstate->samples = VectorArrayInit(numSamples, buildstate->dimensions);
|
||||
if (buildstate->heap != NULL)
|
||||
SampleRows(buildstate);
|
||||
|
||||
/* Calculate centers */
|
||||
UpdateProgress(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_IVFFLAT_PHASE_KMEANS);
|
||||
IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers));
|
||||
|
||||
/* Free samples before we allocate more memory */
|
||||
pfree(buildstate->samples);
|
||||
}
|
||||
|
||||
/*
|
||||
* Create the metapage
|
||||
*/
|
||||
@@ -531,7 +418,9 @@ BuildIndex(Relation heap, Relation index, IndexInfo *indexInfo,
|
||||
{
|
||||
InitBuildState(buildstate, heap, index, indexInfo);
|
||||
|
||||
ComputeCenters(buildstate);
|
||||
/* Perform k-means clustering */
|
||||
UpdateProgress(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_IVFFLAT_PHASE_KMEANS);
|
||||
IvfflatBench("k-means", IvfflatKmeans(buildstate));
|
||||
|
||||
/* Create pages */
|
||||
CreateMetaPage(index, buildstate->dimensions, buildstate->lists, forkNum);
|
||||
|
||||
@@ -45,8 +45,6 @@ ivfflatbuildphasename(int64 phasenum)
|
||||
{
|
||||
case PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE:
|
||||
return "initializing";
|
||||
case PROGRESS_IVFFLAT_PHASE_SAMPLE:
|
||||
return "sampling table";
|
||||
case PROGRESS_IVFFLAT_PHASE_KMEANS:
|
||||
return "performing k-means";
|
||||
case PROGRESS_IVFFLAT_PHASE_SORT:
|
||||
|
||||
@@ -37,10 +37,9 @@
|
||||
|
||||
/* Build phases */
|
||||
/* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */
|
||||
#define PROGRESS_IVFFLAT_PHASE_SAMPLE 2
|
||||
#define PROGRESS_IVFFLAT_PHASE_KMEANS 3
|
||||
#define PROGRESS_IVFFLAT_PHASE_SORT 4
|
||||
#define PROGRESS_IVFFLAT_PHASE_LOAD 5
|
||||
#define PROGRESS_IVFFLAT_PHASE_KMEANS 2
|
||||
#define PROGRESS_IVFFLAT_PHASE_SORT 3
|
||||
#define PROGRESS_IVFFLAT_PHASE_LOAD 4
|
||||
|
||||
#define IVFFLAT_LIST_SIZE(_dim) (offsetof(IvfflatListData, center) + VECTOR_SIZE(_dim))
|
||||
|
||||
@@ -200,7 +199,7 @@ typedef IvfflatScanOpaqueData * IvfflatScanOpaque;
|
||||
void _PG_init(void);
|
||||
VectorArray VectorArrayInit(int maxlen, int dimensions);
|
||||
void PrintVectorArray(char *msg, VectorArray arr);
|
||||
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers);
|
||||
void IvfflatKmeans(IvfflatBuildState * buildstate);
|
||||
FmgrInfo *IvfflatOptionalProcInfo(Relation rel, uint16 procnum);
|
||||
bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result);
|
||||
int IvfflatGetLists(Relation index);
|
||||
|
||||
472
src/ivfkmeans.c
472
src/ivfkmeans.c
@@ -4,6 +4,17 @@
|
||||
|
||||
#include "ivfflat.h"
|
||||
#include "miscadmin.h"
|
||||
#include "storage/bufmgr.h"
|
||||
|
||||
#if PG_VERSION_NUM >= 120000
|
||||
#include "access/tableam.h"
|
||||
#endif
|
||||
|
||||
#if PG_VERSION_NUM >= 130000
|
||||
#define CALLBACK_ITEM_POINTER ItemPointer tid
|
||||
#else
|
||||
#define CALLBACK_ITEM_POINTER HeapTuple hup
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Initialize with kmeans++
|
||||
@@ -11,7 +22,7 @@
|
||||
* https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf
|
||||
*/
|
||||
static void
|
||||
InitCenters(Relation index, VectorArray samples, VectorArray centers, float *lowerBound)
|
||||
InitCenters(Relation index, VectorArray samples, VectorArray centers)
|
||||
{
|
||||
FmgrInfo *procinfo;
|
||||
Oid collation;
|
||||
@@ -35,7 +46,7 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low
|
||||
for (j = 0; j < numSamples; j++)
|
||||
weight[j] = DBL_MAX;
|
||||
|
||||
for (i = 0; i < numCenters; i++)
|
||||
for (i = 0; i < numCenters - 1; i++)
|
||||
{
|
||||
CHECK_FOR_INTERRUPTS();
|
||||
|
||||
@@ -49,9 +60,6 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low
|
||||
/* TODO Use triangle inequality to reduce distance calculations */
|
||||
distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, i))));
|
||||
|
||||
/* Set lower bound */
|
||||
lowerBound[j * numCenters + i] = distance;
|
||||
|
||||
/* Use distance squared for weighted probability distribution */
|
||||
distance *= distance;
|
||||
|
||||
@@ -61,10 +69,6 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low
|
||||
sum += weight[j];
|
||||
}
|
||||
|
||||
/* Only compute lower bound on last iteration */
|
||||
if (i + 1 == numCenters)
|
||||
break;
|
||||
|
||||
/* Choose new center using weighted probability distribution. */
|
||||
choice = sum * (((double) random()) / MAX_RANDOM_VALUE);
|
||||
for (j = 0; j < numSamples - 1; j++)
|
||||
@@ -156,299 +160,202 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers)
|
||||
}
|
||||
|
||||
/*
|
||||
* Use Elkan for performance. This requires distance function to satisfy triangle inequality.
|
||||
* Callback for sampling
|
||||
*/
|
||||
static void
|
||||
SampleCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values,
|
||||
bool *isnull, bool tupleIsAlive, void *state)
|
||||
{
|
||||
IvfflatBuildState *buildstate = (IvfflatBuildState *) state;
|
||||
VectorArray samples = buildstate->samples;
|
||||
int targsamples = samples->maxlen;
|
||||
Datum value = values[0];
|
||||
|
||||
/* Skip nulls */
|
||||
if (isnull[0])
|
||||
return;
|
||||
|
||||
/*
|
||||
* Normalize with KMEANS_NORM_PROC since spherical distance function
|
||||
* expects unit vectors
|
||||
*/
|
||||
if (buildstate->kmeansnormprocinfo != NULL)
|
||||
{
|
||||
if (!IvfflatNormValue(buildstate->kmeansnormprocinfo, buildstate->collation, &value, buildstate->normvec))
|
||||
return;
|
||||
}
|
||||
|
||||
if (samples->length < targsamples)
|
||||
{
|
||||
VectorArraySet(samples, samples->length, DatumGetVector(value));
|
||||
samples->length++;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (buildstate->rowstoskip < 0)
|
||||
buildstate->rowstoskip = reservoir_get_next_S(&buildstate->rstate, samples->length, targsamples);
|
||||
|
||||
if (buildstate->rowstoskip <= 0)
|
||||
{
|
||||
int k = (int) (targsamples * sampler_random_fract(buildstate->rstate.randstate));
|
||||
|
||||
Assert(k >= 0 && k < targsamples);
|
||||
VectorArraySet(samples, k, DatumGetVector(value));
|
||||
}
|
||||
|
||||
buildstate->rowstoskip -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Sample rows with same logic as ANALYZE
|
||||
*/
|
||||
static void
|
||||
SampleRows(IvfflatBuildState * buildstate)
|
||||
{
|
||||
int targsamples = buildstate->samples->maxlen;
|
||||
BlockNumber totalblocks = RelationGetNumberOfBlocks(buildstate->heap);
|
||||
|
||||
buildstate->rowstoskip = -1;
|
||||
|
||||
BlockSampler_Init(&buildstate->bs, totalblocks, targsamples, random());
|
||||
|
||||
reservoir_init_selection_state(&buildstate->rstate, targsamples);
|
||||
while (BlockSampler_HasMore(&buildstate->bs))
|
||||
{
|
||||
BlockNumber targblock = BlockSampler_Next(&buildstate->bs);
|
||||
|
||||
#if PG_VERSION_NUM >= 120000
|
||||
table_index_build_range_scan(buildstate->heap, buildstate->index, buildstate->indexInfo,
|
||||
false, true, true, targblock, 1, SampleCallback, (void *) buildstate, NULL);
|
||||
#elif PG_VERSION_NUM >= 110000
|
||||
IndexBuildHeapRangeScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
|
||||
true, true, targblock, 1, SampleCallback, (void *) buildstate, NULL);
|
||||
#else
|
||||
IndexBuildHeapRangeScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
|
||||
true, true, targblock, 1, SampleCallback, (void *) buildstate);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Use mini-batch k-means
|
||||
*
|
||||
* We use L2 distance for L2 (not L2 squared like index scan)
|
||||
* and angular distance for inner product and cosine distance
|
||||
*
|
||||
* https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf
|
||||
* https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf
|
||||
*/
|
||||
static void
|
||||
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
|
||||
MiniBatchKmeans(IvfflatBuildState * buildstate)
|
||||
{
|
||||
FmgrInfo *procinfo;
|
||||
FmgrInfo *normprocinfo;
|
||||
Oid collation;
|
||||
Vector *vec;
|
||||
Vector *newCenter;
|
||||
int iteration;
|
||||
int j;
|
||||
int k;
|
||||
int dimensions = centers->dim;
|
||||
int numCenters = centers->maxlen;
|
||||
int numSamples = samples->length;
|
||||
VectorArray newCenters;
|
||||
int *centerCounts;
|
||||
int *closestCenters;
|
||||
float *lowerBound;
|
||||
float *upperBound;
|
||||
float *s;
|
||||
float *halfcdist;
|
||||
float *newcdist;
|
||||
int changes;
|
||||
VectorArray centers = buildstate->centers;
|
||||
int b = buildstate->samples->maxlen;
|
||||
int t = 20;
|
||||
double distance;
|
||||
double minDistance;
|
||||
int closestCenter;
|
||||
double distance;
|
||||
bool rj;
|
||||
bool rjreset;
|
||||
double dxcx;
|
||||
double dxc;
|
||||
|
||||
/* Calculate allocation sizes */
|
||||
Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->dim);
|
||||
Size centersSize = VECTOR_ARRAY_SIZE(centers->maxlen, centers->dim);
|
||||
Size newCentersSize = VECTOR_ARRAY_SIZE(numCenters, dimensions);
|
||||
Size centerCountsSize = sizeof(int) * numCenters;
|
||||
Size closestCentersSize = sizeof(int) * numSamples;
|
||||
Size lowerBoundSize = sizeof(float) * numSamples * numCenters;
|
||||
Size upperBoundSize = sizeof(float) * numSamples;
|
||||
Size sSize = sizeof(float) * numCenters;
|
||||
Size halfcdistSize = sizeof(float) * numCenters * numCenters;
|
||||
Size newcdistSize = sizeof(float) * numCenters;
|
||||
|
||||
/* Calculate total size */
|
||||
Size totalSize = samplesSize + centersSize + newCentersSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize;
|
||||
|
||||
/* Check memory requirements */
|
||||
/* Add one to error message to ceil */
|
||||
if (totalSize / 1024 > maintenance_work_mem)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
|
||||
errmsg("memory required is %zu MB, maintenance_work_mem is %d MB",
|
||||
totalSize / (1024 * 1024) + 1, maintenance_work_mem / 1024)));
|
||||
int i;
|
||||
int j;
|
||||
int k;
|
||||
VectorArray m;
|
||||
Vector *c;
|
||||
Vector *x;
|
||||
int *v;
|
||||
int *d;
|
||||
float eta;
|
||||
|
||||
/* Set support functions */
|
||||
procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC);
|
||||
normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
|
||||
collation = index->rd_indcollation[0];
|
||||
|
||||
/* Allocate space */
|
||||
/* Use float instead of double to save memory */
|
||||
centerCounts = palloc(centerCountsSize);
|
||||
closestCenters = palloc(closestCentersSize);
|
||||
lowerBound = palloc_extended(lowerBoundSize, MCXT_ALLOC_HUGE);
|
||||
upperBound = palloc(upperBoundSize);
|
||||
s = palloc(sSize);
|
||||
halfcdist = palloc(halfcdistSize);
|
||||
newcdist = palloc(newcdistSize);
|
||||
|
||||
newCenters = VectorArrayInit(numCenters, dimensions);
|
||||
for (j = 0; j < numCenters; j++)
|
||||
{
|
||||
vec = VectorArrayGet(newCenters, j);
|
||||
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
|
||||
vec->dim = dimensions;
|
||||
}
|
||||
FmgrInfo *procinfo = index_getprocinfo(buildstate->index, 1, IVFFLAT_KMEANS_DISTANCE_PROC);
|
||||
FmgrInfo *normprocinfo = buildstate->kmeansnormprocinfo;
|
||||
Oid collation = buildstate->index->rd_indcollation[0];
|
||||
|
||||
/* Pick initial centers */
|
||||
InitCenters(index, samples, centers, lowerBound);
|
||||
InitCenters(buildstate->index, buildstate->samples, buildstate->centers);
|
||||
|
||||
/* Assign each x to its closest initial center c(x) = argmin d(x,c) */
|
||||
for (j = 0; j < numSamples; j++)
|
||||
v = palloc(sizeof(int) * centers->maxlen);
|
||||
d = palloc(sizeof(int) * b);
|
||||
|
||||
for (int i = 0; i < centers->length; i++)
|
||||
v[i] = 0;
|
||||
|
||||
for (i = 0; i < t; i++)
|
||||
{
|
||||
minDistance = DBL_MAX;
|
||||
closestCenter = -1;
|
||||
/* Reset samples */
|
||||
buildstate->samples->length = 0;
|
||||
|
||||
vec = VectorArrayGet(samples, j);
|
||||
/* Get b examples picked randomly from X */
|
||||
SampleRows(buildstate);
|
||||
m = buildstate->samples;
|
||||
|
||||
/* Find closest center */
|
||||
for (k = 0; k < numCenters; k++)
|
||||
{
|
||||
/* TODO Use Lemma 1 in k-means++ initialization */
|
||||
distance = lowerBound[j * numCenters + k];
|
||||
|
||||
if (distance < minDistance)
|
||||
{
|
||||
minDistance = distance;
|
||||
closestCenter = k;
|
||||
}
|
||||
}
|
||||
|
||||
upperBound[j] = minDistance;
|
||||
closestCenters[j] = closestCenter;
|
||||
}
|
||||
|
||||
/* Give 500 iterations to converge */
|
||||
for (iteration = 0; iteration < 500; iteration++)
|
||||
{
|
||||
/* Can take a while, so ensure we can interrupt */
|
||||
CHECK_FOR_INTERRUPTS();
|
||||
|
||||
changes = 0;
|
||||
|
||||
/* Step 1: For all centers, compute distance */
|
||||
for (j = 0; j < numCenters; j++)
|
||||
{
|
||||
vec = VectorArrayGet(centers, j);
|
||||
|
||||
for (k = j + 1; k < numCenters; k++)
|
||||
{
|
||||
distance = 0.5 * DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k))));
|
||||
halfcdist[j * numCenters + k] = distance;
|
||||
halfcdist[k * numCenters + j] = distance;
|
||||
}
|
||||
}
|
||||
|
||||
/* For all centers c, compute s(c) */
|
||||
for (j = 0; j < numCenters; j++)
|
||||
/* Cache nearest center to x */
|
||||
for (j = 0; j < m->length; j++)
|
||||
{
|
||||
/* compute closest */
|
||||
minDistance = DBL_MAX;
|
||||
closestCenter = -1;
|
||||
|
||||
for (k = 0; k < numCenters; k++)
|
||||
x = VectorArrayGet(m, j);
|
||||
|
||||
/* Find closest center */
|
||||
for (k = 0; k < centers->length; k++)
|
||||
{
|
||||
if (j == k)
|
||||
continue;
|
||||
distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(x), PointerGetDatum(VectorArrayGet(centers, k))));
|
||||
|
||||
distance = halfcdist[j * numCenters + k];
|
||||
if (distance < minDistance)
|
||||
{
|
||||
minDistance = distance;
|
||||
closestCenter = k;
|
||||
}
|
||||
}
|
||||
|
||||
s[j] = minDistance;
|
||||
d[j] = closestCenter;
|
||||
}
|
||||
|
||||
rjreset = iteration != 0;
|
||||
|
||||
for (j = 0; j < numSamples; j++)
|
||||
for (j = 0; j < m->length; j++)
|
||||
{
|
||||
/* Step 2: Identify all points x such that u(x) <= s(c(x)) */
|
||||
if (upperBound[j] <= s[closestCenters[j]])
|
||||
continue;
|
||||
x = VectorArrayGet(m, j);
|
||||
|
||||
rj = rjreset;
|
||||
/* Get cached center for this x */
|
||||
c = VectorArrayGet(centers, d[j]);
|
||||
|
||||
for (k = 0; k < numCenters; k++)
|
||||
/* Update per-center counts */
|
||||
v[d[j]]++;
|
||||
|
||||
/* Get per-center learning rate */
|
||||
eta = 1.0 / v[d[j]];
|
||||
|
||||
/* Take gradient step */
|
||||
for (k = 0; k < buildstate->dimensions; k++)
|
||||
c->x[k] = (1 - eta) * c->x[k] + eta * x->x[k];
|
||||
}
|
||||
|
||||
/* Check for empty centers (likely duplicates) */
|
||||
if (i == 0)
|
||||
{
|
||||
for (j = 0; j < centers->length; j++)
|
||||
{
|
||||
/* Step 3: For all remaining points x and centers c */
|
||||
if (k == closestCenters[j])
|
||||
continue;
|
||||
|
||||
if (upperBound[j] <= lowerBound[j * numCenters + k])
|
||||
continue;
|
||||
|
||||
if (upperBound[j] <= halfcdist[closestCenters[j] * numCenters + k])
|
||||
continue;
|
||||
|
||||
vec = VectorArrayGet(samples, j);
|
||||
|
||||
/* Step 3a */
|
||||
if (rj)
|
||||
if (v[j] == 0)
|
||||
{
|
||||
dxcx = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, closestCenters[j]))));
|
||||
|
||||
/* d(x,c(x)) computed, which is a form of d(x,c) */
|
||||
lowerBound[j * numCenters + closestCenters[j]] = dxcx;
|
||||
upperBound[j] = dxcx;
|
||||
|
||||
rj = false;
|
||||
}
|
||||
else
|
||||
dxcx = upperBound[j];
|
||||
|
||||
/* Step 3b */
|
||||
if (dxcx > lowerBound[j * numCenters + k] || dxcx > halfcdist[closestCenters[j] * numCenters + k])
|
||||
{
|
||||
dxc = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k))));
|
||||
|
||||
/* d(x,c) calculated */
|
||||
lowerBound[j * numCenters + k] = dxc;
|
||||
|
||||
if (dxc < dxcx)
|
||||
{
|
||||
closestCenters[j] = k;
|
||||
|
||||
/* c(x) changed */
|
||||
upperBound[j] = dxc;
|
||||
|
||||
changes++;
|
||||
}
|
||||
c = VectorArrayGet(centers, j);
|
||||
|
||||
/* TODO Handle empty centers properly */
|
||||
for (k = 0; k < c->dim; k++)
|
||||
c->x[k] = ((double) random()) / MAX_RANDOM_VALUE;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Step 4: For each center c, let m(c) be mean of all points assigned */
|
||||
for (j = 0; j < numCenters; j++)
|
||||
/* Normalize if needed */
|
||||
if (normprocinfo != NULL)
|
||||
{
|
||||
vec = VectorArrayGet(newCenters, j);
|
||||
for (k = 0; k < dimensions; k++)
|
||||
vec->x[k] = 0.0;
|
||||
|
||||
centerCounts[j] = 0;
|
||||
for (j = 0; j < centers->length; j++)
|
||||
ApplyNorm(normprocinfo, collation, VectorArrayGet(centers, j));
|
||||
}
|
||||
|
||||
for (j = 0; j < numSamples; j++)
|
||||
{
|
||||
vec = VectorArrayGet(samples, j);
|
||||
closestCenter = closestCenters[j];
|
||||
|
||||
/* Increment sum and count of closest center */
|
||||
newCenter = VectorArrayGet(newCenters, closestCenter);
|
||||
for (k = 0; k < dimensions; k++)
|
||||
newCenter->x[k] += vec->x[k];
|
||||
|
||||
centerCounts[closestCenter] += 1;
|
||||
}
|
||||
|
||||
for (j = 0; j < numCenters; j++)
|
||||
{
|
||||
vec = VectorArrayGet(newCenters, j);
|
||||
|
||||
if (centerCounts[j] > 0)
|
||||
{
|
||||
for (k = 0; k < dimensions; k++)
|
||||
vec->x[k] /= centerCounts[j];
|
||||
}
|
||||
else
|
||||
{
|
||||
/* TODO Handle empty centers properly */
|
||||
for (k = 0; k < dimensions; k++)
|
||||
vec->x[k] = ((double) random()) / MAX_RANDOM_VALUE;
|
||||
}
|
||||
|
||||
/* Normalize if needed */
|
||||
if (normprocinfo != NULL)
|
||||
ApplyNorm(normprocinfo, collation, vec);
|
||||
}
|
||||
|
||||
/* Step 5 */
|
||||
for (j = 0; j < numCenters; j++)
|
||||
newcdist[j] = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(VectorArrayGet(centers, j)), PointerGetDatum(VectorArrayGet(newCenters, j))));
|
||||
|
||||
for (j = 0; j < numSamples; j++)
|
||||
{
|
||||
for (k = 0; k < numCenters; k++)
|
||||
{
|
||||
distance = lowerBound[j * numCenters + k] - newcdist[k];
|
||||
|
||||
if (distance < 0)
|
||||
distance = 0;
|
||||
|
||||
lowerBound[j * numCenters + k] = distance;
|
||||
}
|
||||
}
|
||||
|
||||
/* Step 6 */
|
||||
/* We reset r(x) before Step 3 in the next iteration */
|
||||
for (j = 0; j < numSamples; j++)
|
||||
upperBound[j] += newcdist[closestCenters[j]];
|
||||
|
||||
/* Step 7 */
|
||||
for (j = 0; j < numCenters; j++)
|
||||
memcpy(VectorArrayGet(centers, j), VectorArrayGet(newCenters, j), VECTOR_SIZE(dimensions));
|
||||
|
||||
if (changes == 0 && iteration != 0)
|
||||
break;
|
||||
}
|
||||
|
||||
pfree(newCenters);
|
||||
pfree(centerCounts);
|
||||
pfree(closestCenters);
|
||||
pfree(lowerBound);
|
||||
pfree(upperBound);
|
||||
pfree(s);
|
||||
pfree(halfcdist);
|
||||
pfree(newcdist);
|
||||
pfree(v);
|
||||
pfree(d);
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -491,16 +398,47 @@ CheckCenters(Relation index, VectorArray centers)
|
||||
}
|
||||
|
||||
/*
|
||||
* Perform naive k-means centering
|
||||
* Perform k-means clustering
|
||||
* We use spherical k-means for inner product and cosine
|
||||
*/
|
||||
void
|
||||
IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers)
|
||||
IvfflatKmeans(IvfflatBuildState * buildstate)
|
||||
{
|
||||
if (samples->length <= centers->maxlen)
|
||||
QuickCenters(index, samples, centers);
|
||||
else
|
||||
ElkanKmeans(index, samples, centers);
|
||||
int numSamples;
|
||||
|
||||
CheckCenters(index, centers);
|
||||
/* Target 10 samples per list, with at least 10000 samples */
|
||||
/* The number of samples has a large effect on index build time */
|
||||
numSamples = buildstate->lists * 10;
|
||||
if (numSamples < 10000)
|
||||
numSamples = 10000;
|
||||
|
||||
/* Skip samples for unlogged table */
|
||||
if (buildstate->heap == NULL)
|
||||
numSamples = 1;
|
||||
|
||||
/* Calculate total size */
|
||||
Size totalSize = VECTOR_ARRAY_SIZE(numSamples, buildstate->dimensions);
|
||||
|
||||
/* Check memory requirements */
|
||||
/* Add one to error message to ceil */
|
||||
if (totalSize / 1024 > maintenance_work_mem)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
|
||||
errmsg("memory required is %zu MB, maintenance_work_mem is %d MB",
|
||||
totalSize / (1024 * 1024) + 1, maintenance_work_mem / 1024)));
|
||||
|
||||
/* Sample rows */
|
||||
buildstate->samples = VectorArrayInit(numSamples, buildstate->dimensions);
|
||||
if (buildstate->heap != NULL)
|
||||
SampleRows(buildstate);
|
||||
|
||||
if (buildstate->samples->length <= buildstate->centers->maxlen)
|
||||
QuickCenters(buildstate->index, buildstate->samples, buildstate->centers);
|
||||
else
|
||||
MiniBatchKmeans(buildstate);
|
||||
|
||||
CheckCenters(buildstate->index, buildstate->centers);
|
||||
|
||||
/* Free samples before we allocate more memory */
|
||||
pfree(buildstate->samples);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user